Skip to content

Commit

Permalink
chore: Refactor RelayConn implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Ehco1996 committed Feb 3, 2024
1 parent 82e57b2 commit 44e1ce2
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 67 deletions.
36 changes: 36 additions & 0 deletions internal/cmgr/cmgr.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package cmgr

import (
"sync"

"github.com/Ehco1996/ehco/internal/transporter"
)

// connection manager interface
type Cmgr interface {
ListAllConnections() []transporter.RelayConn
}

type cmgrImpl struct {
lock sync.RWMutex

// k: relay name, v: connectionList
connectionsMap map[string][]transporter.RelayConn
}

func NewCmgr() Cmgr {
return &cmgrImpl{
connectionsMap: make(map[string][]transporter.RelayConn),
}
}

func (cm *cmgrImpl) ListAllConnections() []transporter.RelayConn {
cm.lock.RLock()
defer cm.lock.RUnlock()

var conns []transporter.RelayConn
for _, v := range cm.connectionsMap {
conns = append(conns, v...)
}
return conns
}
72 changes: 38 additions & 34 deletions internal/transporter/conn.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package transporter

import (
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net"
Expand All @@ -11,50 +9,56 @@ import (
"go.uber.org/zap"
)

type RelayConn struct {
clientConn net.Conn
remoteConn net.Conn
type RelayConn interface {
// Transport transports data between the client and the remote server.
// The remoteLabel is the label of the remote server.
Transport(remoteLabel string) error

cs ConnStats
Label string // same load with relay label
// ToJSON returns the JSON representation of the connection.
// ToJSON() string
}

func NewRelayConn(label string, clientConn, remoteConn net.Conn) *RelayConn {
return &RelayConn{clientConn: clientConn, remoteConn: remoteConn, cs: NewConnStats()}
func NewRelayConn(label string, clientConn, remoteConn net.Conn) RelayConn {
return &relayConnImpl{
Label: label,
Stats: &Stats{},

clientConn: clientConn,
remoteConn: remoteConn,
}
}

type relayConnImpl struct {
// same with relay label
Label string `json:"label"`
Stats *Stats `json:"stats"`

clientConn net.Conn
remoteConn net.Conn
}

func (rc *RelayConn) Transport(remoteLabel string) error {
func (rc *relayConnImpl) Transport(remoteLabel string) error {
name := rc.Name()
shortName := shortHashSHA256(name)
cl := zap.L().Named(shortName)
cl.Debug("transport start", zap.String("full name", name), zap.String("stats", rc.cs.GetStats().String()))
defer cl.Debug("transport end", zap.String("stats", rc.cs.GetStats().String()))
err := transport(rc.clientConn, rc.remoteConn, remoteLabel, rc.cs)
cl.Debug("transport start", zap.String("full name", name), zap.String("stats", rc.Stats.String()))

err := transport(rc.clientConn, rc.remoteConn, remoteLabel, rc.Stats)
if err != nil {
cl.Error("transport error", zap.Error(err))
}
cl.Debug("transport end", zap.String("stats", rc.Stats.String()))
return err
}

func shortHashSHA256(input string) string {
hasher := sha256.New()
hasher.Write([]byte(input))
hash := hasher.Sum(nil)
return hex.EncodeToString(hash)[:7]
}

func connectionName(conn net.Conn) string {
return fmt.Sprintf("l:<%s> r:<%s>", conn.LocalAddr(), conn.RemoteAddr())
}

func (rc *RelayConn) Name() string {
func (rc *relayConnImpl) Name() string {
return fmt.Sprintf("c1:[%s] c2:[%s]", connectionName(rc.clientConn), connectionName(rc.remoteConn))
}

type readOnlyConn struct {
io.Reader
remoteLabel string
cs ConnStats
stats *Stats
}

func (r readOnlyConn) Read(p []byte) (n int, err error) {
Expand All @@ -64,36 +68,36 @@ func (r readOnlyConn) Read(p []byte) (n int, err error) {
r.remoteLabel, web.METRIC_CONN_TYPE_TCP, web.METRIC_CONN_FLOW_READ,
).Add(float64(n))
// record the traffic
r.cs.RecordTraffic(int64(n), 0)
r.stats.Record(int64(n), 0)
return
}

type writeOnlyConn struct {
io.Writer
remoteLabel string
cs ConnStats
stats *Stats
}

func (w writeOnlyConn) Write(p []byte) (n int, err error) {
n, err = w.Writer.Write(p)
web.NetWorkTransmitBytes.WithLabelValues(
w.remoteLabel, web.METRIC_CONN_TYPE_TCP, web.METRIC_CONN_FLOW_WRITE,
).Add(float64(n))
w.cs.RecordTraffic(0, int64(n))
w.stats.Record(0, int64(n))
return
}

// Note that this code assumes that conn1 is the connection to the client and conn2 is the connection to the remote server.
// leave some optimization chance for future
// * use io.CopyBuffer
// * use go routine pool
func transport(conn1, conn2 net.Conn, remoteLabel string, cs ConnStats) error {
func transport(conn1, conn2 net.Conn, remoteLabel string, stats *Stats) error {
errCH := make(chan error, 1)
// copy conn1 to conn2,read from conn1 and write to conn2
go func() {
_, err := io.Copy(
writeOnlyConn{Writer: conn2, cs: cs, remoteLabel: remoteLabel},
readOnlyConn{Reader: conn1, cs: cs, remoteLabel: remoteLabel},
writeOnlyConn{Writer: conn2, stats: stats, remoteLabel: remoteLabel},
readOnlyConn{Reader: conn1, stats: stats, remoteLabel: remoteLabel},
)
if tcpConn, ok := conn2.(*net.TCPConn); ok {
_ = tcpConn.CloseWrite() // all data is written to conn2 now, so close the write side of conn2 to send eof
Expand All @@ -103,8 +107,8 @@ func transport(conn1, conn2 net.Conn, remoteLabel string, cs ConnStats) error {

// reverse copy conn2 to conn1,read from conn2 and write to conn1
_, err := io.Copy(
writeOnlyConn{Writer: conn1, cs: cs, remoteLabel: remoteLabel},
readOnlyConn{Reader: conn2, cs: cs, remoteLabel: remoteLabel},
writeOnlyConn{Writer: conn1, stats: stats, remoteLabel: remoteLabel},
readOnlyConn{Reader: conn2, stats: stats, remoteLabel: remoteLabel},
)
if tcpConn, ok := conn1.(*net.TCPConn); ok {
_ = tcpConn.CloseWrite()
Expand Down
35 changes: 3 additions & 32 deletions internal/transporter/stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,40 +11,11 @@ type Stats struct {
down int64
}

func (s *Stats) ReSet() {
s.up = 0
s.down = 0
}

func (s *Stats) String() string {
return fmt.Sprintf("up: %s, down: %s", bytes.PrettyByteSize(float64(s.up)), bytes.PrettyByteSize(float64(s.down)))
}

type ConnStats interface {
RecordTraffic(down, up int64)

ReSetTraffic()

GetStats() *Stats
}

func NewConnStats() ConnStats {
return &connStatsImpl{s: &Stats{up: 0, down: 0}}
}

type connStatsImpl struct {
s *Stats
}

func (c *connStatsImpl) RecordTraffic(down, up int64) {
c.s.down += down
c.s.up += up
}

func (c *connStatsImpl) ReSetTraffic() {
c.s.ReSet()
}

func (c *connStatsImpl) GetStats() *Stats {
return c.s
func (s *Stats) Record(up, down int64) {
s.up += up
s.down += down
}
19 changes: 19 additions & 0 deletions internal/transporter/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package transporter

import (
"crypto/sha256"
"encoding/hex"
"fmt"
"net"
)

func shortHashSHA256(input string) string {
hasher := sha256.New()
hasher.Write([]byte(input))
hash := hasher.Sum(nil)
return hex.EncodeToString(hash)[:7]
}

func connectionName(conn net.Conn) string {
return fmt.Sprintf("l:<%s> r:<%s>", conn.LocalAddr(), conn.RemoteAddr())
}
2 changes: 1 addition & 1 deletion internal/web/metrics.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package web
package web // todo move to another package

import (
"fmt"
Expand Down

0 comments on commit 44e1ce2

Please sign in to comment.