]> Cypherpunks.ru repositories - pygost.git/blobdiff - pygost/test_cms.py
Use DEFINES PyDERASN feature for less .decode() invocations
[pygost.git] / pygost / test_cms.py
index 8eeea91f3edabe6a93eacdfd9d1e0e41c6e31d0e..da6c8596bd95005f305b9fc102740059004d6027 100644 (file)
@@ -35,15 +35,13 @@ from pygost.wrap import unwrap_cryptopro
 from pygost.wrap import unwrap_gost
 
 try:
+    from pyderasn import DecodePathDefBy
     from pyderasn import OctetString
 
     from pygost.asn1schemas.cms import ContentInfo
-    from pygost.asn1schemas.cms import DigestedData
-    from pygost.asn1schemas.cms import EnvelopedData
-    from pygost.asn1schemas.cms import Gost2814789EncryptedKey
-    from pygost.asn1schemas.cms import Gost2814789Parameters
-    from pygost.asn1schemas.cms import GostR3410KeyTransport
-    from pygost.asn1schemas.cms import SignedData
+    from pygost.asn1schemas.oids import id_envelopedData
+    from pygost.asn1schemas.oids import id_tc26_gost3410_2012_256
+    from pygost.asn1schemas.oids import id_tc26_gost3410_2012_512
 except ImportError:
     pyderasn_exists = False
 else:
@@ -67,8 +65,8 @@ class TestSigned(TestCase):
     ):
         content_info, tail = ContentInfo().decode(content_info_raw)
         self.assertSequenceEqual(tail, b"")
-        signed_data, tail = SignedData().decode(bytes(content_info["content"]))
-        self.assertSequenceEqual(tail, b"")
+        self.assertIsNotNone(content_info["content"].defined)
+        _, signed_data = content_info["content"].defined
         self.assertEqual(len(signed_data["signerInfos"]), 1)
         curve = GOST3410Curve(*CURVE_PARAMS[curve_name])
         self.assertTrue(verify(
@@ -127,8 +125,8 @@ class TestDigested(TestCase):
     def process_cms(self, content_info_raw, hasher):
         content_info, tail = ContentInfo().decode(content_info_raw)
         self.assertSequenceEqual(tail, b"")
-        digested_data, tail = DigestedData().decode(bytes(content_info["content"]))
-        self.assertSequenceEqual(tail, b"")
+        self.assertIsNotNone(content_info["content"].defined)
+        _, digested_data = content_info["content"].defined
         self.assertSequenceEqual(
             hasher(bytes(digested_data["encapContentInfo"]["eContent"])).digest(),
             bytes(digested_data["digest"]),
@@ -169,30 +167,56 @@ class TestEnvelopedKTRI(TestCase):
             plaintext_expected,
     ):
         sbox = "Gost28147_tc26_ParamZ"
-        content_info, tail = ContentInfo().decode(content_info_raw)
-        self.assertSequenceEqual(tail, b"")
-        enveloped_data, tail = EnvelopedData().decode(bytes(content_info["content"]))
+        content_info, tail = ContentInfo().decode(content_info_raw, ctx={
+            "defines_by_path": [
+                (
+                    (
+                        "content",
+                        DecodePathDefBy(id_envelopedData),
+                        "recipientInfos",
+                        any,
+                        "ktri",
+                        "encryptedKey",
+                        DecodePathDefBy(spki_algorithm),
+                        "transportParameters",
+                        "ephemeralPublicKey",
+                        "algorithm",
+                        "algorithm",
+                    ),
+                    (
+                        (
+                            ("..", "subjectPublicKey"),
+                            {
+                                id_tc26_gost3410_2012_256: OctetString(),
+                                id_tc26_gost3410_2012_512: OctetString(),
+                            },
+                        ),
+                    ),
+                ) for spki_algorithm in (
+                    id_tc26_gost3410_2012_256,
+                    id_tc26_gost3410_2012_512,
+                )
+            ],
+        })
         self.assertSequenceEqual(tail, b"")
+        self.assertIsNotNone(content_info["content"].defined)
+        _, enveloped_data = content_info["content"].defined
         eci = enveloped_data["encryptedContentInfo"]
         ri = enveloped_data["recipientInfos"][0]
-        encrypted_key, tail = GostR3410KeyTransport().decode(
-            bytes(ri["ktri"]["encryptedKey"])
-        )
-        self.assertSequenceEqual(tail, b"")
+        self.assertIsNotNone(ri["ktri"]["encryptedKey"].defined)
+        _, encrypted_key = ri["ktri"]["encryptedKey"].defined
         ukm = bytes(encrypted_key["transportParameters"]["ukm"])
-        spk = bytes(encrypted_key["transportParameters"]["ephemeralPublicKey"]["subjectPublicKey"])
-        pub_key_their, tail = OctetString().decode(spk)
-        self.assertSequenceEqual(tail, b"")
+        spk = encrypted_key["transportParameters"]["ephemeralPublicKey"]["subjectPublicKey"]
+        self.assertIsNotNone(spk.defined)
+        _, pub_key_their = spk.defined
         curve = GOST3410Curve(*CURVE_PARAMS[curve_name])
         kek = keker(curve, prv_key_our, bytes(pub_key_their), ukm)
         key_wrapped = bytes(encrypted_key["sessionEncryptedKey"]["encryptedKey"])
         mac = bytes(encrypted_key["sessionEncryptedKey"]["macKey"])
         cek = unwrap_cryptopro(kek, ukm + key_wrapped + mac, sbox=sbox)
         ciphertext = bytes(eci["encryptedContent"])
-        encryption_params, tail = Gost2814789Parameters().decode(
-            bytes(eci["contentEncryptionAlgorithm"]["parameters"])
-        )
-        self.assertSequenceEqual(tail, b"")
+        self.assertIsNotNone(eci["contentEncryptionAlgorithm"]["parameters"].defined)
+        _, encryption_params = eci["contentEncryptionAlgorithm"]["parameters"].defined
         iv = bytes(encryption_params["iv"])
         self.assertSequenceEqual(
             cfb_decrypt(cek, ciphertext, iv, sbox=sbox, mesh=True),
@@ -323,32 +347,54 @@ class TestEnvelopedKARI(TestCase):
             plaintext_expected,
     ):
         sbox = "Gost28147_tc26_ParamZ"
-        content_info, tail = ContentInfo().decode(content_info_raw)
-        self.assertSequenceEqual(tail, b"")
-        enveloped_data, tail = EnvelopedData().decode(bytes(content_info["content"]))
+        content_info, tail = ContentInfo().decode(content_info_raw, ctx={
+            "defines_by_path": [
+                (
+                    (
+                        "content",
+                        DecodePathDefBy(id_envelopedData),
+                        "recipientInfos",
+                        any,
+                        "kari",
+                        "originator",
+                        "originatorKey",
+                        "algorithm",
+                        "algorithm",
+                    ),
+                    (
+                        (
+                            ("..", "publicKey"),
+                            {
+                                id_tc26_gost3410_2012_256: OctetString(),
+                                id_tc26_gost3410_2012_512: OctetString(),
+                            },
+                        ),
+                    ),
+                ) for spki_algorithm in (
+                    id_tc26_gost3410_2012_256,
+                    id_tc26_gost3410_2012_512,
+                )
+            ],
+        })
         self.assertSequenceEqual(tail, b"")
+        self.assertIsNotNone(content_info["content"].defined)
+        _, enveloped_data = content_info["content"].defined
         eci = enveloped_data["encryptedContentInfo"]
         kari = enveloped_data["recipientInfos"][0]["kari"]
-        pub_key_their, tail = OctetString().decode(
-            bytes(kari["originator"]["originatorKey"]["publicKey"]),
-        )
-        self.assertSequenceEqual(tail, b"")
+        self.assertIsNotNone(kari["originator"]["originatorKey"]["publicKey"].defined)
+        _, pub_key_their = kari["originator"]["originatorKey"]["publicKey"].defined
         ukm = bytes(kari["ukm"])
         rek = kari["recipientEncryptedKeys"][0]
         curve = GOST3410Curve(*CURVE_PARAMS[curve_name])
         kek = keker(curve, prv_key_our, bytes(pub_key_their), ukm)
-        encrypted_key, tail = Gost2814789EncryptedKey().decode(
-            bytes(rek["encryptedKey"]),
-        )
-        self.assertSequenceEqual(tail, b"")
+        self.assertIsNotNone(rek["encryptedKey"].defined)
+        _, encrypted_key = rek["encryptedKey"].defined
         key_wrapped = bytes(encrypted_key["encryptedKey"])
         mac = bytes(encrypted_key["macKey"])
         cek = unwrap_gost(kek, ukm + key_wrapped + mac, sbox=sbox)
         ciphertext = bytes(eci["encryptedContent"])
-        encryption_params, tail = Gost2814789Parameters().decode(
-            bytes(eci["contentEncryptionAlgorithm"]["parameters"])
-        )
-        self.assertSequenceEqual(tail, b"")
+        self.assertIsNotNone(eci["contentEncryptionAlgorithm"]["parameters"].defined)
+        _, encryption_params = eci["contentEncryptionAlgorithm"]["parameters"].defined
         iv = bytes(encryption_params["iv"])
         self.assertSequenceEqual(
             cfb_decrypt(cek, ciphertext, iv, sbox=sbox, mesh=True),