]> Cypherpunks.ru repositories - gostls13.git/commitdiff
crypto/tls: support QUIC as a transport
authorDamien Neil <dneil@google.com>
Fri, 14 Oct 2022 17:48:42 +0000 (10:48 -0700)
committerDamien Neil <dneil@google.com>
Wed, 24 May 2023 22:40:18 +0000 (22:40 +0000)
Add a QUICConn type for use by QUIC implementations.

A QUICConn provides unencrypted handshake bytes and connection
secrets to the QUIC layer, and receives handshake bytes.

For #44886

Change-Id: I859dda4cc6d466a1df2fb863a69d3a2a069110d5
Reviewed-on: https://go-review.googlesource.com/c/go/+/493655
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Run-TryBot: Damien Neil <dneil@google.com>
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
Reviewed-by: Marten Seemann <martenseemann@gmail.com>
12 files changed:
api/next/44886.txt [new file with mode: 0644]
src/crypto/tls/alert.go
src/crypto/tls/common.go
src/crypto/tls/conn.go
src/crypto/tls/handshake_client.go
src/crypto/tls/handshake_client_tls13.go
src/crypto/tls/handshake_messages.go
src/crypto/tls/handshake_messages_test.go
src/crypto/tls/handshake_server.go
src/crypto/tls/handshake_server_tls13.go
src/crypto/tls/quic.go [new file with mode: 0644]
src/crypto/tls/quic_test.go [new file with mode: 0644]

diff --git a/api/next/44886.txt b/api/next/44886.txt
new file mode 100644 (file)
index 0000000..b3ab699
--- /dev/null
@@ -0,0 +1,41 @@
+pkg crypto/tls, const QUICEncryptionLevelApplication = 2 #44886
+pkg crypto/tls, const QUICEncryptionLevelApplication QUICEncryptionLevel #44886
+pkg crypto/tls, const QUICEncryptionLevelHandshake = 1 #44886
+pkg crypto/tls, const QUICEncryptionLevelHandshake QUICEncryptionLevel #44886
+pkg crypto/tls, const QUICEncryptionLevelInitial = 0 #44886
+pkg crypto/tls, const QUICEncryptionLevelInitial QUICEncryptionLevel #44886
+pkg crypto/tls, const QUICHandshakeDone = 6 #44886
+pkg crypto/tls, const QUICHandshakeDone QUICEventKind #44886
+pkg crypto/tls, const QUICNoEvent = 0 #44886
+pkg crypto/tls, const QUICNoEvent QUICEventKind #44886
+pkg crypto/tls, const QUICSetReadSecret = 1 #44886
+pkg crypto/tls, const QUICSetReadSecret QUICEventKind #44886
+pkg crypto/tls, const QUICSetWriteSecret = 2 #44886
+pkg crypto/tls, const QUICSetWriteSecret QUICEventKind #44886
+pkg crypto/tls, const QUICTransportParameters = 4 #44886
+pkg crypto/tls, const QUICTransportParameters QUICEventKind #44886
+pkg crypto/tls, const QUICTransportParametersRequired = 5 #44886
+pkg crypto/tls, const QUICTransportParametersRequired QUICEventKind #44886
+pkg crypto/tls, const QUICWriteData = 3 #44886
+pkg crypto/tls, const QUICWriteData QUICEventKind #44886
+pkg crypto/tls, func QUICClient(*QUICConfig) *QUICConn #44886
+pkg crypto/tls, func QUICServer(*QUICConfig) *QUICConn #44886
+pkg crypto/tls, method (*QUICConn) Close() error #44886
+pkg crypto/tls, method (*QUICConn) ConnectionState() ConnectionState #44886
+pkg crypto/tls, method (*QUICConn) HandleData(QUICEncryptionLevel, []uint8) error #44886
+pkg crypto/tls, method (*QUICConn) NextEvent() QUICEvent #44886
+pkg crypto/tls, method (*QUICConn) SetTransportParameters([]uint8) #44886
+pkg crypto/tls, method (*QUICConn) Start(context.Context) error #44886
+pkg crypto/tls, method (AlertError) Error() string #44886
+pkg crypto/tls, method (QUICEncryptionLevel) String() string #44886
+pkg crypto/tls, type AlertError uint8 #44886
+pkg crypto/tls, type QUICConfig struct #44886
+pkg crypto/tls, type QUICConfig struct, TLSConfig *Config #44886
+pkg crypto/tls, type QUICConn struct #44886
+pkg crypto/tls, type QUICEncryptionLevel int #44886
+pkg crypto/tls, type QUICEvent struct #44886
+pkg crypto/tls, type QUICEvent struct, Data []uint8 #44886
+pkg crypto/tls, type QUICEvent struct, Kind QUICEventKind #44886
+pkg crypto/tls, type QUICEvent struct, Level QUICEncryptionLevel #44886
+pkg crypto/tls, type QUICEvent struct, Suite uint16 #44886
+pkg crypto/tls, type QUICEventKind int #44886
index 4790b7372459e5ac1068d7aef25bff0e251d4b1e..33022cd2b4bf8a677186176fe5c57e350f13c91f 100644 (file)
@@ -6,6 +6,16 @@ package tls
 
 import "strconv"
 
+// An AlertError is a TLS alert.
+//
+// When using a QUIC transport, QUICConn methods will return an error
+// which wraps AlertError rather than sending a TLS alert.
+type AlertError uint8
+
+func (e AlertError) Error() string {
+       return alert(e).String()
+}
+
 type alert uint8
 
 const (
index 5394d64ac6c810957ef66726f7a8d2bef775ec9c..b8332e90fd233f57d83c24b34b1e7758d179e818 100644 (file)
@@ -99,6 +99,7 @@ const (
        extensionCertificateAuthorities  uint16 = 47
        extensionSignatureAlgorithmsCert uint16 = 50
        extensionKeyShare                uint16 = 51
+       extensionQUICTransportParameters uint16 = 57
        extensionRenegotiationInfo       uint16 = 0xff01
 )
 
index 847d3f8f063c6f8d9cc289d46fb6b2947eae593d..e3607c8fecaf3c0958dc687d3fc7324fb55c0172 100644 (file)
@@ -29,6 +29,7 @@ type Conn struct {
        conn        net.Conn
        isClient    bool
        handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake
+       quic        *quicState                  // nil for non-QUIC connections
 
        // isHandshakeComplete is true if the connection is currently transferring
        // application data (i.e. is not currently processing a handshake).
@@ -176,7 +177,8 @@ type halfConn struct {
        nextCipher any       // next encryption state
        nextMac    hash.Hash // next MAC algorithm
 
-       trafficSecret []byte // current TLS 1.3 traffic secret
+       level         QUICEncryptionLevel // current QUIC encryption level
+       trafficSecret []byte              // current TLS 1.3 traffic secret
 }
 
 type permanentError struct {
@@ -221,8 +223,9 @@ func (hc *halfConn) changeCipherSpec() error {
        return nil
 }
 
-func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, secret []byte) {
+func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, level QUICEncryptionLevel, secret []byte) {
        hc.trafficSecret = secret
+       hc.level = level
        key, iv := suite.trafficKey(secret)
        hc.cipher = suite.aead(key, iv)
        for i := range hc.seq {
@@ -613,6 +616,10 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error {
        }
        c.input.Reset(nil)
 
+       if c.quic != nil {
+               return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with QUIC transport"))
+       }
+
        // Read header, payload.
        if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil {
                // RFC 8446, Section 6.1 suggests that EOF without an alertCloseNotify
@@ -702,6 +709,9 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error {
                return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
 
        case recordTypeAlert:
+               if c.quic != nil {
+                       return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
+               }
                if len(data) != 2 {
                        return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
                }
@@ -819,6 +829,10 @@ func (c *Conn) readFromUntil(r io.Reader, n int) error {
 
 // sendAlertLocked sends a TLS alert message.
 func (c *Conn) sendAlertLocked(err alert) error {
+       if c.quic != nil {
+               return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
+       }
+
        switch err {
        case alertNoRenegotiation, alertCloseNotify:
                c.tmp[0] = alertLevelWarning
@@ -953,6 +967,19 @@ var outBufPool = sync.Pool{
 // writeRecordLocked writes a TLS record with the given type and payload to the
 // connection and updates the record layer state.
 func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
+       if c.quic != nil {
+               if typ != recordTypeHandshake {
+                       return 0, errors.New("tls: internal error: sending non-handshake message to QUIC transport")
+               }
+               c.quicWriteCryptoData(c.out.level, data)
+               if !c.buffering {
+                       if _, err := c.flush(); err != nil {
+                               return 0, err
+                       }
+               }
+               return len(data), nil
+       }
+
        outBufPtr := outBufPool.Get().(*[]byte)
        outBuf := *outBufPtr
        defer func() {
@@ -1037,28 +1064,40 @@ func (c *Conn) writeChangeCipherRecord() error {
        return err
 }
 
+// readHandshakeBytes reads handshake data until c.hand contains at least n bytes.
+func (c *Conn) readHandshakeBytes(n int) error {
+       if c.quic != nil {
+               return c.quicReadHandshakeBytes(n)
+       }
+       for c.hand.Len() < n {
+               if err := c.readRecord(); err != nil {
+                       return err
+               }
+       }
+       return nil
+}
+
 // readHandshake reads the next handshake message from
 // the record layer. If transcript is non-nil, the message
 // is written to the passed transcriptHash.
 func (c *Conn) readHandshake(transcript transcriptHash) (any, error) {
-       for c.hand.Len() < 4 {
-               if err := c.readRecord(); err != nil {
-                       return nil, err
-               }
+       if err := c.readHandshakeBytes(4); err != nil {
+               return nil, err
        }
-
        data := c.hand.Bytes()
        n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
        if n > maxHandshake {
                c.sendAlertLocked(alertInternalError)
                return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake))
        }
-       for c.hand.Len() < 4+n {
-               if err := c.readRecord(); err != nil {
-                       return nil, err
-               }
+       if err := c.readHandshakeBytes(4 + n); err != nil {
+               return nil, err
        }
        data = c.hand.Next(4 + n)
+       return c.unmarshalHandshakeMessage(data, transcript)
+}
+
+func (c *Conn) unmarshalHandshakeMessage(data []byte, transcript transcriptHash) (handshakeMessage, error) {
        var m handshakeMessage
        switch data[0] {
        case typeHelloRequest:
@@ -1249,7 +1288,6 @@ func (c *Conn) handlePostHandshakeMessage() error {
        if err != nil {
                return err
        }
-
        c.retryCount++
        if c.retryCount > maxUselessRecords {
                c.sendAlert(alertUnexpectedMessage)
@@ -1261,20 +1299,28 @@ func (c *Conn) handlePostHandshakeMessage() error {
                return c.handleNewSessionTicket(msg)
        case *keyUpdateMsg:
                return c.handleKeyUpdate(msg)
-       default:
-               c.sendAlert(alertUnexpectedMessage)
-               return fmt.Errorf("tls: received unexpected handshake message of type %T", msg)
        }
+       // The QUIC layer is supposed to treat an unexpected post-handshake CertificateRequest
+       // as a QUIC-level PROTOCOL_VIOLATION error (RFC 9001, Section 4.4). Returning an
+       // unexpected_message alert here doesn't provide it with enough information to distinguish
+       // this condition from other unexpected messages. This is probably fine.
+       c.sendAlert(alertUnexpectedMessage)
+       return fmt.Errorf("tls: received unexpected handshake message of type %T", msg)
 }
 
 func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
+       if c.quic != nil {
+               c.sendAlert(alertUnexpectedMessage)
+               return c.in.setErrorLocked(errors.New("tls: received unexpected key update message"))
+       }
+
        cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite)
        if cipherSuite == nil {
                return c.in.setErrorLocked(c.sendAlert(alertInternalError))
        }
 
        newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret)
-       c.in.setTrafficSecret(cipherSuite, newSecret)
+       c.in.setTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret)
 
        if keyUpdate.updateRequested {
                c.out.Lock()
@@ -1293,7 +1339,7 @@ func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
                }
 
                newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret)
-               c.out.setTrafficSecret(cipherSuite, newSecret)
+               c.out.setTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret)
        }
 
        return nil
@@ -1454,12 +1500,15 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
        // this cancellation. In the former case, we need to close the connection.
        defer cancel()
 
-       // Start the "interrupter" goroutine, if this context might be canceled.
-       // (The background context cannot).
-       //
-       // The interrupter goroutine waits for the input context to be done and
-       // closes the connection if this happens before the function returns.
-       if ctx.Done() != nil {
+       if c.quic != nil {
+               c.quic.cancelc = handshakeCtx.Done()
+               c.quic.cancel = cancel
+       } else if ctx.Done() != nil {
+               // Start the "interrupter" goroutine, if this context might be canceled.
+               // (The background context cannot).
+               //
+               // The interrupter goroutine waits for the input context to be done and
+               // closes the connection if this happens before the function returns.
                done := make(chan struct{})
                interruptRes := make(chan error, 1)
                defer func() {
@@ -1510,6 +1559,30 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
                panic("tls: internal error: handshake returned an error but is marked successful")
        }
 
+       if c.quic != nil {
+               if c.handshakeErr == nil {
+                       c.quicHandshakeComplete()
+                       // Provide the 1-RTT read secret now that the handshake is complete.
+                       // The QUIC layer MUST NOT decrypt 1-RTT packets prior to completing
+                       // the handshake (RFC 9001, Section 5.7).
+                       c.quicSetReadSecret(QUICEncryptionLevelApplication, c.cipherSuite, c.in.trafficSecret)
+               } else {
+                       var a alert
+                       c.out.Lock()
+                       if !errors.As(c.out.err, &a) {
+                               a = alertInternalError
+                       }
+                       c.out.Unlock()
+                       // Return an error which wraps both the handshake error and
+                       // any alert error we may have sent, or alertInternalError
+                       // if we didn't send an alert.
+                       // Truncate the text of the alert to 0 characters.
+                       c.handshakeErr = fmt.Errorf("%w%.0w", c.handshakeErr, AlertError(a))
+               }
+               close(c.quic.blockedc)
+               close(c.quic.signalc)
+       }
+
        return c.handshakeErr
 }
 
index 63d86b9f3a7ef1b4155b45f559ac41a787f4b5cb..9f74cc4ef9723eb69d95013e7fc7a78a20a5d482 100644 (file)
@@ -71,7 +71,6 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) {
                vers:                         clientHelloVersion,
                compressionMethods:           []uint8{compressionNone},
                random:                       make([]byte, 32),
-               sessionId:                    make([]byte, 32),
                ocspStapling:                 true,
                scts:                         true,
                serverName:                   hostnameInSNI(config.ServerName),
@@ -114,8 +113,13 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) {
        // A random session ID is used to detect when the server accepted a ticket
        // and is resuming a session (see RFC 5077). In TLS 1.3, it's always set as
        // a compatibility measure (see RFC 8446, Section 4.1.2).
-       if _, err := io.ReadFull(config.rand(), hello.sessionId); err != nil {
-               return nil, nil, errors.New("tls: short read from Rand: " + err.Error())
+       //
+       // The session ID is not set for QUIC connections (see RFC 9001, Section 8.4).
+       if c.quic == nil {
+               hello.sessionId = make([]byte, 32)
+               if _, err := io.ReadFull(config.rand(), hello.sessionId); err != nil {
+                       return nil, nil, errors.New("tls: short read from Rand: " + err.Error())
+               }
        }
 
        if hello.vers >= VersionTLS12 {
@@ -144,6 +148,17 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) {
                hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}}
        }
 
+       if c.quic != nil {
+               p, err := c.quicGetTransportParameters()
+               if err != nil {
+                       return nil, nil, err
+               }
+               if p == nil {
+                       p = []byte{}
+               }
+               hello.quicTransportParameters = p
+       }
+
        return hello, key, nil
 }
 
@@ -271,7 +286,10 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string,
        }
 
        // Try to resume a previously negotiated TLS session, if available.
-       cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config)
+       cacheKey = c.clientSessionCacheKey()
+       if cacheKey == "" {
+               return "", nil, nil, nil, nil
+       }
        session, ok := c.config.ClientSessionCache.Get(cacheKey)
        if !ok || session == nil {
                return cacheKey, nil, nil, nil, nil
@@ -722,7 +740,7 @@ func (hs *clientHandshakeState) processServerHello() (bool, error) {
                }
        }
 
-       if err := checkALPN(hs.hello.alpnProtocols, hs.serverHello.alpnProtocol); err != nil {
+       if err := checkALPN(hs.hello.alpnProtocols, hs.serverHello.alpnProtocol, false); err != nil {
                c.sendAlert(alertUnsupportedExtension)
                return false, err
        }
@@ -760,8 +778,12 @@ func (hs *clientHandshakeState) processServerHello() (bool, error) {
 
 // checkALPN ensure that the server's choice of ALPN protocol is compatible with
 // the protocols that we advertised in the Client Hello.
-func checkALPN(clientProtos []string, serverProto string) error {
+func checkALPN(clientProtos []string, serverProto string, quic bool) error {
        if serverProto == "" {
+               if quic && len(clientProtos) > 0 {
+                       // RFC 9001, Section 8.1
+                       return errors.New("tls: server did not select an ALPN protocol")
+               }
                return nil
        }
        if len(clientProtos) == 0 {
@@ -1003,11 +1025,14 @@ func (c *Conn) getClientCertificate(cri *CertificateRequestInfo) (*Certificate,
 
 // clientSessionCacheKey returns a key used to cache sessionTickets that could
 // be used to resume previously negotiated TLS sessions with a server.
-func clientSessionCacheKey(serverAddr net.Addr, config *Config) string {
-       if len(config.ServerName) > 0 {
-               return config.ServerName
+func (c *Conn) clientSessionCacheKey() string {
+       if len(c.config.ServerName) > 0 {
+               return c.config.ServerName
+       }
+       if c.conn != nil {
+               return c.conn.RemoteAddr().String()
        }
-       return serverAddr.String()
+       return ""
 }
 
 // hostnameInSNI converts name into an appropriate hostname for SNI.
index 4a8661085ebf57245138022c8449c76fc8375499..15e0a748485717839bcbb066172d0fa9c1299416 100644 (file)
@@ -172,6 +172,9 @@ func (hs *clientHandshakeStateTLS13) checkServerHelloOrHRR() error {
 // sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility
 // with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4.
 func (hs *clientHandshakeStateTLS13) sendDummyChangeCipherSpec() error {
+       if hs.c.quic != nil {
+               return nil
+       }
        if hs.sentDummyCCS {
                return nil
        }
@@ -383,10 +386,18 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error {
 
        clientSecret := hs.suite.deriveSecret(handshakeSecret,
                clientHandshakeTrafficLabel, hs.transcript)
-       c.out.setTrafficSecret(hs.suite, clientSecret)
+       c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, clientSecret)
        serverSecret := hs.suite.deriveSecret(handshakeSecret,
                serverHandshakeTrafficLabel, hs.transcript)
-       c.in.setTrafficSecret(hs.suite, serverSecret)
+       c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, serverSecret)
+
+       if c.quic != nil {
+               if c.hand.Len() != 0 {
+                       c.sendAlert(alertUnexpectedMessage)
+               }
+               c.quicSetWriteSecret(QUICEncryptionLevelHandshake, hs.suite.id, clientSecret)
+               c.quicSetReadSecret(QUICEncryptionLevelHandshake, hs.suite.id, serverSecret)
+       }
 
        err = c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret)
        if err != nil {
@@ -419,12 +430,30 @@ func (hs *clientHandshakeStateTLS13) readServerParameters() error {
                return unexpectedMessageError(encryptedExtensions, msg)
        }
 
-       if err := checkALPN(hs.hello.alpnProtocols, encryptedExtensions.alpnProtocol); err != nil {
-               c.sendAlert(alertUnsupportedExtension)
+       if err := checkALPN(hs.hello.alpnProtocols, encryptedExtensions.alpnProtocol, c.quic != nil); err != nil {
+               // RFC 8446 specifies that no_application_protocol is sent by servers, but
+               // does not specify how clients handle the selection of an incompatible protocol.
+               // RFC 9001 Section 8.1 specifies that QUIC clients send no_application_protocol
+               // in this case. Always sending no_application_protocol seems reasonable.
+               c.sendAlert(alertNoApplicationProtocol)
                return err
        }
        c.clientProtocol = encryptedExtensions.alpnProtocol
 
+       if c.quic != nil {
+               if encryptedExtensions.quicTransportParameters == nil {
+                       // RFC 9001 Section 8.2.
+                       c.sendAlert(alertMissingExtension)
+                       return errors.New("tls: server did not send a quic_transport_parameters extension")
+               }
+               c.quicSetTransportParameters(encryptedExtensions.quicTransportParameters)
+       } else {
+               if encryptedExtensions.quicTransportParameters != nil {
+                       c.sendAlert(alertUnsupportedExtension)
+                       return errors.New("tls: server sent an unexpected quic_transport_parameters extension")
+               }
+       }
+
        return nil
 }
 
@@ -552,7 +581,7 @@ func (hs *clientHandshakeStateTLS13) readServerFinished() error {
                clientApplicationTrafficLabel, hs.transcript)
        serverSecret := hs.suite.deriveSecret(hs.masterSecret,
                serverApplicationTrafficLabel, hs.transcript)
-       c.in.setTrafficSecret(hs.suite, serverSecret)
+       c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, serverSecret)
 
        err = c.config.writeKeyLog(keyLogLabelClientTraffic, hs.hello.random, hs.trafficSecret)
        if err != nil {
@@ -648,13 +677,20 @@ func (hs *clientHandshakeStateTLS13) sendClientFinished() error {
                return err
        }
 
-       c.out.setTrafficSecret(hs.suite, hs.trafficSecret)
+       c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, hs.trafficSecret)
 
        if !c.config.SessionTicketsDisabled && c.config.ClientSessionCache != nil {
                c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret,
                        resumptionLabel, hs.transcript)
        }
 
+       if c.quic != nil {
+               if c.hand.Len() != 0 {
+                       c.sendAlert(alertUnexpectedMessage)
+               }
+               c.quicSetWriteSecret(QUICEncryptionLevelApplication, hs.suite.id, hs.trafficSecret)
+       }
+
        return nil
 }
 
@@ -702,8 +738,10 @@ func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error {
                scts:               c.scts,
        }
 
-       cacheKey := clientSessionCacheKey(c.conn.RemoteAddr(), c.config)
-       c.config.ClientSessionCache.Put(cacheKey, session)
+       cacheKey := c.clientSessionCacheKey()
+       if cacheKey != "" {
+               c.config.ClientSessionCache.Put(cacheKey, session)
+       }
 
        return nil
 }
index 695aacf126a15dd29de193c5677e0ce9f9d193fc..eac01fd085df7301f904650ba680e8e72f11315c 100644 (file)
@@ -93,6 +93,7 @@ type clientHelloMsg struct {
        pskModes                         []uint8
        pskIdentities                    []pskIdentity
        pskBinders                       [][]byte
+       quicTransportParameters          []byte
 }
 
 func (m *clientHelloMsg) marshal() ([]byte, error) {
@@ -246,6 +247,13 @@ func (m *clientHelloMsg) marshal() ([]byte, error) {
                        })
                })
        }
+       if m.quicTransportParameters != nil { // marshal zero-length parameters when present
+               // RFC 9001, Section 8.2
+               exts.AddUint16(extensionQUICTransportParameters)
+               exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+                       exts.AddBytes(m.quicTransportParameters)
+               })
+       }
        if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension
                // RFC 8446, Section 4.2.11
                exts.AddUint16(extensionPreSharedKey)
@@ -560,6 +568,11 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
                        if !readUint8LengthPrefixed(&extData, &m.pskModes) {
                                return false
                        }
+               case extensionQUICTransportParameters:
+                       m.quicTransportParameters = make([]byte, len(extData))
+                       if !extData.CopyBytes(m.quicTransportParameters) {
+                               return false
+                       }
                case extensionPreSharedKey:
                        // RFC 8446, Section 4.2.11
                        if !extensions.Empty() {
@@ -860,8 +873,9 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool {
 }
 
 type encryptedExtensionsMsg struct {
-       raw          []byte
-       alpnProtocol string
+       raw                     []byte
+       alpnProtocol            string
+       quicTransportParameters []byte
 }
 
 func (m *encryptedExtensionsMsg) marshal() ([]byte, error) {
@@ -883,6 +897,13 @@ func (m *encryptedExtensionsMsg) marshal() ([]byte, error) {
                                        })
                                })
                        }
+                       if m.quicTransportParameters != nil { // marshal zero-length parameters when present
+                               // draft-ietf-quic-tls-32, Section 8.2
+                               b.AddUint16(extensionQUICTransportParameters)
+                               b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                                       b.AddBytes(m.quicTransportParameters)
+                               })
+                       }
                })
        })
 
@@ -921,6 +942,11 @@ func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
                                return false
                        }
                        m.alpnProtocol = string(proto)
+               case extensionQUICTransportParameters:
+                       m.quicTransportParameters = make([]byte, len(extData))
+                       if !extData.CopyBytes(m.quicTransportParameters) {
+                               return false
+                       }
                default:
                        // Ignore unknown extensions.
                        continue
index 206e2fb024febca5adfe3c7a58b9815804dbdb6c..1ef6c432ffad51192778c760a89b8e02abf9921c 100644 (file)
@@ -197,6 +197,9 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
                m.pskIdentities = append(m.pskIdentities, psk)
                m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand))
        }
+       if rand.Intn(10) > 5 {
+               m.quicTransportParameters = randomBytes(rand.Intn(500), rand)
+       }
        if rand.Intn(10) > 5 {
                m.earlyData = true
        }
index a17ba2fe27cebb17fe34f43ab694a51454c7d6ea..450c5f77147794cdb1a51bce8254d6b1796ee482 100644 (file)
@@ -218,7 +218,7 @@ func (hs *serverHandshakeState) processClientHello() error {
                c.serverName = hs.clientHello.serverName
        }
 
-       selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols)
+       selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols, false)
        if err != nil {
                c.sendAlert(alertNoApplicationProtocol)
                return err
@@ -279,8 +279,12 @@ func (hs *serverHandshakeState) processClientHello() error {
 // negotiateALPN picks a shared ALPN protocol that both sides support in server
 // preference order. If ALPN is not configured or the peer doesn't support it,
 // it returns "" and no error.
-func negotiateALPN(serverProtos, clientProtos []string) (string, error) {
+func negotiateALPN(serverProtos, clientProtos []string, quic bool) (string, error) {
        if len(serverProtos) == 0 || len(clientProtos) == 0 {
+               if quic && len(serverProtos) != 0 {
+                       // RFC 9001, Section 8.1
+                       return "", fmt.Errorf("tls: client did not request an application protocol")
+               }
                return "", nil
        }
        var http11fallback bool
index b7b568cd84ac80a8108ae2fa461c08555db2d98b..69ebe1c7d59659564820ab9217047e50571a1776 100644 (file)
@@ -226,6 +226,20 @@ GroupSelection:
                return errors.New("tls: invalid client key share")
        }
 
+       if c.quic != nil {
+               if hs.clientHello.quicTransportParameters == nil {
+                       // RFC 9001 Section 8.2.
+                       c.sendAlert(alertMissingExtension)
+                       return errors.New("tls: client did not send a quic_transport_parameters extension")
+               }
+               c.quicSetTransportParameters(hs.clientHello.quicTransportParameters)
+       } else {
+               if hs.clientHello.quicTransportParameters != nil {
+                       c.sendAlert(alertUnsupportedExtension)
+                       return errors.New("tls: client sent an unexpected quic_transport_parameters extension")
+               }
+       }
+
        c.serverName = hs.clientHello.serverName
        return nil
 }
@@ -397,6 +411,9 @@ func (hs *serverHandshakeStateTLS13) pickCertificate() error {
 // sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility
 // with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4.
 func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error {
+       if hs.c.quic != nil {
+               return nil
+       }
        if hs.sentDummyCCS {
                return nil
        }
@@ -548,10 +565,18 @@ func (hs *serverHandshakeStateTLS13) sendServerParameters() error {
 
        clientSecret := hs.suite.deriveSecret(hs.handshakeSecret,
                clientHandshakeTrafficLabel, hs.transcript)
-       c.in.setTrafficSecret(hs.suite, clientSecret)
+       c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, clientSecret)
        serverSecret := hs.suite.deriveSecret(hs.handshakeSecret,
                serverHandshakeTrafficLabel, hs.transcript)
-       c.out.setTrafficSecret(hs.suite, serverSecret)
+       c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, serverSecret)
+
+       if c.quic != nil {
+               if c.hand.Len() != 0 {
+                       c.sendAlert(alertUnexpectedMessage)
+               }
+               c.quicSetWriteSecret(QUICEncryptionLevelHandshake, hs.suite.id, serverSecret)
+               c.quicSetReadSecret(QUICEncryptionLevelHandshake, hs.suite.id, clientSecret)
+       }
 
        err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.clientHello.random, clientSecret)
        if err != nil {
@@ -566,7 +591,7 @@ func (hs *serverHandshakeStateTLS13) sendServerParameters() error {
 
        encryptedExtensions := new(encryptedExtensionsMsg)
 
-       selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols)
+       selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols, c.quic != nil)
        if err != nil {
                c.sendAlert(alertNoApplicationProtocol)
                return err
@@ -574,6 +599,14 @@ func (hs *serverHandshakeStateTLS13) sendServerParameters() error {
        encryptedExtensions.alpnProtocol = selectedProto
        c.clientProtocol = selectedProto
 
+       if c.quic != nil {
+               p, err := c.quicGetTransportParameters()
+               if err != nil {
+                       return err
+               }
+               encryptedExtensions.quicTransportParameters = p
+       }
+
        if _, err := hs.c.writeHandshakeRecord(encryptedExtensions, hs.transcript); err != nil {
                return err
        }
@@ -672,7 +705,15 @@ func (hs *serverHandshakeStateTLS13) sendServerFinished() error {
                clientApplicationTrafficLabel, hs.transcript)
        serverSecret := hs.suite.deriveSecret(hs.masterSecret,
                serverApplicationTrafficLabel, hs.transcript)
-       c.out.setTrafficSecret(hs.suite, serverSecret)
+       c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, serverSecret)
+
+       if c.quic != nil {
+               if c.hand.Len() != 0 {
+                       // TODO: Handle this in setTrafficSecret?
+                       c.sendAlert(alertUnexpectedMessage)
+               }
+               c.quicSetWriteSecret(QUICEncryptionLevelApplication, hs.suite.id, serverSecret)
+       }
 
        err := c.config.writeKeyLog(keyLogLabelClientTraffic, hs.clientHello.random, hs.trafficSecret)
        if err != nil {
@@ -887,7 +928,7 @@ func (hs *serverHandshakeStateTLS13) readClientFinished() error {
                return errors.New("tls: invalid client finished hash")
        }
 
-       c.in.setTrafficSecret(hs.suite, hs.trafficSecret)
+       c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, hs.trafficSecret)
 
        return nil
 }
diff --git a/src/crypto/tls/quic.go b/src/crypto/tls/quic.go
new file mode 100644 (file)
index 0000000..a59b893
--- /dev/null
@@ -0,0 +1,376 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tls
+
+import (
+       "context"
+       "errors"
+       "fmt"
+)
+
+// QUICEncryptionLevel represents a QUIC encryption level used to transmit
+// handshake messages.
+type QUICEncryptionLevel int
+
+const (
+       QUICEncryptionLevelInitial = QUICEncryptionLevel(iota)
+       QUICEncryptionLevelHandshake
+       QUICEncryptionLevelApplication
+)
+
+func (l QUICEncryptionLevel) String() string {
+       switch l {
+       case QUICEncryptionLevelInitial:
+               return "Initial"
+       case QUICEncryptionLevelHandshake:
+               return "Handshake"
+       case QUICEncryptionLevelApplication:
+               return "Application"
+       default:
+               return fmt.Sprintf("QUICEncryptionLevel(%v)", int(l))
+       }
+}
+
+// A QUICConn represents a connection which uses a QUIC implementation as the underlying
+// transport as described in RFC 9001.
+//
+// Methods of QUICConn are not safe for concurrent use.
+type QUICConn struct {
+       conn *Conn
+}
+
+// A QUICConfig configures a QUICConn.
+type QUICConfig struct {
+       TLSConfig *Config
+}
+
+// A QUICEventKind is a type of operation on a QUIC connection.
+type QUICEventKind int
+
+const (
+       // QUICNoEvent indicates that there are no events available.
+       QUICNoEvent QUICEventKind = iota
+
+       // QUICSetReadSecret and QUICSetWriteSecret provide the read and write
+       // secrets for a given encryption level.
+       // QUICEvent.Level, QUICEvent.Data, and QUICEvent.Suite are set.
+       //
+       // Secrets for the Initial encryption level are derived from the initial
+       // destination connection ID, and are not provided by the QUICConn.
+       QUICSetReadSecret
+       QUICSetWriteSecret
+
+       // QUICWriteData provides data to send to the peer in CRYPTO frames.
+       // QUICEvent.Data is set.
+       QUICWriteData
+
+       // QUICTransportParameters provides the peer's QUIC transport parameters.
+       // QUICEvent.Data is set.
+       QUICTransportParameters
+
+       // QUICTransportParametersRequired indicates that the caller must provide
+       // QUIC transport parameters to send to the peer. The caller should set
+       // the transport parameters with QUICConn.SetTransportParameters and call
+       // QUICConn.NextEvent again.
+       //
+       // If transport parameters are set before calling QUICConn.Start, the
+       // connection will never generate a QUICTransportParametersRequired event.
+       QUICTransportParametersRequired
+
+       // QUICHandshakeDone indicates that the TLS handshake has completed.
+       QUICHandshakeDone
+)
+
+// A QUICEvent is an event occurring on a QUIC connection.
+//
+// The type of event is specified by the Kind field.
+// The contents of the other fields are kind-specific.
+type QUICEvent struct {
+       Kind QUICEventKind
+
+       // Set for QUICSetReadSecret, QUICSetWriteSecret, and QUICWriteData.
+       Level QUICEncryptionLevel
+
+       // Set for QUICTransportParameters, QUICSetReadSecret, QUICSetWriteSecret, and QUICWriteData.
+       // The contents are owned by crypto/tls, and are valid until the next NextEvent call.
+       Data []byte
+
+       // Set for QUICSetReadSecret and QUICSetWriteSecret.
+       Suite uint16
+}
+
+type quicState struct {
+       events    []QUICEvent
+       nextEvent int
+
+       // eventArr is a statically allocated event array, large enough to handle
+       // the usual maximum number of events resulting from a single call:
+       // transport parameters, Initial data, Handshake write and read secrets,
+       // Handshake data, Application write secret, Application data.
+       eventArr [7]QUICEvent
+
+       started  bool
+       signalc  chan struct{}   // handshake data is available to be read
+       blockedc chan struct{}   // handshake is waiting for data, closed when done
+       cancelc  <-chan struct{} // handshake has been canceled
+       cancel   context.CancelFunc
+
+       // readbuf is shared between HandleData and the handshake goroutine.
+       // HandshakeCryptoData passes ownership to the handshake goroutine by
+       // reading from signalc, and reclaims ownership by reading from blockedc.
+       readbuf []byte
+
+       transportParams []byte // to send to the peer
+}
+
+// QUICClient returns a new TLS client side connection using QUICTransport as the
+// underlying transport. The config cannot be nil.
+//
+// The config's MinVersion must be at least TLS 1.3.
+func QUICClient(config *QUICConfig) *QUICConn {
+       return newQUICConn(Client(nil, config.TLSConfig))
+}
+
+// QUICServer returns a new TLS server side connection using QUICTransport as the
+// underlying transport. The config cannot be nil.
+//
+// The config's MinVersion must be at least TLS 1.3.
+func QUICServer(config *QUICConfig) *QUICConn {
+       return newQUICConn(Server(nil, config.TLSConfig))
+}
+
+func newQUICConn(conn *Conn) *QUICConn {
+       conn.quic = &quicState{
+               signalc:  make(chan struct{}),
+               blockedc: make(chan struct{}),
+       }
+       conn.quic.events = conn.quic.eventArr[:0]
+       return &QUICConn{
+               conn: conn,
+       }
+}
+
+// Start starts the client or server handshake protocol.
+// It may produce connection events, which may be read with NextEvent.
+//
+// Start must be called at most once.
+func (q *QUICConn) Start(ctx context.Context) error {
+       if q.conn.quic.started {
+               return quicError(errors.New("tls: Start called more than once"))
+       }
+       q.conn.quic.started = true
+       if q.conn.config.MinVersion < VersionTLS13 {
+               return quicError(errors.New("tls: Config MinVersion must be at least TLS 1.13"))
+       }
+       go q.conn.HandshakeContext(ctx)
+       if _, ok := <-q.conn.quic.blockedc; !ok {
+               return q.conn.handshakeErr
+       }
+       return nil
+}
+
+// NextEvent returns the next event occurring on the connection.
+// It returns an event with a Kind of QUICNoEvent when no events are available.
+func (q *QUICConn) NextEvent() QUICEvent {
+       qs := q.conn.quic
+       if last := qs.nextEvent - 1; last >= 0 && len(qs.events[last].Data) > 0 {
+               // Write over some of the previous event's data,
+               // to catch callers erroniously retaining it.
+               qs.events[last].Data[0] = 0
+       }
+       if qs.nextEvent >= len(qs.events) {
+               qs.events = qs.events[:0]
+               qs.nextEvent = 0
+               return QUICEvent{Kind: QUICNoEvent}
+       }
+       e := qs.events[qs.nextEvent]
+       qs.events[qs.nextEvent] = QUICEvent{} // zero out references to data
+       qs.nextEvent++
+       return e
+}
+
+// Close closes the connection and stops any in-progress handshake.
+func (q *QUICConn) Close() error {
+       if q.conn.quic.cancel == nil {
+               return nil // never started
+       }
+       q.conn.quic.cancel()
+       for range q.conn.quic.blockedc {
+               // Wait for the handshake goroutine to return.
+       }
+       return q.conn.handshakeErr
+}
+
+// HandleData handles handshake bytes received from the peer.
+// It may produce connection events, which may be read with NextEvent.
+func (q *QUICConn) HandleData(level QUICEncryptionLevel, data []byte) error {
+       c := q.conn
+       if c.in.level != level {
+               return quicError(c.in.setErrorLocked(errors.New("tls: handshake data received at wrong level")))
+       }
+       c.quic.readbuf = data
+       <-c.quic.signalc
+       _, ok := <-c.quic.blockedc
+       if ok {
+               // The handshake goroutine is waiting for more data.
+               return nil
+       }
+       // The handshake goroutine has exited.
+       c.hand.Write(c.quic.readbuf)
+       c.quic.readbuf = nil
+       for q.conn.hand.Len() >= 4 && q.conn.handshakeErr == nil {
+               b := q.conn.hand.Bytes()
+               n := int(b[1])<<16 | int(b[2])<<8 | int(b[3])
+               if 4+n < len(b) {
+                       return nil
+               }
+               if err := q.conn.handlePostHandshakeMessage(); err != nil {
+                       return quicError(err)
+               }
+       }
+       if q.conn.handshakeErr != nil {
+               return quicError(q.conn.handshakeErr)
+       }
+       return nil
+}
+
+// ConnectionState returns basic TLS details about the connection.
+func (q *QUICConn) ConnectionState() ConnectionState {
+       return q.conn.ConnectionState()
+}
+
+// SetTransportParameters sets the transport parameters to send to the peer.
+//
+// Server connections may delay setting the transport parameters until after
+// receiving the client's transport parameters. See QUICTransportParametersRequired.
+func (q *QUICConn) SetTransportParameters(params []byte) {
+       if params == nil {
+               params = []byte{}
+       }
+       q.conn.quic.transportParams = params
+       if q.conn.quic.started {
+               <-q.conn.quic.signalc
+               <-q.conn.quic.blockedc
+       }
+}
+
+// quicError ensures err is an AlertError.
+// If err is not already, quicError wraps it with alertInternalError.
+func quicError(err error) error {
+       if err == nil {
+               return nil
+       }
+       var ae AlertError
+       if errors.As(err, &ae) {
+               return err
+       }
+       var a alert
+       if !errors.As(err, &a) {
+               a = alertInternalError
+       }
+       // Return an error wrapping the original error and an AlertError.
+       // Truncate the text of the alert to 0 characters.
+       return fmt.Errorf("%w%.0w", err, AlertError(a))
+}
+
+func (c *Conn) quicReadHandshakeBytes(n int) error {
+       for c.hand.Len() < n {
+               if err := c.quicWaitForSignal(); err != nil {
+                       return err
+               }
+       }
+       return nil
+}
+
+func (c *Conn) quicSetReadSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
+       c.quic.events = append(c.quic.events, QUICEvent{
+               Kind:  QUICSetReadSecret,
+               Level: level,
+               Suite: suite,
+               Data:  secret,
+       })
+}
+
+func (c *Conn) quicSetWriteSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
+       c.quic.events = append(c.quic.events, QUICEvent{
+               Kind:  QUICSetWriteSecret,
+               Level: level,
+               Suite: suite,
+               Data:  secret,
+       })
+}
+
+func (c *Conn) quicWriteCryptoData(level QUICEncryptionLevel, data []byte) {
+       var last *QUICEvent
+       if len(c.quic.events) > 0 {
+               last = &c.quic.events[len(c.quic.events)-1]
+       }
+       if last == nil || last.Kind != QUICWriteData || last.Level != level {
+               c.quic.events = append(c.quic.events, QUICEvent{
+                       Kind:  QUICWriteData,
+                       Level: level,
+               })
+               last = &c.quic.events[len(c.quic.events)-1]
+       }
+       last.Data = append(last.Data, data...)
+}
+
+func (c *Conn) quicSetTransportParameters(params []byte) {
+       c.quic.events = append(c.quic.events, QUICEvent{
+               Kind: QUICTransportParameters,
+               Data: params,
+       })
+}
+
+func (c *Conn) quicGetTransportParameters() ([]byte, error) {
+       if c.quic.transportParams == nil {
+               c.quic.events = append(c.quic.events, QUICEvent{
+                       Kind: QUICTransportParametersRequired,
+               })
+       }
+       for c.quic.transportParams == nil {
+               if err := c.quicWaitForSignal(); err != nil {
+                       return nil, err
+               }
+       }
+       return c.quic.transportParams, nil
+}
+
+func (c *Conn) quicHandshakeComplete() {
+       c.quic.events = append(c.quic.events, QUICEvent{
+               Kind: QUICHandshakeDone,
+       })
+}
+
+// quicWaitForSignal notifies the QUICConn that handshake progress is blocked,
+// and waits for a signal that the handshake should proceed.
+//
+// The handshake may become blocked waiting for handshake bytes
+// or for the user to provide transport parameters.
+func (c *Conn) quicWaitForSignal() error {
+       // Drop the handshake mutex while blocked to allow the user
+       // to call ConnectionState before the handshake completes.
+       c.handshakeMutex.Unlock()
+       defer c.handshakeMutex.Lock()
+       // Send on blockedc to notify the QUICConn that the handshake is blocked.
+       // Exported methods of QUICConn wait for the handshake to become blocked
+       // before returning to the user.
+       select {
+       case c.quic.blockedc <- struct{}{}:
+       case <-c.quic.cancelc:
+               return c.sendAlertLocked(alertCloseNotify)
+       }
+       // The QUICConn reads from signalc to notify us that the handshake may
+       // be able to proceed. (The QUICConn reads, because we close signalc to
+       // indicate that the handshake has completed.)
+       select {
+       case c.quic.signalc <- struct{}{}:
+               c.hand.Write(c.quic.readbuf)
+               c.quic.readbuf = nil
+       case <-c.quic.cancelc:
+               return c.sendAlertLocked(alertCloseNotify)
+       }
+       return nil
+}
diff --git a/src/crypto/tls/quic_test.go b/src/crypto/tls/quic_test.go
new file mode 100644 (file)
index 0000000..58054de
--- /dev/null
@@ -0,0 +1,430 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tls
+
+import (
+       "context"
+       "errors"
+       "reflect"
+       "testing"
+)
+
+type testQUICConn struct {
+       t           *testing.T
+       conn        *QUICConn
+       readSecret  map[QUICEncryptionLevel]suiteSecret
+       writeSecret map[QUICEncryptionLevel]suiteSecret
+       gotParams   []byte
+       complete    bool
+}
+
+func newTestQUICClient(t *testing.T, config *Config) *testQUICConn {
+       q := &testQUICConn{t: t}
+       q.conn = QUICClient(&QUICConfig{
+               TLSConfig: config,
+       })
+       t.Cleanup(func() {
+               q.conn.Close()
+       })
+       return q
+}
+
+func newTestQUICServer(t *testing.T, config *Config) *testQUICConn {
+       q := &testQUICConn{t: t}
+       q.conn = QUICServer(&QUICConfig{
+               TLSConfig: config,
+       })
+       t.Cleanup(func() {
+               q.conn.Close()
+       })
+       return q
+}
+
+type suiteSecret struct {
+       suite  uint16
+       secret []byte
+}
+
+func (q *testQUICConn) setReadSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
+       if _, ok := q.writeSecret[level]; !ok {
+               q.t.Errorf("SetReadSecret for level %v called before SetWriteSecret", level)
+       }
+       if level == QUICEncryptionLevelApplication && !q.complete {
+               q.t.Errorf("SetReadSecret for level %v called before HandshakeComplete", level)
+       }
+       if _, ok := q.readSecret[level]; ok {
+               q.t.Errorf("SetReadSecret for level %v called twice", level)
+       }
+       if q.readSecret == nil {
+               q.readSecret = map[QUICEncryptionLevel]suiteSecret{}
+       }
+       switch level {
+       case QUICEncryptionLevelHandshake, QUICEncryptionLevelApplication:
+               q.readSecret[level] = suiteSecret{suite, secret}
+       default:
+               q.t.Errorf("SetReadSecret for unexpected level %v", level)
+       }
+}
+
+func (q *testQUICConn) setWriteSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
+       if _, ok := q.writeSecret[level]; ok {
+               q.t.Errorf("SetWriteSecret for level %v called twice", level)
+       }
+       if q.writeSecret == nil {
+               q.writeSecret = map[QUICEncryptionLevel]suiteSecret{}
+       }
+       switch level {
+       case QUICEncryptionLevelHandshake, QUICEncryptionLevelApplication:
+               q.writeSecret[level] = suiteSecret{suite, secret}
+       default:
+               q.t.Errorf("SetWriteSecret for unexpected level %v", level)
+       }
+}
+
+var errTransportParametersRequired = errors.New("transport parameters required")
+
+func runTestQUICConnection(ctx context.Context, a, b *testQUICConn, onHandleCryptoData func()) error {
+       for _, c := range []*testQUICConn{a, b} {
+               if !c.conn.conn.quic.started {
+                       if err := c.conn.Start(ctx); err != nil {
+                               return err
+                       }
+               }
+       }
+       idleCount := 0
+       for {
+               e := a.conn.NextEvent()
+               switch e.Kind {
+               case QUICNoEvent:
+                       idleCount++
+                       if idleCount == 2 {
+                               if !a.complete || !b.complete {
+                                       return errors.New("handshake incomplete")
+                               }
+                               return nil
+                       }
+                       a, b = b, a
+               case QUICSetReadSecret:
+                       a.setReadSecret(e.Level, e.Suite, e.Data)
+               case QUICSetWriteSecret:
+                       a.setWriteSecret(e.Level, e.Suite, e.Data)
+               case QUICWriteData:
+                       if err := b.conn.HandleData(e.Level, e.Data); err != nil {
+                               return err
+                       }
+               case QUICTransportParameters:
+                       a.gotParams = e.Data
+                       if a.gotParams == nil {
+                               a.gotParams = []byte{}
+                       }
+               case QUICTransportParametersRequired:
+                       return errTransportParametersRequired
+               case QUICHandshakeDone:
+                       a.complete = true
+               }
+               if e.Kind != QUICNoEvent {
+                       idleCount = 0
+               }
+       }
+}
+
+func TestQUICConnection(t *testing.T) {
+       config := testConfig.Clone()
+       config.MinVersion = VersionTLS13
+
+       cli := newTestQUICClient(t, config)
+       cli.conn.SetTransportParameters(nil)
+
+       srv := newTestQUICServer(t, config)
+       srv.conn.SetTransportParameters(nil)
+
+       if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
+               t.Fatalf("error during connection handshake: %v", err)
+       }
+
+       if _, ok := cli.readSecret[QUICEncryptionLevelHandshake]; !ok {
+               t.Errorf("client has no Handshake secret")
+       }
+       if _, ok := cli.readSecret[QUICEncryptionLevelApplication]; !ok {
+               t.Errorf("client has no Application secret")
+       }
+       if _, ok := srv.readSecret[QUICEncryptionLevelHandshake]; !ok {
+               t.Errorf("server has no Handshake secret")
+       }
+       if _, ok := srv.readSecret[QUICEncryptionLevelApplication]; !ok {
+               t.Errorf("server has no Application secret")
+       }
+       for _, level := range []QUICEncryptionLevel{QUICEncryptionLevelHandshake, QUICEncryptionLevelApplication} {
+               if _, ok := cli.readSecret[level]; !ok {
+                       t.Errorf("client has no %v read secret", level)
+               }
+               if _, ok := srv.readSecret[level]; !ok {
+                       t.Errorf("server has no %v read secret", level)
+               }
+               if !reflect.DeepEqual(cli.readSecret[level], srv.writeSecret[level]) {
+                       t.Errorf("client read secret does not match server write secret for level %v", level)
+               }
+               if !reflect.DeepEqual(cli.writeSecret[level], srv.readSecret[level]) {
+                       t.Errorf("client write secret does not match server read secret for level %v", level)
+               }
+       }
+}
+
+func TestQUICSessionResumption(t *testing.T) {
+       clientConfig := testConfig.Clone()
+       clientConfig.MinVersion = VersionTLS13
+       clientConfig.ClientSessionCache = NewLRUClientSessionCache(1)
+       clientConfig.ServerName = "example.go.dev"
+
+       serverConfig := testConfig.Clone()
+       serverConfig.MinVersion = VersionTLS13
+
+       cli := newTestQUICClient(t, clientConfig)
+       cli.conn.SetTransportParameters(nil)
+       srv := newTestQUICServer(t, serverConfig)
+       srv.conn.SetTransportParameters(nil)
+       if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
+               t.Fatalf("error during first connection handshake: %v", err)
+       }
+       if cli.conn.ConnectionState().DidResume {
+               t.Errorf("first connection unexpectedly used session resumption")
+       }
+
+       cli2 := newTestQUICClient(t, clientConfig)
+       cli2.conn.SetTransportParameters(nil)
+       srv2 := newTestQUICServer(t, serverConfig)
+       srv2.conn.SetTransportParameters(nil)
+       if err := runTestQUICConnection(context.Background(), cli2, srv2, nil); err != nil {
+               t.Fatalf("error during second connection handshake: %v", err)
+       }
+       if !cli2.conn.ConnectionState().DidResume {
+               t.Errorf("second connection did not use session resumption")
+       }
+}
+
+func TestQUICPostHandshakeClientAuthentication(t *testing.T) {
+       // RFC 9001, Section 4.4.
+       config := testConfig.Clone()
+       config.MinVersion = VersionTLS13
+       cli := newTestQUICClient(t, config)
+       cli.conn.SetTransportParameters(nil)
+       srv := newTestQUICServer(t, config)
+       srv.conn.SetTransportParameters(nil)
+       if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
+               t.Fatalf("error during connection handshake: %v", err)
+       }
+
+       certReq := new(certificateRequestMsgTLS13)
+       certReq.ocspStapling = true
+       certReq.scts = true
+       certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
+       certReqBytes, err := certReq.marshal()
+       if err != nil {
+               t.Fatal(err)
+       }
+       if err := cli.conn.HandleData(QUICEncryptionLevelApplication, append([]byte{
+               byte(typeCertificateRequest),
+               byte(0), byte(0), byte(len(certReqBytes)),
+       }, certReqBytes...)); err == nil {
+               t.Fatalf("post-handshake authentication request: got no error, want one")
+       }
+}
+
+func TestQUICPostHandshakeKeyUpdate(t *testing.T) {
+       // RFC 9001, Section 6.
+       config := testConfig.Clone()
+       config.MinVersion = VersionTLS13
+       cli := newTestQUICClient(t, config)
+       cli.conn.SetTransportParameters(nil)
+       srv := newTestQUICServer(t, config)
+       srv.conn.SetTransportParameters(nil)
+       if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
+               t.Fatalf("error during connection handshake: %v", err)
+       }
+
+       keyUpdate := new(keyUpdateMsg)
+       keyUpdateBytes, err := keyUpdate.marshal()
+       if err != nil {
+               t.Fatal(err)
+       }
+       if err := cli.conn.HandleData(QUICEncryptionLevelApplication, append([]byte{
+               byte(typeKeyUpdate),
+               byte(0), byte(0), byte(len(keyUpdateBytes)),
+       }, keyUpdateBytes...)); !errors.Is(err, alertUnexpectedMessage) {
+               t.Fatalf("key update request: got error %v, want alertUnexpectedMessage", err)
+       }
+}
+
+func TestQUICHandshakeError(t *testing.T) {
+       clientConfig := testConfig.Clone()
+       clientConfig.MinVersion = VersionTLS13
+       clientConfig.InsecureSkipVerify = false
+       clientConfig.ServerName = "name"
+
+       serverConfig := testConfig.Clone()
+       serverConfig.MinVersion = VersionTLS13
+
+       cli := newTestQUICClient(t, clientConfig)
+       cli.conn.SetTransportParameters(nil)
+       srv := newTestQUICServer(t, serverConfig)
+       srv.conn.SetTransportParameters(nil)
+       err := runTestQUICConnection(context.Background(), cli, srv, nil)
+       if !errors.Is(err, AlertError(alertBadCertificate)) {
+               t.Errorf("connection handshake terminated with error %q, want alertBadCertificate", err)
+       }
+       var e *CertificateVerificationError
+       if !errors.As(err, &e) {
+               t.Errorf("connection handshake terminated with error %q, want CertificateVerificationError", err)
+       }
+}
+
+// Test that QUICConn.ConnectionState can be used during the handshake,
+// and that it reports the application protocol as soon as it has been
+// negotiated.
+func TestQUICConnectionState(t *testing.T) {
+       config := testConfig.Clone()
+       config.MinVersion = VersionTLS13
+       config.NextProtos = []string{"h3"}
+       cli := newTestQUICClient(t, config)
+       cli.conn.SetTransportParameters(nil)
+       srv := newTestQUICServer(t, config)
+       srv.conn.SetTransportParameters(nil)
+       onHandleCryptoData := func() {
+               cliCS := cli.conn.ConnectionState()
+               cliWantALPN := ""
+               if _, ok := cli.readSecret[QUICEncryptionLevelApplication]; ok {
+                       cliWantALPN = "h3"
+               }
+               if want, got := cliCS.NegotiatedProtocol, cliWantALPN; want != got {
+                       t.Errorf("cli.ConnectionState().NegotiatedProtocol = %q, want %q", want, got)
+               }
+
+               srvCS := srv.conn.ConnectionState()
+               srvWantALPN := ""
+               if _, ok := srv.readSecret[QUICEncryptionLevelHandshake]; ok {
+                       srvWantALPN = "h3"
+               }
+               if want, got := srvCS.NegotiatedProtocol, srvWantALPN; want != got {
+                       t.Errorf("srv.ConnectionState().NegotiatedProtocol = %q, want %q", want, got)
+               }
+       }
+       if err := runTestQUICConnection(context.Background(), cli, srv, onHandleCryptoData); err != nil {
+               t.Fatalf("error during connection handshake: %v", err)
+       }
+}
+
+func TestQUICStartContextPropagation(t *testing.T) {
+       const key = "key"
+       const value = "value"
+       ctx := context.WithValue(context.Background(), key, value)
+       config := testConfig.Clone()
+       config.MinVersion = VersionTLS13
+       calls := 0
+       config.GetConfigForClient = func(info *ClientHelloInfo) (*Config, error) {
+               calls++
+               got, _ := info.Context().Value(key).(string)
+               if got != value {
+                       t.Errorf("GetConfigForClient context key %q has value %q, want %q", key, got, value)
+               }
+               return nil, nil
+       }
+       cli := newTestQUICClient(t, config)
+       cli.conn.SetTransportParameters(nil)
+       srv := newTestQUICServer(t, config)
+       srv.conn.SetTransportParameters(nil)
+       if err := runTestQUICConnection(ctx, cli, srv, nil); err != nil {
+               t.Fatalf("error during connection handshake: %v", err)
+       }
+       if calls != 1 {
+               t.Errorf("GetConfigForClient called %v times, want 1", calls)
+       }
+}
+
+func TestQUICDelayedTransportParameters(t *testing.T) {
+       clientConfig := testConfig.Clone()
+       clientConfig.MinVersion = VersionTLS13
+       clientConfig.ClientSessionCache = NewLRUClientSessionCache(1)
+       clientConfig.ServerName = "example.go.dev"
+
+       serverConfig := testConfig.Clone()
+       serverConfig.MinVersion = VersionTLS13
+
+       cliParams := "client params"
+       srvParams := "server params"
+
+       cli := newTestQUICClient(t, clientConfig)
+       srv := newTestQUICServer(t, serverConfig)
+       if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != errTransportParametersRequired {
+               t.Fatalf("handshake with no client parameters: %v; want errTransportParametersRequired", err)
+       }
+       cli.conn.SetTransportParameters([]byte(cliParams))
+       if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != errTransportParametersRequired {
+               t.Fatalf("handshake with no server parameters: %v; want errTransportParametersRequired", err)
+       }
+       srv.conn.SetTransportParameters([]byte(srvParams))
+       if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
+               t.Fatalf("error during connection handshake: %v", err)
+       }
+
+       if got, want := string(cli.gotParams), srvParams; got != want {
+               t.Errorf("client got transport params: %q, want %q", got, want)
+       }
+       if got, want := string(srv.gotParams), cliParams; got != want {
+               t.Errorf("server got transport params: %q, want %q", got, want)
+       }
+}
+
+func TestQUICEmptyTransportParameters(t *testing.T) {
+       config := testConfig.Clone()
+       config.MinVersion = VersionTLS13
+
+       cli := newTestQUICClient(t, config)
+       cli.conn.SetTransportParameters(nil)
+       srv := newTestQUICServer(t, config)
+       srv.conn.SetTransportParameters(nil)
+       if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
+               t.Fatalf("error during connection handshake: %v", err)
+       }
+
+       if cli.gotParams == nil {
+               t.Errorf("client did not get transport params")
+       }
+       if srv.gotParams == nil {
+               t.Errorf("server did not get transport params")
+       }
+       if len(cli.gotParams) != 0 {
+               t.Errorf("client got transport params: %v, want empty", cli.gotParams)
+       }
+       if len(srv.gotParams) != 0 {
+               t.Errorf("server got transport params: %v, want empty", srv.gotParams)
+       }
+}
+
+func TestQUICCanceledWaitingForData(t *testing.T) {
+       config := testConfig.Clone()
+       config.MinVersion = VersionTLS13
+       cli := newTestQUICClient(t, config)
+       cli.conn.SetTransportParameters(nil)
+       cli.conn.Start(context.Background())
+       for cli.conn.NextEvent().Kind != QUICNoEvent {
+       }
+       err := cli.conn.Close()
+       if !errors.Is(err, alertCloseNotify) {
+               t.Errorf("conn.Close() = %v, want alertCloseNotify", err)
+       }
+}
+
+func TestQUICCanceledWaitingForTransportParams(t *testing.T) {
+       config := testConfig.Clone()
+       config.MinVersion = VersionTLS13
+       cli := newTestQUICClient(t, config)
+       cli.conn.Start(context.Background())
+       for cli.conn.NextEvent().Kind != QUICTransportParametersRequired {
+       }
+       err := cli.conn.Close()
+       if !errors.Is(err, alertCloseNotify) {
+               t.Errorf("conn.Close() = %v, want alertCloseNotify", err)
+       }
+}