]> Cypherpunks.ru repositories - pyderasn.git/blobdiff - tests/test_cms.py
Streaming of huge data support
[pyderasn.git] / tests / test_cms.py
index 359b6bbf1f4a3eb932469b4f87121ebff7be79fe..86756f49924d41f01f4b1d9e3b053c57e7947ed7 100644 (file)
 # License along with this program.  If not, see
 # <http://www.gnu.org/licenses/>.
 
 # License along with this program.  If not, see
 # <http://www.gnu.org/licenses/>.
 
-from hashlib import sha256
+from hashlib import sha512
 from io import BytesIO
 from io import open as io_open
 from io import BytesIO
 from io import open as io_open
+from os import environ
 from os import remove
 from os import urandom
 from subprocess import call
 from tempfile import NamedTemporaryFile
 from os import remove
 from os import urandom
 from subprocess import call
 from tempfile import NamedTemporaryFile
+from time import time
 from unittest import skipIf
 from unittest import TestCase
 
 from unittest import skipIf
 from unittest import TestCase
 
@@ -29,6 +31,7 @@ from hypothesis import given
 from hypothesis import settings
 from hypothesis.strategies import integers
 from six import PY2
 from hypothesis import settings
 from hypothesis.strategies import integers
 from six import PY2
+from six.moves import xrange as six_xrange
 
 from pyderasn import agg_octet_string
 from pyderasn import Any
 
 from pyderasn import agg_octet_string
 from pyderasn import Any
@@ -159,39 +162,38 @@ class ContentInfo(Sequence):
 
 
 id_signedData = ObjectIdentifier("1.2.840.113549.1.7.2")
 
 
 id_signedData = ObjectIdentifier("1.2.840.113549.1.7.2")
-id_sha256 = ObjectIdentifier("2.16.840.1.101.3.4.2.1")
+id_sha512 = ObjectIdentifier("2.16.840.1.101.3.4.2.3")
 id_data = ObjectIdentifier("1.2.840.113549.1.7.1")
 id_data = ObjectIdentifier("1.2.840.113549.1.7.1")
-id_ecdsa_with_SHA256 = ObjectIdentifier("1.2.840.10045.4.3.2")
+id_ecdsa_with_SHA512 = ObjectIdentifier("1.2.840.10045.4.3.4")
 id_pkcs9_at_contentType = ObjectIdentifier("1.2.840.113549.1.9.3")
 id_pkcs9_at_messageDigest = ObjectIdentifier("1.2.840.113549.1.9.4")
 id_ce_subjectKeyIdentifier = ObjectIdentifier("2.5.29.14")
 id_pkcs9_at_contentType = ObjectIdentifier("1.2.840.113549.1.9.3")
 id_pkcs9_at_messageDigest = ObjectIdentifier("1.2.840.113549.1.9.4")
 id_ce_subjectKeyIdentifier = ObjectIdentifier("2.5.29.14")
+ai_sha512 = AlgorithmIdentifier((("algorithm", id_sha512),))
 
 openssl_cms_exists = call("openssl cms -help 2>/dev/null", shell=True) == 0
 
 @skipIf(not openssl_cms_exists, "openssl cms command not found")
 class TestSignedDataCERWithOpenSSL(TestCase):
 
 openssl_cms_exists = call("openssl cms -help 2>/dev/null", shell=True) == 0
 
 @skipIf(not openssl_cms_exists, "openssl cms command not found")
 class TestSignedDataCERWithOpenSSL(TestCase):
-    @settings(deadline=None)
-    @given(integers(min_value=1000, max_value=5000))
-    def runTest(self, data_len):
-        def tmpfile():
-            tmp = NamedTemporaryFile(delete=False)
-            tmp.close()
-            tmp = tmp.name
-            self.addCleanup(lambda: remove(tmp))
-            return tmp
-        key_path = tmpfile()
+    def tmpfile(self):
+        tmp = NamedTemporaryFile(delete=False)
+        tmp.close()
+        self.addCleanup(lambda: remove(tmp.name))
+        return tmp.name
+
+    def keypair(self):
+        key_path = self.tmpfile()
         self.assertEqual(0, call(
         self.assertEqual(0, call(
-            "openssl ecparam -name prime256v1 -genkey -out " + key_path,
+            "openssl ecparam -name secp521r1 -genkey -out " + key_path,
             shell=True,
         ))
             shell=True,
         ))
-        cert_path = tmpfile()
+        cert_path = self.tmpfile()
         self.assertEqual(0, call(" ".join((
             "openssl req -x509 -new",
             ("-key " + key_path),
             ("-outform PEM -out " + cert_path),
             "-nodes -subj /CN=pyderasntest",
         )), shell=True))
         self.assertEqual(0, call(" ".join((
             "openssl req -x509 -new",
             ("-key " + key_path),
             ("-outform PEM -out " + cert_path),
             "-nodes -subj /CN=pyderasntest",
         )), shell=True))
-        cert_der_path = tmpfile()
+        cert_der_path = self.tmpfile()
         self.assertEqual(0, call(" ".join((
             "openssl x509",
             "-inform PEM -in " + cert_path,
         self.assertEqual(0, call(" ".join((
             "openssl x509",
             "-inform PEM -in " + cert_path,
@@ -203,9 +205,35 @@ class TestSignedDataCERWithOpenSSL(TestCase):
         for ext in cert["tbsCertificate"]["extensions"]:
             if ext["extnID"] == id_ce_subjectKeyIdentifier:
                 skid = SubjectKeyIdentifier().decod(bytes(ext["extnValue"]))
         for ext in cert["tbsCertificate"]["extensions"]:
             if ext["extnID"] == id_ce_subjectKeyIdentifier:
                 skid = SubjectKeyIdentifier().decod(bytes(ext["extnValue"]))
-        ai_sha256 = AlgorithmIdentifier((
-            ("algorithm", id_sha256),
-        ))
+        return key_path, cert_path, cert, skid
+
+    def sign(self, signed_attrs, key_path):
+        input_path = self.tmpfile()
+        with open(input_path, "wb") as fd:
+            fd.write(encode_cer(signed_attrs))
+        signature_path = self.tmpfile()
+        self.assertEqual(0, call(" ".join((
+            "openssl dgst -sha512",
+            ("-sign " + key_path),
+            "-binary", input_path,
+            ("> " + signature_path),
+        )), shell=True))
+        with open(signature_path, "rb") as fd:
+            signature = fd.read()
+        return signature
+
+    def verify(self, cert_path, cms_path):
+        self.assertEqual(0, call(" ".join((
+            "openssl cms -verify",
+            ("-inform DER -in " + cms_path),
+            "-signer %s -CAfile %s" % (cert_path, cert_path),
+            "-out /dev/null 2>/dev/null",
+        )), shell=True))
+
+    @settings(deadline=None)
+    @given(integers(min_value=1000, max_value=5000))
+    def test_simple(self, data_len):
+        key_path, cert_path, cert, skid = self.keypair()
         data = urandom(data_len)
         eci = EncapsulatedContentInfo((
             ("eContentType", ContentType(id_data)),
         data = urandom(data_len)
         eci = EncapsulatedContentInfo((
             ("eContentType", ContentType(id_data)),
@@ -214,36 +242,23 @@ class TestSignedDataCERWithOpenSSL(TestCase):
         signed_attrs = SignedAttributes([
             Attribute((
                 ("attrType", id_pkcs9_at_contentType),
         signed_attrs = SignedAttributes([
             Attribute((
                 ("attrType", id_pkcs9_at_contentType),
-                ("attrValues", AttributeValues([
-                    AttributeValue(id_data.encode())
-                ])),
+                ("attrValues", AttributeValues([AttributeValue(id_data)])),
             )),
             Attribute((
                 ("attrType", id_pkcs9_at_messageDigest),
                 ("attrValues", AttributeValues([
                     AttributeValue(OctetString(
             )),
             Attribute((
                 ("attrType", id_pkcs9_at_messageDigest),
                 ("attrValues", AttributeValues([
                     AttributeValue(OctetString(
-                        sha256(bytes(eci["eContent"])).digest()
-                    ).encode()),
+                        sha512(bytes(eci["eContent"])).digest()
+                    )),
                 ])),
             )),
         ])
                 ])),
             )),
         ])
-        input_path = tmpfile()
-        with open(input_path, "wb") as fd:
-            fd.write(encode_cer(signed_attrs))
-        signature_path = tmpfile()
-        self.assertEqual(0, call(" ".join((
-            "openssl dgst -sha256",
-            ("-sign " + key_path),
-            "-binary", input_path,
-            ("> " + signature_path),
-        )), shell=True))
-        with open(signature_path, "rb") as fd:
-            signature = fd.read()
+        signature = self.sign(signed_attrs, key_path)
         ci = ContentInfo((
             ("contentType", ContentType(id_signedData)),
             ("content", Any((SignedData((
                 ("version", CMSVersion("v3")),
         ci = ContentInfo((
             ("contentType", ContentType(id_signedData)),
             ("content", Any((SignedData((
                 ("version", CMSVersion("v3")),
-                ("digestAlgorithms", DigestAlgorithmIdentifiers([ai_sha256])),
+                ("digestAlgorithms", DigestAlgorithmIdentifiers([ai_sha512])),
                 ("encapContentInfo", eci),
                 ("certificates", CertificateSet([
                     CertificateChoices(("certificate", cert)),
                 ("encapContentInfo", eci),
                 ("certificates", CertificateSet([
                     CertificateChoices(("certificate", cert)),
@@ -251,25 +266,20 @@ class TestSignedDataCERWithOpenSSL(TestCase):
                 ("signerInfos", SignerInfos([SignerInfo((
                     ("version", CMSVersion("v3")),
                     ("sid", SignerIdentifier(("subjectKeyIdentifier", skid))),
                 ("signerInfos", SignerInfos([SignerInfo((
                     ("version", CMSVersion("v3")),
                     ("sid", SignerIdentifier(("subjectKeyIdentifier", skid))),
-                    ("digestAlgorithm", DigestAlgorithmIdentifier(ai_sha256)),
+                    ("digestAlgorithm", DigestAlgorithmIdentifier(ai_sha512)),
                     ("signedAttrs", signed_attrs),
                     ("signatureAlgorithm", SignatureAlgorithmIdentifier((
                     ("signedAttrs", signed_attrs),
                     ("signatureAlgorithm", SignatureAlgorithmIdentifier((
-                        ("algorithm", id_ecdsa_with_SHA256),
+                        ("algorithm", id_ecdsa_with_SHA512),
                     ))),
                     ("signature", SignatureValue(signature)),
                 ))])),
             ))))),
         ))
                     ))),
                     ("signature", SignatureValue(signature)),
                 ))])),
             ))))),
         ))
-        output_path = tmpfile()
-        with io_open(output_path, "wb") as fd:
+        cms_path = self.tmpfile()
+        with io_open(cms_path, "wb") as fd:
             ci.encode_cer(writer=fd.write)
             ci.encode_cer(writer=fd.write)
-        self.assertEqual(0, call(" ".join((
-            "openssl cms -verify",
-            ("-inform DER -in " + output_path),
-            "-signer %s -CAfile %s" % (cert_path, cert_path),
-            "-out /dev/null 2>/dev/null",
-        )), shell=True))
-        fd = open(output_path, "rb")
+        self.verify(cert_path, cms_path)
+        fd = open(cms_path, "rb")
         raw = memoryview(fd.read()) if PY2 else file_mmaped(fd)
         ctx = {"bered": True}
         for decode_path, obj, _ in ContentInfo().decode_evgen(raw, ctx=ctx):
         raw = memoryview(fd.read()) if PY2 else file_mmaped(fd)
         ctx = {"bered": True}
         for decode_path, obj, _ in ContentInfo().decode_evgen(raw, ctx=ctx):
@@ -279,3 +289,90 @@ class TestSignedDataCERWithOpenSSL(TestCase):
         buf = BytesIO()
         agg_octet_string(evgens, ("encapContentInfo", "eContent"), raw, buf.write)
         self.assertSequenceEqual(buf.getvalue(), data)
         buf = BytesIO()
         agg_octet_string(evgens, ("encapContentInfo", "eContent"), raw, buf.write)
         self.assertSequenceEqual(buf.getvalue(), data)
+
+    @skipIf(PY2, "no mmaped memoryview support in PY2")
+    @skipIf("PYDERASN_TEST_CMS_HUGE" not in environ, "PYDERASN_TEST_CMS_HUGE is not set")
+    def test_huge(self):
+        """Huge CMS test
+
+        Environment variable PYDERASN_TEST_CMS_HUGE tells how many MiBs
+        data to sign. Pay attention that openssl cms is unable to do
+        stream verification and eats huge amounts (several times more,
+        that CMS itself) of memory.
+        """
+        key_path, cert_path, cert, skid = self.keypair()
+        rnd = urandom(1<<20)
+        data_path = self.tmpfile()
+        start = time()
+        with open(data_path, "wb") as fd:
+            for _ in six_xrange(int(environ.get("PYDERASN_TEST_CMS_HUGE"))):
+                # dgst.update(rnd)
+                fd.write(rnd)
+        print("data file written", time() - start)
+        data_fd = open(data_path, "rb")
+        data_raw = file_mmaped(data_fd)
+
+        from sys import getallocatedblocks
+        mem_start = getallocatedblocks()
+        start = time()
+        eci = EncapsulatedContentInfo((
+            ("eContentType", ContentType(id_data)),
+            ("eContent", OctetString(data_raw)),
+        ))
+        eci_path = self.tmpfile()
+        with open(eci_path, "wb") as fd:
+            OctetString(eci["eContent"]).encode_cer(writer=fd.write)
+        print("ECI file written", time() - start)
+        eci_fd = open(eci_path, "rb")
+        eci_raw = file_mmaped(eci_fd)
+
+        start = time()
+        dgst = sha512()
+        def hasher(data):
+            dgst.update(data)
+            return len(data)
+        evgens = OctetString().decode_evgen(eci_raw, ctx={"bered": True})
+        agg_octet_string(evgens, (), eci_raw, hasher)
+        dgst = dgst.digest()
+        print("digest calculated", time() - start)
+
+        signed_attrs = SignedAttributes([
+            Attribute((
+                ("attrType", id_pkcs9_at_contentType),
+                ("attrValues", AttributeValues([AttributeValue(id_data)])),
+            )),
+            Attribute((
+                ("attrType", id_pkcs9_at_messageDigest),
+                ("attrValues", AttributeValues([AttributeValue(OctetString(dgst))])),
+            )),
+        ])
+        signature = self.sign(signed_attrs, key_path)
+
+        self.assertLess(getallocatedblocks(), mem_start * 2)
+        start = time()
+        ci = ContentInfo((
+            ("contentType", ContentType(id_signedData)),
+            ("content", Any((SignedData((
+                ("version", CMSVersion("v3")),
+                ("digestAlgorithms", DigestAlgorithmIdentifiers([ai_sha512])),
+                ("encapContentInfo", eci),
+                ("certificates", CertificateSet([
+                    CertificateChoices(("certificate", cert)),
+                ])),
+                ("signerInfos", SignerInfos([SignerInfo((
+                    ("version", CMSVersion("v3")),
+                    ("sid", SignerIdentifier(("subjectKeyIdentifier", skid))),
+                    ("digestAlgorithm", DigestAlgorithmIdentifier(ai_sha512)),
+                    ("signedAttrs", signed_attrs),
+                    ("signatureAlgorithm", SignatureAlgorithmIdentifier((
+                        ("algorithm", id_ecdsa_with_SHA512),
+                    ))),
+                    ("signature", SignatureValue(signature)),
+                ))])),
+            ))))),
+        ))
+        cms_path = self.tmpfile()
+        with io_open(cms_path, "wb") as fd:
+            ci.encode_cer(writer=fd.write)
+        print("CMS written", time() - start)
+        self.verify(cert_path, cms_path)