diff --git a/README.md b/README.md index c623682..8aa2bcc 100644 --- a/README.md +++ b/README.md @@ -80,7 +80,7 @@ Options: -identity string client identity sent to server -idle-time duration - max idle time for UDP session (default 1m30s) + max idle time for UDP session (default 30s) -key-length uint generate key with specified length (default 16) -mtu int @@ -91,6 +91,8 @@ Options: (server only) skip hello verify request. Useful to workaround DPI -stale-mode value which stale side of connection makes whole session stale (both, either, left, right) (default either) + -time-limit duration + hard time limit for each session -timeout duration network operation timeout (default 10s) ``` diff --git a/client/client.go b/client/client.go index efb203a..dc23d65 100644 --- a/client/client.go +++ b/client/client.go @@ -30,6 +30,7 @@ type Client struct { cancelCtx func() staleMode util.StaleMode workerWG sync.WaitGroup + timeLimit time.Duration } func New(cfg *Config) (*Client, error) { @@ -45,6 +46,7 @@ func New(cfg *Config) (*Client, error) { baseCtx: baseCtx, cancelCtx: cancelCtx, staleMode: cfg.StaleMode, + timeLimit: cfg.TimeLimit, } lAddrPort, err := netip.ParseAddrPort(cfg.BindAddress) @@ -101,7 +103,14 @@ func (client *Client) serve(conn net.Conn) { defer log.Printf("[-] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr()) defer conn.Close() - dialCtx, cancel := context.WithTimeout(client.baseCtx, client.timeout) + ctx := client.baseCtx + if client.timeLimit != 0 { + newCtx, cancel := context.WithTimeout(ctx, client.timeLimit) + defer cancel() + ctx = newCtx + } + + dialCtx, cancel := context.WithTimeout(ctx, client.timeout) defer cancel() remoteConn, err := (&net.Dialer{}).DialContext(dialCtx, "udp", client.rAddr) if err != nil { @@ -116,7 +125,7 @@ func (client *Client) serve(conn net.Conn) { return } - util.PairConn(client.baseCtx, conn, remoteConn, client.idleTimeout, client.staleMode) + util.PairConn(ctx, conn, remoteConn, client.idleTimeout, client.staleMode) } func (client *Client) contextMaker() (context.Context, func()) { diff --git a/client/config.go b/client/config.go index e10a273..ea16b7f 100644 --- a/client/config.go +++ b/client/config.go @@ -20,6 +20,7 @@ type Config struct { CipherSuites ciphers.CipherList EllipticCurves ciphers.CurveList StaleMode util.StaleMode + TimeLimit time.Duration } func (cfg *Config) populateDefaults() *Config { diff --git a/cmd/dtlspipe/main.go b/cmd/dtlspipe/main.go index d3d889d..4e21479 100644 --- a/cmd/dtlspipe/main.go +++ b/cmd/dtlspipe/main.go @@ -72,6 +72,7 @@ var ( ciphersuites = cipherlistArg{} curves = curvelistArg{} staleMode = util.EitherStale + timeLimit = flag.Duration("time-limit", 0, "hard time limit for each session") ) func init() { @@ -139,6 +140,7 @@ func cmdClient(bindAddress, remoteAddress string) int { CipherSuites: ciphersuites.Value, EllipticCurves: curves.Value, StaleMode: staleMode, + TimeLimit: *timeLimit, } clt, err := client.New(&cfg) @@ -176,6 +178,7 @@ func cmdServer(bindAddress, remoteAddress string) int { CipherSuites: ciphersuites.Value, EllipticCurves: curves.Value, StaleMode: staleMode, + TimeLimit: *timeLimit, } srv, err := server.New(&cfg) diff --git a/server/config.go b/server/config.go index 2ac9a24..df9e19f 100644 --- a/server/config.go +++ b/server/config.go @@ -20,6 +20,7 @@ type Config struct { CipherSuites ciphers.CipherList EllipticCurves ciphers.CurveList StaleMode util.StaleMode + TimeLimit time.Duration } func (cfg *Config) populateDefaults() *Config { diff --git a/server/server.go b/server/server.go index 687504a..0e0b916 100644 --- a/server/server.go +++ b/server/server.go @@ -31,6 +31,7 @@ type Server struct { cancelCtx func() staleMode util.StaleMode workerWG sync.WaitGroup + timeLimit time.Duration } func New(cfg *Config) (*Server, error) { @@ -46,6 +47,7 @@ func New(cfg *Config) (*Server, error) { baseCtx: baseCtx, cancelCtx: cancelCtx, staleMode: cfg.StaleMode, + timeLimit: cfg.TimeLimit, } lAddrPort, err := netip.ParseAddrPort(cfg.BindAddress) @@ -119,7 +121,14 @@ func (srv *Server) serve(conn net.Conn) { defer log.Printf("[-] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr()) defer conn.Close() - dialCtx, cancel := context.WithTimeout(srv.baseCtx, srv.timeout) + ctx := srv.baseCtx + if srv.timeLimit != 0 { + newCtx, cancel := context.WithTimeout(ctx, srv.timeLimit) + defer cancel() + ctx = newCtx + } + + dialCtx, cancel := context.WithTimeout(ctx, srv.timeout) defer cancel() remoteConn, err := (&net.Dialer{}).DialContext(dialCtx, "udp", srv.rAddr) if err != nil { @@ -128,7 +137,7 @@ func (srv *Server) serve(conn net.Conn) { } defer remoteConn.Close() - util.PairConn(srv.baseCtx, conn, remoteConn, srv.idleTimeout, srv.staleMode) + util.PairConn(ctx, conn, remoteConn, srv.idleTimeout, srv.staleMode) } func (srv *Server) contextMaker() (context.Context, func()) {