Skip to content

Commit

Permalink
Implement read waiter for none/aead/aead-2022
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Dec 21, 2023
1 parent e07da80 commit 1d569c2
Show file tree
Hide file tree
Showing 9 changed files with 265 additions and 59 deletions.
18 changes: 9 additions & 9 deletions cipher/method_none.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,8 @@ func (m *noneMethod) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.C
}

func (m *noneMethod) DialPacketConn(conn net.Conn) N.NetPacketConn {
return &nonePacketConn{
ExtendedConn: bufio.NewExtendedConn(conn),
}
extendedConn := bufio.NewExtendedConn(conn)
return &nonePacketConn{extendedConn, extendedConn}
}

var (
Expand Down Expand Up @@ -113,11 +112,12 @@ var (
)

type nonePacketConn struct {
N.ExtendedConn
N.AbstractConn
conn N.ExtendedConn
}

func (c *nonePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, err = c.ExtendedConn.Read(p)
n, err = c.conn.Read(p)
if err != nil {
return
}
Expand All @@ -144,7 +144,7 @@ func (c *nonePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return
}
common.Must1(buffer.Write(p))
_, err = c.ExtendedConn.Write(buffer.Bytes())
_, err = c.conn.Write(buffer.Bytes())
if err != nil {
return
}
Expand All @@ -153,7 +153,7 @@ func (c *nonePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
}

func (c *nonePacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
err = c.ExtendedConn.ReadBuffer(buffer)
err = c.conn.ReadBuffer(buffer)
if err != nil {
return
}
Expand All @@ -166,13 +166,13 @@ func (c *nonePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr
if err != nil {
return err
}
return c.ExtendedConn.WriteBuffer(buffer)
return c.conn.WriteBuffer(buffer)
}

func (c *nonePacketConn) FrontHeadroom() int {
return M.MaxSocksaddrLength
}

func (c *nonePacketConn) Upstream() any {
return c.ExtendedConn
return c.conn
}
41 changes: 41 additions & 0 deletions cipher/method_none_wait.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package cipher

import (
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)

var _ N.PacketReadWaitCreator = (*nonePacketConn)(nil)

func (c *nonePacketConn) CreateReadWaiter() (N.PacketReadWaiter, bool) {
readWaiter, isReadWaiter := bufio.CreateReadWaiter(c.conn)
if !isReadWaiter {
return nil, false
}
return &nonePacketReadWaiter{readWaiter}, true
}

var _ N.PacketReadWaiter = (*nonePacketReadWaiter)(nil)

type nonePacketReadWaiter struct {
readWaiter N.ReadWaiter
}

func (w *nonePacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
return w.readWaiter.InitializeReadWaiter(options)
}

func (w *nonePacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
buffer, err = w.readWaiter.WaitReadBuffer()
if err != nil {
return
}
destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer)
if err != nil {
buffer.Release()
return nil, M.Socksaddr{}, err
}
return
}
6 changes: 3 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ module github.com/sagernet/sing-shadowsocks2
go 1.18

require (
github.com/sagernet/sing v0.2.17
golang.org/x/crypto v0.15.0
github.com/sagernet/sing v0.3.0-rc.2
golang.org/x/crypto v0.17.0
lukechampine.com/blake3 v1.2.1
)

require (
github.com/klauspost/cpuid/v2 v2.0.9 // indirect
golang.org/x/sys v0.14.0 // indirect
golang.org/x/sys v0.15.0 // indirect
)
16 changes: 10 additions & 6 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/sagernet/sing v0.2.17 h1:vMPKb3MV0Aa5ws4dCJkRI8XEjrsUcDn810czd0FwmzI=
github.com/sagernet/sing v0.2.17/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo=
golang.org/x/crypto v0.15.0 h1:frVn1TEaCEaZcn3Tmd7Y2b5KKPaZ+I32Q2OA3kYp5TA=
golang.org/x/crypto v0.15.0/go.mod h1:4ChreQoLWfG3xLDer1WdlH5NdlQ3+mwnQq1YTKY+72g=
golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q=
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/sagernet/sing v0.3.0-rc.2 h1:l5rq+bTrNhpAPd2Vjzi/sEhil4O6Bb1CKv6LdPLJKug=
github.com/sagernet/sing v0.3.0-rc.2/go.mod h1:9pfuAH6mZfgnz/YjP6xu5sxx882rfyjpcrTdUpd6w3g=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
lukechampine.com/blake3 v1.2.1 h1:YuqqRuaqsGV71BV/nm9xlI0MKUv4QC54jQnBChWbGnI=
lukechampine.com/blake3 v1.2.1/go.mod h1:0OFRp7fBtAylGVCO40o87sbupkyIGgbpv1+M1k1LM6k=
59 changes: 49 additions & 10 deletions internal/shadowio/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io"

"github.com/sagernet/sing/common/buf"
N "github.com/sagernet/sing/common/network"
)

const PacketLengthBufferSize = 2
Expand All @@ -17,11 +18,17 @@ const (
Overhead = 16
)

var (
_ N.ExtendedReader = (*Reader)(nil)
_ N.ReadWaiter = (*Reader)(nil)
)

type Reader struct {
reader io.Reader
cipher cipher.AEAD
nonce []byte
cache *buf.Buffer
reader io.Reader
cipher cipher.AEAD
nonce []byte
cache *buf.Buffer
readWaitOptions N.ReadWaitOptions
}

func NewReader(upstream io.Reader, cipher cipher.AEAD) *Reader {
Expand Down Expand Up @@ -102,13 +109,45 @@ func (r *Reader) ReadBuffer(buffer *buf.Buffer) error {
}
}

func (r *Reader) ReadBufferThreadSafe() (buffer *buf.Buffer, err error) {
cache := r.cache
if cache != nil {
r.cache = nil
return cache, nil
func (r *Reader) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
r.readWaitOptions = options
return options.NeedHeadroom()
}

func (r *Reader) WaitReadBuffer() (buffer *buf.Buffer, err error) {
if r.readWaitOptions.NeedHeadroom() {
for {
if r.cache != nil {
if r.cache.IsEmpty() {
r.cache.Release()
r.cache = nil
} else {
buffer = r.readWaitOptions.NewBuffer()
var n int
n, err = buffer.Write(r.cache.Bytes())
if err != nil {
buffer.Release()
return
}
buffer.Truncate(n)
r.cache.Advance(n)
r.readWaitOptions.PostReturn(buffer)
return
}
}
r.cache, err = r.readBuffer()
if err != nil {
return
}
}
} else {
cache := r.cache
if cache != nil {
r.cache = nil
return cache, nil
}
return r.readBuffer()
}
return r.readBuffer()
}

func (r *Reader) readBuffer() (*buf.Buffer, error) {
Expand Down
37 changes: 22 additions & 15 deletions shadowaead/method.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,15 @@ func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn {
}
}

var _ N.ExtendedConn = (*clientConn)(nil)

type clientConn struct {
net.Conn
method *Method
destination M.Socksaddr
reader *shadowio.Reader
writer *shadowio.Writer
method *Method
destination M.Socksaddr
reader *shadowio.Reader
readWaitOptions N.ReadWaitOptions
writer *shadowio.Writer
shadowio.WriterInterface
}

Expand Down Expand Up @@ -160,7 +163,9 @@ func (c *clientConn) readResponse() error {
if err != nil {
return err
}
c.reader = shadowio.NewReader(c.Conn, readCipher)
reader := shadowio.NewReader(c.Conn, readCipher)
reader.InitializeReadWaiter(c.readWaitOptions)
c.reader = reader
return nil
}

Expand All @@ -184,16 +189,6 @@ func (c *clientConn) ReadBuffer(buffer *buf.Buffer) error {
return c.reader.ReadBuffer(buffer)
}

func (c *clientConn) ReadBufferThreadSafe() (buffer *buf.Buffer, err error) {
if c.reader == nil {
err = c.readResponse()
if err != nil {
return
}
}
return c.reader.ReadBufferThreadSafe()
}

func (c *clientConn) Write(p []byte) (n int, err error) {
if c.writer == nil {
err = c.writeRequest(p)
Expand Down Expand Up @@ -237,6 +232,10 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksad
if err != nil {
return
}
return c.readPacket(buffer)
}

func (c *clientPacketConn) readPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
if buffer.Len() < c.method.keySaltLength {
return M.Socksaddr{}, C.ErrPacketTooShort
}
Expand Down Expand Up @@ -350,6 +349,14 @@ func (c *clientPacketConn) RearHeadroom() int {
return shadowio.Overhead
}

func (c *clientPacketConn) ReaderMTU() int {
return MaxPacketSize
}

func (c *clientPacketConn) WriterMTU() int {
return MaxPacketSize
}

func (c *clientPacketConn) Upstream() any {
return c.AbstractConn
}
62 changes: 62 additions & 0 deletions shadowaead/method_wait.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package shadowaead

import (
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)

var _ N.ReadWaiter = (*clientConn)(nil)

func (c *clientConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
if c.reader == nil {
c.readWaitOptions = options
return options.NeedHeadroom()
}
return c.reader.InitializeReadWaiter(options)
}

func (c *clientConn) WaitReadBuffer() (buffer *buf.Buffer, err error) {
if c.reader == nil {
err = c.readResponse()
if err != nil {
return
}
}
return c.reader.WaitReadBuffer()
}

var _ N.PacketReadWaitCreator = (*clientPacketConn)(nil)

func (c *clientPacketConn) CreateReadWaiter() (N.PacketReadWaiter, bool) {
readWaiter, isReadWaiter := bufio.CreateReadWaiter(c.reader)
if !isReadWaiter {
return nil, false
}
return &clientPacketReadWaiter{c, readWaiter}, true
}

var _ N.PacketReadWaiter = (*clientPacketReadWaiter)(nil)

type clientPacketReadWaiter struct {
*clientPacketConn
readWaiter N.ReadWaiter
}

func (w *clientPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
return w.readWaiter.InitializeReadWaiter(options)
}

func (w *clientPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
buffer, err = w.readWaiter.WaitReadBuffer()
if err != nil {
return
}
destination, err = w.readPacket(buffer)
if err != nil {
buffer.Release()
return nil, M.Socksaddr{}, err
}
return
}
Loading

0 comments on commit 1d569c2

Please sign in to comment.