Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to "github.com/stretchr/testify" for tests. #811

Merged
merged 1 commit into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 13 additions & 19 deletions allowed_ips_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,17 @@ package signaling
import (
"net"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestAllowedIps(t *testing.T) {
require := require.New(t)
a, err := ParseAllowedIps("127.0.0.1, 192.168.0.1, 192.168.1.1/24")
if err != nil {
t.Fatal(err)
}
if a.Empty() {
t.Fatal("should not be empty")
}
if expected := `[127.0.0.1/32, 192.168.0.1/32, 192.168.1.0/24]`; a.String() != expected {
t.Errorf("expected %s, got %s", expected, a.String())
}
require.NoError(err)
require.False(a.Empty())
require.Equal(`[127.0.0.1/32, 192.168.0.1/32, 192.168.1.0/24]`, a.String())

allowed := []string{
"127.0.0.1",
Expand All @@ -51,22 +49,18 @@ func TestAllowedIps(t *testing.T) {

for _, addr := range allowed {
t.Run(addr, func(t *testing.T) {
ip := net.ParseIP(addr)
if ip == nil {
t.Errorf("error parsing %s", addr)
} else if !a.Allowed(ip) {
t.Errorf("should allow %s", addr)
assert := assert.New(t)
if ip := net.ParseIP(addr); assert.NotNil(ip, "error parsing %s", addr) {
assert.True(a.Allowed(ip), "should allow %s", addr)
}
})
}

for _, addr := range notAllowed {
t.Run(addr, func(t *testing.T) {
ip := net.ParseIP(addr)
if ip == nil {
t.Errorf("error parsing %s", addr)
} else if a.Allowed(ip) {
t.Errorf("should not allow %s", addr)
assert := assert.New(t)
if ip := net.ParseIP(addr); assert.NotNil(ip, "error parsing %s", addr) {
assert.False(a.Allowed(ip), "should not allow %s", addr)
}
})
}
Expand Down
32 changes: 11 additions & 21 deletions api_backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,42 +24,36 @@ package signaling
import (
"net/http"
"testing"

"github.com/stretchr/testify/assert"
)

func TestBackendChecksum(t *testing.T) {
t.Parallel()
assert := assert.New(t)
rnd := newRandomString(32)
body := []byte{1, 2, 3, 4, 5}
secret := []byte("shared-secret")

check1 := CalculateBackendChecksum(rnd, body, secret)
check2 := CalculateBackendChecksum(rnd, body, secret)
if check1 != check2 {
t.Errorf("Expected equal checksums, got %s and %s", check1, check2)
}
assert.Equal(check1, check2, "Expected equal checksums")

if !ValidateBackendChecksumValue(check1, rnd, body, secret) {
t.Errorf("Checksum %s could not be validated", check1)
}
if ValidateBackendChecksumValue(check1[1:], rnd, body, secret) {
t.Errorf("Checksum %s should not be valid", check1[1:])
}
if ValidateBackendChecksumValue(check1[:len(check1)-1], rnd, body, secret) {
t.Errorf("Checksum %s should not be valid", check1[:len(check1)-1])
}
assert.True(ValidateBackendChecksumValue(check1, rnd, body, secret), "Checksum should be valid")
assert.False(ValidateBackendChecksumValue(check1[1:], rnd, body, secret), "Checksum should not be valid")
assert.False(ValidateBackendChecksumValue(check1[:len(check1)-1], rnd, body, secret), "Checksum should not be valid")

request := &http.Request{
Header: make(http.Header),
}
request.Header.Set("Spreed-Signaling-Random", rnd)
request.Header.Set("Spreed-Signaling-Checksum", check1)
if !ValidateBackendChecksum(request, body, secret) {
t.Errorf("Checksum %s could not be validated from request", check1)
}
assert.True(ValidateBackendChecksum(request, body, secret), "Checksum could not be validated from request")
}

func TestValidNumbers(t *testing.T) {
t.Parallel()
assert := assert.New(t)
valid := []string{
"+12",
"+12345",
Expand All @@ -72,13 +66,9 @@ func TestValidNumbers(t *testing.T) {
"+123-45",
}
for _, number := range valid {
if !isValidNumber(number) {
t.Errorf("number %s should be valid", number)
}
assert.True(isValidNumber(number), "number %s should be valid", number)
}
for _, number := range invalid {
if isValidNumber(number) {
t.Errorf("number %s should not be valid", number)
}
assert.False(isValidNumber(number), "number %s should not be valid", number)
}
}
102 changes: 37 additions & 65 deletions api_signaling_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ package signaling
import (
"encoding/json"
"fmt"
"reflect"
"sort"
"testing"

"github.com/stretchr/testify/assert"
)

type testCheckValid interface {
Expand All @@ -53,40 +54,34 @@ func wrapMessage(messageType string, msg testCheckValid) *ClientMessage {
}

func testMessages(t *testing.T, messageType string, valid_messages []testCheckValid, invalid_messages []testCheckValid) {
t.Helper()
assert := assert.New(t)
for _, msg := range valid_messages {
if err := msg.CheckValid(); err != nil {
t.Errorf("Message %+v should be valid, got %s", msg, err)
}
assert.NoError(msg.CheckValid(), "Message %+v should be valid", msg)

// If the inner message is valid, it should also be valid in a wrapped
// ClientMessage.
if wrapped := wrapMessage(messageType, msg); wrapped == nil {
t.Errorf("Unknown message type: %s", messageType)
} else if err := wrapped.CheckValid(); err != nil {
t.Errorf("Message %+v should be valid, got %s", wrapped, err)
if wrapped := wrapMessage(messageType, msg); assert.NotNil(wrapped, "Unknown message type: %s", messageType) {
assert.NoError(wrapped.CheckValid(), "Message %+v should be valid", wrapped)
}
}
for _, msg := range invalid_messages {
if err := msg.CheckValid(); err == nil {
t.Errorf("Message %+v should not be valid", msg)
}
assert.Error(msg.CheckValid(), "Message %+v should not be valid", msg)

// If the inner message is invalid, it should also be invalid in a
// wrapped ClientMessage.
if wrapped := wrapMessage(messageType, msg); wrapped == nil {
t.Errorf("Unknown message type: %s", messageType)
} else if err := wrapped.CheckValid(); err == nil {
t.Errorf("Message %+v should not be valid", wrapped)
if wrapped := wrapMessage(messageType, msg); assert.NotNil(wrapped, "Unknown message type: %s", messageType) {
assert.Error(wrapped.CheckValid(), "Message %+v should not be valid", wrapped)
}
}
}

func TestClientMessage(t *testing.T) {
t.Parallel()
assert := assert.New(t)
// The message needs a type.
msg := ClientMessage{}
if err := msg.CheckValid(); err == nil {
t.Errorf("Message %+v should not be valid", msg)
}
assert.Error(msg.CheckValid())
}

func TestHelloClientMessage(t *testing.T) {
Expand Down Expand Up @@ -229,9 +224,8 @@ func TestHelloClientMessage(t *testing.T) {
msg := ClientMessage{
Type: "hello",
}
if err := msg.CheckValid(); err == nil {
t.Errorf("Message %+v should not be valid", msg)
}
assert := assert.New(t)
assert.Error(msg.CheckValid())
}

func TestMessageClientMessage(t *testing.T) {
Expand Down Expand Up @@ -311,9 +305,8 @@ func TestMessageClientMessage(t *testing.T) {
msg := ClientMessage{
Type: "message",
}
if err := msg.CheckValid(); err == nil {
t.Errorf("Message %+v should not be valid", msg)
}
assert := assert.New(t)
assert.Error(msg.CheckValid())
}

func TestByeClientMessage(t *testing.T) {
Expand All @@ -330,9 +323,8 @@ func TestByeClientMessage(t *testing.T) {
msg := ClientMessage{
Type: "bye",
}
if err := msg.CheckValid(); err != nil {
t.Errorf("Message %+v should be valid, got %s", msg, err)
}
assert := assert.New(t)
assert.NoError(msg.CheckValid())
}

func TestRoomClientMessage(t *testing.T) {
Expand All @@ -349,42 +341,31 @@ func TestRoomClientMessage(t *testing.T) {
msg := ClientMessage{
Type: "room",
}
if err := msg.CheckValid(); err == nil {
t.Errorf("Message %+v should not be valid", msg)
}
assert := assert.New(t)
assert.Error(msg.CheckValid())
}

func TestErrorMessages(t *testing.T) {
t.Parallel()
assert := assert.New(t)
id := "request-id"
msg := ClientMessage{
Id: id,
}
err1 := msg.NewErrorServerMessage(&Error{})
if err1.Id != id {
t.Errorf("Expected id %s, got %+v", id, err1)
}
if err1.Type != "error" || err1.Error == nil {
t.Errorf("Expected type \"error\", got %+v", err1)
}
assert.Equal(id, err1.Id, "%+v", err1)
assert.Equal("error", err1.Type, "%+v", err1)
assert.NotNil(err1.Error, "%+v", err1)

err2 := msg.NewWrappedErrorServerMessage(fmt.Errorf("test-error"))
if err2.Id != id {
t.Errorf("Expected id %s, got %+v", id, err2)
}
if err2.Type != "error" || err2.Error == nil {
t.Errorf("Expected type \"error\", got %+v", err2)
}
if err2.Error.Code != "internal_error" {
t.Errorf("Expected code \"internal_error\", got %+v", err2)
}
if err2.Error.Message != "test-error" {
t.Errorf("Expected message \"test-error\", got %+v", err2)
assert.Equal(id, err2.Id, "%+v", err2)
assert.Equal("error", err2.Type, "%+v", err2)
if assert.NotNil(err2.Error, "%+v", err2) {
assert.Equal("internal_error", err2.Error.Code, "%+v", err2)
assert.Equal("test-error", err2.Error.Message, "%+v", err2)
}
// Test "error" interface
if err2.Error.Error() != "test-error" {
t.Errorf("Expected error string \"test-error\", got %+v", err2)
}
assert.Equal("test-error", err2.Error.Error(), "%+v", err2)
}

func TestIsChatRefresh(t *testing.T) {
Expand All @@ -397,9 +378,7 @@ func TestIsChatRefresh(t *testing.T) {
Data: data_true,
},
}
if !msg.IsChatRefresh() {
t.Error("message should be detected as chat refresh")
}
assert.True(t, msg.IsChatRefresh())

data_false := []byte("{\"type\":\"chat\",\"chat\":{\"refresh\":false}}")
msg = ServerMessage{
Expand All @@ -408,9 +387,7 @@ func TestIsChatRefresh(t *testing.T) {
Data: data_false,
},
}
if msg.IsChatRefresh() {
t.Error("message should not be detected as chat refresh")
}
assert.False(t, msg.IsChatRefresh())
}

func assertEqualStrings(t *testing.T, expected, result []string) {
Expand All @@ -427,27 +404,22 @@ func assertEqualStrings(t *testing.T, expected, result []string) {
sort.Strings(result)
}

if !reflect.DeepEqual(expected, result) {
t.Errorf("Expected %+v, got %+v", expected, result)
}
assert.Equal(t, expected, result)
}

func Test_Welcome_AddRemoveFeature(t *testing.T) {
t.Parallel()
assert := assert.New(t)
var msg WelcomeServerMessage
assertEqualStrings(t, []string{}, msg.Features)

msg.AddFeature("one", "two", "one")
assertEqualStrings(t, []string{"one", "two"}, msg.Features)
if !sort.StringsAreSorted(msg.Features) {
t.Errorf("features should be sorted, got %+v", msg.Features)
}
assert.True(sort.StringsAreSorted(msg.Features), "features should be sorted, got %+v", msg.Features)

msg.AddFeature("three")
assertEqualStrings(t, []string{"one", "two", "three"}, msg.Features)
if !sort.StringsAreSorted(msg.Features) {
t.Errorf("features should be sorted, got %+v", msg.Features)
}
assert.True(sort.StringsAreSorted(msg.Features), "features should be sorted, got %+v", msg.Features)

msg.RemoveFeature("three", "one")
assertEqualStrings(t, []string{"two"}, msg.Features)
Expand Down
6 changes: 4 additions & 2 deletions async_events_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import (
"context"
"strings"
"testing"

"github.com/stretchr/testify/require"
)

var (
Expand All @@ -51,15 +53,15 @@ func getRealAsyncEventsForTest(t *testing.T) AsyncEvents {
url := startLocalNatsServer(t)
events, err := NewAsyncEvents(url)
if err != nil {
t.Fatal(err)
require.NoError(t, err)
}
return events
}

func getLoopbackAsyncEventsForTest(t *testing.T) AsyncEvents {
events, err := NewAsyncEvents(NatsLoopbackUrl)
if err != nil {
t.Fatal(err)
require.NoError(t, err)
}

t.Cleanup(func() {
Expand Down
Loading
Loading