diff --git a/.gitmodules b/.gitmodules index bf824a5a..0dc9a72d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,7 @@ [submodule "proto"] path = proto url = ../proto.git + +[submodule "wireguard-rs"] + path = wireguard-rs + url = ../wireguard-rs.git diff --git a/Cargo.lock b/Cargo.lock index 28a5f551..b6e31ed1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -496,6 +496,7 @@ dependencies = [ "toml", "tonic", "tonic-build", + "wireguard_rs", "x25519-dalek", ] @@ -786,7 +787,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "socket2", + "socket2 0.4.9", "tokio", "tower-service", "tracing", @@ -1628,6 +1629,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "socket2" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2538b18701741680e0322a2302176d3253a35388e2e62f172f64f4f16605f877" +dependencies = [ + "libc", + "windows-sys", +] + [[package]] name = "spin" version = "0.5.2" @@ -1792,18 +1803,17 @@ dependencies = [ [[package]] name = "tokio" -version = "1.29.1" +version = "1.32.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "532826ff75199d5833b9d2c5fe410f29235e25704ee5f0ef599fb51c21f4a4da" +checksum = "17ed6077ed6cd6c74735e21f37eb16dc3935f96878b1fe961074089cc80893f9" dependencies = [ - "autocfg", "backtrace", "bytes", "libc", "mio", "num_cpus", "pin-project-lite", - "socket2", + "socket2 0.5.3", "tokio-macros", "windows-sys", ] @@ -2308,6 +2318,22 @@ dependencies = [ "memchr", ] +[[package]] +name = "wireguard_rs" +version = "0.1.0" +dependencies = [ + "base64 0.13.1", + "log", + "netlink-packet-core", + "netlink-packet-generic", + "netlink-packet-route", + "netlink-packet-wireguard", + "netlink-sys", + "nix", + "thiserror", + "tokio", +] + [[package]] name = "x25519-dalek" version = "2.0.0-rc.3" diff --git a/Cargo.toml b/Cargo.toml index 0c9ca128..ba8e0eb0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,7 @@ tokio = { version = "1", features = ["macros", "rt-multi-thread"] } tokio-stream = { version = "0.1", features = [] } toml = "0.7" serde = { version = "1.0", features = ["derive"] } +wireguard_rs = { path = "wireguard-rs" } [dev-dependencies] tokio = { version = "1", features = ["io-std", "io-util"] } diff --git a/examples/api.rs b/examples/api.rs index 0c63c235..4f2f3940 100644 --- a/examples/api.rs +++ b/examples/api.rs @@ -3,8 +3,8 @@ use std::str::FromStr; use x25519_dalek::{EphemeralSecret, PublicKey, StaticSecret}; #[cfg(target_os = "linux")] -use defguard_gateway::wireguard::netlink::{address_interface, create_interface}; -use defguard_gateway::wireguard::{wgapi::WGApi, Host, IpAddrMask, Key, Peer}; +use wireguard_rs::netlink::{address_interface, create_interface}; +use wireguard_rs::{wgapi::WGApi, Host, IpAddrMask, Key, Peer}; #[tokio::main] async fn main() -> Result<(), Box> { diff --git a/examples/server.rs b/examples/server.rs index e3eef7e2..d7110ba7 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -1,8 +1,4 @@ -use defguard_gateway::proto::ConfigurationRequest; -use defguard_gateway::{ - proto, - wireguard::{Host, IpAddrMask, Key, Peer}, -}; +use defguard_gateway::proto; use std::{ collections::HashMap, io::{stdout, Write}, @@ -18,6 +14,7 @@ use tokio::{ }; use tokio_stream::wrappers::UnboundedReceiverStream; use tonic::{transport::Server, Request, Response, Status, Streaming}; +use wireguard_rs::{Host, IpAddrMask, Key, Peer}; pub struct HostConfig { name: String, @@ -84,7 +81,7 @@ impl proto::gateway_service_server::GatewayService for GatewayServer { async fn config( &self, - request: Request, + request: Request, ) -> Result, Status> { let address = request.remote_addr().unwrap(); eprintln!("CONFIG connected from: {}", address); diff --git a/src/error.rs b/src/error.rs index e40af504..d67f224f 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,4 +1,5 @@ use thiserror::Error; +use wireguard_rs::{error::WireguardError, IpAddrParseError}; #[derive(Debug, Error)] pub enum GatewayError { @@ -16,7 +17,7 @@ pub enum GatewayError { KeyDecode(#[from] base64::DecodeError), #[error("IP address/mask error")] - IpAddrMask(#[from] super::wireguard::IpAddrParseError), + IpAddrMask(#[from] IpAddrParseError), #[error("Logger error")] Logger(#[from] log::SetLoggerError), @@ -35,4 +36,7 @@ pub enum GatewayError { #[error("Invalid config file. Error: {0}")] InvalidConfigFile(String), + + #[error("Wireguard error")] + WireguardError(#[from] WireguardError), } diff --git a/src/gateway.rs b/src/gateway.rs index 95940707..cd1b4139 100644 --- a/src/gateway.rs +++ b/src/gateway.rs @@ -16,8 +16,6 @@ use tonic::{ Request, Status, Streaming, }; -#[cfg(target_os = "linux")] -use crate::wireguard::netlink::delete_interface; use crate::{ config::Config, error::GatewayError, @@ -26,9 +24,11 @@ use crate::{ gateway_service_client::GatewayServiceClient, update, Configuration, ConfigurationRequest, Peer, Update, }, - wireguard::{setup_interface, wgapi::WGApi}, + wireguard_rs::{setup_interface, wgapi::WGApi}, VERSION, }; +#[cfg(target_os = "linux")] +use wireguard_rs::netlink::delete_interface; // helper struct which stores just the interface config without peers #[derive(Clone, PartialEq)] @@ -208,7 +208,7 @@ impl Gateway { setup_interface( &self.config.ifname, self.config.userspace, - &new_configuration, + &new_configuration.clone().into(), )?; info!( "Reconfigured WireGuard interface: {:?}", diff --git a/src/lib.rs b/src/lib.rs index f6aa3d0e..f1614377 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,8 +3,8 @@ pub mod config; pub mod error; pub mod gateway; -pub mod wireguard; +#[allow(non_snake_case)] pub mod proto { tonic::include_proto!("gateway"); } @@ -12,11 +12,14 @@ pub mod proto { #[macro_use] extern crate log; -use std::{process, str::FromStr}; +extern crate wireguard_rs; + +use std::{process, str::FromStr, time::SystemTime}; use config::Config; use error::GatewayError; use syslog::{BasicLogger, Facility, Formatter3164}; +use wireguard_rs::{InterfaceConfiguration, IpAddrMask, Peer}; pub const VERSION: &str = env!("CARGO_PKG_VERSION"); @@ -69,3 +72,56 @@ pub fn execute_command(command: &str) -> Result<(), GatewayError> { } Ok(()) } + +impl From for InterfaceConfiguration { + fn from(config: proto::Configuration) -> Self { + let peers = config.peers.into_iter().map(Peer::from).collect(); + InterfaceConfiguration { + name: config.name, + prvkey: config.prvkey, + address: config.address, + port: config.port, + peers, + } + } +} + +impl From for Peer { + fn from(proto_peer: proto::Peer) -> Self { + let mut peer = Self::new(proto_peer.pubkey.as_str().try_into().unwrap_or_default()); + peer.allowed_ips = proto_peer + .allowed_ips + .iter() + .filter_map(|entry| IpAddrMask::from_str(entry).ok()) + .collect(); + peer + } +} + +impl From<&Peer> for proto::Peer { + fn from(peer: &Peer) -> Self { + Self { + pubkey: peer.public_key.to_string(), + allowed_ips: peer.allowed_ips.iter().map(ToString::to_string).collect(), + } + } +} + +impl From<&Peer> for proto::PeerStats { + fn from(peer: &Peer) -> Self { + Self { + public_key: peer.public_key.to_string(), + endpoint: peer + .endpoint + .map_or(String::new(), |endpoint| endpoint.to_string()), + allowed_ips: peer.allowed_ips.iter().map(ToString::to_string).collect(), + latest_handshake: peer.last_handshake.map_or(0, |ts| { + ts.duration_since(SystemTime::UNIX_EPOCH) + .map_or(0, |duration| duration.as_secs() as i64) + }), + download: peer.rx_bytes as i64, + upload: peer.tx_bytes as i64, + keepalive_interval: i64::from(peer.persistent_keepalive_interval.unwrap_or_default()), + } + } +} diff --git a/src/wireguard/bsd/mod.rs b/src/wireguard/bsd/mod.rs deleted file mode 100644 index f020f64c..00000000 --- a/src/wireguard/bsd/mod.rs +++ /dev/null @@ -1,267 +0,0 @@ -mod nvlist; -mod sockaddr; -mod timespec; -mod wgio; - -use std::{collections::HashMap, mem::size_of, net::IpAddr, ptr::addr_of, slice::from_raw_parts}; - -use self::{ - nvlist::NvList, - sockaddr::{pack_sockaddr, unpack_sockaddr}, - timespec::{pack_timespec, unpack_timespec}, - wgio::{WgDataIo, WgIoError}, -}; -use super::{Host, IpAddrMask, Peer}; - -// nvlist key names -static NV_LISTEN_PORT: &str = "listen-port"; -static NV_FWMARK: &str = "user-cookie"; -static NV_PUBLIC_KEY: &str = "public-key"; -static NV_PRIVATE_KEY: &str = "private-key"; -static NV_PEERS: &str = "peers"; -static NV_REPLACE_PEERS: &str = "replace-peers"; - -static NV_PRESHARED_KEY: &str = "preshared-key"; -static NV_KEEPALIVE_INTERVAL: &str = "persistent-keepalive-interval"; -static NV_ENDPOINT: &str = "endpoint"; -static NV_RX_BYTES: &str = "rx-bytes"; -static NV_TX_BYTES: &str = "tx-bytes"; -static NV_LAST_HANDSHAKE: &str = "last-handshake-time"; -static NV_ALLOWED_IPS: &str = "allowed-ips"; -static NV_REPLACE_ALLOWED_IPS: &str = "replace-allowed-ips"; -static NV_REMOVE: &str = "remove"; - -static NV_CIDR: &str = "cidr"; -static NV_IPV4: &str = "ipv4"; -static NV_IPV6: &str = "ipv6"; - -/// Cast bytes to `T`. -unsafe fn cast_ref(bytes: &[u8]) -> &T { - let ptr: *const u8 = bytes.as_ptr(); - ptr.cast::().as_ref().unwrap() -} - -/// Cast `T' to bytes. -unsafe fn cast_bytes(p: &T) -> &[u8] { - let ptr = addr_of!(p).cast::(); - from_raw_parts(ptr, size_of::()) -} - -impl IpAddrMask { - #[must_use] - fn try_from_nvlist(nvlist: &NvList) -> Option { - // cidr is mendatory - nvlist.get_number(NV_CIDR).and_then(|cidr| { - match nvlist.get_binary(NV_IPV4) { - Some(ipv4) => <[u8; 4]>::try_from(ipv4).ok().map(IpAddr::from), - None => nvlist - .get_binary(NV_IPV6) - .and_then(|ipv6| <[u8; 16]>::try_from(ipv6).ok().map(IpAddr::from)), - } - .map(|ip| Self { - ip, - cidr: cidr as u8, - }) - }) - } -} - -impl<'a> IpAddrMask { - #[must_use] - fn as_nvlist(&'a self) -> NvList<'a> { - let mut nvlist = NvList::new(); - - nvlist.append_number(NV_CIDR, u64::from(self.cidr)); - - match self.ip { - IpAddr::V4(ipv4) => nvlist.append_bytes(NV_IPV4, ipv4.octets().into()), - IpAddr::V6(ipv6) => nvlist.append_bytes(NV_IPV6, ipv6.octets().into()), - } - - nvlist.append_nvlist_array_next(); - nvlist - } -} - -impl Host { - #[must_use] - fn from_nvlist(nvlist: &NvList) -> Self { - let listen_port = nvlist.get_number(NV_LISTEN_PORT).unwrap_or_default(); - let private_key = nvlist - .get_binary(NV_PRIVATE_KEY) - .and_then(|value| (*value).try_into().ok()); - - let mut peers = HashMap::new(); - if let Some(peer_array) = nvlist.get_nvlist_array(NV_PEERS) { - for peer_list in peer_array { - if let Some(peer) = Peer::try_from_nvlist(peer_list) { - peers.insert(peer.public_key.clone(), peer); - } - } - } - - Self { - listen_port: listen_port as u16, - private_key, - fwmark: nvlist.get_number(NV_FWMARK).map(|num| num as u32), - peers, - } - } -} - -impl<'a> Host { - #[must_use] - fn as_nvlist(&'a self) -> NvList<'a> { - let mut nvlist = NvList::new(); - - nvlist.append_number(NV_LISTEN_PORT, u64::from(self.listen_port)); - if let Some(private_key) = self.private_key.as_ref() { - nvlist.append_binary(NV_PRIVATE_KEY, private_key.as_slice()); - } - if let Some(fwmark) = self.fwmark { - nvlist.append_number(NV_FWMARK, u64::from(fwmark)); - } - - nvlist.append_bool(NV_REPLACE_PEERS, true); - if !self.peers.is_empty() { - let peers = self.peers.values().map(Peer::as_nvlist).collect(); - nvlist.append_nvlist_array(NV_PEERS, peers); - } - - nvlist - } -} - -impl Peer { - #[must_use] - fn try_from_nvlist(nvlist: &NvList) -> Option { - if let Some(public_key) = nvlist - .get_binary(NV_PUBLIC_KEY) - .and_then(|value| (*value).try_into().ok()) - { - let preshared_key = nvlist - .get_binary(NV_PRESHARED_KEY) - .and_then(|value| (*value).try_into().ok()); - let mut allowed_ips = Vec::new(); - if let Some(ip_array) = nvlist.get_nvlist_array(NV_ALLOWED_IPS) { - for ip_list in ip_array { - if let Some(ip) = IpAddrMask::try_from_nvlist(ip_list) { - allowed_ips.push(ip); - } - } - } - - Some(Self { - public_key, - preshared_key, - protocol_version: None, - endpoint: nvlist.get_binary(NV_ENDPOINT).and_then(unpack_sockaddr), - last_handshake: nvlist - .get_binary(NV_LAST_HANDSHAKE) - .and_then(unpack_timespec), - tx_bytes: nvlist.get_number(NV_TX_BYTES).unwrap_or_default(), - rx_bytes: nvlist.get_number(NV_RX_BYTES).unwrap_or_default(), - persistent_keepalive_interval: nvlist - .get_number(NV_KEEPALIVE_INTERVAL) - .map(|value| value as u16), - allowed_ips, - }) - } else { - None - } - } -} - -impl<'a> Peer { - #[must_use] - fn as_nvlist(&'a self) -> NvList<'a> { - let mut nvlist = NvList::new(); - - nvlist.append_binary(NV_PUBLIC_KEY, self.public_key.as_slice()); - if let Some(preshared_key) = self.preshared_key.as_ref() { - nvlist.append_binary(NV_PRESHARED_KEY, preshared_key.as_slice()); - } - if let Some(endpoint) = self.endpoint.as_ref() { - nvlist.append_bytes(NV_ENDPOINT, pack_sockaddr(endpoint)); - } - if let Some(last_handshake) = self.last_handshake.as_ref() { - nvlist.append_bytes(NV_LAST_HANDSHAKE, pack_timespec(last_handshake)); - } - nvlist.append_number(NV_TX_BYTES, self.tx_bytes); - nvlist.append_number(NV_RX_BYTES, self.rx_bytes); - if let Some(keepalive_interval) = self.persistent_keepalive_interval { - nvlist.append_number(NV_KEEPALIVE_INTERVAL, u64::from(keepalive_interval)); - } - - nvlist.append_bool(NV_REPLACE_ALLOWED_IPS, true); - if !self.allowed_ips.is_empty() { - let allowed_ips = self.allowed_ips.iter().map(IpAddrMask::as_nvlist).collect(); - nvlist.append_nvlist_array(NV_ALLOWED_IPS, allowed_ips); - } - - nvlist.append_nvlist_array_next(); - nvlist - } - - #[must_use] - fn as_nvlist_for_removal(&'a self) -> NvList<'a> { - let mut nvlist = NvList::new(); - - nvlist.append_binary(NV_PUBLIC_KEY, self.public_key.as_slice()); - nvlist.append_bool(NV_REMOVE, true); - - nvlist.append_nvlist_array_next(); - nvlist - } -} - -pub fn get_host(if_name: &str) -> Result { - let mut wg_data = WgDataIo::new(if_name); - wg_data.read_data()?; - - let mut nvlist = NvList::new(); - // FIXME: use proper error, here and above - nvlist - .unpack(wg_data.as_slice()) - .map_err(|_| WgIoError::MemAlloc)?; - - Ok(Host::from_nvlist(&nvlist)) -} - -pub fn set_host(if_name: &str, host: &Host) -> Result<(), WgIoError> { - let mut wg_data = WgDataIo::new(if_name); - - let nvlist = host.as_nvlist(); - // FIXME: use proper error, here and above - let mut buf = nvlist.pack().map_err(|_| WgIoError::MemAlloc)?; - - wg_data.wgd_data = buf.as_mut_ptr(); - wg_data.wgd_size = buf.len(); - wg_data.write_data() -} - -pub fn set_peer(if_name: &str, peer: &Peer) -> Result<(), WgIoError> { - let mut wg_data = WgDataIo::new(if_name); - - let mut nvlist = NvList::new(); - nvlist.append_nvlist_array(NV_PEERS, vec![peer.as_nvlist()]); - // FIXME: use proper error, here and above - let mut buf = nvlist.pack().map_err(|_| WgIoError::MemAlloc)?; - - wg_data.wgd_data = buf.as_mut_ptr(); - wg_data.wgd_size = buf.len(); - wg_data.write_data() -} - -pub fn delete_peer(if_name: &str, peer: &Peer) -> Result<(), WgIoError> { - let mut wg_data = WgDataIo::new(if_name); - - let mut nvlist = NvList::new(); - nvlist.append_nvlist_array(NV_PEERS, vec![peer.as_nvlist_for_removal()]); - // FIXME: use proper error, here and above - let mut buf = nvlist.pack().map_err(|_| WgIoError::MemAlloc)?; - - wg_data.wgd_data = buf.as_mut_ptr(); - wg_data.wgd_size = buf.len(); - wg_data.write_data() -} diff --git a/src/wireguard/bsd/nvlist.rs b/src/wireguard/bsd/nvlist.rs deleted file mode 100644 index e81f897e..00000000 --- a/src/wireguard/bsd/nvlist.rs +++ /dev/null @@ -1,825 +0,0 @@ -// https://github.com/freebsd/freebsd-src/tree/main/sys/contrib/libnv -// https://github.com/freebsd/freebsd-src/blob/main/sys/sys/nv.h -use std::{error::Error, ffi::CStr, fmt}; - -/// `NV_HEADER_SIZE` is for both: `nvlist_header` and `nvpair_header`. -const NV_HEADER_SIZE: usize = 19; -const NV_NAME_MAX: usize = 2048; -const NVLIST_HEADER_MAGIC: u8 = 0x6c; // 'l' -const NVLIST_HEADER_VERSION: u8 = 0; -// Public flags -// Perform case-insensitive lookups of provided names. -// const NV_FLAG_IGNORE_CASE: u8 = 1; -// Names don't have to be unique. -// const NV_FLAG_NO_UNIQUE: u8 = 2; -// Private flags -const NV_FLAG_BIG_ENDIAN: u8 = 0x80; -// const NV_FLAG_IN_ARRAY: u8 = 0x100; - -#[derive(Debug)] -#[repr(u8)] -pub enum NvType { - None, - Null, - Bool, - Number, - String, - NvList, - _Descriptor, - Binary, - BoolArray, - NumberArray, - StringArray, - NvListArray, - _DescriptorArray, - // must have a parent - NvListArrayNext = 254, - NvListAUp, -} - -impl From for NvType { - fn from(val: u8) -> Self { - match val { - 1 => Self::Null, - 2 => Self::Bool, - 3 => Self::Number, - 4 => Self::String, - 5 => Self::NvList, - 6 => Self::_Descriptor, - 7 => Self::Binary, - 8 => Self::BoolArray, - 9 => Self::NumberArray, - 10 => Self::StringArray, - 11 => Self::NvListArray, - 12 => Self::_DescriptorArray, - 254 => Self::NvListArrayNext, - 255 => Self::NvListAUp, - _ => Self::None, - } - } -} - -#[derive(Debug)] -pub enum NvValue<'a> { - Null, - Bool(bool), - Number(u64), - String(&'a str), - NvList(NvList<'a>), - _Descriptor, // not implemented - Binary(&'a [u8]), - Bytes(Vec), // similar to `Binary`, but owned - BoolArray(Vec), - NumberArray(Vec), - StringArray(Vec<&'a str>), - NvListArray(Vec>), - _DescriptorArray, // not implemented - NvListArrayNext, - // NvListAUp, -} - -impl<'a> NvValue<'a> { - /// Return number of bytes this value occupies when packed. - #[must_use] - pub fn byte_size(&self) -> usize { - match self { - Self::Null | Self::_Descriptor | Self::_DescriptorArray | Self::NvListArrayNext => 0, - Self::Bool(_) => 1, - Self::Number(_) => 8, - Self::String(string) => string.len() + 1, // +1 for NUL - Self::NvList(list) => list.byte_size(), // FIXME: not sure about this - Self::Binary(binary) => binary.len(), - Self::Bytes(bytes) => bytes.len(), - Self::BoolArray(array) => array.len(), - Self::NumberArray(array) => array.len() * 8, - Self::StringArray(array) => array.iter().fold(0, |size, el| size + el.len() + 1), - Self::NvListArray(array) => array.iter().fold(0, |size, el| size + el.byte_size()), - } - } - - #[must_use] - pub fn nv_type(&self) -> NvType { - match self { - Self::Null => NvType::Null, - Self::Bool(_) => NvType::Bool, - Self::Number(_) => NvType::Number, - Self::String(_) => NvType::String, - Self::NvList(_) => NvType::NvList, - Self::_Descriptor => NvType::_Descriptor, - Self::Binary(_) | Self::Bytes(_) => NvType::Binary, - Self::BoolArray(_) => NvType::BoolArray, - Self::NumberArray(_) => NvType::NumberArray, - Self::StringArray(_) => NvType::StringArray, - Self::NvListArray(_) => NvType::NvListArray, - Self::_DescriptorArray => NvType::_DescriptorArray, - Self::NvListArrayNext => NvType::NvListArrayNext, - } - } - - #[must_use] - pub fn number_of_items(&self) -> usize { - match self { - Self::BoolArray(v) => v.len(), - Self::NumberArray(v) => v.len(), - Self::StringArray(v) => v.len(), - Self::NvListArray(v) => v.len(), - _ => 0, // non-array - } - } -} - -#[derive(Debug)] -pub enum NvListError { - NameTooLong, - NotEnoughBytes, - WrongHeader, - WrongName, - WrongPair, - WrongPairData, -} - -impl Error for NvListError {} - -impl fmt::Display for NvListError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::NameTooLong => write!(f, "name is too long"), - Self::NotEnoughBytes => write!(f, "not enough bytes"), - Self::WrongHeader => write!(f, "wrong header"), - Self::WrongName => write!(f, "wrong name"), - Self::WrongPair => write!(f, "wrong name-value pair"), - Self::WrongPairData => write!(f, "wrong name-value pair data"), - } - } -} - -/// `NvList` is a name-value list. -/// It is meant to live shortly. Just build the list and serialize it to bytes. -type NameValue<'a> = (&'a str, NvValue<'a>); -#[derive(Debug)] -pub struct NvList<'a> { - items: Vec>, - is_big_endian: bool, -} - -impl<'a> Default for NvList<'a> { - fn default() -> Self { - Self::new() - } -} - -impl<'a> NvList<'a> { - /// Create new `NvList`. - #[must_use] - pub fn new() -> Self { - Self { - items: Vec::new(), - #[cfg(target_endian = "big")] - is_big_endian: true, - #[cfg(target_endian = "little")] - is_big_endian: false, - } - } - - /// Get value for a given `name`. - fn get(&self, name: &str) -> Option<&NvValue> { - self.items.iter().find(|(n, _)| n == &name).map(|(_, v)| v) - } - - /// Get value as `bool`. - // pub fn get_bool(&self, name: &str) -> Option { - // self.get(name).and_then(|value| match value { - // NvValue::Bool(boolean) => Some(*boolean), - // _ => None, - // }) - // } - - /// Get value as `u64`. - pub fn get_number(&self, name: &str) -> Option { - self.get(name).and_then(|value| match value { - NvValue::Number(number) => Some(*number), - _ => None, - }) - } - - /// Get value as `&str`. - // pub fn get_string(&self, name: &str) -> Option<&str> { - // self.get(name).and_then(|value| match value { - // NvValue::String(string) => Some(*string), - // _ => None, - // }) - // } - - /// Get value as `&[u8]`. - pub fn get_binary(&self, name: &str) -> Option<&[u8]> { - self.get(name).and_then(|value| match value { - NvValue::Binary(binary) => Some(*binary), - _ => None, - }) - } - - /// Get value as `Vec` - pub fn get_nvlist_array(&self, name: &str) -> Option<&[NvList]> { - self.get(name).and_then(|value| match value { - NvValue::NvListArray(array) => Some(array.as_slice()), - _ => None, - }) - } - - /// Append `Null` value to the list. - #[cfg(test)] - pub fn append_null(&mut self, name: &'a str) { - self.items.push((name, NvValue::Null)); - } - - /// Append `Bool` value to the list. - pub fn append_bool(&mut self, name: &'a str, boolean: bool) { - self.items.push((name, NvValue::Bool(boolean))); - } - - /// Append `Number` value to the list. - pub fn append_number(&mut self, name: &'a str, number: u64) { - self.items.push((name, NvValue::Number(number))); - } - - /// Append `String` value to the list. - // pub fn append_string(&mut self, name: &'a str, string: &'a str) { - // self.items.push((name, NvValue::String(string))); - // } - - /// Append `Binary` value to the list. - pub fn append_binary(&mut self, name: &'a str, binary: &'a [u8]) { - self.items.push((name, NvValue::Binary(binary))); - } - - /// Append `Bytes` value to the list. - pub fn append_bytes(&mut self, name: &'a str, bytes: Vec) { - self.items.push((name, NvValue::Bytes(bytes))); - } - - /// Append `NvListArray` value to the list. - pub fn append_nvlist_array(&mut self, name: &'a str, array: Vec>) { - self.items.push((name, NvValue::NvListArray(array))); - } - - /// Append `NvListArrayNext` value to the list. - pub fn append_nvlist_array_next(&mut self) { - self.items.push(("", NvValue::NvListArrayNext)); - } - - fn load_u16(&self, buf: &[u8]) -> Result { - if let Ok(bytes) = <[u8; 2]>::try_from(buf) { - Ok(if self.is_big_endian { - u16::from_be_bytes(bytes) - } else { - u16::from_le_bytes(bytes) - }) - } else { - Err(NvListError::NotEnoughBytes) - } - } - - fn load_u64(&self, buf: &[u8]) -> Result { - if let Ok(bytes) = <[u8; 8]>::try_from(buf) { - Ok(if self.is_big_endian { - u64::from_be_bytes(bytes) - } else { - u64::from_le_bytes(bytes) - }) - } else { - Err(NvListError::NotEnoughBytes) - } - } - - fn load_name(buf: &'a [u8]) -> Result<&'a str, NvListError> { - CStr::from_bytes_with_nul(buf) - .map_err(|_| NvListError::WrongName)? - .to_str() - .map_err(|_| NvListError::WrongName) - } - - fn load_string(buf: &'a [u8]) -> Result<&'a str, NvListError> { - CStr::from_bytes_until_nul(buf) - .map_err(|_| NvListError::WrongPairData)? - .to_str() - .map_err(|_| NvListError::WrongPairData) - } - - fn store_u16(&self, value: u16, buf: &mut Vec) { - buf.extend_from_slice(&if self.is_big_endian { - value.to_be_bytes() - } else { - value.to_le_bytes() - }); - } - - fn store_u64(&self, value: u64, buf: &mut Vec) { - buf.extend_from_slice(&if self.is_big_endian { - value.to_be_bytes() - } else { - value.to_le_bytes() - }); - } - - /// Return number of bytes this list occupies when packed. - #[must_use] - fn byte_size(&self) -> usize { - let mut size = NV_HEADER_SIZE; - for (name, value) in &self.items { - size += NV_HEADER_SIZE + name.len() + 1; // +1 for NUL - size += value.byte_size(); - } - - size - } - - /// Pack name-value list to binary representation. - pub fn pack(&self) -> Result, NvListError> { - let size = self.byte_size(); - let mut buf = Vec::with_capacity(size); - self.pack_with_size(&mut buf, size)?; - Ok(buf) - } - - /// Pack nvlist with pre-calculated buffer size. - /// This is needed for list arrays where lists have fishy size. - fn pack_with_size( - &self, - buf: &mut Vec, - mut byte_size: usize, - ) -> Result { - // pack header - buf.push(NVLIST_HEADER_MAGIC); - buf.push(NVLIST_HEADER_VERSION); - buf.push(if self.is_big_endian { - NV_FLAG_BIG_ENDIAN - } else { - 0 - }); - self.store_u64(0, buf); - byte_size -= NV_HEADER_SIZE; - self.store_u64(byte_size as u64, buf); - - for (name, value) in &self.items { - buf.push(value.nv_type() as u8); - // name length - let name_len = name.len() + 1; - if name_len > NV_NAME_MAX { - return Err(NvListError::NameTooLong); - } - self.store_u16(name_len as u16, buf); - - let value_size = match value { - NvValue::NvListArray(_) => 0, - _ => value.byte_size(), - }; - self.store_u64(value_size as u64, buf); - - let number_of_items = value.number_of_items(); - self.store_u64(number_of_items as u64, buf); - - // name - buf.extend_from_slice(name.as_bytes()); - buf.push(0); // NUL - - byte_size -= NV_HEADER_SIZE + name_len + value_size; - - match value { - NvValue::Bool(boolean) => buf.push(u8::from(*boolean)), - NvValue::Number(number) => self.store_u64(*number, buf), - NvValue::String(string) => { - buf.extend_from_slice(string.as_bytes()); - buf.push(0); // NUL - } - NvValue::Binary(bytes) => buf.extend_from_slice(bytes), - NvValue::Bytes(bytes) => buf.extend_from_slice(bytes.as_slice()), - NvValue::BoolArray(array) => { - array.iter().for_each(|boolean| buf.push((*boolean).into())); - } - NvValue::NumberArray(array) => { - array.iter().for_each(|number| self.store_u64(*number, buf)); - } - NvValue::StringArray(array) => { - for string in array.iter() { - buf.extend_from_slice(string.as_bytes()); - buf.push(0); // NUL - } - } - NvValue::NvListArray(nvlist_array) => { - for nvlist in nvlist_array { - byte_size = nvlist.pack_with_size(buf, byte_size)?; - } - } - NvValue::Null | NvValue::NvListArrayNext => (), - _ => unimplemented!(), - } - } - - Ok(byte_size) - } - - /// Unpack binary representation of name-value list. - /// - /// # Errors - /// Return `Err` when buffer contains invalid data. - pub fn unpack(&mut self, buf: &'a [u8]) -> Result { - let length = buf.len(); - // check header - if length < NV_HEADER_SIZE { - return Err(NvListError::NotEnoughBytes); - } - if buf[0] != NVLIST_HEADER_MAGIC || buf[1] != NVLIST_HEADER_VERSION { - return Err(NvListError::WrongHeader); - } - self.is_big_endian = buf[2] & NV_FLAG_BIG_ENDIAN != 0; - - let _descriptors = self.load_u64(&buf[3..11])?; - let size = self.load_u64(&buf[11..19])? as usize; - - // check total size - if length < NV_HEADER_SIZE + size { - return Err(NvListError::NotEnoughBytes); - } - - let mut index = NV_HEADER_SIZE; - while index < size { - match self.nvpair_unpack(&buf[index..]) { - Ok((count, last_element)) => { - index += count; - if last_element { - break; - } - } - Err(err) => return Err(err), - } - } - - Ok(index) - } - - /// Unpack binary name-value pair and return number of consumed bytes and - /// a flag indicating if array processing should be stopped (`true`), or not (`false`). - /// - /// # Errors - /// Return `Err` when buffer contains invalid data. - fn nvpair_unpack(&mut self, buf: &'a [u8]) -> Result<(usize, bool), NvListError> { - let pair_type = NvType::from(buf[0]); - let name_size = self.load_u16(&buf[1..3])? as usize; - if name_size > NV_NAME_MAX { - return Err(NvListError::WrongPair); - } - let size = self.load_u64(&buf[3..11])? as usize; - // Used only for array types. - let mut item_count = self.load_u64(&buf[11..NV_HEADER_SIZE])?; - let mut index = NV_HEADER_SIZE + name_size; - - let name = Self::load_name(&buf[NV_HEADER_SIZE..index])?; - let mut last_element = false; - - let value = match pair_type { - NvType::Null => { - if size != 0 { - return Err(NvListError::WrongPairData); - } - NvValue::Null - } - NvType::Bool => { - if size != 1 { - return Err(NvListError::WrongPairData); - } - let boolean = buf[index] != 0; - NvValue::Bool(boolean) - } - NvType::Number => { - if size != 8 { - return Err(NvListError::WrongPairData); - } - let number = self.load_u64(&buf[index..index + size])?; - NvValue::Number(number) - } - NvType::String => { - if size == 0 { - return Err(NvListError::WrongPairData); - } - let string = Self::load_string(&buf[index..index + size])?; - // TODO: if string.len() + 1 != size {} - NvValue::String(string) - } - NvType::NvList => { - // TODO: read list elements - NvValue::NvList(NvList::new()) - } - NvType::Binary => { - if size == 0 { - return Err(NvListError::WrongPairData); - } - let binary = &buf[index..index + size]; - NvValue::Binary(binary) - } - NvType::BoolArray => { - if size == 0 { - return Err(NvListError::WrongPairData); - } - let array = buf[index..index + size] - .iter() - .map(|byte| *byte != 0) - .collect(); - NvValue::BoolArray(array) - } - NvType::NumberArray => { - if size == 0 { - return Err(NvListError::WrongPairData); - } - let mut array = Vec::with_capacity(item_count as usize); - for chunk in buf[index..index + size].chunks(8) { - array.push(self.load_u64(chunk)?); - } - NvValue::NumberArray(array) - } - NvType::StringArray => { - if size == 0 { - return Err(NvListError::WrongPairData); - } - let mut array = Vec::with_capacity(item_count as usize); - let mut i = index; - let mut s = size; - for _ in 0..item_count { - let string = Self::load_string(&buf[i..i + s])?; - array.push(string); - i += string.len() + 1; - s -= string.len() + 1; - } - NvValue::StringArray(array) - } - NvType::NvListArray => { - if size != 0 || item_count == 0 { - return Err(NvListError::WrongPairData); - } - let mut array = Vec::with_capacity(item_count as usize); - while item_count != 0 { - let mut list = NvList::new(); - index += list.unpack(&buf[index..])?; - array.push(list); - item_count -= 1; - } - NvValue::NvListArray(array) - } - // This is a nasty hack: this type means we've reach the last item in the array. - // Stop processing the array regardless of `nvlh_size` in (nested) NvList header. - NvType::NvListArrayNext => { - if size != 0 || item_count != 0 { - return Err(NvListError::WrongPairData); - } - last_element = true; - NvValue::NvListArrayNext - } - _ => unimplemented!(), - }; - self.items.push((name, value)); - - Ok((index + size, last_element)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[rustfmt::skip] - static TEST_DATA: [u8; 81] = [ - // *** nvlist_header (19 bytes) - 108, // nvlh_magic - 0, // nvlh_version - 0, // nvlh_flags - 0, 0, 0, 0, 0, 0, 0, 0, // nvlh_descriptors - 39 + 23, 0, 0, 0, 0, 0, 0, 0, // nvlh_size - // *** data (nvlh_size bytes) - // *** nvpair_header (19 bytes) - 3, // nvph_type = NV_TYPE_NUMBER - 12, 0, // nvph_namesize (incl. NUL) - 8, 0, 0, 0, 0, 0, 0, 0, // nvph_datasize - 0, 0, 0, 0, 0, 0, 0, 0, // nvph_nitems - 108, 105, 115, 116, 101, 110, 45, 112, 111, 114, 116, 0, // "listen-port\0" - 57, 48, 0, 0, 0, 0, 0, 0, // 12345 - - 1, // nvph_type = NV_TYPE_NULL - 4, 0, // nvph_namesize (incl. NUL) - 0, 0, 0, 0, 0, 0, 0, 0, // nvph_datasize - 0, 0, 0, 0, 0, 0, 0, 0, // nvph_nitems - 'n' as u8, 'u' as u8, 'l' as u8, 0, - ]; - - #[test] - fn unpack() { - let mut nvlist = NvList::new(); - nvlist.unpack(&TEST_DATA).unwrap(); - - let buf = nvlist.pack().unwrap(); - - let mut nvlist = NvList::new(); - nvlist.unpack(&buf).unwrap(); - - assert_eq!(TEST_DATA.as_slice(), buf.as_slice()); - } - - #[test] - fn pack() { - let mut nvlist = NvList::new(); - nvlist.append_number("listen-port", 12345); - nvlist.append_null("nul"); - - let buf = nvlist.pack().unwrap(); - assert_eq!(TEST_DATA.as_slice(), buf.as_slice()); - } - - #[test] - fn bool_array() { - #[rustfmt::skip] - let data = [ - 108,0,0, - 0,0,0,0,0,0,0,0, - 27,0,0,0,0,0,0,0, - 8,4,0, // NV_TYPE_BOOL_ARRAY - 4,0,0,0,0,0,0,0, // size - 4,0,0,0,0,0,0,0, // items - 98,117,108,0, // "bul\0" - 1,0,0,1, - ]; - let mut nvlist = NvList::new(); - nvlist.unpack(&data).unwrap(); - - let buf = nvlist.pack().unwrap(); - assert_eq!(data.as_slice(), buf.as_slice()); - } - - #[test] - fn number_array() { - #[rustfmt::skip] - let data = [ - 108,0,0, - 0,0,0,0,0,0,0,0, - 40,0,0,0,0,0,0,0, - 9,5,0, - 16,0,0,0,0,0,0,0, - 2,0,0,0,0,0,0,0, - 110,117,109,115,0, // "nums\0" - 68,51,34,17,0,0,0,0, 136,119,102,85,0,0,0,0, - ]; - let mut nvlist = NvList::new(); - nvlist.unpack(&data).unwrap(); - - let buf = nvlist.pack().unwrap(); - assert_eq!(data.as_slice(), buf.as_slice()); - } - - #[test] - fn string_array() { - #[rustfmt::skip] - let data = [ - 108,0,0, - 0,0,0,0,0,0,0,0, - 42,0,0,0,0,0,0,0, - 10,6,0, - 17,0,0,0,0,0,0,0, - 3,0,0,0,0,0,0,0, - 110,97,109,101,115,0, - 83,116,117,97,114,116,0, 75,101,118,105,110,0, 66,111,98,0, - ]; - let mut nvlist = NvList::new(); - nvlist.unpack(&data).unwrap(); - - let buf = nvlist.pack().unwrap(); - assert_eq!(data.as_slice(), buf.as_slice()); - } - - #[test] - fn two_peers() { - #[rustfmt::skip] - let data = [ - // nvlist - 108, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 121, 3, 0, 0, 0, 0, 0, 0, - // NV_TYPE_NUMBER - 3, 12, 0, - 8, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 108, 105, 115, 116, 101, 110, 45, 112, 111, 114, 116, 0, // "listen-port\0" - 133, 28, 0, 0, 0, 0, 0, 0, - // NV_TYPE_BINARY - 7, 11, 0, - 32, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 112, 117, 98, 108, 105, 99, 45, 107, 101, 121, 0, // "public-key\0" - 77, 206, 217, 13, 140, 115, 50, 63, 20, 85, 182, 151, 82, 219, 246, 40, 224, 195, 180, 210, 240, 16, 47, 189, 89, 167, 240, 131, 81, 17, 68, 111, - // NV_TYPE_NUMBER - 7, 12, 0, - 32, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 112, 114, 105, 118, 97, 116, 101, 45, 107, 101, 121, 0, // "private-key\0" - 184, 70, 130, 139, 240, 172, 115, 210, 42, 253, 145, 16, 84, 163, 217, 206, 219, 207, 194, 29, 250, 97, 48, 232, 184, 78, 19, 62, 194, 45, 133, 77, - // NV_TYPE_NVLIST_ARRAY - 11, 6, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 2, 0, 0, 0, 0, 0, 0, 0, - 112, 101, 101, 114, 115, 0, // "peers\0" - // nvlist - 108, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 169, 2, 0, 0, 0, 0, 0, 0, - // NV_TYPE_BINARY - 7, 11, 0, - 32, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 112, 117, 98, 108, 105, 99, 45, 107, 101, 121, 0, // "public-key\0" - 220, 98, 132, 114, 211, 195, 157, 56, 63, 135, 95, 253, 123, 132, 59, 218, 35, 120, 55, 169, 156, 165, 223, 184, 140, 111, 142, 164, 145, 107, 167, 17, - // NV_TYPE_BINARY - 7, 14, 0, - 32, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 112, 114, 101, 115, 104, 97, 114, 101, 100, 45, 107, 101, 121, 0, // "preshared-key\0" - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - // NV_TYPE_BINARY - 7, 20, 0, - 16, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 108, 97, 115, 116, 45, 104, 97, 110, 100, 115, 104, 97, 107, 101, 45, 116, 105, 109, 101, 0, // "last-handshake-time\0" - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - // NV_TYPE_NUMBER - 3, 30, 0, - 8, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 112, 101, 114, 115, 105, 115, 116, 101, 110, 116, 45, 107, 101, 101, 112, 97, 108, 105, 118, 101, 45, 105, 110, 116, 101, 114, 118, 97, 108, 0, // "persistent-keepalive-interval\0" - 0, 0, 0, 0, 0, 0, 0, 0, - // NV_TYPE_NUMBER - 3, 9, 0, - 8, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 114, 120, 45, 98, 121, 116, 101, 115, 0, // "rx-bytes\0" - 0, 0, 0, 0, 0, 0, 0, 0, - // NV_TYPE_NUMBER - 3, 9, 0, - 8, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 116, 120, 45, 98, 121, 116, 101, 115, 0, // "tx-bytes\0" - 0, 0, 0, 0, 0, 0, 0, 0, - // NV_TYPE_NVLIST_ARRAY_NEXT - 254, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, - // nvlist - 108, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 75, 1, 0, 0, 0, 0, 0, 0, - // NV_TYPE_BINARY - 7, 11, 0, - 32, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 112, 117, 98, 108, 105, 99, 45, 107, 101, 121, 0, // "public-key\0" - 60, 195, 52, 243, 24, 229, 218, 5, 142, 193, 30, 194, 241, 176, 169, 221, 121, 39, 172, 116, 158, 67, 46, 115, 119, 155, 107, 159, 128, 201, 79, 54, - // NV_TYPE_BINARY - 7, 14, 0, - 32, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 112, 114, 101, 115, 104, 97, 114, 101, 100, 45, 107, 101, 121, 0, // "preshared-key\0" - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - // NV_TYPE_BINARY - 7, 20, 0, - 16, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 108, 97, 115, 116, 45, 104, 97, 110, 100, 115, 104, 97, 107, 101, 45, 116, 105, 109, 101, 0, // "last-handshake-time\0" - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - // NV_TYPE_NUMBER - 3, 30, 0, - 8, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 112, 101, 114, 115, 105, 115, 116, 101, 110, 116, 45, 107, 101, 101, 112, 97, 108, 105, 118, 101, 45, 105, 110, 116, 101, 114, 118, 97, 108, 0, // "persistent-keepalive-interval\0" - 0, 0, 0, 0, 0, 0, 0, 0, - // NV_TYPE_NUMBER - 3, 9, 0, - 8, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 114, 120, 45, 98, 121, 116, 101, 115, 0, // "rx-bytes\0" - 0, 0, 0, 0, 0, 0, 0, 0, - // NV_TYPE_NUMBER - 3, 9, 0, - 8, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 116, 120, 45, 98, 121, 116, 101, 115, 0, // "tx-bytes\0" - 0, 0, 0, 0, 0, 0, 0, 0, - // NV_TYPE_NVLIST_ARRAY_NEXT - 254, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0]; - let mut nvlist = NvList::new(); - nvlist.unpack(&data).unwrap(); - - let buf = nvlist.pack().unwrap(); - assert_eq!(data.as_slice(), buf.as_slice()); - - let mut nvlist = NvList::new(); - nvlist.unpack(&buf).unwrap(); - } -} diff --git a/src/wireguard/bsd/sockaddr.rs b/src/wireguard/bsd/sockaddr.rs deleted file mode 100644 index d83394a5..00000000 --- a/src/wireguard/bsd/sockaddr.rs +++ /dev/null @@ -1,148 +0,0 @@ -//! Convert binary `sockaddr_in` or `sockaddr_in6` (see netinet/in.h) to `SocketAddr`. -use std::{ - mem::size_of, - net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, -}; - -use super::{cast_bytes, cast_ref}; - -const AF_INET: u8 = 2; // IPv4 -const AF_INET6: u8 = 30; // IPv6 -const SA_IN_SIZE: usize = size_of::(); -const SA_IN6_SIZE: usize = size_of::(); - -// netinet/in.h -#[repr(C)] -struct SockAddrIn { - len: u8, - family: u8, - port: u16, - addr: [u8; 4], - zero: [u8; 8], -} - -impl From<&SockAddrIn> for SocketAddr { - fn from(sa: &SockAddrIn) -> Self { - Self::V4(SocketAddrV4::new( - Ipv4Addr::from(sa.addr), - u16::from_be(sa.port), - )) - } -} - -impl From<&SocketAddrV4> for SockAddrIn { - fn from(sa: &SocketAddrV4) -> Self { - Self { - len: SA_IN_SIZE as u8, - family: AF_INET, - port: sa.port().to_be(), - addr: sa.ip().octets(), - zero: [0u8; 8], - } - } -} - -// netinet6/in6.h -#[repr(C)] -struct SockAddrIn6 { - len: u8, - family: u8, - port: u16, - flowinfo: u32, - addr: [u8; 16], - scope_id: u32, -} - -impl From<&SockAddrIn6> for SocketAddr { - fn from(sa: &SockAddrIn6) -> Self { - Self::V6(SocketAddrV6::new( - Ipv6Addr::from(sa.addr), - u16::from_be(sa.port), - u32::from_be(sa.flowinfo), - u32::from_be(sa.scope_id), - )) - } -} - -impl From<&SocketAddrV6> for SockAddrIn6 { - fn from(sa: &SocketAddrV6) -> Self { - Self { - len: SA_IN6_SIZE as u8, - family: AF_INET6, - port: sa.port().to_be(), - flowinfo: sa.flowinfo().to_be(), - addr: sa.ip().octets(), - scope_id: sa.scope_id().to_be(), - } - } -} - -pub(super) fn pack_sockaddr(sockaddr: &SocketAddr) -> Vec { - match sockaddr { - SocketAddr::V4(sockaddr_v4) => { - let sockaddr_in: SockAddrIn = sockaddr_v4.into(); - let bytes = unsafe { cast_bytes(&sockaddr_in) }; - Vec::from(bytes) - } - SocketAddr::V6(sockaddr_v6) => { - let sockaddr_in6: SockAddrIn6 = sockaddr_v6.into(); - let bytes = unsafe { cast_bytes(&sockaddr_in6) }; - Vec::from(bytes) - } - } -} - -pub(super) fn unpack_sockaddr(buf: &[u8]) -> Option { - match buf.len() { - SA_IN_SIZE => { - let sockaddr_in = unsafe { cast_ref::(buf) }; - // sanity checks - if sockaddr_in.len == SA_IN_SIZE as u8 && sockaddr_in.family == AF_INET { - Some(sockaddr_in.into()) - } else { - None - } - } - - SA_IN6_SIZE => { - let sockaddr_in6 = unsafe { cast_ref::(buf) }; - // sanity checks - if sockaddr_in6.len == SA_IN6_SIZE as u8 && sockaddr_in6.family == AF_INET6 { - Some(sockaddr_in6.into()) - } else { - None - } - } - - _ => None, - } -} - -#[cfg(test)] -mod tests { - use std::net::IpAddr; - - use super::*; - - #[test] - fn ip4() { - let buf = [16, 2, 28, 133, 192, 168, 12, 34, 0, 0, 0, 0, 0, 0, 0, 0]; - let addr = unpack_sockaddr(&buf).unwrap(); - assert_eq!(addr.port(), 7301); - assert_eq!(addr.ip(), IpAddr::V4(Ipv4Addr::new(192, 168, 12, 34))); - } - - #[test] - fn ip6() { - let buf = [ - 28, 30, 28, 133, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 192, 168, 12, 34, - 0, 0, 0, 0, - ]; - let addr = unpack_sockaddr(&buf).unwrap(); - assert_eq!(addr.port(), 7301); - assert_eq!( - addr.ip(), - IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc0a8, 0x0c22)) - ); - } -} diff --git a/src/wireguard/bsd/timespec.rs b/src/wireguard/bsd/timespec.rs deleted file mode 100644 index 64d6c7b7..00000000 --- a/src/wireguard/bsd/timespec.rs +++ /dev/null @@ -1,57 +0,0 @@ -use std::{ - mem::size_of, - time::{Duration, SystemTime}, -}; - -use super::{cast_bytes, cast_ref}; - -#[repr(C)] -struct TimeSpec { - tv_sec: i64, - tv_nsec: i64, -} - -impl TimeSpec { - fn duration(&self) -> Duration { - Duration::from_secs(self.tv_sec as u64) + Duration::from_nanos(self.tv_nsec as u64) - } -} - -impl From<&TimeSpec> for SystemTime { - fn from(time_spec: &TimeSpec) -> SystemTime { - SystemTime::UNIX_EPOCH + time_spec.duration() - } -} - -impl From<&SystemTime> for TimeSpec { - fn from(system_time: &SystemTime) -> Self { - if let Ok(duration) = system_time.duration_since(SystemTime::UNIX_EPOCH) { - Self { - tv_sec: duration.as_secs() as i64, - tv_nsec: duration.as_nanos() as i64, - } - } else { - Self { - tv_sec: 0, - tv_nsec: 0, - } - } - } -} - -pub(super) fn pack_timespec(system_time: &SystemTime) -> Vec { - let timespec: TimeSpec = system_time.into(); - let bytes = unsafe { cast_bytes(×pec) }; - Vec::from(bytes) -} - -pub(super) fn unpack_timespec(buf: &[u8]) -> Option { - const TS_SIZE: usize = size_of::(); - match buf.len() { - TS_SIZE => { - let ts = unsafe { cast_ref::(buf) }; - Some(ts.into()) - } - _ => None, - } -} diff --git a/src/wireguard/bsd/wgio.rs b/src/wireguard/bsd/wgio.rs deleted file mode 100644 index 874ee9a0..00000000 --- a/src/wireguard/bsd/wgio.rs +++ /dev/null @@ -1,122 +0,0 @@ -use std::{ - alloc::{alloc, dealloc, Layout}, - error::Error, - fmt, - os::fd::RawFd, - ptr::null_mut, - slice::from_raw_parts, -}; - -use nix::{errno::Errno, ioctl_readwrite, sys::socket}; - -#[derive(Debug)] -pub enum WgIoError { - MemAlloc, - ReadIo(Errno), - WriteIo(Errno), -} - -impl Error for WgIoError {} - -impl fmt::Display for WgIoError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::MemAlloc => write!(f, "memory allocation"), - Self::ReadIo(errno) => write!(f, "read error {errno}"), - Self::WriteIo(errno) => write!(f, "write error {errno}"), - } - } -} - -// FIXME: `WgDataIo` has to be declared as public -ioctl_readwrite!(write_wireguard_data, b'i', 210, WgDataIo); -ioctl_readwrite!(read_wireguard_data, b'i', 211, WgDataIo); - -/// Create socket for ioctl communication. -fn get_dgram_socket() -> Result { - socket::socket( - socket::AddressFamily::Inet, - socket::SockType::Datagram, - socket::SockFlag::empty(), - None, - ) -} - -/// Represent `struct wg_data_io` defined in -/// https://github.com/freebsd/freebsd-src/blob/main/sys/dev/wg/if_wg.h -#[repr(C)] -pub struct WgDataIo { - pub(super) wgd_name: [u8; 16], - pub(super) wgd_data: *mut u8, // *void - pub(super) wgd_size: usize, -} - -impl WgDataIo { - /// Create `WgDataIo` without data buffer. - #[must_use] - pub fn new(if_name: &str) -> Self { - let mut wgd_name = [0u8; 16]; - if_name - .bytes() - .take(15) - .enumerate() - .for_each(|(i, b)| wgd_name[i] = b); - Self { - wgd_name, - wgd_data: null_mut(), - wgd_size: 0, - } - } - - /// Allocate data buffer. - fn alloc_data(&mut self) -> Result<(), WgIoError> { - if self.wgd_data.is_null() { - if let Ok(layout) = Layout::array::(self.wgd_size) { - unsafe { - self.wgd_data = alloc(layout); - } - return Ok(()); - } - } - Err(WgIoError::MemAlloc) - } - - /// Return buffer as slice. - pub(super) fn as_slice<'a>(&self) -> &'a [u8] { - unsafe { from_raw_parts(self.wgd_data, self.wgd_size) } - } - - pub(super) fn read_data(&mut self) -> Result<(), WgIoError> { - let socket = get_dgram_socket().map_err(WgIoError::ReadIo)?; - unsafe { - // First do ioctl with empty `wg_data` to obtain buffer size. - read_wireguard_data(socket, self).map_err(WgIoError::ReadIo)?; - // Allocate buffer. - self.alloc_data()?; - // Second call to ioctl with allocated buffer. - read_wireguard_data(socket, self).map_err(WgIoError::ReadIo)?; - } - - Ok(()) - } - - pub(super) fn write_data(&mut self) -> Result<(), WgIoError> { - let socket = get_dgram_socket().map_err(WgIoError::WriteIo)?; - unsafe { - write_wireguard_data(socket, self).map_err(WgIoError::WriteIo)?; - } - - Ok(()) - } -} - -impl Drop for WgDataIo { - fn drop(&mut self) { - if self.wgd_size != 0 { - let layout = Layout::array::(self.wgd_size).expect("Bad layout"); - unsafe { - dealloc(self.wgd_data, layout); - } - } - } -} diff --git a/src/wireguard/host.rs b/src/wireguard/host.rs deleted file mode 100644 index 9425fb3d..00000000 --- a/src/wireguard/host.rs +++ /dev/null @@ -1,467 +0,0 @@ -use super::{IpAddrMask, Key}; -use crate::proto; -#[cfg(target_os = "linux")] -use netlink_packet_wireguard::{ - constants::{WGDEVICE_F_REPLACE_PEERS, WGPEER_F_REMOVE_ME, WGPEER_F_REPLACE_ALLOWEDIPS}, - nlas::{WgAllowedIpAttrs, WgDeviceAttrs, WgPeer, WgPeerAttrs}, -}; -use std::{ - collections::HashMap, - io::{self, BufRead, BufReader, Read}, - net::SocketAddr, - str::FromStr, - time::{Duration, SystemTime}, -}; - -#[derive(Debug, Default, PartialEq, Clone)] -pub struct Peer { - pub public_key: Key, - pub(super) preshared_key: Option, - pub(super) protocol_version: Option, - pub(super) endpoint: Option, - pub last_handshake: Option, - pub(super) tx_bytes: u64, - pub(super) rx_bytes: u64, - pub(super) persistent_keepalive_interval: Option, - pub allowed_ips: Vec, -} - -impl Peer { - #[must_use] - pub fn new(public_key: Key) -> Self { - Self { - public_key, - preshared_key: None, - protocol_version: None, - endpoint: None, - last_handshake: None, - tx_bytes: 0, - rx_bytes: 0, - persistent_keepalive_interval: None, - allowed_ips: Vec::new(), - } - } - - pub fn set_allowed_ips(&mut self, allowed_ips: Vec) { - self.allowed_ips = allowed_ips; - } - - #[must_use] - pub fn as_uapi_update(&self) -> String { - let mut output = format!("public_key={}\n", self.public_key.to_lower_hex()); - if let Some(key) = &self.preshared_key { - output.push_str("preshared_key="); - output.push_str(&key.to_lower_hex()); - output.push('\n'); - } - if let Some(endpoint) = &self.endpoint { - output.push_str("endpoint="); - output.push_str(&endpoint.to_string()); - output.push('\n'); - } - if let Some(interval) = &self.persistent_keepalive_interval { - output.push_str("persistent_keepalive_interval="); - output.push_str(&interval.to_string()); - output.push('\n'); - } - output.push_str("replace_allowed_ips=true\n"); - for allowed_ip in &self.allowed_ips { - output.push_str("allowed_ip="); - output.push_str(&allowed_ip.to_string()); - output.push('\n'); - } - - output - } - - #[must_use] - pub fn as_uapi_remove(&self) -> String { - format!( - "public_key={}\nremove=true\n", - self.public_key.to_lower_hex() - ) - } -} - -#[cfg(target_os = "linux")] -impl Peer { - #[must_use] - pub fn from_nlas(nlas: &[WgPeerAttrs]) -> Self { - let mut peer = Self::default(); - - for nla in nlas { - match nla { - WgPeerAttrs::PublicKey(value) => peer.public_key = Key::new(*value), - WgPeerAttrs::PresharedKey(value) => peer.preshared_key = Some(Key::new(*value)), - WgPeerAttrs::Endpoint(value) => peer.endpoint = Some(*value), - WgPeerAttrs::PersistentKeepalive(value) => { - peer.persistent_keepalive_interval = Some(*value); - } - WgPeerAttrs::LastHandshake(value) => peer.last_handshake = Some(*value), - WgPeerAttrs::RxBytes(value) => peer.rx_bytes = *value, - WgPeerAttrs::TxBytes(value) => peer.tx_bytes = *value, - WgPeerAttrs::AllowedIps(nlas) => { - for nla in nlas { - let ip = nla.iter().find_map(|nla| match nla { - WgAllowedIpAttrs::IpAddr(ip) => Some(*ip), - _ => None, - }); - let cidr = nla.iter().find_map(|nla| match nla { - WgAllowedIpAttrs::Cidr(cidr) => Some(*cidr), - _ => None, - }); - if let (Some(ip), Some(cidr)) = (ip, cidr) { - peer.allowed_ips.push(IpAddrMask::new(ip, cidr)); - } - } - } - _ => (), - } - } - - peer - } - - #[must_use] - pub fn as_nlas(&self, ifname: &str) -> Vec { - vec![ - WgDeviceAttrs::IfName(ifname.into()), - WgDeviceAttrs::Peers(vec![self.as_nlas_peer()]), - ] - } - - #[must_use] - pub fn as_nlas_remove(&self, ifname: &str) -> Vec { - vec![ - WgDeviceAttrs::IfName(ifname.into()), - WgDeviceAttrs::Peers(vec![WgPeer(vec![ - WgPeerAttrs::PublicKey(self.public_key.as_array()), - WgPeerAttrs::Flags(WGPEER_F_REMOVE_ME), - ])]), - ] - } - - #[must_use] - pub fn as_nlas_peer(&self) -> WgPeer { - let mut attrs = vec![WgPeerAttrs::PublicKey(self.public_key.as_array())]; - if let Some(keepalive) = self.persistent_keepalive_interval { - attrs.push(WgPeerAttrs::PersistentKeepalive(keepalive)); - } - attrs.push(WgPeerAttrs::Flags(WGPEER_F_REPLACE_ALLOWEDIPS)); - let allowed_ips = self - .allowed_ips - .iter() - .map(IpAddrMask::to_nlas_allowed_ip) - .collect(); - attrs.push(WgPeerAttrs::AllowedIps(allowed_ips)); - - WgPeer(attrs) - } -} - -impl From for Peer { - fn from(proto_peer: proto::Peer) -> Self { - let mut peer = Self::new(proto_peer.pubkey.as_str().try_into().unwrap_or_default()); - peer.allowed_ips = proto_peer - .allowed_ips - .iter() - .filter_map(|entry| IpAddrMask::from_str(entry).ok()) - .collect(); - peer - } -} - -impl From<&Peer> for proto::Peer { - fn from(peer: &Peer) -> Self { - Self { - pubkey: peer.public_key.to_string(), - allowed_ips: peer.allowed_ips.iter().map(ToString::to_string).collect(), - } - } -} - -impl From<&Peer> for proto::PeerStats { - fn from(peer: &Peer) -> Self { - Self { - public_key: peer.public_key.to_string(), - endpoint: peer - .endpoint - .map_or(String::new(), |endpoint| endpoint.to_string()), - allowed_ips: peer.allowed_ips.iter().map(ToString::to_string).collect(), - latest_handshake: peer.last_handshake.map_or(0, |ts| { - ts.duration_since(SystemTime::UNIX_EPOCH) - .map_or(0, |duration| duration.as_secs() as i64) - }), - download: peer.rx_bytes as i64, - upload: peer.tx_bytes as i64, - keepalive_interval: i64::from(peer.persistent_keepalive_interval.unwrap_or_default()), - } - } -} - -#[derive(Debug, Default)] -pub struct Host { - pub listen_port: u16, - pub private_key: Option, - pub(super) fwmark: Option, - pub peers: HashMap, -} - -impl Host { - #[must_use] - pub fn new(listen_port: u16, private_key: Key) -> Self { - Self { - listen_port, - private_key: Some(private_key), - fwmark: None, - peers: HashMap::new(), - } - } - - #[must_use] - pub fn as_uapi(&self) -> String { - let mut output = format!("listen_port={}\n", self.listen_port); - if let Some(key) = &self.private_key { - output.push_str("private_key="); - output.push_str(&key.to_lower_hex()); - output.push('\n'); - } - if let Some(fwmark) = &self.fwmark { - output.push_str("fwmark="); - output.push_str(&fwmark.to_string()); - output.push('\n'); - } - output.push_str("replace_peers=true\n"); - for peer in self.peers.values() { - output.push_str(peer.as_uapi_update().as_ref()); - } - - output - } - - // TODO: use custom Error - pub fn parse_uapi(buf: impl Read) -> io::Result { - let reader = BufReader::new(buf); - let mut host = Self::default(); - let mut peer_ref = None; - - for line_result in reader.lines() { - let line = match line_result { - Ok(line) => line, - Err(err) => { - error!("Error parsing buffer line: {err}"); - continue; - } - }; - if let Some((keyword, value)) = line.split_once('=') { - match keyword { - "listen_port" => host.listen_port = value.parse().unwrap_or_default(), - "fwmark" => host.fwmark = value.parse().ok(), - "private_key" => host.private_key = Key::decode(value).ok(), - // "public_key" starts new peer definition - "public_key" => { - if let Ok(key) = Key::decode(value) { - let peer = Peer::new(key.clone()); - host.peers.insert(key.clone(), peer); - peer_ref = host.peers.get_mut(&key); - } else { - peer_ref = None; - } - } - "preshared_key" => { - if let Some(ref mut peer) = peer_ref { - peer.preshared_key = Key::decode(value).ok(); - } - } - "protocol_version" => { - if let Some(ref mut peer) = peer_ref { - peer.protocol_version = value.parse().ok(); - } - } - "endpoint" => { - if let Some(ref mut peer) = peer_ref { - peer.endpoint = SocketAddr::from_str(value).ok(); - } - } - "persistent_keepalive_interval" => { - if let Some(ref mut peer) = peer_ref { - peer.persistent_keepalive_interval = value.parse().ok(); - } - } - "allowed_ip" => { - if let Some(ref mut peer) = peer_ref { - if let Ok(addr) = value.parse() { - peer.allowed_ips.push(addr); - } - } - } - "last_handshake_time_sec" => { - if let Some(ref mut peer) = peer_ref { - let handshake = - peer.last_handshake.get_or_insert(SystemTime::UNIX_EPOCH); - *handshake += Duration::from_secs(value.parse().unwrap_or_default()); - } - } - "last_handshake_time_nsec" => { - if let Some(ref mut peer) = peer_ref { - let handshake = - peer.last_handshake.get_or_insert(SystemTime::UNIX_EPOCH); - *handshake += Duration::from_nanos(value.parse().unwrap_or_default()); - } - } - "rx_bytes" => { - if let Some(ref mut peer) = peer_ref { - peer.rx_bytes = value.parse().unwrap_or_default(); - } - } - "tx_bytes" => { - if let Some(ref mut peer) = peer_ref { - peer.tx_bytes = value.parse().unwrap_or_default(); - } - } - // "errno" ends config - "errno" => { - if let Ok(errno) = value.parse::() { - if errno == 0 { - // Break here, or BufReader will wait for EOF. - break; - } - } - return Err(io::Error::new(io::ErrorKind::Other, "error reading UAPI")); - } - _ => error!("Unknown UAPI keyword {}", keyword), - } - } - } - - Ok(host) - } -} - -#[cfg(target_os = "linux")] -impl Host { - pub fn append_nlas(&mut self, nlas: &[WgDeviceAttrs]) { - for nla in nlas { - match nla { - WgDeviceAttrs::PrivateKey(value) => self.private_key = Some(Key::new(*value)), - WgDeviceAttrs::ListenPort(value) => self.listen_port = *value, - WgDeviceAttrs::Fwmark(value) => self.fwmark = Some(*value), - WgDeviceAttrs::Peers(nlas) => { - for nla in nlas { - let peer = Peer::from_nlas(nla); - self.peers.insert(peer.public_key.clone(), peer); - } - } - _ => (), - } - } - } - - #[must_use] - pub fn as_nlas(&self, ifname: &str) -> Vec { - let mut nlas = vec![ - WgDeviceAttrs::IfName(ifname.into()), - WgDeviceAttrs::ListenPort(self.listen_port), - ]; - if let Some(key) = &self.private_key { - nlas.push(WgDeviceAttrs::PrivateKey(key.as_array())); - } - if let Some(fwmark) = &self.fwmark { - nlas.push(WgDeviceAttrs::Fwmark(*fwmark)); - } - nlas.push(WgDeviceAttrs::Flags(WGDEVICE_F_REPLACE_PEERS)); - let peers = self.peers.values().map(Peer::as_nlas_peer).collect(); - nlas.push(WgDeviceAttrs::Peers(peers)); - nlas - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::io::Cursor; - - #[test] - fn test_parse_config() { - let uapi_output = - b"private_key=000102030405060708090a0b0c0d0e0ff0e1d2c3b4a5968778695a4b3c2d1e0f\n\ - listen_port=7301\n\ - public_key=100102030405060708090a0b0c0d0e0ff0e1d2c3b4a5968778695a4b3c2d1e0f\n\ - preshared_key=0000000000000000000000000000000000000000000000000000000000000000\n\ - protocol_version=1\n\ - last_handshake_time_sec=0\n\ - last_handshake_time_nsec=0\n\ - tx_bytes=0\n\ - rx_bytes=0\n\ - persistent_keepalive_interval=0\n\ - allowed_ip=10.6.0.12/32\n\ - public_key=200102030405060708090a0b0c0d0e0ff0e1d2c3b4a5968778695a4b3c2d1e0f\n\ - preshared_key=0000000000000000000000000000000000000000000000000000000000000000\n\ - protocol_version=1\n\ - endpoint=83.11.218.160:51421\n\ - last_handshake_time_sec=1654631933\n\ - last_handshake_time_nsec=862977251\n\ - tx_bytes=52759980\n\ - rx_bytes=3683056\n\ - persistent_keepalive_interval=0\n\ - allowed_ip=10.6.0.25/32\n\ - public_key=300102030405060708090a0b0c0d0e0ff0e1d2c3b4a5968778695a4b3c2d1e0f\n\ - preshared_key=0000000000000000000000000000000000000000000000000000000000000000\n\ - protocol_version=1\n\ - endpoint=31.135.163.194:37712\n\ - last_handshake_time_sec=1654776419\n\ - last_handshake_time_nsec=732507856\n\ - tx_bytes=1009094476\n\ - rx_bytes=76734328\n\ - persistent_keepalive_interval=0\n\ - allowed_ip=10.6.0.23/32\n\ - errno=0\n"; - let buf = Cursor::new(uapi_output); - let host = Host::parse_uapi(buf).unwrap(); - assert_eq!(host.listen_port, 7301); - assert_eq!(host.peers.len(), 3); - - let key = Key::decode("200102030405060708090a0b0c0d0e0ff0e1d2c3b4a5968778695a4b3c2d1e0f") - .unwrap(); - let stats = proto::PeerStats::from(&host.peers[&key]); - assert_eq!(stats.download, 3683056); - assert_eq!(stats.upload, 52759980); - assert_eq!(stats.latest_handshake, 1654631933); - } - - #[test] - fn test_host_uapi() { - let key_str = "000102030405060708090a0b0c0d0e0ff0e1d2c3b4a5968778695a4b3c2d1e0f"; - let key = Key::decode(key_str).unwrap(); - - let host = Host::new(12345, key); - assert_eq!( - "listen_port=12345\n\ - private_key=000102030405060708090a0b0c0d0e0ff0e1d2c3b4a5968778695a4b3c2d1e0f\n\ - replace_peers=true\n", - host.as_uapi() - ); - } - - #[test] - fn test_peer_uapi() { - let key_str = "000102030405060708090a0b0c0d0e0ff0e1d2c3b4a5968778695a4b3c2d1e0f"; - let key = Key::decode(key_str).unwrap(); - - let peer = Peer::new(key); - assert_eq!( - "public_key=000102030405060708090a0b0c0d0e0ff0e1d2c3b4a5968778695a4b3c2d1e0f\n\ - replace_allowed_ips=true\n", - peer.as_uapi_update() - ); - - let key_str = "00112233445566778899aaabbcbddeeff0e1d2c3b4a5968778695a4b3c2d1e0f"; - let key = Key::decode(key_str).unwrap(); - let peer = Peer::new(key); - assert_eq!( - "public_key=00112233445566778899aaabbcbddeeff0e1d2c3b4a5968778695a4b3c2d1e0f\n\ - remove=true\n", - peer.as_uapi_remove() - ); - } -} diff --git a/src/wireguard/key.rs b/src/wireguard/key.rs deleted file mode 100644 index fb164768..00000000 --- a/src/wireguard/key.rs +++ /dev/null @@ -1,199 +0,0 @@ -use base64::{decode, encode, DecodeError}; -use std::{ - error, fmt, - hash::{Hash, Hasher}, - str::FromStr, -}; - -const KEY_LENGTH: usize = 32; - -#[derive(Clone, Default)] -pub struct Key([u8; KEY_LENGTH]); - -#[derive(Debug)] -pub enum KeyError { - InvalidCharacter(u8), - InvalidStringLength(usize), -} - -impl error::Error for KeyError {} - -impl fmt::Display for KeyError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Self::InvalidCharacter(char) => { - write!(f, "Invalid character {char}") - } - Self::InvalidStringLength(length) => write!(f, "Invalid string length {length}"), - } - } -} - -impl Key { - #[must_use] - pub fn new(buf: [u8; KEY_LENGTH]) -> Self { - Self(buf) - } - - #[must_use] - pub fn as_array(&self) -> [u8; KEY_LENGTH] { - self.0 - } - - #[must_use] - pub fn as_slice(&self) -> &[u8] { - self.0.as_slice() - } - - #[must_use] - pub fn to_lower_hex(&self) -> String { - let mut hex = String::with_capacity(64); - let to_char = |nibble: u8| -> char { - (match nibble { - 0..=9 => b'0' + nibble, - _ => nibble + b'a' - 10, - }) as char - }; - self.0.iter().for_each(|byte| { - hex.push(to_char(*byte >> 4)); - hex.push(to_char(*byte & 0xf)); - }); - hex - } - - pub fn decode>(hex: T) -> Result { - let hex = hex.as_ref(); - let length = hex.len(); - if length != 64 { - return Err(KeyError::InvalidStringLength(length)); - } - - let hex_value = |char: u8| -> Result { - match char { - b'A'..=b'F' => Ok(char - b'A' + 10), - b'a'..=b'f' => Ok(char - b'a' + 10), - b'0'..=b'9' => Ok(char - b'0'), - _ => Err(KeyError::InvalidCharacter(char)), - } - }; - - let mut key = [0; KEY_LENGTH]; - for (index, chunk) in hex.chunks(2).enumerate() { - let msd = hex_value(chunk[0])?; - let lsd = hex_value(chunk[1])?; - key[index] = msd << 4 | lsd; - } - Ok(Self(key)) - } -} - -impl TryFrom<&str> for Key { - type Error = DecodeError; - - fn try_from(value: &str) -> Result { - let v = decode(value)?; - if v.len() == KEY_LENGTH { - let buf = v.try_into().map_err(|_| Self::Error::InvalidLength)?; - Ok(Self::new(buf)) - } else { - Err(Self::Error::InvalidLength) - } - } -} - -impl TryFrom<&[u8]> for Key { - type Error = DecodeError; - - fn try_from(value: &[u8]) -> Result { - if value.len() == KEY_LENGTH { - let buf = - <[u8; KEY_LENGTH]>::try_from(value).map_err(|_| Self::Error::InvalidLength)?; - Ok(Self::new(buf)) - } else { - Err(Self::Error::InvalidLength) - } - } -} - -impl FromStr for Key { - type Err = DecodeError; - - fn from_str(value: &str) -> Result { - let v = decode(value)?; - if v.len() == KEY_LENGTH { - let buf = v.try_into().map_err(|_| Self::Err::InvalidLength)?; - Ok(Self::new(buf)) - } else { - Err(Self::Err::InvalidLength) - } - } -} - -impl Hash for Key { - fn hash(&self, state: &mut H) { - self.0.hash(state); - } -} - -impl PartialEq for Key { - fn eq(&self, other: &Self) -> bool { - self.0 == other.0 - } -} - -impl Eq for Key {} - -impl fmt::Debug for Key { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.to_lower_hex()) - } -} - -impl fmt::Display for Key { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", encode(self.0)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn decode_key() { - let key_str = "000102030405060708090a0b0c0d0e0ff0e1d2c3b4a5968778695a4b3c2d1e0f"; - let key = Key::decode(key_str).unwrap(); - assert_eq!( - key.0, - [ - 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, - 0x0e, 0x0f, 0xf0, 0xe1, 0xd2, 0xc3, 0xb4, 0xa5, 0x96, 0x87, 0x78, 0x69, 0x5a, 0x4b, - 0x3c, 0x2d, 0x1e, 0x0f - ] - ); - assert_eq!(key.to_lower_hex(), key_str); - assert_eq!( - format!("{key}"), - "AAECAwQFBgcICQoLDA0OD/Dh0sO0pZaHeGlaSzwtHg8=" - ); - } - - #[test] - fn parse_key() { - let key_str = "AAECAwQFBgcICQoLDA0OD/Dh0sO0pZaHeGlaSzwtHg8="; - let key: Key = key_str.try_into().unwrap(); - assert_eq!( - key.0, - [ - 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, - 0x0e, 0x0f, 0xf0, 0xe1, 0xd2, 0xc3, 0xb4, 0xa5, 0x96, 0x87, 0x78, 0x69, 0x5a, 0x4b, - 0x3c, 0x2d, 0x1e, 0x0f - ] - ); - assert_eq!( - key.to_lower_hex(), - "000102030405060708090a0b0c0d0e0ff0e1d2c3b4a5968778695a4b3c2d1e0f" - ); - assert_eq!(format!("{key}"), key_str); - } -} diff --git a/src/wireguard/mod.rs b/src/wireguard/mod.rs deleted file mode 100644 index 4272159c..00000000 --- a/src/wireguard/mod.rs +++ /dev/null @@ -1,110 +0,0 @@ -#[cfg(target_os = "freebsd")] -pub mod bsd; -mod host; -mod key; -mod net; -#[cfg(target_os = "linux")] -pub mod netlink; -pub mod wgapi; - -use crate::{error::GatewayError, proto::Configuration}; -#[cfg(feature = "boringtun")] -use boringtun::{ - device::drop_privileges::drop_privileges, - device::{DeviceConfig, DeviceHandle}, -}; -use std::{process::Command, str::FromStr}; -use wgapi::WGApi; - -/// Creates wireguard interface using userspace implementation. -/// https://github.com/cloudflare/boringtun -/// -/// # Arguments -/// -/// * `name` - Interface name -#[cfg(feature = "boringtun")] -pub fn create_interface_userspace(ifname: &str) -> Result<(), GatewayError> { - let enable_drop_privileges = true; - - let config = DeviceConfig::default(); - - let mut device_handle = DeviceHandle::new(ifname, config).map_err(GatewayError::BorningTun)?; - - if enable_drop_privileges { - if let Err(e) = drop_privileges() { - error!("Failed to drop privileges: {:?}", e); - } - } - - tokio::spawn(async move { - device_handle.wait(); - }); - Ok(()) -} - -/// Assigns address to interface. -/// -/// # Arguments -/// -/// * `interface` - Interface name -/// * `addr` - Address to assign to interface -pub fn assign_addr(ifname: &str, addr: &IpAddrMask) -> Result<(), GatewayError> { - if cfg!(target_os = "linux") { - #[cfg(target_os = "linux")] - netlink::address_interface(ifname, addr)?; - } else if cfg!(target_os = "macos") { - // On macOS, interface is point-to-point and requires a pair of addresses - let address_string = addr.ip.to_string(); - Command::new("ifconfig") - .args([ifname, &address_string, &address_string]) - .output()?; - } else { - Command::new("ifconfig") - .args([ifname, &addr.to_string()]) - .output()?; - } - - Ok(()) -} - -/// Helper method performing interface configuration -pub fn setup_interface( - ifname: &str, - userspace: bool, - config: &Configuration, -) -> Result<(), GatewayError> { - if userspace { - #[cfg(feature = "boringtun")] - create_interface_userspace(ifname)?; - } else { - #[cfg(target_os = "linux")] - netlink::create_interface(ifname)?; - } - - let address = IpAddrMask::from_str(&config.address)?; - assign_addr(ifname, &address)?; - let key = config.prvkey.as_str().try_into()?; - let mut host = Host::new(config.port as u16, key); - for peercfg in &config.peers { - let key: Key = peercfg.pubkey.as_str().try_into()?; - let mut peer = Peer::new(key.clone()); - let allowed_ips = peercfg - .allowed_ips - .iter() - .filter_map(|entry| IpAddrMask::from_str(entry).ok()) - .collect(); - peer.set_allowed_ips(allowed_ips); - - host.peers.insert(key, peer); - } - let api = WGApi::new(ifname.into(), userspace); - api.write_host(&host)?; - - Ok(()) -} - -pub use { - host::{Host, Peer}, - key::Key, - net::{IpAddrMask, IpAddrParseError}, -}; diff --git a/src/wireguard/net.rs b/src/wireguard/net.rs deleted file mode 100644 index a8535d19..00000000 --- a/src/wireguard/net.rs +++ /dev/null @@ -1,122 +0,0 @@ -#[cfg(target_os = "linux")] -use netlink_packet_wireguard::{ - constants::{AF_INET, AF_INET6}, - nlas::{WgAllowedIp, WgAllowedIpAttrs}, -}; -use std::{error, fmt, net::IpAddr, str::FromStr}; - -#[derive(Debug, PartialEq, Clone)] -pub struct IpAddrMask { - // IP v4 or v6 - pub ip: IpAddr, - // Classless Inter-Domain Routing - pub cidr: u8, -} - -impl IpAddrMask { - #[must_use] - pub fn new(ip: IpAddr, cidr: u8) -> Self { - Self { ip, cidr } - } - - #[cfg(target_os = "linux")] - #[must_use] - pub fn to_nlas_allowed_ip(&self) -> WgAllowedIp { - let mut attrs = Vec::new(); - attrs.push(WgAllowedIpAttrs::Family(if self.ip.is_ipv4() { - AF_INET - } else { - AF_INET6 - })); - attrs.push(WgAllowedIpAttrs::IpAddr(self.ip)); - attrs.push(WgAllowedIpAttrs::Cidr(self.cidr)); - WgAllowedIp(attrs) - } -} - -impl fmt::Display for IpAddrMask { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}/{}", self.ip, self.cidr) - } -} - -#[derive(Debug, PartialEq)] -pub struct IpAddrParseError; - -impl error::Error for IpAddrParseError {} - -impl fmt::Display for IpAddrParseError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "IP address/mask parse error") - } -} - -impl FromStr for IpAddrMask { - type Err = IpAddrParseError; - - fn from_str(ip_str: &str) -> Result { - if let Some((left, right)) = ip_str.split_once('/') { - Ok(IpAddrMask { - ip: left.parse().map_err(|_| IpAddrParseError)?, - cidr: right.parse().map_err(|_| IpAddrParseError)?, - }) - } else { - let ip: IpAddr = ip_str.parse().map_err(|_| IpAddrParseError)?; - Ok(IpAddrMask { - ip, - cidr: if ip.is_ipv4() { 32 } else { 128 }, - }) - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::net::{Ipv4Addr, Ipv6Addr}; - - #[test] - fn parse_ip_addr() { - assert_eq!( - "192.168.0.1/24".parse::(), - Ok(IpAddrMask::new( - IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1)), - 24 - )) - ); - - assert_eq!( - "10.11.12.13".parse::(), - Ok(IpAddrMask::new( - IpAddr::V4(Ipv4Addr::new(10, 11, 12, 13)), - 32 - )) - ); - - assert_eq!( - "2001:0db8::1428:57ab/96".parse::(), - Ok(IpAddrMask::new( - IpAddr::V6(Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0x1428, 0x57ab)), - 96 - )) - ); - - assert_eq!( - "::1".parse::(), - Ok(IpAddrMask::new( - IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), - 128 - )) - ); - - assert_eq!( - "172.168.0.256/24".parse::(), - Err(IpAddrParseError) - ); - - assert_eq!( - "172.168.0.0/256".parse::(), - Err(IpAddrParseError) - ); - } -} diff --git a/src/wireguard/netlink.rs b/src/wireguard/netlink.rs deleted file mode 100644 index 9a863ddb..00000000 --- a/src/wireguard/netlink.rs +++ /dev/null @@ -1,318 +0,0 @@ -use std::{ - fmt::Debug, - io, - net::{IpAddr, Ipv4Addr}, -}; - -use netlink_packet_core::{ - NetlinkDeserializable, NetlinkMessage, NetlinkPayload, NetlinkSerializable, NLM_F_ACK, - NLM_F_CREATE, NLM_F_DUMP, NLM_F_EXCL, NLM_F_REPLACE, NLM_F_REQUEST, -}; -use netlink_packet_generic::{ - ctrl::{nlas::GenlCtrlAttrs, GenlCtrl, GenlCtrlCmd}, - GenlFamily, GenlMessage, -}; -use netlink_packet_route::{ - address, - link::nlas::{Info, InfoKind, Nla}, - AddressMessage, LinkHeader, LinkMessage, RtnlMessage, AF_INET, AF_INET6, IFF_UP, -}; -use netlink_packet_wireguard::{nlas::WgDeviceAttrs, Wireguard, WireguardCmd}; -use netlink_sys::{constants::NETLINK_GENERIC, protocols::NETLINK_ROUTE, Socket, SocketAddr}; - -use super::{Host, IpAddrMask, Peer}; - -const SOCKET_BUFFER_LENGTH: usize = 12288; - -macro_rules! get_nla_value { - ($nlas:expr, $e:ident, $v:ident) => { - $nlas.iter().find_map(|attr| match attr { - $e::$v(value) => Some(value), - _ => None, - }) - }; -} - -pub fn netlink_request_genl( - mut message: GenlMessage, - flags: u16, -) -> io::Result>>> -where - F: GenlFamily + Clone + Debug + Eq, - GenlMessage: Clone + Debug + Eq + NetlinkSerializable + NetlinkDeserializable, -{ - if message.family_id() == 0 { - let genlmsg: GenlMessage = GenlMessage::from_payload(GenlCtrl { - cmd: GenlCtrlCmd::GetFamily, - nlas: vec![GenlCtrlAttrs::FamilyName(F::family_name().to_string())], - }); - let responses = netlink_request_genl::(genlmsg, NLM_F_REQUEST | NLM_F_ACK)?; - - match responses.get(0) { - Some(NetlinkMessage { - payload: - NetlinkPayload::InnerMessage(GenlMessage { - payload: GenlCtrl { nlas, .. }, - .. - }), - .. - }) => { - let family_id = get_nla_value!(nlas, GenlCtrlAttrs, FamilyId) - .ok_or_else(|| io::ErrorKind::NotFound)?; - message.set_resolved_family_id(*family_id); - } - _ => { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "Unexpected netlink payload", - )) - } - }; - } - netlink_request(message, flags, NETLINK_GENERIC) -} - -pub fn netlink_request( - message: I, - flags: u16, - socket: isize, -) -> io::Result>> -where - NetlinkPayload: From, - I: Clone + Debug + Eq + NetlinkSerializable + NetlinkDeserializable, -{ - debug!("Sending Netlink request: {message:?}, flags: {flags}, socket: {socket}",); - let mut req = NetlinkMessage::from(message); - - if req.buffer_len() > SOCKET_BUFFER_LENGTH { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - format!( - "Serialized netlink packet ({} bytes) larger than maximum size {SOCKET_BUFFER_LENGTH}: {req:?}", - req.buffer_len(), - ), - )); - } - - req.header.flags = flags; - req.finalize(); - let mut buf = [0; SOCKET_BUFFER_LENGTH]; - req.serialize(&mut buf); - let len = req.buffer_len(); - - let socket = Socket::new(socket)?; - let kernel_addr = SocketAddr::new(0, 0); - socket.connect(&kernel_addr)?; - let n_sent = socket.send(&buf[..len], 0)?; - if n_sent != len { - return Err(io::Error::new( - io::ErrorKind::UnexpectedEof, - "failed to send netlink request", - )); - } - - let mut responses = Vec::new(); - loop { - let n_received = socket.recv(&mut &mut buf[..], 0)?; - let mut offset = 0; - loop { - let response = NetlinkMessage::::deserialize(&buf[offset..]) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - debug!("Read netlink response from socket: {response:?}"); - match response.payload { - // We've parsed all parts of the response and can leave the loop. - NetlinkPayload::Error(msg) if msg.code.is_none() => return Ok(responses), - NetlinkPayload::Done(_) => return Ok(responses), - NetlinkPayload::Error(msg) => return Err(msg.into()), - _ => {} - } - let header_length = response.header.length as usize; - offset += header_length; - responses.push(response); - if offset == n_received || header_length == 0 { - // We've fully parsed the datagram, but there may be further datagrams - // with additional netlink response parts. - break; - } - } - } -} - -/// Create WireGuard interface. -pub fn create_interface(ifname: &str) -> io::Result<()> { - let mut message = LinkMessage::default(); - message.header.flags = IFF_UP; - message.header.change_mask = IFF_UP; - message.nlas.push(Nla::IfName(ifname.into())); - message - .nlas - .push(Nla::Info(vec![Info::Kind(InfoKind::Wireguard)])); - - match netlink_request( - RtnlMessage::NewLink(message), - NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_EXCL, - NETLINK_ROUTE, - ) { - Err(e) if e.kind() != io::ErrorKind::AlreadyExists => Err(e), - _ => Ok(()), - } -} - -fn set_address(ifindex: u32, address: &IpAddrMask) -> io::Result<()> { - let mut message = AddressMessage::default(); - - message.header.prefix_len = address.cidr; - message.header.index = ifindex; - - let address_vec = match address.ip { - IpAddr::V4(ipv4) => { - message.header.family = AF_INET as u8; - ipv4.octets().to_vec() - } - IpAddr::V6(ipv6) => { - message.header.family = AF_INET6 as u8; - ipv6.octets().to_vec() - } - }; - - if address.ip.is_multicast() { - message.nlas.push(address::Nla::Multicast(address_vec)); - } else if address.ip.is_unspecified() { - message.nlas.push(address::Nla::Unspec(address_vec)); - } else if address.ip.is_ipv6() { - message.nlas.push(address::Nla::Address(address_vec)); - } else { - message - .nlas - .push(address::Nla::Address(address_vec.clone())); - // for IPv4 the IFA_LOCAL address can be set to the same value as IFA_ADDRESS - message.nlas.push(address::Nla::Local(address_vec.clone())); - // set the IFA_BROADCAST address as well (IPv6 does not support broadcast) - if address.cidr == 32 { - message.nlas.push(address::Nla::Broadcast(address_vec)); - } else if let IpAddrMask { - ip: IpAddr::V4(ipv4), - .. - } = address - { - let broadcast = - Ipv4Addr::from((0xffff_ffff_u32) >> u32::from(address.cidr) | u32::from(*ipv4)); - message - .nlas - .push(address::Nla::Broadcast(broadcast.octets().to_vec())); - }; - } - - netlink_request( - RtnlMessage::NewAddress(message), - NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_REPLACE, - NETLINK_ROUTE, - )?; - Ok(()) -} - -pub fn address_interface(ifname: &str, address: &IpAddrMask) -> io::Result<()> { - let mut message = LinkMessage::default(); - message.nlas.push(Nla::IfName(ifname.into())); - message - .nlas - .push(Nla::Info(vec![Info::Kind(InfoKind::Wireguard)])); - - let responses = netlink_request( - RtnlMessage::GetLink(message), - NLM_F_REQUEST | NLM_F_ACK, - NETLINK_ROUTE, - )?; - - for nlmsg in responses { - match nlmsg { - NetlinkMessage { - payload: NetlinkPayload::InnerMessage(message), - .. - } => { - if let RtnlMessage::NewLink(LinkMessage { - header: LinkHeader { index, .. }, - .. - }) = message - { - return set_address(index, address); - } - } - _ => debug!("unknown nlmsg response"), - } - } - - Ok(()) -} - -/// Delete WireGuard interface. -pub fn delete_interface(ifname: &str) -> io::Result<()> { - let mut message = LinkMessage::default(); - message.nlas.push(Nla::IfName(ifname.into())); - message - .nlas - .push(Nla::Info(vec![Info::Kind(InfoKind::Wireguard)])); - - match netlink_request( - RtnlMessage::DelLink(message), - NLM_F_REQUEST | NLM_F_ACK, - NETLINK_ROUTE, - ) { - Err(e) if e.kind() != io::ErrorKind::AlreadyExists => Err(e), - _ => Ok(()), - } -} - -pub fn get_host(ifname: &str) -> Result { - debug!("Reading Netlink data for interface {ifname}"); - let genlmsg = GenlMessage::from_payload(Wireguard { - cmd: WireguardCmd::GetDevice, - nlas: vec![WgDeviceAttrs::IfName(ifname.into())], - }); - let responses = netlink_request_genl(genlmsg, NLM_F_REQUEST | NLM_F_DUMP)?; - - let mut host = Host::default(); - for nlmsg in responses { - if let NetlinkMessage { - payload: NetlinkPayload::InnerMessage(ref message), - .. - } = nlmsg - { - host.append_nlas(&message.payload.nlas); - } else { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("unexpected netlink payload: {nlmsg:?}"), - )); - } - } - - Ok(host) -} - -pub fn set_host(ifname: &str, host: &Host) -> io::Result<()> { - let genlmsg = GenlMessage::from_payload(Wireguard { - cmd: WireguardCmd::SetDevice, - nlas: host.as_nlas(ifname), - }); - netlink_request_genl(genlmsg, NLM_F_REQUEST | NLM_F_ACK)?; - Ok(()) -} - -pub fn set_peer(ifname: &str, peer: &Peer) -> io::Result<()> { - let genlmsg = GenlMessage::from_payload(Wireguard { - cmd: WireguardCmd::SetDevice, - nlas: peer.as_nlas(ifname), - }); - netlink_request_genl(genlmsg, NLM_F_REQUEST | NLM_F_ACK)?; - Ok(()) -} - -pub fn delete_peer(ifname: &str, peer: &Peer) -> io::Result<()> { - let genlmsg = GenlMessage::from_payload(Wireguard { - cmd: WireguardCmd::SetDevice, - nlas: peer.as_nlas_remove(ifname), - }); - netlink_request_genl(genlmsg, NLM_F_REQUEST | NLM_F_ACK)?; - Ok(()) -} diff --git a/src/wireguard/wgapi.rs b/src/wireguard/wgapi.rs deleted file mode 100644 index 694b618c..00000000 --- a/src/wireguard/wgapi.rs +++ /dev/null @@ -1,195 +0,0 @@ -#[cfg(target_os = "freebsd")] -use super::bsd::{delete_peer, get_host, set_host, set_peer}; -#[cfg(target_os = "linux")] -use super::netlink::{delete_peer, get_host, set_host, set_peer}; -use super::{Host, Peer}; -use std::{ - io::{self, BufRead, BufReader, Read, Write}, - os::unix::net::UnixStream, - time::Duration, -}; - -pub struct WGApi { - ifname: String, - userspace: bool, -} - -impl WGApi { - #[must_use] - pub fn new(ifname: String, userspace: bool) -> Self { - Self { ifname, userspace } - } - - fn socket(&self) -> io::Result { - let path = format!("/var/run/wireguard/{}.sock", self.ifname); - let socket = UnixStream::connect(path)?; - socket.set_read_timeout(Some(Duration::new(3, 0)))?; - Ok(socket) - } - - // FIXME: currenty other errors are ignored and result in 0 being returned. - fn parse_errno(buf: impl Read) -> u32 { - let reader = BufReader::new(buf); - for line_result in reader.lines() { - let line = match line_result { - Ok(line) => line, - Err(err) => { - error!("Error parsing buffer line: {err}"); - continue; - } - }; - if let Some((keyword, value)) = line.split_once('=') { - if keyword == "errno" { - return value.parse().unwrap_or_default(); - } - } - } - 0 - } - - pub fn read_host(&self) -> io::Result { - debug!("Reading host interface info"); - if self.userspace { - let mut socket = self.socket()?; - socket.write_all(b"get=1\n\n")?; - Host::parse_uapi(socket) - } else { - #[cfg(target_os = "freebsd")] - { - // FIXME: use proper error - get_host(&self.ifname).map_err(|err| { - io::Error::new(io::ErrorKind::Other, format!("kernel support error: {err}")) - }) - } - #[cfg(target_os = "linux")] - { - get_host(&self.ifname) - } - #[cfg(not(any(target_os = "linux", target_os = "freebsd")))] - Err(io::Error::new( - io::ErrorKind::Other, - "kernel support is not available on this platform", - )) - } - } - - pub fn write_host(&self, host: &Host) -> io::Result<()> { - if self.userspace { - let mut socket = self.socket()?; - socket.write_all(b"set=1\n")?; - socket.write_all(host.as_uapi().as_bytes())?; - socket.write_all(b"\n")?; - - if Self::parse_errno(socket) != 0 { - Err(io::Error::new( - io::ErrorKind::InvalidData, - "write configuration error", - )) - } else { - Ok(()) - } - } else { - #[cfg(target_os = "freebsd")] - { - // FIXME: use proper error - set_host(&self.ifname, host).map_err(|err| { - io::Error::new(io::ErrorKind::Other, format!("kernel support error: {err}")) - }) - } - #[cfg(target_os = "linux")] - { - set_host(&self.ifname, host) - } - #[cfg(not(any(target_os = "linux", target_os = "freebsd")))] - Err(io::Error::new( - io::ErrorKind::Other, - "kernel support is not available on this platform", - )) - } - } - - pub fn write_peer(&self, peer: &Peer) -> io::Result<()> { - if self.userspace { - let mut socket = self.socket()?; - socket.write_all(b"set=1\n")?; - socket.write_all(peer.as_uapi_update().as_bytes())?; - socket.write_all(b"\n")?; - - if Self::parse_errno(socket) != 0 { - Err(io::Error::new( - io::ErrorKind::InvalidData, - "write configuration error", - )) - } else { - Ok(()) - } - } else { - #[cfg(target_os = "freebsd")] - { - // FIXME: use proper error - set_peer(&self.ifname, peer).map_err(|err| { - io::Error::new(io::ErrorKind::Other, format!("kernel support error: {err}")) - }) - } - #[cfg(target_os = "linux")] - { - set_peer(&self.ifname, peer) - } - #[cfg(not(any(target_os = "linux", target_os = "freebsd")))] - Err(io::Error::new( - io::ErrorKind::Other, - "kernel support is not available on this platform", - )) - } - } - - pub fn delete_peer(&self, peer: &Peer) -> io::Result<()> { - if self.userspace { - let mut socket = self.socket()?; - socket.write_all(b"set=1\n")?; - socket.write_all(peer.as_uapi_remove().as_bytes())?; - socket.write_all(b"\n")?; - - if Self::parse_errno(socket) != 0 { - Err(io::Error::new( - io::ErrorKind::InvalidData, - "write configuration error", - )) - } else { - Ok(()) - } - } else { - #[cfg(target_os = "freebsd")] - { - // FIXME: use proper error - delete_peer(&self.ifname, peer).map_err(|err| { - io::Error::new(io::ErrorKind::Other, format!("kernel support error: {err}")) - }) - } - #[cfg(target_os = "linux")] - { - delete_peer(&self.ifname, peer) - } - #[cfg(not(any(target_os = "linux", target_os = "freebsd")))] - Err(io::Error::new( - io::ErrorKind::Other, - "kernel support is not available on this platform", - )) - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::io::Cursor; - - #[test] - fn test_parse_errno() { - let buf = Cursor::new(b"errno=0\n"); - assert_eq!(WGApi::parse_errno(buf), 0); - - let buf = Cursor::new(b"errno=12345\n"); - assert_eq!(WGApi::parse_errno(buf), 12345); - } -} diff --git a/wireguard-rs b/wireguard-rs new file mode 160000 index 00000000..b3f40b90 --- /dev/null +++ b/wireguard-rs @@ -0,0 +1 @@ +Subproject commit b3f40b90e94e47d109cd3c4c24908d377e30c561