From: Filippo Valsorda Date: Mon, 22 May 2023 08:49:07 +0000 (+0200) Subject: crypto/tls: add WrapSession and UnwrapSession X-Git-Tag: go1.21rc1~249 X-Git-Url: http://www.git.cypherpunks.ru/?a=commitdiff_plain;h=6824765b4b981291712ae6d60702f6f0350f57d5;p=gostls13.git crypto/tls: add WrapSession and UnwrapSession There was a bug in TestResumption: the first ExpiredSessionTicket was inserting a ticket far in the future, so the second ExpiredSessionTicket wasn't actually supposed to fail. However, there was a bug in checkForResumption->sendSessionTicket, too: if a session was not resumed because it was too old, its createdAt was still persisted in the next ticket. The two bugs used to cancel each other out. For #60105 Fixes #19199 Change-Id: Ic9b2aab943dcbf0de62b8758a6195319dc286e2f Reviewed-on: https://go-review.googlesource.com/c/go/+/496821 Run-TryBot: Filippo Valsorda TryBot-Result: Gopher Robot Reviewed-by: Roland Shoemaker Reviewed-by: Damien Neil --- diff --git a/api/next/60105.txt b/api/next/60105.txt index 251f574c8f..03fb68fa3a 100644 --- a/api/next/60105.txt +++ b/api/next/60105.txt @@ -3,3 +3,7 @@ pkg crypto/tls, method (*SessionState) Bytes() ([]uint8, error) #60105 pkg crypto/tls, type SessionState struct #60105 pkg crypto/tls, func NewResumptionState([]uint8, *SessionState) (*ClientSessionState, error) #60105 pkg crypto/tls, method (*ClientSessionState) ResumptionState() ([]uint8, *SessionState, error) #60105 +pkg crypto/tls, method (*Config) DecryptTicket([]uint8, ConnectionState) (*SessionState, error) #60105 +pkg crypto/tls, method (*Config) EncryptTicket(ConnectionState, *SessionState) ([]uint8, error) #60105 +pkg crypto/tls, type Config struct, UnwrapSession func([]uint8, ConnectionState) (*SessionState, error) #60105 +pkg crypto/tls, type Config struct, WrapSession func(ConnectionState, *SessionState) ([]uint8, error) #60105 diff --git a/src/crypto/tls/common.go b/src/crypto/tls/common.go index ccaf7d352f..8da3cc50ca 100644 --- a/src/crypto/tls/common.go +++ b/src/crypto/tls/common.go @@ -673,6 +673,35 @@ type Config struct { // session resumption. It is only used by clients. ClientSessionCache ClientSessionCache + // UnwrapSession is called on the server to turn a ticket/identity + // previously produced by [WrapSession] into a usable session. + // + // UnwrapSession will usually either decrypt a session state in the ticket + // (for example with [Config.EncryptTicket]), or use the ticket as a handle + // to recover a previously stored state. It must use [ParseSessionState] to + // deserialize the session state. + // + // If UnwrapSession returns an error, the connection is terminated. If it + // returns (nil, nil), the session is ignored. crypto/tls may still choose + // not to resume the returned session. + UnwrapSession func(identity []byte, cs ConnectionState) (*SessionState, error) + + // WrapSession is called on the server to produce a session ticket/identity. + // + // WrapSession must serialize the session state with [SessionState.Bytes]. + // It may then encrypt the serialized state (for example with + // [Config.DecryptTicket]) and use it as the ticket, or store the state and + // return a handle for it. + // + // If WrapSession returns an error, the connection is terminated. + // + // Warning: the return value will be exposed on the wire and to clients in + // plaintext. The application is in charge of encrypting and authenticating + // it (and rotating keys) or returning high-entropy identifiers. Failing to + // do so correctly can compromise current, previous, and future connections + // depending on the protocol version. + WrapSession func(ConnectionState, *SessionState) ([]byte, error) + // MinVersion contains the minimum TLS version that is acceptable. // // By default, TLS 1.2 is currently used as the minimum when acting as a @@ -794,6 +823,8 @@ func (c *Config) Clone() *Config { SessionTicketsDisabled: c.SessionTicketsDisabled, SessionTicketKey: c.SessionTicketKey, ClientSessionCache: c.ClientSessionCache, + UnwrapSession: c.UnwrapSession, + WrapSession: c.WrapSession, MinVersion: c.MinVersion, MaxVersion: c.MaxVersion, CurvePreferences: c.CurvePreferences, diff --git a/src/crypto/tls/handshake_client_test.go b/src/crypto/tls/handshake_client_test.go index f5695df44f..7be6f94c36 100644 --- a/src/crypto/tls/handshake_client_test.go +++ b/src/crypto/tls/handshake_client_test.go @@ -900,6 +900,7 @@ func testResumption(t *testing.T, version uint16) { } testResumeState := func(test string, didResume bool) { + t.Helper() _, hs, err := testHandshake(t, clientConfig, serverConfig) if err != nil { t.Fatalf("%s: handshake failed: %s", test, err) @@ -985,9 +986,11 @@ func testResumption(t *testing.T, version uint16) { // Age the session ticket a bit at a time, but don't expire it. d := 0 * time.Hour + serverConfig.Time = func() time.Time { return time.Now().Add(d) } + deleteTicket() + testResumeState("GetFreshSessionTicket", false) for i := 0; i < 13; i++ { d += 12 * time.Hour - serverConfig.Time = func() time.Time { return time.Now().Add(d) } testResumeState("OldSessionTicket", true) } // Expire it (now a little more than 7 days) and make sure a full @@ -995,7 +998,6 @@ func testResumption(t *testing.T, version uint16) { // TLS 1.3 since the client should be using a fresh ticket sent over // by the server. d += 12 * time.Hour - serverConfig.Time = func() time.Time { return time.Now().Add(d) } if version == VersionTLS13 { testResumeState("ExpiredSessionTicket", true) } else { diff --git a/src/crypto/tls/handshake_server.go b/src/crypto/tls/handshake_server.go index 7dda65676a..ef33ab8d84 100644 --- a/src/crypto/tls/handshake_server.go +++ b/src/crypto/tls/handshake_server.go @@ -70,9 +70,11 @@ func (hs *serverHandshakeState) handshake() error { // For an overview of TLS handshaking, see RFC 5246, Section 7.3. c.buffering = true - if hs.checkForResumption() { + if err := hs.checkForResumption(); err != nil { + return err + } + if hs.sessionState != nil { // The client has included a session ticket and so we do an abbreviated handshake. - c.didResume = true if err := hs.doResumeHandshake(); err != nil { return err } @@ -399,65 +401,80 @@ func (hs *serverHandshakeState) cipherSuiteOk(c *cipherSuite) bool { } // checkForResumption reports whether we should perform resumption on this connection. -func (hs *serverHandshakeState) checkForResumption() bool { +func (hs *serverHandshakeState) checkForResumption() error { c := hs.c if c.config.SessionTicketsDisabled { - return false + return nil } - plaintext := c.decryptTicket(hs.clientHello.sessionTicket) - if plaintext == nil { - return false - } - ss, err := ParseSessionState(plaintext) - if err != nil { - return false + var sessionState *SessionState + if c.config.UnwrapSession != nil { + ss, err := c.config.UnwrapSession(hs.clientHello.sessionTicket, c.connectionStateLocked()) + if err != nil { + return err + } + if ss == nil { + return nil + } + sessionState = ss + } else { + plaintext := c.config.decryptTicket(hs.clientHello.sessionTicket, c.ticketKeys) + if plaintext == nil { + return nil + } + ss, err := ParseSessionState(plaintext) + if err != nil { + return nil + } + sessionState = ss } - hs.sessionState = ss // TLS 1.2 tickets don't natively have a lifetime, but we want to avoid // re-wrapping the same master secret in different tickets over and over for // too long, weakening forward secrecy. - createdAt := time.Unix(int64(hs.sessionState.createdAt), 0) + createdAt := time.Unix(int64(sessionState.createdAt), 0) if c.config.time().Sub(createdAt) > maxSessionTicketLifetime { - return false + return nil } // Never resume a session for a different TLS version. - if c.vers != hs.sessionState.version { - return false + if c.vers != sessionState.version { + return nil } cipherSuiteOk := false // Check that the client is still offering the ciphersuite in the session. for _, id := range hs.clientHello.cipherSuites { - if id == hs.sessionState.cipherSuite { + if id == sessionState.cipherSuite { cipherSuiteOk = true break } } if !cipherSuiteOk { - return false + return nil } // Check that we also support the ciphersuite from the session. - hs.suite = selectCipherSuite([]uint16{hs.sessionState.cipherSuite}, + suite := selectCipherSuite([]uint16{sessionState.cipherSuite}, c.config.cipherSuites(), hs.cipherSuiteOk) - if hs.suite == nil { - return false + if suite == nil { + return nil } - sessionHasClientCerts := len(hs.sessionState.peerCertificates) != 0 + sessionHasClientCerts := len(sessionState.peerCertificates) != 0 needClientCerts := requiresClientCert(c.config.ClientAuth) if needClientCerts && !sessionHasClientCerts { - return false + return nil } if sessionHasClientCerts && c.config.ClientAuth == NoClientCert { - return false + return nil } - return true + hs.sessionState = sessionState + hs.suite = suite + c.didResume = true + return nil } func (hs *serverHandshakeState) doResumeHandshake() error { @@ -769,13 +786,20 @@ func (hs *serverHandshakeState) sendSessionTicket() error { // the original time it was created. state.createdAt = hs.sessionState.createdAt } - stateBytes, err := state.Bytes() - if err != nil { - return err - } - m.ticket, err = c.encryptTicket(stateBytes) - if err != nil { - return err + if c.config.WrapSession != nil { + m.ticket, err = c.config.WrapSession(c.connectionStateLocked(), state) + if err != nil { + return err + } + } else { + stateBytes, err := state.Bytes() + if err != nil { + return err + } + m.ticket, err = c.config.encryptTicket(stateBytes, c.ticketKeys) + if err != nil { + return err + } } if _, err := hs.c.writeHandshakeRecord(m, &hs.finishedHash); err != nil { diff --git a/src/crypto/tls/handshake_server_tls13.go b/src/crypto/tls/handshake_server_tls13.go index 6753ad4aee..0b6b16eb12 100644 --- a/src/crypto/tls/handshake_server_tls13.go +++ b/src/crypto/tls/handshake_server_tls13.go @@ -275,12 +275,29 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error { break } - plaintext := c.decryptTicket(identity.label) - if plaintext == nil { - continue + var sessionState *SessionState + if c.config.UnwrapSession != nil { + var err error + sessionState, err = c.config.UnwrapSession(identity.label, c.connectionStateLocked()) + if err != nil { + return err + } + if sessionState == nil { + continue + } + } else { + plaintext := c.config.decryptTicket(identity.label, c.ticketKeys) + if plaintext == nil { + continue + } + var err error + sessionState, err = ParseSessionState(plaintext) + if err != nil { + continue + } } - sessionState, err := ParseSessionState(plaintext) - if err != nil || sessionState.version != VersionTLS13 { + + if sessionState.version != VersionTLS13 { continue } @@ -781,14 +798,21 @@ func (hs *serverHandshakeStateTLS13) sendSessionTickets() error { return err } state.secret = psk - stateBytes, err := state.Bytes() - if err != nil { - c.sendAlert(alertInternalError) - return err - } - m.label, err = c.encryptTicket(stateBytes) - if err != nil { - return err + if c.config.WrapSession != nil { + m.label, err = c.config.WrapSession(c.connectionStateLocked(), state) + if err != nil { + return err + } + } else { + stateBytes, err := state.Bytes() + if err != nil { + c.sendAlert(alertInternalError) + return err + } + m.label, err = c.config.encryptTicket(stateBytes, c.ticketKeys) + if err != nil { + return err + } } m.lifetime = uint32(maxSessionTicketLifetime / time.Second) diff --git a/src/crypto/tls/ticket.go b/src/crypto/tls/ticket.go index 4eacd43055..2ea65a6b41 100644 --- a/src/crypto/tls/ticket.go +++ b/src/crypto/tls/ticket.go @@ -228,8 +228,21 @@ func (c *Conn) sessionState() (*SessionState, error) { }, nil } -func (c *Conn) encryptTicket(state []byte) ([]byte, error) { - if len(c.ticketKeys) == 0 { +// EncryptTicket encrypts a ticket with the Config's configured (or default) +// session ticket keys. It can be used as a [Config.WrapSession] implementation. +func (c *Config) EncryptTicket(cs ConnectionState, ss *SessionState) ([]byte, error) { + ticketKeys := c.ticketKeys(nil) + stateBytes, err := ss.Bytes() + if err != nil { + return nil, err + } + return c.encryptTicket(stateBytes, ticketKeys) +} + +var _ = &Config{WrapSession: (&Config{}).EncryptTicket} + +func (c *Config) encryptTicket(state []byte, ticketKeys []ticketKey) ([]byte, error) { + if len(ticketKeys) == 0 { return nil, errors.New("tls: internal error: session ticket keys unavailable") } @@ -239,10 +252,10 @@ func (c *Conn) encryptTicket(state []byte) ([]byte, error) { authenticated := encrypted[:len(encrypted)-sha256.Size] macBytes := encrypted[len(encrypted)-sha256.Size:] - if _, err := io.ReadFull(c.config.rand(), iv); err != nil { + if _, err := io.ReadFull(c.rand(), iv); err != nil { return nil, err } - key := c.ticketKeys[0] + key := ticketKeys[0] block, err := aes.NewCipher(key.aesKey[:]) if err != nil { return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error()) @@ -256,7 +269,26 @@ func (c *Conn) encryptTicket(state []byte) ([]byte, error) { return encrypted, nil } -func (c *Conn) decryptTicket(encrypted []byte) []byte { +// DecryptTicket decrypts a ticket encrypted by [Config.EncryptTicket]. It can +// be used as a [Config.UnwrapSession] implementation. +// +// If the ticket can't be decrypted or parsed, DecryptTicket returns (nil, nil). +func (c *Config) DecryptTicket(identity []byte, cs ConnectionState) (*SessionState, error) { + ticketKeys := c.ticketKeys(nil) + stateBytes := c.decryptTicket(identity, ticketKeys) + if stateBytes == nil { + return nil, nil + } + s, err := ParseSessionState(stateBytes) + if err != nil { + return nil, nil // drop unparsable tickets on the floor + } + return s, nil +} + +var _ = &Config{UnwrapSession: (&Config{}).DecryptTicket} + +func (c *Config) decryptTicket(encrypted []byte, ticketKeys []ticketKey) []byte { if len(encrypted) < aes.BlockSize+sha256.Size { return nil } @@ -266,7 +298,7 @@ func (c *Conn) decryptTicket(encrypted []byte) []byte { authenticated := encrypted[:len(encrypted)-sha256.Size] macBytes := encrypted[len(encrypted)-sha256.Size:] - for _, key := range c.ticketKeys { + for _, key := range ticketKeys { mac := hmac.New(sha256.New, key.hmacKey[:]) mac.Write(authenticated) expected := mac.Sum(nil) diff --git a/src/crypto/tls/tls_test.go b/src/crypto/tls/tls_test.go index 3e43a56f22..d7691f41bd 100644 --- a/src/crypto/tls/tls_test.go +++ b/src/crypto/tls/tls_test.go @@ -758,7 +758,7 @@ func TestWarningAlertFlood(t *testing.T) { } func TestCloneFuncFields(t *testing.T) { - const expectedCount = 6 + const expectedCount = 8 called := 0 c1 := Config{ @@ -786,6 +786,14 @@ func TestCloneFuncFields(t *testing.T) { called |= 1 << 5 return nil }, + UnwrapSession: func(identity []byte, cs ConnectionState) (*SessionState, error) { + called |= 1 << 6 + return nil, nil + }, + WrapSession: func(cs ConnectionState, ss *SessionState) ([]byte, error) { + called |= 1 << 7 + return nil, nil + }, } c2 := c1.Clone() @@ -796,6 +804,8 @@ func TestCloneFuncFields(t *testing.T) { c2.GetConfigForClient(nil) c2.VerifyPeerCertificate(nil, nil) c2.VerifyConnection(ConnectionState{}) + c2.UnwrapSession(nil, ConnectionState{}) + c2.WrapSession(ConnectionState{}, nil) if called != (1<