Add a QUICConn type for use by QUIC implementations.
A QUICConn provides unencrypted handshake bytes and connection
secrets to the QUIC layer, and receives handshake bytes.
For #44886
Change-Id: I859dda4cc6d466a1df2fb863a69d3a2a069110d5
Reviewed-on: https://go-review.googlesource.com/c/go/+/493655
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Run-TryBot: Damien Neil <dneil@google.com>
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
Reviewed-by: Marten Seemann <martenseemann@gmail.com>
--- /dev/null
+pkg crypto/tls, const QUICEncryptionLevelApplication = 2 #44886
+pkg crypto/tls, const QUICEncryptionLevelApplication QUICEncryptionLevel #44886
+pkg crypto/tls, const QUICEncryptionLevelHandshake = 1 #44886
+pkg crypto/tls, const QUICEncryptionLevelHandshake QUICEncryptionLevel #44886
+pkg crypto/tls, const QUICEncryptionLevelInitial = 0 #44886
+pkg crypto/tls, const QUICEncryptionLevelInitial QUICEncryptionLevel #44886
+pkg crypto/tls, const QUICHandshakeDone = 6 #44886
+pkg crypto/tls, const QUICHandshakeDone QUICEventKind #44886
+pkg crypto/tls, const QUICNoEvent = 0 #44886
+pkg crypto/tls, const QUICNoEvent QUICEventKind #44886
+pkg crypto/tls, const QUICSetReadSecret = 1 #44886
+pkg crypto/tls, const QUICSetReadSecret QUICEventKind #44886
+pkg crypto/tls, const QUICSetWriteSecret = 2 #44886
+pkg crypto/tls, const QUICSetWriteSecret QUICEventKind #44886
+pkg crypto/tls, const QUICTransportParameters = 4 #44886
+pkg crypto/tls, const QUICTransportParameters QUICEventKind #44886
+pkg crypto/tls, const QUICTransportParametersRequired = 5 #44886
+pkg crypto/tls, const QUICTransportParametersRequired QUICEventKind #44886
+pkg crypto/tls, const QUICWriteData = 3 #44886
+pkg crypto/tls, const QUICWriteData QUICEventKind #44886
+pkg crypto/tls, func QUICClient(*QUICConfig) *QUICConn #44886
+pkg crypto/tls, func QUICServer(*QUICConfig) *QUICConn #44886
+pkg crypto/tls, method (*QUICConn) Close() error #44886
+pkg crypto/tls, method (*QUICConn) ConnectionState() ConnectionState #44886
+pkg crypto/tls, method (*QUICConn) HandleData(QUICEncryptionLevel, []uint8) error #44886
+pkg crypto/tls, method (*QUICConn) NextEvent() QUICEvent #44886
+pkg crypto/tls, method (*QUICConn) SetTransportParameters([]uint8) #44886
+pkg crypto/tls, method (*QUICConn) Start(context.Context) error #44886
+pkg crypto/tls, method (AlertError) Error() string #44886
+pkg crypto/tls, method (QUICEncryptionLevel) String() string #44886
+pkg crypto/tls, type AlertError uint8 #44886
+pkg crypto/tls, type QUICConfig struct #44886
+pkg crypto/tls, type QUICConfig struct, TLSConfig *Config #44886
+pkg crypto/tls, type QUICConn struct #44886
+pkg crypto/tls, type QUICEncryptionLevel int #44886
+pkg crypto/tls, type QUICEvent struct #44886
+pkg crypto/tls, type QUICEvent struct, Data []uint8 #44886
+pkg crypto/tls, type QUICEvent struct, Kind QUICEventKind #44886
+pkg crypto/tls, type QUICEvent struct, Level QUICEncryptionLevel #44886
+pkg crypto/tls, type QUICEvent struct, Suite uint16 #44886
+pkg crypto/tls, type QUICEventKind int #44886
import "strconv"
+// An AlertError is a TLS alert.
+//
+// When using a QUIC transport, QUICConn methods will return an error
+// which wraps AlertError rather than sending a TLS alert.
+type AlertError uint8
+
+func (e AlertError) Error() string {
+ return alert(e).String()
+}
+
type alert uint8
const (
extensionCertificateAuthorities uint16 = 47
extensionSignatureAlgorithmsCert uint16 = 50
extensionKeyShare uint16 = 51
+ extensionQUICTransportParameters uint16 = 57
extensionRenegotiationInfo uint16 = 0xff01
)
conn net.Conn
isClient bool
handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake
+ quic *quicState // nil for non-QUIC connections
// isHandshakeComplete is true if the connection is currently transferring
// application data (i.e. is not currently processing a handshake).
nextCipher any // next encryption state
nextMac hash.Hash // next MAC algorithm
- trafficSecret []byte // current TLS 1.3 traffic secret
+ level QUICEncryptionLevel // current QUIC encryption level
+ trafficSecret []byte // current TLS 1.3 traffic secret
}
type permanentError struct {
return nil
}
-func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, secret []byte) {
+func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, level QUICEncryptionLevel, secret []byte) {
hc.trafficSecret = secret
+ hc.level = level
key, iv := suite.trafficKey(secret)
hc.cipher = suite.aead(key, iv)
for i := range hc.seq {
}
c.input.Reset(nil)
+ if c.quic != nil {
+ return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with QUIC transport"))
+ }
+
// Read header, payload.
if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil {
// RFC 8446, Section 6.1 suggests that EOF without an alertCloseNotify
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
case recordTypeAlert:
+ if c.quic != nil {
+ return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
+ }
if len(data) != 2 {
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
// sendAlertLocked sends a TLS alert message.
func (c *Conn) sendAlertLocked(err alert) error {
+ if c.quic != nil {
+ return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
+ }
+
switch err {
case alertNoRenegotiation, alertCloseNotify:
c.tmp[0] = alertLevelWarning
// writeRecordLocked writes a TLS record with the given type and payload to the
// connection and updates the record layer state.
func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
+ if c.quic != nil {
+ if typ != recordTypeHandshake {
+ return 0, errors.New("tls: internal error: sending non-handshake message to QUIC transport")
+ }
+ c.quicWriteCryptoData(c.out.level, data)
+ if !c.buffering {
+ if _, err := c.flush(); err != nil {
+ return 0, err
+ }
+ }
+ return len(data), nil
+ }
+
outBufPtr := outBufPool.Get().(*[]byte)
outBuf := *outBufPtr
defer func() {
return err
}
+// readHandshakeBytes reads handshake data until c.hand contains at least n bytes.
+func (c *Conn) readHandshakeBytes(n int) error {
+ if c.quic != nil {
+ return c.quicReadHandshakeBytes(n)
+ }
+ for c.hand.Len() < n {
+ if err := c.readRecord(); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
// readHandshake reads the next handshake message from
// the record layer. If transcript is non-nil, the message
// is written to the passed transcriptHash.
func (c *Conn) readHandshake(transcript transcriptHash) (any, error) {
- for c.hand.Len() < 4 {
- if err := c.readRecord(); err != nil {
- return nil, err
- }
+ if err := c.readHandshakeBytes(4); err != nil {
+ return nil, err
}
-
data := c.hand.Bytes()
n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
if n > maxHandshake {
c.sendAlertLocked(alertInternalError)
return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake))
}
- for c.hand.Len() < 4+n {
- if err := c.readRecord(); err != nil {
- return nil, err
- }
+ if err := c.readHandshakeBytes(4 + n); err != nil {
+ return nil, err
}
data = c.hand.Next(4 + n)
+ return c.unmarshalHandshakeMessage(data, transcript)
+}
+
+func (c *Conn) unmarshalHandshakeMessage(data []byte, transcript transcriptHash) (handshakeMessage, error) {
var m handshakeMessage
switch data[0] {
case typeHelloRequest:
if err != nil {
return err
}
-
c.retryCount++
if c.retryCount > maxUselessRecords {
c.sendAlert(alertUnexpectedMessage)
return c.handleNewSessionTicket(msg)
case *keyUpdateMsg:
return c.handleKeyUpdate(msg)
- default:
- c.sendAlert(alertUnexpectedMessage)
- return fmt.Errorf("tls: received unexpected handshake message of type %T", msg)
}
+ // The QUIC layer is supposed to treat an unexpected post-handshake CertificateRequest
+ // as a QUIC-level PROTOCOL_VIOLATION error (RFC 9001, Section 4.4). Returning an
+ // unexpected_message alert here doesn't provide it with enough information to distinguish
+ // this condition from other unexpected messages. This is probably fine.
+ c.sendAlert(alertUnexpectedMessage)
+ return fmt.Errorf("tls: received unexpected handshake message of type %T", msg)
}
func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
+ if c.quic != nil {
+ c.sendAlert(alertUnexpectedMessage)
+ return c.in.setErrorLocked(errors.New("tls: received unexpected key update message"))
+ }
+
cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite)
if cipherSuite == nil {
return c.in.setErrorLocked(c.sendAlert(alertInternalError))
}
newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret)
- c.in.setTrafficSecret(cipherSuite, newSecret)
+ c.in.setTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret)
if keyUpdate.updateRequested {
c.out.Lock()
}
newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret)
- c.out.setTrafficSecret(cipherSuite, newSecret)
+ c.out.setTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret)
}
return nil
// this cancellation. In the former case, we need to close the connection.
defer cancel()
- // Start the "interrupter" goroutine, if this context might be canceled.
- // (The background context cannot).
- //
- // The interrupter goroutine waits for the input context to be done and
- // closes the connection if this happens before the function returns.
- if ctx.Done() != nil {
+ if c.quic != nil {
+ c.quic.cancelc = handshakeCtx.Done()
+ c.quic.cancel = cancel
+ } else if ctx.Done() != nil {
+ // Start the "interrupter" goroutine, if this context might be canceled.
+ // (The background context cannot).
+ //
+ // The interrupter goroutine waits for the input context to be done and
+ // closes the connection if this happens before the function returns.
done := make(chan struct{})
interruptRes := make(chan error, 1)
defer func() {
panic("tls: internal error: handshake returned an error but is marked successful")
}
+ if c.quic != nil {
+ if c.handshakeErr == nil {
+ c.quicHandshakeComplete()
+ // Provide the 1-RTT read secret now that the handshake is complete.
+ // The QUIC layer MUST NOT decrypt 1-RTT packets prior to completing
+ // the handshake (RFC 9001, Section 5.7).
+ c.quicSetReadSecret(QUICEncryptionLevelApplication, c.cipherSuite, c.in.trafficSecret)
+ } else {
+ var a alert
+ c.out.Lock()
+ if !errors.As(c.out.err, &a) {
+ a = alertInternalError
+ }
+ c.out.Unlock()
+ // Return an error which wraps both the handshake error and
+ // any alert error we may have sent, or alertInternalError
+ // if we didn't send an alert.
+ // Truncate the text of the alert to 0 characters.
+ c.handshakeErr = fmt.Errorf("%w%.0w", c.handshakeErr, AlertError(a))
+ }
+ close(c.quic.blockedc)
+ close(c.quic.signalc)
+ }
+
return c.handshakeErr
}
vers: clientHelloVersion,
compressionMethods: []uint8{compressionNone},
random: make([]byte, 32),
- sessionId: make([]byte, 32),
ocspStapling: true,
scts: true,
serverName: hostnameInSNI(config.ServerName),
// A random session ID is used to detect when the server accepted a ticket
// and is resuming a session (see RFC 5077). In TLS 1.3, it's always set as
// a compatibility measure (see RFC 8446, Section 4.1.2).
- if _, err := io.ReadFull(config.rand(), hello.sessionId); err != nil {
- return nil, nil, errors.New("tls: short read from Rand: " + err.Error())
+ //
+ // The session ID is not set for QUIC connections (see RFC 9001, Section 8.4).
+ if c.quic == nil {
+ hello.sessionId = make([]byte, 32)
+ if _, err := io.ReadFull(config.rand(), hello.sessionId); err != nil {
+ return nil, nil, errors.New("tls: short read from Rand: " + err.Error())
+ }
}
if hello.vers >= VersionTLS12 {
hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}}
}
+ if c.quic != nil {
+ p, err := c.quicGetTransportParameters()
+ if err != nil {
+ return nil, nil, err
+ }
+ if p == nil {
+ p = []byte{}
+ }
+ hello.quicTransportParameters = p
+ }
+
return hello, key, nil
}
}
// Try to resume a previously negotiated TLS session, if available.
- cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config)
+ cacheKey = c.clientSessionCacheKey()
+ if cacheKey == "" {
+ return "", nil, nil, nil, nil
+ }
session, ok := c.config.ClientSessionCache.Get(cacheKey)
if !ok || session == nil {
return cacheKey, nil, nil, nil, nil
}
}
- if err := checkALPN(hs.hello.alpnProtocols, hs.serverHello.alpnProtocol); err != nil {
+ if err := checkALPN(hs.hello.alpnProtocols, hs.serverHello.alpnProtocol, false); err != nil {
c.sendAlert(alertUnsupportedExtension)
return false, err
}
// checkALPN ensure that the server's choice of ALPN protocol is compatible with
// the protocols that we advertised in the Client Hello.
-func checkALPN(clientProtos []string, serverProto string) error {
+func checkALPN(clientProtos []string, serverProto string, quic bool) error {
if serverProto == "" {
+ if quic && len(clientProtos) > 0 {
+ // RFC 9001, Section 8.1
+ return errors.New("tls: server did not select an ALPN protocol")
+ }
return nil
}
if len(clientProtos) == 0 {
// clientSessionCacheKey returns a key used to cache sessionTickets that could
// be used to resume previously negotiated TLS sessions with a server.
-func clientSessionCacheKey(serverAddr net.Addr, config *Config) string {
- if len(config.ServerName) > 0 {
- return config.ServerName
+func (c *Conn) clientSessionCacheKey() string {
+ if len(c.config.ServerName) > 0 {
+ return c.config.ServerName
+ }
+ if c.conn != nil {
+ return c.conn.RemoteAddr().String()
}
- return serverAddr.String()
+ return ""
}
// hostnameInSNI converts name into an appropriate hostname for SNI.
// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility
// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4.
func (hs *clientHandshakeStateTLS13) sendDummyChangeCipherSpec() error {
+ if hs.c.quic != nil {
+ return nil
+ }
if hs.sentDummyCCS {
return nil
}
clientSecret := hs.suite.deriveSecret(handshakeSecret,
clientHandshakeTrafficLabel, hs.transcript)
- c.out.setTrafficSecret(hs.suite, clientSecret)
+ c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, clientSecret)
serverSecret := hs.suite.deriveSecret(handshakeSecret,
serverHandshakeTrafficLabel, hs.transcript)
- c.in.setTrafficSecret(hs.suite, serverSecret)
+ c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, serverSecret)
+
+ if c.quic != nil {
+ if c.hand.Len() != 0 {
+ c.sendAlert(alertUnexpectedMessage)
+ }
+ c.quicSetWriteSecret(QUICEncryptionLevelHandshake, hs.suite.id, clientSecret)
+ c.quicSetReadSecret(QUICEncryptionLevelHandshake, hs.suite.id, serverSecret)
+ }
err = c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret)
if err != nil {
return unexpectedMessageError(encryptedExtensions, msg)
}
- if err := checkALPN(hs.hello.alpnProtocols, encryptedExtensions.alpnProtocol); err != nil {
- c.sendAlert(alertUnsupportedExtension)
+ if err := checkALPN(hs.hello.alpnProtocols, encryptedExtensions.alpnProtocol, c.quic != nil); err != nil {
+ // RFC 8446 specifies that no_application_protocol is sent by servers, but
+ // does not specify how clients handle the selection of an incompatible protocol.
+ // RFC 9001 Section 8.1 specifies that QUIC clients send no_application_protocol
+ // in this case. Always sending no_application_protocol seems reasonable.
+ c.sendAlert(alertNoApplicationProtocol)
return err
}
c.clientProtocol = encryptedExtensions.alpnProtocol
+ if c.quic != nil {
+ if encryptedExtensions.quicTransportParameters == nil {
+ // RFC 9001 Section 8.2.
+ c.sendAlert(alertMissingExtension)
+ return errors.New("tls: server did not send a quic_transport_parameters extension")
+ }
+ c.quicSetTransportParameters(encryptedExtensions.quicTransportParameters)
+ } else {
+ if encryptedExtensions.quicTransportParameters != nil {
+ c.sendAlert(alertUnsupportedExtension)
+ return errors.New("tls: server sent an unexpected quic_transport_parameters extension")
+ }
+ }
+
return nil
}
clientApplicationTrafficLabel, hs.transcript)
serverSecret := hs.suite.deriveSecret(hs.masterSecret,
serverApplicationTrafficLabel, hs.transcript)
- c.in.setTrafficSecret(hs.suite, serverSecret)
+ c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, serverSecret)
err = c.config.writeKeyLog(keyLogLabelClientTraffic, hs.hello.random, hs.trafficSecret)
if err != nil {
return err
}
- c.out.setTrafficSecret(hs.suite, hs.trafficSecret)
+ c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, hs.trafficSecret)
if !c.config.SessionTicketsDisabled && c.config.ClientSessionCache != nil {
c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret,
resumptionLabel, hs.transcript)
}
+ if c.quic != nil {
+ if c.hand.Len() != 0 {
+ c.sendAlert(alertUnexpectedMessage)
+ }
+ c.quicSetWriteSecret(QUICEncryptionLevelApplication, hs.suite.id, hs.trafficSecret)
+ }
+
return nil
}
scts: c.scts,
}
- cacheKey := clientSessionCacheKey(c.conn.RemoteAddr(), c.config)
- c.config.ClientSessionCache.Put(cacheKey, session)
+ cacheKey := c.clientSessionCacheKey()
+ if cacheKey != "" {
+ c.config.ClientSessionCache.Put(cacheKey, session)
+ }
return nil
}
pskModes []uint8
pskIdentities []pskIdentity
pskBinders [][]byte
+ quicTransportParameters []byte
}
func (m *clientHelloMsg) marshal() ([]byte, error) {
})
})
}
+ if m.quicTransportParameters != nil { // marshal zero-length parameters when present
+ // RFC 9001, Section 8.2
+ exts.AddUint16(extensionQUICTransportParameters)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddBytes(m.quicTransportParameters)
+ })
+ }
if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension
// RFC 8446, Section 4.2.11
exts.AddUint16(extensionPreSharedKey)
if !readUint8LengthPrefixed(&extData, &m.pskModes) {
return false
}
+ case extensionQUICTransportParameters:
+ m.quicTransportParameters = make([]byte, len(extData))
+ if !extData.CopyBytes(m.quicTransportParameters) {
+ return false
+ }
case extensionPreSharedKey:
// RFC 8446, Section 4.2.11
if !extensions.Empty() {
}
type encryptedExtensionsMsg struct {
- raw []byte
- alpnProtocol string
+ raw []byte
+ alpnProtocol string
+ quicTransportParameters []byte
}
func (m *encryptedExtensionsMsg) marshal() ([]byte, error) {
})
})
}
+ if m.quicTransportParameters != nil { // marshal zero-length parameters when present
+ // draft-ietf-quic-tls-32, Section 8.2
+ b.AddUint16(extensionQUICTransportParameters)
+ b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+ b.AddBytes(m.quicTransportParameters)
+ })
+ }
})
})
return false
}
m.alpnProtocol = string(proto)
+ case extensionQUICTransportParameters:
+ m.quicTransportParameters = make([]byte, len(extData))
+ if !extData.CopyBytes(m.quicTransportParameters) {
+ return false
+ }
default:
// Ignore unknown extensions.
continue
m.pskIdentities = append(m.pskIdentities, psk)
m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand))
}
+ if rand.Intn(10) > 5 {
+ m.quicTransportParameters = randomBytes(rand.Intn(500), rand)
+ }
if rand.Intn(10) > 5 {
m.earlyData = true
}
c.serverName = hs.clientHello.serverName
}
- selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols)
+ selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols, false)
if err != nil {
c.sendAlert(alertNoApplicationProtocol)
return err
// negotiateALPN picks a shared ALPN protocol that both sides support in server
// preference order. If ALPN is not configured or the peer doesn't support it,
// it returns "" and no error.
-func negotiateALPN(serverProtos, clientProtos []string) (string, error) {
+func negotiateALPN(serverProtos, clientProtos []string, quic bool) (string, error) {
if len(serverProtos) == 0 || len(clientProtos) == 0 {
+ if quic && len(serverProtos) != 0 {
+ // RFC 9001, Section 8.1
+ return "", fmt.Errorf("tls: client did not request an application protocol")
+ }
return "", nil
}
var http11fallback bool
return errors.New("tls: invalid client key share")
}
+ if c.quic != nil {
+ if hs.clientHello.quicTransportParameters == nil {
+ // RFC 9001 Section 8.2.
+ c.sendAlert(alertMissingExtension)
+ return errors.New("tls: client did not send a quic_transport_parameters extension")
+ }
+ c.quicSetTransportParameters(hs.clientHello.quicTransportParameters)
+ } else {
+ if hs.clientHello.quicTransportParameters != nil {
+ c.sendAlert(alertUnsupportedExtension)
+ return errors.New("tls: client sent an unexpected quic_transport_parameters extension")
+ }
+ }
+
c.serverName = hs.clientHello.serverName
return nil
}
// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility
// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4.
func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error {
+ if hs.c.quic != nil {
+ return nil
+ }
if hs.sentDummyCCS {
return nil
}
clientSecret := hs.suite.deriveSecret(hs.handshakeSecret,
clientHandshakeTrafficLabel, hs.transcript)
- c.in.setTrafficSecret(hs.suite, clientSecret)
+ c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, clientSecret)
serverSecret := hs.suite.deriveSecret(hs.handshakeSecret,
serverHandshakeTrafficLabel, hs.transcript)
- c.out.setTrafficSecret(hs.suite, serverSecret)
+ c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, serverSecret)
+
+ if c.quic != nil {
+ if c.hand.Len() != 0 {
+ c.sendAlert(alertUnexpectedMessage)
+ }
+ c.quicSetWriteSecret(QUICEncryptionLevelHandshake, hs.suite.id, serverSecret)
+ c.quicSetReadSecret(QUICEncryptionLevelHandshake, hs.suite.id, clientSecret)
+ }
err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.clientHello.random, clientSecret)
if err != nil {
encryptedExtensions := new(encryptedExtensionsMsg)
- selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols)
+ selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols, c.quic != nil)
if err != nil {
c.sendAlert(alertNoApplicationProtocol)
return err
encryptedExtensions.alpnProtocol = selectedProto
c.clientProtocol = selectedProto
+ if c.quic != nil {
+ p, err := c.quicGetTransportParameters()
+ if err != nil {
+ return err
+ }
+ encryptedExtensions.quicTransportParameters = p
+ }
+
if _, err := hs.c.writeHandshakeRecord(encryptedExtensions, hs.transcript); err != nil {
return err
}
clientApplicationTrafficLabel, hs.transcript)
serverSecret := hs.suite.deriveSecret(hs.masterSecret,
serverApplicationTrafficLabel, hs.transcript)
- c.out.setTrafficSecret(hs.suite, serverSecret)
+ c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, serverSecret)
+
+ if c.quic != nil {
+ if c.hand.Len() != 0 {
+ // TODO: Handle this in setTrafficSecret?
+ c.sendAlert(alertUnexpectedMessage)
+ }
+ c.quicSetWriteSecret(QUICEncryptionLevelApplication, hs.suite.id, serverSecret)
+ }
err := c.config.writeKeyLog(keyLogLabelClientTraffic, hs.clientHello.random, hs.trafficSecret)
if err != nil {
return errors.New("tls: invalid client finished hash")
}
- c.in.setTrafficSecret(hs.suite, hs.trafficSecret)
+ c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, hs.trafficSecret)
return nil
}
--- /dev/null
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tls
+
+import (
+ "context"
+ "errors"
+ "fmt"
+)
+
+// QUICEncryptionLevel represents a QUIC encryption level used to transmit
+// handshake messages.
+type QUICEncryptionLevel int
+
+const (
+ QUICEncryptionLevelInitial = QUICEncryptionLevel(iota)
+ QUICEncryptionLevelHandshake
+ QUICEncryptionLevelApplication
+)
+
+func (l QUICEncryptionLevel) String() string {
+ switch l {
+ case QUICEncryptionLevelInitial:
+ return "Initial"
+ case QUICEncryptionLevelHandshake:
+ return "Handshake"
+ case QUICEncryptionLevelApplication:
+ return "Application"
+ default:
+ return fmt.Sprintf("QUICEncryptionLevel(%v)", int(l))
+ }
+}
+
+// A QUICConn represents a connection which uses a QUIC implementation as the underlying
+// transport as described in RFC 9001.
+//
+// Methods of QUICConn are not safe for concurrent use.
+type QUICConn struct {
+ conn *Conn
+}
+
+// A QUICConfig configures a QUICConn.
+type QUICConfig struct {
+ TLSConfig *Config
+}
+
+// A QUICEventKind is a type of operation on a QUIC connection.
+type QUICEventKind int
+
+const (
+ // QUICNoEvent indicates that there are no events available.
+ QUICNoEvent QUICEventKind = iota
+
+ // QUICSetReadSecret and QUICSetWriteSecret provide the read and write
+ // secrets for a given encryption level.
+ // QUICEvent.Level, QUICEvent.Data, and QUICEvent.Suite are set.
+ //
+ // Secrets for the Initial encryption level are derived from the initial
+ // destination connection ID, and are not provided by the QUICConn.
+ QUICSetReadSecret
+ QUICSetWriteSecret
+
+ // QUICWriteData provides data to send to the peer in CRYPTO frames.
+ // QUICEvent.Data is set.
+ QUICWriteData
+
+ // QUICTransportParameters provides the peer's QUIC transport parameters.
+ // QUICEvent.Data is set.
+ QUICTransportParameters
+
+ // QUICTransportParametersRequired indicates that the caller must provide
+ // QUIC transport parameters to send to the peer. The caller should set
+ // the transport parameters with QUICConn.SetTransportParameters and call
+ // QUICConn.NextEvent again.
+ //
+ // If transport parameters are set before calling QUICConn.Start, the
+ // connection will never generate a QUICTransportParametersRequired event.
+ QUICTransportParametersRequired
+
+ // QUICHandshakeDone indicates that the TLS handshake has completed.
+ QUICHandshakeDone
+)
+
+// A QUICEvent is an event occurring on a QUIC connection.
+//
+// The type of event is specified by the Kind field.
+// The contents of the other fields are kind-specific.
+type QUICEvent struct {
+ Kind QUICEventKind
+
+ // Set for QUICSetReadSecret, QUICSetWriteSecret, and QUICWriteData.
+ Level QUICEncryptionLevel
+
+ // Set for QUICTransportParameters, QUICSetReadSecret, QUICSetWriteSecret, and QUICWriteData.
+ // The contents are owned by crypto/tls, and are valid until the next NextEvent call.
+ Data []byte
+
+ // Set for QUICSetReadSecret and QUICSetWriteSecret.
+ Suite uint16
+}
+
+type quicState struct {
+ events []QUICEvent
+ nextEvent int
+
+ // eventArr is a statically allocated event array, large enough to handle
+ // the usual maximum number of events resulting from a single call:
+ // transport parameters, Initial data, Handshake write and read secrets,
+ // Handshake data, Application write secret, Application data.
+ eventArr [7]QUICEvent
+
+ started bool
+ signalc chan struct{} // handshake data is available to be read
+ blockedc chan struct{} // handshake is waiting for data, closed when done
+ cancelc <-chan struct{} // handshake has been canceled
+ cancel context.CancelFunc
+
+ // readbuf is shared between HandleData and the handshake goroutine.
+ // HandshakeCryptoData passes ownership to the handshake goroutine by
+ // reading from signalc, and reclaims ownership by reading from blockedc.
+ readbuf []byte
+
+ transportParams []byte // to send to the peer
+}
+
+// QUICClient returns a new TLS client side connection using QUICTransport as the
+// underlying transport. The config cannot be nil.
+//
+// The config's MinVersion must be at least TLS 1.3.
+func QUICClient(config *QUICConfig) *QUICConn {
+ return newQUICConn(Client(nil, config.TLSConfig))
+}
+
+// QUICServer returns a new TLS server side connection using QUICTransport as the
+// underlying transport. The config cannot be nil.
+//
+// The config's MinVersion must be at least TLS 1.3.
+func QUICServer(config *QUICConfig) *QUICConn {
+ return newQUICConn(Server(nil, config.TLSConfig))
+}
+
+func newQUICConn(conn *Conn) *QUICConn {
+ conn.quic = &quicState{
+ signalc: make(chan struct{}),
+ blockedc: make(chan struct{}),
+ }
+ conn.quic.events = conn.quic.eventArr[:0]
+ return &QUICConn{
+ conn: conn,
+ }
+}
+
+// Start starts the client or server handshake protocol.
+// It may produce connection events, which may be read with NextEvent.
+//
+// Start must be called at most once.
+func (q *QUICConn) Start(ctx context.Context) error {
+ if q.conn.quic.started {
+ return quicError(errors.New("tls: Start called more than once"))
+ }
+ q.conn.quic.started = true
+ if q.conn.config.MinVersion < VersionTLS13 {
+ return quicError(errors.New("tls: Config MinVersion must be at least TLS 1.13"))
+ }
+ go q.conn.HandshakeContext(ctx)
+ if _, ok := <-q.conn.quic.blockedc; !ok {
+ return q.conn.handshakeErr
+ }
+ return nil
+}
+
+// NextEvent returns the next event occurring on the connection.
+// It returns an event with a Kind of QUICNoEvent when no events are available.
+func (q *QUICConn) NextEvent() QUICEvent {
+ qs := q.conn.quic
+ if last := qs.nextEvent - 1; last >= 0 && len(qs.events[last].Data) > 0 {
+ // Write over some of the previous event's data,
+ // to catch callers erroniously retaining it.
+ qs.events[last].Data[0] = 0
+ }
+ if qs.nextEvent >= len(qs.events) {
+ qs.events = qs.events[:0]
+ qs.nextEvent = 0
+ return QUICEvent{Kind: QUICNoEvent}
+ }
+ e := qs.events[qs.nextEvent]
+ qs.events[qs.nextEvent] = QUICEvent{} // zero out references to data
+ qs.nextEvent++
+ return e
+}
+
+// Close closes the connection and stops any in-progress handshake.
+func (q *QUICConn) Close() error {
+ if q.conn.quic.cancel == nil {
+ return nil // never started
+ }
+ q.conn.quic.cancel()
+ for range q.conn.quic.blockedc {
+ // Wait for the handshake goroutine to return.
+ }
+ return q.conn.handshakeErr
+}
+
+// HandleData handles handshake bytes received from the peer.
+// It may produce connection events, which may be read with NextEvent.
+func (q *QUICConn) HandleData(level QUICEncryptionLevel, data []byte) error {
+ c := q.conn
+ if c.in.level != level {
+ return quicError(c.in.setErrorLocked(errors.New("tls: handshake data received at wrong level")))
+ }
+ c.quic.readbuf = data
+ <-c.quic.signalc
+ _, ok := <-c.quic.blockedc
+ if ok {
+ // The handshake goroutine is waiting for more data.
+ return nil
+ }
+ // The handshake goroutine has exited.
+ c.hand.Write(c.quic.readbuf)
+ c.quic.readbuf = nil
+ for q.conn.hand.Len() >= 4 && q.conn.handshakeErr == nil {
+ b := q.conn.hand.Bytes()
+ n := int(b[1])<<16 | int(b[2])<<8 | int(b[3])
+ if 4+n < len(b) {
+ return nil
+ }
+ if err := q.conn.handlePostHandshakeMessage(); err != nil {
+ return quicError(err)
+ }
+ }
+ if q.conn.handshakeErr != nil {
+ return quicError(q.conn.handshakeErr)
+ }
+ return nil
+}
+
+// ConnectionState returns basic TLS details about the connection.
+func (q *QUICConn) ConnectionState() ConnectionState {
+ return q.conn.ConnectionState()
+}
+
+// SetTransportParameters sets the transport parameters to send to the peer.
+//
+// Server connections may delay setting the transport parameters until after
+// receiving the client's transport parameters. See QUICTransportParametersRequired.
+func (q *QUICConn) SetTransportParameters(params []byte) {
+ if params == nil {
+ params = []byte{}
+ }
+ q.conn.quic.transportParams = params
+ if q.conn.quic.started {
+ <-q.conn.quic.signalc
+ <-q.conn.quic.blockedc
+ }
+}
+
+// quicError ensures err is an AlertError.
+// If err is not already, quicError wraps it with alertInternalError.
+func quicError(err error) error {
+ if err == nil {
+ return nil
+ }
+ var ae AlertError
+ if errors.As(err, &ae) {
+ return err
+ }
+ var a alert
+ if !errors.As(err, &a) {
+ a = alertInternalError
+ }
+ // Return an error wrapping the original error and an AlertError.
+ // Truncate the text of the alert to 0 characters.
+ return fmt.Errorf("%w%.0w", err, AlertError(a))
+}
+
+func (c *Conn) quicReadHandshakeBytes(n int) error {
+ for c.hand.Len() < n {
+ if err := c.quicWaitForSignal(); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (c *Conn) quicSetReadSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
+ c.quic.events = append(c.quic.events, QUICEvent{
+ Kind: QUICSetReadSecret,
+ Level: level,
+ Suite: suite,
+ Data: secret,
+ })
+}
+
+func (c *Conn) quicSetWriteSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
+ c.quic.events = append(c.quic.events, QUICEvent{
+ Kind: QUICSetWriteSecret,
+ Level: level,
+ Suite: suite,
+ Data: secret,
+ })
+}
+
+func (c *Conn) quicWriteCryptoData(level QUICEncryptionLevel, data []byte) {
+ var last *QUICEvent
+ if len(c.quic.events) > 0 {
+ last = &c.quic.events[len(c.quic.events)-1]
+ }
+ if last == nil || last.Kind != QUICWriteData || last.Level != level {
+ c.quic.events = append(c.quic.events, QUICEvent{
+ Kind: QUICWriteData,
+ Level: level,
+ })
+ last = &c.quic.events[len(c.quic.events)-1]
+ }
+ last.Data = append(last.Data, data...)
+}
+
+func (c *Conn) quicSetTransportParameters(params []byte) {
+ c.quic.events = append(c.quic.events, QUICEvent{
+ Kind: QUICTransportParameters,
+ Data: params,
+ })
+}
+
+func (c *Conn) quicGetTransportParameters() ([]byte, error) {
+ if c.quic.transportParams == nil {
+ c.quic.events = append(c.quic.events, QUICEvent{
+ Kind: QUICTransportParametersRequired,
+ })
+ }
+ for c.quic.transportParams == nil {
+ if err := c.quicWaitForSignal(); err != nil {
+ return nil, err
+ }
+ }
+ return c.quic.transportParams, nil
+}
+
+func (c *Conn) quicHandshakeComplete() {
+ c.quic.events = append(c.quic.events, QUICEvent{
+ Kind: QUICHandshakeDone,
+ })
+}
+
+// quicWaitForSignal notifies the QUICConn that handshake progress is blocked,
+// and waits for a signal that the handshake should proceed.
+//
+// The handshake may become blocked waiting for handshake bytes
+// or for the user to provide transport parameters.
+func (c *Conn) quicWaitForSignal() error {
+ // Drop the handshake mutex while blocked to allow the user
+ // to call ConnectionState before the handshake completes.
+ c.handshakeMutex.Unlock()
+ defer c.handshakeMutex.Lock()
+ // Send on blockedc to notify the QUICConn that the handshake is blocked.
+ // Exported methods of QUICConn wait for the handshake to become blocked
+ // before returning to the user.
+ select {
+ case c.quic.blockedc <- struct{}{}:
+ case <-c.quic.cancelc:
+ return c.sendAlertLocked(alertCloseNotify)
+ }
+ // The QUICConn reads from signalc to notify us that the handshake may
+ // be able to proceed. (The QUICConn reads, because we close signalc to
+ // indicate that the handshake has completed.)
+ select {
+ case c.quic.signalc <- struct{}{}:
+ c.hand.Write(c.quic.readbuf)
+ c.quic.readbuf = nil
+ case <-c.quic.cancelc:
+ return c.sendAlertLocked(alertCloseNotify)
+ }
+ return nil
+}
--- /dev/null
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tls
+
+import (
+ "context"
+ "errors"
+ "reflect"
+ "testing"
+)
+
+type testQUICConn struct {
+ t *testing.T
+ conn *QUICConn
+ readSecret map[QUICEncryptionLevel]suiteSecret
+ writeSecret map[QUICEncryptionLevel]suiteSecret
+ gotParams []byte
+ complete bool
+}
+
+func newTestQUICClient(t *testing.T, config *Config) *testQUICConn {
+ q := &testQUICConn{t: t}
+ q.conn = QUICClient(&QUICConfig{
+ TLSConfig: config,
+ })
+ t.Cleanup(func() {
+ q.conn.Close()
+ })
+ return q
+}
+
+func newTestQUICServer(t *testing.T, config *Config) *testQUICConn {
+ q := &testQUICConn{t: t}
+ q.conn = QUICServer(&QUICConfig{
+ TLSConfig: config,
+ })
+ t.Cleanup(func() {
+ q.conn.Close()
+ })
+ return q
+}
+
+type suiteSecret struct {
+ suite uint16
+ secret []byte
+}
+
+func (q *testQUICConn) setReadSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
+ if _, ok := q.writeSecret[level]; !ok {
+ q.t.Errorf("SetReadSecret for level %v called before SetWriteSecret", level)
+ }
+ if level == QUICEncryptionLevelApplication && !q.complete {
+ q.t.Errorf("SetReadSecret for level %v called before HandshakeComplete", level)
+ }
+ if _, ok := q.readSecret[level]; ok {
+ q.t.Errorf("SetReadSecret for level %v called twice", level)
+ }
+ if q.readSecret == nil {
+ q.readSecret = map[QUICEncryptionLevel]suiteSecret{}
+ }
+ switch level {
+ case QUICEncryptionLevelHandshake, QUICEncryptionLevelApplication:
+ q.readSecret[level] = suiteSecret{suite, secret}
+ default:
+ q.t.Errorf("SetReadSecret for unexpected level %v", level)
+ }
+}
+
+func (q *testQUICConn) setWriteSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
+ if _, ok := q.writeSecret[level]; ok {
+ q.t.Errorf("SetWriteSecret for level %v called twice", level)
+ }
+ if q.writeSecret == nil {
+ q.writeSecret = map[QUICEncryptionLevel]suiteSecret{}
+ }
+ switch level {
+ case QUICEncryptionLevelHandshake, QUICEncryptionLevelApplication:
+ q.writeSecret[level] = suiteSecret{suite, secret}
+ default:
+ q.t.Errorf("SetWriteSecret for unexpected level %v", level)
+ }
+}
+
+var errTransportParametersRequired = errors.New("transport parameters required")
+
+func runTestQUICConnection(ctx context.Context, a, b *testQUICConn, onHandleCryptoData func()) error {
+ for _, c := range []*testQUICConn{a, b} {
+ if !c.conn.conn.quic.started {
+ if err := c.conn.Start(ctx); err != nil {
+ return err
+ }
+ }
+ }
+ idleCount := 0
+ for {
+ e := a.conn.NextEvent()
+ switch e.Kind {
+ case QUICNoEvent:
+ idleCount++
+ if idleCount == 2 {
+ if !a.complete || !b.complete {
+ return errors.New("handshake incomplete")
+ }
+ return nil
+ }
+ a, b = b, a
+ case QUICSetReadSecret:
+ a.setReadSecret(e.Level, e.Suite, e.Data)
+ case QUICSetWriteSecret:
+ a.setWriteSecret(e.Level, e.Suite, e.Data)
+ case QUICWriteData:
+ if err := b.conn.HandleData(e.Level, e.Data); err != nil {
+ return err
+ }
+ case QUICTransportParameters:
+ a.gotParams = e.Data
+ if a.gotParams == nil {
+ a.gotParams = []byte{}
+ }
+ case QUICTransportParametersRequired:
+ return errTransportParametersRequired
+ case QUICHandshakeDone:
+ a.complete = true
+ }
+ if e.Kind != QUICNoEvent {
+ idleCount = 0
+ }
+ }
+}
+
+func TestQUICConnection(t *testing.T) {
+ config := testConfig.Clone()
+ config.MinVersion = VersionTLS13
+
+ cli := newTestQUICClient(t, config)
+ cli.conn.SetTransportParameters(nil)
+
+ srv := newTestQUICServer(t, config)
+ srv.conn.SetTransportParameters(nil)
+
+ if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
+ t.Fatalf("error during connection handshake: %v", err)
+ }
+
+ if _, ok := cli.readSecret[QUICEncryptionLevelHandshake]; !ok {
+ t.Errorf("client has no Handshake secret")
+ }
+ if _, ok := cli.readSecret[QUICEncryptionLevelApplication]; !ok {
+ t.Errorf("client has no Application secret")
+ }
+ if _, ok := srv.readSecret[QUICEncryptionLevelHandshake]; !ok {
+ t.Errorf("server has no Handshake secret")
+ }
+ if _, ok := srv.readSecret[QUICEncryptionLevelApplication]; !ok {
+ t.Errorf("server has no Application secret")
+ }
+ for _, level := range []QUICEncryptionLevel{QUICEncryptionLevelHandshake, QUICEncryptionLevelApplication} {
+ if _, ok := cli.readSecret[level]; !ok {
+ t.Errorf("client has no %v read secret", level)
+ }
+ if _, ok := srv.readSecret[level]; !ok {
+ t.Errorf("server has no %v read secret", level)
+ }
+ if !reflect.DeepEqual(cli.readSecret[level], srv.writeSecret[level]) {
+ t.Errorf("client read secret does not match server write secret for level %v", level)
+ }
+ if !reflect.DeepEqual(cli.writeSecret[level], srv.readSecret[level]) {
+ t.Errorf("client write secret does not match server read secret for level %v", level)
+ }
+ }
+}
+
+func TestQUICSessionResumption(t *testing.T) {
+ clientConfig := testConfig.Clone()
+ clientConfig.MinVersion = VersionTLS13
+ clientConfig.ClientSessionCache = NewLRUClientSessionCache(1)
+ clientConfig.ServerName = "example.go.dev"
+
+ serverConfig := testConfig.Clone()
+ serverConfig.MinVersion = VersionTLS13
+
+ cli := newTestQUICClient(t, clientConfig)
+ cli.conn.SetTransportParameters(nil)
+ srv := newTestQUICServer(t, serverConfig)
+ srv.conn.SetTransportParameters(nil)
+ if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
+ t.Fatalf("error during first connection handshake: %v", err)
+ }
+ if cli.conn.ConnectionState().DidResume {
+ t.Errorf("first connection unexpectedly used session resumption")
+ }
+
+ cli2 := newTestQUICClient(t, clientConfig)
+ cli2.conn.SetTransportParameters(nil)
+ srv2 := newTestQUICServer(t, serverConfig)
+ srv2.conn.SetTransportParameters(nil)
+ if err := runTestQUICConnection(context.Background(), cli2, srv2, nil); err != nil {
+ t.Fatalf("error during second connection handshake: %v", err)
+ }
+ if !cli2.conn.ConnectionState().DidResume {
+ t.Errorf("second connection did not use session resumption")
+ }
+}
+
+func TestQUICPostHandshakeClientAuthentication(t *testing.T) {
+ // RFC 9001, Section 4.4.
+ config := testConfig.Clone()
+ config.MinVersion = VersionTLS13
+ cli := newTestQUICClient(t, config)
+ cli.conn.SetTransportParameters(nil)
+ srv := newTestQUICServer(t, config)
+ srv.conn.SetTransportParameters(nil)
+ if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
+ t.Fatalf("error during connection handshake: %v", err)
+ }
+
+ certReq := new(certificateRequestMsgTLS13)
+ certReq.ocspStapling = true
+ certReq.scts = true
+ certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
+ certReqBytes, err := certReq.marshal()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err := cli.conn.HandleData(QUICEncryptionLevelApplication, append([]byte{
+ byte(typeCertificateRequest),
+ byte(0), byte(0), byte(len(certReqBytes)),
+ }, certReqBytes...)); err == nil {
+ t.Fatalf("post-handshake authentication request: got no error, want one")
+ }
+}
+
+func TestQUICPostHandshakeKeyUpdate(t *testing.T) {
+ // RFC 9001, Section 6.
+ config := testConfig.Clone()
+ config.MinVersion = VersionTLS13
+ cli := newTestQUICClient(t, config)
+ cli.conn.SetTransportParameters(nil)
+ srv := newTestQUICServer(t, config)
+ srv.conn.SetTransportParameters(nil)
+ if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
+ t.Fatalf("error during connection handshake: %v", err)
+ }
+
+ keyUpdate := new(keyUpdateMsg)
+ keyUpdateBytes, err := keyUpdate.marshal()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err := cli.conn.HandleData(QUICEncryptionLevelApplication, append([]byte{
+ byte(typeKeyUpdate),
+ byte(0), byte(0), byte(len(keyUpdateBytes)),
+ }, keyUpdateBytes...)); !errors.Is(err, alertUnexpectedMessage) {
+ t.Fatalf("key update request: got error %v, want alertUnexpectedMessage", err)
+ }
+}
+
+func TestQUICHandshakeError(t *testing.T) {
+ clientConfig := testConfig.Clone()
+ clientConfig.MinVersion = VersionTLS13
+ clientConfig.InsecureSkipVerify = false
+ clientConfig.ServerName = "name"
+
+ serverConfig := testConfig.Clone()
+ serverConfig.MinVersion = VersionTLS13
+
+ cli := newTestQUICClient(t, clientConfig)
+ cli.conn.SetTransportParameters(nil)
+ srv := newTestQUICServer(t, serverConfig)
+ srv.conn.SetTransportParameters(nil)
+ err := runTestQUICConnection(context.Background(), cli, srv, nil)
+ if !errors.Is(err, AlertError(alertBadCertificate)) {
+ t.Errorf("connection handshake terminated with error %q, want alertBadCertificate", err)
+ }
+ var e *CertificateVerificationError
+ if !errors.As(err, &e) {
+ t.Errorf("connection handshake terminated with error %q, want CertificateVerificationError", err)
+ }
+}
+
+// Test that QUICConn.ConnectionState can be used during the handshake,
+// and that it reports the application protocol as soon as it has been
+// negotiated.
+func TestQUICConnectionState(t *testing.T) {
+ config := testConfig.Clone()
+ config.MinVersion = VersionTLS13
+ config.NextProtos = []string{"h3"}
+ cli := newTestQUICClient(t, config)
+ cli.conn.SetTransportParameters(nil)
+ srv := newTestQUICServer(t, config)
+ srv.conn.SetTransportParameters(nil)
+ onHandleCryptoData := func() {
+ cliCS := cli.conn.ConnectionState()
+ cliWantALPN := ""
+ if _, ok := cli.readSecret[QUICEncryptionLevelApplication]; ok {
+ cliWantALPN = "h3"
+ }
+ if want, got := cliCS.NegotiatedProtocol, cliWantALPN; want != got {
+ t.Errorf("cli.ConnectionState().NegotiatedProtocol = %q, want %q", want, got)
+ }
+
+ srvCS := srv.conn.ConnectionState()
+ srvWantALPN := ""
+ if _, ok := srv.readSecret[QUICEncryptionLevelHandshake]; ok {
+ srvWantALPN = "h3"
+ }
+ if want, got := srvCS.NegotiatedProtocol, srvWantALPN; want != got {
+ t.Errorf("srv.ConnectionState().NegotiatedProtocol = %q, want %q", want, got)
+ }
+ }
+ if err := runTestQUICConnection(context.Background(), cli, srv, onHandleCryptoData); err != nil {
+ t.Fatalf("error during connection handshake: %v", err)
+ }
+}
+
+func TestQUICStartContextPropagation(t *testing.T) {
+ const key = "key"
+ const value = "value"
+ ctx := context.WithValue(context.Background(), key, value)
+ config := testConfig.Clone()
+ config.MinVersion = VersionTLS13
+ calls := 0
+ config.GetConfigForClient = func(info *ClientHelloInfo) (*Config, error) {
+ calls++
+ got, _ := info.Context().Value(key).(string)
+ if got != value {
+ t.Errorf("GetConfigForClient context key %q has value %q, want %q", key, got, value)
+ }
+ return nil, nil
+ }
+ cli := newTestQUICClient(t, config)
+ cli.conn.SetTransportParameters(nil)
+ srv := newTestQUICServer(t, config)
+ srv.conn.SetTransportParameters(nil)
+ if err := runTestQUICConnection(ctx, cli, srv, nil); err != nil {
+ t.Fatalf("error during connection handshake: %v", err)
+ }
+ if calls != 1 {
+ t.Errorf("GetConfigForClient called %v times, want 1", calls)
+ }
+}
+
+func TestQUICDelayedTransportParameters(t *testing.T) {
+ clientConfig := testConfig.Clone()
+ clientConfig.MinVersion = VersionTLS13
+ clientConfig.ClientSessionCache = NewLRUClientSessionCache(1)
+ clientConfig.ServerName = "example.go.dev"
+
+ serverConfig := testConfig.Clone()
+ serverConfig.MinVersion = VersionTLS13
+
+ cliParams := "client params"
+ srvParams := "server params"
+
+ cli := newTestQUICClient(t, clientConfig)
+ srv := newTestQUICServer(t, serverConfig)
+ if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != errTransportParametersRequired {
+ t.Fatalf("handshake with no client parameters: %v; want errTransportParametersRequired", err)
+ }
+ cli.conn.SetTransportParameters([]byte(cliParams))
+ if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != errTransportParametersRequired {
+ t.Fatalf("handshake with no server parameters: %v; want errTransportParametersRequired", err)
+ }
+ srv.conn.SetTransportParameters([]byte(srvParams))
+ if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
+ t.Fatalf("error during connection handshake: %v", err)
+ }
+
+ if got, want := string(cli.gotParams), srvParams; got != want {
+ t.Errorf("client got transport params: %q, want %q", got, want)
+ }
+ if got, want := string(srv.gotParams), cliParams; got != want {
+ t.Errorf("server got transport params: %q, want %q", got, want)
+ }
+}
+
+func TestQUICEmptyTransportParameters(t *testing.T) {
+ config := testConfig.Clone()
+ config.MinVersion = VersionTLS13
+
+ cli := newTestQUICClient(t, config)
+ cli.conn.SetTransportParameters(nil)
+ srv := newTestQUICServer(t, config)
+ srv.conn.SetTransportParameters(nil)
+ if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
+ t.Fatalf("error during connection handshake: %v", err)
+ }
+
+ if cli.gotParams == nil {
+ t.Errorf("client did not get transport params")
+ }
+ if srv.gotParams == nil {
+ t.Errorf("server did not get transport params")
+ }
+ if len(cli.gotParams) != 0 {
+ t.Errorf("client got transport params: %v, want empty", cli.gotParams)
+ }
+ if len(srv.gotParams) != 0 {
+ t.Errorf("server got transport params: %v, want empty", srv.gotParams)
+ }
+}
+
+func TestQUICCanceledWaitingForData(t *testing.T) {
+ config := testConfig.Clone()
+ config.MinVersion = VersionTLS13
+ cli := newTestQUICClient(t, config)
+ cli.conn.SetTransportParameters(nil)
+ cli.conn.Start(context.Background())
+ for cli.conn.NextEvent().Kind != QUICNoEvent {
+ }
+ err := cli.conn.Close()
+ if !errors.Is(err, alertCloseNotify) {
+ t.Errorf("conn.Close() = %v, want alertCloseNotify", err)
+ }
+}
+
+func TestQUICCanceledWaitingForTransportParams(t *testing.T) {
+ config := testConfig.Clone()
+ config.MinVersion = VersionTLS13
+ cli := newTestQUICClient(t, config)
+ cli.conn.Start(context.Background())
+ for cli.conn.NextEvent().Kind != QUICTransportParametersRequired {
+ }
+ err := cli.conn.Close()
+ if !errors.Is(err, alertCloseNotify) {
+ t.Errorf("conn.Close() = %v, want alertCloseNotify", err)
+ }
+}