From 1b82d33433a6040c22a9d6432aec9551dbd4b300 Mon Sep 17 00:00:00 2001 From: sbruens Date: Fri, 31 May 2024 17:03:56 -0400 Subject: [PATCH 001/119] Create a new config format so we can expand listener configuration for proxy protocol. --- cmd/outline-ss-server/config.go | 86 ++++++++++++++++ .../config_example.deprecated.yml | 15 +++ cmd/outline-ss-server/config_example.yml | 36 ++++--- cmd/outline-ss-server/config_test.go | 79 +++++++++++++++ cmd/outline-ss-server/main.go | 98 ++++++++++++------- net/address.go | 63 ++++++++++++ net/address_test.go | 82 ++++++++++++++++ 7 files changed, 408 insertions(+), 51 deletions(-) create mode 100644 cmd/outline-ss-server/config.go create mode 100644 cmd/outline-ss-server/config_example.deprecated.yml create mode 100644 cmd/outline-ss-server/config_test.go create mode 100644 net/address.go create mode 100644 net/address_test.go diff --git a/cmd/outline-ss-server/config.go b/cmd/outline-ss-server/config.go new file mode 100644 index 00000000..1f97e182 --- /dev/null +++ b/cmd/outline-ss-server/config.go @@ -0,0 +1,86 @@ +// Copyright 2024 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "fmt" + "os" + + "gopkg.in/yaml.v2" +) + +type Service struct { + Listeners []Listener + Keys []Key +} + +type Listener struct { + Type string + Address string +} + +type Key struct { + ID string + Cipher string + Secret string +} + +type Config struct { + Services []Service + + // Deprecated: Keys exists for historical compatibility. This is ignored if top-level `services` is specified. + Keys []struct { + ID string + Port int + Cipher string + Secret string + } +} + +// Reads a config from a filename and parses it as a [Config]. +func ReadConfig(filename string) (*Config, error) { + config := Config{} + configData, err := os.ReadFile(filename) + if err != nil { + return nil, fmt.Errorf("failed to read config: %w", err) + } + err = yaml.Unmarshal(configData, &config) + if err != nil { + return nil, fmt.Errorf("failed to parse config: %w", err) + } + if config.Services == nil { + // This is a deprecated config format. We need to transform it to to the new format. + ports := make(map[int][]Key) + for _, keyConfig := range config.Keys { + ports[keyConfig.Port] = append(ports[keyConfig.Port], Key{ + ID: keyConfig.ID, + Cipher: keyConfig.Cipher, + Secret: keyConfig.Secret, + }) + } + for port, keys := range ports { + s := Service{ + Listeners: []Listener{ + Listener{Type: "direct", Address: fmt.Sprintf("tcp://[::]:%d", port)}, + Listener{Type: "direct", Address: fmt.Sprintf("udp://[::]:%d", port)}, + }, + Keys: keys, + } + config.Services = append(config.Services, s) + } + } + config.Keys = nil + return &config, nil +} diff --git a/cmd/outline-ss-server/config_example.deprecated.yml b/cmd/outline-ss-server/config_example.deprecated.yml new file mode 100644 index 00000000..8895b86d --- /dev/null +++ b/cmd/outline-ss-server/config_example.deprecated.yml @@ -0,0 +1,15 @@ +keys: + - id: user-0 + port: 9000 + cipher: chacha20-ietf-poly1305 + secret: Secret0 + + - id: user-1 + port: 9000 + cipher: chacha20-ietf-poly1305 + secret: Secret1 + + - id: user-2 + port: 9001 + cipher: chacha20-ietf-poly1305 + secret: Secret2 diff --git a/cmd/outline-ss-server/config_example.yml b/cmd/outline-ss-server/config_example.yml index 8895b86d..66009c10 100644 --- a/cmd/outline-ss-server/config_example.yml +++ b/cmd/outline-ss-server/config_example.yml @@ -1,15 +1,23 @@ -keys: - - id: user-0 - port: 9000 - cipher: chacha20-ietf-poly1305 - secret: Secret0 +services: + - listeners: + - type: direct + address: "tcp://[::]:9000" + - type: direct + address: "udp://[::]:9000" + keys: + - id: user-0 + cipher: chacha20-ietf-poly1305 + secret: Secret0 + - id: user-1 + cipher: chacha20-ietf-poly1305 + secret: Secret1 - - id: user-1 - port: 9000 - cipher: chacha20-ietf-poly1305 - secret: Secret1 - - - id: user-2 - port: 9001 - cipher: chacha20-ietf-poly1305 - secret: Secret2 + - listeners: + - type: direct + address: "tcp://[::]:9001" + - type: direct + address: "udp://[::]:9001" + keys: + - id: user-2 + cipher: chacha20-ietf-poly1305 + secret: Secret2 diff --git a/cmd/outline-ss-server/config_test.go b/cmd/outline-ss-server/config_test.go new file mode 100644 index 00000000..0d46489c --- /dev/null +++ b/cmd/outline-ss-server/config_test.go @@ -0,0 +1,79 @@ +// Copyright 2024 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestReadConfig(t *testing.T) { + config, _ := ReadConfig("./config_example.yml") + + expected := Config{ + Services: []Service{ + Service{ + Listeners: []Listener{ + Listener{Type: "direct", Address: "tcp://[::]:9000"}, + Listener{Type: "direct", Address: "udp://[::]:9000"}, + }, + Keys: []Key{ + Key{"user-0", "chacha20-ietf-poly1305", "Secret0"}, + Key{"user-1", "chacha20-ietf-poly1305", "Secret1"}, + }, + }, + Service{ + Listeners: []Listener{ + Listener{Type: "direct", Address: "tcp://[::]:9001"}, + Listener{Type: "direct", Address: "udp://[::]:9001"}, + }, + Keys: []Key{ + Key{"user-2", "chacha20-ietf-poly1305", "Secret2"}, + }, + }, + }, + } + require.Equal(t, expected, *config) +} + +func TestReadConfigParsesDeprecatedFormat(t *testing.T) { + config, _ := ReadConfig("./config_example.deprecated.yml") + + expected := Config{ + Services: []Service{ + Service{ + Listeners: []Listener{ + Listener{Type: "direct", Address: "tcp://[::]:9000"}, + Listener{Type: "direct", Address: "udp://[::]:9000"}, + }, + Keys: []Key{ + Key{"user-0", "chacha20-ietf-poly1305", "Secret0"}, + Key{"user-1", "chacha20-ietf-poly1305", "Secret1"}, + }, + }, + Service{ + Listeners: []Listener{ + Listener{Type: "direct", Address: "tcp://[::]:9001"}, + Listener{Type: "direct", Address: "udp://[::]:9001"}, + }, + Keys: []Key{ + Key{"user-2", "chacha20-ietf-poly1305", "Secret2"}, + }, + }, + }, + } + require.Equal(t, expected, *config) +} diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 47b686a9..165e5217 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -28,13 +28,14 @@ import ( "github.com/Jigsaw-Code/outline-sdk/transport" "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks" + "github.com/Jigsaw-Code/outline-ss-server/ipinfo" + onet "github.com/Jigsaw-Code/outline-ss-server/net" "github.com/Jigsaw-Code/outline-ss-server/service" "github.com/op/go-logging" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "golang.org/x/term" - "gopkg.in/yaml.v2" ) var logger *logging.Logger @@ -48,6 +49,8 @@ const tcpReadTimeout time.Duration = 59 * time.Second // A UDP NAT timeout of at least 5 minutes is recommended in RFC 4787 Section 4.3. const defaultNatTimeout time.Duration = 5 * time.Minute +var directListenerType = "direct" + func init() { var prefix = "%{level:.1s}%{time:2006-01-02T15:04:05.000Z07:00} %{pid} %{shortfile}]" if term.IsTerminal(int(os.Stderr.Fd())) { @@ -125,26 +128,67 @@ func (s *SSServer) removePort(portNum int) error { } func (s *SSServer) loadConfig(filename string) error { - config, err := readConfig(filename) + config, err := ReadConfig(filename) if err != nil { return fmt.Errorf("failed to load config (%v): %w", filename, err) } portChanges := make(map[int]int) portCiphers := make(map[int]*list.List) // Values are *List of *CipherEntry. - for _, keyConfig := range config.Keys { - portChanges[keyConfig.Port] = 1 - cipherList, ok := portCiphers[keyConfig.Port] - if !ok { - cipherList = list.New() - portCiphers[keyConfig.Port] = cipherList + for _, serviceConfig := range config.Services { + if serviceConfig.Listeners == nil || serviceConfig.Keys == nil { + return fmt.Errorf("must specify at least 1 listener and 1 key per service") + } + addrs := []net.Addr{} + for _, listener := range serviceConfig.Listeners { + switch t := listener.Type; t { + // TODO: Support more listener types. + case directListenerType: + addr, err := onet.ResolveAddr(listener.Address) + if err != nil { + return fmt.Errorf("failed to resolve direct address: %v: %w", listener.Address, err) + } + addrs = append(addrs, addr) + port, err := onet.GetPort(addr) + if err != nil { + return err + } + portChanges[int(port)] = 1 + default: + return fmt.Errorf("unsupported listener type: %s", t) + } } - cryptoKey, err := shadowsocks.NewEncryptionKey(keyConfig.Cipher, keyConfig.Secret) - if err != nil { - return fmt.Errorf("failed to create encyption key for key %v: %w", keyConfig.ID, err) + + type key struct { + c string + s string + } + existingCipher := make(map[key]bool) + for _, keyConfig := range serviceConfig.Keys { + for _, addr := range addrs { + port, err := onet.GetPort(addr) + if err != nil { + return err + } + cipherList, ok := portCiphers[port] + if !ok { + cipherList = list.New() + portCiphers[port] = cipherList + } + _, ok = existingCipher[key{keyConfig.Cipher, keyConfig.Secret}] + if ok { + logger.Debugf("encryption key already exists for port=%v, ID=`%v`. Skipping.", port, keyConfig.ID) + continue + } + cryptoKey, err := shadowsocks.NewEncryptionKey(keyConfig.Cipher, keyConfig.Secret) + if err != nil { + return fmt.Errorf("failed to create encyption key for key %v: %w", keyConfig.ID, err) + } + entry := service.MakeCipherEntry(keyConfig.ID, cryptoKey, keyConfig.Secret) + cipherList.PushBack(&entry) + existingCipher[key{keyConfig.Cipher, keyConfig.Secret}] = true + } } - entry := service.MakeCipherEntry(keyConfig.ID, cryptoKey, keyConfig.Secret) - cipherList.PushBack(&entry) } for port := range s.ports { portChanges[port] = portChanges[port] - 1 @@ -160,11 +204,13 @@ func (s *SSServer) loadConfig(filename string) error { } } } + numServices := 0 for portNum, cipherList := range portCiphers { s.ports[portNum].cipherList.Update(cipherList) + numServices += cipherList.Len() } - logger.Infof("Loaded %v access keys over %v ports", len(config.Keys), len(s.ports)) - s.m.SetNumAccessKeys(len(config.Keys), len(portCiphers)) + logger.Infof("Loaded %v access keys over %v ports", numServices, len(s.ports)) + s.m.SetNumAccessKeys(numServices, len(s.ports)) return nil } @@ -203,28 +249,6 @@ func RunSSServer(filename string, natTimeout time.Duration, sm *outlineMetrics, return server, nil } -type Config struct { - Keys []struct { - ID string - Port int - Cipher string - Secret string - } -} - -func readConfig(filename string) (*Config, error) { - config := Config{} - configData, err := os.ReadFile(filename) - if err != nil { - return nil, fmt.Errorf("failed to read config: %w", err) - } - err = yaml.Unmarshal(configData, &config) - if err != nil { - return nil, fmt.Errorf("failed to parse config: %w", err) - } - return &config, nil -} - func main() { var flags struct { ConfigFile string diff --git a/net/address.go b/net/address.go new file mode 100644 index 00000000..a74e4f13 --- /dev/null +++ b/net/address.go @@ -0,0 +1,63 @@ +// Copyright 2024 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package net + +import ( + "fmt" + "net" + "net/url" +) + +// Resolves a URL-style listen address specification as a [net.Addr] +// +// Examples: +// +// udp6://127.0.0.1:8000 +// unix:///tmp/foo.sock +// tcp://127.0.0.1:9002 +func ResolveAddr(addr string) (net.Addr, error) { + u, err := url.Parse(addr) + if err != nil { + return nil, err + } + switch u.Scheme { + case "tcp", "tcp4", "tcp6": + return net.ResolveTCPAddr(u.Scheme, u.Host) + case "udp", "udp4", "udp6": + return net.ResolveUDPAddr(u.Scheme, u.Host) + case "unix", "unixgram", "unixpacket": + var path string + if u.Opaque != "" { + path = u.Opaque + } else { + path = u.Path + } + return net.ResolveUnixAddr(u.Scheme, path) + default: + return nil, net.UnknownNetworkError(u.Scheme) + } +} + +// Returns the port from a given address. +func GetPort(addr net.Addr) (port int, err error) { + switch t := addr.(type) { + case *net.TCPAddr: + return t.Port, nil + case *net.UDPAddr: + return t.Port, nil + default: + return -1, fmt.Errorf("failed to get port from address: %v", addr) + } +} diff --git a/net/address_test.go b/net/address_test.go new file mode 100644 index 00000000..f681f056 --- /dev/null +++ b/net/address_test.go @@ -0,0 +1,82 @@ +// Copyright 2024 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package net + +import ( + "net" + "testing" + + "github.com/stretchr/testify/require" +) + +type fakeAddr string + +func (a fakeAddr) String() string { return string(a) } +func (a fakeAddr) Network() string { return "" } + +func TestResolveAddrReturnsTCPAddr(t *testing.T) { + addr, err := ResolveAddr("tcp://0.0.0.0:9000") + + require.NoError(t, err) + if _, ok := addr.(*net.TCPAddr); !ok { + t.Errorf("expected a *net.TCPAddr; it is a %T", addr) + } +} + +func TestResolveAddrReturnsUDPAddr(t *testing.T) { + addr, err := ResolveAddr("udp://[::]:9001") + + require.NoError(t, err) + if _, ok := addr.(*net.UDPAddr); !ok { + t.Errorf("expected a *net.UDPAddr; it is a %T", addr) + } +} + +func TestResolveAddrReturnsUnixAddr(t *testing.T) { + addr, err := ResolveAddr("unix:///path/to/stream_socket") + + require.NoError(t, err) + if _, ok := addr.(*net.UnixAddr); !ok { + t.Errorf("expected a *net.UnixAddr; it is a %T", addr) + } +} + +func TestResolveAddrReturnsErrorForUnknownScheme(t *testing.T) { + addr, err := ResolveAddr("foobar") + + require.Nil(t, addr) + require.Error(t, err) +} + +func TestGetPortFromTCPAddr(t *testing.T) { + port, err := GetPort(&net.TCPAddr{IP: net.ParseIP("1.2.3.4"), Port: 1234}) + + require.NoError(t, err) + require.Equal(t, 1234, port) +} + +func TestGetPortFromUDPPAddr(t *testing.T) { + port, err := GetPort(&net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 5678}) + + require.NoError(t, err) + require.Equal(t, 5678, port) +} + +func TestGetPortReturnsErrorForUnsupportedAddressType(t *testing.T) { + port, err := GetPort(&net.UnixAddr{Name: "/path/to/foo", Net: "unix"}) + + require.Equal(t, -1, port) + require.Error(t, err) +} From a4c200718dad9cdf6e62161789584e0293c14035 Mon Sep 17 00:00:00 2001 From: sbruens Date: Fri, 31 May 2024 17:04:48 -0400 Subject: [PATCH 002/119] Remove unused `fakeAddr`. --- net/address_test.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/net/address_test.go b/net/address_test.go index f681f056..d7fe1d01 100644 --- a/net/address_test.go +++ b/net/address_test.go @@ -21,11 +21,6 @@ import ( "github.com/stretchr/testify/require" ) -type fakeAddr string - -func (a fakeAddr) String() string { return string(a) } -func (a fakeAddr) Network() string { return "" } - func TestResolveAddrReturnsTCPAddr(t *testing.T) { addr, err := ResolveAddr("tcp://0.0.0.0:9000") From 72b27d7327bf50af84d86e03da78711623e78e4f Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 3 Jun 2024 12:33:15 -0400 Subject: [PATCH 003/119] Split `startPort` up between TCP and UDP. --- cmd/outline-ss-server/main.go | 39 +++++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 165e5217..f5b247da 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -75,25 +75,16 @@ type SSServer struct { ports map[int]*ssPort } -func (s *SSServer) startPort(portNum int) error { +func (s *SSServer) startTCP(portNum int, cipherList service.CipherList) (*net.TCPListener, error) { listener, err := net.ListenTCP("tcp", &net.TCPAddr{Port: portNum}) if err != nil { //lint:ignore ST1005 Shadowsocks is capitalized. - return fmt.Errorf("Shadowsocks TCP service failed to start on port %v: %w", portNum, err) + return nil, fmt.Errorf("Shadowsocks TCP service failed to start on port %v: %w", portNum, err) } logger.Infof("Shadowsocks TCP service listening on %v", listener.Addr().String()) - packetConn, err := net.ListenUDP("udp", &net.UDPAddr{Port: portNum}) - if err != nil { - //lint:ignore ST1005 Shadowsocks is capitalized. - return fmt.Errorf("Shadowsocks UDP service failed to start on port %v: %w", portNum, err) - } - logger.Infof("Shadowsocks UDP service listening on %v", packetConn.LocalAddr().String()) - port := &ssPort{tcpListener: listener, packetConn: packetConn, cipherList: service.NewCipherList()} - authFunc := service.NewShadowsocksStreamAuthenticator(port.cipherList, &s.replayCache, s.m) + authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &s.replayCache, s.m) // TODO: Register initial data metrics at zero. tcpHandler := service.NewTCPHandler(portNum, authFunc, s.m, tcpReadTimeout) - packetHandler := service.NewPacketHandler(s.natTimeout, port.cipherList, s.m) - s.ports[portNum] = port accept := func() (transport.StreamConn, error) { conn, err := listener.AcceptTCP() if err == nil { @@ -102,8 +93,19 @@ func (s *SSServer) startPort(portNum int) error { return conn, err } go service.StreamServe(accept, tcpHandler.Handle) - go packetHandler.Handle(port.packetConn) - return nil + return listener, nil +} + +func (s *SSServer) startUDP(portNum int, cipherList service.CipherList) (*net.UDPConn, error) { + packetConn, err := net.ListenUDP("udp", &net.UDPAddr{Port: portNum}) + if err != nil { + //lint:ignore ST1005 Shadowsocks is capitalized. + return nil, fmt.Errorf("Shadowsocks UDP service failed to start on port %v: %w", portNum, err) + } + logger.Infof("Shadowsocks UDP service listening on %v", packetConn.LocalAddr().String()) + packetHandler := service.NewPacketHandler(s.natTimeout, cipherList, s.m) + go packetHandler.Handle(packetConn) + return packetConn, nil } func (s *SSServer) removePort(portNum int) error { @@ -199,9 +201,16 @@ func (s *SSServer) loadConfig(filename string) error { return fmt.Errorf("failed to remove port %v: %w", portNum, err) } } else if count == +1 { - if err := s.startPort(portNum); err != nil { + cipherList := service.NewCipherList() + tcpListener, err := s.startTCP(portNum, cipherList) + if err != nil { + return err + } + packetConn, err := s.startUDP(portNum, cipherList) + if err != nil { return err } + s.ports[portNum] = &ssPort{tcpListener, packetConn, cipherList} } } numServices := 0 From fddfc5720b8c5becfbdeaef0dd14fb7d8a9575a8 Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 3 Jun 2024 17:29:31 -0400 Subject: [PATCH 004/119] Use listeners to configure TCP and/or UDP services as needed. --- cmd/outline-ss-server/main.go | 169 +++++++++++++++++----------------- 1 file changed, 83 insertions(+), 86 deletions(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index f5b247da..206ff05d 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -18,6 +18,7 @@ import ( "container/list" "flag" "fmt" + "io" "net" "net/http" "os" @@ -62,29 +63,28 @@ func init() { logger = logging.MustGetLogger("") } -type ssPort struct { - tcpListener *net.TCPListener - packetConn net.PacketConn - cipherList service.CipherList +type ssListener struct { + io.Closer + cipherList service.CipherList } type SSServer struct { natTimeout time.Duration m *outlineMetrics replayCache service.ReplayCache - ports map[int]*ssPort + listeners map[string]*ssListener } -func (s *SSServer) startTCP(portNum int, cipherList service.CipherList) (*net.TCPListener, error) { - listener, err := net.ListenTCP("tcp", &net.TCPAddr{Port: portNum}) +func (s *SSServer) startDirectTCP(addr *net.TCPAddr, cipherList service.CipherList) (*net.TCPListener, error) { + listener, err := net.ListenTCP("tcp", addr) if err != nil { //lint:ignore ST1005 Shadowsocks is capitalized. - return nil, fmt.Errorf("Shadowsocks TCP service failed to start on port %v: %w", portNum, err) + return nil, fmt.Errorf("Shadowsocks TCP service failed to start on address %v: %w", addr.String(), err) } logger.Infof("Shadowsocks TCP service listening on %v", listener.Addr().String()) authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &s.replayCache, s.m) // TODO: Register initial data metrics at zero. - tcpHandler := service.NewTCPHandler(portNum, authFunc, s.m, tcpReadTimeout) + tcpHandler := service.NewTCPHandler(addr.Port, authFunc, s.m, tcpReadTimeout) accept := func() (transport.StreamConn, error) { conn, err := listener.AcceptTCP() if err == nil { @@ -96,11 +96,11 @@ func (s *SSServer) startTCP(portNum int, cipherList service.CipherList) (*net.TC return listener, nil } -func (s *SSServer) startUDP(portNum int, cipherList service.CipherList) (*net.UDPConn, error) { - packetConn, err := net.ListenUDP("udp", &net.UDPAddr{Port: portNum}) +func (s *SSServer) startDirectUDP(addr *net.UDPAddr, cipherList service.CipherList) (*net.UDPConn, error) { + packetConn, err := net.ListenUDP("udp", addr) if err != nil { //lint:ignore ST1005 Shadowsocks is capitalized. - return nil, fmt.Errorf("Shadowsocks UDP service failed to start on port %v: %w", portNum, err) + return nil, fmt.Errorf("Shadowsocks UDP service failed to start on address %v: %w", addr.String(), err) } logger.Infof("Shadowsocks UDP service listening on %v", packetConn.LocalAddr().String()) packetHandler := service.NewPacketHandler(s.natTimeout, cipherList, s.m) @@ -108,24 +108,29 @@ func (s *SSServer) startUDP(portNum int, cipherList service.CipherList) (*net.UD return packetConn, nil } -func (s *SSServer) removePort(portNum int) error { - port, ok := s.ports[portNum] - if !ok { - return fmt.Errorf("port %v doesn't exist", portNum) +func (s *SSServer) start(addr net.Addr, cipherList service.CipherList) (io.Closer, error) { + switch t := addr.(type) { + case *net.TCPAddr: + return s.startDirectTCP(t, cipherList) + case *net.UDPAddr: + return s.startDirectUDP(t, cipherList) + default: + return nil, fmt.Errorf("unable to start address: %s", t) } - tcpErr := port.tcpListener.Close() - udpErr := port.packetConn.Close() - delete(s.ports, portNum) - if tcpErr != nil { - //lint:ignore ST1005 Shadowsocks is capitalized. - return fmt.Errorf("Shadowsocks TCP service on port %v failed to stop: %w", portNum, tcpErr) +} + +func (s *SSServer) remove(addr string) error { + listener, ok := s.listeners[addr] + if !ok { + return fmt.Errorf("address %v doesn't exist", addr) } - logger.Infof("Shadowsocks TCP service on port %v stopped", portNum) - if udpErr != nil { + err := listener.Close() + delete(s.listeners, addr) + if err != nil { //lint:ignore ST1005 Shadowsocks is capitalized. - return fmt.Errorf("Shadowsocks UDP service on port %v failed to stop: %w", portNum, udpErr) + return fmt.Errorf("Shadowsocks service on address %v failed to stop: %w", addr, err) } - logger.Infof("Shadowsocks UDP service on port %v stopped", portNum) + logger.Infof("Shadowsocks service on address %v stopped", addr) return nil } @@ -135,98 +140,90 @@ func (s *SSServer) loadConfig(filename string) error { return fmt.Errorf("failed to load config (%v): %w", filename, err) } - portChanges := make(map[int]int) - portCiphers := make(map[int]*list.List) // Values are *List of *CipherEntry. + uniqueCiphers := 0 + addrChanges := make(map[string]int) + type addrWithCiphers struct { + address net.Addr + ciphers *list.List // Values are *List of *CipherEntry. + } + addrs := make(map[string]*addrWithCiphers) for _, serviceConfig := range config.Services { if serviceConfig.Listeners == nil || serviceConfig.Keys == nil { return fmt.Errorf("must specify at least 1 listener and 1 key per service") } - addrs := []net.Addr{} + + ciphers := list.New() + type cipherKey struct { + cipher string + secret string + } + existingCiphers := make(map[cipherKey]bool) + for _, keyConfig := range serviceConfig.Keys { + key := cipherKey{keyConfig.Cipher, keyConfig.Secret} + _, ok := existingCiphers[key] + if ok { + logger.Debugf("encryption key already exists for ID=`%v`. Skipping.", keyConfig.ID) + continue + } + cryptoKey, err := shadowsocks.NewEncryptionKey(keyConfig.Cipher, keyConfig.Secret) + if err != nil { + return fmt.Errorf("failed to create encyption key for key %v: %w", keyConfig.ID, err) + } + entry := service.MakeCipherEntry(keyConfig.ID, cryptoKey, keyConfig.Secret) + ciphers.PushBack(&entry) + existingCiphers[key] = true + } + uniqueCiphers += ciphers.Len() + for _, listener := range serviceConfig.Listeners { switch t := listener.Type; t { // TODO: Support more listener types. case directListenerType: + //var addr net.Addr addr, err := onet.ResolveAddr(listener.Address) if err != nil { return fmt.Errorf("failed to resolve direct address: %v: %w", listener.Address, err) } - addrs = append(addrs, addr) - port, err := onet.GetPort(addr) - if err != nil { - return err - } - portChanges[int(port)] = 1 + addrChanges[listener.Address] = 1 + addrs[listener.Address] = &addrWithCiphers{addr, ciphers} default: return fmt.Errorf("unsupported listener type: %s", t) } } - - type key struct { - c string - s string - } - existingCipher := make(map[key]bool) - for _, keyConfig := range serviceConfig.Keys { - for _, addr := range addrs { - port, err := onet.GetPort(addr) - if err != nil { - return err - } - cipherList, ok := portCiphers[port] - if !ok { - cipherList = list.New() - portCiphers[port] = cipherList - } - _, ok = existingCipher[key{keyConfig.Cipher, keyConfig.Secret}] - if ok { - logger.Debugf("encryption key already exists for port=%v, ID=`%v`. Skipping.", port, keyConfig.ID) - continue - } - cryptoKey, err := shadowsocks.NewEncryptionKey(keyConfig.Cipher, keyConfig.Secret) - if err != nil { - return fmt.Errorf("failed to create encyption key for key %v: %w", keyConfig.ID, err) - } - entry := service.MakeCipherEntry(keyConfig.ID, cryptoKey, keyConfig.Secret) - cipherList.PushBack(&entry) - existingCipher[key{keyConfig.Cipher, keyConfig.Secret}] = true - } - } } - for port := range s.ports { - portChanges[port] = portChanges[port] - 1 + for listener := range s.listeners { + addrChanges[listener] = addrChanges[listener] - 1 } - for portNum, count := range portChanges { + for addr, count := range addrChanges { if count == -1 { - if err := s.removePort(portNum); err != nil { - return fmt.Errorf("failed to remove port %v: %w", portNum, err) + if err := s.remove(addr); err != nil { + return fmt.Errorf("failed to remove address %v: %w", addr, err) } } else if count == +1 { cipherList := service.NewCipherList() - tcpListener, err := s.startTCP(portNum, cipherList) + listener, err := s.start(addrs[addr].address, cipherList) if err != nil { return err } - packetConn, err := s.startUDP(portNum, cipherList) - if err != nil { - return err - } - s.ports[portNum] = &ssPort{tcpListener, packetConn, cipherList} + s.listeners[addr] = &ssListener{Closer: listener, cipherList: cipherList} } } - numServices := 0 - for portNum, cipherList := range portCiphers { - s.ports[portNum].cipherList.Update(cipherList) - numServices += cipherList.Len() + for addr, addrWithCiphers := range addrs { + listener, ok := s.listeners[addr] + if !ok { + return fmt.Errorf("unable to find listener for address: %v", addr) + } + listener.cipherList.Update(addrWithCiphers.ciphers) } - logger.Infof("Loaded %v access keys over %v ports", numServices, len(s.ports)) - s.m.SetNumAccessKeys(numServices, len(s.ports)) + logger.Infof("Loaded %v access keys over %v listeners", uniqueCiphers, len(s.listeners)) + s.m.SetNumAccessKeys(uniqueCiphers, len(s.listeners)) return nil } // Stop serving on all ports. func (s *SSServer) Stop() error { - for portNum := range s.ports { - if err := s.removePort(portNum); err != nil { + for addr := range s.listeners { + if err := s.remove(addr); err != nil { return err } } @@ -239,7 +236,7 @@ func RunSSServer(filename string, natTimeout time.Duration, sm *outlineMetrics, natTimeout: natTimeout, m: sm, replayCache: service.NewReplayCache(replayHistory), - ports: make(map[int]*ssPort), + listeners: make(map[string]*ssListener), } err := server.loadConfig(filename) if err != nil { From c1ee12f672571754d87d8bd8e6bb4b3dd6918844 Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 3 Jun 2024 17:37:36 -0400 Subject: [PATCH 005/119] Remove commented out line. --- cmd/outline-ss-server/main.go | 1 - 1 file changed, 1 deletion(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 206ff05d..f45c8f4a 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -179,7 +179,6 @@ func (s *SSServer) loadConfig(filename string) error { switch t := listener.Type; t { // TODO: Support more listener types. case directListenerType: - //var addr net.Addr addr, err := onet.ResolveAddr(listener.Address) if err != nil { return fmt.Errorf("failed to resolve direct address: %v: %w", listener.Address, err) From 354301eafdd3dd99f2b37177a79776d6d1a8d9db Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 3 Jun 2024 17:42:19 -0400 Subject: [PATCH 006/119] Use `ElementsMatch` to compare the services irrespective of element ordering. --- cmd/outline-ss-server/config_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/outline-ss-server/config_test.go b/cmd/outline-ss-server/config_test.go index 0d46489c..1718a1fc 100644 --- a/cmd/outline-ss-server/config_test.go +++ b/cmd/outline-ss-server/config_test.go @@ -75,5 +75,5 @@ func TestReadConfigParsesDeprecatedFormat(t *testing.T) { }, }, } - require.Equal(t, expected, *config) + require.ElementsMatch(t, expected.Services, config.Services) } From 751d1643f1eb26f0e40c85caf2a5caeb9d9ac876 Mon Sep 17 00:00:00 2001 From: sbruens Date: Wed, 12 Jun 2024 13:58:55 -0400 Subject: [PATCH 007/119] Do not ignore the `keys` field if `services` is used as well. --- cmd/outline-ss-server/config.go | 54 +++++++++++++++++---------------- 1 file changed, 28 insertions(+), 26 deletions(-) diff --git a/cmd/outline-ss-server/config.go b/cmd/outline-ss-server/config.go index 1f97e182..21934e31 100644 --- a/cmd/outline-ss-server/config.go +++ b/cmd/outline-ss-server/config.go @@ -37,16 +37,17 @@ type Key struct { Secret string } +type LegacyKeyService struct { + Key `yaml:",inline"` + Port int +} + type Config struct { Services []Service - // Deprecated: Keys exists for historical compatibility. This is ignored if top-level `services` is specified. - Keys []struct { - ID string - Port int - Cipher string - Secret string - } + // Deprecated: `keys` exists for backward compatibility. Prefer to configure + // using the newer `services` format. + Keys []LegacyKeyService } // Reads a config from a filename and parses it as a [Config]. @@ -60,27 +61,28 @@ func ReadConfig(filename string) (*Config, error) { if err != nil { return nil, fmt.Errorf("failed to parse config: %w", err) } - if config.Services == nil { - // This is a deprecated config format. We need to transform it to to the new format. - ports := make(map[int][]Key) - for _, keyConfig := range config.Keys { - ports[keyConfig.Port] = append(ports[keyConfig.Port], Key{ - ID: keyConfig.ID, - Cipher: keyConfig.Cipher, - Secret: keyConfig.Secret, - }) - } - for port, keys := range ports { - s := Service{ - Listeners: []Listener{ - Listener{Type: "direct", Address: fmt.Sprintf("tcp://[::]:%d", port)}, - Listener{Type: "direct", Address: fmt.Sprintf("udp://[::]:%d", port)}, - }, - Keys: keys, - } - config.Services = append(config.Services, s) + + // Specifying keys in `config.Keys` is a deprecated config format. We need to + // transform it to to the new format. + ports := make(map[int][]Key) + for _, keyConfig := range config.Keys { + ports[keyConfig.Port] = append(ports[keyConfig.Port], Key{ + ID: keyConfig.ID, + Cipher: keyConfig.Cipher, + Secret: keyConfig.Secret, + }) + } + for port, keys := range ports { + s := Service{ + Listeners: []Listener{ + Listener{Type: "direct", Address: fmt.Sprintf("tcp://[::]:%d", port)}, + Listener{Type: "direct", Address: fmt.Sprintf("udp://[::]:%d", port)}, + }, + Keys: keys, } + config.Services = append(config.Services, s) } config.Keys = nil + return &config, nil } From 6297304d5170de3224ccd762f1ced7b6f9a82e63 Mon Sep 17 00:00:00 2001 From: sbruens Date: Wed, 12 Jun 2024 14:05:28 -0400 Subject: [PATCH 008/119] Add some more tests for failure scenarios and empty files. --- cmd/outline-ss-server/config_test.go | 33 ++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/cmd/outline-ss-server/config_test.go b/cmd/outline-ss-server/config_test.go index 1718a1fc..329c517f 100644 --- a/cmd/outline-ss-server/config_test.go +++ b/cmd/outline-ss-server/config_test.go @@ -15,14 +15,16 @@ package main import ( + "os" "testing" "github.com/stretchr/testify/require" ) func TestReadConfig(t *testing.T) { - config, _ := ReadConfig("./config_example.yml") + config, err := ReadConfig("./config_example.yml") + require.NoError(t, err) expected := Config{ Services: []Service{ Service{ @@ -50,8 +52,9 @@ func TestReadConfig(t *testing.T) { } func TestReadConfigParsesDeprecatedFormat(t *testing.T) { - config, _ := ReadConfig("./config_example.deprecated.yml") + config, err := ReadConfig("./config_example.deprecated.yml") + require.NoError(t, err) expected := Config{ Services: []Service{ Service{ @@ -77,3 +80,29 @@ func TestReadConfigParsesDeprecatedFormat(t *testing.T) { } require.ElementsMatch(t, expected.Services, config.Services) } + +func TestReadConfigFromEmptyFile(t *testing.T) { + file, _ := os.CreateTemp("", "empty.yaml") + + config, err := ReadConfig(file.Name()) + + require.NoError(t, err) + require.ElementsMatch(t, Config{}, config) +} + +func TestReadConfigFromNonExistingFileFails(t *testing.T) { + config, err := ReadConfig("./foo") + + require.Error(t, err) + require.ElementsMatch(t, nil, config) +} + +func TestReadConfigFromIncorrectFormatFails(t *testing.T) { + file, _ := os.CreateTemp("", "empty.yaml") + file.WriteString("foo") + + config, err := ReadConfig(file.Name()) + + require.Error(t, err) + require.ElementsMatch(t, Config{}, config) +} From 0ac0a724be44fb4bbcd3a3d260ae341d6d62d3d9 Mon Sep 17 00:00:00 2001 From: sbruens Date: Wed, 12 Jun 2024 14:12:48 -0400 Subject: [PATCH 009/119] Remove unused `GetPort()`. --- net/address.go | 13 ------------- net/address_test.go | 21 --------------------- 2 files changed, 34 deletions(-) diff --git a/net/address.go b/net/address.go index a74e4f13..7438ee16 100644 --- a/net/address.go +++ b/net/address.go @@ -15,7 +15,6 @@ package net import ( - "fmt" "net" "net/url" ) @@ -49,15 +48,3 @@ func ResolveAddr(addr string) (net.Addr, error) { return nil, net.UnknownNetworkError(u.Scheme) } } - -// Returns the port from a given address. -func GetPort(addr net.Addr) (port int, err error) { - switch t := addr.(type) { - case *net.TCPAddr: - return t.Port, nil - case *net.UDPAddr: - return t.Port, nil - default: - return -1, fmt.Errorf("failed to get port from address: %v", addr) - } -} diff --git a/net/address_test.go b/net/address_test.go index d7fe1d01..f0904781 100644 --- a/net/address_test.go +++ b/net/address_test.go @@ -54,24 +54,3 @@ func TestResolveAddrReturnsErrorForUnknownScheme(t *testing.T) { require.Nil(t, addr) require.Error(t, err) } - -func TestGetPortFromTCPAddr(t *testing.T) { - port, err := GetPort(&net.TCPAddr{IP: net.ParseIP("1.2.3.4"), Port: 1234}) - - require.NoError(t, err) - require.Equal(t, 1234, port) -} - -func TestGetPortFromUDPPAddr(t *testing.T) { - port, err := GetPort(&net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 5678}) - - require.NoError(t, err) - require.Equal(t, 5678, port) -} - -func TestGetPortReturnsErrorForUnsupportedAddressType(t *testing.T) { - port, err := GetPort(&net.UnixAddr{Name: "/path/to/foo", Net: "unix"}) - - require.Equal(t, -1, port) - require.Error(t, err) -} From 794f860fec4facedd4b45b80fde0877e1b7b1cac Mon Sep 17 00:00:00 2001 From: sbruens Date: Wed, 12 Jun 2024 16:06:24 -0400 Subject: [PATCH 010/119] Move `ResolveAddr` to config.go. --- cmd/outline-ss-server/config.go | 32 ++++++++++++++++ cmd/outline-ss-server/config_test.go | 35 +++++++++++++++++ cmd/outline-ss-server/main.go | 3 +- net/address.go | 50 ------------------------- net/address_test.go | 56 ---------------------------- 5 files changed, 68 insertions(+), 108 deletions(-) delete mode 100644 net/address.go delete mode 100644 net/address_test.go diff --git a/cmd/outline-ss-server/config.go b/cmd/outline-ss-server/config.go index 21934e31..e2fcbd87 100644 --- a/cmd/outline-ss-server/config.go +++ b/cmd/outline-ss-server/config.go @@ -16,6 +16,8 @@ package main import ( "fmt" + "net" + "net/url" "os" "gopkg.in/yaml.v2" @@ -86,3 +88,33 @@ func ReadConfig(filename string) (*Config, error) { return &config, nil } + +// Resolves a URL-style listen address specification as a [net.Addr]. +// +// Examples: +// +// udp6://127.0.0.1:8000 +// unix:///tmp/foo.sock +// tcp://127.0.0.1:9002 +func ResolveAddr(addr string) (net.Addr, error) { + u, err := url.Parse(addr) + if err != nil { + return nil, err + } + switch u.Scheme { + case "tcp", "tcp4", "tcp6": + return net.ResolveTCPAddr(u.Scheme, u.Host) + case "udp", "udp4", "udp6": + return net.ResolveUDPAddr(u.Scheme, u.Host) + case "unix", "unixgram", "unixpacket": + var path string + if u.Opaque != "" { + path = u.Opaque + } else { + path = u.Path + } + return net.ResolveUnixAddr(u.Scheme, path) + default: + return nil, net.UnknownNetworkError(u.Scheme) + } +} diff --git a/cmd/outline-ss-server/config_test.go b/cmd/outline-ss-server/config_test.go index 329c517f..247d7e68 100644 --- a/cmd/outline-ss-server/config_test.go +++ b/cmd/outline-ss-server/config_test.go @@ -15,6 +15,7 @@ package main import ( + "net" "os" "testing" @@ -106,3 +107,37 @@ func TestReadConfigFromIncorrectFormatFails(t *testing.T) { require.Error(t, err) require.ElementsMatch(t, Config{}, config) } + +func TestResolveAddrReturnsTCPAddr(t *testing.T) { + addr, err := ResolveAddr("tcp://0.0.0.0:9000") + + require.NoError(t, err) + if _, ok := addr.(*net.TCPAddr); !ok { + t.Errorf("expected a *net.TCPAddr; it is a %T", addr) + } +} + +func TestResolveAddrReturnsUDPAddr(t *testing.T) { + addr, err := ResolveAddr("udp://[::]:9001") + + require.NoError(t, err) + if _, ok := addr.(*net.UDPAddr); !ok { + t.Errorf("expected a *net.UDPAddr; it is a %T", addr) + } +} + +func TestResolveAddrReturnsUnixAddr(t *testing.T) { + addr, err := ResolveAddr("unix:///path/to/stream_socket") + + require.NoError(t, err) + if _, ok := addr.(*net.UnixAddr); !ok { + t.Errorf("expected a *net.UnixAddr; it is a %T", addr) + } +} + +func TestResolveAddrReturnsErrorForUnknownScheme(t *testing.T) { + addr, err := ResolveAddr("foobar") + + require.Nil(t, addr) + require.Error(t, err) +} diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index f45c8f4a..546bfcb5 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -31,7 +31,6 @@ import ( "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks" "github.com/Jigsaw-Code/outline-ss-server/ipinfo" - onet "github.com/Jigsaw-Code/outline-ss-server/net" "github.com/Jigsaw-Code/outline-ss-server/service" "github.com/op/go-logging" "github.com/prometheus/client_golang/prometheus" @@ -179,7 +178,7 @@ func (s *SSServer) loadConfig(filename string) error { switch t := listener.Type; t { // TODO: Support more listener types. case directListenerType: - addr, err := onet.ResolveAddr(listener.Address) + addr, err := ResolveAddr(listener.Address) if err != nil { return fmt.Errorf("failed to resolve direct address: %v: %w", listener.Address, err) } diff --git a/net/address.go b/net/address.go deleted file mode 100644 index 7438ee16..00000000 --- a/net/address.go +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2024 Jigsaw Operations LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package net - -import ( - "net" - "net/url" -) - -// Resolves a URL-style listen address specification as a [net.Addr] -// -// Examples: -// -// udp6://127.0.0.1:8000 -// unix:///tmp/foo.sock -// tcp://127.0.0.1:9002 -func ResolveAddr(addr string) (net.Addr, error) { - u, err := url.Parse(addr) - if err != nil { - return nil, err - } - switch u.Scheme { - case "tcp", "tcp4", "tcp6": - return net.ResolveTCPAddr(u.Scheme, u.Host) - case "udp", "udp4", "udp6": - return net.ResolveUDPAddr(u.Scheme, u.Host) - case "unix", "unixgram", "unixpacket": - var path string - if u.Opaque != "" { - path = u.Opaque - } else { - path = u.Path - } - return net.ResolveUnixAddr(u.Scheme, path) - default: - return nil, net.UnknownNetworkError(u.Scheme) - } -} diff --git a/net/address_test.go b/net/address_test.go deleted file mode 100644 index f0904781..00000000 --- a/net/address_test.go +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright 2024 Jigsaw Operations LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package net - -import ( - "net" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestResolveAddrReturnsTCPAddr(t *testing.T) { - addr, err := ResolveAddr("tcp://0.0.0.0:9000") - - require.NoError(t, err) - if _, ok := addr.(*net.TCPAddr); !ok { - t.Errorf("expected a *net.TCPAddr; it is a %T", addr) - } -} - -func TestResolveAddrReturnsUDPAddr(t *testing.T) { - addr, err := ResolveAddr("udp://[::]:9001") - - require.NoError(t, err) - if _, ok := addr.(*net.UDPAddr); !ok { - t.Errorf("expected a *net.UDPAddr; it is a %T", addr) - } -} - -func TestResolveAddrReturnsUnixAddr(t *testing.T) { - addr, err := ResolveAddr("unix:///path/to/stream_socket") - - require.NoError(t, err) - if _, ok := addr.(*net.UnixAddr); !ok { - t.Errorf("expected a *net.UnixAddr; it is a %T", addr) - } -} - -func TestResolveAddrReturnsErrorForUnknownScheme(t *testing.T) { - addr, err := ResolveAddr("foobar") - - require.Nil(t, addr) - require.Error(t, err) -} From 01b7e8a20b727dffeaccdee11c79b2658ed0a4a1 Mon Sep 17 00:00:00 2001 From: sbruens Date: Wed, 12 Jun 2024 18:24:50 -0400 Subject: [PATCH 011/119] Remove use of `net.Addr` type. --- cmd/outline-ss-server/config.go | 32 ---------- cmd/outline-ss-server/config_test.go | 35 ---------- cmd/outline-ss-server/main.go | 95 ++++++++++++++-------------- service/tcp.go | 8 +-- 4 files changed, 53 insertions(+), 117 deletions(-) diff --git a/cmd/outline-ss-server/config.go b/cmd/outline-ss-server/config.go index e2fcbd87..21934e31 100644 --- a/cmd/outline-ss-server/config.go +++ b/cmd/outline-ss-server/config.go @@ -16,8 +16,6 @@ package main import ( "fmt" - "net" - "net/url" "os" "gopkg.in/yaml.v2" @@ -88,33 +86,3 @@ func ReadConfig(filename string) (*Config, error) { return &config, nil } - -// Resolves a URL-style listen address specification as a [net.Addr]. -// -// Examples: -// -// udp6://127.0.0.1:8000 -// unix:///tmp/foo.sock -// tcp://127.0.0.1:9002 -func ResolveAddr(addr string) (net.Addr, error) { - u, err := url.Parse(addr) - if err != nil { - return nil, err - } - switch u.Scheme { - case "tcp", "tcp4", "tcp6": - return net.ResolveTCPAddr(u.Scheme, u.Host) - case "udp", "udp4", "udp6": - return net.ResolveUDPAddr(u.Scheme, u.Host) - case "unix", "unixgram", "unixpacket": - var path string - if u.Opaque != "" { - path = u.Opaque - } else { - path = u.Path - } - return net.ResolveUnixAddr(u.Scheme, path) - default: - return nil, net.UnknownNetworkError(u.Scheme) - } -} diff --git a/cmd/outline-ss-server/config_test.go b/cmd/outline-ss-server/config_test.go index 247d7e68..329c517f 100644 --- a/cmd/outline-ss-server/config_test.go +++ b/cmd/outline-ss-server/config_test.go @@ -15,7 +15,6 @@ package main import ( - "net" "os" "testing" @@ -107,37 +106,3 @@ func TestReadConfigFromIncorrectFormatFails(t *testing.T) { require.Error(t, err) require.ElementsMatch(t, Config{}, config) } - -func TestResolveAddrReturnsTCPAddr(t *testing.T) { - addr, err := ResolveAddr("tcp://0.0.0.0:9000") - - require.NoError(t, err) - if _, ok := addr.(*net.TCPAddr); !ok { - t.Errorf("expected a *net.TCPAddr; it is a %T", addr) - } -} - -func TestResolveAddrReturnsUDPAddr(t *testing.T) { - addr, err := ResolveAddr("udp://[::]:9001") - - require.NoError(t, err) - if _, ok := addr.(*net.UDPAddr); !ok { - t.Errorf("expected a *net.UDPAddr; it is a %T", addr) - } -} - -func TestResolveAddrReturnsUnixAddr(t *testing.T) { - addr, err := ResolveAddr("unix:///path/to/stream_socket") - - require.NoError(t, err) - if _, ok := addr.(*net.UnixAddr); !ok { - t.Errorf("expected a *net.UnixAddr; it is a %T", addr) - } -} - -func TestResolveAddrReturnsErrorForUnknownScheme(t *testing.T) { - addr, err := ResolveAddr("foobar") - - require.Nil(t, addr) - require.Error(t, err) -} diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 546bfcb5..294934a5 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -21,6 +21,7 @@ import ( "io" "net" "net/http" + "net/url" "os" "os/signal" "strings" @@ -74,48 +75,58 @@ type SSServer struct { listeners map[string]*ssListener } -func (s *SSServer) startDirectTCP(addr *net.TCPAddr, cipherList service.CipherList) (*net.TCPListener, error) { - listener, err := net.ListenTCP("tcp", addr) - if err != nil { - //lint:ignore ST1005 Shadowsocks is capitalized. - return nil, fmt.Errorf("Shadowsocks TCP service failed to start on address %v: %w", addr.String(), err) - } - logger.Infof("Shadowsocks TCP service listening on %v", listener.Addr().String()) - authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &s.replayCache, s.m) - // TODO: Register initial data metrics at zero. - tcpHandler := service.NewTCPHandler(addr.Port, authFunc, s.m, tcpReadTimeout) - accept := func() (transport.StreamConn, error) { - conn, err := listener.AcceptTCP() - if err == nil { - conn.SetKeepAlive(true) +func (s *SSServer) serve(listener io.Closer, cipherList service.CipherList) error { + switch ln := listener.(type) { + case net.Listener: + authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &s.replayCache, s.m) + // TODO: Register initial data metrics at zero. + tcpHandler := service.NewTCPHandler(authFunc, s.m, tcpReadTimeout) + accept := func() (transport.StreamConn, error) { + conn, err := ln.Accept() + if err == nil { + conn.(*net.TCPConn).SetKeepAlive(true) + } + return conn.(transport.StreamConn), err } - return conn, err + go service.StreamServe(accept, tcpHandler.Handle) + case net.PacketConn: + packetHandler := service.NewPacketHandler(s.natTimeout, cipherList, s.m) + go packetHandler.Handle(ln) + default: + return fmt.Errorf("unknown listener type: %v", ln) } - go service.StreamServe(accept, tcpHandler.Handle) - return listener, nil + return nil } -func (s *SSServer) startDirectUDP(addr *net.UDPAddr, cipherList service.CipherList) (*net.UDPConn, error) { - packetConn, err := net.ListenUDP("udp", addr) +func (s *SSServer) start(addr string, cipherList service.CipherList) (io.Closer, error) { + u, err := url.Parse(addr) if err != nil { - //lint:ignore ST1005 Shadowsocks is capitalized. - return nil, fmt.Errorf("Shadowsocks UDP service failed to start on address %v: %w", addr.String(), err) + return nil, err } - logger.Infof("Shadowsocks UDP service listening on %v", packetConn.LocalAddr().String()) - packetHandler := service.NewPacketHandler(s.natTimeout, cipherList, s.m) - go packetHandler.Handle(packetConn) - return packetConn, nil -} -func (s *SSServer) start(addr net.Addr, cipherList service.CipherList) (io.Closer, error) { - switch t := addr.(type) { - case *net.TCPAddr: - return s.startDirectTCP(t, cipherList) - case *net.UDPAddr: - return s.startDirectUDP(t, cipherList) + var listener io.Closer + switch u.Scheme { + case "tcp", "tcp4", "tcp6": + // TODO: Validate `u` address. + listener, err = net.Listen(u.Scheme, u.Host) + case "udp", "udp4", "udp6": + // TODO: Validate `u` address. + listener, err = net.ListenPacket(u.Scheme, u.Host) default: - return nil, fmt.Errorf("unable to start address: %s", t) + return nil, fmt.Errorf("unsupported protocol: %s", u.Scheme) } + if err != nil { + //lint:ignore ST1005 Shadowsocks is capitalized. + return nil, fmt.Errorf("Shadowsocks service failed to start on address %v: %w", addr, err) + } + logger.Infof("Shadowsocks service listening on %v", addr) + + err = s.serve(listener, cipherList) + if err != nil { + return nil, fmt.Errorf("failed to serve on listener %w: %w", listener, err) + } + + return listener, nil } func (s *SSServer) remove(addr string) error { @@ -141,11 +152,7 @@ func (s *SSServer) loadConfig(filename string) error { uniqueCiphers := 0 addrChanges := make(map[string]int) - type addrWithCiphers struct { - address net.Addr - ciphers *list.List // Values are *List of *CipherEntry. - } - addrs := make(map[string]*addrWithCiphers) + addrCiphers := make(map[string]*list.List) // Values are *List of *CipherEntry. for _, serviceConfig := range config.Services { if serviceConfig.Listeners == nil || serviceConfig.Keys == nil { return fmt.Errorf("must specify at least 1 listener and 1 key per service") @@ -178,12 +185,8 @@ func (s *SSServer) loadConfig(filename string) error { switch t := listener.Type; t { // TODO: Support more listener types. case directListenerType: - addr, err := ResolveAddr(listener.Address) - if err != nil { - return fmt.Errorf("failed to resolve direct address: %v: %w", listener.Address, err) - } addrChanges[listener.Address] = 1 - addrs[listener.Address] = &addrWithCiphers{addr, ciphers} + addrCiphers[listener.Address] = ciphers default: return fmt.Errorf("unsupported listener type: %s", t) } @@ -199,19 +202,19 @@ func (s *SSServer) loadConfig(filename string) error { } } else if count == +1 { cipherList := service.NewCipherList() - listener, err := s.start(addrs[addr].address, cipherList) + listener, err := s.start(addr, cipherList) if err != nil { return err } s.listeners[addr] = &ssListener{Closer: listener, cipherList: cipherList} } } - for addr, addrWithCiphers := range addrs { + for addr, ciphers := range addrCiphers { listener, ok := s.listeners[addr] if !ok { return fmt.Errorf("unable to find listener for address: %v", addr) } - listener.cipherList.Update(addrWithCiphers.ciphers) + listener.cipherList.Update(ciphers) } logger.Infof("Loaded %v access keys over %v listeners", uniqueCiphers, len(s.listeners)) s.m.SetNumAccessKeys(uniqueCiphers, len(s.listeners)) diff --git a/service/tcp.go b/service/tcp.go index 85ab9990..484a1b90 100644 --- a/service/tcp.go +++ b/service/tcp.go @@ -170,9 +170,8 @@ type tcpHandler struct { } // NewTCPService creates a TCPService -func NewTCPHandler(port int, authenticate StreamAuthenticateFunc, m TCPMetrics, timeout time.Duration) TCPHandler { +func NewTCPHandler(authenticate StreamAuthenticateFunc, m TCPMetrics, timeout time.Duration) TCPHandler { return &tcpHandler{ - port: port, m: m, readTimeout: timeout, authenticate: authenticate, @@ -342,7 +341,8 @@ func (h *tcpHandler) handleConnection(ctx context.Context, outerConn transport.S id, innerConn, authErr := h.authenticate(outerConn) if authErr != nil { // Drain to protect against probing attacks. - h.absorbProbe(outerConn, authErr.Status, proxyMetrics) + port := outerConn.LocalAddr().(*net.TCPAddr).Port + h.absorbProbe(outerConn, port, authErr.Status, proxyMetrics) return id, authErr } h.m.AddAuthenticatedTCPConnection(outerConn.RemoteAddr(), id) @@ -370,7 +370,7 @@ func (h *tcpHandler) handleConnection(ctx context.Context, outerConn transport.S // Keep the connection open until we hit the authentication deadline to protect against probing attacks // `proxyMetrics` is a pointer because its value is being mutated by `clientConn`. -func (h *tcpHandler) absorbProbe(clientConn io.ReadCloser, status string, proxyMetrics *metrics.ProxyMetrics) { +func (h *tcpHandler) absorbProbe(clientConn io.ReadCloser, port int, status string, proxyMetrics *metrics.ProxyMetrics) { // This line updates proxyMetrics.ClientProxy before it's used in AddTCPProbe. _, drainErr := io.Copy(io.Discard, clientConn) // drain socket drainResult := drainErrToString(drainErr) From 87a15650e2ca5116e50199b63986f813dc3d8fa5 Mon Sep 17 00:00:00 2001 From: sbruens Date: Wed, 12 Jun 2024 18:26:28 -0400 Subject: [PATCH 012/119] Pull listener creation into its own function. --- cmd/outline-ss-server/main.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 294934a5..8d853ffd 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -98,23 +98,26 @@ func (s *SSServer) serve(listener io.Closer, cipherList service.CipherList) erro return nil } -func (s *SSServer) start(addr string, cipherList service.CipherList) (io.Closer, error) { +func newListener(addr string) (io.Closer, error) { u, err := url.Parse(addr) if err != nil { return nil, err } - var listener io.Closer switch u.Scheme { case "tcp", "tcp4", "tcp6": // TODO: Validate `u` address. - listener, err = net.Listen(u.Scheme, u.Host) + return net.Listen(u.Scheme, u.Host) case "udp", "udp4", "udp6": // TODO: Validate `u` address. - listener, err = net.ListenPacket(u.Scheme, u.Host) + return net.ListenPacket(u.Scheme, u.Host) default: return nil, fmt.Errorf("unsupported protocol: %s", u.Scheme) } +} + +func (s *SSServer) start(addr string, cipherList service.CipherList) (io.Closer, error) { + listener, err := newListener(addr) if err != nil { //lint:ignore ST1005 Shadowsocks is capitalized. return nil, fmt.Errorf("Shadowsocks service failed to start on address %v: %w", addr, err) From 51a13a7a52801a61d0bc06eaa2bae83eacf8fa84 Mon Sep 17 00:00:00 2001 From: sbruens Date: Fri, 14 Jun 2024 10:10:36 -0400 Subject: [PATCH 013/119] Move listener validation/creation to `config.go`. --- cmd/outline-ss-server/config.go | 46 +++++++++++++++++++++++++++++++++ cmd/outline-ss-server/main.go | 21 +-------------- 2 files changed, 47 insertions(+), 20 deletions(-) diff --git a/cmd/outline-ss-server/config.go b/cmd/outline-ss-server/config.go index 21934e31..8d8cd37e 100644 --- a/cmd/outline-ss-server/config.go +++ b/cmd/outline-ss-server/config.go @@ -15,7 +15,11 @@ package main import ( + "errors" "fmt" + "io" + "net" + "net/url" "os" "gopkg.in/yaml.v2" @@ -86,3 +90,45 @@ func ReadConfig(filename string) (*Config, error) { return &config, nil } + +// validateListener asserts that a listener URI conforms to the expected format. +func validateListener(u *url.URL) error { + if u.Opaque != "" { + return errors.New("URI cannot have an opaque part") + } + if u.User != nil { + return errors.New("URI cannot have an userdata part") + } + if u.RawQuery != "" || u.ForceQuery { + return errors.New("URI cannot have a query part") + } + if u.Fragment != "" { + return errors.New("URI cannot have a fragement") + } + if u.Path != "" && u.Path != "/" { + return errors.New("URI path not allowed") + } + return nil +} + +func NewListener(addr string) (io.Closer, error) { + u, err := url.Parse(addr) + if err != nil { + return nil, err + } + + switch u.Scheme { + case "tcp", "tcp4", "tcp6": + if err := validateListener(u); err != nil { + return nil, fmt.Errorf("invalid listener `%s`: %v", u, err) + } + return net.Listen(u.Scheme, u.Host) + case "udp", "udp4", "udp6": + if err := validateListener(u); err != nil { + return nil, fmt.Errorf("invalid listener `%s`: %v", u, err) + } + return net.ListenPacket(u.Scheme, u.Host) + default: + return nil, fmt.Errorf("unsupported protocol: %s", u.Scheme) + } +} diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 8d853ffd..323bcda7 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -21,7 +21,6 @@ import ( "io" "net" "net/http" - "net/url" "os" "os/signal" "strings" @@ -98,26 +97,8 @@ func (s *SSServer) serve(listener io.Closer, cipherList service.CipherList) erro return nil } -func newListener(addr string) (io.Closer, error) { - u, err := url.Parse(addr) - if err != nil { - return nil, err - } - - switch u.Scheme { - case "tcp", "tcp4", "tcp6": - // TODO: Validate `u` address. - return net.Listen(u.Scheme, u.Host) - case "udp", "udp4", "udp6": - // TODO: Validate `u` address. - return net.ListenPacket(u.Scheme, u.Host) - default: - return nil, fmt.Errorf("unsupported protocol: %s", u.Scheme) - } -} - func (s *SSServer) start(addr string, cipherList service.CipherList) (io.Closer, error) { - listener, err := newListener(addr) + listener, err := NewListener(addr) if err != nil { //lint:ignore ST1005 Shadowsocks is capitalized. return nil, fmt.Errorf("Shadowsocks service failed to start on address %v: %w", addr, err) From f8d7aa5b90bd13bae3c1b2542e684adb6497774f Mon Sep 17 00:00:00 2001 From: sbruens Date: Fri, 14 Jun 2024 10:30:16 -0400 Subject: [PATCH 014/119] Use a custom type for listener type. --- cmd/outline-ss-server/config.go | 12 ++++++++---- cmd/outline-ss-server/config_test.go | 16 ++++++++-------- cmd/outline-ss-server/main.go | 8 +++----- 3 files changed, 19 insertions(+), 17 deletions(-) diff --git a/cmd/outline-ss-server/config.go b/cmd/outline-ss-server/config.go index 8d8cd37e..112ba58f 100644 --- a/cmd/outline-ss-server/config.go +++ b/cmd/outline-ss-server/config.go @@ -30,8 +30,12 @@ type Service struct { Keys []Key } +type ListenerType string + +const listenerTypeDirect ListenerType = "direct" + type Listener struct { - Type string + Type ListenerType Address string } @@ -79,8 +83,8 @@ func ReadConfig(filename string) (*Config, error) { for port, keys := range ports { s := Service{ Listeners: []Listener{ - Listener{Type: "direct", Address: fmt.Sprintf("tcp://[::]:%d", port)}, - Listener{Type: "direct", Address: fmt.Sprintf("udp://[::]:%d", port)}, + Listener{Type: listenerTypeDirect, Address: fmt.Sprintf("tcp://[::]:%d", port)}, + Listener{Type: listenerTypeDirect, Address: fmt.Sprintf("udp://[::]:%d", port)}, }, Keys: keys, } @@ -111,7 +115,7 @@ func validateListener(u *url.URL) error { return nil } -func NewListener(addr string) (io.Closer, error) { +func newListener(addr string) (io.Closer, error) { u, err := url.Parse(addr) if err != nil { return nil, err diff --git a/cmd/outline-ss-server/config_test.go b/cmd/outline-ss-server/config_test.go index 329c517f..88660388 100644 --- a/cmd/outline-ss-server/config_test.go +++ b/cmd/outline-ss-server/config_test.go @@ -29,8 +29,8 @@ func TestReadConfig(t *testing.T) { Services: []Service{ Service{ Listeners: []Listener{ - Listener{Type: "direct", Address: "tcp://[::]:9000"}, - Listener{Type: "direct", Address: "udp://[::]:9000"}, + Listener{Type: listenerTypeDirect, Address: "tcp://[::]:9000"}, + Listener{Type: listenerTypeDirect, Address: "udp://[::]:9000"}, }, Keys: []Key{ Key{"user-0", "chacha20-ietf-poly1305", "Secret0"}, @@ -39,8 +39,8 @@ func TestReadConfig(t *testing.T) { }, Service{ Listeners: []Listener{ - Listener{Type: "direct", Address: "tcp://[::]:9001"}, - Listener{Type: "direct", Address: "udp://[::]:9001"}, + Listener{Type: listenerTypeDirect, Address: "tcp://[::]:9001"}, + Listener{Type: listenerTypeDirect, Address: "udp://[::]:9001"}, }, Keys: []Key{ Key{"user-2", "chacha20-ietf-poly1305", "Secret2"}, @@ -59,8 +59,8 @@ func TestReadConfigParsesDeprecatedFormat(t *testing.T) { Services: []Service{ Service{ Listeners: []Listener{ - Listener{Type: "direct", Address: "tcp://[::]:9000"}, - Listener{Type: "direct", Address: "udp://[::]:9000"}, + Listener{Type: listenerTypeDirect, Address: "tcp://[::]:9000"}, + Listener{Type: listenerTypeDirect, Address: "udp://[::]:9000"}, }, Keys: []Key{ Key{"user-0", "chacha20-ietf-poly1305", "Secret0"}, @@ -69,8 +69,8 @@ func TestReadConfigParsesDeprecatedFormat(t *testing.T) { }, Service{ Listeners: []Listener{ - Listener{Type: "direct", Address: "tcp://[::]:9001"}, - Listener{Type: "direct", Address: "udp://[::]:9001"}, + Listener{Type: listenerTypeDirect, Address: "tcp://[::]:9001"}, + Listener{Type: listenerTypeDirect, Address: "udp://[::]:9001"}, }, Keys: []Key{ Key{"user-2", "chacha20-ietf-poly1305", "Secret2"}, diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 323bcda7..2af841cc 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -49,8 +49,6 @@ const tcpReadTimeout time.Duration = 59 * time.Second // A UDP NAT timeout of at least 5 minutes is recommended in RFC 4787 Section 4.3. const defaultNatTimeout time.Duration = 5 * time.Minute -var directListenerType = "direct" - func init() { var prefix = "%{level:.1s}%{time:2006-01-02T15:04:05.000Z07:00} %{pid} %{shortfile}]" if term.IsTerminal(int(os.Stderr.Fd())) { @@ -98,7 +96,7 @@ func (s *SSServer) serve(listener io.Closer, cipherList service.CipherList) erro } func (s *SSServer) start(addr string, cipherList service.CipherList) (io.Closer, error) { - listener, err := NewListener(addr) + listener, err := newListener(addr) if err != nil { //lint:ignore ST1005 Shadowsocks is capitalized. return nil, fmt.Errorf("Shadowsocks service failed to start on address %v: %w", addr, err) @@ -107,7 +105,7 @@ func (s *SSServer) start(addr string, cipherList service.CipherList) (io.Closer, err = s.serve(listener, cipherList) if err != nil { - return nil, fmt.Errorf("failed to serve on listener %w: %w", listener, err) + return nil, fmt.Errorf("failed to serve on listener %v: %w", listener, err) } return listener, nil @@ -168,7 +166,7 @@ func (s *SSServer) loadConfig(filename string) error { for _, listener := range serviceConfig.Listeners { switch t := listener.Type; t { // TODO: Support more listener types. - case directListenerType: + case listenerTypeDirect: addrChanges[listener.Address] = 1 addrCiphers[listener.Address] = ciphers default: From 19520364bcd8f06322455deb06ad28e4dd2642dc Mon Sep 17 00:00:00 2001 From: sbruens Date: Fri, 14 Jun 2024 10:45:38 -0400 Subject: [PATCH 015/119] Fix accept handler. --- cmd/outline-ss-server/main.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 2af841cc..085b8c0e 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -80,10 +80,12 @@ func (s *SSServer) serve(listener io.Closer, cipherList service.CipherList) erro tcpHandler := service.NewTCPHandler(authFunc, s.m, tcpReadTimeout) accept := func() (transport.StreamConn, error) { conn, err := ln.Accept() - if err == nil { - conn.(*net.TCPConn).SetKeepAlive(true) + if err != nil { + return nil, err } - return conn.(transport.StreamConn), err + c := conn.(*net.TCPConn) + c.SetKeepAlive(true) + return c, err } go service.StreamServe(accept, tcpHandler.Handle) case net.PacketConn: From 7212265bdc10c5fac9bb957329ae019f983e6b2a Mon Sep 17 00:00:00 2001 From: sbruens Date: Fri, 14 Jun 2024 10:56:48 -0400 Subject: [PATCH 016/119] Add doc comment. --- cmd/outline-ss-server/config.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/cmd/outline-ss-server/config.go b/cmd/outline-ss-server/config.go index 112ba58f..8ce54066 100644 --- a/cmd/outline-ss-server/config.go +++ b/cmd/outline-ss-server/config.go @@ -115,6 +115,12 @@ func validateListener(u *url.URL) error { return nil } +// newListener creates a new listener from a URL-style address specification. +// +// Example addresses: +// +// tcp4://127.0.0.1:8000 +// udp://127.0.0.1:9000 func newListener(addr string) (io.Closer, error) { u, err := url.Parse(addr) if err != nil { From 6e2068d2491d3277ce44d5a6965a0312c123e24e Mon Sep 17 00:00:00 2001 From: sbruens Date: Fri, 14 Jun 2024 11:00:29 -0400 Subject: [PATCH 017/119] Fix tests still supplying the port. --- internal/integration_test/integration_test.go | 8 ++++---- service/tcp_test.go | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/internal/integration_test/integration_test.go b/internal/integration_test/integration_test.go index 4ca2f120..43109b7a 100644 --- a/internal/integration_test/integration_test.go +++ b/internal/integration_test/integration_test.go @@ -133,7 +133,7 @@ func TestTCPEcho(t *testing.T) { const testTimeout = 200 * time.Millisecond testMetrics := &service.NoOpTCPMetrics{} authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics) - handler := service.NewTCPHandler(proxyListener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout) + handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout) handler.SetTargetDialer(&transport.TCPDialer{}) done := make(chan struct{}) go func() { @@ -202,7 +202,7 @@ func TestRestrictedAddresses(t *testing.T) { const testTimeout = 200 * time.Millisecond testMetrics := &statusMetrics{} authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := service.NewTCPHandler(proxyListener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout) + handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout) done := make(chan struct{}) go func() { service.StreamServe(service.WrapStreamListener(proxyListener.AcceptTCP), handler.Handle) @@ -384,7 +384,7 @@ func BenchmarkTCPThroughput(b *testing.B) { const testTimeout = 200 * time.Millisecond testMetrics := &service.NoOpTCPMetrics{} authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := service.NewTCPHandler(proxyListener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout) + handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout) handler.SetTargetDialer(&transport.TCPDialer{}) done := make(chan struct{}) go func() { @@ -448,7 +448,7 @@ func BenchmarkTCPMultiplexing(b *testing.B) { const testTimeout = 200 * time.Millisecond testMetrics := &service.NoOpTCPMetrics{} authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics) - handler := service.NewTCPHandler(proxyListener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout) + handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout) handler.SetTargetDialer(&transport.TCPDialer{}) done := make(chan struct{}) go func() { diff --git a/service/tcp_test.go b/service/tcp_test.go index 1a70ed67..14069756 100644 --- a/service/tcp_test.go +++ b/service/tcp_test.go @@ -281,7 +281,7 @@ func TestProbeRandom(t *testing.T) { require.NoError(t, err, "MakeTestCiphers failed: %v", err) testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, 200*time.Millisecond) + handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond) done := make(chan struct{}) go func() { StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) @@ -358,7 +358,7 @@ func TestProbeClientBytesBasicTruncated(t *testing.T) { cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, 200*time.Millisecond) + handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond) handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll)) done := make(chan struct{}) go func() { @@ -393,7 +393,7 @@ func TestProbeClientBytesBasicModified(t *testing.T) { cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, 200*time.Millisecond) + handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond) handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll)) done := make(chan struct{}) go func() { @@ -429,7 +429,7 @@ func TestProbeClientBytesCoalescedModified(t *testing.T) { cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, 200*time.Millisecond) + handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond) handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll)) done := make(chan struct{}) go func() { @@ -472,7 +472,7 @@ func TestProbeServerBytesModified(t *testing.T) { cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, 200*time.Millisecond) + handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond) done := make(chan struct{}) go func() { StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) @@ -503,7 +503,7 @@ func TestReplayDefense(t *testing.T) { testMetrics := &probeTestMetrics{} const testTimeout = 200 * time.Millisecond authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics) - handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout) + handler := NewTCPHandler(authFunc, testMetrics, testTimeout) snapshot := cipherList.SnapshotForClientIP(netip.Addr{}) cipherEntry := snapshot[0].Value.(*CipherEntry) cipher := cipherEntry.CryptoKey @@ -582,7 +582,7 @@ func TestReverseReplayDefense(t *testing.T) { testMetrics := &probeTestMetrics{} const testTimeout = 200 * time.Millisecond authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics) - handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout) + handler := NewTCPHandler(authFunc, testMetrics, testTimeout) snapshot := cipherList.SnapshotForClientIP(netip.Addr{}) cipherEntry := snapshot[0].Value.(*CipherEntry) cipher := cipherEntry.CryptoKey @@ -653,7 +653,7 @@ func probeExpectTimeout(t *testing.T, payloadSize int) { require.NoError(t, err, "MakeTestCiphers failed: %v", err) testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout) + handler := NewTCPHandler(authFunc, testMetrics, testTimeout) done := make(chan struct{}) go func() { From 7114434738b7369894de50586f8215527156a0ac Mon Sep 17 00:00:00 2001 From: sbruens Date: Fri, 14 Jun 2024 14:36:02 -0400 Subject: [PATCH 018/119] Move old config parsing to `loadConfig`. --- cmd/outline-ss-server/config.go | 23 --------------------- cmd/outline-ss-server/config_test.go | 31 +++++++++++----------------- cmd/outline-ss-server/main.go | 20 ++++++++++++++++++ 3 files changed, 32 insertions(+), 42 deletions(-) diff --git a/cmd/outline-ss-server/config.go b/cmd/outline-ss-server/config.go index 8ce54066..41a44d7a 100644 --- a/cmd/outline-ss-server/config.go +++ b/cmd/outline-ss-server/config.go @@ -69,29 +69,6 @@ func ReadConfig(filename string) (*Config, error) { if err != nil { return nil, fmt.Errorf("failed to parse config: %w", err) } - - // Specifying keys in `config.Keys` is a deprecated config format. We need to - // transform it to to the new format. - ports := make(map[int][]Key) - for _, keyConfig := range config.Keys { - ports[keyConfig.Port] = append(ports[keyConfig.Port], Key{ - ID: keyConfig.ID, - Cipher: keyConfig.Cipher, - Secret: keyConfig.Secret, - }) - } - for port, keys := range ports { - s := Service{ - Listeners: []Listener{ - Listener{Type: listenerTypeDirect, Address: fmt.Sprintf("tcp://[::]:%d", port)}, - Listener{Type: listenerTypeDirect, Address: fmt.Sprintf("udp://[::]:%d", port)}, - }, - Keys: keys, - } - config.Services = append(config.Services, s) - } - config.Keys = nil - return &config, nil } diff --git a/cmd/outline-ss-server/config_test.go b/cmd/outline-ss-server/config_test.go index 88660388..2793bbc4 100644 --- a/cmd/outline-ss-server/config_test.go +++ b/cmd/outline-ss-server/config_test.go @@ -56,29 +56,22 @@ func TestReadConfigParsesDeprecatedFormat(t *testing.T) { require.NoError(t, err) expected := Config{ - Services: []Service{ - Service{ - Listeners: []Listener{ - Listener{Type: listenerTypeDirect, Address: "tcp://[::]:9000"}, - Listener{Type: listenerTypeDirect, Address: "udp://[::]:9000"}, - }, - Keys: []Key{ - Key{"user-0", "chacha20-ietf-poly1305", "Secret0"}, - Key{"user-1", "chacha20-ietf-poly1305", "Secret1"}, - }, + Keys: []LegacyKeyService{ + LegacyKeyService{ + Key: Key{ID: "user-0", Cipher: "chacha20-ietf-poly1305", Secret: "Secret0"}, + Port: 9000, }, - Service{ - Listeners: []Listener{ - Listener{Type: listenerTypeDirect, Address: "tcp://[::]:9001"}, - Listener{Type: listenerTypeDirect, Address: "udp://[::]:9001"}, - }, - Keys: []Key{ - Key{"user-2", "chacha20-ietf-poly1305", "Secret2"}, - }, + LegacyKeyService{ + Key: Key{ID: "user-1", Cipher: "chacha20-ietf-poly1305", Secret: "Secret1"}, + Port: 9000, + }, + LegacyKeyService{ + Key: Key{ID: "user-2", Cipher: "chacha20-ietf-poly1305", Secret: "Secret2"}, + Port: 9001, }, }, } - require.ElementsMatch(t, expected.Services, config.Services) + require.Equal(t, expected, *config) } func TestReadConfigFromEmptyFile(t *testing.T) { diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 085b8c0e..50c13922 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -137,6 +137,26 @@ func (s *SSServer) loadConfig(filename string) error { uniqueCiphers := 0 addrChanges := make(map[string]int) addrCiphers := make(map[string]*list.List) // Values are *List of *CipherEntry. + + for _, legacyKeyConfig := range config.Keys { + cryptoKey, err := shadowsocks.NewEncryptionKey(legacyKeyConfig.Cipher, legacyKeyConfig.Secret) + if err != nil { + return fmt.Errorf("failed to create encyption key for key %v: %w", legacyKeyConfig.ID, err) + } + entry := service.MakeCipherEntry(legacyKeyConfig.ID, cryptoKey, legacyKeyConfig.Secret) + for _, ln := range []string{"tcp", "udp"} { + addr := fmt.Sprintf("%s://[::]:%d", ln, legacyKeyConfig.Port) + addrChanges[addr] = 1 + ciphers, ok := addrCiphers[addr] + if !ok { + ciphers = list.New() + addrCiphers[addr] = ciphers + } + ciphers.PushBack(&entry) + } + uniqueCiphers += 1 + } + for _, serviceConfig := range config.Services { if serviceConfig.Listeners == nil || serviceConfig.Keys == nil { return fmt.Errorf("must specify at least 1 listener and 1 key per service") From 1b2dd42b392bd65b35ab485bf739163df0b1b86e Mon Sep 17 00:00:00 2001 From: sbruens Date: Fri, 14 Jun 2024 14:43:29 -0400 Subject: [PATCH 019/119] Lowercase `readConfig`. --- cmd/outline-ss-server/config.go | 4 ++-- cmd/outline-ss-server/config_test.go | 10 +++++----- cmd/outline-ss-server/main.go | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/cmd/outline-ss-server/config.go b/cmd/outline-ss-server/config.go index 41a44d7a..4df7de83 100644 --- a/cmd/outline-ss-server/config.go +++ b/cmd/outline-ss-server/config.go @@ -58,8 +58,8 @@ type Config struct { Keys []LegacyKeyService } -// Reads a config from a filename and parses it as a [Config]. -func ReadConfig(filename string) (*Config, error) { +// readConfig attempts to read a config from a filename and parses it as a [Config]. +func readConfig(filename string) (*Config, error) { config := Config{} configData, err := os.ReadFile(filename) if err != nil { diff --git a/cmd/outline-ss-server/config_test.go b/cmd/outline-ss-server/config_test.go index 2793bbc4..fce0d8f8 100644 --- a/cmd/outline-ss-server/config_test.go +++ b/cmd/outline-ss-server/config_test.go @@ -22,7 +22,7 @@ import ( ) func TestReadConfig(t *testing.T) { - config, err := ReadConfig("./config_example.yml") + config, err := readConfig("./config_example.yml") require.NoError(t, err) expected := Config{ @@ -52,7 +52,7 @@ func TestReadConfig(t *testing.T) { } func TestReadConfigParsesDeprecatedFormat(t *testing.T) { - config, err := ReadConfig("./config_example.deprecated.yml") + config, err := readConfig("./config_example.deprecated.yml") require.NoError(t, err) expected := Config{ @@ -77,14 +77,14 @@ func TestReadConfigParsesDeprecatedFormat(t *testing.T) { func TestReadConfigFromEmptyFile(t *testing.T) { file, _ := os.CreateTemp("", "empty.yaml") - config, err := ReadConfig(file.Name()) + config, err := readConfig(file.Name()) require.NoError(t, err) require.ElementsMatch(t, Config{}, config) } func TestReadConfigFromNonExistingFileFails(t *testing.T) { - config, err := ReadConfig("./foo") + config, err := readConfig("./foo") require.Error(t, err) require.ElementsMatch(t, nil, config) @@ -94,7 +94,7 @@ func TestReadConfigFromIncorrectFormatFails(t *testing.T) { file, _ := os.CreateTemp("", "empty.yaml") file.WriteString("foo") - config, err := ReadConfig(file.Name()) + config, err := readConfig(file.Name()) require.Error(t, err) require.ElementsMatch(t, Config{}, config) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 50c13922..190657e8 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -129,7 +129,7 @@ func (s *SSServer) remove(addr string) error { } func (s *SSServer) loadConfig(filename string) error { - config, err := ReadConfig(filename) + config, err := readConfig(filename) if err != nil { return fmt.Errorf("failed to load config (%v): %w", filename, err) } From 744b2cf74d2e40ca07ef572f6bbb9f0b5a545dfe Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 17 Jun 2024 12:33:07 -0400 Subject: [PATCH 020/119] Add support for PROXY protocol on TCP. --- cmd/outline-ss-server/config.go | 5 +- cmd/outline-ss-server/config_example.yml | 4 + cmd/outline-ss-server/main.go | 106 +++++++++++++---------- go.mod | 1 + go.sum | 2 + service/tcp.go | 28 +++--- 6 files changed, 88 insertions(+), 58 deletions(-) diff --git a/cmd/outline-ss-server/config.go b/cmd/outline-ss-server/config.go index 4df7de83..4da1ce8d 100644 --- a/cmd/outline-ss-server/config.go +++ b/cmd/outline-ss-server/config.go @@ -32,7 +32,10 @@ type Service struct { type ListenerType string -const listenerTypeDirect ListenerType = "direct" +const ( + listenerTypeDirect ListenerType = "direct" + listenerTypeProxy ListenerType = "proxy_protocol" +) type Listener struct { Type ListenerType diff --git a/cmd/outline-ss-server/config_example.yml b/cmd/outline-ss-server/config_example.yml index 66009c10..14ceb313 100644 --- a/cmd/outline-ss-server/config_example.yml +++ b/cmd/outline-ss-server/config_example.yml @@ -4,6 +4,10 @@ services: address: "tcp://[::]:9000" - type: direct address: "udp://[::]:9000" + - type: proxy_protocol + address: "tcp://[::]:9010" + - type: proxy_protocol + address: "udp://[::]:9010" keys: - id: user-0 cipher: chacha20-ietf-poly1305 diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 190657e8..75636efc 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -15,6 +15,7 @@ package main import ( + "bufio" "container/list" "flag" "fmt" @@ -27,12 +28,12 @@ import ( "syscall" "time" - "github.com/Jigsaw-Code/outline-sdk/transport" "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks" "github.com/Jigsaw-Code/outline-ss-server/ipinfo" "github.com/Jigsaw-Code/outline-ss-server/service" "github.com/op/go-logging" + proxyproto "github.com/pires/go-proxyproto" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "golang.org/x/term" @@ -60,6 +61,8 @@ func init() { logger = logging.MustGetLogger("") } +type ListenerConfig = Listener + type ssListener struct { io.Closer cipherList service.CipherList @@ -69,23 +72,39 @@ type SSServer struct { natTimeout time.Duration m *outlineMetrics replayCache service.ReplayCache - listeners map[string]*ssListener + listeners map[ListenerConfig]*ssListener } -func (s *SSServer) serve(listener io.Closer, cipherList service.CipherList) error { +func (s *SSServer) serve(lnType ListenerType, listener io.Closer, cipherList service.CipherList) error { switch ln := listener.(type) { case net.Listener: authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &s.replayCache, s.m) // TODO: Register initial data metrics at zero. tcpHandler := service.NewTCPHandler(authFunc, s.m, tcpReadTimeout) - accept := func() (transport.StreamConn, error) { + accept := func() (service.ClientStreamConn, error) { conn, err := ln.Accept() if err != nil { - return nil, err + return service.ClientStreamConn{}, err } c := conn.(*net.TCPConn) c.SetKeepAlive(true) - return c, err + switch lnType { + case listenerTypeDirect: + return service.ClientStreamConn{StreamConn: c, ClientAddress: c.RemoteAddr()}, err + case listenerTypeProxy: + r := bufio.NewReader(c) + h, err := proxyproto.Read(r) + if err == proxyproto.ErrNoProxyProtocol { + logger.Warningf("Received connection from %v without proxy header.", c.RemoteAddr()) + return service.ClientStreamConn{StreamConn: c, ClientAddress: c.RemoteAddr()}, nil + } + if err != nil { + return service.ClientStreamConn{}, fmt.Errorf("error parsing proxy header: %v", err) + } + return service.ClientStreamConn{StreamConn: c, ClientAddress: h.SourceAddr}, nil + default: + return service.ClientStreamConn{}, fmt.Errorf("unknown listener config: %v", lnType) + } } go service.StreamServe(accept, tcpHandler.Handle) case net.PacketConn: @@ -97,34 +116,33 @@ func (s *SSServer) serve(listener io.Closer, cipherList service.CipherList) erro return nil } -func (s *SSServer) start(addr string, cipherList service.CipherList) (io.Closer, error) { - listener, err := newListener(addr) +func (s *SSServer) start(lnConfig ListenerConfig, cipherList service.CipherList) (io.Closer, error) { + listener, err := newListener(lnConfig.Address) if err != nil { - //lint:ignore ST1005 Shadowsocks is capitalized. - return nil, fmt.Errorf("Shadowsocks service failed to start on address %v: %w", addr, err) + return nil, fmt.Errorf("%s service failed to start on address %v: %w", lnConfig.Type, lnConfig.Address, err) } - logger.Infof("Shadowsocks service listening on %v", addr) + logger.Infof("%s service listening on %v", lnConfig.Type, lnConfig.Address) - err = s.serve(listener, cipherList) + err = s.serve(lnConfig.Type, listener, cipherList) if err != nil { - return nil, fmt.Errorf("failed to serve on listener %v: %w", listener, err) + return nil, fmt.Errorf("failed to serve %s on listener %v: %w", lnConfig.Type, listener, err) } return listener, nil } -func (s *SSServer) remove(addr string) error { - listener, ok := s.listeners[addr] +func (s *SSServer) remove(lnConfig ListenerConfig) error { + listener, ok := s.listeners[lnConfig] if !ok { - return fmt.Errorf("address %v doesn't exist", addr) + return fmt.Errorf("address %v doesn't exist", lnConfig.Address) } err := listener.Close() - delete(s.listeners, addr) + delete(s.listeners, lnConfig) if err != nil { //lint:ignore ST1005 Shadowsocks is capitalized. - return fmt.Errorf("Shadowsocks service on address %v failed to stop: %w", addr, err) + return fmt.Errorf("Shadowsocks service on address %v failed to stop: %w", lnConfig.Address, err) } - logger.Infof("Shadowsocks service on address %v stopped", addr) + logger.Infof("Shadowsocks service on address %v stopped", lnConfig.Address) return nil } @@ -135,8 +153,8 @@ func (s *SSServer) loadConfig(filename string) error { } uniqueCiphers := 0 - addrChanges := make(map[string]int) - addrCiphers := make(map[string]*list.List) // Values are *List of *CipherEntry. + listenerChanges := make(map[ListenerConfig]int) + listenerCiphers := make(map[ListenerConfig]*list.List) // Values are *List of *CipherEntry. for _, legacyKeyConfig := range config.Keys { cryptoKey, err := shadowsocks.NewEncryptionKey(legacyKeyConfig.Cipher, legacyKeyConfig.Secret) @@ -145,12 +163,12 @@ func (s *SSServer) loadConfig(filename string) error { } entry := service.MakeCipherEntry(legacyKeyConfig.ID, cryptoKey, legacyKeyConfig.Secret) for _, ln := range []string{"tcp", "udp"} { - addr := fmt.Sprintf("%s://[::]:%d", ln, legacyKeyConfig.Port) - addrChanges[addr] = 1 - ciphers, ok := addrCiphers[addr] + lnConfig := ListenerConfig{Type: listenerTypeDirect, Address: fmt.Sprintf("%s://[::]:%d", ln, legacyKeyConfig.Port)} + listenerChanges[lnConfig] = 1 + ciphers, ok := listenerCiphers[lnConfig] if !ok { ciphers = list.New() - addrCiphers[addr] = ciphers + listenerCiphers[lnConfig] = ciphers } ciphers.PushBack(&entry) } @@ -185,38 +203,32 @@ func (s *SSServer) loadConfig(filename string) error { } uniqueCiphers += ciphers.Len() - for _, listener := range serviceConfig.Listeners { - switch t := listener.Type; t { - // TODO: Support more listener types. - case listenerTypeDirect: - addrChanges[listener.Address] = 1 - addrCiphers[listener.Address] = ciphers - default: - return fmt.Errorf("unsupported listener type: %s", t) - } + for _, lnConfig := range serviceConfig.Listeners { + listenerChanges[lnConfig] = 1 + listenerCiphers[lnConfig] = ciphers } } - for listener := range s.listeners { - addrChanges[listener] = addrChanges[listener] - 1 + for lnConfig := range s.listeners { + listenerChanges[lnConfig] = listenerChanges[lnConfig] - 1 } - for addr, count := range addrChanges { + for lnConfig, count := range listenerChanges { if count == -1 { - if err := s.remove(addr); err != nil { - return fmt.Errorf("failed to remove address %v: %w", addr, err) + if err := s.remove(lnConfig); err != nil { + return fmt.Errorf("failed to remove %s listener on address %v: %w", lnConfig.Type, lnConfig.Address, err) } } else if count == +1 { cipherList := service.NewCipherList() - listener, err := s.start(addr, cipherList) + listener, err := s.start(lnConfig, cipherList) if err != nil { return err } - s.listeners[addr] = &ssListener{Closer: listener, cipherList: cipherList} + s.listeners[lnConfig] = &ssListener{Closer: listener, cipherList: cipherList} } } - for addr, ciphers := range addrCiphers { - listener, ok := s.listeners[addr] + for lnConfig, ciphers := range listenerCiphers { + listener, ok := s.listeners[lnConfig] if !ok { - return fmt.Errorf("unable to find listener for address: %v", addr) + return fmt.Errorf("unable to find listener for address: %v", lnConfig.Address) } listener.cipherList.Update(ciphers) } @@ -227,8 +239,8 @@ func (s *SSServer) loadConfig(filename string) error { // Stop serving on all ports. func (s *SSServer) Stop() error { - for addr := range s.listeners { - if err := s.remove(addr); err != nil { + for lnConfig := range s.listeners { + if err := s.remove(lnConfig); err != nil { return err } } @@ -241,7 +253,7 @@ func RunSSServer(filename string, natTimeout time.Duration, sm *outlineMetrics, natTimeout: natTimeout, m: sm, replayCache: service.NewReplayCache(replayHistory), - listeners: make(map[string]*ssListener), + listeners: make(map[ListenerConfig]*ssListener), } err := server.loadConfig(filename) if err != nil { diff --git a/go.mod b/go.mod index 8935374c..4b06c0bd 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/goreleaser/goreleaser v1.18.2 github.com/op/go-logging v0.0.0-20160315200505-970db520ece7 github.com/oschwald/geoip2-golang v1.8.0 + github.com/pires/go-proxyproto v0.7.0 github.com/prometheus/client_golang v1.15.0 github.com/shadowsocks/go-shadowsocks2 v0.1.5 github.com/stretchr/testify v1.8.4 diff --git a/go.sum b/go.sum index 846e9c24..b8ebb49b 100644 --- a/go.sum +++ b/go.sum @@ -1884,6 +1884,8 @@ github.com/pelletier/go-toml/v2 v2.0.6/go.mod h1:eumQOmlWiOPt5WriQQqoM5y18pDHwha github.com/performancecopilot/speed/v4 v4.0.0/go.mod h1:qxrSyuDGrTOWfV+uKRFhfxw6h/4HXRGUiZiufxo49BM= github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU= github.com/pierrec/lz4 v1.0.2-0.20190131084431-473cd7ce01a1/go.mod h1:3/3N9NVKO0jef7pBehbT1qWhCMrIgbYNnFAZCqQ5LRc= +github.com/pires/go-proxyproto v0.7.0 h1:IukmRewDQFWC7kfnb66CSomk2q/seBuilHBYFwyq0Hs= +github.com/pires/go-proxyproto v0.7.0/go.mod h1:Vz/1JPY/OACxWGQNIRY2BeyDmpoaWmEP40O9LbuiFR4= github.com/pjbgf/sha1cd v0.3.0 h1:4D5XXmUUBUl/xQ6IjCkEAbqXskkq/4O7LmGn0AqMDs4= github.com/pjbgf/sha1cd v0.3.0/go.mod h1:nZ1rrWOcGJ5uZgEEVL1VUM9iRQiZvWdbZjkKyFzPPsI= github.com/pkg/browser v0.0.0-20210115035449-ce105d075bb4/go.mod h1:N6UoU20jOqggOuDwUaBQpluzLNDqif3kq9z2wpdYEfQ= diff --git a/service/tcp.go b/service/tcp.go index 484a1b90..ee38c1d5 100644 --- a/service/tcp.go +++ b/service/tcp.go @@ -188,9 +188,9 @@ func makeValidatingTCPStreamDialer(targetIPValidator onet.TargetIPValidator) tra }}} } -// TCPService is a Shadowsocks TCP service that can be started and stopped. +// TCPHandler is a Shadowsocks TCP service that can be started and stopped. type TCPHandler interface { - Handle(ctx context.Context, conn transport.StreamConn) + Handle(ctx context.Context, conn ClientStreamConn) // SetTargetDialer sets the [transport.StreamDialer] to be used to connect to target addresses. SetTargetDialer(dialer transport.StreamDialer) } @@ -211,15 +211,23 @@ func ensureConnectionError(err error, fallbackStatus string, fallbackMsg string) } } -type StreamListener func() (transport.StreamConn, error) +// ClientStreamConn wraps a [transport.StreamConn] and sets the client source of the connection. +// This is useful for handling the PROXY protocol where the RemoteAddr() points to the +// server/load balancer address and we need the perceived source of the connection. +type ClientStreamConn struct { + transport.StreamConn + ClientAddress net.Addr +} + +type StreamListener func() (ClientStreamConn, error) -func WrapStreamListener[T transport.StreamConn](f func() (T, error)) StreamListener { - return func() (transport.StreamConn, error) { +func WrapStreamListener[T ClientStreamConn](f func() (T, error)) StreamListener { + return func() (ClientStreamConn, error) { return f() } } -type StreamHandler func(ctx context.Context, conn transport.StreamConn) +type StreamHandler func(ctx context.Context, conn ClientStreamConn) // StreamServe repeatedly calls `accept` to obtain connections and `handle` to handle them until // accept() returns [ErrClosed]. When that happens, all connection handlers will be notified @@ -253,12 +261,12 @@ func StreamServe(accept StreamListener, handle StreamHandler) { } } -func (h *tcpHandler) Handle(ctx context.Context, clientConn transport.StreamConn) { - clientInfo, err := ipinfo.GetIPInfoFromAddr(h.m, clientConn.RemoteAddr()) +func (h *tcpHandler) Handle(ctx context.Context, clientConn ClientStreamConn) { + clientInfo, err := ipinfo.GetIPInfoFromAddr(h.m, clientConn.ClientAddress) if err != nil { logger.Warningf("Failed client info lookup: %v", err) } - logger.Debugf("Got info \"%#v\" for IP %v", clientInfo, clientConn.RemoteAddr().String()) + logger.Debugf("Got info \"%#v\" for IP %v", clientInfo, clientConn.ClientAddress.String()) h.m.AddOpenTCPConnection(clientInfo) var proxyMetrics metrics.ProxyMetrics measuredClientConn := metrics.MeasureConn(clientConn, &proxyMetrics.ProxyClient, &proxyMetrics.ClientProxy) @@ -272,7 +280,7 @@ func (h *tcpHandler) Handle(ctx context.Context, clientConn transport.StreamConn status = connError.Status logger.Debugf("TCP Error: %v: %v", connError.Message, connError.Cause) } - h.m.AddClosedTCPConnection(clientInfo, clientConn.RemoteAddr(), id, status, proxyMetrics, connDuration) + h.m.AddClosedTCPConnection(clientInfo, clientConn.ClientAddress, id, status, proxyMetrics, connDuration) measuredClientConn.Close() // Closing after the metrics are added aids integration testing. logger.Debugf("Done with status %v, duration %v", status, connDuration) } From c5fb99a0a28f7e880a2fd01a5fc4fc898bd52c2d Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 17 Jun 2024 16:04:44 -0400 Subject: [PATCH 021/119] Create a `Listen` func to create proxy or direct listeners based on the config. --- cmd/outline-ss-server/config.go | 22 +++++++++++ cmd/outline-ss-server/main.go | 42 +++++---------------- service/listeners.go | 65 +++++++++++++++++++++++++++++++++ service/tcp.go | 36 ++++++++++++------ 4 files changed, 121 insertions(+), 44 deletions(-) create mode 100644 service/listeners.go diff --git a/cmd/outline-ss-server/config.go b/cmd/outline-ss-server/config.go index 4da1ce8d..ecfbc7a2 100644 --- a/cmd/outline-ss-server/config.go +++ b/cmd/outline-ss-server/config.go @@ -22,6 +22,7 @@ import ( "net/url" "os" + "github.com/Jigsaw-Code/outline-ss-server/service" "gopkg.in/yaml.v2" ) @@ -122,3 +123,24 @@ func newListener(addr string) (io.Closer, error) { return nil, fmt.Errorf("unsupported protocol: %s", u.Scheme) } } + +// Listen creates a new listener based on a given [Listener] config. +func Listen(config Listener) (io.Closer, error) { + listener, err := newListener(config.Address) + if err != nil { + return nil, err + } + switch ln := listener.(type) { + case net.Listener: + switch config.Type { + case listenerTypeDirect: + return &service.DirectListener{Listener: ln}, nil + case listenerTypeProxy: + return &service.ProxyListener{Listener: ln}, nil + default: + return ln, err + } + default: + return listener, err + } +} diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 75636efc..29f0dc90 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -15,7 +15,6 @@ package main import ( - "bufio" "container/list" "flag" "fmt" @@ -33,7 +32,6 @@ import ( "github.com/Jigsaw-Code/outline-ss-server/ipinfo" "github.com/Jigsaw-Code/outline-ss-server/service" "github.com/op/go-logging" - proxyproto "github.com/pires/go-proxyproto" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "golang.org/x/term" @@ -63,6 +61,11 @@ func init() { type ListenerConfig = Listener +type ClientStreamListener interface { + Accept() (service.IClientStreamConn, error) + Close() error +} + type ssListener struct { io.Closer cipherList service.CipherList @@ -75,38 +78,13 @@ type SSServer struct { listeners map[ListenerConfig]*ssListener } -func (s *SSServer) serve(lnType ListenerType, listener io.Closer, cipherList service.CipherList) error { +func (s *SSServer) serve(listener io.Closer, cipherList service.CipherList) error { switch ln := listener.(type) { - case net.Listener: + case ClientStreamListener: authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &s.replayCache, s.m) // TODO: Register initial data metrics at zero. tcpHandler := service.NewTCPHandler(authFunc, s.m, tcpReadTimeout) - accept := func() (service.ClientStreamConn, error) { - conn, err := ln.Accept() - if err != nil { - return service.ClientStreamConn{}, err - } - c := conn.(*net.TCPConn) - c.SetKeepAlive(true) - switch lnType { - case listenerTypeDirect: - return service.ClientStreamConn{StreamConn: c, ClientAddress: c.RemoteAddr()}, err - case listenerTypeProxy: - r := bufio.NewReader(c) - h, err := proxyproto.Read(r) - if err == proxyproto.ErrNoProxyProtocol { - logger.Warningf("Received connection from %v without proxy header.", c.RemoteAddr()) - return service.ClientStreamConn{StreamConn: c, ClientAddress: c.RemoteAddr()}, nil - } - if err != nil { - return service.ClientStreamConn{}, fmt.Errorf("error parsing proxy header: %v", err) - } - return service.ClientStreamConn{StreamConn: c, ClientAddress: h.SourceAddr}, nil - default: - return service.ClientStreamConn{}, fmt.Errorf("unknown listener config: %v", lnType) - } - } - go service.StreamServe(accept, tcpHandler.Handle) + go service.StreamServe(ln.Accept, tcpHandler.Handle) case net.PacketConn: packetHandler := service.NewPacketHandler(s.natTimeout, cipherList, s.m) go packetHandler.Handle(ln) @@ -117,13 +95,13 @@ func (s *SSServer) serve(lnType ListenerType, listener io.Closer, cipherList ser } func (s *SSServer) start(lnConfig ListenerConfig, cipherList service.CipherList) (io.Closer, error) { - listener, err := newListener(lnConfig.Address) + listener, err := Listen(lnConfig) if err != nil { return nil, fmt.Errorf("%s service failed to start on address %v: %w", lnConfig.Type, lnConfig.Address, err) } logger.Infof("%s service listening on %v", lnConfig.Type, lnConfig.Address) - err = s.serve(lnConfig.Type, listener, cipherList) + err = s.serve(listener, cipherList) if err != nil { return nil, fmt.Errorf("failed to serve %s on listener %v: %w", lnConfig.Type, listener, err) } diff --git a/service/listeners.go b/service/listeners.go new file mode 100644 index 00000000..362b3cc1 --- /dev/null +++ b/service/listeners.go @@ -0,0 +1,65 @@ +// Copyright 2024 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "bufio" + "fmt" + "net" + + "github.com/Jigsaw-Code/outline-sdk/transport" + proxyproto "github.com/pires/go-proxyproto" +) + +type DirectListener struct { + net.Listener +} + +func (l *DirectListener) Accept() (IClientStreamConn, error) { + c, err := l.Listener.Accept() + if err != nil { + return nil, err + } + c.(*net.TCPConn).SetKeepAlive(true) + return &clientStreamConn{StreamConn: c.(transport.StreamConn)}, nil +} + +// ProxyListener wraps a [net.Listener] and fetches the source of the connection from the PROXY +// protocol header string. See https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt. +type ProxyListener struct { + net.Listener +} + +// Accept waits for and returns the next incoming connection. +func (l *ProxyListener) Accept() (IClientStreamConn, error) { + c, err := l.Listener.Accept() + if err != nil { + return nil, err + } + conn := c.(transport.StreamConn) + r := bufio.NewReader(conn) + h, err := proxyproto.Read(r) + if err == proxyproto.ErrNoProxyProtocol { + logger.Warningf("Received connection from %v without proxy header.", conn.RemoteAddr()) + return &clientStreamConn{StreamConn: conn}, nil + } + if err != nil { + return nil, fmt.Errorf("error parsing proxy header: %v", err) + } + + conn.(*net.TCPConn).SetKeepAlive(true) + clientConn := transport.WrapConn(conn, r, conn) + return &clientStreamConn{clientConn, h.SourceAddr}, nil +} diff --git a/service/tcp.go b/service/tcp.go index ee38c1d5..c8fe205b 100644 --- a/service/tcp.go +++ b/service/tcp.go @@ -190,7 +190,7 @@ func makeValidatingTCPStreamDialer(targetIPValidator onet.TargetIPValidator) tra // TCPHandler is a Shadowsocks TCP service that can be started and stopped. type TCPHandler interface { - Handle(ctx context.Context, conn ClientStreamConn) + Handle(ctx context.Context, conn IClientStreamConn) // SetTargetDialer sets the [transport.StreamDialer] to be used to connect to target addresses. SetTargetDialer(dialer transport.StreamDialer) } @@ -211,23 +211,35 @@ func ensureConnectionError(err error, fallbackStatus string, fallbackMsg string) } } -// ClientStreamConn wraps a [transport.StreamConn] and sets the client source of the connection. +// IClientStreamConn wraps a [transport.StreamConn] and sets the client source of the connection. // This is useful for handling the PROXY protocol where the RemoteAddr() points to the // server/load balancer address and we need the perceived source of the connection. -type ClientStreamConn struct { +type IClientStreamConn interface { transport.StreamConn - ClientAddress net.Addr + ClientAddr() net.Addr } -type StreamListener func() (ClientStreamConn, error) +type clientStreamConn struct { + transport.StreamConn + clientAddr net.Addr +} + +func (c *clientStreamConn) ClientAddr() net.Addr { + if c.clientAddr != nil { + return c.clientAddr + } + return c.StreamConn.RemoteAddr() +} + +type StreamListener func() (IClientStreamConn, error) -func WrapStreamListener[T ClientStreamConn](f func() (T, error)) StreamListener { - return func() (ClientStreamConn, error) { +func WrapStreamListener[T IClientStreamConn](f func() (T, error)) StreamListener { + return func() (IClientStreamConn, error) { return f() } } -type StreamHandler func(ctx context.Context, conn ClientStreamConn) +type StreamHandler func(ctx context.Context, conn IClientStreamConn) // StreamServe repeatedly calls `accept` to obtain connections and `handle` to handle them until // accept() returns [ErrClosed]. When that happens, all connection handlers will be notified @@ -261,12 +273,12 @@ func StreamServe(accept StreamListener, handle StreamHandler) { } } -func (h *tcpHandler) Handle(ctx context.Context, clientConn ClientStreamConn) { - clientInfo, err := ipinfo.GetIPInfoFromAddr(h.m, clientConn.ClientAddress) +func (h *tcpHandler) Handle(ctx context.Context, clientConn IClientStreamConn) { + clientInfo, err := ipinfo.GetIPInfoFromAddr(h.m, clientConn.ClientAddr()) if err != nil { logger.Warningf("Failed client info lookup: %v", err) } - logger.Debugf("Got info \"%#v\" for IP %v", clientInfo, clientConn.ClientAddress.String()) + logger.Debugf("Got info \"%#v\" for IP %v", clientInfo, clientConn.ClientAddr().String()) h.m.AddOpenTCPConnection(clientInfo) var proxyMetrics metrics.ProxyMetrics measuredClientConn := metrics.MeasureConn(clientConn, &proxyMetrics.ProxyClient, &proxyMetrics.ClientProxy) @@ -280,7 +292,7 @@ func (h *tcpHandler) Handle(ctx context.Context, clientConn ClientStreamConn) { status = connError.Status logger.Debugf("TCP Error: %v: %v", connError.Message, connError.Cause) } - h.m.AddClosedTCPConnection(clientInfo, clientConn.ClientAddress, id, status, proxyMetrics, connDuration) + h.m.AddClosedTCPConnection(clientInfo, clientConn.ClientAddr(), id, status, proxyMetrics, connDuration) measuredClientConn.Close() // Closing after the metrics are added aids integration testing. logger.Debugf("Done with status %v, duration %v", status, connDuration) } From 802e689dd8acff670e91218e7efec484ed50fea4 Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 17 Jun 2024 16:08:22 -0400 Subject: [PATCH 022/119] Update config test. --- cmd/outline-ss-server/config_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cmd/outline-ss-server/config_test.go b/cmd/outline-ss-server/config_test.go index fce0d8f8..a487fb56 100644 --- a/cmd/outline-ss-server/config_test.go +++ b/cmd/outline-ss-server/config_test.go @@ -31,6 +31,8 @@ func TestReadConfig(t *testing.T) { Listeners: []Listener{ Listener{Type: listenerTypeDirect, Address: "tcp://[::]:9000"}, Listener{Type: listenerTypeDirect, Address: "udp://[::]:9000"}, + Listener{Type: listenerTypeProxy, Address: "tcp://[::]:9010"}, + Listener{Type: listenerTypeProxy, Address: "udp://[::]:9010"}, }, Keys: []Key{ Key{"user-0", "chacha20-ietf-poly1305", "Secret0"}, From e085d4c7387cf94ae071d1fe5ab577e61e50154a Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 17 Jun 2024 16:50:05 -0400 Subject: [PATCH 023/119] Wrap direct listener inside proxy protocol listener. --- cmd/outline-ss-server/config.go | 11 ++++------ internal/integration_test/integration_test.go | 6 +++--- service/listeners.go | 21 +++++++++---------- service/tcp.go | 6 +++--- service/tcp_test.go | 20 +++++++++--------- 5 files changed, 30 insertions(+), 34 deletions(-) diff --git a/cmd/outline-ss-server/config.go b/cmd/outline-ss-server/config.go index ecfbc7a2..b2a83cbb 100644 --- a/cmd/outline-ss-server/config.go +++ b/cmd/outline-ss-server/config.go @@ -132,14 +132,11 @@ func Listen(config Listener) (io.Closer, error) { } switch ln := listener.(type) { case net.Listener: - switch config.Type { - case listenerTypeDirect: - return &service.DirectListener{Listener: ln}, nil - case listenerTypeProxy: - return &service.ProxyListener{Listener: ln}, nil - default: - return ln, err + streamListener := &service.StreamListener{Listener: ln} + if config.Type == listenerTypeProxy { + return &service.ProxyListener{StreamListener: *streamListener}, err } + return streamListener, err default: return listener, err } diff --git a/internal/integration_test/integration_test.go b/internal/integration_test/integration_test.go index 43109b7a..9f35da70 100644 --- a/internal/integration_test/integration_test.go +++ b/internal/integration_test/integration_test.go @@ -205,7 +205,7 @@ func TestRestrictedAddresses(t *testing.T) { handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout) done := make(chan struct{}) go func() { - service.StreamServe(service.WrapStreamListener(proxyListener.AcceptTCP), handler.Handle) + service.StreamServe(service.WrapStreamAccepter(proxyListener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -388,7 +388,7 @@ func BenchmarkTCPThroughput(b *testing.B) { handler.SetTargetDialer(&transport.TCPDialer{}) done := make(chan struct{}) go func() { - service.StreamServe(service.WrapStreamListener(proxyListener.AcceptTCP), handler.Handle) + service.StreamServe(service.WrapStreamAccepter(proxyListener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -452,7 +452,7 @@ func BenchmarkTCPMultiplexing(b *testing.B) { handler.SetTargetDialer(&transport.TCPDialer{}) done := make(chan struct{}) go func() { - service.StreamServe(service.WrapStreamListener(proxyListener.AcceptTCP), handler.Handle) + service.StreamServe(service.WrapStreamAccepter(proxyListener.AcceptTCP), handler.Handle) done <- struct{}{} }() diff --git a/service/listeners.go b/service/listeners.go index 362b3cc1..3fbe68fc 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -23,11 +23,13 @@ import ( proxyproto "github.com/pires/go-proxyproto" ) -type DirectListener struct { +// StreamListener wraps a [net.Listener]. +type StreamListener struct { net.Listener } -func (l *DirectListener) Accept() (IClientStreamConn, error) { +// Accept waits for and returns the next incoming connection. +func (l *StreamListener) Accept() (IClientStreamConn, error) { c, err := l.Listener.Accept() if err != nil { return nil, err @@ -39,27 +41,24 @@ func (l *DirectListener) Accept() (IClientStreamConn, error) { // ProxyListener wraps a [net.Listener] and fetches the source of the connection from the PROXY // protocol header string. See https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt. type ProxyListener struct { - net.Listener + StreamListener } -// Accept waits for and returns the next incoming connection. +// Accept waits for the next incoming connection, parses the client IP from the PROXY protocol +// header, and adds it to the connection. func (l *ProxyListener) Accept() (IClientStreamConn, error) { - c, err := l.Listener.Accept() + conn, err := l.StreamListener.Accept() if err != nil { return nil, err } - conn := c.(transport.StreamConn) r := bufio.NewReader(conn) h, err := proxyproto.Read(r) if err == proxyproto.ErrNoProxyProtocol { logger.Warningf("Received connection from %v without proxy header.", conn.RemoteAddr()) - return &clientStreamConn{StreamConn: conn}, nil + return conn, nil } if err != nil { return nil, fmt.Errorf("error parsing proxy header: %v", err) } - - conn.(*net.TCPConn).SetKeepAlive(true) - clientConn := transport.WrapConn(conn, r, conn) - return &clientStreamConn{clientConn, h.SourceAddr}, nil + return &clientStreamConn{StreamConn: conn, clientAddr: h.SourceAddr}, nil } diff --git a/service/tcp.go b/service/tcp.go index c8fe205b..bdaa1641 100644 --- a/service/tcp.go +++ b/service/tcp.go @@ -231,9 +231,9 @@ func (c *clientStreamConn) ClientAddr() net.Addr { return c.StreamConn.RemoteAddr() } -type StreamListener func() (IClientStreamConn, error) +type StreamAccepter func() (IClientStreamConn, error) -func WrapStreamListener[T IClientStreamConn](f func() (T, error)) StreamListener { +func WrapStreamAccepter[T IClientStreamConn](f func() (T, error)) StreamAccepter { return func() (IClientStreamConn, error) { return f() } @@ -244,7 +244,7 @@ type StreamHandler func(ctx context.Context, conn IClientStreamConn) // StreamServe repeatedly calls `accept` to obtain connections and `handle` to handle them until // accept() returns [ErrClosed]. When that happens, all connection handlers will be notified // via their [context.Context]. StreamServe will return after all pending handlers return. -func StreamServe(accept StreamListener, handle StreamHandler) { +func StreamServe(accept StreamAccepter, handle StreamHandler) { var running sync.WaitGroup defer running.Wait() ctx, contextCancel := context.WithCancel(context.Background()) diff --git a/service/tcp_test.go b/service/tcp_test.go index 14069756..51948b53 100644 --- a/service/tcp_test.go +++ b/service/tcp_test.go @@ -284,7 +284,7 @@ func TestProbeRandom(t *testing.T) { handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond) done := make(chan struct{}) go func() { - StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) + StreamServe(WrapStreamAccepter(listener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -362,7 +362,7 @@ func TestProbeClientBytesBasicTruncated(t *testing.T) { handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll)) done := make(chan struct{}) go func() { - StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) + StreamServe(WrapStreamAccepter(listener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -397,7 +397,7 @@ func TestProbeClientBytesBasicModified(t *testing.T) { handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll)) done := make(chan struct{}) go func() { - StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) + StreamServe(WrapStreamAccepter(listener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -433,7 +433,7 @@ func TestProbeClientBytesCoalescedModified(t *testing.T) { handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll)) done := make(chan struct{}) go func() { - StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) + StreamServe(WrapStreamAccepter(listener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -475,7 +475,7 @@ func TestProbeServerBytesModified(t *testing.T) { handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond) done := make(chan struct{}) go func() { - StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) + StreamServe(WrapStreamAccepter(listener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -528,7 +528,7 @@ func TestReplayDefense(t *testing.T) { done := make(chan struct{}) go func() { - StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) + StreamServe(WrapStreamAccepter(listener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -598,7 +598,7 @@ func TestReverseReplayDefense(t *testing.T) { done := make(chan struct{}) go func() { - StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) + StreamServe(WrapStreamAccepter(listener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -657,7 +657,7 @@ func probeExpectTimeout(t *testing.T, payloadSize int) { done := make(chan struct{}) go func() { - StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) + StreamServe(WrapStreamAccepter(listener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -717,14 +717,14 @@ func TestStreamServeEarlyClose(t *testing.T) { err = tcpListener.Close() require.NoError(t, err) // This should return quickly, without timing out or calling the handler. - StreamServe(WrapStreamListener(tcpListener.AcceptTCP), nil) + StreamServe(WrapStreamAccepter(tcpListener.AcceptTCP), nil) } // Makes sure the TCP listener returns [io.ErrClosed] on Close(). func TestClosedTCPListenerError(t *testing.T) { tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) require.NoError(t, err) - accept := WrapStreamListener(tcpListener.AcceptTCP) + accept := WrapStreamAccepter(tcpListener.AcceptTCP) err = tcpListener.Close() require.NoError(t, err) _, err = accept() From 98ccbce5b4fafe4f8f479dc3fca9e7b17f86129c Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 17 Jun 2024 17:07:07 -0400 Subject: [PATCH 024/119] Move listeners into `tcp.go`. --- service/listeners.go | 64 -------------------------------------------- service/tcp.go | 42 +++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 64 deletions(-) delete mode 100644 service/listeners.go diff --git a/service/listeners.go b/service/listeners.go deleted file mode 100644 index 3fbe68fc..00000000 --- a/service/listeners.go +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright 2024 Jigsaw Operations LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package service - -import ( - "bufio" - "fmt" - "net" - - "github.com/Jigsaw-Code/outline-sdk/transport" - proxyproto "github.com/pires/go-proxyproto" -) - -// StreamListener wraps a [net.Listener]. -type StreamListener struct { - net.Listener -} - -// Accept waits for and returns the next incoming connection. -func (l *StreamListener) Accept() (IClientStreamConn, error) { - c, err := l.Listener.Accept() - if err != nil { - return nil, err - } - c.(*net.TCPConn).SetKeepAlive(true) - return &clientStreamConn{StreamConn: c.(transport.StreamConn)}, nil -} - -// ProxyListener wraps a [net.Listener] and fetches the source of the connection from the PROXY -// protocol header string. See https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt. -type ProxyListener struct { - StreamListener -} - -// Accept waits for the next incoming connection, parses the client IP from the PROXY protocol -// header, and adds it to the connection. -func (l *ProxyListener) Accept() (IClientStreamConn, error) { - conn, err := l.StreamListener.Accept() - if err != nil { - return nil, err - } - r := bufio.NewReader(conn) - h, err := proxyproto.Read(r) - if err == proxyproto.ErrNoProxyProtocol { - logger.Warningf("Received connection from %v without proxy header.", conn.RemoteAddr()) - return conn, nil - } - if err != nil { - return nil, fmt.Errorf("error parsing proxy header: %v", err) - } - return &clientStreamConn{StreamConn: conn, clientAddr: h.SourceAddr}, nil -} diff --git a/service/tcp.go b/service/tcp.go index bdaa1641..9e1980c2 100644 --- a/service/tcp.go +++ b/service/tcp.go @@ -15,6 +15,7 @@ package service import ( + "bufio" "bytes" "container/list" "context" @@ -33,6 +34,7 @@ import ( onet "github.com/Jigsaw-Code/outline-ss-server/net" "github.com/Jigsaw-Code/outline-ss-server/service/metrics" logging "github.com/op/go-logging" + proxyproto "github.com/pires/go-proxyproto" "github.com/shadowsocks/go-shadowsocks2/socks" ) @@ -231,6 +233,46 @@ func (c *clientStreamConn) ClientAddr() net.Addr { return c.StreamConn.RemoteAddr() } +// StreamListener wraps a [net.Listener]. +type StreamListener struct { + net.Listener +} + +// Accept waits for and returns the next incoming connection. +func (l *StreamListener) Accept() (IClientStreamConn, error) { + c, err := l.Listener.Accept() + if err != nil { + return nil, err + } + c.(*net.TCPConn).SetKeepAlive(true) + return &clientStreamConn{StreamConn: c.(transport.StreamConn)}, nil +} + +// ProxyListener wraps a [StreamListener] and fetches the source of the connection from the PROXY +// protocol header string. See https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt. +type ProxyListener struct { + StreamListener +} + +// Accept waits for the next incoming connection, parses the client IP from the PROXY protocol +// header, and adds it to the connection. +func (l *ProxyListener) Accept() (IClientStreamConn, error) { + conn, err := l.StreamListener.Accept() + if err != nil { + return nil, err + } + r := bufio.NewReader(conn) + h, err := proxyproto.Read(r) + if err == proxyproto.ErrNoProxyProtocol { + logger.Warningf("Received connection from %v without proxy header.", conn.RemoteAddr()) + return conn, nil + } + if err != nil { + return nil, fmt.Errorf("error parsing proxy header: %v", err) + } + return &clientStreamConn{StreamConn: conn, clientAddr: h.SourceAddr}, nil +} + type StreamAccepter func() (IClientStreamConn, error) func WrapStreamAccepter[T IClientStreamConn](f func() (T, error)) StreamAccepter { From d83e6ba5e60f9876d9f1eb440cddade89739e04f Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 17 Jun 2024 17:15:07 -0400 Subject: [PATCH 025/119] Fix tests. --- internal/integration_test/integration_test.go | 2 +- service/tcp.go | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/internal/integration_test/integration_test.go b/internal/integration_test/integration_test.go index 9f35da70..d5597dc7 100644 --- a/internal/integration_test/integration_test.go +++ b/internal/integration_test/integration_test.go @@ -137,7 +137,7 @@ func TestTCPEcho(t *testing.T) { handler.SetTargetDialer(&transport.TCPDialer{}) done := make(chan struct{}) go func() { - service.StreamServe(func() (transport.StreamConn, error) { return proxyListener.AcceptTCP() }, handler.Handle) + service.StreamServe(service.WrapStreamAccepter(proxyListener.AcceptTCP), handler.Handle) done <- struct{}{} }() diff --git a/service/tcp.go b/service/tcp.go index 9e1980c2..436f5896 100644 --- a/service/tcp.go +++ b/service/tcp.go @@ -275,9 +275,10 @@ func (l *ProxyListener) Accept() (IClientStreamConn, error) { type StreamAccepter func() (IClientStreamConn, error) -func WrapStreamAccepter[T IClientStreamConn](f func() (T, error)) StreamAccepter { +func WrapStreamAccepter[T transport.StreamConn](f func() (T, error)) StreamAccepter { return func() (IClientStreamConn, error) { - return f() + c, err := f() + return &clientStreamConn{StreamConn: c}, err } } From 870b3790a443f2baf7789fb9119017818e91b97b Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 17 Jun 2024 17:48:06 -0400 Subject: [PATCH 026/119] Rename `IClientStreamConn` to `ClientStreamConn`. --- cmd/outline-ss-server/main.go | 2 +- service/tcp.go | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 29f0dc90..9c08fc23 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -62,7 +62,7 @@ func init() { type ListenerConfig = Listener type ClientStreamListener interface { - Accept() (service.IClientStreamConn, error) + Accept() (service.ClientStreamConn, error) Close() error } diff --git a/service/tcp.go b/service/tcp.go index 436f5896..1a9fa8fa 100644 --- a/service/tcp.go +++ b/service/tcp.go @@ -192,7 +192,7 @@ func makeValidatingTCPStreamDialer(targetIPValidator onet.TargetIPValidator) tra // TCPHandler is a Shadowsocks TCP service that can be started and stopped. type TCPHandler interface { - Handle(ctx context.Context, conn IClientStreamConn) + Handle(ctx context.Context, conn ClientStreamConn) // SetTargetDialer sets the [transport.StreamDialer] to be used to connect to target addresses. SetTargetDialer(dialer transport.StreamDialer) } @@ -213,10 +213,10 @@ func ensureConnectionError(err error, fallbackStatus string, fallbackMsg string) } } -// IClientStreamConn wraps a [transport.StreamConn] and sets the client source of the connection. +// ClientStreamConn wraps a [transport.StreamConn] and sets the client source of the connection. // This is useful for handling the PROXY protocol where the RemoteAddr() points to the // server/load balancer address and we need the perceived source of the connection. -type IClientStreamConn interface { +type ClientStreamConn interface { transport.StreamConn ClientAddr() net.Addr } @@ -239,7 +239,7 @@ type StreamListener struct { } // Accept waits for and returns the next incoming connection. -func (l *StreamListener) Accept() (IClientStreamConn, error) { +func (l *StreamListener) Accept() (ClientStreamConn, error) { c, err := l.Listener.Accept() if err != nil { return nil, err @@ -256,7 +256,7 @@ type ProxyListener struct { // Accept waits for the next incoming connection, parses the client IP from the PROXY protocol // header, and adds it to the connection. -func (l *ProxyListener) Accept() (IClientStreamConn, error) { +func (l *ProxyListener) Accept() (ClientStreamConn, error) { conn, err := l.StreamListener.Accept() if err != nil { return nil, err @@ -273,16 +273,16 @@ func (l *ProxyListener) Accept() (IClientStreamConn, error) { return &clientStreamConn{StreamConn: conn, clientAddr: h.SourceAddr}, nil } -type StreamAccepter func() (IClientStreamConn, error) +type StreamAccepter func() (ClientStreamConn, error) func WrapStreamAccepter[T transport.StreamConn](f func() (T, error)) StreamAccepter { - return func() (IClientStreamConn, error) { + return func() (ClientStreamConn, error) { c, err := f() return &clientStreamConn{StreamConn: c}, err } } -type StreamHandler func(ctx context.Context, conn IClientStreamConn) +type StreamHandler func(ctx context.Context, conn ClientStreamConn) // StreamServe repeatedly calls `accept` to obtain connections and `handle` to handle them until // accept() returns [ErrClosed]. When that happens, all connection handlers will be notified @@ -316,7 +316,7 @@ func StreamServe(accept StreamAccepter, handle StreamHandler) { } } -func (h *tcpHandler) Handle(ctx context.Context, clientConn IClientStreamConn) { +func (h *tcpHandler) Handle(ctx context.Context, clientConn ClientStreamConn) { clientInfo, err := ipinfo.GetIPInfoFromAddr(h.m, clientConn.ClientAddr()) if err != nil { logger.Warningf("Failed client info lookup: %v", err) From 00d484b5aa4947243c75c4dd3cbcb31a1e3d48d2 Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 17 Jun 2024 17:49:48 -0400 Subject: [PATCH 027/119] Rename `WrapStreamAccepter` to `WrapStreamListener`. --- internal/integration_test/integration_test.go | 8 ++++---- service/tcp.go | 2 +- service/tcp_test.go | 20 +++++++++---------- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/internal/integration_test/integration_test.go b/internal/integration_test/integration_test.go index d5597dc7..d0267640 100644 --- a/internal/integration_test/integration_test.go +++ b/internal/integration_test/integration_test.go @@ -137,7 +137,7 @@ func TestTCPEcho(t *testing.T) { handler.SetTargetDialer(&transport.TCPDialer{}) done := make(chan struct{}) go func() { - service.StreamServe(service.WrapStreamAccepter(proxyListener.AcceptTCP), handler.Handle) + service.StreamServe(service.WrapStreamListener(proxyListener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -205,7 +205,7 @@ func TestRestrictedAddresses(t *testing.T) { handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout) done := make(chan struct{}) go func() { - service.StreamServe(service.WrapStreamAccepter(proxyListener.AcceptTCP), handler.Handle) + service.StreamServe(service.WrapStreamListener(proxyListener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -388,7 +388,7 @@ func BenchmarkTCPThroughput(b *testing.B) { handler.SetTargetDialer(&transport.TCPDialer{}) done := make(chan struct{}) go func() { - service.StreamServe(service.WrapStreamAccepter(proxyListener.AcceptTCP), handler.Handle) + service.StreamServe(service.WrapStreamListener(proxyListener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -452,7 +452,7 @@ func BenchmarkTCPMultiplexing(b *testing.B) { handler.SetTargetDialer(&transport.TCPDialer{}) done := make(chan struct{}) go func() { - service.StreamServe(service.WrapStreamAccepter(proxyListener.AcceptTCP), handler.Handle) + service.StreamServe(service.WrapStreamListener(proxyListener.AcceptTCP), handler.Handle) done <- struct{}{} }() diff --git a/service/tcp.go b/service/tcp.go index 1a9fa8fa..69bd2a5c 100644 --- a/service/tcp.go +++ b/service/tcp.go @@ -275,7 +275,7 @@ func (l *ProxyListener) Accept() (ClientStreamConn, error) { type StreamAccepter func() (ClientStreamConn, error) -func WrapStreamAccepter[T transport.StreamConn](f func() (T, error)) StreamAccepter { +func WrapStreamListener[T transport.StreamConn](f func() (T, error)) StreamAccepter { return func() (ClientStreamConn, error) { c, err := f() return &clientStreamConn{StreamConn: c}, err diff --git a/service/tcp_test.go b/service/tcp_test.go index 51948b53..14069756 100644 --- a/service/tcp_test.go +++ b/service/tcp_test.go @@ -284,7 +284,7 @@ func TestProbeRandom(t *testing.T) { handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond) done := make(chan struct{}) go func() { - StreamServe(WrapStreamAccepter(listener.AcceptTCP), handler.Handle) + StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -362,7 +362,7 @@ func TestProbeClientBytesBasicTruncated(t *testing.T) { handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll)) done := make(chan struct{}) go func() { - StreamServe(WrapStreamAccepter(listener.AcceptTCP), handler.Handle) + StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -397,7 +397,7 @@ func TestProbeClientBytesBasicModified(t *testing.T) { handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll)) done := make(chan struct{}) go func() { - StreamServe(WrapStreamAccepter(listener.AcceptTCP), handler.Handle) + StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -433,7 +433,7 @@ func TestProbeClientBytesCoalescedModified(t *testing.T) { handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll)) done := make(chan struct{}) go func() { - StreamServe(WrapStreamAccepter(listener.AcceptTCP), handler.Handle) + StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -475,7 +475,7 @@ func TestProbeServerBytesModified(t *testing.T) { handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond) done := make(chan struct{}) go func() { - StreamServe(WrapStreamAccepter(listener.AcceptTCP), handler.Handle) + StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -528,7 +528,7 @@ func TestReplayDefense(t *testing.T) { done := make(chan struct{}) go func() { - StreamServe(WrapStreamAccepter(listener.AcceptTCP), handler.Handle) + StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -598,7 +598,7 @@ func TestReverseReplayDefense(t *testing.T) { done := make(chan struct{}) go func() { - StreamServe(WrapStreamAccepter(listener.AcceptTCP), handler.Handle) + StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -657,7 +657,7 @@ func probeExpectTimeout(t *testing.T, payloadSize int) { done := make(chan struct{}) go func() { - StreamServe(WrapStreamAccepter(listener.AcceptTCP), handler.Handle) + StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -717,14 +717,14 @@ func TestStreamServeEarlyClose(t *testing.T) { err = tcpListener.Close() require.NoError(t, err) // This should return quickly, without timing out or calling the handler. - StreamServe(WrapStreamAccepter(tcpListener.AcceptTCP), nil) + StreamServe(WrapStreamListener(tcpListener.AcceptTCP), nil) } // Makes sure the TCP listener returns [io.ErrClosed] on Close(). func TestClosedTCPListenerError(t *testing.T) { tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) require.NoError(t, err) - accept := WrapStreamAccepter(tcpListener.AcceptTCP) + accept := WrapStreamListener(tcpListener.AcceptTCP) err = tcpListener.Close() require.NoError(t, err) _, err = accept() From 4ce06f0ee394d1b784a3f6e1a35be3cacd19c3a0 Mon Sep 17 00:00:00 2001 From: sbruens Date: Fri, 21 Jun 2024 13:37:26 -0400 Subject: [PATCH 028/119] Use `Config` suffix for config types. --- cmd/outline-ss-server/config.go | 20 ++++++------ cmd/outline-ss-server/config_test.go | 48 ++++++++++++++-------------- 2 files changed, 34 insertions(+), 34 deletions(-) diff --git a/cmd/outline-ss-server/config.go b/cmd/outline-ss-server/config.go index 4df7de83..666d45da 100644 --- a/cmd/outline-ss-server/config.go +++ b/cmd/outline-ss-server/config.go @@ -25,37 +25,37 @@ import ( "gopkg.in/yaml.v2" ) -type Service struct { - Listeners []Listener - Keys []Key +type ServiceConfig struct { + Listeners []ListenerConfig + Keys []KeyConfig } type ListenerType string const listenerTypeDirect ListenerType = "direct" -type Listener struct { +type ListenerConfig struct { Type ListenerType Address string } -type Key struct { +type KeyConfig struct { ID string Cipher string Secret string } -type LegacyKeyService struct { - Key `yaml:",inline"` - Port int +type LegacyKeyServiceConfig struct { + KeyConfig `yaml:",inline"` + Port int } type Config struct { - Services []Service + Services []ServiceConfig // Deprecated: `keys` exists for backward compatibility. Prefer to configure // using the newer `services` format. - Keys []LegacyKeyService + Keys []LegacyKeyServiceConfig } // readConfig attempts to read a config from a filename and parses it as a [Config]. diff --git a/cmd/outline-ss-server/config_test.go b/cmd/outline-ss-server/config_test.go index fce0d8f8..af42c9fe 100644 --- a/cmd/outline-ss-server/config_test.go +++ b/cmd/outline-ss-server/config_test.go @@ -26,24 +26,24 @@ func TestReadConfig(t *testing.T) { require.NoError(t, err) expected := Config{ - Services: []Service{ - Service{ - Listeners: []Listener{ - Listener{Type: listenerTypeDirect, Address: "tcp://[::]:9000"}, - Listener{Type: listenerTypeDirect, Address: "udp://[::]:9000"}, + Services: []ServiceConfig{ + ServiceConfig{ + Listeners: []ListenerConfig{ + ListenerConfig{Type: listenerTypeDirect, Address: "tcp://[::]:9000"}, + ListenerConfig{Type: listenerTypeDirect, Address: "udp://[::]:9000"}, }, - Keys: []Key{ - Key{"user-0", "chacha20-ietf-poly1305", "Secret0"}, - Key{"user-1", "chacha20-ietf-poly1305", "Secret1"}, + Keys: []KeyConfig{ + KeyConfig{"user-0", "chacha20-ietf-poly1305", "Secret0"}, + KeyConfig{"user-1", "chacha20-ietf-poly1305", "Secret1"}, }, }, - Service{ - Listeners: []Listener{ - Listener{Type: listenerTypeDirect, Address: "tcp://[::]:9001"}, - Listener{Type: listenerTypeDirect, Address: "udp://[::]:9001"}, + ServiceConfig{ + Listeners: []ListenerConfig{ + ListenerConfig{Type: listenerTypeDirect, Address: "tcp://[::]:9001"}, + ListenerConfig{Type: listenerTypeDirect, Address: "udp://[::]:9001"}, }, - Keys: []Key{ - Key{"user-2", "chacha20-ietf-poly1305", "Secret2"}, + Keys: []KeyConfig{ + KeyConfig{"user-2", "chacha20-ietf-poly1305", "Secret2"}, }, }, }, @@ -56,18 +56,18 @@ func TestReadConfigParsesDeprecatedFormat(t *testing.T) { require.NoError(t, err) expected := Config{ - Keys: []LegacyKeyService{ - LegacyKeyService{ - Key: Key{ID: "user-0", Cipher: "chacha20-ietf-poly1305", Secret: "Secret0"}, - Port: 9000, + Keys: []LegacyKeyServiceConfig{ + LegacyKeyServiceConfig{ + KeyConfig: KeyConfig{ID: "user-0", Cipher: "chacha20-ietf-poly1305", Secret: "Secret0"}, + Port: 9000, }, - LegacyKeyService{ - Key: Key{ID: "user-1", Cipher: "chacha20-ietf-poly1305", Secret: "Secret1"}, - Port: 9000, + LegacyKeyServiceConfig{ + KeyConfig: KeyConfig{ID: "user-1", Cipher: "chacha20-ietf-poly1305", Secret: "Secret1"}, + Port: 9000, }, - LegacyKeyService{ - Key: Key{ID: "user-2", Cipher: "chacha20-ietf-poly1305", Secret: "Secret2"}, - Port: 9001, + LegacyKeyServiceConfig{ + KeyConfig: KeyConfig{ID: "user-2", Cipher: "chacha20-ietf-poly1305", Secret: "Secret2"}, + Port: 9001, }, }, } From 866003227e23e63aca544aaf3527fc15a81e5be4 Mon Sep 17 00:00:00 2001 From: sbruens Date: Fri, 21 Jun 2024 13:39:23 -0400 Subject: [PATCH 029/119] Remove the IP version specifiers from the `newListener` config handling. --- cmd/outline-ss-server/config.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cmd/outline-ss-server/config.go b/cmd/outline-ss-server/config.go index 666d45da..95c441c9 100644 --- a/cmd/outline-ss-server/config.go +++ b/cmd/outline-ss-server/config.go @@ -96,7 +96,7 @@ func validateListener(u *url.URL) error { // // Example addresses: // -// tcp4://127.0.0.1:8000 +// tcp://127.0.0.1:8000 // udp://127.0.0.1:9000 func newListener(addr string) (io.Closer, error) { u, err := url.Parse(addr) @@ -105,12 +105,12 @@ func newListener(addr string) (io.Closer, error) { } switch u.Scheme { - case "tcp", "tcp4", "tcp6": + case "tcp": if err := validateListener(u); err != nil { return nil, fmt.Errorf("invalid listener `%s`: %v", u, err) } return net.Listen(u.Scheme, u.Host) - case "udp", "udp4", "udp6": + case "udp": if err := validateListener(u); err != nil { return nil, fmt.Errorf("invalid listener `%s`: %v", u, err) } From 26b9100320c42ea459e336df8b03db6e99b94da5 Mon Sep 17 00:00:00 2001 From: sbruens Date: Fri, 21 Jun 2024 13:58:53 -0400 Subject: [PATCH 030/119] refactor: remove use of port in proving metric --- cmd/outline-ss-server/main.go | 2 +- cmd/outline-ss-server/metrics.go | 5 ++--- cmd/outline-ss-server/metrics_test.go | 4 ++-- service/tcp.go | 12 ++++++------ service/tcp_test.go | 2 +- 5 files changed, 12 insertions(+), 13 deletions(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 47b686a9..5f4b2821 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -88,7 +88,7 @@ func (s *SSServer) startPort(portNum int) error { port := &ssPort{tcpListener: listener, packetConn: packetConn, cipherList: service.NewCipherList()} authFunc := service.NewShadowsocksStreamAuthenticator(port.cipherList, &s.replayCache, s.m) // TODO: Register initial data metrics at zero. - tcpHandler := service.NewTCPHandler(portNum, authFunc, s.m, tcpReadTimeout) + tcpHandler := service.NewTCPHandler(listener.Addr().String(), authFunc, s.m, tcpReadTimeout) packetHandler := service.NewPacketHandler(s.natTimeout, port.cipherList, s.m) s.ports[portNum] = port accept := func() (transport.StreamConn, error) { diff --git a/cmd/outline-ss-server/metrics.go b/cmd/outline-ss-server/metrics.go index 531c16ba..e95ceeb3 100644 --- a/cmd/outline-ss-server/metrics.go +++ b/cmd/outline-ss-server/metrics.go @@ -18,7 +18,6 @@ import ( "fmt" "net" "net/netip" - "strconv" "sync" "time" @@ -357,8 +356,8 @@ func (m *outlineMetrics) RemoveUDPNatEntry(clientAddr net.Addr, accessKey string } } -func (m *outlineMetrics) AddTCPProbe(status, drainResult string, port int, clientProxyBytes int64) { - m.tcpProbes.WithLabelValues(strconv.Itoa(port), status, drainResult).Observe(float64(clientProxyBytes)) +func (m *outlineMetrics) AddTCPProbe(status, drainResult string, listenerId string, clientProxyBytes int64) { + m.tcpProbes.WithLabelValues(listenerId, status, drainResult).Observe(float64(clientProxyBytes)) } func (m *outlineMetrics) AddTCPCipherSearch(accessKeyFound bool, timeToCipher time.Duration) { diff --git a/cmd/outline-ss-server/metrics_test.go b/cmd/outline-ss-server/metrics_test.go index 353520e4..e2605918 100644 --- a/cmd/outline-ss-server/metrics_test.go +++ b/cmd/outline-ss-server/metrics_test.go @@ -68,7 +68,7 @@ func TestMethodsDontPanic(t *testing.T) { ssMetrics.AddUDPPacketFromTarget(ipInfo, "3", "OK", 10, 20) ssMetrics.AddUDPNatEntry(fakeAddr("127.0.0.1:9"), "key-1") ssMetrics.RemoveUDPNatEntry(fakeAddr("127.0.0.1:9"), "key-1") - ssMetrics.AddTCPProbe("ERR_CIPHER", "eof", 443, proxyMetrics.ClientProxy) + ssMetrics.AddTCPProbe("ERR_CIPHER", "eof", "127.0.0.1:443", proxyMetrics.ClientProxy) ssMetrics.AddTCPCipherSearch(true, 10*time.Millisecond) ssMetrics.AddUDPCipherSearch(true, 10*time.Millisecond) } @@ -168,7 +168,7 @@ func BenchmarkProbe(b *testing.B) { data := metrics.ProxyMetrics{} b.ResetTimer() for i := 0; i < b.N; i++ { - ssMetrics.AddTCPProbe(status, drainResult, port, data.ClientProxy) + ssMetrics.AddTCPProbe(status, drainResult, "127.0.0.1:12345", data.ClientProxy) } } diff --git a/service/tcp.go b/service/tcp.go index 85ab9990..2195480b 100644 --- a/service/tcp.go +++ b/service/tcp.go @@ -44,7 +44,7 @@ type TCPMetrics interface { AddOpenTCPConnection(clientInfo ipinfo.IPInfo) AddAuthenticatedTCPConnection(clientAddr net.Addr, accessKey string) AddClosedTCPConnection(clientInfo ipinfo.IPInfo, clientAddr net.Addr, accessKey string, status string, data metrics.ProxyMetrics, duration time.Duration) - AddTCPProbe(status, drainResult string, port int, clientProxyBytes int64) + AddTCPProbe(status, drainResult string, listenerId string, clientProxyBytes int64) } func remoteIP(conn net.Conn) netip.Addr { @@ -162,7 +162,7 @@ func NewShadowsocksStreamAuthenticator(ciphers CipherList, replayCache *ReplayCa } type tcpHandler struct { - port int + listenerId string m TCPMetrics readTimeout time.Duration authenticate StreamAuthenticateFunc @@ -170,9 +170,9 @@ type tcpHandler struct { } // NewTCPService creates a TCPService -func NewTCPHandler(port int, authenticate StreamAuthenticateFunc, m TCPMetrics, timeout time.Duration) TCPHandler { +func NewTCPHandler(listenerId string, authenticate StreamAuthenticateFunc, m TCPMetrics, timeout time.Duration) TCPHandler { return &tcpHandler{ - port: port, + listenerId: listenerId, m: m, readTimeout: timeout, authenticate: authenticate, @@ -375,7 +375,7 @@ func (h *tcpHandler) absorbProbe(clientConn io.ReadCloser, status string, proxyM _, drainErr := io.Copy(io.Discard, clientConn) // drain socket drainResult := drainErrToString(drainErr) logger.Debugf("Drain error: %v, drain result: %v", drainErr, drainResult) - h.m.AddTCPProbe(status, drainResult, h.port, proxyMetrics.ClientProxy) + h.m.AddTCPProbe(status, drainResult, h.listenerId, proxyMetrics.ClientProxy) } func drainErrToString(drainErr error) string { @@ -404,6 +404,6 @@ func (m *NoOpTCPMetrics) GetIPInfo(net.IP) (ipinfo.IPInfo, error) { func (m *NoOpTCPMetrics) AddOpenTCPConnection(clientInfo ipinfo.IPInfo) {} func (m *NoOpTCPMetrics) AddAuthenticatedTCPConnection(clientAddr net.Addr, accessKey string) { } -func (m *NoOpTCPMetrics) AddTCPProbe(status, drainResult string, port int, clientProxyBytes int64) { +func (m *NoOpTCPMetrics) AddTCPProbe(status, drainResult string, listenerId string, clientProxyBytes int64) { } func (m *NoOpTCPMetrics) AddTCPCipherSearch(accessKeyFound bool, timeToCipher time.Duration) {} diff --git a/service/tcp_test.go b/service/tcp_test.go index 1a70ed67..2f66a7aa 100644 --- a/service/tcp_test.go +++ b/service/tcp_test.go @@ -239,7 +239,7 @@ func (m *probeTestMetrics) AddOpenTCPConnection(clientInfo ipinfo.IPInfo) { func (m *probeTestMetrics) AddAuthenticatedTCPConnection(clientAddr net.Addr, accessKey string) { } -func (m *probeTestMetrics) AddTCPProbe(status, drainResult string, port int, clientProxyBytes int64) { +func (m *probeTestMetrics) AddTCPProbe(status, drainResult string, listenerId string, clientProxyBytes int64) { m.mu.Lock() m.probeData = append(m.probeData, clientProxyBytes) m.probeStatus = append(m.probeStatus, status) From 1b8e9038625a6cf62e8e630b0866bcb20df8c936 Mon Sep 17 00:00:00 2001 From: sbruens Date: Fri, 21 Jun 2024 14:07:36 -0400 Subject: [PATCH 031/119] Fix tests. --- cmd/outline-ss-server/metrics_test.go | 1 - internal/integration_test/integration_test.go | 8 ++++---- service/tcp_test.go | 16 ++++++++-------- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/cmd/outline-ss-server/metrics_test.go b/cmd/outline-ss-server/metrics_test.go index e2605918..80e81817 100644 --- a/cmd/outline-ss-server/metrics_test.go +++ b/cmd/outline-ss-server/metrics_test.go @@ -164,7 +164,6 @@ func BenchmarkProbe(b *testing.B) { ssMetrics := newPrometheusOutlineMetrics(nil, prometheus.NewRegistry()) status := "ERR_REPLAY" drainResult := "other" - port := 12345 data := metrics.ProxyMetrics{} b.ResetTimer() for i := 0; i < b.N; i++ { diff --git a/internal/integration_test/integration_test.go b/internal/integration_test/integration_test.go index 4ca2f120..f98319f4 100644 --- a/internal/integration_test/integration_test.go +++ b/internal/integration_test/integration_test.go @@ -133,7 +133,7 @@ func TestTCPEcho(t *testing.T) { const testTimeout = 200 * time.Millisecond testMetrics := &service.NoOpTCPMetrics{} authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics) - handler := service.NewTCPHandler(proxyListener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout) + handler := service.NewTCPHandler(proxyListener.Addr().String(), authFunc, testMetrics, testTimeout) handler.SetTargetDialer(&transport.TCPDialer{}) done := make(chan struct{}) go func() { @@ -202,7 +202,7 @@ func TestRestrictedAddresses(t *testing.T) { const testTimeout = 200 * time.Millisecond testMetrics := &statusMetrics{} authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := service.NewTCPHandler(proxyListener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout) + handler := service.NewTCPHandler(proxyListener.Addr().String(), authFunc, testMetrics, testTimeout) done := make(chan struct{}) go func() { service.StreamServe(service.WrapStreamListener(proxyListener.AcceptTCP), handler.Handle) @@ -384,7 +384,7 @@ func BenchmarkTCPThroughput(b *testing.B) { const testTimeout = 200 * time.Millisecond testMetrics := &service.NoOpTCPMetrics{} authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := service.NewTCPHandler(proxyListener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout) + handler := service.NewTCPHandler(proxyListener.Addr().String(), authFunc, testMetrics, testTimeout) handler.SetTargetDialer(&transport.TCPDialer{}) done := make(chan struct{}) go func() { @@ -448,7 +448,7 @@ func BenchmarkTCPMultiplexing(b *testing.B) { const testTimeout = 200 * time.Millisecond testMetrics := &service.NoOpTCPMetrics{} authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics) - handler := service.NewTCPHandler(proxyListener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout) + handler := service.NewTCPHandler(proxyListener.Addr().String(), authFunc, testMetrics, testTimeout) handler.SetTargetDialer(&transport.TCPDialer{}) done := make(chan struct{}) go func() { diff --git a/service/tcp_test.go b/service/tcp_test.go index 2f66a7aa..5c3bc9df 100644 --- a/service/tcp_test.go +++ b/service/tcp_test.go @@ -281,7 +281,7 @@ func TestProbeRandom(t *testing.T) { require.NoError(t, err, "MakeTestCiphers failed: %v", err) testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, 200*time.Millisecond) + handler := NewTCPHandler(listener.Addr().String(), authFunc, testMetrics, 200*time.Millisecond) done := make(chan struct{}) go func() { StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) @@ -358,7 +358,7 @@ func TestProbeClientBytesBasicTruncated(t *testing.T) { cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, 200*time.Millisecond) + handler := NewTCPHandler(listener.Addr().String(), authFunc, testMetrics, 200*time.Millisecond) handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll)) done := make(chan struct{}) go func() { @@ -393,7 +393,7 @@ func TestProbeClientBytesBasicModified(t *testing.T) { cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, 200*time.Millisecond) + handler := NewTCPHandler(listener.Addr().String(), authFunc, testMetrics, 200*time.Millisecond) handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll)) done := make(chan struct{}) go func() { @@ -429,7 +429,7 @@ func TestProbeClientBytesCoalescedModified(t *testing.T) { cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, 200*time.Millisecond) + handler := NewTCPHandler(listener.Addr().String(), authFunc, testMetrics, 200*time.Millisecond) handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll)) done := make(chan struct{}) go func() { @@ -472,7 +472,7 @@ func TestProbeServerBytesModified(t *testing.T) { cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, 200*time.Millisecond) + handler := NewTCPHandler(listener.Addr().String(), authFunc, testMetrics, 200*time.Millisecond) done := make(chan struct{}) go func() { StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) @@ -503,7 +503,7 @@ func TestReplayDefense(t *testing.T) { testMetrics := &probeTestMetrics{} const testTimeout = 200 * time.Millisecond authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics) - handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout) + handler := NewTCPHandler(listener.Addr().String(), authFunc, testMetrics, testTimeout) snapshot := cipherList.SnapshotForClientIP(netip.Addr{}) cipherEntry := snapshot[0].Value.(*CipherEntry) cipher := cipherEntry.CryptoKey @@ -582,7 +582,7 @@ func TestReverseReplayDefense(t *testing.T) { testMetrics := &probeTestMetrics{} const testTimeout = 200 * time.Millisecond authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics) - handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout) + handler := NewTCPHandler(listener.Addr().String(), authFunc, testMetrics, testTimeout) snapshot := cipherList.SnapshotForClientIP(netip.Addr{}) cipherEntry := snapshot[0].Value.(*CipherEntry) cipher := cipherEntry.CryptoKey @@ -653,7 +653,7 @@ func probeExpectTimeout(t *testing.T, payloadSize int) { require.NoError(t, err, "MakeTestCiphers failed: %v", err) testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout) + handler := NewTCPHandler(listener.Addr().String(), authFunc, testMetrics, testTimeout) done := make(chan struct{}) go func() { From 4216ce3233825deae5f6079b51e0e91a175abbb8 Mon Sep 17 00:00:00 2001 From: sbruens Date: Fri, 21 Jun 2024 14:47:50 -0400 Subject: [PATCH 032/119] Add a TODO comment to allow short-form direct listener config. --- cmd/outline-ss-server/config_example.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cmd/outline-ss-server/config_example.yml b/cmd/outline-ss-server/config_example.yml index 66009c10..f7fc2e71 100644 --- a/cmd/outline-ss-server/config_example.yml +++ b/cmd/outline-ss-server/config_example.yml @@ -1,5 +1,7 @@ services: - listeners: + # TODO(sbruens): Allow a string-based listener config, as a convenient short-form + # to create a direct listener, e.g. `- tcp://[::]:9000`. - type: direct address: "tcp://[::]:9000" - type: direct From 35c828db2fba3a5863075e674a3a404772cf7842 Mon Sep 17 00:00:00 2001 From: sbruens Date: Fri, 21 Jun 2024 15:48:18 -0400 Subject: [PATCH 033/119] Make legacy key config name consistent with type. --- cmd/outline-ss-server/main.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 066299aa..f7b45dc1 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -138,14 +138,14 @@ func (s *SSServer) loadConfig(filename string) error { addrChanges := make(map[string]int) addrCiphers := make(map[string]*list.List) // Values are *List of *CipherEntry. - for _, legacyKeyConfig := range config.Keys { - cryptoKey, err := shadowsocks.NewEncryptionKey(legacyKeyConfig.Cipher, legacyKeyConfig.Secret) + for _, legacyKeyServiceConfig := range config.Keys { + cryptoKey, err := shadowsocks.NewEncryptionKey(legacyKeyServiceConfig.Cipher, legacyKeyServiceConfig.Secret) if err != nil { - return fmt.Errorf("failed to create encyption key for key %v: %w", legacyKeyConfig.ID, err) + return fmt.Errorf("failed to create encyption key for key %v: %w", legacyKeyServiceConfig.ID, err) } - entry := service.MakeCipherEntry(legacyKeyConfig.ID, cryptoKey, legacyKeyConfig.Secret) + entry := service.MakeCipherEntry(legacyKeyServiceConfig.ID, cryptoKey, legacyKeyServiceConfig.Secret) for _, ln := range []string{"tcp", "udp"} { - addr := fmt.Sprintf("%s://[::]:%d", ln, legacyKeyConfig.Port) + addr := fmt.Sprintf("%s://[::]:%d", ln, legacyKeyServiceConfig.Port) addrChanges[addr] = 1 ciphers, ok := addrCiphers[addr] if !ok { From 1322f2d124c5398d65db626baa6596244d5554f5 Mon Sep 17 00:00:00 2001 From: sbruens Date: Fri, 21 Jun 2024 15:59:38 -0400 Subject: [PATCH 034/119] Move config validation out of the `loadConfig` function. --- cmd/outline-ss-server/config.go | 17 +++++++++ cmd/outline-ss-server/config_test.go | 54 ++++++++++++++++++++++++++++ cmd/outline-ss-server/main.go | 20 ++++------- 3 files changed, 77 insertions(+), 14 deletions(-) diff --git a/cmd/outline-ss-server/config.go b/cmd/outline-ss-server/config.go index 95c441c9..73550454 100644 --- a/cmd/outline-ss-server/config.go +++ b/cmd/outline-ss-server/config.go @@ -58,6 +58,23 @@ type Config struct { Keys []LegacyKeyServiceConfig } +// Validate checks that the config is valid. +func (c *Config) Validate() error { + for _, serviceConfig := range c.Services { + if serviceConfig.Listeners == nil || serviceConfig.Keys == nil { + return errors.New("must specify at least 1 listener and 1 key per service") + } + + for _, listener := range serviceConfig.Listeners { + // TODO: Support more listener types. + if listener.Type != listenerTypeDirect { + return fmt.Errorf("unsupported listener type: %s", listener.Type) + } + } + } + return nil +} + // readConfig attempts to read a config from a filename and parses it as a [Config]. func readConfig(filename string) (*Config, error) { config := Config{} diff --git a/cmd/outline-ss-server/config_test.go b/cmd/outline-ss-server/config_test.go index af42c9fe..d8628f65 100644 --- a/cmd/outline-ss-server/config_test.go +++ b/cmd/outline-ss-server/config_test.go @@ -21,6 +21,60 @@ import ( "github.com/stretchr/testify/require" ) +func TestValidateConfigFails(t *testing.T) { + tests := []struct { + name string + cfg *Config + }{ + { + name: "WithoutListeners", + cfg: &Config{ + Services: []ServiceConfig{ + ServiceConfig{ + Keys: []KeyConfig{ + KeyConfig{"user-0", "chacha20-ietf-poly1305", "Secret0"}, + }, + }, + }, + }, + }, + { + name: "WithoutKeys", + cfg: &Config{ + Services: []ServiceConfig{ + ServiceConfig{ + Listeners: []ListenerConfig{ + ListenerConfig{Type: listenerTypeDirect, Address: "tcp://[::]:9000"}, + }, + }, + }, + }, + }, + { + name: "WithUnknownListenerType", + cfg: &Config{ + Services: []ServiceConfig{ + ServiceConfig{ + Listeners: []ListenerConfig{ + ListenerConfig{Type: "foo", Address: "tcp://[::]:9000"}, + }, + Keys: []KeyConfig{ + KeyConfig{"user-0", "chacha20-ietf-poly1305", "Secret0"}, + }, + }, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := tc.cfg.Validate() + require.Error(t, err) + }) + } +} + func TestReadConfig(t *testing.T) { config, err := readConfig("./config_example.yml") diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index f7b45dc1..b457b990 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -133,7 +133,9 @@ func (s *SSServer) loadConfig(filename string) error { if err != nil { return fmt.Errorf("failed to load config (%v): %w", filename, err) } - + if err := config.Validate(); err != nil { + return err + } uniqueCiphers := 0 addrChanges := make(map[string]int) addrCiphers := make(map[string]*list.List) // Values are *List of *CipherEntry. @@ -158,10 +160,6 @@ func (s *SSServer) loadConfig(filename string) error { } for _, serviceConfig := range config.Services { - if serviceConfig.Listeners == nil || serviceConfig.Keys == nil { - return fmt.Errorf("must specify at least 1 listener and 1 key per service") - } - ciphers := list.New() type cipherKey struct { cipher string @@ -186,14 +184,8 @@ func (s *SSServer) loadConfig(filename string) error { uniqueCiphers += ciphers.Len() for _, listener := range serviceConfig.Listeners { - switch t := listener.Type; t { - // TODO: Support more listener types. - case listenerTypeDirect: - addrChanges[listener.Address] = 1 - addrCiphers[listener.Address] = ciphers - default: - return fmt.Errorf("unsupported listener type: %s", t) - } + addrChanges[listener.Address] = 1 + addrCiphers[listener.Address] = ciphers } } for listener := range s.listeners { @@ -245,7 +237,7 @@ func RunSSServer(filename string, natTimeout time.Duration, sm *outlineMetrics, } err := server.loadConfig(filename) if err != nil { - return nil, fmt.Errorf("failed configure server: %w", err) + return nil, fmt.Errorf("failed to configure server: %w", err) } sigHup := make(chan os.Signal, 1) signal.Notify(sigHup, syscall.SIGHUP) From adc11f2b6b3711ea333cec1f65b757e675c8b3be Mon Sep 17 00:00:00 2001 From: sbruens Date: Fri, 21 Jun 2024 16:12:03 -0400 Subject: [PATCH 035/119] Remove unused port from bad merge. --- service/tcp.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/service/tcp.go b/service/tcp.go index 297ed478..2195480b 100644 --- a/service/tcp.go +++ b/service/tcp.go @@ -342,8 +342,7 @@ func (h *tcpHandler) handleConnection(ctx context.Context, outerConn transport.S id, innerConn, authErr := h.authenticate(outerConn) if authErr != nil { // Drain to protect against probing attacks. - port := outerConn.LocalAddr().(*net.TCPAddr).Port - h.absorbProbe(outerConn, port, authErr.Status, proxyMetrics) + h.absorbProbe(outerConn, authErr.Status, proxyMetrics) return id, authErr } h.m.AddAuthenticatedTCPConnection(outerConn.RemoteAddr(), id) @@ -371,7 +370,7 @@ func (h *tcpHandler) handleConnection(ctx context.Context, outerConn transport.S // Keep the connection open until we hit the authentication deadline to protect against probing attacks // `proxyMetrics` is a pointer because its value is being mutated by `clientConn`. -func (h *tcpHandler) absorbProbe(clientConn io.ReadCloser, port int, status string, proxyMetrics *metrics.ProxyMetrics) { +func (h *tcpHandler) absorbProbe(clientConn io.ReadCloser, status string, proxyMetrics *metrics.ProxyMetrics) { // This line updates proxyMetrics.ClientProxy before it's used in AddTCPProbe. _, drainErr := io.Copy(io.Discard, clientConn) // drain socket drainResult := drainErrToString(drainErr) From 3084dfd2be13d8256607644121b193f5a50fa9a2 Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 24 Jun 2024 10:57:45 -0400 Subject: [PATCH 036/119] Add comment describing keys. --- cmd/outline-ss-server/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index b457b990..ce377e79 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -69,7 +69,7 @@ type SSServer struct { natTimeout time.Duration m *outlineMetrics replayCache service.ReplayCache - listeners map[string]*ssListener + listeners map[string]*ssListener // Keys are addresses, e.g. `tcp://[::]:9000` } func (s *SSServer) serve(addr string, listener io.Closer, cipherList service.CipherList) error { From 7e5aae52b16d149057a11014b474d6a803ad1513 Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 24 Jun 2024 12:02:21 -0400 Subject: [PATCH 037/119] Move validation of listeners to config's `Validate()` function. --- cmd/outline-ss-server/config.go | 29 +++++++++++++++++----------- cmd/outline-ss-server/config_test.go | 15 ++++++++++++++ cmd/outline-ss-server/main.go | 2 +- 3 files changed, 34 insertions(+), 12 deletions(-) diff --git a/cmd/outline-ss-server/config.go b/cmd/outline-ss-server/config.go index 73550454..ea96d0e8 100644 --- a/cmd/outline-ss-server/config.go +++ b/cmd/outline-ss-server/config.go @@ -65,10 +65,23 @@ func (c *Config) Validate() error { return errors.New("must specify at least 1 listener and 1 key per service") } - for _, listener := range serviceConfig.Listeners { + for _, listenerConfig := range serviceConfig.Listeners { // TODO: Support more listener types. - if listener.Type != listenerTypeDirect { - return fmt.Errorf("unsupported listener type: %s", listener.Type) + if listenerConfig.Type != listenerTypeDirect { + return fmt.Errorf("unsupported listener type: %s", listenerConfig.Type) + } + + u, err := url.Parse(listenerConfig.Address) + if err != nil { + return err + } + switch u.Scheme { + case "tcp", "udp": + if err := validateListenerAddress(u); err != nil { + return fmt.Errorf("invalid listener address `%s`: %v", u, err) + } + default: + return fmt.Errorf("unsupported protocol: %s", u.Scheme) } } } @@ -89,8 +102,8 @@ func readConfig(filename string) (*Config, error) { return &config, nil } -// validateListener asserts that a listener URI conforms to the expected format. -func validateListener(u *url.URL) error { +// validateListenerAddress asserts that a listener URI conforms to the expected format. +func validateListenerAddress(u *url.URL) error { if u.Opaque != "" { return errors.New("URI cannot have an opaque part") } @@ -123,14 +136,8 @@ func newListener(addr string) (io.Closer, error) { switch u.Scheme { case "tcp": - if err := validateListener(u); err != nil { - return nil, fmt.Errorf("invalid listener `%s`: %v", u, err) - } return net.Listen(u.Scheme, u.Host) case "udp": - if err := validateListener(u); err != nil { - return nil, fmt.Errorf("invalid listener `%s`: %v", u, err) - } return net.ListenPacket(u.Scheme, u.Host) default: return nil, fmt.Errorf("unsupported protocol: %s", u.Scheme) diff --git a/cmd/outline-ss-server/config_test.go b/cmd/outline-ss-server/config_test.go index d8628f65..72c11083 100644 --- a/cmd/outline-ss-server/config_test.go +++ b/cmd/outline-ss-server/config_test.go @@ -65,6 +65,21 @@ func TestValidateConfigFails(t *testing.T) { }, }, }, + { + name: "WithInvalidListenerAddress", + cfg: &Config{ + Services: []ServiceConfig{ + ServiceConfig{ + Listeners: []ListenerConfig{ + ListenerConfig{Type: listenerTypeDirect, Address: "tcp://[::]:9000/path"}, + }, + Keys: []KeyConfig{ + KeyConfig{"user-0", "chacha20-ietf-poly1305", "Secret0"}, + }, + }, + }, + }, + }, } for _, tc := range tests { diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index ce377e79..5e38804d 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -134,7 +134,7 @@ func (s *SSServer) loadConfig(filename string) error { return fmt.Errorf("failed to load config (%v): %w", filename, err) } if err := config.Validate(); err != nil { - return err + return fmt.Errorf("failed to validate config: %w", err) } uniqueCiphers := 0 addrChanges := make(map[string]int) From b136c79de18d163848c75d39a0da5d90b98b4689 Mon Sep 17 00:00:00 2001 From: sbruens Date: Tue, 25 Jun 2024 12:03:36 -0400 Subject: [PATCH 038/119] Introduce a `NetworkAdd` to centralize parsing and creation of listeners. --- cmd/outline-ss-server/config.go | 58 +------------- cmd/outline-ss-server/config_example.yml | 10 +-- cmd/outline-ss-server/config_test.go | 14 ++-- cmd/outline-ss-server/listeners.go | 99 ++++++++++++++++++++++++ cmd/outline-ss-server/main.go | 10 ++- 5 files changed, 122 insertions(+), 69 deletions(-) create mode 100644 cmd/outline-ss-server/listeners.go diff --git a/cmd/outline-ss-server/config.go b/cmd/outline-ss-server/config.go index ea96d0e8..b970c120 100644 --- a/cmd/outline-ss-server/config.go +++ b/cmd/outline-ss-server/config.go @@ -17,9 +17,6 @@ package main import ( "errors" "fmt" - "io" - "net" - "net/url" "os" "gopkg.in/yaml.v2" @@ -71,17 +68,12 @@ func (c *Config) Validate() error { return fmt.Errorf("unsupported listener type: %s", listenerConfig.Type) } - u, err := url.Parse(listenerConfig.Address) + network, _, _, err := SplitNetworkAddr(listenerConfig.Address) if err != nil { - return err + return fmt.Errorf("invalid listener address `%s`: %v", listenerConfig.Address, err) } - switch u.Scheme { - case "tcp", "udp": - if err := validateListenerAddress(u); err != nil { - return fmt.Errorf("invalid listener address `%s`: %v", u, err) - } - default: - return fmt.Errorf("unsupported protocol: %s", u.Scheme) + if network != "tcp" && network != "udp" { + return fmt.Errorf("unsupported network: %s", network) } } } @@ -101,45 +93,3 @@ func readConfig(filename string) (*Config, error) { } return &config, nil } - -// validateListenerAddress asserts that a listener URI conforms to the expected format. -func validateListenerAddress(u *url.URL) error { - if u.Opaque != "" { - return errors.New("URI cannot have an opaque part") - } - if u.User != nil { - return errors.New("URI cannot have an userdata part") - } - if u.RawQuery != "" || u.ForceQuery { - return errors.New("URI cannot have a query part") - } - if u.Fragment != "" { - return errors.New("URI cannot have a fragement") - } - if u.Path != "" && u.Path != "/" { - return errors.New("URI path not allowed") - } - return nil -} - -// newListener creates a new listener from a URL-style address specification. -// -// Example addresses: -// -// tcp://127.0.0.1:8000 -// udp://127.0.0.1:9000 -func newListener(addr string) (io.Closer, error) { - u, err := url.Parse(addr) - if err != nil { - return nil, err - } - - switch u.Scheme { - case "tcp": - return net.Listen(u.Scheme, u.Host) - case "udp": - return net.ListenPacket(u.Scheme, u.Host) - default: - return nil, fmt.Errorf("unsupported protocol: %s", u.Scheme) - } -} diff --git a/cmd/outline-ss-server/config_example.yml b/cmd/outline-ss-server/config_example.yml index f7fc2e71..bbfd265f 100644 --- a/cmd/outline-ss-server/config_example.yml +++ b/cmd/outline-ss-server/config_example.yml @@ -1,11 +1,11 @@ services: - listeners: # TODO(sbruens): Allow a string-based listener config, as a convenient short-form - # to create a direct listener, e.g. `- tcp://[::]:9000`. + # to create a direct listener, e.g. `- tcp/[::]:9000`. - type: direct - address: "tcp://[::]:9000" + address: "tcp/[::]:9000" - type: direct - address: "udp://[::]:9000" + address: "udp/[::]:9000" keys: - id: user-0 cipher: chacha20-ietf-poly1305 @@ -16,9 +16,9 @@ services: - listeners: - type: direct - address: "tcp://[::]:9001" + address: "tcp/[::]:9001" - type: direct - address: "udp://[::]:9001" + address: "udp/[::]:9001" keys: - id: user-2 cipher: chacha20-ietf-poly1305 diff --git a/cmd/outline-ss-server/config_test.go b/cmd/outline-ss-server/config_test.go index 72c11083..3e76fa57 100644 --- a/cmd/outline-ss-server/config_test.go +++ b/cmd/outline-ss-server/config_test.go @@ -44,7 +44,7 @@ func TestValidateConfigFails(t *testing.T) { Services: []ServiceConfig{ ServiceConfig{ Listeners: []ListenerConfig{ - ListenerConfig{Type: listenerTypeDirect, Address: "tcp://[::]:9000"}, + ListenerConfig{Type: listenerTypeDirect, Address: "tcp/[::]:9000"}, }, }, }, @@ -56,7 +56,7 @@ func TestValidateConfigFails(t *testing.T) { Services: []ServiceConfig{ ServiceConfig{ Listeners: []ListenerConfig{ - ListenerConfig{Type: "foo", Address: "tcp://[::]:9000"}, + ListenerConfig{Type: "foo", Address: "tcp/[::]:9000"}, }, Keys: []KeyConfig{ KeyConfig{"user-0", "chacha20-ietf-poly1305", "Secret0"}, @@ -71,7 +71,7 @@ func TestValidateConfigFails(t *testing.T) { Services: []ServiceConfig{ ServiceConfig{ Listeners: []ListenerConfig{ - ListenerConfig{Type: listenerTypeDirect, Address: "tcp://[::]:9000/path"}, + ListenerConfig{Type: listenerTypeDirect, Address: "tcp//[::]:9000"}, }, Keys: []KeyConfig{ KeyConfig{"user-0", "chacha20-ietf-poly1305", "Secret0"}, @@ -98,8 +98,8 @@ func TestReadConfig(t *testing.T) { Services: []ServiceConfig{ ServiceConfig{ Listeners: []ListenerConfig{ - ListenerConfig{Type: listenerTypeDirect, Address: "tcp://[::]:9000"}, - ListenerConfig{Type: listenerTypeDirect, Address: "udp://[::]:9000"}, + ListenerConfig{Type: listenerTypeDirect, Address: "tcp/[::]:9000"}, + ListenerConfig{Type: listenerTypeDirect, Address: "udp/[::]:9000"}, }, Keys: []KeyConfig{ KeyConfig{"user-0", "chacha20-ietf-poly1305", "Secret0"}, @@ -108,8 +108,8 @@ func TestReadConfig(t *testing.T) { }, ServiceConfig{ Listeners: []ListenerConfig{ - ListenerConfig{Type: listenerTypeDirect, Address: "tcp://[::]:9001"}, - ListenerConfig{Type: listenerTypeDirect, Address: "udp://[::]:9001"}, + ListenerConfig{Type: listenerTypeDirect, Address: "tcp/[::]:9001"}, + ListenerConfig{Type: listenerTypeDirect, Address: "udp/[::]:9001"}, }, Keys: []KeyConfig{ KeyConfig{"user-2", "chacha20-ietf-poly1305", "Secret2"}, diff --git a/cmd/outline-ss-server/listeners.go b/cmd/outline-ss-server/listeners.go new file mode 100644 index 00000000..09f24a10 --- /dev/null +++ b/cmd/outline-ss-server/listeners.go @@ -0,0 +1,99 @@ +// Copyright 2024 The Outline Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "errors" + "fmt" + "io" + "net" + "strconv" + "strings" +) + +type NetworkAddr struct { + network string + Host string + Port uint +} + +// String returns a human-readable representation of the [NetworkAddr]. +func (na *NetworkAddr) Network() string { + return na.network +} + +// String returns a human-readable representation of the [NetworkAddr]. +func (na *NetworkAddr) String() string { + return na.JoinHostPort() +} + +// JoinHostPort is a convenience wrapper around [net.JoinHostPort]. +func (na *NetworkAddr) JoinHostPort() string { + return net.JoinHostPort(na.Host, strconv.Itoa(int(na.Port))) +} + +// Listen creates a new listener for the [NetworkAddr]. +func (na *NetworkAddr) Listen() (io.Closer, error) { + address := na.JoinHostPort() + + switch na.network { + + case "tcp": + return net.Listen(na.network, address) + case "udp": + return net.ListenPacket(na.network, address) + default: + return nil, fmt.Errorf("unsupported network: %s", na.network) + } +} + +// ParseNetworkAddr parses an address into a [NetworkAddr]. The input +// string is expected to be of the form "network/host:port" where any part is +// optional. +// +// Examples: +// +// tcp/127.0.0.1:8000 +// udp/127.0.0.1:9000 +func ParseNetworkAddr(addr string) (NetworkAddr, error) { + var host, port string + network, host, port, err := SplitNetworkAddr(addr) + if err != nil { + return NetworkAddr{}, err + } + if network == "" { + return NetworkAddr{}, errors.New("missing network") + } + p, err := strconv.ParseUint(port, 10, 16) + if err != nil { + return NetworkAddr{}, fmt.Errorf("invalid port: %v", err) + } + return NetworkAddr{ + network: network, + Host: host, + Port: uint(p), + }, nil +} + +// SplitNetworkAddr splits a into its network, host, and port components. +func SplitNetworkAddr(a string) (network, host, port string, err error) { + beforeSlash, afterSlash, slashFound := strings.Cut(a, "/") + if slashFound { + network = strings.ToLower(strings.TrimSpace(beforeSlash)) + a = afterSlash + } + host, port, err = net.SplitHostPort(a) + return +} diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 5e38804d..f6b0f8cf 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -69,7 +69,7 @@ type SSServer struct { natTimeout time.Duration m *outlineMetrics replayCache service.ReplayCache - listeners map[string]*ssListener // Keys are addresses, e.g. `tcp://[::]:9000` + listeners map[string]*ssListener // Keys are addresses, e.g. `tcp/[::]:9000` } func (s *SSServer) serve(addr string, listener io.Closer, cipherList service.CipherList) error { @@ -98,7 +98,11 @@ func (s *SSServer) serve(addr string, listener io.Closer, cipherList service.Cip } func (s *SSServer) start(addr string, cipherList service.CipherList) (io.Closer, error) { - listener, err := newListener(addr) + listenAddr, err := ParseNetworkAddr(addr) + if err != nil { + return nil, fmt.Errorf("error parsing listener address `%s`: %v", addr, err) + } + listener, err := listenAddr.Listen() if err != nil { //lint:ignore ST1005 Shadowsocks is capitalized. return nil, fmt.Errorf("Shadowsocks service failed to start on address %v: %w", addr, err) @@ -147,7 +151,7 @@ func (s *SSServer) loadConfig(filename string) error { } entry := service.MakeCipherEntry(legacyKeyServiceConfig.ID, cryptoKey, legacyKeyServiceConfig.Secret) for _, ln := range []string{"tcp", "udp"} { - addr := fmt.Sprintf("%s://[::]:%d", ln, legacyKeyServiceConfig.Port) + addr := fmt.Sprintf("%s/[::]:%d", ln, legacyKeyServiceConfig.Port) addrChanges[addr] = 1 ciphers, ok := addrCiphers[addr] if !ok { From 4bf9c272dd792e9638d37ebd6718f70b1d3b9bbf Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 24 Jun 2024 15:17:57 -0400 Subject: [PATCH 039/119] Use `net.ListenConfig` to listen. --- cmd/outline-ss-server/listeners.go | 7 ++++--- cmd/outline-ss-server/main.go | 12 ++++-------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/cmd/outline-ss-server/listeners.go b/cmd/outline-ss-server/listeners.go index 09f24a10..a5c9c8b0 100644 --- a/cmd/outline-ss-server/listeners.go +++ b/cmd/outline-ss-server/listeners.go @@ -15,6 +15,7 @@ package main import ( + "context" "errors" "fmt" "io" @@ -45,15 +46,15 @@ func (na *NetworkAddr) JoinHostPort() string { } // Listen creates a new listener for the [NetworkAddr]. -func (na *NetworkAddr) Listen() (io.Closer, error) { +func (na *NetworkAddr) Listen(ctx context.Context, config net.ListenConfig) (io.Closer, error) { address := na.JoinHostPort() switch na.network { case "tcp": - return net.Listen(na.network, address) + return config.Listen(ctx, na.network, address) case "udp": - return net.ListenPacket(na.network, address) + return config.ListenPacket(ctx, na.network, address) default: return nil, fmt.Errorf("unsupported network: %s", na.network) } diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index f6b0f8cf..ca019810 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -16,6 +16,7 @@ package main import ( "container/list" + "context" "flag" "fmt" "io" @@ -79,13 +80,8 @@ func (s *SSServer) serve(addr string, listener io.Closer, cipherList service.Cip // TODO: Register initial data metrics at zero. tcpHandler := service.NewTCPHandler(addr, authFunc, s.m, tcpReadTimeout) accept := func() (transport.StreamConn, error) { - conn, err := ln.Accept() - if err != nil { - return nil, err - } - c := conn.(*net.TCPConn) - c.SetKeepAlive(true) - return c, err + c, err := ln.Accept() + return c.(transport.StreamConn), err } go service.StreamServe(accept, tcpHandler.Handle) case net.PacketConn: @@ -102,7 +98,7 @@ func (s *SSServer) start(addr string, cipherList service.CipherList) (io.Closer, if err != nil { return nil, fmt.Errorf("error parsing listener address `%s`: %v", addr, err) } - listener, err := listenAddr.Listen() + listener, err := listenAddr.Listen(context.TODO(), net.ListenConfig{KeepAlive: 0}) if err != nil { //lint:ignore ST1005 Shadowsocks is capitalized. return nil, fmt.Errorf("Shadowsocks service failed to start on address %v: %w", addr, err) From b7bb65bf605a4a9e24c51d7174a5f7a318b0d96e Mon Sep 17 00:00:00 2001 From: sbruens Date: Tue, 25 Jun 2024 15:14:11 -0400 Subject: [PATCH 040/119] Simplify how we create new listeners. This does not yet deal with reused sockets. --- cmd/outline-ss-server/main.go | 128 ++++++++++++++-------------------- 1 file changed, 52 insertions(+), 76 deletions(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index ca019810..6d48a628 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -61,24 +61,19 @@ func init() { logger = logging.MustGetLogger("") } -type ssListener struct { - io.Closer - cipherList service.CipherList -} - type SSServer struct { natTimeout time.Duration m *outlineMetrics replayCache service.ReplayCache - listeners map[string]*ssListener // Keys are addresses, e.g. `tcp/[::]:9000` + listeners []io.Closer } -func (s *SSServer) serve(addr string, listener io.Closer, cipherList service.CipherList) error { +func (s *SSServer) serve(addr NetworkAddr, listener io.Closer, cipherList service.CipherList) error { switch ln := listener.(type) { case net.Listener: authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &s.replayCache, s.m) // TODO: Register initial data metrics at zero. - tcpHandler := service.NewTCPHandler(addr, authFunc, s.m, tcpReadTimeout) + tcpHandler := service.NewTCPHandler(addr.String(), authFunc, s.m, tcpReadTimeout) accept := func() (transport.StreamConn, error) { c, err := ln.Accept() return c.(transport.StreamConn), err @@ -93,41 +88,6 @@ func (s *SSServer) serve(addr string, listener io.Closer, cipherList service.Cip return nil } -func (s *SSServer) start(addr string, cipherList service.CipherList) (io.Closer, error) { - listenAddr, err := ParseNetworkAddr(addr) - if err != nil { - return nil, fmt.Errorf("error parsing listener address `%s`: %v", addr, err) - } - listener, err := listenAddr.Listen(context.TODO(), net.ListenConfig{KeepAlive: 0}) - if err != nil { - //lint:ignore ST1005 Shadowsocks is capitalized. - return nil, fmt.Errorf("Shadowsocks service failed to start on address %v: %w", addr, err) - } - logger.Infof("Shadowsocks service listening on %v", addr) - - err = s.serve(addr, listener, cipherList) - if err != nil { - return nil, fmt.Errorf("failed to serve on listener %v: %w", listener, err) - } - - return listener, nil -} - -func (s *SSServer) remove(addr string) error { - listener, ok := s.listeners[addr] - if !ok { - return fmt.Errorf("address %v doesn't exist", addr) - } - err := listener.Close() - delete(s.listeners, addr) - if err != nil { - //lint:ignore ST1005 Shadowsocks is capitalized. - return fmt.Errorf("Shadowsocks service on address %v failed to stop: %w", addr, err) - } - logger.Infof("Shadowsocks service on address %v stopped", addr) - return nil -} - func (s *SSServer) loadConfig(filename string) error { config, err := readConfig(filename) if err != nil { @@ -137,8 +97,9 @@ func (s *SSServer) loadConfig(filename string) error { return fmt.Errorf("failed to validate config: %w", err) } uniqueCiphers := 0 - addrChanges := make(map[string]int) - addrCiphers := make(map[string]*list.List) // Values are *List of *CipherEntry. + // TODO: Clone existing listeners so we can close them after starting the new ones. + addrs := make([]NetworkAddr, 0) + addrCiphers := make(map[NetworkAddr]*list.List) // Values are *List of *CipherEntry. for _, legacyKeyServiceConfig := range config.Keys { cryptoKey, err := shadowsocks.NewEncryptionKey(legacyKeyServiceConfig.Cipher, legacyKeyServiceConfig.Secret) @@ -146,9 +107,13 @@ func (s *SSServer) loadConfig(filename string) error { return fmt.Errorf("failed to create encyption key for key %v: %w", legacyKeyServiceConfig.ID, err) } entry := service.MakeCipherEntry(legacyKeyServiceConfig.ID, cryptoKey, legacyKeyServiceConfig.Secret) - for _, ln := range []string{"tcp", "udp"} { - addr := fmt.Sprintf("%s/[::]:%d", ln, legacyKeyServiceConfig.Port) - addrChanges[addr] = 1 + for _, network := range []string{"tcp", "udp"} { + addr := NetworkAddr{ + network: network, + Host: "::", + Port: uint(legacyKeyServiceConfig.Port), + } + addrs = append(addrs, addr) ciphers, ok := addrCiphers[addr] if !ok { ciphers = list.New() @@ -168,8 +133,7 @@ func (s *SSServer) loadConfig(filename string) error { existingCiphers := make(map[cipherKey]bool) for _, keyConfig := range serviceConfig.Keys { key := cipherKey{keyConfig.Cipher, keyConfig.Secret} - _, ok := existingCiphers[key] - if ok { + if _, exists := existingCiphers[key]; exists { logger.Debugf("encryption key already exists for ID=`%v`. Skipping.", keyConfig.ID) continue } @@ -184,45 +148,57 @@ func (s *SSServer) loadConfig(filename string) error { uniqueCiphers += ciphers.Len() for _, listener := range serviceConfig.Listeners { - addrChanges[listener.Address] = 1 - addrCiphers[listener.Address] = ciphers - } - } - for listener := range s.listeners { - addrChanges[listener] = addrChanges[listener] - 1 - } - for addr, count := range addrChanges { - if count == -1 { - if err := s.remove(addr); err != nil { - return fmt.Errorf("failed to remove address %v: %w", addr, err) - } - } else if count == +1 { - cipherList := service.NewCipherList() - listener, err := s.start(addr, cipherList) + addr, err := ParseNetworkAddr(listener.Address) if err != nil { - return err + return fmt.Errorf("error parsing listener address `%s`: %v", listener.Address, err) } - s.listeners[addr] = &ssListener{Closer: listener, cipherList: cipherList} + addrs = append(addrs, addr) + addrCiphers[addr] = ciphers } } - for addr, ciphers := range addrCiphers { - listener, ok := s.listeners[addr] + + for _, addr := range addrs { + cipherList := service.NewCipherList() + ciphers, ok := addrCiphers[addr] if !ok { - return fmt.Errorf("unable to find listener for address: %v", addr) + return fmt.Errorf("unable to find ciphers for address: %v", addr) + } + cipherList.Update(ciphers) + + listener, err := addr.Listen(context.TODO(), net.ListenConfig{KeepAlive: 0}) + if err != nil { + //lint:ignore ST1005 Shadowsocks is capitalized. + return fmt.Errorf("Shadowsocks %s service failed to start on address %s: %w", addr.Network(), addr.String(), err) + } + logger.Infof("Shadowsocks %s service listening on %s", addr.Network(), addr.String()) + s.listeners = append(s.listeners, listener) + + if err = s.serve(addr, listener, cipherList); err != nil { + return fmt.Errorf("failed to serve on listener %v: %w", listener, err) } - listener.cipherList.Update(ciphers) } logger.Infof("Loaded %v access keys over %v listeners", uniqueCiphers, len(s.listeners)) s.m.SetNumAccessKeys(uniqueCiphers, len(s.listeners)) return nil } -// Stop serving on all ports. +// Stop serving on all listeners. func (s *SSServer) Stop() error { - for addr := range s.listeners { - if err := s.remove(addr); err != nil { - return err + for _, listener := range s.listeners { + var addr net.Addr + switch ln := listener.(type) { + case net.Listener: + addr = ln.Addr() + case net.PacketConn: + addr = ln.LocalAddr() + default: + return fmt.Errorf("unknown listener type: %v", ln) + } + if err := listener.Close(); err != nil { + //lint:ignore ST1005 Shadowsocks is capitalized. + return fmt.Errorf("Shadowsocks service on address %s %s failed to stop: %w", addr.Network(), addr.String(), err) } + logger.Infof("Shadowsocks service on address %s %s stopped", addr.Network(), addr.String()) } return nil } @@ -233,7 +209,7 @@ func RunSSServer(filename string, natTimeout time.Duration, sm *outlineMetrics, natTimeout: natTimeout, m: sm, replayCache: service.NewReplayCache(replayHistory), - listeners: make(map[string]*ssListener), + listeners: nil, } err := server.loadConfig(filename) if err != nil { From af3ca3155f8335159c7e11dbe167928675a01fcb Mon Sep 17 00:00:00 2001 From: sbruens Date: Fri, 28 Jun 2024 16:47:33 -0400 Subject: [PATCH 041/119] Do not use `io.Closer`. --- cmd/outline-ss-server/listeners.go | 9 ++++++-- cmd/outline-ss-server/main.go | 36 ++++++++++++++++++------------ service/tcp.go | 2 +- 3 files changed, 30 insertions(+), 17 deletions(-) diff --git a/cmd/outline-ss-server/listeners.go b/cmd/outline-ss-server/listeners.go index a5c9c8b0..3457b794 100644 --- a/cmd/outline-ss-server/listeners.go +++ b/cmd/outline-ss-server/listeners.go @@ -18,7 +18,6 @@ import ( "context" "errors" "fmt" - "io" "net" "strconv" "strings" @@ -45,8 +44,14 @@ func (na *NetworkAddr) JoinHostPort() string { return net.JoinHostPort(na.Host, strconv.Itoa(int(na.Port))) } +// Key returns a representative string useful to retrieve this entity from a +// map. This is used to uniquely identify reusable listeners. +func (na *NetworkAddr) Key() string { + return na.network + "/" + na.JoinHostPort() +} + // Listen creates a new listener for the [NetworkAddr]. -func (na *NetworkAddr) Listen(ctx context.Context, config net.ListenConfig) (io.Closer, error) { +func (na *NetworkAddr) Listen(ctx context.Context, config net.ListenConfig) (Listener, error) { address := na.JoinHostPort() switch na.network { diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 6d48a628..cd96b7ea 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -19,7 +19,6 @@ import ( "context" "flag" "fmt" - "io" "net" "net/http" "os" @@ -65,15 +64,22 @@ type SSServer struct { natTimeout time.Duration m *outlineMetrics replayCache service.ReplayCache - listeners []io.Closer + listeners []Listener } -func (s *SSServer) serve(addr NetworkAddr, listener io.Closer, cipherList service.CipherList) error { +// The implementations of listeners for different network types are not +// interchangeable. The type of listener depends on the network type. +// TODO(sbruens): Create a custom `Listener` type so we can share serving logic, +// dispatching to the handlers based on connection type instead of on the +// listener type. +type Listener = any + +func (s *SSServer) serve(addr NetworkAddr, listener Listener, cipherList service.CipherList) error { switch ln := listener.(type) { case net.Listener: authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &s.replayCache, s.m) // TODO: Register initial data metrics at zero. - tcpHandler := service.NewTCPHandler(addr.String(), authFunc, s.m, tcpReadTimeout) + tcpHandler := service.NewTCPHandler(addr.Key(), authFunc, s.m, tcpReadTimeout) accept := func() (transport.StreamConn, error) { c, err := ln.Accept() return c.(transport.StreamConn), err @@ -161,7 +167,7 @@ func (s *SSServer) loadConfig(filename string) error { cipherList := service.NewCipherList() ciphers, ok := addrCiphers[addr] if !ok { - return fmt.Errorf("unable to find ciphers for address: %v", addr) + return fmt.Errorf("unable to find ciphers for address: %v", addr.Key()) } cipherList.Update(ciphers) @@ -174,7 +180,7 @@ func (s *SSServer) loadConfig(filename string) error { s.listeners = append(s.listeners, listener) if err = s.serve(addr, listener, cipherList); err != nil { - return fmt.Errorf("failed to serve on listener %v: %w", listener, err) + return fmt.Errorf("failed to serve on %s listener on address %s: %w", addr.Network(), addr.String(), err) } } logger.Infof("Loaded %v access keys over %v listeners", uniqueCiphers, len(s.listeners)) @@ -185,20 +191,22 @@ func (s *SSServer) loadConfig(filename string) error { // Stop serving on all listeners. func (s *SSServer) Stop() error { for _, listener := range s.listeners { - var addr net.Addr switch ln := listener.(type) { case net.Listener: - addr = ln.Addr() + err := ln.Close() + if err != nil { + //lint:ignore ST1005 Shadowsocks is capitalized. + return fmt.Errorf("Shadowsocks %s service on address %s failed to stop: %w", ln.Addr().Network(), ln.Addr().String(), err) + } case net.PacketConn: - addr = ln.LocalAddr() + err := ln.Close() + if err != nil { + //lint:ignore ST1005 Shadowsocks is capitalized. + return fmt.Errorf("Shadowsocks %s service on address %s failed to stop: %w", ln.LocalAddr().Network(), ln.LocalAddr().String(), err) + } default: return fmt.Errorf("unknown listener type: %v", ln) } - if err := listener.Close(); err != nil { - //lint:ignore ST1005 Shadowsocks is capitalized. - return fmt.Errorf("Shadowsocks service on address %s %s failed to stop: %w", addr.Network(), addr.String(), err) - } - logger.Infof("Shadowsocks service on address %s %s stopped", addr.Network(), addr.String()) } return nil } diff --git a/service/tcp.go b/service/tcp.go index 2195480b..ced85a54 100644 --- a/service/tcp.go +++ b/service/tcp.go @@ -236,7 +236,7 @@ func StreamServe(accept StreamListener, handle StreamHandler) { if errors.Is(err, net.ErrClosed) { break } - logger.Warningf("AcceptTCP failed: %v. Continuing to listen.", err) + logger.Warningf("Accept failed: %v. Continuing to listen.", err) continue } From fc725937c549d9341fae774edac7aa55ae0e54f7 Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 1 Jul 2024 15:08:33 -0400 Subject: [PATCH 042/119] Use an inline error check. --- cmd/outline-ss-server/main.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index cd96b7ea..097d3653 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -193,14 +193,12 @@ func (s *SSServer) Stop() error { for _, listener := range s.listeners { switch ln := listener.(type) { case net.Listener: - err := ln.Close() - if err != nil { + if err := ln.Close(); err != nil { //lint:ignore ST1005 Shadowsocks is capitalized. return fmt.Errorf("Shadowsocks %s service on address %s failed to stop: %w", ln.Addr().Network(), ln.Addr().String(), err) } case net.PacketConn: - err := ln.Close() - if err != nil { + if err := ln.Close(); err != nil { //lint:ignore ST1005 Shadowsocks is capitalized. return fmt.Errorf("Shadowsocks %s service on address %s failed to stop: %w", ln.LocalAddr().Network(), ln.LocalAddr().String(), err) } From b24a3390d5e7e5c3544c32ec365b93c6f9744f9f Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 1 Jul 2024 14:49:14 -0400 Subject: [PATCH 043/119] Use shared listeners and packet connections. This allows us to reload a config while the existing one is still running. They share the same underlying listener, which is actually closed when the last user closes it. --- cmd/outline-ss-server/listeners.go | 195 ++++++++++++++++++++++++++++- 1 file changed, 190 insertions(+), 5 deletions(-) diff --git a/cmd/outline-ss-server/listeners.go b/cmd/outline-ss-server/listeners.go index 3457b794..d819587e 100644 --- a/cmd/outline-ss-server/listeners.go +++ b/cmd/outline-ss-server/listeners.go @@ -21,8 +21,133 @@ import ( "net" "strconv" "strings" + "sync" + "sync/atomic" + "time" ) +var ( + listeners = make(map[string]*globalListener) + listenersMu sync.Mutex +) + +type sharedListener struct { + net.Listener + key string + closed atomic.Int32 + usage *atomic.Int32 + deadline *bool + deadlineMu *sync.Mutex +} + +// Accept accepts connections until Close() is called. +func (sl *sharedListener) Accept() (net.Conn, error) { + if sl.closed.Load() == 1 { + return nil, &net.OpError{ + Op: "accept", + Net: sl.Listener.Addr().Network(), + Addr: sl.Listener.Addr(), + Err: fmt.Errorf("listener closed"), + } + } + + conn, err := sl.Listener.Accept() + if err == nil { + return conn, nil + } + + sl.deadlineMu.Lock() + if *sl.deadline { + switch ln := sl.Listener.(type) { + case *net.TCPListener: + ln.SetDeadline(time.Time{}) + } + *sl.deadline = false + } + sl.deadlineMu.Unlock() + + if sl.closed.Load() == 1 { + // In `Close()` we set a deadline in the past to force currently-blocked + // listeners to close without having to close the underlying socket. To + // avoid callers from retrying, we avoid returning timeout errors and + // instead make sure we return a fake "closed" error. + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + return nil, &net.OpError{ + Op: "accept", + Net: sl.Listener.Addr().Network(), + Addr: sl.Listener.Addr(), + Err: fmt.Errorf("listener closed"), + } + } + } + + return nil, err +} + +// Close stops accepting new connections without closing the underlying socket. +// Only when the last user closes it, we actually close it. +func (sl *sharedListener) Close() error { + if sl.closed.CompareAndSwap(0, 1) { + // NOTE: In order to cancel current calls to Accept(), we set a deadline in + // the past, as we cannot actually close the listener. + sl.deadlineMu.Lock() + if !*sl.deadline { + switch ln := sl.Listener.(type) { + case *net.TCPListener: + ln.SetDeadline(time.Now().Add(-1 * time.Minute)) + } + *sl.deadline = true + } + sl.deadlineMu.Unlock() + + // See if we need to actually close the underlying listener. + if sl.usage.Add(-1) == 0 { + listenersMu.Lock() + delete(listeners, sl.key) + listenersMu.Unlock() + err := sl.Listener.Close() + if err != nil { + return err + } + } + + } + + return nil +} + +type sharedPacketConn struct { + net.PacketConn + key string + closed atomic.Int32 + usage *atomic.Int32 +} + +func (spc *sharedPacketConn) Close() error { + if spc.closed.CompareAndSwap(0, 1) { + // See if we need to actually close the underlying listener. + if spc.usage.Add(-1) == 0 { + listenersMu.Lock() + delete(listeners, spc.key) + listenersMu.Unlock() + err := spc.PacketConn.Close() + if err != nil { + return err + } + } + } + + return nil +} + +type globalListener struct { + ln net.Listener + pc net.PacketConn + usage atomic.Int32 + deadline bool + deadlineMu sync.Mutex +} + type NetworkAddr struct { network string Host string @@ -51,17 +176,77 @@ func (na *NetworkAddr) Key() string { } // Listen creates a new listener for the [NetworkAddr]. -func (na *NetworkAddr) Listen(ctx context.Context, config net.ListenConfig) (Listener, error) { - address := na.JoinHostPort() - +// +// Listeners can overlap one another, because during config changes the new +// config is started before the old config is destroyed. This is done by using +// reusable listener wrappers, which do not actually close the underlying socket +// until all uses of the shared listener have been closed. +func (na *NetworkAddr) Listen(ctx context.Context, config net.ListenConfig) (any, error) { switch na.network { case "tcp": - return config.Listen(ctx, na.network, address) + listenersMu.Lock() + defer listenersMu.Unlock() + + if lnGlobal, ok := listeners[na.Key()]; ok { + lnGlobal.usage.Add(1) + return &sharedListener{ + usage: &lnGlobal.usage, + deadline: &lnGlobal.deadline, + deadlineMu: &lnGlobal.deadlineMu, + key: na.Key(), + Listener: lnGlobal.ln, + }, nil + } + + ln, err := config.Listen(ctx, na.network, na.JoinHostPort()) + if err != nil { + return nil, err + } + + lnGlobal := &globalListener{ln: ln} + lnGlobal.usage.Store(1) + listeners[na.Key()] = lnGlobal + + return &sharedListener{ + usage: &lnGlobal.usage, + deadline: &lnGlobal.deadline, + deadlineMu: &lnGlobal.deadlineMu, + key: na.Key(), + Listener: ln, + }, nil + case "udp": - return config.ListenPacket(ctx, na.network, address) + listenersMu.Lock() + defer listenersMu.Unlock() + + if lnGlobal, ok := listeners[na.Key()]; ok { + lnGlobal.usage.Add(1) + return &sharedPacketConn{ + usage: &lnGlobal.usage, + key: na.Key(), + PacketConn: lnGlobal.pc, + }, nil + } + + pc, err := config.ListenPacket(ctx, na.network, na.JoinHostPort()) + if err != nil { + return nil, err + } + + lnGlobal := &globalListener{pc: pc} + lnGlobal.usage.Store(1) + listeners[na.Key()] = lnGlobal + + return &sharedPacketConn{ + usage: &lnGlobal.usage, + key: na.Key(), + PacketConn: pc, + }, nil + default: return nil, fmt.Errorf("unsupported network: %s", na.network) + } } From 3bc76bc5a11d72d01e506c925684b4ddcc89afee Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 1 Jul 2024 16:23:19 -0400 Subject: [PATCH 044/119] Close existing listeners once the new ones are serving. --- cmd/outline-ss-server/listeners.go | 4 ++-- cmd/outline-ss-server/main.go | 20 ++++++++++++++++---- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/cmd/outline-ss-server/listeners.go b/cmd/outline-ss-server/listeners.go index d819587e..fc83b5d7 100644 --- a/cmd/outline-ss-server/listeners.go +++ b/cmd/outline-ss-server/listeners.go @@ -47,7 +47,7 @@ func (sl *sharedListener) Accept() (net.Conn, error) { Op: "accept", Net: sl.Listener.Addr().Network(), Addr: sl.Listener.Addr(), - Err: fmt.Errorf("listener closed"), + Err: net.ErrClosed, } } @@ -76,7 +76,7 @@ func (sl *sharedListener) Accept() (net.Conn, error) { Op: "accept", Net: sl.Listener.Addr().Network(), Addr: sl.Listener.Addr(), - Err: fmt.Errorf("listener closed"), + Err: net.ErrClosed, } } } diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 097d3653..2d3ac56a 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -82,7 +82,10 @@ func (s *SSServer) serve(addr NetworkAddr, listener Listener, cipherList service tcpHandler := service.NewTCPHandler(addr.Key(), authFunc, s.m, tcpReadTimeout) accept := func() (transport.StreamConn, error) { c, err := ln.Accept() - return c.(transport.StreamConn), err + if err == nil { + return c.(transport.StreamConn), err + } + return nil, err } go service.StreamServe(accept, tcpHandler.Handle) case net.PacketConn: @@ -102,8 +105,8 @@ func (s *SSServer) loadConfig(filename string) error { if err := config.Validate(); err != nil { return fmt.Errorf("failed to validate config: %w", err) } + uniqueCiphers := 0 - // TODO: Clone existing listeners so we can close them after starting the new ones. addrs := make([]NetworkAddr, 0) addrCiphers := make(map[NetworkAddr]*list.List) // Values are *List of *CipherEntry. @@ -163,6 +166,8 @@ func (s *SSServer) loadConfig(filename string) error { } } + // Create new listeners based on the configured network addresses. + newListeners := make([]Listener, 0) for _, addr := range addrs { cipherList := service.NewCipherList() ciphers, ok := addrCiphers[addr] @@ -177,18 +182,25 @@ func (s *SSServer) loadConfig(filename string) error { return fmt.Errorf("Shadowsocks %s service failed to start on address %s: %w", addr.Network(), addr.String(), err) } logger.Infof("Shadowsocks %s service listening on %s", addr.Network(), addr.String()) - s.listeners = append(s.listeners, listener) + newListeners = append(newListeners, listener) if err = s.serve(addr, listener, cipherList); err != nil { return fmt.Errorf("failed to serve on %s listener on address %s: %w", addr.Network(), addr.String(), err) } } + + // Take down the old listeners now that the new ones are serving. + if err := s.Stop(); err != nil { + logger.Warningf("Failed to stop old listeners: %w", err) + } + s.listeners = newListeners + logger.Infof("Loaded %v access keys over %v listeners", uniqueCiphers, len(s.listeners)) s.m.SetNumAccessKeys(uniqueCiphers, len(s.listeners)) return nil } -// Stop serving on all listeners. +// Stop serving on all existing listeners. func (s *SSServer) Stop() error { for _, listener := range s.listeners { switch ln := listener.(type) { From f71b13d6bbcfaa680714b2a4924801c69c955b77 Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 1 Jul 2024 16:56:46 -0400 Subject: [PATCH 045/119] Elevate failure to stop listeners to `ERROR` level. --- cmd/outline-ss-server/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 2d3ac56a..9b3d4d65 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -191,7 +191,7 @@ func (s *SSServer) loadConfig(filename string) error { // Take down the old listeners now that the new ones are serving. if err := s.Stop(); err != nil { - logger.Warningf("Failed to stop old listeners: %w", err) + logger.Errorf("Failed to stop old listeners: %w", err) } s.listeners = newListeners From 32cc180b2c50bc3755f4d83c6b11c6dfa5de324f Mon Sep 17 00:00:00 2001 From: sbruens Date: Tue, 2 Jul 2024 10:30:56 -0400 Subject: [PATCH 046/119] Be more lenient in config validation to allow empty listeners or keys. --- cmd/outline-ss-server/config.go | 5 ----- cmd/outline-ss-server/config_test.go | 30 ++++++---------------------- cmd/outline-ss-server/main.go | 1 + 3 files changed, 7 insertions(+), 29 deletions(-) diff --git a/cmd/outline-ss-server/config.go b/cmd/outline-ss-server/config.go index b970c120..d101bf26 100644 --- a/cmd/outline-ss-server/config.go +++ b/cmd/outline-ss-server/config.go @@ -15,7 +15,6 @@ package main import ( - "errors" "fmt" "os" @@ -58,10 +57,6 @@ type Config struct { // Validate checks that the config is valid. func (c *Config) Validate() error { for _, serviceConfig := range c.Services { - if serviceConfig.Listeners == nil || serviceConfig.Keys == nil { - return errors.New("must specify at least 1 listener and 1 key per service") - } - for _, listenerConfig := range serviceConfig.Listeners { // TODO: Support more listener types. if listenerConfig.Type != listenerTypeDirect { diff --git a/cmd/outline-ss-server/config_test.go b/cmd/outline-ss-server/config_test.go index 3e76fa57..463de266 100644 --- a/cmd/outline-ss-server/config_test.go +++ b/cmd/outline-ss-server/config_test.go @@ -27,54 +27,36 @@ func TestValidateConfigFails(t *testing.T) { cfg *Config }{ { - name: "WithoutListeners", - cfg: &Config{ - Services: []ServiceConfig{ - ServiceConfig{ - Keys: []KeyConfig{ - KeyConfig{"user-0", "chacha20-ietf-poly1305", "Secret0"}, - }, - }, - }, - }, - }, - { - name: "WithoutKeys", + name: "WithUnknownListenerType", cfg: &Config{ Services: []ServiceConfig{ ServiceConfig{ Listeners: []ListenerConfig{ - ListenerConfig{Type: listenerTypeDirect, Address: "tcp/[::]:9000"}, + ListenerConfig{Type: "foo", Address: "tcp/[::]:9000"}, }, }, }, }, }, { - name: "WithUnknownListenerType", + name: "WithInvalidListenerAddress", cfg: &Config{ Services: []ServiceConfig{ ServiceConfig{ Listeners: []ListenerConfig{ - ListenerConfig{Type: "foo", Address: "tcp/[::]:9000"}, - }, - Keys: []KeyConfig{ - KeyConfig{"user-0", "chacha20-ietf-poly1305", "Secret0"}, + ListenerConfig{Type: listenerTypeDirect, Address: "tcp//[::]:9000"}, }, }, }, }, }, { - name: "WithInvalidListenerAddress", + name: "WithUnsupportedNetworkType", cfg: &Config{ Services: []ServiceConfig{ ServiceConfig{ Listeners: []ListenerConfig{ - ListenerConfig{Type: listenerTypeDirect, Address: "tcp//[::]:9000"}, - }, - Keys: []KeyConfig{ - KeyConfig{"user-0", "chacha20-ietf-poly1305", "Secret0"}, + ListenerConfig{Type: listenerTypeDirect, Address: "foo/[::]:9000"}, }, }, }, diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index b9e70497..9b3d4d65 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -27,6 +27,7 @@ import ( "syscall" "time" + "github.com/Jigsaw-Code/outline-sdk/transport" "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks" "github.com/Jigsaw-Code/outline-ss-server/ipinfo" From 640f80f0089e26f364f9eab8b9d97068c4d44931 Mon Sep 17 00:00:00 2001 From: sbruens Date: Wed, 3 Jul 2024 17:16:14 -0400 Subject: [PATCH 047/119] Ensure the address is an IP address. --- cmd/outline-ss-server/config.go | 6 +++++- cmd/outline-ss-server/config_test.go | 12 ++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/cmd/outline-ss-server/config.go b/cmd/outline-ss-server/config.go index d101bf26..2341f4a5 100644 --- a/cmd/outline-ss-server/config.go +++ b/cmd/outline-ss-server/config.go @@ -16,6 +16,7 @@ package main import ( "fmt" + "net" "os" "gopkg.in/yaml.v2" @@ -63,13 +64,16 @@ func (c *Config) Validate() error { return fmt.Errorf("unsupported listener type: %s", listenerConfig.Type) } - network, _, _, err := SplitNetworkAddr(listenerConfig.Address) + network, host, _, err := SplitNetworkAddr(listenerConfig.Address) if err != nil { return fmt.Errorf("invalid listener address `%s`: %v", listenerConfig.Address, err) } if network != "tcp" && network != "udp" { return fmt.Errorf("unsupported network: %s", network) } + if ip := net.ParseIP(host); ip == nil { + return fmt.Errorf("address must be IP, found: %s", host) + } } } return nil diff --git a/cmd/outline-ss-server/config_test.go b/cmd/outline-ss-server/config_test.go index 463de266..86d20afb 100644 --- a/cmd/outline-ss-server/config_test.go +++ b/cmd/outline-ss-server/config_test.go @@ -62,6 +62,18 @@ func TestValidateConfigFails(t *testing.T) { }, }, }, + { + name: "WithHostnameAddress", + cfg: &Config{ + Services: []ServiceConfig{ + ServiceConfig{ + Listeners: []ListenerConfig{ + ListenerConfig{Type: listenerTypeDirect, Address: "tcp/example.com:9000"}, + }, + }, + }, + }, + }, } for _, tc := range tests { From 22638c7370442990491d9ecb639714674befb250 Mon Sep 17 00:00:00 2001 From: sbruens Date: Wed, 3 Jul 2024 17:18:29 -0400 Subject: [PATCH 048/119] Use `yaml.v3`. --- cmd/outline-ss-server/config.go | 2 +- go.mod | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cmd/outline-ss-server/config.go b/cmd/outline-ss-server/config.go index 2341f4a5..d678679e 100644 --- a/cmd/outline-ss-server/config.go +++ b/cmd/outline-ss-server/config.go @@ -19,7 +19,7 @@ import ( "net" "os" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" ) type ServiceConfig struct { diff --git a/go.mod b/go.mod index 04a9ddab..5c1419d2 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/stretchr/testify v1.8.4 golang.org/x/crypto v0.17.0 golang.org/x/term v0.16.0 - gopkg.in/yaml.v2 v2.4.0 + gopkg.in/yaml.v3 v3.0.1 ) require ( @@ -263,7 +263,7 @@ require ( gopkg.in/src-d/go-billy.v4 v4.3.2 // indirect gopkg.in/src-d/go-git.v4 v4.13.1 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect k8s.io/klog/v2 v2.90.0 // indirect mvdan.cc/sh/v3 v3.7.0 // indirect sigs.k8s.io/kind v0.17.0 // indirect From 2631b872ccd9d3e1d36390e76ca8f0725c0a80f4 Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 8 Jul 2024 12:41:14 -0400 Subject: [PATCH 049/119] Move file reading back to `main.go`. --- cmd/outline-ss-server/config.go | 10 ++-------- cmd/outline-ss-server/config_test.go | 20 +++++++++----------- cmd/outline-ss-server/main.go | 6 +++++- 3 files changed, 16 insertions(+), 20 deletions(-) diff --git a/cmd/outline-ss-server/config.go b/cmd/outline-ss-server/config.go index d678679e..5929f202 100644 --- a/cmd/outline-ss-server/config.go +++ b/cmd/outline-ss-server/config.go @@ -17,7 +17,6 @@ package main import ( "fmt" "net" - "os" "gopkg.in/yaml.v3" ) @@ -80,14 +79,9 @@ func (c *Config) Validate() error { } // readConfig attempts to read a config from a filename and parses it as a [Config]. -func readConfig(filename string) (*Config, error) { +func readConfig(configData []byte) (*Config, error) { config := Config{} - configData, err := os.ReadFile(filename) - if err != nil { - return nil, fmt.Errorf("failed to read config: %w", err) - } - err = yaml.Unmarshal(configData, &config) - if err != nil { + if err := yaml.Unmarshal(configData, &config); err != nil { return nil, fmt.Errorf("failed to parse config: %w", err) } return &config, nil diff --git a/cmd/outline-ss-server/config_test.go b/cmd/outline-ss-server/config_test.go index 86d20afb..3a8b6cf8 100644 --- a/cmd/outline-ss-server/config_test.go +++ b/cmd/outline-ss-server/config_test.go @@ -85,7 +85,7 @@ func TestValidateConfigFails(t *testing.T) { } func TestReadConfig(t *testing.T) { - config, err := readConfig("./config_example.yml") + config, err := readConfigFile("./config_example.yml") require.NoError(t, err) expected := Config{ @@ -115,7 +115,7 @@ func TestReadConfig(t *testing.T) { } func TestReadConfigParsesDeprecatedFormat(t *testing.T) { - config, err := readConfig("./config_example.deprecated.yml") + config, err := readConfigFile("./config_example.deprecated.yml") require.NoError(t, err) expected := Config{ @@ -140,25 +140,23 @@ func TestReadConfigParsesDeprecatedFormat(t *testing.T) { func TestReadConfigFromEmptyFile(t *testing.T) { file, _ := os.CreateTemp("", "empty.yaml") - config, err := readConfig(file.Name()) + config, err := readConfigFile(file.Name()) require.NoError(t, err) require.ElementsMatch(t, Config{}, config) } -func TestReadConfigFromNonExistingFileFails(t *testing.T) { - config, err := readConfig("./foo") - - require.Error(t, err) - require.ElementsMatch(t, nil, config) -} - func TestReadConfigFromIncorrectFormatFails(t *testing.T) { file, _ := os.CreateTemp("", "empty.yaml") file.WriteString("foo") - config, err := readConfig(file.Name()) + config, err := readConfigFile(file.Name()) require.Error(t, err) require.ElementsMatch(t, Config{}, config) } + +func readConfigFile(filename string) (*Config, error) { + configData, _ := os.ReadFile(filename) + return readConfig(configData) +} diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 9b3d4d65..1deb3454 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -98,7 +98,11 @@ func (s *SSServer) serve(addr NetworkAddr, listener Listener, cipherList service } func (s *SSServer) loadConfig(filename string) error { - config, err := readConfig(filename) + configData, err := os.ReadFile(filename) + if err != nil { + return fmt.Errorf("failed to read config file %s: %w", filename, err) + } + config, err := readConfig(configData) if err != nil { return fmt.Errorf("failed to load config (%v): %w", filename, err) } From d76efd2f7a6a5a23e4cc465657853d99888c04aa Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 8 Jul 2024 13:25:20 -0400 Subject: [PATCH 050/119] Do not embed the `net.Listener` type. --- cmd/outline-ss-server/listeners.go | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/cmd/outline-ss-server/listeners.go b/cmd/outline-ss-server/listeners.go index fc83b5d7..fefa6920 100644 --- a/cmd/outline-ss-server/listeners.go +++ b/cmd/outline-ss-server/listeners.go @@ -32,7 +32,7 @@ var ( ) type sharedListener struct { - net.Listener + listener net.Listener key string closed atomic.Int32 usage *atomic.Int32 @@ -45,20 +45,20 @@ func (sl *sharedListener) Accept() (net.Conn, error) { if sl.closed.Load() == 1 { return nil, &net.OpError{ Op: "accept", - Net: sl.Listener.Addr().Network(), - Addr: sl.Listener.Addr(), + Net: sl.listener.Addr().Network(), + Addr: sl.listener.Addr(), Err: net.ErrClosed, } } - conn, err := sl.Listener.Accept() + conn, err := sl.listener.Accept() if err == nil { return conn, nil } sl.deadlineMu.Lock() if *sl.deadline { - switch ln := sl.Listener.(type) { + switch ln := sl.listener.(type) { case *net.TCPListener: ln.SetDeadline(time.Time{}) } @@ -74,8 +74,8 @@ func (sl *sharedListener) Accept() (net.Conn, error) { if netErr, ok := err.(net.Error); ok && netErr.Timeout() { return nil, &net.OpError{ Op: "accept", - Net: sl.Listener.Addr().Network(), - Addr: sl.Listener.Addr(), + Net: sl.listener.Addr().Network(), + Addr: sl.listener.Addr(), Err: net.ErrClosed, } } @@ -92,7 +92,7 @@ func (sl *sharedListener) Close() error { // the past, as we cannot actually close the listener. sl.deadlineMu.Lock() if !*sl.deadline { - switch ln := sl.Listener.(type) { + switch ln := sl.listener.(type) { case *net.TCPListener: ln.SetDeadline(time.Now().Add(-1 * time.Minute)) } @@ -105,7 +105,7 @@ func (sl *sharedListener) Close() error { listenersMu.Lock() delete(listeners, sl.key) listenersMu.Unlock() - err := sl.Listener.Close() + err := sl.listener.Close() if err != nil { return err } @@ -116,6 +116,10 @@ func (sl *sharedListener) Close() error { return nil } +func (sl *sharedListener) Addr() net.Addr { + return sl.listener.Addr() +} + type sharedPacketConn struct { net.PacketConn key string @@ -195,7 +199,7 @@ func (na *NetworkAddr) Listen(ctx context.Context, config net.ListenConfig) (any deadline: &lnGlobal.deadline, deadlineMu: &lnGlobal.deadlineMu, key: na.Key(), - Listener: lnGlobal.ln, + listener: lnGlobal.ln, }, nil } @@ -213,7 +217,7 @@ func (na *NetworkAddr) Listen(ctx context.Context, config net.ListenConfig) (any deadline: &lnGlobal.deadline, deadlineMu: &lnGlobal.deadlineMu, key: na.Key(), - Listener: ln, + listener: ln, }, nil case "udp": From b8c5ab8db267099e8e5f74d0f212c87394854f38 Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 8 Jul 2024 14:47:52 -0400 Subject: [PATCH 051/119] Use a `Service` object to abstract away some of the complex logic of managing listeners. --- cmd/outline-ss-server/main.go | 266 ++++++++++++++++++++-------------- 1 file changed, 161 insertions(+), 105 deletions(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 1deb3454..6093bb68 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -60,13 +60,6 @@ func init() { logger = logging.MustGetLogger("") } -type SSServer struct { - natTimeout time.Duration - m *outlineMetrics - replayCache service.ReplayCache - listeners []Listener -} - // The implementations of listeners for different network types are not // interchangeable. The type of listener depends on the network type. // TODO(sbruens): Create a custom `Listener` type so we can share serving logic, @@ -74,10 +67,18 @@ type SSServer struct { // listener type. type Listener = any -func (s *SSServer) serve(addr NetworkAddr, listener Listener, cipherList service.CipherList) error { +type Service struct { + natTimeout time.Duration + m *outlineMetrics + replayCache *service.ReplayCache + Listeners []Listener + Ciphers *list.List // Values are *List of *CipherEntry. +} + +func (s *Service) Serve(addr NetworkAddr, listener Listener, cipherList service.CipherList) error { switch ln := listener.(type) { case net.Listener: - authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &s.replayCache, s.m) + authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, s.replayCache, s.m) // TODO: Register initial data metrics at zero. tcpHandler := service.NewTCPHandler(addr.Key(), authFunc, s.m, tcpReadTimeout) accept := func() (transport.StreamConn, error) { @@ -97,131 +98,187 @@ func (s *SSServer) serve(addr NetworkAddr, listener Listener, cipherList service return nil } -func (s *SSServer) loadConfig(filename string) error { - configData, err := os.ReadFile(filename) - if err != nil { - return fmt.Errorf("failed to read config file %s: %w", filename, err) +func (s *Service) Stop() error { + for _, listener := range s.Listeners { + switch ln := listener.(type) { + case net.Listener: + if err := ln.Close(); err != nil { + //lint:ignore ST1005 Shadowsocks is capitalized. + return fmt.Errorf("Shadowsocks %s service on address %s failed to stop: %w", ln.Addr().Network(), ln.Addr().String(), err) + } + case net.PacketConn: + if err := ln.Close(); err != nil { + //lint:ignore ST1005 Shadowsocks is capitalized. + return fmt.Errorf("Shadowsocks %s service on address %s failed to stop: %w", ln.LocalAddr().Network(), ln.LocalAddr().String(), err) + } + default: + return fmt.Errorf("unknown listener type: %v", ln) + } } - config, err := readConfig(configData) + return nil +} + +// AddListener adds a new listener to the service. +func (s *Service) AddListener(addr NetworkAddr) error { + // Create new listeners based on the configured network addresses. + cipherList := service.NewCipherList() + cipherList.Update(s.Ciphers) + + listener, err := addr.Listen(context.TODO(), net.ListenConfig{KeepAlive: 0}) if err != nil { - return fmt.Errorf("failed to load config (%v): %w", filename, err) + //lint:ignore ST1005 Shadowsocks is capitalized. + return fmt.Errorf("Shadowsocks %s service failed to start on address %s: %w", addr.Network(), addr.String(), err) } - if err := config.Validate(); err != nil { - return fmt.Errorf("failed to validate config: %w", err) + s.Listeners = append(s.Listeners, listener) + logger.Infof("Shadowsocks %s service listening on %s", addr.Network(), addr.String()) + if err = s.Serve(addr, listener, cipherList); err != nil { + return fmt.Errorf("failed to serve on %s listener on address %s: %w", addr.Network(), addr.String(), err) } + return nil +} - uniqueCiphers := 0 - addrs := make([]NetworkAddr, 0) - addrCiphers := make(map[NetworkAddr]*list.List) // Values are *List of *CipherEntry. +// NewService creates a new Service. +func NewService(config ServiceConfig, natTimeout time.Duration, m *outlineMetrics, replayCache *service.ReplayCache) (*Service, error) { + s := Service{ + natTimeout: natTimeout, + m: m, + replayCache: replayCache, + Ciphers: list.New(), + } - for _, legacyKeyServiceConfig := range config.Keys { - cryptoKey, err := shadowsocks.NewEncryptionKey(legacyKeyServiceConfig.Cipher, legacyKeyServiceConfig.Secret) - if err != nil { - return fmt.Errorf("failed to create encyption key for key %v: %w", legacyKeyServiceConfig.ID, err) + type cipherKey struct { + cipher string + secret string + } + existingCiphers := make(map[cipherKey]bool) + for _, keyConfig := range config.Keys { + key := cipherKey{keyConfig.Cipher, keyConfig.Secret} + if _, exists := existingCiphers[key]; exists { + logger.Debugf("encryption key already exists for ID=`%v`. Skipping.", keyConfig.ID) + continue } - entry := service.MakeCipherEntry(legacyKeyServiceConfig.ID, cryptoKey, legacyKeyServiceConfig.Secret) - for _, network := range []string{"tcp", "udp"} { - addr := NetworkAddr{ - network: network, - Host: "::", - Port: uint(legacyKeyServiceConfig.Port), - } - addrs = append(addrs, addr) - ciphers, ok := addrCiphers[addr] - if !ok { - ciphers = list.New() - addrCiphers[addr] = ciphers - } - ciphers.PushBack(&entry) + cryptoKey, err := shadowsocks.NewEncryptionKey(keyConfig.Cipher, keyConfig.Secret) + if err != nil { + return nil, fmt.Errorf("failed to create encyption key for key %v: %w", keyConfig.ID, err) } - uniqueCiphers += 1 + entry := service.MakeCipherEntry(keyConfig.ID, cryptoKey, keyConfig.Secret) + s.Ciphers.PushBack(&entry) + existingCiphers[key] = true } - for _, serviceConfig := range config.Services { - ciphers := list.New() - type cipherKey struct { - cipher string - secret string + for _, listener := range config.Listeners { + addr, err := ParseNetworkAddr(listener.Address) + if err != nil { + return nil, fmt.Errorf("error parsing listener address `%s`: %v", listener.Address, err) } - existingCiphers := make(map[cipherKey]bool) - for _, keyConfig := range serviceConfig.Keys { - key := cipherKey{keyConfig.Cipher, keyConfig.Secret} - if _, exists := existingCiphers[key]; exists { - logger.Debugf("encryption key already exists for ID=`%v`. Skipping.", keyConfig.ID) - continue - } - cryptoKey, err := shadowsocks.NewEncryptionKey(keyConfig.Cipher, keyConfig.Secret) - if err != nil { - return fmt.Errorf("failed to create encyption key for key %v: %w", keyConfig.ID, err) - } - entry := service.MakeCipherEntry(keyConfig.ID, cryptoKey, keyConfig.Secret) - ciphers.PushBack(&entry) - existingCiphers[key] = true + if err := s.AddListener(addr); err != nil { + return nil, err } - uniqueCiphers += ciphers.Len() + } - for _, listener := range serviceConfig.Listeners { - addr, err := ParseNetworkAddr(listener.Address) - if err != nil { - return fmt.Errorf("error parsing listener address `%s`: %v", listener.Address, err) - } - addrs = append(addrs, addr) - addrCiphers[addr] = ciphers - } + return &s, nil +} + +func NewLegacyKeyService(config LegacyKeyServiceConfig, natTimeout time.Duration, m *outlineMetrics, replayCache *service.ReplayCache) (*Service, error) { + s := Service{ + natTimeout: natTimeout, + m: m, + replayCache: replayCache, + Ciphers: list.New(), } - // Create new listeners based on the configured network addresses. - newListeners := make([]Listener, 0) - for _, addr := range addrs { - cipherList := service.NewCipherList() - ciphers, ok := addrCiphers[addr] - if !ok { - return fmt.Errorf("unable to find ciphers for address: %v", addr.Key()) + cryptoKey, err := shadowsocks.NewEncryptionKey(config.Cipher, config.Secret) + if err != nil { + return nil, fmt.Errorf("failed to create encyption key for key %v: %w", config.ID, err) + } + entry := service.MakeCipherEntry(config.ID, cryptoKey, config.Secret) + s.Ciphers.PushBack(&entry) + + for _, network := range []string{"tcp", "udp"} { + addr := NetworkAddr{ + network: network, + Host: "::", + Port: uint(config.Port), } - cipherList.Update(ciphers) + if err := s.AddListener(addr); err != nil { + return nil, err + } + } - listener, err := addr.Listen(context.TODO(), net.ListenConfig{KeepAlive: 0}) + return &s, nil +} + +type SSServer struct { + natTimeout time.Duration + m *outlineMetrics + replayCache service.ReplayCache + services []Service +} + +func (s *SSServer) loadConfig(filename string) error { + configData, err := os.ReadFile(filename) + if err != nil { + return fmt.Errorf("failed to read config file %s: %w", filename, err) + } + config, err := readConfig(configData) + if err != nil { + return fmt.Errorf("failed to load config (%v): %w", filename, err) + } + if err := config.Validate(); err != nil { + return fmt.Errorf("failed to validate config: %w", err) + } + + // We hot swap the services by having them both live at the same time. This + // means we create services for the new config first, and then take down the + // services from the old config. + newServices := make([]Service, 0) + for _, legacyKeyServiceConfig := range config.Keys { + service, err := NewLegacyKeyService(legacyKeyServiceConfig, s.natTimeout, s.m, &s.replayCache) if err != nil { - //lint:ignore ST1005 Shadowsocks is capitalized. - return fmt.Errorf("Shadowsocks %s service failed to start on address %s: %w", addr.Network(), addr.String(), err) + return fmt.Errorf("Failed to create new service: %v", err) } - logger.Infof("Shadowsocks %s service listening on %s", addr.Network(), addr.String()) - newListeners = append(newListeners, listener) - - if err = s.serve(addr, listener, cipherList); err != nil { - return fmt.Errorf("failed to serve on %s listener on address %s: %w", addr.Network(), addr.String(), err) + newServices = append(newServices, *service) + } + for _, serviceConfig := range config.Services { + service, err := NewService(serviceConfig, s.natTimeout, s.m, &s.replayCache) + if err != nil { + return fmt.Errorf("Failed to create new service: %v", err) } + newServices = append(newServices, *service) } - // Take down the old listeners now that the new ones are serving. + // Take down the old services now that the new ones are created and serving. if err := s.Stop(); err != nil { - logger.Errorf("Failed to stop old listeners: %w", err) + logger.Errorf("Failed to stop old services: %w", err) } - s.listeners = newListeners - - logger.Infof("Loaded %v access keys over %v listeners", uniqueCiphers, len(s.listeners)) - s.m.SetNumAccessKeys(uniqueCiphers, len(s.listeners)) + s.services = newServices + + // Gather some basic stats for logging. + var ( + listenerCount int + cipherCount int + ) + for _, service := range s.services { + listenerCount += len(service.Listeners) + cipherCount += service.Ciphers.Len() + } + logger.Infof("Loaded %d services with %d access keys over %d listeners", len(s.services), cipherCount, listenerCount) + s.m.SetNumAccessKeys(cipherCount, len(s.services)) return nil } -// Stop serving on all existing listeners. +// Stop serving on all existing services. func (s *SSServer) Stop() error { - for _, listener := range s.listeners { - switch ln := listener.(type) { - case net.Listener: - if err := ln.Close(); err != nil { - //lint:ignore ST1005 Shadowsocks is capitalized. - return fmt.Errorf("Shadowsocks %s service on address %s failed to stop: %w", ln.Addr().Network(), ln.Addr().String(), err) - } - case net.PacketConn: - if err := ln.Close(); err != nil { - //lint:ignore ST1005 Shadowsocks is capitalized. - return fmt.Errorf("Shadowsocks %s service on address %s failed to stop: %w", ln.LocalAddr().Network(), ln.LocalAddr().String(), err) - } - default: - return fmt.Errorf("unknown listener type: %v", ln) + if len(s.services) == 0 { + return nil + } + + for _, service := range s.services { + if err := service.Stop(); err != nil { + return err } } + logger.Infof("Stopped %d old services", len(s.services)) return nil } @@ -231,7 +288,6 @@ func RunSSServer(filename string, natTimeout time.Duration, sm *outlineMetrics, natTimeout: natTimeout, m: sm, replayCache: service.NewReplayCache(replayHistory), - listeners: nil, } err := server.loadConfig(filename) if err != nil { From 5ac0f46d99359a36ed7baf1e709cce39625d0542 Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 8 Jul 2024 16:17:44 -0400 Subject: [PATCH 052/119] Fix how we deal with legacy services. --- cmd/outline-ss-server/main.go | 201 ++++++------------------------- cmd/outline-ss-server/metrics.go | 14 +-- cmd/outline-ss-server/service.go | 167 +++++++++++++++++++++++++ 3 files changed, 211 insertions(+), 171 deletions(-) create mode 100644 cmd/outline-ss-server/service.go diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 6093bb68..2b1f7842 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -16,10 +16,8 @@ package main import ( "container/list" - "context" "flag" "fmt" - "net" "net/http" "os" "os/signal" @@ -27,9 +25,7 @@ import ( "syscall" "time" - "github.com/Jigsaw-Code/outline-sdk/transport" "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks" - "github.com/Jigsaw-Code/outline-ss-server/ipinfo" "github.com/Jigsaw-Code/outline-ss-server/service" "github.com/op/go-logging" @@ -60,159 +56,11 @@ func init() { logger = logging.MustGetLogger("") } -// The implementations of listeners for different network types are not -// interchangeable. The type of listener depends on the network type. -// TODO(sbruens): Create a custom `Listener` type so we can share serving logic, -// dispatching to the handlers based on connection type instead of on the -// listener type. -type Listener = any - -type Service struct { - natTimeout time.Duration - m *outlineMetrics - replayCache *service.ReplayCache - Listeners []Listener - Ciphers *list.List // Values are *List of *CipherEntry. -} - -func (s *Service) Serve(addr NetworkAddr, listener Listener, cipherList service.CipherList) error { - switch ln := listener.(type) { - case net.Listener: - authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, s.replayCache, s.m) - // TODO: Register initial data metrics at zero. - tcpHandler := service.NewTCPHandler(addr.Key(), authFunc, s.m, tcpReadTimeout) - accept := func() (transport.StreamConn, error) { - c, err := ln.Accept() - if err == nil { - return c.(transport.StreamConn), err - } - return nil, err - } - go service.StreamServe(accept, tcpHandler.Handle) - case net.PacketConn: - packetHandler := service.NewPacketHandler(s.natTimeout, cipherList, s.m) - go packetHandler.Handle(ln) - default: - return fmt.Errorf("unknown listener type: %v", ln) - } - return nil -} - -func (s *Service) Stop() error { - for _, listener := range s.Listeners { - switch ln := listener.(type) { - case net.Listener: - if err := ln.Close(); err != nil { - //lint:ignore ST1005 Shadowsocks is capitalized. - return fmt.Errorf("Shadowsocks %s service on address %s failed to stop: %w", ln.Addr().Network(), ln.Addr().String(), err) - } - case net.PacketConn: - if err := ln.Close(); err != nil { - //lint:ignore ST1005 Shadowsocks is capitalized. - return fmt.Errorf("Shadowsocks %s service on address %s failed to stop: %w", ln.LocalAddr().Network(), ln.LocalAddr().String(), err) - } - default: - return fmt.Errorf("unknown listener type: %v", ln) - } - } - return nil -} - -// AddListener adds a new listener to the service. -func (s *Service) AddListener(addr NetworkAddr) error { - // Create new listeners based on the configured network addresses. - cipherList := service.NewCipherList() - cipherList.Update(s.Ciphers) - - listener, err := addr.Listen(context.TODO(), net.ListenConfig{KeepAlive: 0}) - if err != nil { - //lint:ignore ST1005 Shadowsocks is capitalized. - return fmt.Errorf("Shadowsocks %s service failed to start on address %s: %w", addr.Network(), addr.String(), err) - } - s.Listeners = append(s.Listeners, listener) - logger.Infof("Shadowsocks %s service listening on %s", addr.Network(), addr.String()) - if err = s.Serve(addr, listener, cipherList); err != nil { - return fmt.Errorf("failed to serve on %s listener on address %s: %w", addr.Network(), addr.String(), err) - } - return nil -} - -// NewService creates a new Service. -func NewService(config ServiceConfig, natTimeout time.Duration, m *outlineMetrics, replayCache *service.ReplayCache) (*Service, error) { - s := Service{ - natTimeout: natTimeout, - m: m, - replayCache: replayCache, - Ciphers: list.New(), - } - - type cipherKey struct { - cipher string - secret string - } - existingCiphers := make(map[cipherKey]bool) - for _, keyConfig := range config.Keys { - key := cipherKey{keyConfig.Cipher, keyConfig.Secret} - if _, exists := existingCiphers[key]; exists { - logger.Debugf("encryption key already exists for ID=`%v`. Skipping.", keyConfig.ID) - continue - } - cryptoKey, err := shadowsocks.NewEncryptionKey(keyConfig.Cipher, keyConfig.Secret) - if err != nil { - return nil, fmt.Errorf("failed to create encyption key for key %v: %w", keyConfig.ID, err) - } - entry := service.MakeCipherEntry(keyConfig.ID, cryptoKey, keyConfig.Secret) - s.Ciphers.PushBack(&entry) - existingCiphers[key] = true - } - - for _, listener := range config.Listeners { - addr, err := ParseNetworkAddr(listener.Address) - if err != nil { - return nil, fmt.Errorf("error parsing listener address `%s`: %v", listener.Address, err) - } - if err := s.AddListener(addr); err != nil { - return nil, err - } - } - - return &s, nil -} - -func NewLegacyKeyService(config LegacyKeyServiceConfig, natTimeout time.Duration, m *outlineMetrics, replayCache *service.ReplayCache) (*Service, error) { - s := Service{ - natTimeout: natTimeout, - m: m, - replayCache: replayCache, - Ciphers: list.New(), - } - - cryptoKey, err := shadowsocks.NewEncryptionKey(config.Cipher, config.Secret) - if err != nil { - return nil, fmt.Errorf("failed to create encyption key for key %v: %w", config.ID, err) - } - entry := service.MakeCipherEntry(config.ID, cryptoKey, config.Secret) - s.Ciphers.PushBack(&entry) - - for _, network := range []string{"tcp", "udp"} { - addr := NetworkAddr{ - network: network, - Host: "::", - Port: uint(config.Port), - } - if err := s.AddListener(addr); err != nil { - return nil, err - } - } - - return &s, nil -} - type SSServer struct { natTimeout time.Duration m *outlineMetrics replayCache service.ReplayCache - services []Service + services []*Service } func (s *SSServer) loadConfig(filename string) error { @@ -231,21 +79,47 @@ func (s *SSServer) loadConfig(filename string) error { // We hot swap the services by having them both live at the same time. This // means we create services for the new config first, and then take down the // services from the old config. - newServices := make([]Service, 0) + newServices := make([]*Service, 0) + + legacyPortService := make(map[int]*Service) // Values are *List of *CipherEntry. for _, legacyKeyServiceConfig := range config.Keys { - service, err := NewLegacyKeyService(legacyKeyServiceConfig, s.natTimeout, s.m, &s.replayCache) + legacyService, ok := legacyPortService[legacyKeyServiceConfig.Port] + if !ok { + legacyService = &Service{ + natTimeout: s.natTimeout, + m: s.m, + replayCache: &s.replayCache, + ciphers: list.New(), + } + for _, network := range []string{"tcp", "udp"} { + addr := NetworkAddr{ + network: network, + Host: "::", + Port: uint(legacyKeyServiceConfig.Port), + } + if err := legacyService.AddListener(addr); err != nil { + return err + } + } + newServices = append(newServices, legacyService) + legacyPortService[legacyKeyServiceConfig.Port] = legacyService + } + cryptoKey, err := shadowsocks.NewEncryptionKey(legacyKeyServiceConfig.Cipher, legacyKeyServiceConfig.Secret) if err != nil { - return fmt.Errorf("Failed to create new service: %v", err) + return fmt.Errorf("failed to create encyption key for key %v: %w", legacyKeyServiceConfig.ID, err) } - newServices = append(newServices, *service) + entry := service.MakeCipherEntry(legacyKeyServiceConfig.ID, cryptoKey, legacyKeyServiceConfig.Secret) + legacyService.AddCipher(&entry) } + for _, serviceConfig := range config.Services { service, err := NewService(serviceConfig, s.natTimeout, s.m, &s.replayCache) if err != nil { return fmt.Errorf("Failed to create new service: %v", err) } - newServices = append(newServices, *service) + newServices = append(newServices, service) } + logger.Infof("Loaded %d new services", len(newServices)) // Take down the old services now that the new ones are created and serving. if err := s.Stop(); err != nil { @@ -253,17 +127,16 @@ func (s *SSServer) loadConfig(filename string) error { } s.services = newServices - // Gather some basic stats for logging. var ( listenerCount int cipherCount int ) for _, service := range s.services { - listenerCount += len(service.Listeners) - cipherCount += service.Ciphers.Len() + listenerCount += service.NumListeners() + cipherCount += service.NumCiphers() } - logger.Infof("Loaded %d services with %d access keys over %d listeners", len(s.services), cipherCount, listenerCount) - s.m.SetNumAccessKeys(cipherCount, len(s.services)) + logger.Infof("%d services active: %d access keys over %d listeners", len(s.services), cipherCount, listenerCount) + s.m.SetNumAccessKeys(cipherCount, listenerCount) return nil } @@ -272,13 +145,13 @@ func (s *SSServer) Stop() error { if len(s.services) == 0 { return nil } - for _, service := range s.services { if err := service.Stop(); err != nil { return err } } logger.Infof("Stopped %d old services", len(s.services)) + s.services = nil return nil } diff --git a/cmd/outline-ss-server/metrics.go b/cmd/outline-ss-server/metrics.go index e95ceeb3..600cea16 100644 --- a/cmd/outline-ss-server/metrics.go +++ b/cmd/outline-ss-server/metrics.go @@ -38,7 +38,7 @@ type outlineMetrics struct { buildInfo *prometheus.GaugeVec accessKeys prometheus.Gauge - ports prometheus.Gauge + listeners prometheus.Gauge dataBytes *prometheus.CounterVec dataBytesPerLocation *prometheus.CounterVec timeToCipherMs *prometheus.HistogramVec @@ -183,10 +183,10 @@ func newPrometheusOutlineMetrics(ip2info ipinfo.IPInfoMap, registerer prometheus Name: "keys", Help: "Count of access keys", }), - ports: prometheus.NewGauge(prometheus.GaugeOpts{ + listeners: prometheus.NewGauge(prometheus.GaugeOpts{ Namespace: namespace, - Name: "ports", - Help: "Count of open Shadowsocks ports", + Name: "listeners", + Help: "Count of open Shadowsocks listeners", }), tcpProbes: prometheus.NewHistogramVec(prometheus.HistogramOpts{ Namespace: namespace, @@ -265,7 +265,7 @@ func newPrometheusOutlineMetrics(ip2info ipinfo.IPInfoMap, registerer prometheus m.tunnelTimeCollector = newTunnelTimeCollector(ip2info, registerer) // TODO: Is it possible to pass where to register the collectors? - registerer.MustRegister(m.buildInfo, m.accessKeys, m.ports, m.tcpProbes, m.tcpOpenConnections, m.tcpClosedConnections, m.tcpConnectionDurationMs, + registerer.MustRegister(m.buildInfo, m.accessKeys, m.listeners, m.tcpProbes, m.tcpOpenConnections, m.tcpClosedConnections, m.tcpConnectionDurationMs, m.dataBytes, m.dataBytesPerLocation, m.timeToCipherMs, m.udpPacketsFromClientPerLocation, m.udpAddedNatEntries, m.udpRemovedNatEntries, m.tunnelTimeCollector) return m @@ -275,9 +275,9 @@ func (m *outlineMetrics) SetBuildInfo(version string) { m.buildInfo.WithLabelValues(version).Set(1) } -func (m *outlineMetrics) SetNumAccessKeys(numKeys int, ports int) { +func (m *outlineMetrics) SetNumAccessKeys(numKeys int, listeners int) { m.accessKeys.Set(float64(numKeys)) - m.ports.Set(float64(ports)) + m.listeners.Set(float64(listeners)) } func (m *outlineMetrics) AddOpenTCPConnection(clientInfo ipinfo.IPInfo) { diff --git a/cmd/outline-ss-server/service.go b/cmd/outline-ss-server/service.go new file mode 100644 index 00000000..dc91e5a5 --- /dev/null +++ b/cmd/outline-ss-server/service.go @@ -0,0 +1,167 @@ +// Copyright 2024 The Outline Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "container/list" + "context" + "fmt" + "net" + "time" + + "github.com/Jigsaw-Code/outline-sdk/transport" + "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks" + "github.com/Jigsaw-Code/outline-ss-server/service" +) + +// The implementations of listeners for different network types are not +// interchangeable. The type of listener depends on the network type. +// TODO(sbruens): Create a custom `Listener` type so we can share serving logic, +// dispatching to the handlers based on connection type instead of on the +// listener type. +type Listener = any + +type Service struct { + natTimeout time.Duration + m *outlineMetrics + replayCache *service.ReplayCache + listeners []Listener + ciphers *list.List // Values are *List of *service.CipherEntry. +} + +func (s *Service) Serve(addr NetworkAddr, listener Listener, cipherList service.CipherList) error { + switch ln := listener.(type) { + case net.Listener: + authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, s.replayCache, s.m) + // TODO: Register initial data metrics at zero. + tcpHandler := service.NewTCPHandler(addr.Key(), authFunc, s.m, tcpReadTimeout) + accept := func() (transport.StreamConn, error) { + c, err := ln.Accept() + if err == nil { + return c.(transport.StreamConn), err + } + return nil, err + } + go service.StreamServe(accept, tcpHandler.Handle) + case net.PacketConn: + packetHandler := service.NewPacketHandler(s.natTimeout, cipherList, s.m) + go packetHandler.Handle(ln) + default: + return fmt.Errorf("unknown listener type: %v", ln) + } + return nil +} + +func (s *Service) Stop() error { + for _, listener := range s.listeners { + switch ln := listener.(type) { + case net.Listener: + if err := ln.Close(); err != nil { + //lint:ignore ST1005 Shadowsocks is capitalized. + return fmt.Errorf("Shadowsocks %s service on address %s failed to stop: %w", ln.Addr().Network(), ln.Addr().String(), err) + } + case net.PacketConn: + if err := ln.Close(); err != nil { + //lint:ignore ST1005 Shadowsocks is capitalized. + return fmt.Errorf("Shadowsocks %s service on address %s failed to stop: %w", ln.LocalAddr().Network(), ln.LocalAddr().String(), err) + } + default: + return fmt.Errorf("unknown listener type: %v", ln) + } + } + return nil +} + +// AddListener adds a new listener to the service. +func (s *Service) AddListener(addr NetworkAddr) error { + // Create new listeners based on the configured network addresses. + cipherList := service.NewCipherList() + cipherList.Update(s.ciphers) + + listener, err := addr.Listen(context.TODO(), net.ListenConfig{KeepAlive: 0}) + if err != nil { + //lint:ignore ST1005 Shadowsocks is capitalized. + return fmt.Errorf("Shadowsocks %s service failed to start on address %s: %w", addr.Network(), addr.String(), err) + } + s.listeners = append(s.listeners, listener) + logger.Infof("Shadowsocks %s service listening on %s", addr.Network(), addr.String()) + if err = s.Serve(addr, listener, cipherList); err != nil { + return fmt.Errorf("failed to serve on %s listener on address %s: %w", addr.Network(), addr.String(), err) + } + return nil +} + +func (s *Service) NumListeners() int { + return len(s.listeners) +} + +func (s *Service) AddCipher(entry *service.CipherEntry) { + s.ciphers.PushBack(entry) +} + +func (s *Service) NumCiphers() int { + return s.ciphers.Len() +} + +// func NewService(natTimeout time.Duration, m *outlineMetrics, replayCache *service.ReplayCache) Service { +// return &Service{ +// natTimeout: natTimeout, +// m: m, +// replayCache: replayCache, +// ciphers: list.New(), +// } +// } + +// NewService creates a new Service based on a config +func NewService(config ServiceConfig, natTimeout time.Duration, m *outlineMetrics, replayCache *service.ReplayCache) (*Service, error) { + s := Service{ + natTimeout: natTimeout, + m: m, + replayCache: replayCache, + ciphers: list.New(), + } + + type cipherKey struct { + cipher string + secret string + } + existingCiphers := make(map[cipherKey]bool) + for _, keyConfig := range config.Keys { + key := cipherKey{keyConfig.Cipher, keyConfig.Secret} + if _, exists := existingCiphers[key]; exists { + logger.Debugf("encryption key already exists for ID=`%v`. Skipping.", keyConfig.ID) + continue + } + cryptoKey, err := shadowsocks.NewEncryptionKey(keyConfig.Cipher, keyConfig.Secret) + if err != nil { + return nil, fmt.Errorf("failed to create encyption key for key %v: %w", keyConfig.ID, err) + } + entry := service.MakeCipherEntry(keyConfig.ID, cryptoKey, keyConfig.Secret) + s.AddCipher(&entry) + existingCiphers[key] = true + } + + for _, listener := range config.Listeners { + addr, err := ParseNetworkAddr(listener.Address) + if err != nil { + return nil, fmt.Errorf("error parsing listener address `%s`: %v", listener.Address, err) + } + if err := s.AddListener(addr); err != nil { + return nil, err + } + } + + return &s, nil +} From 1f097be05a9ed4343f753a73306bab3c4b6e8aa6 Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 8 Jul 2024 16:21:22 -0400 Subject: [PATCH 053/119] Remove commented out lines. --- cmd/outline-ss-server/service.go | 9 --------- 1 file changed, 9 deletions(-) diff --git a/cmd/outline-ss-server/service.go b/cmd/outline-ss-server/service.go index dc91e5a5..5635de3e 100644 --- a/cmd/outline-ss-server/service.go +++ b/cmd/outline-ss-server/service.go @@ -115,15 +115,6 @@ func (s *Service) NumCiphers() int { return s.ciphers.Len() } -// func NewService(natTimeout time.Duration, m *outlineMetrics, replayCache *service.ReplayCache) Service { -// return &Service{ -// natTimeout: natTimeout, -// m: m, -// replayCache: replayCache, -// ciphers: list.New(), -// } -// } - // NewService creates a new Service based on a config func NewService(config ServiceConfig, natTimeout time.Duration, m *outlineMetrics, replayCache *service.ReplayCache) (*Service, error) { s := Service{ From 80b25b16ac2645477b51cd759113a970421ba9e4 Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 8 Jul 2024 16:55:28 -0400 Subject: [PATCH 054/119] Use `tcp` and `udp` types for direct listeners. --- cmd/outline-ss-server/config.go | 10 +-- cmd/outline-ss-server/config_example.yml | 16 ++-- cmd/outline-ss-server/config_test.go | 26 ++----- cmd/outline-ss-server/listeners.go | 99 ++++-------------------- cmd/outline-ss-server/main.go | 10 +-- cmd/outline-ss-server/service.go | 24 +++--- 6 files changed, 50 insertions(+), 135 deletions(-) diff --git a/cmd/outline-ss-server/config.go b/cmd/outline-ss-server/config.go index 5929f202..e8a8a43f 100644 --- a/cmd/outline-ss-server/config.go +++ b/cmd/outline-ss-server/config.go @@ -28,7 +28,8 @@ type ServiceConfig struct { type ListenerType string -const listenerTypeDirect ListenerType = "direct" +const listenerTypeTCP ListenerType = "tcp" +const listenerTypeUDP ListenerType = "udp" type ListenerConfig struct { Type ListenerType @@ -59,17 +60,14 @@ func (c *Config) Validate() error { for _, serviceConfig := range c.Services { for _, listenerConfig := range serviceConfig.Listeners { // TODO: Support more listener types. - if listenerConfig.Type != listenerTypeDirect { + if listenerConfig.Type != listenerTypeTCP && listenerConfig.Type != listenerTypeUDP { return fmt.Errorf("unsupported listener type: %s", listenerConfig.Type) } - network, host, _, err := SplitNetworkAddr(listenerConfig.Address) + host, _, err := net.SplitHostPort(listenerConfig.Address) if err != nil { return fmt.Errorf("invalid listener address `%s`: %v", listenerConfig.Address, err) } - if network != "tcp" && network != "udp" { - return fmt.Errorf("unsupported network: %s", network) - } if ip := net.ParseIP(host); ip == nil { return fmt.Errorf("address must be IP, found: %s", host) } diff --git a/cmd/outline-ss-server/config_example.yml b/cmd/outline-ss-server/config_example.yml index bbfd265f..7af360b2 100644 --- a/cmd/outline-ss-server/config_example.yml +++ b/cmd/outline-ss-server/config_example.yml @@ -2,10 +2,10 @@ services: - listeners: # TODO(sbruens): Allow a string-based listener config, as a convenient short-form # to create a direct listener, e.g. `- tcp/[::]:9000`. - - type: direct - address: "tcp/[::]:9000" - - type: direct - address: "udp/[::]:9000" + - type: tcp + address: "[::]:9000" + - type: udp + address: "[::]:9000" keys: - id: user-0 cipher: chacha20-ietf-poly1305 @@ -15,10 +15,10 @@ services: secret: Secret1 - listeners: - - type: direct - address: "tcp/[::]:9001" - - type: direct - address: "udp/[::]:9001" + - type: tcp + address: "[::]:9001" + - type: udp + address: "[::]:9001" keys: - id: user-2 cipher: chacha20-ietf-poly1305 diff --git a/cmd/outline-ss-server/config_test.go b/cmd/outline-ss-server/config_test.go index 3a8b6cf8..25895111 100644 --- a/cmd/outline-ss-server/config_test.go +++ b/cmd/outline-ss-server/config_test.go @@ -32,7 +32,7 @@ func TestValidateConfigFails(t *testing.T) { Services: []ServiceConfig{ ServiceConfig{ Listeners: []ListenerConfig{ - ListenerConfig{Type: "foo", Address: "tcp/[::]:9000"}, + ListenerConfig{Type: "foo", Address: "[::]:9000"}, }, }, }, @@ -44,19 +44,7 @@ func TestValidateConfigFails(t *testing.T) { Services: []ServiceConfig{ ServiceConfig{ Listeners: []ListenerConfig{ - ListenerConfig{Type: listenerTypeDirect, Address: "tcp//[::]:9000"}, - }, - }, - }, - }, - }, - { - name: "WithUnsupportedNetworkType", - cfg: &Config{ - Services: []ServiceConfig{ - ServiceConfig{ - Listeners: []ListenerConfig{ - ListenerConfig{Type: listenerTypeDirect, Address: "foo/[::]:9000"}, + ListenerConfig{Type: listenerTypeTCP, Address: "tcp/[::]:9000"}, }, }, }, @@ -68,7 +56,7 @@ func TestValidateConfigFails(t *testing.T) { Services: []ServiceConfig{ ServiceConfig{ Listeners: []ListenerConfig{ - ListenerConfig{Type: listenerTypeDirect, Address: "tcp/example.com:9000"}, + ListenerConfig{Type: listenerTypeTCP, Address: "example.com:9000"}, }, }, }, @@ -92,8 +80,8 @@ func TestReadConfig(t *testing.T) { Services: []ServiceConfig{ ServiceConfig{ Listeners: []ListenerConfig{ - ListenerConfig{Type: listenerTypeDirect, Address: "tcp/[::]:9000"}, - ListenerConfig{Type: listenerTypeDirect, Address: "udp/[::]:9000"}, + ListenerConfig{Type: listenerTypeTCP, Address: "[::]:9000"}, + ListenerConfig{Type: listenerTypeUDP, Address: "[::]:9000"}, }, Keys: []KeyConfig{ KeyConfig{"user-0", "chacha20-ietf-poly1305", "Secret0"}, @@ -102,8 +90,8 @@ func TestReadConfig(t *testing.T) { }, ServiceConfig{ Listeners: []ListenerConfig{ - ListenerConfig{Type: listenerTypeDirect, Address: "tcp/[::]:9001"}, - ListenerConfig{Type: listenerTypeDirect, Address: "udp/[::]:9001"}, + ListenerConfig{Type: listenerTypeTCP, Address: "[::]:9001"}, + ListenerConfig{Type: listenerTypeUDP, Address: "[::]:9001"}, }, Keys: []KeyConfig{ KeyConfig{"user-2", "chacha20-ietf-poly1305", "Secret2"}, diff --git a/cmd/outline-ss-server/listeners.go b/cmd/outline-ss-server/listeners.go index fefa6920..2f8f7c59 100644 --- a/cmd/outline-ss-server/listeners.go +++ b/cmd/outline-ss-server/listeners.go @@ -16,11 +16,8 @@ package main import ( "context" - "errors" "fmt" "net" - "strconv" - "strings" "sync" "sync/atomic" "time" @@ -152,71 +149,46 @@ type globalListener struct { deadlineMu sync.Mutex } -type NetworkAddr struct { - network string - Host string - Port uint -} - -// String returns a human-readable representation of the [NetworkAddr]. -func (na *NetworkAddr) Network() string { - return na.network -} - -// String returns a human-readable representation of the [NetworkAddr]. -func (na *NetworkAddr) String() string { - return na.JoinHostPort() -} - -// JoinHostPort is a convenience wrapper around [net.JoinHostPort]. -func (na *NetworkAddr) JoinHostPort() string { - return net.JoinHostPort(na.Host, strconv.Itoa(int(na.Port))) -} - -// Key returns a representative string useful to retrieve this entity from a -// map. This is used to uniquely identify reusable listeners. -func (na *NetworkAddr) Key() string { - return na.network + "/" + na.JoinHostPort() -} - -// Listen creates a new listener for the [NetworkAddr]. +// Listen creates a new listener for a given network and address. // // Listeners can overlap one another, because during config changes the new // config is started before the old config is destroyed. This is done by using // reusable listener wrappers, which do not actually close the underlying socket // until all uses of the shared listener have been closed. -func (na *NetworkAddr) Listen(ctx context.Context, config net.ListenConfig) (any, error) { - switch na.network { +func Listen(ctx context.Context, network string, addr string, config net.ListenConfig) (any, error) { + lnKey := network + "/" + addr + + switch network { case "tcp": listenersMu.Lock() defer listenersMu.Unlock() - if lnGlobal, ok := listeners[na.Key()]; ok { + if lnGlobal, ok := listeners[lnKey]; ok { lnGlobal.usage.Add(1) return &sharedListener{ usage: &lnGlobal.usage, deadline: &lnGlobal.deadline, deadlineMu: &lnGlobal.deadlineMu, - key: na.Key(), + key: lnKey, listener: lnGlobal.ln, }, nil } - ln, err := config.Listen(ctx, na.network, na.JoinHostPort()) + ln, err := config.Listen(ctx, network, addr) if err != nil { return nil, err } lnGlobal := &globalListener{ln: ln} lnGlobal.usage.Store(1) - listeners[na.Key()] = lnGlobal + listeners[lnKey] = lnGlobal return &sharedListener{ usage: &lnGlobal.usage, deadline: &lnGlobal.deadline, deadlineMu: &lnGlobal.deadlineMu, - key: na.Key(), + key: lnKey, listener: ln, }, nil @@ -224,71 +196,32 @@ func (na *NetworkAddr) Listen(ctx context.Context, config net.ListenConfig) (any listenersMu.Lock() defer listenersMu.Unlock() - if lnGlobal, ok := listeners[na.Key()]; ok { + if lnGlobal, ok := listeners[lnKey]; ok { lnGlobal.usage.Add(1) return &sharedPacketConn{ usage: &lnGlobal.usage, - key: na.Key(), + key: lnKey, PacketConn: lnGlobal.pc, }, nil } - pc, err := config.ListenPacket(ctx, na.network, na.JoinHostPort()) + pc, err := config.ListenPacket(ctx, network, addr) if err != nil { return nil, err } lnGlobal := &globalListener{pc: pc} lnGlobal.usage.Store(1) - listeners[na.Key()] = lnGlobal + listeners[lnKey] = lnGlobal return &sharedPacketConn{ usage: &lnGlobal.usage, - key: na.Key(), + key: lnKey, PacketConn: pc, }, nil default: - return nil, fmt.Errorf("unsupported network: %s", na.network) - - } -} - -// ParseNetworkAddr parses an address into a [NetworkAddr]. The input -// string is expected to be of the form "network/host:port" where any part is -// optional. -// -// Examples: -// -// tcp/127.0.0.1:8000 -// udp/127.0.0.1:9000 -func ParseNetworkAddr(addr string) (NetworkAddr, error) { - var host, port string - network, host, port, err := SplitNetworkAddr(addr) - if err != nil { - return NetworkAddr{}, err - } - if network == "" { - return NetworkAddr{}, errors.New("missing network") - } - p, err := strconv.ParseUint(port, 10, 16) - if err != nil { - return NetworkAddr{}, fmt.Errorf("invalid port: %v", err) - } - return NetworkAddr{ - network: network, - Host: host, - Port: uint(p), - }, nil -} + return nil, fmt.Errorf("unsupported network: %s", network) -// SplitNetworkAddr splits a into its network, host, and port components. -func SplitNetworkAddr(a string) (network, host, port string, err error) { - beforeSlash, afterSlash, slashFound := strings.Cut(a, "/") - if slashFound { - network = strings.ToLower(strings.TrimSpace(beforeSlash)) - a = afterSlash } - host, port, err = net.SplitHostPort(a) - return } diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 2b1f7842..fbc40e89 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -18,9 +18,11 @@ import ( "container/list" "flag" "fmt" + "net" "net/http" "os" "os/signal" + "strconv" "strings" "syscall" "time" @@ -92,12 +94,8 @@ func (s *SSServer) loadConfig(filename string) error { ciphers: list.New(), } for _, network := range []string{"tcp", "udp"} { - addr := NetworkAddr{ - network: network, - Host: "::", - Port: uint(legacyKeyServiceConfig.Port), - } - if err := legacyService.AddListener(addr); err != nil { + addr := net.JoinHostPort("::", strconv.Itoa(legacyKeyServiceConfig.Port)) + if err := legacyService.AddListener(network, addr); err != nil { return err } } diff --git a/cmd/outline-ss-server/service.go b/cmd/outline-ss-server/service.go index 5635de3e..9dc98657 100644 --- a/cmd/outline-ss-server/service.go +++ b/cmd/outline-ss-server/service.go @@ -41,12 +41,12 @@ type Service struct { ciphers *list.List // Values are *List of *service.CipherEntry. } -func (s *Service) Serve(addr NetworkAddr, listener Listener, cipherList service.CipherList) error { +func (s *Service) Serve(lnKey string, listener Listener, cipherList service.CipherList) error { switch ln := listener.(type) { case net.Listener: authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, s.replayCache, s.m) // TODO: Register initial data metrics at zero. - tcpHandler := service.NewTCPHandler(addr.Key(), authFunc, s.m, tcpReadTimeout) + tcpHandler := service.NewTCPHandler(lnKey, authFunc, s.m, tcpReadTimeout) accept := func() (transport.StreamConn, error) { c, err := ln.Accept() if err == nil { @@ -85,20 +85,21 @@ func (s *Service) Stop() error { } // AddListener adds a new listener to the service. -func (s *Service) AddListener(addr NetworkAddr) error { +func (s *Service) AddListener(network string, addr string) error { // Create new listeners based on the configured network addresses. cipherList := service.NewCipherList() cipherList.Update(s.ciphers) - listener, err := addr.Listen(context.TODO(), net.ListenConfig{KeepAlive: 0}) + listener, err := Listen(context.TODO(), network, addr, net.ListenConfig{KeepAlive: 0}) if err != nil { //lint:ignore ST1005 Shadowsocks is capitalized. - return fmt.Errorf("Shadowsocks %s service failed to start on address %s: %w", addr.Network(), addr.String(), err) + return fmt.Errorf("Shadowsocks %s service failed to start on address %s: %w", network, addr, err) } s.listeners = append(s.listeners, listener) - logger.Infof("Shadowsocks %s service listening on %s", addr.Network(), addr.String()) - if err = s.Serve(addr, listener, cipherList); err != nil { - return fmt.Errorf("failed to serve on %s listener on address %s: %w", addr.Network(), addr.String(), err) + logger.Infof("Shadowsocks %s service listening on %s", network, addr) + lnKey := network + "/" + addr + if err = s.Serve(lnKey, listener, cipherList); err != nil { + return fmt.Errorf("failed to serve on %s listener on address %s: %w", network, addr, err) } return nil } @@ -145,11 +146,8 @@ func NewService(config ServiceConfig, natTimeout time.Duration, m *outlineMetric } for _, listener := range config.Listeners { - addr, err := ParseNetworkAddr(listener.Address) - if err != nil { - return nil, fmt.Errorf("error parsing listener address `%s`: %v", listener.Address, err) - } - if err := s.AddListener(addr); err != nil { + network := string(listener.Type) + if err := s.AddListener(network, listener.Address); err != nil { return nil, err } } From 2070d40ff9cb36ec4f054b766057e318f8bddb88 Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 8 Jul 2024 17:41:37 -0400 Subject: [PATCH 055/119] Use a `ListenerManager` instead of globals to manage listener state. --- cmd/outline-ss-server/listeners.go | 82 +++++++++++++++++++----------- cmd/outline-ss-server/main.go | 5 +- cmd/outline-ss-server/service.go | 6 ++- 3 files changed, 59 insertions(+), 34 deletions(-) diff --git a/cmd/outline-ss-server/listeners.go b/cmd/outline-ss-server/listeners.go index 2f8f7c59..c8a5d644 100644 --- a/cmd/outline-ss-server/listeners.go +++ b/cmd/outline-ss-server/listeners.go @@ -23,13 +23,9 @@ import ( "time" ) -var ( - listeners = make(map[string]*globalListener) - listenersMu sync.Mutex -) - type sharedListener struct { listener net.Listener + manager ListenerManager key string closed atomic.Int32 usage *atomic.Int32 @@ -99,9 +95,7 @@ func (sl *sharedListener) Close() error { // See if we need to actually close the underlying listener. if sl.usage.Add(-1) == 0 { - listenersMu.Lock() - delete(listeners, sl.key) - listenersMu.Unlock() + sl.manager.Delete(sl.key) err := sl.listener.Close() if err != nil { return err @@ -119,18 +113,17 @@ func (sl *sharedListener) Addr() net.Addr { type sharedPacketConn struct { net.PacketConn - key string - closed atomic.Int32 - usage *atomic.Int32 + manager ListenerManager + key string + closed atomic.Int32 + usage *atomic.Int32 } func (spc *sharedPacketConn) Close() error { if spc.closed.CompareAndSwap(0, 1) { // See if we need to actually close the underlying listener. if spc.usage.Add(-1) == 0 { - listenersMu.Lock() - delete(listeners, spc.key) - listenersMu.Unlock() + spc.manager.Delete(spc.key) err := spc.PacketConn.Close() if err != nil { return err @@ -149,29 +142,47 @@ type globalListener struct { deadlineMu sync.Mutex } +// ListenerManager holds and manages the state of shared listeners. +type ListenerManager interface { + Listen(ctx context.Context, network string, addr string, config net.ListenConfig) (any, error) + Delete(key string) +} + +type listenerManager struct { + listeners map[string]*globalListener + listenersMu sync.Mutex +} + +func NewListenerManager() ListenerManager { + return &listenerManager{ + listeners: make(map[string]*globalListener), + } +} + // Listen creates a new listener for a given network and address. // // Listeners can overlap one another, because during config changes the new // config is started before the old config is destroyed. This is done by using // reusable listener wrappers, which do not actually close the underlying socket // until all uses of the shared listener have been closed. -func Listen(ctx context.Context, network string, addr string, config net.ListenConfig) (any, error) { +func (m *listenerManager) Listen(ctx context.Context, network string, addr string, config net.ListenConfig) (any, error) { lnKey := network + "/" + addr switch network { case "tcp": - listenersMu.Lock() - defer listenersMu.Unlock() + m.listenersMu.Lock() + defer m.listenersMu.Unlock() - if lnGlobal, ok := listeners[lnKey]; ok { + if lnGlobal, ok := m.listeners[lnKey]; ok { lnGlobal.usage.Add(1) return &sharedListener{ + listener: lnGlobal.ln, + manager: m, + key: lnKey, usage: &lnGlobal.usage, deadline: &lnGlobal.deadline, deadlineMu: &lnGlobal.deadlineMu, - key: lnKey, - listener: lnGlobal.ln, }, nil } @@ -182,26 +193,28 @@ func Listen(ctx context.Context, network string, addr string, config net.ListenC lnGlobal := &globalListener{ln: ln} lnGlobal.usage.Store(1) - listeners[lnKey] = lnGlobal + m.listeners[lnKey] = lnGlobal return &sharedListener{ + listener: ln, + manager: m, + key: lnKey, usage: &lnGlobal.usage, deadline: &lnGlobal.deadline, deadlineMu: &lnGlobal.deadlineMu, - key: lnKey, - listener: ln, }, nil case "udp": - listenersMu.Lock() - defer listenersMu.Unlock() + m.listenersMu.Lock() + defer m.listenersMu.Unlock() - if lnGlobal, ok := listeners[lnKey]; ok { + if lnGlobal, ok := m.listeners[lnKey]; ok { lnGlobal.usage.Add(1) return &sharedPacketConn{ - usage: &lnGlobal.usage, - key: lnKey, PacketConn: lnGlobal.pc, + manager: m, + key: lnKey, + usage: &lnGlobal.usage, }, nil } @@ -212,12 +225,13 @@ func Listen(ctx context.Context, network string, addr string, config net.ListenC lnGlobal := &globalListener{pc: pc} lnGlobal.usage.Store(1) - listeners[lnKey] = lnGlobal + m.listeners[lnKey] = lnGlobal return &sharedPacketConn{ - usage: &lnGlobal.usage, - key: lnKey, PacketConn: pc, + manager: m, + key: lnKey, + usage: &lnGlobal.usage, }, nil default: @@ -225,3 +239,9 @@ func Listen(ctx context.Context, network string, addr string, config net.ListenC } } + +func (m *listenerManager) Delete(key string) { + m.listenersMu.Lock() + delete(m.listeners, key) + m.listenersMu.Unlock() +} diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index fbc40e89..92bb62b4 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -59,6 +59,7 @@ func init() { } type SSServer struct { + lnManager ListenerManager natTimeout time.Duration m *outlineMetrics replayCache service.ReplayCache @@ -88,6 +89,7 @@ func (s *SSServer) loadConfig(filename string) error { legacyService, ok := legacyPortService[legacyKeyServiceConfig.Port] if !ok { legacyService = &Service{ + lnManager: s.lnManager, natTimeout: s.natTimeout, m: s.m, replayCache: &s.replayCache, @@ -111,7 +113,7 @@ func (s *SSServer) loadConfig(filename string) error { } for _, serviceConfig := range config.Services { - service, err := NewService(serviceConfig, s.natTimeout, s.m, &s.replayCache) + service, err := NewService(serviceConfig, s.lnManager, s.natTimeout, s.m, &s.replayCache) if err != nil { return fmt.Errorf("Failed to create new service: %v", err) } @@ -156,6 +158,7 @@ func (s *SSServer) Stop() error { // RunSSServer starts a shadowsocks server running, and returns the server or an error. func RunSSServer(filename string, natTimeout time.Duration, sm *outlineMetrics, replayHistory int) (*SSServer, error) { server := &SSServer{ + lnManager: NewListenerManager(), natTimeout: natTimeout, m: sm, replayCache: service.NewReplayCache(replayHistory), diff --git a/cmd/outline-ss-server/service.go b/cmd/outline-ss-server/service.go index 9dc98657..a8ac9f04 100644 --- a/cmd/outline-ss-server/service.go +++ b/cmd/outline-ss-server/service.go @@ -34,6 +34,7 @@ import ( type Listener = any type Service struct { + lnManager ListenerManager natTimeout time.Duration m *outlineMetrics replayCache *service.ReplayCache @@ -90,7 +91,7 @@ func (s *Service) AddListener(network string, addr string) error { cipherList := service.NewCipherList() cipherList.Update(s.ciphers) - listener, err := Listen(context.TODO(), network, addr, net.ListenConfig{KeepAlive: 0}) + listener, err := s.lnManager.Listen(context.TODO(), network, addr, net.ListenConfig{KeepAlive: 0}) if err != nil { //lint:ignore ST1005 Shadowsocks is capitalized. return fmt.Errorf("Shadowsocks %s service failed to start on address %s: %w", network, addr, err) @@ -117,8 +118,9 @@ func (s *Service) NumCiphers() int { } // NewService creates a new Service based on a config -func NewService(config ServiceConfig, natTimeout time.Duration, m *outlineMetrics, replayCache *service.ReplayCache) (*Service, error) { +func NewService(config ServiceConfig, lnManager ListenerManager, natTimeout time.Duration, m *outlineMetrics, replayCache *service.ReplayCache) (*Service, error) { s := Service{ + lnManager: lnManager, natTimeout: natTimeout, m: m, replayCache: replayCache, From eacfa0e7fd3e5b8c7b0bc8170bfa018622d5ca13 Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 8 Jul 2024 17:50:49 -0400 Subject: [PATCH 056/119] Add validation check that no two services have the same listener. --- cmd/outline-ss-server/config.go | 17 +++++++++++------ cmd/outline-ss-server/config_test.go | 17 +++++++++++++++++ cmd/outline-ss-server/listeners.go | 6 +++++- 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/cmd/outline-ss-server/config.go b/cmd/outline-ss-server/config.go index e8a8a43f..1a734720 100644 --- a/cmd/outline-ss-server/config.go +++ b/cmd/outline-ss-server/config.go @@ -57,20 +57,25 @@ type Config struct { // Validate checks that the config is valid. func (c *Config) Validate() error { + existingListeners := make(map[string]bool) for _, serviceConfig := range c.Services { - for _, listenerConfig := range serviceConfig.Listeners { + for _, lnConfig := range serviceConfig.Listeners { // TODO: Support more listener types. - if listenerConfig.Type != listenerTypeTCP && listenerConfig.Type != listenerTypeUDP { - return fmt.Errorf("unsupported listener type: %s", listenerConfig.Type) + if lnConfig.Type != listenerTypeTCP && lnConfig.Type != listenerTypeUDP { + return fmt.Errorf("unsupported listener type: %s", lnConfig.Type) } - - host, _, err := net.SplitHostPort(listenerConfig.Address) + host, _, err := net.SplitHostPort(lnConfig.Address) if err != nil { - return fmt.Errorf("invalid listener address `%s`: %v", listenerConfig.Address, err) + return fmt.Errorf("invalid listener address `%s`: %v", lnConfig.Address, err) } if ip := net.ParseIP(host); ip == nil { return fmt.Errorf("address must be IP, found: %s", host) } + key := listenerKey(string(lnConfig.Type), lnConfig.Address) + if _, exists := existingListeners[key]; exists { + return fmt.Errorf("listener of type %s with address %s already exists.", lnConfig.Type, lnConfig.Address) + } + existingListeners[key] = true } } return nil diff --git a/cmd/outline-ss-server/config_test.go b/cmd/outline-ss-server/config_test.go index 25895111..f183ff5a 100644 --- a/cmd/outline-ss-server/config_test.go +++ b/cmd/outline-ss-server/config_test.go @@ -62,6 +62,23 @@ func TestValidateConfigFails(t *testing.T) { }, }, }, + { + name: "WithDuplicateListeners", + cfg: &Config{ + Services: []ServiceConfig{ + ServiceConfig{ + Listeners: []ListenerConfig{ + ListenerConfig{Type: listenerTypeTCP, Address: "[::]:9000"}, + }, + }, + ServiceConfig{ + Listeners: []ListenerConfig{ + ListenerConfig{Type: listenerTypeTCP, Address: "[::]:9000"}, + }, + }, + }, + }, + }, } for _, tc := range tests { diff --git a/cmd/outline-ss-server/listeners.go b/cmd/outline-ss-server/listeners.go index c8a5d644..77716acf 100644 --- a/cmd/outline-ss-server/listeners.go +++ b/cmd/outline-ss-server/listeners.go @@ -166,7 +166,7 @@ func NewListenerManager() ListenerManager { // reusable listener wrappers, which do not actually close the underlying socket // until all uses of the shared listener have been closed. func (m *listenerManager) Listen(ctx context.Context, network string, addr string, config net.ListenConfig) (any, error) { - lnKey := network + "/" + addr + lnKey := listenerKey(network, addr) switch network { @@ -245,3 +245,7 @@ func (m *listenerManager) Delete(key string) { delete(m.listeners, key) m.listenersMu.Unlock() } + +func listenerKey(network string, addr string) string { + return network + "/" + addr +} From dc1075a1ddae2d0407b91f92fe2c4d03cac9e2e3 Mon Sep 17 00:00:00 2001 From: sbruens Date: Wed, 10 Jul 2024 15:45:35 -0400 Subject: [PATCH 057/119] Use channels to notify shared listeners they need to stop acceoting. --- cmd/outline-ss-server/listeners.go | 118 +++++++++++------------------ 1 file changed, 45 insertions(+), 73 deletions(-) diff --git a/cmd/outline-ss-server/listeners.go b/cmd/outline-ss-server/listeners.go index 77716acf..d00ac30b 100644 --- a/cmd/outline-ss-server/listeners.go +++ b/cmd/outline-ss-server/listeners.go @@ -20,78 +20,44 @@ import ( "net" "sync" "sync/atomic" - "time" ) +type acceptResponse struct { + conn net.Conn + err error +} + type sharedListener struct { - listener net.Listener - manager ListenerManager - key string - closed atomic.Int32 - usage *atomic.Int32 - deadline *bool - deadlineMu *sync.Mutex + listener net.Listener + manager ListenerManager + key string + closed atomic.Int32 + usage *atomic.Int32 + acceptCh chan acceptResponse + closeCh chan struct{} } // Accept accepts connections until Close() is called. func (sl *sharedListener) Accept() (net.Conn, error) { if sl.closed.Load() == 1 { - return nil, &net.OpError{ - Op: "accept", - Net: sl.listener.Addr().Network(), - Addr: sl.listener.Addr(), - Err: net.ErrClosed, - } + return nil, net.ErrClosed } - - conn, err := sl.listener.Accept() - if err == nil { - return conn, nil - } - - sl.deadlineMu.Lock() - if *sl.deadline { - switch ln := sl.listener.(type) { - case *net.TCPListener: - ln.SetDeadline(time.Time{}) - } - *sl.deadline = false - } - sl.deadlineMu.Unlock() - - if sl.closed.Load() == 1 { - // In `Close()` we set a deadline in the past to force currently-blocked - // listeners to close without having to close the underlying socket. To - // avoid callers from retrying, we avoid returning timeout errors and - // instead make sure we return a fake "closed" error. - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - return nil, &net.OpError{ - Op: "accept", - Net: sl.listener.Addr().Network(), - Addr: sl.listener.Addr(), - Err: net.ErrClosed, - } + select { + case acceptResponse := <-sl.acceptCh: + if acceptResponse.err != nil { + return nil, acceptResponse.err } + return acceptResponse.conn, nil + case <-sl.closeCh: + return nil, net.ErrClosed } - - return nil, err } // Close stops accepting new connections without closing the underlying socket. // Only when the last user closes it, we actually close it. func (sl *sharedListener) Close() error { if sl.closed.CompareAndSwap(0, 1) { - // NOTE: In order to cancel current calls to Accept(), we set a deadline in - // the past, as we cannot actually close the listener. - sl.deadlineMu.Lock() - if !*sl.deadline { - switch ln := sl.listener.(type) { - case *net.TCPListener: - ln.SetDeadline(time.Now().Add(-1 * time.Minute)) - } - *sl.deadline = true - } - sl.deadlineMu.Unlock() + close(sl.closeCh) // See if we need to actually close the underlying listener. if sl.usage.Add(-1) == 0 { @@ -135,11 +101,10 @@ func (spc *sharedPacketConn) Close() error { } type globalListener struct { - ln net.Listener - pc net.PacketConn - usage atomic.Int32 - deadline bool - deadlineMu sync.Mutex + ln net.Listener + pc net.PacketConn + usage atomic.Int32 + acceptCh chan acceptResponse } // ListenerManager holds and manages the state of shared listeners. @@ -177,12 +142,12 @@ func (m *listenerManager) Listen(ctx context.Context, network string, addr strin if lnGlobal, ok := m.listeners[lnKey]; ok { lnGlobal.usage.Add(1) return &sharedListener{ - listener: lnGlobal.ln, - manager: m, - key: lnKey, - usage: &lnGlobal.usage, - deadline: &lnGlobal.deadline, - deadlineMu: &lnGlobal.deadlineMu, + listener: lnGlobal.ln, + manager: m, + key: lnKey, + usage: &lnGlobal.usage, + acceptCh: lnGlobal.acceptCh, + closeCh: make(chan struct{}), }, nil } @@ -191,17 +156,24 @@ func (m *listenerManager) Listen(ctx context.Context, network string, addr strin return nil, err } - lnGlobal := &globalListener{ln: ln} + lnGlobal := &globalListener{ln: ln, acceptCh: make(chan acceptResponse)} lnGlobal.usage.Store(1) m.listeners[lnKey] = lnGlobal + go func() { + for { + conn, err := lnGlobal.ln.Accept() + lnGlobal.acceptCh <- acceptResponse{conn, err} + } + }() + return &sharedListener{ - listener: ln, - manager: m, - key: lnKey, - usage: &lnGlobal.usage, - deadline: &lnGlobal.deadline, - deadlineMu: &lnGlobal.deadlineMu, + listener: ln, + manager: m, + key: lnKey, + usage: &lnGlobal.usage, + acceptCh: lnGlobal.acceptCh, + closeCh: make(chan struct{}), }, nil case "udp": From 2a343e2efbffbe247458e86fb8e058af79b4ce8b Mon Sep 17 00:00:00 2001 From: sbruens Date: Wed, 10 Jul 2024 15:51:37 -0400 Subject: [PATCH 058/119] Pass TCP timeout to service. --- cmd/outline-ss-server/main.go | 3 ++- cmd/outline-ss-server/service.go | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 92bb62b4..111b1c2d 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -90,6 +90,7 @@ func (s *SSServer) loadConfig(filename string) error { if !ok { legacyService = &Service{ lnManager: s.lnManager, + tcpTimeout: tcpReadTimeout, natTimeout: s.natTimeout, m: s.m, replayCache: &s.replayCache, @@ -113,7 +114,7 @@ func (s *SSServer) loadConfig(filename string) error { } for _, serviceConfig := range config.Services { - service, err := NewService(serviceConfig, s.lnManager, s.natTimeout, s.m, &s.replayCache) + service, err := NewService(serviceConfig, s.lnManager, tcpReadTimeout, s.natTimeout, s.m, &s.replayCache) if err != nil { return fmt.Errorf("Failed to create new service: %v", err) } diff --git a/cmd/outline-ss-server/service.go b/cmd/outline-ss-server/service.go index a8ac9f04..d935b0de 100644 --- a/cmd/outline-ss-server/service.go +++ b/cmd/outline-ss-server/service.go @@ -35,6 +35,7 @@ type Listener = any type Service struct { lnManager ListenerManager + tcpTimeout time.Duration natTimeout time.Duration m *outlineMetrics replayCache *service.ReplayCache @@ -47,7 +48,7 @@ func (s *Service) Serve(lnKey string, listener Listener, cipherList service.Ciph case net.Listener: authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, s.replayCache, s.m) // TODO: Register initial data metrics at zero. - tcpHandler := service.NewTCPHandler(lnKey, authFunc, s.m, tcpReadTimeout) + tcpHandler := service.NewTCPHandler(lnKey, authFunc, s.m, s.tcpTimeout) accept := func() (transport.StreamConn, error) { c, err := ln.Accept() if err == nil { @@ -118,9 +119,10 @@ func (s *Service) NumCiphers() int { } // NewService creates a new Service based on a config -func NewService(config ServiceConfig, lnManager ListenerManager, natTimeout time.Duration, m *outlineMetrics, replayCache *service.ReplayCache) (*Service, error) { +func NewService(config ServiceConfig, lnManager ListenerManager, tcpTimeout time.Duration, natTimeout time.Duration, m *outlineMetrics, replayCache *service.ReplayCache) (*Service, error) { s := Service{ lnManager: lnManager, + tcpTimeout: tcpTimeout, natTimeout: natTimeout, m: m, replayCache: replayCache, From e58b79dcf55ef5a0d567d1c8682d247c8c346c48 Mon Sep 17 00:00:00 2001 From: sbruens Date: Wed, 10 Jul 2024 15:57:23 -0400 Subject: [PATCH 059/119] Move go routine call up. --- cmd/outline-ss-server/listeners.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/cmd/outline-ss-server/listeners.go b/cmd/outline-ss-server/listeners.go index d00ac30b..03be496d 100644 --- a/cmd/outline-ss-server/listeners.go +++ b/cmd/outline-ss-server/listeners.go @@ -67,7 +67,6 @@ func (sl *sharedListener) Close() error { return err } } - } return nil @@ -157,15 +156,14 @@ func (m *listenerManager) Listen(ctx context.Context, network string, addr strin } lnGlobal := &globalListener{ln: ln, acceptCh: make(chan acceptResponse)} - lnGlobal.usage.Store(1) - m.listeners[lnKey] = lnGlobal - go func() { for { conn, err := lnGlobal.ln.Accept() lnGlobal.acceptCh <- acceptResponse{conn, err} } }() + lnGlobal.usage.Store(1) + m.listeners[lnKey] = lnGlobal return &sharedListener{ listener: ln, From c7465fbe0496e9a36a6211aa503689acb01e70c2 Mon Sep 17 00:00:00 2001 From: sbruens Date: Thu, 11 Jul 2024 12:37:17 -0400 Subject: [PATCH 060/119] Allow inserting single elements directly into the cipher list. --- service/cipher_list.go | 6 ++++++ service/cipher_list_testing.go | 7 ++----- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/service/cipher_list.go b/service/cipher_list.go index 3b6f1957..171b4236 100644 --- a/service/cipher_list.go +++ b/service/cipher_list.go @@ -62,6 +62,8 @@ type CipherList interface { // which is a List of *CipherEntry. Update takes ownership of `contents`, // which must not be read or written after this call. Update(contents *list.List) + // PushBack inserts a new cipher at the back of the list. + PushBack(entry *CipherEntry) *list.Element } type cipherList struct { @@ -116,3 +118,7 @@ func (cl *cipherList) Update(src *list.List) { cl.list = src cl.mu.Unlock() } + +func (cl *cipherList) PushBack(entry *CipherEntry) *list.Element { + return cl.list.PushBack(entry) +} diff --git a/service/cipher_list_testing.go b/service/cipher_list_testing.go index a77427ed..d8532f79 100644 --- a/service/cipher_list_testing.go +++ b/service/cipher_list_testing.go @@ -15,7 +15,6 @@ package service import ( - "container/list" "fmt" "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks" @@ -24,7 +23,7 @@ import ( // MakeTestCiphers creates a CipherList containing one fresh AEAD cipher // for each secret in `secrets`. func MakeTestCiphers(secrets []string) (CipherList, error) { - l := list.New() + cipherList := NewCipherList() for i := 0; i < len(secrets); i++ { cipherID := fmt.Sprintf("id-%v", i) cipher, err := shadowsocks.NewEncryptionKey(shadowsocks.CHACHA20IETFPOLY1305, secrets[i]) @@ -32,10 +31,8 @@ func MakeTestCiphers(secrets []string) (CipherList, error) { return nil, fmt.Errorf("failed to create cipher %v: %w", i, err) } entry := MakeCipherEntry(cipherID, cipher, secrets[i]) - l.PushBack(&entry) + cipherList.PushBack(&entry) } - cipherList := NewCipherList() - cipherList.Update(l) return cipherList, nil } From 43fa0d61bc26c2b09cae07d9149549734523549b Mon Sep 17 00:00:00 2001 From: sbruens Date: Thu, 11 Jul 2024 16:14:40 -0400 Subject: [PATCH 061/119] Add the concept of a listener set to track existing listeners and close them all. --- cmd/outline-ss-server/listeners.go | 83 ++++++++- cmd/outline-ss-server/main.go | 154 ++++++++++++----- cmd/outline-ss-server/service.go | 160 ------------------ internal/integration_test/integration_test.go | 6 +- service/cipher_list.go | 5 + service/udp.go | 5 +- service/udp_test.go | 5 +- 7 files changed, 208 insertions(+), 210 deletions(-) delete mode 100644 cmd/outline-ss-server/service.go diff --git a/cmd/outline-ss-server/listeners.go b/cmd/outline-ss-server/listeners.go index 03be496d..f45339b1 100644 --- a/cmd/outline-ss-server/listeners.go +++ b/cmd/outline-ss-server/listeners.go @@ -20,6 +20,9 @@ import ( "net" "sync" "sync/atomic" + + "github.com/Jigsaw-Code/outline-sdk/transport" + "github.com/Jigsaw-Code/outline-ss-server/service" ) type acceptResponse struct { @@ -27,6 +30,10 @@ type acceptResponse struct { err error } +type SharedListener interface { + SetHandler(handler Handler) +} + type sharedListener struct { listener net.Listener manager ListenerManager @@ -37,6 +44,20 @@ type sharedListener struct { closeCh chan struct{} } +func (sl *sharedListener) SetHandler(handler Handler) { + accept := func() (transport.StreamConn, error) { + c, err := sl.Accept() + if err == nil { + return c.(transport.StreamConn), err + } + return nil, err + } + handle := func(ctx context.Context, conn transport.StreamConn) { + handler.Handle(ctx, conn) + } + go service.StreamServe(accept, handle) +} + // Accept accepts connections until Close() is called. func (sl *sharedListener) Accept() (net.Conn, error) { if sl.closed.Load() == 1 { @@ -99,6 +120,10 @@ func (spc *sharedPacketConn) Close() error { return nil } +func (spc *sharedPacketConn) SetHandler(handler Handler) { + go handler.Handle(context.TODO(), spc.PacketConn) +} + type globalListener struct { ln net.Listener pc net.PacketConn @@ -106,9 +131,56 @@ type globalListener struct { acceptCh chan acceptResponse } +type ListenerSet interface { + Listen(ctx context.Context, network string, addr string, config net.ListenConfig) (SharedListener, error) + Close() error + Len() int +} + +type listenerSet struct { + manager ListenerManager + listeners map[string]*SharedListener +} + +func (ls *listenerSet) Listen(ctx context.Context, network string, addr string, config net.ListenConfig) (SharedListener, error) { + lnKey := listenerKey(network, addr) + if _, exists := ls.listeners[lnKey]; exists { + return nil, fmt.Errorf("listener %s already exists", lnKey) + } + ln, err := ls.manager.Listen(ctx, network, addr, config) + if err != nil { + return nil, err + } + ls.listeners[lnKey] = &ln + return ln, nil +} + +func (ls *listenerSet) Close() error { + for _, listener := range ls.listeners { + switch ln := (*listener).(type) { + case net.Listener: + if err := ln.Close(); err != nil { + return fmt.Errorf("%s listener on address %s failed to stop: %w", ln.Addr().Network(), ln.Addr().String(), err) + } + case net.PacketConn: + if err := ln.Close(); err != nil { + return fmt.Errorf("%s listener on address %s failed to stop: %w", ln.LocalAddr().Network(), ln.LocalAddr().String(), err) + } + default: + return fmt.Errorf("unknown listener type: %v", ln) + } + } + return nil +} + +func (ls *listenerSet) Len() int { + return len(ls.listeners) +} + // ListenerManager holds and manages the state of shared listeners. type ListenerManager interface { - Listen(ctx context.Context, network string, addr string, config net.ListenConfig) (any, error) + NewListenerSet() ListenerSet + Listen(ctx context.Context, network string, addr string, config net.ListenConfig) (SharedListener, error) Delete(key string) } @@ -123,13 +195,20 @@ func NewListenerManager() ListenerManager { } } +func (m *listenerManager) NewListenerSet() ListenerSet { + return &listenerSet{ + manager: m, + listeners: make(map[string]*SharedListener), + } +} + // Listen creates a new listener for a given network and address. // // Listeners can overlap one another, because during config changes the new // config is started before the old config is destroyed. This is done by using // reusable listener wrappers, which do not actually close the underlying socket // until all uses of the shared listener have been closed. -func (m *listenerManager) Listen(ctx context.Context, network string, addr string, config net.ListenConfig) (any, error) { +func (m *listenerManager) Listen(ctx context.Context, network string, addr string, config net.ListenConfig) (SharedListener, error) { lnKey := listenerKey(network, addr) switch network { diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 111b1c2d..10375744 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -15,7 +15,7 @@ package main import ( - "container/list" + "context" "flag" "fmt" "net" @@ -27,6 +27,7 @@ import ( "syscall" "time" + "github.com/Jigsaw-Code/outline-sdk/transport" "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks" "github.com/Jigsaw-Code/outline-ss-server/ipinfo" "github.com/Jigsaw-Code/outline-ss-server/service" @@ -58,12 +59,80 @@ func init() { logger = logging.MustGetLogger("") } +type Handler interface { + NumCiphers() int + AddCipher(entry *service.CipherEntry) + Handle(ctx context.Context, conn any) +} + +type connHandler struct { + tcpTimeout time.Duration + natTimeout time.Duration + replayCache *service.ReplayCache + m *outlineMetrics + ciphers service.CipherList +} + +// NewHandler creates a new Handler handler based on a service config. +func NewHandler(config ServiceConfig, tcpTimeout time.Duration, natTimeout time.Duration, m *outlineMetrics, replayCache *service.ReplayCache) (Handler, error) { + type cipherKey struct { + cipher string + secret string + } + ciphers := service.NewCipherList() + existingCiphers := make(map[cipherKey]bool) + for _, keyConfig := range config.Keys { + key := cipherKey{keyConfig.Cipher, keyConfig.Secret} + if _, exists := existingCiphers[key]; exists { + logger.Debugf("encryption key already exists for ID=`%v`. Skipping.", keyConfig.ID) + continue + } + cryptoKey, err := shadowsocks.NewEncryptionKey(keyConfig.Cipher, keyConfig.Secret) + if err != nil { + return nil, fmt.Errorf("failed to create encyption key for key %v: %w", keyConfig.ID, err) + } + entry := service.MakeCipherEntry(keyConfig.ID, cryptoKey, keyConfig.Secret) + ciphers.PushBack(&entry) + existingCiphers[key] = true + } + return &connHandler{ + ciphers: ciphers, + tcpTimeout: tcpTimeout, + natTimeout: natTimeout, + m: m, + replayCache: replayCache, + }, nil +} + +func (h *connHandler) NumCiphers() int { + return h.ciphers.Len() +} + +func (h *connHandler) AddCipher(entry *service.CipherEntry) { + h.ciphers.PushBack(entry) +} + +func (h *connHandler) Handle(ctx context.Context, conn any) { + switch c := conn.(type) { + case transport.StreamConn: + authFunc := service.NewShadowsocksStreamAuthenticator(h.ciphers, h.replayCache, h.m) + // TODO: Register initial data metrics at zero. + tcpHandler := service.NewTCPHandler(c.LocalAddr().String(), authFunc, h.m, h.tcpTimeout) + tcpHandler.Handle(ctx, c) + case net.PacketConn: + packetHandler := service.NewPacketHandler(h.natTimeout, h.ciphers, h.m) + packetHandler.Handle(ctx, c) + default: + logger.Errorf("unknown connection type: %v", c) + } +} + type SSServer struct { lnManager ListenerManager + lnSet ListenerSet natTimeout time.Duration m *outlineMetrics replayCache service.ReplayCache - services []*Service } func (s *SSServer) loadConfig(filename string) error { @@ -82,77 +151,80 @@ func (s *SSServer) loadConfig(filename string) error { // We hot swap the services by having them both live at the same time. This // means we create services for the new config first, and then take down the // services from the old config. - newServices := make([]*Service, 0) + oldListenerSet := s.lnSet + s.lnSet = s.lnManager.NewListenerSet() + var totalCipherCount int - legacyPortService := make(map[int]*Service) // Values are *List of *CipherEntry. + portHandlers := make(map[int]Handler) for _, legacyKeyServiceConfig := range config.Keys { - legacyService, ok := legacyPortService[legacyKeyServiceConfig.Port] + handler, ok := portHandlers[legacyKeyServiceConfig.Port] if !ok { - legacyService = &Service{ - lnManager: s.lnManager, + handler = &connHandler{ + ciphers: service.NewCipherList(), tcpTimeout: tcpReadTimeout, natTimeout: s.natTimeout, m: s.m, replayCache: &s.replayCache, - ciphers: list.New(), } - for _, network := range []string{"tcp", "udp"} { - addr := net.JoinHostPort("::", strconv.Itoa(legacyKeyServiceConfig.Port)) - if err := legacyService.AddListener(network, addr); err != nil { - return err - } - } - newServices = append(newServices, legacyService) - legacyPortService[legacyKeyServiceConfig.Port] = legacyService + portHandlers[legacyKeyServiceConfig.Port] = handler } cryptoKey, err := shadowsocks.NewEncryptionKey(legacyKeyServiceConfig.Cipher, legacyKeyServiceConfig.Secret) if err != nil { return fmt.Errorf("failed to create encyption key for key %v: %w", legacyKeyServiceConfig.ID, err) } entry := service.MakeCipherEntry(legacyKeyServiceConfig.ID, cryptoKey, legacyKeyServiceConfig.Secret) - legacyService.AddCipher(&entry) + handler.AddCipher(&entry) + } + for portNum, handler := range portHandlers { + totalCipherCount += handler.NumCiphers() + for _, network := range []string{"tcp", "udp"} { + addr := net.JoinHostPort("::", strconv.Itoa(portNum)) + listener, err := s.lnSet.Listen(context.TODO(), network, addr, net.ListenConfig{KeepAlive: 0}) + if err != nil { + return fmt.Errorf("%s service failed to start listening on address %s: %w", network, addr, err) + } + listener.SetHandler(handler) + } } for _, serviceConfig := range config.Services { - service, err := NewService(serviceConfig, s.lnManager, tcpReadTimeout, s.natTimeout, s.m, &s.replayCache) + handler, err := NewHandler(serviceConfig, tcpReadTimeout, s.natTimeout, s.m, &s.replayCache) if err != nil { - return fmt.Errorf("Failed to create new service: %v", err) + return fmt.Errorf("failed to create service handler: %w", err) + } + totalCipherCount += handler.NumCiphers() + for _, listenerConfig := range serviceConfig.Listeners { + network := string(listenerConfig.Type) + listener, err := s.lnSet.Listen(context.TODO(), network, listenerConfig.Address, net.ListenConfig{KeepAlive: 0}) + if err != nil { + return fmt.Errorf("%s service failed to start listening on address %s: %w", network, listenerConfig.Address, err) + } + listener.SetHandler(handler) } - newServices = append(newServices, service) } - logger.Infof("Loaded %d new services", len(newServices)) + logger.Infof("Loaded %d access keys over %d listeners", totalCipherCount, s.lnSet.Len()) + s.m.SetNumAccessKeys(totalCipherCount, s.lnSet.Len()) // Take down the old services now that the new ones are created and serving. - if err := s.Stop(); err != nil { - logger.Errorf("Failed to stop old services: %w", err) + if oldListenerSet != nil { + if err := oldListenerSet.Close(); err != nil { + logger.Errorf("Failed to stop old listeners: %w", err) + } + logger.Infof("Stopped %d old listeners", s.lnSet.Len()) } - s.services = newServices - var ( - listenerCount int - cipherCount int - ) - for _, service := range s.services { - listenerCount += service.NumListeners() - cipherCount += service.NumCiphers() - } - logger.Infof("%d services active: %d access keys over %d listeners", len(s.services), cipherCount, listenerCount) - s.m.SetNumAccessKeys(cipherCount, listenerCount) return nil } // Stop serving on all existing services. func (s *SSServer) Stop() error { - if len(s.services) == 0 { + if s.lnSet == nil { return nil } - for _, service := range s.services { - if err := service.Stop(); err != nil { - return err - } + if err := s.lnSet.Close(); err != nil { + logger.Errorf("Failed to stop all listeners: %w", err) } - logger.Infof("Stopped %d old services", len(s.services)) - s.services = nil + logger.Infof("Stopped %d listeners", s.lnSet.Len()) return nil } diff --git a/cmd/outline-ss-server/service.go b/cmd/outline-ss-server/service.go deleted file mode 100644 index d935b0de..00000000 --- a/cmd/outline-ss-server/service.go +++ /dev/null @@ -1,160 +0,0 @@ -// Copyright 2024 The Outline Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "container/list" - "context" - "fmt" - "net" - "time" - - "github.com/Jigsaw-Code/outline-sdk/transport" - "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks" - "github.com/Jigsaw-Code/outline-ss-server/service" -) - -// The implementations of listeners for different network types are not -// interchangeable. The type of listener depends on the network type. -// TODO(sbruens): Create a custom `Listener` type so we can share serving logic, -// dispatching to the handlers based on connection type instead of on the -// listener type. -type Listener = any - -type Service struct { - lnManager ListenerManager - tcpTimeout time.Duration - natTimeout time.Duration - m *outlineMetrics - replayCache *service.ReplayCache - listeners []Listener - ciphers *list.List // Values are *List of *service.CipherEntry. -} - -func (s *Service) Serve(lnKey string, listener Listener, cipherList service.CipherList) error { - switch ln := listener.(type) { - case net.Listener: - authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, s.replayCache, s.m) - // TODO: Register initial data metrics at zero. - tcpHandler := service.NewTCPHandler(lnKey, authFunc, s.m, s.tcpTimeout) - accept := func() (transport.StreamConn, error) { - c, err := ln.Accept() - if err == nil { - return c.(transport.StreamConn), err - } - return nil, err - } - go service.StreamServe(accept, tcpHandler.Handle) - case net.PacketConn: - packetHandler := service.NewPacketHandler(s.natTimeout, cipherList, s.m) - go packetHandler.Handle(ln) - default: - return fmt.Errorf("unknown listener type: %v", ln) - } - return nil -} - -func (s *Service) Stop() error { - for _, listener := range s.listeners { - switch ln := listener.(type) { - case net.Listener: - if err := ln.Close(); err != nil { - //lint:ignore ST1005 Shadowsocks is capitalized. - return fmt.Errorf("Shadowsocks %s service on address %s failed to stop: %w", ln.Addr().Network(), ln.Addr().String(), err) - } - case net.PacketConn: - if err := ln.Close(); err != nil { - //lint:ignore ST1005 Shadowsocks is capitalized. - return fmt.Errorf("Shadowsocks %s service on address %s failed to stop: %w", ln.LocalAddr().Network(), ln.LocalAddr().String(), err) - } - default: - return fmt.Errorf("unknown listener type: %v", ln) - } - } - return nil -} - -// AddListener adds a new listener to the service. -func (s *Service) AddListener(network string, addr string) error { - // Create new listeners based on the configured network addresses. - cipherList := service.NewCipherList() - cipherList.Update(s.ciphers) - - listener, err := s.lnManager.Listen(context.TODO(), network, addr, net.ListenConfig{KeepAlive: 0}) - if err != nil { - //lint:ignore ST1005 Shadowsocks is capitalized. - return fmt.Errorf("Shadowsocks %s service failed to start on address %s: %w", network, addr, err) - } - s.listeners = append(s.listeners, listener) - logger.Infof("Shadowsocks %s service listening on %s", network, addr) - lnKey := network + "/" + addr - if err = s.Serve(lnKey, listener, cipherList); err != nil { - return fmt.Errorf("failed to serve on %s listener on address %s: %w", network, addr, err) - } - return nil -} - -func (s *Service) NumListeners() int { - return len(s.listeners) -} - -func (s *Service) AddCipher(entry *service.CipherEntry) { - s.ciphers.PushBack(entry) -} - -func (s *Service) NumCiphers() int { - return s.ciphers.Len() -} - -// NewService creates a new Service based on a config -func NewService(config ServiceConfig, lnManager ListenerManager, tcpTimeout time.Duration, natTimeout time.Duration, m *outlineMetrics, replayCache *service.ReplayCache) (*Service, error) { - s := Service{ - lnManager: lnManager, - tcpTimeout: tcpTimeout, - natTimeout: natTimeout, - m: m, - replayCache: replayCache, - ciphers: list.New(), - } - - type cipherKey struct { - cipher string - secret string - } - existingCiphers := make(map[cipherKey]bool) - for _, keyConfig := range config.Keys { - key := cipherKey{keyConfig.Cipher, keyConfig.Secret} - if _, exists := existingCiphers[key]; exists { - logger.Debugf("encryption key already exists for ID=`%v`. Skipping.", keyConfig.ID) - continue - } - cryptoKey, err := shadowsocks.NewEncryptionKey(keyConfig.Cipher, keyConfig.Secret) - if err != nil { - return nil, fmt.Errorf("failed to create encyption key for key %v: %w", keyConfig.ID, err) - } - entry := service.MakeCipherEntry(keyConfig.ID, cryptoKey, keyConfig.Secret) - s.AddCipher(&entry) - existingCiphers[key] = true - } - - for _, listener := range config.Listeners { - network := string(listener.Type) - if err := s.AddListener(network, listener.Address); err != nil { - return nil, err - } - } - - return &s, nil -} diff --git a/internal/integration_test/integration_test.go b/internal/integration_test/integration_test.go index f98319f4..2bfb0dfa 100644 --- a/internal/integration_test/integration_test.go +++ b/internal/integration_test/integration_test.go @@ -293,7 +293,7 @@ func TestUDPEcho(t *testing.T) { proxy.SetTargetIPValidator(allowAll) done := make(chan struct{}) go func() { - proxy.Handle(proxyConn) + proxy.Handle(context.Background(), proxyConn) done <- struct{}{} }() @@ -525,7 +525,7 @@ func BenchmarkUDPEcho(b *testing.B) { proxy.SetTargetIPValidator(allowAll) done := make(chan struct{}) go func() { - proxy.Handle(server) + proxy.Handle(context.Background(), server) done <- struct{}{} }() @@ -569,7 +569,7 @@ func BenchmarkUDPManyKeys(b *testing.B) { proxy.SetTargetIPValidator(allowAll) done := make(chan struct{}) go func() { - proxy.Handle(proxyConn) + proxy.Handle(context.Background(), proxyConn) done <- struct{}{} }() diff --git a/service/cipher_list.go b/service/cipher_list.go index 171b4236..d84ab1ac 100644 --- a/service/cipher_list.go +++ b/service/cipher_list.go @@ -55,6 +55,7 @@ func MakeCipherEntry(id string, cryptoKey *shadowsocks.EncryptionKey, secret str // CipherList is a thread-safe collection of CipherEntry elements that allows for // snapshotting and moving to front. type CipherList interface { + Len() int // Returns a snapshot of the cipher list optimized for this client IP SnapshotForClientIP(clientIP netip.Addr) []*list.Element MarkUsedByClientIP(e *list.Element, clientIP netip.Addr) @@ -77,6 +78,10 @@ func NewCipherList() CipherList { return &cipherList{list: list.New()} } +func (cl *cipherList) Len() int { + return cl.list.Len() +} + func matchesIP(e *list.Element, clientIP netip.Addr) bool { c := e.Value.(*CipherEntry) return clientIP != netip.Addr{} && clientIP == c.lastClientIP diff --git a/service/udp.go b/service/udp.go index 4830e302..859c6c44 100644 --- a/service/udp.go +++ b/service/udp.go @@ -15,6 +15,7 @@ package service import ( + "context" "errors" "fmt" "net" @@ -101,7 +102,7 @@ type PacketHandler interface { // SetTargetIPValidator sets the function to be used to validate the target IP addresses. SetTargetIPValidator(targetIPValidator onet.TargetIPValidator) // Handle returns after clientConn closes and all the sub goroutines return. - Handle(clientConn net.PacketConn) + Handle(ctx context.Context, clientConn net.PacketConn) } func (h *packetHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPValidator) { @@ -110,7 +111,7 @@ func (h *packetHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPVali // Listen on addr for encrypted packets and basically do UDP NAT. // We take the ciphers as a pointer because it gets replaced on config updates. -func (h *packetHandler) Handle(clientConn net.PacketConn) { +func (h *packetHandler) Handle(ctx context.Context, clientConn net.PacketConn) { var running sync.WaitGroup nm := newNATmap(h.natTimeout, h.m, &running) diff --git a/service/udp_test.go b/service/udp_test.go index f94238c5..90d880b5 100644 --- a/service/udp_test.go +++ b/service/udp_test.go @@ -16,6 +16,7 @@ package service import ( "bytes" + "context" "errors" "net" "net/netip" @@ -132,7 +133,7 @@ func sendToDiscard(payloads [][]byte, validator onet.TargetIPValidator) *natTest handler.SetTargetIPValidator(validator) done := make(chan struct{}) go func() { - handler.Handle(clientConn) + handler.Handle(context.Background(), clientConn) done <- struct{}{} }() @@ -488,7 +489,7 @@ func TestUDPEarlyClose(t *testing.T) { } require.Nil(t, clientConn.Close()) // This should return quickly without timing out. - s.Handle(clientConn) + s.Handle(context.Background(), clientConn) } // Makes sure the UDP listener returns [io.ErrClosed] on reads and writes after Close(). From cf9b7d2cd389c16c002f51331d076640228f59d4 Mon Sep 17 00:00:00 2001 From: sbruens Date: Thu, 11 Jul 2024 16:46:38 -0400 Subject: [PATCH 062/119] Refactor how we create listeners. We introduce shared listeners that allow us to keep an old config running while we set up a new config. This is done by keeping track of the usage of the listeners and only closing them when the last user is done with the shared listener. --- cmd/outline-ss-server/listeners.go | 300 ++++++++++++++++++ cmd/outline-ss-server/main.go | 167 +++++----- cmd/outline-ss-server/metrics.go | 14 +- go.mod | 4 +- internal/integration_test/integration_test.go | 6 +- service/cipher_list.go | 11 + service/cipher_list_testing.go | 7 +- service/tcp.go | 2 +- service/udp.go | 5 +- service/udp_test.go | 5 +- 10 files changed, 418 insertions(+), 103 deletions(-) create mode 100644 cmd/outline-ss-server/listeners.go diff --git a/cmd/outline-ss-server/listeners.go b/cmd/outline-ss-server/listeners.go new file mode 100644 index 00000000..f45339b1 --- /dev/null +++ b/cmd/outline-ss-server/listeners.go @@ -0,0 +1,300 @@ +// Copyright 2024 The Outline Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "fmt" + "net" + "sync" + "sync/atomic" + + "github.com/Jigsaw-Code/outline-sdk/transport" + "github.com/Jigsaw-Code/outline-ss-server/service" +) + +type acceptResponse struct { + conn net.Conn + err error +} + +type SharedListener interface { + SetHandler(handler Handler) +} + +type sharedListener struct { + listener net.Listener + manager ListenerManager + key string + closed atomic.Int32 + usage *atomic.Int32 + acceptCh chan acceptResponse + closeCh chan struct{} +} + +func (sl *sharedListener) SetHandler(handler Handler) { + accept := func() (transport.StreamConn, error) { + c, err := sl.Accept() + if err == nil { + return c.(transport.StreamConn), err + } + return nil, err + } + handle := func(ctx context.Context, conn transport.StreamConn) { + handler.Handle(ctx, conn) + } + go service.StreamServe(accept, handle) +} + +// Accept accepts connections until Close() is called. +func (sl *sharedListener) Accept() (net.Conn, error) { + if sl.closed.Load() == 1 { + return nil, net.ErrClosed + } + select { + case acceptResponse := <-sl.acceptCh: + if acceptResponse.err != nil { + return nil, acceptResponse.err + } + return acceptResponse.conn, nil + case <-sl.closeCh: + return nil, net.ErrClosed + } +} + +// Close stops accepting new connections without closing the underlying socket. +// Only when the last user closes it, we actually close it. +func (sl *sharedListener) Close() error { + if sl.closed.CompareAndSwap(0, 1) { + close(sl.closeCh) + + // See if we need to actually close the underlying listener. + if sl.usage.Add(-1) == 0 { + sl.manager.Delete(sl.key) + err := sl.listener.Close() + if err != nil { + return err + } + } + } + + return nil +} + +func (sl *sharedListener) Addr() net.Addr { + return sl.listener.Addr() +} + +type sharedPacketConn struct { + net.PacketConn + manager ListenerManager + key string + closed atomic.Int32 + usage *atomic.Int32 +} + +func (spc *sharedPacketConn) Close() error { + if spc.closed.CompareAndSwap(0, 1) { + // See if we need to actually close the underlying listener. + if spc.usage.Add(-1) == 0 { + spc.manager.Delete(spc.key) + err := spc.PacketConn.Close() + if err != nil { + return err + } + } + } + + return nil +} + +func (spc *sharedPacketConn) SetHandler(handler Handler) { + go handler.Handle(context.TODO(), spc.PacketConn) +} + +type globalListener struct { + ln net.Listener + pc net.PacketConn + usage atomic.Int32 + acceptCh chan acceptResponse +} + +type ListenerSet interface { + Listen(ctx context.Context, network string, addr string, config net.ListenConfig) (SharedListener, error) + Close() error + Len() int +} + +type listenerSet struct { + manager ListenerManager + listeners map[string]*SharedListener +} + +func (ls *listenerSet) Listen(ctx context.Context, network string, addr string, config net.ListenConfig) (SharedListener, error) { + lnKey := listenerKey(network, addr) + if _, exists := ls.listeners[lnKey]; exists { + return nil, fmt.Errorf("listener %s already exists", lnKey) + } + ln, err := ls.manager.Listen(ctx, network, addr, config) + if err != nil { + return nil, err + } + ls.listeners[lnKey] = &ln + return ln, nil +} + +func (ls *listenerSet) Close() error { + for _, listener := range ls.listeners { + switch ln := (*listener).(type) { + case net.Listener: + if err := ln.Close(); err != nil { + return fmt.Errorf("%s listener on address %s failed to stop: %w", ln.Addr().Network(), ln.Addr().String(), err) + } + case net.PacketConn: + if err := ln.Close(); err != nil { + return fmt.Errorf("%s listener on address %s failed to stop: %w", ln.LocalAddr().Network(), ln.LocalAddr().String(), err) + } + default: + return fmt.Errorf("unknown listener type: %v", ln) + } + } + return nil +} + +func (ls *listenerSet) Len() int { + return len(ls.listeners) +} + +// ListenerManager holds and manages the state of shared listeners. +type ListenerManager interface { + NewListenerSet() ListenerSet + Listen(ctx context.Context, network string, addr string, config net.ListenConfig) (SharedListener, error) + Delete(key string) +} + +type listenerManager struct { + listeners map[string]*globalListener + listenersMu sync.Mutex +} + +func NewListenerManager() ListenerManager { + return &listenerManager{ + listeners: make(map[string]*globalListener), + } +} + +func (m *listenerManager) NewListenerSet() ListenerSet { + return &listenerSet{ + manager: m, + listeners: make(map[string]*SharedListener), + } +} + +// Listen creates a new listener for a given network and address. +// +// Listeners can overlap one another, because during config changes the new +// config is started before the old config is destroyed. This is done by using +// reusable listener wrappers, which do not actually close the underlying socket +// until all uses of the shared listener have been closed. +func (m *listenerManager) Listen(ctx context.Context, network string, addr string, config net.ListenConfig) (SharedListener, error) { + lnKey := listenerKey(network, addr) + + switch network { + + case "tcp": + m.listenersMu.Lock() + defer m.listenersMu.Unlock() + + if lnGlobal, ok := m.listeners[lnKey]; ok { + lnGlobal.usage.Add(1) + return &sharedListener{ + listener: lnGlobal.ln, + manager: m, + key: lnKey, + usage: &lnGlobal.usage, + acceptCh: lnGlobal.acceptCh, + closeCh: make(chan struct{}), + }, nil + } + + ln, err := config.Listen(ctx, network, addr) + if err != nil { + return nil, err + } + + lnGlobal := &globalListener{ln: ln, acceptCh: make(chan acceptResponse)} + go func() { + for { + conn, err := lnGlobal.ln.Accept() + lnGlobal.acceptCh <- acceptResponse{conn, err} + } + }() + lnGlobal.usage.Store(1) + m.listeners[lnKey] = lnGlobal + + return &sharedListener{ + listener: ln, + manager: m, + key: lnKey, + usage: &lnGlobal.usage, + acceptCh: lnGlobal.acceptCh, + closeCh: make(chan struct{}), + }, nil + + case "udp": + m.listenersMu.Lock() + defer m.listenersMu.Unlock() + + if lnGlobal, ok := m.listeners[lnKey]; ok { + lnGlobal.usage.Add(1) + return &sharedPacketConn{ + PacketConn: lnGlobal.pc, + manager: m, + key: lnKey, + usage: &lnGlobal.usage, + }, nil + } + + pc, err := config.ListenPacket(ctx, network, addr) + if err != nil { + return nil, err + } + + lnGlobal := &globalListener{pc: pc} + lnGlobal.usage.Store(1) + m.listeners[lnKey] = lnGlobal + + return &sharedPacketConn{ + PacketConn: pc, + manager: m, + key: lnKey, + usage: &lnGlobal.usage, + }, nil + + default: + return nil, fmt.Errorf("unsupported network: %s", network) + + } +} + +func (m *listenerManager) Delete(key string) { + m.listenersMu.Lock() + delete(m.listeners, key) + m.listenersMu.Unlock() +} + +func listenerKey(network string, addr string) string { + return network + "/" + addr +} diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 87860df7..be94cbdf 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -15,17 +15,19 @@ package main import ( - "container/list" + "context" "flag" "fmt" "net" "net/http" "os" "os/signal" + "strconv" "strings" "syscall" "time" + "github.com/Jigsaw-Code/outline-sdk/transport" "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks" "github.com/Jigsaw-Code/outline-ss-server/ipinfo" "github.com/Jigsaw-Code/outline-ss-server/service" @@ -58,62 +60,49 @@ func init() { logger = logging.MustGetLogger("") } -type ssPort struct { - tcpListener *net.TCPListener - packetConn net.PacketConn - cipherList service.CipherList +type Handler interface { + NumCiphers() int + AddCipher(entry *service.CipherEntry) + Handle(ctx context.Context, conn any) } -type SSServer struct { +type connHandler struct { + tcpTimeout time.Duration natTimeout time.Duration + replayCache *service.ReplayCache m *outlineMetrics - replayCache service.ReplayCache - ports map[int]*ssPort + ciphers service.CipherList } -func (s *SSServer) startPort(portNum int) error { - listener, err := net.ListenTCP("tcp", &net.TCPAddr{Port: portNum}) - if err != nil { - //lint:ignore ST1005 Shadowsocks is capitalized. - return fmt.Errorf("Shadowsocks TCP service failed to start on port %v: %w", portNum, err) - } - logger.Infof("Shadowsocks TCP service listening on %v", listener.Addr().String()) - packetConn, err := net.ListenUDP("udp", &net.UDPAddr{Port: portNum}) - if err != nil { - //lint:ignore ST1005 Shadowsocks is capitalized. - return fmt.Errorf("Shadowsocks UDP service failed to start on port %v: %w", portNum, err) - } - logger.Infof("Shadowsocks UDP service listening on %v", packetConn.LocalAddr().String()) - port := &ssPort{tcpListener: listener, packetConn: packetConn, cipherList: service.NewCipherList()} - authFunc := service.NewShadowsocksStreamAuthenticator(port.cipherList, &s.replayCache, s.m) - // TODO: Register initial data metrics at zero. - tcpHandler := service.NewTCPHandler(listener.Addr().String(), authFunc, s.m, tcpReadTimeout) - packetHandler := service.NewPacketHandler(s.natTimeout, port.cipherList, s.m) - s.ports[portNum] = port - go service.StreamServe(service.WrapStreamListener(listener.AcceptTCP), tcpHandler.Handle) - go packetHandler.Handle(port.packetConn) - return nil +func (h *connHandler) NumCiphers() int { + return h.ciphers.Len() } -func (s *SSServer) removePort(portNum int) error { - port, ok := s.ports[portNum] - if !ok { - return fmt.Errorf("port %v doesn't exist", portNum) - } - tcpErr := port.tcpListener.Close() - udpErr := port.packetConn.Close() - delete(s.ports, portNum) - if tcpErr != nil { - //lint:ignore ST1005 Shadowsocks is capitalized. - return fmt.Errorf("Shadowsocks TCP service on port %v failed to stop: %w", portNum, tcpErr) - } - logger.Infof("Shadowsocks TCP service on port %v stopped", portNum) - if udpErr != nil { - //lint:ignore ST1005 Shadowsocks is capitalized. - return fmt.Errorf("Shadowsocks UDP service on port %v failed to stop: %w", portNum, udpErr) +func (h *connHandler) AddCipher(entry *service.CipherEntry) { + h.ciphers.PushBack(entry) +} + +func (h *connHandler) Handle(ctx context.Context, conn any) { + switch c := conn.(type) { + case transport.StreamConn: + authFunc := service.NewShadowsocksStreamAuthenticator(h.ciphers, h.replayCache, h.m) + // TODO: Register initial data metrics at zero. + tcpHandler := service.NewTCPHandler(c.LocalAddr().String(), authFunc, h.m, h.tcpTimeout) + tcpHandler.Handle(ctx, c) + case net.PacketConn: + packetHandler := service.NewPacketHandler(h.natTimeout, h.ciphers, h.m) + packetHandler.Handle(ctx, c) + default: + logger.Errorf("unknown connection type: %v", c) } - logger.Infof("Shadowsocks UDP service on port %v stopped", portNum) - return nil +} + +type SSServer struct { + lnManager ListenerManager + lnSet ListenerSet + natTimeout time.Duration + m *outlineMetrics + replayCache service.ReplayCache } func (s *SSServer) loadConfig(filename string) error { @@ -122,65 +111,81 @@ func (s *SSServer) loadConfig(filename string) error { return fmt.Errorf("failed to load config (%v): %w", filename, err) } - portChanges := make(map[int]int) - portCiphers := make(map[int]*list.List) // Values are *List of *CipherEntry. - for _, keyConfig := range config.Keys { - portChanges[keyConfig.Port] = 1 - cipherList, ok := portCiphers[keyConfig.Port] + // We hot swap the services by having them both live at the same time. This + // means we create services for the new config first, and then take down the + // services from the old config. + oldListenerSet := s.lnSet + s.lnSet = s.lnManager.NewListenerSet() + var totalCipherCount int + + portHandlers := make(map[int]Handler) + for _, legacyKeyServiceConfig := range config.Keys { + handler, ok := portHandlers[legacyKeyServiceConfig.Port] if !ok { - cipherList = list.New() - portCiphers[keyConfig.Port] = cipherList + handler = &connHandler{ + ciphers: service.NewCipherList(), + tcpTimeout: tcpReadTimeout, + natTimeout: s.natTimeout, + m: s.m, + replayCache: &s.replayCache, + } + portHandlers[legacyKeyServiceConfig.Port] = handler } - cryptoKey, err := shadowsocks.NewEncryptionKey(keyConfig.Cipher, keyConfig.Secret) + cryptoKey, err := shadowsocks.NewEncryptionKey(legacyKeyServiceConfig.Cipher, legacyKeyServiceConfig.Secret) if err != nil { - return fmt.Errorf("failed to create encyption key for key %v: %w", keyConfig.ID, err) + return fmt.Errorf("failed to create encyption key for key %v: %w", legacyKeyServiceConfig.ID, err) } - entry := service.MakeCipherEntry(keyConfig.ID, cryptoKey, keyConfig.Secret) - cipherList.PushBack(&entry) - } - for port := range s.ports { - portChanges[port] = portChanges[port] - 1 + entry := service.MakeCipherEntry(legacyKeyServiceConfig.ID, cryptoKey, legacyKeyServiceConfig.Secret) + handler.AddCipher(&entry) } - for portNum, count := range portChanges { - if count == -1 { - if err := s.removePort(portNum); err != nil { - return fmt.Errorf("failed to remove port %v: %w", portNum, err) - } - } else if count == +1 { - if err := s.startPort(portNum); err != nil { - return err + for portNum, handler := range portHandlers { + totalCipherCount += handler.NumCiphers() + for _, network := range []string{"tcp", "udp"} { + addr := net.JoinHostPort("::", strconv.Itoa(portNum)) + listener, err := s.lnSet.Listen(context.TODO(), network, addr, net.ListenConfig{KeepAlive: 0}) + if err != nil { + return fmt.Errorf("%s service failed to start listening on address %s: %w", network, addr, err) } + listener.SetHandler(handler) } } - for portNum, cipherList := range portCiphers { - s.ports[portNum].cipherList.Update(cipherList) + logger.Infof("Loaded %d access keys over %d listeners", totalCipherCount, s.lnSet.Len()) + s.m.SetNumAccessKeys(totalCipherCount, s.lnSet.Len()) + + // Take down the old services now that the new ones are created and serving. + if oldListenerSet != nil { + if err := oldListenerSet.Close(); err != nil { + logger.Errorf("Failed to stop old listeners: %w", err) + } + logger.Infof("Stopped %d old listeners", s.lnSet.Len()) } - logger.Infof("Loaded %v access keys over %v ports", len(config.Keys), len(s.ports)) - s.m.SetNumAccessKeys(len(config.Keys), len(portCiphers)) + return nil } -// Stop serving on all ports. +// Stop serving on all existing services. func (s *SSServer) Stop() error { - for portNum := range s.ports { - if err := s.removePort(portNum); err != nil { - return err - } + if s.lnSet == nil { + return nil + } + if err := s.lnSet.Close(); err != nil { + logger.Errorf("Failed to stop all listeners: %w", err) } + logger.Infof("Stopped %d listeners", s.lnSet.Len()) return nil } // RunSSServer starts a shadowsocks server running, and returns the server or an error. func RunSSServer(filename string, natTimeout time.Duration, sm *outlineMetrics, replayHistory int) (*SSServer, error) { server := &SSServer{ + lnManager: NewListenerManager(), natTimeout: natTimeout, m: sm, replayCache: service.NewReplayCache(replayHistory), - ports: make(map[int]*ssPort), } err := server.loadConfig(filename) if err != nil { - return nil, fmt.Errorf("failed configure server: %w", err) + return nil, fmt.Errorf("failed to configure server: %w", err) } sigHup := make(chan os.Signal, 1) signal.Notify(sigHup, syscall.SIGHUP) diff --git a/cmd/outline-ss-server/metrics.go b/cmd/outline-ss-server/metrics.go index e95ceeb3..600cea16 100644 --- a/cmd/outline-ss-server/metrics.go +++ b/cmd/outline-ss-server/metrics.go @@ -38,7 +38,7 @@ type outlineMetrics struct { buildInfo *prometheus.GaugeVec accessKeys prometheus.Gauge - ports prometheus.Gauge + listeners prometheus.Gauge dataBytes *prometheus.CounterVec dataBytesPerLocation *prometheus.CounterVec timeToCipherMs *prometheus.HistogramVec @@ -183,10 +183,10 @@ func newPrometheusOutlineMetrics(ip2info ipinfo.IPInfoMap, registerer prometheus Name: "keys", Help: "Count of access keys", }), - ports: prometheus.NewGauge(prometheus.GaugeOpts{ + listeners: prometheus.NewGauge(prometheus.GaugeOpts{ Namespace: namespace, - Name: "ports", - Help: "Count of open Shadowsocks ports", + Name: "listeners", + Help: "Count of open Shadowsocks listeners", }), tcpProbes: prometheus.NewHistogramVec(prometheus.HistogramOpts{ Namespace: namespace, @@ -265,7 +265,7 @@ func newPrometheusOutlineMetrics(ip2info ipinfo.IPInfoMap, registerer prometheus m.tunnelTimeCollector = newTunnelTimeCollector(ip2info, registerer) // TODO: Is it possible to pass where to register the collectors? - registerer.MustRegister(m.buildInfo, m.accessKeys, m.ports, m.tcpProbes, m.tcpOpenConnections, m.tcpClosedConnections, m.tcpConnectionDurationMs, + registerer.MustRegister(m.buildInfo, m.accessKeys, m.listeners, m.tcpProbes, m.tcpOpenConnections, m.tcpClosedConnections, m.tcpConnectionDurationMs, m.dataBytes, m.dataBytesPerLocation, m.timeToCipherMs, m.udpPacketsFromClientPerLocation, m.udpAddedNatEntries, m.udpRemovedNatEntries, m.tunnelTimeCollector) return m @@ -275,9 +275,9 @@ func (m *outlineMetrics) SetBuildInfo(version string) { m.buildInfo.WithLabelValues(version).Set(1) } -func (m *outlineMetrics) SetNumAccessKeys(numKeys int, ports int) { +func (m *outlineMetrics) SetNumAccessKeys(numKeys int, listeners int) { m.accessKeys.Set(float64(numKeys)) - m.ports.Set(float64(ports)) + m.listeners.Set(float64(listeners)) } func (m *outlineMetrics) AddOpenTCPConnection(clientInfo ipinfo.IPInfo) { diff --git a/go.mod b/go.mod index 04a9ddab..5c1419d2 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/stretchr/testify v1.8.4 golang.org/x/crypto v0.17.0 golang.org/x/term v0.16.0 - gopkg.in/yaml.v2 v2.4.0 + gopkg.in/yaml.v3 v3.0.1 ) require ( @@ -263,7 +263,7 @@ require ( gopkg.in/src-d/go-billy.v4 v4.3.2 // indirect gopkg.in/src-d/go-git.v4 v4.13.1 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect k8s.io/klog/v2 v2.90.0 // indirect mvdan.cc/sh/v3 v3.7.0 // indirect sigs.k8s.io/kind v0.17.0 // indirect diff --git a/internal/integration_test/integration_test.go b/internal/integration_test/integration_test.go index f98319f4..2bfb0dfa 100644 --- a/internal/integration_test/integration_test.go +++ b/internal/integration_test/integration_test.go @@ -293,7 +293,7 @@ func TestUDPEcho(t *testing.T) { proxy.SetTargetIPValidator(allowAll) done := make(chan struct{}) go func() { - proxy.Handle(proxyConn) + proxy.Handle(context.Background(), proxyConn) done <- struct{}{} }() @@ -525,7 +525,7 @@ func BenchmarkUDPEcho(b *testing.B) { proxy.SetTargetIPValidator(allowAll) done := make(chan struct{}) go func() { - proxy.Handle(server) + proxy.Handle(context.Background(), server) done <- struct{}{} }() @@ -569,7 +569,7 @@ func BenchmarkUDPManyKeys(b *testing.B) { proxy.SetTargetIPValidator(allowAll) done := make(chan struct{}) go func() { - proxy.Handle(proxyConn) + proxy.Handle(context.Background(), proxyConn) done <- struct{}{} }() diff --git a/service/cipher_list.go b/service/cipher_list.go index 3b6f1957..d84ab1ac 100644 --- a/service/cipher_list.go +++ b/service/cipher_list.go @@ -55,6 +55,7 @@ func MakeCipherEntry(id string, cryptoKey *shadowsocks.EncryptionKey, secret str // CipherList is a thread-safe collection of CipherEntry elements that allows for // snapshotting and moving to front. type CipherList interface { + Len() int // Returns a snapshot of the cipher list optimized for this client IP SnapshotForClientIP(clientIP netip.Addr) []*list.Element MarkUsedByClientIP(e *list.Element, clientIP netip.Addr) @@ -62,6 +63,8 @@ type CipherList interface { // which is a List of *CipherEntry. Update takes ownership of `contents`, // which must not be read or written after this call. Update(contents *list.List) + // PushBack inserts a new cipher at the back of the list. + PushBack(entry *CipherEntry) *list.Element } type cipherList struct { @@ -75,6 +78,10 @@ func NewCipherList() CipherList { return &cipherList{list: list.New()} } +func (cl *cipherList) Len() int { + return cl.list.Len() +} + func matchesIP(e *list.Element, clientIP netip.Addr) bool { c := e.Value.(*CipherEntry) return clientIP != netip.Addr{} && clientIP == c.lastClientIP @@ -116,3 +123,7 @@ func (cl *cipherList) Update(src *list.List) { cl.list = src cl.mu.Unlock() } + +func (cl *cipherList) PushBack(entry *CipherEntry) *list.Element { + return cl.list.PushBack(entry) +} diff --git a/service/cipher_list_testing.go b/service/cipher_list_testing.go index a77427ed..d8532f79 100644 --- a/service/cipher_list_testing.go +++ b/service/cipher_list_testing.go @@ -15,7 +15,6 @@ package service import ( - "container/list" "fmt" "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks" @@ -24,7 +23,7 @@ import ( // MakeTestCiphers creates a CipherList containing one fresh AEAD cipher // for each secret in `secrets`. func MakeTestCiphers(secrets []string) (CipherList, error) { - l := list.New() + cipherList := NewCipherList() for i := 0; i < len(secrets); i++ { cipherID := fmt.Sprintf("id-%v", i) cipher, err := shadowsocks.NewEncryptionKey(shadowsocks.CHACHA20IETFPOLY1305, secrets[i]) @@ -32,10 +31,8 @@ func MakeTestCiphers(secrets []string) (CipherList, error) { return nil, fmt.Errorf("failed to create cipher %v: %w", i, err) } entry := MakeCipherEntry(cipherID, cipher, secrets[i]) - l.PushBack(&entry) + cipherList.PushBack(&entry) } - cipherList := NewCipherList() - cipherList.Update(l) return cipherList, nil } diff --git a/service/tcp.go b/service/tcp.go index 2195480b..ced85a54 100644 --- a/service/tcp.go +++ b/service/tcp.go @@ -236,7 +236,7 @@ func StreamServe(accept StreamListener, handle StreamHandler) { if errors.Is(err, net.ErrClosed) { break } - logger.Warningf("AcceptTCP failed: %v. Continuing to listen.", err) + logger.Warningf("Accept failed: %v. Continuing to listen.", err) continue } diff --git a/service/udp.go b/service/udp.go index 4830e302..859c6c44 100644 --- a/service/udp.go +++ b/service/udp.go @@ -15,6 +15,7 @@ package service import ( + "context" "errors" "fmt" "net" @@ -101,7 +102,7 @@ type PacketHandler interface { // SetTargetIPValidator sets the function to be used to validate the target IP addresses. SetTargetIPValidator(targetIPValidator onet.TargetIPValidator) // Handle returns after clientConn closes and all the sub goroutines return. - Handle(clientConn net.PacketConn) + Handle(ctx context.Context, clientConn net.PacketConn) } func (h *packetHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPValidator) { @@ -110,7 +111,7 @@ func (h *packetHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPVali // Listen on addr for encrypted packets and basically do UDP NAT. // We take the ciphers as a pointer because it gets replaced on config updates. -func (h *packetHandler) Handle(clientConn net.PacketConn) { +func (h *packetHandler) Handle(ctx context.Context, clientConn net.PacketConn) { var running sync.WaitGroup nm := newNATmap(h.natTimeout, h.m, &running) diff --git a/service/udp_test.go b/service/udp_test.go index f94238c5..90d880b5 100644 --- a/service/udp_test.go +++ b/service/udp_test.go @@ -16,6 +16,7 @@ package service import ( "bytes" + "context" "errors" "net" "net/netip" @@ -132,7 +133,7 @@ func sendToDiscard(payloads [][]byte, validator onet.TargetIPValidator) *natTest handler.SetTargetIPValidator(validator) done := make(chan struct{}) go func() { - handler.Handle(clientConn) + handler.Handle(context.Background(), clientConn) done <- struct{}{} }() @@ -488,7 +489,7 @@ func TestUDPEarlyClose(t *testing.T) { } require.Nil(t, clientConn.Close()) // This should return quickly without timing out. - s.Handle(clientConn) + s.Handle(context.Background(), clientConn) } // Makes sure the UDP listener returns [io.ErrClosed] on reads and writes after Close(). From ae7f41d8a0632c406f7902bd81b81639d588fc16 Mon Sep 17 00:00:00 2001 From: sbruens Date: Thu, 11 Jul 2024 16:50:34 -0400 Subject: [PATCH 063/119] Update comments. --- cmd/outline-ss-server/main.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index be94cbdf..65d8acf9 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -111,9 +111,9 @@ func (s *SSServer) loadConfig(filename string) error { return fmt.Errorf("failed to load config (%v): %w", filename, err) } - // We hot swap the services by having them both live at the same time. This - // means we create services for the new config first, and then take down the - // services from the old config. + // We hot swap the config by having the old and new listeners both live at + // the same time. This means we create listeners for the new config first, + // and then close the old ones after. oldListenerSet := s.lnSet s.lnSet = s.lnManager.NewListenerSet() var totalCipherCount int @@ -152,7 +152,7 @@ func (s *SSServer) loadConfig(filename string) error { logger.Infof("Loaded %d access keys over %d listeners", totalCipherCount, s.lnSet.Len()) s.m.SetNumAccessKeys(totalCipherCount, s.lnSet.Len()) - // Take down the old services now that the new ones are created and serving. + // Take down the old listeners now that the new ones are created and serving. if oldListenerSet != nil { if err := oldListenerSet.Close(); err != nil { logger.Errorf("Failed to stop old listeners: %w", err) @@ -163,7 +163,7 @@ func (s *SSServer) loadConfig(filename string) error { return nil } -// Stop serving on all existing services. +// Stop serving on all existing listeners. func (s *SSServer) Stop() error { if s.lnSet == nil { return nil From 120db8e43f086e5fb3a1de168002d92bd50651ba Mon Sep 17 00:00:00 2001 From: sbruens Date: Thu, 11 Jul 2024 16:51:34 -0400 Subject: [PATCH 064/119] `go mod tidy`. --- go.mod | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/go.mod b/go.mod index 5c1419d2..04a9ddab 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/stretchr/testify v1.8.4 golang.org/x/crypto v0.17.0 golang.org/x/term v0.16.0 - gopkg.in/yaml.v3 v3.0.1 + gopkg.in/yaml.v2 v2.4.0 ) require ( @@ -263,7 +263,7 @@ require ( gopkg.in/src-d/go-billy.v4 v4.3.2 // indirect gopkg.in/src-d/go-git.v4 v4.13.1 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect - gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect k8s.io/klog/v2 v2.90.0 // indirect mvdan.cc/sh/v3 v3.7.0 // indirect sigs.k8s.io/kind v0.17.0 // indirect From d705603e1a70a15fd3234fc496610071633b6f5a Mon Sep 17 00:00:00 2001 From: sbruens Date: Tue, 16 Jul 2024 16:50:36 -0400 Subject: [PATCH 065/119] refactor: don't link the TCP handler to a specific listener --- cmd/outline-ss-server/main.go | 2 +- internal/integration_test/integration_test.go | 8 ++++---- service/tcp.go | 7 +++---- service/tcp_test.go | 16 ++++++++-------- 4 files changed, 16 insertions(+), 17 deletions(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 87860df7..e73506a8 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -87,7 +87,7 @@ func (s *SSServer) startPort(portNum int) error { port := &ssPort{tcpListener: listener, packetConn: packetConn, cipherList: service.NewCipherList()} authFunc := service.NewShadowsocksStreamAuthenticator(port.cipherList, &s.replayCache, s.m) // TODO: Register initial data metrics at zero. - tcpHandler := service.NewTCPHandler(listener.Addr().String(), authFunc, s.m, tcpReadTimeout) + tcpHandler := service.NewTCPHandler(authFunc, s.m, tcpReadTimeout) packetHandler := service.NewPacketHandler(s.natTimeout, port.cipherList, s.m) s.ports[portNum] = port go service.StreamServe(service.WrapStreamListener(listener.AcceptTCP), tcpHandler.Handle) diff --git a/internal/integration_test/integration_test.go b/internal/integration_test/integration_test.go index f98319f4..43109b7a 100644 --- a/internal/integration_test/integration_test.go +++ b/internal/integration_test/integration_test.go @@ -133,7 +133,7 @@ func TestTCPEcho(t *testing.T) { const testTimeout = 200 * time.Millisecond testMetrics := &service.NoOpTCPMetrics{} authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics) - handler := service.NewTCPHandler(proxyListener.Addr().String(), authFunc, testMetrics, testTimeout) + handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout) handler.SetTargetDialer(&transport.TCPDialer{}) done := make(chan struct{}) go func() { @@ -202,7 +202,7 @@ func TestRestrictedAddresses(t *testing.T) { const testTimeout = 200 * time.Millisecond testMetrics := &statusMetrics{} authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := service.NewTCPHandler(proxyListener.Addr().String(), authFunc, testMetrics, testTimeout) + handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout) done := make(chan struct{}) go func() { service.StreamServe(service.WrapStreamListener(proxyListener.AcceptTCP), handler.Handle) @@ -384,7 +384,7 @@ func BenchmarkTCPThroughput(b *testing.B) { const testTimeout = 200 * time.Millisecond testMetrics := &service.NoOpTCPMetrics{} authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := service.NewTCPHandler(proxyListener.Addr().String(), authFunc, testMetrics, testTimeout) + handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout) handler.SetTargetDialer(&transport.TCPDialer{}) done := make(chan struct{}) go func() { @@ -448,7 +448,7 @@ func BenchmarkTCPMultiplexing(b *testing.B) { const testTimeout = 200 * time.Millisecond testMetrics := &service.NoOpTCPMetrics{} authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics) - handler := service.NewTCPHandler(proxyListener.Addr().String(), authFunc, testMetrics, testTimeout) + handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout) handler.SetTargetDialer(&transport.TCPDialer{}) done := make(chan struct{}) go func() { diff --git a/service/tcp.go b/service/tcp.go index 2195480b..bd459e20 100644 --- a/service/tcp.go +++ b/service/tcp.go @@ -170,9 +170,8 @@ type tcpHandler struct { } // NewTCPService creates a TCPService -func NewTCPHandler(listenerId string, authenticate StreamAuthenticateFunc, m TCPMetrics, timeout time.Duration) TCPHandler { +func NewTCPHandler(authenticate StreamAuthenticateFunc, m TCPMetrics, timeout time.Duration) TCPHandler { return &tcpHandler{ - listenerId: listenerId, m: m, readTimeout: timeout, authenticate: authenticate, @@ -370,12 +369,12 @@ func (h *tcpHandler) handleConnection(ctx context.Context, outerConn transport.S // Keep the connection open until we hit the authentication deadline to protect against probing attacks // `proxyMetrics` is a pointer because its value is being mutated by `clientConn`. -func (h *tcpHandler) absorbProbe(clientConn io.ReadCloser, status string, proxyMetrics *metrics.ProxyMetrics) { +func (h *tcpHandler) absorbProbe(clientConn transport.StreamConn, status string, proxyMetrics *metrics.ProxyMetrics) { // This line updates proxyMetrics.ClientProxy before it's used in AddTCPProbe. _, drainErr := io.Copy(io.Discard, clientConn) // drain socket drainResult := drainErrToString(drainErr) logger.Debugf("Drain error: %v, drain result: %v", drainErr, drainResult) - h.m.AddTCPProbe(status, drainResult, h.listenerId, proxyMetrics.ClientProxy) + h.m.AddTCPProbe(status, drainResult, clientConn.LocalAddr().String(), proxyMetrics.ClientProxy) } func drainErrToString(drainErr error) string { diff --git a/service/tcp_test.go b/service/tcp_test.go index 5c3bc9df..fbe80f7c 100644 --- a/service/tcp_test.go +++ b/service/tcp_test.go @@ -281,7 +281,7 @@ func TestProbeRandom(t *testing.T) { require.NoError(t, err, "MakeTestCiphers failed: %v", err) testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := NewTCPHandler(listener.Addr().String(), authFunc, testMetrics, 200*time.Millisecond) + handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond) done := make(chan struct{}) go func() { StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) @@ -358,7 +358,7 @@ func TestProbeClientBytesBasicTruncated(t *testing.T) { cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := NewTCPHandler(listener.Addr().String(), authFunc, testMetrics, 200*time.Millisecond) + handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond) handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll)) done := make(chan struct{}) go func() { @@ -393,7 +393,7 @@ func TestProbeClientBytesBasicModified(t *testing.T) { cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := NewTCPHandler(listener.Addr().String(), authFunc, testMetrics, 200*time.Millisecond) + handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond) handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll)) done := make(chan struct{}) go func() { @@ -429,7 +429,7 @@ func TestProbeClientBytesCoalescedModified(t *testing.T) { cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := NewTCPHandler(listener.Addr().String(), authFunc, testMetrics, 200*time.Millisecond) + handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond) handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll)) done := make(chan struct{}) go func() { @@ -472,7 +472,7 @@ func TestProbeServerBytesModified(t *testing.T) { cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := NewTCPHandler(listener.Addr().String(), authFunc, testMetrics, 200*time.Millisecond) + handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond) done := make(chan struct{}) go func() { StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) @@ -503,7 +503,7 @@ func TestReplayDefense(t *testing.T) { testMetrics := &probeTestMetrics{} const testTimeout = 200 * time.Millisecond authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics) - handler := NewTCPHandler(listener.Addr().String(), authFunc, testMetrics, testTimeout) + handler := NewTCPHandler(authFunc, testMetrics, testTimeout) snapshot := cipherList.SnapshotForClientIP(netip.Addr{}) cipherEntry := snapshot[0].Value.(*CipherEntry) cipher := cipherEntry.CryptoKey @@ -582,7 +582,7 @@ func TestReverseReplayDefense(t *testing.T) { testMetrics := &probeTestMetrics{} const testTimeout = 200 * time.Millisecond authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics) - handler := NewTCPHandler(listener.Addr().String(), authFunc, testMetrics, testTimeout) + handler := NewTCPHandler(authFunc, testMetrics, testTimeout) snapshot := cipherList.SnapshotForClientIP(netip.Addr{}) cipherEntry := snapshot[0].Value.(*CipherEntry) cipher := cipherEntry.CryptoKey @@ -653,7 +653,7 @@ func probeExpectTimeout(t *testing.T, payloadSize int) { require.NoError(t, err, "MakeTestCiphers failed: %v", err) testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := NewTCPHandler(listener.Addr().String(), authFunc, testMetrics, testTimeout) + handler := NewTCPHandler(authFunc, testMetrics, testTimeout) done := make(chan struct{}) go func() { From d2ef46efbea2fc50a89191d1ea064327cd1f8fb3 Mon Sep 17 00:00:00 2001 From: sbruens Date: Tue, 16 Jul 2024 18:12:26 -0400 Subject: [PATCH 066/119] Protect new cipher handling methods with mutex. --- service/cipher_list.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/service/cipher_list.go b/service/cipher_list.go index d84ab1ac..beda57bc 100644 --- a/service/cipher_list.go +++ b/service/cipher_list.go @@ -79,6 +79,8 @@ func NewCipherList() CipherList { } func (cl *cipherList) Len() int { + cl.mu.Lock() + defer cl.mu.Unlock() return cl.list.Len() } @@ -125,5 +127,7 @@ func (cl *cipherList) Update(src *list.List) { } func (cl *cipherList) PushBack(entry *CipherEntry) *list.Element { + cl.mu.Lock() + defer cl.mu.Unlock() return cl.list.PushBack(entry) } From ab07400909c5ceeacbdeb2515b2953f0d74fb3b5 Mon Sep 17 00:00:00 2001 From: sbruens Date: Tue, 16 Jul 2024 18:04:53 -0400 Subject: [PATCH 067/119] Move `listeners.go` under `/service`. --- cmd/outline-ss-server/main.go | 14 ++++---------- {cmd/outline-ss-server => service}/listeners.go | 11 ++++++++--- 2 files changed, 12 insertions(+), 13 deletions(-) rename {cmd/outline-ss-server => service}/listeners.go (97%) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 849b72a0..6e0d0e90 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -60,12 +60,6 @@ func init() { logger = logging.MustGetLogger("") } -type Handler interface { - NumCiphers() int - AddCipher(entry *service.CipherEntry) - Handle(ctx context.Context, conn any) -} - type connHandler struct { tcpTimeout time.Duration natTimeout time.Duration @@ -98,8 +92,8 @@ func (h *connHandler) Handle(ctx context.Context, conn any) { } type SSServer struct { - lnManager ListenerManager - lnSet ListenerSet + lnManager service.ListenerManager + lnSet service.ListenerSet natTimeout time.Duration m *outlineMetrics replayCache service.ReplayCache @@ -118,7 +112,7 @@ func (s *SSServer) loadConfig(filename string) error { s.lnSet = s.lnManager.NewListenerSet() var totalCipherCount int - portHandlers := make(map[int]Handler) + portHandlers := make(map[int]service.Handler) for _, legacyKeyServiceConfig := range config.Keys { handler, ok := portHandlers[legacyKeyServiceConfig.Port] if !ok { @@ -178,7 +172,7 @@ func (s *SSServer) Stop() error { // RunSSServer starts a shadowsocks server running, and returns the server or an error. func RunSSServer(filename string, natTimeout time.Duration, sm *outlineMetrics, replayHistory int) (*SSServer, error) { server := &SSServer{ - lnManager: NewListenerManager(), + lnManager: service.NewListenerManager(), natTimeout: natTimeout, m: sm, replayCache: service.NewReplayCache(replayHistory), diff --git a/cmd/outline-ss-server/listeners.go b/service/listeners.go similarity index 97% rename from cmd/outline-ss-server/listeners.go rename to service/listeners.go index f45339b1..14615d9f 100644 --- a/cmd/outline-ss-server/listeners.go +++ b/service/listeners.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package main +package service import ( "context" @@ -22,9 +22,14 @@ import ( "sync/atomic" "github.com/Jigsaw-Code/outline-sdk/transport" - "github.com/Jigsaw-Code/outline-ss-server/service" ) +type Handler interface { + NumCiphers() int + AddCipher(entry *CipherEntry) + Handle(ctx context.Context, conn any) +} + type acceptResponse struct { conn net.Conn err error @@ -55,7 +60,7 @@ func (sl *sharedListener) SetHandler(handler Handler) { handle := func(ctx context.Context, conn transport.StreamConn) { handler.Handle(ctx, conn) } - go service.StreamServe(accept, handle) + go StreamServe(accept, handle) } // Accept accepts connections until Close() is called. From 71d7140258031d36b4d724779bb7e8f06ebf8938 Mon Sep 17 00:00:00 2001 From: sbruens Date: Tue, 16 Jul 2024 18:19:47 -0400 Subject: [PATCH 068/119] Use callback instead of passing in key and manager. --- service/listeners.go | 47 ++++++++++++++++++++++---------------------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/service/listeners.go b/service/listeners.go index 14615d9f..58169362 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -40,13 +40,12 @@ type SharedListener interface { } type sharedListener struct { - listener net.Listener - manager ListenerManager - key string - closed atomic.Int32 - usage *atomic.Int32 - acceptCh chan acceptResponse - closeCh chan struct{} + listener net.Listener + closed atomic.Int32 + usage *atomic.Int32 + acceptCh chan acceptResponse + closeCh chan struct{} + closeFunc func() } func (sl *sharedListener) SetHandler(handler Handler) { @@ -87,7 +86,7 @@ func (sl *sharedListener) Close() error { // See if we need to actually close the underlying listener. if sl.usage.Add(-1) == 0 { - sl.manager.Delete(sl.key) + sl.closeFunc() err := sl.listener.Close() if err != nil { return err @@ -104,17 +103,16 @@ func (sl *sharedListener) Addr() net.Addr { type sharedPacketConn struct { net.PacketConn - manager ListenerManager - key string - closed atomic.Int32 - usage *atomic.Int32 + closed atomic.Int32 + usage *atomic.Int32 + closeFunc func() } func (spc *sharedPacketConn) Close() error { if spc.closed.CompareAndSwap(0, 1) { // See if we need to actually close the underlying listener. if spc.usage.Add(-1) == 0 { - spc.manager.Delete(spc.key) + spc.closeFunc() err := spc.PacketConn.Close() if err != nil { return err @@ -186,7 +184,6 @@ func (ls *listenerSet) Len() int { type ListenerManager interface { NewListenerSet() ListenerSet Listen(ctx context.Context, network string, addr string, config net.ListenConfig) (SharedListener, error) - Delete(key string) } type listenerManager struct { @@ -226,11 +223,12 @@ func (m *listenerManager) Listen(ctx context.Context, network string, addr strin lnGlobal.usage.Add(1) return &sharedListener{ listener: lnGlobal.ln, - manager: m, - key: lnKey, usage: &lnGlobal.usage, acceptCh: lnGlobal.acceptCh, closeCh: make(chan struct{}), + closeFunc: func() { + m.delete(lnKey) + }, }, nil } @@ -251,11 +249,12 @@ func (m *listenerManager) Listen(ctx context.Context, network string, addr strin return &sharedListener{ listener: ln, - manager: m, - key: lnKey, usage: &lnGlobal.usage, acceptCh: lnGlobal.acceptCh, closeCh: make(chan struct{}), + closeFunc: func() { + m.delete(lnKey) + }, }, nil case "udp": @@ -266,9 +265,10 @@ func (m *listenerManager) Listen(ctx context.Context, network string, addr strin lnGlobal.usage.Add(1) return &sharedPacketConn{ PacketConn: lnGlobal.pc, - manager: m, - key: lnKey, usage: &lnGlobal.usage, + closeFunc: func() { + m.delete(lnKey) + }, }, nil } @@ -283,9 +283,10 @@ func (m *listenerManager) Listen(ctx context.Context, network string, addr strin return &sharedPacketConn{ PacketConn: pc, - manager: m, - key: lnKey, usage: &lnGlobal.usage, + closeFunc: func() { + m.delete(lnKey) + }, }, nil default: @@ -294,7 +295,7 @@ func (m *listenerManager) Listen(ctx context.Context, network string, addr strin } } -func (m *listenerManager) Delete(key string) { +func (m *listenerManager) delete(key string) { m.listenersMu.Lock() delete(m.listeners, key) m.listenersMu.Unlock() From 9dfa4e269bd3700963347e32b99ef92048dafd47 Mon Sep 17 00:00:00 2001 From: sbruens Date: Tue, 16 Jul 2024 18:02:19 -0400 Subject: [PATCH 069/119] Move config start into a go routine for easier cleanup. --- cmd/outline-ss-server/main.go | 177 ++++++++--------- internal/integration_test/integration_test.go | 6 +- service/listeners.go | 184 ++++++++---------- service/tcp.go | 24 +-- service/udp.go | 5 +- service/udp_test.go | 5 +- 6 files changed, 193 insertions(+), 208 deletions(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 6e0d0e90..13400c11 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -15,7 +15,6 @@ package main import ( - "context" "flag" "fmt" "net" @@ -60,40 +59,9 @@ func init() { logger = logging.MustGetLogger("") } -type connHandler struct { - tcpTimeout time.Duration - natTimeout time.Duration - replayCache *service.ReplayCache - m *outlineMetrics - ciphers service.CipherList -} - -func (h *connHandler) NumCiphers() int { - return h.ciphers.Len() -} - -func (h *connHandler) AddCipher(entry *service.CipherEntry) { - h.ciphers.PushBack(entry) -} - -func (h *connHandler) Handle(ctx context.Context, conn any) { - switch c := conn.(type) { - case transport.StreamConn: - authFunc := service.NewShadowsocksStreamAuthenticator(h.ciphers, h.replayCache, h.m) - // TODO: Register initial data metrics at zero. - tcpHandler := service.NewTCPHandler(authFunc, h.m, h.tcpTimeout) - tcpHandler.Handle(ctx, c) - case net.PacketConn: - packetHandler := service.NewPacketHandler(h.natTimeout, h.ciphers, h.m) - packetHandler.Handle(ctx, c) - default: - logger.Errorf("unknown connection type: %v", c) - } -} - type SSServer struct { + stopConfig func() lnManager service.ListenerManager - lnSet service.ListenerSet natTimeout time.Duration m *outlineMetrics replayCache service.ReplayCache @@ -104,74 +72,109 @@ func (s *SSServer) loadConfig(filename string) error { if err != nil { return fmt.Errorf("failed to load config (%v): %w", filename, err) } - // We hot swap the config by having the old and new listeners both live at // the same time. This means we create listeners for the new config first, // and then close the old ones after. - oldListenerSet := s.lnSet - s.lnSet = s.lnManager.NewListenerSet() - var totalCipherCount int - - portHandlers := make(map[int]service.Handler) - for _, legacyKeyServiceConfig := range config.Keys { - handler, ok := portHandlers[legacyKeyServiceConfig.Port] - if !ok { - handler = &connHandler{ - ciphers: service.NewCipherList(), - tcpTimeout: tcpReadTimeout, - natTimeout: s.natTimeout, - m: s.m, - replayCache: &s.replayCache, - } - portHandlers[legacyKeyServiceConfig.Port] = handler - } - cryptoKey, err := shadowsocks.NewEncryptionKey(legacyKeyServiceConfig.Cipher, legacyKeyServiceConfig.Secret) - if err != nil { - return fmt.Errorf("failed to create encyption key for key %v: %w", legacyKeyServiceConfig.ID, err) - } - entry := service.MakeCipherEntry(legacyKeyServiceConfig.ID, cryptoKey, legacyKeyServiceConfig.Secret) - handler.AddCipher(&entry) - } - for portNum, handler := range portHandlers { - totalCipherCount += handler.NumCiphers() - for _, network := range []string{"tcp", "udp"} { - addr := net.JoinHostPort("::", strconv.Itoa(portNum)) - listener, err := s.lnSet.Listen(context.TODO(), network, addr, net.ListenConfig{KeepAlive: 0}) - if err != nil { - return fmt.Errorf("%s service failed to start listening on address %s: %w", network, addr, err) - } - listener.SetHandler(handler) - } + stopConfig, err := s.runConfig(*config) + if err != nil { + return err } - logger.Infof("Loaded %d access keys over %d listeners", totalCipherCount, s.lnSet.Len()) - s.m.SetNumAccessKeys(totalCipherCount, s.lnSet.Len()) + s.stopConfig() + s.stopConfig = stopConfig + return nil +} - // Take down the old listeners now that the new ones are created and serving. - if oldListenerSet != nil { - if err := oldListenerSet.Close(); err != nil { - logger.Errorf("Failed to stop old listeners: %w", err) - } - logger.Infof("Stopped %d old listeners", s.lnSet.Len()) - } +func (s *SSServer) NewShadowsocksStreamHandler(ciphers service.CipherList) service.StreamHandler { + authFunc := service.NewShadowsocksStreamAuthenticator(ciphers, &s.replayCache, s.m) + // TODO: Register initial data metrics at zero. + return service.NewStreamHandler(authFunc, s.m, tcpReadTimeout) +} - return nil +func (s *SSServer) NewShadowsocksPacketHandler(ciphers service.CipherList) service.PacketHandler { + return service.NewPacketHandler(s.natTimeout, ciphers, s.m) } -// Stop serving on all existing listeners. -func (s *SSServer) Stop() error { - if s.lnSet == nil { - return nil - } - if err := s.lnSet.Close(); err != nil { - logger.Errorf("Failed to stop all listeners: %w", err) +func (s *SSServer) runConfig(config Config) (func(), error) { + startErrCh := make(chan error) + stopCh := make(chan struct{}) + + go func() { + startErrCh <- func() error { + lnSet := s.lnManager.NewListenerSet() + defer lnSet.Close() + + var totalCipherCount int + + portCiphers := make(map[int]service.CipherList) + for _, keyConfig := range config.Keys { + ciphers, ok := portCiphers[keyConfig.Port] + if !ok { + ciphers = service.NewCipherList() + portCiphers[keyConfig.Port] = ciphers + } + cryptoKey, err := shadowsocks.NewEncryptionKey(keyConfig.Cipher, keyConfig.Secret) + if err != nil { + return fmt.Errorf("failed to create encyption key for key %v: %w", keyConfig.ID, err) + } + entry := service.MakeCipherEntry(keyConfig.ID, cryptoKey, keyConfig.Secret) + ciphers.PushBack(&entry) + } + for portNum, ciphers := range portCiphers { + addr := net.JoinHostPort("::", strconv.Itoa(portNum)) + + sh := s.NewShadowsocksStreamHandler(ciphers) + ln, err := lnSet.Listen("tcp", addr) + if err != nil { + return err + } + logger.Infof("Shadowsocks TCP service listening on %v", ln.Addr().String()) + accept := func() (transport.StreamConn, error) { + c, err := ln.Accept() + if err == nil { + return c.(transport.StreamConn), err + } + return nil, err + } + go service.StreamServe(accept, sh.Handle) + + pc, err := lnSet.ListenPacket("udp", addr) + if err != nil { + return err + } + logger.Infof("Shadowsocks UDP service listening on %v", pc.LocalAddr().String()) + ph := s.NewShadowsocksPacketHandler(ciphers) + go ph.Handle(pc) + + totalCipherCount += ciphers.Len() + } + logger.Infof("Loaded %d access keys over %d listeners", totalCipherCount, lnSet.Len()) + s.m.SetNumAccessKeys(totalCipherCount, lnSet.Len()) + return nil + }() + + <-stopCh + }() + + err := <-startErrCh + if err != nil { + return nil, err } - logger.Infof("Stopped %d listeners", s.lnSet.Len()) - return nil + return func() { + logger.Infof("Stopping running config.") + stopCh <- struct{}{} + }, nil +} + +// Stop serving the current config. +func (s *SSServer) Stop() { + s.stopConfig() + logger.Info("Stopped all listeners for running config") } // RunSSServer starts a shadowsocks server running, and returns the server or an error. func RunSSServer(filename string, natTimeout time.Duration, sm *outlineMetrics, replayHistory int) (*SSServer, error) { server := &SSServer{ + stopConfig: func() {}, lnManager: service.NewListenerManager(), natTimeout: natTimeout, m: sm, diff --git a/internal/integration_test/integration_test.go b/internal/integration_test/integration_test.go index db1b82d5..43109b7a 100644 --- a/internal/integration_test/integration_test.go +++ b/internal/integration_test/integration_test.go @@ -293,7 +293,7 @@ func TestUDPEcho(t *testing.T) { proxy.SetTargetIPValidator(allowAll) done := make(chan struct{}) go func() { - proxy.Handle(context.Background(), proxyConn) + proxy.Handle(proxyConn) done <- struct{}{} }() @@ -525,7 +525,7 @@ func BenchmarkUDPEcho(b *testing.B) { proxy.SetTargetIPValidator(allowAll) done := make(chan struct{}) go func() { - proxy.Handle(context.Background(), server) + proxy.Handle(server) done <- struct{}{} }() @@ -569,7 +569,7 @@ func BenchmarkUDPManyKeys(b *testing.B) { proxy.SetTargetIPValidator(allowAll) done := make(chan struct{}) go func() { - proxy.Handle(context.Background(), proxyConn) + proxy.Handle(proxyConn) done <- struct{}{} }() diff --git a/service/listeners.go b/service/listeners.go index 58169362..355d36b3 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -15,30 +15,21 @@ package service import ( - "context" "fmt" "net" "sync" "sync/atomic" - - "github.com/Jigsaw-Code/outline-sdk/transport" ) -type Handler interface { - NumCiphers() int - AddCipher(entry *CipherEntry) - Handle(ctx context.Context, conn any) -} +// The implementations of listeners for different network types are not +// interchangeable. The type of listener depends on the network type. +type Listener = any type acceptResponse struct { conn net.Conn err error } -type SharedListener interface { - SetHandler(handler Handler) -} - type sharedListener struct { listener net.Listener closed atomic.Int32 @@ -48,20 +39,6 @@ type sharedListener struct { closeFunc func() } -func (sl *sharedListener) SetHandler(handler Handler) { - accept := func() (transport.StreamConn, error) { - c, err := sl.Accept() - if err == nil { - return c.(transport.StreamConn), err - } - return nil, err - } - handle := func(ctx context.Context, conn transport.StreamConn) { - handler.Handle(ctx, conn) - } - go StreamServe(accept, handle) -} - // Accept accepts connections until Close() is called. func (sl *sharedListener) Accept() (net.Conn, error) { if sl.closed.Load() == 1 { @@ -123,10 +100,6 @@ func (spc *sharedPacketConn) Close() error { return nil } -func (spc *sharedPacketConn) SetHandler(handler Handler) { - go handler.Handle(context.TODO(), spc.PacketConn) -} - type globalListener struct { ln net.Listener pc net.PacketConn @@ -135,32 +108,46 @@ type globalListener struct { } type ListenerSet interface { - Listen(ctx context.Context, network string, addr string, config net.ListenConfig) (SharedListener, error) + Listen(network string, addr string) (net.Listener, error) + ListenPacket(network string, addr string) (net.PacketConn, error) Close() error Len() int } type listenerSet struct { manager ListenerManager - listeners map[string]*SharedListener + listeners map[string]Listener } -func (ls *listenerSet) Listen(ctx context.Context, network string, addr string, config net.ListenConfig) (SharedListener, error) { +func (ls *listenerSet) Listen(network string, addr string) (net.Listener, error) { lnKey := listenerKey(network, addr) if _, exists := ls.listeners[lnKey]; exists { return nil, fmt.Errorf("listener %s already exists", lnKey) } - ln, err := ls.manager.Listen(ctx, network, addr, config) + ln, err := ls.manager.Listen(network, addr) if err != nil { return nil, err } - ls.listeners[lnKey] = &ln + ls.listeners[lnKey] = ln + return ln, nil +} + +func (ls *listenerSet) ListenPacket(network string, addr string) (net.PacketConn, error) { + lnKey := listenerKey(network, addr) + if _, exists := ls.listeners[lnKey]; exists { + return nil, fmt.Errorf("listener %s already exists", lnKey) + } + ln, err := ls.manager.ListenPacket(network, addr) + if err != nil { + return nil, err + } + ls.listeners[lnKey] = ln return ln, nil } func (ls *listenerSet) Close() error { for _, listener := range ls.listeners { - switch ln := (*listener).(type) { + switch ln := listener.(type) { case net.Listener: if err := ln.Close(); err != nil { return fmt.Errorf("%s listener on address %s failed to stop: %w", ln.Addr().Network(), ln.Addr().String(), err) @@ -183,7 +170,8 @@ func (ls *listenerSet) Len() int { // ListenerManager holds and manages the state of shared listeners. type ListenerManager interface { NewListenerSet() ListenerSet - Listen(ctx context.Context, network string, addr string, config net.ListenConfig) (SharedListener, error) + Listen(network string, addr string) (net.Listener, error) + ListenPacket(network string, addr string) (net.PacketConn, error) } type listenerManager struct { @@ -200,55 +188,25 @@ func NewListenerManager() ListenerManager { func (m *listenerManager) NewListenerSet() ListenerSet { return &listenerSet{ manager: m, - listeners: make(map[string]*SharedListener), + listeners: make(map[string]Listener), } } -// Listen creates a new listener for a given network and address. +// ListenStream creates a new stream listener for a given network and address. // // Listeners can overlap one another, because during config changes the new // config is started before the old config is destroyed. This is done by using // reusable listener wrappers, which do not actually close the underlying socket // until all uses of the shared listener have been closed. -func (m *listenerManager) Listen(ctx context.Context, network string, addr string, config net.ListenConfig) (SharedListener, error) { - lnKey := listenerKey(network, addr) - - switch network { - - case "tcp": - m.listenersMu.Lock() - defer m.listenersMu.Unlock() - - if lnGlobal, ok := m.listeners[lnKey]; ok { - lnGlobal.usage.Add(1) - return &sharedListener{ - listener: lnGlobal.ln, - usage: &lnGlobal.usage, - acceptCh: lnGlobal.acceptCh, - closeCh: make(chan struct{}), - closeFunc: func() { - m.delete(lnKey) - }, - }, nil - } - - ln, err := config.Listen(ctx, network, addr) - if err != nil { - return nil, err - } - - lnGlobal := &globalListener{ln: ln, acceptCh: make(chan acceptResponse)} - go func() { - for { - conn, err := lnGlobal.ln.Accept() - lnGlobal.acceptCh <- acceptResponse{conn, err} - } - }() - lnGlobal.usage.Store(1) - m.listeners[lnKey] = lnGlobal +func (m *listenerManager) Listen(network string, addr string) (net.Listener, error) { + m.listenersMu.Lock() + defer m.listenersMu.Unlock() + lnKey := listenerKey(network, addr) + if lnGlobal, ok := m.listeners[lnKey]; ok { + lnGlobal.usage.Add(1) return &sharedListener{ - listener: ln, + listener: lnGlobal.ln, usage: &lnGlobal.usage, acceptCh: lnGlobal.acceptCh, closeCh: make(chan struct{}), @@ -256,43 +214,69 @@ func (m *listenerManager) Listen(ctx context.Context, network string, addr strin m.delete(lnKey) }, }, nil + } - case "udp": - m.listenersMu.Lock() - defer m.listenersMu.Unlock() - - if lnGlobal, ok := m.listeners[lnKey]; ok { - lnGlobal.usage.Add(1) - return &sharedPacketConn{ - PacketConn: lnGlobal.pc, - usage: &lnGlobal.usage, - closeFunc: func() { - m.delete(lnKey) - }, - }, nil - } + ln, err := net.Listen(network, addr) + if err != nil { + return nil, err + } - pc, err := config.ListenPacket(ctx, network, addr) - if err != nil { - return nil, err + lnGlobal := &globalListener{ln: ln, acceptCh: make(chan acceptResponse)} + go func() { + for { + conn, err := lnGlobal.ln.Accept() + lnGlobal.acceptCh <- acceptResponse{conn, err} } + }() + lnGlobal.usage.Store(1) + m.listeners[lnKey] = lnGlobal + + return &sharedListener{ + listener: ln, + usage: &lnGlobal.usage, + acceptCh: lnGlobal.acceptCh, + closeCh: make(chan struct{}), + closeFunc: func() { + m.delete(lnKey) + }, + }, nil +} - lnGlobal := &globalListener{pc: pc} - lnGlobal.usage.Store(1) - m.listeners[lnKey] = lnGlobal +// ListenPacket creates a new packet listener for a given network and address. +// +// See notes on [ListenStream]. +func (m *listenerManager) ListenPacket(network string, addr string) (net.PacketConn, error) { + m.listenersMu.Lock() + defer m.listenersMu.Unlock() + lnKey := listenerKey(network, addr) + if lnGlobal, ok := m.listeners[lnKey]; ok { + lnGlobal.usage.Add(1) return &sharedPacketConn{ - PacketConn: pc, + PacketConn: lnGlobal.pc, usage: &lnGlobal.usage, closeFunc: func() { m.delete(lnKey) }, }, nil + } - default: - return nil, fmt.Errorf("unsupported network: %s", network) - + pc, err := net.ListenPacket(network, addr) + if err != nil { + return nil, err } + + lnGlobal := &globalListener{pc: pc} + lnGlobal.usage.Store(1) + m.listeners[lnKey] = lnGlobal + + return &sharedPacketConn{ + PacketConn: pc, + usage: &lnGlobal.usage, + closeFunc: func() { + m.delete(lnKey) + }, + }, nil } func (m *listenerManager) delete(key string) { diff --git a/service/tcp.go b/service/tcp.go index 10c2a4f2..d88f537a 100644 --- a/service/tcp.go +++ b/service/tcp.go @@ -161,7 +161,7 @@ func NewShadowsocksStreamAuthenticator(ciphers CipherList, replayCache *ReplayCa } } -type tcpHandler struct { +type streamHandler struct { listenerId string m TCPMetrics readTimeout time.Duration @@ -169,9 +169,9 @@ type tcpHandler struct { dialer transport.StreamDialer } -// NewTCPService creates a TCPService -func NewTCPHandler(authenticate StreamAuthenticateFunc, m TCPMetrics, timeout time.Duration) TCPHandler { - return &tcpHandler{ +// NewStreamHandler creates a StreamHandler +func NewStreamHandler(authenticate StreamAuthenticateFunc, m TCPMetrics, timeout time.Duration) StreamHandler { + return &streamHandler{ m: m, readTimeout: timeout, authenticate: authenticate, @@ -188,14 +188,14 @@ func makeValidatingTCPStreamDialer(targetIPValidator onet.TargetIPValidator) tra }}} } -// TCPService is a Shadowsocks TCP service that can be started and stopped. -type TCPHandler interface { +// StreamHandler is a handler that handles stream connections. +type StreamHandler interface { Handle(ctx context.Context, conn transport.StreamConn) // SetTargetDialer sets the [transport.StreamDialer] to be used to connect to target addresses. SetTargetDialer(dialer transport.StreamDialer) } -func (s *tcpHandler) SetTargetDialer(dialer transport.StreamDialer) { +func (s *streamHandler) SetTargetDialer(dialer transport.StreamDialer) { s.dialer = dialer } @@ -219,12 +219,12 @@ func WrapStreamListener[T transport.StreamConn](f func() (T, error)) StreamListe } } -type StreamHandler func(ctx context.Context, conn transport.StreamConn) +type StreamHandleFunc func(ctx context.Context, conn transport.StreamConn) // StreamServe repeatedly calls `accept` to obtain connections and `handle` to handle them until // accept() returns [ErrClosed]. When that happens, all connection handlers will be notified // via their [context.Context]. StreamServe will return after all pending handlers return. -func StreamServe(accept StreamListener, handle StreamHandler) { +func StreamServe(accept StreamListener, handle StreamHandleFunc) { var running sync.WaitGroup defer running.Wait() ctx, contextCancel := context.WithCancel(context.Background()) @@ -253,7 +253,7 @@ func StreamServe(accept StreamListener, handle StreamHandler) { } } -func (h *tcpHandler) Handle(ctx context.Context, clientConn transport.StreamConn) { +func (h *streamHandler) Handle(ctx context.Context, clientConn transport.StreamConn) { clientInfo, err := ipinfo.GetIPInfoFromAddr(h.m, clientConn.RemoteAddr()) if err != nil { logger.Warningf("Failed client info lookup: %v", err) @@ -327,7 +327,7 @@ func proxyConnection(ctx context.Context, dialer transport.StreamDialer, tgtAddr return nil } -func (h *tcpHandler) handleConnection(ctx context.Context, outerConn transport.StreamConn, proxyMetrics *metrics.ProxyMetrics) (string, *onet.ConnectionError) { +func (h *streamHandler) handleConnection(ctx context.Context, outerConn transport.StreamConn, proxyMetrics *metrics.ProxyMetrics) (string, *onet.ConnectionError) { // Set a deadline to receive the address to the target. readDeadline := time.Now().Add(h.readTimeout) if deadline, ok := ctx.Deadline(); ok { @@ -369,7 +369,7 @@ func (h *tcpHandler) handleConnection(ctx context.Context, outerConn transport.S // Keep the connection open until we hit the authentication deadline to protect against probing attacks // `proxyMetrics` is a pointer because its value is being mutated by `clientConn`. -func (h *tcpHandler) absorbProbe(clientConn transport.StreamConn, status string, proxyMetrics *metrics.ProxyMetrics) { +func (h *streamHandler) absorbProbe(clientConn transport.StreamConn, status string, proxyMetrics *metrics.ProxyMetrics) { // This line updates proxyMetrics.ClientProxy before it's used in AddTCPProbe. _, drainErr := io.Copy(io.Discard, clientConn) // drain socket drainResult := drainErrToString(drainErr) diff --git a/service/udp.go b/service/udp.go index 859c6c44..4830e302 100644 --- a/service/udp.go +++ b/service/udp.go @@ -15,7 +15,6 @@ package service import ( - "context" "errors" "fmt" "net" @@ -102,7 +101,7 @@ type PacketHandler interface { // SetTargetIPValidator sets the function to be used to validate the target IP addresses. SetTargetIPValidator(targetIPValidator onet.TargetIPValidator) // Handle returns after clientConn closes and all the sub goroutines return. - Handle(ctx context.Context, clientConn net.PacketConn) + Handle(clientConn net.PacketConn) } func (h *packetHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPValidator) { @@ -111,7 +110,7 @@ func (h *packetHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPVali // Listen on addr for encrypted packets and basically do UDP NAT. // We take the ciphers as a pointer because it gets replaced on config updates. -func (h *packetHandler) Handle(ctx context.Context, clientConn net.PacketConn) { +func (h *packetHandler) Handle(clientConn net.PacketConn) { var running sync.WaitGroup nm := newNATmap(h.natTimeout, h.m, &running) diff --git a/service/udp_test.go b/service/udp_test.go index 90d880b5..f94238c5 100644 --- a/service/udp_test.go +++ b/service/udp_test.go @@ -16,7 +16,6 @@ package service import ( "bytes" - "context" "errors" "net" "net/netip" @@ -133,7 +132,7 @@ func sendToDiscard(payloads [][]byte, validator onet.TargetIPValidator) *natTest handler.SetTargetIPValidator(validator) done := make(chan struct{}) go func() { - handler.Handle(context.Background(), clientConn) + handler.Handle(clientConn) done <- struct{}{} }() @@ -489,7 +488,7 @@ func TestUDPEarlyClose(t *testing.T) { } require.Nil(t, clientConn.Close()) // This should return quickly without timing out. - s.Handle(context.Background(), clientConn) + s.Handle(clientConn) } // Makes sure the UDP listener returns [io.ErrClosed] on reads and writes after Close(). From 0a63f5c01fc2e366b7d89197313faaccf8e22ab0 Mon Sep 17 00:00:00 2001 From: sbruens Date: Fri, 19 Jul 2024 13:06:26 -0400 Subject: [PATCH 070/119] Make a `StreamListener` type. --- cmd/outline-ss-server/main.go | 14 +--- internal/integration_test/integration_test.go | 14 ++-- service/listeners.go | 66 ++++++++++++++----- service/tcp.go | 12 ++-- service/tcp_test.go | 36 +++++----- 5 files changed, 84 insertions(+), 58 deletions(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 13400c11..6180106e 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -26,7 +26,6 @@ import ( "syscall" "time" - "github.com/Jigsaw-Code/outline-sdk/transport" "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks" "github.com/Jigsaw-Code/outline-ss-server/ipinfo" "github.com/Jigsaw-Code/outline-ss-server/service" @@ -123,21 +122,14 @@ func (s *SSServer) runConfig(config Config) (func(), error) { addr := net.JoinHostPort("::", strconv.Itoa(portNum)) sh := s.NewShadowsocksStreamHandler(ciphers) - ln, err := lnSet.Listen("tcp", addr) + ln, err := lnSet.ListenStream(addr) if err != nil { return err } logger.Infof("Shadowsocks TCP service listening on %v", ln.Addr().String()) - accept := func() (transport.StreamConn, error) { - c, err := ln.Accept() - if err == nil { - return c.(transport.StreamConn), err - } - return nil, err - } - go service.StreamServe(accept, sh.Handle) + go service.StreamServe(ln.AcceptStream, sh.Handle) - pc, err := lnSet.ListenPacket("udp", addr) + pc, err := lnSet.ListenPacket(addr) if err != nil { return err } diff --git a/internal/integration_test/integration_test.go b/internal/integration_test/integration_test.go index 43109b7a..f72d835b 100644 --- a/internal/integration_test/integration_test.go +++ b/internal/integration_test/integration_test.go @@ -133,7 +133,7 @@ func TestTCPEcho(t *testing.T) { const testTimeout = 200 * time.Millisecond testMetrics := &service.NoOpTCPMetrics{} authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics) - handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout) + handler := service.NewStreamHandler(authFunc, testMetrics, testTimeout) handler.SetTargetDialer(&transport.TCPDialer{}) done := make(chan struct{}) go func() { @@ -202,10 +202,10 @@ func TestRestrictedAddresses(t *testing.T) { const testTimeout = 200 * time.Millisecond testMetrics := &statusMetrics{} authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout) + handler := service.NewStreamHandler(authFunc, testMetrics, testTimeout) done := make(chan struct{}) go func() { - service.StreamServe(service.WrapStreamListener(proxyListener.AcceptTCP), handler.Handle) + service.StreamServe(service.WrapStreamAcceptFunc(proxyListener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -384,11 +384,11 @@ func BenchmarkTCPThroughput(b *testing.B) { const testTimeout = 200 * time.Millisecond testMetrics := &service.NoOpTCPMetrics{} authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout) + handler := service.NewStreamHandler(authFunc, testMetrics, testTimeout) handler.SetTargetDialer(&transport.TCPDialer{}) done := make(chan struct{}) go func() { - service.StreamServe(service.WrapStreamListener(proxyListener.AcceptTCP), handler.Handle) + service.StreamServe(service.WrapStreamAcceptFunc(proxyListener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -448,11 +448,11 @@ func BenchmarkTCPMultiplexing(b *testing.B) { const testTimeout = 200 * time.Millisecond testMetrics := &service.NoOpTCPMetrics{} authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics) - handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout) + handler := service.NewStreamHandler(authFunc, testMetrics, testTimeout) handler.SetTargetDialer(&transport.TCPDialer{}) done := make(chan struct{}) go func() { - service.StreamServe(service.WrapStreamListener(proxyListener.AcceptTCP), handler.Handle) + service.StreamServe(service.WrapStreamAcceptFunc(proxyListener.AcceptTCP), handler.Handle) done <- struct{}{} }() diff --git a/service/listeners.go b/service/listeners.go index 355d36b3..f36ecfd9 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -16,22 +16,37 @@ package service import ( "fmt" + "io" "net" "sync" "sync/atomic" + + "github.com/Jigsaw-Code/outline-sdk/transport" ) // The implementations of listeners for different network types are not // interchangeable. The type of listener depends on the network type. -type Listener = any +type Listener = io.Closer + +type StreamListener interface { + // Accept waits for and returns the next connection to the listener. + AcceptStream() (transport.StreamConn, error) + + // Close closes the listener. + // Any blocked Accept operations will be unblocked and return errors. + Close() error + + // Addr returns the listener's network address. + Addr() net.Addr +} type acceptResponse struct { - conn net.Conn + conn transport.StreamConn err error } type sharedListener struct { - listener net.Listener + listener net.TCPListener closed atomic.Int32 usage *atomic.Int32 acceptCh chan acceptResponse @@ -40,7 +55,7 @@ type sharedListener struct { } // Accept accepts connections until Close() is called. -func (sl *sharedListener) Accept() (net.Conn, error) { +func (sl *sharedListener) AcceptStream() (transport.StreamConn, error) { if sl.closed.Load() == 1 { return nil, net.ErrClosed } @@ -101,16 +116,24 @@ func (spc *sharedPacketConn) Close() error { } type globalListener struct { - ln net.Listener + ln net.TCPListener pc net.PacketConn usage atomic.Int32 acceptCh chan acceptResponse } +// ListenerSet represents a set of listeners listening on unique addresses. Trying +// to listen on the same address twice will result in an error. The set can be +// closed as a unit, which is useful if you want to bring down a group of +// listeners, such as when reloading a new config. type ListenerSet interface { - Listen(network string, addr string) (net.Listener, error) - ListenPacket(network string, addr string) (net.PacketConn, error) + // ListenStream announces on a given TCP network address. + ListenStream(addr string) (StreamListener, error) + // ListenStream announces on a given UDP network address. + ListenPacket(addr string) (net.PacketConn, error) + // Close closes all the listeners in the set. Close() error + // Len returns the number of listeners in the set. Len() int } @@ -119,12 +142,14 @@ type listenerSet struct { listeners map[string]Listener } -func (ls *listenerSet) Listen(network string, addr string) (net.Listener, error) { +// ListenStream announces on a given TCP network address. +func (ls *listenerSet) ListenStream(addr string) (StreamListener, error) { + network := "tcp" lnKey := listenerKey(network, addr) if _, exists := ls.listeners[lnKey]; exists { return nil, fmt.Errorf("listener %s already exists", lnKey) } - ln, err := ls.manager.Listen(network, addr) + ln, err := ls.manager.ListenStream(network, addr) if err != nil { return nil, err } @@ -132,7 +157,9 @@ func (ls *listenerSet) Listen(network string, addr string) (net.Listener, error) return ln, nil } -func (ls *listenerSet) ListenPacket(network string, addr string) (net.PacketConn, error) { +// ListenPacket announces on a given UDP network address. +func (ls *listenerSet) ListenPacket(addr string) (net.PacketConn, error) { + network := "udp" lnKey := listenerKey(network, addr) if _, exists := ls.listeners[lnKey]; exists { return nil, fmt.Errorf("listener %s already exists", lnKey) @@ -145,6 +172,7 @@ func (ls *listenerSet) ListenPacket(network string, addr string) (net.PacketConn return ln, nil } +// Close closes all the listeners in the set. func (ls *listenerSet) Close() error { for _, listener := range ls.listeners { switch ln := listener.(type) { @@ -163,6 +191,7 @@ func (ls *listenerSet) Close() error { return nil } +// Len returns the number of listeners in the set. func (ls *listenerSet) Len() int { return len(ls.listeners) } @@ -170,7 +199,7 @@ func (ls *listenerSet) Len() int { // ListenerManager holds and manages the state of shared listeners. type ListenerManager interface { NewListenerSet() ListenerSet - Listen(network string, addr string) (net.Listener, error) + ListenStream(network string, addr string) (StreamListener, error) ListenPacket(network string, addr string) (net.PacketConn, error) } @@ -179,6 +208,7 @@ type listenerManager struct { listenersMu sync.Mutex } +// NewListenerManager creates a new [ListenerManger]. func NewListenerManager() ListenerManager { return &listenerManager{ listeners: make(map[string]*globalListener), @@ -198,7 +228,7 @@ func (m *listenerManager) NewListenerSet() ListenerSet { // config is started before the old config is destroyed. This is done by using // reusable listener wrappers, which do not actually close the underlying socket // until all uses of the shared listener have been closed. -func (m *listenerManager) Listen(network string, addr string) (net.Listener, error) { +func (m *listenerManager) ListenStream(network string, addr string) (StreamListener, error) { m.listenersMu.Lock() defer m.listenersMu.Unlock() @@ -216,15 +246,19 @@ func (m *listenerManager) Listen(network string, addr string) (net.Listener, err }, nil } - ln, err := net.Listen(network, addr) + tcpAddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + return nil, err + } + ln, err := net.ListenTCP(network, tcpAddr) if err != nil { return nil, err } - lnGlobal := &globalListener{ln: ln, acceptCh: make(chan acceptResponse)} + lnGlobal := &globalListener{ln: *ln, acceptCh: make(chan acceptResponse)} go func() { for { - conn, err := lnGlobal.ln.Accept() + conn, err := lnGlobal.ln.AcceptTCP() lnGlobal.acceptCh <- acceptResponse{conn, err} } }() @@ -232,7 +266,7 @@ func (m *listenerManager) Listen(network string, addr string) (net.Listener, err m.listeners[lnKey] = lnGlobal return &sharedListener{ - listener: ln, + listener: lnGlobal.ln, usage: &lnGlobal.usage, acceptCh: lnGlobal.acceptCh, closeCh: make(chan struct{}), diff --git a/service/tcp.go b/service/tcp.go index d88f537a..8138a24c 100644 --- a/service/tcp.go +++ b/service/tcp.go @@ -211,9 +211,9 @@ func ensureConnectionError(err error, fallbackStatus string, fallbackMsg string) } } -type StreamListener func() (transport.StreamConn, error) +type StreamAcceptFunc func() (transport.StreamConn, error) -func WrapStreamListener[T transport.StreamConn](f func() (T, error)) StreamListener { +func WrapStreamAcceptFunc[T transport.StreamConn](f func() (T, error)) StreamAcceptFunc { return func() (transport.StreamConn, error) { return f() } @@ -224,7 +224,7 @@ type StreamHandleFunc func(ctx context.Context, conn transport.StreamConn) // StreamServe repeatedly calls `accept` to obtain connections and `handle` to handle them until // accept() returns [ErrClosed]. When that happens, all connection handlers will be notified // via their [context.Context]. StreamServe will return after all pending handlers return. -func StreamServe(accept StreamListener, handle StreamHandleFunc) { +func StreamServe(accept StreamAcceptFunc, handle StreamHandleFunc) { var running sync.WaitGroup defer running.Wait() ctx, contextCancel := context.WithCancel(context.Background()) @@ -341,7 +341,7 @@ func (h *streamHandler) handleConnection(ctx context.Context, outerConn transpor id, innerConn, authErr := h.authenticate(outerConn) if authErr != nil { // Drain to protect against probing attacks. - h.absorbProbe(outerConn, authErr.Status, proxyMetrics) + h.absorbProbe(outerConn, outerConn.LocalAddr().String(), authErr.Status, proxyMetrics) return id, authErr } h.m.AddAuthenticatedTCPConnection(outerConn.RemoteAddr(), id) @@ -369,12 +369,12 @@ func (h *streamHandler) handleConnection(ctx context.Context, outerConn transpor // Keep the connection open until we hit the authentication deadline to protect against probing attacks // `proxyMetrics` is a pointer because its value is being mutated by `clientConn`. -func (h *streamHandler) absorbProbe(clientConn transport.StreamConn, status string, proxyMetrics *metrics.ProxyMetrics) { +func (h *streamHandler) absorbProbe(clientConn io.ReadCloser, addr, status string, proxyMetrics *metrics.ProxyMetrics) { // This line updates proxyMetrics.ClientProxy before it's used in AddTCPProbe. _, drainErr := io.Copy(io.Discard, clientConn) // drain socket drainResult := drainErrToString(drainErr) logger.Debugf("Drain error: %v, drain result: %v", drainErr, drainResult) - h.m.AddTCPProbe(status, drainResult, clientConn.LocalAddr().String(), proxyMetrics.ClientProxy) + h.m.AddTCPProbe(status, drainResult, addr, proxyMetrics.ClientProxy) } func drainErrToString(drainErr error) string { diff --git a/service/tcp_test.go b/service/tcp_test.go index fbe80f7c..428f70e0 100644 --- a/service/tcp_test.go +++ b/service/tcp_test.go @@ -281,10 +281,10 @@ func TestProbeRandom(t *testing.T) { require.NoError(t, err, "MakeTestCiphers failed: %v", err) testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond) + handler := NewStreamHandler(authFunc, testMetrics, 200*time.Millisecond) done := make(chan struct{}) go func() { - StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) + StreamServe(WrapStreamAcceptFunc(listener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -358,11 +358,11 @@ func TestProbeClientBytesBasicTruncated(t *testing.T) { cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond) + handler := NewStreamHandler(authFunc, testMetrics, 200*time.Millisecond) handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll)) done := make(chan struct{}) go func() { - StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) + StreamServe(WrapStreamAcceptFunc(listener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -393,11 +393,11 @@ func TestProbeClientBytesBasicModified(t *testing.T) { cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond) + handler := NewStreamHandler(authFunc, testMetrics, 200*time.Millisecond) handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll)) done := make(chan struct{}) go func() { - StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) + StreamServe(WrapStreamAcceptFunc(listener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -429,11 +429,11 @@ func TestProbeClientBytesCoalescedModified(t *testing.T) { cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond) + handler := NewStreamHandler(authFunc, testMetrics, 200*time.Millisecond) handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll)) done := make(chan struct{}) go func() { - StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) + StreamServe(WrapStreamAcceptFunc(listener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -472,10 +472,10 @@ func TestProbeServerBytesModified(t *testing.T) { cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond) + handler := NewStreamHandler(authFunc, testMetrics, 200*time.Millisecond) done := make(chan struct{}) go func() { - StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) + StreamServe(WrapStreamAcceptFunc(listener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -503,7 +503,7 @@ func TestReplayDefense(t *testing.T) { testMetrics := &probeTestMetrics{} const testTimeout = 200 * time.Millisecond authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics) - handler := NewTCPHandler(authFunc, testMetrics, testTimeout) + handler := NewStreamHandler(authFunc, testMetrics, testTimeout) snapshot := cipherList.SnapshotForClientIP(netip.Addr{}) cipherEntry := snapshot[0].Value.(*CipherEntry) cipher := cipherEntry.CryptoKey @@ -528,7 +528,7 @@ func TestReplayDefense(t *testing.T) { done := make(chan struct{}) go func() { - StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) + StreamServe(WrapStreamAcceptFunc(listener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -582,7 +582,7 @@ func TestReverseReplayDefense(t *testing.T) { testMetrics := &probeTestMetrics{} const testTimeout = 200 * time.Millisecond authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics) - handler := NewTCPHandler(authFunc, testMetrics, testTimeout) + handler := NewStreamHandler(authFunc, testMetrics, testTimeout) snapshot := cipherList.SnapshotForClientIP(netip.Addr{}) cipherEntry := snapshot[0].Value.(*CipherEntry) cipher := cipherEntry.CryptoKey @@ -598,7 +598,7 @@ func TestReverseReplayDefense(t *testing.T) { done := make(chan struct{}) go func() { - StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) + StreamServe(WrapStreamAcceptFunc(listener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -653,11 +653,11 @@ func probeExpectTimeout(t *testing.T, payloadSize int) { require.NoError(t, err, "MakeTestCiphers failed: %v", err) testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := NewTCPHandler(authFunc, testMetrics, testTimeout) + handler := NewStreamHandler(authFunc, testMetrics, testTimeout) done := make(chan struct{}) go func() { - StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) + StreamServe(WrapStreamAcceptFunc(listener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -717,14 +717,14 @@ func TestStreamServeEarlyClose(t *testing.T) { err = tcpListener.Close() require.NoError(t, err) // This should return quickly, without timing out or calling the handler. - StreamServe(WrapStreamListener(tcpListener.AcceptTCP), nil) + StreamServe(WrapStreamAcceptFunc(tcpListener.AcceptTCP), nil) } // Makes sure the TCP listener returns [io.ErrClosed] on Close(). func TestClosedTCPListenerError(t *testing.T) { tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) require.NoError(t, err) - accept := WrapStreamListener(tcpListener.AcceptTCP) + accept := WrapStreamAcceptFunc(tcpListener.AcceptTCP) err = tcpListener.Close() require.NoError(t, err) _, err = accept() From f018d175a73f54d4d56cad98f43d51529be30d0a Mon Sep 17 00:00:00 2001 From: sbruens Date: Fri, 19 Jul 2024 13:35:20 -0400 Subject: [PATCH 071/119] Rename `closeFunc` to `onCloseFunc`. --- service/listeners.go | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/service/listeners.go b/service/listeners.go index f36ecfd9..3a3d3a83 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -46,12 +46,12 @@ type acceptResponse struct { } type sharedListener struct { - listener net.TCPListener - closed atomic.Int32 - usage *atomic.Int32 - acceptCh chan acceptResponse - closeCh chan struct{} - closeFunc func() + listener net.TCPListener + closed atomic.Int32 + usage *atomic.Int32 + acceptCh chan acceptResponse + closeCh chan struct{} + onCloseFunc func() } // Accept accepts connections until Close() is called. @@ -78,7 +78,7 @@ func (sl *sharedListener) Close() error { // See if we need to actually close the underlying listener. if sl.usage.Add(-1) == 0 { - sl.closeFunc() + sl.onCloseFunc() err := sl.listener.Close() if err != nil { return err @@ -95,16 +95,16 @@ func (sl *sharedListener) Addr() net.Addr { type sharedPacketConn struct { net.PacketConn - closed atomic.Int32 - usage *atomic.Int32 - closeFunc func() + closed atomic.Int32 + usage *atomic.Int32 + onCloseFunc func() } func (spc *sharedPacketConn) Close() error { if spc.closed.CompareAndSwap(0, 1) { // See if we need to actually close the underlying listener. if spc.usage.Add(-1) == 0 { - spc.closeFunc() + spc.onCloseFunc() err := spc.PacketConn.Close() if err != nil { return err @@ -240,7 +240,7 @@ func (m *listenerManager) ListenStream(network string, addr string) (StreamListe usage: &lnGlobal.usage, acceptCh: lnGlobal.acceptCh, closeCh: make(chan struct{}), - closeFunc: func() { + onCloseFunc: func() { m.delete(lnKey) }, }, nil @@ -270,7 +270,7 @@ func (m *listenerManager) ListenStream(network string, addr string) (StreamListe usage: &lnGlobal.usage, acceptCh: lnGlobal.acceptCh, closeCh: make(chan struct{}), - closeFunc: func() { + onCloseFunc: func() { m.delete(lnKey) }, }, nil @@ -289,7 +289,7 @@ func (m *listenerManager) ListenPacket(network string, addr string) (net.PacketC return &sharedPacketConn{ PacketConn: lnGlobal.pc, usage: &lnGlobal.usage, - closeFunc: func() { + onCloseFunc: func() { m.delete(lnKey) }, }, nil @@ -307,7 +307,7 @@ func (m *listenerManager) ListenPacket(network string, addr string) (net.PacketC return &sharedPacketConn{ PacketConn: pc, usage: &lnGlobal.usage, - closeFunc: func() { + onCloseFunc: func() { m.delete(lnKey) }, }, nil From 4295c45f3133d7715bdfb5efba92de2852ba1be8 Mon Sep 17 00:00:00 2001 From: sbruens Date: Fri, 19 Jul 2024 13:40:09 -0400 Subject: [PATCH 072/119] Rename `globalListener`. --- service/listeners.go | 48 ++++++++++++++++++++++---------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/service/listeners.go b/service/listeners.go index 3a3d3a83..ce635217 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -115,7 +115,7 @@ func (spc *sharedPacketConn) Close() error { return nil } -type globalListener struct { +type concreteListener struct { ln net.TCPListener pc net.PacketConn usage atomic.Int32 @@ -204,14 +204,14 @@ type ListenerManager interface { } type listenerManager struct { - listeners map[string]*globalListener + listeners map[string]*concreteListener listenersMu sync.Mutex } // NewListenerManager creates a new [ListenerManger]. func NewListenerManager() ListenerManager { return &listenerManager{ - listeners: make(map[string]*globalListener), + listeners: make(map[string]*concreteListener), } } @@ -233,12 +233,12 @@ func (m *listenerManager) ListenStream(network string, addr string) (StreamListe defer m.listenersMu.Unlock() lnKey := listenerKey(network, addr) - if lnGlobal, ok := m.listeners[lnKey]; ok { - lnGlobal.usage.Add(1) + if lnConcrete, ok := m.listeners[lnKey]; ok { + lnConcrete.usage.Add(1) return &sharedListener{ - listener: lnGlobal.ln, - usage: &lnGlobal.usage, - acceptCh: lnGlobal.acceptCh, + listener: lnConcrete.ln, + usage: &lnConcrete.usage, + acceptCh: lnConcrete.acceptCh, closeCh: make(chan struct{}), onCloseFunc: func() { m.delete(lnKey) @@ -255,20 +255,20 @@ func (m *listenerManager) ListenStream(network string, addr string) (StreamListe return nil, err } - lnGlobal := &globalListener{ln: *ln, acceptCh: make(chan acceptResponse)} + lnConcrete := &concreteListener{ln: *ln, acceptCh: make(chan acceptResponse)} go func() { for { - conn, err := lnGlobal.ln.AcceptTCP() - lnGlobal.acceptCh <- acceptResponse{conn, err} + conn, err := lnConcrete.ln.AcceptTCP() + lnConcrete.acceptCh <- acceptResponse{conn, err} } }() - lnGlobal.usage.Store(1) - m.listeners[lnKey] = lnGlobal + lnConcrete.usage.Store(1) + m.listeners[lnKey] = lnConcrete return &sharedListener{ - listener: lnGlobal.ln, - usage: &lnGlobal.usage, - acceptCh: lnGlobal.acceptCh, + listener: lnConcrete.ln, + usage: &lnConcrete.usage, + acceptCh: lnConcrete.acceptCh, closeCh: make(chan struct{}), onCloseFunc: func() { m.delete(lnKey) @@ -284,11 +284,11 @@ func (m *listenerManager) ListenPacket(network string, addr string) (net.PacketC defer m.listenersMu.Unlock() lnKey := listenerKey(network, addr) - if lnGlobal, ok := m.listeners[lnKey]; ok { - lnGlobal.usage.Add(1) + if lnConcrete, ok := m.listeners[lnKey]; ok { + lnConcrete.usage.Add(1) return &sharedPacketConn{ - PacketConn: lnGlobal.pc, - usage: &lnGlobal.usage, + PacketConn: lnConcrete.pc, + usage: &lnConcrete.usage, onCloseFunc: func() { m.delete(lnKey) }, @@ -300,13 +300,13 @@ func (m *listenerManager) ListenPacket(network string, addr string) (net.PacketC return nil, err } - lnGlobal := &globalListener{pc: pc} - lnGlobal.usage.Store(1) - m.listeners[lnKey] = lnGlobal + lnConcrete := &concreteListener{pc: pc} + lnConcrete.usage.Store(1) + m.listeners[lnKey] = lnConcrete return &sharedPacketConn{ PacketConn: pc, - usage: &lnGlobal.usage, + usage: &lnConcrete.usage, onCloseFunc: func() { m.delete(lnKey) }, From e6963f62e2b2e3fbe1f2ca7b3f2b8500917e262c Mon Sep 17 00:00:00 2001 From: sbruens Date: Fri, 19 Jul 2024 14:45:09 -0400 Subject: [PATCH 073/119] Don't track usage in the shared listeners. --- service/listeners.go | 98 +++++++++++++++++++++++--------------------- 1 file changed, 51 insertions(+), 47 deletions(-) diff --git a/service/listeners.go b/service/listeners.go index ce635217..b7a56aca 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -15,6 +15,7 @@ package service import ( + "errors" "fmt" "io" "net" @@ -46,19 +47,16 @@ type acceptResponse struct { } type sharedListener struct { - listener net.TCPListener - closed atomic.Int32 - usage *atomic.Int32 + listener net.TCPListener + once sync.Once + acceptCh chan acceptResponse closeCh chan struct{} - onCloseFunc func() + onCloseFunc func() error } // Accept accepts connections until Close() is called. func (sl *sharedListener) AcceptStream() (transport.StreamConn, error) { - if sl.closed.Load() == 1 { - return nil, net.ErrClosed - } select { case acceptResponse := <-sl.acceptCh: if acceptResponse.err != nil { @@ -73,20 +71,13 @@ func (sl *sharedListener) AcceptStream() (transport.StreamConn, error) { // Close stops accepting new connections without closing the underlying socket. // Only when the last user closes it, we actually close it. func (sl *sharedListener) Close() error { - if sl.closed.CompareAndSwap(0, 1) { - close(sl.closeCh) - - // See if we need to actually close the underlying listener. - if sl.usage.Add(-1) == 0 { - sl.onCloseFunc() - err := sl.listener.Close() - if err != nil { - return err - } - } - } + var err error + sl.once.Do(func() { - return nil + close(sl.closeCh) + err = sl.onCloseFunc() + }) + return err } func (sl *sharedListener) Addr() net.Addr { @@ -95,33 +86,43 @@ func (sl *sharedListener) Addr() net.Addr { type sharedPacketConn struct { net.PacketConn - closed atomic.Int32 - usage *atomic.Int32 - onCloseFunc func() + once sync.Once + onCloseFunc func() error } func (spc *sharedPacketConn) Close() error { - if spc.closed.CompareAndSwap(0, 1) { - // See if we need to actually close the underlying listener. - if spc.usage.Add(-1) == 0 { - spc.onCloseFunc() - err := spc.PacketConn.Close() - if err != nil { - return err - } - } - } - - return nil + var err error + spc.once.Do(func() { + err = spc.onCloseFunc() + }) + return err } type concreteListener struct { - ln net.TCPListener + ln *net.TCPListener pc net.PacketConn usage atomic.Int32 acceptCh chan acceptResponse } +func (cl *concreteListener) Close() error { + if cl.usage.Add(-1) == 0 { + if cl.ln != nil { + err := cl.ln.Close() + if err != nil { + return err + } + } + if cl.pc != nil { + err := cl.pc.Close() + if err != nil { + return err + } + } + } + return nil +} + // ListenerSet represents a set of listeners listening on unique addresses. Trying // to listen on the same address twice will result in an error. The set can be // closed as a unit, which is useful if you want to bring down a group of @@ -236,12 +237,12 @@ func (m *listenerManager) ListenStream(network string, addr string) (StreamListe if lnConcrete, ok := m.listeners[lnKey]; ok { lnConcrete.usage.Add(1) return &sharedListener{ - listener: lnConcrete.ln, - usage: &lnConcrete.usage, + listener: *lnConcrete.ln, acceptCh: lnConcrete.acceptCh, closeCh: make(chan struct{}), - onCloseFunc: func() { + onCloseFunc: func() error { m.delete(lnKey) + return lnConcrete.Close() }, }, nil } @@ -255,10 +256,13 @@ func (m *listenerManager) ListenStream(network string, addr string) (StreamListe return nil, err } - lnConcrete := &concreteListener{ln: *ln, acceptCh: make(chan acceptResponse)} + lnConcrete := &concreteListener{ln: ln, acceptCh: make(chan acceptResponse)} go func() { for { conn, err := lnConcrete.ln.AcceptTCP() + if errors.Is(err, net.ErrClosed) { + return + } lnConcrete.acceptCh <- acceptResponse{conn, err} } }() @@ -266,12 +270,12 @@ func (m *listenerManager) ListenStream(network string, addr string) (StreamListe m.listeners[lnKey] = lnConcrete return &sharedListener{ - listener: lnConcrete.ln, - usage: &lnConcrete.usage, + listener: *lnConcrete.ln, acceptCh: lnConcrete.acceptCh, closeCh: make(chan struct{}), - onCloseFunc: func() { + onCloseFunc: func() error { m.delete(lnKey) + return lnConcrete.Close() }, }, nil } @@ -288,9 +292,9 @@ func (m *listenerManager) ListenPacket(network string, addr string) (net.PacketC lnConcrete.usage.Add(1) return &sharedPacketConn{ PacketConn: lnConcrete.pc, - usage: &lnConcrete.usage, - onCloseFunc: func() { + onCloseFunc: func() error { m.delete(lnKey) + return lnConcrete.Close() }, }, nil } @@ -306,9 +310,9 @@ func (m *listenerManager) ListenPacket(network string, addr string) (net.PacketC return &sharedPacketConn{ PacketConn: pc, - usage: &lnConcrete.usage, - onCloseFunc: func() { + onCloseFunc: func() error { m.delete(lnKey) + return lnConcrete.Close() }, }, nil } From 7113f02144e696c8d0581d56f3305a4be8951fe8 Mon Sep 17 00:00:00 2001 From: sbruens Date: Fri, 19 Jul 2024 14:52:02 -0400 Subject: [PATCH 074/119] Add `getAddr()` to avoid some duplicate code. --- service/listeners.go | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/service/listeners.go b/service/listeners.go index b7a56aca..f109cd6b 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -73,7 +73,6 @@ func (sl *sharedListener) AcceptStream() (transport.StreamConn, error) { func (sl *sharedListener) Close() error { var err error sl.once.Do(func() { - close(sl.closeCh) err = sl.onCloseFunc() }) @@ -176,17 +175,12 @@ func (ls *listenerSet) ListenPacket(addr string) (net.PacketConn, error) { // Close closes all the listeners in the set. func (ls *listenerSet) Close() error { for _, listener := range ls.listeners { - switch ln := listener.(type) { - case net.Listener: - if err := ln.Close(); err != nil { - return fmt.Errorf("%s listener on address %s failed to stop: %w", ln.Addr().Network(), ln.Addr().String(), err) - } - case net.PacketConn: - if err := ln.Close(); err != nil { - return fmt.Errorf("%s listener on address %s failed to stop: %w", ln.LocalAddr().Network(), ln.LocalAddr().String(), err) - } - default: - return fmt.Errorf("unknown listener type: %v", ln) + addr, err := getAddr(listener) + if err != nil { + return err + } + if err := listener.Close(); err != nil { + return fmt.Errorf("%s listener on address %s failed to stop: %w", addr.Network(), addr.String(), err) } } return nil @@ -326,3 +320,14 @@ func (m *listenerManager) delete(key string) { func listenerKey(network string, addr string) string { return network + "/" + addr } + +func getAddr(listener Listener) (net.Addr, error) { + switch ln := listener.(type) { + case net.Listener: + return ln.Addr(), nil + case net.PacketConn: + return ln.LocalAddr(), nil + default: + return nil, fmt.Errorf("unknown listener type: %v", ln) + } +} From e4d679f05a4befb1f3e1bcde622fb56f13fcd5e7 Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 22 Jul 2024 11:54:42 -0400 Subject: [PATCH 075/119] Move listener set creation out of the inner function. --- cmd/outline-ss-server/main.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 6180106e..b97f14a5 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -98,10 +98,10 @@ func (s *SSServer) runConfig(config Config) (func(), error) { stopCh := make(chan struct{}) go func() { - startErrCh <- func() error { - lnSet := s.lnManager.NewListenerSet() - defer lnSet.Close() + lnSet := s.lnManager.NewListenerSet() + defer lnSet.Close() + startErrCh <- func() error { var totalCipherCount int portCiphers := make(map[int]service.CipherList) From be5f9b0ab0f4ded9d472bf8f30451bf18eb7632c Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 22 Jul 2024 12:34:00 -0400 Subject: [PATCH 076/119] Remove `PushBack()` from `CipherList`. --- cmd/outline-ss-server/main.go | 24 ++++++++++++------------ service/cipher_list.go | 15 --------------- service/cipher_list_testing.go | 7 +++++-- 3 files changed, 17 insertions(+), 29 deletions(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index b97f14a5..5040f25c 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -15,6 +15,7 @@ package main import ( + "container/list" "flag" "fmt" "net" @@ -102,25 +103,26 @@ func (s *SSServer) runConfig(config Config) (func(), error) { defer lnSet.Close() startErrCh <- func() error { - var totalCipherCount int - - portCiphers := make(map[int]service.CipherList) + portCiphers := make(map[int]*list.List) // Values are *List of *CipherEntry. for _, keyConfig := range config.Keys { - ciphers, ok := portCiphers[keyConfig.Port] + cipherList, ok := portCiphers[keyConfig.Port] if !ok { - ciphers = service.NewCipherList() - portCiphers[keyConfig.Port] = ciphers + cipherList = list.New() + portCiphers[keyConfig.Port] = cipherList } cryptoKey, err := shadowsocks.NewEncryptionKey(keyConfig.Cipher, keyConfig.Secret) if err != nil { return fmt.Errorf("failed to create encyption key for key %v: %w", keyConfig.ID, err) } entry := service.MakeCipherEntry(keyConfig.ID, cryptoKey, keyConfig.Secret) - ciphers.PushBack(&entry) + cipherList.PushBack(&entry) } - for portNum, ciphers := range portCiphers { + for portNum, cipherList := range portCiphers { addr := net.JoinHostPort("::", strconv.Itoa(portNum)) + ciphers := service.NewCipherList() + ciphers.Update(cipherList) + sh := s.NewShadowsocksStreamHandler(ciphers) ln, err := lnSet.ListenStream(addr) if err != nil { @@ -136,11 +138,9 @@ func (s *SSServer) runConfig(config Config) (func(), error) { logger.Infof("Shadowsocks UDP service listening on %v", pc.LocalAddr().String()) ph := s.NewShadowsocksPacketHandler(ciphers) go ph.Handle(pc) - - totalCipherCount += ciphers.Len() } - logger.Infof("Loaded %d access keys over %d listeners", totalCipherCount, lnSet.Len()) - s.m.SetNumAccessKeys(totalCipherCount, lnSet.Len()) + logger.Infof("Loaded %d access keys over %d listeners", len(config.Keys), lnSet.Len()) + s.m.SetNumAccessKeys(len(config.Keys), lnSet.Len()) return nil }() diff --git a/service/cipher_list.go b/service/cipher_list.go index beda57bc..3b6f1957 100644 --- a/service/cipher_list.go +++ b/service/cipher_list.go @@ -55,7 +55,6 @@ func MakeCipherEntry(id string, cryptoKey *shadowsocks.EncryptionKey, secret str // CipherList is a thread-safe collection of CipherEntry elements that allows for // snapshotting and moving to front. type CipherList interface { - Len() int // Returns a snapshot of the cipher list optimized for this client IP SnapshotForClientIP(clientIP netip.Addr) []*list.Element MarkUsedByClientIP(e *list.Element, clientIP netip.Addr) @@ -63,8 +62,6 @@ type CipherList interface { // which is a List of *CipherEntry. Update takes ownership of `contents`, // which must not be read or written after this call. Update(contents *list.List) - // PushBack inserts a new cipher at the back of the list. - PushBack(entry *CipherEntry) *list.Element } type cipherList struct { @@ -78,12 +75,6 @@ func NewCipherList() CipherList { return &cipherList{list: list.New()} } -func (cl *cipherList) Len() int { - cl.mu.Lock() - defer cl.mu.Unlock() - return cl.list.Len() -} - func matchesIP(e *list.Element, clientIP netip.Addr) bool { c := e.Value.(*CipherEntry) return clientIP != netip.Addr{} && clientIP == c.lastClientIP @@ -125,9 +116,3 @@ func (cl *cipherList) Update(src *list.List) { cl.list = src cl.mu.Unlock() } - -func (cl *cipherList) PushBack(entry *CipherEntry) *list.Element { - cl.mu.Lock() - defer cl.mu.Unlock() - return cl.list.PushBack(entry) -} diff --git a/service/cipher_list_testing.go b/service/cipher_list_testing.go index d8532f79..a77427ed 100644 --- a/service/cipher_list_testing.go +++ b/service/cipher_list_testing.go @@ -15,6 +15,7 @@ package service import ( + "container/list" "fmt" "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks" @@ -23,7 +24,7 @@ import ( // MakeTestCiphers creates a CipherList containing one fresh AEAD cipher // for each secret in `secrets`. func MakeTestCiphers(secrets []string) (CipherList, error) { - cipherList := NewCipherList() + l := list.New() for i := 0; i < len(secrets); i++ { cipherID := fmt.Sprintf("id-%v", i) cipher, err := shadowsocks.NewEncryptionKey(shadowsocks.CHACHA20IETFPOLY1305, secrets[i]) @@ -31,8 +32,10 @@ func MakeTestCiphers(secrets []string) (CipherList, error) { return nil, fmt.Errorf("failed to create cipher %v: %w", i, err) } entry := MakeCipherEntry(cipherID, cipher, secrets[i]) - cipherList.PushBack(&entry) + l.PushBack(&entry) } + cipherList := NewCipherList() + cipherList.Update(l) return cipherList, nil } From 343e4120fafb7b33fe843a98aeb0d165a3e5ef16 Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 22 Jul 2024 12:47:31 -0400 Subject: [PATCH 077/119] Move listener set to `main.go`. --- cmd/outline-ss-server/main.go | 63 ++++++++++++++++++++++++- service/listeners.go | 89 ----------------------------------- 2 files changed, 62 insertions(+), 90 deletions(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 5040f25c..90ff14fa 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -24,6 +24,7 @@ import ( "os/signal" "strconv" "strings" + "sync" "syscall" "time" @@ -94,12 +95,72 @@ func (s *SSServer) NewShadowsocksPacketHandler(ciphers service.CipherList) servi return service.NewPacketHandler(s.natTimeout, ciphers, s.m) } +type listenerSet struct { + manager service.ListenerManager + listeners map[string]service.Listener + listenersMu sync.Mutex +} + +// ListenStream announces on a given TCP network address. Trying to listen on +// the same address twice will result in an error. +func (ls *listenerSet) ListenStream(addr string) (service.StreamListener, error) { + ls.listenersMu.Lock() + defer ls.listenersMu.Unlock() + + lnKey := "tcp/" + addr + if _, exists := ls.listeners[lnKey]; exists { + return nil, fmt.Errorf("listener %s already exists", lnKey) + } + ln, err := ls.manager.ListenStream("tcp", addr) + if err != nil { + return nil, err + } + ls.listeners[lnKey] = ln + return ln, nil +} + +// ListenPacket announces on a given UDP network address. Trying to listen on +// the same address twice will result in an error. +func (ls *listenerSet) ListenPacket(addr string) (net.PacketConn, error) { + ls.listenersMu.Lock() + defer ls.listenersMu.Unlock() + + lnKey := "udp/" + addr + if _, exists := ls.listeners[lnKey]; exists { + return nil, fmt.Errorf("listener %s already exists", lnKey) + } + ln, err := ls.manager.ListenPacket("udp", addr) + if err != nil { + return nil, err + } + ls.listeners[lnKey] = ln + return ln, nil +} + +// Close closes all the listeners in the set. +func (ls *listenerSet) Close() error { + for addr, listener := range ls.listeners { + if err := listener.Close(); err != nil { + return fmt.Errorf("listener on address %s failed to stop: %w", addr, err) + } + } + return nil +} + +// Len returns the number of listeners in the set. +func (ls *listenerSet) Len() int { + return len(ls.listeners) +} + func (s *SSServer) runConfig(config Config) (func(), error) { startErrCh := make(chan error) stopCh := make(chan struct{}) go func() { - lnSet := s.lnManager.NewListenerSet() + lnSet := &listenerSet{ + manager: s.lnManager, + listeners: make(map[string]service.Listener), + } defer lnSet.Close() startErrCh <- func() error { diff --git a/service/listeners.go b/service/listeners.go index f109cd6b..62f55a0f 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -16,7 +16,6 @@ package service import ( "errors" - "fmt" "io" "net" "sync" @@ -122,78 +121,8 @@ func (cl *concreteListener) Close() error { return nil } -// ListenerSet represents a set of listeners listening on unique addresses. Trying -// to listen on the same address twice will result in an error. The set can be -// closed as a unit, which is useful if you want to bring down a group of -// listeners, such as when reloading a new config. -type ListenerSet interface { - // ListenStream announces on a given TCP network address. - ListenStream(addr string) (StreamListener, error) - // ListenStream announces on a given UDP network address. - ListenPacket(addr string) (net.PacketConn, error) - // Close closes all the listeners in the set. - Close() error - // Len returns the number of listeners in the set. - Len() int -} - -type listenerSet struct { - manager ListenerManager - listeners map[string]Listener -} - -// ListenStream announces on a given TCP network address. -func (ls *listenerSet) ListenStream(addr string) (StreamListener, error) { - network := "tcp" - lnKey := listenerKey(network, addr) - if _, exists := ls.listeners[lnKey]; exists { - return nil, fmt.Errorf("listener %s already exists", lnKey) - } - ln, err := ls.manager.ListenStream(network, addr) - if err != nil { - return nil, err - } - ls.listeners[lnKey] = ln - return ln, nil -} - -// ListenPacket announces on a given UDP network address. -func (ls *listenerSet) ListenPacket(addr string) (net.PacketConn, error) { - network := "udp" - lnKey := listenerKey(network, addr) - if _, exists := ls.listeners[lnKey]; exists { - return nil, fmt.Errorf("listener %s already exists", lnKey) - } - ln, err := ls.manager.ListenPacket(network, addr) - if err != nil { - return nil, err - } - ls.listeners[lnKey] = ln - return ln, nil -} - -// Close closes all the listeners in the set. -func (ls *listenerSet) Close() error { - for _, listener := range ls.listeners { - addr, err := getAddr(listener) - if err != nil { - return err - } - if err := listener.Close(); err != nil { - return fmt.Errorf("%s listener on address %s failed to stop: %w", addr.Network(), addr.String(), err) - } - } - return nil -} - -// Len returns the number of listeners in the set. -func (ls *listenerSet) Len() int { - return len(ls.listeners) -} - // ListenerManager holds and manages the state of shared listeners. type ListenerManager interface { - NewListenerSet() ListenerSet ListenStream(network string, addr string) (StreamListener, error) ListenPacket(network string, addr string) (net.PacketConn, error) } @@ -210,13 +139,6 @@ func NewListenerManager() ListenerManager { } } -func (m *listenerManager) NewListenerSet() ListenerSet { - return &listenerSet{ - manager: m, - listeners: make(map[string]Listener), - } -} - // ListenStream creates a new stream listener for a given network and address. // // Listeners can overlap one another, because during config changes the new @@ -320,14 +242,3 @@ func (m *listenerManager) delete(key string) { func listenerKey(network string, addr string) string { return network + "/" + addr } - -func getAddr(listener Listener) (net.Addr, error) { - switch ln := listener.(type) { - case net.Listener: - return ln.Addr(), nil - case net.PacketConn: - return ln.LocalAddr(), nil - default: - return nil, fmt.Errorf("unknown listener type: %v", ln) - } -} From 7f86ff17ae012c08f2cf96af77a2e1c7326661c1 Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 22 Jul 2024 13:15:14 -0400 Subject: [PATCH 078/119] Close the accept channel with an atomic value. --- service/listeners.go | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/service/listeners.go b/service/listeners.go index 62f55a0f..78523ac7 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -49,7 +49,7 @@ type sharedListener struct { listener net.TCPListener once sync.Once - acceptCh chan acceptResponse + acceptCh *atomic.Value // closed by first Close() call closeCh chan struct{} onCloseFunc func() error } @@ -57,7 +57,7 @@ type sharedListener struct { // Accept accepts connections until Close() is called. func (sl *sharedListener) AcceptStream() (transport.StreamConn, error) { select { - case acceptResponse := <-sl.acceptCh: + case acceptResponse := <-sl.acceptCh.Load().(chan acceptResponse): if acceptResponse.err != nil { return nil, acceptResponse.err } @@ -70,6 +70,7 @@ func (sl *sharedListener) AcceptStream() (transport.StreamConn, error) { // Close stops accepting new connections without closing the underlying socket. // Only when the last user closes it, we actually close it. func (sl *sharedListener) Close() error { + sl.acceptCh = nil var err error sl.once.Do(func() { close(sl.closeCh) @@ -152,15 +153,17 @@ func (m *listenerManager) ListenStream(network string, addr string) (StreamListe lnKey := listenerKey(network, addr) if lnConcrete, ok := m.listeners[lnKey]; ok { lnConcrete.usage.Add(1) - return &sharedListener{ + sl := &sharedListener{ listener: *lnConcrete.ln, - acceptCh: lnConcrete.acceptCh, closeCh: make(chan struct{}), onCloseFunc: func() error { m.delete(lnKey) return lnConcrete.Close() }, - }, nil + } + sl.acceptCh = &atomic.Value{} + sl.acceptCh.Store(lnConcrete.acceptCh) + return sl, nil } tcpAddr, err := net.ResolveTCPAddr("tcp", addr) @@ -185,15 +188,17 @@ func (m *listenerManager) ListenStream(network string, addr string) (StreamListe lnConcrete.usage.Store(1) m.listeners[lnKey] = lnConcrete - return &sharedListener{ + sl := &sharedListener{ listener: *lnConcrete.ln, - acceptCh: lnConcrete.acceptCh, closeCh: make(chan struct{}), onCloseFunc: func() error { m.delete(lnKey) return lnConcrete.Close() }, - }, nil + } + sl.acceptCh = &atomic.Value{} + sl.acceptCh.Store(lnConcrete.acceptCh) + return sl, nil } // ListenPacket creates a new packet listener for a given network and address. From e80b2c51d6f14d4643a8d7458f432b0f5903c10e Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 22 Jul 2024 13:21:48 -0400 Subject: [PATCH 079/119] Update comment. --- cmd/outline-ss-server/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 90ff14fa..7c27fcdc 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -218,7 +218,7 @@ func (s *SSServer) runConfig(config Config) (func(), error) { }, nil } -// Stop serving the current config. +// Stop stops serving the current config. func (s *SSServer) Stop() { s.stopConfig() logger.Info("Stopped all listeners for running config") From b1428edca64e02e6e40232fa21041aa7fb59cc1b Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 22 Jul 2024 13:27:38 -0400 Subject: [PATCH 080/119] Address review comments. --- cmd/outline-ss-server/main.go | 2 +- cmd/outline-ss-server/metrics.go | 14 +++++++------- service/listeners.go | 19 ++++--------------- 3 files changed, 12 insertions(+), 23 deletions(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 7c27fcdc..2a1c32ab 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -161,7 +161,7 @@ func (s *SSServer) runConfig(config Config) (func(), error) { manager: s.lnManager, listeners: make(map[string]service.Listener), } - defer lnSet.Close() + defer lnSet.Close() // This closes all the listeners in the set. startErrCh <- func() error { portCiphers := make(map[int]*list.List) // Values are *List of *CipherEntry. diff --git a/cmd/outline-ss-server/metrics.go b/cmd/outline-ss-server/metrics.go index 600cea16..e95ceeb3 100644 --- a/cmd/outline-ss-server/metrics.go +++ b/cmd/outline-ss-server/metrics.go @@ -38,7 +38,7 @@ type outlineMetrics struct { buildInfo *prometheus.GaugeVec accessKeys prometheus.Gauge - listeners prometheus.Gauge + ports prometheus.Gauge dataBytes *prometheus.CounterVec dataBytesPerLocation *prometheus.CounterVec timeToCipherMs *prometheus.HistogramVec @@ -183,10 +183,10 @@ func newPrometheusOutlineMetrics(ip2info ipinfo.IPInfoMap, registerer prometheus Name: "keys", Help: "Count of access keys", }), - listeners: prometheus.NewGauge(prometheus.GaugeOpts{ + ports: prometheus.NewGauge(prometheus.GaugeOpts{ Namespace: namespace, - Name: "listeners", - Help: "Count of open Shadowsocks listeners", + Name: "ports", + Help: "Count of open Shadowsocks ports", }), tcpProbes: prometheus.NewHistogramVec(prometheus.HistogramOpts{ Namespace: namespace, @@ -265,7 +265,7 @@ func newPrometheusOutlineMetrics(ip2info ipinfo.IPInfoMap, registerer prometheus m.tunnelTimeCollector = newTunnelTimeCollector(ip2info, registerer) // TODO: Is it possible to pass where to register the collectors? - registerer.MustRegister(m.buildInfo, m.accessKeys, m.listeners, m.tcpProbes, m.tcpOpenConnections, m.tcpClosedConnections, m.tcpConnectionDurationMs, + registerer.MustRegister(m.buildInfo, m.accessKeys, m.ports, m.tcpProbes, m.tcpOpenConnections, m.tcpClosedConnections, m.tcpConnectionDurationMs, m.dataBytes, m.dataBytesPerLocation, m.timeToCipherMs, m.udpPacketsFromClientPerLocation, m.udpAddedNatEntries, m.udpRemovedNatEntries, m.tunnelTimeCollector) return m @@ -275,9 +275,9 @@ func (m *outlineMetrics) SetBuildInfo(version string) { m.buildInfo.WithLabelValues(version).Set(1) } -func (m *outlineMetrics) SetNumAccessKeys(numKeys int, listeners int) { +func (m *outlineMetrics) SetNumAccessKeys(numKeys int, ports int) { m.accessKeys.Set(float64(numKeys)) - m.listeners.Set(float64(listeners)) + m.ports.Set(float64(ports)) } func (m *outlineMetrics) AddOpenTCPConnection(clientInfo ipinfo.IPInfo) { diff --git a/service/listeners.go b/service/listeners.go index 78523ac7..15073e02 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -46,9 +46,7 @@ type acceptResponse struct { } type sharedListener struct { - listener net.TCPListener - once sync.Once - + listener net.TCPListener acceptCh *atomic.Value // closed by first Close() call closeCh chan struct{} onCloseFunc func() error @@ -71,12 +69,8 @@ func (sl *sharedListener) AcceptStream() (transport.StreamConn, error) { // Only when the last user closes it, we actually close it. func (sl *sharedListener) Close() error { sl.acceptCh = nil - var err error - sl.once.Do(func() { - close(sl.closeCh) - err = sl.onCloseFunc() - }) - return err + close(sl.closeCh) + return sl.onCloseFunc() } func (sl *sharedListener) Addr() net.Addr { @@ -85,16 +79,11 @@ func (sl *sharedListener) Addr() net.Addr { type sharedPacketConn struct { net.PacketConn - once sync.Once onCloseFunc func() error } func (spc *sharedPacketConn) Close() error { - var err error - spc.once.Do(func() { - err = spc.onCloseFunc() - }) - return err + return spc.onCloseFunc() } type concreteListener struct { From 1c16de86b18c3092e1ae9110203fd9a1b7ada4ef Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 22 Jul 2024 13:31:53 -0400 Subject: [PATCH 081/119] Close before deleting key. --- service/listeners.go | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/service/listeners.go b/service/listeners.go index 15073e02..96e41085 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -146,8 +146,11 @@ func (m *listenerManager) ListenStream(network string, addr string) (StreamListe listener: *lnConcrete.ln, closeCh: make(chan struct{}), onCloseFunc: func() error { + if err := lnConcrete.Close(); err != nil { + return err + } m.delete(lnKey) - return lnConcrete.Close() + return nil }, } sl.acceptCh = &atomic.Value{} @@ -181,8 +184,11 @@ func (m *listenerManager) ListenStream(network string, addr string) (StreamListe listener: *lnConcrete.ln, closeCh: make(chan struct{}), onCloseFunc: func() error { + if err := lnConcrete.Close(); err != nil { + return err + } m.delete(lnKey) - return lnConcrete.Close() + return nil }, } sl.acceptCh = &atomic.Value{} @@ -203,8 +209,11 @@ func (m *listenerManager) ListenPacket(network string, addr string) (net.PacketC return &sharedPacketConn{ PacketConn: lnConcrete.pc, onCloseFunc: func() error { + if err := lnConcrete.Close(); err != nil { + return err + } m.delete(lnKey) - return lnConcrete.Close() + return nil }, }, nil } @@ -221,8 +230,11 @@ func (m *listenerManager) ListenPacket(network string, addr string) (net.PacketC return &sharedPacketConn{ PacketConn: pc, onCloseFunc: func() error { + if err := lnConcrete.Close(); err != nil { + return err + } m.delete(lnKey) - return lnConcrete.Close() + return nil }, }, nil } From ebc7053c6ca1fb72403e4940ab441c203063bf19 Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 22 Jul 2024 15:40:28 -0400 Subject: [PATCH 082/119] `server.Stop()` does not return a value --- cmd/outline-ss-server/server_test.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cmd/outline-ss-server/server_test.go b/cmd/outline-ss-server/server_test.go index 0b7777b2..2ba0772e 100644 --- a/cmd/outline-ss-server/server_test.go +++ b/cmd/outline-ss-server/server_test.go @@ -27,7 +27,5 @@ func TestRunSSServer(t *testing.T) { if err != nil { t.Fatalf("RunSSServer() error = %v", err) } - if err := server.Stop(); err != nil { - t.Errorf("Error while stopping server: %v", err) - } + server.Stop() } From 67fc7fbf6ea596ef7b4cef7ced9c0667736ee31e Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 22 Jul 2024 15:47:07 -0400 Subject: [PATCH 083/119] Add a comment for `StreamListener`. --- service/listeners.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/service/listeners.go b/service/listeners.go index 96e41085..2eda5481 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -28,6 +28,8 @@ import ( // interchangeable. The type of listener depends on the network type. type Listener = io.Closer +// StreamListener is a network listener for stream-oriented protocols that +// accepts [transport.StreamConn] connections. type StreamListener interface { // Accept waits for and returns the next connection to the listener. AcceptStream() (transport.StreamConn, error) From 7a15e7df8ebcfe5ed156dbe0839685b7ec1a9c38 Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 22 Jul 2024 16:11:04 -0400 Subject: [PATCH 084/119] Do not delete the listener from the manager until the last user has closed it. --- service/listeners.go | 49 +++++++++++++++++++++----------------------- 1 file changed, 23 insertions(+), 26 deletions(-) diff --git a/service/listeners.go b/service/listeners.go index 2eda5481..ff93e25d 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -89,10 +89,11 @@ func (spc *sharedPacketConn) Close() error { } type concreteListener struct { - ln *net.TCPListener - pc net.PacketConn - usage atomic.Int32 - acceptCh chan acceptResponse + ln *net.TCPListener + pc net.PacketConn + usage atomic.Int32 + acceptCh chan acceptResponse + onCloseFunc func() // Called when the listener's last user closes it. } func (cl *concreteListener) Close() error { @@ -109,6 +110,7 @@ func (cl *concreteListener) Close() error { return err } } + cl.onCloseFunc() } return nil } @@ -148,11 +150,7 @@ func (m *listenerManager) ListenStream(network string, addr string) (StreamListe listener: *lnConcrete.ln, closeCh: make(chan struct{}), onCloseFunc: func() error { - if err := lnConcrete.Close(); err != nil { - return err - } - m.delete(lnKey) - return nil + return lnConcrete.Close() }, } sl.acceptCh = &atomic.Value{} @@ -169,7 +167,13 @@ func (m *listenerManager) ListenStream(network string, addr string) (StreamListe return nil, err } - lnConcrete := &concreteListener{ln: ln, acceptCh: make(chan acceptResponse)} + lnConcrete := &concreteListener{ + ln: ln, + acceptCh: make(chan acceptResponse), + onCloseFunc: func() { + m.delete(lnKey) + }, + } go func() { for { conn, err := lnConcrete.ln.AcceptTCP() @@ -186,11 +190,7 @@ func (m *listenerManager) ListenStream(network string, addr string) (StreamListe listener: *lnConcrete.ln, closeCh: make(chan struct{}), onCloseFunc: func() error { - if err := lnConcrete.Close(); err != nil { - return err - } - m.delete(lnKey) - return nil + return lnConcrete.Close() }, } sl.acceptCh = &atomic.Value{} @@ -211,11 +211,7 @@ func (m *listenerManager) ListenPacket(network string, addr string) (net.PacketC return &sharedPacketConn{ PacketConn: lnConcrete.pc, onCloseFunc: func() error { - if err := lnConcrete.Close(); err != nil { - return err - } - m.delete(lnKey) - return nil + return lnConcrete.Close() }, }, nil } @@ -225,18 +221,19 @@ func (m *listenerManager) ListenPacket(network string, addr string) (net.PacketC return nil, err } - lnConcrete := &concreteListener{pc: pc} + lnConcrete := &concreteListener{ + pc: pc, + onCloseFunc: func() { + m.delete(lnKey) + }, + } lnConcrete.usage.Store(1) m.listeners[lnKey] = lnConcrete return &sharedPacketConn{ PacketConn: pc, onCloseFunc: func() error { - if err := lnConcrete.Close(); err != nil { - return err - } - m.delete(lnKey) - return nil + return lnConcrete.Close() }, }, nil } From 499829e1462a03eb82281f1efb32970345fe1d31 Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 22 Jul 2024 16:28:03 -0400 Subject: [PATCH 085/119] Consolidate usage counting inside a `listenAddress` type. --- service/listeners.go | 111 +++++++++++++++++++------------------------ 1 file changed, 48 insertions(+), 63 deletions(-) diff --git a/service/listeners.go b/service/listeners.go index ff93e25d..41aec8b0 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -88,7 +88,7 @@ func (spc *sharedPacketConn) Close() error { return spc.onCloseFunc() } -type concreteListener struct { +type listenAddr struct { ln *net.TCPListener pc net.PacketConn usage atomic.Int32 @@ -96,23 +96,42 @@ type concreteListener struct { onCloseFunc func() // Called when the listener's last user closes it. } -func (cl *concreteListener) Close() error { - if cl.usage.Add(-1) == 0 { - if cl.ln != nil { - err := cl.ln.Close() - if err != nil { - return err +func (cl *listenAddr) NewStreamListener() StreamListener { + cl.usage.Add(1) + sl := &sharedListener{ + listener: *cl.ln, + closeCh: make(chan struct{}), + onCloseFunc: func() error { + if cl.usage.Add(-1) == 0 { + err := cl.ln.Close() + if err != nil { + return err + } + cl.onCloseFunc() } - } - if cl.pc != nil { - err := cl.pc.Close() - if err != nil { - return err + return nil + }, + } + sl.acceptCh = &atomic.Value{} + sl.acceptCh.Store(cl.acceptCh) + return sl +} + +func (cl *listenAddr) NewPacketListener() net.PacketConn { + cl.usage.Add(1) + return &sharedPacketConn{ + PacketConn: cl.pc, + onCloseFunc: func() error { + if cl.usage.Add(-1) == 0 { + err := cl.pc.Close() + if err != nil { + return err + } + cl.onCloseFunc() } - } - cl.onCloseFunc() + return nil + }, } - return nil } // ListenerManager holds and manages the state of shared listeners. @@ -122,14 +141,14 @@ type ListenerManager interface { } type listenerManager struct { - listeners map[string]*concreteListener + listeners map[string]*listenAddr listenersMu sync.Mutex } // NewListenerManager creates a new [ListenerManger]. func NewListenerManager() ListenerManager { return &listenerManager{ - listeners: make(map[string]*concreteListener), + listeners: make(map[string]*listenAddr), } } @@ -144,18 +163,8 @@ func (m *listenerManager) ListenStream(network string, addr string) (StreamListe defer m.listenersMu.Unlock() lnKey := listenerKey(network, addr) - if lnConcrete, ok := m.listeners[lnKey]; ok { - lnConcrete.usage.Add(1) - sl := &sharedListener{ - listener: *lnConcrete.ln, - closeCh: make(chan struct{}), - onCloseFunc: func() error { - return lnConcrete.Close() - }, - } - sl.acceptCh = &atomic.Value{} - sl.acceptCh.Store(lnConcrete.acceptCh) - return sl, nil + if listenAddress, ok := m.listeners[lnKey]; ok { + return listenAddress.NewStreamListener(), nil } tcpAddr, err := net.ResolveTCPAddr("tcp", addr) @@ -167,7 +176,7 @@ func (m *listenerManager) ListenStream(network string, addr string) (StreamListe return nil, err } - lnConcrete := &concreteListener{ + listenAddress := &listenAddr{ ln: ln, acceptCh: make(chan acceptResponse), onCloseFunc: func() { @@ -176,26 +185,15 @@ func (m *listenerManager) ListenStream(network string, addr string) (StreamListe } go func() { for { - conn, err := lnConcrete.ln.AcceptTCP() + conn, err := listenAddress.ln.AcceptTCP() if errors.Is(err, net.ErrClosed) { return } - lnConcrete.acceptCh <- acceptResponse{conn, err} + listenAddress.acceptCh <- acceptResponse{conn, err} } }() - lnConcrete.usage.Store(1) - m.listeners[lnKey] = lnConcrete - - sl := &sharedListener{ - listener: *lnConcrete.ln, - closeCh: make(chan struct{}), - onCloseFunc: func() error { - return lnConcrete.Close() - }, - } - sl.acceptCh = &atomic.Value{} - sl.acceptCh.Store(lnConcrete.acceptCh) - return sl, nil + m.listeners[lnKey] = listenAddress + return listenAddress.NewStreamListener(), nil } // ListenPacket creates a new packet listener for a given network and address. @@ -206,14 +204,8 @@ func (m *listenerManager) ListenPacket(network string, addr string) (net.PacketC defer m.listenersMu.Unlock() lnKey := listenerKey(network, addr) - if lnConcrete, ok := m.listeners[lnKey]; ok { - lnConcrete.usage.Add(1) - return &sharedPacketConn{ - PacketConn: lnConcrete.pc, - onCloseFunc: func() error { - return lnConcrete.Close() - }, - }, nil + if listenAddress, ok := m.listeners[lnKey]; ok { + return listenAddress.NewPacketListener(), nil } pc, err := net.ListenPacket(network, addr) @@ -221,21 +213,14 @@ func (m *listenerManager) ListenPacket(network string, addr string) (net.PacketC return nil, err } - lnConcrete := &concreteListener{ + listenAddress := &listenAddr{ pc: pc, onCloseFunc: func() { m.delete(lnKey) }, } - lnConcrete.usage.Store(1) - m.listeners[lnKey] = lnConcrete - - return &sharedPacketConn{ - PacketConn: pc, - onCloseFunc: func() error { - return lnConcrete.Close() - }, - }, nil + m.listeners[lnKey] = listenAddress + return listenAddress.NewPacketListener(), nil } func (m *listenerManager) delete(key string) { From f165dbd73dfb57f111b34ba79a75cddfee54ce5f Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 22 Jul 2024 17:00:07 -0400 Subject: [PATCH 086/119] Remove `atomic.Value`. --- service/listeners.go | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/service/listeners.go b/service/listeners.go index 41aec8b0..6ee307f8 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -49,7 +49,7 @@ type acceptResponse struct { type sharedListener struct { listener net.TCPListener - acceptCh *atomic.Value // closed by first Close() call + acceptCh chan acceptResponse closeCh chan struct{} onCloseFunc func() error } @@ -57,7 +57,7 @@ type sharedListener struct { // Accept accepts connections until Close() is called. func (sl *sharedListener) AcceptStream() (transport.StreamConn, error) { select { - case acceptResponse := <-sl.acceptCh.Load().(chan acceptResponse): + case acceptResponse := <-sl.acceptCh: if acceptResponse.err != nil { return nil, acceptResponse.err } @@ -70,7 +70,6 @@ func (sl *sharedListener) AcceptStream() (transport.StreamConn, error) { // Close stops accepting new connections without closing the underlying socket. // Only when the last user closes it, we actually close it. func (sl *sharedListener) Close() error { - sl.acceptCh = nil close(sl.closeCh) return sl.onCloseFunc() } @@ -100,6 +99,7 @@ func (cl *listenAddr) NewStreamListener() StreamListener { cl.usage.Add(1) sl := &sharedListener{ listener: *cl.ln, + acceptCh: cl.acceptCh, closeCh: make(chan struct{}), onCloseFunc: func() error { if cl.usage.Add(-1) == 0 { @@ -112,8 +112,6 @@ func (cl *listenAddr) NewStreamListener() StreamListener { return nil }, } - sl.acceptCh = &atomic.Value{} - sl.acceptCh.Store(cl.acceptCh) return sl } From 2a2420a162dfc813e30b8a61c879a791b468f2e0 Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 22 Jul 2024 17:18:37 -0400 Subject: [PATCH 087/119] Add some missing comments. --- service/listeners.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/service/listeners.go b/service/listeners.go index 6ee307f8..b9743ca3 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -74,6 +74,7 @@ func (sl *sharedListener) Close() error { return sl.onCloseFunc() } +// Addr returns the listener's network address. func (sl *sharedListener) Addr() net.Addr { return sl.listener.Addr() } @@ -95,6 +96,7 @@ type listenAddr struct { onCloseFunc func() // Called when the listener's last user closes it. } +// NewStreamListener creates a new [StreamListener]. func (cl *listenAddr) NewStreamListener() StreamListener { cl.usage.Add(1) sl := &sharedListener{ @@ -115,6 +117,7 @@ func (cl *listenAddr) NewStreamListener() StreamListener { return sl } +// NewStreamListener creates a new [net.PacketConn]. func (cl *listenAddr) NewPacketListener() net.PacketConn { cl.usage.Add(1) return &sharedPacketConn{ From cccba1a8312c0125785e516aa396349e48f7f3cb Mon Sep 17 00:00:00 2001 From: sbruens Date: Thu, 25 Jul 2024 16:25:42 -0400 Subject: [PATCH 088/119] address review comments --- cmd/outline-ss-server/main.go | 5 ++++- service/listeners.go | 17 ++++++++++------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 2a1c32ab..5dc28daf 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -137,13 +137,16 @@ func (ls *listenerSet) ListenPacket(addr string) (net.PacketConn, error) { return ln, nil } -// Close closes all the listeners in the set. +// Close closes all the listeners in the set, after which the set can't be used again. func (ls *listenerSet) Close() error { for addr, listener := range ls.listeners { if err := listener.Close(); err != nil { return fmt.Errorf("listener on address %s failed to stop: %w", addr, err) } } + ls.listenersMu.Lock() + defer ls.listenersMu.Unlock() + ls.listeners = nil return nil } diff --git a/service/listeners.go b/service/listeners.go index b9743ca3..8c848e7d 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -35,7 +35,10 @@ type StreamListener interface { AcceptStream() (transport.StreamConn, error) // Close closes the listener. - // Any blocked Accept operations will be unblocked and return errors. + // Any blocked Accept operations will be unblocked and return errors. This + // stops the current listener from accepting new connections without closing + // the underlying socket. Only when the last user of the underlying socket + // closes it, do we actually close it. Close() error // Addr returns the listener's network address. @@ -57,18 +60,17 @@ type sharedListener struct { // Accept accepts connections until Close() is called. func (sl *sharedListener) AcceptStream() (transport.StreamConn, error) { select { - case acceptResponse := <-sl.acceptCh: - if acceptResponse.err != nil { - return nil, acceptResponse.err + case acceptResponse, ok := <-sl.acceptCh: + if !ok { + return nil, net.ErrClosed } - return acceptResponse.conn, nil + return acceptResponse.conn, acceptResponse.err case <-sl.closeCh: return nil, net.ErrClosed } } -// Close stops accepting new connections without closing the underlying socket. -// Only when the last user closes it, we actually close it. +// Close implements [StreamListener.Close]. func (sl *sharedListener) Close() error { close(sl.closeCh) return sl.onCloseFunc() @@ -188,6 +190,7 @@ func (m *listenerManager) ListenStream(network string, addr string) (StreamListe for { conn, err := listenAddress.ln.AcceptTCP() if errors.Is(err, net.ErrClosed) { + close(listenAddress.acceptCh) return } listenAddress.acceptCh <- acceptResponse{conn, err} From da4ccaadeef19de59a744df320b5aab0af9d2634 Mon Sep 17 00:00:00 2001 From: sbruens Date: Thu, 25 Jul 2024 16:38:11 -0400 Subject: [PATCH 089/119] Add type guard for `sharedListener`. --- service/listeners.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/service/listeners.go b/service/listeners.go index 8c848e7d..12253961 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -57,6 +57,8 @@ type sharedListener struct { onCloseFunc func() error } +var _ StreamListener = (*sharedListener)(nil) + // Accept accepts connections until Close() is called. func (sl *sharedListener) AcceptStream() (transport.StreamConn, error) { select { From d47f6120f060184eae1b70b62ae1e732a012c02c Mon Sep 17 00:00:00 2001 From: sbruens Date: Thu, 25 Jul 2024 16:40:58 -0400 Subject: [PATCH 090/119] Stop the existing config in a goroutine. --- cmd/outline-ss-server/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 5dc28daf..8b2c3049 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -80,7 +80,7 @@ func (s *SSServer) loadConfig(filename string) error { if err != nil { return err } - s.stopConfig() + go s.stopConfig() s.stopConfig = stopConfig return nil } From a928e2c9b9948a56c03a78248309f1de2a7f2be9 Mon Sep 17 00:00:00 2001 From: sbruens Date: Thu, 25 Jul 2024 16:44:36 -0400 Subject: [PATCH 091/119] Add a TODO to wait for all handlers to be stopped. --- cmd/outline-ss-server/main.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 8b2c3049..b2b28296 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -217,6 +217,8 @@ func (s *SSServer) runConfig(config Config) (func(), error) { } return func() { logger.Infof("Stopping running config.") + // TODO(sbruens): Actually wait for all handlers to be stopped, e.g. by + // using a https://pkg.go.dev/sync#WaitGroup. stopCh <- struct{}{} }, nil } From 98cc3a02181bb209869fcfe66eb8a04531c5b51e Mon Sep 17 00:00:00 2001 From: sbruens Date: Thu, 25 Jul 2024 16:48:51 -0400 Subject: [PATCH 092/119] Run `stopConfig` in a goroutine in `Stop()` as well. --- cmd/outline-ss-server/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index b2b28296..78bf2ef0 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -225,7 +225,7 @@ func (s *SSServer) runConfig(config Config) (func(), error) { // Stop stops serving the current config. func (s *SSServer) Stop() { - s.stopConfig() + go s.stopConfig() logger.Info("Stopped all listeners for running config") } From 48d0931e36e310be13b3160097b788fecbcc274b Mon Sep 17 00:00:00 2001 From: sbruens Date: Thu, 25 Jul 2024 16:55:07 -0400 Subject: [PATCH 093/119] Create a `TCPListener` that implements a `StreamListener`. --- service/listeners.go | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/service/listeners.go b/service/listeners.go index 12253961..79e24158 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -45,13 +45,31 @@ type StreamListener interface { Addr() net.Addr } +type TCPListener struct { + ln *net.TCPListener +} + +var _ StreamListener = (*TCPListener)(nil) + +func (t *TCPListener) AcceptStream() (transport.StreamConn, error) { + return t.ln.AcceptTCP() +} + +func (t *TCPListener) Close() error { + return t.ln.Close() +} + +func (t *TCPListener) Addr() net.Addr { + return t.ln.Addr() +} + type acceptResponse struct { conn transport.StreamConn err error } type sharedListener struct { - listener net.TCPListener + listener StreamListener acceptCh chan acceptResponse closeCh chan struct{} onCloseFunc func() error @@ -93,7 +111,7 @@ func (spc *sharedPacketConn) Close() error { } type listenAddr struct { - ln *net.TCPListener + ln StreamListener pc net.PacketConn usage atomic.Int32 acceptCh chan acceptResponse @@ -104,7 +122,7 @@ type listenAddr struct { func (cl *listenAddr) NewStreamListener() StreamListener { cl.usage.Add(1) sl := &sharedListener{ - listener: *cl.ln, + listener: cl.ln, acceptCh: cl.acceptCh, closeCh: make(chan struct{}), onCloseFunc: func() error { @@ -181,8 +199,9 @@ func (m *listenerManager) ListenStream(network string, addr string) (StreamListe return nil, err } + streamLn := &TCPListener{ln} listenAddress := &listenAddr{ - ln: ln, + ln: streamLn, acceptCh: make(chan acceptResponse), onCloseFunc: func() { m.delete(lnKey) @@ -190,7 +209,7 @@ func (m *listenerManager) ListenStream(network string, addr string) (StreamListe } go func() { for { - conn, err := listenAddress.ln.AcceptTCP() + conn, err := streamLn.AcceptStream() if errors.Is(err, net.ErrClosed) { close(listenAddress.acceptCh) return From 2dec847096b1fcfbd7f0b12e3a6539a1fa8d2e17 Mon Sep 17 00:00:00 2001 From: sbruens Date: Thu, 25 Jul 2024 17:04:14 -0400 Subject: [PATCH 094/119] Track close functions instead of the entire listener, which is not needed. --- cmd/outline-ss-server/main.go | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 78bf2ef0..ec6b66b0 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -96,9 +96,9 @@ func (s *SSServer) NewShadowsocksPacketHandler(ciphers service.CipherList) servi } type listenerSet struct { - manager service.ListenerManager - listeners map[string]service.Listener - listenersMu sync.Mutex + manager service.ListenerManager + listenerCloseFuncs map[string]func() error + listenersMu sync.Mutex } // ListenStream announces on a given TCP network address. Trying to listen on @@ -108,14 +108,14 @@ func (ls *listenerSet) ListenStream(addr string) (service.StreamListener, error) defer ls.listenersMu.Unlock() lnKey := "tcp/" + addr - if _, exists := ls.listeners[lnKey]; exists { + if _, exists := ls.listenerCloseFuncs[lnKey]; exists { return nil, fmt.Errorf("listener %s already exists", lnKey) } ln, err := ls.manager.ListenStream("tcp", addr) if err != nil { return nil, err } - ls.listeners[lnKey] = ln + ls.listenerCloseFuncs[lnKey] = ln.Close return ln, nil } @@ -126,33 +126,33 @@ func (ls *listenerSet) ListenPacket(addr string) (net.PacketConn, error) { defer ls.listenersMu.Unlock() lnKey := "udp/" + addr - if _, exists := ls.listeners[lnKey]; exists { + if _, exists := ls.listenerCloseFuncs[lnKey]; exists { return nil, fmt.Errorf("listener %s already exists", lnKey) } ln, err := ls.manager.ListenPacket("udp", addr) if err != nil { return nil, err } - ls.listeners[lnKey] = ln + ls.listenerCloseFuncs[lnKey] = ln.Close return ln, nil } // Close closes all the listeners in the set, after which the set can't be used again. func (ls *listenerSet) Close() error { - for addr, listener := range ls.listeners { - if err := listener.Close(); err != nil { + for addr, listenerCloseFunc := range ls.listenerCloseFuncs { + if err := listenerCloseFunc(); err != nil { return fmt.Errorf("listener on address %s failed to stop: %w", addr, err) } } ls.listenersMu.Lock() defer ls.listenersMu.Unlock() - ls.listeners = nil + ls.listenerCloseFuncs = nil return nil } // Len returns the number of listeners in the set. func (ls *listenerSet) Len() int { - return len(ls.listeners) + return len(ls.listenerCloseFuncs) } func (s *SSServer) runConfig(config Config) (func(), error) { @@ -161,8 +161,8 @@ func (s *SSServer) runConfig(config Config) (func(), error) { go func() { lnSet := &listenerSet{ - manager: s.lnManager, - listeners: make(map[string]service.Listener), + manager: s.lnManager, + listenerCloseFuncs: make(map[string]func() error), } defer lnSet.Close() // This closes all the listeners in the set. From ab22e47a3b68702fc9a2bd3fc6fac5f787332b8b Mon Sep 17 00:00:00 2001 From: sbruens Date: Tue, 30 Jul 2024 17:22:29 -0400 Subject: [PATCH 095/119] Delegate usage tracking to a reference counter. --- service/listeners.go | 139 ++++++++++++++++++++++++-------------- service/listeners_test.go | 46 +++++++++++++ 2 files changed, 134 insertions(+), 51 deletions(-) create mode 100644 service/listeners_test.go diff --git a/service/listeners.go b/service/listeners.go index 79e24158..6fcc3b6d 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -113,50 +113,45 @@ func (spc *sharedPacketConn) Close() error { type listenAddr struct { ln StreamListener pc net.PacketConn - usage atomic.Int32 acceptCh chan acceptResponse - onCloseFunc func() // Called when the listener's last user closes it. + onCloseFunc func() error } // NewStreamListener creates a new [StreamListener]. func (cl *listenAddr) NewStreamListener() StreamListener { - cl.usage.Add(1) sl := &sharedListener{ - listener: cl.ln, - acceptCh: cl.acceptCh, - closeCh: make(chan struct{}), - onCloseFunc: func() error { - if cl.usage.Add(-1) == 0 { - err := cl.ln.Close() - if err != nil { - return err - } - cl.onCloseFunc() - } - return nil - }, + listener: cl.ln, + acceptCh: cl.acceptCh, + closeCh: make(chan struct{}), + onCloseFunc: cl.Close, } return sl } -// NewStreamListener creates a new [net.PacketConn]. +// NewPacketListener creates a new [net.PacketConn]. func (cl *listenAddr) NewPacketListener() net.PacketConn { - cl.usage.Add(1) return &sharedPacketConn{ - PacketConn: cl.pc, - onCloseFunc: func() error { - if cl.usage.Add(-1) == 0 { - err := cl.pc.Close() - if err != nil { - return err - } - cl.onCloseFunc() - } - return nil - }, + PacketConn: cl.pc, + onCloseFunc: cl.Close, } } +func (cl *listenAddr) Close() error { + if cl.ln != nil { + err := cl.ln.Close() + if err != nil { + return err + } + } + if cl.pc != nil { + err := cl.pc.Close() + if err != nil { + return err + } + } + return cl.onCloseFunc() +} + // ListenerManager holds and manages the state of shared listeners. type ListenerManager interface { ListenStream(network string, addr string) (StreamListener, error) @@ -164,14 +159,14 @@ type ListenerManager interface { } type listenerManager struct { - listeners map[string]*listenAddr + listeners map[string]RefCount[*listenAddr] listenersMu sync.Mutex } // NewListenerManager creates a new [ListenerManger]. func NewListenerManager() ListenerManager { return &listenerManager{ - listeners: make(map[string]*listenAddr), + listeners: make(map[string]RefCount[*listenAddr]), } } @@ -185,9 +180,10 @@ func (m *listenerManager) ListenStream(network string, addr string) (StreamListe m.listenersMu.Lock() defer m.listenersMu.Unlock() - lnKey := listenerKey(network, addr) - if listenAddress, ok := m.listeners[lnKey]; ok { - return listenAddress.NewStreamListener(), nil + lnKey := network + "/" + addr + if lnRefCount, ok := m.listeners[lnKey]; ok { + lnAddr := lnRefCount.Acquire().Get() + return lnAddr.NewStreamListener(), nil } tcpAddr, err := net.ResolveTCPAddr("tcp", addr) @@ -200,25 +196,27 @@ func (m *listenerManager) ListenStream(network string, addr string) (StreamListe } streamLn := &TCPListener{ln} - listenAddress := &listenAddr{ + lnRefCount := NewRefCount(&listenAddr{ ln: streamLn, acceptCh: make(chan acceptResponse), - onCloseFunc: func() { + onCloseFunc: func() error { m.delete(lnKey) + return nil }, - } + }) + lnAddr := lnRefCount.Get() go func() { for { conn, err := streamLn.AcceptStream() if errors.Is(err, net.ErrClosed) { - close(listenAddress.acceptCh) + close(lnAddr.acceptCh) return } - listenAddress.acceptCh <- acceptResponse{conn, err} + lnAddr.acceptCh <- acceptResponse{conn, err} } }() - m.listeners[lnKey] = listenAddress - return listenAddress.NewStreamListener(), nil + m.listeners[lnKey] = lnRefCount + return lnAddr.NewStreamListener(), nil } // ListenPacket creates a new packet listener for a given network and address. @@ -228,9 +226,10 @@ func (m *listenerManager) ListenPacket(network string, addr string) (net.PacketC m.listenersMu.Lock() defer m.listenersMu.Unlock() - lnKey := listenerKey(network, addr) - if listenAddress, ok := m.listeners[lnKey]; ok { - return listenAddress.NewPacketListener(), nil + lnKey := network + "/" + addr + if lnRefCount, ok := m.listeners[lnKey]; ok { + lnAddr := lnRefCount.Acquire().Get() + return lnAddr.NewPacketListener(), nil } pc, err := net.ListenPacket(network, addr) @@ -238,14 +237,15 @@ func (m *listenerManager) ListenPacket(network string, addr string) (net.PacketC return nil, err } - listenAddress := &listenAddr{ + lnRefCount := NewRefCount(&listenAddr{ pc: pc, - onCloseFunc: func() { + onCloseFunc: func() error { m.delete(lnKey) + return nil }, - } - m.listeners[lnKey] = listenAddress - return listenAddress.NewPacketListener(), nil + }) + m.listeners[lnKey] = lnRefCount + return lnRefCount.Get().NewPacketListener(), nil } func (m *listenerManager) delete(key string) { @@ -254,6 +254,43 @@ func (m *listenerManager) delete(key string) { m.listenersMu.Unlock() } -func listenerKey(network string, addr string) string { - return network + "/" + addr +type RefCount[T io.Closer] interface { + io.Closer + + Acquire() RefCount[T] + Get() T +} + +func NewRefCount[T io.Closer](value T) RefCount[T] { + res := &refCount[T]{ + count: &atomic.Int32{}, + value: value, + } + res.count.Store(1) + return res +} + +type refCount[T io.Closer] struct { + count *atomic.Int32 + value T +} + +func (r refCount[T]) Close() error { + if count := r.count.Add(-1); count == 0 { + return r.value.Close() + } + + return nil +} + +func (r refCount[T]) Acquire() RefCount[T] { + r.count.Add(1) + return &refCount[T]{ + count: r.count, + value: r.value, + } +} + +func (r refCount[T]) Get() T { + return r.value } diff --git a/service/listeners_test.go b/service/listeners_test.go new file mode 100644 index 00000000..9005eb34 --- /dev/null +++ b/service/listeners_test.go @@ -0,0 +1,46 @@ +// Copyright 2024 The Outline Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +type testRefCount struct { + onCloseFunc func() +} + +func (t *testRefCount) Close() error { + t.onCloseFunc() + return nil +} + +func TestRefCount(t *testing.T) { + var done bool + rc := NewRefCount[*testRefCount](&testRefCount{ + onCloseFunc: func() { + done = true + }, + }) + rc.Acquire() + + require.NoError(t, rc.Close()) + require.False(t, done) + + require.NoError(t, rc.Close()) + require.True(t, done) +} From 3c2a3efc5323888a752beff35cca903ff7a5e78d Mon Sep 17 00:00:00 2001 From: sbruens Date: Wed, 31 Jul 2024 12:13:17 -0400 Subject: [PATCH 096/119] Remove the `Get()` method from `refCount`. --- service/listeners.go | 44 +++++++++++++++++++------------------------- 1 file changed, 19 insertions(+), 25 deletions(-) diff --git a/service/listeners.go b/service/listeners.go index 6fcc3b6d..7752fe83 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -182,7 +182,7 @@ func (m *listenerManager) ListenStream(network string, addr string) (StreamListe lnKey := network + "/" + addr if lnRefCount, ok := m.listeners[lnKey]; ok { - lnAddr := lnRefCount.Acquire().Get() + lnAddr := lnRefCount.Acquire() return lnAddr.NewStreamListener(), nil } @@ -196,15 +196,14 @@ func (m *listenerManager) ListenStream(network string, addr string) (StreamListe } streamLn := &TCPListener{ln} - lnRefCount := NewRefCount(&listenAddr{ + lnAddr := &listenAddr{ ln: streamLn, acceptCh: make(chan acceptResponse), onCloseFunc: func() error { m.delete(lnKey) return nil }, - }) - lnAddr := lnRefCount.Get() + } go func() { for { conn, err := streamLn.AcceptStream() @@ -215,7 +214,7 @@ func (m *listenerManager) ListenStream(network string, addr string) (StreamListe lnAddr.acceptCh <- acceptResponse{conn, err} } }() - m.listeners[lnKey] = lnRefCount + m.listeners[lnKey] = NewRefCount(lnAddr) return lnAddr.NewStreamListener(), nil } @@ -228,7 +227,7 @@ func (m *listenerManager) ListenPacket(network string, addr string) (net.PacketC lnKey := network + "/" + addr if lnRefCount, ok := m.listeners[lnKey]; ok { - lnAddr := lnRefCount.Acquire().Get() + lnAddr := lnRefCount.Acquire() return lnAddr.NewPacketListener(), nil } @@ -237,15 +236,15 @@ func (m *listenerManager) ListenPacket(network string, addr string) (net.PacketC return nil, err } - lnRefCount := NewRefCount(&listenAddr{ + lnAddr := &listenAddr{ pc: pc, onCloseFunc: func() error { m.delete(lnKey) return nil }, - }) - m.listeners[lnKey] = lnRefCount - return lnRefCount.Get().NewPacketListener(), nil + } + m.listeners[lnKey] = NewRefCount(lnAddr) + return lnAddr.NewPacketListener(), nil } func (m *listenerManager) delete(key string) { @@ -254,11 +253,18 @@ func (m *listenerManager) delete(key string) { m.listenersMu.Unlock() } +// RefCount is an atomic reference counter that can be used to track a shared +// [io.Closer] resource. type RefCount[T io.Closer] interface { io.Closer - Acquire() RefCount[T] - Get() T + // Acquire increases the ref count and returns the wrapped object. + Acquire() T +} + +type refCount[T io.Closer] struct { + count *atomic.Int32 + value T } func NewRefCount[T io.Closer](value T) RefCount[T] { @@ -270,11 +276,6 @@ func NewRefCount[T io.Closer](value T) RefCount[T] { return res } -type refCount[T io.Closer] struct { - count *atomic.Int32 - value T -} - func (r refCount[T]) Close() error { if count := r.count.Add(-1); count == 0 { return r.value.Close() @@ -283,14 +284,7 @@ func (r refCount[T]) Close() error { return nil } -func (r refCount[T]) Acquire() RefCount[T] { +func (r refCount[T]) Acquire() T { r.count.Add(1) - return &refCount[T]{ - count: r.count, - value: r.value, - } -} - -func (r refCount[T]) Get() T { return r.value } From 5e282f1debda6c247da82b09e52788dd0d42d3c8 Mon Sep 17 00:00:00 2001 From: sbruens Date: Wed, 31 Jul 2024 12:18:18 -0400 Subject: [PATCH 097/119] Return immediately. --- service/listeners.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/service/listeners.go b/service/listeners.go index 7752fe83..80873696 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -119,13 +119,12 @@ type listenAddr struct { // NewStreamListener creates a new [StreamListener]. func (cl *listenAddr) NewStreamListener() StreamListener { - sl := &sharedListener{ + return &sharedListener{ listener: cl.ln, acceptCh: cl.acceptCh, closeCh: make(chan struct{}), onCloseFunc: cl.Close, } - return sl } // NewPacketListener creates a new [net.PacketConn]. From 547e9e67484cad2ffc76257766eab1c29db0e804 Mon Sep 17 00:00:00 2001 From: sbruens Date: Wed, 31 Jul 2024 13:26:53 -0400 Subject: [PATCH 098/119] Rename `shared` to `virtual` as they are not actually shared. --- service/listeners.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/service/listeners.go b/service/listeners.go index 80873696..680c006b 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -68,17 +68,17 @@ type acceptResponse struct { err error } -type sharedListener struct { +type virtualStreamListener struct { listener StreamListener acceptCh chan acceptResponse closeCh chan struct{} onCloseFunc func() error } -var _ StreamListener = (*sharedListener)(nil) +var _ StreamListener = (*virtualStreamListener)(nil) // Accept accepts connections until Close() is called. -func (sl *sharedListener) AcceptStream() (transport.StreamConn, error) { +func (sl *virtualStreamListener) AcceptStream() (transport.StreamConn, error) { select { case acceptResponse, ok := <-sl.acceptCh: if !ok { @@ -91,22 +91,22 @@ func (sl *sharedListener) AcceptStream() (transport.StreamConn, error) { } // Close implements [StreamListener.Close]. -func (sl *sharedListener) Close() error { +func (sl *virtualStreamListener) Close() error { close(sl.closeCh) return sl.onCloseFunc() } // Addr returns the listener's network address. -func (sl *sharedListener) Addr() net.Addr { +func (sl *virtualStreamListener) Addr() net.Addr { return sl.listener.Addr() } -type sharedPacketConn struct { +type virtualPacketConn struct { net.PacketConn onCloseFunc func() error } -func (spc *sharedPacketConn) Close() error { +func (spc *virtualPacketConn) Close() error { return spc.onCloseFunc() } @@ -119,7 +119,7 @@ type listenAddr struct { // NewStreamListener creates a new [StreamListener]. func (cl *listenAddr) NewStreamListener() StreamListener { - return &sharedListener{ + return &virtualStreamListener{ listener: cl.ln, acceptCh: cl.acceptCh, closeCh: make(chan struct{}), @@ -129,7 +129,7 @@ func (cl *listenAddr) NewStreamListener() StreamListener { // NewPacketListener creates a new [net.PacketConn]. func (cl *listenAddr) NewPacketListener() net.PacketConn { - return &sharedPacketConn{ + return &virtualPacketConn{ PacketConn: cl.pc, onCloseFunc: cl.Close, } From c6774c8692197ab2b27ec4ddcf9a397a13d60e34 Mon Sep 17 00:00:00 2001 From: sbruens Date: Wed, 31 Jul 2024 13:33:30 -0400 Subject: [PATCH 099/119] Simplify `listenAddr`. --- service/listeners.go | 40 ++++++++++++++++++---------------------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/service/listeners.go b/service/listeners.go index 680c006b..9647285c 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -111,42 +111,38 @@ func (spc *virtualPacketConn) Close() error { } type listenAddr struct { - ln StreamListener - pc net.PacketConn + ln Listener acceptCh chan acceptResponse onCloseFunc func() error } // NewStreamListener creates a new [StreamListener]. func (cl *listenAddr) NewStreamListener() StreamListener { - return &virtualStreamListener{ - listener: cl.ln, - acceptCh: cl.acceptCh, - closeCh: make(chan struct{}), - onCloseFunc: cl.Close, + if ln, ok := cl.ln.(StreamListener); ok { + return &virtualStreamListener{ + listener: ln, + acceptCh: cl.acceptCh, + closeCh: make(chan struct{}), + onCloseFunc: cl.Close, + } } + return nil } // NewPacketListener creates a new [net.PacketConn]. func (cl *listenAddr) NewPacketListener() net.PacketConn { - return &virtualPacketConn{ - PacketConn: cl.pc, - onCloseFunc: cl.Close, + if ln, ok := cl.ln.(net.PacketConn); ok { + return &virtualPacketConn{ + PacketConn: ln, + onCloseFunc: cl.Close, + } } + return nil } func (cl *listenAddr) Close() error { - if cl.ln != nil { - err := cl.ln.Close() - if err != nil { - return err - } - } - if cl.pc != nil { - err := cl.pc.Close() - if err != nil { - return err - } + if err := cl.ln.Close(); err != nil { + return err } return cl.onCloseFunc() } @@ -236,7 +232,7 @@ func (m *listenerManager) ListenPacket(network string, addr string) (net.PacketC } lnAddr := &listenAddr{ - pc: pc, + ln: pc, onCloseFunc: func() error { m.delete(lnKey) return nil From df2f9d0c3ecc15c07d6413460029ac9ee0af8023 Mon Sep 17 00:00:00 2001 From: sbruens Date: Wed, 31 Jul 2024 14:24:56 -0400 Subject: [PATCH 100/119] Fix use of the ref count. --- service/listeners.go | 174 +++++++++++++++++++++++++------------------ 1 file changed, 103 insertions(+), 71 deletions(-) diff --git a/service/listeners.go b/service/listeners.go index 9647285c..e5d48e45 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -16,6 +16,7 @@ package service import ( "errors" + "fmt" "io" "net" "sync" @@ -68,16 +69,17 @@ type acceptResponse struct { err error } +type OnCloseFunc func() error + type virtualStreamListener struct { listener StreamListener acceptCh chan acceptResponse closeCh chan struct{} - onCloseFunc func() error + onCloseFunc OnCloseFunc } var _ StreamListener = (*virtualStreamListener)(nil) -// Accept accepts connections until Close() is called. func (sl *virtualStreamListener) AcceptStream() (transport.StreamConn, error) { select { case acceptResponse, ok := <-sl.acceptCh: @@ -90,20 +92,18 @@ func (sl *virtualStreamListener) AcceptStream() (transport.StreamConn, error) { } } -// Close implements [StreamListener.Close]. func (sl *virtualStreamListener) Close() error { close(sl.closeCh) return sl.onCloseFunc() } -// Addr returns the listener's network address. func (sl *virtualStreamListener) Addr() net.Addr { return sl.listener.Addr() } type virtualPacketConn struct { net.PacketConn - onCloseFunc func() error + onCloseFunc OnCloseFunc } func (spc *virtualPacketConn) Close() error { @@ -113,38 +113,50 @@ func (spc *virtualPacketConn) Close() error { type listenAddr struct { ln Listener acceptCh chan acceptResponse - onCloseFunc func() error + onCloseFunc OnCloseFunc +} + +type canCreateStreamListener interface { + NewStreamListener(onCloseFunc OnCloseFunc) StreamListener } +var _ canCreateStreamListener = (*listenAddr)(nil) + // NewStreamListener creates a new [StreamListener]. -func (cl *listenAddr) NewStreamListener() StreamListener { - if ln, ok := cl.ln.(StreamListener); ok { +func (la *listenAddr) NewStreamListener(onCloseFunc OnCloseFunc) StreamListener { + if ln, ok := la.ln.(StreamListener); ok { return &virtualStreamListener{ listener: ln, - acceptCh: cl.acceptCh, + acceptCh: la.acceptCh, closeCh: make(chan struct{}), - onCloseFunc: cl.Close, + onCloseFunc: onCloseFunc, } } return nil } +type canCreatePacketListener interface { + NewPacketListener(onCloseFunc OnCloseFunc) net.PacketConn +} + +var _ canCreatePacketListener = (*listenAddr)(nil) + // NewPacketListener creates a new [net.PacketConn]. -func (cl *listenAddr) NewPacketListener() net.PacketConn { +func (cl *listenAddr) NewPacketListener(onCloseFunc OnCloseFunc) net.PacketConn { if ln, ok := cl.ln.(net.PacketConn); ok { return &virtualPacketConn{ PacketConn: ln, - onCloseFunc: cl.Close, + onCloseFunc: onCloseFunc, } } return nil } -func (cl *listenAddr) Close() error { - if err := cl.ln.Close(); err != nil { +func (la *listenAddr) Close() error { + if err := la.ln.Close(); err != nil { return err } - return cl.onCloseFunc() + return la.onCloseFunc() } // ListenerManager holds and manages the state of shared listeners. @@ -154,92 +166,106 @@ type ListenerManager interface { } type listenerManager struct { - listeners map[string]RefCount[*listenAddr] + listeners map[string]RefCount[Listener] listenersMu sync.Mutex } // NewListenerManager creates a new [ListenerManger]. func NewListenerManager() ListenerManager { return &listenerManager{ - listeners: make(map[string]RefCount[*listenAddr]), + listeners: make(map[string]RefCount[Listener]), } } -// ListenStream creates a new stream listener for a given network and address. -// -// Listeners can overlap one another, because during config changes the new -// config is started before the old config is destroyed. This is done by using -// reusable listener wrappers, which do not actually close the underlying socket -// until all uses of the shared listener have been closed. -func (m *listenerManager) ListenStream(network string, addr string) (StreamListener, error) { +func (m *listenerManager) getOrCreate(key string, createFunc func() (Listener, error)) (RefCount[Listener], error) { m.listenersMu.Lock() defer m.listenersMu.Unlock() - lnKey := network + "/" + addr - if lnRefCount, ok := m.listeners[lnKey]; ok { - lnAddr := lnRefCount.Acquire() - return lnAddr.NewStreamListener(), nil + if lnRefCount, exists := m.listeners[key]; exists { + return lnRefCount.Acquire(), nil } - tcpAddr, err := net.ResolveTCPAddr("tcp", addr) + ln, err := createFunc() if err != nil { return nil, err } - ln, err := net.ListenTCP(network, tcpAddr) + lnRefCount := NewRefCount(ln) + m.listeners[key] = lnRefCount + return lnRefCount, nil +} + +// ListenStream creates a new stream listener for a given network and address. +// +// Listeners can overlap one another, because during config changes the new +// config is started before the old config is destroyed. This is done by using +// reusable listener wrappers, which do not actually close the underlying socket +// until all uses of the shared listener have been closed. +func (m *listenerManager) ListenStream(network string, addr string) (StreamListener, error) { + lnKey := network + "/" + addr + lnRefCount, err := m.getOrCreate(lnKey, func() (Listener, error) { + tcpAddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + return nil, err + } + ln, err := net.ListenTCP(network, tcpAddr) + if err != nil { + return nil, err + } + streamLn := &TCPListener{ln} + lnAddr := &listenAddr{ + ln: streamLn, + acceptCh: make(chan acceptResponse), + onCloseFunc: func() error { + m.delete(lnKey) + return nil + }, + } + go func() { + for { + conn, err := streamLn.AcceptStream() + if errors.Is(err, net.ErrClosed) { + close(lnAddr.acceptCh) + return + } + lnAddr.acceptCh <- acceptResponse{conn, err} + } + }() + return lnAddr, nil + }) if err != nil { return nil, err } - - streamLn := &TCPListener{ln} - lnAddr := &listenAddr{ - ln: streamLn, - acceptCh: make(chan acceptResponse), - onCloseFunc: func() error { - m.delete(lnKey) - return nil - }, + if lnAddr, ok := lnRefCount.Get().(canCreateStreamListener); ok { + return lnAddr.NewStreamListener(lnRefCount.Close), nil } - go func() { - for { - conn, err := streamLn.AcceptStream() - if errors.Is(err, net.ErrClosed) { - close(lnAddr.acceptCh) - return - } - lnAddr.acceptCh <- acceptResponse{conn, err} - } - }() - m.listeners[lnKey] = NewRefCount(lnAddr) - return lnAddr.NewStreamListener(), nil + return nil, fmt.Errorf("unable to create stream listener for %s", lnKey) } // ListenPacket creates a new packet listener for a given network and address. // // See notes on [ListenStream]. func (m *listenerManager) ListenPacket(network string, addr string) (net.PacketConn, error) { - m.listenersMu.Lock() - defer m.listenersMu.Unlock() - lnKey := network + "/" + addr - if lnRefCount, ok := m.listeners[lnKey]; ok { - lnAddr := lnRefCount.Acquire() - return lnAddr.NewPacketListener(), nil - } - - pc, err := net.ListenPacket(network, addr) + lnRefCount, err := m.getOrCreate(lnKey, func() (Listener, error) { + pc, err := net.ListenPacket(network, addr) + if err != nil { + return nil, err + } + return &listenAddr{ + ln: pc, + onCloseFunc: func() error { + m.delete(lnKey) + return nil + }, + }, nil + }) if err != nil { return nil, err } - - lnAddr := &listenAddr{ - ln: pc, - onCloseFunc: func() error { - m.delete(lnKey) - return nil - }, + if lnAddr, ok := lnRefCount.Get().(canCreatePacketListener); ok { + return lnAddr.NewPacketListener(lnRefCount.Close), nil } - m.listeners[lnKey] = NewRefCount(lnAddr) - return lnAddr.NewPacketListener(), nil + return nil, fmt.Errorf("unable to create packet listener for %s", lnKey) } func (m *listenerManager) delete(key string) { @@ -254,7 +280,9 @@ type RefCount[T io.Closer] interface { io.Closer // Acquire increases the ref count and returns the wrapped object. - Acquire() T + Acquire() RefCount[T] + + Get() T } type refCount[T io.Closer] struct { @@ -279,7 +307,11 @@ func (r refCount[T]) Close() error { return nil } -func (r refCount[T]) Acquire() T { +func (r refCount[T]) Acquire() RefCount[T] { r.count.Add(1) + return r +} + +func (r refCount[T]) Get() T { return r.value } From c678372d7a9d0feeb12b5d92dd29d6caea91dfd5 Mon Sep 17 00:00:00 2001 From: sbruens Date: Wed, 31 Jul 2024 16:05:00 -0400 Subject: [PATCH 101/119] Add simple test case for early closing of stream listener. --- service/listeners.go | 19 ++++++++++--------- service/listeners_test.go | 12 ++++++++++++ 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/service/listeners.go b/service/listeners.go index e5d48e45..44fa7ec5 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -161,7 +161,17 @@ func (la *listenAddr) Close() error { // ListenerManager holds and manages the state of shared listeners. type ListenerManager interface { + // ListenStream creates a new stream listener for a given network and address. + // + // Listeners can overlap one another, because during config changes the new + // config is started before the old config is destroyed. This is done by using + // reusable listener wrappers, which do not actually close the underlying socket + // until all uses of the shared listener have been closed. ListenStream(network string, addr string) (StreamListener, error) + + // ListenPacket creates a new packet listener for a given network and address. + // + // See notes on [ListenStream]. ListenPacket(network string, addr string) (net.PacketConn, error) } @@ -194,12 +204,6 @@ func (m *listenerManager) getOrCreate(key string, createFunc func() (Listener, e return lnRefCount, nil } -// ListenStream creates a new stream listener for a given network and address. -// -// Listeners can overlap one another, because during config changes the new -// config is started before the old config is destroyed. This is done by using -// reusable listener wrappers, which do not actually close the underlying socket -// until all uses of the shared listener have been closed. func (m *listenerManager) ListenStream(network string, addr string) (StreamListener, error) { lnKey := network + "/" + addr lnRefCount, err := m.getOrCreate(lnKey, func() (Listener, error) { @@ -241,9 +245,6 @@ func (m *listenerManager) ListenStream(network string, addr string) (StreamListe return nil, fmt.Errorf("unable to create stream listener for %s", lnKey) } -// ListenPacket creates a new packet listener for a given network and address. -// -// See notes on [ListenStream]. func (m *listenerManager) ListenPacket(network string, addr string) (net.PacketConn, error) { lnKey := network + "/" + addr lnRefCount, err := m.getOrCreate(lnKey, func() (Listener, error) { diff --git a/service/listeners_test.go b/service/listeners_test.go index 9005eb34..3e2c3e64 100644 --- a/service/listeners_test.go +++ b/service/listeners_test.go @@ -15,11 +15,23 @@ package service import ( + "net" "testing" "github.com/stretchr/testify/require" ) +func TestListenerManagerStreamListenerEarlyClose(t *testing.T) { + m := NewListenerManager() + ln, err := m.ListenStream("tcp", "127.0.0.1:0") + require.NoError(t, err) + + ln.Close() + _, err = ln.AcceptStream() + + require.ErrorIs(t, err, net.ErrClosed) +} + type testRefCount struct { onCloseFunc func() } From e41ababc48eaa64c2b838a4087ca8019743541a8 Mon Sep 17 00:00:00 2001 From: sbruens Date: Wed, 31 Jul 2024 16:28:32 -0400 Subject: [PATCH 102/119] Add tests for creating stream listeners. --- service/listeners_test.go | 55 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/service/listeners_test.go b/service/listeners_test.go index 3e2c3e64..384126ad 100644 --- a/service/listeners_test.go +++ b/service/listeners_test.go @@ -15,6 +15,7 @@ package service import ( + "fmt" "net" "testing" @@ -32,6 +33,60 @@ func TestListenerManagerStreamListenerEarlyClose(t *testing.T) { require.ErrorIs(t, err, net.ErrClosed) } +func writeTestPayload(ln StreamListener) error { + conn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + return fmt.Errorf("Failed to dial %v: %v", ln.Addr().String(), err) + } + if _, err = conn.Write(makeTestPayload(50)); err != nil { + return fmt.Errorf("Failed to write to connection: %v", err) + } + conn.Close() + return nil +} + +func TestListenerManagerStreamListenerNotClosedIfStillInUse(t *testing.T) { + m := NewListenerManager() + ln, err := m.ListenStream("tcp", "127.0.0.1:0") + require.NoError(t, err) + ln2, err := m.ListenStream("tcp", "127.0.0.1:0") + require.NoError(t, err) + + // Close only the first listener. + ln.Close() + done := make(chan struct{}) + go func() { + ln2.AcceptStream() + done <- struct{}{} + }() + + err = writeTestPayload(ln2) + require.NoError(t, err) + + <-done +} + +func TestListenerManagerStreamListenerCreatesListenerOnDemand(t *testing.T) { + m := NewListenerManager() + // Create a listener and immediately close it. + ln, err := m.ListenStream("tcp", "127.0.0.1:0") + require.NoError(t, err) + ln.Close() + // Now create another listener on the same address. + ln2, err := m.ListenStream("tcp", "127.0.0.1:0") + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + ln2.AcceptStream() + done <- struct{}{} + }() + err = writeTestPayload(ln2) + require.NoError(t, err) + + <-done +} + type testRefCount struct { onCloseFunc func() } From f9432d23a51c29d7b3e7021fa7c0ac49675fb96f Mon Sep 17 00:00:00 2001 From: sbruens Date: Wed, 31 Jul 2024 17:35:42 -0400 Subject: [PATCH 103/119] Create handlers on demand. --- cmd/outline-ss-server/main.go | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index b706ccb9..4a38ad40 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -256,16 +256,10 @@ func (s *SSServer) runConfig(config Config) (func(), error) { } for _, serviceConfig := range config.Services { - // TODO: Create the handlers on demand. - sh, err := s.NewShadowsocksStreamHandlerFromConfig(serviceConfig) - if err != nil { - return err - } - ph, err := s.NewShadowsocksPacketHandlerFromConfig(serviceConfig) - if err != nil { - return err - } - + var ( + sh service.StreamHandler + ph service.PacketHandler + ) for _, lnConfig := range serviceConfig.Listeners { switch lnConfig.Type { case listenerTypeTCP: @@ -274,6 +268,12 @@ func (s *SSServer) runConfig(config Config) (func(), error) { return err } logger.Infof("TCP service listening on %s", ln.Addr().String()) + if sh == nil { + sh, err = s.NewShadowsocksStreamHandlerFromConfig(serviceConfig) + if err != nil { + return err + } + } go service.StreamServe(ln.AcceptStream, sh.Handle) case listenerTypeUDP: pc, err := lnSet.ListenPacket(lnConfig.Address) @@ -281,6 +281,12 @@ func (s *SSServer) runConfig(config Config) (func(), error) { return err } logger.Infof("UDP service listening on %v", pc.LocalAddr().String()) + if ph == nil { + ph, err = s.NewShadowsocksPacketHandlerFromConfig(serviceConfig) + if err != nil { + return err + } + } go ph.Handle(pc) } } From 6b11f4ffc4bd2c9d25500a8512c35da742211715 Mon Sep 17 00:00:00 2001 From: sbruens Date: Fri, 2 Aug 2024 11:51:35 -0400 Subject: [PATCH 104/119] Refactor create methods. --- service/listeners.go | 117 ++++++++++++++++++++++++------------------- 1 file changed, 66 insertions(+), 51 deletions(-) diff --git a/service/listeners.go b/service/listeners.go index 44fa7ec5..8f0882ba 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -187,86 +187,97 @@ func NewListenerManager() ListenerManager { } } -func (m *listenerManager) getOrCreate(key string, createFunc func() (Listener, error)) (RefCount[Listener], error) { +func (m *listenerManager) newStreamListener(network string, addr string) (Listener, error) { + tcpAddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + return nil, err + } + ln, err := net.ListenTCP(network, tcpAddr) + if err != nil { + return nil, err + } + streamLn := &TCPListener{ln} + lnAddr := &listenAddr{ + ln: streamLn, + acceptCh: make(chan acceptResponse), + onCloseFunc: func() error { + m.delete(listenerKey(network, addr)) + return nil + }, + } + go func() { + for { + conn, err := streamLn.AcceptStream() + if errors.Is(err, net.ErrClosed) { + close(lnAddr.acceptCh) + return + } + lnAddr.acceptCh <- acceptResponse{conn, err} + } + }() + return lnAddr, nil +} + +func (m *listenerManager) newPacketListener(network string, addr string) (Listener, error) { + pc, err := net.ListenPacket(network, addr) + if err != nil { + return nil, err + } + return &listenAddr{ + ln: pc, + onCloseFunc: func() error { + m.delete(listenerKey(network, addr)) + return nil + }, + }, nil +} + +func (m *listenerManager) getListener(network string, addr string) (RefCount[Listener], error) { m.listenersMu.Lock() defer m.listenersMu.Unlock() - if lnRefCount, exists := m.listeners[key]; exists { + lnKey := listenerKey(network, addr) + if lnRefCount, exists := m.listeners[lnKey]; exists { return lnRefCount.Acquire(), nil } - ln, err := createFunc() + var ( + ln Listener + err error + ) + if network == "tcp" { + ln, err = m.newStreamListener(network, addr) + } else { + ln, err = m.newPacketListener(network, addr) + } if err != nil { return nil, err } lnRefCount := NewRefCount(ln) - m.listeners[key] = lnRefCount + m.listeners[lnKey] = lnRefCount return lnRefCount, nil } func (m *listenerManager) ListenStream(network string, addr string) (StreamListener, error) { - lnKey := network + "/" + addr - lnRefCount, err := m.getOrCreate(lnKey, func() (Listener, error) { - tcpAddr, err := net.ResolveTCPAddr("tcp", addr) - if err != nil { - return nil, err - } - ln, err := net.ListenTCP(network, tcpAddr) - if err != nil { - return nil, err - } - streamLn := &TCPListener{ln} - lnAddr := &listenAddr{ - ln: streamLn, - acceptCh: make(chan acceptResponse), - onCloseFunc: func() error { - m.delete(lnKey) - return nil - }, - } - go func() { - for { - conn, err := streamLn.AcceptStream() - if errors.Is(err, net.ErrClosed) { - close(lnAddr.acceptCh) - return - } - lnAddr.acceptCh <- acceptResponse{conn, err} - } - }() - return lnAddr, nil - }) + lnRefCount, err := m.getListener(network, addr) if err != nil { return nil, err } if lnAddr, ok := lnRefCount.Get().(canCreateStreamListener); ok { return lnAddr.NewStreamListener(lnRefCount.Close), nil } - return nil, fmt.Errorf("unable to create stream listener for %s", lnKey) + return nil, fmt.Errorf("unable to create stream listener for %s/%s", network, addr) } func (m *listenerManager) ListenPacket(network string, addr string) (net.PacketConn, error) { - lnKey := network + "/" + addr - lnRefCount, err := m.getOrCreate(lnKey, func() (Listener, error) { - pc, err := net.ListenPacket(network, addr) - if err != nil { - return nil, err - } - return &listenAddr{ - ln: pc, - onCloseFunc: func() error { - m.delete(lnKey) - return nil - }, - }, nil - }) + lnRefCount, err := m.getListener(network, addr) if err != nil { return nil, err } if lnAddr, ok := lnRefCount.Get().(canCreatePacketListener); ok { return lnAddr.NewPacketListener(lnRefCount.Close), nil } - return nil, fmt.Errorf("unable to create packet listener for %s", lnKey) + return nil, fmt.Errorf("unable to create packet listener for %s/%s", network, addr) } func (m *listenerManager) delete(key string) { @@ -275,6 +286,10 @@ func (m *listenerManager) delete(key string) { m.listenersMu.Unlock() } +func listenerKey(network string, addr string) string { + return network + "/" + addr +} + // RefCount is an atomic reference counter that can be used to track a shared // [io.Closer] resource. type RefCount[T io.Closer] interface { From 992471380885abcecbc95026b70be67a3b93cdb5 Mon Sep 17 00:00:00 2001 From: sbruens Date: Fri, 2 Aug 2024 12:13:14 -0400 Subject: [PATCH 105/119] Use `errors.Is()`. --- service/listeners.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/service/listeners.go b/service/listeners.go index 204996e4..387fbb3e 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -139,7 +139,7 @@ func (l *ProxyStreamListener) AcceptStream() (ClientStreamConn, error) { } r := bufio.NewReader(conn) h, err := proxyproto.Read(r) - if err == proxyproto.ErrNoProxyProtocol { + if errors.Is(err, proxyproto.ErrNoProxyProtocol) { logger.Warningf("Received connection from %v without proxy header.", conn.RemoteAddr()) return conn, nil } From c8d8332c3e4eb295e09d4ff444fb30a51b335c2e Mon Sep 17 00:00:00 2001 From: sbruens Date: Fri, 2 Aug 2024 18:41:53 -0400 Subject: [PATCH 106/119] Fix proxy post-merge. --- cmd/outline-ss-server/config.go | 2 +- cmd/outline-ss-server/main.go | 97 +++++++++++++++++++++------------ service/listeners.go | 34 +++++++----- 3 files changed, 82 insertions(+), 51 deletions(-) diff --git a/cmd/outline-ss-server/config.go b/cmd/outline-ss-server/config.go index 64e70499..5cc4c34c 100644 --- a/cmd/outline-ss-server/config.go +++ b/cmd/outline-ss-server/config.go @@ -38,7 +38,7 @@ const ( type ListenerConfig struct { Type ListenerType Address string - Listeners []*ListenerConfig + Listeners []ListenerConfig } // Validate checks that the config is valid. diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 4a38ad40..ca184209 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -153,7 +153,7 @@ type listenerSet struct { // ListenStream announces on a given TCP network address. Trying to listen on // the same address twice will result in an error. -func (ls *listenerSet) ListenStream(addr string) (service.StreamListener, error) { +func (ls *listenerSet) ListenStream(addr string, proxy bool) (service.StreamListener, error) { ls.listenersMu.Lock() defer ls.listenersMu.Unlock() @@ -161,7 +161,7 @@ func (ls *listenerSet) ListenStream(addr string) (service.StreamListener, error) if _, exists := ls.listenerCloseFuncs[lnKey]; exists { return nil, fmt.Errorf("listener %s already exists", lnKey) } - ln, err := ls.manager.ListenStream("tcp", addr) + ln, err := ls.manager.ListenStream("tcp", addr, proxy) if err != nil { return nil, err } @@ -171,7 +171,7 @@ func (ls *listenerSet) ListenStream(addr string) (service.StreamListener, error) // ListenPacket announces on a given UDP network address. Trying to listen on // the same address twice will result in an error. -func (ls *listenerSet) ListenPacket(addr string) (net.PacketConn, error) { +func (ls *listenerSet) ListenPacket(addr string, proxy bool) (net.PacketConn, error) { ls.listenersMu.Lock() defer ls.listenersMu.Unlock() @@ -179,7 +179,7 @@ func (ls *listenerSet) ListenPacket(addr string) (net.PacketConn, error) { if _, exists := ls.listenerCloseFuncs[lnKey]; exists { return nil, fmt.Errorf("listener %s already exists", lnKey) } - ln, err := ls.manager.ListenPacket("udp", addr) + ln, err := ls.manager.ListenPacket("udp", addr, proxy) if err != nil { return nil, err } @@ -239,18 +239,18 @@ func (s *SSServer) runConfig(config Config) (func(), error) { ciphers.Update(cipherList) sh := s.NewShadowsocksStreamHandler(ciphers) - ln, err := lnSet.ListenStream(addr) + ln, err := lnSet.ListenStream(addr, false) if err != nil { return err } - logger.Infof("TCP service listening on %v", ln.Addr().String()) + logger.Infof("Shadowsocks service listening on tcp/%s", ln.Addr().String()) go service.StreamServe(ln.AcceptStream, sh.Handle) - pc, err := lnSet.ListenPacket(addr) + pc, err := lnSet.ListenPacket(addr, false) if err != nil { return err } - logger.Infof("UDP service listening on %v", pc.LocalAddr().String()) + logger.Infof("Shadowsocks service listening on udp/%s", pc.LocalAddr().String()) ph := s.NewShadowsocksPacketHandler(ciphers) go ph.Handle(pc) } @@ -261,33 +261,9 @@ func (s *SSServer) runConfig(config Config) (func(), error) { ph service.PacketHandler ) for _, lnConfig := range serviceConfig.Listeners { - switch lnConfig.Type { - case listenerTypeTCP: - ln, err := lnSet.ListenStream(lnConfig.Address) - if err != nil { - return err - } - logger.Infof("TCP service listening on %s", ln.Addr().String()) - if sh == nil { - sh, err = s.NewShadowsocksStreamHandlerFromConfig(serviceConfig) - if err != nil { - return err - } - } - go service.StreamServe(ln.AcceptStream, sh.Handle) - case listenerTypeUDP: - pc, err := lnSet.ListenPacket(lnConfig.Address) - if err != nil { - return err - } - logger.Infof("UDP service listening on %v", pc.LocalAddr().String()) - if ph == nil { - ph, err = s.NewShadowsocksPacketHandlerFromConfig(serviceConfig) - if err != nil { - return err - } - } - go ph.Handle(pc) + err := s.startListenerFromConfig(lnSet, serviceConfig, lnConfig, false, sh, ph) + if err != nil { + return err } } totalCipherCount += len(serviceConfig.Keys) @@ -313,6 +289,57 @@ func (s *SSServer) runConfig(config Config) (func(), error) { }, nil } +func (s *SSServer) startListenerFromConfig(lnSet *listenerSet, serviceConfig ServiceConfig, lnConfig ListenerConfig, proxy bool, sh service.StreamHandler, ph service.PacketHandler) error { + lnLogFunc := func(key string) { + var serviceToLog string + if proxy { + serviceToLog = "Proxy" + } else { + serviceToLog = "Shadowsocks" + } + logger.Infof("%s service listening on %s", serviceToLog, key) + } + switch lnConfig.Type { + case listenerTypeTCP: + ln, err := lnSet.ListenStream(lnConfig.Address, proxy) + if err != nil { + return err + } + lnLogFunc("tcp/" + ln.Addr().String()) + if sh == nil { + sh, err = s.NewShadowsocksStreamHandlerFromConfig(serviceConfig) + if err != nil { + return err + } + } + go service.StreamServe(ln.AcceptStream, sh.Handle) + + case listenerTypeUDP: + pc, err := lnSet.ListenPacket(lnConfig.Address, proxy) + if err != nil { + return err + } + lnLogFunc("udp/" + pc.LocalAddr().String()) + if ph == nil { + ph, err = s.NewShadowsocksPacketHandlerFromConfig(serviceConfig) + if err != nil { + return err + } + } + go ph.Handle(pc) + + case listenerTypeProxy: + for _, proxyLnConfig := range lnConfig.Listeners { + err := s.startListenerFromConfig(lnSet, serviceConfig, proxyLnConfig, true, sh, ph) + if err != nil { + return err + } + } + } + + return nil +} + // Stop stops serving the current config. func (s *SSServer) Stop() { go s.stopConfig() diff --git a/service/listeners.go b/service/listeners.go index 387fbb3e..e1bc763f 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -76,7 +76,7 @@ var _ StreamListener = (*TCPListener)(nil) func (t *TCPListener) AcceptStream() (ClientStreamConn, error) { conn, err := t.ln.AcceptTCP() - return &clientStreamConn{StreamConn: conn}, err + return &clientStreamConn{StreamConn: conn, clientAddr: conn.RemoteAddr()}, err } func (t *TCPListener) Close() error { @@ -88,7 +88,7 @@ func (t *TCPListener) Addr() net.Addr { } type acceptResponse struct { - conn transport.StreamConn + conn ClientStreamConn err error } @@ -109,7 +109,7 @@ func (sl *virtualStreamListener) AcceptStream() (ClientStreamConn, error) { if !ok { return nil, net.ErrClosed } - return &clientStreamConn{StreamConn: acceptResponse.conn}, acceptResponse.err + return acceptResponse.conn, acceptResponse.err case <-sl.closeCh: return nil, net.ErrClosed } @@ -215,12 +215,12 @@ type ListenerManager interface { // config is started before the old config is destroyed. This is done by using // reusable listener wrappers, which do not actually close the underlying socket // until all uses of the shared listener have been closed. - ListenStream(network string, addr string) (StreamListener, error) + ListenStream(network string, addr string, proxy bool) (StreamListener, error) // ListenPacket creates a new packet listener for a given network and address. // // See notes on [ListenStream]. - ListenPacket(network string, addr string) (net.PacketConn, error) + ListenPacket(network string, addr string, proxy bool) (net.PacketConn, error) } type listenerManager struct { @@ -235,7 +235,7 @@ func NewListenerManager() ListenerManager { } } -func (m *listenerManager) newStreamListener(network string, addr string) (Listener, error) { +func (m *listenerManager) newStreamListener(network string, addr string, proxy bool) (Listener, error) { tcpAddr, err := net.ResolveTCPAddr("tcp", addr) if err != nil { return nil, err @@ -244,7 +244,11 @@ func (m *listenerManager) newStreamListener(network string, addr string) (Listen if err != nil { return nil, err } - streamLn := &TCPListener{ln} + var streamLn StreamListener + streamLn = &TCPListener{ln} + if proxy { + streamLn = &ProxyStreamListener{StreamListener: streamLn} + } lnAddr := &listenAddr{ ln: streamLn, acceptCh: make(chan acceptResponse), @@ -266,7 +270,7 @@ func (m *listenerManager) newStreamListener(network string, addr string) (Listen return lnAddr, nil } -func (m *listenerManager) newPacketListener(network string, addr string) (Listener, error) { +func (m *listenerManager) newPacketListener(network string, addr string, proxy bool) (Listener, error) { pc, err := net.ListenPacket(network, addr) if err != nil { return nil, err @@ -280,7 +284,7 @@ func (m *listenerManager) newPacketListener(network string, addr string) (Listen }, nil } -func (m *listenerManager) getListener(network string, addr string) (RefCount[Listener], error) { +func (m *listenerManager) getListener(network string, addr string, proxy bool) (RefCount[Listener], error) { m.listenersMu.Lock() defer m.listenersMu.Unlock() @@ -294,9 +298,9 @@ func (m *listenerManager) getListener(network string, addr string) (RefCount[Lis err error ) if network == "tcp" { - ln, err = m.newStreamListener(network, addr) + ln, err = m.newStreamListener(network, addr, proxy) } else { - ln, err = m.newPacketListener(network, addr) + ln, err = m.newPacketListener(network, addr, proxy) } if err != nil { return nil, err @@ -306,8 +310,8 @@ func (m *listenerManager) getListener(network string, addr string) (RefCount[Lis return lnRefCount, nil } -func (m *listenerManager) ListenStream(network string, addr string) (StreamListener, error) { - lnRefCount, err := m.getListener(network, addr) +func (m *listenerManager) ListenStream(network string, addr string, proxy bool) (StreamListener, error) { + lnRefCount, err := m.getListener(network, addr, proxy) if err != nil { return nil, err } @@ -317,8 +321,8 @@ func (m *listenerManager) ListenStream(network string, addr string) (StreamListe return nil, fmt.Errorf("unable to create stream listener for %s/%s", network, addr) } -func (m *listenerManager) ListenPacket(network string, addr string) (net.PacketConn, error) { - lnRefCount, err := m.getListener(network, addr) +func (m *listenerManager) ListenPacket(network string, addr string, proxy bool) (net.PacketConn, error) { + lnRefCount, err := m.getListener(network, addr, proxy) if err != nil { return nil, err } From 827be3cea3dc7ca0073befb0da4fda8d28aea68d Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 5 Aug 2024 12:27:16 -0400 Subject: [PATCH 107/119] Check if the command is UNKNOWN (v1) or LOCAL (v2). --- service/listeners.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/service/listeners.go b/service/listeners.go index e1bc763f..4037771e 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -24,7 +24,7 @@ import ( "sync/atomic" "github.com/Jigsaw-Code/outline-sdk/transport" - proxyproto "github.com/pires/go-proxyproto" + "github.com/pires/go-proxyproto" ) // The implementations of listeners for different network types are not @@ -138,15 +138,18 @@ func (l *ProxyStreamListener) AcceptStream() (ClientStreamConn, error) { return nil, err } r := bufio.NewReader(conn) - h, err := proxyproto.Read(r) + header, err := proxyproto.Read(r) if errors.Is(err, proxyproto.ErrNoProxyProtocol) { logger.Warningf("Received connection from %v without proxy header.", conn.RemoteAddr()) return conn, nil } - if err != nil { + if header == nil || err != nil { return nil, fmt.Errorf("error parsing proxy header: %v", err) } - return &clientStreamConn{StreamConn: conn, clientAddr: h.SourceAddr}, nil + if header.Command.IsLocal() { + return conn, nil + } + return &clientStreamConn{StreamConn: conn, clientAddr: header.SourceAddr}, nil } type virtualPacketConn struct { From 2622b5e5f8fa22810d118e402ce37aeba29caca0 Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 5 Aug 2024 14:20:23 -0400 Subject: [PATCH 108/119] Add test scenarios for client addr. --- service/listeners.go | 7 +--- service/listeners_test.go | 87 ++++++++++++++++++++++++++++++++++++--- service/tcp.go | 2 +- 3 files changed, 85 insertions(+), 11 deletions(-) diff --git a/service/listeners.go b/service/listeners.go index 4037771e..02258ae1 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -45,10 +45,7 @@ type clientStreamConn struct { } func (c *clientStreamConn) ClientAddr() net.Addr { - if c.clientAddr != nil { - return c.clientAddr - } - return c.StreamConn.RemoteAddr() + return c.clientAddr } // StreamListener is a network listener for stream-oriented protocols that @@ -124,7 +121,7 @@ func (sl *virtualStreamListener) Addr() net.Addr { return sl.listener.Addr() } -// ProxyListener wraps a [StreamListener] and fetches the source of the connection from the PROXY +// ProxyStreamListener wraps a [StreamListener] and fetches the source of the connection from the PROXY // protocol header string. See https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt. type ProxyStreamListener struct { StreamListener diff --git a/service/listeners_test.go b/service/listeners_test.go index 384126ad..e4e8c7c6 100644 --- a/service/listeners_test.go +++ b/service/listeners_test.go @@ -19,12 +19,89 @@ import ( "net" "testing" + "github.com/pires/go-proxyproto" "github.com/stretchr/testify/require" ) +func TestDirectListenerSetsRemoteAddrAsClientAddr(t *testing.T) { + listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + require.NoError(t, err) + + go func() { + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoErrorf(t, err, "Failed to dial %v: %v", listener.Addr(), err) + conn.Write(makeTestPayload(50)) + conn.Close() + }() + + ln := &TCPListener{listener} + conn, err := ln.AcceptStream() + require.NoError(t, err) + require.Equal(t, conn.RemoteAddr(), conn.ClientAddr()) +} + +func TestProxyProtocolListenerParsesSourceAddressAsClientAddr(t *testing.T) { + listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + require.NoError(t, err) + + sourceAddr := &net.TCPAddr{ + IP: net.ParseIP("10.1.1.1"), + Port: 1000, + } + go func() { + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoErrorf(t, err, "Failed to dial %v: %v", listener.Addr(), err) + header := &proxyproto.Header{ + Version: 2, + Command: proxyproto.PROXY, + TransportProtocol: proxyproto.TCPv4, + SourceAddr: sourceAddr, + DestinationAddr: conn.RemoteAddr(), + } + header.WriteTo(conn) + conn.Write(makeTestPayload(50)) + conn.Close() + }() + + ln := &ProxyStreamListener{StreamListener: &TCPListener{listener}} + conn, err := ln.AcceptStream() + require.NoError(t, err) + require.True(t, sourceAddr.IP.Equal(conn.ClientAddr().(*net.TCPAddr).IP)) + require.Equal(t, sourceAddr.Port, conn.ClientAddr().(*net.TCPAddr).Port) +} + +func TestProxyProtocolListenerUsesRemoteAddrAsClientAddrIfLocalHeader(t *testing.T) { + listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + require.NoError(t, err) + + go func() { + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoErrorf(t, err, "Failed to dial %v: %v", listener.Addr(), err) + + header := &proxyproto.Header{ + Version: 2, + Command: proxyproto.LOCAL, + TransportProtocol: proxyproto.UNSPEC, + SourceAddr: &net.TCPAddr{ + IP: net.ParseIP("10.1.1.1"), + Port: 1000, + }, + DestinationAddr: conn.RemoteAddr(), + } + header.WriteTo(conn) + conn.Write(makeTestPayload(50)) + conn.Close() + }() + + ln := &ProxyStreamListener{StreamListener: &TCPListener{listener}} + conn, err := ln.AcceptStream() + require.NoError(t, err) + require.Equal(t, conn.RemoteAddr(), conn.ClientAddr()) +} + func TestListenerManagerStreamListenerEarlyClose(t *testing.T) { m := NewListenerManager() - ln, err := m.ListenStream("tcp", "127.0.0.1:0") + ln, err := m.ListenStream("tcp", "127.0.0.1:0", false) require.NoError(t, err) ln.Close() @@ -47,9 +124,9 @@ func writeTestPayload(ln StreamListener) error { func TestListenerManagerStreamListenerNotClosedIfStillInUse(t *testing.T) { m := NewListenerManager() - ln, err := m.ListenStream("tcp", "127.0.0.1:0") + ln, err := m.ListenStream("tcp", "127.0.0.1:0", false) require.NoError(t, err) - ln2, err := m.ListenStream("tcp", "127.0.0.1:0") + ln2, err := m.ListenStream("tcp", "127.0.0.1:0", false) require.NoError(t, err) // Close only the first listener. @@ -69,11 +146,11 @@ func TestListenerManagerStreamListenerNotClosedIfStillInUse(t *testing.T) { func TestListenerManagerStreamListenerCreatesListenerOnDemand(t *testing.T) { m := NewListenerManager() // Create a listener and immediately close it. - ln, err := m.ListenStream("tcp", "127.0.0.1:0") + ln, err := m.ListenStream("tcp", "127.0.0.1:0", false) require.NoError(t, err) ln.Close() // Now create another listener on the same address. - ln2, err := m.ListenStream("tcp", "127.0.0.1:0") + ln2, err := m.ListenStream("tcp", "127.0.0.1:0", false) require.NoError(t, err) done := make(chan struct{}) diff --git a/service/tcp.go b/service/tcp.go index 054cc01e..b4188fa0 100644 --- a/service/tcp.go +++ b/service/tcp.go @@ -215,7 +215,7 @@ type StreamAcceptFunc func() (ClientStreamConn, error) func WrapStreamAcceptFunc[T transport.StreamConn](f func() (T, error)) StreamAcceptFunc { return func() (ClientStreamConn, error) { c, err := f() - return &clientStreamConn{StreamConn: c}, err + return &clientStreamConn{StreamConn: c, clientAddr: c.RemoteAddr()}, err } } From fe8bbdda37bbb3927761690bf5ded40ed7b11ff7 Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 5 Aug 2024 15:22:13 -0400 Subject: [PATCH 109/119] Address review comments. --- cmd/outline-ss-server/main.go | 20 +++++++-------- service/listeners.go | 47 +++++++++++++++++++---------------- service/listeners_test.go | 10 ++++---- 3 files changed, 40 insertions(+), 37 deletions(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index ec6b66b0..f180044c 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -101,17 +101,17 @@ type listenerSet struct { listenersMu sync.Mutex } -// ListenStream announces on a given TCP network address. Trying to listen on -// the same address twice will result in an error. +// ListenStream announces on a given network address. Trying to listen for stream connections +// on the same address twice will result in an error. func (ls *listenerSet) ListenStream(addr string) (service.StreamListener, error) { ls.listenersMu.Lock() defer ls.listenersMu.Unlock() - lnKey := "tcp/" + addr + lnKey := "stream-" + addr if _, exists := ls.listenerCloseFuncs[lnKey]; exists { - return nil, fmt.Errorf("listener %s already exists", lnKey) + return nil, fmt.Errorf("stream listener for %s already exists", addr) } - ln, err := ls.manager.ListenStream("tcp", addr) + ln, err := ls.manager.ListenStream(addr) if err != nil { return nil, err } @@ -119,17 +119,17 @@ func (ls *listenerSet) ListenStream(addr string) (service.StreamListener, error) return ln, nil } -// ListenPacket announces on a given UDP network address. Trying to listen on -// the same address twice will result in an error. +// ListenPacket announces on a given network address. Trying to listen for packet connections +// on the same address twice will result in an error. func (ls *listenerSet) ListenPacket(addr string) (net.PacketConn, error) { ls.listenersMu.Lock() defer ls.listenersMu.Unlock() - lnKey := "udp/" + addr + lnKey := "packet-" + addr if _, exists := ls.listenerCloseFuncs[lnKey]; exists { - return nil, fmt.Errorf("listener %s already exists", lnKey) + return nil, fmt.Errorf("packet listener for %s already exists", addr) } - ln, err := ls.manager.ListenPacket("udp", addr) + ln, err := ls.manager.ListenPacket(addr) if err != nil { return nil, err } diff --git a/service/listeners.go b/service/listeners.go index 8f0882ba..d8d58192 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -72,8 +72,8 @@ type acceptResponse struct { type OnCloseFunc func() error type virtualStreamListener struct { - listener StreamListener - acceptCh chan acceptResponse + addr net.Addr + acceptCh <-chan acceptResponse closeCh chan struct{} onCloseFunc OnCloseFunc } @@ -93,12 +93,13 @@ func (sl *virtualStreamListener) AcceptStream() (transport.StreamConn, error) { } func (sl *virtualStreamListener) Close() error { + sl.acceptCh = nil close(sl.closeCh) return sl.onCloseFunc() } func (sl *virtualStreamListener) Addr() net.Addr { - return sl.listener.Addr() + return sl.addr } type virtualPacketConn struct { @@ -126,7 +127,7 @@ var _ canCreateStreamListener = (*listenAddr)(nil) func (la *listenAddr) NewStreamListener(onCloseFunc OnCloseFunc) StreamListener { if ln, ok := la.ln.(StreamListener); ok { return &virtualStreamListener{ - listener: ln, + addr: ln.Addr(), acceptCh: la.acceptCh, closeCh: make(chan struct{}), onCloseFunc: onCloseFunc, @@ -161,18 +162,18 @@ func (la *listenAddr) Close() error { // ListenerManager holds and manages the state of shared listeners. type ListenerManager interface { - // ListenStream creates a new stream listener for a given network and address. + // ListenStream creates a new stream listener for a given address. // // Listeners can overlap one another, because during config changes the new // config is started before the old config is destroyed. This is done by using // reusable listener wrappers, which do not actually close the underlying socket // until all uses of the shared listener have been closed. - ListenStream(network string, addr string) (StreamListener, error) + ListenStream(addr string) (StreamListener, error) - // ListenPacket creates a new packet listener for a given network and address. + // ListenPacket creates a new packet listener for a given address. // // See notes on [ListenStream]. - ListenPacket(network string, addr string) (net.PacketConn, error) + ListenPacket(addr string) (net.PacketConn, error) } type listenerManager struct { @@ -187,12 +188,12 @@ func NewListenerManager() ListenerManager { } } -func (m *listenerManager) newStreamListener(network string, addr string) (Listener, error) { +func (m *listenerManager) newStreamListener(addr string) (Listener, error) { tcpAddr, err := net.ResolveTCPAddr("tcp", addr) if err != nil { return nil, err } - ln, err := net.ListenTCP(network, tcpAddr) + ln, err := net.ListenTCP("tcp", tcpAddr) if err != nil { return nil, err } @@ -201,7 +202,7 @@ func (m *listenerManager) newStreamListener(network string, addr string) (Listen ln: streamLn, acceptCh: make(chan acceptResponse), onCloseFunc: func() error { - m.delete(listenerKey(network, addr)) + m.delete(listenerKey("tcp", addr)) return nil }, } @@ -218,15 +219,15 @@ func (m *listenerManager) newStreamListener(network string, addr string) (Listen return lnAddr, nil } -func (m *listenerManager) newPacketListener(network string, addr string) (Listener, error) { - pc, err := net.ListenPacket(network, addr) +func (m *listenerManager) newPacketListener(addr string) (Listener, error) { + pc, err := net.ListenPacket("udp", addr) if err != nil { return nil, err } return &listenAddr{ ln: pc, onCloseFunc: func() error { - m.delete(listenerKey(network, addr)) + m.delete(listenerKey("udp", addr)) return nil }, }, nil @@ -246,9 +247,11 @@ func (m *listenerManager) getListener(network string, addr string) (RefCount[Lis err error ) if network == "tcp" { - ln, err = m.newStreamListener(network, addr) + ln, err = m.newStreamListener(addr) + } else if network == "udp" { + ln, err = m.newPacketListener(addr) } else { - ln, err = m.newPacketListener(network, addr) + return nil, fmt.Errorf("unable to get listener for unsupported network %s", network) } if err != nil { return nil, err @@ -258,26 +261,26 @@ func (m *listenerManager) getListener(network string, addr string) (RefCount[Lis return lnRefCount, nil } -func (m *listenerManager) ListenStream(network string, addr string) (StreamListener, error) { - lnRefCount, err := m.getListener(network, addr) +func (m *listenerManager) ListenStream(addr string) (StreamListener, error) { + lnRefCount, err := m.getListener("tcp", addr) if err != nil { return nil, err } if lnAddr, ok := lnRefCount.Get().(canCreateStreamListener); ok { return lnAddr.NewStreamListener(lnRefCount.Close), nil } - return nil, fmt.Errorf("unable to create stream listener for %s/%s", network, addr) + return nil, fmt.Errorf("unable to create stream listener for %s", addr) } -func (m *listenerManager) ListenPacket(network string, addr string) (net.PacketConn, error) { - lnRefCount, err := m.getListener(network, addr) +func (m *listenerManager) ListenPacket(addr string) (net.PacketConn, error) { + lnRefCount, err := m.getListener("udp", addr) if err != nil { return nil, err } if lnAddr, ok := lnRefCount.Get().(canCreatePacketListener); ok { return lnAddr.NewPacketListener(lnRefCount.Close), nil } - return nil, fmt.Errorf("unable to create packet listener for %s/%s", network, addr) + return nil, fmt.Errorf("unable to create packet listener for %s", addr) } func (m *listenerManager) delete(key string) { diff --git a/service/listeners_test.go b/service/listeners_test.go index 384126ad..da5aaa1e 100644 --- a/service/listeners_test.go +++ b/service/listeners_test.go @@ -24,7 +24,7 @@ import ( func TestListenerManagerStreamListenerEarlyClose(t *testing.T) { m := NewListenerManager() - ln, err := m.ListenStream("tcp", "127.0.0.1:0") + ln, err := m.ListenStream("127.0.0.1:0") require.NoError(t, err) ln.Close() @@ -47,9 +47,9 @@ func writeTestPayload(ln StreamListener) error { func TestListenerManagerStreamListenerNotClosedIfStillInUse(t *testing.T) { m := NewListenerManager() - ln, err := m.ListenStream("tcp", "127.0.0.1:0") + ln, err := m.ListenStream("127.0.0.1:0") require.NoError(t, err) - ln2, err := m.ListenStream("tcp", "127.0.0.1:0") + ln2, err := m.ListenStream("127.0.0.1:0") require.NoError(t, err) // Close only the first listener. @@ -69,11 +69,11 @@ func TestListenerManagerStreamListenerNotClosedIfStillInUse(t *testing.T) { func TestListenerManagerStreamListenerCreatesListenerOnDemand(t *testing.T) { m := NewListenerManager() // Create a listener and immediately close it. - ln, err := m.ListenStream("tcp", "127.0.0.1:0") + ln, err := m.ListenStream("127.0.0.1:0") require.NoError(t, err) ln.Close() // Now create another listener on the same address. - ln2, err := m.ListenStream("tcp", "127.0.0.1:0") + ln2, err := m.ListenStream("127.0.0.1:0") require.NoError(t, err) done := make(chan struct{}) From 36a0a1d9f43c25e023b7c9b70fd54dcdae2dcd5e Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 5 Aug 2024 18:32:20 -0400 Subject: [PATCH 110/119] Use a mutex to ensure another user doesn't acquire a new closer while we're closing it. --- service/listeners.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/service/listeners.go b/service/listeners.go index d8d58192..648f315e 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -305,24 +305,28 @@ type RefCount[T io.Closer] interface { } type refCount[T io.Closer] struct { + mu sync.Mutex count *atomic.Int32 value T } func NewRefCount[T io.Closer](value T) RefCount[T] { - res := &refCount[T]{ + r := &refCount[T]{ count: &atomic.Int32{}, value: value, } - res.count.Store(1) - return res + r.count.Store(1) + return r } func (r refCount[T]) Close() error { + // Lock to prevent someone from acquiring while we close the value. + r.mu.Lock() + defer r.mu.Unlock() + if count := r.count.Add(-1); count == 0 { return r.value.Close() } - return nil } From aeb2652fb3da4ea7339c88fcba7e32c5b91a5454 Mon Sep 17 00:00:00 2001 From: sbruens Date: Tue, 6 Aug 2024 10:13:00 -0400 Subject: [PATCH 111/119] Move mutex up. --- cmd/outline-ss-server/main.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index f180044c..55a11acc 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -139,13 +139,14 @@ func (ls *listenerSet) ListenPacket(addr string) (net.PacketConn, error) { // Close closes all the listeners in the set, after which the set can't be used again. func (ls *listenerSet) Close() error { + ls.listenersMu.Lock() + defer ls.listenersMu.Unlock() + for addr, listenerCloseFunc := range ls.listenerCloseFuncs { if err := listenerCloseFunc(); err != nil { return fmt.Errorf("listener on address %s failed to stop: %w", addr, err) } } - ls.listenersMu.Lock() - defer ls.listenersMu.Unlock() ls.listenerCloseFuncs = nil return nil } From 8873b107083fedfe2b6627f326405c039177e4a3 Mon Sep 17 00:00:00 2001 From: sbruens Date: Tue, 6 Aug 2024 17:33:18 -0400 Subject: [PATCH 112/119] Manage the ref counting next to the listener creation. --- service/listeners.go | 310 +++++++++++++++++++------------------- service/listeners_test.go | 24 ++- 2 files changed, 170 insertions(+), 164 deletions(-) diff --git a/service/listeners.go b/service/listeners.go index 648f315e..91aa877f 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -95,7 +95,10 @@ func (sl *virtualStreamListener) AcceptStream() (transport.StreamConn, error) { func (sl *virtualStreamListener) Close() error { sl.acceptCh = nil close(sl.closeCh) - return sl.onCloseFunc() + if sl.onCloseFunc != nil { + return sl.onCloseFunc() + } + return nil } func (sl *virtualStreamListener) Addr() net.Addr { @@ -108,189 +111,180 @@ type virtualPacketConn struct { } func (spc *virtualPacketConn) Close() error { - return spc.onCloseFunc() + if spc.onCloseFunc != nil { + return spc.onCloseFunc() + } + return nil } -type listenAddr struct { - ln Listener +// MultiListener manages shared listeners. +type MultiListener[T Listener] interface { + // Acquire creates a new listener from the shared listener. Listeners can overlap + // one another (e.g. during config changes the new config is started before the + // old config is destroyed), which is done by creating virtual listeners that wrap + // the shared listener. These virtual listeners do not actually close the + // underlying socket until all uses of the shared listener have been closed. + Acquire() (T, error) +} + +type multiStreamListener struct { + mu sync.Mutex + addr string + ln RefCount[StreamListener] acceptCh chan acceptResponse onCloseFunc OnCloseFunc } -type canCreateStreamListener interface { - NewStreamListener(onCloseFunc OnCloseFunc) StreamListener +// NewMultiStreamListener creates a new stream-based [MultiListener]. +func NewMultiStreamListener(addr string, onCloseFunc OnCloseFunc) MultiListener[StreamListener] { + return &multiStreamListener{ + addr: addr, + acceptCh: make(chan acceptResponse), + onCloseFunc: onCloseFunc, + } } -var _ canCreateStreamListener = (*listenAddr)(nil) +func (m *multiStreamListener) Acquire() (StreamListener, error) { + m.mu.Lock() + defer m.mu.Unlock() -// NewStreamListener creates a new [StreamListener]. -func (la *listenAddr) NewStreamListener(onCloseFunc OnCloseFunc) StreamListener { - if ln, ok := la.ln.(StreamListener); ok { - return &virtualStreamListener{ - addr: ln.Addr(), - acceptCh: la.acceptCh, - closeCh: make(chan struct{}), - onCloseFunc: onCloseFunc, + var sl StreamListener + if m.ln == nil { + tcpAddr, err := net.ResolveTCPAddr("tcp", m.addr) + if err != nil { + return nil, err + } + ln, err := net.ListenTCP("tcp", tcpAddr) + if err != nil { + return nil, err } + sl = &TCPListener{ln} + m.ln = NewRefCount(sl, m.onCloseFunc) + go func() { + for { + conn, err := sl.AcceptStream() + if errors.Is(err, net.ErrClosed) { + close(m.acceptCh) + return + } + m.acceptCh <- acceptResponse{conn, err} + } + }() } - return nil -} -type canCreatePacketListener interface { - NewPacketListener(onCloseFunc OnCloseFunc) net.PacketConn + sl = m.ln.Acquire() + return &virtualStreamListener{ + addr: sl.Addr(), + acceptCh: m.acceptCh, + closeCh: make(chan struct{}), + onCloseFunc: m.ln.Close, + }, nil } -var _ canCreatePacketListener = (*listenAddr)(nil) +type multiPacketListener struct { + mu sync.Mutex + addr string + pc RefCount[net.PacketConn] + onCloseFunc OnCloseFunc +} -// NewPacketListener creates a new [net.PacketConn]. -func (cl *listenAddr) NewPacketListener(onCloseFunc OnCloseFunc) net.PacketConn { - if ln, ok := cl.ln.(net.PacketConn); ok { - return &virtualPacketConn{ - PacketConn: ln, - onCloseFunc: onCloseFunc, - } +// NewMultiPacketListener creates a new packet-based [MultiListener]. +func NewMultiPacketListener(addr string, onCloseFunc OnCloseFunc) MultiListener[net.PacketConn] { + return &multiPacketListener{ + addr: addr, + onCloseFunc: onCloseFunc, } - return nil } -func (la *listenAddr) Close() error { - if err := la.ln.Close(); err != nil { - return err +func (m *multiPacketListener) Acquire() (net.PacketConn, error) { + m.mu.Lock() + defer m.mu.Unlock() + + var pc net.PacketConn + if m.pc == nil { + pc, err := net.ListenPacket("udp", m.addr) + if err != nil { + return nil, err + } + m.pc = NewRefCount(pc, m.onCloseFunc) } - return la.onCloseFunc() + pc = m.pc.Acquire() + return &virtualPacketConn{ + PacketConn: pc, + onCloseFunc: m.pc.Close, + }, nil } -// ListenerManager holds and manages the state of shared listeners. +// ListenerManager holds the state of shared listeners. type ListenerManager interface { // ListenStream creates a new stream listener for a given address. - // - // Listeners can overlap one another, because during config changes the new - // config is started before the old config is destroyed. This is done by using - // reusable listener wrappers, which do not actually close the underlying socket - // until all uses of the shared listener have been closed. ListenStream(addr string) (StreamListener, error) // ListenPacket creates a new packet listener for a given address. - // - // See notes on [ListenStream]. ListenPacket(addr string) (net.PacketConn, error) } type listenerManager struct { - listeners map[string]RefCount[Listener] - listenersMu sync.Mutex + streamListeners map[string]MultiListener[StreamListener] + packetListeners map[string]MultiListener[net.PacketConn] + mu sync.Mutex } // NewListenerManager creates a new [ListenerManger]. func NewListenerManager() ListenerManager { return &listenerManager{ - listeners: make(map[string]RefCount[Listener]), - } -} - -func (m *listenerManager) newStreamListener(addr string) (Listener, error) { - tcpAddr, err := net.ResolveTCPAddr("tcp", addr) - if err != nil { - return nil, err - } - ln, err := net.ListenTCP("tcp", tcpAddr) - if err != nil { - return nil, err - } - streamLn := &TCPListener{ln} - lnAddr := &listenAddr{ - ln: streamLn, - acceptCh: make(chan acceptResponse), - onCloseFunc: func() error { - m.delete(listenerKey("tcp", addr)) - return nil - }, + streamListeners: make(map[string]MultiListener[StreamListener]), + packetListeners: make(map[string]MultiListener[net.PacketConn]), } - go func() { - for { - conn, err := streamLn.AcceptStream() - if errors.Is(err, net.ErrClosed) { - close(lnAddr.acceptCh) - return - } - lnAddr.acceptCh <- acceptResponse{conn, err} - } - }() - return lnAddr, nil } -func (m *listenerManager) newPacketListener(addr string) (Listener, error) { - pc, err := net.ListenPacket("udp", addr) - if err != nil { - return nil, err - } - return &listenAddr{ - ln: pc, - onCloseFunc: func() error { - m.delete(listenerKey("udp", addr)) - return nil - }, - }, nil -} - -func (m *listenerManager) getListener(network string, addr string) (RefCount[Listener], error) { - m.listenersMu.Lock() - defer m.listenersMu.Unlock() - - lnKey := listenerKey(network, addr) - if lnRefCount, exists := m.listeners[lnKey]; exists { - return lnRefCount.Acquire(), nil - } - - var ( - ln Listener - err error - ) - if network == "tcp" { - ln, err = m.newStreamListener(addr) - } else if network == "udp" { - ln, err = m.newPacketListener(addr) - } else { - return nil, fmt.Errorf("unable to get listener for unsupported network %s", network) +func (m *listenerManager) ListenStream(addr string) (StreamListener, error) { + m.mu.Lock() + defer m.mu.Unlock() + + streamLn, exists := m.streamListeners[addr] + if !exists { + streamLn = NewMultiStreamListener( + addr, + func() error { + m.mu.Lock() + delete(m.streamListeners, addr) + m.mu.Unlock() + return nil + }, + ) + m.streamListeners[addr] = streamLn } + ln, err := streamLn.Acquire() if err != nil { - return nil, err + return nil, fmt.Errorf("unable to create stream listener for %s: %v", addr, err) } - lnRefCount := NewRefCount(ln) - m.listeners[lnKey] = lnRefCount - return lnRefCount, nil + return ln, nil } -func (m *listenerManager) ListenStream(addr string) (StreamListener, error) { - lnRefCount, err := m.getListener("tcp", addr) - if err != nil { - return nil, err - } - if lnAddr, ok := lnRefCount.Get().(canCreateStreamListener); ok { - return lnAddr.NewStreamListener(lnRefCount.Close), nil +func (m *listenerManager) ListenPacket(addr string) (net.PacketConn, error) { + m.mu.Lock() + defer m.mu.Unlock() + + packetLn, exists := m.packetListeners[addr] + if !exists { + packetLn = NewMultiPacketListener( + addr, + func() error { + m.mu.Lock() + delete(m.packetListeners, addr) + m.mu.Unlock() + return nil + }, + ) + m.packetListeners[addr] = packetLn } - return nil, fmt.Errorf("unable to create stream listener for %s", addr) -} -func (m *listenerManager) ListenPacket(addr string) (net.PacketConn, error) { - lnRefCount, err := m.getListener("udp", addr) + ln, err := packetLn.Acquire() if err != nil { - return nil, err - } - if lnAddr, ok := lnRefCount.Get().(canCreatePacketListener); ok { - return lnAddr.NewPacketListener(lnRefCount.Close), nil + return nil, fmt.Errorf("unable to create packet listener for %s: %v", addr, err) } - return nil, fmt.Errorf("unable to create packet listener for %s", addr) -} - -func (m *listenerManager) delete(key string) { - m.listenersMu.Lock() - delete(m.listeners, key) - m.listenersMu.Unlock() -} - -func listenerKey(network string, addr string) string { - return network + "/" + addr + return ln, nil } // RefCount is an atomic reference counter that can be used to track a shared @@ -299,42 +293,44 @@ type RefCount[T io.Closer] interface { io.Closer // Acquire increases the ref count and returns the wrapped object. - Acquire() RefCount[T] - - Get() T + Acquire() T } type refCount[T io.Closer] struct { - mu sync.Mutex - count *atomic.Int32 - value T + mu sync.Mutex + count *atomic.Int32 + value T + onCloseFunc OnCloseFunc } -func NewRefCount[T io.Closer](value T) RefCount[T] { +func NewRefCount[T io.Closer](value T, onCloseFunc OnCloseFunc) RefCount[T] { r := &refCount[T]{ - count: &atomic.Int32{}, - value: value, + count: &atomic.Int32{}, + value: value, + onCloseFunc: onCloseFunc, } - r.count.Store(1) return r } +func (r refCount[T]) Acquire() T { + r.count.Add(1) + return r.value +} + func (r refCount[T]) Close() error { // Lock to prevent someone from acquiring while we close the value. r.mu.Lock() defer r.mu.Unlock() if count := r.count.Add(-1); count == 0 { - return r.value.Close() + err := r.value.Close() + if err != nil { + return err + } + if r.onCloseFunc != nil { + return r.onCloseFunc() + } + return nil } return nil } - -func (r refCount[T]) Acquire() RefCount[T] { - r.count.Add(1) - return r -} - -func (r refCount[T]) Get() T { - return r.value -} diff --git a/service/listeners_test.go b/service/listeners_test.go index da5aaa1e..0a840cff 100644 --- a/service/listeners_test.go +++ b/service/listeners_test.go @@ -97,17 +97,27 @@ func (t *testRefCount) Close() error { } func TestRefCount(t *testing.T) { - var done bool - rc := NewRefCount[*testRefCount](&testRefCount{ - onCloseFunc: func() { - done = true + var objectCloseDone bool + var onCloseFuncDone bool + rc := NewRefCount[*testRefCount]( + &testRefCount{ + onCloseFunc: func() { + objectCloseDone = true + }, }, - }) + func() error { + onCloseFuncDone = true + return nil + }, + ) + rc.Acquire() rc.Acquire() require.NoError(t, rc.Close()) - require.False(t, done) + require.False(t, objectCloseDone) + require.False(t, onCloseFuncDone) require.NoError(t, rc.Close()) - require.True(t, done) + require.True(t, objectCloseDone) + require.True(t, onCloseFuncDone) } From 899d13d80af21ed132b670916e8ec3fb22e1f762 Mon Sep 17 00:00:00 2001 From: sbruens Date: Tue, 6 Aug 2024 17:57:12 -0400 Subject: [PATCH 113/119] Do the lazy initialization inside an anonymous function. --- service/listeners.go | 87 +++++++++++++++++++++++++------------------- 1 file changed, 49 insertions(+), 38 deletions(-) diff --git a/service/listeners.go b/service/listeners.go index 91aa877f..9814adca 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -139,45 +139,50 @@ type multiStreamListener struct { func NewMultiStreamListener(addr string, onCloseFunc OnCloseFunc) MultiListener[StreamListener] { return &multiStreamListener{ addr: addr, - acceptCh: make(chan acceptResponse), onCloseFunc: onCloseFunc, } } func (m *multiStreamListener) Acquire() (StreamListener, error) { - m.mu.Lock() - defer m.mu.Unlock() - - var sl StreamListener - if m.ln == nil { - tcpAddr, err := net.ResolveTCPAddr("tcp", m.addr) - if err != nil { - return nil, err - } - ln, err := net.ListenTCP("tcp", tcpAddr) - if err != nil { - return nil, err - } - sl = &TCPListener{ln} - m.ln = NewRefCount(sl, m.onCloseFunc) - go func() { - for { - conn, err := sl.AcceptStream() - if errors.Is(err, net.ErrClosed) { - close(m.acceptCh) - return - } - m.acceptCh <- acceptResponse{conn, err} + refCount, err := func() (RefCount[StreamListener], error) { + m.mu.Lock() + defer m.mu.Unlock() + + if m.ln == nil { + tcpAddr, err := net.ResolveTCPAddr("tcp", m.addr) + if err != nil { + return nil, err } - }() + ln, err := net.ListenTCP("tcp", tcpAddr) + if err != nil { + return nil, err + } + sl := &TCPListener{ln} + m.ln = NewRefCount[StreamListener](sl, m.onCloseFunc) + m.acceptCh = make(chan acceptResponse) + go func() { + for { + conn, err := sl.AcceptStream() + if errors.Is(err, net.ErrClosed) { + close(m.acceptCh) + return + } + m.acceptCh <- acceptResponse{conn, err} + } + }() + } + return m.ln, nil + }() + if err != nil { + return nil, err } - sl = m.ln.Acquire() + sl := refCount.Acquire() return &virtualStreamListener{ addr: sl.Addr(), acceptCh: m.acceptCh, closeCh: make(chan struct{}), - onCloseFunc: m.ln.Close, + onCloseFunc: refCount.Close, }, nil } @@ -197,21 +202,27 @@ func NewMultiPacketListener(addr string, onCloseFunc OnCloseFunc) MultiListener[ } func (m *multiPacketListener) Acquire() (net.PacketConn, error) { - m.mu.Lock() - defer m.mu.Unlock() - - var pc net.PacketConn - if m.pc == nil { - pc, err := net.ListenPacket("udp", m.addr) - if err != nil { - return nil, err + refCount, err := func() (RefCount[net.PacketConn], error) { + m.mu.Lock() + defer m.mu.Unlock() + + if m.pc == nil { + pc, err := net.ListenPacket("udp", m.addr) + if err != nil { + return nil, err + } + m.pc = NewRefCount(pc, m.onCloseFunc) } - m.pc = NewRefCount(pc, m.onCloseFunc) + return m.pc, nil + }() + if err != nil { + return nil, err } - pc = m.pc.Acquire() + + pc := refCount.Acquire() return &virtualPacketConn{ PacketConn: pc, - onCloseFunc: m.pc.Close, + onCloseFunc: refCount.Close, }, nil } From 80e5d491c6c5027b45e8b6e45e76f2428cea3ab0 Mon Sep 17 00:00:00 2001 From: sbruens Date: Wed, 7 Aug 2024 10:40:33 -0400 Subject: [PATCH 114/119] Fix concurrent access to `acceptCh` and `closeCh`. --- service/listeners.go | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/service/listeners.go b/service/listeners.go index 9814adca..1cb09d06 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -72,17 +72,23 @@ type acceptResponse struct { type OnCloseFunc func() error type virtualStreamListener struct { + mu sync.Mutex // Mutex to protect access to the channels addr net.Addr acceptCh <-chan acceptResponse closeCh chan struct{} + closed bool onCloseFunc OnCloseFunc } var _ StreamListener = (*virtualStreamListener)(nil) func (sl *virtualStreamListener) AcceptStream() (transport.StreamConn, error) { + sl.mu.Lock() + acceptCh := sl.acceptCh + sl.mu.Unlock() + select { - case acceptResponse, ok := <-sl.acceptCh: + case acceptResponse, ok := <-acceptCh: if !ok { return nil, net.ErrClosed } @@ -93,8 +99,16 @@ func (sl *virtualStreamListener) AcceptStream() (transport.StreamConn, error) { } func (sl *virtualStreamListener) Close() error { + sl.mu.Lock() + if sl.closed { + sl.mu.Unlock() + return nil + } + sl.closed = true sl.acceptCh = nil close(sl.closeCh) + sl.mu.Unlock() + if sl.onCloseFunc != nil { return sl.onCloseFunc() } From aa00f2efe1d19632e30c2cd27f7a018e92d5d4a4 Mon Sep 17 00:00:00 2001 From: sbruens Date: Wed, 7 Aug 2024 11:42:56 -0400 Subject: [PATCH 115/119] Use `/` in key instead of `-`. --- cmd/outline-ss-server/main.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 55a11acc..9fd570f7 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -107,7 +107,7 @@ func (ls *listenerSet) ListenStream(addr string) (service.StreamListener, error) ls.listenersMu.Lock() defer ls.listenersMu.Unlock() - lnKey := "stream-" + addr + lnKey := "stream/" + addr if _, exists := ls.listenerCloseFuncs[lnKey]; exists { return nil, fmt.Errorf("stream listener for %s already exists", addr) } @@ -125,7 +125,7 @@ func (ls *listenerSet) ListenPacket(addr string) (net.PacketConn, error) { ls.listenersMu.Lock() defer ls.listenersMu.Unlock() - lnKey := "packet-" + addr + lnKey := "packet/" + addr if _, exists := ls.listenerCloseFuncs[lnKey]; exists { return nil, fmt.Errorf("packet listener for %s already exists", addr) } From e658b90573a79bd26cf19f468da47f73ca694a07 Mon Sep 17 00:00:00 2001 From: sbruens Date: Wed, 7 Aug 2024 11:51:19 -0400 Subject: [PATCH 116/119] Return error from stopping listeners. --- cmd/outline-ss-server/main.go | 34 ++++++++++++++++++++-------- cmd/outline-ss-server/server_test.go | 4 +++- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 9fd570f7..aa8bb55e 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -61,7 +61,7 @@ func init() { } type SSServer struct { - stopConfig func() + stopConfig func() error lnManager service.ListenerManager natTimeout time.Duration m *outlineMetrics @@ -76,12 +76,14 @@ func (s *SSServer) loadConfig(filename string) error { // We hot swap the config by having the old and new listeners both live at // the same time. This means we create listeners for the new config first, // and then close the old ones after. - stopConfig, err := s.runConfig(*config) + sopConfig, err := s.runConfig(*config) if err != nil { return err } - go s.stopConfig() - s.stopConfig = stopConfig + if err := s.Stop(); err != nil { + return fmt.Errorf("unable to stop old config: %v", err) + } + s.stopConfig = sopConfig return nil } @@ -156,8 +158,9 @@ func (ls *listenerSet) Len() int { return len(ls.listenerCloseFuncs) } -func (s *SSServer) runConfig(config Config) (func(), error) { +func (s *SSServer) runConfig(config Config) (func() error, error) { startErrCh := make(chan error) + stopErrCh := make(chan error) stopCh := make(chan struct{}) go func() { @@ -165,7 +168,9 @@ func (s *SSServer) runConfig(config Config) (func(), error) { manager: s.lnManager, listenerCloseFuncs: make(map[string]func() error), } - defer lnSet.Close() // This closes all the listeners in the set. + defer func() { + stopErrCh <- lnSet.Close() + }() startErrCh <- func() error { portCiphers := make(map[int]*list.List) // Values are *List of *CipherEntry. @@ -216,24 +221,33 @@ func (s *SSServer) runConfig(config Config) (func(), error) { if err != nil { return nil, err } - return func() { + return func() error { logger.Infof("Stopping running config.") // TODO(sbruens): Actually wait for all handlers to be stopped, e.g. by // using a https://pkg.go.dev/sync#WaitGroup. stopCh <- struct{}{} + stopErr := <-stopErrCh + return stopErr }, nil } // Stop stops serving the current config. -func (s *SSServer) Stop() { - go s.stopConfig() +func (s *SSServer) Stop() error { + stopFunc := s.stopConfig + if stopFunc == nil { + return nil + } + if err := stopFunc(); err != nil { + logger.Errorf("Error stopping config: %v", err) + return err + } logger.Info("Stopped all listeners for running config") + return nil } // RunSSServer starts a shadowsocks server running, and returns the server or an error. func RunSSServer(filename string, natTimeout time.Duration, sm *outlineMetrics, replayHistory int) (*SSServer, error) { server := &SSServer{ - stopConfig: func() {}, lnManager: service.NewListenerManager(), natTimeout: natTimeout, m: sm, diff --git a/cmd/outline-ss-server/server_test.go b/cmd/outline-ss-server/server_test.go index 2ba0772e..0b7777b2 100644 --- a/cmd/outline-ss-server/server_test.go +++ b/cmd/outline-ss-server/server_test.go @@ -27,5 +27,7 @@ func TestRunSSServer(t *testing.T) { if err != nil { t.Fatalf("RunSSServer() error = %v", err) } - server.Stop() + if err := server.Stop(); err != nil { + t.Errorf("Error while stopping server: %v", err) + } } From fede4d8d7764e452951dc7f59a0a7bb38b828b41 Mon Sep 17 00:00:00 2001 From: sbruens Date: Wed, 7 Aug 2024 13:18:06 -0400 Subject: [PATCH 117/119] Use channels to ensure `virtualPacketConn`s get closed. --- service/listeners.go | 63 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 58 insertions(+), 5 deletions(-) diff --git a/service/listeners.go b/service/listeners.go index 1cb09d06..788d651f 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -64,13 +64,13 @@ func (t *TCPListener) Addr() net.Addr { return t.ln.Addr() } +type OnCloseFunc func() error + type acceptResponse struct { conn transport.StreamConn err error } -type OnCloseFunc func() error - type virtualStreamListener struct { mu sync.Mutex // Mutex to protect access to the channels addr net.Addr @@ -119,14 +119,52 @@ func (sl *virtualStreamListener) Addr() net.Addr { return sl.addr } +type packetResponse struct { + n int + addr net.Addr + err error + data []byte +} + type virtualPacketConn struct { net.PacketConn + mu sync.Mutex // Mutex to protect access to the channels + readCh <-chan packetResponse + closeCh chan struct{} + closed bool onCloseFunc OnCloseFunc } -func (spc *virtualPacketConn) Close() error { - if spc.onCloseFunc != nil { - return spc.onCloseFunc() +func (pc *virtualPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + pc.mu.Lock() + readCh := pc.readCh + pc.mu.Unlock() + + select { + case packetResponse, ok := <-readCh: + if !ok { + return 0, nil, net.ErrClosed + } + copy(p, packetResponse.data) + return packetResponse.n, packetResponse.addr, packetResponse.err + case <-pc.closeCh: + return 0, nil, net.ErrClosed + } +} + +func (pc *virtualPacketConn) Close() error { + pc.mu.Lock() + if pc.closed { + pc.mu.Unlock() + return nil + } + pc.closed = true + pc.readCh = nil + close(pc.closeCh) + pc.mu.Unlock() + + if pc.onCloseFunc != nil { + return pc.onCloseFunc() } return nil } @@ -204,6 +242,7 @@ type multiPacketListener struct { mu sync.Mutex addr string pc RefCount[net.PacketConn] + readCh chan packetResponse onCloseFunc OnCloseFunc } @@ -226,6 +265,18 @@ func (m *multiPacketListener) Acquire() (net.PacketConn, error) { return nil, err } m.pc = NewRefCount(pc, m.onCloseFunc) + m.readCh = make(chan packetResponse) + go func() { + for { + buffer := make([]byte, serverUDPBufferSize) + n, addr, err := pc.ReadFrom(buffer) + if err != nil { + close(m.readCh) + return + } + m.readCh <- packetResponse{n: n, addr: addr, err: err, data: buffer[:n]} + } + }() } return m.pc, nil }() @@ -236,6 +287,8 @@ func (m *multiPacketListener) Acquire() (net.PacketConn, error) { pc := refCount.Acquire() return &virtualPacketConn{ PacketConn: pc, + readCh: m.readCh, + closeCh: make(chan struct{}), onCloseFunc: refCount.Close, }, nil } From 4730d741237e5f697975d5f0f1ae26c3cebc0fce Mon Sep 17 00:00:00 2001 From: sbruens Date: Wed, 7 Aug 2024 16:21:10 -0400 Subject: [PATCH 118/119] Add more test cases for packet listeners. --- service/listeners_test.go | 61 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 58 insertions(+), 3 deletions(-) diff --git a/service/listeners_test.go b/service/listeners_test.go index 0a840cff..d627ec1a 100644 --- a/service/listeners_test.go +++ b/service/listeners_test.go @@ -51,18 +51,17 @@ func TestListenerManagerStreamListenerNotClosedIfStillInUse(t *testing.T) { require.NoError(t, err) ln2, err := m.ListenStream("127.0.0.1:0") require.NoError(t, err) - // Close only the first listener. ln.Close() + done := make(chan struct{}) go func() { ln2.AcceptStream() done <- struct{}{} }() - err = writeTestPayload(ln2) - require.NoError(t, err) + require.NoError(t, err) <-done } @@ -82,8 +81,64 @@ func TestListenerManagerStreamListenerCreatesListenerOnDemand(t *testing.T) { done <- struct{}{} }() err = writeTestPayload(ln2) + + require.NoError(t, err) + <-done +} + +func TestListenerManagerPacketListenerEarlyClose(t *testing.T) { + m := NewListenerManager() + pc, err := m.ListenPacket("127.0.0.1:0") + require.NoError(t, err) + + pc.Close() + _, _, readErr := pc.ReadFrom(nil) + _, writeErr := pc.WriteTo(nil, &net.UDPAddr{}) + + require.ErrorIs(t, readErr, net.ErrClosed) + require.ErrorIs(t, writeErr, net.ErrClosed) +} + +func TestListenerManagerPacketListenerNotClosedIfStillInUse(t *testing.T) { + m := NewListenerManager() + pc, err := m.ListenPacket("127.0.0.1:0") + require.NoError(t, err) + pc2, err := m.ListenPacket("127.0.0.1:0") + require.NoError(t, err) + // Close only the first listener. + pc.Close() + + done := make(chan struct{}) + go func() { + _, _, readErr := pc2.ReadFrom(nil) + require.NoError(t, readErr) + done <- struct{}{} + }() + _, err = pc.WriteTo(nil, pc2.LocalAddr()) + + require.NoError(t, err) + <-done +} + +func TestListenerManagerPacketListenerCreatesListenerOnDemand(t *testing.T) { + m := NewListenerManager() + // Create a listener and immediately close it. + pc, err := m.ListenPacket("127.0.0.1:0") require.NoError(t, err) + pc.Close() + // Now create another listener on the same address. + pc2, err := m.ListenPacket("127.0.0.1:0") + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + _, _, readErr := pc2.ReadFrom(nil) + require.NoError(t, readErr) + done <- struct{}{} + }() + _, err = pc2.WriteTo(nil, pc2.LocalAddr()) + require.NoError(t, err) <-done } From baea2a2344ddd781900a76ac336539d3f41c2f74 Mon Sep 17 00:00:00 2001 From: sbruens Date: Wed, 7 Aug 2024 17:31:44 -0400 Subject: [PATCH 119/119] Move PROXY proto logic to its own `proxyproto.go` file. --- service/listeners.go | 30 ------------------------- service/proxyproto.go | 51 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 30 deletions(-) create mode 100644 service/proxyproto.go diff --git a/service/listeners.go b/service/listeners.go index 3f180b99..25974a94 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -15,7 +15,6 @@ package service import ( - "bufio" "errors" "fmt" "io" @@ -24,7 +23,6 @@ import ( "sync/atomic" "github.com/Jigsaw-Code/outline-sdk/transport" - "github.com/pires/go-proxyproto" ) // The implementations of listeners for different network types are not @@ -142,34 +140,6 @@ func (sl *virtualStreamListener) Addr() net.Addr { return sl.addr } -// ProxyStreamListener wraps a [StreamListener] and fetches the source of the connection from the PROXY -// protocol header string. See https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt. -type ProxyStreamListener struct { - StreamListener -} - -// AcceptStream waits for the next incoming connection, parses the client IP from the PROXY protocol -// header, and adds it to the connection. -func (l *ProxyStreamListener) AcceptStream() (ClientStreamConn, error) { - conn, err := l.StreamListener.AcceptStream() - if err != nil { - return nil, err - } - r := bufio.NewReader(conn) - header, err := proxyproto.Read(r) - if errors.Is(err, proxyproto.ErrNoProxyProtocol) { - logger.Warningf("Received connection from %v without proxy header.", conn.RemoteAddr()) - return conn, nil - } - if header == nil || err != nil { - return nil, fmt.Errorf("error parsing proxy header: %v", err) - } - if header.Command.IsLocal() { - return conn, nil - } - return &clientStreamConn{StreamConn: conn, clientAddr: header.SourceAddr}, nil -} - type packetResponse struct { n int addr net.Addr diff --git a/service/proxyproto.go b/service/proxyproto.go new file mode 100644 index 00000000..3978daf9 --- /dev/null +++ b/service/proxyproto.go @@ -0,0 +1,51 @@ +// Copyright 2024 The Outline Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "bufio" + "errors" + "fmt" + + "github.com/pires/go-proxyproto" +) + +// ProxyStreamListener wraps a [StreamListener] and fetches the source of the connection from the PROXY +// protocol header string. See https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt. +type ProxyStreamListener struct { + StreamListener +} + +// AcceptStream waits for the next incoming connection, parses the client IP from the PROXY protocol +// header, and adds it to the connection. +func (l *ProxyStreamListener) AcceptStream() (ClientStreamConn, error) { + conn, err := l.StreamListener.AcceptStream() + if err != nil { + return nil, err + } + r := bufio.NewReader(conn) + header, err := proxyproto.Read(r) + if errors.Is(err, proxyproto.ErrNoProxyProtocol) { + logger.Warningf("Received connection from %v without proxy header.", conn.RemoteAddr()) + return conn, nil + } + if header == nil || err != nil { + return nil, fmt.Errorf("error parsing proxy header: %v", err) + } + if header.Command.IsLocal() { + return conn, nil + } + return &clientStreamConn{StreamConn: conn, clientAddr: header.SourceAddr}, nil +}