]> Cypherpunks.ru repositories - gostls13.git/commitdiff
crypto/tls: add GetClientCertificate callback
authorAdam Langley <agl@golang.org>
Wed, 26 Oct 2016 17:05:03 +0000 (10:05 -0700)
committerAdam Langley <agl@golang.org>
Thu, 27 Oct 2016 17:20:07 +0000 (17:20 +0000)
Currently, the selection of a client certificate done internally based
on the limitations given by the server's request and the certifcates in
the Config. This means that it's not possible for an application to
control that selection based on details of the request.

This change adds a callback, GetClientCertificate, that is called by a
Client during the handshake and which allows applications to select the
best certificate at that time.

(Based on https://golang.org/cl/25570/ by Bernd Fix.)

Fixes #16626.

Change-Id: Ia4cea03235d2aa3c9fd49c99c227593c8e86ddd9
Reviewed-on: https://go-review.googlesource.com/32115
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/crypto/tls/common.go
src/crypto/tls/handshake_client.go
src/crypto/tls/handshake_client_test.go
src/crypto/tls/tls_test.go

index beca79897b57a78bd5065d1b706c3a919a21a2e1..5b2c6664b2fb064b4adf929bd3127a356c65c0b1 100644 (file)
@@ -286,6 +286,21 @@ type ClientHelloInfo struct {
        Conn net.Conn
 }
 
+// CertificateRequestInfo contains information from a server's
+// CertificateRequest message, which is used to demand a certificate and proof
+// of control from a client.
+type CertificateRequestInfo struct {
+       // AcceptableCAs contains zero or more, DER-encoded, X.501
+       // Distinguished Names. These are the names of root or intermediate CAs
+       // that the server wishes the returned certificate to be signed by. An
+       // empty slice indicates that the server has no preference.
+       AcceptableCAs [][]byte
+
+       // SignatureSchemes lists the signature schemes that the server is
+       // willing to verify.
+       SignatureSchemes []SignatureScheme
+}
+
 // RenegotiationSupport enumerates the different levels of support for TLS
 // renegotiation. TLS renegotiation is the act of performing subsequent
 // handshakes on a connection after the first. This significantly complicates
@@ -328,10 +343,11 @@ type Config struct {
        // If Time is nil, TLS uses time.Now.
        Time func() time.Time
 
-       // Certificates contains one or more certificate chains
-       // to present to the other side of the connection.
-       // Server configurations must include at least one certificate
-       // or else set GetCertificate.
+       // Certificates contains one or more certificate chains to present to
+       // the other side of the connection. Server configurations must include
+       // at least one certificate or else set GetCertificate. Clients doing
+       // client-authentication may set either Certificates or
+       // GetClientCertificate.
        Certificates []Certificate
 
        // NameToCertificate maps from a certificate name to an element of
@@ -351,6 +367,21 @@ type Config struct {
        // first element of Certificates will be used.
        GetCertificate func(*ClientHelloInfo) (*Certificate, error)
 
+       // GetClientCertificate, if not nil, is called when a server requests a
+       // certificate from a client. If set, the contents of Certificates will
+       // be ignored.
+       //
+       // If GetClientCertificate returns an error, the handshake will be
+       // aborted and that error will be returned. Otherwise
+       // GetClientCertificate must return a non-nil Certificate. If
+       // Certificate.Certificate is empty then no certificate will be sent to
+       // the server. If this is unacceptable to the server then it may abort
+       // the handshake.
+       //
+       // GetClientCertificate may be called multiple times for the same
+       // connection if renegotiation occurs or if TLS 1.3 is in use.
+       GetClientCertificate func(*CertificateRequestInfo) (*Certificate, error)
+
        // GetConfigForClient, if not nil, is called after a ClientHello is
        // received from a client. It may return a non-nil Config in order to
        // change the Config that will be used to handle this connection. If
index c331c652157889752b34b7ac2b255114bf5b50f8..89bdd5944dc76811883d8fd7420ab7a282d7e793 100644 (file)
@@ -199,7 +199,7 @@ NextCipherSuite:
        // Otherwise, in a full handshake, if we don't have any certificates
        // configured then we will never send a CertificateVerify message and
        // thus no signatures are needed in that case either.
-       if isResume || len(c.config.Certificates) == 0 {
+       if isResume || (len(c.config.Certificates) == 0 && c.config.GetClientCertificate == nil) {
                hs.finishedHash.discardHandshakeBuffer()
        }
 
@@ -377,71 +377,11 @@ func (hs *clientHandshakeState) doFullHandshake() error {
        certReq, ok := msg.(*certificateRequestMsg)
        if ok {
                certRequested = true
-
-               // RFC 4346 on the certificateAuthorities field:
-               // A list of the distinguished names of acceptable certificate
-               // authorities. These distinguished names may specify a desired
-               // distinguished name for a root CA or for a subordinate CA;
-               // thus, this message can be used to describe both known roots
-               // and a desired authorization space. If the
-               // certificate_authorities list is empty then the client MAY
-               // send any certificate of the appropriate
-               // ClientCertificateType, unless there is some external
-               // arrangement to the contrary.
-
                hs.finishedHash.Write(certReq.marshal())
 
-               var rsaAvail, ecdsaAvail bool
-               for _, certType := range certReq.certificateTypes {
-                       switch certType {
-                       case certTypeRSASign:
-                               rsaAvail = true
-                       case certTypeECDSASign:
-                               ecdsaAvail = true
-                       }
-               }
-
-               // We need to search our list of client certs for one
-               // where SignatureAlgorithm is acceptable to the server and the
-               // Issuer is in certReq.certificateAuthorities
-       findCert:
-               for i, chain := range c.config.Certificates {
-                       if !rsaAvail && !ecdsaAvail {
-                               continue
-                       }
-
-                       for j, cert := range chain.Certificate {
-                               x509Cert := chain.Leaf
-                               // parse the certificate if this isn't the leaf
-                               // node, or if chain.Leaf was nil
-                               if j != 0 || x509Cert == nil {
-                                       if x509Cert, err = x509.ParseCertificate(cert); err != nil {
-                                               c.sendAlert(alertInternalError)
-                                               return errors.New("tls: failed to parse client certificate #" + strconv.Itoa(i) + ": " + err.Error())
-                                       }
-                               }
-
-                               switch {
-                               case rsaAvail && x509Cert.PublicKeyAlgorithm == x509.RSA:
-                               case ecdsaAvail && x509Cert.PublicKeyAlgorithm == x509.ECDSA:
-                               default:
-                                       continue findCert
-                               }
-
-                               if len(certReq.certificateAuthorities) == 0 {
-                                       // they gave us an empty list, so just take the
-                                       // first cert from c.config.Certificates
-                                       chainToSend = &chain
-                                       break findCert
-                               }
-
-                               for _, ca := range certReq.certificateAuthorities {
-                                       if bytes.Equal(x509Cert.RawIssuer, ca) {
-                                               chainToSend = &chain
-                                               break findCert
-                                       }
-                               }
-                       }
+               if chainToSend, err = hs.getCertificate(certReq); err != nil {
+                       c.sendAlert(alertInternalError)
+                       return err
                }
 
                msg, err = c.readHandshake()
@@ -462,9 +402,7 @@ func (hs *clientHandshakeState) doFullHandshake() error {
        // certificate to send.
        if certRequested {
                certMsg = new(certificateMsg)
-               if chainToSend != nil {
-                       certMsg.certificates = chainToSend.Certificate
-               }
+               certMsg.certificates = chainToSend.Certificate
                hs.finishedHash.Write(certMsg.marshal())
                if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
                        return err
@@ -483,7 +421,7 @@ func (hs *clientHandshakeState) doFullHandshake() error {
                }
        }
 
-       if chainToSend != nil {
+       if chainToSend != nil && len(chainToSend.Certificate) > 0 {
                certVerify := &certificateVerifyMsg{
                        hasSignatureAndHash: c.vers >= VersionTLS12,
                }
@@ -727,6 +665,117 @@ func (hs *clientHandshakeState) sendFinished(out []byte) error {
        return nil
 }
 
+// tls11SignatureSchemes contains the signature schemes that we synthesise for
+// a TLS <= 1.1 connection, based on the supported certificate types.
+var tls11SignatureSchemes = []SignatureScheme{ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512, PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512, PKCS1WithSHA1}
+
+const (
+       // tls11SignatureSchemesNumECDSA is the number of initial elements of
+       // tls11SignatureSchemes that use ECDSA.
+       tls11SignatureSchemesNumECDSA = 3
+       // tls11SignatureSchemesNumRSA is the number of trailing elements of
+       // tls11SignatureSchemes that use RSA.
+       tls11SignatureSchemesNumRSA = 4
+)
+
+func (hs *clientHandshakeState) getCertificate(certReq *certificateRequestMsg) (*Certificate, error) {
+       c := hs.c
+
+       var rsaAvail, ecdsaAvail bool
+       for _, certType := range certReq.certificateTypes {
+               switch certType {
+               case certTypeRSASign:
+                       rsaAvail = true
+               case certTypeECDSASign:
+                       ecdsaAvail = true
+               }
+       }
+
+       if c.config.GetClientCertificate != nil {
+               var signatureSchemes []SignatureScheme
+
+               if !certReq.hasSignatureAndHash {
+                       // Prior to TLS 1.2, the signature schemes were not
+                       // included in the certificate request message. In this
+                       // case we use a plausible list based on the acceptable
+                       // certificate types.
+                       signatureSchemes = tls11SignatureSchemes
+                       if !ecdsaAvail {
+                               signatureSchemes = signatureSchemes[tls11SignatureSchemesNumECDSA:]
+                       }
+                       if !rsaAvail {
+                               signatureSchemes = signatureSchemes[:len(signatureSchemes)-tls11SignatureSchemesNumRSA]
+                       }
+               } else {
+                       signatureSchemes = make([]SignatureScheme, 0, len(certReq.signatureAndHashes))
+                       for _, sah := range certReq.signatureAndHashes {
+                               signatureSchemes = append(signatureSchemes, SignatureScheme(sah.hash)<<8+SignatureScheme(sah.signature))
+                       }
+               }
+
+               return c.config.GetClientCertificate(&CertificateRequestInfo{
+                       AcceptableCAs:    certReq.certificateAuthorities,
+                       SignatureSchemes: signatureSchemes,
+               })
+       }
+
+       // RFC 4346 on the certificateAuthorities field: A list of the
+       // distinguished names of acceptable certificate authorities.
+       // These distinguished names may specify a desired
+       // distinguished name for a root CA or for a subordinate CA;
+       // thus, this message can be used to describe both known roots
+       // and a desired authorization space. If the
+       // certificate_authorities list is empty then the client MAY
+       // send any certificate of the appropriate
+       // ClientCertificateType, unless there is some external
+       // arrangement to the contrary.
+
+       // We need to search our list of client certs for one
+       // where SignatureAlgorithm is acceptable to the server and the
+       // Issuer is in certReq.certificateAuthorities
+findCert:
+       for i, chain := range c.config.Certificates {
+               if !rsaAvail && !ecdsaAvail {
+                       continue
+               }
+
+               for j, cert := range chain.Certificate {
+                       x509Cert := chain.Leaf
+                       // parse the certificate if this isn't the leaf
+                       // node, or if chain.Leaf was nil
+                       if j != 0 || x509Cert == nil {
+                               var err error
+                               if x509Cert, err = x509.ParseCertificate(cert); err != nil {
+                                       c.sendAlert(alertInternalError)
+                                       return nil, errors.New("tls: failed to parse client certificate #" + strconv.Itoa(i) + ": " + err.Error())
+                               }
+                       }
+
+                       switch {
+                       case rsaAvail && x509Cert.PublicKeyAlgorithm == x509.RSA:
+                       case ecdsaAvail && x509Cert.PublicKeyAlgorithm == x509.ECDSA:
+                       default:
+                               continue findCert
+                       }
+
+                       if len(certReq.certificateAuthorities) == 0 {
+                               // they gave us an empty list, so just take the
+                               // first cert from c.config.Certificates
+                               return &chain, nil
+                       }
+
+                       for _, ca := range certReq.certificateAuthorities {
+                               if bytes.Equal(x509Cert.RawIssuer, ca) {
+                                       return &chain, nil
+                               }
+                       }
+               }
+       }
+
+       // No acceptable certificate found. Don't send a certificate.
+       return new(Certificate), nil
+}
+
 // clientSessionCacheKey returns a key used to cache sessionTickets that could
 // be used to resume previously negotiated TLS sessions with a server.
 func clientSessionCacheKey(serverAddr net.Addr, config *Config) string {
index 24d119e50c5eb77e8731814c9e9fd68147733aba..d603915e17c545fa9562a7e34153ce8e19662b3a 100644 (file)
@@ -1408,3 +1408,137 @@ func TestHandshakeRace(t *testing.T) {
                <-readDone
        }
 }
+
+func TestTLS11SignatureSchemes(t *testing.T) {
+       expected := tls11SignatureSchemesNumECDSA + tls11SignatureSchemesNumRSA
+       if expected != len(tls11SignatureSchemes) {
+               t.Errorf("expected to find %d TLS 1.1 signature schemes, but found %d", expected, len(tls11SignatureSchemes))
+       }
+}
+
+var getClientCertificateTests = []struct {
+       setup               func(*Config)
+       expectedClientError string
+       verify              func(*testing.T, int, *ConnectionState)
+}{
+       {
+               func(clientConfig *Config) {
+                       // Returning a Certificate with no certificate data
+                       // should result in an empty message being sent to the
+                       // server.
+                       clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
+                               if len(cri.SignatureSchemes) == 0 {
+                                       panic("empty SignatureSchemes")
+                               }
+                               return new(Certificate), nil
+                       }
+               },
+               "",
+               func(t *testing.T, testNum int, cs *ConnectionState) {
+                       if l := len(cs.PeerCertificates); l != 0 {
+                               t.Errorf("#%d: expected no certificates but got %d", testNum, l)
+                       }
+               },
+       },
+       {
+               func(clientConfig *Config) {
+                       // With TLS 1.1, the SignatureSchemes should be
+                       // synthesised from the supported certificate types.
+                       clientConfig.MaxVersion = VersionTLS11
+                       clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
+                               if len(cri.SignatureSchemes) == 0 {
+                                       panic("empty SignatureSchemes")
+                               }
+                               return new(Certificate), nil
+                       }
+               },
+               "",
+               func(t *testing.T, testNum int, cs *ConnectionState) {
+                       if l := len(cs.PeerCertificates); l != 0 {
+                               t.Errorf("#%d: expected no certificates but got %d", testNum, l)
+                       }
+               },
+       },
+       {
+               func(clientConfig *Config) {
+                       // Returning an error should abort the handshake with
+                       // that error.
+                       clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
+                               return nil, errors.New("GetClientCertificate")
+                       }
+               },
+               "GetClientCertificate",
+               func(t *testing.T, testNum int, cs *ConnectionState) {
+               },
+       },
+       {
+               func(clientConfig *Config) {
+                       clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
+                               return &testConfig.Certificates[0], nil
+                       }
+               },
+               "",
+               func(t *testing.T, testNum int, cs *ConnectionState) {
+                       if l := len(cs.VerifiedChains); l != 0 {
+                               t.Errorf("#%d: expected some verified chains, but found none", testNum)
+                       }
+               },
+       },
+}
+
+func TestGetClientCertificate(t *testing.T) {
+       issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
+       if err != nil {
+               panic(err)
+       }
+
+       for i, test := range getClientCertificateTests {
+               serverConfig := testConfig.Clone()
+               serverConfig.ClientAuth = RequestClientCert
+               serverConfig.RootCAs = x509.NewCertPool()
+               serverConfig.RootCAs.AddCert(issuer)
+
+               clientConfig := testConfig.Clone()
+
+               test.setup(clientConfig)
+
+               type serverResult struct {
+                       cs  ConnectionState
+                       err error
+               }
+
+               c, s := net.Pipe()
+               done := make(chan serverResult)
+
+               go func() {
+                       defer s.Close()
+                       server := Server(s, serverConfig)
+                       err := server.Handshake()
+
+                       var cs ConnectionState
+                       if err == nil {
+                               cs = server.ConnectionState()
+                       }
+                       done <- serverResult{cs, err}
+               }()
+
+               clientErr := Client(c, clientConfig).Handshake()
+               c.Close()
+
+               result := <-done
+
+               if clientErr != nil {
+                       if len(test.expectedClientError) == 0 {
+                               t.Errorf("#%d: client error: %v", i, clientErr)
+                       } else if got := clientErr.Error(); got != test.expectedClientError {
+                               t.Errorf("#%d: expected client error %q, but got %q", i, test.expectedClientError, got)
+                       }
+               } else if len(test.expectedClientError) > 0 {
+                       t.Errorf("#%d: expected client error %q, but got no error", i, test.expectedClientError)
+               } else if err := result.err; err != nil {
+                       t.Errorf("#%d: server error: %v", i, err)
+               } else {
+                       test.verify(t, i, &result.cs)
+               }
+       }
+}
index b4fedd59c0ffbf30a849993158bf1eb9785ae887..a0c09081a6af4bdd9e0d6a2ee0b94a2955e2e285 100644 (file)
@@ -584,7 +584,7 @@ func TestClone(t *testing.T) {
                case "Rand":
                        f.Set(reflect.ValueOf(io.Reader(os.Stdin)))
                        continue
-               case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate":
+               case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "GetClientCertificate":
                        // DeepEqual can't compare functions.
                        continue
                case "Certificates":