From 873ef3e4a52fb409df4929de35b82753b9550aa6 Mon Sep 17 00:00:00 2001 From: Hamza El-Saawy Date: Tue, 9 Apr 2024 11:39:53 -0400 Subject: [PATCH] PR: simplify type checking Signed-off-by: Hamza El-Saawy --- hvsock_test.go | 46 +++++++++++++++++----------------------------- 1 file changed, 17 insertions(+), 29 deletions(-) diff --git a/hvsock_test.go b/hvsock_test.go index f3060645..f492f47f 100644 --- a/hvsock_test.go +++ b/hvsock_test.go @@ -57,11 +57,7 @@ func clientServer(u testUtil) (cl, sv *HvsockConn, _ *HvsockAddr) { if err != nil { return fmt.Errorf("listener accept: %w", err) } - var ok bool - sv, ok = conn.(*HvsockConn) - if !ok { - return fmt.Errorf("expected connection type %T; got %T", new(HvsockConn), conn) - } + sv = mustBeType[*HvsockConn](u.T, conn) if err := l.Close(); err != nil { return err } @@ -113,10 +109,7 @@ func TestHvSockListenerAddresses(t *testing.T) { u := newUtil(t) l, addr := serverListen(u) - la, ok := (l.Addr()).(*HvsockAddr) - if !ok { - t.Fatalf("expected type %T; got %T", new(HvsockAddr), l.Addr()) - } + la := mustBeType[*HvsockAddr](t, l.Addr()) u.Assert(*la == *addr, fmt.Sprintf("give: %v; want: %v", la, addr)) ra := rawHvsockAddr{} @@ -130,22 +123,10 @@ func TestHvSockAddresses(t *testing.T) { u := newUtil(t) cl, sv, addr := clientServer(u) - sra, ok := (sv.RemoteAddr()).(*HvsockAddr) - if !ok { - t.Fatalf("expected type %T; got %T", new(HvsockAddr), sv.RemoteAddr()) - } - sla, ok := (sv.LocalAddr()).(*HvsockAddr) - if !ok { - t.Fatalf("expected type %T; got %T", new(HvsockAddr), sv.LocalAddr()) - } - cra, ok := (cl.RemoteAddr()).(*HvsockAddr) - if !ok { - t.Fatalf("expected type %T; got %T", new(HvsockAddr), cl.RemoteAddr()) - } - cla, ok := (cl.LocalAddr()).(*HvsockAddr) - if !ok { - t.Fatalf("expected type %T; got %T", new(HvsockAddr), cl.LocalAddr()) - } + sra := mustBeType[*HvsockAddr](t, sv.RemoteAddr()) + sla := mustBeType[*HvsockAddr](t, sv.LocalAddr()) + cra := mustBeType[*HvsockAddr](t, cl.RemoteAddr()) + cla := mustBeType[*HvsockAddr](t, cl.LocalAddr()) t.Run("Info", func(t *testing.T) { tests := []struct { @@ -341,10 +322,7 @@ func TestHvSockCloseReadWriteListener(t *testing.T) { } defer c.Close() - hv, ok := c.(*HvsockConn) - if !ok { - t.Fatalf("expected type %T; got %T", new(HvsockConn), c) - } + hv := mustBeType[*HvsockConn](t, c) // // test CloseWrite() // @@ -683,3 +661,13 @@ func (u testUtil) Check() { func msgJoin(pre []string, s string) string { return strings.Join(append(pre, s), ": ") } + +func mustBeType[T any](tb testing.TB, v any) T { + tb.Helper() + + v2, ok := v.(T) + if !ok { + tb.Fatalf("expected type %T; got %T", *new(T), v) + } + return v2 +}