diff --git a/cipher/method_none.go b/cipher/method_none.go index 22775d2..948b79b 100644 --- a/cipher/method_none.go +++ b/cipher/method_none.go @@ -39,9 +39,8 @@ func (m *noneMethod) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.C } func (m *noneMethod) DialPacketConn(conn net.Conn) N.NetPacketConn { - return &nonePacketConn{ - ExtendedConn: bufio.NewExtendedConn(conn), - } + extendedConn := bufio.NewExtendedConn(conn) + return &nonePacketConn{extendedConn, extendedConn} } var ( @@ -113,11 +112,12 @@ var ( ) type nonePacketConn struct { - N.ExtendedConn + N.AbstractConn + conn N.ExtendedConn } func (c *nonePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - n, err = c.ExtendedConn.Read(p) + n, err = c.conn.Read(p) if err != nil { return } @@ -144,7 +144,7 @@ func (c *nonePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { return } common.Must1(buffer.Write(p)) - _, err = c.ExtendedConn.Write(buffer.Bytes()) + _, err = c.conn.Write(buffer.Bytes()) if err != nil { return } @@ -153,7 +153,7 @@ func (c *nonePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { } func (c *nonePacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { - err = c.ExtendedConn.ReadBuffer(buffer) + err = c.conn.ReadBuffer(buffer) if err != nil { return } @@ -166,7 +166,7 @@ func (c *nonePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr if err != nil { return err } - return c.ExtendedConn.WriteBuffer(buffer) + return c.conn.WriteBuffer(buffer) } func (c *nonePacketConn) FrontHeadroom() int { @@ -174,5 +174,5 @@ func (c *nonePacketConn) FrontHeadroom() int { } func (c *nonePacketConn) Upstream() any { - return c.ExtendedConn + return c.conn } diff --git a/cipher/method_none_wait.go b/cipher/method_none_wait.go new file mode 100644 index 0000000..51a581f --- /dev/null +++ b/cipher/method_none_wait.go @@ -0,0 +1,41 @@ +package cipher + +import ( + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +var _ N.PacketReadWaitCreator = (*nonePacketConn)(nil) + +func (c *nonePacketConn) CreateReadWaiter() (N.PacketReadWaiter, bool) { + readWaiter, isReadWaiter := bufio.CreateReadWaiter(c.conn) + if !isReadWaiter { + return nil, false + } + return &nonePacketReadWaiter{readWaiter}, true +} + +var _ N.PacketReadWaiter = (*nonePacketReadWaiter)(nil) + +type nonePacketReadWaiter struct { + readWaiter N.ReadWaiter +} + +func (w *nonePacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + return w.readWaiter.InitializeReadWaiter(options) +} + +func (w *nonePacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { + buffer, err = w.readWaiter.WaitReadBuffer() + if err != nil { + return + } + destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer) + if err != nil { + buffer.Release() + return nil, M.Socksaddr{}, err + } + return +} diff --git a/go.mod b/go.mod index 8334d88..3293ef0 100644 --- a/go.mod +++ b/go.mod @@ -3,12 +3,12 @@ module github.com/sagernet/sing-shadowsocks2 go 1.18 require ( - github.com/sagernet/sing v0.2.17 + github.com/sagernet/sing v0.2.19-0.20231207034108-445cd4f41e3f golang.org/x/crypto v0.15.0 lukechampine.com/blake3 v1.2.1 ) require ( github.com/klauspost/cpuid/v2 v2.0.9 // indirect - golang.org/x/sys v0.14.0 // indirect + golang.org/x/sys v0.15.0 // indirect ) diff --git a/go.sum b/go.sum index 0512153..752e035 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,10 @@ github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= -github.com/sagernet/sing v0.2.17 h1:vMPKb3MV0Aa5ws4dCJkRI8XEjrsUcDn810czd0FwmzI= -github.com/sagernet/sing v0.2.17/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo= +github.com/sagernet/sing v0.2.19-0.20231207034108-445cd4f41e3f h1:hYkBnmJjVphGc4b02b4jN46ojh05vACYZI3ciD/V3pA= +github.com/sagernet/sing v0.2.19-0.20231207034108-445cd4f41e3f/go.mod h1:Ce5LNojQOgOiWhiD8pPD6E9H7e2KgtOe3Zxx4Ou5u80= golang.org/x/crypto v0.15.0 h1:frVn1TEaCEaZcn3Tmd7Y2b5KKPaZ+I32Q2OA3kYp5TA= golang.org/x/crypto v0.15.0/go.mod h1:4ChreQoLWfG3xLDer1WdlH5NdlQ3+mwnQq1YTKY+72g= -golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= -golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= +golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= lukechampine.com/blake3 v1.2.1 h1:YuqqRuaqsGV71BV/nm9xlI0MKUv4QC54jQnBChWbGnI= lukechampine.com/blake3 v1.2.1/go.mod h1:0OFRp7fBtAylGVCO40o87sbupkyIGgbpv1+M1k1LM6k= diff --git a/internal/shadowio/reader.go b/internal/shadowio/reader.go index 85199de..c755a6a 100644 --- a/internal/shadowio/reader.go +++ b/internal/shadowio/reader.go @@ -6,6 +6,7 @@ import ( "io" "github.com/sagernet/sing/common/buf" + N "github.com/sagernet/sing/common/network" ) const PacketLengthBufferSize = 2 @@ -17,11 +18,17 @@ const ( Overhead = 16 ) +var ( + _ N.ExtendedReader = (*Reader)(nil) + _ N.ReadWaiter = (*Reader)(nil) +) + type Reader struct { - reader io.Reader - cipher cipher.AEAD - nonce []byte - cache *buf.Buffer + reader io.Reader + cipher cipher.AEAD + nonce []byte + cache *buf.Buffer + readWaitOptions N.ReadWaitOptions } func NewReader(upstream io.Reader, cipher cipher.AEAD) *Reader { @@ -102,13 +109,45 @@ func (r *Reader) ReadBuffer(buffer *buf.Buffer) error { } } -func (r *Reader) ReadBufferThreadSafe() (buffer *buf.Buffer, err error) { - cache := r.cache - if cache != nil { - r.cache = nil - return cache, nil +func (r *Reader) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + r.readWaitOptions = options + return options.NeedHeadroom() +} + +func (r *Reader) WaitReadBuffer() (buffer *buf.Buffer, err error) { + if r.readWaitOptions.NeedHeadroom() { + for { + if r.cache != nil { + if r.cache.IsEmpty() { + r.cache.Release() + r.cache = nil + } else { + buffer = r.readWaitOptions.NewBuffer() + var n int + n, err = buffer.Write(r.cache.Bytes()) + if err != nil { + buffer.Release() + return + } + buffer.Truncate(n) + r.cache.Advance(n) + r.readWaitOptions.PostReturn(buffer) + return + } + } + r.cache, err = r.readBuffer() + if err != nil { + return + } + } + } else { + cache := r.cache + if cache != nil { + r.cache = nil + return cache, nil + } + return r.readBuffer() } - return r.readBuffer() } func (r *Reader) readBuffer() (*buf.Buffer, error) { diff --git a/shadowaead/method.go b/shadowaead/method.go index ef24bfa..c8cb2d0 100644 --- a/shadowaead/method.go +++ b/shadowaead/method.go @@ -107,12 +107,15 @@ func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn { } } +var _ N.ExtendedConn = (*clientConn)(nil) + type clientConn struct { net.Conn - method *Method - destination M.Socksaddr - reader *shadowio.Reader - writer *shadowio.Writer + method *Method + destination M.Socksaddr + reader *shadowio.Reader + readWaitOptions N.ReadWaitOptions + writer *shadowio.Writer shadowio.WriterInterface } @@ -160,7 +163,9 @@ func (c *clientConn) readResponse() error { if err != nil { return err } - c.reader = shadowio.NewReader(c.Conn, readCipher) + reader := shadowio.NewReader(c.Conn, readCipher) + reader.InitializeReadWaiter(c.readWaitOptions) + c.reader = reader return nil } @@ -184,16 +189,6 @@ func (c *clientConn) ReadBuffer(buffer *buf.Buffer) error { return c.reader.ReadBuffer(buffer) } -func (c *clientConn) ReadBufferThreadSafe() (buffer *buf.Buffer, err error) { - if c.reader == nil { - err = c.readResponse() - if err != nil { - return - } - } - return c.reader.ReadBufferThreadSafe() -} - func (c *clientConn) Write(p []byte) (n int, err error) { if c.writer == nil { err = c.writeRequest(p) @@ -237,6 +232,10 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksad if err != nil { return } + return c.readPacket(buffer) +} + +func (c *clientPacketConn) readPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { if buffer.Len() < c.method.keySaltLength { return M.Socksaddr{}, C.ErrPacketTooShort } @@ -350,6 +349,14 @@ func (c *clientPacketConn) RearHeadroom() int { return shadowio.Overhead } +func (c *clientPacketConn) ReaderMTU() int { + return MaxPacketSize +} + +func (c *clientPacketConn) WriterMTU() int { + return MaxPacketSize +} + func (c *clientPacketConn) Upstream() any { return c.AbstractConn } diff --git a/shadowaead/method_wait.go b/shadowaead/method_wait.go new file mode 100644 index 0000000..26ec4bb --- /dev/null +++ b/shadowaead/method_wait.go @@ -0,0 +1,62 @@ +package shadowaead + +import ( + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +var _ N.ReadWaiter = (*clientConn)(nil) + +func (c *clientConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + if c.reader == nil { + c.readWaitOptions = options + return options.NeedHeadroom() + } + return c.reader.InitializeReadWaiter(options) +} + +func (c *clientConn) WaitReadBuffer() (buffer *buf.Buffer, err error) { + if c.reader == nil { + err = c.readResponse() + if err != nil { + return + } + } + return c.reader.WaitReadBuffer() +} + +var _ N.PacketReadWaitCreator = (*clientPacketConn)(nil) + +func (c *clientPacketConn) CreateReadWaiter() (N.PacketReadWaiter, bool) { + readWaiter, isReadWaiter := bufio.CreateReadWaiter(c.reader) + if !isReadWaiter { + return nil, false + } + return &clientPacketReadWaiter{c, readWaiter}, true +} + +var _ N.PacketReadWaiter = (*clientPacketReadWaiter)(nil) + +type clientPacketReadWaiter struct { + *clientPacketConn + readWaiter N.ReadWaiter +} + +func (w *clientPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + return w.readWaiter.InitializeReadWaiter(options) +} + +func (w *clientPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { + buffer, err = w.readWaiter.WaitReadBuffer() + if err != nil { + return + } + destination, err = w.readPacket(buffer) + if err != nil { + buffer.Release() + return nil, M.Socksaddr{}, err + } + return +} diff --git a/shadowaead_2022/method.go b/shadowaead_2022/method.go index 7e6bfcb..428e9cf 100644 --- a/shadowaead_2022/method.go +++ b/shadowaead_2022/method.go @@ -163,11 +163,12 @@ func (m *Method) time() time.Time { type clientConn struct { net.Conn - method *Method - destination M.Socksaddr - requestSalt []byte - reader *shadowio.Reader - writer *shadowio.Writer + method *Method + destination M.Socksaddr + requestSalt []byte + reader *shadowio.Reader + readWaitOptions N.ReadWaitOptions + writer *shadowio.Writer shadowio.WriterInterface } @@ -302,6 +303,7 @@ func (c *clientConn) readResponse() error { if err != nil { return err } + reader.InitializeReadWaiter(c.readWaitOptions) c.reader = reader return nil } @@ -325,17 +327,6 @@ func (c *clientConn) ReadBuffer(buffer *buf.Buffer) error { return c.reader.ReadBuffer(buffer) } -func (c *clientConn) ReadBufferThreadSafe() (buffer *buf.Buffer, err error) { - if c.reader == nil { - err = c.readResponse() - if err != nil { - return - } - - } - return c.reader.ReadBufferThreadSafe() -} - func (c *clientConn) Write(p []byte) (n int, err error) { if c.writer == nil { err = c.writeRequest(p) diff --git a/shadowaead_2022/method_wait.go b/shadowaead_2022/method_wait.go new file mode 100644 index 0000000..6ac99b9 --- /dev/null +++ b/shadowaead_2022/method_wait.go @@ -0,0 +1,62 @@ +package shadowaead_2022 + +import ( + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +var _ N.ReadWaiter = (*clientConn)(nil) + +func (c *clientConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + if c.reader == nil { + c.readWaitOptions = options + return options.NeedHeadroom() + } + return c.reader.InitializeReadWaiter(options) +} + +func (c *clientConn) WaitReadBuffer() (buffer *buf.Buffer, err error) { + if c.reader == nil { + err = c.readResponse() + if err != nil { + return + } + } + return c.reader.WaitReadBuffer() +} + +var _ N.PacketReadWaitCreator = (*clientPacketConn)(nil) + +func (c *clientPacketConn) CreateReadWaiter() (N.PacketReadWaiter, bool) { + readWaiter, isReadWaiter := bufio.CreateReadWaiter(c.reader) + if !isReadWaiter { + return nil, false + } + return &clientPacketReadWaiter{c, readWaiter}, true +} + +var _ N.PacketReadWaiter = (*clientPacketReadWaiter)(nil) + +type clientPacketReadWaiter struct { + *clientPacketConn + readWaiter N.ReadWaiter +} + +func (w *clientPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + return w.readWaiter.InitializeReadWaiter(options) +} + +func (w *clientPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { + buffer, err = w.readWaiter.WaitReadBuffer() + if err != nil { + return + } + destination, err = w.readPacket(buffer) + if err != nil { + buffer.Release() + return nil, M.Socksaddr{}, err + } + return +}