-
Notifications
You must be signed in to change notification settings - Fork 43
/
backend_test.go
122 lines (100 loc) · 2.88 KB
/
backend_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
package pgproto3_test
import (
"io"
"testing"
"github.com/jackc/pgio"
"github.com/jackc/pgproto3/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestBackendReceiveInterrupted(t *testing.T) {
t.Parallel()
server := &interruptReader{}
server.push([]byte{'Q', 0, 0, 0, 6})
backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil)
msg, err := backend.Receive()
if err == nil {
t.Fatal("expected err")
}
if msg != nil {
t.Fatalf("did not expect msg, but %v", msg)
}
server.push([]byte{'I', 0})
msg, err = backend.Receive()
if err != nil {
t.Fatal(err)
}
if msg, ok := msg.(*pgproto3.Query); !ok || msg.String != "I" {
t.Fatalf("unexpected msg: %v", msg)
}
}
func TestBackendReceiveUnexpectedEOF(t *testing.T) {
t.Parallel()
server := &interruptReader{}
server.push([]byte{'Q', 0, 0, 0, 6})
backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil)
// Receive regular msg
msg, err := backend.Receive()
assert.Nil(t, msg)
assert.Equal(t, io.ErrUnexpectedEOF, err)
// Receive StartupMessage msg
dst := []byte{}
dst = pgio.AppendUint32(dst, 1000) // tell the backend we expect 1000 bytes to be read
dst = pgio.AppendUint32(dst, 1) // only send 1 byte
server.push(dst)
msg, err = backend.ReceiveStartupMessage()
assert.Nil(t, msg)
assert.Equal(t, io.ErrUnexpectedEOF, err)
}
func TestStartupMessage(t *testing.T) {
t.Parallel()
t.Run("valid StartupMessage", func(t *testing.T) {
want := &pgproto3.StartupMessage{
ProtocolVersion: pgproto3.ProtocolVersionNumber,
Parameters: map[string]string{
"username": "tester",
},
}
dst, err := want.Encode([]byte{})
require.NoError(t, err)
server := &interruptReader{}
server.push(dst)
backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil)
msg, err := backend.ReceiveStartupMessage()
require.NoError(t, err)
require.Equal(t, want, msg)
})
t.Run("invalid packet length", func(t *testing.T) {
wantErr := "invalid length of startup packet"
tests := []struct {
name string
packetLen uint32
}{
{
name: "large packet length",
// Since the StartupMessage contains the "Length of message contents
// in bytes, including self", the max startup packet length is actually
// 10000+4. Therefore, let's go past the limit with 10005
packetLen: 10005,
},
{
name: "short packet length",
packetLen: 3,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := &interruptReader{}
dst := []byte{}
dst = pgio.AppendUint32(dst, tt.packetLen)
dst = pgio.AppendUint32(dst, pgproto3.ProtocolVersionNumber)
server.push(dst)
backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil)
msg, err := backend.ReceiveStartupMessage()
require.Error(t, err)
require.Nil(t, msg)
require.Contains(t, err.Error(), wantErr)
})
}
})
}