diff --git a/quic.go b/quic.go index 286302f0..ba5c2af0 100644 --- a/quic.go +++ b/quic.go @@ -228,16 +228,22 @@ func (q *QUICConn) HandleData(level QUICEncryptionLevel, data []byte) error { return nil } // The handshake goroutine has exited. + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() c.hand.Write(c.quic.readbuf) c.quic.readbuf = nil for q.conn.hand.Len() >= 4 && q.conn.handshakeErr == nil { b := q.conn.hand.Bytes() n := int(b[1])<<16 | int(b[2])<<8 | int(b[3]) - if 4+n < len(b) { + if n > maxHandshake { + q.conn.handshakeErr = fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake) + break + } + if len(b) < 4+n { return nil } if err := q.conn.handlePostHandshakeMessage(); err != nil { - return quicError(err) + q.conn.handshakeErr = err } } if q.conn.handshakeErr != nil { diff --git a/quic_test.go b/quic_test.go index 9a29fa56..323906a2 100644 --- a/quic_test.go +++ b/quic_test.go @@ -85,7 +85,7 @@ func (q *testQUICConn) setWriteSecret(level QUICEncryptionLevel, suite uint16, s var errTransportParametersRequired = errors.New("transport parameters required") -func runTestQUICConnection(ctx context.Context, cli, srv *testQUICConn, onHandleCryptoData func()) error { +func runTestQUICConnection(ctx context.Context, cli, srv *testQUICConn, onEvent func(e QUICEvent, src, dst *testQUICConn) bool) error { a, b := cli, srv for _, c := range []*testQUICConn{a, b} { if !c.conn.conn.quic.started { @@ -97,6 +97,9 @@ func runTestQUICConnection(ctx context.Context, cli, srv *testQUICConn, onHandle idleCount := 0 for { e := a.conn.NextEvent() + if onEvent != nil && onEvent(e, a, b) { + continue + } switch e.Kind { case QUICNoEvent: idleCount++ @@ -211,6 +214,37 @@ func TestQUICSessionResumption(t *testing.T) { } } +func TestQUICFragmentaryData(t *testing.T) { + clientConfig := testConfig.Clone() + clientConfig.MinVersion = VersionTLS13 + clientConfig.ClientSessionCache = NewLRUClientSessionCache(1) + clientConfig.ServerName = "example.go.dev" + + serverConfig := testConfig.Clone() + serverConfig.MinVersion = VersionTLS13 + + cli := newTestQUICClient(t, clientConfig) + cli.conn.SetTransportParameters(nil) + srv := newTestQUICServer(t, serverConfig) + srv.conn.SetTransportParameters(nil) + onEvent := func(e QUICEvent, src, dst *testQUICConn) bool { + if e.Kind == QUICWriteData { + // Provide the data one byte at a time. + for i := range e.Data { + if err := dst.conn.HandleData(e.Level, e.Data[i:i+1]); err != nil { + t.Errorf("HandleData: %v", err) + break + } + } + return true + } + return false + } + if err := runTestQUICConnection(context.Background(), cli, srv, onEvent); err != nil { + t.Fatalf("error during first connection handshake: %v", err) + } +} + func TestQUICPostHandshakeClientAuthentication(t *testing.T) { // RFC 9001, Section 4.4. config := testConfig.Clone() @@ -264,6 +298,28 @@ func TestQUICPostHandshakeKeyUpdate(t *testing.T) { } } +func TestQUICPostHandshakeMessageTooLarge(t *testing.T) { + config := testConfig.Clone() + config.MinVersion = VersionTLS13 + cli := newTestQUICClient(t, config) + cli.conn.SetTransportParameters(nil) + srv := newTestQUICServer(t, config) + srv.conn.SetTransportParameters(nil) + if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil { + t.Fatalf("error during connection handshake: %v", err) + } + + size := maxHandshake + 1 + if err := cli.conn.HandleData(QUICEncryptionLevelApplication, []byte{ + byte(typeNewSessionTicket), + byte(size >> 16), + byte(size >> 8), + byte(size), + }); err == nil { + t.Fatalf("%v-byte post-handshake message: got no error, want one", size) + } +} + func TestQUICHandshakeError(t *testing.T) { clientConfig := testConfig.Clone() clientConfig.MinVersion = VersionTLS13 @@ -298,26 +354,22 @@ func TestQUICConnectionState(t *testing.T) { cli.conn.SetTransportParameters(nil) srv := newTestQUICServer(t, config) srv.conn.SetTransportParameters(nil) - onHandleCryptoData := func() { + onEvent := func(e QUICEvent, src, dst *testQUICConn) bool { cliCS := cli.conn.ConnectionState() - cliWantALPN := "" if _, ok := cli.readSecret[QUICEncryptionLevelApplication]; ok { - cliWantALPN = "h3" - } - if want, got := cliCS.NegotiatedProtocol, cliWantALPN; want != got { - t.Errorf("cli.ConnectionState().NegotiatedProtocol = %q, want %q", want, got) + if want, got := cliCS.NegotiatedProtocol, "h3"; want != got { + t.Errorf("cli.ConnectionState().NegotiatedProtocol = %q, want %q", want, got) + } } - srvCS := srv.conn.ConnectionState() - srvWantALPN := "" if _, ok := srv.readSecret[QUICEncryptionLevelHandshake]; ok { - srvWantALPN = "h3" - } - if want, got := srvCS.NegotiatedProtocol, srvWantALPN; want != got { - t.Errorf("srv.ConnectionState().NegotiatedProtocol = %q, want %q", want, got) + if want, got := srvCS.NegotiatedProtocol, "h3"; want != got { + t.Errorf("srv.ConnectionState().NegotiatedProtocol = %q, want %q", want, got) + } } + return false } - if err := runTestQUICConnection(context.Background(), cli, srv, onHandleCryptoData); err != nil { + if err := runTestQUICConnection(context.Background(), cli, srv, onEvent); err != nil { t.Fatalf("error during connection handshake: %v", err) } } diff --git a/u_quic.go b/u_quic.go index e0c44d55..c312522e 100644 --- a/u_quic.go +++ b/u_quic.go @@ -7,6 +7,7 @@ package tls import ( "context" "errors" + "fmt" ) // A UQUICConn represents a connection which uses a QUIC implementation as the underlying @@ -108,16 +109,22 @@ func (q *UQUICConn) HandleData(level QUICEncryptionLevel, data []byte) error { return nil } // The handshake goroutine has exited. + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() c.hand.Write(c.quic.readbuf) c.quic.readbuf = nil for q.conn.hand.Len() >= 4 && q.conn.handshakeErr == nil { b := q.conn.hand.Bytes() n := int(b[1])<<16 | int(b[2])<<8 | int(b[3]) - if 4+n < len(b) { + if n > maxHandshake { + q.conn.handshakeErr = fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake) + break + } + if len(b) < 4+n { return nil } if err := q.conn.handlePostHandshakeMessage(); err != nil { - return quicError(err) + q.conn.handshakeErr = err } } if q.conn.handshakeErr != nil {