Skip to content

Commit

Permalink
Make local track an interface of iterators
Browse files Browse the repository at this point in the history
  • Loading branch information
mengelbart committed Nov 3, 2024
1 parent 4a61461 commit 1e9617b
Show file tree
Hide file tree
Showing 14 changed files with 424 additions and 459 deletions.
2 changes: 1 addition & 1 deletion control_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func (s *controlStream) readMessages() {
return
}
if err = s.handle(msg); err != nil {
s.logger.Error("failed to handle control stream message", "error", err)
s.logger.Error("failed to handle control stream message", "error", err, "msg", msg)
panic("TODO: Close connection")
}
}
Expand Down
12 changes: 5 additions & 7 deletions examples/chat/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import (
)

type clientRoom struct {
lt *moqtransport.LocalTrack
lt *moqtransport.ListTrack
rts []*moqtransport.RemoteTrack
}

Expand Down Expand Up @@ -142,8 +142,8 @@ func (c *Client) handleCatalogDeltas(roomID, username string, previous *chatalog
func (c *Client) joinRoom(roomID, username string) error {
c.rm.lock.Lock()
defer c.rm.lock.Unlock()
lt := moqtransport.NewLocalTrack(fmt.Sprintf("moq-chat/%v/participant/%v", roomID, username), "")
if err := c.session.AddLocalTrack(lt); err != nil {
lt := moqtransport.NewListTrack()
if err := c.session.AddLocalTrack(fmt.Sprintf("moq-chat/%v/participant/%v", roomID, username), "", lt); err != nil {
return err
}
c.rm.rooms[roomID] = &clientRoom{
Expand Down Expand Up @@ -225,16 +225,14 @@ func (c *Client) Run() error {
fmt.Println("server not subscribed, dropping message")
break
}
err := c.rm.rooms[fields[1]].lt.WriteObject(context.Background(), moqtransport.Object{
// TODO: Set correct groupid and objectid
c.rm.rooms[fields[1]].lt.Append(moqtransport.Object{
GroupID: 0,
ObjectID: 0,
PublisherPriority: 0,
ForwardingPreference: moqtransport.ObjectForwardingPreferenceStream,
Payload: []byte(strings.TrimSpace(msg)),
})
if err != nil {
return fmt.Errorf("failed to write to room: %v", err)
}
default:
fmt.Println("invalid command, try 'join' or 'msg'")
}
Expand Down
25 changes: 11 additions & 14 deletions examples/chat/room.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ type roomID string

type user struct {
name string
track *moqtransport.LocalTrack
track *moqtransport.ListTrack
session *moqtransport.Session
}

type room struct {
ID roomID
catalogTrack *moqtransport.LocalTrack
catalogTrack *moqtransport.ListTrack
catalogGroup uint64
users *chatalog[*user]
usersLock sync.Mutex
Expand All @@ -29,25 +29,25 @@ type room struct {
func newRoom(id roomID) *room {
return &room{
ID: id,
catalogTrack: moqtransport.NewLocalTrack(fmt.Sprintf("moq-chat/%v", id), ""),
catalogTrack: moqtransport.NewListTrack(),
catalogGroup: 0,
users: &chatalog[*user]{version: 1, participants: map[string]*user{}},
usersLock: sync.Mutex{},
}
}

func (r *room) addParticipant(username string, session *moqtransport.Session, track *moqtransport.LocalTrack) error {
func (r *room) addParticipant(username string, session *moqtransport.Session, track *moqtransport.ListTrack) error {
r.usersLock.Lock()
defer r.usersLock.Unlock()
log.Printf("saving user: %v", username)
if _, ok := r.users.participants[username]; ok {
return errors.New("duplicate participant")
}
for _, u := range r.users.participants {
if err := session.AddLocalTrack(u.track); err != nil {
if err := session.AddLocalTrack(fmt.Sprintf("moq-chat/%v/participant/%v", r.ID, u.name), "", u.track); err != nil {
panic(err)
}
if err := u.session.AddLocalTrack(track); err != nil {
if err := u.session.AddLocalTrack(fmt.Sprintf("moq-chat/%v/participant/%v", r.ID, username), "", track); err != nil {
panic(err)
}
}
Expand Down Expand Up @@ -79,35 +79,32 @@ func (r *room) announceUser(username string, s *moqtransport.Session, arw moqtra
}
catalog := r.users.serialize()
fmt.Printf("sending catalog: %v\n", catalog)
r.catalogTrack.WriteObject(context.Background(), moqtransport.Object{
r.catalogTrack.Append(moqtransport.Object{
GroupID: r.catalogGroup,
ObjectID: 0,
PublisherPriority: 0,
ForwardingPreference: moqtransport.ObjectForwardingPreferenceStreamTrack,
Payload: []byte(catalog),
})
r.catalogGroup += 1
go func(remote *moqtransport.RemoteTrack, local *moqtransport.LocalTrack) {
go func(remote *moqtransport.RemoteTrack, local *moqtransport.ListTrack) {
for {
obj, err := remote.ReadObject(context.Background())
if err != nil {
panic(err)
}
fmt.Printf("relay read object: %v\n", obj)
err = local.WriteObject(context.Background(), obj)
if err != nil {
panic(err)
}
local.Append(obj)
}
}(sub, u.track)
}

func (r *room) subscribeCatalog(s *moqtransport.Session, sub *moqtransport.Subscription, srw moqtransport.SubscriptionResponseWriter) {
if err := s.AddLocalTrack(r.catalogTrack); err != nil {
if err := s.AddLocalTrack(fmt.Sprintf("moq-chat/%v", r.ID), "", r.catalogTrack); err != nil {
srw.Reject(uint64(errorCodeInternal), "failed to setup room catalog track")
return
}
track := moqtransport.NewLocalTrack(fmt.Sprintf("moq-chat/%v/participant/%v", r.ID, sub.Authorization), "") // TODO: Track ID?
track := moqtransport.NewListTrack()
err := r.addParticipant(sub.Authorization, s, track)
if err != nil {
srw.Reject(uint64(errorCodeDuplicateUsername), "username already in use")
Expand Down
20 changes: 10 additions & 10 deletions examples/date/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type moqHandler struct {
trackname string
publish bool
subscribe bool
localTrack *moqtransport.LocalTrack
localTrack *moqtransport.ListTrack
}

func (h *moqHandler) runClient(ctx context.Context, wt bool) error {
Expand All @@ -40,7 +40,7 @@ func (h *moqHandler) runClient(ctx context.Context, wt bool) error {
return err
}
if h.publish {
h.setupDateTrack(ctx)
h.setupDateTrack()
}
h.handle(ctx, conn)
select {}
Expand All @@ -60,7 +60,7 @@ func (h *moqHandler) runServer(ctx context.Context) error {
},
}
if h.publish {
h.setupDateTrack(ctx)
h.setupDateTrack()
}
http.HandleFunc("/moq", func(w http.ResponseWriter, r *http.Request) {
session, err := wt.Upgrade(w, r)
Expand Down Expand Up @@ -100,7 +100,7 @@ func (h *moqHandler) handle(ctx context.Context, conn moqtransport.Connection) {
srw.Reject(0, "endpoint does not publish any tracks")
return
}
if sub.Namespace != h.namespace && sub.TrackName != h.trackname {
if sub.Namespace != h.namespace && sub.Trackname != h.trackname {
srw.Reject(0, "unknown track")
return
}
Expand Down Expand Up @@ -153,21 +153,21 @@ func (h *moqHandler) subscribeAndRead(ctx context.Context, s *moqtransport.Sessi
return nil
}

func (h *moqHandler) setupDateTrack(ctx context.Context) {
h.localTrack = moqtransport.NewLocalTrack(h.namespace, h.trackname)
func (h *moqHandler) setupDateTrack() {
h.localTrack = moqtransport.NewListTrack()
go func() {
defer h.localTrack.Close()
ticker := time.NewTicker(time.Second)
id := uint64(0)
i := 0
for ts := range ticker.C {
h.localTrack.WriteObject(ctx, moqtransport.Object{
GroupID: id,
h.localTrack.Append(moqtransport.Object{
GroupID: uint64(i),
ObjectID: 0,
PublisherPriority: 0,
ForwardingPreference: moqtransport.ObjectForwardingPreferenceStream,
Payload: []byte(fmt.Sprintf("%v", ts)),
})
id++
i++
}
}()
}
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/mengelbart/moqtransport

go 1.22.0
go 1.23.2

require (
github.com/quic-go/quic-go v0.45.0
Expand Down
124 changes: 68 additions & 56 deletions integrationtests/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"crypto/x509"
"encoding/pem"
"fmt"
"log"
"math/big"
"net"
"sync"
Expand Down Expand Up @@ -150,9 +151,9 @@ func TestIntegration(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
server := quicServerSession(t, ctx, listener, nil)
track := moqtransport.NewLocalTrack("namespace", "track")
track := moqtransport.NewListTrack()
defer track.Close()
err := server.AddLocalTrack(track)
err := server.AddLocalTrack("namespace", "track", track)
assert.NoError(t, err)
err = server.Announce(ctx, "namespace")
assert.NoError(t, err)
Expand Down Expand Up @@ -189,23 +190,20 @@ func TestIntegration(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
server := quicServerSession(t, ctx, listener, nil)
track := moqtransport.NewLocalTrack("namespace", "track")
track := moqtransport.NewListTrack()
defer track.Close()
err := server.AddLocalTrack(track)
err := server.AddLocalTrack("namespace", "track", track)
assert.NoError(t, err)
err = server.Announce(ctx, "namespace")
assert.NoError(t, err)
<-subscribedCh
err = track.WriteObject(ctx, moqtransport.Object{
GroupID: 0,
ObjectID: 0,
track.Append(moqtransport.Object{
PublisherPriority: 0,
ForwardingPreference: 0,
Payload: []byte("hello world"),
})
assert.NoError(t, err)
<-receivedObject
assert.NoError(t, track.Close())
track.Close()
assert.NoError(t, server.Close())
}()
ctx, cancel := context.WithCancel(context.Background())
Expand Down Expand Up @@ -234,77 +232,91 @@ func TestIntegration(t *testing.T) {
listener, addr, teardown := setup()
defer teardown()
wg.Add(1)
receivedSubscribeCh := make(chan struct{})
receivedUnsubscribeCh := make(chan struct{})
subscribedCh := make(chan struct{})
unsubscribedCh := make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

trackCreated := make(chan struct{})
unsubscribed := make(chan struct{})

go func() {
defer wg.Done()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
server := quicServerSession(t, ctx, listener, nil)
track := moqtransport.NewLocalTrack("namespace", "track")
defer track.Close()
err := server.AddLocalTrack(track)
assert.NoError(t, err)
err = server.Announce(ctx, "namespace")
track := moqtransport.NewListTrack()
err := server.AddLocalTrack("namespace", "track", track)
assert.NoError(t, err)
err = track.WriteObject(ctx, moqtransport.Object{
close(trackCreated)
track.Append(moqtransport.Object{
GroupID: 0,
ObjectID: 0,
PublisherPriority: 0,
ForwardingPreference: 0,
Payload: []byte("hello world"),
Payload: []byte("hello world: 0"),
})
assert.NoError(t, err)
<-subscribedCh
assert.Equal(t, 1, track.SubscriberCount())
err = track.WriteObject(ctx, moqtransport.Object{
track.Append(moqtransport.Object{
GroupID: 0,
ObjectID: 0,
ObjectID: 1,
PublisherPriority: 0,
ForwardingPreference: 0,
Payload: []byte("hello world"),
Payload: []byte("hello world: 1"),
})
assert.NoError(t, err)
close(receivedSubscribeCh)
<-unsubscribedCh
err = track.WriteObject(ctx, moqtransport.Object{
track.Append(moqtransport.Object{
GroupID: 0,
ObjectID: 0,
ObjectID: 2,
PublisherPriority: 0,
ForwardingPreference: 0,
Payload: []byte("hello world"),
Payload: []byte("hello world: 2"),
})
for i := 0; i < 3; i++ {
if track.SubscriberCount() == 0 {
break
}
time.Sleep(10 * time.Millisecond)
}
assert.NoError(t, err)
assert.Equal(t, 0, track.SubscriberCount())
close(receivedUnsubscribeCh)
<-unsubscribed
time.Sleep(time.Second)
track.Close()
assert.NoError(t, server.Close())
}()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
announcementCh := make(chan struct{})
client := quicClientSession(t, ctx, addr, moqtransport.AnnouncementHandlerFunc(func(_ *moqtransport.Session, a *moqtransport.Announcement, arw moqtransport.AnnouncementResponseWriter) {
assert.Equal(t, "namespace", a.Namespace())
arw.Accept()
close(announcementCh)
}))
<-announcementCh

client := quicClientSession(t, ctx, addr, nil)

<-trackCreated
sub, err := client.Subscribe(ctx, 0, 0, "namespace", "track", "auth")
assert.NoError(t, err)
close(subscribedCh)
<-receivedSubscribeCh
res := []moqtransport.Object{}
for i := 0; i < 3; i++ {
o, err := sub.ReadObject(ctx)
log.Printf("read object %v", o)
assert.NoError(t, err)
res = append(res, o)
}
sub.Unsubscribe()
close(unsubscribedCh)
<-receivedUnsubscribeCh
close(unsubscribed)
time.Sleep(time.Second)
assert.NoError(t, client.Close())
wg.Wait()

expected := []moqtransport.Object{
{
GroupID: 0,
ObjectID: 0,
PublisherPriority: 0,
ForwardingPreference: 0,
Payload: []byte("hello world: 0"),
},
{
GroupID: 0,
ObjectID: 1,
PublisherPriority: 0,
ForwardingPreference: 0,
Payload: []byte("hello world: 1"),
},
{
GroupID: 0,
ObjectID: 2,
PublisherPriority: 0,
ForwardingPreference: 0,
Payload: []byte("hello world: 2"),
},
}
assert.Len(t, res, 3)
for _, o := range expected {
assert.Contains(t, res, o)
}
})
}

Expand Down
Loading

0 comments on commit 1e9617b

Please sign in to comment.