]> Cypherpunks.ru repositories - pyderasn.git/commitdiff
Streaming of huge data support
authorSergey Matveev <stargrave@stargrave.org>
Sat, 15 Feb 2020 14:52:45 +0000 (17:52 +0300)
committerSergey Matveev <stargrave@stargrave.org>
Sun, 16 Feb 2020 18:25:49 +0000 (21:25 +0300)
doc/news.rst
pyderasn.py
tests/test_cms.py

index ec2cf0262fae2b272acbad7a2b1709ca4c041e49..930a3ba0db3a4f408e79e46b9dd50ace77f247b3 100644 (file)
@@ -23,6 +23,9 @@ News
 * Initial experimental CER encoding mode, allowing streaming encoding of
   the data directly to some writeable object
 * Ability to use mmap-ed memoryviews to skip files loading to memory
 * Initial experimental CER encoding mode, allowing streaming encoding of
   the data directly to some writeable object
 * Ability to use mmap-ed memoryviews to skip files loading to memory
+* Ability to use memoryview as an input for \*Strings. If they are
+  mmap-ed, then you can encode any quantities of data streamingly
+  without copying to memory
 
 .. _release6.3:
 
 
 .. _release6.3:
 
index 8d88a3b352a0fa9238718d54578cb5d742bb6043..38f9fafdec12617fb68311b182dec1a5bca2eafd 100755 (executable)
@@ -3146,13 +3146,10 @@ class OctetString(Obj):
     >>> OctetString(b"hell", bounds=(4, 4))
     OCTET STRING 4 bytes 68656c6c
 
     >>> OctetString(b"hell", bounds=(4, 4))
     OCTET STRING 4 bytes 68656c6c
 
-    .. note::
-
-       Pay attention that OCTET STRING can be encoded both in primitive
-       and constructed forms. Decoder always checks constructed form tag
-       additionally to specified primitive one. If BER decoding is
-       :ref:`not enabled <bered_ctx>`, then decoder will fail, because
-       of DER restrictions.
+    Memoryviews can be used as a values. If memoryview is made on
+    mmap-ed file, then it does not take storage inside OctetString
+    itself. In CER encoding mode it will be streamed to the specified
+    writer, copying 1 KB chunks.
     """
     __slots__ = ("tag_constructed", "_bound_min", "_bound_max", "defined")
     tag_default = tag_encode(4)
     """
     __slots__ = ("tag_constructed", "_bound_min", "_bound_max", "defined")
     tag_default = tag_encode(4)
@@ -3207,12 +3204,12 @@ class OctetString(Obj):
         )
 
     def _value_sanitize(self, value):
         )
 
     def _value_sanitize(self, value):
-        if value.__class__ == binary_type:
+        if value.__class__ == binary_type or value.__class__ == memoryview:
             pass
         elif issubclass(value.__class__, OctetString):
             value = value._value
         else:
             pass
         elif issubclass(value.__class__, OctetString):
             value = value._value
         else:
-            raise InvalidValueType((self.__class__, bytes))
+            raise InvalidValueType((self.__class__, bytes, memoryview))
         if not self._bound_min <= len(value) <= self._bound_max:
             raise BoundsError(self._bound_min, len(value), self._bound_max)
         return value
         if not self._bound_min <= len(value) <= self._bound_max:
             raise BoundsError(self._bound_min, len(value), self._bound_max)
         return value
@@ -3252,7 +3249,7 @@ class OctetString(Obj):
 
     def __bytes__(self):
         self._assert_ready()
 
     def __bytes__(self):
         self._assert_ready()
-        return self._value
+        return bytes(self._value)
 
     def __eq__(self, their):
         if their.__class__ == binary_type:
 
     def __eq__(self, their):
         if their.__class__ == binary_type:
index 359b6bbf1f4a3eb932469b4f87121ebff7be79fe..86756f49924d41f01f4b1d9e3b053c57e7947ed7 100644 (file)
 # License along with this program.  If not, see
 # <http://www.gnu.org/licenses/>.
 
 # License along with this program.  If not, see
 # <http://www.gnu.org/licenses/>.
 
-from hashlib import sha256
+from hashlib import sha512
 from io import BytesIO
 from io import open as io_open
 from io import BytesIO
 from io import open as io_open
+from os import environ
 from os import remove
 from os import urandom
 from subprocess import call
 from tempfile import NamedTemporaryFile
 from os import remove
 from os import urandom
 from subprocess import call
 from tempfile import NamedTemporaryFile
+from time import time
 from unittest import skipIf
 from unittest import TestCase
 
 from unittest import skipIf
 from unittest import TestCase
 
@@ -29,6 +31,7 @@ from hypothesis import given
 from hypothesis import settings
 from hypothesis.strategies import integers
 from six import PY2
 from hypothesis import settings
 from hypothesis.strategies import integers
 from six import PY2
+from six.moves import xrange as six_xrange
 
 from pyderasn import agg_octet_string
 from pyderasn import Any
 
 from pyderasn import agg_octet_string
 from pyderasn import Any
@@ -159,39 +162,38 @@ class ContentInfo(Sequence):
 
 
 id_signedData = ObjectIdentifier("1.2.840.113549.1.7.2")
 
 
 id_signedData = ObjectIdentifier("1.2.840.113549.1.7.2")
-id_sha256 = ObjectIdentifier("2.16.840.1.101.3.4.2.1")
+id_sha512 = ObjectIdentifier("2.16.840.1.101.3.4.2.3")
 id_data = ObjectIdentifier("1.2.840.113549.1.7.1")
 id_data = ObjectIdentifier("1.2.840.113549.1.7.1")
-id_ecdsa_with_SHA256 = ObjectIdentifier("1.2.840.10045.4.3.2")
+id_ecdsa_with_SHA512 = ObjectIdentifier("1.2.840.10045.4.3.4")
 id_pkcs9_at_contentType = ObjectIdentifier("1.2.840.113549.1.9.3")
 id_pkcs9_at_messageDigest = ObjectIdentifier("1.2.840.113549.1.9.4")
 id_ce_subjectKeyIdentifier = ObjectIdentifier("2.5.29.14")
 id_pkcs9_at_contentType = ObjectIdentifier("1.2.840.113549.1.9.3")
 id_pkcs9_at_messageDigest = ObjectIdentifier("1.2.840.113549.1.9.4")
 id_ce_subjectKeyIdentifier = ObjectIdentifier("2.5.29.14")
+ai_sha512 = AlgorithmIdentifier((("algorithm", id_sha512),))
 
 openssl_cms_exists = call("openssl cms -help 2>/dev/null", shell=True) == 0
 
 @skipIf(not openssl_cms_exists, "openssl cms command not found")
 class TestSignedDataCERWithOpenSSL(TestCase):
 
 openssl_cms_exists = call("openssl cms -help 2>/dev/null", shell=True) == 0
 
 @skipIf(not openssl_cms_exists, "openssl cms command not found")
 class TestSignedDataCERWithOpenSSL(TestCase):
-    @settings(deadline=None)
-    @given(integers(min_value=1000, max_value=5000))
-    def runTest(self, data_len):
-        def tmpfile():
-            tmp = NamedTemporaryFile(delete=False)
-            tmp.close()
-            tmp = tmp.name
-            self.addCleanup(lambda: remove(tmp))
-            return tmp
-        key_path = tmpfile()
+    def tmpfile(self):
+        tmp = NamedTemporaryFile(delete=False)
+        tmp.close()
+        self.addCleanup(lambda: remove(tmp.name))
+        return tmp.name
+
+    def keypair(self):
+        key_path = self.tmpfile()
         self.assertEqual(0, call(
         self.assertEqual(0, call(
-            "openssl ecparam -name prime256v1 -genkey -out " + key_path,
+            "openssl ecparam -name secp521r1 -genkey -out " + key_path,
             shell=True,
         ))
             shell=True,
         ))
-        cert_path = tmpfile()
+        cert_path = self.tmpfile()
         self.assertEqual(0, call(" ".join((
             "openssl req -x509 -new",
             ("-key " + key_path),
             ("-outform PEM -out " + cert_path),
             "-nodes -subj /CN=pyderasntest",
         )), shell=True))
         self.assertEqual(0, call(" ".join((
             "openssl req -x509 -new",
             ("-key " + key_path),
             ("-outform PEM -out " + cert_path),
             "-nodes -subj /CN=pyderasntest",
         )), shell=True))
-        cert_der_path = tmpfile()
+        cert_der_path = self.tmpfile()
         self.assertEqual(0, call(" ".join((
             "openssl x509",
             "-inform PEM -in " + cert_path,
         self.assertEqual(0, call(" ".join((
             "openssl x509",
             "-inform PEM -in " + cert_path,
@@ -203,9 +205,35 @@ class TestSignedDataCERWithOpenSSL(TestCase):
         for ext in cert["tbsCertificate"]["extensions"]:
             if ext["extnID"] == id_ce_subjectKeyIdentifier:
                 skid = SubjectKeyIdentifier().decod(bytes(ext["extnValue"]))
         for ext in cert["tbsCertificate"]["extensions"]:
             if ext["extnID"] == id_ce_subjectKeyIdentifier:
                 skid = SubjectKeyIdentifier().decod(bytes(ext["extnValue"]))
-        ai_sha256 = AlgorithmIdentifier((
-            ("algorithm", id_sha256),
-        ))
+        return key_path, cert_path, cert, skid
+
+    def sign(self, signed_attrs, key_path):
+        input_path = self.tmpfile()
+        with open(input_path, "wb") as fd:
+            fd.write(encode_cer(signed_attrs))
+        signature_path = self.tmpfile()
+        self.assertEqual(0, call(" ".join((
+            "openssl dgst -sha512",
+            ("-sign " + key_path),
+            "-binary", input_path,
+            ("> " + signature_path),
+        )), shell=True))
+        with open(signature_path, "rb") as fd:
+            signature = fd.read()
+        return signature
+
+    def verify(self, cert_path, cms_path):
+        self.assertEqual(0, call(" ".join((
+            "openssl cms -verify",
+            ("-inform DER -in " + cms_path),
+            "-signer %s -CAfile %s" % (cert_path, cert_path),
+            "-out /dev/null 2>/dev/null",
+        )), shell=True))
+
+    @settings(deadline=None)
+    @given(integers(min_value=1000, max_value=5000))
+    def test_simple(self, data_len):
+        key_path, cert_path, cert, skid = self.keypair()
         data = urandom(data_len)
         eci = EncapsulatedContentInfo((
             ("eContentType", ContentType(id_data)),
         data = urandom(data_len)
         eci = EncapsulatedContentInfo((
             ("eContentType", ContentType(id_data)),
@@ -214,36 +242,23 @@ class TestSignedDataCERWithOpenSSL(TestCase):
         signed_attrs = SignedAttributes([
             Attribute((
                 ("attrType", id_pkcs9_at_contentType),
         signed_attrs = SignedAttributes([
             Attribute((
                 ("attrType", id_pkcs9_at_contentType),
-                ("attrValues", AttributeValues([
-                    AttributeValue(id_data.encode())
-                ])),
+                ("attrValues", AttributeValues([AttributeValue(id_data)])),
             )),
             Attribute((
                 ("attrType", id_pkcs9_at_messageDigest),
                 ("attrValues", AttributeValues([
                     AttributeValue(OctetString(
             )),
             Attribute((
                 ("attrType", id_pkcs9_at_messageDigest),
                 ("attrValues", AttributeValues([
                     AttributeValue(OctetString(
-                        sha256(bytes(eci["eContent"])).digest()
-                    ).encode()),
+                        sha512(bytes(eci["eContent"])).digest()
+                    )),
                 ])),
             )),
         ])
                 ])),
             )),
         ])
-        input_path = tmpfile()
-        with open(input_path, "wb") as fd:
-            fd.write(encode_cer(signed_attrs))
-        signature_path = tmpfile()
-        self.assertEqual(0, call(" ".join((
-            "openssl dgst -sha256",
-            ("-sign " + key_path),
-            "-binary", input_path,
-            ("> " + signature_path),
-        )), shell=True))
-        with open(signature_path, "rb") as fd:
-            signature = fd.read()
+        signature = self.sign(signed_attrs, key_path)
         ci = ContentInfo((
             ("contentType", ContentType(id_signedData)),
             ("content", Any((SignedData((
                 ("version", CMSVersion("v3")),
         ci = ContentInfo((
             ("contentType", ContentType(id_signedData)),
             ("content", Any((SignedData((
                 ("version", CMSVersion("v3")),
-                ("digestAlgorithms", DigestAlgorithmIdentifiers([ai_sha256])),
+                ("digestAlgorithms", DigestAlgorithmIdentifiers([ai_sha512])),
                 ("encapContentInfo", eci),
                 ("certificates", CertificateSet([
                     CertificateChoices(("certificate", cert)),
                 ("encapContentInfo", eci),
                 ("certificates", CertificateSet([
                     CertificateChoices(("certificate", cert)),
@@ -251,25 +266,20 @@ class TestSignedDataCERWithOpenSSL(TestCase):
                 ("signerInfos", SignerInfos([SignerInfo((
                     ("version", CMSVersion("v3")),
                     ("sid", SignerIdentifier(("subjectKeyIdentifier", skid))),
                 ("signerInfos", SignerInfos([SignerInfo((
                     ("version", CMSVersion("v3")),
                     ("sid", SignerIdentifier(("subjectKeyIdentifier", skid))),
-                    ("digestAlgorithm", DigestAlgorithmIdentifier(ai_sha256)),
+                    ("digestAlgorithm", DigestAlgorithmIdentifier(ai_sha512)),
                     ("signedAttrs", signed_attrs),
                     ("signatureAlgorithm", SignatureAlgorithmIdentifier((
                     ("signedAttrs", signed_attrs),
                     ("signatureAlgorithm", SignatureAlgorithmIdentifier((
-                        ("algorithm", id_ecdsa_with_SHA256),
+                        ("algorithm", id_ecdsa_with_SHA512),
                     ))),
                     ("signature", SignatureValue(signature)),
                 ))])),
             ))))),
         ))
                     ))),
                     ("signature", SignatureValue(signature)),
                 ))])),
             ))))),
         ))
-        output_path = tmpfile()
-        with io_open(output_path, "wb") as fd:
+        cms_path = self.tmpfile()
+        with io_open(cms_path, "wb") as fd:
             ci.encode_cer(writer=fd.write)
             ci.encode_cer(writer=fd.write)
-        self.assertEqual(0, call(" ".join((
-            "openssl cms -verify",
-            ("-inform DER -in " + output_path),
-            "-signer %s -CAfile %s" % (cert_path, cert_path),
-            "-out /dev/null 2>/dev/null",
-        )), shell=True))
-        fd = open(output_path, "rb")
+        self.verify(cert_path, cms_path)
+        fd = open(cms_path, "rb")
         raw = memoryview(fd.read()) if PY2 else file_mmaped(fd)
         ctx = {"bered": True}
         for decode_path, obj, _ in ContentInfo().decode_evgen(raw, ctx=ctx):
         raw = memoryview(fd.read()) if PY2 else file_mmaped(fd)
         ctx = {"bered": True}
         for decode_path, obj, _ in ContentInfo().decode_evgen(raw, ctx=ctx):
@@ -279,3 +289,90 @@ class TestSignedDataCERWithOpenSSL(TestCase):
         buf = BytesIO()
         agg_octet_string(evgens, ("encapContentInfo", "eContent"), raw, buf.write)
         self.assertSequenceEqual(buf.getvalue(), data)
         buf = BytesIO()
         agg_octet_string(evgens, ("encapContentInfo", "eContent"), raw, buf.write)
         self.assertSequenceEqual(buf.getvalue(), data)
+
+    @skipIf(PY2, "no mmaped memoryview support in PY2")
+    @skipIf("PYDERASN_TEST_CMS_HUGE" not in environ, "PYDERASN_TEST_CMS_HUGE is not set")
+    def test_huge(self):
+        """Huge CMS test
+
+        Environment variable PYDERASN_TEST_CMS_HUGE tells how many MiBs
+        data to sign. Pay attention that openssl cms is unable to do
+        stream verification and eats huge amounts (several times more,
+        that CMS itself) of memory.
+        """
+        key_path, cert_path, cert, skid = self.keypair()
+        rnd = urandom(1<<20)
+        data_path = self.tmpfile()
+        start = time()
+        with open(data_path, "wb") as fd:
+            for _ in six_xrange(int(environ.get("PYDERASN_TEST_CMS_HUGE"))):
+                # dgst.update(rnd)
+                fd.write(rnd)
+        print("data file written", time() - start)
+        data_fd = open(data_path, "rb")
+        data_raw = file_mmaped(data_fd)
+
+        from sys import getallocatedblocks
+        mem_start = getallocatedblocks()
+        start = time()
+        eci = EncapsulatedContentInfo((
+            ("eContentType", ContentType(id_data)),
+            ("eContent", OctetString(data_raw)),
+        ))
+        eci_path = self.tmpfile()
+        with open(eci_path, "wb") as fd:
+            OctetString(eci["eContent"]).encode_cer(writer=fd.write)
+        print("ECI file written", time() - start)
+        eci_fd = open(eci_path, "rb")
+        eci_raw = file_mmaped(eci_fd)
+
+        start = time()
+        dgst = sha512()
+        def hasher(data):
+            dgst.update(data)
+            return len(data)
+        evgens = OctetString().decode_evgen(eci_raw, ctx={"bered": True})
+        agg_octet_string(evgens, (), eci_raw, hasher)
+        dgst = dgst.digest()
+        print("digest calculated", time() - start)
+
+        signed_attrs = SignedAttributes([
+            Attribute((
+                ("attrType", id_pkcs9_at_contentType),
+                ("attrValues", AttributeValues([AttributeValue(id_data)])),
+            )),
+            Attribute((
+                ("attrType", id_pkcs9_at_messageDigest),
+                ("attrValues", AttributeValues([AttributeValue(OctetString(dgst))])),
+            )),
+        ])
+        signature = self.sign(signed_attrs, key_path)
+
+        self.assertLess(getallocatedblocks(), mem_start * 2)
+        start = time()
+        ci = ContentInfo((
+            ("contentType", ContentType(id_signedData)),
+            ("content", Any((SignedData((
+                ("version", CMSVersion("v3")),
+                ("digestAlgorithms", DigestAlgorithmIdentifiers([ai_sha512])),
+                ("encapContentInfo", eci),
+                ("certificates", CertificateSet([
+                    CertificateChoices(("certificate", cert)),
+                ])),
+                ("signerInfos", SignerInfos([SignerInfo((
+                    ("version", CMSVersion("v3")),
+                    ("sid", SignerIdentifier(("subjectKeyIdentifier", skid))),
+                    ("digestAlgorithm", DigestAlgorithmIdentifier(ai_sha512)),
+                    ("signedAttrs", signed_attrs),
+                    ("signatureAlgorithm", SignatureAlgorithmIdentifier((
+                        ("algorithm", id_ecdsa_with_SHA512),
+                    ))),
+                    ("signature", SignatureValue(signature)),
+                ))])),
+            ))))),
+        ))
+        cms_path = self.tmpfile()
+        with io_open(cms_path, "wb") as fd:
+            ci.encode_cer(writer=fd.write)
+        print("CMS written", time() - start)
+        self.verify(cert_path, cms_path)