diff --git a/go/netutil/conn_test.go b/go/netutil/conn_test.go index 78776035856..b27f81a6311 100644 --- a/go/netutil/conn_test.go +++ b/go/netutil/conn_test.go @@ -15,18 +15,17 @@ package netutil import ( "net" - "strings" "sync" "testing" "time" + + "github.com/stretchr/testify/assert" ) func createSocketPair(t *testing.T) (net.Listener, net.Conn, net.Conn) { // Create a listener. listener, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("Listen failed: %v", err) - } + assert.NoError(t, err) addr := listener.Addr().String() // Dial a client, Accept a server. @@ -38,9 +37,7 @@ func createSocketPair(t *testing.T) (net.Listener, net.Conn, net.Conn) { defer wg.Done() var err error clientConn, err = net.Dial("tcp", addr) - if err != nil { - t.Errorf("Dial failed: %v", err) - } + assert.NoError(t, err) }() var serverConn net.Conn @@ -49,9 +46,7 @@ func createSocketPair(t *testing.T) (net.Listener, net.Conn, net.Conn) { defer wg.Done() var err error serverConn, err = listener.Accept() - if err != nil { - t.Errorf("Accept failed: %v", err) - } + assert.NoError(t, err) }() wg.Wait() @@ -77,13 +72,7 @@ func TestReadTimeout(t *testing.T) { select { case err := <-c: - if err == nil { - t.Fatalf("Expected error, got nil") - } - - if !strings.HasSuffix(err.Error(), "i/o timeout") { - t.Errorf("Expected error timeout, got %s", err) - } + assert.ErrorContains(t, err, "i/o timeout", "Expected error timeout") case <-time.After(10 * time.Second): t.Errorf("Timeout did not happen") } @@ -113,13 +102,7 @@ func TestWriteTimeout(t *testing.T) { select { case err := <-c: - if err == nil { - t.Fatalf("Expected error, got nil") - } - - if !strings.HasSuffix(err.Error(), "i/o timeout") { - t.Errorf("Expected error timeout, got %s", err) - } + assert.ErrorContains(t, err, "i/o timeout", "Expected error timeout") case <-time.After(10 * time.Second): t.Errorf("Timeout did not happen") } @@ -167,3 +150,42 @@ func TestNoTimeouts(t *testing.T) { // NOOP } } + +func TestSetDeadline(t *testing.T) { + listener, sConn, cConn := createSocketPair(t) + defer func() { + listener.Close() + sConn.Close() + cConn.Close() + }() + + cConnWithTimeout := NewConnWithTimeouts(cConn, 0, 24*time.Hour) + + assert.Panics(t, func() { _ = cConnWithTimeout.SetDeadline(time.Now()) }) +} + +func TestSetReadDeadline(t *testing.T) { + listener, sConn, cConn := createSocketPair(t) + defer func() { + listener.Close() + sConn.Close() + cConn.Close() + }() + + cConnWithTimeout := NewConnWithTimeouts(cConn, 0, 24*time.Hour) + + assert.Panics(t, func() { _ = cConnWithTimeout.SetReadDeadline(time.Now()) }) +} + +func TestSetWriteDeadline(t *testing.T) { + listener, sConn, cConn := createSocketPair(t) + defer func() { + listener.Close() + sConn.Close() + cConn.Close() + }() + + cConnWithTimeout := NewConnWithTimeouts(cConn, 0, 24*time.Hour) + + assert.Panics(t, func() { _ = cConnWithTimeout.SetWriteDeadline(time.Now()) }) +} diff --git a/go/netutil/netutil_test.go b/go/netutil/netutil_test.go index c0c0e16cfed..e5df2065033 100644 --- a/go/netutil/netutil_test.go +++ b/go/netutil/netutil_test.go @@ -17,7 +17,10 @@ limitations under the License. package netutil import ( + "net" "testing" + + "github.com/stretchr/testify/assert" ) func TestSplitHostPort(t *testing.T) { @@ -33,12 +36,9 @@ func TestSplitHostPort(t *testing.T) { } for input, want := range table { gotHost, gotPort, err := SplitHostPort(input) - if err != nil { - t.Errorf("SplitHostPort error: %v", err) - } - if gotHost != want.host || gotPort != want.port { - t.Errorf("SplitHostPort(%#v) = (%v, %v), want (%v, %v)", input, gotHost, gotPort, want.host, want.port) - } + assert.NoError(t, err) + assert.Equal(t, want.host, gotHost) + assert.Equal(t, want.port, gotPort) } } @@ -50,9 +50,7 @@ func TestSplitHostPortFail(t *testing.T) { } for _, input := range inputs { _, _, err := SplitHostPort(input) - if err == nil { - t.Errorf("expected error from SplitHostPort(%q), but got none", input) - } + assert.Error(t, err) } } @@ -66,9 +64,7 @@ func TestJoinHostPort(t *testing.T) { "[::1]:321": {host: "::1", port: 321}, } for want, input := range table { - if got := JoinHostPort(input.host, input.port); got != want { - t.Errorf("SplitHostPort(%v, %v) = %#v, want %#v", input.host, input.port, got, want) - } + assert.Equal(t, want, JoinHostPort(input.host, input.port)) } } @@ -83,8 +79,34 @@ func TestNormalizeIP(t *testing.T) { "127.": "127.", } for input, want := range table { - if got := NormalizeIP(input); got != want { - t.Errorf("NormalizeIP(%#v) = %#v, want %#v", input, got, want) - } + assert.Equal(t, want, NormalizeIP(input)) } } + +func TestDNSTracker(t *testing.T) { + refresh := DNSTracker("localhost") + _, err := refresh() + assert.NoError(t, err) + + refresh = DNSTracker("") + val, err := refresh() + assert.NoError(t, err) + assert.False(t, val, "DNS name resolution should not have changed") +} + +func TestAddrEqual(t *testing.T) { + addr1 := net.ParseIP("1.2.3.4") + addr2 := net.ParseIP("127.0.0.1") + + addrSet1 := []net.IP{addr1, addr2} + addrSet2 := []net.IP{addr1} + addrSet3 := []net.IP{addr2} + ok := addrEqual(addrSet1, addrSet2) + assert.False(t, ok, "addresses %q and %q should not be equal", addrSet1, addrSet2) + + ok = addrEqual(addrSet3, addrSet2) + assert.False(t, ok, "addresses %q and %q should not be equal", addrSet3, addrSet2) + + ok = addrEqual(addrSet1, addrSet1) + assert.True(t, ok, "addresses %q and %q should be equal", addrSet1, addrSet1) +}