Skip to content

Commit

Permalink
Revert "Revert "Merge NetlinkPayload::{Ack,Error}""
Browse files Browse the repository at this point in the history
This reverts commit 16300f5.

Signed-off-by: Gris Ge <[email protected]>
  • Loading branch information
cathay4t committed Jul 9, 2023
1 parent 0e486c3 commit 2c88cc3
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 22 deletions.
79 changes: 69 additions & 10 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// SPDX-License-Identifier: MIT

use std::{fmt, io, mem::size_of};
use std::{fmt, io, mem::size_of, num::NonZeroI32};

use byteorder::{ByteOrder, NativeEndian};
use netlink_packet_utils::DecodeError;
Expand Down Expand Up @@ -46,10 +46,14 @@ impl<T: AsRef<[u8]>> ErrorBuffer<T> {
}
}

/// Return the error code
pub fn code(&self) -> i32 {
/// Return the error code.
///
/// Returns `None` when there is no error to report (the message is an ACK),
/// or a `Some(e)` if there is a non-zero error code `e` to report (the
/// message is a NACK).
pub fn code(&self) -> Option<NonZeroI32> {
let data = self.buffer.as_ref();
NativeEndian::read_i32(&data[CODE])
NonZeroI32::new(NativeEndian::read_i32(&data[CODE]))
}
}

Expand Down Expand Up @@ -77,22 +81,36 @@ impl<T: AsRef<[u8]> + AsMut<[u8]>> ErrorBuffer<T> {
}
}

/// An `NLMSG_ERROR` message.
///
/// Per [RFC 3549 section 2.3.2.2], this message carries the return code for a
/// request which will indicate either success (an ACK) or failure (a NACK).
///
/// [RFC 3549 section 2.3.2.2]: https://datatracker.ietf.org/doc/html/rfc3549#section-2.3.2.2
#[derive(Debug, Default, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub struct ErrorMessage {
pub code: i32,
/// The error code.
///
/// Holds `None` when there is no error to report (the message is an ACK),
/// or a `Some(e)` if there is a non-zero error code `e` to report (the
/// message is a NACK).
///
/// See [Netlink message types] for details.
///
/// [Netlink message types]: https://kernel.org/doc/html/next/userspace-api/netlink/intro.html#netlink-message-types
pub code: Option<NonZeroI32>,
/// The original request's header.
pub header: Vec<u8>,
}

pub type AckMessage = ErrorMessage;

impl Emitable for ErrorMessage {
fn buffer_len(&self) -> usize {
size_of::<i32>() + self.header.len()
}
fn emit(&self, buffer: &mut [u8]) {
let mut buffer = ErrorBuffer::new(buffer);
buffer.set_code(self.code);
buffer.set_code(self.raw_code());
buffer.payload_mut().copy_from_slice(&self.header)
}
}
Expand All @@ -119,13 +137,18 @@ impl<'buffer, T: AsRef<[u8]> + 'buffer> Parseable<ErrorBuffer<&'buffer T>>
}

impl ErrorMessage {
/// Returns the raw error code.
pub fn raw_code(&self) -> i32 {
self.code.map_or(0, NonZeroI32::get)
}

/// According to [`netlink(7)`](https://linux.die.net/man/7/netlink)
/// the `NLMSG_ERROR` return Negative errno or 0 for acknowledgements.
///
/// convert into [`std::io::Error`](https://doc.rust-lang.org/std/io/struct.Error.html)
/// using the absolute value from errno code
pub fn to_io(&self) -> io::Error {
io::Error::from_raw_os_error(self.code.abs())
io::Error::from_raw_os_error(self.raw_code().abs())
}
}

Expand All @@ -149,7 +172,7 @@ mod tests {
fn into_io_error() {
let io_err = io::Error::from_raw_os_error(95);
let err_msg = ErrorMessage {
code: -95,
code: NonZeroI32::new(-95),
header: vec![],
};

Expand All @@ -158,4 +181,40 @@ mod tests {
assert_eq!(err_msg.to_string(), io_err.to_string());
assert_eq!(to_io.raw_os_error(), io_err.raw_os_error());
}

#[test]
fn parse_ack() {
let bytes = vec![0, 0, 0, 0];
let msg = ErrorBuffer::new_checked(&bytes)
.and_then(|buf| ErrorMessage::parse(&buf))
.expect("failed to parse NLMSG_ERROR");
assert_eq!(
ErrorMessage {
code: None,
header: Vec::new()
},
msg
);
assert_eq!(msg.raw_code(), 0);
}

#[test]
fn parse_nack() {
// SAFETY: value is non-zero.
const ERROR_CODE: NonZeroI32 =
unsafe { NonZeroI32::new_unchecked(-1234) };
let mut bytes = vec![0, 0, 0, 0];
NativeEndian::write_i32(&mut bytes, ERROR_CODE.get());
let msg = ErrorBuffer::new_checked(&bytes)
.and_then(|buf| ErrorMessage::parse(&buf))
.expect("failed to parse NLMSG_ERROR");
assert_eq!(
ErrorMessage {
code: Some(ERROR_CODE),
header: Vec::new()
},
msg
);
assert_eq!(msg.raw_code(), ERROR_CODE.get());
}
}
42 changes: 33 additions & 9 deletions src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use netlink_packet_utils::DecodeError;

use crate::{
payload::{NLMSG_DONE, NLMSG_ERROR, NLMSG_NOOP, NLMSG_OVERRUN},
AckMessage, DoneBuffer, DoneMessage, Emitable, ErrorBuffer, ErrorMessage,
DoneBuffer, DoneMessage, Emitable, ErrorBuffer, ErrorMessage,
NetlinkBuffer, NetlinkDeserializable, NetlinkHeader, NetlinkPayload,
NetlinkSerializable, Parseable,
};
Expand Down Expand Up @@ -101,11 +101,7 @@ where
let msg = ErrorBuffer::new_checked(&bytes)
.and_then(|buf| ErrorMessage::parse(&buf))
.context("failed to parse NLMSG_ERROR")?;
if msg.code >= 0 {
Ack(msg as AckMessage)
} else {
Error(msg)
}
Error(msg)
}
NLMSG_NOOP => Noop,
NLMSG_DONE => {
Expand Down Expand Up @@ -138,7 +134,6 @@ where
Done(ref msg) => msg.buffer_len(),
Overrun(ref bytes) => bytes.len(),
Error(ref msg) => msg.buffer_len(),
Ack(ref msg) => msg.buffer_len(),
InnerMessage(ref msg) => msg.buffer_len(),
};

Expand All @@ -157,7 +152,6 @@ where
Done(ref msg) => msg.emit(buffer),
Overrun(ref bytes) => buffer.copy_from_slice(bytes),
Error(ref msg) => msg.emit(buffer),
Ack(ref msg) => msg.emit(buffer),
InnerMessage(ref msg) => msg.serialize(buffer),
}
}
Expand All @@ -179,7 +173,7 @@ where
mod tests {
use super::*;

use std::{convert::Infallible, mem::size_of};
use std::{convert::Infallible, mem::size_of, num::NonZeroI32};

#[derive(Clone, Debug, Default, PartialEq)]
struct FakeNetlinkInnerMessage;
Expand Down Expand Up @@ -240,4 +234,34 @@ mod tests {
let got = NetlinkMessage::parse(&NetlinkBuffer::new(&buf)).unwrap();
assert_eq!(got, want);
}

#[test]
fn test_error() {
// SAFETY: value is non-zero.
const ERROR_CODE: NonZeroI32 =
unsafe { NonZeroI32::new_unchecked(-8765) };

let header = NetlinkHeader::default();
let error_msg = ErrorMessage {
code: Some(ERROR_CODE),
header: vec![],
};
let mut want = NetlinkMessage::new(
header,
NetlinkPayload::<FakeNetlinkInnerMessage>::Error(error_msg.clone()),
);
want.finalize();

let len = want.buffer_len();
assert_eq!(len, header.buffer_len() + error_msg.buffer_len());

let mut buf = vec![1; len];
want.emit(&mut buf);

let error_buf = ErrorBuffer::new(&buf[header.buffer_len()..]);
assert_eq!(error_buf.code(), error_msg.code);

let got = NetlinkMessage::parse(&NetlinkBuffer::new(&buf)).unwrap();
assert_eq!(got, want);
}
}
5 changes: 2 additions & 3 deletions src/payload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use std::fmt::Debug;

use crate::{AckMessage, DoneMessage, ErrorMessage, NetlinkSerializable};
use crate::{DoneMessage, ErrorMessage, NetlinkSerializable};

/// The message is ignored.
pub const NLMSG_NOOP: u16 = 1;
Expand All @@ -20,7 +20,6 @@ pub const NLMSG_ALIGNTO: u16 = 4;
pub enum NetlinkPayload<I> {
Done(DoneMessage),
Error(ErrorMessage),
Ack(AckMessage),
Noop,
Overrun(Vec<u8>),
InnerMessage(I),
Expand All @@ -33,7 +32,7 @@ where
pub fn message_type(&self) -> u16 {
match self {
NetlinkPayload::Done(_) => NLMSG_DONE,
NetlinkPayload::Error(_) | NetlinkPayload::Ack(_) => NLMSG_ERROR,
NetlinkPayload::Error(_) => NLMSG_ERROR,
NetlinkPayload::Noop => NLMSG_NOOP,
NetlinkPayload::Overrun(_) => NLMSG_OVERRUN,
NetlinkPayload::InnerMessage(message) => message.message_type(),
Expand Down

0 comments on commit 2c88cc3

Please sign in to comment.