Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TCP/UDP: new function is_listening: t -> ~port:int -> callback option #508

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/core/tcp.ml
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@ module type S = sig
and type write_error := write_error

val dst: flow -> ipaddr * int
val unread : flow -> Cstruct.t -> unit
val write_nodelay: flow -> Cstruct.t -> (unit, write_error) result Lwt.t
val writev_nodelay: flow -> Cstruct.t list -> (unit, write_error) result Lwt.t
val create_connection: ?keepalive:Keepalive.t -> t -> ipaddr * int -> (flow, error) result Lwt.t
val listen : t -> port:int -> ?keepalive:Keepalive.t -> (flow -> unit Lwt.t) -> unit
val is_listening : t -> port:int -> (flow -> unit Lwt.t) option
val unlisten : t -> port:int -> unit
val input: t -> src:ipaddr -> dst:ipaddr -> Cstruct.t -> unit Lwt.t
end
12 changes: 10 additions & 2 deletions src/core/tcp.mli
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ module type S = sig
(** Get the destination IP address and destination port that a
flow is currently connected to. *)

val unread : flow -> Cstruct.t -> unit
(** [unread flow buffer] puts [buffer] at the beginning of the receive queue,
so the next [read] from [flow] will receive [buffer]. *)

val write_nodelay: flow -> Cstruct.t -> (unit, write_error) result Lwt.t
(** [write_nodelay flow buffer] writes the contents of [buffer]
to the flow. The thread blocks until all data has been successfully
Expand Down Expand Up @@ -83,8 +87,12 @@ module type S = sig
executed for each flow that was established. If [keepalive] is provided,
this configuration will be applied before calling [callback].

@raise Invalid_argument if [port < 0] or [port > 65535]
*)
@raise Invalid_argument if [port < 0] or [port > 65535] *)

val is_listening : t -> port:int -> (flow -> unit Lwt.t) option
(** [is_listening t ~port] returns the [callback] on [port], if it exists.

@raise Invalid_argument if [port < 0] or [port > 65535] *)

val unlisten : t -> port:int -> unit
(** [unlisten t ~port] stops any listener on [port]. *)
Expand Down
1 change: 1 addition & 0 deletions src/core/udp.ml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ module type S = sig
val disconnect : t -> unit Lwt.t
type callback = src:ipaddr -> dst:ipaddr -> src_port:int -> Cstruct.t -> unit Lwt.t
val listen : t -> port:int -> callback -> unit
val is_listening : t -> port:int -> callback option
val unlisten : t -> port:int -> unit
val input: t -> src:ipaddr -> dst:ipaddr -> Cstruct.t -> unit Lwt.t
val write: ?src:ipaddr -> ?src_port:int -> ?ttl:int -> dst:ipaddr -> dst_port:int -> t -> Cstruct.t ->
Expand Down
5 changes: 5 additions & 0 deletions src/core/udp.mli
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ module type S = sig

@raise Invalid_argument if [port < 0] or [port > 65535] *)

val is_listening : t -> port:int -> callback option
(** [is_listening t ~port] returns the [callback] on [port], if it exists.

@raise Invalid_argument if [port < 0] or [port > 65535] *)

val unlisten : t -> port:int -> unit
(** [unlisten t ~port] stops any listeners on [port]. *)

Expand Down
2 changes: 1 addition & 1 deletion src/stack-unix/dune
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
(library
(name tcpv4v6_socket)
(public_name tcpip.tcpv4v6-socket)
(modules tcp_socket tcpv4v6_socket)
(modules tcpv4v6_socket)
(wrapped false)
(instrumentation
(backend bisect_ppx))
Expand Down
68 changes: 0 additions & 68 deletions src/stack-unix/tcp_socket.ml

This file was deleted.

112 changes: 96 additions & 16 deletions src/stack-unix/tcpv4v6_socket.ml
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@ module Log = (val Logs.src_log src : Logs.LOG)
open Lwt.Infix

type ipaddr = Ipaddr.t
type flow = Lwt_unix.file_descr
type flow = {
mutable buf : Cstruct.t;
fd : Lwt_unix.file_descr;
}

type t = {
interface: [ `Any | `Ip of Unix.inet_addr * Unix.inet_addr | `V4_only of Unix.inet_addr | `V6_only of Unix.inet_addr ]; (* source ip to bind to *)
mutable active_connections : Lwt_unix.file_descr list;
listen_sockets : (int, Lwt_unix.file_descr list) Hashtbl.t;
mutable active_connections : flow list;
listen_sockets : (int, Lwt_unix.file_descr list * (flow -> unit Lwt.t)) Hashtbl.t;
mutable switched_off : unit Lwt.t;
}

Expand All @@ -35,7 +38,75 @@ let set_switched_off t switched_off =

let any_v6 = Ipaddr_unix.V6.to_inet_addr Ipaddr.V6.unspecified

include Tcp_socket
type error = [ Tcpip.Tcp.error | `Exn of exn ]
type write_error = [ Tcpip.Tcp.write_error | `Exn of exn ]

let pp_error ppf = function
| #Tcpip.Tcp.error as e -> Tcpip.Tcp.pp_error ppf e
| `Exn e -> Fmt.exn ppf e

let pp_write_error ppf = function
| #Tcpip.Tcp.write_error as e -> Tcpip.Tcp.pp_write_error ppf e
| `Exn e -> Fmt.exn ppf e

let ignore_canceled = function
| Lwt.Canceled -> Lwt.return_unit
| exn -> raise exn

let read ({ buf ; fd } as flow) =
if Cstruct.length buf > 0 then begin
flow.buf <- Cstruct.empty;
Lwt.return (Ok (`Data buf))
end else
let buflen = 4096 in
let buf = Cstruct.create buflen in
Lwt.catch (fun () ->
Lwt_cstruct.read fd buf
>>= function
| 0 -> Lwt.return (Ok `Eof)
| n when n = buflen -> Lwt.return (Ok (`Data buf))
| n -> Lwt.return @@ Ok (`Data (Cstruct.sub buf 0 n))
)
(fun exn -> Lwt.return (Error (`Exn exn)))

let rec write ({ fd; _ } as flow) buf =
Lwt.catch
(fun () ->
Lwt_cstruct.write fd buf
>>= function
| n when n = Cstruct.length buf -> Lwt.return @@ Ok ()
| 0 -> Lwt.return @@ Error `Closed
| n -> write flow (Cstruct.sub buf n (Cstruct.length buf - n))
) (function
| Unix.Unix_error(Unix.EPIPE, _, _) -> Lwt.return @@ Error `Closed
| e -> Lwt.return (Error (`Exn e)))

let writev fd bufs =
Lwt_list.fold_left_s
(fun res buf ->
match res with
| Error _ as e -> Lwt.return e
| Ok () -> write fd buf
) (Ok ()) bufs

(* TODO make nodelay a flow option *)
let write_nodelay fd buf =
write fd buf

(* TODO make nodelay a flow option *)
let writev_nodelay fd bufs =
writev fd bufs

let close_fd fd =
Lwt.catch
(fun () -> Lwt_unix.close fd)
(function
| Unix.Unix_error (Unix.EBADF, _, _) -> Lwt.return_unit
| e -> Lwt.fail e)

let close { fd; _ } = close_fd fd

let input _t ~src:_ ~dst:_ _buf = Lwt.return_unit

let connect ~ipv4_only ~ipv6_only ipv4 ipv6 =
let interface =
Expand All @@ -62,11 +133,11 @@ let connect ~ipv4_only ~ipv6_only ipv4 ipv6 =

let disconnect t =
Lwt_list.iter_p close t.active_connections >>= fun () ->
Lwt_list.iter_p close
(Hashtbl.fold (fun _ fd acc -> fd @ acc) t.listen_sockets []) >>= fun () ->
Lwt_list.iter_p close_fd
(Hashtbl.fold (fun _ (fds, _) acc -> fds @ acc) t.listen_sockets []) >>= fun () ->
Lwt.cancel t.switched_off ; Lwt.return_unit

let dst fd =
let dst { fd; _ } =
match Lwt_unix.getpeername fd with
| Unix.ADDR_UNIX _ ->
raise (Failure "unexpected: got a unix instead of tcp sock")
Expand All @@ -78,6 +149,10 @@ let dst fd =
in
ip, port

let unread fd buf =
let buf = Cstruct.append buf fd.buf in
fd.buf <- buf
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what needs to be handled (for a complete, general API) is if a lwt task is already in Lwt_cstruct.read -- where the read should be cancelled and the buf provided here being returned to the caller.


let create_connection ?keepalive t (dst,dst_port) =
match
match dst, t.interface with
Expand All @@ -104,19 +179,23 @@ let create_connection ?keepalive t (dst,dst_port) =
| None -> ()
| Some { Tcpip.Tcp.Keepalive.after; interval; probes } ->
Tcp_socket_options.enable_keepalive ~fd ~after ~interval ~probes );
t.active_connections <- fd :: t.active_connections;
Lwt.return (Ok fd))
let flow = { buf = Cstruct.empty ; fd } in
t.active_connections <- flow :: t.active_connections;
Lwt.return (Ok flow))
(fun exn ->
close fd >>= fun () ->
close_fd fd >>= fun () ->
Lwt.return (Error (`Exn exn)))

let unlisten t ~port =
match Hashtbl.find_opt t.listen_sockets port with
| None -> ()
| Some fds ->
| Some (fds, _) ->
Hashtbl.remove t.listen_sockets port;
try List.iter (fun fd -> Unix.close (Lwt_unix.unix_file_descr fd)) fds with _ -> ()

let is_listening t ~port =
Option.map snd (Hashtbl.find_opt t.listen_sockets port)

let listen t ~port ?keepalive callback =
if port < 0 || port > 65535 then
raise (Invalid_argument (Printf.sprintf "invalid port number (%d)" port));
Expand Down Expand Up @@ -147,7 +226,7 @@ let listen t ~port ?keepalive callback =
in
List.iter (fun (fd, addr) ->
Unix.bind (Lwt_unix.unix_file_descr fd) addr;
Hashtbl.replace t.listen_sockets port (List.map fst fds);
Hashtbl.replace t.listen_sockets port (List.map fst fds, callback);
Lwt_unix.listen fd 10;
(* FIXME: we should not ignore the result *)
Lwt.async (fun () ->
Expand All @@ -156,18 +235,19 @@ let listen t ~port ?keepalive callback =
if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ;
Lwt.catch (fun () ->
Lwt_unix.accept fd >|= fun (afd, _) ->
t.active_connections <- afd :: t.active_connections;
let flow = { buf = Cstruct.empty ; fd = afd } in
t.active_connections <- flow :: t.active_connections;
(match keepalive with
| None -> ()
| Some { Tcpip.Tcp.Keepalive.after; interval; probes } ->
Tcp_socket_options.enable_keepalive ~fd:afd ~after ~interval ~probes);
Lwt.async
(fun () ->
Lwt.catch
(fun () -> callback afd)
(fun () -> callback flow)
(fun exn ->
Log.warn (fun m -> m "error %s in callback" (Printexc.to_string exn)) ;
close afd));
close flow));
`Continue)
(function
| Unix.Unix_error (Unix.EBADF, _, _) ->
Expand All @@ -179,4 +259,4 @@ let listen t ~port ?keepalive callback =
| `Continue -> loop ()
| `Stop -> Lwt.return_unit
in
Lwt.catch loop ignore_canceled >>= fun () -> close fd)) fds
Lwt.catch loop ignore_canceled >>= fun () -> close_fd fd)) fds
1 change: 0 additions & 1 deletion src/stack-unix/tcpv4v6_socket.mli
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

include Tcpip.Tcp.S
with type ipaddr = Ipaddr.t
and type flow = Lwt_unix.file_descr
Copy link
Member Author

@hannesm hannesm Apr 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any insight whether this is needed somewhere?

and type error = [ Tcpip.Tcp.error | `Exn of exn ]
and type write_error = [ Tcpip.Tcp.write_error | `Exn of exn ]

Expand Down
Loading