]> Cypherpunks.ru repositories - pyderasn.git/blobdiff - tests/test_cms.py
agg_octet_string
[pyderasn.git] / tests / test_cms.py
index 86df541d2169b43f8b656dd13d215b1bbb1dbb12..359b6bbf1f4a3eb932469b4f87121ebff7be79fe 100644 (file)
 # License along with this program.  If not, see
 # <http://www.gnu.org/licenses/>.
 
-from datetime import datetime
 from hashlib import sha256
+from io import BytesIO
+from io import open as io_open
+from os import remove
 from os import urandom
-from random import randint
 from subprocess import call
+from tempfile import NamedTemporaryFile
+from unittest import skipIf
 from unittest import TestCase
 
+from hypothesis import given
+from hypothesis import settings
+from hypothesis.strategies import integers
+from six import PY2
+
+from pyderasn import agg_octet_string
 from pyderasn import Any
 from pyderasn import Choice
 from pyderasn import encode_cer
+from pyderasn import file_mmaped
 from pyderasn import Integer
 from pyderasn import ObjectIdentifier
 from pyderasn import OctetString
@@ -32,11 +42,9 @@ from pyderasn import Sequence
 from pyderasn import SetOf
 from pyderasn import tag_ctxc
 from pyderasn import tag_ctxp
-from pyderasn import UTCTime
 from tests.test_crts import AlgorithmIdentifier
 from tests.test_crts import Certificate
 from tests.test_crts import SubjectKeyIdentifier
-from tests.test_crts import Time
 
 
 class CMSVersion(Integer):
@@ -158,13 +166,39 @@ id_pkcs9_at_contentType = ObjectIdentifier("1.2.840.113549.1.9.3")
 id_pkcs9_at_messageDigest = ObjectIdentifier("1.2.840.113549.1.9.4")
 id_ce_subjectKeyIdentifier = ObjectIdentifier("2.5.29.14")
 
-
-class TestSignedDataCER(TestCase):
-    def runTest(self):
-        # openssl ecparam -name prime256v1 -genkey -out key.pem
-        # openssl req -x509 -new -key key.pem -outform PEM -out cert.pem
-        #    -days 365 -nodes -subj "/CN=doesnotmatter"
-        with open("cert.cer", "rb") as fd:
+openssl_cms_exists = call("openssl cms -help 2>/dev/null", shell=True) == 0
+
+@skipIf(not openssl_cms_exists, "openssl cms command not found")
+class TestSignedDataCERWithOpenSSL(TestCase):
+    @settings(deadline=None)
+    @given(integers(min_value=1000, max_value=5000))
+    def runTest(self, data_len):
+        def tmpfile():
+            tmp = NamedTemporaryFile(delete=False)
+            tmp.close()
+            tmp = tmp.name
+            self.addCleanup(lambda: remove(tmp))
+            return tmp
+        key_path = tmpfile()
+        self.assertEqual(0, call(
+            "openssl ecparam -name prime256v1 -genkey -out " + key_path,
+            shell=True,
+        ))
+        cert_path = tmpfile()
+        self.assertEqual(0, call(" ".join((
+            "openssl req -x509 -new",
+            ("-key " + key_path),
+            ("-outform PEM -out " + cert_path),
+            "-nodes -subj /CN=pyderasntest",
+        )), shell=True))
+        cert_der_path = tmpfile()
+        self.assertEqual(0, call(" ".join((
+            "openssl x509",
+            "-inform PEM -in " + cert_path,
+            "-outform DER -out " + cert_der_path,
+        )), shell=True))
+        self.assertEqual(0, call("cat %s >> %s" % (key_path, cert_path), shell=True))
+        with open(cert_der_path, "rb") as fd:
             cert = Certificate().decod(fd.read())
         for ext in cert["tbsCertificate"]["extensions"]:
             if ext["extnID"] == id_ce_subjectKeyIdentifier:
@@ -172,7 +206,7 @@ class TestSignedDataCER(TestCase):
         ai_sha256 = AlgorithmIdentifier((
             ("algorithm", id_sha256),
         ))
-        data = urandom(randint(1000, 3000))
+        data = urandom(data_len)
         eci = EncapsulatedContentInfo((
             ("eContentType", ContentType(id_data)),
             ("eContent", OctetString(data)),
@@ -193,15 +227,18 @@ class TestSignedDataCER(TestCase):
                 ])),
             )),
         ])
-        with open("/tmp/in", "wb") as fd:
+        input_path = tmpfile()
+        with open(input_path, "wb") as fd:
             fd.write(encode_cer(signed_attrs))
+        signature_path = tmpfile()
         self.assertEqual(0, call(" ".join((
             "openssl dgst -sha256",
-            "-sign key.pem",
-            "-binary",
-            "/tmp/in",
-            "> /tmp/signature",
+            ("-sign " + key_path),
+            "-binary", input_path,
+            ("> " + signature_path),
         )), shell=True))
+        with open(signature_path, "rb") as fd:
+            signature = fd.read()
         ci = ContentInfo((
             ("contentType", ContentType(id_signedData)),
             ("content", Any((SignedData((
@@ -213,23 +250,32 @@ class TestSignedDataCER(TestCase):
                 ])),
                 ("signerInfos", SignerInfos([SignerInfo((
                     ("version", CMSVersion("v3")),
-                    ("sid", SignerIdentifier(
-                        ("subjectKeyIdentifier", skid)
-                    )),
+                    ("sid", SignerIdentifier(("subjectKeyIdentifier", skid))),
                     ("digestAlgorithm", DigestAlgorithmIdentifier(ai_sha256)),
                     ("signedAttrs", signed_attrs),
                     ("signatureAlgorithm", SignatureAlgorithmIdentifier((
                         ("algorithm", id_ecdsa_with_SHA256),
                     ))),
-                    ("signature", SignatureValue(open("/tmp/signature", "rb").read())),
+                    ("signature", SignatureValue(signature)),
                 ))])),
             ))))),
         ))
-        with open("/tmp/out.p7m", "wb") as fd:
-            fd.write(encode_cer(ci))
+        output_path = tmpfile()
+        with io_open(output_path, "wb") as fd:
+            ci.encode_cer(writer=fd.write)
         self.assertEqual(0, call(" ".join((
             "openssl cms -verify",
-            "-inform DER -in /tmp/out.p7m",
-            "-signer cert.pem -CAfile cert.pem",
-            "-out /dev/null",
+            ("-inform DER -in " + output_path),
+            "-signer %s -CAfile %s" % (cert_path, cert_path),
+            "-out /dev/null 2>/dev/null",
         )), shell=True))
+        fd = open(output_path, "rb")
+        raw = memoryview(fd.read()) if PY2 else file_mmaped(fd)
+        ctx = {"bered": True}
+        for decode_path, obj, _ in ContentInfo().decode_evgen(raw, ctx=ctx):
+            if decode_path == ("content",):
+                break
+        evgens = SignedData().decode_evgen(raw[obj.offset:], offset=obj.offset, ctx=ctx)
+        buf = BytesIO()
+        agg_octet_string(evgens, ("encapContentInfo", "eContent"), raw, buf.write)
+        self.assertSequenceEqual(buf.getvalue(), data)