From a44414bda67b50060974e84ad31c23f89de05bf5 Mon Sep 17 00:00:00 2001 From: Alex Gartner <git@agartner.com> Date: Wed, 27 Nov 2024 16:04:26 -0800 Subject: [PATCH] client: refactor options and naming (#2) --- client/client.go | 241 ++++++++++++++++++++++++++++++++++++++ client/tunnel.go | 214 --------------------------------- cmd/tunnel-client/main.go | 11 +- test/e2e_test.go | 6 +- 4 files changed, 251 insertions(+), 221 deletions(-) create mode 100644 client/client.go delete mode 100644 client/tunnel.go diff --git a/client/client.go b/client/client.go new file mode 100644 index 0000000..59d2056 --- /dev/null +++ b/client/client.go @@ -0,0 +1,241 @@ +package client + +import ( + "bytes" + "context" + "crypto/tls" + "fmt" + "io" + "net" + "net/http" + "net/url" + "strings" + "sync" + "time" + + gNet "gitlab.com/gartnera/golib/net" +) + +type ClientOpt func(c *Client) + +// WithControlTLSConfig sets the `tls.Config` used when connecting +// to the control server +func WithControlTLSConfig(tlsConfig *tls.Config) ClientOpt { + return func(c *Client) { + c.controlTTLSconfig = tlsConfig + } +} + +// WithHostname requests a specific hostname. A token must be specified if using a hostname. +func WithHostname(hostname string, token string) ClientOpt { + return func(c *Client) { + c.hostname = hostname + c.token = token + } +} + +// WithUseTLS enables use of TLS when connecting to the target +func WithUseTLS(useTLS bool) ClientOpt { + return func(c *Client) { + c.useTLS = useTLS + } +} + +// WithTLSSkipVerify skips verification of TLS certificates when connecting to the target +func WithTLSSkipVerify(skipVerify bool) ClientOpt { + return func(c *Client) { + c.tlsSkipVerify = skipVerify + } +} + +// WithHTTPTargetHostHeader rewrites the HTTP Host header to the target name. +// This is useful if the target breaks if it's hostname is unexpected. +func WithHTTPTargetHostHeader(useHostHeader bool) ClientOpt { + return func(c *Client) { + c.httpTargetHostHeader = useHostHeader + } +} + +type Client struct { + token string + server string + hostname string + useTLS bool + tlsSkipVerify bool + target string + httpTargetHostHeader bool + + controlTTLSconfig *tls.Config + issuedAddr string + connectLock sync.Mutex +} + +func New(server, target string, opts ...ClientOpt) *Client { + c := &Client{ + server: server, + target: target, + controlTTLSconfig: &tls.Config{}, + } + + for _, opt := range opts { + opt(c) + } + serverName, _, _ := net.SplitHostPort(c.server) + c.controlTTLSconfig.ServerName = serverName + + return c +} + +func (c *Client) Start() error { + conn, err := c.stage1(true) + if err != nil { + return fmt.Errorf("unable to complete initial connection to server: %w", err) + } + go c.stage2(conn) + + for i := 0; i < 20; i++ { + go c.both() + } + return nil +} + +func (c *Client) Shutdown() { + conn, err := tls.Dial("tcp", c.server, c.controlTTLSconfig) + if err != nil { + panic(err) + } + msg := fmt.Sprintf("backend-shutdown:%s:%s", c.token, c.hostname) + _, err = conn.Write([]byte(msg)) + if err != nil { + panic(err) + } +} + +// IssuedAddr gets the address issued by the server +func (c *Client) IssuedAddr() string { + return c.issuedAddr +} + +// IssuedAddrHTTPS gets the address issued by the server with https prefix +func (c *Client) IssuedAddrHTTPS() string { + addr := c.issuedAddr + if strings.HasSuffix(addr, ":443") { + addr = strings.TrimSuffix(addr, ":443") + } + return fmt.Sprintf("https://%s", addr) +} + +func (c *Client) stage1(first bool) (net.Conn, error) { + var err error + var conn net.Conn + backoff := time.Second * 10 + c.connectLock.Lock() + for { + conn, err = tls.Dial("tcp", c.server, c.controlTTLSconfig) + if err == nil { + break + } + fmt.Printf("error while connecting to server: %s\n", err) + time.Sleep(backoff) + backoff = backoff + (time.Second * 10) + } + c.connectLock.Unlock() + msg := fmt.Sprintf("backend-open:%s:%s", c.token, c.hostname) + _, err = conn.Write([]byte(msg)) + if err != nil { + return nil, fmt.Errorf("unable to write to conn: %w", err) + } + + buf := make([]byte, 512) + + n, err := conn.Read(buf) + if err != nil { + return nil, fmt.Errorf("unable to read from conn: %w", err) + + } + res := string(buf[:n]) + if first { + _, port, _ := net.SplitHostPort(c.server) + if port == "" { + port = "443" + } + c.issuedAddr = fmt.Sprintf("%s:%s", res, port) + } + return conn, nil +} + +func (c *Client) both() { + conn, err := c.stage1(false) + if err != nil { + fmt.Printf("unable to connect to server: %v\n", err) + go c.both() + return + } + c.stage2(conn) +} + +func (c *Client) dialTLS(network, addr string) (net.Conn, error) { + host, port, _ := net.SplitHostPort(addr) + conf := &tls.Config{ + ServerName: host, + InsecureSkipVerify: c.tlsSkipVerify, + } + addrWithPort := addr + if port == "" { + addrWithPort += ":443" + } + conn, err := tls.Dial("tcp", addrWithPort, conf) + if err != nil { + return nil, fmt.Errorf("unable to dial tls: %w", err) + } + return conn, nil +} + +func (c *Client) stage2(conn net.Conn) { + buf := make([]byte, 100) + n, err := conn.Read(buf) + + go c.both() + + // conn closed by server or other + if err != nil { + conn.Close() + return + } + res := string(buf[:n]) + if res != "frontend-connected" { + conn.Close() + return + } + defer conn.Close() + + var tConn net.Conn + if strings.HasPrefix(c.target, "http") { + lis := NewSingleConnListener(conn) + targetUrl, _ := url.Parse(c.target) + reverseProxy := NewSingleHostReverseProxy(targetUrl, c.httpTargetHostHeader) + reverseProxy.Transport = &http.Transport{ + DialTLS: c.dialTLS, + IdleConnTimeout: time.Second * 10, + } + _ = http.Serve(lis, reverseProxy) + return + } else if c.useTLS { + tConn, err = c.dialTLS("tcp", c.target) + } else { + tConn, err = net.Dial("tcp", c.target) + } + if err != nil { + s := fmt.Sprintf("target %s returned error %s", c.target, err) + r := http.Response{ + StatusCode: 500, + Body: io.NopCloser(bytes.NewBufferString(s)), + } + r.Write(conn) + return + } + + ctx := context.Background() + gNet.PipeConn(ctx, tConn, conn) + tConn.Close() +} diff --git a/client/tunnel.go b/client/tunnel.go deleted file mode 100644 index a411ae6..0000000 --- a/client/tunnel.go +++ /dev/null @@ -1,214 +0,0 @@ -package client - -import ( - "bytes" - "context" - "crypto/tls" - "fmt" - "io" - "net" - "net/http" - "net/url" - "strings" - "sync" - "time" - - gNet "gitlab.com/gartnera/golib/net" -) - -type TunnelOpt func(t *Tunnel) - -func WithControlTLSConfig(tlsConfig *tls.Config) TunnelOpt { - return func(t *Tunnel) { - t.controlTTLSconfig = tlsConfig - } -} - -type Tunnel struct { - token string - server string - hostname string - useTLS bool - tlsSkipVerify bool - target string - httpTargetHostHeader bool - - controlTTLSconfig *tls.Config - issuedAddr string - connectLock sync.Mutex -} - -func New(server, hostname, token string, useTLS, tlsSkipVerify, httpTargetHostHeader bool, target string, opts ...TunnelOpt) *Tunnel { - t := &Tunnel{ - server: server, - hostname: hostname, - token: token, - useTLS: useTLS, - tlsSkipVerify: tlsSkipVerify, - target: target, - httpTargetHostHeader: httpTargetHostHeader, - controlTTLSconfig: &tls.Config{}, - } - - for _, opt := range opts { - opt(t) - } - serverName, _, _ := net.SplitHostPort(t.server) - t.controlTTLSconfig.ServerName = serverName - - return t -} - -func (t *Tunnel) Start() error { - conn, err := t.stage1(true) - if err != nil { - return fmt.Errorf("unable to complete initial connection to server: %w", err) - } - go t.stage2(conn) - - for i := 0; i < 20; i++ { - go t.both() - } - return nil -} - -func (t *Tunnel) Shutdown() { - conn, err := tls.Dial("tcp", t.server, t.controlTTLSconfig) - if err != nil { - panic(err) - } - msg := fmt.Sprintf("backend-shutdown:%s:%s", t.token, t.hostname) - _, err = conn.Write([]byte(msg)) - if err != nil { - panic(err) - } -} - -// IssuedAddr gets the address issued by the server -func (t *Tunnel) IssuedAddr() string { - return t.issuedAddr -} - -// IssuedAddrHTTPS gets the address issued by the server with https prefix -func (t *Tunnel) IssuedAddrHTTPS() string { - addr := t.issuedAddr - if strings.HasSuffix(addr, ":443") { - addr = strings.TrimSuffix(addr, ":443") - } - return fmt.Sprintf("https://%s", addr) -} - -func (t *Tunnel) stage1(first bool) (net.Conn, error) { - var err error - var conn net.Conn - backoff := time.Second * 10 - t.connectLock.Lock() - for { - conn, err = tls.Dial("tcp", t.server, t.controlTTLSconfig) - if err == nil { - break - } - fmt.Printf("error while connecting to server: %s\n", err) - time.Sleep(backoff) - backoff = backoff + (time.Second * 10) - } - t.connectLock.Unlock() - msg := fmt.Sprintf("backend-open:%s:%s", t.token, t.hostname) - _, err = conn.Write([]byte(msg)) - if err != nil { - return nil, fmt.Errorf("unable to write to conn: %w", err) - } - - buf := make([]byte, 512) - - n, err := conn.Read(buf) - if err != nil { - return nil, fmt.Errorf("unable to read from conn: %w", err) - - } - res := string(buf[:n]) - if first { - _, port, _ := net.SplitHostPort(t.server) - if port == "" { - port = "443" - } - t.issuedAddr = fmt.Sprintf("%s:%s", res, port) - } - return conn, nil -} - -func (t *Tunnel) both() { - conn, err := t.stage1(false) - if err != nil { - fmt.Printf("unable to connect to server: %v\n", err) - go t.both() - return - } - t.stage2(conn) -} - -func (t *Tunnel) dialTLS(network, addr string) (net.Conn, error) { - host, port, _ := net.SplitHostPort(addr) - conf := &tls.Config{ - ServerName: host, - InsecureSkipVerify: t.tlsSkipVerify, - } - addrWithPort := addr - if port == "" { - addrWithPort += ":443" - } - conn, err := tls.Dial("tcp", addrWithPort, conf) - if err != nil { - return nil, fmt.Errorf("unable to dial tls: %w", err) - } - return conn, nil -} - -func (t *Tunnel) stage2(conn net.Conn) { - buf := make([]byte, 100) - n, err := conn.Read(buf) - - go t.both() - - // conn closed by server or other - if err != nil { - conn.Close() - return - } - res := string(buf[:n]) - if res != "frontend-connected" { - conn.Close() - return - } - defer conn.Close() - - var tConn net.Conn - if strings.HasPrefix(t.target, "http") { - lis := NewSingleConnListener(conn) - targetUrl, _ := url.Parse(t.target) - reverseProxy := NewSingleHostReverseProxy(targetUrl, t.httpTargetHostHeader) - reverseProxy.Transport = &http.Transport{ - DialTLS: t.dialTLS, - IdleConnTimeout: time.Second * 10, - } - _ = http.Serve(lis, reverseProxy) - return - } else if t.useTLS { - tConn, err = t.dialTLS("tcp", t.target) - } else { - tConn, err = net.Dial("tcp", t.target) - } - if err != nil { - s := fmt.Sprintf("target %s returned error %s", t.target, err) - r := http.Response{ - StatusCode: 500, - Body: io.NopCloser(bytes.NewBufferString(s)), - } - r.Write(conn) - return - } - - ctx := context.Background() - gNet.PipeConn(ctx, tConn, conn) - tConn.Close() -} diff --git a/cmd/tunnel-client/main.go b/cmd/tunnel-client/main.go index 957e43a..4ccf512 100644 --- a/cmd/tunnel-client/main.go +++ b/cmd/tunnel-client/main.go @@ -79,7 +79,7 @@ var rootCmd = &cobra.Command{ return err } - tunnels := []*client.Tunnel{} + tunnels := []*client.Client{} for _, server := range servers { // we now use the control subdomain rather than the basename of the server controlName := fmt.Sprintf("control.%s", server) @@ -95,7 +95,14 @@ var rootCmd = &cobra.Command{ hostnameFqdn = strings.Join([]string{hostname, serverHostOnly}, ".") } - tunnel := client.New(controlName, hostnameFqdn, token, useTLS, tlsSkipVerify, httpTargetHostHeader, target) + tunnel := client.New( + controlName, + target, + client.WithHostname(hostname, token), + client.WithUseTLS(useTLS), + client.WithTLSSkipVerify(tlsSkipVerify), + client.WithHTTPTargetHostHeader(httpTargetHostHeader), + ) err := tunnel.Start() if err != nil { return fmt.Errorf("start %s: %w", controlName, err) diff --git a/test/e2e_test.go b/test/e2e_test.go index 51df39f..e71050e 100644 --- a/test/e2e_test.go +++ b/test/e2e_test.go @@ -21,6 +21,7 @@ import ( ) // generateCertificate generates a CA certificate, client certificate, and returns a tls.Config. +// it can be used for both clients and servers func generateCertificate(cn string) (*tls.Config, error) { // Generate CA private key caPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) @@ -108,11 +109,6 @@ func TestE2E(t *testing.T) { controlAddr := fmt.Sprintf("control.localtest.me:%s", port) client := client.New( controlAddr, - "", - "", - false, - false, - false, "localhost:1234", client.WithControlTLSConfig(tlsConfig), )