From b5f8e8c89f1623eabbf2b3215057b84fd48155e4 Mon Sep 17 00:00:00 2001 From: Dusty Phillips Date: Fri, 5 Jun 2020 11:28:15 +0000 Subject: [PATCH] Prototype listening on SocketAddr instead of port The basic idea is the same, but it better mimics the normal socket interface. The Switchboard can now be thought of as a parody of the entire internet instead of a single machine. Most everything is a direct mapping, but there are two uncertain situations: * When connect() is called, there is no clear indicator as to where the client is connecting *from*. So if we ever wanted to implement e.g. peer_addr on a MemorySocket, the behaviour would be undefined. * It is not clear what address to listen on when connecting to 0.0.0.0. Both of these *could* be solved by mapping to the local machine address (using e.g the get_if_addrs crate), but I'm not sure that make sense. Another option I toyed with was having a "set_connect_address" global, but that simply lacked elegance. ( Discussion in #3 ) --- Cargo.toml | 1 + src/async.rs | 2 +- src/lib.rs | 122 +++++++++++++++++++++++++------------------------ tests/async.rs | 38 ++++++++++----- tests/sync.rs | 43 ++++++++++++----- 5 files changed, 121 insertions(+), 85 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f6922b2..e9dfb8b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ bytes = "0.5" flume = { version = "0.7", default-features = false } futures = { version = "0.3", optional = true } once_cell = "1.3" +rand = "0.7.3" [features] # Include nothing by default diff --git a/src/async.rs b/src/async.rs index 9f5771d..586460e 100644 --- a/src/async.rs +++ b/src/async.rs @@ -24,7 +24,7 @@ impl MemoryListener { /// use memory_socket::MemoryListener; /// /// # async fn work () -> ::std::io::Result<()> { - /// let mut listener = MemoryListener::bind(80).unwrap(); + /// let mut listener = MemoryListener::bind("192.51.100.2:60").unwrap(); /// let mut incoming = listener.incoming_stream(); /// /// while let Some(stream) = incoming.next().await { diff --git a/src/lib.rs b/src/lib.rs index 29138bf..84bf814 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,10 +14,11 @@ use bytes::{buf::BufExt, Buf, Bytes, BytesMut}; use flume::{Receiver, Sender}; use once_cell::sync::Lazy; +use rand::{thread_rng, Rng}; use std::{ collections::HashMap, io::{ErrorKind, Read, Result, Write}, - num::NonZeroU16, + net::{SocketAddr, ToSocketAddrs}, sync::Mutex, }; @@ -27,10 +28,11 @@ mod r#async; #[cfg(feature = "async")] pub use r#async::IncomingStream; +/// Collection of open connected sockets static SWITCHBOARD: Lazy> = - Lazy::new(|| Mutex::new(SwitchBoard(HashMap::default(), 1))); + Lazy::new(|| Mutex::new(SwitchBoard(HashMap::default()))); -struct SwitchBoard(HashMap>, u16); +struct SwitchBoard(HashMap>); /// An in-memory socket server, listening for connections. /// @@ -59,7 +61,7 @@ struct SwitchBoard(HashMap>, u16); /// } /// /// fn main() -> Result<()> { -/// let mut listener = MemoryListener::bind(16)?; +/// let mut listener = MemoryListener::bind("192.51.100.2:1337")?; /// /// // accept connections and process them serially /// for stream in listener.incoming() { @@ -70,7 +72,7 @@ struct SwitchBoard(HashMap>, u16); /// ``` pub struct MemoryListener { incoming: Receiver, - port: NonZeroU16, + address: SocketAddr, } impl Drop for MemoryListener { @@ -78,7 +80,7 @@ impl Drop for MemoryListener { let mut switchboard = (&*SWITCHBOARD).lock().unwrap(); // Remove the Sending side of the channel in the switchboard when // MemoryListener is dropped - switchboard.0.remove(&self.port); + switchboard.0.remove(&self.address); } } @@ -102,48 +104,49 @@ impl MemoryListener { /// use memory_socket::MemoryListener; /// /// # fn main () -> ::std::io::Result<()> { - /// let listener = MemoryListener::bind(16)?; + /// let listener = MemoryListener::bind("192.51.100.2:1337")?; /// # Ok(())} /// ``` - pub fn bind(port: u16) -> Result { + pub fn bind(addresses: A) -> Result { let mut switchboard = (&*SWITCHBOARD).lock().unwrap(); - // Get the port we should bind to. If 0 was given, use a random port - let port = if let Some(port) = NonZeroU16::new(port) { - if switchboard.0.contains_key(&port) { - return Err(ErrorKind::AddrInUse.into()); - } - - port - } else { - loop { - let port = NonZeroU16::new(switchboard.1).unwrap_or_else(|| unreachable!()); + let mut addresses = addresses.to_socket_addrs()?; - // The switchboard is full and all ports are in use - if switchboard.0.len() == (std::u16::MAX - 1) as usize { - return Err(ErrorKind::AddrInUse.into()); - } + let mut address = match addresses.next() { + Some(address) => address, + None => return Err(ErrorKind::AddrNotAvailable.into()), + }; - // Instead of overflowing to 0, resume searching at port 1 since port 0 isn't a - // valid port to bind to. - if switchboard.1 == std::u16::MAX { - switchboard.1 = 1; - } else { - switchboard.1 += 1; - } + // It doesn't really make sense to listen on multiple interfaces in + // this environment, so we place a restriction on the parameter. + if addresses.next().is_some() { + return Err(ErrorKind::AddrNotAvailable.into()); + } + // Similarly, it doesn't make a sense to listen on "all interfaces" + // in this environment, so return an error if they requested 0.0.0.0 + // TODO: We could use get_if_addrs and use the host's real name? + if address.ip().is_unspecified() { + return Err(ErrorKind::AddrNotAvailable.into()); + } - if !switchboard.0.contains_key(&port) { - break port; - } + // If they didn't provide a port find one that isn't in use. + if address.port() == 0 { + let mut rng = thread_rng(); + address.set_port(rng.gen()); + while switchboard.0.contains_key(&address) { + address.set_port(rng.gen()); } - }; + } else if switchboard.0.contains_key(&address) { + // Can't listen on the same address and port twice + return Err(ErrorKind::AddrInUse.into()); + } let (sender, receiver) = flume::unbounded(); - switchboard.0.insert(port, sender); + switchboard.0.insert(address, sender); Ok(Self { incoming: receiver, - port, + address, }) } @@ -156,15 +159,17 @@ impl MemoryListener { /// /// ``` /// use memory_socket::MemoryListener; + /// use std::net::SocketAddr; /// /// # fn main () -> ::std::io::Result<()> { - /// let listener = MemoryListener::bind(16)?; + /// let listener = MemoryListener::bind("192.51.100.2:1337")?; /// - /// assert_eq!(listener.local_addr(), 16); + /// let expected: SocketAddr = "192.51.100.2:1337".parse().unwrap(); + /// assert_eq!(listener.local_addr().unwrap(), expected); /// # Ok(())} /// ``` - pub fn local_addr(&self) -> u16 { - self.port.get() + pub fn local_addr(&self) -> Result { + Ok(self.address) } /// Returns an iterator over the connections being received on this @@ -181,7 +186,7 @@ impl MemoryListener { /// use memory_socket::MemoryListener; /// use std::io::{Read, Write}; /// - /// let mut listener = MemoryListener::bind(80).unwrap(); + /// let mut listener = MemoryListener::bind("192.51.100.2:1337").unwrap(); /// /// for stream in listener.incoming() { /// match stream { @@ -210,7 +215,7 @@ impl MemoryListener { /// use std::net::TcpListener; /// use memory_socket::MemoryListener; /// - /// let mut listener = MemoryListener::bind(8080).unwrap(); + /// let mut listener = MemoryListener::bind("192.51.100.2:8080").unwrap(); /// match listener.accept() { /// Ok(_socket) => println!("new client!"), /// Err(e) => println!("couldn't get client: {:?}", e), @@ -317,29 +322,26 @@ impl MemorySocket { /// use memory_socket::MemorySocket; /// /// # fn main () -> ::std::io::Result<()> { - /// # let _listener = memory_socket::MemoryListener::bind(16)?; - /// let socket = MemorySocket::connect(16)?; + /// # let _listener = memory_socket::MemoryListener::bind("192.51.100.2:60")?; + /// let socket = MemorySocket::connect("192.51.100.2:60")?; /// # Ok(())} /// ``` - pub fn connect(port: u16) -> Result { + pub fn connect(addresses: A) -> Result { let mut switchboard = (&*SWITCHBOARD).lock().unwrap(); + let addresses = addresses.to_socket_addrs()?; + for address in addresses { + if let Some(sender) = switchboard.0.get_mut(&address) { + let (socket_a, socket_b) = Self::new_pair(); + // Send the socket to the listener + sender + .send(socket_a) + .map_err(|_| ErrorKind::AddrNotAvailable)?; + + return Ok(socket_b); + } + } - // Find port to connect to - let port = NonZeroU16::new(port).ok_or_else(|| ErrorKind::AddrNotAvailable)?; - - let sender = switchboard - .0 - .get_mut(&port) - .ok_or_else(|| ErrorKind::AddrNotAvailable)?; - - let (socket_a, socket_b) = Self::new_pair(); - - // Send the socket to the listener - sender - .send(socket_a) - .map_err(|_| ErrorKind::AddrNotAvailable)?; - - Ok(socket_b) + Err(ErrorKind::AddrNotAvailable.into()) } } diff --git a/tests/async.rs b/tests/async.rs index a50a97d..d143134 100644 --- a/tests/async.rs +++ b/tests/async.rs @@ -4,7 +4,10 @@ use futures::{ stream::StreamExt, }; use memory_socket::{MemoryListener, MemorySocket}; -use std::io::Result; +use std::{ + io::Result, + net::{IpAddr, Ipv4Addr, SocketAddr}, +}; // // MemoryListener Tests @@ -12,17 +15,28 @@ use std::io::Result; #[test] fn listener_bind() -> Result<()> { - let listener = MemoryListener::bind(42)?; - assert_eq!(listener.local_addr(), 42); + let listener = MemoryListener::bind("192.51.100.2:42")?; + let expected = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 51, 100, 2)), 42); + let actual = listener + .local_addr() + .expect("Socket should have a local address"); + assert_eq!(actual, expected); Ok(()) } +#[test] +fn bind_unspecified() { + // Current implementation does not know how to handle unspecified address + let listener_result = MemoryListener::bind("0.0.0.0:0"); + assert!(listener_result.is_err()); +} + #[test] fn simple_connect() -> Result<()> { - let mut listener = MemoryListener::bind(10)?; + let mut listener = MemoryListener::bind("192.51.100.2:10")?; - let mut dialer = MemorySocket::connect(10)?; + let mut dialer = MemorySocket::connect("192.51.100.2:10")?; let mut listener_socket = block_on(listener.incoming_stream().next()).unwrap()?; block_on(dialer.write_all(b"foo"))?; @@ -37,8 +51,8 @@ fn simple_connect() -> Result<()> { #[test] fn listen_on_port_zero() -> Result<()> { - let mut listener = MemoryListener::bind(0)?; - let listener_addr = listener.local_addr(); + let mut listener = MemoryListener::bind("192.51.100.2:0")?; + let listener_addr = listener.local_addr().expect("That is a valid address"); let mut dialer = MemorySocket::connect(listener_addr)?; let mut listener_socket = block_on(listener.incoming_stream().next()).unwrap()?; @@ -62,9 +76,9 @@ fn listen_on_port_zero() -> Result<()> { #[test] fn listener_correctly_frees_port_on_drop() -> Result<()> { - fn connect_on_port(port: u16) -> Result<()> { - let mut listener = MemoryListener::bind(port)?; - let mut dialer = MemorySocket::connect(port)?; + fn connect_on_port(address: SocketAddr) -> Result<()> { + let mut listener = MemoryListener::bind(address)?; + let mut dialer = MemorySocket::connect(address)?; let mut listener_socket = block_on(listener.incoming_stream().next()).unwrap()?; block_on(dialer.write_all(b"foo"))?; @@ -77,8 +91,8 @@ fn listener_correctly_frees_port_on_drop() -> Result<()> { Ok(()) } - connect_on_port(9)?; - connect_on_port(9)?; + connect_on_port("192.51.100.2:9".parse().unwrap())?; + connect_on_port("192.51.100.2:9".parse().unwrap())?; Ok(()) } diff --git a/tests/sync.rs b/tests/sync.rs index aad8a15..75e82af 100644 --- a/tests/sync.rs +++ b/tests/sync.rs @@ -1,5 +1,8 @@ use memory_socket::{MemoryListener, MemorySocket}; -use std::io::{Read, Result, Write}; +use std::{ + io::{Read, Result, Write}, + net::{IpAddr, Ipv4Addr, SocketAddr}, +}; // // MemoryListener Tests @@ -7,17 +10,28 @@ use std::io::{Read, Result, Write}; #[test] fn listener_bind() -> Result<()> { - let listener = MemoryListener::bind(42)?; - assert_eq!(listener.local_addr(), 42); + let listener = MemoryListener::bind("192.51.100.2:42").expect("Should listen on valid address"); + let expected = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 51, 100, 2)), 42); + let actual = listener + .local_addr() + .expect("Socket should have a local address"); + assert_eq!(actual, expected); Ok(()) } +#[test] +fn bind_unspecified() { + // Current implementation does not know how to handle unspecified address + let listener_result = MemoryListener::bind("0.0.0.0:0"); + assert!(listener_result.is_err()); +} + #[test] fn simple_connect() -> Result<()> { - let listener = MemoryListener::bind(10)?; + let listener = MemoryListener::bind("192.51.100.2:1337")?; - let mut dialer = MemorySocket::connect(10)?; + let mut dialer = MemorySocket::connect("192.51.100.2:1337")?; let mut listener_socket = listener.incoming().next().unwrap()?; dialer.write_all(b"foo")?; @@ -32,8 +46,13 @@ fn simple_connect() -> Result<()> { #[test] fn listen_on_port_zero() -> Result<()> { - let listener = MemoryListener::bind(0)?; - let listener_addr = listener.local_addr(); + let listener = MemoryListener::bind("192.51.100.3:0").expect("Should listen on port 0"); + let listener_addr = listener.local_addr().expect("Should have a local address"); + assert_eq!( + listener_addr.ip(), + IpAddr::V4(Ipv4Addr::new(192, 51, 100, 3)) + ); + assert_ne!(listener_addr.port(), 0); let mut dialer = MemorySocket::connect(listener_addr)?; let mut listener_socket = listener.incoming().next().unwrap()?; @@ -57,9 +76,9 @@ fn listen_on_port_zero() -> Result<()> { #[test] fn listener_correctly_frees_port_on_drop() -> Result<()> { - fn connect_on_port(port: u16) -> Result<()> { - let listener = MemoryListener::bind(port)?; - let mut dialer = MemorySocket::connect(port)?; + fn connect_to(address: SocketAddr) -> Result<()> { + let listener = MemoryListener::bind(address)?; + let mut dialer = MemorySocket::connect(address)?; let mut listener_socket = listener.incoming().next().unwrap()?; dialer.write_all(b"foo")?; @@ -72,8 +91,8 @@ fn listener_correctly_frees_port_on_drop() -> Result<()> { Ok(()) } - connect_on_port(9)?; - connect_on_port(9)?; + connect_to("192.51.100.3:9".parse().unwrap())?; + connect_to("192.51.100.3:9".parse().unwrap())?; Ok(()) }