]> Cypherpunks.ru repositories - pyderasn.git/commitdiff
Streaming of huge data support
authorSergey Matveev <stargrave@stargrave.org>
Sat, 15 Feb 2020 14:52:45 +0000 (17:52 +0300)
committerSergey Matveev <stargrave@stargrave.org>
Sun, 16 Feb 2020 18:25:49 +0000 (21:25 +0300)
doc/news.rst
pyderasn.py
tests/test_cms.py

index ec2cf0262fae2b272acbad7a2b1709ca4c041e49..930a3ba0db3a4f408e79e46b9dd50ace77f247b3 100644 (file)
@@ -23,6 +23,9 @@ News
 * Initial experimental CER encoding mode, allowing streaming encoding of
   the data directly to some writeable object
 * Ability to use mmap-ed memoryviews to skip files loading to memory
+* Ability to use memoryview as an input for \*Strings. If they are
+  mmap-ed, then you can encode any quantities of data streamingly
+  without copying to memory
 
 .. _release6.3:
 
index 8d88a3b352a0fa9238718d54578cb5d742bb6043..38f9fafdec12617fb68311b182dec1a5bca2eafd 100755 (executable)
@@ -3146,13 +3146,10 @@ class OctetString(Obj):
     >>> OctetString(b"hell", bounds=(4, 4))
     OCTET STRING 4 bytes 68656c6c
 
-    .. note::
-
-       Pay attention that OCTET STRING can be encoded both in primitive
-       and constructed forms. Decoder always checks constructed form tag
-       additionally to specified primitive one. If BER decoding is
-       :ref:`not enabled <bered_ctx>`, then decoder will fail, because
-       of DER restrictions.
+    Memoryviews can be used as a values. If memoryview is made on
+    mmap-ed file, then it does not take storage inside OctetString
+    itself. In CER encoding mode it will be streamed to the specified
+    writer, copying 1 KB chunks.
     """
     __slots__ = ("tag_constructed", "_bound_min", "_bound_max", "defined")
     tag_default = tag_encode(4)
@@ -3207,12 +3204,12 @@ class OctetString(Obj):
         )
 
     def _value_sanitize(self, value):
-        if value.__class__ == binary_type:
+        if value.__class__ == binary_type or value.__class__ == memoryview:
             pass
         elif issubclass(value.__class__, OctetString):
             value = value._value
         else:
-            raise InvalidValueType((self.__class__, bytes))
+            raise InvalidValueType((self.__class__, bytes, memoryview))
         if not self._bound_min <= len(value) <= self._bound_max:
             raise BoundsError(self._bound_min, len(value), self._bound_max)
         return value
@@ -3252,7 +3249,7 @@ class OctetString(Obj):
 
     def __bytes__(self):
         self._assert_ready()
-        return self._value
+        return bytes(self._value)
 
     def __eq__(self, their):
         if their.__class__ == binary_type:
index 359b6bbf1f4a3eb932469b4f87121ebff7be79fe..86756f49924d41f01f4b1d9e3b053c57e7947ed7 100644 (file)
 # License along with this program.  If not, see
 # <http://www.gnu.org/licenses/>.
 
-from hashlib import sha256
+from hashlib import sha512
 from io import BytesIO
 from io import open as io_open
+from os import environ
 from os import remove
 from os import urandom
 from subprocess import call
 from tempfile import NamedTemporaryFile
+from time import time
 from unittest import skipIf
 from unittest import TestCase
 
@@ -29,6 +31,7 @@ from hypothesis import given
 from hypothesis import settings
 from hypothesis.strategies import integers
 from six import PY2
+from six.moves import xrange as six_xrange
 
 from pyderasn import agg_octet_string
 from pyderasn import Any
@@ -159,39 +162,38 @@ class ContentInfo(Sequence):
 
 
 id_signedData = ObjectIdentifier("1.2.840.113549.1.7.2")
-id_sha256 = ObjectIdentifier("2.16.840.1.101.3.4.2.1")
+id_sha512 = ObjectIdentifier("2.16.840.1.101.3.4.2.3")
 id_data = ObjectIdentifier("1.2.840.113549.1.7.1")
-id_ecdsa_with_SHA256 = ObjectIdentifier("1.2.840.10045.4.3.2")
+id_ecdsa_with_SHA512 = ObjectIdentifier("1.2.840.10045.4.3.4")
 id_pkcs9_at_contentType = ObjectIdentifier("1.2.840.113549.1.9.3")
 id_pkcs9_at_messageDigest = ObjectIdentifier("1.2.840.113549.1.9.4")
 id_ce_subjectKeyIdentifier = ObjectIdentifier("2.5.29.14")
+ai_sha512 = AlgorithmIdentifier((("algorithm", id_sha512),))
 
 openssl_cms_exists = call("openssl cms -help 2>/dev/null", shell=True) == 0
 
 @skipIf(not openssl_cms_exists, "openssl cms command not found")
 class TestSignedDataCERWithOpenSSL(TestCase):
-    @settings(deadline=None)
-    @given(integers(min_value=1000, max_value=5000))
-    def runTest(self, data_len):
-        def tmpfile():
-            tmp = NamedTemporaryFile(delete=False)
-            tmp.close()
-            tmp = tmp.name
-            self.addCleanup(lambda: remove(tmp))
-            return tmp
-        key_path = tmpfile()
+    def tmpfile(self):
+        tmp = NamedTemporaryFile(delete=False)
+        tmp.close()
+        self.addCleanup(lambda: remove(tmp.name))
+        return tmp.name
+
+    def keypair(self):
+        key_path = self.tmpfile()
         self.assertEqual(0, call(
-            "openssl ecparam -name prime256v1 -genkey -out " + key_path,
+            "openssl ecparam -name secp521r1 -genkey -out " + key_path,
             shell=True,
         ))
-        cert_path = tmpfile()
+        cert_path = self.tmpfile()
         self.assertEqual(0, call(" ".join((
             "openssl req -x509 -new",
             ("-key " + key_path),
             ("-outform PEM -out " + cert_path),
             "-nodes -subj /CN=pyderasntest",
         )), shell=True))
-        cert_der_path = tmpfile()
+        cert_der_path = self.tmpfile()
         self.assertEqual(0, call(" ".join((
             "openssl x509",
             "-inform PEM -in " + cert_path,
@@ -203,9 +205,35 @@ class TestSignedDataCERWithOpenSSL(TestCase):
         for ext in cert["tbsCertificate"]["extensions"]:
             if ext["extnID"] == id_ce_subjectKeyIdentifier:
                 skid = SubjectKeyIdentifier().decod(bytes(ext["extnValue"]))
-        ai_sha256 = AlgorithmIdentifier((
-            ("algorithm", id_sha256),
-        ))
+        return key_path, cert_path, cert, skid
+
+    def sign(self, signed_attrs, key_path):
+        input_path = self.tmpfile()
+        with open(input_path, "wb") as fd:
+            fd.write(encode_cer(signed_attrs))
+        signature_path = self.tmpfile()
+        self.assertEqual(0, call(" ".join((
+            "openssl dgst -sha512",
+            ("-sign " + key_path),
+            "-binary", input_path,
+            ("> " + signature_path),
+        )), shell=True))
+        with open(signature_path, "rb") as fd:
+            signature = fd.read()
+        return signature
+
+    def verify(self, cert_path, cms_path):
+        self.assertEqual(0, call(" ".join((
+            "openssl cms -verify",
+            ("-inform DER -in " + cms_path),
+            "-signer %s -CAfile %s" % (cert_path, cert_path),
+            "-out /dev/null 2>/dev/null",
+        )), shell=True))
+
+    @settings(deadline=None)
+    @given(integers(min_value=1000, max_value=5000))
+    def test_simple(self, data_len):
+        key_path, cert_path, cert, skid = self.keypair()
         data = urandom(data_len)
         eci = EncapsulatedContentInfo((
             ("eContentType", ContentType(id_data)),
@@ -214,36 +242,23 @@ class TestSignedDataCERWithOpenSSL(TestCase):
         signed_attrs = SignedAttributes([
             Attribute((
                 ("attrType", id_pkcs9_at_contentType),
-                ("attrValues", AttributeValues([
-                    AttributeValue(id_data.encode())
-                ])),
+                ("attrValues", AttributeValues([AttributeValue(id_data)])),
             )),
             Attribute((
                 ("attrType", id_pkcs9_at_messageDigest),
                 ("attrValues", AttributeValues([
                     AttributeValue(OctetString(
-                        sha256(bytes(eci["eContent"])).digest()
-                    ).encode()),
+                        sha512(bytes(eci["eContent"])).digest()
+                    )),
                 ])),
             )),
         ])
-        input_path = tmpfile()
-        with open(input_path, "wb") as fd:
-            fd.write(encode_cer(signed_attrs))
-        signature_path = tmpfile()
-        self.assertEqual(0, call(" ".join((
-            "openssl dgst -sha256",
-            ("-sign " + key_path),
-            "-binary", input_path,
-            ("> " + signature_path),
-        )), shell=True))
-        with open(signature_path, "rb") as fd:
-            signature = fd.read()
+        signature = self.sign(signed_attrs, key_path)
         ci = ContentInfo((
             ("contentType", ContentType(id_signedData)),
             ("content", Any((SignedData((
                 ("version", CMSVersion("v3")),
-                ("digestAlgorithms", DigestAlgorithmIdentifiers([ai_sha256])),
+                ("digestAlgorithms", DigestAlgorithmIdentifiers([ai_sha512])),
                 ("encapContentInfo", eci),
                 ("certificates", CertificateSet([
                     CertificateChoices(("certificate", cert)),
@@ -251,25 +266,20 @@ class TestSignedDataCERWithOpenSSL(TestCase):
                 ("signerInfos", SignerInfos([SignerInfo((
                     ("version", CMSVersion("v3")),
                     ("sid", SignerIdentifier(("subjectKeyIdentifier", skid))),
-                    ("digestAlgorithm", DigestAlgorithmIdentifier(ai_sha256)),
+                    ("digestAlgorithm", DigestAlgorithmIdentifier(ai_sha512)),
                     ("signedAttrs", signed_attrs),
                     ("signatureAlgorithm", SignatureAlgorithmIdentifier((
-                        ("algorithm", id_ecdsa_with_SHA256),
+                        ("algorithm", id_ecdsa_with_SHA512),
                     ))),
                     ("signature", SignatureValue(signature)),
                 ))])),
             ))))),
         ))
-        output_path = tmpfile()
-        with io_open(output_path, "wb") as fd:
+        cms_path = self.tmpfile()
+        with io_open(cms_path, "wb") as fd:
             ci.encode_cer(writer=fd.write)
-        self.assertEqual(0, call(" ".join((
-            "openssl cms -verify",
-            ("-inform DER -in " + output_path),
-            "-signer %s -CAfile %s" % (cert_path, cert_path),
-            "-out /dev/null 2>/dev/null",
-        )), shell=True))
-        fd = open(output_path, "rb")
+        self.verify(cert_path, cms_path)
+        fd = open(cms_path, "rb")
         raw = memoryview(fd.read()) if PY2 else file_mmaped(fd)
         ctx = {"bered": True}
         for decode_path, obj, _ in ContentInfo().decode_evgen(raw, ctx=ctx):
@@ -279,3 +289,90 @@ class TestSignedDataCERWithOpenSSL(TestCase):
         buf = BytesIO()
         agg_octet_string(evgens, ("encapContentInfo", "eContent"), raw, buf.write)
         self.assertSequenceEqual(buf.getvalue(), data)
+
+    @skipIf(PY2, "no mmaped memoryview support in PY2")
+    @skipIf("PYDERASN_TEST_CMS_HUGE" not in environ, "PYDERASN_TEST_CMS_HUGE is not set")
+    def test_huge(self):
+        """Huge CMS test
+
+        Environment variable PYDERASN_TEST_CMS_HUGE tells how many MiBs
+        data to sign. Pay attention that openssl cms is unable to do
+        stream verification and eats huge amounts (several times more,
+        that CMS itself) of memory.
+        """
+        key_path, cert_path, cert, skid = self.keypair()
+        rnd = urandom(1<<20)
+        data_path = self.tmpfile()
+        start = time()
+        with open(data_path, "wb") as fd:
+            for _ in six_xrange(int(environ.get("PYDERASN_TEST_CMS_HUGE"))):
+                # dgst.update(rnd)
+                fd.write(rnd)
+        print("data file written", time() - start)
+        data_fd = open(data_path, "rb")
+        data_raw = file_mmaped(data_fd)
+
+        from sys import getallocatedblocks
+        mem_start = getallocatedblocks()
+        start = time()
+        eci = EncapsulatedContentInfo((
+            ("eContentType", ContentType(id_data)),
+            ("eContent", OctetString(data_raw)),
+        ))
+        eci_path = self.tmpfile()
+        with open(eci_path, "wb") as fd:
+            OctetString(eci["eContent"]).encode_cer(writer=fd.write)
+        print("ECI file written", time() - start)
+        eci_fd = open(eci_path, "rb")
+        eci_raw = file_mmaped(eci_fd)
+
+        start = time()
+        dgst = sha512()
+        def hasher(data):
+            dgst.update(data)
+            return len(data)
+        evgens = OctetString().decode_evgen(eci_raw, ctx={"bered": True})
+        agg_octet_string(evgens, (), eci_raw, hasher)
+        dgst = dgst.digest()
+        print("digest calculated", time() - start)
+
+        signed_attrs = SignedAttributes([
+            Attribute((
+                ("attrType", id_pkcs9_at_contentType),
+                ("attrValues", AttributeValues([AttributeValue(id_data)])),
+            )),
+            Attribute((
+                ("attrType", id_pkcs9_at_messageDigest),
+                ("attrValues", AttributeValues([AttributeValue(OctetString(dgst))])),
+            )),
+        ])
+        signature = self.sign(signed_attrs, key_path)
+
+        self.assertLess(getallocatedblocks(), mem_start * 2)
+        start = time()
+        ci = ContentInfo((
+            ("contentType", ContentType(id_signedData)),
+            ("content", Any((SignedData((
+                ("version", CMSVersion("v3")),
+                ("digestAlgorithms", DigestAlgorithmIdentifiers([ai_sha512])),
+                ("encapContentInfo", eci),
+                ("certificates", CertificateSet([
+                    CertificateChoices(("certificate", cert)),
+                ])),
+                ("signerInfos", SignerInfos([SignerInfo((
+                    ("version", CMSVersion("v3")),
+                    ("sid", SignerIdentifier(("subjectKeyIdentifier", skid))),
+                    ("digestAlgorithm", DigestAlgorithmIdentifier(ai_sha512)),
+                    ("signedAttrs", signed_attrs),
+                    ("signatureAlgorithm", SignatureAlgorithmIdentifier((
+                        ("algorithm", id_ecdsa_with_SHA512),
+                    ))),
+                    ("signature", SignatureValue(signature)),
+                ))])),
+            ))))),
+        ))
+        cms_path = self.tmpfile()
+        with io_open(cms_path, "wb") as fd:
+            ci.encode_cer(writer=fd.write)
+        print("CMS written", time() - start)
+        self.verify(cert_path, cms_path)