From 71ceadbf4ce8fb663e9da60dcae26063412fc6d1 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Mon, 11 Nov 2024 10:17:25 +0100 Subject: [PATCH] Add more cases when to stop remote publishing. --- proxy/proxy_server.go | 12 + proxy/proxy_server_test.go | 511 +++++++++++++++++++++++++++++++++++++ proxy/proxy_session.go | 83 ++++++ 3 files changed, 606 insertions(+) diff --git a/proxy/proxy_server.go b/proxy/proxy_server.go index 0e6634dc..b256d099 100644 --- a/proxy/proxy_server.go +++ b/proxy/proxy_server.go @@ -1163,6 +1163,7 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s } } + session.AddRemotePublisher(publisher, cmd.Hostname, cmd.Port, cmd.RtcpPort) response := &signaling.ProxyServerMessage{ Id: message.Id, Type: "command", @@ -1193,6 +1194,8 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s return } + session.RemoveRemotePublisher(publisher, cmd.Hostname, cmd.Port, cmd.RtcpPort) + response := &signaling.ProxyServerMessage{ Id: message.Id, Type: "command", @@ -1599,3 +1602,12 @@ func (s *ProxyServer) getRemoteConnection(url string) (*RemoteConnection, error) s.remoteConnections[url] = conn return conn, nil } + +func (s *ProxyServer) PublisherDeleted(publisher signaling.McuPublisher) { + s.sessionsLock.RLock() + defer s.sessionsLock.RUnlock() + + for _, session := range s.sessions { + session.OnPublisherDeleted(publisher) + } +} diff --git a/proxy/proxy_server_test.go b/proxy/proxy_server_test.go index cb6f42f0..973b6dc3 100644 --- a/proxy/proxy_server_test.go +++ b/proxy/proxy_server_test.go @@ -33,6 +33,7 @@ import ( "net/http/httptest" "os" "strings" + "sync" "sync/atomic" "testing" "time" @@ -835,3 +836,513 @@ func TestProxyRemoteSubscriber(t *testing.T) { } } } + +func TestProxyCloseRemoteOnSessionClose(t *testing.T) { + signaling.CatchLogForTest(t) + assert := assert.New(t) + require := require.New(t) + proxy, key, server := newProxyServerForTest(t) + + mcu := NewRemoteSubscriberTestMCU(t) + proxy.mcu = mcu + // Unused but must be set so remote subscribing works + proxy.tokenId = "token" + proxy.tokenKey = key + proxy.remoteHostname = "test-hostname" + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + client := NewProxyTestClient(ctx, t, server.URL) + defer client.CloseWithBye() + + require.NoError(client.SendHello(key)) + + if hello, err := client.RunUntilHello(ctx); assert.NoError(err) { + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello) + } + + _, err := client.RunUntilLoad(ctx, 0) + assert.NoError(err) + + publisherId := "the-publisher-id" + claims := &signaling.TokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(time.Now().Add(-maxTokenAge / 2)), + Issuer: TokenIdForTest, + Subject: publisherId, + }, + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tokenString, err := token.SignedString(key) + require.NoError(err) + + require.NoError(client.WriteJSON(&signaling.ProxyClientMessage{ + Id: "2345", + Type: "command", + Command: &signaling.CommandProxyClientMessage{ + Type: "create-subscriber", + StreamType: signaling.StreamTypeVideo, + PublisherId: publisherId, + RemoteUrl: "https://remote-hostname", + RemoteToken: tokenString, + }, + })) + + if message, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("2345", message.Id) + if err := checkMessageType(message, "command"); assert.NoError(err) { + require.NotEmpty(message.Command.Id) + } + } + + // Closing the session will cause any active remote publishers stop be stopped. + client.CloseWithBye() + + if assert.NotNil(mcu.publisher) && assert.NotNil(mcu.subscriber) { + select { + case <-mcu.subscriber.closed.Done(): + case <-ctx.Done(): + assert.Fail("subscriber was not closed") + } + select { + case <-mcu.publisher.closed.Done(): + case <-ctx.Done(): + assert.Fail("publisher was not closed") + } + } +} + +type UnpublishRemoteTestMCU struct { + TestMCU + + publisher atomic.Pointer[UnpublishRemoteTestPublisher] +} + +func NewUnpublishRemoteTestMCU(t *testing.T) *UnpublishRemoteTestMCU { + return &UnpublishRemoteTestMCU{ + TestMCU: TestMCU{ + t: t, + }, + } +} + +type UnpublishRemoteTestPublisher struct { + TestMCUPublisher + + t *testing.T + + mu sync.RWMutex + remoteId string + remoteData *remotePublisherData +} + +func (m *UnpublishRemoteTestMCU) NewPublisher(ctx context.Context, listener signaling.McuListener, id string, sid string, streamType signaling.StreamType, settings signaling.NewPublisherSettings, initiator signaling.McuInitiator) (signaling.McuPublisher, error) { + publisher := &UnpublishRemoteTestPublisher{ + TestMCUPublisher: TestMCUPublisher{ + id: id, + sid: sid, + streamType: streamType, + }, + + t: m.t, + } + m.publisher.Store(publisher) + return publisher, nil +} + +func (p *UnpublishRemoteTestPublisher) getRemoteId() string { + p.mu.RLock() + defer p.mu.RUnlock() + return p.remoteId +} + +func (p *UnpublishRemoteTestPublisher) getRemoteData() *remotePublisherData { + p.mu.RLock() + defer p.mu.RUnlock() + return p.remoteData +} + +func (p *UnpublishRemoteTestPublisher) clearRemote() { + p.mu.Lock() + defer p.mu.Unlock() + p.remoteId = "" + p.remoteData = nil +} + +func (p *UnpublishRemoteTestPublisher) PublishRemote(ctx context.Context, remoteId string, hostname string, port int, rtcpPort int) error { + p.mu.Lock() + defer p.mu.Unlock() + if assert.Empty(p.t, p.remoteId) { + p.remoteId = remoteId + p.remoteData = &remotePublisherData{ + hostname: hostname, + port: port, + rtcpPort: rtcpPort, + } + } + return nil +} + +func (p *UnpublishRemoteTestPublisher) UnpublishRemote(ctx context.Context, remoteId string, hostname string, port int, rtcpPort int) error { + p.mu.Lock() + defer p.mu.Unlock() + assert.Equal(p.t, remoteId, p.remoteId) + if remoteData := p.remoteData; assert.NotNil(p.t, remoteData) && + assert.Equal(p.t, remoteData.hostname, hostname) && + assert.EqualValues(p.t, remoteData.port, port) && + assert.EqualValues(p.t, remoteData.rtcpPort, rtcpPort) { + p.remoteId = "" + p.remoteData = nil + } + return nil +} + +func TestProxyUnpublishRemote(t *testing.T) { + signaling.CatchLogForTest(t) + assert := assert.New(t) + require := require.New(t) + proxy, key, server := newProxyServerForTest(t) + + mcu := NewUnpublishRemoteTestMCU(t) + proxy.mcu = mcu + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + client1 := NewProxyTestClient(ctx, t, server.URL) + defer client1.CloseWithBye() + + require.NoError(client1.SendHello(key)) + + if hello, err := client1.RunUntilHello(ctx); assert.NoError(err) { + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello) + } + + _, err := client1.RunUntilLoad(ctx, 0) + assert.NoError(err) + + publisherId := "the-publisher-id" + require.NoError(client1.WriteJSON(&signaling.ProxyClientMessage{ + Id: "2345", + Type: "command", + Command: &signaling.CommandProxyClientMessage{ + Type: "create-publisher", + PublisherId: publisherId, + Sid: "1234-abcd", + StreamType: signaling.StreamTypeVideo, + PublisherSettings: &signaling.NewPublisherSettings{ + Bitrate: 1234567, + MediaTypes: signaling.MediaTypeAudio | signaling.MediaTypeVideo, + }, + }, + })) + + var clientId string + if message, err := client1.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("2345", message.Id) + if err := checkMessageType(message, "command"); assert.NoError(err) { + require.NotEmpty(message.Command.Id) + clientId = message.Command.Id + } + } + + client2 := NewProxyTestClient(ctx, t, server.URL) + defer client2.CloseWithBye() + + require.NoError(client2.SendHello(key)) + + hello2, err := client2.RunUntilHello(ctx) + if assert.NoError(err) { + assert.NotEmpty(hello2.Hello.SessionId, "%+v", hello2) + } + + _, err = client2.RunUntilLoad(ctx, 0) + assert.NoError(err) + + require.NoError(client2.WriteJSON(&signaling.ProxyClientMessage{ + Id: "3456", + Type: "command", + Command: &signaling.CommandProxyClientMessage{ + Type: "publish-remote", + StreamType: signaling.StreamTypeVideo, + ClientId: clientId, + Hostname: "remote-host", + Port: 10001, + RtcpPort: 10002, + }, + })) + + if message, err := client2.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("3456", message.Id) + if err := checkMessageType(message, "command"); assert.NoError(err) { + require.NotEmpty(message.Command.Id) + } + } + + if publisher := mcu.publisher.Load(); assert.NotNil(publisher) { + assert.Equal(hello2.Hello.SessionId, publisher.getRemoteId()) + if remoteData := publisher.getRemoteData(); assert.NotNil(remoteData) { + assert.Equal("remote-host", remoteData.hostname) + assert.EqualValues(10001, remoteData.port) + assert.EqualValues(10002, remoteData.rtcpPort) + } + } + + require.NoError(client2.WriteJSON(&signaling.ProxyClientMessage{ + Id: "4567", + Type: "command", + Command: &signaling.CommandProxyClientMessage{ + Type: "unpublish-remote", + StreamType: signaling.StreamTypeVideo, + ClientId: clientId, + Hostname: "remote-host", + Port: 10001, + RtcpPort: 10002, + }, + })) + + if message, err := client2.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("4567", message.Id) + if err := checkMessageType(message, "command"); assert.NoError(err) { + require.NotEmpty(message.Command.Id) + } + } + + if publisher := mcu.publisher.Load(); assert.NotNil(publisher) { + assert.Empty(publisher.getRemoteId()) + assert.Nil(publisher.getRemoteData()) + } +} + +func TestProxyUnpublishRemotePublisherClosed(t *testing.T) { + signaling.CatchLogForTest(t) + assert := assert.New(t) + require := require.New(t) + proxy, key, server := newProxyServerForTest(t) + + mcu := NewUnpublishRemoteTestMCU(t) + proxy.mcu = mcu + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + client1 := NewProxyTestClient(ctx, t, server.URL) + defer client1.CloseWithBye() + + require.NoError(client1.SendHello(key)) + + if hello, err := client1.RunUntilHello(ctx); assert.NoError(err) { + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello) + } + + _, err := client1.RunUntilLoad(ctx, 0) + assert.NoError(err) + + publisherId := "the-publisher-id" + require.NoError(client1.WriteJSON(&signaling.ProxyClientMessage{ + Id: "2345", + Type: "command", + Command: &signaling.CommandProxyClientMessage{ + Type: "create-publisher", + PublisherId: publisherId, + Sid: "1234-abcd", + StreamType: signaling.StreamTypeVideo, + PublisherSettings: &signaling.NewPublisherSettings{ + Bitrate: 1234567, + MediaTypes: signaling.MediaTypeAudio | signaling.MediaTypeVideo, + }, + }, + })) + + var clientId string + if message, err := client1.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("2345", message.Id) + if err := checkMessageType(message, "command"); assert.NoError(err) { + require.NotEmpty(message.Command.Id) + clientId = message.Command.Id + } + } + + client2 := NewProxyTestClient(ctx, t, server.URL) + defer client2.CloseWithBye() + + require.NoError(client2.SendHello(key)) + + hello2, err := client2.RunUntilHello(ctx) + if assert.NoError(err) { + assert.NotEmpty(hello2.Hello.SessionId, "%+v", hello2) + } + + _, err = client2.RunUntilLoad(ctx, 0) + assert.NoError(err) + + require.NoError(client2.WriteJSON(&signaling.ProxyClientMessage{ + Id: "3456", + Type: "command", + Command: &signaling.CommandProxyClientMessage{ + Type: "publish-remote", + StreamType: signaling.StreamTypeVideo, + ClientId: clientId, + Hostname: "remote-host", + Port: 10001, + RtcpPort: 10002, + }, + })) + + if message, err := client2.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("3456", message.Id) + if err := checkMessageType(message, "command"); assert.NoError(err) { + require.NotEmpty(message.Command.Id) + } + } + + if publisher := mcu.publisher.Load(); assert.NotNil(publisher) { + assert.Equal(hello2.Hello.SessionId, publisher.getRemoteId()) + if remoteData := publisher.getRemoteData(); assert.NotNil(remoteData) { + assert.Equal("remote-host", remoteData.hostname) + assert.EqualValues(10001, remoteData.port) + assert.EqualValues(10002, remoteData.rtcpPort) + } + } + + require.NoError(client1.WriteJSON(&signaling.ProxyClientMessage{ + Id: "4567", + Type: "command", + Command: &signaling.CommandProxyClientMessage{ + Type: "delete-publisher", + ClientId: clientId, + }, + })) + + if message, err := client1.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("4567", message.Id) + if err := checkMessageType(message, "command"); assert.NoError(err) { + require.NotEmpty(message.Command.Id) + } + } + + // Remote publishing was not stopped explicitly... + if publisher := mcu.publisher.Load(); assert.NotNil(publisher) { + assert.Equal(hello2.Hello.SessionId, publisher.getRemoteId()) + if remoteData := publisher.getRemoteData(); assert.NotNil(remoteData) { + assert.Equal("remote-host", remoteData.hostname) + assert.EqualValues(10001, remoteData.port) + assert.EqualValues(10002, remoteData.rtcpPort) + } + } + + // ...but the session no longer contains information on the remote publisher. + if data, err := proxy.cookie.DecodePublic(hello2.Hello.SessionId); assert.NoError(err) { + session := proxy.GetSession(data.Sid) + if assert.NotNil(session) { + session.remotePublishersLock.Lock() + defer session.remotePublishersLock.Unlock() + assert.Empty(session.remotePublishers) + } + } + + if publisher := mcu.publisher.Load(); assert.NotNil(publisher) { + publisher.clearRemote() + } +} + +func TestProxyUnpublishRemoteOnSessionClose(t *testing.T) { + signaling.CatchLogForTest(t) + assert := assert.New(t) + require := require.New(t) + proxy, key, server := newProxyServerForTest(t) + + mcu := NewUnpublishRemoteTestMCU(t) + proxy.mcu = mcu + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + client1 := NewProxyTestClient(ctx, t, server.URL) + defer client1.CloseWithBye() + + require.NoError(client1.SendHello(key)) + + if hello, err := client1.RunUntilHello(ctx); assert.NoError(err) { + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello) + } + + _, err := client1.RunUntilLoad(ctx, 0) + assert.NoError(err) + + publisherId := "the-publisher-id" + require.NoError(client1.WriteJSON(&signaling.ProxyClientMessage{ + Id: "2345", + Type: "command", + Command: &signaling.CommandProxyClientMessage{ + Type: "create-publisher", + PublisherId: publisherId, + Sid: "1234-abcd", + StreamType: signaling.StreamTypeVideo, + PublisherSettings: &signaling.NewPublisherSettings{ + Bitrate: 1234567, + MediaTypes: signaling.MediaTypeAudio | signaling.MediaTypeVideo, + }, + }, + })) + + var clientId string + if message, err := client1.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("2345", message.Id) + if err := checkMessageType(message, "command"); assert.NoError(err) { + require.NotEmpty(message.Command.Id) + clientId = message.Command.Id + } + } + + client2 := NewProxyTestClient(ctx, t, server.URL) + defer client2.CloseWithBye() + + require.NoError(client2.SendHello(key)) + + hello2, err := client2.RunUntilHello(ctx) + if assert.NoError(err) { + assert.NotEmpty(hello2.Hello.SessionId, "%+v", hello2) + } + + _, err = client2.RunUntilLoad(ctx, 0) + assert.NoError(err) + + require.NoError(client2.WriteJSON(&signaling.ProxyClientMessage{ + Id: "3456", + Type: "command", + Command: &signaling.CommandProxyClientMessage{ + Type: "publish-remote", + StreamType: signaling.StreamTypeVideo, + ClientId: clientId, + Hostname: "remote-host", + Port: 10001, + RtcpPort: 10002, + }, + })) + + if message, err := client2.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("3456", message.Id) + if err := checkMessageType(message, "command"); assert.NoError(err) { + require.NotEmpty(message.Command.Id) + } + } + + if publisher := mcu.publisher.Load(); assert.NotNil(publisher) { + assert.Equal(hello2.Hello.SessionId, publisher.getRemoteId()) + if remoteData := publisher.getRemoteData(); assert.NotNil(remoteData) { + assert.Equal("remote-host", remoteData.hostname) + assert.EqualValues(10001, remoteData.port) + assert.EqualValues(10002, remoteData.rtcpPort) + } + } + + // Closing the session will cause any active remote publishers stop be stopped. + client2.CloseWithBye() + + if publisher := mcu.publisher.Load(); assert.NotNil(publisher) { + assert.Empty(publisher.getRemoteId()) + assert.Nil(publisher.getRemoteData()) + } +} diff --git a/proxy/proxy_session.go b/proxy/proxy_session.go index ed9ac260..de6645be 100644 --- a/proxy/proxy_session.go +++ b/proxy/proxy_session.go @@ -23,6 +23,7 @@ package main import ( "context" + "fmt" "log" "sync" "sync/atomic" @@ -36,6 +37,12 @@ const ( sessionExpirationTime = time.Minute ) +type remotePublisherData struct { + hostname string + port int + rtcpPort int +} + type ProxySession struct { proxy *ProxyServer id string @@ -55,6 +62,9 @@ type ProxySession struct { subscribersLock sync.Mutex subscribers map[string]signaling.McuSubscriber subscriberIds map[signaling.McuSubscriber]string + + remotePublishersLock sync.Mutex + remotePublishers map[signaling.McuPublisher]map[string]*remotePublisherData } func NewProxySession(proxy *ProxyServer, sid uint64, id string) *ProxySession { @@ -121,6 +131,7 @@ func (s *ProxySession) Close() { s.closeFunc() s.clearPublishers() s.clearSubscribers() + s.clearRemotePublishers() s.proxy.DeleteSession(s.Sid()) } @@ -287,6 +298,8 @@ func (s *ProxySession) DeletePublisher(publisher signaling.McuPublisher) string delete(s.publishers, id) delete(s.publisherIds, publisher) + delete(s.remotePublishers, publisher) + go s.proxy.PublisherDeleted(publisher) return id } @@ -329,6 +342,22 @@ func (s *ProxySession) clearPublishers() { clear(s.publisherIds) } +func (s *ProxySession) clearRemotePublishers() { + s.remotePublishersLock.Lock() + defer s.remotePublishersLock.Unlock() + + go func(remotePublishers map[signaling.McuPublisher]map[string]*remotePublisherData) { + for publisher, entries := range remotePublishers { + for _, data := range entries { + if err := publisher.UnpublishRemote(context.Background(), s.PublicId(), data.hostname, data.port, data.rtcpPort); err != nil { + log.Printf("Error unpublishing %s %s from remote %s: %s", publisher.StreamType(), publisher.Id(), data.hostname, err) + } + } + } + }(s.remotePublishers) + s.remotePublishers = nil +} + func (s *ProxySession) clearSubscribers() { s.publishersLock.Lock() defer s.publishersLock.Unlock() @@ -349,4 +378,58 @@ func (s *ProxySession) clearSubscribers() { func (s *ProxySession) NotifyDisconnected() { s.clearPublishers() s.clearSubscribers() + s.clearRemotePublishers() +} + +func (s *ProxySession) AddRemotePublisher(publisher signaling.McuPublisher, hostname string, port int, rtcpPort int) bool { + s.remotePublishersLock.Lock() + defer s.remotePublishersLock.Unlock() + + remote, found := s.remotePublishers[publisher] + if !found { + remote = make(map[string]*remotePublisherData) + if s.remotePublishers == nil { + s.remotePublishers = make(map[signaling.McuPublisher]map[string]*remotePublisherData) + } + s.remotePublishers[publisher] = remote + } + + key := fmt.Sprintf("%s:%d%d", hostname, port, rtcpPort) + if _, found := remote[key]; found { + return false + } + + data := &remotePublisherData{ + hostname: hostname, + port: port, + rtcpPort: rtcpPort, + } + remote[key] = data + return true +} + +func (s *ProxySession) RemoveRemotePublisher(publisher signaling.McuPublisher, hostname string, port int, rtcpPort int) { + s.remotePublishersLock.Lock() + defer s.remotePublishersLock.Unlock() + + remote, found := s.remotePublishers[publisher] + if !found { + return + } + + key := fmt.Sprintf("%s:%d%d", hostname, port, rtcpPort) + delete(remote, key) + if len(remote) == 0 { + delete(s.remotePublishers, publisher) + if len(s.remotePublishers) == 0 { + s.remotePublishers = nil + } + } +} + +func (s *ProxySession) OnPublisherDeleted(publisher signaling.McuPublisher) { + s.remotePublishersLock.Lock() + defer s.remotePublishersLock.Unlock() + + delete(s.remotePublishers, publisher) }