Skip to content

Commit

Permalink
Refactor SendSusbcription API
Browse files Browse the repository at this point in the history
  • Loading branch information
mengelbart committed Jun 7, 2024
1 parent a8fc8cc commit 60fbf81
Show file tree
Hide file tree
Showing 15 changed files with 580 additions and 374 deletions.
44 changes: 0 additions & 44 deletions announcement_map.go

This file was deleted.

60 changes: 23 additions & 37 deletions examples/date-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"flag"
"fmt"
"log"
Expand Down Expand Up @@ -56,6 +55,8 @@ func listen(ctx context.Context, addr string, tlsConfig *tls.Config) error {
TLSConfig: tlsConfig,
},
}
track := moqtransport.NewLocalTrack(0, "clock", "second")
defer track.Close()
http.HandleFunc("/moq", func(w http.ResponseWriter, r *http.Request) {
session, err := wt.Upgrade(w, r)
if err != nil {
Expand All @@ -72,8 +73,22 @@ func listen(ctx context.Context, addr string, tlsConfig *tls.Config) error {
w.WriteHeader(http.StatusInternalServerError)
return
}
go handle(moqSession)
go handle(moqSession, track)
})
go func() {
ticker := time.NewTicker(time.Second)
id := uint64(0)
for ts := range ticker.C {
log.Printf("tick: %v, subscribers: %v\n", ts, track.SubscriberCount())
track.WriteObject(ctx, moqtransport.Object{
GroupID: id,
ObjectID: 0,
ObjectSendOrder: 0,
Payload: []byte(fmt.Sprintf("%v", ts)),
})
id++
}
}()
for {
conn, err := listener.Accept(ctx)
if err != nil {
Expand All @@ -90,45 +105,16 @@ func listen(ctx context.Context, addr string, tlsConfig *tls.Config) error {
if err := s.RunServer(ctx); err != nil {
return err
}
go handle(s)
go handle(s, track)
}
}
}

func handle(p *moqtransport.Session) {
go func() {
s, err := p.ReadSubscription(context.Background(), func(s *moqtransport.SendSubscription) error {
if fmt.Sprintf("%v/%v", s.Namespace(), s.Trackname()) != "clock/second" {
return errors.New("unknown namespace/trackname")
}
return nil
})
if err != nil {
panic(err)
}
log.Printf("got subscription: %v", s)
go func() {
ticker := time.NewTicker(time.Second)
id := uint64(0)
for ts := range ticker.C {
w, err := s.NewObjectStream(id, 0, 0) // TODO: Use meaningful values
if err != nil {
log.Println(err)
return
}
if _, err := fmt.Fprintf(w, "%v", ts); err != nil {
log.Println(err)
return
}
if err := w.Close(); err != nil {
log.Println(err)
return
}
id++
}
}()
}()
if err := p.Announce(context.Background(), "clock"); err != nil {
func handle(s *moqtransport.Session, t *moqtransport.LocalTrack) {
if err := s.AddLocalTrack(t); err != nil {
panic(err)
}
if err := s.Announce(context.Background(), "clock"); err != nil {
panic(err)
}
}
Expand Down
10 changes: 7 additions & 3 deletions group_header_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@ func newGroupHeaderStream(stream SendStream, subscribeID, trackAlias, groupID, o
}, nil
}

func (s *groupHeaderStream) NewObject() *groupHeaderStreamObject {
return &groupHeaderStreamObject{
stream: s.stream,
func (s *groupHeaderStream) writeObject(objectID uint64, payload []byte) (int, error) {
shgo := streamHeaderGroupObject{
ObjectID: objectID,
ObjectPayload: payload,
}
buf := make([]byte, 0, 16+len(payload))
buf = shgo.append(buf)
return s.stream.Write(buf)
}

func (s *groupHeaderStream) Close() error {
Expand Down
16 changes: 0 additions & 16 deletions group_header_stream_object.go

This file was deleted.

113 changes: 88 additions & 25 deletions integrationtests/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,18 +149,27 @@ func TestIntegration(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
server := quicServerSession(t, ctx, listener, nil)
sub, err := server.ReadSubscription(ctx, func(ss *moqtransport.SendSubscription) error { return nil })
track := moqtransport.NewLocalTrack(0, "namespace", "track")
defer track.Close()
err := server.AddLocalTrack(track)
assert.NoError(t, err)
err = server.Announce(ctx, "namespace")
assert.NoError(t, err)
assert.Equal(t, "namespace", sub.Namespace())
assert.Equal(t, "track", sub.Trackname())
<-receivedSubscribeOK
assert.NoError(t, server.Close())
}()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
client := quicClientSession(t, ctx, addr, nil)
_, err := client.Subscribe(ctx, 0, 0, "namespace", "track", "auth")
announcementCh := make(chan struct{})
client := quicClientSession(t, ctx, addr, moqtransport.AnnouncementHandlerFunc(func(a *moqtransport.Announcement, arw moqtransport.AnnouncementResponseWriter) {
assert.Equal(t, "namespace", a.Namespace())
arw.Accept()
close(announcementCh)
}))
<-announcementCh
r, err := client.Subscribe(ctx, 0, 0, "namespace", "track", "auth")
assert.NoError(t, err)
assert.NotNil(t, r)
close(receivedSubscribeOK)
assert.NoError(t, client.Close())
wg.Wait()
Expand All @@ -172,29 +181,44 @@ func TestIntegration(t *testing.T) {
listener, addr, teardown := setup()
defer teardown()
wg.Add(1)
subscribedCh := make(chan struct{})
receivedObject := make(chan struct{})
go func() {
defer wg.Done()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
server := quicServerSession(t, ctx, listener, nil)
sub, err := server.ReadSubscription(ctx, func(ss *moqtransport.SendSubscription) error { return nil })
track := moqtransport.NewLocalTrack(0, "namespace", "track")
defer track.Close()
err := server.AddLocalTrack(track)
assert.NoError(t, err)
assert.Equal(t, "namespace", sub.Namespace())
assert.Equal(t, "track", sub.Trackname())
s, err := sub.NewObjectStream(0, 0, 0)
err = server.Announce(ctx, "namespace")
assert.NoError(t, err)
_, err = s.Write([]byte("hello world"))
<-subscribedCh
err = track.WriteObject(ctx, moqtransport.Object{
GroupID: 0,
ObjectID: 0,
ObjectSendOrder: 0,
ForwardingPreference: 0,
Payload: []byte("hello world"),
})
assert.NoError(t, err)
assert.NoError(t, s.Close())
<-receivedObject
assert.NoError(t, track.Close())
assert.NoError(t, server.Close())
}()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
client := quicClientSession(t, ctx, addr, nil)
announcementCh := make(chan struct{})
client := quicClientSession(t, ctx, addr, moqtransport.AnnouncementHandlerFunc(func(a *moqtransport.Announcement, arw moqtransport.AnnouncementResponseWriter) {
assert.Equal(t, "namespace", a.Namespace())
arw.Accept()
close(announcementCh)
}))
<-announcementCh
sub, err := client.Subscribe(ctx, 0, 0, "namespace", "track", "auth")
assert.NoError(t, err)
close(subscribedCh)
buf := make([]byte, 1500)
n, err := sub.Read(buf)
assert.NoError(t, err)
Expand All @@ -210,36 +234,75 @@ func TestIntegration(t *testing.T) {
listener, addr, teardown := setup()
defer teardown()
wg.Add(1)
receivedUnsubscribe := make(chan struct{})
receivedSubscribeCh := make(chan struct{})
receivedUnsubscribeCh := make(chan struct{})
subscribedCh := make(chan struct{})
unsubscribedCh := make(chan struct{})
go func() {
defer wg.Done()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
server := quicServerSession(t, ctx, listener, nil)
sub, err := server.ReadSubscription(ctx, func(ss *moqtransport.SendSubscription) error { return nil })
track := moqtransport.NewLocalTrack(0, "namespace", "track")
defer track.Close()
err := server.AddLocalTrack(track)
assert.NoError(t, err)
err = server.Announce(ctx, "namespace")
assert.NoError(t, err)
assert.Equal(t, "namespace", sub.Namespace())
assert.Equal(t, "track", sub.Trackname())
for i := 0; i < 10; i++ {
err = sub.NewObjectPreferDatagram(0, 0, 0, nil)
if err != nil {
err = track.WriteObject(ctx, moqtransport.Object{
GroupID: 0,
ObjectID: 0,
ObjectSendOrder: 0,
ForwardingPreference: 0,
Payload: []byte("hello world"),
})
assert.NoError(t, err)
<-subscribedCh
assert.Equal(t, 1, track.SubscriberCount())
err = track.WriteObject(ctx, moqtransport.Object{
GroupID: 0,
ObjectID: 0,
ObjectSendOrder: 0,
ForwardingPreference: 0,
Payload: []byte("hello world"),
})
assert.NoError(t, err)
close(receivedSubscribeCh)
<-unsubscribedCh
err = track.WriteObject(ctx, moqtransport.Object{
GroupID: 0,
ObjectID: 0,
ObjectSendOrder: 0,
ForwardingPreference: 0,
Payload: []byte("hello world"),
})
for i := 0; i < 3; i++ {
if track.SubscriberCount() == 0 {
break
}
time.Sleep(10 * time.Millisecond)
}
assert.Error(t, err)
assert.ErrorContains(t, err, "peer unsubscribed")
close(receivedUnsubscribe)
assert.NoError(t, err)
assert.Equal(t, 0, track.SubscriberCount())
close(receivedUnsubscribeCh)
assert.NoError(t, server.Close())
}()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
client := quicClientSession(t, ctx, addr, nil)
announcementCh := make(chan struct{})
client := quicClientSession(t, ctx, addr, moqtransport.AnnouncementHandlerFunc(func(a *moqtransport.Announcement, arw moqtransport.AnnouncementResponseWriter) {
assert.Equal(t, "namespace", a.Namespace())
arw.Accept()
close(announcementCh)
}))
<-announcementCh
sub, err := client.Subscribe(ctx, 0, 0, "namespace", "track", "auth")
assert.NoError(t, err)
close(subscribedCh)
<-receivedSubscribeCh
sub.Unsubscribe()
<-receivedUnsubscribe
assert.NoError(t, err)
close(unsubscribedCh)
<-receivedUnsubscribeCh
assert.NoError(t, client.Close())
wg.Wait()
})
Expand Down
Loading

0 comments on commit 60fbf81

Please sign in to comment.