]> Cypherpunks.ru repositories - gostls13.git/blobdiff - src/crypto/x509/verify.go
crypto/x509: implement AddCertWithConstraint
[gostls13.git] / src / crypto / x509 / verify.go
index 218d794cca460ac4d41fc3e3a12d7570ba5a18a3..9d3c3246d3098dc474ed1f0b49db3ee316fe9b6c 100644 (file)
@@ -7,6 +7,7 @@ package x509
 import (
        "bytes"
        "crypto"
+       "crypto/x509/pkix"
        "errors"
        "fmt"
        "net"
@@ -173,11 +174,6 @@ var errNotParsed = errors.New("x509: missing ASN.1 contents; use ParseCertificat
 
 // VerifyOptions contains parameters for Certificate.Verify.
 type VerifyOptions struct {
-       // IsBoring is a validity check for BoringCrypto.
-       // If not nil, it will be called to check whether a given certificate
-       // can be used for constructing verification chains.
-       IsBoring func(*Certificate) bool
-
        // DNSName, if set, is checked against the leaf certificate with
        // Certificate.VerifyHostname or the platform verifier.
        DNSName string
@@ -595,41 +591,19 @@ func (c *Certificate) isValid(certType int, currentChain []*Certificate, opts *V
        }
        comparisonCount := 0
 
-       var leaf *Certificate
        if certType == intermediateCertificate || certType == rootCertificate {
                if len(currentChain) == 0 {
                        return errors.New("x509: internal error: empty chain when appending CA cert")
                }
-               leaf = currentChain[0]
-       }
-
-       if (len(c.ExtKeyUsage) > 0 || len(c.UnknownExtKeyUsage) > 0) && len(opts.KeyUsages) > 0 {
-               acceptableUsage := false
-               um := make(map[ExtKeyUsage]bool, len(opts.KeyUsages))
-               for _, u := range opts.KeyUsages {
-                       um[u] = true
-               }
-               if !um[ExtKeyUsageAny] {
-                       for _, u := range c.ExtKeyUsage {
-                               if u == ExtKeyUsageAny || um[u] {
-                                       acceptableUsage = true
-                                       break
-                               }
-                       }
-                       if !acceptableUsage {
-                               return CertificateInvalidError{c, IncompatibleUsage, ""}
-                       }
-               }
        }
 
        if (certType == intermediateCertificate || certType == rootCertificate) &&
                c.hasNameConstraints() {
                toCheck := []*Certificate{}
-               if leaf.hasSANExtension() {
-                       toCheck = append(toCheck, leaf)
-               }
-               if c.hasSANExtension() {
-                       toCheck = append(toCheck, c)
+               for _, c := range currentChain {
+                       if c.hasSANExtension() {
+                               toCheck = append(toCheck, c)
+                       }
                }
                for _, sanCert := range toCheck {
                        err := forEachSAN(sanCert.getSANExtension(), func(tag int, data []byte) error {
@@ -729,7 +703,7 @@ func (c *Certificate) isValid(certType int, currentChain []*Certificate, opts *V
                }
        }
 
-       if opts.IsBoring != nil && !opts.IsBoring(c) {
+       if !boringAllowCert(c) {
                // IncompatibleUsage is not quite right here,
                // but it's also the "no chains found" error
                // and is close enough.
@@ -768,6 +742,8 @@ func (c *Certificate) isValid(certType int, currentChain []*Certificate, opts *V
 // Certificates that use SHA1WithRSA and ECDSAWithSHA1 signatures are not supported,
 // and will not be used to build chains.
 //
+// Certificates other than c in the returned chains should not be modified.
+//
 // WARNING: this function doesn't do any revocation checking.
 func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err error) {
        // Platform-specific verification needs the ASN.1 contents so
@@ -776,7 +752,7 @@ func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err e
                return nil, errNotParsed
        }
        for i := 0; i < opts.Intermediates.len(); i++ {
-               c, err := opts.Intermediates.cert(i)
+               c, _, err := opts.Intermediates.cert(i)
                if err != nil {
                        return nil, fmt.Errorf("crypto/x509: error fetching intermediate: %w", err)
                }
@@ -787,7 +763,10 @@ func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err e
 
        // Use platform verifiers, where available, if Roots is from SystemCertPool.
        if runtime.GOOS == "windows" || runtime.GOOS == "darwin" || runtime.GOOS == "ios" {
-               if opts.Roots == nil {
+               // Don't use the system verifier if the system pool was replaced with a non-system pool,
+               // i.e. if SetFallbackRoots was called with x509usefallbackroots=1.
+               systemPool := systemRootsPool()
+               if opts.Roots == nil && (systemPool == nil || systemPool.systemPool) {
                        return c.systemVerify(&opts)
                }
                if opts.Roots != nil && opts.Roots.systemPool {
@@ -808,10 +787,6 @@ func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err e
                }
        }
 
-       if len(opts.KeyUsages) == 0 {
-               opts.KeyUsages = []ExtKeyUsage{ExtKeyUsageServerAuth}
-       }
-
        err = c.isValid(leafCertificate, nil, &opts)
        if err != nil {
                return
@@ -824,10 +799,40 @@ func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err e
                }
        }
 
+       var candidateChains [][]*Certificate
        if opts.Roots.contains(c) {
-               return [][]*Certificate{{c}}, nil
+               candidateChains = [][]*Certificate{{c}}
+       } else {
+               candidateChains, err = c.buildChains([]*Certificate{c}, nil, &opts)
+               if err != nil {
+                       return nil, err
+               }
+       }
+
+       if len(opts.KeyUsages) == 0 {
+               opts.KeyUsages = []ExtKeyUsage{ExtKeyUsageServerAuth}
+       }
+
+       for _, eku := range opts.KeyUsages {
+               if eku == ExtKeyUsageAny {
+                       // If any key usage is acceptable, no need to check the chain for
+                       // key usages.
+                       return candidateChains, nil
+               }
        }
-       return c.buildChains([]*Certificate{c}, nil, &opts)
+
+       chains = make([][]*Certificate, 0, len(candidateChains))
+       for _, candidate := range candidateChains {
+               if checkChainForKeyUsage(candidate, opts.KeyUsages) {
+                       chains = append(chains, candidate)
+               }
+       }
+
+       if len(chains) == 0 {
+               return nil, CertificateInvalidError{c, IncompatibleUsage, ""}
+       }
+
+       return chains, nil
 }
 
 func appendToFreshChain(chain []*Certificate, cert *Certificate) []*Certificate {
@@ -837,6 +842,50 @@ func appendToFreshChain(chain []*Certificate, cert *Certificate) []*Certificate
        return n
 }
 
+// alreadyInChain checks whether a candidate certificate is present in a chain.
+// Rather than doing a direct byte for byte equivalency check, we check if the
+// subject, public key, and SAN, if present, are equal. This prevents loops that
+// are created by mutual cross-signatures, or other cross-signature bridge
+// oddities.
+func alreadyInChain(candidate *Certificate, chain []*Certificate) bool {
+       type pubKeyEqual interface {
+               Equal(crypto.PublicKey) bool
+       }
+
+       var candidateSAN *pkix.Extension
+       for _, ext := range candidate.Extensions {
+               if ext.Id.Equal(oidExtensionSubjectAltName) {
+                       candidateSAN = &ext
+                       break
+               }
+       }
+
+       for _, cert := range chain {
+               if !bytes.Equal(candidate.RawSubject, cert.RawSubject) {
+                       continue
+               }
+               if !candidate.PublicKey.(pubKeyEqual).Equal(cert.PublicKey) {
+                       continue
+               }
+               var certSAN *pkix.Extension
+               for _, ext := range cert.Extensions {
+                       if ext.Id.Equal(oidExtensionSubjectAltName) {
+                               certSAN = &ext
+                               break
+                       }
+               }
+               if candidateSAN == nil && certSAN == nil {
+                       return true
+               } else if candidateSAN == nil || certSAN == nil {
+                       return false
+               }
+               if bytes.Equal(candidateSAN.Value, certSAN.Value) {
+                       return true
+               }
+       }
+       return false
+}
+
 // maxChainSignatureChecks is the maximum number of CheckSignatureFrom calls
 // that an invocation of buildChains will (transitively) make. Most chains are
 // less than 15 certificates long, so this leaves space for multiple chains and
@@ -849,18 +898,9 @@ func (c *Certificate) buildChains(currentChain []*Certificate, sigChecks *int, o
                hintCert *Certificate
        )
 
-       type pubKeyEqual interface {
-               Equal(crypto.PublicKey) bool
-       }
-
-       considerCandidate := func(certType int, candidate *Certificate) {
-               for _, cert := range currentChain {
-                       // If a certificate already appeared in the chain we've built, don't
-                       // reconsider it. This prevents loops, for isntance those created by
-                       // mutual cross-signatures, or other cross-signature bridges oddities.
-                       if bytes.Equal(cert.RawSubject, candidate.RawSubject) && cert.PublicKey.(pubKeyEqual).Equal(candidate.PublicKey) {
-                               return
-                       }
+       considerCandidate := func(certType int, candidate potentialParent) {
+               if alreadyInChain(candidate.cert, currentChain) {
+                       return
                }
 
                if sigChecks == nil {
@@ -872,25 +912,39 @@ func (c *Certificate) buildChains(currentChain []*Certificate, sigChecks *int, o
                        return
                }
 
-               if err := c.CheckSignatureFrom(candidate); err != nil {
+               if err := c.CheckSignatureFrom(candidate.cert); err != nil {
                        if hintErr == nil {
                                hintErr = err
-                               hintCert = candidate
+                               hintCert = candidate.cert
                        }
                        return
                }
 
-               err = candidate.isValid(certType, currentChain, opts)
+               err = candidate.cert.isValid(certType, currentChain, opts)
                if err != nil {
+                       if hintErr == nil {
+                               hintErr = err
+                               hintCert = candidate.cert
+                       }
                        return
                }
 
+               if candidate.constraint != nil {
+                       if err := candidate.constraint(currentChain); err != nil {
+                               if hintErr == nil {
+                                       hintErr = err
+                                       hintCert = candidate.cert
+                               }
+                               return
+                       }
+               }
+
                switch certType {
                case rootCertificate:
-                       chains = append(chains, appendToFreshChain(currentChain, candidate))
+                       chains = append(chains, appendToFreshChain(currentChain, candidate.cert))
                case intermediateCertificate:
                        var childChains [][]*Certificate
-                       childChains, err = candidate.buildChains(appendToFreshChain(currentChain, candidate), sigChecks, opts)
+                       childChains, err = candidate.cert.buildChains(appendToFreshChain(currentChain, candidate.cert), sigChecks, opts)
                        chains = append(chains, childChains...)
                }
        }
@@ -1034,7 +1088,7 @@ func toLowerCaseASCII(in string) string {
 // IP addresses can be optionally enclosed in square brackets and are checked
 // against the IPAddresses field. Other names are checked case insensitively
 // against the DNSNames field. If the names are valid hostnames, the certificate
-// fields can have a wildcard as the left-most label.
+// fields can have a wildcard as the complete left-most label (e.g. *.example.com).
 //
 // Note that the legacy Common Name field is ignored.
 func (c *Certificate) VerifyHostname(h string) error {