Skip to content

Commit

Permalink
Refactor control stream sending and session creation API
Browse files Browse the repository at this point in the history
  • Loading branch information
mengelbart committed Jun 7, 2024
1 parent 58fec26 commit a8fc8cc
Show file tree
Hide file tree
Showing 20 changed files with 565 additions and 616 deletions.
12 changes: 6 additions & 6 deletions announcement.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
package moqtransport

type AnnouncementResponseWriter interface {
Accept() error
Reject(code uint64, reason string) error
Accept()
Reject(code uint64, reason string)
}

type defaultAnnouncementResponseWriter struct {
a *Announcement
s *Session
}

func (a *defaultAnnouncementResponseWriter) Accept() error {
return a.s.acceptAnnouncement(a.a)
func (a *defaultAnnouncementResponseWriter) Accept() {
a.s.acceptAnnouncement(a.a)
}

func (a *defaultAnnouncementResponseWriter) Reject(code uint64, reason string) error {
return a.s.rejectAnnouncement(a.a, code, reason)
func (a *defaultAnnouncementResponseWriter) Reject(code uint64, reason string) {
a.s.rejectAnnouncement(a.a, code, reason)
}

type AnnouncementHandler interface {
Expand Down
6 changes: 6 additions & 0 deletions announcement_map.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,9 @@ func (m *announcementMap) get(name string) (*Announcement, bool) {
a, ok := m.announcements[name]
return a, ok
}

func (m *announcementMap) delete(name string) {
m.mutex.Lock()
defer m.mutex.Unlock()
delete(m.announcements, name)
}
83 changes: 83 additions & 0 deletions control_stream.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package moqtransport

import (
"io"
"log/slog"

"github.com/quic-go/quic-go/quicvarint"
)

type controlStream struct {
logger *slog.Logger
stream Stream
handle messageHandler
parser parser
sendQueue chan message
closeCh chan struct{}
}

type messageHandler func(message) error

func newControlStream(s Stream, h messageHandler) *controlStream {
cs := &controlStream{
logger: defaultLogger.WithGroup("MOQ_CONTROL_STREAM"),
stream: s,
handle: h,
parser: newParser(quicvarint.NewReader(s)),
sendQueue: make(chan message, 64),
closeCh: make(chan struct{}),
}
go cs.readMessages()
go cs.writeMessages()
return cs
}

func (s *controlStream) readMessages() {
for {
msg, err := s.parser.parse()
if err != nil {
if err == io.EOF {
return
}
s.logger.Error("TODO", "error", err)
return
}
if err = s.handle(msg); err != nil {
s.logger.Error("failed to handle control stream message", "error", err)
panic("TODO: Close connection")
}
}
}

func (s *controlStream) writeMessages() {
for {
select {
case <-s.closeCh:
s.logger.Info("close called, leaving control stream write loop")
return
case msg := <-s.sendQueue:
s.logger.Info("sending control message", "message", msg)
buf := make([]byte, 0, 1500)
buf = msg.append(buf)
if _, err := s.stream.Write(buf); err != nil {
if err == io.EOF {
s.logger.Info("write stream closed, leaving control stream write loop")
return
}
s.logger.Error("failed to write to control stream", "error", err)
}
}
}
}

func (s *controlStream) enqueue(m message) {
select {
case s.sendQueue <- m:
default:
s.logger.Warn("dropping control stream message because send queue is full")
}
}

func (s *controlStream) close() {
close(s.closeCh)
}
32 changes: 19 additions & 13 deletions examples/date-client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ func main() {
}

func run(ctx context.Context, addr string, wt bool, namespace, trackname string) error {
var session *moqtransport.Session
var conn moqtransport.Connection
var err error
if wt {
Expand All @@ -38,24 +37,31 @@ func run(ctx context.Context, addr string, wt bool, namespace, trackname string)
if err != nil {
return err
}
session, err = moqtransport.NewClientSession(conn, moqtransport.IngestionDeliveryRole, !wt)
if err != nil {
announcementWaitCh := make(chan *moqtransport.Announcement)
session := &moqtransport.Session{
Conn: conn,
EnableDatagrams: true,
LocalRole: moqtransport.DeliveryRole,
AnnouncementHandler: moqtransport.AnnouncementHandlerFunc(func(a *moqtransport.Announcement, arw moqtransport.AnnouncementResponseWriter) {
if a.Namespace() == "clock" {
arw.Accept()
announcementWaitCh <- a
return
}
arw.Reject(0, "invalid namespace")
}),
}
if err = session.RunClient(); err != nil {
return err
}
defer session.Close()

session.HandleAnnouncements(moqtransport.AnnouncementHandlerFunc(func(a *moqtransport.Announcement, arw moqtransport.AnnouncementResponseWriter) {
if a.Namespace() == "clock" {
arw.Accept()
return
}
arw.Reject(0, "invalid namespace")
}))
log.Println("got Announcement")
a := <-announcementWaitCh
log.Printf("got Announcement: %v\n", a)
log.Println("subscribing")
rs, err := session.Subscribe(context.Background(), 0, 0, namespace, trackname, "")
if err != nil {
panic(err)
return err
}
log.Println("got subscription")
buf := make([]byte, 64_000)
Expand All @@ -66,7 +72,7 @@ func run(ctx context.Context, addr string, wt bool, namespace, trackname string)
log.Printf("got last object")
return nil
}
panic(err)
return err
}
log.Printf("got object: %v\n", string(buf[:n]))
}
Expand Down
71 changes: 12 additions & 59 deletions examples/date-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,19 @@ func main() {
certFile := flag.String("cert", "localhost.pem", "TLS certificate file")
keyFile := flag.String("key", "localhost-key.pem", "TLS key file")
addr := flag.String("addr", "localhost:8080", "listen address")
wt := flag.Bool("webtransport", false, "Serve WebTransport only")
quic := flag.Bool("quic", false, "Serve QUIC only")
flag.Parse()

if err := run(context.Background(), *addr, *wt, *quic, *certFile, *keyFile); err != nil {
if err := run(context.Background(), *addr, *certFile, *keyFile); err != nil {
log.Fatal(err)
}
}

func run(ctx context.Context, addr string, wt, quic bool, certFile, keyFile string) error {
func run(ctx context.Context, addr string, certFile, keyFile string) error {
tlsConfig, err := generateTLSConfigWithCertAndKey(certFile, keyFile)
if err != nil {
log.Printf("failed to generate TLS config from cert file and key, generating in memory certs: %v", err)
tlsConfig = generateTLSConfig()
}
if wt {
return listenWebTransport(addr, tlsConfig)
}
if quic {
return listenQUIC(ctx, addr, tlsConfig)
}
return listen(ctx, addr, tlsConfig)
}

Expand All @@ -71,8 +63,11 @@ func listen(ctx context.Context, addr string, tlsConfig *tls.Config) error {
w.WriteHeader(http.StatusInternalServerError)
return
}
moqSession, err := moqtransport.NewServerSession(webtransportmoq.New(session), false)
if err != nil {
moqSession := &moqtransport.Session{
Conn: webtransportmoq.New(session),
EnableDatagrams: true,
}
if err := moqSession.RunServer(ctx); err != nil {
log.Printf("MoQ Session initialization failed: %v", err)
w.WriteHeader(http.StatusInternalServerError)
return
Expand All @@ -88,60 +83,18 @@ func listen(ctx context.Context, addr string, tlsConfig *tls.Config) error {
go wt.ServeQUICConn(conn)
}
if conn.ConnectionState().TLS.NegotiatedProtocol == "moq-00" {
s, err := moqtransport.NewServerSession(quicmoq.New(conn), true)
if err != nil {
s := &moqtransport.Session{
Conn: quicmoq.New(conn),
EnableDatagrams: true,
}
if err := s.RunServer(ctx); err != nil {
return err
}
go handle(s)
}
}
}

func listenWebTransport(addr string, tlsConfig *tls.Config) error {
wt := webtransport.Server{
H3: http3.Server{
Addr: addr,
TLSConfig: tlsConfig,
},
}
http.HandleFunc("/moq", func(w http.ResponseWriter, r *http.Request) {
session, err := wt.Upgrade(w, r)
if err != nil {
log.Printf("upgrading to webtransport failed: %v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
moqSession, err := moqtransport.NewServerSession(webtransportmoq.New(session), false)
if err != nil {
log.Printf("MoQ Session initialization failed: %v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
go handle(moqSession)
})
return wt.ListenAndServe()
}

func listenQUIC(ctx context.Context, addr string, tlsConfig *tls.Config) error {
listener, err := quic.ListenAddr(addr, tlsConfig, &quic.Config{
EnableDatagrams: true,
})
if err != nil {
return err
}
for {
conn, err := listener.Accept(ctx)
if err != nil {
return err
}
s, err := moqtransport.NewServerSession(quicmoq.New(conn), true)
if err != nil {
return err
}
go handle(s)
}
}

func handle(p *moqtransport.Session) {
go func() {
s, err := p.ReadSubscription(context.Background(), func(s *moqtransport.SendSubscription) error {
Expand Down
Loading

0 comments on commit a8fc8cc

Please sign in to comment.