]> Cypherpunks.ru repositories - pygost.git/blobdiff - pygost/test_x509.py
Use DEFINES PyDERASN feature for less .decode() invocations
[pygost.git] / pygost / test_x509.py
index cddee1c37ae166af20eca6ab2daf1c40faf480f5..df82bdb3ebc61fc9b3f9ae58f254f4aee4869cb1 100644 (file)
@@ -32,6 +32,8 @@ from pygost.utils import hexdec
 try:
     from pyderasn import OctetString
 
+    from pygost.asn1schemas.oids import id_tc26_gost3410_2012_256
+    from pygost.asn1schemas.oids import id_tc26_gost3410_2012_512
     from pygost.asn1schemas.x509 import Certificate
 except ImportError:
     pyderasn_exists = False
@@ -48,15 +50,34 @@ class TestCertificate(TestCase):
     """
 
     def process_cert(self, curve_name, mode, hasher, prv_key_raw, cert_raw):
-        cert, tail = Certificate().decode(cert_raw)
+        cert, tail = Certificate().decode(cert_raw, ctx={
+            "defines_by_path": (
+                (
+                    (
+                        "tbsCertificate",
+                        "subjectPublicKeyInfo",
+                        "algorithm",
+                        "algorithm",
+                    ),
+                    (
+                        (
+                            ("..", "subjectPublicKey"),
+                            {
+                                id_tc26_gost3410_2012_256: OctetString(),
+                                id_tc26_gost3410_2012_512: OctetString(),
+                            },
+                        ),
+                    ),
+                ),
+            ),
+        })
         self.assertSequenceEqual(tail, b"")
         curve = GOST3410Curve(*CURVE_PARAMS[curve_name])
         prv_key = prv_unmarshal(prv_key_raw)
-        pub_key_raw, tail = OctetString().decode(
-            bytes(cert["tbsCertificate"]["subjectPublicKeyInfo"]["subjectPublicKey"])
-        )
+        spk = cert["tbsCertificate"]["subjectPublicKeyInfo"]["subjectPublicKey"]
+        self.assertIsNotNone(spk.defined)
+        _, pub_key_raw = spk.defined
         pub_key = pub_unmarshal(bytes(pub_key_raw), mode=mode)
-        self.assertSequenceEqual(tail, b"")
         self.assertSequenceEqual(pub_key, public_key(curve, prv_key))
         self.assertTrue(verify(
             curve,