]> Cypherpunks.ru repositories - gostls13.git/commitdiff
crypto/x509: add new OID type and use it in Certificate
authorMateusz Poliwczak <mpoliwczak34@gmail.com>
Tue, 31 Oct 2023 17:26:43 +0000 (17:26 +0000)
committerGopher Robot <gobot@golang.org>
Tue, 31 Oct 2023 19:22:19 +0000 (19:22 +0000)
Fixes #60665

Change-Id: I814b7d4b26b964f74443584fb2048b3e27e3b675
GitHub-Last-Rev: 693c741c76e6369e36aa2a599ee6242d632573c7
GitHub-Pull-Request: golang/go#62096
Reviewed-on: https://go-review.googlesource.com/c/go/+/520535
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Mateusz Poliwczak <mpoliwczak34@gmail.com>
Auto-Submit: Roland Shoemaker <roland@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Reviewed-by: Cherry Mui <cherryyz@google.com>
api/next/60665.txt [new file with mode: 0644]
src/crypto/x509/oid.go [new file with mode: 0644]
src/crypto/x509/oid_test.go [new file with mode: 0644]
src/crypto/x509/parser.go
src/crypto/x509/x509.go
src/crypto/x509/x509_test.go

diff --git a/api/next/60665.txt b/api/next/60665.txt
new file mode 100644 (file)
index 0000000..10e50e1
--- /dev/null
@@ -0,0 +1,6 @@
+pkg crypto/x509, type Certificate struct, Policies []OID #60665
+pkg crypto/x509, type OID struct #60665
+pkg crypto/x509, method (OID) Equal(OID) bool #60665
+pkg crypto/x509, method (OID) EqualASN1OID(asn1.ObjectIdentifier) bool #60665
+pkg crypto/x509, method (OID) String() string #60665
+pkg crypto/x509, func OIDFromInts([]uint64) (OID, error) #60665
diff --git a/src/crypto/x509/oid.go b/src/crypto/x509/oid.go
new file mode 100644 (file)
index 0000000..5359af6
--- /dev/null
@@ -0,0 +1,273 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package x509
+
+import (
+       "bytes"
+       "encoding/asn1"
+       "errors"
+       "math"
+       "math/big"
+       "math/bits"
+       "strconv"
+       "strings"
+)
+
+var (
+       errInvalidOID = errors.New("invalid oid")
+)
+
+// An OID represents an ASN.1 OBJECT IDENTIFIER.
+type OID struct {
+       der []byte
+}
+
+func newOIDFromDER(der []byte) (OID, bool) {
+       if len(der) == 0 || der[len(der)-1]&0x80 != 0 {
+               return OID{}, false
+       }
+
+       start := 0
+       for i, v := range der {
+               // ITU-T X.690, section 8.19.2:
+               // The subidentifier shall be encoded in the fewest possible octets,
+               // that is, the leading octet of the subidentifier shall not have the value 0x80.
+               if i == start && v == 0x80 {
+                       return OID{}, false
+               }
+               if v&0x80 == 0 {
+                       start = i + 1
+               }
+       }
+
+       return OID{der}, true
+}
+
+// OIDFromInts creates a new OID using ints, each integer is a separate component.
+func OIDFromInts(oid []uint64) (OID, error) {
+       if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) {
+               return OID{}, errInvalidOID
+       }
+
+       length := base128IntLength(oid[0]*40 + oid[1])
+       for _, v := range oid[2:] {
+               length += base128IntLength(v)
+       }
+
+       der := make([]byte, 0, length)
+       der = appendBase128Int(der, oid[0]*40+oid[1])
+       for _, v := range oid[2:] {
+               der = appendBase128Int(der, v)
+       }
+       return OID{der}, nil
+}
+
+func base128IntLength(n uint64) int {
+       if n == 0 {
+               return 1
+       }
+       return (bits.Len64(n) + 6) / 7
+}
+
+func appendBase128Int(dst []byte, n uint64) []byte {
+       for i := base128IntLength(n) - 1; i >= 0; i-- {
+               o := byte(n >> uint(i*7))
+               o &= 0x7f
+               if i != 0 {
+                       o |= 0x80
+               }
+               dst = append(dst, o)
+       }
+       return dst
+}
+
+// Equal returns true when oid and other represents the same Object Identifier.
+func (oid OID) Equal(other OID) bool {
+       // There is only one possible DER encoding of
+       // each unique Object Identifier.
+       return bytes.Equal(oid.der, other.der)
+}
+
+func parseBase128Int(bytes []byte, initOffset int) (ret, offset int, failed bool) {
+       offset = initOffset
+       var ret64 int64
+       for shifted := 0; offset < len(bytes); shifted++ {
+               // 5 * 7 bits per byte == 35 bits of data
+               // Thus the representation is either non-minimal or too large for an int32
+               if shifted == 5 {
+                       failed = true
+                       return
+               }
+               ret64 <<= 7
+               b := bytes[offset]
+               // integers should be minimally encoded, so the leading octet should
+               // never be 0x80
+               if shifted == 0 && b == 0x80 {
+                       failed = true
+                       return
+               }
+               ret64 |= int64(b & 0x7f)
+               offset++
+               if b&0x80 == 0 {
+                       ret = int(ret64)
+                       // Ensure that the returned value fits in an int on all platforms
+                       if ret64 > math.MaxInt32 {
+                               failed = true
+                       }
+                       return
+               }
+       }
+       failed = true
+       return
+}
+
+// EqualASN1OID returns whether an OID equals an asn1.ObjectIdentifier. If
+// asn1.ObjectIdentifier cannot represent the OID specified by oid, because
+// a component of OID requires more than 31 bits, it returns false.
+func (oid OID) EqualASN1OID(other asn1.ObjectIdentifier) bool {
+       if len(other) < 2 {
+               return false
+       }
+       v, offset, failed := parseBase128Int(oid.der, 0)
+       if failed {
+               // This should never happen, since we've already parsed the OID,
+               // but just in case.
+               return false
+       }
+       if v < 80 {
+               a, b := v/40, v%40
+               if other[0] != a || other[1] != b {
+                       return false
+               }
+       } else {
+               a, b := 2, v-80
+               if other[0] != a || other[1] != b {
+                       return false
+               }
+       }
+
+       i := 2
+       for ; offset < len(oid.der); i++ {
+               v, offset, failed = parseBase128Int(oid.der, offset)
+               if failed {
+                       // Again, shouldn't happen, since we've already parsed
+                       // the OID, but better safe than sorry.
+                       return false
+               }
+               if v != other[i] {
+                       return false
+               }
+       }
+
+       return i == len(other)
+}
+
+// Strings returns the string representation of the Object Identifier.
+func (oid OID) String() string {
+       var b strings.Builder
+       b.Grow(32)
+       const (
+               valSize         = 64 // size in bits of val.
+               bitsPerByte     = 7
+               maxValSafeShift = (1 << (valSize - bitsPerByte)) - 1
+       )
+       var (
+               start    = 0
+               val      = uint64(0)
+               numBuf   = make([]byte, 0, 21)
+               bigVal   *big.Int
+               overflow bool
+       )
+       for i, v := range oid.der {
+               curVal := v & 0x7F
+               valEnd := v&0x80 == 0
+               if valEnd {
+                       if start != 0 {
+                               b.WriteByte('.')
+                       }
+               }
+               if !overflow && val > maxValSafeShift {
+                       if bigVal == nil {
+                               bigVal = new(big.Int)
+                       }
+                       bigVal = bigVal.SetUint64(val)
+                       overflow = true
+               }
+               if overflow {
+                       bigVal = bigVal.Lsh(bigVal, bitsPerByte).Or(bigVal, big.NewInt(int64(curVal)))
+                       if valEnd {
+                               if start == 0 {
+                                       b.WriteString("2.")
+                                       bigVal = bigVal.Sub(bigVal, big.NewInt(80))
+                               }
+                               numBuf = bigVal.Append(numBuf, 10)
+                               b.Write(numBuf)
+                               numBuf = numBuf[:0]
+                               val = 0
+                               start = i + 1
+                               overflow = false
+                       }
+                       continue
+               }
+               val <<= bitsPerByte
+               val |= uint64(curVal)
+               if valEnd {
+                       if start == 0 {
+                               if val < 80 {
+                                       b.Write(strconv.AppendUint(numBuf, val/40, 10))
+                                       b.WriteByte('.')
+                                       b.Write(strconv.AppendUint(numBuf, val%40, 10))
+                               } else {
+                                       b.WriteString("2.")
+                                       b.Write(strconv.AppendUint(numBuf, val-80, 10))
+                               }
+                       } else {
+                               b.Write(strconv.AppendUint(numBuf, val, 10))
+                       }
+                       val = 0
+                       start = i + 1
+               }
+       }
+       return b.String()
+}
+
+func (oid OID) toASN1OID() (asn1.ObjectIdentifier, bool) {
+       out := make([]int, 0, len(oid.der)+1)
+
+       const (
+               valSize         = 31 // amount of usable bits of val for OIDs.
+               bitsPerByte     = 7
+               maxValSafeShift = (1 << (valSize - bitsPerByte)) - 1
+       )
+
+       val := 0
+
+       for _, v := range oid.der {
+               if val > maxValSafeShift {
+                       return nil, false
+               }
+
+               val <<= bitsPerByte
+               val |= int(v & 0x7F)
+
+               if v&0x80 == 0 {
+                       if len(out) == 0 {
+                               if val < 80 {
+                                       out = append(out, val/40)
+                                       out = append(out, val%40)
+                               } else {
+                                       out = append(out, 2)
+                                       out = append(out, val-80)
+                               }
+                               val = 0
+                               continue
+                       }
+                       out = append(out, val)
+                       val = 0
+               }
+       }
+
+       return out, true
+}
diff --git a/src/crypto/x509/oid_test.go b/src/crypto/x509/oid_test.go
new file mode 100644 (file)
index 0000000..b2be107
--- /dev/null
@@ -0,0 +1,110 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package x509
+
+import (
+       "encoding/asn1"
+       "math"
+       "testing"
+)
+
+func TestOID(t *testing.T) {
+       var tests = []struct {
+               raw   []byte
+               valid bool
+               str   string
+               ints  []uint64
+       }{
+               {[]byte{}, false, "", nil},
+               {[]byte{0x80, 0x01}, false, "", nil},
+               {[]byte{0x01, 0x80, 0x01}, false, "", nil},
+
+               {[]byte{1, 2, 3}, true, "0.1.2.3", []uint64{0, 1, 2, 3}},
+               {[]byte{41, 2, 3}, true, "1.1.2.3", []uint64{1, 1, 2, 3}},
+               {[]byte{86, 2, 3}, true, "2.6.2.3", []uint64{2, 6, 2, 3}},
+
+               {[]byte{41, 255, 255, 255, 127}, true, "1.1.268435455", []uint64{1, 1, 268435455}},
+               {[]byte{41, 0x87, 255, 255, 255, 127}, true, "1.1.2147483647", []uint64{1, 1, 2147483647}},
+               {[]byte{41, 255, 255, 255, 255, 127}, true, "1.1.34359738367", []uint64{1, 1, 34359738367}},
+               {[]byte{42, 255, 255, 255, 255, 255, 255, 255, 255, 127}, true, "1.2.9223372036854775807", []uint64{1, 2, 9223372036854775807}},
+               {[]byte{43, 0x81, 255, 255, 255, 255, 255, 255, 255, 255, 127}, true, "1.3.18446744073709551615", []uint64{1, 3, 18446744073709551615}},
+               {[]byte{44, 0x83, 255, 255, 255, 255, 255, 255, 255, 255, 127}, true, "1.4.36893488147419103231", nil},
+               {[]byte{85, 255, 255, 255, 255, 255, 255, 255, 255, 255, 127}, true, "2.5.1180591620717411303423", nil},
+               {[]byte{85, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 127}, true, "2.5.19342813113834066795298815", nil},
+
+               {[]byte{255, 255, 255, 127}, true, "2.268435375", []uint64{2, 268435375}},
+               {[]byte{0x87, 255, 255, 255, 127}, true, "2.2147483567", []uint64{2, 2147483567}},
+               {[]byte{255, 127}, true, "2.16303", []uint64{2, 16303}},
+               {[]byte{255, 255, 255, 255, 127}, true, "2.34359738287", []uint64{2, 34359738287}},
+               {[]byte{255, 255, 255, 255, 255, 255, 255, 255, 127}, true, "2.9223372036854775727", []uint64{2, 9223372036854775727}},
+               {[]byte{0x81, 255, 255, 255, 255, 255, 255, 255, 255, 127}, true, "2.18446744073709551535", []uint64{2, 18446744073709551535}},
+               {[]byte{0x83, 255, 255, 255, 255, 255, 255, 255, 255, 127}, true, "2.36893488147419103151", nil},
+               {[]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 127}, true, "2.1180591620717411303343", nil},
+               {[]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 127}, true, "2.19342813113834066795298735", nil},
+       }
+
+       for _, v := range tests {
+               oid, ok := newOIDFromDER(v.raw)
+               if ok != v.valid {
+                       if ok {
+                               t.Errorf("%v: unexpected success while parsing: %v", v.raw, oid)
+                       } else {
+                               t.Errorf("%v: unexpected failure while parsing", v.raw)
+                       }
+                       continue
+               }
+
+               if !ok {
+                       continue
+               }
+
+               if str := oid.String(); str != v.str {
+                       t.Errorf("%v: oid.String() = %v, want; %v", v.raw, str, v.str)
+               }
+
+               var asn1OID asn1.ObjectIdentifier
+               for _, v := range v.ints {
+                       if v > math.MaxInt32 {
+                               asn1OID = nil
+                               break
+                       }
+                       asn1OID = append(asn1OID, int(v))
+               }
+
+               o, ok := oid.toASN1OID()
+               if shouldOk := asn1OID != nil; shouldOk != ok {
+                       if ok {
+                               t.Errorf("%v: oid.toASN1OID() unexpected success", v.raw)
+                       } else {
+                               t.Errorf("%v: oid.toASN1OID() unexpected fauilure", v.raw)
+                       }
+                       continue
+               }
+
+               if asn1OID != nil {
+                       if !o.Equal(asn1OID) {
+                               t.Errorf("%v: oid.toASN1OID(asn1OID).Equal(oid) = false, want: true", v.raw)
+                       }
+               }
+
+               if v.ints != nil {
+                       oid2, err := OIDFromInts(v.ints)
+                       if err != nil {
+                               t.Errorf("%v: OIDFromInts() unexpected error: %v", v.raw, err)
+                       }
+                       if !oid2.Equal(oid) {
+                               t.Errorf("%v: %#v.Equal(%#v) = false, want: true", v.raw, oid2, oid)
+                       }
+               }
+       }
+}
+
+func mustNewOIDFromInts(t *testing.T, ints []uint64) OID {
+       oid, err := OIDFromInts(ints)
+       if err != nil {
+               t.Fatalf("OIDFromInts(%v) unexpected error: %v", ints, err)
+       }
+       return oid
+}
index 019a53b5dcf21df019549d5040b1448e9ce94373..812b0d2d28540a3af388501937aea67e7b642925 100644 (file)
@@ -435,23 +435,23 @@ func parseExtKeyUsageExtension(der cryptobyte.String) ([]ExtKeyUsage, []asn1.Obj
        return extKeyUsages, unknownUsages, nil
 }
 
-func parseCertificatePoliciesExtension(der cryptobyte.String) ([]asn1.ObjectIdentifier, error) {
-       var oids []asn1.ObjectIdentifier
+func parseCertificatePoliciesExtension(der cryptobyte.String) ([]OID, error) {
+       var oids []OID
        if !der.ReadASN1(&der, cryptobyte_asn1.SEQUENCE) {
                return nil, errors.New("x509: invalid certificate policies")
        }
        for !der.Empty() {
                var cp cryptobyte.String
-               if !der.ReadASN1(&cp, cryptobyte_asn1.SEQUENCE) {
+               var OIDBytes cryptobyte.String
+               if !der.ReadASN1(&cp, cryptobyte_asn1.SEQUENCE) || !cp.ReadASN1(&OIDBytes, cryptobyte_asn1.OBJECT_IDENTIFIER) {
                        return nil, errors.New("x509: invalid certificate policies")
                }
-               var oid asn1.ObjectIdentifier
-               if !cp.ReadASN1ObjectIdentifier(&oid) {
+               oid, ok := newOIDFromDER(OIDBytes)
+               if !ok {
                        return nil, errors.New("x509: invalid certificate policies")
                }
                oids = append(oids, oid)
        }
-
        return oids, nil
 }
 
@@ -748,10 +748,16 @@ func processExtensions(out *Certificate) error {
                                }
                                out.SubjectKeyId = skid
                        case 32:
-                               out.PolicyIdentifiers, err = parseCertificatePoliciesExtension(e.Value)
+                               out.Policies, err = parseCertificatePoliciesExtension(e.Value)
                                if err != nil {
                                        return err
                                }
+                               out.PolicyIdentifiers = make([]asn1.ObjectIdentifier, 0, len(out.Policies))
+                               for _, oid := range out.Policies {
+                                       if oid, ok := oid.toASN1OID(); ok {
+                                               out.PolicyIdentifiers = append(out.PolicyIdentifiers, oid)
+                                       }
+                               }
                        default:
                                // Unknown extensions are recorded if critical.
                                unhandled = true
index dfc5092b307c02313252bec665c0601549c22ef3..b2e31f76b4cb5ccbb03cc4f65339c75e50dd3e59 100644 (file)
@@ -772,7 +772,15 @@ type Certificate struct {
        // CRL Distribution Points
        CRLDistributionPoints []string
 
+       // PolicyIdentifiers contains asn1.ObjectIdentifiers, the components
+       // of which are limited to int32. If a certificate contains a policy which
+       // cannot be represented by asn1.ObjectIdentifier, it will not be included in
+       // PolicyIdentifiers, but will be present in Policies, which contains all parsed
+       // policy OIDs.
        PolicyIdentifiers []asn1.ObjectIdentifier
+
+       // Policies contains all policy identifiers included in the certificate.
+       Policies []OID
 }
 
 // ErrUnsupportedAlgorithm results from attempting to perform an operation that
@@ -1179,7 +1187,7 @@ func buildCertExtensions(template *Certificate, subjectIsEmpty bool, authorityKe
 
        if len(template.PolicyIdentifiers) > 0 &&
                !oidInExtensions(oidExtensionCertificatePolicies, template.ExtraExtensions) {
-               ret[n], err = marshalCertificatePolicies(template.PolicyIdentifiers)
+               ret[n], err = marshalCertificatePolicies(template.Policies, template.PolicyIdentifiers)
                if err != nil {
                        return nil, err
                }
@@ -1364,14 +1372,27 @@ func marshalBasicConstraints(isCA bool, maxPathLen int, maxPathLenZero bool) (pk
        return ext, err
 }
 
-func marshalCertificatePolicies(policyIdentifiers []asn1.ObjectIdentifier) (pkix.Extension, error) {
+func marshalCertificatePolicies(policies []OID, policyIdentifiers []asn1.ObjectIdentifier) (pkix.Extension, error) {
        ext := pkix.Extension{Id: oidExtensionCertificatePolicies}
-       policies := make([]policyInformation, len(policyIdentifiers))
-       for i, policy := range policyIdentifiers {
-               policies[i].Policy = policy
-       }
+
+       b := cryptobyte.NewBuilder(make([]byte, 0, 128))
+       b.AddASN1(cryptobyte_asn1.SEQUENCE, func(child *cryptobyte.Builder) {
+               for _, v := range policies {
+                       child.AddASN1(cryptobyte_asn1.SEQUENCE, func(child *cryptobyte.Builder) {
+                               child.AddASN1(cryptobyte_asn1.OBJECT_IDENTIFIER, func(child *cryptobyte.Builder) {
+                                       child.AddBytes(v.der)
+                               })
+                       })
+               }
+               for _, v := range policyIdentifiers {
+                       child.AddASN1(cryptobyte_asn1.SEQUENCE, func(child *cryptobyte.Builder) {
+                               child.AddASN1ObjectIdentifier(v)
+                       })
+               }
+       })
+
        var err error
-       ext.Value, err = asn1.Marshal(policies)
+       ext.Value, err = b.Bytes()
        return ext, err
 }
 
index 9a80b2b4343325997570d6f7a988714781338aed..bdc03216bcbc625d6ceed4c9d9b799f8e50a47cd 100644 (file)
@@ -24,12 +24,14 @@ import (
        "fmt"
        "internal/testenv"
        "io"
+       "math"
        "math/big"
        "net"
        "net/url"
        "os/exec"
        "reflect"
        "runtime"
+       "slices"
        "strings"
        "testing"
        "time"
@@ -671,6 +673,7 @@ func TestCreateSelfSignedCertificate(t *testing.T) {
                        URIs:           []*url.URL{parseURI("https://foo.com/wibble#foo")},
 
                        PolicyIdentifiers:       []asn1.ObjectIdentifier{[]int{1, 2, 3}},
+                       Policies:                []OID{mustNewOIDFromInts(t, []uint64{1, 2, 3, math.MaxUint32, math.MaxUint64})},
                        PermittedDNSDomains:     []string{".example.com", "example.com"},
                        ExcludedDNSDomains:      []string{"bar.example.com"},
                        PermittedIPRanges:       []*net.IPNet{parseCIDR("192.168.1.1/16"), parseCIDR("1.2.3.4/8")},
@@ -3917,3 +3920,49 @@ func TestDuplicateAttributesCSR(t *testing.T) {
                t.Fatal("ParseCertificateRequest should succeed when parsing CSR with duplicate attributes")
        }
 }
+
+func TestCertificateOIDPolicies(t *testing.T) {
+       template := Certificate{
+               SerialNumber:      big.NewInt(1),
+               Subject:           pkix.Name{CommonName: "Cert"},
+               NotBefore:         time.Unix(1000, 0),
+               NotAfter:          time.Unix(100000, 0),
+               PolicyIdentifiers: []asn1.ObjectIdentifier{[]int{1, 2, 3}},
+               Policies: []OID{
+                       mustNewOIDFromInts(t, []uint64{1, 2, 3, 4, 5}),
+                       mustNewOIDFromInts(t, []uint64{1, 2, 3, math.MaxInt32}),
+                       mustNewOIDFromInts(t, []uint64{1, 2, 3, math.MaxUint32, math.MaxUint64}),
+               },
+       }
+
+       var expectPolicyIdentifiers = []asn1.ObjectIdentifier{
+               []int{1, 2, 3, 4, 5},
+               []int{1, 2, 3, math.MaxInt32},
+               []int{1, 2, 3},
+       }
+
+       var expectPolicies = []OID{
+               mustNewOIDFromInts(t, []uint64{1, 2, 3, 4, 5}),
+               mustNewOIDFromInts(t, []uint64{1, 2, 3, math.MaxInt32}),
+               mustNewOIDFromInts(t, []uint64{1, 2, 3, math.MaxUint32, math.MaxUint64}),
+               mustNewOIDFromInts(t, []uint64{1, 2, 3}),
+       }
+
+       certDER, err := CreateCertificate(rand.Reader, &template, &template, rsaPrivateKey.Public(), rsaPrivateKey)
+       if err != nil {
+               t.Fatalf("CreateCertificate() unexpected error: %v", err)
+       }
+
+       cert, err := ParseCertificate(certDER)
+       if err != nil {
+               t.Fatalf("ParseCertificate() unexpected error: %v", err)
+       }
+
+       if !slices.EqualFunc(cert.PolicyIdentifiers, expectPolicyIdentifiers, slices.Equal) {
+               t.Errorf("cert.PolicyIdentifiers = %v, want: %v", cert.PolicyIdentifiers, expectPolicyIdentifiers)
+       }
+
+       if !slices.EqualFunc(cert.Policies, expectPolicies, OID.Equal) {
+               t.Errorf("cert.Policies = %v, want: %v", cert.Policies, expectPolicies)
+       }
+}