Skip to content

Commit

Permalink
Add EventLoop.create_connection impl (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
gi0baro authored Jan 23, 2025
1 parent 1bd0c96 commit 2875b8d
Show file tree
Hide file tree
Showing 5 changed files with 259 additions and 13 deletions.
6 changes: 5 additions & 1 deletion rloop/_rloop.pyi
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Any
from typing import Any, Callable, Tuple, TypeVar
from weakref import WeakSet

__version__: str

T = TypeVar('T')

class CBHandle:
def cancel(self): ...
def cancelled(self) -> bool: ...
Expand Down Expand Up @@ -45,6 +47,8 @@ class EventLoop:
def _sig_clear(self): ...
def _ssock_set(self, fd): ...
def _ssock_del(self, fd): ...
def _tcp_conn(self, sock, protocol_factory: Callable[[], T]) -> Tuple[Any, T]: ...
def _tcp_server(self, socks, rsocks, protocol_factory, backlog) -> Server: ...
def call_soon(self, callback, *args, context=None) -> CBHandle: ...
def call_soon_threadsafe(self, callback, *args, context=None) -> CBHandle: ...

Expand Down
161 changes: 159 additions & 2 deletions rloop/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from asyncio.coroutines import iscoroutine as _iscoroutine, iscoroutinefunction as _iscoroutinefunction
from asyncio.events import _get_running_loop, _set_running_loop
from asyncio.futures import Future as _Future, isfuture as _isfuture, wrap_future as _wrap_future
from asyncio.staggered import staggered_race as _staggered_race
from asyncio.tasks import Task as _Task, ensure_future as _ensure_future, gather as _gather
from concurrent.futures import ThreadPoolExecutor
from contextvars import copy_context as _copy_context
Expand All @@ -27,7 +28,7 @@
_SubProcessTransport,
_ThreadedChildWatcher,
)
from .utils import _can_use_pidfd, _HAS_IPv6, _ipaddr_info, _noop, _set_reuseport
from .utils import _can_use_pidfd, _HAS_IPv6, _interleave_addrinfos, _ipaddr_info, _noop, _set_reuseport


class RLoop(__BaseLoop, __asyncio.AbstractEventLoop):
Expand Down Expand Up @@ -312,8 +313,164 @@ async def create_connection(
ssl_shutdown_timeout=None,
happy_eyeballs_delay=None,
interleave=None,
all_errors=False,
):
raise NotImplementedError
# TODO
if ssl:
raise NotImplementedError

if server_hostname is not None and not ssl:
raise ValueError('server_hostname is only meaningful with ssl')

if server_hostname is None and ssl:
if not host:
raise ValueError('You must set server_hostname when using ssl without a host')
server_hostname = host

if ssl_handshake_timeout is not None and not ssl:
raise ValueError('ssl_handshake_timeout is only meaningful with ssl')

if ssl_shutdown_timeout is not None and not ssl:
raise ValueError('ssl_shutdown_timeout is only meaningful with ssl')

# TODO
# if sock is not None:
# _check_ssl_socket(sock)

if happy_eyeballs_delay is not None and interleave is None:
# If using happy eyeballs, default to interleave addresses by family
interleave = 1

if host is not None or port is not None:
if sock is not None:
raise ValueError('host/port and sock can not be specified at the same time')

infos = await self._ensure_resolved(
(host, port), family=family, type=socket.SOCK_STREAM, proto=proto, flags=flags, loop=self
)
if not infos:
raise OSError('getaddrinfo() returned empty list')

if local_addr is not None:
laddr_infos = await self._ensure_resolved(
local_addr, family=family, type=socket.SOCK_STREAM, proto=proto, flags=flags, loop=self
)
if not laddr_infos:
raise OSError('getaddrinfo() returned empty list')
else:
laddr_infos = None

if interleave:
infos = _interleave_addrinfos(infos, interleave)

exceptions = []
if happy_eyeballs_delay is None:
# not using happy eyeballs
for addrinfo in infos:
try:
sock = await self._connect_sock(exceptions, addrinfo, laddr_infos)
break
except OSError:
continue
else: # using happy eyeballs
sock = (
await _staggered_race(
(
# can't use functools.partial as it keeps a reference
# to exceptions
lambda addrinfo=addrinfo: self._connect_sock(exceptions, addrinfo, laddr_infos)
for addrinfo in infos
),
happy_eyeballs_delay,
loop=self,
)
)[0] # can't use sock, _, _ as it keeks a reference to exceptions

if sock is None:
exceptions = [exc for sub in exceptions for exc in sub]
try:
if all_errors:
raise ExceptionGroup('create_connection failed', exceptions)
if len(exceptions) == 1:
raise exceptions[0]
else:
# If they all have the same str(), raise one.
model = str(exceptions[0])
if all(str(exc) == model for exc in exceptions):
raise exceptions[0]
# Raise a combined exception so the user can see all
# the various error messages.
raise OSError('Multiple exceptions: {}'.format(', '.join(str(exc) for exc in exceptions)))
finally:
exceptions = None

else:
if sock is None:
raise ValueError('host and port was not specified and no sock specified')
if sock.type != socket.SOCK_STREAM:
# We allow AF_INET, AF_INET6, AF_UNIX as long as they
# are SOCK_STREAM.
# We support passing AF_UNIX sockets even though we have
# a dedicated API for that: create_unix_connection.
# Disallowing AF_UNIX in this method, breaks backwards
# compatibility.
raise ValueError(f'A Stream Socket was expected, got {sock!r}')

sock.setblocking(False)
rsock = (sock.fileno(), sock.family)
sock.detach()

# TODO: ssl
transport, protocol = self._tcp_conn(rsock, protocol_factory)
# transport, protocol = await self._create_connection_transport(
# sock,
# protocol_factory,
# ssl,
# server_hostname,
# ssl_handshake_timeout=ssl_handshake_timeout,
# ssl_shutdown_timeout=ssl_shutdown_timeout,
# )

return transport, protocol

async def _connect_sock(self, exceptions, addr_info, local_addr_infos=None):
my_exceptions = []
exceptions.append(my_exceptions)
family, type_, proto, _, address = addr_info
sock = None
try:
sock = socket.socket(family=family, type=type_, proto=proto)
sock.setblocking(False)
if local_addr_infos is not None:
for lfamily, _, _, _, laddr in local_addr_infos:
# skip local addresses of different family
if lfamily != family:
continue
try:
sock.bind(laddr)
break
except OSError as exc:
msg = f'error while attempting to bind on address {laddr!r}: {str(exc).lower()}'
exc = OSError(exc.errno, msg)
my_exceptions.append(exc)
else: # all bind attempts failed
if my_exceptions:
raise my_exceptions.pop()
else:
raise OSError(f'no matching local address with {family=} found')
await self.sock_connect(sock, address)
return sock
except OSError as exc:
my_exceptions.append(exc)
if sock is not None:
sock.close()
raise
except:
if sock is not None:
sock.close()
raise
finally:
exceptions = my_exceptions = None

async def create_server(
self,
Expand Down
19 changes: 19 additions & 0 deletions rloop/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import collections
import itertools
import os
import socket

Expand Down Expand Up @@ -79,6 +81,23 @@ def _ipaddr_info(host, port, family, type, proto, flowinfo=0, scopeid=0):
return None


def _interleave_addrinfos(addrinfos, first_address_family_count=1):
addrinfos_by_family = collections.OrderedDict()
for addr in addrinfos:
family = addr[0]
if family not in addrinfos_by_family:
addrinfos_by_family[family] = []
addrinfos_by_family[family].append(addr)
addrinfos_lists = list(addrinfos_by_family.values())

reordered = []
if first_address_family_count > 1:
reordered.extend(addrinfos_lists[0][: first_address_family_count - 1])
del addrinfos_lists[0][: first_address_family_count - 1]
reordered.extend(a for a in itertools.chain.from_iterable(itertools.zip_longest(*addrinfos_lists)) if a is not None)
return reordered


def _set_reuseport(sock):
if not hasattr(socket, 'SO_REUSEPORT'):
raise ValueError('reuse_port not supported by socket module')
Expand Down
29 changes: 24 additions & 5 deletions src/event_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::{
log::{log_exc_to_py_ctx, LogExc},
py::{copy_context, weakset},
server::Server,
tcp::{TCPReadHandle, TCPServer, TCPServerRef, TCPStream, TCPWriteHandle},
tcp::{PyTCPTransport, TCPReadHandle, TCPServer, TCPServerRef, TCPStream, TCPWriteHandle},
time::Timer,
};

Expand Down Expand Up @@ -413,10 +413,12 @@ impl EventLoop {
pub(crate) fn tcp_stream_close(&self, fd: usize) {
// println!("tcp_stream_close {:?}", fd);
if let Some((_, stream)) = self.tcp_streams.remove(&fd) {
self.tcp_lstreams.alter(&stream.lfd, |_, mut v| {
v.remove(&fd);
v
});
if let Some(lfd) = &stream.lfd {
self.tcp_lstreams.alter(lfd, |_, mut v| {
v.remove(&fd);
v
});
}
}
}

Expand Down Expand Up @@ -988,6 +990,22 @@ impl EventLoop {
})
}

fn _tcp_conn(
pyself: Py<Self>,
py: Python,
sock: (i32, i32),
protocol_factory: PyObject,
) -> PyResult<(Py<PyTCPTransport>, PyObject)> {
let rself = pyself.get();
let stream = TCPStream::from_py(py, &pyself, sock, protocol_factory);
let transport = stream.pytransport.clone_ref(py);
let fd = transport.get().fd;
let proto = PyTCPTransport::attach(&transport, py)?;
rself.tcp_streams.insert(fd, stream);
rself.tcp_stream_add(fd, Interest::READABLE);
Ok((transport, proto))
}

fn _tcp_server(
pyself: Py<Self>,
py: Python,
Expand Down Expand Up @@ -1020,6 +1038,7 @@ impl EventLoop {
fn _run(&self, py: Python) -> PyResult<()> {
let mut state = EventLoopRunState {
events: event::Events::with_capacity(128),
#[allow(clippy::large_stack_arrays)]
read_buf: [0; 262_144].into(),
tick_last: 0,
};
Expand Down
57 changes: 52 additions & 5 deletions src/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ impl TCPServerRef {
);

(
TCPStream::new(
TCPStream::from_listener(
self.fd,
stream,
pytransport.into(),
Expand All @@ -143,7 +143,7 @@ impl TCPServerRef {
}

pub(crate) struct TCPStream {
pub lfd: usize,
pub lfd: Option<usize>,
pub io: TcpStream,
pub pytransport: Arc<Py<PyTCPTransport>>,
read_buffered: bool,
Expand All @@ -153,16 +153,16 @@ pub(crate) struct TCPStream {
}

impl TCPStream {
fn new(
lfd: usize,
fn from_listener(
fd: usize,
stream: TcpStream,
pytransport: Arc<Py<PyTCPTransport>>,
read_buffered: bool,
pym_recv_data: Arc<PyObject>,
pym_buf_get: PyObject,
) -> Self {
Self {
lfd,
lfd: Some(fd),
io: stream,
pytransport,
read_buffered,
Expand All @@ -171,6 +171,45 @@ impl TCPStream {
pym_buf_get,
}
}

pub(crate) fn from_py(py: Python, pyloop: &Py<EventLoop>, pysock: (i32, i32), proto_factory: PyObject) -> Self {
let sock = unsafe { socket2::Socket::from_raw_fd(pysock.0) };
_ = sock.set_nonblocking(true);
let stdl: std::net::TcpStream = sock.into();
let stream = TcpStream::from_std(stdl);
// let stream = TcpStream::from_raw_fd(rsock);

let proto = proto_factory.bind(py).call0().unwrap();
let mut buffered_proto = false;
let pym_recv_data: PyObject;
let pym_buf_get: PyObject;
if proto.is_instance(asyncio_proto_buf(py).unwrap()).unwrap() {
buffered_proto = true;
pym_recv_data = proto.getattr(pyo3::intern!(py, "buffer_updated")).unwrap().unbind();
pym_buf_get = proto.getattr(pyo3::intern!(py, "get_buffer")).unwrap().unbind();
} else {
pym_recv_data = proto.getattr(pyo3::intern!(py, "data_received")).unwrap().unbind();
pym_buf_get = py.None();
}
let pyproto = proto.unbind();
let pytransport = PyTCPTransport::new(
py,
stream.as_raw_fd() as usize,
pysock.1,
pyloop.clone_ref(py),
pyproto.clone_ref(py),
);

Self {
lfd: None,
io: stream,
pytransport: pytransport.into(),
read_buffered: buffered_proto,
write_buffer: VecDeque::new(),
pym_recv_data: pym_recv_data.into(),
pym_buf_get,
}
}
}

#[pyclass(frozen)]
Expand Down Expand Up @@ -217,6 +256,14 @@ impl PyTCPTransport {
.unwrap()
}

pub(crate) fn attach(pyself: &Py<Self>, py: Python) -> PyResult<PyObject> {
let rself = pyself.get();
rself
.proto
.call_method1(py, pyo3::intern!(py, "connection_made"), (pyself.clone_ref(py),))?;
Ok(rself.proto.clone_ref(py))
}

#[inline]
fn write_buf_size_decr(pyself: &Py<Self>, py: Python, val: usize) {
// println!("tcp write_buf_size_decr {:?}", val);
Expand Down

0 comments on commit 2875b8d

Please sign in to comment.