]> Cypherpunks.ru repositories - gostls13.git/commitdiff
crypto/tls: add WrapSession and UnwrapSession
authorFilippo Valsorda <filippo@golang.org>
Mon, 22 May 2023 08:49:07 +0000 (10:49 +0200)
committerFilippo Valsorda <filippo@golang.org>
Wed, 24 May 2023 23:56:55 +0000 (23:56 +0000)
There was a bug in TestResumption: the first ExpiredSessionTicket was
inserting a ticket far in the future, so the second ExpiredSessionTicket
wasn't actually supposed to fail. However, there was a bug in
checkForResumption->sendSessionTicket, too: if a session was not resumed
because it was too old, its createdAt was still persisted in the next
ticket. The two bugs used to cancel each other out.

For #60105
Fixes #19199

Change-Id: Ic9b2aab943dcbf0de62b8758a6195319dc286e2f
Reviewed-on: https://go-review.googlesource.com/c/go/+/496821
Run-TryBot: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
api/next/60105.txt
src/crypto/tls/common.go
src/crypto/tls/handshake_client_test.go
src/crypto/tls/handshake_server.go
src/crypto/tls/handshake_server_tls13.go
src/crypto/tls/ticket.go
src/crypto/tls/tls_test.go

index 251f574c8fbb5c38eb64cc42db2c657b3ab70063..03fb68fa3a69809920c37992b97f51a782f59ed9 100644 (file)
@@ -3,3 +3,7 @@ pkg crypto/tls, method (*SessionState) Bytes() ([]uint8, error) #60105
 pkg crypto/tls, type SessionState struct #60105
 pkg crypto/tls, func NewResumptionState([]uint8, *SessionState) (*ClientSessionState, error) #60105
 pkg crypto/tls, method (*ClientSessionState) ResumptionState() ([]uint8, *SessionState, error) #60105
+pkg crypto/tls, method (*Config) DecryptTicket([]uint8, ConnectionState) (*SessionState, error) #60105
+pkg crypto/tls, method (*Config) EncryptTicket(ConnectionState, *SessionState) ([]uint8, error) #60105
+pkg crypto/tls, type Config struct, UnwrapSession func([]uint8, ConnectionState) (*SessionState, error) #60105
+pkg crypto/tls, type Config struct, WrapSession func(ConnectionState, *SessionState) ([]uint8, error) #60105
index ccaf7d352fcdef77b6ed8f4065af0b5a4a25799c..8da3cc50cab2d1a66fcf93ca0dd204eaf30a3e2c 100644 (file)
@@ -673,6 +673,35 @@ type Config struct {
        // session resumption. It is only used by clients.
        ClientSessionCache ClientSessionCache
 
+       // UnwrapSession is called on the server to turn a ticket/identity
+       // previously produced by [WrapSession] into a usable session.
+       //
+       // UnwrapSession will usually either decrypt a session state in the ticket
+       // (for example with [Config.EncryptTicket]), or use the ticket as a handle
+       // to recover a previously stored state. It must use [ParseSessionState] to
+       // deserialize the session state.
+       //
+       // If UnwrapSession returns an error, the connection is terminated. If it
+       // returns (nil, nil), the session is ignored. crypto/tls may still choose
+       // not to resume the returned session.
+       UnwrapSession func(identity []byte, cs ConnectionState) (*SessionState, error)
+
+       // WrapSession is called on the server to produce a session ticket/identity.
+       //
+       // WrapSession must serialize the session state with [SessionState.Bytes].
+       // It may then encrypt the serialized state (for example with
+       // [Config.DecryptTicket]) and use it as the ticket, or store the state and
+       // return a handle for it.
+       //
+       // If WrapSession returns an error, the connection is terminated.
+       //
+       // Warning: the return value will be exposed on the wire and to clients in
+       // plaintext. The application is in charge of encrypting and authenticating
+       // it (and rotating keys) or returning high-entropy identifiers. Failing to
+       // do so correctly can compromise current, previous, and future connections
+       // depending on the protocol version.
+       WrapSession func(ConnectionState, *SessionState) ([]byte, error)
+
        // MinVersion contains the minimum TLS version that is acceptable.
        //
        // By default, TLS 1.2 is currently used as the minimum when acting as a
@@ -794,6 +823,8 @@ func (c *Config) Clone() *Config {
                SessionTicketsDisabled:      c.SessionTicketsDisabled,
                SessionTicketKey:            c.SessionTicketKey,
                ClientSessionCache:          c.ClientSessionCache,
+               UnwrapSession:               c.UnwrapSession,
+               WrapSession:                 c.WrapSession,
                MinVersion:                  c.MinVersion,
                MaxVersion:                  c.MaxVersion,
                CurvePreferences:            c.CurvePreferences,
index f5695df44f10d0551d3edd84a6ea93865f2f3075..7be6f94c36adb4561198343a35b3b6af2e6c61be 100644 (file)
@@ -900,6 +900,7 @@ func testResumption(t *testing.T, version uint16) {
        }
 
        testResumeState := func(test string, didResume bool) {
+               t.Helper()
                _, hs, err := testHandshake(t, clientConfig, serverConfig)
                if err != nil {
                        t.Fatalf("%s: handshake failed: %s", test, err)
@@ -985,9 +986,11 @@ func testResumption(t *testing.T, version uint16) {
 
        // Age the session ticket a bit at a time, but don't expire it.
        d := 0 * time.Hour
+       serverConfig.Time = func() time.Time { return time.Now().Add(d) }
+       deleteTicket()
+       testResumeState("GetFreshSessionTicket", false)
        for i := 0; i < 13; i++ {
                d += 12 * time.Hour
-               serverConfig.Time = func() time.Time { return time.Now().Add(d) }
                testResumeState("OldSessionTicket", true)
        }
        // Expire it (now a little more than 7 days) and make sure a full
@@ -995,7 +998,6 @@ func testResumption(t *testing.T, version uint16) {
        // TLS 1.3 since the client should be using a fresh ticket sent over
        // by the server.
        d += 12 * time.Hour
-       serverConfig.Time = func() time.Time { return time.Now().Add(d) }
        if version == VersionTLS13 {
                testResumeState("ExpiredSessionTicket", true)
        } else {
index 7dda65676ab15455cd5deec9cf09b338ec07f6f3..ef33ab8d84532a441d8c9177f89acdb5b2dbc594 100644 (file)
@@ -70,9 +70,11 @@ func (hs *serverHandshakeState) handshake() error {
 
        // For an overview of TLS handshaking, see RFC 5246, Section 7.3.
        c.buffering = true
-       if hs.checkForResumption() {
+       if err := hs.checkForResumption(); err != nil {
+               return err
+       }
+       if hs.sessionState != nil {
                // The client has included a session ticket and so we do an abbreviated handshake.
-               c.didResume = true
                if err := hs.doResumeHandshake(); err != nil {
                        return err
                }
@@ -399,65 +401,80 @@ func (hs *serverHandshakeState) cipherSuiteOk(c *cipherSuite) bool {
 }
 
 // checkForResumption reports whether we should perform resumption on this connection.
-func (hs *serverHandshakeState) checkForResumption() bool {
+func (hs *serverHandshakeState) checkForResumption() error {
        c := hs.c
 
        if c.config.SessionTicketsDisabled {
-               return false
+               return nil
        }
 
-       plaintext := c.decryptTicket(hs.clientHello.sessionTicket)
-       if plaintext == nil {
-               return false
-       }
-       ss, err := ParseSessionState(plaintext)
-       if err != nil {
-               return false
+       var sessionState *SessionState
+       if c.config.UnwrapSession != nil {
+               ss, err := c.config.UnwrapSession(hs.clientHello.sessionTicket, c.connectionStateLocked())
+               if err != nil {
+                       return err
+               }
+               if ss == nil {
+                       return nil
+               }
+               sessionState = ss
+       } else {
+               plaintext := c.config.decryptTicket(hs.clientHello.sessionTicket, c.ticketKeys)
+               if plaintext == nil {
+                       return nil
+               }
+               ss, err := ParseSessionState(plaintext)
+               if err != nil {
+                       return nil
+               }
+               sessionState = ss
        }
-       hs.sessionState = ss
 
        // TLS 1.2 tickets don't natively have a lifetime, but we want to avoid
        // re-wrapping the same master secret in different tickets over and over for
        // too long, weakening forward secrecy.
-       createdAt := time.Unix(int64(hs.sessionState.createdAt), 0)
+       createdAt := time.Unix(int64(sessionState.createdAt), 0)
        if c.config.time().Sub(createdAt) > maxSessionTicketLifetime {
-               return false
+               return nil
        }
 
        // Never resume a session for a different TLS version.
-       if c.vers != hs.sessionState.version {
-               return false
+       if c.vers != sessionState.version {
+               return nil
        }
 
        cipherSuiteOk := false
        // Check that the client is still offering the ciphersuite in the session.
        for _, id := range hs.clientHello.cipherSuites {
-               if id == hs.sessionState.cipherSuite {
+               if id == sessionState.cipherSuite {
                        cipherSuiteOk = true
                        break
                }
        }
        if !cipherSuiteOk {
-               return false
+               return nil
        }
 
        // Check that we also support the ciphersuite from the session.
-       hs.suite = selectCipherSuite([]uint16{hs.sessionState.cipherSuite},
+       suite := selectCipherSuite([]uint16{sessionState.cipherSuite},
                c.config.cipherSuites(), hs.cipherSuiteOk)
-       if hs.suite == nil {
-               return false
+       if suite == nil {
+               return nil
        }
 
-       sessionHasClientCerts := len(hs.sessionState.peerCertificates) != 0
+       sessionHasClientCerts := len(sessionState.peerCertificates) != 0
        needClientCerts := requiresClientCert(c.config.ClientAuth)
        if needClientCerts && !sessionHasClientCerts {
-               return false
+               return nil
        }
        if sessionHasClientCerts && c.config.ClientAuth == NoClientCert {
-               return false
+               return nil
        }
 
-       return true
+       hs.sessionState = sessionState
+       hs.suite = suite
+       c.didResume = true
+       return nil
 }
 
 func (hs *serverHandshakeState) doResumeHandshake() error {
@@ -769,13 +786,20 @@ func (hs *serverHandshakeState) sendSessionTicket() error {
                // the original time it was created.
                state.createdAt = hs.sessionState.createdAt
        }
-       stateBytes, err := state.Bytes()
-       if err != nil {
-               return err
-       }
-       m.ticket, err = c.encryptTicket(stateBytes)
-       if err != nil {
-               return err
+       if c.config.WrapSession != nil {
+               m.ticket, err = c.config.WrapSession(c.connectionStateLocked(), state)
+               if err != nil {
+                       return err
+               }
+       } else {
+               stateBytes, err := state.Bytes()
+               if err != nil {
+                       return err
+               }
+               m.ticket, err = c.config.encryptTicket(stateBytes, c.ticketKeys)
+               if err != nil {
+                       return err
+               }
        }
 
        if _, err := hs.c.writeHandshakeRecord(m, &hs.finishedHash); err != nil {
index 6753ad4aee05928ba073ed26c05b18ca10495e9d..0b6b16eb1245904053b7265393d4f3dcd492da27 100644 (file)
@@ -275,12 +275,29 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error {
                        break
                }
 
-               plaintext := c.decryptTicket(identity.label)
-               if plaintext == nil {
-                       continue
+               var sessionState *SessionState
+               if c.config.UnwrapSession != nil {
+                       var err error
+                       sessionState, err = c.config.UnwrapSession(identity.label, c.connectionStateLocked())
+                       if err != nil {
+                               return err
+                       }
+                       if sessionState == nil {
+                               continue
+                       }
+               } else {
+                       plaintext := c.config.decryptTicket(identity.label, c.ticketKeys)
+                       if plaintext == nil {
+                               continue
+                       }
+                       var err error
+                       sessionState, err = ParseSessionState(plaintext)
+                       if err != nil {
+                               continue
+                       }
                }
-               sessionState, err := ParseSessionState(plaintext)
-               if err != nil || sessionState.version != VersionTLS13 {
+
+               if sessionState.version != VersionTLS13 {
                        continue
                }
 
@@ -781,14 +798,21 @@ func (hs *serverHandshakeStateTLS13) sendSessionTickets() error {
                return err
        }
        state.secret = psk
-       stateBytes, err := state.Bytes()
-       if err != nil {
-               c.sendAlert(alertInternalError)
-               return err
-       }
-       m.label, err = c.encryptTicket(stateBytes)
-       if err != nil {
-               return err
+       if c.config.WrapSession != nil {
+               m.label, err = c.config.WrapSession(c.connectionStateLocked(), state)
+               if err != nil {
+                       return err
+               }
+       } else {
+               stateBytes, err := state.Bytes()
+               if err != nil {
+                       c.sendAlert(alertInternalError)
+                       return err
+               }
+               m.label, err = c.config.encryptTicket(stateBytes, c.ticketKeys)
+               if err != nil {
+                       return err
+               }
        }
        m.lifetime = uint32(maxSessionTicketLifetime / time.Second)
 
index 4eacd4305576298ad86b7e6ef9587cf8e43a8a14..2ea65a6b41272a785b71d0151f72add136ef45c0 100644 (file)
@@ -228,8 +228,21 @@ func (c *Conn) sessionState() (*SessionState, error) {
        }, nil
 }
 
-func (c *Conn) encryptTicket(state []byte) ([]byte, error) {
-       if len(c.ticketKeys) == 0 {
+// EncryptTicket encrypts a ticket with the Config's configured (or default)
+// session ticket keys. It can be used as a [Config.WrapSession] implementation.
+func (c *Config) EncryptTicket(cs ConnectionState, ss *SessionState) ([]byte, error) {
+       ticketKeys := c.ticketKeys(nil)
+       stateBytes, err := ss.Bytes()
+       if err != nil {
+               return nil, err
+       }
+       return c.encryptTicket(stateBytes, ticketKeys)
+}
+
+var _ = &Config{WrapSession: (&Config{}).EncryptTicket}
+
+func (c *Config) encryptTicket(state []byte, ticketKeys []ticketKey) ([]byte, error) {
+       if len(ticketKeys) == 0 {
                return nil, errors.New("tls: internal error: session ticket keys unavailable")
        }
 
@@ -239,10 +252,10 @@ func (c *Conn) encryptTicket(state []byte) ([]byte, error) {
        authenticated := encrypted[:len(encrypted)-sha256.Size]
        macBytes := encrypted[len(encrypted)-sha256.Size:]
 
-       if _, err := io.ReadFull(c.config.rand(), iv); err != nil {
+       if _, err := io.ReadFull(c.rand(), iv); err != nil {
                return nil, err
        }
-       key := c.ticketKeys[0]
+       key := ticketKeys[0]
        block, err := aes.NewCipher(key.aesKey[:])
        if err != nil {
                return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error())
@@ -256,7 +269,26 @@ func (c *Conn) encryptTicket(state []byte) ([]byte, error) {
        return encrypted, nil
 }
 
-func (c *Conn) decryptTicket(encrypted []byte) []byte {
+// DecryptTicket decrypts a ticket encrypted by [Config.EncryptTicket]. It can
+// be used as a [Config.UnwrapSession] implementation.
+//
+// If the ticket can't be decrypted or parsed, DecryptTicket returns (nil, nil).
+func (c *Config) DecryptTicket(identity []byte, cs ConnectionState) (*SessionState, error) {
+       ticketKeys := c.ticketKeys(nil)
+       stateBytes := c.decryptTicket(identity, ticketKeys)
+       if stateBytes == nil {
+               return nil, nil
+       }
+       s, err := ParseSessionState(stateBytes)
+       if err != nil {
+               return nil, nil // drop unparsable tickets on the floor
+       }
+       return s, nil
+}
+
+var _ = &Config{UnwrapSession: (&Config{}).DecryptTicket}
+
+func (c *Config) decryptTicket(encrypted []byte, ticketKeys []ticketKey) []byte {
        if len(encrypted) < aes.BlockSize+sha256.Size {
                return nil
        }
@@ -266,7 +298,7 @@ func (c *Conn) decryptTicket(encrypted []byte) []byte {
        authenticated := encrypted[:len(encrypted)-sha256.Size]
        macBytes := encrypted[len(encrypted)-sha256.Size:]
 
-       for _, key := range c.ticketKeys {
+       for _, key := range ticketKeys {
                mac := hmac.New(sha256.New, key.hmacKey[:])
                mac.Write(authenticated)
                expected := mac.Sum(nil)
index 3e43a56f22a0ee881ad41bae0520b1f867cbe3f6..d7691f41bdf3ab3f7c3da32c9f1cfa6a3b2936f9 100644 (file)
@@ -758,7 +758,7 @@ func TestWarningAlertFlood(t *testing.T) {
 }
 
 func TestCloneFuncFields(t *testing.T) {
-       const expectedCount = 6
+       const expectedCount = 8
        called := 0
 
        c1 := Config{
@@ -786,6 +786,14 @@ func TestCloneFuncFields(t *testing.T) {
                        called |= 1 << 5
                        return nil
                },
+               UnwrapSession: func(identity []byte, cs ConnectionState) (*SessionState, error) {
+                       called |= 1 << 6
+                       return nil, nil
+               },
+               WrapSession: func(cs ConnectionState, ss *SessionState) ([]byte, error) {
+                       called |= 1 << 7
+                       return nil, nil
+               },
        }
 
        c2 := c1.Clone()
@@ -796,6 +804,8 @@ func TestCloneFuncFields(t *testing.T) {
        c2.GetConfigForClient(nil)
        c2.VerifyPeerCertificate(nil, nil)
        c2.VerifyConnection(ConnectionState{})
+       c2.UnwrapSession(nil, ConnectionState{})
+       c2.WrapSession(ConnectionState{}, nil)
 
        if called != (1<<expectedCount)-1 {
                t.Fatalf("expected %d calls but saw calls %b", expectedCount, called)
@@ -814,7 +824,7 @@ func TestCloneNonFuncFields(t *testing.T) {
                switch fn := typ.Field(i).Name; fn {
                case "Rand":
                        f.Set(reflect.ValueOf(io.Reader(os.Stdin)))
-               case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "VerifyConnection", "GetClientCertificate":
+               case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "VerifyConnection", "GetClientCertificate", "WrapSession", "UnwrapSession":
                        // DeepEqual can't compare functions. If you add a
                        // function field to this list, you must also change
                        // TestCloneFuncFields to ensure that the func field is