From 2c88cc3a8004505f6c3054637f63cb408c99cdf2 Mon Sep 17 00:00:00 2001 From: Gris Ge Date: Sun, 9 Jul 2023 15:07:47 +0800 Subject: [PATCH] Revert "Revert "Merge `NetlinkPayload::{Ack,Error}`"" This reverts commit 16300f56d390b8128fdf37a86b0c0e2bc9d019e3. Signed-off-by: Gris Ge --- src/error.rs | 79 +++++++++++++++++++++++++++++++++++++++++++------- src/message.rs | 42 +++++++++++++++++++++------ src/payload.rs | 5 ++-- 3 files changed, 104 insertions(+), 22 deletions(-) diff --git a/src/error.rs b/src/error.rs index 3f4f5a5..f7951f7 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,6 +1,6 @@ // SPDX-License-Identifier: MIT -use std::{fmt, io, mem::size_of}; +use std::{fmt, io, mem::size_of, num::NonZeroI32}; use byteorder::{ByteOrder, NativeEndian}; use netlink_packet_utils::DecodeError; @@ -46,10 +46,14 @@ impl> ErrorBuffer { } } - /// Return the error code - pub fn code(&self) -> i32 { + /// Return the error code. + /// + /// Returns `None` when there is no error to report (the message is an ACK), + /// or a `Some(e)` if there is a non-zero error code `e` to report (the + /// message is a NACK). + pub fn code(&self) -> Option { let data = self.buffer.as_ref(); - NativeEndian::read_i32(&data[CODE]) + NonZeroI32::new(NativeEndian::read_i32(&data[CODE])) } } @@ -77,22 +81,36 @@ impl + AsMut<[u8]>> ErrorBuffer { } } +/// An `NLMSG_ERROR` message. +/// +/// Per [RFC 3549 section 2.3.2.2], this message carries the return code for a +/// request which will indicate either success (an ACK) or failure (a NACK). +/// +/// [RFC 3549 section 2.3.2.2]: https://datatracker.ietf.org/doc/html/rfc3549#section-2.3.2.2 #[derive(Debug, Default, Clone, PartialEq, Eq)] #[non_exhaustive] pub struct ErrorMessage { - pub code: i32, + /// The error code. + /// + /// Holds `None` when there is no error to report (the message is an ACK), + /// or a `Some(e)` if there is a non-zero error code `e` to report (the + /// message is a NACK). + /// + /// See [Netlink message types] for details. + /// + /// [Netlink message types]: https://kernel.org/doc/html/next/userspace-api/netlink/intro.html#netlink-message-types + pub code: Option, + /// The original request's header. pub header: Vec, } -pub type AckMessage = ErrorMessage; - impl Emitable for ErrorMessage { fn buffer_len(&self) -> usize { size_of::() + self.header.len() } fn emit(&self, buffer: &mut [u8]) { let mut buffer = ErrorBuffer::new(buffer); - buffer.set_code(self.code); + buffer.set_code(self.raw_code()); buffer.payload_mut().copy_from_slice(&self.header) } } @@ -119,13 +137,18 @@ impl<'buffer, T: AsRef<[u8]> + 'buffer> Parseable> } impl ErrorMessage { + /// Returns the raw error code. + pub fn raw_code(&self) -> i32 { + self.code.map_or(0, NonZeroI32::get) + } + /// According to [`netlink(7)`](https://linux.die.net/man/7/netlink) /// the `NLMSG_ERROR` return Negative errno or 0 for acknowledgements. /// /// convert into [`std::io::Error`](https://doc.rust-lang.org/std/io/struct.Error.html) /// using the absolute value from errno code pub fn to_io(&self) -> io::Error { - io::Error::from_raw_os_error(self.code.abs()) + io::Error::from_raw_os_error(self.raw_code().abs()) } } @@ -149,7 +172,7 @@ mod tests { fn into_io_error() { let io_err = io::Error::from_raw_os_error(95); let err_msg = ErrorMessage { - code: -95, + code: NonZeroI32::new(-95), header: vec![], }; @@ -158,4 +181,40 @@ mod tests { assert_eq!(err_msg.to_string(), io_err.to_string()); assert_eq!(to_io.raw_os_error(), io_err.raw_os_error()); } + + #[test] + fn parse_ack() { + let bytes = vec![0, 0, 0, 0]; + let msg = ErrorBuffer::new_checked(&bytes) + .and_then(|buf| ErrorMessage::parse(&buf)) + .expect("failed to parse NLMSG_ERROR"); + assert_eq!( + ErrorMessage { + code: None, + header: Vec::new() + }, + msg + ); + assert_eq!(msg.raw_code(), 0); + } + + #[test] + fn parse_nack() { + // SAFETY: value is non-zero. + const ERROR_CODE: NonZeroI32 = + unsafe { NonZeroI32::new_unchecked(-1234) }; + let mut bytes = vec![0, 0, 0, 0]; + NativeEndian::write_i32(&mut bytes, ERROR_CODE.get()); + let msg = ErrorBuffer::new_checked(&bytes) + .and_then(|buf| ErrorMessage::parse(&buf)) + .expect("failed to parse NLMSG_ERROR"); + assert_eq!( + ErrorMessage { + code: Some(ERROR_CODE), + header: Vec::new() + }, + msg + ); + assert_eq!(msg.raw_code(), ERROR_CODE.get()); + } } diff --git a/src/message.rs b/src/message.rs index 80fba79..4bc7dda 100644 --- a/src/message.rs +++ b/src/message.rs @@ -7,7 +7,7 @@ use netlink_packet_utils::DecodeError; use crate::{ payload::{NLMSG_DONE, NLMSG_ERROR, NLMSG_NOOP, NLMSG_OVERRUN}, - AckMessage, DoneBuffer, DoneMessage, Emitable, ErrorBuffer, ErrorMessage, + DoneBuffer, DoneMessage, Emitable, ErrorBuffer, ErrorMessage, NetlinkBuffer, NetlinkDeserializable, NetlinkHeader, NetlinkPayload, NetlinkSerializable, Parseable, }; @@ -101,11 +101,7 @@ where let msg = ErrorBuffer::new_checked(&bytes) .and_then(|buf| ErrorMessage::parse(&buf)) .context("failed to parse NLMSG_ERROR")?; - if msg.code >= 0 { - Ack(msg as AckMessage) - } else { - Error(msg) - } + Error(msg) } NLMSG_NOOP => Noop, NLMSG_DONE => { @@ -138,7 +134,6 @@ where Done(ref msg) => msg.buffer_len(), Overrun(ref bytes) => bytes.len(), Error(ref msg) => msg.buffer_len(), - Ack(ref msg) => msg.buffer_len(), InnerMessage(ref msg) => msg.buffer_len(), }; @@ -157,7 +152,6 @@ where Done(ref msg) => msg.emit(buffer), Overrun(ref bytes) => buffer.copy_from_slice(bytes), Error(ref msg) => msg.emit(buffer), - Ack(ref msg) => msg.emit(buffer), InnerMessage(ref msg) => msg.serialize(buffer), } } @@ -179,7 +173,7 @@ where mod tests { use super::*; - use std::{convert::Infallible, mem::size_of}; + use std::{convert::Infallible, mem::size_of, num::NonZeroI32}; #[derive(Clone, Debug, Default, PartialEq)] struct FakeNetlinkInnerMessage; @@ -240,4 +234,34 @@ mod tests { let got = NetlinkMessage::parse(&NetlinkBuffer::new(&buf)).unwrap(); assert_eq!(got, want); } + + #[test] + fn test_error() { + // SAFETY: value is non-zero. + const ERROR_CODE: NonZeroI32 = + unsafe { NonZeroI32::new_unchecked(-8765) }; + + let header = NetlinkHeader::default(); + let error_msg = ErrorMessage { + code: Some(ERROR_CODE), + header: vec![], + }; + let mut want = NetlinkMessage::new( + header, + NetlinkPayload::::Error(error_msg.clone()), + ); + want.finalize(); + + let len = want.buffer_len(); + assert_eq!(len, header.buffer_len() + error_msg.buffer_len()); + + let mut buf = vec![1; len]; + want.emit(&mut buf); + + let error_buf = ErrorBuffer::new(&buf[header.buffer_len()..]); + assert_eq!(error_buf.code(), error_msg.code); + + let got = NetlinkMessage::parse(&NetlinkBuffer::new(&buf)).unwrap(); + assert_eq!(got, want); + } } diff --git a/src/payload.rs b/src/payload.rs index d98e4d2..34c6a33 100644 --- a/src/payload.rs +++ b/src/payload.rs @@ -2,7 +2,7 @@ use std::fmt::Debug; -use crate::{AckMessage, DoneMessage, ErrorMessage, NetlinkSerializable}; +use crate::{DoneMessage, ErrorMessage, NetlinkSerializable}; /// The message is ignored. pub const NLMSG_NOOP: u16 = 1; @@ -20,7 +20,6 @@ pub const NLMSG_ALIGNTO: u16 = 4; pub enum NetlinkPayload { Done(DoneMessage), Error(ErrorMessage), - Ack(AckMessage), Noop, Overrun(Vec), InnerMessage(I), @@ -33,7 +32,7 @@ where pub fn message_type(&self) -> u16 { match self { NetlinkPayload::Done(_) => NLMSG_DONE, - NetlinkPayload::Error(_) | NetlinkPayload::Ack(_) => NLMSG_ERROR, + NetlinkPayload::Error(_) => NLMSG_ERROR, NetlinkPayload::Noop => NLMSG_NOOP, NetlinkPayload::Overrun(_) => NLMSG_OVERRUN, NetlinkPayload::InnerMessage(message) => message.message_type(),