Skip to content

Commit

Permalink
Merge pull request #811 from strukturag/testify
Browse files Browse the repository at this point in the history
Switch to "github.com/stretchr/testify" for tests.
  • Loading branch information
fancycode authored Sep 3, 2024
2 parents 9fdd617 + 03cad99 commit 9e4d446
Show file tree
Hide file tree
Showing 50 changed files with 3,082 additions and 6,234 deletions.
32 changes: 13 additions & 19 deletions allowed_ips_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,17 @@ package signaling
import (
"net"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestAllowedIps(t *testing.T) {
require := require.New(t)
a, err := ParseAllowedIps("127.0.0.1, 192.168.0.1, 192.168.1.1/24")
if err != nil {
t.Fatal(err)
}
if a.Empty() {
t.Fatal("should not be empty")
}
if expected := `[127.0.0.1/32, 192.168.0.1/32, 192.168.1.0/24]`; a.String() != expected {
t.Errorf("expected %s, got %s", expected, a.String())
}
require.NoError(err)
require.False(a.Empty())
require.Equal(`[127.0.0.1/32, 192.168.0.1/32, 192.168.1.0/24]`, a.String())

allowed := []string{
"127.0.0.1",
Expand All @@ -51,22 +49,18 @@ func TestAllowedIps(t *testing.T) {

for _, addr := range allowed {
t.Run(addr, func(t *testing.T) {
ip := net.ParseIP(addr)
if ip == nil {
t.Errorf("error parsing %s", addr)
} else if !a.Allowed(ip) {
t.Errorf("should allow %s", addr)
assert := assert.New(t)
if ip := net.ParseIP(addr); assert.NotNil(ip, "error parsing %s", addr) {
assert.True(a.Allowed(ip), "should allow %s", addr)
}
})
}

for _, addr := range notAllowed {
t.Run(addr, func(t *testing.T) {
ip := net.ParseIP(addr)
if ip == nil {
t.Errorf("error parsing %s", addr)
} else if a.Allowed(ip) {
t.Errorf("should not allow %s", addr)
assert := assert.New(t)
if ip := net.ParseIP(addr); assert.NotNil(ip, "error parsing %s", addr) {
assert.False(a.Allowed(ip), "should not allow %s", addr)
}
})
}
Expand Down
32 changes: 11 additions & 21 deletions api_backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,42 +24,36 @@ package signaling
import (
"net/http"
"testing"

"github.com/stretchr/testify/assert"
)

func TestBackendChecksum(t *testing.T) {
t.Parallel()
assert := assert.New(t)
rnd := newRandomString(32)
body := []byte{1, 2, 3, 4, 5}
secret := []byte("shared-secret")

check1 := CalculateBackendChecksum(rnd, body, secret)
check2 := CalculateBackendChecksum(rnd, body, secret)
if check1 != check2 {
t.Errorf("Expected equal checksums, got %s and %s", check1, check2)
}
assert.Equal(check1, check2, "Expected equal checksums")

if !ValidateBackendChecksumValue(check1, rnd, body, secret) {
t.Errorf("Checksum %s could not be validated", check1)
}
if ValidateBackendChecksumValue(check1[1:], rnd, body, secret) {
t.Errorf("Checksum %s should not be valid", check1[1:])
}
if ValidateBackendChecksumValue(check1[:len(check1)-1], rnd, body, secret) {
t.Errorf("Checksum %s should not be valid", check1[:len(check1)-1])
}
assert.True(ValidateBackendChecksumValue(check1, rnd, body, secret), "Checksum should be valid")
assert.False(ValidateBackendChecksumValue(check1[1:], rnd, body, secret), "Checksum should not be valid")
assert.False(ValidateBackendChecksumValue(check1[:len(check1)-1], rnd, body, secret), "Checksum should not be valid")

request := &http.Request{
Header: make(http.Header),
}
request.Header.Set("Spreed-Signaling-Random", rnd)
request.Header.Set("Spreed-Signaling-Checksum", check1)
if !ValidateBackendChecksum(request, body, secret) {
t.Errorf("Checksum %s could not be validated from request", check1)
}
assert.True(ValidateBackendChecksum(request, body, secret), "Checksum could not be validated from request")
}

func TestValidNumbers(t *testing.T) {
t.Parallel()
assert := assert.New(t)
valid := []string{
"+12",
"+12345",
Expand All @@ -72,13 +66,9 @@ func TestValidNumbers(t *testing.T) {
"+123-45",
}
for _, number := range valid {
if !isValidNumber(number) {
t.Errorf("number %s should be valid", number)
}
assert.True(isValidNumber(number), "number %s should be valid", number)
}
for _, number := range invalid {
if isValidNumber(number) {
t.Errorf("number %s should not be valid", number)
}
assert.False(isValidNumber(number), "number %s should not be valid", number)
}
}
102 changes: 37 additions & 65 deletions api_signaling_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ package signaling
import (
"encoding/json"
"fmt"
"reflect"
"sort"
"testing"

"github.com/stretchr/testify/assert"
)

type testCheckValid interface {
Expand All @@ -53,40 +54,34 @@ func wrapMessage(messageType string, msg testCheckValid) *ClientMessage {
}

func testMessages(t *testing.T, messageType string, valid_messages []testCheckValid, invalid_messages []testCheckValid) {
t.Helper()
assert := assert.New(t)
for _, msg := range valid_messages {
if err := msg.CheckValid(); err != nil {
t.Errorf("Message %+v should be valid, got %s", msg, err)
}
assert.NoError(msg.CheckValid(), "Message %+v should be valid", msg)

// If the inner message is valid, it should also be valid in a wrapped
// ClientMessage.
if wrapped := wrapMessage(messageType, msg); wrapped == nil {
t.Errorf("Unknown message type: %s", messageType)
} else if err := wrapped.CheckValid(); err != nil {
t.Errorf("Message %+v should be valid, got %s", wrapped, err)
if wrapped := wrapMessage(messageType, msg); assert.NotNil(wrapped, "Unknown message type: %s", messageType) {
assert.NoError(wrapped.CheckValid(), "Message %+v should be valid", wrapped)
}
}
for _, msg := range invalid_messages {
if err := msg.CheckValid(); err == nil {
t.Errorf("Message %+v should not be valid", msg)
}
assert.Error(msg.CheckValid(), "Message %+v should not be valid", msg)

// If the inner message is invalid, it should also be invalid in a
// wrapped ClientMessage.
if wrapped := wrapMessage(messageType, msg); wrapped == nil {
t.Errorf("Unknown message type: %s", messageType)
} else if err := wrapped.CheckValid(); err == nil {
t.Errorf("Message %+v should not be valid", wrapped)
if wrapped := wrapMessage(messageType, msg); assert.NotNil(wrapped, "Unknown message type: %s", messageType) {
assert.Error(wrapped.CheckValid(), "Message %+v should not be valid", wrapped)
}
}
}

func TestClientMessage(t *testing.T) {
t.Parallel()
assert := assert.New(t)
// The message needs a type.
msg := ClientMessage{}
if err := msg.CheckValid(); err == nil {
t.Errorf("Message %+v should not be valid", msg)
}
assert.Error(msg.CheckValid())
}

func TestHelloClientMessage(t *testing.T) {
Expand Down Expand Up @@ -229,9 +224,8 @@ func TestHelloClientMessage(t *testing.T) {
msg := ClientMessage{
Type: "hello",
}
if err := msg.CheckValid(); err == nil {
t.Errorf("Message %+v should not be valid", msg)
}
assert := assert.New(t)
assert.Error(msg.CheckValid())
}

func TestMessageClientMessage(t *testing.T) {
Expand Down Expand Up @@ -311,9 +305,8 @@ func TestMessageClientMessage(t *testing.T) {
msg := ClientMessage{
Type: "message",
}
if err := msg.CheckValid(); err == nil {
t.Errorf("Message %+v should not be valid", msg)
}
assert := assert.New(t)
assert.Error(msg.CheckValid())
}

func TestByeClientMessage(t *testing.T) {
Expand All @@ -330,9 +323,8 @@ func TestByeClientMessage(t *testing.T) {
msg := ClientMessage{
Type: "bye",
}
if err := msg.CheckValid(); err != nil {
t.Errorf("Message %+v should be valid, got %s", msg, err)
}
assert := assert.New(t)
assert.NoError(msg.CheckValid())
}

func TestRoomClientMessage(t *testing.T) {
Expand All @@ -349,42 +341,31 @@ func TestRoomClientMessage(t *testing.T) {
msg := ClientMessage{
Type: "room",
}
if err := msg.CheckValid(); err == nil {
t.Errorf("Message %+v should not be valid", msg)
}
assert := assert.New(t)
assert.Error(msg.CheckValid())
}

func TestErrorMessages(t *testing.T) {
t.Parallel()
assert := assert.New(t)
id := "request-id"
msg := ClientMessage{
Id: id,
}
err1 := msg.NewErrorServerMessage(&Error{})
if err1.Id != id {
t.Errorf("Expected id %s, got %+v", id, err1)
}
if err1.Type != "error" || err1.Error == nil {
t.Errorf("Expected type \"error\", got %+v", err1)
}
assert.Equal(id, err1.Id, "%+v", err1)
assert.Equal("error", err1.Type, "%+v", err1)
assert.NotNil(err1.Error, "%+v", err1)

err2 := msg.NewWrappedErrorServerMessage(fmt.Errorf("test-error"))
if err2.Id != id {
t.Errorf("Expected id %s, got %+v", id, err2)
}
if err2.Type != "error" || err2.Error == nil {
t.Errorf("Expected type \"error\", got %+v", err2)
}
if err2.Error.Code != "internal_error" {
t.Errorf("Expected code \"internal_error\", got %+v", err2)
}
if err2.Error.Message != "test-error" {
t.Errorf("Expected message \"test-error\", got %+v", err2)
assert.Equal(id, err2.Id, "%+v", err2)
assert.Equal("error", err2.Type, "%+v", err2)
if assert.NotNil(err2.Error, "%+v", err2) {
assert.Equal("internal_error", err2.Error.Code, "%+v", err2)
assert.Equal("test-error", err2.Error.Message, "%+v", err2)
}
// Test "error" interface
if err2.Error.Error() != "test-error" {
t.Errorf("Expected error string \"test-error\", got %+v", err2)
}
assert.Equal("test-error", err2.Error.Error(), "%+v", err2)
}

func TestIsChatRefresh(t *testing.T) {
Expand All @@ -397,9 +378,7 @@ func TestIsChatRefresh(t *testing.T) {
Data: data_true,
},
}
if !msg.IsChatRefresh() {
t.Error("message should be detected as chat refresh")
}
assert.True(t, msg.IsChatRefresh())

data_false := []byte("{\"type\":\"chat\",\"chat\":{\"refresh\":false}}")
msg = ServerMessage{
Expand All @@ -408,9 +387,7 @@ func TestIsChatRefresh(t *testing.T) {
Data: data_false,
},
}
if msg.IsChatRefresh() {
t.Error("message should not be detected as chat refresh")
}
assert.False(t, msg.IsChatRefresh())
}

func assertEqualStrings(t *testing.T, expected, result []string) {
Expand All @@ -427,27 +404,22 @@ func assertEqualStrings(t *testing.T, expected, result []string) {
sort.Strings(result)
}

if !reflect.DeepEqual(expected, result) {
t.Errorf("Expected %+v, got %+v", expected, result)
}
assert.Equal(t, expected, result)
}

func Test_Welcome_AddRemoveFeature(t *testing.T) {
t.Parallel()
assert := assert.New(t)
var msg WelcomeServerMessage
assertEqualStrings(t, []string{}, msg.Features)

msg.AddFeature("one", "two", "one")
assertEqualStrings(t, []string{"one", "two"}, msg.Features)
if !sort.StringsAreSorted(msg.Features) {
t.Errorf("features should be sorted, got %+v", msg.Features)
}
assert.True(sort.StringsAreSorted(msg.Features), "features should be sorted, got %+v", msg.Features)

msg.AddFeature("three")
assertEqualStrings(t, []string{"one", "two", "three"}, msg.Features)
if !sort.StringsAreSorted(msg.Features) {
t.Errorf("features should be sorted, got %+v", msg.Features)
}
assert.True(sort.StringsAreSorted(msg.Features), "features should be sorted, got %+v", msg.Features)

msg.RemoveFeature("three", "one")
assertEqualStrings(t, []string{"two"}, msg.Features)
Expand Down
6 changes: 4 additions & 2 deletions async_events_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import (
"context"
"strings"
"testing"

"github.com/stretchr/testify/require"
)

var (
Expand All @@ -51,15 +53,15 @@ func getRealAsyncEventsForTest(t *testing.T) AsyncEvents {
url := startLocalNatsServer(t)
events, err := NewAsyncEvents(url)
if err != nil {
t.Fatal(err)
require.NoError(t, err)
}
return events
}

func getLoopbackAsyncEventsForTest(t *testing.T) AsyncEvents {
events, err := NewAsyncEvents(NatsLoopbackUrl)
if err != nil {
t.Fatal(err)
require.NoError(t, err)
}

t.Cleanup(func() {
Expand Down
Loading

0 comments on commit 9e4d446

Please sign in to comment.