// If GetCertificate is nil or returns nil, then the certificate is
// retrieved from NameToCertificate. If NameToCertificate is nil, the
// first element of Certificates will be used.
- GetCertificate func(clientHello *ClientHelloInfo) (*Certificate, error)
+ GetCertificate func(*ClientHelloInfo) (*Certificate, error)
+
+ // GetConfigForClient, if not nil, is called after a ClientHello is
+ // received from a client. It may return a non-nil Config in order to
+ // change the Config that will be used to handle this connection. If
+ // the returned Config is nil, the original Config will be used. The
+ // Config returned by this callback may not be subsequently modified.
+ //
+ // If GetConfigForClient is nil, the Config passed to Server() will be
+ // used for all connections.
+ //
+ // Uniquely for the fields in the returned Config, session ticket keys
+ // will be duplicated from the original Config if not set.
+ // Specifically, if SetSessionTicketKeys was called on the original
+ // config but not on the returned config then the ticket keys from the
+ // original config will be copied into the new config before use.
+ // Otherwise, if SessionTicketKey was set in the original config but
+ // not in the returned config then it will be copied into the returned
+ // config before use. If neither of those cases applies then the key
+ // material from the returned config will be used for session tickets.
+ GetConfigForClient func(*ClientHelloInfo) (*Config, error)
// RootCAs defines the set of root certificate authorities
// that clients use when verifying server certificates.
serverInitOnce sync.Once // guards calling (*Config).serverInit
- // mutex protects sessionTicketKeys
+ // mutex protects sessionTicketKeys and originalConfig.
mutex sync.RWMutex
// sessionTicketKeys contains zero or more ticket keys. If the length
// is zero, SessionTicketsDisabled must be true. The first key is used
// for new tickets and any subsequent keys can be used to decrypt old
// tickets.
sessionTicketKeys []ticketKey
+ // originalConfig is set to the Config that was passed to Server if
+ // this Config is returned by a GetConfigForClient callback. It's used
+ // by serverInit in order to copy session ticket keys if needed.
+ originalConfig *Config
}
// ticketKeyNameLen is the number of bytes of identifier that is prepended to
// Clone returns a shallow clone of c.
// Only the exported fields are copied.
func (c *Config) Clone() *Config {
+ var sessionTicketKeys []ticketKey
+ c.mutex.RLock()
+ sessionTicketKeys = c.sessionTicketKeys
+ c.mutex.RUnlock()
+
return &Config{
Rand: c.Rand,
Time: c.Time,
Certificates: c.Certificates,
NameToCertificate: c.NameToCertificate,
GetCertificate: c.GetCertificate,
+ GetConfigForClient: c.GetConfigForClient,
RootCAs: c.RootCAs,
NextProtos: c.NextProtos,
ServerName: c.ServerName,
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
Renegotiation: c.Renegotiation,
KeyLogWriter: c.KeyLogWriter,
+ sessionTicketKeys: sessionTicketKeys,
+ // originalConfig is deliberately not duplicated.
}
}
return
}
+ var originalConfig *Config
+ c.mutex.Lock()
+ originalConfig, c.originalConfig = c.originalConfig, nil
+ c.mutex.Unlock()
+
alreadySet := false
for _, b := range c.SessionTicketKey {
if b != 0 {
}
if !alreadySet {
- if _, err := io.ReadFull(c.rand(), c.SessionTicketKey[:]); err != nil {
+ if originalConfig != nil {
+ copy(c.SessionTicketKey[:], originalConfig.SessionTicketKey[:])
+ } else if _, err := io.ReadFull(c.rand(), c.SessionTicketKey[:]); err != nil {
c.SessionTicketsDisabled = true
return
}
}
- c.sessionTicketKeys = []ticketKey{ticketKeyFromBytes(c.SessionTicketKey)}
+ if originalConfig != nil {
+ originalConfig.mutex.RLock()
+ c.sessionTicketKeys = originalConfig.sessionTicketKeys
+ originalConfig.mutex.RUnlock()
+ } else {
+ c.sessionTicketKeys = []ticketKey{ticketKeyFromBytes(c.SessionTicketKey)}
+ }
}
func (c *Config) ticketKeys() []ticketKey {
// serverHandshake performs a TLS handshake as a server.
// c.out.Mutex <= L; c.handshakeMutex <= L.
func (c *Conn) serverHandshake() error {
- config := c.config
-
// If this is the first server handshake, we generate a random key to
// encrypt the tickets with.
- config.serverInitOnce.Do(config.serverInit)
+ c.config.serverInitOnce.Do(c.config.serverInit)
hs := serverHandshakeState{
c: c,
// readClientHello reads a ClientHello message from the client and decides
// whether we will perform session resumption.
func (hs *serverHandshakeState) readClientHello() (isResume bool, err error) {
- config := hs.c.config
c := hs.c
msg, err := c.readHandshake()
c.sendAlert(alertUnexpectedMessage)
return false, unexpectedMessageError(hs.clientHello, msg)
}
- c.vers, ok = config.mutualVersion(hs.clientHello.vers)
+
+ clientHelloInfo := &ClientHelloInfo{
+ CipherSuites: hs.clientHello.cipherSuites,
+ ServerName: hs.clientHello.serverName,
+ SupportedCurves: hs.clientHello.supportedCurves,
+ SupportedPoints: hs.clientHello.supportedPoints,
+ }
+
+ if c.config.GetConfigForClient != nil {
+ if newConfig, err := c.config.GetConfigForClient(clientHelloInfo); err != nil {
+ c.sendAlert(alertInternalError)
+ return false, err
+ } else if newConfig != nil {
+ newConfig.mutex.Lock()
+ newConfig.originalConfig = c.config
+ newConfig.mutex.Unlock()
+
+ newConfig.serverInitOnce.Do(newConfig.serverInit)
+ c.config = newConfig
+ }
+ }
+
+ c.vers, ok = c.config.mutualVersion(hs.clientHello.vers)
if !ok {
c.sendAlert(alertProtocolVersion)
return false, fmt.Errorf("tls: client offered an unsupported, maximum protocol version of %x", hs.clientHello.vers)
hs.hello = new(serverHelloMsg)
supportedCurve := false
- preferredCurves := config.curvePreferences()
+ preferredCurves := c.config.curvePreferences()
Curves:
for _, curve := range hs.clientHello.supportedCurves {
for _, supported := range preferredCurves {
hs.hello.vers = c.vers
hs.hello.random = make([]byte, 32)
- _, err = io.ReadFull(config.rand(), hs.hello.random)
+ _, err = io.ReadFull(c.config.rand(), hs.hello.random)
if err != nil {
c.sendAlert(alertInternalError)
return false, err
} else {
// Although sending an empty NPN extension is reasonable, Firefox has
// had a bug around this. Best to send nothing at all if
- // config.NextProtos is empty. See
+ // c.config.NextProtos is empty. See
// https://golang.org/issue/5445.
- if hs.clientHello.nextProtoNeg && len(config.NextProtos) > 0 {
+ if hs.clientHello.nextProtoNeg && len(c.config.NextProtos) > 0 {
hs.hello.nextProtoNeg = true
- hs.hello.nextProtos = config.NextProtos
+ hs.hello.nextProtos = c.config.NextProtos
}
}
- hs.cert, err = config.getCertificate(&ClientHelloInfo{
- CipherSuites: hs.clientHello.cipherSuites,
- ServerName: hs.clientHello.serverName,
- SupportedCurves: hs.clientHello.supportedCurves,
- SupportedPoints: hs.clientHello.supportedPoints,
- })
+ hs.cert, err = c.config.getCertificate(clientHelloInfo)
if err != nil {
c.sendAlert(alertInternalError)
return false, err
}
func (hs *serverHandshakeState) doFullHandshake() error {
- config := hs.c.config
c := hs.c
if hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 {
hs.hello.ocspStapling = true
}
- hs.hello.ticketSupported = hs.clientHello.ticketSupported && !config.SessionTicketsDisabled
+ hs.hello.ticketSupported = hs.clientHello.ticketSupported && !c.config.SessionTicketsDisabled
hs.hello.cipherSuite = hs.suite.id
hs.finishedHash = newFinishedHash(hs.c.vers, hs.suite)
- if config.ClientAuth == NoClientCert {
+ if c.config.ClientAuth == NoClientCert {
// No need to keep a full record of the handshake if client
// certificates won't be used.
hs.finishedHash.discardHandshakeBuffer()
}
keyAgreement := hs.suite.ka(c.vers)
- skx, err := keyAgreement.generateServerKeyExchange(config, hs.cert, hs.clientHello, hs.hello)
+ skx, err := keyAgreement.generateServerKeyExchange(c.config, hs.cert, hs.clientHello, hs.hello)
if err != nil {
c.sendAlert(alertHandshakeFailure)
return err
}
}
- if config.ClientAuth >= RequestClientCert {
+ if c.config.ClientAuth >= RequestClientCert {
// Request a client certificate
certReq := new(certificateRequestMsg)
certReq.certificateTypes = []byte{
// to our request. When we know the CAs we trust, then
// we can send them down, so that the client can choose
// an appropriate certificate to give to us.
- if config.ClientCAs != nil {
- certReq.certificateAuthorities = config.ClientCAs.Subjects()
+ if c.config.ClientCAs != nil {
+ certReq.certificateAuthorities = c.config.ClientCAs.Subjects()
}
hs.finishedHash.Write(certReq.marshal())
if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil {
var ok bool
// If we requested a client certificate, then the client must send a
// certificate message, even if it's empty.
- if config.ClientAuth >= RequestClientCert {
+ if c.config.ClientAuth >= RequestClientCert {
if certMsg, ok = msg.(*certificateMsg); !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certMsg, msg)
if len(certMsg.certificates) == 0 {
// The client didn't actually send a certificate
- switch config.ClientAuth {
+ switch c.config.ClientAuth {
case RequireAnyClientCert, RequireAndVerifyClientCert:
c.sendAlert(alertBadCertificate)
return errors.New("tls: client didn't provide a certificate")
}
hs.finishedHash.Write(ckx.marshal())
- preMasterSecret, err := keyAgreement.processClientKeyExchange(config, hs.cert, ckx, c.vers)
+ preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers)
if err != nil {
c.sendAlert(alertHandshakeFailure)
return err
}
hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.clientHello.random, hs.hello.random)
- if err := config.writeKeyLog(hs.clientHello.random, hs.masterSecret); err != nil {
+ if err := c.config.writeKeyLog(hs.clientHello.random, hs.masterSecret); err != nil {
c.sendAlert(alertInternalError)
return err
}
}
}
+var getConfigForClientTests = []struct {
+ setup func(config *Config)
+ callback func(clientHello *ClientHelloInfo) (*Config, error)
+ errorSubstring string
+ verify func(config *Config) error
+}{
+ {
+ nil,
+ func(clientHello *ClientHelloInfo) (*Config, error) {
+ return nil, nil
+ },
+ "",
+ nil,
+ },
+ {
+ nil,
+ func(clientHello *ClientHelloInfo) (*Config, error) {
+ return nil, errors.New("should bubble up")
+ },
+ "should bubble up",
+ nil,
+ },
+ {
+ nil,
+ func(clientHello *ClientHelloInfo) (*Config, error) {
+ config := testConfig.Clone()
+ // Setting a maximum version of TLS 1.1 should cause
+ // the handshake to fail.
+ config.MaxVersion = VersionTLS11
+ return config, nil
+ },
+ "version 301 when expecting version 302",
+ nil,
+ },
+ {
+ func(config *Config) {
+ for i := range config.SessionTicketKey {
+ config.SessionTicketKey[i] = byte(i)
+ }
+ config.sessionTicketKeys = nil
+ },
+ func(clientHello *ClientHelloInfo) (*Config, error) {
+ config := testConfig.Clone()
+ for i := range config.SessionTicketKey {
+ config.SessionTicketKey[i] = 0
+ }
+ config.sessionTicketKeys = nil
+ return config, nil
+ },
+ "",
+ func(config *Config) error {
+ // The value of SessionTicketKey should have been
+ // duplicated into the per-connection Config.
+ for i := range config.SessionTicketKey {
+ if b := config.SessionTicketKey[i]; b != byte(i) {
+ return fmt.Errorf("SessionTicketKey was not duplicated from original Config: byte %d has value %d", i, b)
+ }
+ }
+ return nil
+ },
+ },
+ {
+ func(config *Config) {
+ var dummyKey [32]byte
+ for i := range dummyKey {
+ dummyKey[i] = byte(i)
+ }
+
+ config.SetSessionTicketKeys([][32]byte{dummyKey})
+ },
+ func(clientHello *ClientHelloInfo) (*Config, error) {
+ config := testConfig.Clone()
+ config.sessionTicketKeys = nil
+ return config, nil
+ },
+ "",
+ func(config *Config) error {
+ // The session ticket keys should have been duplicated
+ // into the per-connection Config.
+ if l := len(config.sessionTicketKeys); l != 1 {
+ return fmt.Errorf("got len(sessionTicketKeys) == %d, wanted 1", l)
+ }
+ return nil
+ },
+ },
+}
+
+func TestGetConfigForClient(t *testing.T) {
+ serverConfig := testConfig.Clone()
+ clientConfig := testConfig.Clone()
+ clientConfig.MinVersion = VersionTLS12
+
+ for i, test := range getConfigForClientTests {
+ if test.setup != nil {
+ test.setup(serverConfig)
+ }
+
+ var configReturned *Config
+ serverConfig.GetConfigForClient = func(clientHello *ClientHelloInfo) (*Config, error) {
+ config, err := test.callback(clientHello)
+ configReturned = config
+ return config, err
+ }
+ c, s := net.Pipe()
+ done := make(chan error)
+
+ go func() {
+ defer s.Close()
+ done <- Server(s, serverConfig).Handshake()
+ }()
+
+ clientErr := Client(c, clientConfig).Handshake()
+ c.Close()
+
+ serverErr := <-done
+
+ if len(test.errorSubstring) == 0 {
+ if serverErr != nil || clientErr != nil {
+ t.Errorf("%#d: expected no error but got serverErr: %q, clientErr: %q", i, serverErr, clientErr)
+ }
+ if test.verify != nil {
+ if err := test.verify(configReturned); err != nil {
+ t.Errorf("#%d: verify returned error: %v", i, err)
+ }
+ }
+ } else {
+ if serverErr == nil {
+ t.Errorf("%#d: expected error containing %q but got no error", i, test.errorSubstring)
+ } else if !strings.Contains(serverErr.Error(), test.errorSubstring) {
+ t.Errorf("%#d: expected error to contain %q but it was %q", i, test.errorSubstring, serverErr)
+ }
+ }
+ }
+}
+
func bigFromString(s string) *big.Int {
ret := new(big.Int)
ret.SetString(s, 10)
case "Rand":
f.Set(reflect.ValueOf(io.Reader(os.Stdin)))
continue
- case "Time", "GetCertificate":
+ case "Time", "GetCertificate", "GetConfigForClient":
// DeepEqual can't compare functions.
continue
case "Certificates":