From a44414bda67b50060974e84ad31c23f89de05bf5 Mon Sep 17 00:00:00 2001
From: Alex Gartner <git@agartner.com>
Date: Wed, 27 Nov 2024 16:04:26 -0800
Subject: [PATCH] client: refactor options and naming (#2)

---
 client/client.go          | 241 ++++++++++++++++++++++++++++++++++++++
 client/tunnel.go          | 214 ---------------------------------
 cmd/tunnel-client/main.go |  11 +-
 test/e2e_test.go          |   6 +-
 4 files changed, 251 insertions(+), 221 deletions(-)
 create mode 100644 client/client.go
 delete mode 100644 client/tunnel.go

diff --git a/client/client.go b/client/client.go
new file mode 100644
index 0000000..59d2056
--- /dev/null
+++ b/client/client.go
@@ -0,0 +1,241 @@
+package client
+
+import (
+	"bytes"
+	"context"
+	"crypto/tls"
+	"fmt"
+	"io"
+	"net"
+	"net/http"
+	"net/url"
+	"strings"
+	"sync"
+	"time"
+
+	gNet "gitlab.com/gartnera/golib/net"
+)
+
+type ClientOpt func(c *Client)
+
+// WithControlTLSConfig sets the `tls.Config` used when connecting
+// to the control server
+func WithControlTLSConfig(tlsConfig *tls.Config) ClientOpt {
+	return func(c *Client) {
+		c.controlTTLSconfig = tlsConfig
+	}
+}
+
+// WithHostname requests a specific hostname. A token must be specified if using a hostname.
+func WithHostname(hostname string, token string) ClientOpt {
+	return func(c *Client) {
+		c.hostname = hostname
+		c.token = token
+	}
+}
+
+// WithUseTLS enables use of TLS when connecting to the target
+func WithUseTLS(useTLS bool) ClientOpt {
+	return func(c *Client) {
+		c.useTLS = useTLS
+	}
+}
+
+// WithTLSSkipVerify skips verification of TLS certificates when connecting to the target
+func WithTLSSkipVerify(skipVerify bool) ClientOpt {
+	return func(c *Client) {
+		c.tlsSkipVerify = skipVerify
+	}
+}
+
+// WithHTTPTargetHostHeader rewrites the HTTP Host header to the target name.
+// This is useful if the target breaks if it's hostname is unexpected.
+func WithHTTPTargetHostHeader(useHostHeader bool) ClientOpt {
+	return func(c *Client) {
+		c.httpTargetHostHeader = useHostHeader
+	}
+}
+
+type Client struct {
+	token                string
+	server               string
+	hostname             string
+	useTLS               bool
+	tlsSkipVerify        bool
+	target               string
+	httpTargetHostHeader bool
+
+	controlTTLSconfig *tls.Config
+	issuedAddr        string
+	connectLock       sync.Mutex
+}
+
+func New(server, target string, opts ...ClientOpt) *Client {
+	c := &Client{
+		server:            server,
+		target:            target,
+		controlTTLSconfig: &tls.Config{},
+	}
+
+	for _, opt := range opts {
+		opt(c)
+	}
+	serverName, _, _ := net.SplitHostPort(c.server)
+	c.controlTTLSconfig.ServerName = serverName
+
+	return c
+}
+
+func (c *Client) Start() error {
+	conn, err := c.stage1(true)
+	if err != nil {
+		return fmt.Errorf("unable to complete initial connection to server: %w", err)
+	}
+	go c.stage2(conn)
+
+	for i := 0; i < 20; i++ {
+		go c.both()
+	}
+	return nil
+}
+
+func (c *Client) Shutdown() {
+	conn, err := tls.Dial("tcp", c.server, c.controlTTLSconfig)
+	if err != nil {
+		panic(err)
+	}
+	msg := fmt.Sprintf("backend-shutdown:%s:%s", c.token, c.hostname)
+	_, err = conn.Write([]byte(msg))
+	if err != nil {
+		panic(err)
+	}
+}
+
+// IssuedAddr gets the address issued by the server
+func (c *Client) IssuedAddr() string {
+	return c.issuedAddr
+}
+
+// IssuedAddrHTTPS gets the address issued by the server with https prefix
+func (c *Client) IssuedAddrHTTPS() string {
+	addr := c.issuedAddr
+	if strings.HasSuffix(addr, ":443") {
+		addr = strings.TrimSuffix(addr, ":443")
+	}
+	return fmt.Sprintf("https://%s", addr)
+}
+
+func (c *Client) stage1(first bool) (net.Conn, error) {
+	var err error
+	var conn net.Conn
+	backoff := time.Second * 10
+	c.connectLock.Lock()
+	for {
+		conn, err = tls.Dial("tcp", c.server, c.controlTTLSconfig)
+		if err == nil {
+			break
+		}
+		fmt.Printf("error while connecting to server: %s\n", err)
+		time.Sleep(backoff)
+		backoff = backoff + (time.Second * 10)
+	}
+	c.connectLock.Unlock()
+	msg := fmt.Sprintf("backend-open:%s:%s", c.token, c.hostname)
+	_, err = conn.Write([]byte(msg))
+	if err != nil {
+		return nil, fmt.Errorf("unable to write to conn: %w", err)
+	}
+
+	buf := make([]byte, 512)
+
+	n, err := conn.Read(buf)
+	if err != nil {
+		return nil, fmt.Errorf("unable to read from conn: %w", err)
+
+	}
+	res := string(buf[:n])
+	if first {
+		_, port, _ := net.SplitHostPort(c.server)
+		if port == "" {
+			port = "443"
+		}
+		c.issuedAddr = fmt.Sprintf("%s:%s", res, port)
+	}
+	return conn, nil
+}
+
+func (c *Client) both() {
+	conn, err := c.stage1(false)
+	if err != nil {
+		fmt.Printf("unable to connect to server: %v\n", err)
+		go c.both()
+		return
+	}
+	c.stage2(conn)
+}
+
+func (c *Client) dialTLS(network, addr string) (net.Conn, error) {
+	host, port, _ := net.SplitHostPort(addr)
+	conf := &tls.Config{
+		ServerName:         host,
+		InsecureSkipVerify: c.tlsSkipVerify,
+	}
+	addrWithPort := addr
+	if port == "" {
+		addrWithPort += ":443"
+	}
+	conn, err := tls.Dial("tcp", addrWithPort, conf)
+	if err != nil {
+		return nil, fmt.Errorf("unable to dial tls: %w", err)
+	}
+	return conn, nil
+}
+
+func (c *Client) stage2(conn net.Conn) {
+	buf := make([]byte, 100)
+	n, err := conn.Read(buf)
+
+	go c.both()
+
+	// conn closed by server or other
+	if err != nil {
+		conn.Close()
+		return
+	}
+	res := string(buf[:n])
+	if res != "frontend-connected" {
+		conn.Close()
+		return
+	}
+	defer conn.Close()
+
+	var tConn net.Conn
+	if strings.HasPrefix(c.target, "http") {
+		lis := NewSingleConnListener(conn)
+		targetUrl, _ := url.Parse(c.target)
+		reverseProxy := NewSingleHostReverseProxy(targetUrl, c.httpTargetHostHeader)
+		reverseProxy.Transport = &http.Transport{
+			DialTLS:         c.dialTLS,
+			IdleConnTimeout: time.Second * 10,
+		}
+		_ = http.Serve(lis, reverseProxy)
+		return
+	} else if c.useTLS {
+		tConn, err = c.dialTLS("tcp", c.target)
+	} else {
+		tConn, err = net.Dial("tcp", c.target)
+	}
+	if err != nil {
+		s := fmt.Sprintf("target %s returned error %s", c.target, err)
+		r := http.Response{
+			StatusCode: 500,
+			Body:       io.NopCloser(bytes.NewBufferString(s)),
+		}
+		r.Write(conn)
+		return
+	}
+
+	ctx := context.Background()
+	gNet.PipeConn(ctx, tConn, conn)
+	tConn.Close()
+}
diff --git a/client/tunnel.go b/client/tunnel.go
deleted file mode 100644
index a411ae6..0000000
--- a/client/tunnel.go
+++ /dev/null
@@ -1,214 +0,0 @@
-package client
-
-import (
-	"bytes"
-	"context"
-	"crypto/tls"
-	"fmt"
-	"io"
-	"net"
-	"net/http"
-	"net/url"
-	"strings"
-	"sync"
-	"time"
-
-	gNet "gitlab.com/gartnera/golib/net"
-)
-
-type TunnelOpt func(t *Tunnel)
-
-func WithControlTLSConfig(tlsConfig *tls.Config) TunnelOpt {
-	return func(t *Tunnel) {
-		t.controlTTLSconfig = tlsConfig
-	}
-}
-
-type Tunnel struct {
-	token                string
-	server               string
-	hostname             string
-	useTLS               bool
-	tlsSkipVerify        bool
-	target               string
-	httpTargetHostHeader bool
-
-	controlTTLSconfig *tls.Config
-	issuedAddr        string
-	connectLock       sync.Mutex
-}
-
-func New(server, hostname, token string, useTLS, tlsSkipVerify, httpTargetHostHeader bool, target string, opts ...TunnelOpt) *Tunnel {
-	t := &Tunnel{
-		server:               server,
-		hostname:             hostname,
-		token:                token,
-		useTLS:               useTLS,
-		tlsSkipVerify:        tlsSkipVerify,
-		target:               target,
-		httpTargetHostHeader: httpTargetHostHeader,
-		controlTTLSconfig:    &tls.Config{},
-	}
-
-	for _, opt := range opts {
-		opt(t)
-	}
-	serverName, _, _ := net.SplitHostPort(t.server)
-	t.controlTTLSconfig.ServerName = serverName
-
-	return t
-}
-
-func (t *Tunnel) Start() error {
-	conn, err := t.stage1(true)
-	if err != nil {
-		return fmt.Errorf("unable to complete initial connection to server: %w", err)
-	}
-	go t.stage2(conn)
-
-	for i := 0; i < 20; i++ {
-		go t.both()
-	}
-	return nil
-}
-
-func (t *Tunnel) Shutdown() {
-	conn, err := tls.Dial("tcp", t.server, t.controlTTLSconfig)
-	if err != nil {
-		panic(err)
-	}
-	msg := fmt.Sprintf("backend-shutdown:%s:%s", t.token, t.hostname)
-	_, err = conn.Write([]byte(msg))
-	if err != nil {
-		panic(err)
-	}
-}
-
-// IssuedAddr gets the address issued by the server
-func (t *Tunnel) IssuedAddr() string {
-	return t.issuedAddr
-}
-
-// IssuedAddrHTTPS gets the address issued by the server with https prefix
-func (t *Tunnel) IssuedAddrHTTPS() string {
-	addr := t.issuedAddr
-	if strings.HasSuffix(addr, ":443") {
-		addr = strings.TrimSuffix(addr, ":443")
-	}
-	return fmt.Sprintf("https://%s", addr)
-}
-
-func (t *Tunnel) stage1(first bool) (net.Conn, error) {
-	var err error
-	var conn net.Conn
-	backoff := time.Second * 10
-	t.connectLock.Lock()
-	for {
-		conn, err = tls.Dial("tcp", t.server, t.controlTTLSconfig)
-		if err == nil {
-			break
-		}
-		fmt.Printf("error while connecting to server: %s\n", err)
-		time.Sleep(backoff)
-		backoff = backoff + (time.Second * 10)
-	}
-	t.connectLock.Unlock()
-	msg := fmt.Sprintf("backend-open:%s:%s", t.token, t.hostname)
-	_, err = conn.Write([]byte(msg))
-	if err != nil {
-		return nil, fmt.Errorf("unable to write to conn: %w", err)
-	}
-
-	buf := make([]byte, 512)
-
-	n, err := conn.Read(buf)
-	if err != nil {
-		return nil, fmt.Errorf("unable to read from conn: %w", err)
-
-	}
-	res := string(buf[:n])
-	if first {
-		_, port, _ := net.SplitHostPort(t.server)
-		if port == "" {
-			port = "443"
-		}
-		t.issuedAddr = fmt.Sprintf("%s:%s", res, port)
-	}
-	return conn, nil
-}
-
-func (t *Tunnel) both() {
-	conn, err := t.stage1(false)
-	if err != nil {
-		fmt.Printf("unable to connect to server: %v\n", err)
-		go t.both()
-		return
-	}
-	t.stage2(conn)
-}
-
-func (t *Tunnel) dialTLS(network, addr string) (net.Conn, error) {
-	host, port, _ := net.SplitHostPort(addr)
-	conf := &tls.Config{
-		ServerName:         host,
-		InsecureSkipVerify: t.tlsSkipVerify,
-	}
-	addrWithPort := addr
-	if port == "" {
-		addrWithPort += ":443"
-	}
-	conn, err := tls.Dial("tcp", addrWithPort, conf)
-	if err != nil {
-		return nil, fmt.Errorf("unable to dial tls: %w", err)
-	}
-	return conn, nil
-}
-
-func (t *Tunnel) stage2(conn net.Conn) {
-	buf := make([]byte, 100)
-	n, err := conn.Read(buf)
-
-	go t.both()
-
-	// conn closed by server or other
-	if err != nil {
-		conn.Close()
-		return
-	}
-	res := string(buf[:n])
-	if res != "frontend-connected" {
-		conn.Close()
-		return
-	}
-	defer conn.Close()
-
-	var tConn net.Conn
-	if strings.HasPrefix(t.target, "http") {
-		lis := NewSingleConnListener(conn)
-		targetUrl, _ := url.Parse(t.target)
-		reverseProxy := NewSingleHostReverseProxy(targetUrl, t.httpTargetHostHeader)
-		reverseProxy.Transport = &http.Transport{
-			DialTLS:         t.dialTLS,
-			IdleConnTimeout: time.Second * 10,
-		}
-		_ = http.Serve(lis, reverseProxy)
-		return
-	} else if t.useTLS {
-		tConn, err = t.dialTLS("tcp", t.target)
-	} else {
-		tConn, err = net.Dial("tcp", t.target)
-	}
-	if err != nil {
-		s := fmt.Sprintf("target %s returned error %s", t.target, err)
-		r := http.Response{
-			StatusCode: 500,
-			Body:       io.NopCloser(bytes.NewBufferString(s)),
-		}
-		r.Write(conn)
-		return
-	}
-
-	ctx := context.Background()
-	gNet.PipeConn(ctx, tConn, conn)
-	tConn.Close()
-}
diff --git a/cmd/tunnel-client/main.go b/cmd/tunnel-client/main.go
index 957e43a..4ccf512 100644
--- a/cmd/tunnel-client/main.go
+++ b/cmd/tunnel-client/main.go
@@ -79,7 +79,7 @@ var rootCmd = &cobra.Command{
 			return err
 		}
 
-		tunnels := []*client.Tunnel{}
+		tunnels := []*client.Client{}
 		for _, server := range servers {
 			// we now use the control subdomain rather than the basename of the server
 			controlName := fmt.Sprintf("control.%s", server)
@@ -95,7 +95,14 @@ var rootCmd = &cobra.Command{
 				hostnameFqdn = strings.Join([]string{hostname, serverHostOnly}, ".")
 			}
 
-			tunnel := client.New(controlName, hostnameFqdn, token, useTLS, tlsSkipVerify, httpTargetHostHeader, target)
+			tunnel := client.New(
+				controlName,
+				target,
+				client.WithHostname(hostname, token),
+				client.WithUseTLS(useTLS),
+				client.WithTLSSkipVerify(tlsSkipVerify),
+				client.WithHTTPTargetHostHeader(httpTargetHostHeader),
+			)
 			err := tunnel.Start()
 			if err != nil {
 				return fmt.Errorf("start %s: %w", controlName, err)
diff --git a/test/e2e_test.go b/test/e2e_test.go
index 51df39f..e71050e 100644
--- a/test/e2e_test.go
+++ b/test/e2e_test.go
@@ -21,6 +21,7 @@ import (
 )
 
 // generateCertificate generates a CA certificate, client certificate, and returns a tls.Config.
+// it can be used for both clients and servers
 func generateCertificate(cn string) (*tls.Config, error) {
 	// Generate CA private key
 	caPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
@@ -108,11 +109,6 @@ func TestE2E(t *testing.T) {
 	controlAddr := fmt.Sprintf("control.localtest.me:%s", port)
 	client := client.New(
 		controlAddr,
-		"",
-		"",
-		false,
-		false,
-		false,
 		"localhost:1234",
 		client.WithControlTLSConfig(tlsConfig),
 	)