From e1fc062464feb418991e23b9cfc3fa19498d3a8f Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Mon, 28 Oct 2024 09:24:23 +0100 Subject: [PATCH 1/3] proxy: Add timeouts to requests to Janus and cancel if session is closed. --- client.go | 6 +++--- proxy/proxy_client.go | 5 +++-- proxy/proxy_server.go | 49 +++++++++++++++++++++++++++++++++--------- proxy/proxy_session.go | 24 +++++++++++++++------ 4 files changed, 62 insertions(+), 22 deletions(-) diff --git a/client.go b/client.go index c9f1de09..3980218c 100644 --- a/client.go +++ b/client.go @@ -158,15 +158,15 @@ func NewClient(ctx context.Context, conn *websocket.Conn, remoteAddress string, } client := &Client{ - ctx: ctx, agent: agent, logRTT: true, } - client.SetConn(conn, remoteAddress, handler) + client.SetConn(ctx, conn, remoteAddress, handler) return client, nil } -func (c *Client) SetConn(conn *websocket.Conn, remoteAddress string, handler ClientHandler) { +func (c *Client) SetConn(ctx context.Context, conn *websocket.Conn, remoteAddress string, handler ClientHandler) { + c.ctx = ctx c.conn = conn c.addr = remoteAddress c.SetHandler(handler) diff --git a/proxy/proxy_client.go b/proxy/proxy_client.go index cee7328f..935a2b93 100644 --- a/proxy/proxy_client.go +++ b/proxy/proxy_client.go @@ -22,6 +22,7 @@ package main import ( + "context" "sync/atomic" "time" @@ -37,11 +38,11 @@ type ProxyClient struct { session atomic.Pointer[ProxySession] } -func NewProxyClient(proxy *ProxyServer, conn *websocket.Conn, addr string) (*ProxyClient, error) { +func NewProxyClient(ctx context.Context, proxy *ProxyServer, conn *websocket.Conn, addr string) (*ProxyClient, error) { client := &ProxyClient{ proxy: proxy, } - client.SetConn(conn, addr, client) + client.SetConn(ctx, conn, addr, client) return client, nil } diff --git a/proxy/proxy_server.go b/proxy/proxy_server.go index 68b77e25..7dea28cc 100644 --- a/proxy/proxy_server.go +++ b/proxy/proxy_server.go @@ -62,6 +62,9 @@ const ( initialMcuRetry = time.Second maxMcuRetry = time.Second * 16 + // MCU requests will be cancelled if they take too long. + defaultMcuTimeoutSeconds = 10 + updateLoadInterval = time.Second expireSessionsInterval = 10 * time.Second @@ -103,6 +106,7 @@ type ProxyServer struct { welcomeMessage string welcomeMsg *signaling.WelcomeServerMessage config *goconf.ConfigFile + mcuTimeout time.Duration url string mcu signaling.Mcu @@ -319,6 +323,12 @@ func NewProxyServer(r *mux.Router, version string, config *goconf.ConfigFile) (* maxIncoming, maxOutgoing := getTargetBandwidths(config) + mcuTimeoutSeconds, _ := config.GetInt("mcu", "timeout") + if mcuTimeoutSeconds <= 0 { + mcuTimeoutSeconds = defaultMcuTimeoutSeconds + } + mcuTimeout := time.Duration(mcuTimeoutSeconds) * time.Second + result := &ProxyServer{ version: version, country: country, @@ -328,7 +338,8 @@ func NewProxyServer(r *mux.Router, version string, config *goconf.ConfigFile) (* Country: country, Features: defaultProxyFeatures, }, - config: config, + config: config, + mcuTimeout: mcuTimeout, shutdownChannel: make(chan struct{}), @@ -634,14 +645,14 @@ func (s *ProxyServer) proxyHandler(w http.ResponseWriter, r *http.Request) { return } - client, err := NewProxyClient(s, conn, addr) + client, err := NewProxyClient(r.Context(), s, conn, addr) if err != nil { log.Printf("Could not create client for %s: %s", addr, err) return } go client.WritePump() - go client.ReadPump() + client.ReadPump() } func (s *ProxyServer) clientClosed(client *signaling.Client) { @@ -789,7 +800,7 @@ func (s *ProxyServer) processMessage(client *ProxyClient, data []byte) { return } - ctx := context.WithValue(context.Background(), ContextKeySession, session) + ctx := context.WithValue(session.Context(), ContextKeySession, session) session.MarkUsed() switch message.Type { @@ -873,8 +884,11 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s return } + ctx2, cancel := context.WithTimeout(ctx, s.mcuTimeout) + defer cancel() + id := uuid.New().String() - publisher, err := s.mcu.NewPublisher(ctx, session, id, cmd.Sid, cmd.StreamType, cmd.Bitrate, cmd.MediaTypes, &emptyInitiator{}) + publisher, err := s.mcu.NewPublisher(ctx2, session, id, cmd.Sid, cmd.StreamType, cmd.Bitrate, cmd.MediaTypes, &emptyInitiator{}) if err == context.DeadlineExceeded { log.Printf("Timeout while creating %s publisher %s for %s", cmd.StreamType, id, session.PublicId()) session.sendMessage(message.NewErrorServerMessage(TimeoutCreatingPublisher)) @@ -977,7 +991,10 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s log.Printf("Created remote %s subscriber %s as %s for %s on %s", cmd.StreamType, subscriber.Id(), id, session.PublicId(), cmd.RemoteUrl) } else { - subscriber, err = s.mcu.NewSubscriber(ctx, session, publisherId, cmd.StreamType, &emptyInitiator{}) + ctx2, cancel := context.WithTimeout(ctx, s.mcuTimeout) + defer cancel() + + subscriber, err = s.mcu.NewSubscriber(ctx2, session, publisherId, cmd.StreamType, &emptyInitiator{}) if err != nil { handleCreateError(err) return @@ -1083,7 +1100,10 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s return } - if err := publisher.PublishRemote(ctx, session.PublicId(), cmd.Hostname, cmd.Port, cmd.RtcpPort); err != nil { + ctx2, cancel := context.WithTimeout(ctx, s.mcuTimeout) + defer cancel() + + if err := publisher.PublishRemote(ctx2, session.PublicId(), cmd.Hostname, cmd.Port, cmd.RtcpPort); err != nil { var je *janus.ErrorMsg if !errors.As(err, &je) || je.Err.Code != signaling.JANUS_VIDEOROOM_ERROR_ID_EXISTS { log.Printf("Error publishing %s %s to remote %s (port=%d, rtcpPort=%d): %s", publisher.StreamType(), cmd.ClientId, cmd.Hostname, cmd.Port, cmd.RtcpPort, err) @@ -1091,13 +1111,19 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s return } - if err := publisher.UnpublishRemote(ctx, session.PublicId()); err != nil { + ctx2, cancel = context.WithTimeout(ctx, s.mcuTimeout) + defer cancel() + + if err := publisher.UnpublishRemote(ctx2, session.PublicId()); err != nil { log.Printf("Error unpublishing old %s %s to remote %s (port=%d, rtcpPort=%d): %s", publisher.StreamType(), cmd.ClientId, cmd.Hostname, cmd.Port, cmd.RtcpPort, err) session.sendMessage(message.NewWrappedErrorServerMessage(err)) return } - if err := publisher.PublishRemote(ctx, session.PublicId(), cmd.Hostname, cmd.Port, cmd.RtcpPort); err != nil { + ctx2, cancel = context.WithTimeout(ctx, s.mcuTimeout) + defer cancel() + + if err := publisher.PublishRemote(ctx2, session.PublicId(), cmd.Hostname, cmd.Port, cmd.RtcpPort); err != nil { log.Printf("Error publishing %s %s to remote %s (port=%d, rtcpPort=%d): %s", publisher.StreamType(), cmd.ClientId, cmd.Hostname, cmd.Port, cmd.RtcpPort, err) session.sendMessage(message.NewWrappedErrorServerMessage(err)) return @@ -1202,7 +1228,10 @@ func (s *ProxyServer) processPayload(ctx context.Context, client *ProxyClient, s return } - mcuClient.SendMessage(ctx, nil, mcuData, func(err error, response map[string]interface{}) { + ctx2, cancel := context.WithTimeout(ctx, s.mcuTimeout) + defer cancel() + + mcuClient.SendMessage(ctx2, nil, mcuData, func(err error, response map[string]interface{}) { var responseMsg *signaling.ProxyServerMessage if err != nil { log.Printf("Error sending %+v to %s client %s: %s", mcuData, mcuClient.StreamType(), payload.ClientId, err) diff --git a/proxy/proxy_session.go b/proxy/proxy_session.go index f2c9f499..1fefca69 100644 --- a/proxy/proxy_session.go +++ b/proxy/proxy_session.go @@ -37,10 +37,12 @@ const ( ) type ProxySession struct { - proxy *ProxyServer - id string - sid uint64 - lastUsed atomic.Int64 + proxy *ProxyServer + id string + sid uint64 + lastUsed atomic.Int64 + ctx context.Context + closeFunc context.CancelFunc clientLock sync.Mutex client *ProxyClient @@ -56,10 +58,13 @@ type ProxySession struct { } func NewProxySession(proxy *ProxyServer, sid uint64, id string) *ProxySession { + ctx, closeFunc := context.WithCancel(context.Background()) result := &ProxySession{ - proxy: proxy, - id: id, - sid: sid, + proxy: proxy, + id: id, + sid: sid, + ctx: ctx, + closeFunc: closeFunc, publishers: make(map[string]signaling.McuPublisher), publisherIds: make(map[signaling.McuPublisher]string), @@ -71,6 +76,10 @@ func NewProxySession(proxy *ProxyServer, sid uint64, id string) *ProxySession { return result } +func (s *ProxySession) Context() context.Context { + return s.ctx +} + func (s *ProxySession) PublicId() string { return s.id } @@ -95,6 +104,7 @@ func (s *ProxySession) MarkUsed() { } func (s *ProxySession) Close() { + s.closeFunc() s.clearPublishers() s.clearSubscribers() } From 8077ca410448291b2affbf7e7f02b01ba914dcf5 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Mon, 28 Oct 2024 10:17:38 +0100 Subject: [PATCH 2/3] proxy: Implement "bye" message. --- proxy/proxy_server.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/proxy/proxy_server.go b/proxy/proxy_server.go index 7dea28cc..da28a7ad 100644 --- a/proxy/proxy_server.go +++ b/proxy/proxy_server.go @@ -808,6 +808,8 @@ func (s *ProxyServer) processMessage(client *ProxyClient, data []byte) { s.processCommand(ctx, client, session, &message) case "payload": s.processPayload(ctx, client, session, &message) + case "bye": + s.processBye(ctx, client, session, &message) default: session.sendMessage(message.NewErrorServerMessage(UnsupportedMessage)) } @@ -1252,6 +1254,11 @@ func (s *ProxyServer) processPayload(ctx context.Context, client *ProxyClient, s }) } +func (s *ProxyServer) processBye(ctx context.Context, client *ProxyClient, session *ProxySession, message *signaling.ProxyClientMessage) { + log.Printf("Closing session %s", session.PublicId()) + s.DeleteSession(session.Sid()) +} + func (s *ProxyServer) parseToken(tokenValue string) (*signaling.TokenClaims, string, error) { reason := "auth-failed" token, err := jwt.ParseWithClaims(tokenValue, &signaling.TokenClaims{}, func(token *jwt.Token) (interface{}, error) { From 128b506ea0f45fae1fabf71995bdcaf3a06d0afe Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Mon, 28 Oct 2024 10:41:58 +0100 Subject: [PATCH 3/3] Add test for cancellation of proxy request. --- proxy/proxy_server_test.go | 177 ++++++++++++++++++++++- proxy/proxy_testclient_test.go | 254 +++++++++++++++++++++++++++++++++ 2 files changed, 425 insertions(+), 6 deletions(-) create mode 100644 proxy/proxy_testclient_test.go diff --git a/proxy/proxy_server_test.go b/proxy/proxy_server_test.go index 9dd3714f..9cb97174 100644 --- a/proxy/proxy_server_test.go +++ b/proxy/proxy_server_test.go @@ -27,6 +27,8 @@ import ( "crypto/rsa" "crypto/x509" "encoding/pem" + "errors" + "fmt" "net" "net/http/httptest" "os" @@ -46,6 +48,8 @@ import ( const ( KeypairSizeForTest = 2048 TokenIdForTest = "foo" + + testTimeout = 10 * time.Second ) func getWebsocketUrl(url string) string { @@ -58,13 +62,48 @@ func getWebsocketUrl(url string) string { } } +func WaitForProxyServer(ctx context.Context, t *testing.T, proxy *ProxyServer) { + // Wait for any channel messages to be processed. + time.Sleep(10 * time.Millisecond) + proxy.Stop() + for { + proxy.clientsLock.Lock() + clients := len(proxy.clients) + sessions := len(proxy.sessions) + proxy.clientsLock.Unlock() + proxy.remoteConnectionsLock.Lock() + remoteConnections := len(proxy.remoteConnections) + proxy.remoteConnectionsLock.Unlock() + if clients == 0 && + sessions == 0 && + remoteConnections == 0 { + break + } + + select { + case <-ctx.Done(): + proxy.clientsLock.Lock() + proxy.remoteConnectionsLock.Lock() + assert.Fail(t, fmt.Sprintf("Error waiting for clients %+v / sessions %+v / remoteConnections %+v to terminate: %+v", proxy.clients, proxy.sessions, proxy.remoteConnections, ctx.Err())) + proxy.remoteConnectionsLock.Unlock() + proxy.clientsLock.Unlock() + return + default: + time.Sleep(time.Millisecond) + } + } +} + func newProxyServerForTest(t *testing.T) (*ProxyServer, *rsa.PrivateKey, *httptest.Server) { require := require.New(t) tempdir := t.TempDir() var proxy *ProxyServer t.Cleanup(func() { if proxy != nil { - proxy.Stop() + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + WaitForProxyServer(ctx, t, proxy) } }) @@ -124,7 +163,7 @@ func TestTokenValid(t *testing.T) { Token: tokenString, } if session, err := proxy.NewSession(hello); assert.NoError(t, err) { - defer session.Close() + defer proxy.DeleteSession(session.Sid()) } } @@ -148,7 +187,7 @@ func TestTokenNotSigned(t *testing.T) { } if session, err := proxy.NewSession(hello); !assert.ErrorIs(t, err, TokenAuthFailed) { if session != nil { - defer session.Close() + defer proxy.DeleteSession(session.Sid()) } } } @@ -173,7 +212,7 @@ func TestTokenUnknown(t *testing.T) { } if session, err := proxy.NewSession(hello); !assert.ErrorIs(t, err, TokenAuthFailed) { if session != nil { - defer session.Close() + defer proxy.DeleteSession(session.Sid()) } } } @@ -198,7 +237,7 @@ func TestTokenInFuture(t *testing.T) { } if session, err := proxy.NewSession(hello); !assert.ErrorIs(t, err, TokenNotValidYet) { if session != nil { - defer session.Close() + defer proxy.DeleteSession(session.Sid()) } } } @@ -223,7 +262,7 @@ func TestTokenExpired(t *testing.T) { } if session, err := proxy.NewSession(hello); !assert.ErrorIs(t, err, TokenExpired) { if session != nil { - defer session.Close() + defer proxy.DeleteSession(session.Sid()) } } } @@ -290,3 +329,129 @@ func TestWebsocketFeatures(t *testing.T) { assert.NoError(conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Time{})) } + +func TestProxyCreateSession(t *testing.T) { + signaling.CatchLogForTest(t) + assert := assert.New(t) + require := require.New(t) + _, key, server := newProxyServerForTest(t) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + client := NewProxyTestClient(ctx, t, server.URL) + defer client.CloseWithBye() + + require.NoError(client.SendHello(key)) + + if hello, err := client.RunUntilHello(ctx); assert.NoError(err) { + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello) + } + + _, err := client.RunUntilLoad(ctx, 0) + assert.NoError(err) +} + +type HangingTestMCU struct { + t *testing.T + ctx context.Context +} + +func NewHangingTestMCU(t *testing.T) *HangingTestMCU { + ctx, closeFunc := context.WithCancel(context.Background()) + t.Cleanup(func() { + closeFunc() + }) + + return &HangingTestMCU{ + t: t, + ctx: ctx, + } +} + +func (m *HangingTestMCU) Start(ctx context.Context) error { + return nil +} + +func (m *HangingTestMCU) Stop() { +} + +func (m *HangingTestMCU) Reload(config *goconf.ConfigFile) { +} + +func (m *HangingTestMCU) SetOnConnected(f func()) { +} + +func (m *HangingTestMCU) SetOnDisconnected(f func()) { +} + +func (m *HangingTestMCU) GetStats() interface{} { + return nil +} + +func (m *HangingTestMCU) NewPublisher(ctx context.Context, listener signaling.McuListener, id string, sid string, streamType signaling.StreamType, bitrate int, mediaTypes signaling.MediaType, initiator signaling.McuInitiator) (signaling.McuPublisher, error) { + ctx2, cancel := context.WithTimeout(m.ctx, testTimeout*2) + defer cancel() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-ctx2.Done(): + return nil, errors.New("Should have been cancelled before") + } +} + +func (m *HangingTestMCU) NewSubscriber(ctx context.Context, listener signaling.McuListener, publisher string, streamType signaling.StreamType, initiator signaling.McuInitiator) (signaling.McuSubscriber, error) { + ctx2, cancel := context.WithTimeout(m.ctx, testTimeout*2) + defer cancel() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-ctx2.Done(): + return nil, errors.New("Should have been cancelled before") + } +} + +func TestProxyCancelOnClose(t *testing.T) { + signaling.CatchLogForTest(t) + assert := assert.New(t) + require := require.New(t) + proxy, key, server := newProxyServerForTest(t) + + proxy.mcu = NewHangingTestMCU(t) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + client := NewProxyTestClient(ctx, t, server.URL) + defer client.CloseWithBye() + + require.NoError(client.SendHello(key)) + + if hello, err := client.RunUntilHello(ctx); assert.NoError(err) { + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello) + } + + _, err := client.RunUntilLoad(ctx, 0) + assert.NoError(err) + + require.NoError(client.SendCommand(&signaling.CommandProxyClientMessage{ + Type: "create-publisher", + StreamType: signaling.StreamTypeVideo, + })) + + // Simulate expired session while request is still being processed. + go func() { + if session := proxy.GetSession(1); assert.NotNil(session) { + session.Close() + } + }() + + if message, err := client.RunUntilMessage(ctx); assert.NoError(err) { + if err := checkMessageType(message, "error"); assert.NoError(err) { + assert.Equal("internal_error", message.Error.Code) + assert.Equal(context.Canceled.Error(), message.Error.Message) + } + } +} diff --git a/proxy/proxy_testclient_test.go b/proxy/proxy_testclient_test.go new file mode 100644 index 00000000..32641578 --- /dev/null +++ b/proxy/proxy_testclient_test.go @@ -0,0 +1,254 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2024 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package main + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sync" + "testing" + "time" + + "github.com/golang-jwt/jwt/v4" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + signaling "github.com/strukturag/nextcloud-spreed-signaling" +) + +var ( + ErrNoMessageReceived = errors.New("no message was received by the server") +) + +type ProxyTestClient struct { + t *testing.T + assert *assert.Assertions + require *require.Assertions + + mu sync.Mutex + conn *websocket.Conn + messageChan chan []byte + readErrorChan chan error + + sessionId string +} + +func NewProxyTestClient(ctx context.Context, t *testing.T, url string) *ProxyTestClient { + conn, _, err := websocket.DefaultDialer.DialContext(ctx, getWebsocketUrl(url), nil) + require.NoError(t, err) + + messageChan := make(chan []byte) + readErrorChan := make(chan error, 1) + + go func() { + for { + messageType, data, err := conn.ReadMessage() + if err != nil { + readErrorChan <- err + return + } else if !assert.Equal(t, websocket.TextMessage, messageType) { + return + } + + messageChan <- data + } + }() + + client := &ProxyTestClient{ + t: t, + assert: assert.New(t), + require: require.New(t), + + conn: conn, + messageChan: messageChan, + readErrorChan: readErrorChan, + } + return client +} + +func (c *ProxyTestClient) CloseWithBye() { + c.SendBye() // nolint + c.Close() +} + +func (c *ProxyTestClient) Close() { + c.mu.Lock() + defer c.mu.Unlock() + if err := c.conn.WriteMessage(websocket.CloseMessage, []byte{}); err == websocket.ErrCloseSent { + // Already closed + return + } + + // Wait a bit for close message to be processed. + time.Sleep(100 * time.Millisecond) + c.assert.NoError(c.conn.Close()) + + // Drain any entries in the channels to terminate the read goroutine. +loop: + for { + select { + case <-c.readErrorChan: + case <-c.messageChan: + default: + break loop + } + } +} + +func (c *ProxyTestClient) SendBye() error { + hello := &signaling.ProxyClientMessage{ + Id: "9876", + Type: "bye", + Bye: &signaling.ByeProxyClientMessage{}, + } + return c.WriteJSON(hello) +} + +func (c *ProxyTestClient) WriteJSON(data interface{}) error { + if msg, ok := data.(*signaling.ProxyClientMessage); ok { + if err := msg.CheckValid(); err != nil { + return err + } + } + + c.mu.Lock() + defer c.mu.Unlock() + return c.conn.WriteJSON(data) +} + +func (c *ProxyTestClient) RunUntilMessage(ctx context.Context) (message *signaling.ProxyServerMessage, err error) { + select { + case err = <-c.readErrorChan: + case msg := <-c.messageChan: + var m signaling.ProxyServerMessage + if err = json.Unmarshal(msg, &m); err == nil { + message = &m + } + case <-ctx.Done(): + err = ctx.Err() + } + return +} + +func checkUnexpectedClose(err error) error { + if err != nil && websocket.IsUnexpectedCloseError(err, + websocket.CloseNormalClosure, + websocket.CloseGoingAway, + websocket.CloseNoStatusReceived) { + return fmt.Errorf("Connection was closed with unexpected error: %s", err) + } + + return nil +} + +func checkMessageType(message *signaling.ProxyServerMessage, expectedType string) error { + if message == nil { + return ErrNoMessageReceived + } + + if message.Type != expectedType { + return fmt.Errorf("Expected \"%s\" message, got %+v", expectedType, message) + } + switch message.Type { + case "hello": + if message.Hello == nil { + return fmt.Errorf("Expected \"%s\" message, got %+v", expectedType, message) + } + case "command": + if message.Command == nil { + return fmt.Errorf("Expected \"%s\" message, got %+v", expectedType, message) + } + case "event": + if message.Event == nil { + return fmt.Errorf("Expected \"%s\" message, got %+v", expectedType, message) + } + } + + return nil +} + +func (c *ProxyTestClient) SendHello(key interface{}) error { + claims := &signaling.TokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(time.Now().Add(-maxTokenAge / 2)), + Issuer: TokenIdForTest, + }, + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tokenString, err := token.SignedString(key) + c.require.NoError(err) + + hello := &signaling.ProxyClientMessage{ + Id: "1234", + Type: "hello", + Hello: &signaling.HelloProxyClientMessage{ + Version: "1.0", + Features: []string{}, + Token: tokenString, + }, + } + return c.WriteJSON(hello) +} + +func (c *ProxyTestClient) RunUntilHello(ctx context.Context) (message *signaling.ProxyServerMessage, err error) { + if message, err = c.RunUntilMessage(ctx); err != nil { + return nil, err + } + if err := checkUnexpectedClose(err); err != nil { + return nil, err + } + if err := checkMessageType(message, "hello"); err != nil { + return nil, err + } + c.sessionId = message.Hello.SessionId + return message, nil +} + +func (c *ProxyTestClient) RunUntilLoad(ctx context.Context, load int64) (message *signaling.ProxyServerMessage, err error) { + if message, err = c.RunUntilMessage(ctx); err != nil { + return nil, err + } + if err := checkUnexpectedClose(err); err != nil { + return nil, err + } + if err := checkMessageType(message, "event"); err != nil { + return nil, err + } + if expectedType := "update-load"; message.Event.Type != expectedType { + return nil, fmt.Errorf("Expected \"%s\" event message, got %+v", expectedType, message) + } + if load != message.Event.Load { + return nil, fmt.Errorf("Expected load %d, got %+v", load, message) + } + return message, nil +} + +func (c *ProxyTestClient) SendCommand(command *signaling.CommandProxyClientMessage) error { + message := &signaling.ProxyClientMessage{ + Id: "2345", + Type: "command", + Command: command, + } + return c.WriteJSON(message) +}