]> Cypherpunks.ru repositories - gostls13.git/blobdiff - src/crypto/x509/verify.go
crypto/x509: implement AddCertWithConstraint
[gostls13.git] / src / crypto / x509 / verify.go
index 9ef11466a470a3730ea08614a4a0dd8e907d723f..9d3c3246d3098dc474ed1f0b49db3ee316fe9b6c 100644 (file)
@@ -6,6 +6,8 @@ package x509
 
 import (
        "bytes"
+       "crypto"
+       "crypto/x509/pkix"
        "errors"
        "fmt"
        "net"
@@ -500,9 +502,9 @@ func (c *Certificate) checkNameConstraints(count *int,
        maxConstraintComparisons int,
        nameType string,
        name string,
-       parsedName interface{},
-       match func(parsedName, constraint interface{}) (match bool, err error),
-       permitted, excluded interface{}) error {
+       parsedName any,
+       match func(parsedName, constraint any) (match bool, err error),
+       permitted, excluded any) error {
 
        excludedValue := reflect.ValueOf(excluded)
 
@@ -589,81 +591,87 @@ func (c *Certificate) isValid(certType int, currentChain []*Certificate, opts *V
        }
        comparisonCount := 0
 
-       var leaf *Certificate
        if certType == intermediateCertificate || certType == rootCertificate {
                if len(currentChain) == 0 {
                        return errors.New("x509: internal error: empty chain when appending CA cert")
                }
-               leaf = currentChain[0]
        }
 
        if (certType == intermediateCertificate || certType == rootCertificate) &&
-               c.hasNameConstraints() && leaf.hasSANExtension() {
-               err := forEachSAN(leaf.getSANExtension(), func(tag int, data []byte) error {
-                       switch tag {
-                       case nameTypeEmail:
-                               name := string(data)
-                               mailbox, ok := parseRFC2821Mailbox(name)
-                               if !ok {
-                                       return fmt.Errorf("x509: cannot parse rfc822Name %q", mailbox)
-                               }
-
-                               if err := c.checkNameConstraints(&comparisonCount, maxConstraintComparisons, "email address", name, mailbox,
-                                       func(parsedName, constraint interface{}) (bool, error) {
-                                               return matchEmailConstraint(parsedName.(rfc2821Mailbox), constraint.(string))
-                                       }, c.PermittedEmailAddresses, c.ExcludedEmailAddresses); err != nil {
-                                       return err
-                               }
-
-                       case nameTypeDNS:
-                               name := string(data)
-                               if _, ok := domainToReverseLabels(name); !ok {
-                                       return fmt.Errorf("x509: cannot parse dnsName %q", name)
-                               }
-
-                               if err := c.checkNameConstraints(&comparisonCount, maxConstraintComparisons, "DNS name", name, name,
-                                       func(parsedName, constraint interface{}) (bool, error) {
-                                               return matchDomainConstraint(parsedName.(string), constraint.(string))
-                                       }, c.PermittedDNSDomains, c.ExcludedDNSDomains); err != nil {
-                                       return err
-                               }
-
-                       case nameTypeURI:
-                               name := string(data)
-                               uri, err := url.Parse(name)
-                               if err != nil {
-                                       return fmt.Errorf("x509: internal error: URI SAN %q failed to parse", name)
-                               }
-
-                               if err := c.checkNameConstraints(&comparisonCount, maxConstraintComparisons, "URI", name, uri,
-                                       func(parsedName, constraint interface{}) (bool, error) {
-                                               return matchURIConstraint(parsedName.(*url.URL), constraint.(string))
-                                       }, c.PermittedURIDomains, c.ExcludedURIDomains); err != nil {
-                                       return err
-                               }
-
-                       case nameTypeIP:
-                               ip := net.IP(data)
-                               if l := len(ip); l != net.IPv4len && l != net.IPv6len {
-                                       return fmt.Errorf("x509: internal error: IP SAN %x failed to parse", data)
+               c.hasNameConstraints() {
+               toCheck := []*Certificate{}
+               for _, c := range currentChain {
+                       if c.hasSANExtension() {
+                               toCheck = append(toCheck, c)
+                       }
+               }
+               for _, sanCert := range toCheck {
+                       err := forEachSAN(sanCert.getSANExtension(), func(tag int, data []byte) error {
+                               switch tag {
+                               case nameTypeEmail:
+                                       name := string(data)
+                                       mailbox, ok := parseRFC2821Mailbox(name)
+                                       if !ok {
+                                               return fmt.Errorf("x509: cannot parse rfc822Name %q", mailbox)
+                                       }
+
+                                       if err := c.checkNameConstraints(&comparisonCount, maxConstraintComparisons, "email address", name, mailbox,
+                                               func(parsedName, constraint any) (bool, error) {
+                                                       return matchEmailConstraint(parsedName.(rfc2821Mailbox), constraint.(string))
+                                               }, c.PermittedEmailAddresses, c.ExcludedEmailAddresses); err != nil {
+                                               return err
+                                       }
+
+                               case nameTypeDNS:
+                                       name := string(data)
+                                       if _, ok := domainToReverseLabels(name); !ok {
+                                               return fmt.Errorf("x509: cannot parse dnsName %q", name)
+                                       }
+
+                                       if err := c.checkNameConstraints(&comparisonCount, maxConstraintComparisons, "DNS name", name, name,
+                                               func(parsedName, constraint any) (bool, error) {
+                                                       return matchDomainConstraint(parsedName.(string), constraint.(string))
+                                               }, c.PermittedDNSDomains, c.ExcludedDNSDomains); err != nil {
+                                               return err
+                                       }
+
+                               case nameTypeURI:
+                                       name := string(data)
+                                       uri, err := url.Parse(name)
+                                       if err != nil {
+                                               return fmt.Errorf("x509: internal error: URI SAN %q failed to parse", name)
+                                       }
+
+                                       if err := c.checkNameConstraints(&comparisonCount, maxConstraintComparisons, "URI", name, uri,
+                                               func(parsedName, constraint any) (bool, error) {
+                                                       return matchURIConstraint(parsedName.(*url.URL), constraint.(string))
+                                               }, c.PermittedURIDomains, c.ExcludedURIDomains); err != nil {
+                                               return err
+                                       }
+
+                               case nameTypeIP:
+                                       ip := net.IP(data)
+                                       if l := len(ip); l != net.IPv4len && l != net.IPv6len {
+                                               return fmt.Errorf("x509: internal error: IP SAN %x failed to parse", data)
+                                       }
+
+                                       if err := c.checkNameConstraints(&comparisonCount, maxConstraintComparisons, "IP address", ip.String(), ip,
+                                               func(parsedName, constraint any) (bool, error) {
+                                                       return matchIPConstraint(parsedName.(net.IP), constraint.(*net.IPNet))
+                                               }, c.PermittedIPRanges, c.ExcludedIPRanges); err != nil {
+                                               return err
+                                       }
+
+                               default:
+                                       // Unknown SAN types are ignored.
                                }
 
-                               if err := c.checkNameConstraints(&comparisonCount, maxConstraintComparisons, "IP address", ip.String(), ip,
-                                       func(parsedName, constraint interface{}) (bool, error) {
-                                               return matchIPConstraint(parsedName.(net.IP), constraint.(*net.IPNet))
-                                       }, c.PermittedIPRanges, c.ExcludedIPRanges); err != nil {
-                                       return err
-                               }
+                               return nil
+                       })
 
-                       default:
-                               // Unknown SAN types are ignored.
+                       if err != nil {
+                               return err
                        }
-
-                       return nil
-               })
-
-               if err != nil {
-                       return err
                }
        }
 
@@ -695,6 +703,13 @@ func (c *Certificate) isValid(certType int, currentChain []*Certificate, opts *V
                }
        }
 
+       if !boringAllowCert(c) {
+               // IncompatibleUsage is not quite right here,
+               // but it's also the "no chains found" error
+               // and is close enough.
+               return CertificateInvalidError{c, IncompatibleUsage, ""}
+       }
+
        return nil
 }
 
@@ -724,6 +739,11 @@ func (c *Certificate) isValid(certType int, currentChain []*Certificate, opts *V
 // list. (While this is not specified, it is common practice in order to limit
 // the types of certificates a CA can issue.)
 //
+// Certificates that use SHA1WithRSA and ECDSAWithSHA1 signatures are not supported,
+// and will not be used to build chains.
+//
+// Certificates other than c in the returned chains should not be modified.
+//
 // WARNING: this function doesn't do any revocation checking.
 func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err error) {
        // Platform-specific verification needs the ASN.1 contents so
@@ -732,7 +752,7 @@ func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err e
                return nil, errNotParsed
        }
        for i := 0; i < opts.Intermediates.len(); i++ {
-               c, err := opts.Intermediates.cert(i)
+               c, _, err := opts.Intermediates.cert(i)
                if err != nil {
                        return nil, fmt.Errorf("crypto/x509: error fetching intermediate: %w", err)
                }
@@ -741,9 +761,23 @@ func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err e
                }
        }
 
-       // Use Windows's own verification and chain building.
-       if opts.Roots == nil && runtime.GOOS == "windows" {
-               return c.systemVerify(&opts)
+       // Use platform verifiers, where available, if Roots is from SystemCertPool.
+       if runtime.GOOS == "windows" || runtime.GOOS == "darwin" || runtime.GOOS == "ios" {
+               // Don't use the system verifier if the system pool was replaced with a non-system pool,
+               // i.e. if SetFallbackRoots was called with x509usefallbackroots=1.
+               systemPool := systemRootsPool()
+               if opts.Roots == nil && (systemPool == nil || systemPool.systemPool) {
+                       return c.systemVerify(&opts)
+               }
+               if opts.Roots != nil && opts.Roots.systemPool {
+                       platformChains, err := c.systemVerify(&opts)
+                       // If the platform verifier succeeded, or there are no additional
+                       // roots, return the platform verifier result. Otherwise, continue
+                       // with the Go verifier.
+                       if err == nil || opts.Roots.len() == 0 {
+                               return platformChains, err
+                       }
+               }
        }
 
        if opts.Roots == nil {
@@ -767,27 +801,29 @@ func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err e
 
        var candidateChains [][]*Certificate
        if opts.Roots.contains(c) {
-               candidateChains = append(candidateChains, []*Certificate{c})
+               candidateChains = [][]*Certificate{{c}}
        } else {
-               if candidateChains, err = c.buildChains(nil, []*Certificate{c}, nil, &opts); err != nil {
+               candidateChains, err = c.buildChains([]*Certificate{c}, nil, &opts)
+               if err != nil {
                        return nil, err
                }
        }
 
-       keyUsages := opts.KeyUsages
-       if len(keyUsages) == 0 {
-               keyUsages = []ExtKeyUsage{ExtKeyUsageServerAuth}
+       if len(opts.KeyUsages) == 0 {
+               opts.KeyUsages = []ExtKeyUsage{ExtKeyUsageServerAuth}
        }
 
-       // If any key usage is acceptable then we're done.
-       for _, usage := range keyUsages {
-               if usage == ExtKeyUsageAny {
+       for _, eku := range opts.KeyUsages {
+               if eku == ExtKeyUsageAny {
+                       // If any key usage is acceptable, no need to check the chain for
+                       // key usages.
                        return candidateChains, nil
                }
        }
 
+       chains = make([][]*Certificate, 0, len(candidateChains))
        for _, candidate := range candidateChains {
-               if checkChainForKeyUsage(candidate, keyUsages) {
+               if checkChainForKeyUsage(candidate, opts.KeyUsages) {
                        chains = append(chains, candidate)
                }
        }
@@ -806,23 +842,65 @@ func appendToFreshChain(chain []*Certificate, cert *Certificate) []*Certificate
        return n
 }
 
+// alreadyInChain checks whether a candidate certificate is present in a chain.
+// Rather than doing a direct byte for byte equivalency check, we check if the
+// subject, public key, and SAN, if present, are equal. This prevents loops that
+// are created by mutual cross-signatures, or other cross-signature bridge
+// oddities.
+func alreadyInChain(candidate *Certificate, chain []*Certificate) bool {
+       type pubKeyEqual interface {
+               Equal(crypto.PublicKey) bool
+       }
+
+       var candidateSAN *pkix.Extension
+       for _, ext := range candidate.Extensions {
+               if ext.Id.Equal(oidExtensionSubjectAltName) {
+                       candidateSAN = &ext
+                       break
+               }
+       }
+
+       for _, cert := range chain {
+               if !bytes.Equal(candidate.RawSubject, cert.RawSubject) {
+                       continue
+               }
+               if !candidate.PublicKey.(pubKeyEqual).Equal(cert.PublicKey) {
+                       continue
+               }
+               var certSAN *pkix.Extension
+               for _, ext := range cert.Extensions {
+                       if ext.Id.Equal(oidExtensionSubjectAltName) {
+                               certSAN = &ext
+                               break
+                       }
+               }
+               if candidateSAN == nil && certSAN == nil {
+                       return true
+               } else if candidateSAN == nil || certSAN == nil {
+                       return false
+               }
+               if bytes.Equal(candidateSAN.Value, certSAN.Value) {
+                       return true
+               }
+       }
+       return false
+}
+
 // maxChainSignatureChecks is the maximum number of CheckSignatureFrom calls
 // that an invocation of buildChains will (transitively) make. Most chains are
 // less than 15 certificates long, so this leaves space for multiple chains and
 // for failed checks due to different intermediates having the same Subject.
 const maxChainSignatureChecks = 100
 
-func (c *Certificate) buildChains(cache map[*Certificate][][]*Certificate, currentChain []*Certificate, sigChecks *int, opts *VerifyOptions) (chains [][]*Certificate, err error) {
+func (c *Certificate) buildChains(currentChain []*Certificate, sigChecks *int, opts *VerifyOptions) (chains [][]*Certificate, err error) {
        var (
                hintErr  error
                hintCert *Certificate
        )
 
-       considerCandidate := func(certType int, candidate *Certificate) {
-               for _, cert := range currentChain {
-                       if cert.Equal(candidate) {
-                               return
-                       }
+       considerCandidate := func(certType int, candidate potentialParent) {
+               if alreadyInChain(candidate.cert, currentChain) {
+                       return
                }
 
                if sigChecks == nil {
@@ -834,31 +912,39 @@ func (c *Certificate) buildChains(cache map[*Certificate][][]*Certificate, curre
                        return
                }
 
-               if err := c.CheckSignatureFrom(candidate); err != nil {
+               if err := c.CheckSignatureFrom(candidate.cert); err != nil {
                        if hintErr == nil {
                                hintErr = err
-                               hintCert = candidate
+                               hintCert = candidate.cert
                        }
                        return
                }
 
-               err = candidate.isValid(certType, currentChain, opts)
+               err = candidate.cert.isValid(certType, currentChain, opts)
                if err != nil {
+                       if hintErr == nil {
+                               hintErr = err
+                               hintCert = candidate.cert
+                       }
                        return
                }
 
+               if candidate.constraint != nil {
+                       if err := candidate.constraint(currentChain); err != nil {
+                               if hintErr == nil {
+                                       hintErr = err
+                                       hintCert = candidate.cert
+                               }
+                               return
+                       }
+               }
+
                switch certType {
                case rootCertificate:
-                       chains = append(chains, appendToFreshChain(currentChain, candidate))
+                       chains = append(chains, appendToFreshChain(currentChain, candidate.cert))
                case intermediateCertificate:
-                       if cache == nil {
-                               cache = make(map[*Certificate][][]*Certificate)
-                       }
-                       childChains, ok := cache[candidate]
-                       if !ok {
-                               childChains, err = candidate.buildChains(cache, appendToFreshChain(currentChain, candidate), sigChecks, opts)
-                               cache[candidate] = childChains
-                       }
+                       var childChains [][]*Certificate
+                       childChains, err = candidate.cert.buildChains(appendToFreshChain(currentChain, candidate.cert), sigChecks, opts)
                        chains = append(chains, childChains...)
                }
        }
@@ -1002,7 +1088,7 @@ func toLowerCaseASCII(in string) string {
 // IP addresses can be optionally enclosed in square brackets and are checked
 // against the IPAddresses field. Other names are checked case insensitively
 // against the DNSNames field. If the names are valid hostnames, the certificate
-// fields can have a wildcard as the left-most label.
+// fields can have a wildcard as the complete left-most label (e.g. *.example.com).
 //
 // Note that the legacy Common Name field is ignored.
 func (c *Certificate) VerifyHostname(h string) error {
@@ -1085,14 +1171,6 @@ NextCert:
                        for _, usage := range cert.ExtKeyUsage {
                                if requestedUsage == usage {
                                        continue NextRequestedUsage
-                               } else if requestedUsage == ExtKeyUsageServerAuth &&
-                                       (usage == ExtKeyUsageNetscapeServerGatedCrypto ||
-                                               usage == ExtKeyUsageMicrosoftServerGatedCrypto) {
-                                       // In order to support COMODO
-                                       // certificate chains, we have to
-                                       // accept Netscape or Microsoft SGC
-                                       // usages as equal to ServerAuth.
-                                       continue NextRequestedUsage
                                }
                        }