Skip to content

Commit

Permalink
fix race condition in tests (#3834)
Browse files Browse the repository at this point in the history
  • Loading branch information
aler9 authored Oct 5, 2024
1 parent 534b637 commit 2586782
Show file tree
Hide file tree
Showing 10 changed files with 53 additions and 34 deletions.
2 changes: 2 additions & 0 deletions internal/protocols/hls/from_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,7 @@ func TestFromStreamSkipUnsupportedTracks(t *testing.T) {

err = FromStream(stream, l, m)
require.NoError(t, err)
defer stream.RemoveReader(l)

require.Equal(t, 2, n)
}
2 changes: 2 additions & 0 deletions internal/protocols/mpegts/from_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,7 @@ func TestFromStreamSkipUnsupportedTracks(t *testing.T) {

err = FromStream(stream, l, nil, nil, 0)
require.NoError(t, err)
defer stream.RemoveReader(l)

require.Equal(t, 1, n)
}
2 changes: 2 additions & 0 deletions internal/protocols/rtmp/from_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,7 @@ func TestFromStreamSkipUnsupportedTracks(t *testing.T) {

err = FromStream(stream, l, conn, nil, 0)
require.NoError(t, err)
defer stream.RemoveReader(l)

require.Equal(t, 2, n)
}
3 changes: 3 additions & 0 deletions internal/protocols/webrtc/from_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ func TestFromStreamSkipUnsupportedTracks(t *testing.T) {

err = FromStream(stream, l, pc)
require.NoError(t, err)
defer stream.RemoveReader(l)

require.Equal(t, 1, n)
}

Expand Down Expand Up @@ -93,6 +95,7 @@ func TestFromStream(t *testing.T) {

err = FromStream(stream, nil, pc)
require.NoError(t, err)
defer stream.RemoveReader(nil)

require.Equal(t, ca.webrtcCaps, pc.OutgoingTracks[0].Caps)
})
Expand Down
4 changes: 2 additions & 2 deletions internal/servers/hls/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ func TestServerRead(t *testing.T) {

s.PathReady(&dummyPath{})

time.Sleep(100 * time.Millisecond)
str.WaitRunningReader()

for i := 0; i < 4; i++ {
str.WriteUnit(test.MediaH264, test.FormatH264, &unit.H264{
Expand Down Expand Up @@ -398,7 +398,7 @@ func TestServerReadAuthorizationHeader(t *testing.T) {

s.PathReady(&dummyPath{})

time.Sleep(100 * time.Millisecond)
str.WaitRunningReader()

for i := 0; i < 4; i++ {
str.WriteUnit(test.MediaH264, test.FormatH264, &unit.H264{
Expand Down
8 changes: 5 additions & 3 deletions internal/servers/rtmp/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,14 @@ func TestServerPublish(t *testing.T) {
return nil
})

path.stream.StartReader(reader)
defer path.stream.RemoveReader(reader)

err = w.WriteH264(0, 0, true, [][]byte{
{5, 2, 3, 4},
})
require.NoError(t, err)

path.stream.StartReader(reader)
defer path.stream.RemoveReader(reader)

<-recv
})
}
Expand Down Expand Up @@ -250,6 +250,8 @@ func TestServerRead(t *testing.T) {
videoTrack, _ := r.Tracks()
require.Equal(t, test.FormatH264, videoTrack)

stream.WaitRunningReader()

stream.WriteUnit(desc.Medias[0], desc.Medias[0].Formats[0], &unit.H264{
Base: unit.Base{
NTP: time.Time{},
Expand Down
6 changes: 3 additions & 3 deletions internal/servers/rtsp/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ func TestServerPublish(t *testing.T) {
return nil
})

path.stream.StartReader(reader)
defer path.stream.RemoveReader(reader)

err = source.WritePacketRTP(media0, &rtp.Packet{
Header: rtp.Header{
Version: 2,
Expand All @@ -163,9 +166,6 @@ func TestServerPublish(t *testing.T) {
})
require.NoError(t, err)

path.stream.StartReader(reader)
defer path.stream.RemoveReader(reader)

<-recv
}

Expand Down
8 changes: 5 additions & 3 deletions internal/servers/srt/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ func TestServerPublish(t *testing.T) {
return nil
})

path.stream.StartReader(reader)
defer path.stream.RemoveReader(reader)

err = w.WriteH264(track, 0, 0, true, [][]byte{
{5, 2},
})
Expand All @@ -164,9 +167,6 @@ func TestServerPublish(t *testing.T) {
err = bw.Flush()
require.NoError(t, err)

path.stream.StartReader(reader)
defer path.stream.RemoveReader(reader)

<-recv
}

Expand Down Expand Up @@ -219,6 +219,8 @@ func TestServerRead(t *testing.T) {
require.NoError(t, err)
defer reader.Close()

stream.WaitRunningReader()

stream.WriteUnit(desc.Medias[0], desc.Medias[0].Formats[0], &unit.H264{
Base: unit.Base{
NTP: time.Time{},
Expand Down
38 changes: 15 additions & 23 deletions internal/servers/webrtc/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,9 @@ func TestServerPublish(t *testing.T) {
return nil
})

path.stream.StartReader(reader)
defer path.stream.RemoveReader(reader)

err = track.WriteRTP(&rtp.Packet{
Header: rtp.Header{
Version: 2,
Expand All @@ -353,9 +356,6 @@ func TestServerPublish(t *testing.T) {
})
require.NoError(t, err)

path.stream.StartReader(reader)
defer path.stream.RemoveReader(reader)

<-recv
}

Expand Down Expand Up @@ -572,29 +572,20 @@ func TestServerRead(t *testing.T) {
}

writerDone := make(chan struct{})
defer func() { <-writerDone }()

writerTerminate := make(chan struct{})
defer close(writerTerminate)

go func() {
defer close(writerDone)
for {
select {
case <-time.After(100 * time.Millisecond):
case <-writerTerminate:
return
}

r := reflect.New(reflect.TypeOf(ca.unit).Elem())
r.Elem().Set(reflect.ValueOf(ca.unit).Elem())

if g, ok := r.Interface().(*unit.Generic); ok {
clone := *g.RTPPackets[0]
str.WriteRTPPacket(desc.Medias[0], desc.Medias[0].Formats[0], &clone, time.Time{}, 0)
} else {
str.WriteUnit(desc.Medias[0], desc.Medias[0].Formats[0], r.Interface().(unit.Unit))
}

str.WaitRunningReader()

r := reflect.New(reflect.TypeOf(ca.unit).Elem())
r.Elem().Set(reflect.ValueOf(ca.unit).Elem())

if g, ok := r.Interface().(*unit.Generic); ok {
clone := *g.RTPPackets[0]
str.WriteRTPPacket(desc.Medias[0], desc.Medias[0].Formats[0], &clone, time.Time{}, 0)
} else {
str.WriteUnit(desc.Medias[0], desc.Medias[0].Formats[0], r.Interface().(unit.Unit))
}
}()

Expand All @@ -615,6 +606,7 @@ func TestServerRead(t *testing.T) {

wc.StartReading()

<-writerDone
<-done
})
}
Expand Down
14 changes: 14 additions & 0 deletions internal/stream/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ type Stream struct {
rtspStream *gortsplib.ServerStream
rtspsStream *gortsplib.ServerStream
streamReaders map[Reader]*streamReader

readerRunning chan struct{}
}

// New allocates a Stream.
Expand All @@ -55,6 +57,7 @@ func New(

s.streamMedias = make(map[*description.Media]*streamMedia)
s.streamReaders = make(map[Reader]*streamReader)
s.readerRunning = make(chan struct{})

for _, media := range desc.Medias {
var err error
Expand Down Expand Up @@ -180,6 +183,12 @@ func (s *Stream) StartReader(reader Reader) {
sf.startReader(sr)
}
}

select {
case <-s.readerRunning:
default:
close(s.readerRunning)
}
}

// ReaderError returns whenever there's an error.
Expand Down Expand Up @@ -209,6 +218,11 @@ func (s *Stream) ReaderFormats(reader Reader) []format.Format {
return formats
}

// WaitRunningReader waits for a running reader.
func (s *Stream) WaitRunningReader() {
<-s.readerRunning
}

// WriteUnit writes a Unit.
func (s *Stream) WriteUnit(medi *description.Media, forma format.Format, u unit.Unit) {
sm := s.streamMedias[medi]
Expand Down

0 comments on commit 2586782

Please sign in to comment.