Skip to content

Commit

Permalink
Use thiserror Errors instead of relying on anyhow
Browse files Browse the repository at this point in the history
This improves matching on particular errors when we need to handle
different conditions downstream.

It's still possible to convert a anyhow::Error to a DecodeError in this
change, but every other error this crate expsoes is now in a variant.
  • Loading branch information
miguelfrde committed Jan 3, 2025
1 parent 97c4a4b commit 0135bf2
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 79 deletions.
50 changes: 28 additions & 22 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,29 +32,35 @@ impl From<anyhow::Error> for EncodeError {
}

#[derive(Debug, Error)]
#[error("Decode error occurred: {inner}")]
pub struct DecodeError {
inner: anyhow::Error,
}
pub enum DecodeError {
#[error(
"Invalid MAC address. Expected 6 bytes, received {received} bytes"
)]
InvalidMACAddress { received: usize },

impl From<&'static str> for DecodeError {
fn from(msg: &'static str) -> Self {
DecodeError {
inner: anyhow!(msg),
}
}
}
#[error(
"Invalid IP address. Expected 4 or 16 bytes, received {received} bytes"
)]
InvalidIPAddress { received: usize },

impl From<String> for DecodeError {
fn from(msg: String) -> Self {
DecodeError {
inner: anyhow!(msg),
}
}
}
#[error("Invalid string")]
Utf8Error(#[from] std::string::FromUtf8Error),

impl From<anyhow::Error> for DecodeError {
fn from(inner: anyhow::Error) -> DecodeError {
DecodeError { inner }
}
#[error(
"Invalid number. Expected {expected} bytes, received {received} bytes"
)]
InvalidNumber { expected: usize, received: usize },

#[error("Invalid buffer {name}. Expected at least {minimum_length} bytes, received {received} bytes")]
InvalidBuffer {
name: &'static str,
received: usize,
minimum_length: usize,
},

#[error(transparent)]
Nla(#[from] crate::nla::NlaError),

#[error(transparent)]
Other(#[from] Box<dyn std::error::Error>),
}
14 changes: 5 additions & 9 deletions src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,15 +192,11 @@ macro_rules! buffer_check_length {
fn check_buffer_length(&self) -> Result<(), DecodeError> {
let len = self.buffer.as_ref().len();
if len < $buffer_len {
Err(format!(
concat!(
"invalid ",
stringify!($name),
": length {} < {}"
),
len, $buffer_len
)
.into())
Err(DecodeError::InvalidBuffer {
name: stringify!($name),
received: len,
minimum_length: $buffer_len,
})
} else {
Ok(())
}
Expand Down
67 changes: 36 additions & 31 deletions src/nla.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
// SPDX-License-Identifier: MIT

use core::ops::Range;

use anyhow::Context;
use byteorder::{ByteOrder, NativeEndian};
/ SPDX-License-Identifier: MIT

use crate::{
traits::{Emitable, Parseable},
DecodeError,
};
use byteorder::{ByteOrder, NativeEndian};
use core::ops::Range;
use thiserror::Error;

/// Represent a multi-bytes field with a fixed size in a packet
type Field = Range<usize>;
Expand All @@ -25,6 +23,20 @@ pub const NLA_ALIGNTO: usize = 4;
/// NlA(RTA) header size. (unsigned short rta_len) + (unsigned short rta_type)
pub const NLA_HEADER_SIZE: usize = 4;

#[derive(Debug, Error)]
pub enum NlaError {
#[error("buffer has length {buffer_len}, but an NLA header is {} bytes", TYPE.end)]
BufferTooSmall { buffer_len: usize },

#[error("buffer has length: {buffer_len}, but the NLA is {nla_len} bytes")]
LengthMismatch { buffer_len: usize, nla_len: u16 },

#[error(
"NLA has invalid length: {nla_len} (should be at least {} bytes", TYPE.end
)]
InvalidLength { nla_len: u16 },
}

#[macro_export]
macro_rules! nla_align {
($len: expr) => {
Expand Down Expand Up @@ -52,34 +64,25 @@ impl<T: AsRef<[u8]>> NlaBuffer<T> {
NlaBuffer { buffer }
}

pub fn new_checked(buffer: T) -> Result<NlaBuffer<T>, DecodeError> {
pub fn new_checked(buffer: T) -> Result<NlaBuffer<T>, NlaError> {
let buffer = Self::new(buffer);
buffer.check_buffer_length().context("invalid NLA buffer")?;
buffer.check_buffer_length()?;
Ok(buffer)
}

pub fn check_buffer_length(&self) -> Result<(), DecodeError> {
pub fn check_buffer_length(&self) -> Result<(), NlaError> {
let len = self.buffer.as_ref().len();
if len < TYPE.end {
Err(format!(
"buffer has length {}, but an NLA header is {} bytes",
len, TYPE.end
)
.into())
Err(NlaError::BufferTooSmall { buffer_len: len })
} else if len < self.length() as usize {
Err(format!(
"buffer has length: {}, but the NLA is {} bytes",
len,
self.length()
)
.into())
Err(NlaError::LengthMismatch {
buffer_len: len,
nla_len: self.length(),
})
} else if (self.length() as usize) < TYPE.end {
Err(format!(
"NLA has invalid length: {} (should be at least {} bytes",
self.length(),
TYPE.end,
)
.into())
Err(NlaError::InvalidLength {
nla_len: self.length(),
})
} else {
Ok(())
}
Expand Down Expand Up @@ -162,14 +165,14 @@ impl<T: AsRef<[u8]> + AsMut<[u8]>> NlaBuffer<T> {
}
}

impl<'buffer, T: AsRef<[u8]> + ?Sized> NlaBuffer<&'buffer T> {
impl<T: AsRef<[u8]> + ?Sized> NlaBuffer<&T> {
/// Return the `value` field
pub fn value(&self) -> &[u8] {
&self.buffer.as_ref()[VALUE(self.value_length())]
}
}

impl<'buffer, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> NlaBuffer<&'buffer mut T> {
impl<T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> NlaBuffer<&mut T> {
/// Return the `value` field
pub fn value_mut(&mut self) -> &mut [u8] {
let length = VALUE(self.value_length());
Expand Down Expand Up @@ -204,7 +207,9 @@ impl Nla for DefaultNla {
impl<'buffer, T: AsRef<[u8]> + ?Sized> Parseable<NlaBuffer<&'buffer T>>
for DefaultNla
{
fn parse(buf: &NlaBuffer<&'buffer T>) -> Result<Self, DecodeError> {
type Error = DecodeError;

fn parse(buf: &NlaBuffer<&'buffer T>) -> Result<Self, Self::Error> {
let mut kind = buf.kind();

if buf.network_byte_order_flag() {
Expand Down Expand Up @@ -273,7 +278,7 @@ impl<T: Nla> Emitable for T {
// The reason this does not work today is because it conflicts with
//
// impl<T: Nla> Emitable for T { ... }
impl<'a, T: Nla> Emitable for &'a [T] {
impl<T: Nla> Emitable for &[T] {
fn buffer_len(&self) -> usize {
self.iter().fold(0, |acc, nla| {
assert_eq!(nla.buffer_len() % NLA_ALIGNTO, 0);
Expand Down Expand Up @@ -314,7 +319,7 @@ impl<T> NlasIterator<T> {
impl<'buffer, T: AsRef<[u8]> + ?Sized + 'buffer> Iterator
for NlasIterator<&'buffer T>
{
type Item = Result<NlaBuffer<&'buffer [u8]>, DecodeError>;
type Item = Result<NlaBuffer<&'buffer [u8]>, NlaError>;

fn next(&mut self) -> Option<Self::Item> {
if self.position >= self.buffer.as_ref().len() {
Expand Down
53 changes: 40 additions & 13 deletions src/parsers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@ use std::{
net::{IpAddr, Ipv4Addr, Ipv6Addr},
};

use anyhow::Context;
use byteorder::{BigEndian, ByteOrder, NativeEndian};

use crate::DecodeError;

pub fn parse_mac(payload: &[u8]) -> Result<[u8; 6], DecodeError> {
if payload.len() != 6 {
return Err(format!("invalid MAC address: {payload:?}").into());
return Err(DecodeError::InvalidMACAddress {
received: payload.len(),
});
}
let mut address: [u8; 6] = [0; 6];
for (i, byte) in payload.iter().enumerate() {
Expand All @@ -23,7 +24,9 @@ pub fn parse_mac(payload: &[u8]) -> Result<[u8; 6], DecodeError> {

pub fn parse_ipv6(payload: &[u8]) -> Result<[u8; 16], DecodeError> {
if payload.len() != 16 {
return Err(format!("invalid IPv6 address: {payload:?}").into());
return Err(DecodeError::InvalidIPAddress {
received: payload.len(),
});
}
let mut address: [u8; 16] = [0; 16];
for (i, byte) in payload.iter().enumerate() {
Expand Down Expand Up @@ -57,7 +60,7 @@ pub fn parse_ip(payload: &[u8]) -> Result<IpAddr, DecodeError> {
payload[15],
])
.into()),
_ => Err(format!("invalid IPv6 address: {payload:?}").into()),
other => Err(DecodeError::InvalidIPAddress { received: other }),
}
}

Expand All @@ -71,62 +74,86 @@ pub fn parse_string(payload: &[u8]) -> Result<String, DecodeError> {
} else {
&payload[..payload.len()]
};
let s = String::from_utf8(slice.to_vec()).context("invalid string")?;
let s = String::from_utf8(slice.to_vec())?;
Ok(s)
}

pub fn parse_u8(payload: &[u8]) -> Result<u8, DecodeError> {
if payload.len() != 1 {
return Err(format!("invalid u8: {payload:?}").into());
return Err(DecodeError::InvalidNumber {
expected: 1,
received: payload.len(),
});
}
Ok(payload[0])
}

pub fn parse_u32(payload: &[u8]) -> Result<u32, DecodeError> {
if payload.len() != size_of::<u32>() {
return Err(format!("invalid u32: {payload:?}").into());
return Err(DecodeError::InvalidNumber {
expected: size_of::<u32>(),
received: payload.len(),
});
}
Ok(NativeEndian::read_u32(payload))
}

pub fn parse_u64(payload: &[u8]) -> Result<u64, DecodeError> {
if payload.len() != size_of::<u64>() {
return Err(format!("invalid u64: {payload:?}").into());
return Err(DecodeError::InvalidNumber {
expected: size_of::<u64>(),
received: payload.len(),
});
}
Ok(NativeEndian::read_u64(payload))
}

pub fn parse_u128(payload: &[u8]) -> Result<u128, DecodeError> {
if payload.len() != size_of::<u128>() {
return Err(format!("invalid u128: {payload:?}").into());
return Err(DecodeError::InvalidNumber {
expected: size_of::<u128>(),
received: payload.len(),
});
}
Ok(NativeEndian::read_u128(payload))
}

pub fn parse_u16(payload: &[u8]) -> Result<u16, DecodeError> {
if payload.len() != size_of::<u16>() {
return Err(format!("invalid u16: {payload:?}").into());
return Err(DecodeError::InvalidNumber {
expected: size_of::<u16>(),
received: payload.len(),
});
}
Ok(NativeEndian::read_u16(payload))
}

pub fn parse_i32(payload: &[u8]) -> Result<i32, DecodeError> {
if payload.len() != 4 {
return Err(format!("invalid u32: {payload:?}").into());
return Err(DecodeError::InvalidNumber {
expected: 4,
received: payload.len(),
});
}
Ok(NativeEndian::read_i32(payload))
}

pub fn parse_u16_be(payload: &[u8]) -> Result<u16, DecodeError> {
if payload.len() != size_of::<u16>() {
return Err(format!("invalid u16: {payload:?}").into());
return Err(DecodeError::InvalidNumber {
expected: size_of::<u16>(),
received: payload.len(),
});
}
Ok(BigEndian::read_u16(payload))
}

pub fn parse_u32_be(payload: &[u8]) -> Result<u32, DecodeError> {
if payload.len() != size_of::<u32>() {
return Err(format!("invalid u32: {payload:?}").into());
return Err(DecodeError::InvalidNumber {
expected: size_of::<u32>(),
received: payload.len(),
});
}
Ok(BigEndian::read_u32(payload))
}
10 changes: 6 additions & 4 deletions src/traits.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
// SPDX-License-Identifier: MIT

use crate::DecodeError;

/// A type that implements `Emitable` can be serialized.
pub trait Emitable {
/// Return the length of the serialized data.
Expand All @@ -26,8 +24,10 @@ where
Self: Sized,
T: ?Sized,
{
type Error;

/// Deserialize the current type.
fn parse(buf: &T) -> Result<Self, DecodeError>;
fn parse(buf: &T) -> Result<Self, Self::Error>;
}

/// A `Parseable` type can be used to deserialize data from the type `T` for
Expand All @@ -37,6 +37,8 @@ where
Self: Sized,
T: ?Sized,
{
type Error;

/// Deserialize the current type.
fn parse_with_param(buf: &T, params: P) -> Result<Self, DecodeError>;
fn parse_with_param(buf: &T, params: P) -> Result<Self, Self::Error>;
}

0 comments on commit 0135bf2

Please sign in to comment.