diff --git a/src/errors.rs b/src/errors.rs index 5dcc64d..b3d8662 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -32,29 +32,35 @@ impl From for EncodeError { } #[derive(Debug, Error)] -#[error("Decode error occurred: {inner}")] -pub struct DecodeError { - inner: anyhow::Error, -} +pub enum DecodeError { + #[error( + "Invalid MAC address. Expected 6 bytes, received {received} bytes" + )] + InvalidMACAddress { received: usize }, -impl From<&'static str> for DecodeError { - fn from(msg: &'static str) -> Self { - DecodeError { - inner: anyhow!(msg), - } - } -} + #[error( + "Invalid IP address. Expected 4 or 16 bytes, received {received} bytes" + )] + InvalidIPAddress { received: usize }, -impl From for DecodeError { - fn from(msg: String) -> Self { - DecodeError { - inner: anyhow!(msg), - } - } -} + #[error("Invalid string")] + Utf8Error(#[from] std::string::FromUtf8Error), -impl From for DecodeError { - fn from(inner: anyhow::Error) -> DecodeError { - DecodeError { inner } - } + #[error( + "Invalid number. Expected {expected} bytes, received {received} bytes" + )] + InvalidNumber { expected: usize, received: usize }, + + #[error("Invalid buffer {name}. Expected at least {minimum_length} bytes, received {received} bytes")] + InvalidBuffer { + name: &'static str, + received: usize, + minimum_length: usize, + }, + + #[error(transparent)] + Nla(#[from] crate::nla::NlaError), + + #[error(transparent)] + Other(#[from] Box), } diff --git a/src/macros.rs b/src/macros.rs index 836b5cb..3c1bc43 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -192,15 +192,11 @@ macro_rules! buffer_check_length { fn check_buffer_length(&self) -> Result<(), DecodeError> { let len = self.buffer.as_ref().len(); if len < $buffer_len { - Err(format!( - concat!( - "invalid ", - stringify!($name), - ": length {} < {}" - ), - len, $buffer_len - ) - .into()) + Err(DecodeError::InvalidBuffer { + name: stringify!($name), + received: len, + minimum_length: $buffer_len, + }) } else { Ok(()) } diff --git a/src/nla.rs b/src/nla.rs index 96f7d79..3c9f4e4 100644 --- a/src/nla.rs +++ b/src/nla.rs @@ -1,14 +1,12 @@ -// SPDX-License-Identifier: MIT - -use core::ops::Range; - -use anyhow::Context; -use byteorder::{ByteOrder, NativeEndian}; +/ SPDX-License-Identifier: MIT use crate::{ traits::{Emitable, Parseable}, DecodeError, }; +use byteorder::{ByteOrder, NativeEndian}; +use core::ops::Range; +use thiserror::Error; /// Represent a multi-bytes field with a fixed size in a packet type Field = Range; @@ -25,6 +23,20 @@ pub const NLA_ALIGNTO: usize = 4; /// NlA(RTA) header size. (unsigned short rta_len) + (unsigned short rta_type) pub const NLA_HEADER_SIZE: usize = 4; +#[derive(Debug, Error)] +pub enum NlaError { + #[error("buffer has length {buffer_len}, but an NLA header is {} bytes", TYPE.end)] + BufferTooSmall { buffer_len: usize }, + + #[error("buffer has length: {buffer_len}, but the NLA is {nla_len} bytes")] + LengthMismatch { buffer_len: usize, nla_len: u16 }, + + #[error( + "NLA has invalid length: {nla_len} (should be at least {} bytes", TYPE.end + )] + InvalidLength { nla_len: u16 }, +} + #[macro_export] macro_rules! nla_align { ($len: expr) => { @@ -52,34 +64,25 @@ impl> NlaBuffer { NlaBuffer { buffer } } - pub fn new_checked(buffer: T) -> Result, DecodeError> { + pub fn new_checked(buffer: T) -> Result, NlaError> { let buffer = Self::new(buffer); - buffer.check_buffer_length().context("invalid NLA buffer")?; + buffer.check_buffer_length()?; Ok(buffer) } - pub fn check_buffer_length(&self) -> Result<(), DecodeError> { + pub fn check_buffer_length(&self) -> Result<(), NlaError> { let len = self.buffer.as_ref().len(); if len < TYPE.end { - Err(format!( - "buffer has length {}, but an NLA header is {} bytes", - len, TYPE.end - ) - .into()) + Err(NlaError::BufferTooSmall { buffer_len: len }) } else if len < self.length() as usize { - Err(format!( - "buffer has length: {}, but the NLA is {} bytes", - len, - self.length() - ) - .into()) + Err(NlaError::LengthMismatch { + buffer_len: len, + nla_len: self.length(), + }) } else if (self.length() as usize) < TYPE.end { - Err(format!( - "NLA has invalid length: {} (should be at least {} bytes", - self.length(), - TYPE.end, - ) - .into()) + Err(NlaError::InvalidLength { + nla_len: self.length(), + }) } else { Ok(()) } @@ -162,14 +165,14 @@ impl + AsMut<[u8]>> NlaBuffer { } } -impl<'buffer, T: AsRef<[u8]> + ?Sized> NlaBuffer<&'buffer T> { +impl + ?Sized> NlaBuffer<&T> { /// Return the `value` field pub fn value(&self) -> &[u8] { &self.buffer.as_ref()[VALUE(self.value_length())] } } -impl<'buffer, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> NlaBuffer<&'buffer mut T> { +impl + AsMut<[u8]> + ?Sized> NlaBuffer<&mut T> { /// Return the `value` field pub fn value_mut(&mut self) -> &mut [u8] { let length = VALUE(self.value_length()); @@ -204,7 +207,9 @@ impl Nla for DefaultNla { impl<'buffer, T: AsRef<[u8]> + ?Sized> Parseable> for DefaultNla { - fn parse(buf: &NlaBuffer<&'buffer T>) -> Result { + type Error = DecodeError; + + fn parse(buf: &NlaBuffer<&'buffer T>) -> Result { let mut kind = buf.kind(); if buf.network_byte_order_flag() { @@ -273,7 +278,7 @@ impl Emitable for T { // The reason this does not work today is because it conflicts with // // impl Emitable for T { ... } -impl<'a, T: Nla> Emitable for &'a [T] { +impl Emitable for &[T] { fn buffer_len(&self) -> usize { self.iter().fold(0, |acc, nla| { assert_eq!(nla.buffer_len() % NLA_ALIGNTO, 0); @@ -314,7 +319,7 @@ impl NlasIterator { impl<'buffer, T: AsRef<[u8]> + ?Sized + 'buffer> Iterator for NlasIterator<&'buffer T> { - type Item = Result, DecodeError>; + type Item = Result, NlaError>; fn next(&mut self) -> Option { if self.position >= self.buffer.as_ref().len() { diff --git a/src/parsers.rs b/src/parsers.rs index f1198d3..cb6ade9 100644 --- a/src/parsers.rs +++ b/src/parsers.rs @@ -5,14 +5,15 @@ use std::{ net::{IpAddr, Ipv4Addr, Ipv6Addr}, }; -use anyhow::Context; use byteorder::{BigEndian, ByteOrder, NativeEndian}; use crate::DecodeError; pub fn parse_mac(payload: &[u8]) -> Result<[u8; 6], DecodeError> { if payload.len() != 6 { - return Err(format!("invalid MAC address: {payload:?}").into()); + return Err(DecodeError::InvalidMACAddress { + received: payload.len(), + }); } let mut address: [u8; 6] = [0; 6]; for (i, byte) in payload.iter().enumerate() { @@ -23,7 +24,9 @@ pub fn parse_mac(payload: &[u8]) -> Result<[u8; 6], DecodeError> { pub fn parse_ipv6(payload: &[u8]) -> Result<[u8; 16], DecodeError> { if payload.len() != 16 { - return Err(format!("invalid IPv6 address: {payload:?}").into()); + return Err(DecodeError::InvalidIPAddress { + received: payload.len(), + }); } let mut address: [u8; 16] = [0; 16]; for (i, byte) in payload.iter().enumerate() { @@ -57,7 +60,7 @@ pub fn parse_ip(payload: &[u8]) -> Result { payload[15], ]) .into()), - _ => Err(format!("invalid IPv6 address: {payload:?}").into()), + other => Err(DecodeError::InvalidIPAddress { received: other }), } } @@ -71,62 +74,86 @@ pub fn parse_string(payload: &[u8]) -> Result { } else { &payload[..payload.len()] }; - let s = String::from_utf8(slice.to_vec()).context("invalid string")?; + let s = String::from_utf8(slice.to_vec())?; Ok(s) } pub fn parse_u8(payload: &[u8]) -> Result { if payload.len() != 1 { - return Err(format!("invalid u8: {payload:?}").into()); + return Err(DecodeError::InvalidNumber { + expected: 1, + received: payload.len(), + }); } Ok(payload[0]) } pub fn parse_u32(payload: &[u8]) -> Result { if payload.len() != size_of::() { - return Err(format!("invalid u32: {payload:?}").into()); + return Err(DecodeError::InvalidNumber { + expected: size_of::(), + received: payload.len(), + }); } Ok(NativeEndian::read_u32(payload)) } pub fn parse_u64(payload: &[u8]) -> Result { if payload.len() != size_of::() { - return Err(format!("invalid u64: {payload:?}").into()); + return Err(DecodeError::InvalidNumber { + expected: size_of::(), + received: payload.len(), + }); } Ok(NativeEndian::read_u64(payload)) } pub fn parse_u128(payload: &[u8]) -> Result { if payload.len() != size_of::() { - return Err(format!("invalid u128: {payload:?}").into()); + return Err(DecodeError::InvalidNumber { + expected: size_of::(), + received: payload.len(), + }); } Ok(NativeEndian::read_u128(payload)) } pub fn parse_u16(payload: &[u8]) -> Result { if payload.len() != size_of::() { - return Err(format!("invalid u16: {payload:?}").into()); + return Err(DecodeError::InvalidNumber { + expected: size_of::(), + received: payload.len(), + }); } Ok(NativeEndian::read_u16(payload)) } pub fn parse_i32(payload: &[u8]) -> Result { if payload.len() != 4 { - return Err(format!("invalid u32: {payload:?}").into()); + return Err(DecodeError::InvalidNumber { + expected: 4, + received: payload.len(), + }); } Ok(NativeEndian::read_i32(payload)) } pub fn parse_u16_be(payload: &[u8]) -> Result { if payload.len() != size_of::() { - return Err(format!("invalid u16: {payload:?}").into()); + return Err(DecodeError::InvalidNumber { + expected: size_of::(), + received: payload.len(), + }); } Ok(BigEndian::read_u16(payload)) } pub fn parse_u32_be(payload: &[u8]) -> Result { if payload.len() != size_of::() { - return Err(format!("invalid u32: {payload:?}").into()); + return Err(DecodeError::InvalidNumber { + expected: size_of::(), + received: payload.len(), + }); } Ok(BigEndian::read_u32(payload)) } diff --git a/src/traits.rs b/src/traits.rs index 89c1bed..855dc60 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -1,7 +1,5 @@ // SPDX-License-Identifier: MIT -use crate::DecodeError; - /// A type that implements `Emitable` can be serialized. pub trait Emitable { /// Return the length of the serialized data. @@ -26,8 +24,10 @@ where Self: Sized, T: ?Sized, { + type Error; + /// Deserialize the current type. - fn parse(buf: &T) -> Result; + fn parse(buf: &T) -> Result; } /// A `Parseable` type can be used to deserialize data from the type `T` for @@ -37,6 +37,8 @@ where Self: Sized, T: ?Sized, { + type Error; + /// Deserialize the current type. - fn parse_with_param(buf: &T, params: P) -> Result; + fn parse_with_param(buf: &T, params: P) -> Result; }