From a8fc8cce935bb68203a5bbce474e1ff1f0688d0a Mon Sep 17 00:00:00 2001 From: Mathis Engelbart Date: Mon, 13 May 2024 09:57:26 +0200 Subject: [PATCH] Refactor control stream sending and session creation API --- announcement.go | 12 +- announcement_map.go | 6 + control_stream.go | 83 +++++ examples/date-client/main.go | 32 +- examples/date-server/main.go | 71 +---- integrationtests/integration_test.go | 79 +++-- message.go | 3 + mock_announcement_handler_test.go | 52 ---- mock_connection_test.go | 17 +- mock_control_message_sender_test.go | 63 ++++ mock_control_stream_handler_test.go | 79 ----- mock_parser_factory_test.go | 52 ---- mock_parser_test.go | 1 + mock_receive_stream_test.go | 1 + mock_send_stream_test.go | 1 + mock_stream_test.go | 1 + mockgen.go | 11 +- receive_subscription.go | 4 +- session.go | 446 +++++++++++++-------------- session_test.go | 167 ++++++---- 20 files changed, 565 insertions(+), 616 deletions(-) create mode 100644 control_stream.go delete mode 100644 mock_announcement_handler_test.go create mode 100644 mock_control_message_sender_test.go delete mode 100644 mock_control_stream_handler_test.go delete mode 100644 mock_parser_factory_test.go diff --git a/announcement.go b/announcement.go index 06699a1..881e4d0 100644 --- a/announcement.go +++ b/announcement.go @@ -1,8 +1,8 @@ package moqtransport type AnnouncementResponseWriter interface { - Accept() error - Reject(code uint64, reason string) error + Accept() + Reject(code uint64, reason string) } type defaultAnnouncementResponseWriter struct { @@ -10,12 +10,12 @@ type defaultAnnouncementResponseWriter struct { s *Session } -func (a *defaultAnnouncementResponseWriter) Accept() error { - return a.s.acceptAnnouncement(a.a) +func (a *defaultAnnouncementResponseWriter) Accept() { + a.s.acceptAnnouncement(a.a) } -func (a *defaultAnnouncementResponseWriter) Reject(code uint64, reason string) error { - return a.s.rejectAnnouncement(a.a, code, reason) +func (a *defaultAnnouncementResponseWriter) Reject(code uint64, reason string) { + a.s.rejectAnnouncement(a.a, code, reason) } type AnnouncementHandler interface { diff --git a/announcement_map.go b/announcement_map.go index cfe1378..c2892cd 100644 --- a/announcement_map.go +++ b/announcement_map.go @@ -36,3 +36,9 @@ func (m *announcementMap) get(name string) (*Announcement, bool) { a, ok := m.announcements[name] return a, ok } + +func (m *announcementMap) delete(name string) { + m.mutex.Lock() + defer m.mutex.Unlock() + delete(m.announcements, name) +} diff --git a/control_stream.go b/control_stream.go new file mode 100644 index 0000000..60e4c65 --- /dev/null +++ b/control_stream.go @@ -0,0 +1,83 @@ +package moqtransport + +import ( + "io" + "log/slog" + + "github.com/quic-go/quic-go/quicvarint" +) + +type controlStream struct { + logger *slog.Logger + stream Stream + handle messageHandler + parser parser + sendQueue chan message + closeCh chan struct{} +} + +type messageHandler func(message) error + +func newControlStream(s Stream, h messageHandler) *controlStream { + cs := &controlStream{ + logger: defaultLogger.WithGroup("MOQ_CONTROL_STREAM"), + stream: s, + handle: h, + parser: newParser(quicvarint.NewReader(s)), + sendQueue: make(chan message, 64), + closeCh: make(chan struct{}), + } + go cs.readMessages() + go cs.writeMessages() + return cs +} + +func (s *controlStream) readMessages() { + for { + msg, err := s.parser.parse() + if err != nil { + if err == io.EOF { + return + } + s.logger.Error("TODO", "error", err) + return + } + if err = s.handle(msg); err != nil { + s.logger.Error("failed to handle control stream message", "error", err) + panic("TODO: Close connection") + } + } +} + +func (s *controlStream) writeMessages() { + for { + select { + case <-s.closeCh: + s.logger.Info("close called, leaving control stream write loop") + return + case msg := <-s.sendQueue: + s.logger.Info("sending control message", "message", msg) + buf := make([]byte, 0, 1500) + buf = msg.append(buf) + if _, err := s.stream.Write(buf); err != nil { + if err == io.EOF { + s.logger.Info("write stream closed, leaving control stream write loop") + return + } + s.logger.Error("failed to write to control stream", "error", err) + } + } + } +} + +func (s *controlStream) enqueue(m message) { + select { + case s.sendQueue <- m: + default: + s.logger.Warn("dropping control stream message because send queue is full") + } +} + +func (s *controlStream) close() { + close(s.closeCh) +} diff --git a/examples/date-client/main.go b/examples/date-client/main.go index 4339508..f042776 100644 --- a/examples/date-client/main.go +++ b/examples/date-client/main.go @@ -27,7 +27,6 @@ func main() { } func run(ctx context.Context, addr string, wt bool, namespace, trackname string) error { - var session *moqtransport.Session var conn moqtransport.Connection var err error if wt { @@ -38,24 +37,31 @@ func run(ctx context.Context, addr string, wt bool, namespace, trackname string) if err != nil { return err } - session, err = moqtransport.NewClientSession(conn, moqtransport.IngestionDeliveryRole, !wt) - if err != nil { + announcementWaitCh := make(chan *moqtransport.Announcement) + session := &moqtransport.Session{ + Conn: conn, + EnableDatagrams: true, + LocalRole: moqtransport.DeliveryRole, + AnnouncementHandler: moqtransport.AnnouncementHandlerFunc(func(a *moqtransport.Announcement, arw moqtransport.AnnouncementResponseWriter) { + if a.Namespace() == "clock" { + arw.Accept() + announcementWaitCh <- a + return + } + arw.Reject(0, "invalid namespace") + }), + } + if err = session.RunClient(); err != nil { return err } defer session.Close() - session.HandleAnnouncements(moqtransport.AnnouncementHandlerFunc(func(a *moqtransport.Announcement, arw moqtransport.AnnouncementResponseWriter) { - if a.Namespace() == "clock" { - arw.Accept() - return - } - arw.Reject(0, "invalid namespace") - })) - log.Println("got Announcement") + a := <-announcementWaitCh + log.Printf("got Announcement: %v\n", a) log.Println("subscribing") rs, err := session.Subscribe(context.Background(), 0, 0, namespace, trackname, "") if err != nil { - panic(err) + return err } log.Println("got subscription") buf := make([]byte, 64_000) @@ -66,7 +72,7 @@ func run(ctx context.Context, addr string, wt bool, namespace, trackname string) log.Printf("got last object") return nil } - panic(err) + return err } log.Printf("got object: %v\n", string(buf[:n])) } diff --git a/examples/date-server/main.go b/examples/date-server/main.go index d1ca391..ff5b847 100644 --- a/examples/date-server/main.go +++ b/examples/date-server/main.go @@ -27,27 +27,19 @@ func main() { certFile := flag.String("cert", "localhost.pem", "TLS certificate file") keyFile := flag.String("key", "localhost-key.pem", "TLS key file") addr := flag.String("addr", "localhost:8080", "listen address") - wt := flag.Bool("webtransport", false, "Serve WebTransport only") - quic := flag.Bool("quic", false, "Serve QUIC only") flag.Parse() - if err := run(context.Background(), *addr, *wt, *quic, *certFile, *keyFile); err != nil { + if err := run(context.Background(), *addr, *certFile, *keyFile); err != nil { log.Fatal(err) } } -func run(ctx context.Context, addr string, wt, quic bool, certFile, keyFile string) error { +func run(ctx context.Context, addr string, certFile, keyFile string) error { tlsConfig, err := generateTLSConfigWithCertAndKey(certFile, keyFile) if err != nil { log.Printf("failed to generate TLS config from cert file and key, generating in memory certs: %v", err) tlsConfig = generateTLSConfig() } - if wt { - return listenWebTransport(addr, tlsConfig) - } - if quic { - return listenQUIC(ctx, addr, tlsConfig) - } return listen(ctx, addr, tlsConfig) } @@ -71,8 +63,11 @@ func listen(ctx context.Context, addr string, tlsConfig *tls.Config) error { w.WriteHeader(http.StatusInternalServerError) return } - moqSession, err := moqtransport.NewServerSession(webtransportmoq.New(session), false) - if err != nil { + moqSession := &moqtransport.Session{ + Conn: webtransportmoq.New(session), + EnableDatagrams: true, + } + if err := moqSession.RunServer(ctx); err != nil { log.Printf("MoQ Session initialization failed: %v", err) w.WriteHeader(http.StatusInternalServerError) return @@ -88,8 +83,11 @@ func listen(ctx context.Context, addr string, tlsConfig *tls.Config) error { go wt.ServeQUICConn(conn) } if conn.ConnectionState().TLS.NegotiatedProtocol == "moq-00" { - s, err := moqtransport.NewServerSession(quicmoq.New(conn), true) - if err != nil { + s := &moqtransport.Session{ + Conn: quicmoq.New(conn), + EnableDatagrams: true, + } + if err := s.RunServer(ctx); err != nil { return err } go handle(s) @@ -97,51 +95,6 @@ func listen(ctx context.Context, addr string, tlsConfig *tls.Config) error { } } -func listenWebTransport(addr string, tlsConfig *tls.Config) error { - wt := webtransport.Server{ - H3: http3.Server{ - Addr: addr, - TLSConfig: tlsConfig, - }, - } - http.HandleFunc("/moq", func(w http.ResponseWriter, r *http.Request) { - session, err := wt.Upgrade(w, r) - if err != nil { - log.Printf("upgrading to webtransport failed: %v", err) - w.WriteHeader(http.StatusInternalServerError) - return - } - moqSession, err := moqtransport.NewServerSession(webtransportmoq.New(session), false) - if err != nil { - log.Printf("MoQ Session initialization failed: %v", err) - w.WriteHeader(http.StatusInternalServerError) - return - } - go handle(moqSession) - }) - return wt.ListenAndServe() -} - -func listenQUIC(ctx context.Context, addr string, tlsConfig *tls.Config) error { - listener, err := quic.ListenAddr(addr, tlsConfig, &quic.Config{ - EnableDatagrams: true, - }) - if err != nil { - return err - } - for { - conn, err := listener.Accept(ctx) - if err != nil { - return err - } - s, err := moqtransport.NewServerSession(quicmoq.New(conn), true) - if err != nil { - return err - } - go handle(s) - } -} - func handle(p *moqtransport.Session) { go func() { s, err := p.ReadSubscription(context.Background(), func(s *moqtransport.SendSubscription) error { diff --git a/integrationtests/integration_test.go b/integrationtests/integration_test.go index 47ba796..83f2d5a 100644 --- a/integrationtests/integration_test.go +++ b/integrationtests/integration_test.go @@ -21,10 +21,29 @@ import ( "go.uber.org/goleak" ) -func quicClientSession(t *testing.T, ctx context.Context, addr string) *moqtransport.Session { +func quicServerSession(t *testing.T, ctx context.Context, listener *quic.Listener, handler moqtransport.AnnouncementHandler) *moqtransport.Session { + conn, err := listener.Accept(ctx) + assert.NoError(t, err) + session := &moqtransport.Session{ + Conn: quicmoq.New(conn), + EnableDatagrams: true, + AnnouncementHandler: handler, + } + err = session.RunServer(ctx) + assert.NoError(t, err) + return session +} + +func quicClientSession(t *testing.T, ctx context.Context, addr string, handler moqtransport.AnnouncementHandler) *moqtransport.Session { conn, err := quic.DialAddr(ctx, addr, generateTLSConfig(), &quic.Config{EnableDatagrams: true}) assert.NoError(t, err) - session, err := moqtransport.NewClientSession(quicmoq.New(conn), moqtransport.IngestionDeliveryRole, true) + session := &moqtransport.Session{ + Conn: quicmoq.New(conn), + EnableDatagrams: true, + LocalRole: moqtransport.IngestionDeliveryRole, + AnnouncementHandler: handler, + } + err = session.RunClient() assert.NoError(t, err) return session } @@ -50,20 +69,17 @@ func TestIntegration(t *testing.T) { defer wg.Done() ctx, cancel := context.WithCancel(context.Background()) defer cancel() - conn, err := listener.Accept(ctx) - assert.NoError(t, err) - server, err := moqtransport.NewServerSession(quicmoq.New(conn), true) - assert.NoError(t, err) + server := quicServerSession(t, ctx, listener, nil) assert.NotNil(t, server) <-sessionEstablished assert.NoError(t, server.Close()) }() ctx, cancel := context.WithCancel(context.Background()) defer cancel() - client := quicClientSession(t, ctx, addr) - assert.NoError(t, client.Close()) + client := quicClientSession(t, ctx, addr, nil) close(sessionEstablished) wg.Wait() + assert.NoError(t, client.Close()) }) t.Run("announce", func(t *testing.T) { @@ -77,21 +93,16 @@ func TestIntegration(t *testing.T) { defer wg.Done() ctx, cancel := context.WithCancel(context.Background()) defer cancel() - conn, err := listener.Accept(ctx) - assert.NoError(t, err) - server, err := moqtransport.NewServerSession(quicmoq.New(conn), true) - assert.NoError(t, err) + server := quicServerSession(t, ctx, listener, nil) assert.NoError(t, server.Announce(ctx, "/namespace")) close(receivedAnnounceOK) assert.NoError(t, server.Close()) }() ctx, cancel := context.WithCancel(context.Background()) defer cancel() - client := quicClientSession(t, ctx, addr) - client.HandleAnnouncements(moqtransport.AnnouncementHandlerFunc(func(a *moqtransport.Announcement, arw moqtransport.AnnouncementResponseWriter) { + client := quicClientSession(t, ctx, addr, moqtransport.AnnouncementHandlerFunc(func(a *moqtransport.Announcement, arw moqtransport.AnnouncementResponseWriter) { assert.Equal(t, "/namespace", a.Namespace()) - err := arw.Accept() - assert.NoError(t, err) + arw.Accept() })) <-receivedAnnounceOK assert.NoError(t, client.Close()) @@ -109,11 +120,8 @@ func TestIntegration(t *testing.T) { defer wg.Done() ctx, cancel := context.WithCancel(context.Background()) defer cancel() - conn, err := listener.Accept(ctx) - assert.NoError(t, err) - server, err := moqtransport.NewServerSession(quicmoq.New(conn), true) - assert.NoError(t, err) - err = server.Announce(ctx, "/namespace") + server := quicServerSession(t, ctx, listener, nil) + err := server.Announce(ctx, "/namespace") assert.Error(t, err) assert.ErrorContains(t, err, "TEST_ERR") close(receivedAnnounceError) @@ -121,10 +129,8 @@ func TestIntegration(t *testing.T) { }() ctx, cancel := context.WithCancel(context.Background()) defer cancel() - client := quicClientSession(t, ctx, addr) - client.HandleAnnouncements(moqtransport.AnnouncementHandlerFunc(func(_ *moqtransport.Announcement, arw moqtransport.AnnouncementResponseWriter) { - err := arw.Reject(0, "TEST_ERR") - assert.NoError(t, err) + client := quicClientSession(t, ctx, addr, moqtransport.AnnouncementHandlerFunc(func(_ *moqtransport.Announcement, arw moqtransport.AnnouncementResponseWriter) { + arw.Reject(0, "TEST_ERR") })) <-receivedAnnounceError assert.NoError(t, client.Close()) @@ -142,10 +148,7 @@ func TestIntegration(t *testing.T) { defer wg.Done() ctx, cancel := context.WithCancel(context.Background()) defer cancel() - conn, err := listener.Accept(ctx) - assert.NoError(t, err) - server, err := moqtransport.NewServerSession(quicmoq.New(conn), true) - assert.NoError(t, err) + server := quicServerSession(t, ctx, listener, nil) sub, err := server.ReadSubscription(ctx, func(ss *moqtransport.SendSubscription) error { return nil }) assert.NoError(t, err) assert.Equal(t, "namespace", sub.Namespace()) @@ -155,7 +158,7 @@ func TestIntegration(t *testing.T) { }() ctx, cancel := context.WithCancel(context.Background()) defer cancel() - client := quicClientSession(t, ctx, addr) + client := quicClientSession(t, ctx, addr, nil) _, err := client.Subscribe(ctx, 0, 0, "namespace", "track", "auth") assert.NoError(t, err) close(receivedSubscribeOK) @@ -174,10 +177,7 @@ func TestIntegration(t *testing.T) { defer wg.Done() ctx, cancel := context.WithCancel(context.Background()) defer cancel() - conn, err := listener.Accept(ctx) - assert.NoError(t, err) - server, err := moqtransport.NewServerSession(quicmoq.New(conn), true) - assert.NoError(t, err) + server := quicServerSession(t, ctx, listener, nil) sub, err := server.ReadSubscription(ctx, func(ss *moqtransport.SendSubscription) error { return nil }) assert.NoError(t, err) assert.Equal(t, "namespace", sub.Namespace()) @@ -192,7 +192,7 @@ func TestIntegration(t *testing.T) { }() ctx, cancel := context.WithCancel(context.Background()) defer cancel() - client := quicClientSession(t, ctx, addr) + client := quicClientSession(t, ctx, addr, nil) sub, err := client.Subscribe(ctx, 0, 0, "namespace", "track", "auth") assert.NoError(t, err) buf := make([]byte, 1500) @@ -215,10 +215,7 @@ func TestIntegration(t *testing.T) { defer wg.Done() ctx, cancel := context.WithCancel(context.Background()) defer cancel() - conn, err := listener.Accept(ctx) - assert.NoError(t, err) - server, err := moqtransport.NewServerSession(quicmoq.New(conn), true) - assert.NoError(t, err) + server := quicServerSession(t, ctx, listener, nil) sub, err := server.ReadSubscription(ctx, func(ss *moqtransport.SendSubscription) error { return nil }) assert.NoError(t, err) assert.Equal(t, "namespace", sub.Namespace()) @@ -237,10 +234,10 @@ func TestIntegration(t *testing.T) { }() ctx, cancel := context.WithCancel(context.Background()) defer cancel() - client := quicClientSession(t, ctx, addr) + client := quicClientSession(t, ctx, addr, nil) sub, err := client.Subscribe(ctx, 0, 0, "namespace", "track", "auth") assert.NoError(t, err) - assert.NoError(t, sub.Unsubscribe()) + sub.Unsubscribe() <-receivedUnsubscribe assert.NoError(t, err) assert.NoError(t, client.Close()) diff --git a/message.go b/message.go index 140f676..2fc1855 100644 --- a/message.go +++ b/message.go @@ -19,6 +19,9 @@ const ( ErrorCodeDuplicateTrackAlias = 0x04 ErrorCodeParameterLengthMismatch = 0x05 ErrorCodeGoAwayTimeout = 0x10 + + // Errors not included in current draft + ErrorCodeUnsupportedVersion = 0xff01 ) const ( diff --git a/mock_announcement_handler_test.go b/mock_announcement_handler_test.go deleted file mode 100644 index 3939318..0000000 --- a/mock_announcement_handler_test.go +++ /dev/null @@ -1,52 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/mengelbart/moqtransport (interfaces: AnnouncementHandler) -// -// Generated by this command: -// -// mockgen -build_flags=-tags=gomock -package moqtransport -self_package github.com/mengelbart/moqtransport -destination mock_announcement_handler_test.go github.com/mengelbart/moqtransport AnnouncementHandler -// -// Package moqtransport is a generated GoMock package. -package moqtransport - -import ( - reflect "reflect" - - gomock "go.uber.org/mock/gomock" -) - -// MockAnnouncementHandler is a mock of AnnouncementHandler interface. -type MockAnnouncementHandler struct { - ctrl *gomock.Controller - recorder *MockAnnouncementHandlerMockRecorder -} - -// MockAnnouncementHandlerMockRecorder is the mock recorder for MockAnnouncementHandler. -type MockAnnouncementHandlerMockRecorder struct { - mock *MockAnnouncementHandler -} - -// NewMockAnnouncementHandler creates a new mock instance. -func NewMockAnnouncementHandler(ctrl *gomock.Controller) *MockAnnouncementHandler { - mock := &MockAnnouncementHandler{ctrl: ctrl} - mock.recorder = &MockAnnouncementHandlerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockAnnouncementHandler) EXPECT() *MockAnnouncementHandlerMockRecorder { - return m.recorder -} - -// handle mocks base method. -func (m *MockAnnouncementHandler) handle(arg0 string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "handle", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// handle indicates an expected call of handle. -func (mr *MockAnnouncementHandlerMockRecorder) handle(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handle", reflect.TypeOf((*MockAnnouncementHandler)(nil).handle), arg0) -} diff --git a/mock_connection_test.go b/mock_connection_test.go index 78aee25..e034547 100644 --- a/mock_connection_test.go +++ b/mock_connection_test.go @@ -5,6 +5,7 @@ // // mockgen -build_flags=-tags=gomock -package moqtransport -self_package github.com/mengelbart/moqtransport -destination mock_connection_test.go github.com/mengelbart/moqtransport Connection // + // Package moqtransport is a generated GoMock package. package moqtransport @@ -145,28 +146,28 @@ func (mr *MockConnectionMockRecorder) OpenUniStreamSync(arg0 any) *gomock.Call { // ReceiveDatagram mocks base method. func (m *MockConnection) ReceiveDatagram(arg0 context.Context) ([]byte, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ReceiveMessage", arg0) + ret := m.ctrl.Call(m, "ReceiveDatagram", arg0) ret0, _ := ret[0].([]byte) ret1, _ := ret[1].(error) return ret0, ret1 } -// ReceiveMessage indicates an expected call of ReceiveMessage. -func (mr *MockConnectionMockRecorder) ReceiveMessage(arg0 any) *gomock.Call { +// ReceiveDatagram indicates an expected call of ReceiveDatagram. +func (mr *MockConnectionMockRecorder) ReceiveDatagram(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveMessage", reflect.TypeOf((*MockConnection)(nil).ReceiveDatagram), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveDatagram", reflect.TypeOf((*MockConnection)(nil).ReceiveDatagram), arg0) } // SendDatagram mocks base method. func (m *MockConnection) SendDatagram(arg0 []byte) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SendMessage", arg0) + ret := m.ctrl.Call(m, "SendDatagram", arg0) ret0, _ := ret[0].(error) return ret0 } -// SendMessage indicates an expected call of SendMessage. -func (mr *MockConnectionMockRecorder) SendMessage(arg0 any) *gomock.Call { +// SendDatagram indicates an expected call of SendDatagram. +func (mr *MockConnectionMockRecorder) SendDatagram(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMessage", reflect.TypeOf((*MockConnection)(nil).SendDatagram), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendDatagram", reflect.TypeOf((*MockConnection)(nil).SendDatagram), arg0) } diff --git a/mock_control_message_sender_test.go b/mock_control_message_sender_test.go new file mode 100644 index 0000000..195d8fe --- /dev/null +++ b/mock_control_message_sender_test.go @@ -0,0 +1,63 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/mengelbart/moqtransport (interfaces: ControlMessageSender) +// +// Generated by this command: +// +// mockgen -build_flags=-tags=gomock -package moqtransport -self_package github.com/mengelbart/moqtransport -destination mock_control_message_sender_test.go github.com/mengelbart/moqtransport ControlMessageSender +// + +// Package moqtransport is a generated GoMock package. +package moqtransport + +import ( + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockControlMessageSender is a mock of ControlMessageSender interface. +type MockControlMessageSender struct { + ctrl *gomock.Controller + recorder *MockControlMessageSenderMockRecorder +} + +// MockControlMessageSenderMockRecorder is the mock recorder for MockControlMessageSender. +type MockControlMessageSenderMockRecorder struct { + mock *MockControlMessageSender +} + +// NewMockControlMessageSender creates a new mock instance. +func NewMockControlMessageSender(ctrl *gomock.Controller) *MockControlMessageSender { + mock := &MockControlMessageSender{ctrl: ctrl} + mock.recorder = &MockControlMessageSenderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockControlMessageSender) EXPECT() *MockControlMessageSenderMockRecorder { + return m.recorder +} + +// close mocks base method. +func (m *MockControlMessageSender) close() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "close") +} + +// close indicates an expected call of close. +func (mr *MockControlMessageSenderMockRecorder) close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "close", reflect.TypeOf((*MockControlMessageSender)(nil).close)) +} + +// enqueue mocks base method. +func (m *MockControlMessageSender) enqueue(arg0 message) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "enqueue", arg0) +} + +// enqueue indicates an expected call of enqueue. +func (mr *MockControlMessageSenderMockRecorder) enqueue(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "enqueue", reflect.TypeOf((*MockControlMessageSender)(nil).enqueue), arg0) +} diff --git a/mock_control_stream_handler_test.go b/mock_control_stream_handler_test.go deleted file mode 100644 index 2fdb022..0000000 --- a/mock_control_stream_handler_test.go +++ /dev/null @@ -1,79 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/mengelbart/moqtransport (interfaces: ControlStreamHandler) -// -// Generated by this command: -// -// mockgen -build_flags=-tags=gomock -package moqtransport -self_package github.com/mengelbart/moqtransport -destination mock_control_stream_handler_test.go github.com/mengelbart/moqtransport ControlStreamHandler -// -// Package moqtransport is a generated GoMock package. -package moqtransport - -import ( - reflect "reflect" - - gomock "go.uber.org/mock/gomock" -) - -// MockControlStreamHandler is a mock of ControlStreamHandler interface. -type MockControlStreamHandler struct { - ctrl *gomock.Controller - recorder *MockControlStreamHandlerMockRecorder -} - -// MockControlStreamHandlerMockRecorder is the mock recorder for MockControlStreamHandler. -type MockControlStreamHandlerMockRecorder struct { - mock *MockControlStreamHandler -} - -// NewMockControlStreamHandler creates a new mock instance. -func NewMockControlStreamHandler(ctrl *gomock.Controller) *MockControlStreamHandler { - mock := &MockControlStreamHandler{ctrl: ctrl} - mock.recorder = &MockControlStreamHandlerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockControlStreamHandler) EXPECT() *MockControlStreamHandlerMockRecorder { - return m.recorder -} - -// Read mocks base method. -func (m *MockControlStreamHandler) Read(arg0 []byte) (int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Read", arg0) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Read indicates an expected call of Read. -func (mr *MockControlStreamHandlerMockRecorder) Read(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockControlStreamHandler)(nil).Read), arg0) -} - -// readMessages mocks base method. -func (m *MockControlStreamHandler) readMessages(arg0 messageHandler) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "readMessages", arg0) -} - -// readMessages indicates an expected call of readMessages. -func (mr *MockControlStreamHandlerMockRecorder) readMessages(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "readMessages", reflect.TypeOf((*MockControlStreamHandler)(nil).readMessages), arg0) -} - -// send mocks base method. -func (m *MockControlStreamHandler) send(arg0 message) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "send", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// send indicates an expected call of send. -func (mr *MockControlStreamHandlerMockRecorder) send(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "send", reflect.TypeOf((*MockControlStreamHandler)(nil).send), arg0) -} diff --git a/mock_parser_factory_test.go b/mock_parser_factory_test.go deleted file mode 100644 index bd93470..0000000 --- a/mock_parser_factory_test.go +++ /dev/null @@ -1,52 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/mengelbart/moqtransport (interfaces: ParserFactory) -// -// Generated by this command: -// -// mockgen -build_flags=-tags=gomock -package moqtransport -self_package github.com/mengelbart/moqtransport -destination mock_parser_factory_test.go github.com/mengelbart/moqtransport ParserFactory -// -// Package moqtransport is a generated GoMock package. -package moqtransport - -import ( - reflect "reflect" - - gomock "go.uber.org/mock/gomock" -) - -// MockParserFactory is a mock of ParserFactory interface. -type MockParserFactory struct { - ctrl *gomock.Controller - recorder *MockParserFactoryMockRecorder -} - -// MockParserFactoryMockRecorder is the mock recorder for MockParserFactory. -type MockParserFactoryMockRecorder struct { - mock *MockParserFactory -} - -// NewMockParserFactory creates a new mock instance. -func NewMockParserFactory(ctrl *gomock.Controller) *MockParserFactory { - mock := &MockParserFactory{ctrl: ctrl} - mock.recorder = &MockParserFactoryMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockParserFactory) EXPECT() *MockParserFactoryMockRecorder { - return m.recorder -} - -// new mocks base method. -func (m *MockParserFactory) new(arg0 messageReader) parser { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "new", arg0) - ret0, _ := ret[0].(parser) - return ret0 -} - -// new indicates an expected call of new. -func (mr *MockParserFactoryMockRecorder) new(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "new", reflect.TypeOf((*MockParserFactory)(nil).new), arg0) -} diff --git a/mock_parser_test.go b/mock_parser_test.go index 6ee6049..74f6953 100644 --- a/mock_parser_test.go +++ b/mock_parser_test.go @@ -5,6 +5,7 @@ // // mockgen -build_flags=-tags=gomock -package moqtransport -self_package github.com/mengelbart/moqtransport -destination mock_parser_test.go github.com/mengelbart/moqtransport Parser // + // Package moqtransport is a generated GoMock package. package moqtransport diff --git a/mock_receive_stream_test.go b/mock_receive_stream_test.go index b0476bb..c3d4bfa 100644 --- a/mock_receive_stream_test.go +++ b/mock_receive_stream_test.go @@ -5,6 +5,7 @@ // // mockgen -build_flags=-tags=gomock -package moqtransport -self_package github.com/mengelbart/moqtransport -destination mock_receive_stream_test.go github.com/mengelbart/moqtransport ReceiveStream // + // Package moqtransport is a generated GoMock package. package moqtransport diff --git a/mock_send_stream_test.go b/mock_send_stream_test.go index d6a6e53..4e3f8b8 100644 --- a/mock_send_stream_test.go +++ b/mock_send_stream_test.go @@ -5,6 +5,7 @@ // // mockgen -build_flags=-tags=gomock -package moqtransport -self_package github.com/mengelbart/moqtransport -destination mock_send_stream_test.go github.com/mengelbart/moqtransport SendStream // + // Package moqtransport is a generated GoMock package. package moqtransport diff --git a/mock_stream_test.go b/mock_stream_test.go index 7b2bc03..3737a71 100644 --- a/mock_stream_test.go +++ b/mock_stream_test.go @@ -5,6 +5,7 @@ // // mockgen -build_flags=-tags=gomock -package moqtransport -self_package github.com/mengelbart/moqtransport -destination mock_stream_test.go github.com/mengelbart/moqtransport Stream // + // Package moqtransport is a generated GoMock package. package moqtransport diff --git a/mockgen.go b/mockgen.go index 9609161..fb4c795 100644 --- a/mockgen.go +++ b/mockgen.go @@ -3,22 +3,15 @@ package moqtransport //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package moqtransport -self_package github.com/mengelbart/moqtransport -destination mock_stream_test.go github.com/mengelbart/moqtransport Stream" -type Stream = stream //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package moqtransport -self_package github.com/mengelbart/moqtransport -destination mock_receive_stream_test.go github.com/mengelbart/moqtransport ReceiveStream" -type ReceiveStream = receiveStream //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package moqtransport -self_package github.com/mengelbart/moqtransport -destination mock_send_stream_test.go github.com/mengelbart/moqtransport SendStream" -type SendStream = sendStream //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package moqtransport -self_package github.com/mengelbart/moqtransport -destination mock_connection_test.go github.com/mengelbart/moqtransport Connection" -type Connection = connection //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package moqtransport -self_package github.com/mengelbart/moqtransport -destination mock_parser_test.go github.com/mengelbart/moqtransport Parser" type Parser = parser -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package moqtransport -self_package github.com/mengelbart/moqtransport -destination mock_parser_factory_test.go github.com/mengelbart/moqtransport ParserFactory" -type ParserFactory = parserFactory - -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package moqtransport -self_package github.com/mengelbart/moqtransport -destination mock_control_stream_handler_test.go github.com/mengelbart/moqtransport ControlStreamHandler" -type ControlStreamHandler = controlStreamHandler +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package moqtransport -self_package github.com/mengelbart/moqtransport -destination mock_control_message_sender_test.go github.com/mengelbart/moqtransport ControlMessageSender" +type ControlMessageSender = controlMessageSender diff --git a/receive_subscription.go b/receive_subscription.go index 22113e7..b644b37 100644 --- a/receive_subscription.go +++ b/receive_subscription.go @@ -40,8 +40,8 @@ func (s *ReceiveSubscription) Read(buf []byte) (int, error) { return s.readBuffer.Read(buf) } -func (s *ReceiveSubscription) Unsubscribe() error { - return s.session.unsubscribe(s.subscribeID) +func (s *ReceiveSubscription) Unsubscribe() { + s.session.unsubscribe(s.subscribeID) } func (s *ReceiveSubscription) unsubscribe() error { diff --git a/session.go b/session.go index 6702c7c..ba24c40 100644 --- a/session.go +++ b/session.go @@ -13,183 +13,173 @@ import ( "github.com/quic-go/quic-go/quicvarint" ) +const ( + serverLoggingSuffix = "SERVER" + clientLoggingSuffix = "CLIENT" +) + var ( - errClosed = errors.New("session closed") - errUnsupportedVersion = errors.New("unsupported version") + errClosed = errors.New("session closed") ) -type messageHandler func(message) error +type parser interface { + parse() (message, error) +} -type controlStreamHandler interface { - io.Reader - send(message) error - readMessages(handler messageHandler) +type subscribeIDer interface { + message + subscribeID() uint64 } -type defaultCtrlStreamHandler struct { - logger *slog.Logger - stream Stream - parser parser +type trackNamespacer interface { + message + trackNamespace() string } -func (h *defaultCtrlStreamHandler) readNext() (message, error) { - msg, err := h.parser.parse() - h.logger.Info("handling message", "message", msg) - return msg, err +type sessionInternals struct { + logger *slog.Logger + serverHandshakeDoneCh chan struct{} + controlStreamStoreCh chan controlMessageSender // Needs to be buffered + closeOnce sync.Once + closed chan struct{} + sendSubscriptions *subscriptionMap[*SendSubscription] + receiveSubscriptions *subscriptionMap[*ReceiveSubscription] + localAnnouncements *announcementMap + remoteAnnouncements *announcementMap } -func (h *defaultCtrlStreamHandler) readMessages(handler messageHandler) { - for { - msg, err := h.readNext() - if err != nil { - if err == io.EOF { - return - } - h.logger.Error("TODO", "error", err) - return - } - if err = handler(msg); err != nil { - h.logger.Error("TODO", "error", err) - return - } +func newSessionInternals(logSuffix string) *sessionInternals { + return &sessionInternals{ + logger: defaultLogger.WithGroup(fmt.Sprintf("MOQ_SESSION_%v", logSuffix)), + serverHandshakeDoneCh: make(chan struct{}), + controlStreamStoreCh: make(chan controlMessageSender, 1), + closeOnce: sync.Once{}, + closed: make(chan struct{}), + sendSubscriptions: newSubscriptionMap[*SendSubscription](), + receiveSubscriptions: newSubscriptionMap[*ReceiveSubscription](), + localAnnouncements: newAnnouncementMap(), + remoteAnnouncements: newAnnouncementMap(), } } -func (h *defaultCtrlStreamHandler) Read(buf []byte) (int, error) { - return h.stream.Read(buf) +type controlMessageSender interface { + enqueue(message) + close() } -func (s defaultCtrlStreamHandler) send(msg message) error { - s.logger.Info("sending message", "message", msg) - return sendOnStream(s.stream, msg) -} +type Session struct { + Conn Connection + EnableDatagrams bool + LocalRole Role + RemoteRole Role + AnnouncementHandler AnnouncementHandler + HandshakeDone bool -func sendOnStream(stream SendStream, msg message) error { - buf := make([]byte, 0, 1500) - buf = msg.append(buf) - if _, err := stream.Write(buf); err != nil { - return err - } - return nil + controlStream controlMessageSender + isClient bool + si *sessionInternals } -type parser interface { - parse() (message, error) +func (s *Session) initRole() { + switch s.LocalRole { + case IngestionRole, DeliveryRole, IngestionDeliveryRole: + default: + s.LocalRole = IngestionDeliveryRole + } + switch s.RemoteRole { + case IngestionRole, DeliveryRole, IngestionDeliveryRole: + default: + s.RemoteRole = IngestionDeliveryRole + } } -type subscribeIDer interface { - message - subscribeID() uint64 +func (s *Session) validateRemoteRoleParameter(setupParameters parameters) error { + remoteRoleParam, ok := setupParameters[roleParameterKey] + if !ok { + return s.CloseWithError(ErrorCodeProtocolViolation, "missing role parameter") + } + remoteRoleParamValue, ok := remoteRoleParam.(varintParameter) + if !ok { + return s.CloseWithError(ErrorCodeProtocolViolation, "invalid role parameter type") + } + switch Role(remoteRoleParamValue.v) { + case IngestionRole, DeliveryRole, IngestionDeliveryRole: + s.RemoteRole = Role(remoteRoleParamValue.v) + default: + return s.CloseWithError(ErrorCodeProtocolViolation, "invalid role parameter value") + } + return nil } -type trackNamespacer interface { - message - trackNamespace() string +func (s *Session) storeControlStream(cs controlMessageSender) { + s.si.controlStreamStoreCh <- cs } -type Session struct { - logger *slog.Logger - - closeOnce sync.Once - closed chan struct{} - conn Connection - cms controlStreamHandler - - enableDatagrams bool - - mutex sync.Mutex - announcementHandler AnnouncementHandler - - sendSubscriptions *subscriptionMap[*SendSubscription] - receiveSubscriptions *subscriptionMap[*ReceiveSubscription] - localAnnouncements *announcementMap - remoteAnnouncements *announcementMap +func (s *Session) loadControlStream() controlMessageSender { + return <-s.si.controlStreamStoreCh } -func NewClientSession(conn Connection, clientRole Role, enableDatagrams bool) (*Session, error) { - ctrlStream, err := conn.OpenStreamSync(context.TODO()) +func (s *Session) RunClient() error { + s.si = newSessionInternals(clientLoggingSuffix) + s.isClient = true + s.initRole() + controlStream, err := s.Conn.OpenStream() if err != nil { - return nil, fmt.Errorf("opening control stream failed: %w", err) - } - ctrlStreamHandler := &defaultCtrlStreamHandler{ - logger: defaultLogger.WithGroup("MOQ_CONTROL_STREAM"), - stream: ctrlStream, - parser: newParser(quicvarint.NewReader(ctrlStream)), + return err } - csm := &clientSetupMessage{ + s.controlStream = newControlStream(controlStream, s.handleControlMessage) + s.controlStream.enqueue(&clientSetupMessage{ SupportedVersions: []version{CURRENT_VERSION}, SetupParameters: map[uint64]parameter{ roleParameterKey: varintParameter{ k: roleParameterKey, - v: uint64(clientRole), + v: uint64(s.LocalRole), }, }, - } - if err = ctrlStreamHandler.send(csm); err != nil { - return nil, fmt.Errorf("sending message on control stream failed: %w", err) - } - msg, err := ctrlStreamHandler.readNext() - if err != nil { - return nil, fmt.Errorf("parsing message filed: %w", err) - } - ssm, ok := msg.(*serverSetupMessage) - if !ok { - pe := ProtocolError{ - code: ErrorCodeProtocolViolation, - message: "received unexpected first message on control stream", - } - _ = conn.CloseWithError(pe.code, pe.message) - return nil, pe - } - if !slices.Contains(csm.SupportedVersions, ssm.SelectedVersion) { - return nil, errUnsupportedVersion - } - l := defaultLogger.WithGroup("MOQ_CLIENT_SESSION") - s, err := newSession(conn, ctrlStreamHandler, enableDatagrams, l), nil - if err != nil { - return nil, err - } - s.run() - return s, nil + }) + go s.run() + return nil } -func NewServerSession(conn Connection, enableDatagrams bool) (*Session, error) { - ctrlStream, err := conn.AcceptStream(context.TODO()) +func (s *Session) RunServer(ctx context.Context) error { + s.si = newSessionInternals(serverLoggingSuffix) + s.isClient = false + s.initRole() + controlStream, err := s.Conn.AcceptStream(ctx) if err != nil { - return nil, fmt.Errorf("accepting control stream failed: %w", err) + return err } - ctrlStreamHandler := &defaultCtrlStreamHandler{ - logger: defaultLogger.WithGroup("MOQ_CONTROL_STREAM"), - stream: ctrlStream, - parser: newParser(quicvarint.NewReader(ctrlStream)), + s.si.controlStreamStoreCh <- newControlStream(controlStream, s.handleControlMessage) + select { + case <-ctx.Done(): + s.Close() + return ctx.Err() + case <-s.si.serverHandshakeDoneCh: } - m, err := ctrlStreamHandler.readNext() - if err != nil { - return nil, fmt.Errorf("parsing message failed: %w", err) + s.si.logger.Info("server handshake done") + go s.run() + return nil +} + +func (s *Session) initClient(setup *serverSetupMessage) error { + if setup.SelectedVersion != CURRENT_VERSION { + return s.CloseWithError(ErrorCodeUnsupportedVersion, "unsupported version") } - msg, ok := m.(*clientSetupMessage) - if !ok { - pe := ProtocolError{ - code: ErrorCodeProtocolViolation, - message: "received unexpected first message on control stream", - } - _ = conn.CloseWithError(pe.code, pe.message) - return nil, pe + if err := s.validateRemoteRoleParameter(setup.SetupParameters); err != nil { + return err } - // TODO: Algorithm to select best matching version - if !slices.Contains(msg.SupportedVersions, CURRENT_VERSION) { - return nil, errUnsupportedVersion + s.HandshakeDone = true + return nil +} + +func (s *Session) initServer(setup *clientSetupMessage) error { + s.controlStream = s.loadControlStream() + if !slices.Contains(setup.SupportedVersions, CURRENT_VERSION) { + return s.CloseWithError(ErrorCodeUnsupportedVersion, "unsupported version") } - _, ok = msg.SetupParameters[roleParameterKey] - if !ok { - pe := ProtocolError{ - code: ErrorCodeProtocolViolation, - message: "missing role parameter", - } - _ = conn.CloseWithError(pe.code, pe.message) - return nil, pe + if err := s.validateRemoteRoleParameter(setup.SetupParameters); err != nil { + return err } - // TODO: save role parameter ssm := &serverSetupMessage{ SelectedVersion: CURRENT_VERSION, SetupParameters: map[uint64]parameter{ @@ -199,39 +189,17 @@ func NewServerSession(conn Connection, enableDatagrams bool) (*Session, error) { }, }, } - if err = ctrlStreamHandler.send(ssm); err != nil { - return nil, fmt.Errorf("sending message on control stream failed: %w", err) - } - l := defaultLogger.WithGroup("MOQ_SERVER_SESSION") - s, err := newSession(conn, ctrlStreamHandler, enableDatagrams, l), nil - if err != nil { - return nil, err - } - s.run() - return s, nil -} - -func newSession(conn Connection, cms controlStreamHandler, enableDatagrams bool, logger *slog.Logger) *Session { - s := &Session{ - logger: logger, - closed: make(chan struct{}), - conn: conn, - cms: cms, - enableDatagrams: enableDatagrams, - sendSubscriptions: newSubscriptionMap[*SendSubscription](), - receiveSubscriptions: newSubscriptionMap[*ReceiveSubscription](), - localAnnouncements: newAnnouncementMap(), - remoteAnnouncements: newAnnouncementMap(), - } - return s + s.controlStream.enqueue(ssm) + s.HandshakeDone = true + close(s.si.serverHandshakeDoneCh) + return nil } func (s *Session) run() { go s.acceptUnidirectionalStreams() - if s.enableDatagrams { + if s.EnableDatagrams { go s.acceptDatagrams() } - go s.cms.readMessages(s.handleControlMessage) } func (s *Session) acceptUnidirectionalStream() (ReceiveStream, error) { @@ -239,19 +207,19 @@ func (s *Session) acceptUnidirectionalStream() (ReceiveStream, error) { defer cancel() go func() { select { - case <-s.closed: + case <-s.si.closed: cancel() case <-ctx.Done(): } }() - return s.conn.AcceptUniStream(ctx) + return s.Conn.AcceptUniStream(ctx) } func (s *Session) acceptUnidirectionalStreams() { for { stream, err := s.acceptUnidirectionalStream() if err != nil { - s.logger.Error("failed to accept uni stream", "error", err) + s.si.logger.Error("failed to accept uni stream", "error", err) s.peerClosed() return } @@ -263,30 +231,30 @@ func (s *Session) handleIncomingUniStream(stream ReceiveStream) { p := newParser(quicvarint.NewReader(stream)) msg, err := p.parse() if err != nil { - s.logger.Error("failed to parse message", "error", err) + s.si.logger.Error("failed to parse message", "error", err) return } switch h := msg.(type) { case *objectMessage: - sub, ok := s.receiveSubscriptions.get(h.SubscribeID) + sub, ok := s.si.receiveSubscriptions.get(h.SubscribeID) if !ok { - s.logger.Warn("got object for unknown subscribe ID") + s.si.logger.Warn("got object for unknown subscribe ID") return } if _, err := sub.push(h); err != nil { panic(err) } case *streamHeaderTrackMessage: - sub, ok := s.receiveSubscriptions.get(h.SubscribeID) + sub, ok := s.si.receiveSubscriptions.get(h.SubscribeID) if !ok { - s.logger.Warn("got stream header track message for unknown subscription") + s.si.logger.Warn("got stream header track message for unknown subscription") return } sub.readTrackHeaderStream(stream) case *streamHeaderGroupMessage: - sub, ok := s.receiveSubscriptions.get(h.SubscribeID) + sub, ok := s.si.receiveSubscriptions.get(h.SubscribeID) if !ok { - s.logger.Warn("got stream header track message for unknown subscription") + s.si.logger.Warn("got stream header track message for unknown subscription") return } sub.readGroupHeaderStream(stream) @@ -298,19 +266,19 @@ func (s *Session) acceptDatagram() ([]byte, error) { defer cancel() go func() { select { - case <-s.closed: + case <-s.si.closed: cancel() case <-ctx.Done(): } }() - return s.conn.ReceiveDatagram(ctx) + return s.Conn.ReceiveDatagram(ctx) } func (s *Session) acceptDatagrams() { for { dgram, err := s.acceptDatagram() if err != nil { - s.logger.Error("failed to receive datagram", "error", err) + s.si.logger.Error("failed to receive datagram", "error", err) s.peerClosed() return } @@ -326,12 +294,12 @@ func (s *Session) readObjectMessages(r messageReader) { if err == io.EOF { return } - s.logger.Error("failed to parse message", "error", err) + s.si.logger.Error("failed to parse message", "error", err) pe := &ProtocolError{ code: ErrorCodeProtocolViolation, message: "invalid message format", } - _ = s.conn.CloseWithError(pe.code, pe.message) + _ = s.Conn.CloseWithError(pe.code, pe.message) return } o, ok := msg.(*objectMessage) @@ -341,17 +309,41 @@ func (s *Session) readObjectMessages(r messageReader) { message: "received unexpected control message on object stream or datagram", } // TODO: Set error on session to surface to application? - _ = s.conn.CloseWithError(pe.code, pe.message) + _ = s.Conn.CloseWithError(pe.code, pe.message) return } if err = s.handleObjectMessage(o); err != nil { - s.logger.Info("failed to handle message", "error", err) + s.si.logger.Info("failed to handle message", "error", err) return } } } func (s *Session) handleControlMessage(msg message) error { + s.si.logger.Info("received message", "message", msg) + if s.HandshakeDone { + if err := s.handleNonSetupMessage(msg); err != nil { + return err + } + return nil + } + switch mt := msg.(type) { + case *serverSetupMessage: + return s.initClient(mt) + case *clientSetupMessage: + return s.initServer(mt) + } + s.si.logger.Info("received message during handshake", "message", msg) + pe := ProtocolError{ + code: ErrorCodeProtocolViolation, + message: "received unexpected first message on control stream", + } + s.controlStream.close() + _ = s.Conn.CloseWithError(pe.code, pe.message) + return pe +} + +func (s *Session) handleNonSetupMessage(msg message) error { switch m := msg.(type) { case *subscribeMessage: return s.handleSubscribe(m) @@ -380,7 +372,7 @@ func (s *Session) handleControlMessage(msg message) error { } func (s *Session) handleSubscriptionResponse(msg subscribeIDer) error { - sub, ok := s.receiveSubscriptions.get(msg.subscribeID()) + sub, ok := s.si.receiveSubscriptions.get(msg.subscribeID()) if !ok { return &ProtocolError{ code: ErrorCodeInternal, @@ -390,14 +382,14 @@ func (s *Session) handleSubscriptionResponse(msg subscribeIDer) error { // TODO: Run a goroutine to avoid blocking here? select { case sub.responseCh <- msg: - case <-s.closed: + case <-s.si.closed: return errClosed } return nil } func (s *Session) handleAnnouncementResponse(msg trackNamespacer) error { - a, ok := s.localAnnouncements.get(msg.trackNamespace()) + a, ok := s.si.localAnnouncements.get(msg.trackNamespace()) if !ok { return &ProtocolError{ code: ErrorCodeInternal, @@ -407,7 +399,7 @@ func (s *Session) handleAnnouncementResponse(msg trackNamespacer) error { // TODO: Run a goroutine to avoid blocking here? select { case a.responseCh <- msg: - case <-s.closed: + case <-s.si.closed: return errClosed } return nil @@ -418,7 +410,7 @@ func (s *Session) handleSubscribe(msg *subscribeMessage) error { lock: sync.RWMutex{}, closeCh: make(chan struct{}), expires: 0, - conn: s.conn, + conn: s.Conn, subscribeID: msg.SubscribeID, trackAlias: msg.TrackAlias, namespace: msg.TrackNamespace, @@ -429,14 +421,14 @@ func (s *Session) handleSubscribe(msg *subscribeMessage) error { endObject: msg.EndObject, parameters: msg.Parameters, } - return s.sendSubscriptions.add(sub.subscribeID, sub) + return s.si.sendSubscriptions.add(sub.subscribeID, sub) } func (s *Session) handleUnsubscribe(msg *unsubscribeMessage) error { - if err := s.sendSubscriptions.delete(msg.SubscribeID); err != nil { + if err := s.si.sendSubscriptions.delete(msg.SubscribeID); err != nil { return err } - return s.cms.send(&subscribeDoneMessage{ + s.controlStream.enqueue(&subscribeDoneMessage{ SusbcribeID: msg.SubscribeID, StatusCode: 0, ReasonPhrase: "unsubscribed", @@ -444,10 +436,11 @@ func (s *Session) handleUnsubscribe(msg *unsubscribeMessage) error { FinalGroup: 0, FinalObject: 0, }) + return nil } func (s *Session) handleSubscribeDone(msg *subscribeDoneMessage) error { - return s.receiveSubscriptions.delete(msg.SusbcribeID) + return s.si.receiveSubscriptions.delete(msg.SusbcribeID) } func (s *Session) handleAnnounceMessage(msg *announceMessage) { @@ -456,59 +449,60 @@ func (s *Session) handleAnnounceMessage(msg *announceMessage) { namespace: msg.TrackNamespace, parameters: msg.TrackRequestParameters, } - s.mutex.Lock() - defer s.mutex.Unlock() - h := s.announcementHandler - if h != nil { - go h.Handle(a, &defaultAnnouncementResponseWriter{ + if err := s.si.remoteAnnouncements.add(a.namespace, a); err != nil { + s.si.logger.Error("dropping announcement", "error", err) + return + } + if s.AnnouncementHandler != nil { + go s.AnnouncementHandler.Handle(a, &defaultAnnouncementResponseWriter{ a: a, s: s, }) } } -func (s *Session) rejectAnnouncement(a *Announcement, code uint64, reason string) error { - return s.cms.send(&announceErrorMessage{ +func (s *Session) rejectAnnouncement(a *Announcement, code uint64, reason string) { + s.si.remoteAnnouncements.delete(a.namespace) + s.controlStream.enqueue(&announceErrorMessage{ TrackNamespace: a.namespace, ErrorCode: code, ReasonPhrase: reason, }) } -func (s *Session) acceptAnnouncement(a *Announcement) error { - if err := s.remoteAnnouncements.add(a.namespace, a); err != nil { - return err - } - return s.cms.send(&announceOkMessage{ +func (s *Session) acceptAnnouncement(a *Announcement) { + s.controlStream.enqueue(&announceOkMessage{ TrackNamespace: a.namespace, }) } func (s *Session) handleObjectMessage(o *objectMessage) error { - sub, ok := s.receiveSubscriptions.get(o.SubscribeID) + sub, ok := s.si.receiveSubscriptions.get(o.SubscribeID) if ok { _, err := sub.push(o) return err } - s.logger.Warn("dropping object message for unknown track") + s.si.logger.Warn("dropping object message for unknown track") return nil } -func (s *Session) unsubscribe(id uint64) error { - return s.cms.send(&unsubscribeMessage{ +func (s *Session) unsubscribe(id uint64) { + s.controlStream.enqueue(&unsubscribeMessage{ SubscribeID: id, }) } func (s *Session) peerClosed() { - s.closeOnce.Do(func() { - close(s.closed) + s.si.logger.Info("peerClosed called") + s.si.closeOnce.Do(func() { + close(s.si.closed) + s.controlStream.close() }) } func (s *Session) CloseWithError(code uint64, msg string) error { s.peerClosed() - return s.conn.CloseWithError(code, msg) + return s.Conn.CloseWithError(code, msg) } func (s *Session) Close() error { @@ -518,22 +512,20 @@ func (s *Session) Close() error { // TODO: Acceptor func should not pass the complete subscription object but only // the relevant header info func (s *Session) ReadSubscription(ctx context.Context, accept func(*SendSubscription) error) (*SendSubscription, error) { - sub, err := s.sendSubscriptions.getNext(ctx) + sub, err := s.si.sendSubscriptions.getNext(ctx) if err != nil { return nil, err } if err = accept(sub); err != nil { - if err = s.cms.send(&subscribeErrorMessage{ + s.controlStream.enqueue(&subscribeErrorMessage{ SubscribeID: sub.subscribeID, ErrorCode: 0, ReasonPhrase: err.Error(), TrackAlias: sub.trackAlias, - }); err != nil { - panic(err) - } - return nil, s.sendSubscriptions.delete(sub.subscribeID) + }) + return nil, s.si.sendSubscriptions.delete(sub.subscribeID) } - err = s.cms.send(&subscribeOkMessage{ + s.controlStream.enqueue(&subscribeOkMessage{ SubscribeID: sub.subscribeID, Expires: 0, // TODO: Let user set these values? ContentExists: false, // TODO: Let user set these values? @@ -543,12 +535,6 @@ func (s *Session) ReadSubscription(ctx context.Context, accept func(*SendSubscri return sub, err } -func (s *Session) HandleAnnouncements(handler AnnouncementHandler) { - s.mutex.Lock() - defer s.mutex.Unlock() - s.announcementHandler = handler -} - func (s *Session) Subscribe(ctx context.Context, subscribeID, trackAlias uint64, namespace, trackname, auth string) (*ReceiveSubscription, error) { sm := &subscribeMessage{ SubscribeID: subscribeID, @@ -568,31 +554,29 @@ func (s *Session) Subscribe(ctx context.Context, subscribeID, trackAlias uint64, } } sub := newReceiveSubscription(sm.SubscribeID, s) - if err := s.receiveSubscriptions.add(sm.SubscribeID, sub); err != nil { - return nil, err - } - if err := s.cms.send(sm); err != nil { + if err := s.si.receiveSubscriptions.add(sm.SubscribeID, sub); err != nil { return nil, err } + s.controlStream.enqueue(sm) var resp subscribeIDer select { case <-ctx.Done(): return nil, ctx.Err() - case <-s.closed: + case <-s.si.closed: return nil, errClosed case resp = <-sub.responseCh: } if resp.subscribeID() != sm.SubscribeID { // Should never happen, because messages are routed based on subscribe // ID. Wrong IDs would thus never end up here. - s.logger.Error("internal error: received response message for wrong subscription ID", "expected_id", sm.SubscribeID, "repsonse_id", resp.subscribeID()) + s.si.logger.Error("internal error: received response message for wrong subscription ID", "expected_id", sm.SubscribeID, "repsonse_id", resp.subscribeID()) return nil, errors.New("internal error: received response message for wrong subscription ID") } switch v := resp.(type) { case *subscribeOkMessage: return sub, nil case *subscribeErrorMessage: - _ = s.receiveSubscriptions.delete(sm.SubscribeID) + _ = s.si.receiveSubscriptions.delete(sm.SubscribeID) return nil, ApplicationError{ code: v.ErrorCode, mesage: v.ReasonPhrase, @@ -616,24 +600,22 @@ func (s *Session) Announce(ctx context.Context, namespace string) error { a := &Announcement{ responseCh: responseCh, } - if err := s.localAnnouncements.add(am.TrackNamespace, a); err != nil { - return err - } - if err := s.cms.send(am); err != nil { + if err := s.si.localAnnouncements.add(am.TrackNamespace, a); err != nil { return err } + s.controlStream.enqueue(am) var resp trackNamespacer select { case <-ctx.Done(): return ctx.Err() - case <-s.closed: + case <-s.si.closed: return errClosed case resp = <-responseCh: } if resp.trackNamespace() != am.TrackNamespace { // Should never happen, because messages are routed based on trackname. // Wrong tracknames would thus never end up here. - s.logger.Error("internal error: received response message for wrong announce track namespace", "expected_track_namespace", am.TrackNamespace, "response_track_namespace", resp.trackNamespace()) + s.si.logger.Error("internal error: received response message for wrong announce track namespace", "expected_track_namespace", am.TrackNamespace, "response_track_namespace", resp.trackNamespace()) return errors.New("internal error: received response message for wrong announce track namespace") } switch v := resp.(type) { diff --git a/session_test.go b/session_test.go index a2612ba..eeb8b42 100644 --- a/session_test.go +++ b/session_test.go @@ -2,8 +2,6 @@ package moqtransport import ( "context" - "log/slog" - "sync" "testing" "time" @@ -11,28 +9,28 @@ import ( "go.uber.org/mock/gomock" ) -func session(conn Connection, ctrlStream controlStreamHandler) *Session { - return &Session{ - logger: slog.Default(), - closeOnce: sync.Once{}, - closed: make(chan struct{}), - conn: conn, - cms: ctrlStream, - enableDatagrams: false, - sendSubscriptions: newSubscriptionMap[*SendSubscription](), - receiveSubscriptions: newSubscriptionMap[*ReceiveSubscription](), - localAnnouncements: newAnnouncementMap(), - remoteAnnouncements: newAnnouncementMap(), +func session(conn Connection, ctrl controlMessageSender, h AnnouncementHandler) *Session { + s := &Session{ + Conn: conn, + EnableDatagrams: false, + LocalRole: 0, + RemoteRole: 0, + AnnouncementHandler: h, + HandshakeDone: false, + isClient: false, + si: newSessionInternals("SERVER"), } + s.storeControlStream(ctrl) + return s } func TestSession(t *testing.T) { t.Run("handle_object", func(t *testing.T) { ctrl := gomock.NewController(t) mc := NewMockConnection(ctrl) - csh := NewMockControlStreamHandler(ctrl) - s := *session(mc, csh) - err := s.receiveSubscriptions.add(0, newReceiveSubscription(0, &s)) + done := make(chan struct{}) + s := session(mc, nil, nil) + err := s.si.receiveSubscriptions.add(0, newReceiveSubscription(0, s)) assert.NoError(t, err) object := &objectMessage{ SubscribeID: 0, @@ -42,10 +40,9 @@ func TestSession(t *testing.T) { ObjectSendOrder: 0, ObjectPayload: []byte{0x0a, 0x0b}, } - done := make(chan struct{}) go func() { buf := make([]byte, 1024) - sub, ok := s.receiveSubscriptions.get(0) + sub, ok := s.si.receiveSubscriptions.get(0) assert.True(t, ok) n, err1 := sub.Read(buf) assert.NoError(t, err1) @@ -62,8 +59,10 @@ func TestSession(t *testing.T) { t.Run("handle_client_setup", func(t *testing.T) { ctrl := gomock.NewController(t) mc := NewMockConnection(ctrl) - csh := NewMockControlStreamHandler(ctrl) - s := session(mc, csh) + csh := NewMockControlMessageSender(ctrl) + csh.EXPECT().enqueue(gomock.Any()).AnyTimes() + done := make(chan struct{}) + s := session(mc, csh, nil) csm := &clientSetupMessage{ SupportedVersions: []version{CURRENT_VERSION}, SetupParameters: map[uint64]parameter{ @@ -74,16 +73,17 @@ func TestSession(t *testing.T) { }, } err := s.handleControlMessage(csm) - assert.Error(t, err) - assert.EqualError(t, err, "received unexpected message type on control stream") + assert.NoError(t, err) + close(done) }) t.Run("handle_subscribe_request", func(t *testing.T) { ctrl := gomock.NewController(t) mc := NewMockConnection(ctrl) - csh := NewMockControlStreamHandler(ctrl) - s := session(mc, csh) + csh := NewMockControlMessageSender(ctrl) + s := session(mc, csh, nil) done := make(chan struct{}) - csh.EXPECT().send(&subscribeOkMessage{ + csh.EXPECT().enqueue(gomock.Any()).Times(1) // Setup message + csh.EXPECT().enqueue(&subscribeOkMessage{ SubscribeID: 17, Expires: 0, ContentExists: false, @@ -92,20 +92,28 @@ func TestSession(t *testing.T) { }).Do(func(_ message) { close(done) }) - go func() { - err := s.handleControlMessage(&subscribeMessage{ - SubscribeID: 17, - TrackAlias: 0, - TrackNamespace: "namespace", - TrackName: "track", - StartGroup: Location{}, - StartObject: Location{}, - EndGroup: Location{}, - EndObject: Location{}, - Parameters: map[uint64]parameter{}, - }) - assert.NoError(t, err) - }() + err := s.handleControlMessage(&clientSetupMessage{ + SupportedVersions: []version{CURRENT_VERSION}, + SetupParameters: map[uint64]parameter{ + roleParameterKey: varintParameter{ + k: roleParameterKey, + v: uint64(IngestionDeliveryRole), + }, + }, + }) + assert.NoError(t, err) + err = s.handleControlMessage(&subscribeMessage{ + SubscribeID: 17, + TrackAlias: 0, + TrackNamespace: "namespace", + TrackName: "track", + StartGroup: Location{}, + StartObject: Location{}, + EndGroup: Location{}, + EndObject: Location{}, + Parameters: map[uint64]parameter{}, + }) + assert.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() sub, err := s.ReadSubscription(ctx, func(ss *SendSubscription) error { @@ -123,26 +131,33 @@ func TestSession(t *testing.T) { t.Run("handle_announcement", func(t *testing.T) { ctrl := gomock.NewController(t) mc := NewMockConnection(ctrl) - csh := NewMockControlStreamHandler(ctrl) - s := session(mc, csh) + csh := NewMockControlMessageSender(ctrl) + s := session(mc, csh, AnnouncementHandlerFunc(func(a *Announcement, arw AnnouncementResponseWriter) { + assert.NotNil(t, a) + arw.Accept() + })) done := make(chan struct{}) - csh.EXPECT().send(&announceOkMessage{ + csh.EXPECT().enqueue(gomock.Any()).Times(1) // setup message + csh.EXPECT().enqueue(&announceOkMessage{ TrackNamespace: "namespace", }).Do(func(_ message) { close(done) }) - s.HandleAnnouncements(AnnouncementHandlerFunc(func(a *Announcement, arw AnnouncementResponseWriter) { - err := arw.Accept() - assert.NoError(t, err) - assert.NotNil(t, a) - })) - go func() { - err := s.handleControlMessage(&announceMessage{ - TrackNamespace: "namespace", - TrackRequestParameters: map[uint64]parameter{}, - }) - assert.NoError(t, err) - }() + err := s.handleControlMessage(&clientSetupMessage{ + SupportedVersions: []version{CURRENT_VERSION}, + SetupParameters: map[uint64]parameter{ + roleParameterKey: varintParameter{ + k: roleParameterKey, + v: uint64(IngestionDeliveryRole), + }, + }, + }) + assert.NoError(t, err) + err = s.handleControlMessage(&announceMessage{ + TrackNamespace: "namespace", + TrackRequestParameters: map[uint64]parameter{}, + }) + assert.NoError(t, err) select { case <-time.After(time.Second): assert.Fail(t, "test timed out") @@ -152,10 +167,11 @@ func TestSession(t *testing.T) { t.Run("subscribe", func(t *testing.T) { ctrl := gomock.NewController(t) mc := NewMockConnection(ctrl) - csh := NewMockControlStreamHandler(ctrl) - s := session(mc, csh) + csh := NewMockControlMessageSender(ctrl) + s := session(mc, csh, nil) done := make(chan struct{}) - csh.EXPECT().send(&subscribeMessage{ + csh.EXPECT().enqueue(gomock.Any()).Times(1) + csh.EXPECT().enqueue(&subscribeMessage{ SubscribeID: 17, TrackAlias: 0, TrackNamespace: "namespace", @@ -175,19 +191,34 @@ func TestSession(t *testing.T) { close(done) }() }) + err := s.handleControlMessage(&clientSetupMessage{ + SupportedVersions: []version{CURRENT_VERSION}, + SetupParameters: map[uint64]parameter{ + roleParameterKey: varintParameter{ + k: roleParameterKey, + v: uint64(IngestionDeliveryRole), + }, + }, + }) + assert.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() track, err := s.Subscribe(ctx, 17, 0, "namespace", "track", "auth") assert.NoError(t, err) assert.NotNil(t, track) - <-done + select { + case <-time.After(time.Second): + assert.Fail(t, "test timed out") + case <-done: + } }) t.Run("announce", func(t *testing.T) { ctrl := gomock.NewController(t) mc := NewMockConnection(ctrl) - csh := NewMockControlStreamHandler(ctrl) - s := session(mc, csh) - csh.EXPECT().send(&announceMessage{ + csh := NewMockControlMessageSender(ctrl) + s := session(mc, csh, nil) + csh.EXPECT().enqueue(gomock.Any()).Times(1) + csh.EXPECT().enqueue(&announceMessage{ TrackNamespace: "namespace", TrackRequestParameters: map[uint64]parameter{}, }).Do(func(_ message) { @@ -198,9 +229,19 @@ func TestSession(t *testing.T) { assert.NoError(t, err) }() }) + err := s.handleControlMessage(&clientSetupMessage{ + SupportedVersions: []version{CURRENT_VERSION}, + SetupParameters: map[uint64]parameter{ + roleParameterKey: varintParameter{ + k: roleParameterKey, + v: uint64(IngestionDeliveryRole), + }, + }, + }) + assert.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - err := s.Announce(ctx, "namespace") + err = s.Announce(ctx, "namespace") assert.NoError(t, err) }) }