Skip to content

Commit

Permalink
Prototype listening on SocketAddr instead of port
Browse files Browse the repository at this point in the history
The basic idea is the same, but it better mimics
the normal socket interface.

The Switchboard can now be thought of as
a parody of the entire internet
instead of a single machine.

Most everything is a direct mapping,
but there are two uncertain situations:
 *  When connect() is called, there is no clear indicator
    as to where the client is connecting *from*.
    So if we ever wanted to implement e.g. peer_addr
    on a MemorySocket,
    the behaviour would be undefined.
 *  It is not clear what address to listen on when
     connecting to 0.0.0.0.

Both of these *could* be solved by mapping to the local
machine address (using e.g the get_if_addrs crate),
but I'm not sure that make sense.
Another option I toyed with was having a "set_connect_address"
global, but that simply lacked elegance.

( Discussion in bmwill#3 )
  • Loading branch information
Dusty Phillips committed Jun 5, 2020
1 parent 83213a8 commit b5f8e8c
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 85 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ bytes = "0.5"
flume = { version = "0.7", default-features = false }
futures = { version = "0.3", optional = true }
once_cell = "1.3"
rand = "0.7.3"

[features]
# Include nothing by default
Expand Down
2 changes: 1 addition & 1 deletion src/async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ impl MemoryListener {
/// use memory_socket::MemoryListener;
///
/// # async fn work () -> ::std::io::Result<()> {
/// let mut listener = MemoryListener::bind(80).unwrap();
/// let mut listener = MemoryListener::bind("192.51.100.2:60").unwrap();
/// let mut incoming = listener.incoming_stream();
///
/// while let Some(stream) = incoming.next().await {
Expand Down
122 changes: 62 additions & 60 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
use bytes::{buf::BufExt, Buf, Bytes, BytesMut};
use flume::{Receiver, Sender};
use once_cell::sync::Lazy;
use rand::{thread_rng, Rng};
use std::{
collections::HashMap,
io::{ErrorKind, Read, Result, Write},
num::NonZeroU16,
net::{SocketAddr, ToSocketAddrs},
sync::Mutex,
};

Expand All @@ -27,10 +28,11 @@ mod r#async;
#[cfg(feature = "async")]
pub use r#async::IncomingStream;

/// Collection of open connected sockets
static SWITCHBOARD: Lazy<Mutex<SwitchBoard>> =
Lazy::new(|| Mutex::new(SwitchBoard(HashMap::default(), 1)));
Lazy::new(|| Mutex::new(SwitchBoard(HashMap::default())));

struct SwitchBoard(HashMap<NonZeroU16, Sender<MemorySocket>>, u16);
struct SwitchBoard(HashMap<SocketAddr, Sender<MemorySocket>>);

/// An in-memory socket server, listening for connections.
///
Expand Down Expand Up @@ -59,7 +61,7 @@ struct SwitchBoard(HashMap<NonZeroU16, Sender<MemorySocket>>, u16);
/// }
///
/// fn main() -> Result<()> {
/// let mut listener = MemoryListener::bind(16)?;
/// let mut listener = MemoryListener::bind("192.51.100.2:1337")?;
///
/// // accept connections and process them serially
/// for stream in listener.incoming() {
Expand All @@ -70,15 +72,15 @@ struct SwitchBoard(HashMap<NonZeroU16, Sender<MemorySocket>>, u16);
/// ```
pub struct MemoryListener {
incoming: Receiver<MemorySocket>,
port: NonZeroU16,
address: SocketAddr,
}

impl Drop for MemoryListener {
fn drop(&mut self) {
let mut switchboard = (&*SWITCHBOARD).lock().unwrap();
// Remove the Sending side of the channel in the switchboard when
// MemoryListener is dropped
switchboard.0.remove(&self.port);
switchboard.0.remove(&self.address);
}
}

Expand All @@ -102,48 +104,49 @@ impl MemoryListener {
/// use memory_socket::MemoryListener;
///
/// # fn main () -> ::std::io::Result<()> {
/// let listener = MemoryListener::bind(16)?;
/// let listener = MemoryListener::bind("192.51.100.2:1337")?;
/// # Ok(())}
/// ```
pub fn bind(port: u16) -> Result<Self> {
pub fn bind<A: ToSocketAddrs>(addresses: A) -> Result<Self> {
let mut switchboard = (&*SWITCHBOARD).lock().unwrap();

// Get the port we should bind to. If 0 was given, use a random port
let port = if let Some(port) = NonZeroU16::new(port) {
if switchboard.0.contains_key(&port) {
return Err(ErrorKind::AddrInUse.into());
}

port
} else {
loop {
let port = NonZeroU16::new(switchboard.1).unwrap_or_else(|| unreachable!());
let mut addresses = addresses.to_socket_addrs()?;

// The switchboard is full and all ports are in use
if switchboard.0.len() == (std::u16::MAX - 1) as usize {
return Err(ErrorKind::AddrInUse.into());
}
let mut address = match addresses.next() {
Some(address) => address,
None => return Err(ErrorKind::AddrNotAvailable.into()),
};

// Instead of overflowing to 0, resume searching at port 1 since port 0 isn't a
// valid port to bind to.
if switchboard.1 == std::u16::MAX {
switchboard.1 = 1;
} else {
switchboard.1 += 1;
}
// It doesn't really make sense to listen on multiple interfaces in
// this environment, so we place a restriction on the parameter.
if addresses.next().is_some() {
return Err(ErrorKind::AddrNotAvailable.into());
}
// Similarly, it doesn't make a sense to listen on "all interfaces"
// in this environment, so return an error if they requested 0.0.0.0
// TODO: We could use get_if_addrs and use the host's real name?
if address.ip().is_unspecified() {
return Err(ErrorKind::AddrNotAvailable.into());
}

if !switchboard.0.contains_key(&port) {
break port;
}
// If they didn't provide a port find one that isn't in use.
if address.port() == 0 {
let mut rng = thread_rng();
address.set_port(rng.gen());
while switchboard.0.contains_key(&address) {
address.set_port(rng.gen());
}
};
} else if switchboard.0.contains_key(&address) {
// Can't listen on the same address and port twice
return Err(ErrorKind::AddrInUse.into());
}

let (sender, receiver) = flume::unbounded();
switchboard.0.insert(port, sender);
switchboard.0.insert(address, sender);

Ok(Self {
incoming: receiver,
port,
address,
})
}

Expand All @@ -156,15 +159,17 @@ impl MemoryListener {
///
/// ```
/// use memory_socket::MemoryListener;
/// use std::net::SocketAddr;
///
/// # fn main () -> ::std::io::Result<()> {
/// let listener = MemoryListener::bind(16)?;
/// let listener = MemoryListener::bind("192.51.100.2:1337")?;
///
/// assert_eq!(listener.local_addr(), 16);
/// let expected: SocketAddr = "192.51.100.2:1337".parse().unwrap();
/// assert_eq!(listener.local_addr().unwrap(), expected);
/// # Ok(())}
/// ```
pub fn local_addr(&self) -> u16 {
self.port.get()
pub fn local_addr(&self) -> Result<SocketAddr> {
Ok(self.address)
}

/// Returns an iterator over the connections being received on this
Expand All @@ -181,7 +186,7 @@ impl MemoryListener {
/// use memory_socket::MemoryListener;
/// use std::io::{Read, Write};
///
/// let mut listener = MemoryListener::bind(80).unwrap();
/// let mut listener = MemoryListener::bind("192.51.100.2:1337").unwrap();
///
/// for stream in listener.incoming() {
/// match stream {
Expand Down Expand Up @@ -210,7 +215,7 @@ impl MemoryListener {
/// use std::net::TcpListener;
/// use memory_socket::MemoryListener;
///
/// let mut listener = MemoryListener::bind(8080).unwrap();
/// let mut listener = MemoryListener::bind("192.51.100.2:8080").unwrap();
/// match listener.accept() {
/// Ok(_socket) => println!("new client!"),
/// Err(e) => println!("couldn't get client: {:?}", e),
Expand Down Expand Up @@ -317,29 +322,26 @@ impl MemorySocket {
/// use memory_socket::MemorySocket;
///
/// # fn main () -> ::std::io::Result<()> {
/// # let _listener = memory_socket::MemoryListener::bind(16)?;
/// let socket = MemorySocket::connect(16)?;
/// # let _listener = memory_socket::MemoryListener::bind("192.51.100.2:60")?;
/// let socket = MemorySocket::connect("192.51.100.2:60")?;
/// # Ok(())}
/// ```
pub fn connect(port: u16) -> Result<MemorySocket> {
pub fn connect<A: ToSocketAddrs>(addresses: A) -> Result<MemorySocket> {
let mut switchboard = (&*SWITCHBOARD).lock().unwrap();
let addresses = addresses.to_socket_addrs()?;
for address in addresses {
if let Some(sender) = switchboard.0.get_mut(&address) {
let (socket_a, socket_b) = Self::new_pair();
// Send the socket to the listener
sender
.send(socket_a)
.map_err(|_| ErrorKind::AddrNotAvailable)?;

return Ok(socket_b);
}
}

// Find port to connect to
let port = NonZeroU16::new(port).ok_or_else(|| ErrorKind::AddrNotAvailable)?;

let sender = switchboard
.0
.get_mut(&port)
.ok_or_else(|| ErrorKind::AddrNotAvailable)?;

let (socket_a, socket_b) = Self::new_pair();

// Send the socket to the listener
sender
.send(socket_a)
.map_err(|_| ErrorKind::AddrNotAvailable)?;

Ok(socket_b)
Err(ErrorKind::AddrNotAvailable.into())
}
}

Expand Down
38 changes: 26 additions & 12 deletions tests/async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,39 @@ use futures::{
stream::StreamExt,
};
use memory_socket::{MemoryListener, MemorySocket};
use std::io::Result;
use std::{
io::Result,
net::{IpAddr, Ipv4Addr, SocketAddr},
};

//
// MemoryListener Tests
//

#[test]
fn listener_bind() -> Result<()> {
let listener = MemoryListener::bind(42)?;
assert_eq!(listener.local_addr(), 42);
let listener = MemoryListener::bind("192.51.100.2:42")?;
let expected = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 51, 100, 2)), 42);
let actual = listener
.local_addr()
.expect("Socket should have a local address");
assert_eq!(actual, expected);

Ok(())
}

#[test]
fn bind_unspecified() {
// Current implementation does not know how to handle unspecified address
let listener_result = MemoryListener::bind("0.0.0.0:0");
assert!(listener_result.is_err());
}

#[test]
fn simple_connect() -> Result<()> {
let mut listener = MemoryListener::bind(10)?;
let mut listener = MemoryListener::bind("192.51.100.2:10")?;

let mut dialer = MemorySocket::connect(10)?;
let mut dialer = MemorySocket::connect("192.51.100.2:10")?;
let mut listener_socket = block_on(listener.incoming_stream().next()).unwrap()?;

block_on(dialer.write_all(b"foo"))?;
Expand All @@ -37,8 +51,8 @@ fn simple_connect() -> Result<()> {

#[test]
fn listen_on_port_zero() -> Result<()> {
let mut listener = MemoryListener::bind(0)?;
let listener_addr = listener.local_addr();
let mut listener = MemoryListener::bind("192.51.100.2:0")?;
let listener_addr = listener.local_addr().expect("That is a valid address");

let mut dialer = MemorySocket::connect(listener_addr)?;
let mut listener_socket = block_on(listener.incoming_stream().next()).unwrap()?;
Expand All @@ -62,9 +76,9 @@ fn listen_on_port_zero() -> Result<()> {

#[test]
fn listener_correctly_frees_port_on_drop() -> Result<()> {
fn connect_on_port(port: u16) -> Result<()> {
let mut listener = MemoryListener::bind(port)?;
let mut dialer = MemorySocket::connect(port)?;
fn connect_on_port(address: SocketAddr) -> Result<()> {
let mut listener = MemoryListener::bind(address)?;
let mut dialer = MemorySocket::connect(address)?;
let mut listener_socket = block_on(listener.incoming_stream().next()).unwrap()?;

block_on(dialer.write_all(b"foo"))?;
Expand All @@ -77,8 +91,8 @@ fn listener_correctly_frees_port_on_drop() -> Result<()> {
Ok(())
}

connect_on_port(9)?;
connect_on_port(9)?;
connect_on_port("192.51.100.2:9".parse().unwrap())?;
connect_on_port("192.51.100.2:9".parse().unwrap())?;

Ok(())
}
Expand Down
43 changes: 31 additions & 12 deletions tests/sync.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,37 @@
use memory_socket::{MemoryListener, MemorySocket};
use std::io::{Read, Result, Write};
use std::{
io::{Read, Result, Write},
net::{IpAddr, Ipv4Addr, SocketAddr},
};

//
// MemoryListener Tests
//

#[test]
fn listener_bind() -> Result<()> {
let listener = MemoryListener::bind(42)?;
assert_eq!(listener.local_addr(), 42);
let listener = MemoryListener::bind("192.51.100.2:42").expect("Should listen on valid address");
let expected = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 51, 100, 2)), 42);
let actual = listener
.local_addr()
.expect("Socket should have a local address");
assert_eq!(actual, expected);

Ok(())
}

#[test]
fn bind_unspecified() {
// Current implementation does not know how to handle unspecified address
let listener_result = MemoryListener::bind("0.0.0.0:0");
assert!(listener_result.is_err());
}

#[test]
fn simple_connect() -> Result<()> {
let listener = MemoryListener::bind(10)?;
let listener = MemoryListener::bind("192.51.100.2:1337")?;

let mut dialer = MemorySocket::connect(10)?;
let mut dialer = MemorySocket::connect("192.51.100.2:1337")?;
let mut listener_socket = listener.incoming().next().unwrap()?;

dialer.write_all(b"foo")?;
Expand All @@ -32,8 +46,13 @@ fn simple_connect() -> Result<()> {

#[test]
fn listen_on_port_zero() -> Result<()> {
let listener = MemoryListener::bind(0)?;
let listener_addr = listener.local_addr();
let listener = MemoryListener::bind("192.51.100.3:0").expect("Should listen on port 0");
let listener_addr = listener.local_addr().expect("Should have a local address");
assert_eq!(
listener_addr.ip(),
IpAddr::V4(Ipv4Addr::new(192, 51, 100, 3))
);
assert_ne!(listener_addr.port(), 0);

let mut dialer = MemorySocket::connect(listener_addr)?;
let mut listener_socket = listener.incoming().next().unwrap()?;
Expand All @@ -57,9 +76,9 @@ fn listen_on_port_zero() -> Result<()> {

#[test]
fn listener_correctly_frees_port_on_drop() -> Result<()> {
fn connect_on_port(port: u16) -> Result<()> {
let listener = MemoryListener::bind(port)?;
let mut dialer = MemorySocket::connect(port)?;
fn connect_to(address: SocketAddr) -> Result<()> {
let listener = MemoryListener::bind(address)?;
let mut dialer = MemorySocket::connect(address)?;
let mut listener_socket = listener.incoming().next().unwrap()?;

dialer.write_all(b"foo")?;
Expand All @@ -72,8 +91,8 @@ fn listener_correctly_frees_port_on_drop() -> Result<()> {
Ok(())
}

connect_on_port(9)?;
connect_on_port(9)?;
connect_to("192.51.100.3:9".parse().unwrap())?;
connect_to("192.51.100.3:9".parse().unwrap())?;

Ok(())
}
Expand Down

0 comments on commit b5f8e8c

Please sign in to comment.