diff --git a/internal/protocols/hls/from_stream_test.go b/internal/protocols/hls/from_stream_test.go index a04b968ad62..53bd483a750 100644 --- a/internal/protocols/hls/from_stream_test.go +++ b/internal/protocols/hls/from_stream_test.go @@ -76,5 +76,7 @@ func TestFromStreamSkipUnsupportedTracks(t *testing.T) { err = FromStream(stream, l, m) require.NoError(t, err) + defer stream.RemoveReader(l) + require.Equal(t, 2, n) } diff --git a/internal/protocols/mpegts/from_stream_test.go b/internal/protocols/mpegts/from_stream_test.go index 90595a70b78..62c695632ee 100644 --- a/internal/protocols/mpegts/from_stream_test.go +++ b/internal/protocols/mpegts/from_stream_test.go @@ -64,5 +64,7 @@ func TestFromStreamSkipUnsupportedTracks(t *testing.T) { err = FromStream(stream, l, nil, nil, 0) require.NoError(t, err) + defer stream.RemoveReader(l) + require.Equal(t, 1, n) } diff --git a/internal/protocols/rtmp/from_stream_test.go b/internal/protocols/rtmp/from_stream_test.go index 0c23b0e5d6a..d1d0d1eb709 100644 --- a/internal/protocols/rtmp/from_stream_test.go +++ b/internal/protocols/rtmp/from_stream_test.go @@ -78,5 +78,7 @@ func TestFromStreamSkipUnsupportedTracks(t *testing.T) { err = FromStream(stream, l, conn, nil, 0) require.NoError(t, err) + defer stream.RemoveReader(l) + require.Equal(t, 2, n) } diff --git a/internal/protocols/webrtc/from_stream_test.go b/internal/protocols/webrtc/from_stream_test.go index 25c73a9e8a6..17d20bb95fb 100644 --- a/internal/protocols/webrtc/from_stream_test.go +++ b/internal/protocols/webrtc/from_stream_test.go @@ -66,6 +66,8 @@ func TestFromStreamSkipUnsupportedTracks(t *testing.T) { err = FromStream(stream, l, pc) require.NoError(t, err) + defer stream.RemoveReader(l) + require.Equal(t, 1, n) } @@ -93,6 +95,7 @@ func TestFromStream(t *testing.T) { err = FromStream(stream, nil, pc) require.NoError(t, err) + defer stream.RemoveReader(nil) require.Equal(t, ca.webrtcCaps, pc.OutgoingTracks[0].Caps) }) diff --git a/internal/servers/hls/server_test.go b/internal/servers/hls/server_test.go index e00ce5b0d50..86d51c927de 100644 --- a/internal/servers/hls/server_test.go +++ b/internal/servers/hls/server_test.go @@ -305,7 +305,7 @@ func TestServerRead(t *testing.T) { s.PathReady(&dummyPath{}) - time.Sleep(100 * time.Millisecond) + str.WaitRunningReader() for i := 0; i < 4; i++ { str.WriteUnit(test.MediaH264, test.FormatH264, &unit.H264{ @@ -398,7 +398,7 @@ func TestServerReadAuthorizationHeader(t *testing.T) { s.PathReady(&dummyPath{}) - time.Sleep(100 * time.Millisecond) + str.WaitRunningReader() for i := 0; i < 4; i++ { str.WriteUnit(test.MediaH264, test.FormatH264, &unit.H264{ diff --git a/internal/servers/rtmp/server_test.go b/internal/servers/rtmp/server_test.go index de7c7edcf72..bdd41769235 100644 --- a/internal/servers/rtmp/server_test.go +++ b/internal/servers/rtmp/server_test.go @@ -163,14 +163,14 @@ func TestServerPublish(t *testing.T) { return nil }) + path.stream.StartReader(reader) + defer path.stream.RemoveReader(reader) + err = w.WriteH264(0, 0, true, [][]byte{ {5, 2, 3, 4}, }) require.NoError(t, err) - path.stream.StartReader(reader) - defer path.stream.RemoveReader(reader) - <-recv }) } @@ -250,6 +250,8 @@ func TestServerRead(t *testing.T) { videoTrack, _ := r.Tracks() require.Equal(t, test.FormatH264, videoTrack) + stream.WaitRunningReader() + stream.WriteUnit(desc.Medias[0], desc.Medias[0].Formats[0], &unit.H264{ Base: unit.Base{ NTP: time.Time{}, diff --git a/internal/servers/rtsp/server_test.go b/internal/servers/rtsp/server_test.go index 7f54a4bc44d..a22c089db87 100644 --- a/internal/servers/rtsp/server_test.go +++ b/internal/servers/rtsp/server_test.go @@ -150,6 +150,9 @@ func TestServerPublish(t *testing.T) { return nil }) + path.stream.StartReader(reader) + defer path.stream.RemoveReader(reader) + err = source.WritePacketRTP(media0, &rtp.Packet{ Header: rtp.Header{ Version: 2, @@ -163,9 +166,6 @@ func TestServerPublish(t *testing.T) { }) require.NoError(t, err) - path.stream.StartReader(reader) - defer path.stream.RemoveReader(reader) - <-recv } diff --git a/internal/servers/srt/server_test.go b/internal/servers/srt/server_test.go index f6461375f7b..04c40561796 100644 --- a/internal/servers/srt/server_test.go +++ b/internal/servers/srt/server_test.go @@ -156,6 +156,9 @@ func TestServerPublish(t *testing.T) { return nil }) + path.stream.StartReader(reader) + defer path.stream.RemoveReader(reader) + err = w.WriteH264(track, 0, 0, true, [][]byte{ {5, 2}, }) @@ -164,9 +167,6 @@ func TestServerPublish(t *testing.T) { err = bw.Flush() require.NoError(t, err) - path.stream.StartReader(reader) - defer path.stream.RemoveReader(reader) - <-recv } @@ -219,6 +219,8 @@ func TestServerRead(t *testing.T) { require.NoError(t, err) defer reader.Close() + stream.WaitRunningReader() + stream.WriteUnit(desc.Medias[0], desc.Medias[0].Formats[0], &unit.H264{ Base: unit.Base{ NTP: time.Time{}, diff --git a/internal/servers/webrtc/server_test.go b/internal/servers/webrtc/server_test.go index 19f37648a34..80fa556199c 100644 --- a/internal/servers/webrtc/server_test.go +++ b/internal/servers/webrtc/server_test.go @@ -340,6 +340,9 @@ func TestServerPublish(t *testing.T) { return nil }) + path.stream.StartReader(reader) + defer path.stream.RemoveReader(reader) + err = track.WriteRTP(&rtp.Packet{ Header: rtp.Header{ Version: 2, @@ -353,9 +356,6 @@ func TestServerPublish(t *testing.T) { }) require.NoError(t, err) - path.stream.StartReader(reader) - defer path.stream.RemoveReader(reader) - <-recv } @@ -572,29 +572,20 @@ func TestServerRead(t *testing.T) { } writerDone := make(chan struct{}) - defer func() { <-writerDone }() - - writerTerminate := make(chan struct{}) - defer close(writerTerminate) go func() { defer close(writerDone) - for { - select { - case <-time.After(100 * time.Millisecond): - case <-writerTerminate: - return - } - - r := reflect.New(reflect.TypeOf(ca.unit).Elem()) - r.Elem().Set(reflect.ValueOf(ca.unit).Elem()) - - if g, ok := r.Interface().(*unit.Generic); ok { - clone := *g.RTPPackets[0] - str.WriteRTPPacket(desc.Medias[0], desc.Medias[0].Formats[0], &clone, time.Time{}, 0) - } else { - str.WriteUnit(desc.Medias[0], desc.Medias[0].Formats[0], r.Interface().(unit.Unit)) - } + + str.WaitRunningReader() + + r := reflect.New(reflect.TypeOf(ca.unit).Elem()) + r.Elem().Set(reflect.ValueOf(ca.unit).Elem()) + + if g, ok := r.Interface().(*unit.Generic); ok { + clone := *g.RTPPackets[0] + str.WriteRTPPacket(desc.Medias[0], desc.Medias[0].Formats[0], &clone, time.Time{}, 0) + } else { + str.WriteUnit(desc.Medias[0], desc.Medias[0].Formats[0], r.Interface().(unit.Unit)) } }() @@ -615,6 +606,7 @@ func TestServerRead(t *testing.T) { wc.StartReading() + <-writerDone <-done }) } diff --git a/internal/stream/stream.go b/internal/stream/stream.go index 9d654f57e22..8ff8b949a10 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -36,6 +36,8 @@ type Stream struct { rtspStream *gortsplib.ServerStream rtspsStream *gortsplib.ServerStream streamReaders map[Reader]*streamReader + + readerRunning chan struct{} } // New allocates a Stream. @@ -55,6 +57,7 @@ func New( s.streamMedias = make(map[*description.Media]*streamMedia) s.streamReaders = make(map[Reader]*streamReader) + s.readerRunning = make(chan struct{}) for _, media := range desc.Medias { var err error @@ -180,6 +183,12 @@ func (s *Stream) StartReader(reader Reader) { sf.startReader(sr) } } + + select { + case <-s.readerRunning: + default: + close(s.readerRunning) + } } // ReaderError returns whenever there's an error. @@ -209,6 +218,11 @@ func (s *Stream) ReaderFormats(reader Reader) []format.Format { return formats } +// WaitRunningReader waits for a running reader. +func (s *Stream) WaitRunningReader() { + <-s.readerRunning +} + // WriteUnit writes a Unit. func (s *Stream) WriteUnit(medi *description.Media, forma format.Format, u unit.Unit) { sm := s.streamMedias[medi]