]> Cypherpunks.ru repositories - gostls13.git/commitdiff
crypto/tls: add Config.GetConfigForClient
authorAdam Langley <agl@golang.org>
Mon, 10 Oct 2016 22:27:34 +0000 (15:27 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 18 Oct 2016 06:44:05 +0000 (06:44 +0000)
GetConfigForClient allows the tls.Config to be updated on a per-client
basis.

Fixes #16066.
Fixes #15707.
Fixes #15699.

Change-Id: I2c675a443d557f969441226729f98502b38901ea
Reviewed-on: https://go-review.googlesource.com/30790
Run-TryBot: Adam Langley <agl@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/crypto/tls/common.go
src/crypto/tls/handshake_server.go
src/crypto/tls/handshake_server_test.go
src/crypto/tls/tls_test.go

index 021f475700e5f50cf9b5dca1a1f9e8227db9ec77..7199cd9d714b7c031a3aa8a01153ca7bca2a8eec 100644 (file)
@@ -303,7 +303,27 @@ type Config struct {
        // If GetCertificate is nil or returns nil, then the certificate is
        // retrieved from NameToCertificate. If NameToCertificate is nil, the
        // first element of Certificates will be used.
-       GetCertificate func(clientHello *ClientHelloInfo) (*Certificate, error)
+       GetCertificate func(*ClientHelloInfo) (*Certificate, error)
+
+       // GetConfigForClient, if not nil, is called after a ClientHello is
+       // received from a client. It may return a non-nil Config in order to
+       // change the Config that will be used to handle this connection. If
+       // the returned Config is nil, the original Config will be used. The
+       // Config returned by this callback may not be subsequently modified.
+       //
+       // If GetConfigForClient is nil, the Config passed to Server() will be
+       // used for all connections.
+       //
+       // Uniquely for the fields in the returned Config, session ticket keys
+       // will be duplicated from the original Config if not set.
+       // Specifically, if SetSessionTicketKeys was called on the original
+       // config but not on the returned config then the ticket keys from the
+       // original config will be copied into the new config before use.
+       // Otherwise, if SessionTicketKey was set in the original config but
+       // not in the returned config then it will be copied into the returned
+       // config before use. If neither of those cases applies then the key
+       // material from the returned config will be used for session tickets.
+       GetConfigForClient func(*ClientHelloInfo) (*Config, error)
 
        // RootCAs defines the set of root certificate authorities
        // that clients use when verifying server certificates.
@@ -398,13 +418,17 @@ type Config struct {
 
        serverInitOnce sync.Once // guards calling (*Config).serverInit
 
-       // mutex protects sessionTicketKeys
+       // mutex protects sessionTicketKeys and originalConfig.
        mutex sync.RWMutex
        // sessionTicketKeys contains zero or more ticket keys. If the length
        // is zero, SessionTicketsDisabled must be true. The first key is used
        // for new tickets and any subsequent keys can be used to decrypt old
        // tickets.
        sessionTicketKeys []ticketKey
+       // originalConfig is set to the Config that was passed to Server if
+       // this Config is returned by a GetConfigForClient callback. It's used
+       // by serverInit in order to copy session ticket keys if needed.
+       originalConfig *Config
 }
 
 // ticketKeyNameLen is the number of bytes of identifier that is prepended to
@@ -434,12 +458,18 @@ func ticketKeyFromBytes(b [32]byte) (key ticketKey) {
 // Clone returns a shallow clone of c.
 // Only the exported fields are copied.
 func (c *Config) Clone() *Config {
+       var sessionTicketKeys []ticketKey
+       c.mutex.RLock()
+       sessionTicketKeys = c.sessionTicketKeys
+       c.mutex.RUnlock()
+
        return &Config{
                Rand:                        c.Rand,
                Time:                        c.Time,
                Certificates:                c.Certificates,
                NameToCertificate:           c.NameToCertificate,
                GetCertificate:              c.GetCertificate,
+               GetConfigForClient:          c.GetConfigForClient,
                RootCAs:                     c.RootCAs,
                NextProtos:                  c.NextProtos,
                ServerName:                  c.ServerName,
@@ -457,6 +487,8 @@ func (c *Config) Clone() *Config {
                DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
                Renegotiation:               c.Renegotiation,
                KeyLogWriter:                c.KeyLogWriter,
+               sessionTicketKeys:           sessionTicketKeys,
+               // originalConfig is deliberately not duplicated.
        }
 }
 
@@ -465,6 +497,11 @@ func (c *Config) serverInit() {
                return
        }
 
+       var originalConfig *Config
+       c.mutex.Lock()
+       originalConfig, c.originalConfig = c.originalConfig, nil
+       c.mutex.Unlock()
+
        alreadySet := false
        for _, b := range c.SessionTicketKey {
                if b != 0 {
@@ -474,13 +511,21 @@ func (c *Config) serverInit() {
        }
 
        if !alreadySet {
-               if _, err := io.ReadFull(c.rand(), c.SessionTicketKey[:]); err != nil {
+               if originalConfig != nil {
+                       copy(c.SessionTicketKey[:], originalConfig.SessionTicketKey[:])
+               } else if _, err := io.ReadFull(c.rand(), c.SessionTicketKey[:]); err != nil {
                        c.SessionTicketsDisabled = true
                        return
                }
        }
 
-       c.sessionTicketKeys = []ticketKey{ticketKeyFromBytes(c.SessionTicketKey)}
+       if originalConfig != nil {
+               originalConfig.mutex.RLock()
+               c.sessionTicketKeys = originalConfig.sessionTicketKeys
+               originalConfig.mutex.RUnlock()
+       } else {
+               c.sessionTicketKeys = []ticketKey{ticketKeyFromBytes(c.SessionTicketKey)}
+       }
 }
 
 func (c *Config) ticketKeys() []ticketKey {
index 17141ae43cf4a15ad02f96840535dd4bf08bdaf6..630a99acebf04d627964dc398f90abdde0e59d18 100644 (file)
@@ -37,11 +37,9 @@ type serverHandshakeState struct {
 // serverHandshake performs a TLS handshake as a server.
 // c.out.Mutex <= L; c.handshakeMutex <= L.
 func (c *Conn) serverHandshake() error {
-       config := c.config
-
        // If this is the first server handshake, we generate a random key to
        // encrypt the tickets with.
-       config.serverInitOnce.Do(config.serverInit)
+       c.config.serverInitOnce.Do(c.config.serverInit)
 
        hs := serverHandshakeState{
                c: c,
@@ -112,7 +110,6 @@ func (c *Conn) serverHandshake() error {
 // readClientHello reads a ClientHello message from the client and decides
 // whether we will perform session resumption.
 func (hs *serverHandshakeState) readClientHello() (isResume bool, err error) {
-       config := hs.c.config
        c := hs.c
 
        msg, err := c.readHandshake()
@@ -125,7 +122,29 @@ func (hs *serverHandshakeState) readClientHello() (isResume bool, err error) {
                c.sendAlert(alertUnexpectedMessage)
                return false, unexpectedMessageError(hs.clientHello, msg)
        }
-       c.vers, ok = config.mutualVersion(hs.clientHello.vers)
+
+       clientHelloInfo := &ClientHelloInfo{
+               CipherSuites:    hs.clientHello.cipherSuites,
+               ServerName:      hs.clientHello.serverName,
+               SupportedCurves: hs.clientHello.supportedCurves,
+               SupportedPoints: hs.clientHello.supportedPoints,
+       }
+
+       if c.config.GetConfigForClient != nil {
+               if newConfig, err := c.config.GetConfigForClient(clientHelloInfo); err != nil {
+                       c.sendAlert(alertInternalError)
+                       return false, err
+               } else if newConfig != nil {
+                       newConfig.mutex.Lock()
+                       newConfig.originalConfig = c.config
+                       newConfig.mutex.Unlock()
+
+                       newConfig.serverInitOnce.Do(newConfig.serverInit)
+                       c.config = newConfig
+               }
+       }
+
+       c.vers, ok = c.config.mutualVersion(hs.clientHello.vers)
        if !ok {
                c.sendAlert(alertProtocolVersion)
                return false, fmt.Errorf("tls: client offered an unsupported, maximum protocol version of %x", hs.clientHello.vers)
@@ -135,7 +154,7 @@ func (hs *serverHandshakeState) readClientHello() (isResume bool, err error) {
        hs.hello = new(serverHelloMsg)
 
        supportedCurve := false
-       preferredCurves := config.curvePreferences()
+       preferredCurves := c.config.curvePreferences()
 Curves:
        for _, curve := range hs.clientHello.supportedCurves {
                for _, supported := range preferredCurves {
@@ -171,7 +190,7 @@ Curves:
 
        hs.hello.vers = c.vers
        hs.hello.random = make([]byte, 32)
-       _, err = io.ReadFull(config.rand(), hs.hello.random)
+       _, err = io.ReadFull(c.config.rand(), hs.hello.random)
        if err != nil {
                c.sendAlert(alertInternalError)
                return false, err
@@ -196,20 +215,15 @@ Curves:
        } else {
                // Although sending an empty NPN extension is reasonable, Firefox has
                // had a bug around this. Best to send nothing at all if
-               // config.NextProtos is empty. See
+               // c.config.NextProtos is empty. See
                // https://golang.org/issue/5445.
-               if hs.clientHello.nextProtoNeg && len(config.NextProtos) > 0 {
+               if hs.clientHello.nextProtoNeg && len(c.config.NextProtos) > 0 {
                        hs.hello.nextProtoNeg = true
-                       hs.hello.nextProtos = config.NextProtos
+                       hs.hello.nextProtos = c.config.NextProtos
                }
        }
 
-       hs.cert, err = config.getCertificate(&ClientHelloInfo{
-               CipherSuites:    hs.clientHello.cipherSuites,
-               ServerName:      hs.clientHello.serverName,
-               SupportedCurves: hs.clientHello.supportedCurves,
-               SupportedPoints: hs.clientHello.supportedPoints,
-       })
+       hs.cert, err = c.config.getCertificate(clientHelloInfo)
        if err != nil {
                c.sendAlert(alertInternalError)
                return false, err
@@ -354,18 +368,17 @@ func (hs *serverHandshakeState) doResumeHandshake() error {
 }
 
 func (hs *serverHandshakeState) doFullHandshake() error {
-       config := hs.c.config
        c := hs.c
 
        if hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 {
                hs.hello.ocspStapling = true
        }
 
-       hs.hello.ticketSupported = hs.clientHello.ticketSupported && !config.SessionTicketsDisabled
+       hs.hello.ticketSupported = hs.clientHello.ticketSupported && !c.config.SessionTicketsDisabled
        hs.hello.cipherSuite = hs.suite.id
 
        hs.finishedHash = newFinishedHash(hs.c.vers, hs.suite)
-       if config.ClientAuth == NoClientCert {
+       if c.config.ClientAuth == NoClientCert {
                // No need to keep a full record of the handshake if client
                // certificates won't be used.
                hs.finishedHash.discardHandshakeBuffer()
@@ -394,7 +407,7 @@ func (hs *serverHandshakeState) doFullHandshake() error {
        }
 
        keyAgreement := hs.suite.ka(c.vers)
-       skx, err := keyAgreement.generateServerKeyExchange(config, hs.cert, hs.clientHello, hs.hello)
+       skx, err := keyAgreement.generateServerKeyExchange(c.config, hs.cert, hs.clientHello, hs.hello)
        if err != nil {
                c.sendAlert(alertHandshakeFailure)
                return err
@@ -406,7 +419,7 @@ func (hs *serverHandshakeState) doFullHandshake() error {
                }
        }
 
-       if config.ClientAuth >= RequestClientCert {
+       if c.config.ClientAuth >= RequestClientCert {
                // Request a client certificate
                certReq := new(certificateRequestMsg)
                certReq.certificateTypes = []byte{
@@ -423,8 +436,8 @@ func (hs *serverHandshakeState) doFullHandshake() error {
                // to our request. When we know the CAs we trust, then
                // we can send them down, so that the client can choose
                // an appropriate certificate to give to us.
-               if config.ClientCAs != nil {
-                       certReq.certificateAuthorities = config.ClientCAs.Subjects()
+               if c.config.ClientCAs != nil {
+                       certReq.certificateAuthorities = c.config.ClientCAs.Subjects()
                }
                hs.finishedHash.Write(certReq.marshal())
                if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil {
@@ -452,7 +465,7 @@ func (hs *serverHandshakeState) doFullHandshake() error {
        var ok bool
        // If we requested a client certificate, then the client must send a
        // certificate message, even if it's empty.
-       if config.ClientAuth >= RequestClientCert {
+       if c.config.ClientAuth >= RequestClientCert {
                if certMsg, ok = msg.(*certificateMsg); !ok {
                        c.sendAlert(alertUnexpectedMessage)
                        return unexpectedMessageError(certMsg, msg)
@@ -461,7 +474,7 @@ func (hs *serverHandshakeState) doFullHandshake() error {
 
                if len(certMsg.certificates) == 0 {
                        // The client didn't actually send a certificate
-                       switch config.ClientAuth {
+                       switch c.config.ClientAuth {
                        case RequireAnyClientCert, RequireAndVerifyClientCert:
                                c.sendAlert(alertBadCertificate)
                                return errors.New("tls: client didn't provide a certificate")
@@ -487,13 +500,13 @@ func (hs *serverHandshakeState) doFullHandshake() error {
        }
        hs.finishedHash.Write(ckx.marshal())
 
-       preMasterSecret, err := keyAgreement.processClientKeyExchange(config, hs.cert, ckx, c.vers)
+       preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers)
        if err != nil {
                c.sendAlert(alertHandshakeFailure)
                return err
        }
        hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.clientHello.random, hs.hello.random)
-       if err := config.writeKeyLog(hs.clientHello.random, hs.masterSecret); err != nil {
+       if err := c.config.writeKeyLog(hs.clientHello.random, hs.masterSecret); err != nil {
                c.sendAlert(alertInternalError)
                return err
        }
index 38d8275a9a71284dd8b410ab948971cb42e0e95b..94193bdac76eaa735caf96d4b36d72d293ab06af 100644 (file)
@@ -1141,6 +1141,141 @@ func TestSNIGivenOnFailure(t *testing.T) {
        }
 }
 
+var getConfigForClientTests = []struct {
+       setup          func(config *Config)
+       callback       func(clientHello *ClientHelloInfo) (*Config, error)
+       errorSubstring string
+       verify         func(config *Config) error
+}{
+       {
+               nil,
+               func(clientHello *ClientHelloInfo) (*Config, error) {
+                       return nil, nil
+               },
+               "",
+               nil,
+       },
+       {
+               nil,
+               func(clientHello *ClientHelloInfo) (*Config, error) {
+                       return nil, errors.New("should bubble up")
+               },
+               "should bubble up",
+               nil,
+       },
+       {
+               nil,
+               func(clientHello *ClientHelloInfo) (*Config, error) {
+                       config := testConfig.Clone()
+                       // Setting a maximum version of TLS 1.1 should cause
+                       // the handshake to fail.
+                       config.MaxVersion = VersionTLS11
+                       return config, nil
+               },
+               "version 301 when expecting version 302",
+               nil,
+       },
+       {
+               func(config *Config) {
+                       for i := range config.SessionTicketKey {
+                               config.SessionTicketKey[i] = byte(i)
+                       }
+                       config.sessionTicketKeys = nil
+               },
+               func(clientHello *ClientHelloInfo) (*Config, error) {
+                       config := testConfig.Clone()
+                       for i := range config.SessionTicketKey {
+                               config.SessionTicketKey[i] = 0
+                       }
+                       config.sessionTicketKeys = nil
+                       return config, nil
+               },
+               "",
+               func(config *Config) error {
+                       // The value of SessionTicketKey should have been
+                       // duplicated into the per-connection Config.
+                       for i := range config.SessionTicketKey {
+                               if b := config.SessionTicketKey[i]; b != byte(i) {
+                                       return fmt.Errorf("SessionTicketKey was not duplicated from original Config: byte %d has value %d", i, b)
+                               }
+                       }
+                       return nil
+               },
+       },
+       {
+               func(config *Config) {
+                       var dummyKey [32]byte
+                       for i := range dummyKey {
+                               dummyKey[i] = byte(i)
+                       }
+
+                       config.SetSessionTicketKeys([][32]byte{dummyKey})
+               },
+               func(clientHello *ClientHelloInfo) (*Config, error) {
+                       config := testConfig.Clone()
+                       config.sessionTicketKeys = nil
+                       return config, nil
+               },
+               "",
+               func(config *Config) error {
+                       // The session ticket keys should have been duplicated
+                       // into the per-connection Config.
+                       if l := len(config.sessionTicketKeys); l != 1 {
+                               return fmt.Errorf("got len(sessionTicketKeys) == %d, wanted 1", l)
+                       }
+                       return nil
+               },
+       },
+}
+
+func TestGetConfigForClient(t *testing.T) {
+       serverConfig := testConfig.Clone()
+       clientConfig := testConfig.Clone()
+       clientConfig.MinVersion = VersionTLS12
+
+       for i, test := range getConfigForClientTests {
+               if test.setup != nil {
+                       test.setup(serverConfig)
+               }
+
+               var configReturned *Config
+               serverConfig.GetConfigForClient = func(clientHello *ClientHelloInfo) (*Config, error) {
+                       config, err := test.callback(clientHello)
+                       configReturned = config
+                       return config, err
+               }
+               c, s := net.Pipe()
+               done := make(chan error)
+
+               go func() {
+                       defer s.Close()
+                       done <- Server(s, serverConfig).Handshake()
+               }()
+
+               clientErr := Client(c, clientConfig).Handshake()
+               c.Close()
+
+               serverErr := <-done
+
+               if len(test.errorSubstring) == 0 {
+                       if serverErr != nil || clientErr != nil {
+                               t.Errorf("%#d: expected no error but got serverErr: %q, clientErr: %q", i, serverErr, clientErr)
+                       }
+                       if test.verify != nil {
+                               if err := test.verify(configReturned); err != nil {
+                                       t.Errorf("#%d: verify returned error: %v", i, err)
+                               }
+                       }
+               } else {
+                       if serverErr == nil {
+                               t.Errorf("%#d: expected error containing %q but got no error", i, test.errorSubstring)
+                       } else if !strings.Contains(serverErr.Error(), test.errorSubstring) {
+                               t.Errorf("%#d: expected error to contain %q but it was %q", i, test.errorSubstring, serverErr)
+                       }
+               }
+       }
+}
+
 func bigFromString(s string) *big.Int {
        ret := new(big.Int)
        ret.SetString(s, 10)
index 8b8dfa4e1e6ab7cbc4abb278e2e8df75f965bfbf..d2a674d08b6aab5e12755a71db5eff8ead60e885 100644 (file)
@@ -477,7 +477,7 @@ func TestClone(t *testing.T) {
                case "Rand":
                        f.Set(reflect.ValueOf(io.Reader(os.Stdin)))
                        continue
-               case "Time", "GetCertificate":
+               case "Time", "GetCertificate", "GetConfigForClient":
                        // DeepEqual can't compare functions.
                        continue
                case "Certificates":