]> Cypherpunks.ru repositories - pyderasn.git/blobdiff - tests/test_crts.py
keep_memoryview context option
[pyderasn.git] / tests / test_crts.py
index 93162e0b8ad5a6eac35f2551ae115802ae5f87b8..8ceeb086f9a1eca60710eb889618520185104753 100644 (file)
@@ -1,11 +1,10 @@
 # coding: utf-8
 # coding: utf-8
-# PyDERASN -- Python ASN.1 DER codec with abstract structures
-# Copyright (C) 2017 Sergey Matveev <stargrave@stargrave.org>
+# PyDERASN -- Python ASN.1 DER/CER/BER codec with abstract structures
+# Copyright (C) 2017-2022 Sergey Matveev <stargrave@stargrave.org>
 #
 # This program is free software: you can redistribute it and/or modify
 # it under the terms of the GNU Lesser General Public License as
 #
 # This program is free software: you can redistribute it and/or modify
 # it under the terms of the GNU Lesser General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
+# published by the Free Software Foundation, version 3 of the License.
 #
 # This program is distributed in the hope that it will be useful,
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 #
 # This program is distributed in the hope that it will be useful,
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # License along with this program.  If not, see
 # <http://www.gnu.org/licenses/>.
 
 # License along with this program.  If not, see
 # <http://www.gnu.org/licenses/>.
 
+from copy import copy
 from datetime import datetime
 from datetime import datetime
+from pickle import dumps as pickle_dumps
+from pickle import HIGHEST_PROTOCOL as pickle_proto
+from pickle import loads as pickle_loads
 from unittest import TestCase
 
 from pyderasn import Any
 from pyderasn import BitString
 from pyderasn import Boolean
 from pyderasn import Choice
 from unittest import TestCase
 
 from pyderasn import Any
 from pyderasn import BitString
 from pyderasn import Boolean
 from pyderasn import Choice
+from pyderasn import DecodeError
+from pyderasn import encode_cer
 from pyderasn import GeneralizedTime
 from pyderasn import hexdec
 from pyderasn import IA5String
 from pyderasn import GeneralizedTime
 from pyderasn import hexdec
 from pyderasn import IA5String
@@ -37,32 +42,34 @@ from pyderasn import SequenceOf
 from pyderasn import SetOf
 from pyderasn import tag_ctxc
 from pyderasn import tag_ctxp
 from pyderasn import SetOf
 from pyderasn import tag_ctxc
 from pyderasn import tag_ctxp
+from pyderasn import TeletexString
 from pyderasn import UTCTime
 from pyderasn import UTCTime
-
-
-some_oids = {
-    "1.2.840.113549.1.1.1": "id-rsaEncryption",
-    "1.2.840.113549.1.1.5": "id-sha1WithRSAEncryption",
-    "1.2.840.113549.1.9.1": "id-emailAddress",
-    "2.5.29.14": "id-ce-subjectKeyIdentifier",
-    "2.5.29.15": "id-ce-keyUsage",
-    "2.5.29.17": "id-ce-subjectAltName",
-    "2.5.29.18": "id-ce-issuerAltName",
-    "2.5.29.19": "id-ce-basicConstraints",
-    "2.5.29.31": "id-ce-cRLDistributionPoints",
-    "2.5.29.35": "id-ce-authorityKeyIdentifier",
-    "2.5.29.37": "id-ce-extKeyUsage",
-    "2.5.4.3": "id-at-commonName",
-    "2.5.4.6": "id-at-countryName",
-    "2.5.4.7": "id-at-localityName",
-    "2.5.4.8": "id-at-stateOrProvinceName",
-    "2.5.4.10": "id-at-organizationName",
-    "2.5.4.11": "id-at-organizationalUnitName",
+from pyderasn import UTF8String
+
+
+name2oid = {
+    "id-rsaEncryption": ObjectIdentifier("1.2.840.113549.1.1.1"),
+    "id-sha1WithRSAEncryption": ObjectIdentifier("1.2.840.113549.1.1.5"),
+    "id-emailAddress": ObjectIdentifier("1.2.840.113549.1.9.1"),
+    "id-ce-subjectKeyIdentifier": ObjectIdentifier("2.5.29.14"),
+    "id-ce-keyUsage": ObjectIdentifier("2.5.29.15"),
+    "id-ce-subjectAltName": ObjectIdentifier("2.5.29.17"),
+    "id-ce-issuerAltName": ObjectIdentifier("2.5.29.18"),
+    "id-ce-basicConstraints": ObjectIdentifier("2.5.29.19"),
+    "id-ce-cRLDistributionPoints": ObjectIdentifier("2.5.29.31"),
+    "id-ce-authorityKeyIdentifier": ObjectIdentifier("2.5.29.35"),
+    "id-ce-extKeyUsage": ObjectIdentifier("2.5.29.37"),
+    "id-at-commonName": ObjectIdentifier("2.5.4.3"),
+    "id-at-countryName": ObjectIdentifier("2.5.4.6"),
+    "id-at-localityName": ObjectIdentifier("2.5.4.7"),
+    "id-at-stateOrProvinceName": ObjectIdentifier("2.5.4.8"),
+    "id-at-organizationName": ObjectIdentifier("2.5.4.10"),
+    "id-at-organizationalUnitName": ObjectIdentifier("2.5.4.11"),
 }
 }
+stroid2name = {str(oid): name for name, oid in name2oid.items()}
 
 
 class Version(Integer):
 
 
 class Version(Integer):
-    __slots__ = ()
     schema = (
         ("v1", 0),
         ("v2", 1),
     schema = (
         ("v1", 0),
         ("v2", 1),
@@ -71,12 +78,10 @@ class Version(Integer):
 
 
 class CertificateSerialNumber(Integer):
 
 
 class CertificateSerialNumber(Integer):
-    __slots__ = ()
     pass
 
 
 class AlgorithmIdentifier(Sequence):
     pass
 
 
 class AlgorithmIdentifier(Sequence):
-    __slots__ = ()
     schema = (
         ("algorithm", ObjectIdentifier()),
         ("parameters", Any(optional=True)),
     schema = (
         ("algorithm", ObjectIdentifier()),
         ("parameters", Any(optional=True)),
@@ -84,43 +89,56 @@ class AlgorithmIdentifier(Sequence):
 
 
 class AttributeType(ObjectIdentifier):
 
 
 class AttributeType(ObjectIdentifier):
-    __slots__ = ()
     pass
 
 
 class AttributeValue(Any):
     pass
 
 
 class AttributeValue(Any):
-    __slots__ = ()
     pass
 
 
     pass
 
 
+class OrganizationName(Choice):
+    schema = (
+        ("printableString", PrintableString()),
+        ("teletexString", TeletexString()),
+    )
+
+
+class CommonName(Choice):
+    schema = (
+        ("printableString", PrintableString()),
+        ("utf8String", UTF8String()),
+    )
+
+
 class AttributeTypeAndValue(Sequence):
 class AttributeTypeAndValue(Sequence):
-    __slots__ = ()
     schema = (
     schema = (
-        ("type", AttributeType()),
+        ("type", AttributeType(defines=((("value",), {
+            name2oid["id-at-countryName"]: PrintableString(),
+            name2oid["id-at-localityName"]: PrintableString(),
+            name2oid["id-at-stateOrProvinceName"]: PrintableString(),
+            name2oid["id-at-organizationName"]: OrganizationName(),
+            name2oid["id-at-commonName"]: CommonName(),
+        }),))),
         ("value", AttributeValue()),
     )
 
 
 class RelativeDistinguishedName(SetOf):
         ("value", AttributeValue()),
     )
 
 
 class RelativeDistinguishedName(SetOf):
-    __slots__ = ()
     schema = AttributeTypeAndValue()
     bounds = (1, float("+inf"))
 
 
 class RDNSequence(SequenceOf):
     schema = AttributeTypeAndValue()
     bounds = (1, float("+inf"))
 
 
 class RDNSequence(SequenceOf):
-    __slots__ = ()
     schema = RelativeDistinguishedName()
 
 
 class Name(Choice):
     schema = RelativeDistinguishedName()
 
 
 class Name(Choice):
-    __slots__ = ()
     schema = (
         ("rdnSequence", RDNSequence()),
     )
 
 
 class Time(Choice):
     schema = (
         ("rdnSequence", RDNSequence()),
     )
 
 
 class Time(Choice):
-    __slots__ = ()
     schema = (
         ("utcTime", UTCTime()),
         ("generalTime", GeneralizedTime()),
     schema = (
         ("utcTime", UTCTime()),
         ("generalTime", GeneralizedTime()),
@@ -128,7 +146,6 @@ class Time(Choice):
 
 
 class Validity(Sequence):
 
 
 class Validity(Sequence):
-    __slots__ = ()
     schema = (
         ("notBefore", Time()),
         ("notAfter", Time()),
     schema = (
         ("notBefore", Time()),
         ("notAfter", Time()),
@@ -136,7 +153,6 @@ class Validity(Sequence):
 
 
 class SubjectPublicKeyInfo(Sequence):
 
 
 class SubjectPublicKeyInfo(Sequence):
-    __slots__ = ()
     schema = (
         ("algorithm", AlgorithmIdentifier()),
         ("subjectPublicKey", BitString()),
     schema = (
         ("algorithm", AlgorithmIdentifier()),
         ("subjectPublicKey", BitString()),
@@ -144,12 +160,18 @@ class SubjectPublicKeyInfo(Sequence):
 
 
 class UniqueIdentifier(BitString):
 
 
 class UniqueIdentifier(BitString):
-    __slots__ = ()
+    pass
+
+
+class KeyIdentifier(OctetString):
+    pass
+
+
+class SubjectKeyIdentifier(KeyIdentifier):
     pass
 
 
 class Extension(Sequence):
     pass
 
 
 class Extension(Sequence):
-    __slots__ = ()
     schema = (
         ("extnID", ObjectIdentifier()),
         ("critical", Boolean(default=False)),
     schema = (
         ("extnID", ObjectIdentifier()),
         ("critical", Boolean(default=False)),
@@ -158,13 +180,11 @@ class Extension(Sequence):
 
 
 class Extensions(SequenceOf):
 
 
 class Extensions(SequenceOf):
-    __slots__ = ()
     schema = Extension()
     bounds = (1, float("+inf"))
 
 
 class TBSCertificate(Sequence):
     schema = Extension()
     bounds = (1, float("+inf"))
 
 
 class TBSCertificate(Sequence):
-    __slots__ = ()
     schema = (
         ("version", Version(expl=tag_ctxc(0), default="v1")),
         ("serialNumber", CertificateSerialNumber()),
     schema = (
         ("version", Version(expl=tag_ctxc(0), default="v1")),
         ("serialNumber", CertificateSerialNumber()),
@@ -180,12 +200,12 @@ class TBSCertificate(Sequence):
 
 
 class Certificate(Sequence):
 
 
 class Certificate(Sequence):
-    __slots__ = ()
     schema = (
         ("tbsCertificate", TBSCertificate()),
         ("signatureAlgorithm", AlgorithmIdentifier()),
         ("signatureValue", BitString()),
     )
     schema = (
         ("tbsCertificate", TBSCertificate()),
         ("signatureAlgorithm", AlgorithmIdentifier()),
         ("signatureValue", BitString()),
     )
+    der_forced = True
 
 
 class TestGoSelfSignedVector(TestCase):
 
 
 class TestGoSelfSignedVector(TestCase):
@@ -209,8 +229,7 @@ class TestGoSelfSignedVector(TestCase):
             "ba3ca12568fdc6c7b4511cd40a7f659980402df2b998bb9a4a8cbeb34c0f0a78c",
             "f8d91ede14a5ed76bf116fe360aafa8821490435",
         )))
             "ba3ca12568fdc6c7b4511cd40a7f659980402df2b998bb9a4a8cbeb34c0f0a78c",
             "f8d91ede14a5ed76bf116fe360aafa8821490435",
         )))
-        crt, tail = Certificate().decode(raw)
-        self.assertSequenceEqual(tail, b"")
+        crt = Certificate().decod(raw, ctx={"keep_memoryview": True})
         tbs = crt["tbsCertificate"]
         self.assertEqual(tbs["version"], 0)
         self.assertFalse(tbs["version"].decoded)
         tbs = crt["tbsCertificate"]
         self.assertEqual(tbs["version"], 0)
         self.assertFalse(tbs["version"].decoded)
@@ -224,12 +243,12 @@ class TestGoSelfSignedVector(TestCase):
                 expect.encode(),
             )
         assert_raw_equals(tbs["serialNumber"], Integer(10143011886257155224))
                 expect.encode(),
             )
         assert_raw_equals(tbs["serialNumber"], Integer(10143011886257155224))
-        algo_id = AlgorithmIdentifier()
-        algo_id["algorithm"] = ObjectIdentifier("1.2.840.113549.1.1.5")
-        algo_id["parameters"] = Any(Null())
+        algo_id = AlgorithmIdentifier((
+            ("algorithm", name2oid["id-sha1WithRSAEncryption"]),
+            ("parameters", Any(Null())),
+        ))
         self.assertEqual(tbs["signature"], algo_id)
         assert_raw_equals(tbs["signature"], algo_id)
         self.assertEqual(tbs["signature"], algo_id)
         assert_raw_equals(tbs["signature"], algo_id)
-        issuer = Name()
         rdnSeq = RDNSequence()
         for oid, klass, text in (
                 ("2.5.4.6", PrintableString, "XX"),
         rdnSeq = RDNSequence()
         for oid, klass, text in (
                 ("2.5.4.6", PrintableString, "XX"),
@@ -239,28 +258,31 @@ class TestGoSelfSignedVector(TestCase):
                 ("2.5.4.3", PrintableString, "false.example.com"),
                 ("1.2.840.113549.1.9.1", IA5String, "false@example.com"),
         ):
                 ("2.5.4.3", PrintableString, "false.example.com"),
                 ("1.2.840.113549.1.9.1", IA5String, "false@example.com"),
         ):
-            attr = AttributeTypeAndValue()
-            attr["type"] = AttributeType(oid)
-            attr["value"] = AttributeValue(klass(text))
-            rdn = RelativeDistinguishedName()
-            rdn.append(attr)
-            rdnSeq.append(rdn)
-        issuer["rdnSequence"] = rdnSeq
+            rdnSeq.append(
+                RelativeDistinguishedName((
+                    AttributeTypeAndValue((
+                        ("type", AttributeType(oid)),
+                        ("value", AttributeValue(klass(text))),
+                    )),
+                ))
+            )
+        issuer = Name(("rdnSequence", rdnSeq))
         self.assertEqual(tbs["issuer"], issuer)
         assert_raw_equals(tbs["issuer"], issuer)
         self.assertEqual(tbs["issuer"], issuer)
         assert_raw_equals(tbs["issuer"], issuer)
-        validity = Validity()
-        validity["notBefore"] = Time(
-            ("utcTime", UTCTime(datetime(2009, 10, 8, 0, 25, 53)))
-        )
-        validity["notAfter"] = Time(
-            ("utcTime", UTCTime(datetime(2010, 10, 8, 0, 25, 53)))
-        )
+        validity = Validity((
+            ("notBefore", Time(
+                ("utcTime", UTCTime(datetime(2009, 10, 8, 0, 25, 53)))
+            )),
+            ("notAfter", Time(
+                ("utcTime", UTCTime(datetime(2010, 10, 8, 0, 25, 53)))
+            )),
+        ))
         self.assertEqual(tbs["validity"], validity)
         assert_raw_equals(tbs["validity"], validity)
         self.assertEqual(tbs["subject"], issuer)
         assert_raw_equals(tbs["subject"], issuer)
         spki = SubjectPublicKeyInfo()
         self.assertEqual(tbs["validity"], validity)
         assert_raw_equals(tbs["validity"], validity)
         self.assertEqual(tbs["subject"], issuer)
         assert_raw_equals(tbs["subject"], issuer)
         spki = SubjectPublicKeyInfo()
-        algo_id["algorithm"] = ObjectIdentifier("1.2.840.113549.1.1.1")
+        algo_id["algorithm"] = name2oid["id-rsaEncryption"]
         spki["algorithm"] = algo_id
         spki["subjectPublicKey"] = BitString(hexdec("".join((
             "3048024100cdb7639c3278f006aa277f6eaf42902b592d8cbcbe38a1c92ba4695",
         spki["algorithm"] = algo_id
         spki["subjectPublicKey"] = BitString(hexdec("".join((
             "3048024100cdb7639c3278f006aa277f6eaf42902b592d8cbcbe38a1c92ba4695",
@@ -272,22 +294,25 @@ class TestGoSelfSignedVector(TestCase):
         self.assertNotIn("issuerUniqueID", tbs)
         self.assertNotIn("subjectUniqueID", tbs)
         self.assertNotIn("extensions", tbs)
         self.assertNotIn("issuerUniqueID", tbs)
         self.assertNotIn("subjectUniqueID", tbs)
         self.assertNotIn("extensions", tbs)
-        algo_id["algorithm"] = ObjectIdentifier("1.2.840.113549.1.1.5")
+        algo_id["algorithm"] = name2oid["id-sha1WithRSAEncryption"]
         self.assertEqual(crt["signatureAlgorithm"], algo_id)
         self.assertEqual(crt["signatureValue"], BitString(hexdec("".join((
             "a67b06ec5ece92772ca413cba3ca12568fdc6c7b4511cd40a7f659980402df2b",
             "998bb9a4a8cbeb34c0f0a78cf8d91ede14a5ed76bf116fe360aafa8821490435",
         )))))
         self.assertSequenceEqual(crt.encode(), raw)
         self.assertEqual(crt["signatureAlgorithm"], algo_id)
         self.assertEqual(crt["signatureValue"], BitString(hexdec("".join((
             "a67b06ec5ece92772ca413cba3ca12568fdc6c7b4511cd40a7f659980402df2b",
             "998bb9a4a8cbeb34c0f0a78cf8d91ede14a5ed76bf116fe360aafa8821490435",
         )))))
         self.assertSequenceEqual(crt.encode(), raw)
+        crt = Certificate().decod(raw)
         pprint(crt)
         repr(crt)
         pprint(crt)
         repr(crt)
+        pickle_loads(pickle_dumps(crt, pickle_proto))
 
         tbs = TBSCertificate()
         tbs["serialNumber"] = CertificateSerialNumber(10143011886257155224)
 
 
         tbs = TBSCertificate()
         tbs["serialNumber"] = CertificateSerialNumber(10143011886257155224)
 
-        sign_algo_id = AlgorithmIdentifier()
-        sign_algo_id["algorithm"] = ObjectIdentifier("1.2.840.113549.1.1.5")
-        sign_algo_id["parameters"] = Any(Null())
+        sign_algo_id = AlgorithmIdentifier((
+            ("algorithm", name2oid["id-sha1WithRSAEncryption"]),
+            ("parameters", Any(Null())),
+        ))
         tbs["signature"] = sign_algo_id
 
         rdnSeq = RDNSequence()
         tbs["signature"] = sign_algo_id
 
         rdnSeq = RDNSequence()
@@ -299,25 +324,32 @@ class TestGoSelfSignedVector(TestCase):
                 ("2.5.4.3", PrintableString, "false.example.com"),
                 ("1.2.840.113549.1.9.1", IA5String, "false@example.com"),
         ):
                 ("2.5.4.3", PrintableString, "false.example.com"),
                 ("1.2.840.113549.1.9.1", IA5String, "false@example.com"),
         ):
-            attr = AttributeTypeAndValue()
-            attr["type"] = AttributeType(oid)
-            attr["value"] = AttributeValue(klass(text))
-            rdn = RelativeDistinguishedName()
-            rdn.append(attr)
-            rdnSeq.append(rdn)
+            rdnSeq.append(
+                RelativeDistinguishedName((
+                    AttributeTypeAndValue((
+                        ("type", AttributeType(oid)),
+                        ("value", AttributeValue(klass(text))),
+                    )),
+                ))
+            )
         issuer = Name()
         issuer["rdnSequence"] = rdnSeq
         tbs["issuer"] = issuer
         tbs["subject"] = issuer
 
         issuer = Name()
         issuer["rdnSequence"] = rdnSeq
         tbs["issuer"] = issuer
         tbs["subject"] = issuer
 
-        validity = Validity()
-        validity["notBefore"] = Time(("utcTime", UTCTime(datetime(2009, 10, 8, 0, 25, 53))))
-        validity["notAfter"] = Time(("utcTime", UTCTime(datetime(2010, 10, 8, 0, 25, 53))))
+        validity = Validity((
+            ("notBefore", Time(
+                ("utcTime", UTCTime(datetime(2009, 10, 8, 0, 25, 53)),),
+            )),
+            ("notAfter", Time(
+                ("utcTime", UTCTime(datetime(2010, 10, 8, 0, 25, 53)),),
+            )),
+        ))
         tbs["validity"] = validity
 
         spki = SubjectPublicKeyInfo()
         tbs["validity"] = validity
 
         spki = SubjectPublicKeyInfo()
-        spki_algo_id = sign_algo_id.copy()
-        spki_algo_id["algorithm"] = ObjectIdentifier("1.2.840.113549.1.1.1")
+        spki_algo_id = copy(sign_algo_id)
+        spki_algo_id["algorithm"] = name2oid["id-rsaEncryption"]
         spki["algorithm"] = spki_algo_id
         spki["subjectPublicKey"] = BitString(hexdec("".join((
             "3048024100cdb7639c3278f006aa277f6eaf42902b592d8cbcbe38a1c92ba4695",
         spki["algorithm"] = spki_algo_id
         spki["subjectPublicKey"] = BitString(hexdec("".join((
             "3048024100cdb7639c3278f006aa277f6eaf42902b592d8cbcbe38a1c92ba4695",
@@ -334,9 +366,15 @@ class TestGoSelfSignedVector(TestCase):
             "998bb9a4a8cbeb34c0f0a78cf8d91ede14a5ed76bf116fe360aafa8821490435",
         ))))
         self.assertSequenceEqual(crt.encode(), raw)
             "998bb9a4a8cbeb34c0f0a78cf8d91ede14a5ed76bf116fe360aafa8821490435",
         ))))
         self.assertSequenceEqual(crt.encode(), raw)
+        self.assertEqual(
+            Certificate().decod(encode_cer(crt), ctx={"bered": True}),
+            crt,
+        )
 
 
 class TestGoPayPalVector(TestCase):
 
 
 class TestGoPayPalVector(TestCase):
+    """PayPal certificate with "www.paypal.com\x00ssl.secureconnection.cc" name
+    """
     def runTest(self):
         raw = hexdec("".join((
             "30820644308205ada003020102020300f09b300d06092a864886f70d010105050",
     def runTest(self):
         raw = hexdec("".join((
             "30820644308205ada003020102020300f09b300d06092a864886f70d010105050",
@@ -390,8 +428,5 @@ class TestGoPayPalVector(TestCase):
             "07ba44cce54a2d723f9847f626dc054605076321ab469b9c78d5545b3d0c1ec86",
             "48cb55023826fdbb8221c439607a8bb",
         )))
             "07ba44cce54a2d723f9847f626dc054605076321ab469b9c78d5545b3d0c1ec86",
             "48cb55023826fdbb8221c439607a8bb",
         )))
-        crt, tail = Certificate().decode(raw)
-        self.assertSequenceEqual(tail, b"")
-        self.assertSequenceEqual(crt.encode(), raw)
-        pprint(crt)
-        repr(crt)
+        with self.assertRaisesRegex(DecodeError, "alphabet value"):
+            crt = Certificate().decod(raw)