Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ws): send pings every 10 seconds #51

Merged
merged 8 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package connection
package proxy

import (
"bytes"
Expand All @@ -12,19 +12,36 @@ import (
)

const (
// Time allowed to write a message to the peer.
writeWait = 10 * time.Second

// Time allowed to read the next pong message from the peer.
pongWait = 60 * time.Second

// Send pings to peer with this period. Must be less than pongWait.
pingPeriod = (pongWait * 9) / 10
DefaultWriteTimeout = 10 * time.Second
DefaultReadTimeout = 20 * time.Second
DefaultPingFrequency = DefaultReadTimeout / 4
DefaultDisconnectAfter = 3 * time.Minute

// Maximum message size allowed from peer.
maxMessageSize = 10 * (2 << 20)
)

type Config struct {
WriteTimeout time.Duration
ReadTimeout time.Duration
PingFrequency time.Duration
DisconnectAfter time.Duration
}

func DefaultConfig() Config {
return Config{
WriteTimeout: DefaultWriteTimeout,
ReadTimeout: DefaultReadTimeout,
PingFrequency: DefaultPingFrequency,
DisconnectAfter: DefaultDisconnectAfter,
}
}

type WriterCloser interface {
Write(msg []byte)
Close()
}

var (
newline = []byte{'\n'}
space = []byte{' '}
Expand All @@ -37,20 +54,21 @@ var (
}
)

// proxy is a responsible for reading from read chan and sending it over wsConn
// and reading fom wsChan and sending it over send chan
// proxy is a responsible for reading from reader chan and sending it over conn
// and reading fom conn and sending it over writer.
type proxy struct {
send *safeChannel
read chan []byte

conn *websocket.Conn
writer WriterCloser
reader chan []byte
conn *websocket.Conn
cfg Config
}

func startProxy(wsConn *websocket.Conn, send *safeChannel, read chan []byte) {
func Start(wsConn *websocket.Conn, writer WriterCloser, reader chan []byte, cfg Config) {
proxy := &proxy{
send: send,
read: read,
conn: wsConn,
writer: writer,
reader: reader,
conn: wsConn,
cfg: cfg,
}

wg := sync.WaitGroup{}
Expand All @@ -67,33 +85,32 @@ func startProxy(wsConn *websocket.Conn, send *safeChannel, read chan []byte) {
})

go recovery.DoNotPanic(func() {
disconnectAfter := 3 * time.Minute
timeout := time.After(disconnectAfter)
timeout := time.After(cfg.DisconnectAfter)

<-timeout
logging.Info("Connection closed after", disconnectAfter)
logging.Info("Connection closed after", cfg.DisconnectAfter)

proxy.conn.Close()
})

wg.Wait()
}

// readPump pumps messages from the websocket proxy to send.
// readPump pumps messages from the websocket proxy to writer.
//
// The application runs readPump in a per-proxy goroutine. The application
// ensures that there is at most one reader on a proxy by executing all
// reads from this goroutine.
func (p *proxy) readPump() {
defer func() {
p.conn.Close()
p.send.close()
p.writer.Close()
}()

p.conn.SetReadLimit(maxMessageSize)
p.conn.SetReadDeadline(time.Now().Add(pongWait))
p.conn.SetReadDeadline(time.Now().Add(p.cfg.ReadTimeout))
p.conn.SetPongHandler(func(string) error {
p.conn.SetReadDeadline(time.Now().Add(pongWait))
p.conn.SetReadDeadline(time.Now().Add(p.cfg.ReadTimeout))
return nil
})

Expand All @@ -112,26 +129,26 @@ func (p *proxy) readPump() {
break
}
message = bytes.TrimSpace(bytes.Replace(message, newline, space, -1))
p.send.write(message)
p.writer.Write(message)
}
}

// writePump pumps messages from the read chan to the websocket proxy.
// writePump pumps messages from the reader chan to the websocket proxy.
//
// A goroutine running writePump is started for each proxy. The
// application ensures that there is at most one writer to a proxy by
// executing all writes from this goroutine.
func (p *proxy) writePump() {
ticker := time.NewTicker(pingPeriod)
ticker := time.NewTicker(p.cfg.PingFrequency)
defer func() {
ticker.Stop()
p.conn.Close()
}()

for {
select {
case message, ok := <-p.read:
p.conn.SetWriteDeadline(time.Now().Add(writeWait))
case message, ok := <-p.reader:
p.conn.SetWriteDeadline(time.Now().Add(p.cfg.WriteTimeout))
if !ok {
// The hub closed the channel.
p.conn.WriteMessage(websocket.CloseMessage, []byte{})
Expand All @@ -148,7 +165,7 @@ func (p *proxy) writePump() {
return
}
case <-ticker.C:
p.conn.SetWriteDeadline(time.Now().Add(writeWait))
p.conn.SetWriteDeadline(time.Now().Add(p.cfg.WriteTimeout))
if err := p.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
Expand Down
15 changes: 7 additions & 8 deletions internal/pass/connection/proxy_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ type proxyPool struct {
proxies map[string]*proxyPair
}

// registerMobileConn register proxyPair if not existing in pool and returns it.
func (pp *proxyPool) getOrCreateProxyPair(id string) *proxyPair {
// getOrCreateProxyPair registers proxyPair if not existing in pool and returns it.
func (pp *proxyPool) getOrCreateProxyPair(id string, disconnectAfter time.Duration) *proxyPair {
pp.mu.Lock()
defer pp.mu.Unlock()
v, ok := pp.proxies[id]
if !ok {
v = initProxyPair()
v = initProxyPair(disconnectAfter)
}
pp.proxies[id] = v
return v
Expand Down Expand Up @@ -48,12 +48,11 @@ type proxyPair struct {
}

// initProxyPair returns proxyPair and runs loop responsible for proxing data.
func initProxyPair() *proxyPair {
const proxyTimeout = 3 * time.Minute
func initProxyPair(disconnectAfter time.Duration) *proxyPair {
return &proxyPair{
toMobileDataCh: newSafeChannel(),
toExtensionDataCh: newSafeChannel(),
expiresAt: time.Now().Add(proxyTimeout),
expiresAt: time.Now().Add(disconnectAfter + time.Minute),
}
}

Expand All @@ -69,7 +68,7 @@ func newSafeChannel() *safeChannel {
}
}

func (sc *safeChannel) write(data []byte) {
func (sc *safeChannel) Write(data []byte) {
sc.mu.Lock()
defer sc.mu.Unlock()

Expand All @@ -80,7 +79,7 @@ func (sc *safeChannel) write(data []byte) {
sc.channel <- data
}

func (sc *safeChannel) close() {
func (sc *safeChannel) Close() {
sc.mu.Lock()
defer sc.mu.Unlock()

Expand Down
21 changes: 12 additions & 9 deletions internal/pass/connection/proxy_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@ import (
"time"

"github.com/twofas/2fas-server/internal/common/logging"
"github.com/twofas/2fas-server/internal/pass/connection/proxy"
)

// ProxyServer manages proxy connections between Browser Extension and Mobile.
type ProxyServer struct {
proxyPool *proxyPool
idLabel string
proxyPool *proxyPool
idLabel string
proxyConfig proxy.Config
}

func NewProxyServer(idLabel string) *ProxyServer {
func NewProxyServer(idLabel string, proxyConfig proxy.Config) *ProxyServer {
proxyPool := &proxyPool{proxies: map[string]*proxyPair{}}
go func() {
ticker := time.NewTicker(30 * time.Second)
Expand All @@ -24,8 +26,9 @@ func NewProxyServer(idLabel string) *ProxyServer {
}
}()
return &ProxyServer{
proxyPool: proxyPool,
idLabel: idLabel,
proxyPool: proxyPool,
idLabel: idLabel,
proxyConfig: proxyConfig,
}
}

Expand All @@ -38,8 +41,8 @@ func (p *ProxyServer) ServeExtensionProxyToMobileWS(w http.ResponseWriter, r *ht

log.Infof("Starting ServeExtensionProxyToMobileWS")

proxyPair := p.proxyPool.getOrCreateProxyPair(id)
startProxy(conn, proxyPair.toMobileDataCh, proxyPair.toExtensionDataCh.channel)
proxyPair := p.proxyPool.getOrCreateProxyPair(id, p.proxyConfig.DisconnectAfter)
proxy.Start(conn, proxyPair.toMobileDataCh, proxyPair.toExtensionDataCh.channel, p.proxyConfig)

p.proxyPool.deleteProxyPair(id)
return nil
Expand All @@ -52,9 +55,9 @@ func (p *ProxyServer) ServeMobileProxyToExtensionWS(w http.ResponseWriter, r *ht
}

logging.Infof("Starting ServeMobileProxyToExtensionWS for dev: %v", id)
proxyPair := p.proxyPool.getOrCreateProxyPair(id)
proxyPair := p.proxyPool.getOrCreateProxyPair(id, p.proxyConfig.DisconnectAfter)

startProxy(conn, proxyPair.toExtensionDataCh, proxyPair.toMobileDataCh.channel)
proxy.Start(conn, proxyPair.toExtensionDataCh, proxyPair.toMobileDataCh.channel, p.proxyConfig)

p.proxyPool.deleteProxyPair(id)
return nil
Expand Down
Loading