diff --git a/application.go b/application.go index 29b1fae..2784df5 100644 --- a/application.go +++ b/application.go @@ -226,6 +226,14 @@ func (a *Application) Close() { a.logger.Errorf("closing api server: %v", err) } } + + if a.Tunnel != nil { + a.Tunnel.Close() + } + if a.SOCKS5 != nil { + a.SOCKS5.Close() + } + if a.P2p != nil { err := a.P2p.Close() if err != nil { @@ -235,12 +243,6 @@ func (a *Application) Close() { if a.Dns != nil { a.Dns.Close() } - if a.Tunnel != nil { - a.Tunnel.Close() - } - if a.SOCKS5 != nil { - a.SOCKS5.Close() - } if a.vpnDevice != nil { err := a.vpnDevice.Close() if err != nil { diff --git a/cmd/awl-tray/tray.go b/cmd/awl-tray/tray.go index 104e6d9..3bb7b4c 100644 --- a/cmd/awl-tray/tray.go +++ b/cmd/awl-tray/tray.go @@ -161,15 +161,15 @@ func setPeersConnectedCounter(peers int) { } func refreshPeersCounterOnPeersConnectionChanged(peerID *string) { - app.Conf.RLock() - defer app.Conf.RUnlock() - if peerID != nil { if _, known := app.Conf.GetPeer(*peerID); !known { return } } + app.Conf.RLock() + defer app.Conf.RUnlock() + connected := 0 for _, knownPeer := range app.Conf.KnownPeers { online := app.P2p.IsConnected(knownPeer.PeerId()) diff --git a/service/socks5.go b/service/socks5.go index 454cf91..9cfef15 100644 --- a/service/socks5.go +++ b/service/socks5.go @@ -103,8 +103,14 @@ func (s *SOCKS5) ServeConns(ctx context.Context) { proxyConns := s.client.ConnsChan() for conn := range proxyConns { go func() { - // TODO: check err? - _ = s.proxyConn(ctx, conn) + defer func() { + _ = conn.Close() + }() + + err := s.proxyConn(ctx, conn) + if err != nil { + _ = s.server.SendServerFailureReply(conn) + } }() } } @@ -119,22 +125,16 @@ func (s *SOCKS5) SetProxyingLocalhostEnabled(enabled bool) { } func (s *SOCKS5) proxyConn(ctx context.Context, conn net.Conn) error { - defer func() { - _ = conn.Close() - }() - s.conf.RLock() usePeerID := s.conf.SOCKS5.UsingPeerID s.conf.RUnlock() if usePeerID == "" { - _ = s.server.SendServerFailureReply(conn) return errors.New("no peer is set for proxy") } peer, exists := s.conf.GetPeer(usePeerID) if !exists || !peer.AllowedUsingAsExitNode { - _ = s.server.SendServerFailureReply(conn) return fmt.Errorf("configured proxy peer %s don't allow us proxying", usePeerID) } @@ -158,34 +158,46 @@ func (s *SOCKS5) proxyConn(ctx context.Context, conn net.Conn) error { } func (s *SOCKS5) handleStream(conn net.Conn, stream network.Stream) { + // TODO: SetDeadline on conn for ~5 min just in case? wg := &sync.WaitGroup{} wg.Add(2) go func() { defer wg.Done() // Copy from conn to stream - s.copyStream(conn, stream) + _ = s.copyStream(conn, stream) }() go func() { defer wg.Done() // Copy from stream to conn - s.copyStream(stream, conn) - - // in some cases stream could finish writing before conn (e.g for errors) - // without closing conn we will have a deadlock - _ = conn.Close() + _ = s.copyStream(stream, conn) }() wg.Wait() } -func (s *SOCKS5) copyStream(from io.ReadCloser, to io.WriteCloser) { +func (s *SOCKS5) copyStream(from io.ReadCloser, to io.WriteCloser) error { const bufSize = 32 * 1024 buf := pool.Get(bufSize) defer func() { pool.Put(buf) }() - _, _ = io.CopyBuffer(to, from, buf) - // ignore error, we can do nothing about it + _, err := io.CopyBuffer(to, from, buf) + + type closeWriter interface { + CloseWrite() error + } + if conn, ok := to.(closeWriter); ok { + _ = conn.CloseWrite() + } + + type closeReader interface { + CloseRead() error + } + if conn, ok := from.(closeReader); ok { + _ = conn.CloseRead() + } + + return err } diff --git a/socks5/server.go b/socks5/server.go index 12cfe65..1b4a9d4 100644 --- a/socks5/server.go +++ b/socks5/server.go @@ -55,6 +55,7 @@ func (s *Server) ServeStreamConn(stream network.Stream) error { return s.socks.ServeConn(conn) } +// ServeConn is only used in tests. TODO: refactor tests func (s *Server) ServeConn(ioConn io.ReadWriteCloser) error { conn := ReadWriterConnWrapper{ReadWriteCloser: ioConn} return s.socks.ServeConn(conn) diff --git a/vpn/vpn.go b/vpn/vpn.go index 22a1cea..75d5cf3 100644 --- a/vpn/vpn.go +++ b/vpn/vpn.go @@ -32,6 +32,7 @@ type Device struct { localIP net.IP outboundCh chan *Packet + closeCh chan struct{} packetsPool sync.Pool logger *log.ZapEventLogger } @@ -62,7 +63,8 @@ func NewDevice(existingTun tun.Device, interfaceName string, localIP net.IP, ipM New: func() interface{} { return new(Packet) }}, - logger: log.Logger("awl/vpn"), + logger: log.Logger("awl/vpn"), + closeCh: make(chan struct{}), } go dev.tunEventsReader() go dev.tunPacketsReader() @@ -106,6 +108,7 @@ func (d *Device) OutboundChan() <-chan *Packet { } func (d *Device) Close() error { + close(d.closeCh) return d.tun.Close() } @@ -173,7 +176,12 @@ func (d *Device) tunPacketsReader() { continue } - d.outboundCh <- data + select { + case <-d.closeCh: + return + case d.outboundCh <- data: + // ok + } packets[i] = nil }