]> Cypherpunks.ru repositories - pyderasn.git/commitdiff
DER 2pass encoding 7.2
authorSergey Matveev <stargrave@stargrave.org>
Mon, 17 Feb 2020 13:58:41 +0000 (16:58 +0300)
committerSergey Matveev <stargrave@stargrave.org>
Mon, 17 Feb 2020 15:06:34 +0000 (18:06 +0300)
README
doc/features.rst
doc/news.rst
pyderasn.py
tests/test_cms.py
tests/test_crl.py
tests/test_pyderasn.py

diff --git a/README b/README
index 150779fd3bad1099ccdd8a944d44d2a354b9718d..b4d771cf3f14486dfecc45668b8f126b02aaed1c 100644 (file)
--- a/README
+++ b/README
@@ -19,9 +19,9 @@ PyDERASN -- strict and fast ASN.1 DER/CER/BER library for Python
 * Ability to allow BER-encoded data with knowing if any of specified
   field has either DER or BER encoding (or possibly indefinite-length
   encoding)
-* Ability to use mmap-ed files, memoryviews, iterators and CER encoder
-  dealing with the writer, giving ability to create huge ASN.1 encoded
-  files without storing all the data in the memory first
+* Ability to use mmap-ed files, memoryviews, iterators, 2-pass DER
+  encoding mode and CER encoder dealing with the writer, giving ability
+  to create huge ASN.1 encoded files with very little memory footprint
 * Ability to decode files in event generation mode, without the need to
   keep all the data and decoded structures in the memory
 * __slots__, copy.copy() friendliness
index b17e9a8038118ace3e3e456424e29bcc6de2d0ac..ff913bbc07e4b0ee949ae637bf91469dcec6489c 100644 (file)
@@ -38,9 +38,9 @@ Also there is `asn1crypto <https://github.com/wbond/asn1crypto>`__.
   structures allow BER encoding for the whole message, except for
   ``SignedAttributes`` -- you can easily verify your CMS satisfies that
   requirement
-* Ability to use mmap-ed files, memoryviews, iterators and CER encoder
-  dealing with the writer, giving ability to create huge ASN.1 encoded
-  files without storing all the data in the memory first
+* Ability to use mmap-ed files, memoryviews, iterators, 2-pass DER
+  encoding mode and CER encoder dealing with the writer, giving ability
+  to create huge ASN.1 encoded files with very little memory footprint
 * Ability to decode files in event generation mode, without the need to
   keep all the data and decoded structures (that takes huge quantity of
   memory in all known ASN.1 libraries) in the memory
index 8ed34217d76109acccd8f260f2503e156bb58cbe..919944a05005a9ed61426f317be2c511d9ea5036 100644 (file)
@@ -7,6 +7,7 @@ News
 ---
 
 * Restored workability of some command line options
+* 2-pass DER encoding mode with very little memory footprint
 
 .. _release7.1:
 
index 8208efeb024148e54e8c6e71ca369383b393f273..c9d6a116044b5515b0f75b6b5271327345106f6d 100755 (executable)
@@ -829,6 +829,8 @@ copy the payload (without BER/CER encoding interleaved overhead) in it.
 Virtually it won't take memory more than for keeping small structures
 and 1 KB binary chunks.
 
+.. _seqof-iterators:
+
 SEQUENCE OF iterators
 _____________________
 
@@ -844,6 +846,54 @@ generator taking necessary data from the database and giving the
 ``RevokedCertificate`` objects. Only binary representation of that
 objects will take memory during DER encoding.
 
+2-pass DER encoding
+-------------------
+
+There is ability to do 2-pass encoding to DER, writing results directly
+to specified writer (buffer, file, whatever). It could be 1.5+ times
+slower than ordinary encoding, but it takes little memory for 1st pass
+state storing. For example, 1st pass state for CACert.org's CRL with
+~416K of certificate entries takes nearly 3.5 MB of memory.
+``SignedData`` with several gigabyte ``EncapsulatedContentInfo`` takes
+nearly 0.5 KB of memory.
+
+If you use :ref:`mmap-ed <mmap>` memoryviews, :ref:`SEQUENCE OF
+iterators <seqof-iterators>` and write directly to opened file, then
+there is very small memory footprint.
+
+1st pass traverses through all the objects of the structure and returns
+the size of DER encoded structure, together with 1st pass state object.
+That state contains precalculated lengths for various objects inside the
+structure.
+
+::
+
+    fulllen, state = obj.encode1st()
+
+2nd pass takes the writer and 1st pass state. It traverses through all
+the objects again, but writes their encoded representation to the writer.
+
+::
+
+    opener = io.open if PY2 else open
+    with opener("result", "wb") as fd:
+        obj.encode2nd(fd.write, iter(state))
+
+.. warning::
+
+   You **MUST NOT** use 1st pass state if anything is changed in the
+   objects. It is intended to be used immediately after 1st pass is
+   done!
+
+If you use :ref:`SEQUENCE OF iterators <seqof-iterators>`, then you
+have to reinitialize the values after the 1st pass. And you **have to**
+be sure that the iterator gives exactly the same values as previously.
+Yes, you have to run your iterator twice -- because this is two pass
+encoding mode.
+
+If you want to encode to the memory, then you can use convenient
+:py:func:`pyderasn.encode2pass` helper.
+
 Base Obj
 --------
 .. autoclass:: pyderasn.Obj
@@ -955,6 +1005,7 @@ Various
 .. autofunction:: pyderasn.abs_decode_path
 .. autofunction:: pyderasn.agg_octet_string
 .. autofunction:: pyderasn.colonize_hex
+.. autofunction:: pyderasn.encode2pass
 .. autofunction:: pyderasn.encode_cer
 .. autofunction:: pyderasn.file_mmaped
 .. autofunction:: pyderasn.hexenc
@@ -1101,6 +1152,7 @@ from mmap import PROT_READ
 from operator import attrgetter
 from string import ascii_letters
 from string import digits
+from sys import maxsize as sys_maxsize
 from sys import version_info
 from unicodedata import category as unicat
 
@@ -1138,6 +1190,7 @@ __all__ = (
     "Choice",
     "DecodeError",
     "DecodePathDefBy",
+    "encode2pass",
     "encode_cer",
     "Enumerated",
     "ExceedingData",
@@ -1525,6 +1578,28 @@ LEN1 = len_encode(1)
 LEN1K = len_encode(1000)
 
 
+def len_size(l):
+    """How many bytes length field will take
+    """
+    if l < 128:
+        return 1
+    if l < 256:  # 1 << 8
+        return 2
+    if l < 65536:  # 1 << 16
+        return 3
+    if l < 16777216:  # 1 << 24
+        return 4
+    if l < 4294967296:  # 1 << 32
+        return 5
+    if l < 1099511627776:  # 1 << 40
+        return 6
+    if l < 281474976710656:  # 1 << 48
+        return 7
+    if l < 72057594037927936:  # 1 << 56
+        return 8
+    raise OverflowError("too big length")
+
+
 def write_full(writer, data):
     """Fully write provided data
 
@@ -1543,6 +1618,17 @@ def write_full(writer, data):
         written += n
 
 
+# If it is 64-bit system, then use compact 64-bit array of unsigned
+# longs. Use an ordinary list with universal integers otherwise, that
+# is slower.
+if sys_maxsize > 2 ** 32:
+    def state_2pass_new():
+        return array("L")
+else:
+    def state_2pass_new():
+        return []
+
+
 ########################################################################
 # Base class
 ########################################################################
@@ -1701,9 +1787,18 @@ class Obj(object):
     def _encode(self):  # pragma: no cover
         raise NotImplementedError()
 
+    def _encode_cer(self, writer):
+        write_full(writer, self._encode())
+
     def _decode(self, tlv, offset, decode_path, ctx, tag_only, evgen_mode):  # pragma: no cover
         yield NotImplemented
 
+    def _encode1st(self, state):
+        raise NotImplementedError()
+
+    def _encode2nd(self, writer, state_iter):
+        raise NotImplementedError()
+
     def encode(self):
         """DER encode the structure
 
@@ -1714,6 +1809,36 @@ class Obj(object):
             return raw
         return b"".join((self._expl, len_encode(len(raw)), raw))
 
+    def encode1st(self, state=None):
+        """Do the 1st pass of 2-pass encoding
+
+        :rtype: (int, array("L"))
+        :returns: full length of encoded data and precalculated various
+                  objects lengths
+        """
+        if state is None:
+            state = state_2pass_new()
+        if self._expl is None:
+            return self._encode1st(state)
+        state.append(0)
+        idx = len(state) - 1
+        vlen, _ = self._encode1st(state)
+        state[idx] = vlen
+        fulllen = len(self._expl) + len_size(vlen) + vlen
+        return fulllen, state
+
+    def encode2nd(self, writer, state_iter):
+        """Do the 2nd pass of 2-pass encoding
+
+        :param writer: must comply with ``io.RawIOBase.write`` behaviour
+        :param state_iter: iterator over the 1st pass state (``iter(state)``)
+        """
+        if self._expl is None:
+            self._encode2nd(writer, state_iter)
+        else:
+            write_full(writer, self._expl + len_encode(next(state_iter)))
+            self._encode2nd(writer, state_iter)
+
     def encode_cer(self, writer):
         """CER encode the structure to specified writer
 
@@ -1731,9 +1856,6 @@ class Obj(object):
         if self._expl is not None:
             write_full(writer, EOC)
 
-    def _encode_cer(self, writer):
-        write_full(writer, self._encode())
-
     def hexencode(self):
         """Do hexadecimal encoded :py:meth:`pyderasn.Obj.encode`
         """
@@ -2045,6 +2167,17 @@ def encode_cer(obj):
     return buf.getvalue()
 
 
+def encode2pass(obj):
+    """Encode (2-pass mode) to DER in memory buffer
+
+    :returns bytes: memory buffer contents
+    """
+    buf = BytesIO()
+    _, state = obj.encode1st()
+    obj.encode2nd(buf.write, iter(state))
+    return buf.getvalue()
+
+
 class DecodePathDefBy(object):
     """DEFINED BY representation inside decode path
     """
@@ -2464,6 +2597,13 @@ class Boolean(Obj):
         self._assert_ready()
         return b"".join((self.tag, LEN1, (b"\xFF" if self._value else b"\x00")))
 
+    def _encode1st(self, state):
+        return len(self.tag) + 2, state
+
+    def _encode2nd(self, writer, state_iter):
+        self._assert_ready()
+        write_full(writer, self._encode())
+
     def _decode(self, tlv, offset, decode_path, ctx, tag_only, evgen_mode):
         try:
             t, _, lv = tag_strip(tlv)
@@ -2755,7 +2895,7 @@ class Integer(Obj):
             _specs=self.specs,
         )
 
-    def _encode(self):
+    def _encode_payload(self):
         self._assert_ready()
         value = self._value
         if PY2:
@@ -2792,8 +2932,20 @@ class Integer(Obj):
                     bytes_len += 1
                 else:
                     break
+        return octets
+        return b"".join((self.tag, len_encode(len(octets)), octets))
+
+    def _encode(self):
+        octets = self._encode_payload()
         return b"".join((self.tag, len_encode(len(octets)), octets))
 
+    def _encode1st(self, state):
+        l = len(self._encode_payload())
+        return len(self.tag) + len_size(l) + l, state
+
+    def _encode2nd(self, writer, state_iter):
+        write_full(writer, self._encode())
+
     def _decode(self, tlv, offset, decode_path, ctx, tag_only, evgen_mode):
         try:
             t, _, lv = tag_strip(tlv)
@@ -3178,6 +3330,21 @@ class BitString(Obj):
             octets,
         ))
 
+    def _encode1st(self, state):
+        self._assert_ready()
+        _, octets = self._value
+        l = len(octets) + 1
+        return len(self.tag) + len_size(l) + l, state
+
+    def _encode2nd(self, writer, state_iter):
+        bit_len, octets = self._value
+        write_full(writer, b"".join((
+            self.tag,
+            len_encode(len(octets) + 1),
+            int2byte((8 - bit_len % 8) % 8),
+        )))
+        write_full(writer, octets)
+
     def _encode_cer(self, writer):
         bit_len, octets = self._value
         if len(octets) + 1 <= 1000:
@@ -3629,6 +3796,16 @@ class OctetString(Obj):
             self._value,
         ))
 
+    def _encode1st(self, state):
+        self._assert_ready()
+        l = len(self._value)
+        return len(self.tag) + len_size(l) + l, state
+
+    def _encode2nd(self, writer, state_iter):
+        value = self._value
+        write_full(writer, self.tag + len_encode(len(value)))
+        write_full(writer, value)
+
     def _encode_cer(self, writer):
         octets = self._value
         if len(octets) <= 1000:
@@ -3987,6 +4164,12 @@ class Null(Obj):
     def _encode(self):
         return self.tag + LEN0
 
+    def _encode1st(self, state):
+        return len(self.tag) + 1, state
+
+    def _encode2nd(self, writer, state_iter):
+        write_full(writer, self.tag + LEN0)
+
     def _decode(self, tlv, offset, decode_path, ctx, tag_only, evgen_mode):
         try:
             t, _, lv = tag_strip(tlv)
@@ -4239,7 +4422,7 @@ class ObjectIdentifier(Obj):
             optional=self.optional if optional is None else optional,
         )
 
-    def _encode(self):
+    def _encode_octets(self):
         self._assert_ready()
         value = self._value
         first_value = value[1]
@@ -4255,9 +4438,19 @@ class ObjectIdentifier(Obj):
         octets = [zero_ended_encode(first_value)]
         for arc in value[2:]:
             octets.append(zero_ended_encode(arc))
-        v = b"".join(octets)
+        return b"".join(octets)
+
+    def _encode(self):
+        v = self._encode_octets()
         return b"".join((self.tag, len_encode(len(v)), v))
 
+    def _encode1st(self, state):
+        l = len(self._encode_octets())
+        return len(self.tag) + len_size(l) + l, state
+
+    def _encode2nd(self, writer, state_iter):
+        write_full(writer, self._encode())
+
     def _decode(self, tlv, offset, decode_path, ctx, tag_only, evgen_mode):
         try:
             t, _, lv = tag_strip(tlv)
@@ -5017,6 +5210,13 @@ class UTCTime(VisibleString):
         self._assert_ready()
         return b"".join((self.tag, LEN_LEN_YYMMDDHHMMSSZ, self._encode_time()))
 
+    def _encode1st(self, state):
+        return len(self.tag) + LEN_YYMMDDHHMMSSZ_WITH_LEN, state
+
+    def _encode2nd(self, writer, state_iter):
+        self._assert_ready()
+        write_full(writer, self._encode())
+
     def _encode_cer(self, writer):
         write_full(writer, self._encode())
 
@@ -5189,6 +5389,14 @@ class GeneralizedTime(UTCTime):
             return b"".join((self.tag, len_encode(len(encoded)), encoded))
         return b"".join((self.tag, LEN_LEN_YYYYMMDDHHMMSSZ, self._encode_time()))
 
+    def _encode1st(self, state):
+        self._assert_ready()
+        vlen = len(self._encode_time())
+        return len(self.tag) + len_size(vlen) + vlen, state
+
+    def _encode2nd(self, writer, state_iter):
+        write_full(writer, self._encode())
+
 
 class GraphicString(CommonString):
     __slots__ = ()
@@ -5433,6 +5641,13 @@ class Choice(Obj):
         self._assert_ready()
         return self._value[1].encode()
 
+    def _encode1st(self, state):
+        self._assert_ready()
+        return self._value[1].encode1st(state)
+
+    def _encode2nd(self, writer, state_iter):
+        self._value[1].encode2nd(writer, state_iter)
+
     def _encode_cer(self, writer):
         self._assert_ready()
         self._value[1].encode_cer(writer)
@@ -5701,6 +5916,20 @@ class Any(Obj):
             return value
         return value.encode()
 
+    def _encode1st(self, state):
+        self._assert_ready()
+        value = self._value
+        if value.__class__ == binary_type:
+            return len(value), state
+        return value.encode1st(state)
+
+    def _encode2nd(self, writer, state_iter):
+        value = self._value
+        if value.__class__ == binary_type:
+            write_full(writer, value)
+        else:
+            value.encode2nd(writer, state_iter)
+
     def _encode_cer(self, writer):
         self._assert_ready()
         value = self._value
@@ -5859,7 +6088,19 @@ SequenceState = namedtuple(
 )
 
 
-class Sequence(Obj):
+class SequenceEncode1stMixing(object):
+    def _encode1st(self, state):
+        state.append(0)
+        idx = len(state) - 1
+        vlen = 0
+        for v in self._values_for_encoding():
+            l, _ = v.encode1st(state)
+            vlen += l
+        state[idx] = vlen
+        return len(self.tag) + len_size(vlen) + vlen, state
+
+
+class Sequence(SequenceEncode1stMixing, Obj):
     """``SEQUENCE`` structure type
 
     You have to make specification of sequence::
@@ -6106,6 +6347,11 @@ class Sequence(Obj):
         v = b"".join(v.encode() for v in self._values_for_encoding())
         return b"".join((self.tag, len_encode(len(v)), v))
 
+    def _encode2nd(self, writer, state_iter):
+        write_full(writer, self.tag + len_encode(next(state_iter)))
+        for v in self._values_for_encoding():
+            v.encode2nd(writer, state_iter)
+
     def _encode_cer(self, writer):
         write_full(writer, self.tag + LENINDEF)
         for v in self._values_for_encoding():
@@ -6350,7 +6596,7 @@ class Sequence(Obj):
             yield pp
 
 
-class Set(Sequence):
+class Set(Sequence, SequenceEncode1stMixing):
     """``SET`` structure type
 
     Its usage is identical to :py:class:`pyderasn.Sequence`.
@@ -6552,7 +6798,7 @@ SequenceOfState = namedtuple(
 )
 
 
-class SequenceOf(Obj):
+class SequenceOf(SequenceEncode1stMixing, Obj):
     """``SEQUENCE OF`` sequence type
 
     For that kind of type you must specify the object it will carry on
@@ -6781,6 +7027,31 @@ class SequenceOf(Obj):
             value = b"".join(v.encode() for v in self._values_for_encoding())
         return b"".join((self.tag, len_encode(len(value)), value))
 
+    def _encode1st(self, state):
+        state = super(SequenceOf, self)._encode1st(state)
+        if hasattr(self._value, NEXT_ATTR_NAME):
+            self._value = []
+        return state
+
+    def _encode2nd(self, writer, state_iter):
+        write_full(writer, self.tag + len_encode(next(state_iter)))
+        iterator = hasattr(self._value, NEXT_ATTR_NAME)
+        if iterator:
+            values_count = 0
+            class_expected = self.spec.__class__
+            values_for_encoding = self._values_for_encoding()
+            self._value = []
+            for v in values_for_encoding:
+                if not isinstance(v, class_expected):
+                    raise InvalidValueType((class_expected,))
+                v.encode2nd(writer, state_iter)
+                values_count += 1
+            if not self._bound_min <= values_count <= self._bound_max:
+                raise BoundsError(self._bound_min, values_count, self._bound_max)
+        else:
+            for v in self._values_for_encoding():
+                v.encode2nd(writer, state_iter)
+
     def _encode_cer(self, writer):
         write_full(writer, self.tag + LENINDEF)
         iterator = hasattr(self._value, NEXT_ATTR_NAME)
@@ -7004,6 +7275,17 @@ class SetOf(SequenceOf):
         v = b"".join(sorted(v.encode() for v in self._values_for_encoding()))
         return b"".join((self.tag, len_encode(len(v)), v))
 
+    def _encode2nd(self, writer, state_iter):
+        write_full(writer, self.tag + len_encode(next(state_iter)))
+        values = []
+        for v in self._values_for_encoding():
+            buf = BytesIO()
+            v.encode2nd(buf.write, state_iter)
+            values.append(buf.getvalue())
+        values.sort()
+        for v in values:
+            write_full(writer, v)
+
     def _encode_cer(self, writer):
         write_full(writer, self.tag + LENINDEF)
         for v in sorted(encode_cer(v) for v in self._values_for_encoding()):
index faa52ab4aac0a0d11695c04e3c9d3bf4e0e883ee..51bdaf1cbf26adb3d2843d9481d8c36766032eb3 100644 (file)
@@ -22,6 +22,7 @@ from os import environ
 from os import remove
 from os import urandom
 from subprocess import call
+from sys import getsizeof
 from tempfile import NamedTemporaryFile
 from time import time
 from unittest import skipIf
@@ -276,6 +277,10 @@ class TestSignedDataCERWithOpenSSL(TestCase):
             ))))),
         ))
         cms_path = self.tmpfile()
+        _, state = ci.encode1st()
+        with io_open(cms_path, "wb") as fd:
+            ci.encode2nd(fd.write, iter(state))
+        self.verify(cert_path, cms_path)
         with io_open(cms_path, "wb") as fd:
             ci.encode_cer(fd.write)
         self.verify(cert_path, cms_path)
@@ -290,17 +295,7 @@ class TestSignedDataCERWithOpenSSL(TestCase):
         agg_octet_string(evgens, ("encapContentInfo", "eContent"), raw, buf.write)
         self.assertSequenceEqual(buf.getvalue(), data)
 
-    @skipIf(PY2, "no mmaped memoryview support in PY2")
-    @skipIf("PYDERASN_TEST_CMS_HUGE" not in environ, "PYDERASN_TEST_CMS_HUGE is not set")
-    def test_huge(self):
-        """Huge CMS test
-
-        Environment variable PYDERASN_TEST_CMS_HUGE tells how many MiBs
-        data to sign. Pay attention that openssl cms is unable to do
-        stream verification and eats huge amounts (several times more,
-        that CMS itself) of memory.
-        """
-        key_path, cert_path, cert, skid = self.keypair()
+    def create_huge_file(self):
         rnd = urandom(1<<20)
         data_path = self.tmpfile()
         start = time()
@@ -309,10 +304,21 @@ class TestSignedDataCERWithOpenSSL(TestCase):
                 # dgst.update(rnd)
                 fd.write(rnd)
         print("data file written", time() - start)
-        data_fd = open(data_path, "rb")
-        data_raw = file_mmaped(data_fd)
+        return file_mmaped(open(data_path, "rb"))
 
-        from sys import getallocatedblocks
+    @skipIf(PY2, "no mmaped memoryview support in PY2")
+    @skipIf("PYDERASN_TEST_CMS_HUGE" not in environ, "PYDERASN_TEST_CMS_HUGE is not set")
+    def test_huge_cer(self):
+        """Huge CMS test
+
+        Environment variable PYDERASN_TEST_CMS_HUGE tells how many MiBs
+        data to sign. Pay attention that openssl cms is unable to do
+        stream verification and eats huge amounts (several times more,
+        than CMS itself) of memory.
+        """
+        data_raw = self.create_huge_file()
+        key_path, cert_path, cert, skid = self.keypair()
+        from sys import getallocatedblocks  # PY2 does not have it
         mem_start = getallocatedblocks()
         start = time()
         eci = EncapsulatedContentInfo((
@@ -376,3 +382,61 @@ class TestSignedDataCERWithOpenSSL(TestCase):
             ci.encode_cer(fd.write)
         print("CMS written", time() - start)
         self.verify(cert_path, cms_path)
+
+    @skipIf(PY2, "no mmaped memoryview support in PY2")
+    @skipIf("PYDERASN_TEST_CMS_HUGE" not in environ, "PYDERASN_TEST_CMS_HUGE is not set")
+    def test_huge_der_2pass(self):
+        """Same test as above, but 2pass DER encoder and just signature verification
+        """
+        data_raw = self.create_huge_file()
+        key_path, cert_path, cert, skid = self.keypair()
+        from sys import getallocatedblocks
+        mem_start = getallocatedblocks()
+        dgst = sha512(data_raw).digest()
+        start = time()
+        eci = EncapsulatedContentInfo((
+            ("eContentType", ContentType(id_data)),
+            ("eContent", OctetString(data_raw)),
+        ))
+        signed_attrs = SignedAttributes([
+            Attribute((
+                ("attrType", id_pkcs9_at_contentType),
+                ("attrValues", AttributeValues([AttributeValue(id_data)])),
+            )),
+            Attribute((
+                ("attrType", id_pkcs9_at_messageDigest),
+                ("attrValues", AttributeValues([AttributeValue(OctetString(dgst))])),
+            )),
+        ])
+        signature = self.sign(signed_attrs, key_path)
+        self.assertLess(getallocatedblocks(), mem_start * 2)
+        start = time()
+        ci = ContentInfo((
+            ("contentType", ContentType(id_signedData)),
+            ("content", Any((SignedData((
+                ("version", CMSVersion("v3")),
+                ("digestAlgorithms", DigestAlgorithmIdentifiers([ai_sha512])),
+                ("encapContentInfo", eci),
+                ("certificates", CertificateSet([
+                    CertificateChoices(("certificate", cert)),
+                ])),
+                ("signerInfos", SignerInfos([SignerInfo((
+                    ("version", CMSVersion("v3")),
+                    ("sid", SignerIdentifier(("subjectKeyIdentifier", skid))),
+                    ("digestAlgorithm", DigestAlgorithmIdentifier(ai_sha512)),
+                    ("signedAttrs", signed_attrs),
+                    ("signatureAlgorithm", SignatureAlgorithmIdentifier((
+                        ("algorithm", id_ecdsa_with_SHA512),
+                    ))),
+                    ("signature", SignatureValue(signature)),
+                ))])),
+            ))))),
+        ))
+        _, state = ci.encode1st()
+        print("2pass state size", getsizeof(state))
+        cms_path = self.tmpfile()
+        with io_open(cms_path, "wb") as fd:
+            ci.encode2nd(fd.write, iter(state))
+        print("CMS written", time() - start)
+        self.assertLess(getallocatedblocks(), mem_start * 2)
+        self.verify(cert_path, cms_path)
index fcfcb9e379ab92d001dad67ceacc235e5735b0aa..bb8748f4d97bfeb758f4d1b453faec88caddf38e 100644 (file)
@@ -17,7 +17,9 @@
 """CRL related schemas, just to test the performance with them
 """
 
+from io import BytesIO
 from os.path import exists
+from sys import getsizeof
 from time import time
 from unittest import skipIf
 from unittest import TestCase
@@ -76,7 +78,7 @@ CRL_PATH = "revoke.crl"
 
 @skipIf(not exists(CRL_PATH), "CACert's revoke.crl not found")
 class TestCACert(TestCase):
-    def test_cer(self):
+    def test_cer_and_2pass(self):
         with open(CRL_PATH, "rb") as fd:
             raw = fd.read()
         print("DER read")
@@ -84,16 +86,23 @@ class TestCACert(TestCase):
         crl1 = CertificateList().decod(raw)
         print("DER decoded", time() - start)
         start = time()
+        der_raw = crl1.encode()
+        print("DER encoded", time() - start)
+        self.assertSequenceEqual(der_raw, raw)
+        buf = BytesIO()
+        start = time()
+        _, state = crl1.encode1st()
+        print("1st pass state size", getsizeof(state))
+        crl1.encode2nd(buf.write, iter(state))
+        print("DER 2pass encoded", time() - start)
+        self.assertSequenceEqual(buf.getvalue(), raw)
+        start = time()
         cer_raw = encode_cer(crl1)
         print("CER encoded", time() - start)
         start = time()
         crl2 = CertificateList().decod(cer_raw, ctx={"bered": True})
         print("CER decoded", time() - start)
         self.assertEqual(crl2, crl1)
-        start = time()
-        der_raw = crl2.encode()
-        print("DER encoded", time() - start)
-        self.assertSequenceEqual(der_raw, raw)
 
     @skipIf(PY2, "Py27 mmap does not implement buffer protocol")
     def test_mmaped(self):
index 9fae13ea1eae16653d587bea0d038d472b6ba719..46e10e53abdbb12848d8be51d21d58f43ff0a11b 100644 (file)
@@ -20,6 +20,7 @@ from copy import deepcopy
 from datetime import datetime
 from datetime import timedelta
 from importlib import import_module
+from io import BytesIO
 from operator import attrgetter
 from os import environ
 from os import urandom
@@ -75,6 +76,7 @@ from pyderasn import BoundsError
 from pyderasn import Choice
 from pyderasn import DecodeError
 from pyderasn import DecodePathDefBy
+from pyderasn import encode2pass
 from pyderasn import encode_cer
 from pyderasn import Enumerated
 from pyderasn import EOC
@@ -422,6 +424,8 @@ class TestBoolean(CommonMixin, TestCase):
         pprint(obj, big_blobs=True, with_decode_path=True)
         with self.assertRaises(ObjNotReady) as err:
             obj.encode()
+        with self.assertRaises(ObjNotReady) as err:
+            encode2pass(obj)
         repr(err.exception)
         obj = Boolean(value)
         self.assertTrue(obj.ready)
@@ -506,6 +510,8 @@ class TestBoolean(CommonMixin, TestCase):
         obj = Boolean(value, impl=tag_impl)
         with self.assertRaises(NotEnoughData):
             obj.decode(obj.encode()[:-1])
+        with self.assertRaises(NotEnoughData):
+            obj.decode(encode2pass(obj)[:-1])
 
     @given(
         booleans(),
@@ -515,6 +521,8 @@ class TestBoolean(CommonMixin, TestCase):
         obj = Boolean(value, expl=tag_expl)
         with self.assertRaises(NotEnoughData):
             obj.decode(obj.encode()[:-1])
+        with self.assertRaises(NotEnoughData):
+            obj.decode(encode2pass(obj)[:-1])
 
     @given(
         integers(min_value=31),
@@ -603,6 +611,7 @@ class TestBoolean(CommonMixin, TestCase):
             pprint(obj, big_blobs=True, with_decode_path=True)
             self.assertFalse(obj.expled)
             obj_encoded = obj.encode()
+            self.assertEqual(encode2pass(obj), obj_encoded)
             self.assertSequenceEqual(encode_cer(obj), obj_encoded)
             obj_expled = obj(value, expl=tag_expl)
             self.assertTrue(obj_expled.expled)
@@ -863,6 +872,8 @@ class TestInteger(CommonMixin, TestCase):
         pprint(obj, big_blobs=True, with_decode_path=True)
         with self.assertRaises(ObjNotReady) as err:
             obj.encode()
+        with self.assertRaises(ObjNotReady) as err:
+            encode2pass(obj)
         repr(err.exception)
         obj = Integer(value)
         self.assertTrue(obj.ready)
@@ -925,6 +936,10 @@ class TestInteger(CommonMixin, TestCase):
                 Integer(values[0]).encode()
             )
         repr(err.exception)
+        with assertRaisesRegex(self, DecodeError, "bounds") as err:
+            Integer(bounds=(values[1], values[2])).decode(
+                encode2pass(Integer(values[0]))
+            )
         with self.assertRaises(BoundsError) as err:
             Integer(value=values[2], bounds=(values[0], values[1]))
         repr(err.exception)
@@ -933,6 +948,10 @@ class TestInteger(CommonMixin, TestCase):
                 Integer(values[2]).encode()
             )
         repr(err.exception)
+        with assertRaisesRegex(self, DecodeError, "bounds") as err:
+            Integer(bounds=(values[0], values[1])).decode(
+                encode2pass(Integer(values[2]))
+            )
 
     @given(data_strategy())
     def test_call(self, d):
@@ -1124,6 +1143,7 @@ class TestInteger(CommonMixin, TestCase):
             pprint(obj, big_blobs=True, with_decode_path=True)
             self.assertFalse(obj.expled)
             obj_encoded = obj.encode()
+            self.assertEqual(encode2pass(obj), obj_encoded)
             self.assertSequenceEqual(encode_cer(obj), obj_encoded)
             obj_expled = obj(value, expl=tag_expl)
             self.assertTrue(obj_expled.expled)
@@ -1362,6 +1382,8 @@ class TestBitString(CommonMixin, TestCase):
         with self.assertRaises(ObjNotReady) as err:
             obj.encode()
         repr(err.exception)
+        with self.assertRaises(ObjNotReady) as err:
+            encode2pass(obj)
         obj = BitString(value)
         self.assertTrue(obj.ready)
         repr(obj)
@@ -1540,6 +1562,7 @@ class TestBitString(CommonMixin, TestCase):
             pprint(obj, big_blobs=True, with_decode_path=True)
             self.assertFalse(obj.expled)
             obj_encoded = obj.encode()
+            self.assertEqual(encode2pass(obj), obj_encoded)
             self.assertSequenceEqual(encode_cer(obj), obj_encoded)
             obj_expled = obj(value, expl=tag_expl)
             self.assertTrue(obj_expled.expled)
@@ -1965,6 +1988,8 @@ class TestOctetString(CommonMixin, TestCase):
         with self.assertRaises(ObjNotReady) as err:
             obj.encode()
         repr(err.exception)
+        with self.assertRaises(ObjNotReady) as err:
+            encode2pass(obj)
         obj = OctetString(value)
         self.assertTrue(obj.ready)
         repr(obj)
@@ -2011,6 +2036,10 @@ class TestOctetString(CommonMixin, TestCase):
                 OctetString(value).encode()
             )
         repr(err.exception)
+        with assertRaisesRegex(self, DecodeError, "bounds") as err:
+            OctetString(bounds=(bound_min, bound_max)).decode(
+                encode2pass(OctetString(value))
+            )
         value = d.draw(binary(min_size=bound_max + 1))
         with self.assertRaises(BoundsError) as err:
             OctetString(value=value, bounds=(bound_min, bound_max))
@@ -2020,6 +2049,10 @@ class TestOctetString(CommonMixin, TestCase):
                 OctetString(value).encode()
             )
         repr(err.exception)
+        with assertRaisesRegex(self, DecodeError, "bounds") as err:
+            OctetString(bounds=(bound_min, bound_max)).decode(
+                encode2pass(OctetString(value))
+            )
 
     @given(data_strategy())
     def test_call(self, d):
@@ -2196,6 +2229,7 @@ class TestOctetString(CommonMixin, TestCase):
             pprint(obj, big_blobs=True, with_decode_path=True)
             self.assertFalse(obj.expled)
             obj_encoded = obj.encode()
+            self.assertEqual(encode2pass(obj), obj_encoded)
             self.assertSequenceEqual(encode_cer(obj), obj_encoded)
             obj_expled = obj(value, expl=tag_expl)
             self.assertTrue(obj_expled.expled)
@@ -2566,6 +2600,7 @@ class TestNull(CommonMixin, TestCase):
             pprint(obj, big_blobs=True, with_decode_path=True)
             self.assertFalse(obj.expled)
             obj_encoded = obj.encode()
+            self.assertEqual(encode2pass(obj), obj_encoded)
             self.assertSequenceEqual(encode_cer(obj), obj_encoded)
             obj_expled = obj(expl=tag_expl)
             self.assertTrue(obj_expled.expled)
@@ -2695,6 +2730,8 @@ class TestObjectIdentifier(CommonMixin, TestCase):
         with self.assertRaises(ObjNotReady) as err:
             obj.encode()
         repr(err.exception)
+        with self.assertRaises(ObjNotReady) as err:
+            encode2pass(obj)
         obj = ObjectIdentifier(value)
         self.assertTrue(obj.ready)
         self.assertFalse(obj.ber_encoded)
@@ -2932,6 +2969,7 @@ class TestObjectIdentifier(CommonMixin, TestCase):
             pprint(obj, big_blobs=True, with_decode_path=True)
             self.assertFalse(obj.expled)
             obj_encoded = obj.encode()
+            self.assertEqual(encode2pass(obj), obj_encoded)
             self.assertSequenceEqual(encode_cer(obj), obj_encoded)
             obj_expled = obj(value, expl=tag_expl)
             self.assertTrue(obj_expled.expled)
@@ -3327,6 +3365,7 @@ class TestEnumerated(CommonMixin, TestCase):
         pprint(obj, big_blobs=True, with_decode_path=True)
         self.assertFalse(obj.expled)
         obj_encoded = obj.encode()
+        self.assertEqual(encode2pass(obj), obj_encoded)
         obj_expled = obj(value, expl=tag_expl)
         self.assertTrue(obj_expled.expled)
         repr(obj_expled)
@@ -3444,6 +3483,8 @@ class StringMixin(object):
         with self.assertRaises(ObjNotReady) as err:
             obj.encode()
         repr(err.exception)
+        with self.assertRaises(ObjNotReady) as err:
+            encode2pass(obj)
         value = d.draw(text(alphabet=self.text_alphabet()))
         obj = self.base_klass(value)
         self.assertTrue(obj.ready)
@@ -3493,6 +3534,10 @@ class StringMixin(object):
                 self.base_klass(value).encode()
             )
         repr(err.exception)
+        with assertRaisesRegex(self, DecodeError, "bounds") as err:
+            self.base_klass(bounds=(bound_min, bound_max)).decode(
+                encode2pass(self.base_klass(value))
+            )
         value = d.draw(text(alphabet=self.text_alphabet(), min_size=bound_max + 1))
         with self.assertRaises(BoundsError) as err:
             self.base_klass(value=value, bounds=(bound_min, bound_max))
@@ -3502,6 +3547,10 @@ class StringMixin(object):
                 self.base_klass(value).encode()
             )
         repr(err.exception)
+        with assertRaisesRegex(self, DecodeError, "bounds") as err:
+            self.base_klass(bounds=(bound_min, bound_max)).decode(
+                encode2pass(self.base_klass(value))
+            )
 
     @given(data_strategy())
     def test_call(self, d):
@@ -3677,6 +3726,7 @@ class StringMixin(object):
         pprint(obj, big_blobs=True, with_decode_path=True)
         self.assertFalse(obj.expled)
         obj_encoded = obj.encode()
+        self.assertEqual(encode2pass(obj), obj_encoded)
         obj_expled = obj(value, expl=tag_expl)
         self.assertTrue(obj_expled.expled)
         repr(obj_expled)
@@ -4033,6 +4083,8 @@ class TimeMixin(object):
         with self.assertRaises(ObjNotReady) as err:
             obj.encode()
         repr(err.exception)
+        with self.assertRaises(ObjNotReady) as err:
+            encode2pass(obj)
         value = d.draw(datetimes(
             min_value=self.min_datetime,
             max_value=self.max_datetime,
@@ -4191,6 +4243,7 @@ class TimeMixin(object):
         pprint(obj, big_blobs=True, with_decode_path=True)
         self.assertFalse(obj.expled)
         obj_encoded = obj.encode()
+        self.assertEqual(encode2pass(obj), obj_encoded)
         self.additional_symmetric_check(value, obj_encoded)
         obj_expled = obj(value, expl=tag_expl)
         self.assertTrue(obj_expled.expled)
@@ -5000,6 +5053,8 @@ class TestAny(CommonMixin, TestCase):
         with self.assertRaises(ObjNotReady) as err:
             obj.encode()
         repr(err.exception)
+        with self.assertRaises(ObjNotReady) as err:
+            encode2pass(obj)
         obj = Any(value)
         self.assertTrue(obj.ready)
         repr(obj)
@@ -5148,6 +5203,7 @@ class TestAny(CommonMixin, TestCase):
             tag_class, _, tag_num = tag_decode(tag_strip(value)[0])
             self.assertEqual(obj.tag_order, (tag_class, tag_num))
             obj_encoded = obj.encode()
+            self.assertEqual(encode2pass(obj), obj_encoded)
             obj_expled = obj(value, expl=tag_expl)
             self.assertTrue(obj_expled.expled)
             tag_class, _, tag_num = tag_decode(tag_expl)
@@ -5384,6 +5440,8 @@ class TestChoice(CommonMixin, TestCase):
         with self.assertRaises(ObjNotReady) as err:
             obj.encode()
         repr(err.exception)
+        with self.assertRaises(ObjNotReady) as err:
+            encode2pass(obj)
         obj["whatever"] = Boolean()
         self.assertFalse(obj.ready)
         repr(obj)
@@ -5532,6 +5590,7 @@ class TestChoice(CommonMixin, TestCase):
         self.assertFalse(obj.expled)
         self.assertEqual(obj.tag_order, obj.value.tag_order)
         obj_encoded = obj.encode()
+        self.assertEqual(encode2pass(obj), obj_encoded)
         obj_expled = obj(value, expl=tag_expl)
         self.assertTrue(obj_expled.expled)
         tag_class, _, tag_num = tag_decode(tag_expl)
@@ -5879,6 +5938,8 @@ class SeqMixing(object):
         with self.assertRaises(ObjNotReady) as err:
             seq.encode()
         repr(err.exception)
+        with self.assertRaises(ObjNotReady) as err:
+            encode2pass(seq)
         for name, value in non_ready.items():
             seq[name] = Boolean(value)
         self.assertTrue(seq.ready)
@@ -6067,6 +6128,7 @@ class SeqMixing(object):
         pprint(seq, big_blobs=True, with_decode_path=True)
         self.assertTrue(seq.ready)
         seq_encoded = seq.encode()
+        self.assertEqual(encode2pass(seq), seq_encoded)
         seq_encoded_cer = encode_cer(seq)
         self.assertNotEqual(seq_encoded_cer, seq_encoded)
         self.assertSequenceEqual(
@@ -6155,6 +6217,7 @@ class SeqMixing(object):
         seq, expect_outers = d.draw(sequences_strategy(seq_klass=self.base_klass))
         self.assertTrue(seq.ready)
         seq_encoded = seq.encode()
+        self.assertEqual(encode2pass(seq), seq_encoded)
         seq_decoded, tail = seq.decode(seq_encoded)
         self.assertEqual(tail, b"")
         self.assertTrue(seq.ready)
@@ -6527,6 +6590,8 @@ class SeqOfMixing(object):
         with self.assertRaises(ObjNotReady) as err:
             seqof.encode()
         repr(err.exception)
+        with self.assertRaises(ObjNotReady) as err:
+            encode2pass(seqof)
         for i, value in enumerate(values):
             self.assertEqual(seqof[i], value)
             if not seqof[i].ready:
@@ -6570,6 +6635,10 @@ class SeqOfMixing(object):
                 SeqOf(value).encode()
             )
         repr(err.exception)
+        with assertRaisesRegex(self, DecodeError, "bounds") as err:
+            SeqOf(bounds=(bound_min, bound_max)).decode(
+                encode2pass(SeqOf(value))
+            )
         value = [Boolean(True)] * d.draw(integers(
             min_value=bound_max + 1,
             max_value=bound_max + 10,
@@ -6582,6 +6651,10 @@ class SeqOfMixing(object):
                 SeqOf(value).encode()
             )
         repr(err.exception)
+        with assertRaisesRegex(self, DecodeError, "bounds") as err:
+            SeqOf(bounds=(bound_min, bound_max)).decode(
+                encode2pass(SeqOf(value))
+            )
 
     @given(integers(min_value=1, max_value=10))
     def test_out_of_bounds(self, bound_max):
@@ -6788,6 +6861,7 @@ class SeqOfMixing(object):
         pprint(obj, big_blobs=True, with_decode_path=True)
         self.assertFalse(obj.expled)
         obj_encoded = obj.encode()
+        self.assertEqual(encode2pass(obj), obj_encoded)
         obj_encoded_cer = encode_cer(obj)
         self.assertNotEqual(obj_encoded_cer, obj_encoded)
         self.assertSequenceEqual(
@@ -6973,6 +7047,26 @@ class TestSequenceOf(SeqOfMixing, CommonMixin, TestCase):
         register_class(SeqOf)
         pickle_dumps(seqof)
 
+    def test_iterator_2pass(self):
+        class SeqOf(SequenceOf):
+            schema = Integer()
+            bounds = (1, float("+inf"))
+        def gen():
+            for i in six_xrange(10):
+                yield Integer(i)
+        seqof = SeqOf(gen())
+        self.assertTrue(seqof.ready)
+        _, state = seqof.encode1st()
+        self.assertFalse(seqof.ready)
+        seqof = seqof(gen())
+        self.assertTrue(seqof.ready)
+        buf = BytesIO()
+        seqof.encode2nd(buf.write, iter(state))
+        self.assertSequenceEqual(
+            [int(i) for i in seqof.decod(buf.getvalue())],
+            list(gen()),
+        )
+
     def test_non_ready_bound_min(self):
         class SeqOf(SequenceOf):
             schema = Integer()