diff --git a/src/lib.rs b/src/lib.rs index f0108f0..bf37dc2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,3 +3,4 @@ mod common; pub mod tpm2; +pub mod tpm2net; diff --git a/src/tpm2net.rs b/src/tpm2net.rs new file mode 100644 index 0000000..3786b97 --- /dev/null +++ b/src/tpm2net.rs @@ -0,0 +1,119 @@ +pub const PACKET_START_BYTE: u8 = 0x9C; + +pub const PAYLOAD_SIZE_MAX: usize = 1490; + +pub use crate::common::{PacketType, PACKET_END_BYTE}; + +#[derive(PartialEq)] +#[non_exhaustive] +pub enum Tpm2NetError { + IndexOutOfBounds, + PayloadTooLarge, +} + +pub struct Packet<'a> { + packet_type: PacketType, + index: u8, + total: u8, + payload: &'a [u8], +} + +impl<'a> Packet<'a> { + pub fn new(packet_type: PacketType) -> Self { + Packet { + packet_type, + index: 1, + total: 1, + payload: &[] as &[u8], + } + } + + pub fn with_payload(packet_type: PacketType, payload: &'a [u8]) -> Result { + Self::with_payload_and_index(packet_type, payload, 1, 1) + } + + pub fn with_payload_and_index( + packet_type: PacketType, + payload: &'a [u8], + index: u8, + total: u8, + ) -> Result { + if payload.len() > PAYLOAD_SIZE_MAX { + return Err(Tpm2NetError::PayloadTooLarge); + } + if index == 0 { + return Err(Tpm2NetError::IndexOutOfBounds); + } + if total == 0 { + return Err(Tpm2NetError::IndexOutOfBounds); + } + if total < index { + return Err(Tpm2NetError::IndexOutOfBounds); + } + + Ok(Packet { + packet_type, + index, + total, + payload, + }) + } + + pub fn packet_type(&self) -> PacketType { + self.packet_type + } + + pub fn size(&self) -> usize { + self.payload.len() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn empty_packet() { + let packet = Packet::new(PacketType::Command); + assert_eq!(packet.size(), 0); + } + + #[test] + fn packet_with_payload() { + let payload: [u8; 4] = [0xca, 0xfe, 0xba, 0xbe]; + let result = Packet::with_payload(PacketType::Data, &payload); + assert!(result.is_ok()); + } + + #[test] + fn payload_too_large() { + let payload = [0u8; PAYLOAD_SIZE_MAX + 1]; + let result = Packet::with_payload(PacketType::Response, &payload); + assert!(result.is_err()); + assert!(result.err().unwrap() == Tpm2NetError::PayloadTooLarge); + } + + #[test] + fn index_is_zero() { + let payload: [u8; 4] = [0xca, 0xfe, 0xba, 0xbe]; + let result = Packet::with_payload_and_index(PacketType::Data, &payload, 0, 1); + assert!(result.is_err()); + assert!(result.err().unwrap() == Tpm2NetError::IndexOutOfBounds); + } + + #[test] + fn total_is_zero() { + let payload: [u8; 4] = [0xca, 0xfe, 0xba, 0xbe]; + let result = Packet::with_payload_and_index(PacketType::Data, &payload, 1, 0); + assert!(result.is_err()); + assert!(result.err().unwrap() == Tpm2NetError::IndexOutOfBounds); + } + + #[test] + fn total_is_less_than_index() { + let payload: [u8; 4] = [0xca, 0xfe, 0xba, 0xbe]; + let result = Packet::with_payload_and_index(PacketType::Data, &payload, 43, 42); + assert!(result.is_err()); + assert!(result.err().unwrap() == Tpm2NetError::IndexOutOfBounds); + } +}