diff --git a/announcement_map.go b/announcement_map.go deleted file mode 100644 index c2892cd..0000000 --- a/announcement_map.go +++ /dev/null @@ -1,44 +0,0 @@ -package moqtransport - -import ( - "errors" - "sync" -) - -type announcementMap struct { - mutex sync.Mutex - announcements map[string]*Announcement - newAnnouncementChan chan string -} - -func newAnnouncementMap() *announcementMap { - return &announcementMap{ - mutex: sync.Mutex{}, - announcements: map[string]*Announcement{}, - newAnnouncementChan: make(chan string, 1), - } -} - -func (m *announcementMap) add(name string, a *Announcement) error { - m.mutex.Lock() - defer m.mutex.Unlock() - if _, ok := m.announcements[name]; ok { - return errors.New("duplicate announcement") - } - m.announcements[name] = a - m.newAnnouncementChan <- name - return nil -} - -func (m *announcementMap) get(name string) (*Announcement, bool) { - m.mutex.Lock() - defer m.mutex.Unlock() - a, ok := m.announcements[name] - return a, ok -} - -func (m *announcementMap) delete(name string) { - m.mutex.Lock() - defer m.mutex.Unlock() - delete(m.announcements, name) -} diff --git a/examples/date-server/main.go b/examples/date-server/main.go index ff5b847..70c73c3 100644 --- a/examples/date-server/main.go +++ b/examples/date-server/main.go @@ -7,7 +7,6 @@ import ( "crypto/tls" "crypto/x509" "encoding/pem" - "errors" "flag" "fmt" "log" @@ -56,6 +55,8 @@ func listen(ctx context.Context, addr string, tlsConfig *tls.Config) error { TLSConfig: tlsConfig, }, } + track := moqtransport.NewLocalTrack(0, "clock", "second") + defer track.Close() http.HandleFunc("/moq", func(w http.ResponseWriter, r *http.Request) { session, err := wt.Upgrade(w, r) if err != nil { @@ -72,8 +73,22 @@ func listen(ctx context.Context, addr string, tlsConfig *tls.Config) error { w.WriteHeader(http.StatusInternalServerError) return } - go handle(moqSession) + go handle(moqSession, track) }) + go func() { + ticker := time.NewTicker(time.Second) + id := uint64(0) + for ts := range ticker.C { + log.Printf("tick: %v, subscribers: %v\n", ts, track.SubscriberCount()) + track.WriteObject(ctx, moqtransport.Object{ + GroupID: id, + ObjectID: 0, + ObjectSendOrder: 0, + Payload: []byte(fmt.Sprintf("%v", ts)), + }) + id++ + } + }() for { conn, err := listener.Accept(ctx) if err != nil { @@ -90,45 +105,16 @@ func listen(ctx context.Context, addr string, tlsConfig *tls.Config) error { if err := s.RunServer(ctx); err != nil { return err } - go handle(s) + go handle(s, track) } } } -func handle(p *moqtransport.Session) { - go func() { - s, err := p.ReadSubscription(context.Background(), func(s *moqtransport.SendSubscription) error { - if fmt.Sprintf("%v/%v", s.Namespace(), s.Trackname()) != "clock/second" { - return errors.New("unknown namespace/trackname") - } - return nil - }) - if err != nil { - panic(err) - } - log.Printf("got subscription: %v", s) - go func() { - ticker := time.NewTicker(time.Second) - id := uint64(0) - for ts := range ticker.C { - w, err := s.NewObjectStream(id, 0, 0) // TODO: Use meaningful values - if err != nil { - log.Println(err) - return - } - if _, err := fmt.Fprintf(w, "%v", ts); err != nil { - log.Println(err) - return - } - if err := w.Close(); err != nil { - log.Println(err) - return - } - id++ - } - }() - }() - if err := p.Announce(context.Background(), "clock"); err != nil { +func handle(s *moqtransport.Session, t *moqtransport.LocalTrack) { + if err := s.AddLocalTrack(t); err != nil { + panic(err) + } + if err := s.Announce(context.Background(), "clock"); err != nil { panic(err) } } diff --git a/group_header_stream.go b/group_header_stream.go index 731b76f..e09f738 100644 --- a/group_header_stream.go +++ b/group_header_stream.go @@ -22,10 +22,14 @@ func newGroupHeaderStream(stream SendStream, subscribeID, trackAlias, groupID, o }, nil } -func (s *groupHeaderStream) NewObject() *groupHeaderStreamObject { - return &groupHeaderStreamObject{ - stream: s.stream, +func (s *groupHeaderStream) writeObject(objectID uint64, payload []byte) (int, error) { + shgo := streamHeaderGroupObject{ + ObjectID: objectID, + ObjectPayload: payload, } + buf := make([]byte, 0, 16+len(payload)) + buf = shgo.append(buf) + return s.stream.Write(buf) } func (s *groupHeaderStream) Close() error { diff --git a/group_header_stream_object.go b/group_header_stream_object.go deleted file mode 100644 index 82e6aad..0000000 --- a/group_header_stream_object.go +++ /dev/null @@ -1,16 +0,0 @@ -package moqtransport - -type groupHeaderStreamObject struct { - stream SendStream - objectID uint64 -} - -func (s *groupHeaderStreamObject) Write(payload []byte) (int, error) { - shgo := streamHeaderGroupObject{ - ObjectID: s.objectID, - ObjectPayload: payload, - } - buf := make([]byte, 0, 16+len(payload)) - buf = shgo.append(buf) - return s.stream.Write(buf) -} diff --git a/integrationtests/integration_test.go b/integrationtests/integration_test.go index 83f2d5a..0be60e3 100644 --- a/integrationtests/integration_test.go +++ b/integrationtests/integration_test.go @@ -149,18 +149,27 @@ func TestIntegration(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() server := quicServerSession(t, ctx, listener, nil) - sub, err := server.ReadSubscription(ctx, func(ss *moqtransport.SendSubscription) error { return nil }) + track := moqtransport.NewLocalTrack(0, "namespace", "track") + defer track.Close() + err := server.AddLocalTrack(track) + assert.NoError(t, err) + err = server.Announce(ctx, "namespace") assert.NoError(t, err) - assert.Equal(t, "namespace", sub.Namespace()) - assert.Equal(t, "track", sub.Trackname()) <-receivedSubscribeOK assert.NoError(t, server.Close()) }() ctx, cancel := context.WithCancel(context.Background()) defer cancel() - client := quicClientSession(t, ctx, addr, nil) - _, err := client.Subscribe(ctx, 0, 0, "namespace", "track", "auth") + announcementCh := make(chan struct{}) + client := quicClientSession(t, ctx, addr, moqtransport.AnnouncementHandlerFunc(func(a *moqtransport.Announcement, arw moqtransport.AnnouncementResponseWriter) { + assert.Equal(t, "namespace", a.Namespace()) + arw.Accept() + close(announcementCh) + })) + <-announcementCh + r, err := client.Subscribe(ctx, 0, 0, "namespace", "track", "auth") assert.NoError(t, err) + assert.NotNil(t, r) close(receivedSubscribeOK) assert.NoError(t, client.Close()) wg.Wait() @@ -172,29 +181,44 @@ func TestIntegration(t *testing.T) { listener, addr, teardown := setup() defer teardown() wg.Add(1) + subscribedCh := make(chan struct{}) receivedObject := make(chan struct{}) go func() { defer wg.Done() ctx, cancel := context.WithCancel(context.Background()) defer cancel() server := quicServerSession(t, ctx, listener, nil) - sub, err := server.ReadSubscription(ctx, func(ss *moqtransport.SendSubscription) error { return nil }) + track := moqtransport.NewLocalTrack(0, "namespace", "track") + defer track.Close() + err := server.AddLocalTrack(track) assert.NoError(t, err) - assert.Equal(t, "namespace", sub.Namespace()) - assert.Equal(t, "track", sub.Trackname()) - s, err := sub.NewObjectStream(0, 0, 0) + err = server.Announce(ctx, "namespace") assert.NoError(t, err) - _, err = s.Write([]byte("hello world")) + <-subscribedCh + err = track.WriteObject(ctx, moqtransport.Object{ + GroupID: 0, + ObjectID: 0, + ObjectSendOrder: 0, + ForwardingPreference: 0, + Payload: []byte("hello world"), + }) assert.NoError(t, err) - assert.NoError(t, s.Close()) <-receivedObject + assert.NoError(t, track.Close()) assert.NoError(t, server.Close()) }() ctx, cancel := context.WithCancel(context.Background()) defer cancel() - client := quicClientSession(t, ctx, addr, nil) + announcementCh := make(chan struct{}) + client := quicClientSession(t, ctx, addr, moqtransport.AnnouncementHandlerFunc(func(a *moqtransport.Announcement, arw moqtransport.AnnouncementResponseWriter) { + assert.Equal(t, "namespace", a.Namespace()) + arw.Accept() + close(announcementCh) + })) + <-announcementCh sub, err := client.Subscribe(ctx, 0, 0, "namespace", "track", "auth") assert.NoError(t, err) + close(subscribedCh) buf := make([]byte, 1500) n, err := sub.Read(buf) assert.NoError(t, err) @@ -210,36 +234,75 @@ func TestIntegration(t *testing.T) { listener, addr, teardown := setup() defer teardown() wg.Add(1) - receivedUnsubscribe := make(chan struct{}) + receivedSubscribeCh := make(chan struct{}) + receivedUnsubscribeCh := make(chan struct{}) + subscribedCh := make(chan struct{}) + unsubscribedCh := make(chan struct{}) go func() { defer wg.Done() ctx, cancel := context.WithCancel(context.Background()) defer cancel() server := quicServerSession(t, ctx, listener, nil) - sub, err := server.ReadSubscription(ctx, func(ss *moqtransport.SendSubscription) error { return nil }) + track := moqtransport.NewLocalTrack(0, "namespace", "track") + defer track.Close() + err := server.AddLocalTrack(track) + assert.NoError(t, err) + err = server.Announce(ctx, "namespace") assert.NoError(t, err) - assert.Equal(t, "namespace", sub.Namespace()) - assert.Equal(t, "track", sub.Trackname()) - for i := 0; i < 10; i++ { - err = sub.NewObjectPreferDatagram(0, 0, 0, nil) - if err != nil { + err = track.WriteObject(ctx, moqtransport.Object{ + GroupID: 0, + ObjectID: 0, + ObjectSendOrder: 0, + ForwardingPreference: 0, + Payload: []byte("hello world"), + }) + assert.NoError(t, err) + <-subscribedCh + assert.Equal(t, 1, track.SubscriberCount()) + err = track.WriteObject(ctx, moqtransport.Object{ + GroupID: 0, + ObjectID: 0, + ObjectSendOrder: 0, + ForwardingPreference: 0, + Payload: []byte("hello world"), + }) + assert.NoError(t, err) + close(receivedSubscribeCh) + <-unsubscribedCh + err = track.WriteObject(ctx, moqtransport.Object{ + GroupID: 0, + ObjectID: 0, + ObjectSendOrder: 0, + ForwardingPreference: 0, + Payload: []byte("hello world"), + }) + for i := 0; i < 3; i++ { + if track.SubscriberCount() == 0 { break } time.Sleep(10 * time.Millisecond) } - assert.Error(t, err) - assert.ErrorContains(t, err, "peer unsubscribed") - close(receivedUnsubscribe) + assert.NoError(t, err) + assert.Equal(t, 0, track.SubscriberCount()) + close(receivedUnsubscribeCh) assert.NoError(t, server.Close()) }() ctx, cancel := context.WithCancel(context.Background()) defer cancel() - client := quicClientSession(t, ctx, addr, nil) + announcementCh := make(chan struct{}) + client := quicClientSession(t, ctx, addr, moqtransport.AnnouncementHandlerFunc(func(a *moqtransport.Announcement, arw moqtransport.AnnouncementResponseWriter) { + assert.Equal(t, "namespace", a.Namespace()) + arw.Accept() + close(announcementCh) + })) + <-announcementCh sub, err := client.Subscribe(ctx, 0, 0, "namespace", "track", "auth") assert.NoError(t, err) + close(subscribedCh) + <-receivedSubscribeCh sub.Unsubscribe() - <-receivedUnsubscribe - assert.NoError(t, err) + close(unsubscribedCh) + <-receivedUnsubscribeCh assert.NoError(t, client.Close()) wg.Wait() }) diff --git a/local_track.go b/local_track.go new file mode 100644 index 0000000..a212b5d --- /dev/null +++ b/local_track.go @@ -0,0 +1,188 @@ +package moqtransport + +import ( + "context" + "errors" + "io" + "sync" +) + +type subscriberMapKey struct { + trackID, subscribeID uint64 +} + +type subscriberAction bool + +const ( + subscriberActionAdd subscriberAction = true + subscriberActionRemove subscriberAction = false +) + +type subscriberOp struct { + action subscriberAction + trackAlias, subscriberID uint64 + subscriber ObjectWriter + errCh chan error +} + +type ObjectForwardingPreference int + +const ( + ObjectForwardingPreferenceDatagram ObjectForwardingPreference = iota + ObjectForwardingPreferenceStream + ObjectForwardingPreferenceStreamGroup + ObjectForwardingPreferenceStreamTrack +) + +type Object struct { + GroupID uint64 + ObjectID uint64 + ObjectSendOrder uint64 + ForwardingPreference ObjectForwardingPreference + + Payload []byte +} + +// An ObjectWriter allows sending objects using. +type ObjectWriter interface { + WriteObject(Object) error + io.Closer +} + +// A LocalTrack is a local media source. Writing objects to the track will relay +// the objects to all subscribers. +// All methods are safe for concurrent use. Ordering of objects is only +// guaranteed within MultiObjectStreams. LocalTracks must be created with +// NewLocalTrack to ensure proper initialization. +type LocalTrack struct { + ID uint64 + Namespace string + Name string + + cancelCtx context.CancelFunc + cancelWG sync.WaitGroup + ctx context.Context + + subscriberCh chan subscriberOp + subscribers map[subscriberMapKey]ObjectWriter + objectCh chan Object + subscriberCountCh chan int +} + +// NewLocalTrack creates a new LocalTrack +func NewLocalTrack(id uint64, namespace, trackname string) *LocalTrack { + ctx, cancelCtx := context.WithCancel(context.Background()) + lt := &LocalTrack{ + ID: id, + Namespace: namespace, + Name: trackname, + cancelCtx: cancelCtx, + cancelWG: sync.WaitGroup{}, + ctx: ctx, + subscriberCh: make(chan subscriberOp), + subscribers: map[subscriberMapKey]ObjectWriter{}, + objectCh: make(chan Object), + subscriberCountCh: make(chan int), + } + lt.cancelWG.Add(1) + go lt.loop() + return lt +} + +func (t *LocalTrack) manageSubscriber(op subscriberOp) error { + key := subscriberMapKey{ + trackID: op.trackAlias, + subscribeID: op.subscriberID, + } + _, ok := t.subscribers[key] + switch op.action { + case subscriberActionAdd: + if ok { + return errors.New("duplicate subscriber ID in session") + } + t.subscribers[key] = op.subscriber + return nil + case subscriberActionRemove: + if !ok { + return errors.New("subscriber not found") + } + delete(t.subscribers, key) + return nil + } + return errors.New("invalid subscriber action") +} + +func (t *LocalTrack) loop() { + defer t.cancelWG.Done() + for { + select { + case <-t.ctx.Done(): + for _, v := range t.subscribers { + v.Close() + } + return + case op := <-t.subscriberCh: + op.errCh <- t.manageSubscriber(op) + case object := <-t.objectCh: + for _, v := range t.subscribers { + if err := v.WriteObject(object); err != nil { + // TODO: Notify / remove subscriber? + panic(err) + } + } + case t.subscriberCountCh <- len(t.subscribers): + } + } +} + +func (t *LocalTrack) subscribe( + trackAlias uint64, + subscribeID uint64, + subscriber ObjectWriter, +) error { + if subscriber == nil { + return errors.New("subscriber MUST NOT be nil") + } + addOp := subscriberOp{ + action: subscriberActionAdd, + trackAlias: trackAlias, + subscriberID: subscribeID, + subscriber: subscriber, + errCh: make(chan error), + } + // TODO: Should this have a timeout or similar? + t.subscriberCh <- addOp + return <-addOp.errCh +} + +func (t *LocalTrack) unsubscribe(trackAlias, subscribeID uint64) error { + removeOp := subscriberOp{ + action: subscriberActionRemove, + trackAlias: trackAlias, + subscriberID: subscribeID, + subscriber: nil, + errCh: make(chan error), + } + t.subscriberCh <- removeOp + return <-removeOp.errCh +} + +// WriteObject adds an object with forwarding preference DATAGRAM +func (t *LocalTrack) WriteObject(ctx context.Context, o Object) error { + select { + case <-ctx.Done(): + return ctx.Err() + case t.objectCh <- o: + } + return nil +} + +func (t *LocalTrack) Close() error { + t.cancelCtx() + t.cancelWG.Wait() + return nil +} + +func (t *LocalTrack) SubscriberCount() int { + return <-t.subscriberCountCh +} diff --git a/message.go b/message.go index 2fc1855..3b0abb3 100644 --- a/message.go +++ b/message.go @@ -22,6 +22,7 @@ const ( // Errors not included in current draft ErrorCodeUnsupportedVersion = 0xff01 + ErrorCodeTrackNotFound = 0xff02 ) const ( diff --git a/receive_subscription.go b/receive_subscription.go index b644b37..cad6c2c 100644 --- a/receive_subscription.go +++ b/receive_subscription.go @@ -32,10 +32,6 @@ type payloader interface { payload() []byte } -func (s *ReceiveSubscription) push(p payloader) (int, error) { - return s.writeBuffer.Write(p.payload()) -} - func (s *ReceiveSubscription) Read(buf []byte) (int, error) { return s.readBuffer.Read(buf) } @@ -44,8 +40,12 @@ func (s *ReceiveSubscription) Unsubscribe() { s.session.unsubscribe(s.subscribeID) } -func (s *ReceiveSubscription) unsubscribe() error { - return s.writeBuffer.Close() +func (s *ReceiveSubscription) close() { + s.writeBuffer.Close() +} + +func (s *ReceiveSubscription) push(p payloader) (int, error) { + return s.writeBuffer.Write(p.payload()) } func (s *ReceiveSubscription) readTrackHeaderStream(rs ReceiveStream) { diff --git a/send_subscription.go b/send_subscription.go index 07569a9..2c3d3f1 100644 --- a/send_subscription.go +++ b/send_subscription.go @@ -1,9 +1,9 @@ package moqtransport import ( + "context" "errors" "sync" - "time" "github.com/quic-go/quic-go" ) @@ -11,84 +11,94 @@ import ( var errUnsubscribed = errors.New("peer unsubscribed") type SendSubscription struct { - lock sync.RWMutex - closeCh chan struct{} - expires time.Duration - - conn Connection + cancelCtx context.CancelFunc + cancelWG sync.WaitGroup + ctx context.Context subscribeID, trackAlias uint64 namespace, trackname string - startGroup, startObject Location - endGroup, endObject Location - parameters parameters -} - -func (s *SendSubscription) SetExpires(d time.Duration) { - s.lock.Lock() - defer s.lock.Unlock() - s.expires = d -} - -func (s *SendSubscription) Namespace() string { - return s.namespace -} - -func (s *SendSubscription) Trackname() string { - return s.trackname -} - -func (s *SendSubscription) StartGroup() Location { - return s.startGroup -} - -func (s *SendSubscription) StartObject() Location { - return s.startObject + conn Connection + objectCh chan Object + trackHeaderStream *TrackHeaderStream + groupHeaderStreams map[uint64]*groupHeaderStream } -func (s *SendSubscription) EndGroup() Location { - return s.endGroup -} - -func (s *SendSubscription) EndObject() Location { - return s.endObject +func newSendSubscription(conn Connection, subscribeID, trackAlias uint64, namespace, trackname string) *SendSubscription { + ctx, cancelCtx := context.WithCancel(context.Background()) + sub := &SendSubscription{ + cancelCtx: cancelCtx, + cancelWG: sync.WaitGroup{}, + ctx: ctx, + subscribeID: subscribeID, + trackAlias: trackAlias, + namespace: namespace, + trackname: trackname, + conn: conn, + objectCh: make(chan Object, 64), + trackHeaderStream: &TrackHeaderStream{}, + groupHeaderStreams: map[uint64]*groupHeaderStream{}, + } + sub.cancelWG.Add(1) + go sub.loop() + return sub } -func (s *SendSubscription) unsubscribe() error { - close(s.closeCh) - return nil +func (s *SendSubscription) loop() { + defer s.cancelWG.Done() + for { + select { + case o := <-s.objectCh: + s.sendObject(o) + case <-s.ctx.Done(): + return + } + } } -func (s *SendSubscription) NewObjectStream(groupID, objectID, objectSendOrder uint64) (*objectStream, error) { - select { - case <-s.closeCh: - return nil, errUnsubscribed - default: - } - stream, err := s.conn.OpenUniStream() - if err != nil { - return nil, err +func (s *SendSubscription) sendObject(o Object) { + switch o.ForwardingPreference { + case ObjectForwardingPreferenceDatagram: + if err := s.sendDatagram(o); err != nil { + panic(err) + } + case ObjectForwardingPreferenceStream: + if err := s.sendObjectStream(o); err != nil { + panic(err) + } + case ObjectForwardingPreferenceStreamTrack: + if err := s.sendTrackHeaderStream(o); err != nil { + panic(err) + } + case ObjectForwardingPreferenceStreamGroup: + if err := s.sendGroupHeaderStream(o); err != nil { + panic(err) + } } - return newObjectStream(stream, s.subscribeID, s.trackAlias, groupID, objectID, objectSendOrder) } -func (s *SendSubscription) NewObjectPreferDatagram(groupID, objectID, objectSendOrder uint64, payload []byte) error { +func (s *SendSubscription) WriteObject(o Object) error { select { - case <-s.closeCh: + case s.objectCh <- o: + case <-s.ctx.Done(): return errUnsubscribed default: + panic("TODO: improve queuing/caching for slow subscribers?") } - o := objectMessage{ + return nil +} + +func (s *SendSubscription) sendDatagram(o Object) error { + om := objectMessage{ datagram: true, SubscribeID: s.subscribeID, TrackAlias: s.trackAlias, - GroupID: groupID, - ObjectID: objectID, - ObjectSendOrder: objectSendOrder, - ObjectPayload: payload, + GroupID: o.GroupID, + ObjectID: o.ObjectID, + ObjectSendOrder: o.ObjectSendOrder, + ObjectPayload: o.Payload, } - buf := make([]byte, 0, 48+len(o.ObjectPayload)) - buf = o.append(buf) + buf := make([]byte, 0, 48+len(o.Payload)) + buf = om.append(buf) err := s.conn.SendDatagram(buf) if err == nil { return nil @@ -96,39 +106,61 @@ func (s *SendSubscription) NewObjectPreferDatagram(groupID, objectID, objectSend if !errors.Is(err, &quic.DatagramTooLargeError{}) { return err } - os, err := s.NewObjectStream(groupID, objectID, objectSendOrder) + return s.sendObjectStream(o) +} + +func (s *SendSubscription) sendObjectStream(o Object) error { + stream, err := s.conn.OpenUniStream() if err != nil { return err } - _, err = os.Write(buf) + os, err := newObjectStream(stream, s.subscribeID, s.trackAlias, o.GroupID, o.ObjectID, o.ObjectSendOrder) if err != nil { return err } + if _, err := os.Write(o.Payload); err != nil { + return err + } return os.Close() } -func (s *SendSubscription) NewTrackHeaderStream(objectSendOrder uint64) (*TrackHeaderStream, error) { - select { - case <-s.closeCh: - return nil, errUnsubscribed - default: +func (s *SendSubscription) sendTrackHeaderStream(o Object) error { + if s.trackHeaderStream == nil { + stream, err := s.conn.OpenUniStream() + if err != nil { + return err + } + ts, err := newTrackHeaderStream(stream, s.subscribeID, s.trackAlias, o.ObjectSendOrder) + if err != nil { + return err + } + s.trackHeaderStream = ts } - stream, err := s.conn.OpenUniStream() - if err != nil { - return nil, err - } - return newTrackHeaderStream(stream, s.subscribeID, s.trackAlias, objectSendOrder) + _, err := s.trackHeaderStream.writeObject(o.GroupID, o.ObjectID, o.Payload) + return err } -func (s *SendSubscription) NewGroupHeaderStream(groupID, objectSendOrder uint64) (*groupHeaderStream, error) { - select { - case <-s.closeCh: - return nil, errUnsubscribed - default: - } - stream, err := s.conn.OpenUniStream() - if err != nil { - return nil, err +func (s *SendSubscription) sendGroupHeaderStream(o Object) error { + gs, ok := s.groupHeaderStreams[o.GroupID] + if !ok { + var stream SendStream + var err error + stream, err = s.conn.OpenUniStream() + if err != nil { + return err + } + gs, err = newGroupHeaderStream(stream, s.subscribeID, s.trackAlias, o.GroupID, o.ObjectSendOrder) + if err != nil { + return err + } + s.groupHeaderStreams[o.GroupID] = gs } - return newGroupHeaderStream(stream, s.subscribeID, s.trackAlias, groupID, objectSendOrder) + _, err := gs.writeObject(o.ObjectID, o.Payload) + return err +} + +func (s *SendSubscription) Close() error { + s.cancelCtx() + s.cancelWG.Wait() + return nil } diff --git a/session.go b/session.go index ba24c40..fc0ac10 100644 --- a/session.go +++ b/session.go @@ -36,16 +36,22 @@ type trackNamespacer interface { trackNamespace() string } +type trackKey struct { + namespace string + trackname string +} + type sessionInternals struct { logger *slog.Logger serverHandshakeDoneCh chan struct{} controlStreamStoreCh chan controlMessageSender // Needs to be buffered closeOnce sync.Once closed chan struct{} - sendSubscriptions *subscriptionMap[*SendSubscription] - receiveSubscriptions *subscriptionMap[*ReceiveSubscription] - localAnnouncements *announcementMap - remoteAnnouncements *announcementMap + sendSubscriptions *syncMap[uint64, *SendSubscription] + receiveSubscriptions *syncMap[uint64, *ReceiveSubscription] + localAnnouncements *syncMap[string, *Announcement] + remoteAnnouncements *syncMap[string, *Announcement] + localTracks *syncMap[trackKey, *LocalTrack] } func newSessionInternals(logSuffix string) *sessionInternals { @@ -55,10 +61,11 @@ func newSessionInternals(logSuffix string) *sessionInternals { controlStreamStoreCh: make(chan controlMessageSender, 1), closeOnce: sync.Once{}, closed: make(chan struct{}), - sendSubscriptions: newSubscriptionMap[*SendSubscription](), - receiveSubscriptions: newSubscriptionMap[*ReceiveSubscription](), - localAnnouncements: newAnnouncementMap(), - remoteAnnouncements: newAnnouncementMap(), + sendSubscriptions: newSyncMap[uint64, *SendSubscription](), + receiveSubscriptions: newSyncMap[uint64, *ReceiveSubscription](), + localAnnouncements: newSyncMap[string, *Announcement](), + remoteAnnouncements: newSyncMap[string, *Announcement](), + localTracks: newSyncMap[trackKey, *LocalTrack](), } } @@ -346,29 +353,30 @@ func (s *Session) handleControlMessage(msg message) error { func (s *Session) handleNonSetupMessage(msg message) error { switch m := msg.(type) { case *subscribeMessage: - return s.handleSubscribe(m) + s.handleSubscribe(m) case *subscribeOkMessage: return s.handleSubscriptionResponse(m) case *subscribeErrorMessage: return s.handleSubscriptionResponse(m) case *subscribeDoneMessage: - return s.handleSubscribeDone(m) + s.handleSubscribeDone(m) case *unsubscribeMessage: return s.handleUnsubscribe(m) case *announceMessage: s.handleAnnounceMessage(m) - return nil case *announceOkMessage: return s.handleAnnouncementResponse(m) case *announceErrorMessage: return s.handleAnnouncementResponse(m) case *goAwayMessage: panic("TODO") + default: + return &ProtocolError{ + code: ErrorCodeInternal, + message: "received unexpected message type on control stream", + } } - return &ProtocolError{ - code: ErrorCodeInternal, - message: "received unexpected message type on control stream", - } + return nil } func (s *Session) handleSubscriptionResponse(msg subscribeIDer) error { @@ -405,42 +413,89 @@ func (s *Session) handleAnnouncementResponse(msg trackNamespacer) error { return nil } -func (s *Session) handleSubscribe(msg *subscribeMessage) error { - sub := &SendSubscription{ - lock: sync.RWMutex{}, - closeCh: make(chan struct{}), - expires: 0, - conn: s.Conn, - subscribeID: msg.SubscribeID, - trackAlias: msg.TrackAlias, - namespace: msg.TrackNamespace, - trackname: msg.TrackName, - startGroup: msg.StartGroup, - startObject: msg.StartObject, - endGroup: msg.EndGroup, - endObject: msg.EndObject, - parameters: msg.Parameters, +func (s *Session) handleSubscribe(msg *subscribeMessage) { + t, ok := s.si.localTracks.get(trackKey{ + namespace: msg.TrackNamespace, + trackname: msg.TrackName, + }) + if !ok { + s.controlStream.enqueue(&subscribeErrorMessage{ + SubscribeID: msg.SubscribeID, + ErrorCode: ErrorCodeTrackNotFound, + ReasonPhrase: "track not found", + TrackAlias: msg.TrackAlias, + }) + return + } + sub := newSendSubscription(s.Conn, msg.SubscribeID, msg.TrackAlias, msg.TrackNamespace, msg.TrackName) + if err := s.si.sendSubscriptions.add(sub.subscribeID, sub); err != nil { + s.controlStream.enqueue(&subscribeErrorMessage{ + SubscribeID: msg.SubscribeID, + ErrorCode: ErrorCodeInternal, // TODO: Set better error code? + ReasonPhrase: "internal error", + TrackAlias: msg.TrackAlias, + }) + return + } + if err := t.subscribe( + msg.SubscribeID, + msg.SubscribeID, + sub, + ); err != nil { + s.controlStream.enqueue(&subscribeErrorMessage{ + SubscribeID: msg.SubscribeID, + ErrorCode: ErrorCodeInternal, // TODO: Set better error code? + ReasonPhrase: "internal error", + TrackAlias: msg.TrackAlias, + }) + return } - return s.si.sendSubscriptions.add(sub.subscribeID, sub) + s.controlStream.enqueue(&subscribeOkMessage{ + SubscribeID: msg.SubscribeID, + Expires: 0, // TODO + ContentExists: false, // TODO + FinalGroup: 0, // TODO + FinalObject: 0, // TODO + }) } func (s *Session) handleUnsubscribe(msg *unsubscribeMessage) error { - if err := s.si.sendSubscriptions.delete(msg.SubscribeID); err != nil { - return err + sub, ok := s.si.sendSubscriptions.get(msg.SubscribeID) + if !ok { + return errors.New("subscription not found") } + track, ok := s.si.localTracks.get(trackKey{ + namespace: sub.namespace, + trackname: sub.trackname, + }) + if !ok { + return errors.New("no track related to subscription found") + } + err := track.unsubscribe(sub.trackAlias, sub.subscribeID) + if err != nil { + panic(err) + } + sub.Close() + s.si.sendSubscriptions.delete(msg.SubscribeID) s.controlStream.enqueue(&subscribeDoneMessage{ SusbcribeID: msg.SubscribeID, StatusCode: 0, ReasonPhrase: "unsubscribed", - ContentExists: false, - FinalGroup: 0, - FinalObject: 0, + ContentExists: false, // TODO + FinalGroup: 0, // TODO + FinalObject: 0, // TODO }) - return nil + return err } -func (s *Session) handleSubscribeDone(msg *subscribeDoneMessage) error { - return s.si.receiveSubscriptions.delete(msg.SusbcribeID) +func (s *Session) handleSubscribeDone(msg *subscribeDoneMessage) { + sub, ok := s.si.receiveSubscriptions.get(msg.SusbcribeID) + if !ok { + s.si.logger.Info("got SubscribeDone for unknown subscription") + return + } + sub.close() + s.si.receiveSubscriptions.delete(msg.SusbcribeID) } func (s *Session) handleAnnounceMessage(msg *announceMessage) { @@ -509,30 +564,11 @@ func (s *Session) Close() error { return s.CloseWithError(0, "") } -// TODO: Acceptor func should not pass the complete subscription object but only -// the relevant header info -func (s *Session) ReadSubscription(ctx context.Context, accept func(*SendSubscription) error) (*SendSubscription, error) { - sub, err := s.si.sendSubscriptions.getNext(ctx) - if err != nil { - return nil, err - } - if err = accept(sub); err != nil { - s.controlStream.enqueue(&subscribeErrorMessage{ - SubscribeID: sub.subscribeID, - ErrorCode: 0, - ReasonPhrase: err.Error(), - TrackAlias: sub.trackAlias, - }) - return nil, s.si.sendSubscriptions.delete(sub.subscribeID) - } - s.controlStream.enqueue(&subscribeOkMessage{ - SubscribeID: sub.subscribeID, - Expires: 0, // TODO: Let user set these values? - ContentExists: false, // TODO: Let user set these values? - FinalGroup: 0, // TODO: Let user set these values? - FinalObject: 0, // TODO: Let user set these values? - }) - return sub, err +func (s *Session) AddLocalTrack(t *LocalTrack) error { + return s.si.localTracks.add(trackKey{ + namespace: t.Namespace, + trackname: t.Name, + }, t) } func (s *Session) Subscribe(ctx context.Context, subscribeID, trackAlias uint64, namespace, trackname, auth string) (*ReceiveSubscription, error) { @@ -576,7 +612,7 @@ func (s *Session) Subscribe(ctx context.Context, subscribeID, trackAlias uint64, case *subscribeOkMessage: return sub, nil case *subscribeErrorMessage: - _ = s.si.receiveSubscriptions.delete(sm.SubscribeID) + s.si.receiveSubscriptions.delete(sm.SubscribeID) return nil, ApplicationError{ code: v.ErrorCode, mesage: v.ReasonPhrase, diff --git a/session_test.go b/session_test.go index eeb8b42..4361538 100644 --- a/session_test.go +++ b/session_test.go @@ -92,7 +92,11 @@ func TestSession(t *testing.T) { }).Do(func(_ message) { close(done) }) - err := s.handleControlMessage(&clientSetupMessage{ + track := NewLocalTrack(0, "namespace", "track") + defer track.Close() + err := s.AddLocalTrack(track) + assert.NoError(t, err) + err = s.handleControlMessage(&clientSetupMessage{ SupportedVersions: []version{CURRENT_VERSION}, SetupParameters: map[uint64]parameter{ roleParameterKey: varintParameter{ @@ -114,14 +118,6 @@ func TestSession(t *testing.T) { Parameters: map[uint64]parameter{}, }) assert.NoError(t, err) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - sub, err := s.ReadSubscription(ctx, func(ss *SendSubscription) error { - return nil - }) - assert.NoError(t, err) - assert.NotNil(t, sub) - sub.SetExpires(time.Second) select { case <-time.After(time.Second): assert.Fail(t, "test timed out") diff --git a/subscription_map.go b/subscription_map.go deleted file mode 100644 index 79c8409..0000000 --- a/subscription_map.go +++ /dev/null @@ -1,67 +0,0 @@ -package moqtransport - -import ( - "context" - "errors" - "sync" -) - -type subscription interface { - unsubscribe() error -} - -type subscriptionMap[T subscription] struct { - mutex sync.Mutex - subscriptions map[uint64]T - newSubChan chan uint64 -} - -func newSubscriptionMap[T subscription]() *subscriptionMap[T] { - return &subscriptionMap[T]{ - mutex: sync.Mutex{}, - subscriptions: make(map[uint64]T), - newSubChan: make(chan uint64, 1), - } -} - -func (m *subscriptionMap[T]) add(id uint64, t T) error { - m.mutex.Lock() - defer m.mutex.Unlock() - if _, ok := m.subscriptions[id]; ok { - return errors.New("duplicate subscribe id") - } - m.subscriptions[id] = t - m.newSubChan <- id - return nil -} - -func (m *subscriptionMap[T]) get(id uint64) (T, bool) { - m.mutex.Lock() - defer m.mutex.Unlock() - sub, ok := m.subscriptions[id] - return sub, ok -} - -func (m *subscriptionMap[T]) delete(id uint64) error { - m.mutex.Lock() - defer m.mutex.Unlock() - sub, ok := m.subscriptions[id] - delete(m.subscriptions, id) - if !ok { - return errors.New("delete on unknown subscribe ID") - } - return sub.unsubscribe() -} - -func (m *subscriptionMap[T]) getNext(ctx context.Context) (T, error) { - for { - select { - case <-ctx.Done(): - return *new(T), ctx.Err() - case id := <-m.newSubChan: - if sub, ok := m.get(id); ok { - return sub, nil - } - } - } -} diff --git a/sync_map.go b/sync_map.go new file mode 100644 index 0000000..4c46f1e --- /dev/null +++ b/sync_map.go @@ -0,0 +1,41 @@ +package moqtransport + +import ( + "errors" + "sync" +) + +type syncMap[K comparable, V any] struct { + mutex sync.Mutex + elements map[K]V +} + +func newSyncMap[K comparable, V any]() *syncMap[K, V] { + return &syncMap[K, V]{ + mutex: sync.Mutex{}, + elements: make(map[K]V), + } +} + +func (m *syncMap[K, V]) add(k K, v V) error { + m.mutex.Lock() + defer m.mutex.Unlock() + if _, ok := m.elements[k]; ok { + return errors.New("duplicate entry") + } + m.elements[k] = v + return nil +} + +func (m *syncMap[K, V]) get(k K) (V, bool) { + m.mutex.Lock() + defer m.mutex.Unlock() + v, ok := m.elements[k] + return v, ok +} + +func (m *syncMap[K, V]) delete(k K) { + m.mutex.Lock() + defer m.mutex.Unlock() + delete(m.elements, k) +} diff --git a/track_header_stream.go b/track_header_stream.go index dbb1830..fce7192 100644 --- a/track_header_stream.go +++ b/track_header_stream.go @@ -21,12 +21,15 @@ func newTrackHeaderStream(stream SendStream, subscribeID, trackAlias, objectSend }, nil } -func (s *TrackHeaderStream) NewObject(groupID, objectID uint64) *TrackHeaderStreamObject { - return &TrackHeaderStreamObject{ - stream: s.stream, - groupID: groupID, - objectID: objectID, +func (s *TrackHeaderStream) writeObject(groupID, objectID uint64, payload []byte) (int, error) { + shto := streamHeaderTrackObject{ + GroupID: groupID, + ObjectID: objectID, + ObjectPayload: payload, } + buf := make([]byte, 0, 32+len(payload)) + buf = shto.append(buf) + return s.stream.Write(buf) } func (s *TrackHeaderStream) Close() error { diff --git a/track_header_stream_object.go b/track_header_stream_object.go deleted file mode 100644 index 4448c03..0000000 --- a/track_header_stream_object.go +++ /dev/null @@ -1,17 +0,0 @@ -package moqtransport - -type TrackHeaderStreamObject struct { - stream SendStream - groupID, objectID uint64 -} - -func (o *TrackHeaderStreamObject) Write(payload []byte) (int, error) { - shto := streamHeaderTrackObject{ - GroupID: o.groupID, - ObjectID: o.objectID, - ObjectPayload: payload, - } - buf := make([]byte, 0, 32+len(payload)) - buf = shto.append(buf) - return o.stream.Write(buf) -}