]> Cypherpunks.ru repositories - gostls13.git/commitdiff
crypto/tls: add VerifyPeerCertificate to tls.Config
authorJoshua Boelter <joshua.boelter@intel.com>
Wed, 13 Jul 2016 22:22:28 +0000 (16:22 -0600)
committerAdam Langley <agl@golang.org>
Mon, 24 Oct 2016 23:24:11 +0000 (23:24 +0000)
VerifyPeerCertificate returns an error if the peer should not be
trusted. It will be called after the initial handshake and before
any other verification checks on the cert or chain are performed.
This provides the callee an opportunity to augment the certificate
verification.

If VerifyPeerCertificate is not nil and returns an error,
then the handshake will fail.

Fixes #16363

Change-Id: I6a22f199f0e81b6f5d5f37c54d85ab878216bb22
Reviewed-on: https://go-review.googlesource.com/26654
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/crypto/tls/common.go
src/crypto/tls/handshake_client.go
src/crypto/tls/handshake_client_test.go
src/crypto/tls/handshake_server.go
src/crypto/tls/tls_test.go

index fc898e202acdad123185c647fd8d0ef5af6eb20a..9078b63cdf231765fb56381eaf8c44de867692cc 100644 (file)
@@ -325,6 +325,18 @@ type Config struct {
        // material from the returned config will be used for session tickets.
        GetConfigForClient func(*ClientHelloInfo) (*Config, error)
 
+       // VerifyPeerCertificate, if not nil, is called after normal
+       // certificate verification by either a TLS client or server. It
+       // receives the raw ASN.1 certificates provided by the peer and also
+       // any verified chains that normal processing found. If it returns a
+       // non-nil error, the handshake is aborted and that error results.
+       //
+       // If normal verification fails then the handshake will abort before
+       // considering this callback. If normal verification is disabled by
+       // setting InsecureSkipVerify then this callback will be considered but
+       // the verifiedChains argument will always be nil.
+       VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
+
        // RootCAs defines the set of root certificate authorities
        // that clients use when verifying server certificates.
        // If RootCAs is nil, TLS uses the host's root CA set.
@@ -474,6 +486,7 @@ func (c *Config) Clone() *Config {
                NameToCertificate:           c.NameToCertificate,
                GetCertificate:              c.GetCertificate,
                GetConfigForClient:          c.GetConfigForClient,
+               VerifyPeerCertificate:       c.VerifyPeerCertificate,
                RootCAs:                     c.RootCAs,
                NextProtos:                  c.NextProtos,
                ServerName:                  c.ServerName,
index e42953a0755e3236c2f3734548b816b7647aae9a..c331c652157889752b34b7ac2b255114bf5b50f8 100644 (file)
@@ -304,6 +304,13 @@ func (hs *clientHandshakeState) doFullHandshake() error {
                        }
                }
 
+               if c.config.VerifyPeerCertificate != nil {
+                       if err := c.config.VerifyPeerCertificate(certMsg.certificates, c.verifiedChains); err != nil {
+                               c.sendAlert(alertBadCertificate)
+                               return err
+                       }
+               }
+
                switch certs[0].PublicKey.(type) {
                case *rsa.PublicKey, *ecdsa.PublicKey:
                        break
index 3de1dfab86628d80a8776dc488f7d6263ab505d4..7bbeed003417f9fdd2d70bb0a34c984edca56042 100644 (file)
@@ -1067,6 +1067,160 @@ func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) {
        }
 }
 
+func TestVerifyPeerCertificate(t *testing.T) {
+       issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
+       if err != nil {
+               panic(err)
+       }
+
+       rootCAs := x509.NewCertPool()
+       rootCAs.AddCert(issuer)
+
+       now := func() time.Time { return time.Unix(1476984729, 0) }
+
+       sentinelErr := errors.New("TestVerifyPeerCertificate")
+
+       verifyCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
+               if l := len(rawCerts); l != 1 {
+                       return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
+               }
+               if len(validatedChains) == 0 {
+                       return errors.New("got len(validatedChains) = 0, wanted non-zero")
+               }
+               *called = true
+               return nil
+       }
+
+       tests := []struct {
+               configureServer func(*Config, *bool)
+               configureClient func(*Config, *bool)
+               validate        func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error)
+       }{
+               {
+                       configureServer: func(config *Config, called *bool) {
+                               config.InsecureSkipVerify = false
+                               config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
+                                       return verifyCallback(called, rawCerts, validatedChains)
+                               }
+                       },
+                       configureClient: func(config *Config, called *bool) {
+                               config.InsecureSkipVerify = false
+                               config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
+                                       return verifyCallback(called, rawCerts, validatedChains)
+                               }
+                       },
+                       validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
+                               if clientErr != nil {
+                                       t.Errorf("#%d: client handshake failed: %v", testNo, clientErr)
+                               }
+                               if serverErr != nil {
+                                       t.Errorf("#%d: server handshake failed: %v", testNo, serverErr)
+                               }
+                               if !clientCalled {
+                                       t.Error("#%d: client did not call callback", testNo)
+                               }
+                               if !serverCalled {
+                                       t.Error("#%d: server did not call callback", testNo)
+                               }
+                       },
+               },
+               {
+                       configureServer: func(config *Config, called *bool) {
+                               config.InsecureSkipVerify = false
+                               config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
+                                       return sentinelErr
+                               }
+                       },
+                       configureClient: func(config *Config, called *bool) {
+                               config.VerifyPeerCertificate = nil
+                       },
+                       validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
+                               if serverErr != sentinelErr {
+                                       t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
+                               }
+                       },
+               },
+               {
+                       configureServer: func(config *Config, called *bool) {
+                               config.InsecureSkipVerify = false
+                       },
+                       configureClient: func(config *Config, called *bool) {
+                               config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
+                                       return sentinelErr
+                               }
+                       },
+                       validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
+                               if clientErr != sentinelErr {
+                                       t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
+                               }
+                       },
+               },
+               {
+                       configureServer: func(config *Config, called *bool) {
+                               config.InsecureSkipVerify = false
+                       },
+                       configureClient: func(config *Config, called *bool) {
+                               config.InsecureSkipVerify = true
+                               config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
+                                       if l := len(rawCerts); l != 1 {
+                                               return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
+                                       }
+                                       // With InsecureSkipVerify set, this
+                                       // callback should still be called but
+                                       // validatedChains must be empty.
+                                       if l := len(validatedChains); l != 0 {
+                                               return errors.New("got len(validatedChains) = 0, wanted zero")
+                                       }
+                                       *called = true
+                                       return nil
+                               }
+                       },
+                       validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
+                               if clientErr != nil {
+                                       t.Errorf("#%d: client handshake failed: %v", testNo, clientErr)
+                               }
+                               if serverErr != nil {
+                                       t.Errorf("#%d: server handshake failed: %v", testNo, serverErr)
+                               }
+                               if !clientCalled {
+                                       t.Error("#%d: client did not call callback", testNo)
+                               }
+                       },
+               },
+       }
+
+       for i, test := range tests {
+               c, s := net.Pipe()
+               done := make(chan error)
+
+               var clientCalled, serverCalled bool
+
+               go func() {
+                       config := testConfig.Clone()
+                       config.ServerName = "example.golang"
+                       config.ClientAuth = RequireAndVerifyClientCert
+                       config.ClientCAs = rootCAs
+                       config.Time = now
+                       test.configureServer(config, &serverCalled)
+
+                       err = Server(s, config).Handshake()
+                       s.Close()
+                       done <- err
+               }()
+
+               config := testConfig.Clone()
+               config.ServerName = "example.golang"
+               config.RootCAs = rootCAs
+               config.Time = now
+               test.configureClient(config, &clientCalled)
+               clientErr := Client(c, config).Handshake()
+               c.Close()
+               serverErr := <-done
+
+               test.validate(t, i, clientCalled, serverCalled, clientErr, serverErr)
+       }
+}
+
 // brokenConn wraps a net.Conn and causes all Writes after a certain number to
 // fail with brokenConnErr.
 type brokenConn struct {
index 630a99acebf04d627964dc398f90abdde0e59d18..724ed71df4c11165001de1b3a54381a666e22d34 100644 (file)
@@ -747,6 +747,13 @@ func (hs *serverHandshakeState) processCertsFromClient(certificates [][]byte) (c
                c.verifiedChains = chains
        }
 
+       if c.config.VerifyPeerCertificate != nil {
+               if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil {
+                       c.sendAlert(alertBadCertificate)
+                       return nil, err
+               }
+       }
+
        if len(certs) == 0 {
                return nil, nil
        }
index 87cfa3f7f1dcdf500a8554c2bd0e710224e81664..99153a5dad4f7e3bda0b257e8d58bc863ec5719a 100644 (file)
@@ -477,7 +477,7 @@ func TestClone(t *testing.T) {
                case "Rand":
                        f.Set(reflect.ValueOf(io.Reader(os.Stdin)))
                        continue
-               case "Time", "GetCertificate", "GetConfigForClient":
+               case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate":
                        // DeepEqual can't compare functions.
                        continue
                case "Certificates":