Skip to content

Commit

Permalink
Add required tests for go/netutil (#15392)
Browse files Browse the repository at this point in the history
Signed-off-by: Noble Mittal <[email protected]>
  • Loading branch information
beingnoble03 authored Mar 5, 2024
1 parent 9d861f8 commit 171e305
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 39 deletions.
70 changes: 46 additions & 24 deletions go/netutil/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,17 @@ package netutil

import (
"net"
"strings"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func createSocketPair(t *testing.T) (net.Listener, net.Conn, net.Conn) {
// Create a listener.
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Listen failed: %v", err)
}
assert.NoError(t, err)
addr := listener.Addr().String()

// Dial a client, Accept a server.
Expand All @@ -38,9 +37,7 @@ func createSocketPair(t *testing.T) (net.Listener, net.Conn, net.Conn) {
defer wg.Done()
var err error
clientConn, err = net.Dial("tcp", addr)
if err != nil {
t.Errorf("Dial failed: %v", err)
}
assert.NoError(t, err)
}()

var serverConn net.Conn
Expand All @@ -49,9 +46,7 @@ func createSocketPair(t *testing.T) (net.Listener, net.Conn, net.Conn) {
defer wg.Done()
var err error
serverConn, err = listener.Accept()
if err != nil {
t.Errorf("Accept failed: %v", err)
}
assert.NoError(t, err)
}()

wg.Wait()
Expand All @@ -77,13 +72,7 @@ func TestReadTimeout(t *testing.T) {

select {
case err := <-c:
if err == nil {
t.Fatalf("Expected error, got nil")
}

if !strings.HasSuffix(err.Error(), "i/o timeout") {
t.Errorf("Expected error timeout, got %s", err)
}
assert.ErrorContains(t, err, "i/o timeout", "Expected error timeout")
case <-time.After(10 * time.Second):
t.Errorf("Timeout did not happen")
}
Expand Down Expand Up @@ -113,13 +102,7 @@ func TestWriteTimeout(t *testing.T) {

select {
case err := <-c:
if err == nil {
t.Fatalf("Expected error, got nil")
}

if !strings.HasSuffix(err.Error(), "i/o timeout") {
t.Errorf("Expected error timeout, got %s", err)
}
assert.ErrorContains(t, err, "i/o timeout", "Expected error timeout")
case <-time.After(10 * time.Second):
t.Errorf("Timeout did not happen")
}
Expand Down Expand Up @@ -167,3 +150,42 @@ func TestNoTimeouts(t *testing.T) {
// NOOP
}
}

func TestSetDeadline(t *testing.T) {
listener, sConn, cConn := createSocketPair(t)
defer func() {
listener.Close()
sConn.Close()
cConn.Close()
}()

cConnWithTimeout := NewConnWithTimeouts(cConn, 0, 24*time.Hour)

assert.Panics(t, func() { _ = cConnWithTimeout.SetDeadline(time.Now()) })
}

func TestSetReadDeadline(t *testing.T) {
listener, sConn, cConn := createSocketPair(t)
defer func() {
listener.Close()
sConn.Close()
cConn.Close()
}()

cConnWithTimeout := NewConnWithTimeouts(cConn, 0, 24*time.Hour)

assert.Panics(t, func() { _ = cConnWithTimeout.SetReadDeadline(time.Now()) })
}

func TestSetWriteDeadline(t *testing.T) {
listener, sConn, cConn := createSocketPair(t)
defer func() {
listener.Close()
sConn.Close()
cConn.Close()
}()

cConnWithTimeout := NewConnWithTimeouts(cConn, 0, 24*time.Hour)

assert.Panics(t, func() { _ = cConnWithTimeout.SetWriteDeadline(time.Now()) })
}
52 changes: 37 additions & 15 deletions go/netutil/netutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ limitations under the License.
package netutil

import (
"net"
"testing"

"github.com/stretchr/testify/assert"
)

func TestSplitHostPort(t *testing.T) {
Expand All @@ -33,12 +36,9 @@ func TestSplitHostPort(t *testing.T) {
}
for input, want := range table {
gotHost, gotPort, err := SplitHostPort(input)
if err != nil {
t.Errorf("SplitHostPort error: %v", err)
}
if gotHost != want.host || gotPort != want.port {
t.Errorf("SplitHostPort(%#v) = (%v, %v), want (%v, %v)", input, gotHost, gotPort, want.host, want.port)
}
assert.NoError(t, err)
assert.Equal(t, want.host, gotHost)
assert.Equal(t, want.port, gotPort)
}
}

Expand All @@ -50,9 +50,7 @@ func TestSplitHostPortFail(t *testing.T) {
}
for _, input := range inputs {
_, _, err := SplitHostPort(input)
if err == nil {
t.Errorf("expected error from SplitHostPort(%q), but got none", input)
}
assert.Error(t, err)
}
}

Expand All @@ -66,9 +64,7 @@ func TestJoinHostPort(t *testing.T) {
"[::1]:321": {host: "::1", port: 321},
}
for want, input := range table {
if got := JoinHostPort(input.host, input.port); got != want {
t.Errorf("SplitHostPort(%v, %v) = %#v, want %#v", input.host, input.port, got, want)
}
assert.Equal(t, want, JoinHostPort(input.host, input.port))
}
}

Expand All @@ -83,8 +79,34 @@ func TestNormalizeIP(t *testing.T) {
"127.": "127.",
}
for input, want := range table {
if got := NormalizeIP(input); got != want {
t.Errorf("NormalizeIP(%#v) = %#v, want %#v", input, got, want)
}
assert.Equal(t, want, NormalizeIP(input))
}
}

func TestDNSTracker(t *testing.T) {
refresh := DNSTracker("localhost")
_, err := refresh()
assert.NoError(t, err)

refresh = DNSTracker("")
val, err := refresh()
assert.NoError(t, err)
assert.False(t, val, "DNS name resolution should not have changed")
}

func TestAddrEqual(t *testing.T) {
addr1 := net.ParseIP("1.2.3.4")
addr2 := net.ParseIP("127.0.0.1")

addrSet1 := []net.IP{addr1, addr2}
addrSet2 := []net.IP{addr1}
addrSet3 := []net.IP{addr2}
ok := addrEqual(addrSet1, addrSet2)
assert.False(t, ok, "addresses %q and %q should not be equal", addrSet1, addrSet2)

ok = addrEqual(addrSet3, addrSet2)
assert.False(t, ok, "addresses %q and %q should not be equal", addrSet3, addrSet2)

ok = addrEqual(addrSet1, addrSet1)
assert.True(t, ok, "addresses %q and %q should be equal", addrSet1, addrSet1)
}

0 comments on commit 171e305

Please sign in to comment.