diff --git a/client/client.go b/client/client.go index 59d2056..96af421 100644 --- a/client/client.go +++ b/client/client.go @@ -22,7 +22,7 @@ type ClientOpt func(c *Client) // to the control server func WithControlTLSConfig(tlsConfig *tls.Config) ClientOpt { return func(c *Client) { - c.controlTTLSconfig = tlsConfig + c.controlTTLSconfig = tlsConfig.Clone() } } diff --git a/server/server.go b/server/server.go index 202ec05..866bd3d 100644 --- a/server/server.go +++ b/server/server.go @@ -37,6 +37,7 @@ func New(basename string, logger *zap.Logger) *Server { func (s *Server) Start(laddr string, tlsConfig *tls.Config) error { var serverName string + tlsConfig = tlsConfig.Clone() tlsConfig.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { serverName = info.ServerName return nil, nil diff --git a/test/e2e_test.go b/test/e2e_test.go index e71050e..885fb66 100644 --- a/test/e2e_test.go +++ b/test/e2e_test.go @@ -90,9 +90,42 @@ func generateCertificate(cn string) (*tls.Config, error) { return tlsConfig, nil } +// startTCPEchoServer starts a server which echos back any content recieved +func startTCPEchoServer() (string, error) { + ln, err := net.Listen("tcp", ":0") + if err != nil { + return "", fmt.Errorf("listen: %w", err) + } + + go func() { + for { + conn, err := ln.Accept() + if err != nil { + continue + } + + go func(c net.Conn) { + buf := make([]byte, 1024) + for { + n, err := c.Read(buf) + if err != nil { + c.Close() + return + } + c.Write(buf[:n]) + } + }(conn) + } + }() + + return ln.Addr().String(), nil +} + +// TestE2E tests the default options end to end. +// Reminder: frontend (browser) <-> server <-> client <-> backend (target service) func TestE2E(t *testing.T) { r := require.New(t) - logger := zaptest.NewLogger(t) + logger := zaptest.NewLogger(t, zaptest.Level(zap.InfoLevel)) server := server.New("localtest.me", logger) @@ -105,15 +138,35 @@ func TestE2E(t *testing.T) { }() time.Sleep(time.Millisecond * 50) + // echo backend will just send back whatever it gets + backend, err := startTCPEchoServer() + r.NoError(err) + _, port, _ := net.SplitHostPort(server.Addr().String()) controlAddr := fmt.Sprintf("control.localtest.me:%s", port) client := client.New( controlAddr, - "localhost:1234", + backend, client.WithControlTLSConfig(tlsConfig), ) err = client.Start() r.NoError(err, "client start") r.Contains(client.IssuedAddr(), "localtest.me") + + // connect a client and ensure it get the same data back + frontend, err := tls.Dial("tcp", client.IssuedAddr(), tlsConfig) + r.NoError(err) + + sentData := make([]byte, 1000) + _, err = rand.Read(sentData) + r.NoError(err) + _, err = frontend.Write(sentData) + r.NoError(err) + + receivedData := make([]byte, 2000) + recievedCount, err := frontend.Read(receivedData) + r.Equal(len(sentData), recievedCount) + receivedData = receivedData[:recievedCount] + r.Equal(sentData, receivedData) }