From 659eb0e090ad8c9209c4e5b030806125509844f9 Mon Sep 17 00:00:00 2001 From: Sergey Matveev Date: Mon, 17 Feb 2020 16:58:41 +0300 Subject: [PATCH] DER 2pass encoding --- README | 6 +- doc/features.rst | 6 +- doc/news.rst | 1 + pyderasn.py | 300 +++++++++++++++++++++++++++++++++++++++-- tests/test_cms.py | 92 +++++++++++-- tests/test_crl.py | 19 ++- tests/test_pyderasn.py | 94 +++++++++++++ 7 files changed, 484 insertions(+), 34 deletions(-) diff --git a/README b/README index 150779f..b4d771c 100644 --- a/README +++ b/README @@ -19,9 +19,9 @@ PyDERASN -- strict and fast ASN.1 DER/CER/BER library for Python * Ability to allow BER-encoded data with knowing if any of specified field has either DER or BER encoding (or possibly indefinite-length encoding) -* Ability to use mmap-ed files, memoryviews, iterators and CER encoder - dealing with the writer, giving ability to create huge ASN.1 encoded - files without storing all the data in the memory first +* Ability to use mmap-ed files, memoryviews, iterators, 2-pass DER + encoding mode and CER encoder dealing with the writer, giving ability + to create huge ASN.1 encoded files with very little memory footprint * Ability to decode files in event generation mode, without the need to keep all the data and decoded structures in the memory * __slots__, copy.copy() friendliness diff --git a/doc/features.rst b/doc/features.rst index b17e9a8..ff913bb 100644 --- a/doc/features.rst +++ b/doc/features.rst @@ -38,9 +38,9 @@ Also there is `asn1crypto `__. structures allow BER encoding for the whole message, except for ``SignedAttributes`` -- you can easily verify your CMS satisfies that requirement -* Ability to use mmap-ed files, memoryviews, iterators and CER encoder - dealing with the writer, giving ability to create huge ASN.1 encoded - files without storing all the data in the memory first +* Ability to use mmap-ed files, memoryviews, iterators, 2-pass DER + encoding mode and CER encoder dealing with the writer, giving ability + to create huge ASN.1 encoded files with very little memory footprint * Ability to decode files in event generation mode, without the need to keep all the data and decoded structures (that takes huge quantity of memory in all known ASN.1 libraries) in the memory diff --git a/doc/news.rst b/doc/news.rst index 8ed3421..919944a 100644 --- a/doc/news.rst +++ b/doc/news.rst @@ -7,6 +7,7 @@ News --- * Restored workability of some command line options +* 2-pass DER encoding mode with very little memory footprint .. _release7.1: diff --git a/pyderasn.py b/pyderasn.py index 8208efe..c9d6a11 100755 --- a/pyderasn.py +++ b/pyderasn.py @@ -829,6 +829,8 @@ copy the payload (without BER/CER encoding interleaved overhead) in it. Virtually it won't take memory more than for keeping small structures and 1 KB binary chunks. +.. _seqof-iterators: + SEQUENCE OF iterators _____________________ @@ -844,6 +846,54 @@ generator taking necessary data from the database and giving the ``RevokedCertificate`` objects. Only binary representation of that objects will take memory during DER encoding. +2-pass DER encoding +------------------- + +There is ability to do 2-pass encoding to DER, writing results directly +to specified writer (buffer, file, whatever). It could be 1.5+ times +slower than ordinary encoding, but it takes little memory for 1st pass +state storing. For example, 1st pass state for CACert.org's CRL with +~416K of certificate entries takes nearly 3.5 MB of memory. +``SignedData`` with several gigabyte ``EncapsulatedContentInfo`` takes +nearly 0.5 KB of memory. + +If you use :ref:`mmap-ed ` memoryviews, :ref:`SEQUENCE OF +iterators ` and write directly to opened file, then +there is very small memory footprint. + +1st pass traverses through all the objects of the structure and returns +the size of DER encoded structure, together with 1st pass state object. +That state contains precalculated lengths for various objects inside the +structure. + +:: + + fulllen, state = obj.encode1st() + +2nd pass takes the writer and 1st pass state. It traverses through all +the objects again, but writes their encoded representation to the writer. + +:: + + opener = io.open if PY2 else open + with opener("result", "wb") as fd: + obj.encode2nd(fd.write, iter(state)) + +.. warning:: + + You **MUST NOT** use 1st pass state if anything is changed in the + objects. It is intended to be used immediately after 1st pass is + done! + +If you use :ref:`SEQUENCE OF iterators `, then you +have to reinitialize the values after the 1st pass. And you **have to** +be sure that the iterator gives exactly the same values as previously. +Yes, you have to run your iterator twice -- because this is two pass +encoding mode. + +If you want to encode to the memory, then you can use convenient +:py:func:`pyderasn.encode2pass` helper. + Base Obj -------- .. autoclass:: pyderasn.Obj @@ -955,6 +1005,7 @@ Various .. autofunction:: pyderasn.abs_decode_path .. autofunction:: pyderasn.agg_octet_string .. autofunction:: pyderasn.colonize_hex +.. autofunction:: pyderasn.encode2pass .. autofunction:: pyderasn.encode_cer .. autofunction:: pyderasn.file_mmaped .. autofunction:: pyderasn.hexenc @@ -1101,6 +1152,7 @@ from mmap import PROT_READ from operator import attrgetter from string import ascii_letters from string import digits +from sys import maxsize as sys_maxsize from sys import version_info from unicodedata import category as unicat @@ -1138,6 +1190,7 @@ __all__ = ( "Choice", "DecodeError", "DecodePathDefBy", + "encode2pass", "encode_cer", "Enumerated", "ExceedingData", @@ -1525,6 +1578,28 @@ LEN1 = len_encode(1) LEN1K = len_encode(1000) +def len_size(l): + """How many bytes length field will take + """ + if l < 128: + return 1 + if l < 256: # 1 << 8 + return 2 + if l < 65536: # 1 << 16 + return 3 + if l < 16777216: # 1 << 24 + return 4 + if l < 4294967296: # 1 << 32 + return 5 + if l < 1099511627776: # 1 << 40 + return 6 + if l < 281474976710656: # 1 << 48 + return 7 + if l < 72057594037927936: # 1 << 56 + return 8 + raise OverflowError("too big length") + + def write_full(writer, data): """Fully write provided data @@ -1543,6 +1618,17 @@ def write_full(writer, data): written += n +# If it is 64-bit system, then use compact 64-bit array of unsigned +# longs. Use an ordinary list with universal integers otherwise, that +# is slower. +if sys_maxsize > 2 ** 32: + def state_2pass_new(): + return array("L") +else: + def state_2pass_new(): + return [] + + ######################################################################## # Base class ######################################################################## @@ -1701,9 +1787,18 @@ class Obj(object): def _encode(self): # pragma: no cover raise NotImplementedError() + def _encode_cer(self, writer): + write_full(writer, self._encode()) + def _decode(self, tlv, offset, decode_path, ctx, tag_only, evgen_mode): # pragma: no cover yield NotImplemented + def _encode1st(self, state): + raise NotImplementedError() + + def _encode2nd(self, writer, state_iter): + raise NotImplementedError() + def encode(self): """DER encode the structure @@ -1714,6 +1809,36 @@ class Obj(object): return raw return b"".join((self._expl, len_encode(len(raw)), raw)) + def encode1st(self, state=None): + """Do the 1st pass of 2-pass encoding + + :rtype: (int, array("L")) + :returns: full length of encoded data and precalculated various + objects lengths + """ + if state is None: + state = state_2pass_new() + if self._expl is None: + return self._encode1st(state) + state.append(0) + idx = len(state) - 1 + vlen, _ = self._encode1st(state) + state[idx] = vlen + fulllen = len(self._expl) + len_size(vlen) + vlen + return fulllen, state + + def encode2nd(self, writer, state_iter): + """Do the 2nd pass of 2-pass encoding + + :param writer: must comply with ``io.RawIOBase.write`` behaviour + :param state_iter: iterator over the 1st pass state (``iter(state)``) + """ + if self._expl is None: + self._encode2nd(writer, state_iter) + else: + write_full(writer, self._expl + len_encode(next(state_iter))) + self._encode2nd(writer, state_iter) + def encode_cer(self, writer): """CER encode the structure to specified writer @@ -1731,9 +1856,6 @@ class Obj(object): if self._expl is not None: write_full(writer, EOC) - def _encode_cer(self, writer): - write_full(writer, self._encode()) - def hexencode(self): """Do hexadecimal encoded :py:meth:`pyderasn.Obj.encode` """ @@ -2045,6 +2167,17 @@ def encode_cer(obj): return buf.getvalue() +def encode2pass(obj): + """Encode (2-pass mode) to DER in memory buffer + + :returns bytes: memory buffer contents + """ + buf = BytesIO() + _, state = obj.encode1st() + obj.encode2nd(buf.write, iter(state)) + return buf.getvalue() + + class DecodePathDefBy(object): """DEFINED BY representation inside decode path """ @@ -2464,6 +2597,13 @@ class Boolean(Obj): self._assert_ready() return b"".join((self.tag, LEN1, (b"\xFF" if self._value else b"\x00"))) + def _encode1st(self, state): + return len(self.tag) + 2, state + + def _encode2nd(self, writer, state_iter): + self._assert_ready() + write_full(writer, self._encode()) + def _decode(self, tlv, offset, decode_path, ctx, tag_only, evgen_mode): try: t, _, lv = tag_strip(tlv) @@ -2755,7 +2895,7 @@ class Integer(Obj): _specs=self.specs, ) - def _encode(self): + def _encode_payload(self): self._assert_ready() value = self._value if PY2: @@ -2792,8 +2932,20 @@ class Integer(Obj): bytes_len += 1 else: break + return octets + return b"".join((self.tag, len_encode(len(octets)), octets)) + + def _encode(self): + octets = self._encode_payload() return b"".join((self.tag, len_encode(len(octets)), octets)) + def _encode1st(self, state): + l = len(self._encode_payload()) + return len(self.tag) + len_size(l) + l, state + + def _encode2nd(self, writer, state_iter): + write_full(writer, self._encode()) + def _decode(self, tlv, offset, decode_path, ctx, tag_only, evgen_mode): try: t, _, lv = tag_strip(tlv) @@ -3178,6 +3330,21 @@ class BitString(Obj): octets, )) + def _encode1st(self, state): + self._assert_ready() + _, octets = self._value + l = len(octets) + 1 + return len(self.tag) + len_size(l) + l, state + + def _encode2nd(self, writer, state_iter): + bit_len, octets = self._value + write_full(writer, b"".join(( + self.tag, + len_encode(len(octets) + 1), + int2byte((8 - bit_len % 8) % 8), + ))) + write_full(writer, octets) + def _encode_cer(self, writer): bit_len, octets = self._value if len(octets) + 1 <= 1000: @@ -3629,6 +3796,16 @@ class OctetString(Obj): self._value, )) + def _encode1st(self, state): + self._assert_ready() + l = len(self._value) + return len(self.tag) + len_size(l) + l, state + + def _encode2nd(self, writer, state_iter): + value = self._value + write_full(writer, self.tag + len_encode(len(value))) + write_full(writer, value) + def _encode_cer(self, writer): octets = self._value if len(octets) <= 1000: @@ -3987,6 +4164,12 @@ class Null(Obj): def _encode(self): return self.tag + LEN0 + def _encode1st(self, state): + return len(self.tag) + 1, state + + def _encode2nd(self, writer, state_iter): + write_full(writer, self.tag + LEN0) + def _decode(self, tlv, offset, decode_path, ctx, tag_only, evgen_mode): try: t, _, lv = tag_strip(tlv) @@ -4239,7 +4422,7 @@ class ObjectIdentifier(Obj): optional=self.optional if optional is None else optional, ) - def _encode(self): + def _encode_octets(self): self._assert_ready() value = self._value first_value = value[1] @@ -4255,9 +4438,19 @@ class ObjectIdentifier(Obj): octets = [zero_ended_encode(first_value)] for arc in value[2:]: octets.append(zero_ended_encode(arc)) - v = b"".join(octets) + return b"".join(octets) + + def _encode(self): + v = self._encode_octets() return b"".join((self.tag, len_encode(len(v)), v)) + def _encode1st(self, state): + l = len(self._encode_octets()) + return len(self.tag) + len_size(l) + l, state + + def _encode2nd(self, writer, state_iter): + write_full(writer, self._encode()) + def _decode(self, tlv, offset, decode_path, ctx, tag_only, evgen_mode): try: t, _, lv = tag_strip(tlv) @@ -5017,6 +5210,13 @@ class UTCTime(VisibleString): self._assert_ready() return b"".join((self.tag, LEN_LEN_YYMMDDHHMMSSZ, self._encode_time())) + def _encode1st(self, state): + return len(self.tag) + LEN_YYMMDDHHMMSSZ_WITH_LEN, state + + def _encode2nd(self, writer, state_iter): + self._assert_ready() + write_full(writer, self._encode()) + def _encode_cer(self, writer): write_full(writer, self._encode()) @@ -5189,6 +5389,14 @@ class GeneralizedTime(UTCTime): return b"".join((self.tag, len_encode(len(encoded)), encoded)) return b"".join((self.tag, LEN_LEN_YYYYMMDDHHMMSSZ, self._encode_time())) + def _encode1st(self, state): + self._assert_ready() + vlen = len(self._encode_time()) + return len(self.tag) + len_size(vlen) + vlen, state + + def _encode2nd(self, writer, state_iter): + write_full(writer, self._encode()) + class GraphicString(CommonString): __slots__ = () @@ -5433,6 +5641,13 @@ class Choice(Obj): self._assert_ready() return self._value[1].encode() + def _encode1st(self, state): + self._assert_ready() + return self._value[1].encode1st(state) + + def _encode2nd(self, writer, state_iter): + self._value[1].encode2nd(writer, state_iter) + def _encode_cer(self, writer): self._assert_ready() self._value[1].encode_cer(writer) @@ -5701,6 +5916,20 @@ class Any(Obj): return value return value.encode() + def _encode1st(self, state): + self._assert_ready() + value = self._value + if value.__class__ == binary_type: + return len(value), state + return value.encode1st(state) + + def _encode2nd(self, writer, state_iter): + value = self._value + if value.__class__ == binary_type: + write_full(writer, value) + else: + value.encode2nd(writer, state_iter) + def _encode_cer(self, writer): self._assert_ready() value = self._value @@ -5859,7 +6088,19 @@ SequenceState = namedtuple( ) -class Sequence(Obj): +class SequenceEncode1stMixing(object): + def _encode1st(self, state): + state.append(0) + idx = len(state) - 1 + vlen = 0 + for v in self._values_for_encoding(): + l, _ = v.encode1st(state) + vlen += l + state[idx] = vlen + return len(self.tag) + len_size(vlen) + vlen, state + + +class Sequence(SequenceEncode1stMixing, Obj): """``SEQUENCE`` structure type You have to make specification of sequence:: @@ -6106,6 +6347,11 @@ class Sequence(Obj): v = b"".join(v.encode() for v in self._values_for_encoding()) return b"".join((self.tag, len_encode(len(v)), v)) + def _encode2nd(self, writer, state_iter): + write_full(writer, self.tag + len_encode(next(state_iter))) + for v in self._values_for_encoding(): + v.encode2nd(writer, state_iter) + def _encode_cer(self, writer): write_full(writer, self.tag + LENINDEF) for v in self._values_for_encoding(): @@ -6350,7 +6596,7 @@ class Sequence(Obj): yield pp -class Set(Sequence): +class Set(Sequence, SequenceEncode1stMixing): """``SET`` structure type Its usage is identical to :py:class:`pyderasn.Sequence`. @@ -6552,7 +6798,7 @@ SequenceOfState = namedtuple( ) -class SequenceOf(Obj): +class SequenceOf(SequenceEncode1stMixing, Obj): """``SEQUENCE OF`` sequence type For that kind of type you must specify the object it will carry on @@ -6781,6 +7027,31 @@ class SequenceOf(Obj): value = b"".join(v.encode() for v in self._values_for_encoding()) return b"".join((self.tag, len_encode(len(value)), value)) + def _encode1st(self, state): + state = super(SequenceOf, self)._encode1st(state) + if hasattr(self._value, NEXT_ATTR_NAME): + self._value = [] + return state + + def _encode2nd(self, writer, state_iter): + write_full(writer, self.tag + len_encode(next(state_iter))) + iterator = hasattr(self._value, NEXT_ATTR_NAME) + if iterator: + values_count = 0 + class_expected = self.spec.__class__ + values_for_encoding = self._values_for_encoding() + self._value = [] + for v in values_for_encoding: + if not isinstance(v, class_expected): + raise InvalidValueType((class_expected,)) + v.encode2nd(writer, state_iter) + values_count += 1 + if not self._bound_min <= values_count <= self._bound_max: + raise BoundsError(self._bound_min, values_count, self._bound_max) + else: + for v in self._values_for_encoding(): + v.encode2nd(writer, state_iter) + def _encode_cer(self, writer): write_full(writer, self.tag + LENINDEF) iterator = hasattr(self._value, NEXT_ATTR_NAME) @@ -7004,6 +7275,17 @@ class SetOf(SequenceOf): v = b"".join(sorted(v.encode() for v in self._values_for_encoding())) return b"".join((self.tag, len_encode(len(v)), v)) + def _encode2nd(self, writer, state_iter): + write_full(writer, self.tag + len_encode(next(state_iter))) + values = [] + for v in self._values_for_encoding(): + buf = BytesIO() + v.encode2nd(buf.write, state_iter) + values.append(buf.getvalue()) + values.sort() + for v in values: + write_full(writer, v) + def _encode_cer(self, writer): write_full(writer, self.tag + LENINDEF) for v in sorted(encode_cer(v) for v in self._values_for_encoding()): diff --git a/tests/test_cms.py b/tests/test_cms.py index faa52ab..51bdaf1 100644 --- a/tests/test_cms.py +++ b/tests/test_cms.py @@ -22,6 +22,7 @@ from os import environ from os import remove from os import urandom from subprocess import call +from sys import getsizeof from tempfile import NamedTemporaryFile from time import time from unittest import skipIf @@ -276,6 +277,10 @@ class TestSignedDataCERWithOpenSSL(TestCase): ))))), )) cms_path = self.tmpfile() + _, state = ci.encode1st() + with io_open(cms_path, "wb") as fd: + ci.encode2nd(fd.write, iter(state)) + self.verify(cert_path, cms_path) with io_open(cms_path, "wb") as fd: ci.encode_cer(fd.write) self.verify(cert_path, cms_path) @@ -290,17 +295,7 @@ class TestSignedDataCERWithOpenSSL(TestCase): agg_octet_string(evgens, ("encapContentInfo", "eContent"), raw, buf.write) self.assertSequenceEqual(buf.getvalue(), data) - @skipIf(PY2, "no mmaped memoryview support in PY2") - @skipIf("PYDERASN_TEST_CMS_HUGE" not in environ, "PYDERASN_TEST_CMS_HUGE is not set") - def test_huge(self): - """Huge CMS test - - Environment variable PYDERASN_TEST_CMS_HUGE tells how many MiBs - data to sign. Pay attention that openssl cms is unable to do - stream verification and eats huge amounts (several times more, - that CMS itself) of memory. - """ - key_path, cert_path, cert, skid = self.keypair() + def create_huge_file(self): rnd = urandom(1<<20) data_path = self.tmpfile() start = time() @@ -309,10 +304,21 @@ class TestSignedDataCERWithOpenSSL(TestCase): # dgst.update(rnd) fd.write(rnd) print("data file written", time() - start) - data_fd = open(data_path, "rb") - data_raw = file_mmaped(data_fd) + return file_mmaped(open(data_path, "rb")) - from sys import getallocatedblocks + @skipIf(PY2, "no mmaped memoryview support in PY2") + @skipIf("PYDERASN_TEST_CMS_HUGE" not in environ, "PYDERASN_TEST_CMS_HUGE is not set") + def test_huge_cer(self): + """Huge CMS test + + Environment variable PYDERASN_TEST_CMS_HUGE tells how many MiBs + data to sign. Pay attention that openssl cms is unable to do + stream verification and eats huge amounts (several times more, + than CMS itself) of memory. + """ + data_raw = self.create_huge_file() + key_path, cert_path, cert, skid = self.keypair() + from sys import getallocatedblocks # PY2 does not have it mem_start = getallocatedblocks() start = time() eci = EncapsulatedContentInfo(( @@ -376,3 +382,61 @@ class TestSignedDataCERWithOpenSSL(TestCase): ci.encode_cer(fd.write) print("CMS written", time() - start) self.verify(cert_path, cms_path) + + @skipIf(PY2, "no mmaped memoryview support in PY2") + @skipIf("PYDERASN_TEST_CMS_HUGE" not in environ, "PYDERASN_TEST_CMS_HUGE is not set") + def test_huge_der_2pass(self): + """Same test as above, but 2pass DER encoder and just signature verification + """ + data_raw = self.create_huge_file() + key_path, cert_path, cert, skid = self.keypair() + from sys import getallocatedblocks + mem_start = getallocatedblocks() + dgst = sha512(data_raw).digest() + start = time() + eci = EncapsulatedContentInfo(( + ("eContentType", ContentType(id_data)), + ("eContent", OctetString(data_raw)), + )) + signed_attrs = SignedAttributes([ + Attribute(( + ("attrType", id_pkcs9_at_contentType), + ("attrValues", AttributeValues([AttributeValue(id_data)])), + )), + Attribute(( + ("attrType", id_pkcs9_at_messageDigest), + ("attrValues", AttributeValues([AttributeValue(OctetString(dgst))])), + )), + ]) + signature = self.sign(signed_attrs, key_path) + self.assertLess(getallocatedblocks(), mem_start * 2) + start = time() + ci = ContentInfo(( + ("contentType", ContentType(id_signedData)), + ("content", Any((SignedData(( + ("version", CMSVersion("v3")), + ("digestAlgorithms", DigestAlgorithmIdentifiers([ai_sha512])), + ("encapContentInfo", eci), + ("certificates", CertificateSet([ + CertificateChoices(("certificate", cert)), + ])), + ("signerInfos", SignerInfos([SignerInfo(( + ("version", CMSVersion("v3")), + ("sid", SignerIdentifier(("subjectKeyIdentifier", skid))), + ("digestAlgorithm", DigestAlgorithmIdentifier(ai_sha512)), + ("signedAttrs", signed_attrs), + ("signatureAlgorithm", SignatureAlgorithmIdentifier(( + ("algorithm", id_ecdsa_with_SHA512), + ))), + ("signature", SignatureValue(signature)), + ))])), + ))))), + )) + _, state = ci.encode1st() + print("2pass state size", getsizeof(state)) + cms_path = self.tmpfile() + with io_open(cms_path, "wb") as fd: + ci.encode2nd(fd.write, iter(state)) + print("CMS written", time() - start) + self.assertLess(getallocatedblocks(), mem_start * 2) + self.verify(cert_path, cms_path) diff --git a/tests/test_crl.py b/tests/test_crl.py index fcfcb9e..bb8748f 100644 --- a/tests/test_crl.py +++ b/tests/test_crl.py @@ -17,7 +17,9 @@ """CRL related schemas, just to test the performance with them """ +from io import BytesIO from os.path import exists +from sys import getsizeof from time import time from unittest import skipIf from unittest import TestCase @@ -76,7 +78,7 @@ CRL_PATH = "revoke.crl" @skipIf(not exists(CRL_PATH), "CACert's revoke.crl not found") class TestCACert(TestCase): - def test_cer(self): + def test_cer_and_2pass(self): with open(CRL_PATH, "rb") as fd: raw = fd.read() print("DER read") @@ -84,16 +86,23 @@ class TestCACert(TestCase): crl1 = CertificateList().decod(raw) print("DER decoded", time() - start) start = time() + der_raw = crl1.encode() + print("DER encoded", time() - start) + self.assertSequenceEqual(der_raw, raw) + buf = BytesIO() + start = time() + _, state = crl1.encode1st() + print("1st pass state size", getsizeof(state)) + crl1.encode2nd(buf.write, iter(state)) + print("DER 2pass encoded", time() - start) + self.assertSequenceEqual(buf.getvalue(), raw) + start = time() cer_raw = encode_cer(crl1) print("CER encoded", time() - start) start = time() crl2 = CertificateList().decod(cer_raw, ctx={"bered": True}) print("CER decoded", time() - start) self.assertEqual(crl2, crl1) - start = time() - der_raw = crl2.encode() - print("DER encoded", time() - start) - self.assertSequenceEqual(der_raw, raw) @skipIf(PY2, "Py27 mmap does not implement buffer protocol") def test_mmaped(self): diff --git a/tests/test_pyderasn.py b/tests/test_pyderasn.py index 9fae13e..46e10e5 100644 --- a/tests/test_pyderasn.py +++ b/tests/test_pyderasn.py @@ -20,6 +20,7 @@ from copy import deepcopy from datetime import datetime from datetime import timedelta from importlib import import_module +from io import BytesIO from operator import attrgetter from os import environ from os import urandom @@ -75,6 +76,7 @@ from pyderasn import BoundsError from pyderasn import Choice from pyderasn import DecodeError from pyderasn import DecodePathDefBy +from pyderasn import encode2pass from pyderasn import encode_cer from pyderasn import Enumerated from pyderasn import EOC @@ -422,6 +424,8 @@ class TestBoolean(CommonMixin, TestCase): pprint(obj, big_blobs=True, with_decode_path=True) with self.assertRaises(ObjNotReady) as err: obj.encode() + with self.assertRaises(ObjNotReady) as err: + encode2pass(obj) repr(err.exception) obj = Boolean(value) self.assertTrue(obj.ready) @@ -506,6 +510,8 @@ class TestBoolean(CommonMixin, TestCase): obj = Boolean(value, impl=tag_impl) with self.assertRaises(NotEnoughData): obj.decode(obj.encode()[:-1]) + with self.assertRaises(NotEnoughData): + obj.decode(encode2pass(obj)[:-1]) @given( booleans(), @@ -515,6 +521,8 @@ class TestBoolean(CommonMixin, TestCase): obj = Boolean(value, expl=tag_expl) with self.assertRaises(NotEnoughData): obj.decode(obj.encode()[:-1]) + with self.assertRaises(NotEnoughData): + obj.decode(encode2pass(obj)[:-1]) @given( integers(min_value=31), @@ -603,6 +611,7 @@ class TestBoolean(CommonMixin, TestCase): pprint(obj, big_blobs=True, with_decode_path=True) self.assertFalse(obj.expled) obj_encoded = obj.encode() + self.assertEqual(encode2pass(obj), obj_encoded) self.assertSequenceEqual(encode_cer(obj), obj_encoded) obj_expled = obj(value, expl=tag_expl) self.assertTrue(obj_expled.expled) @@ -863,6 +872,8 @@ class TestInteger(CommonMixin, TestCase): pprint(obj, big_blobs=True, with_decode_path=True) with self.assertRaises(ObjNotReady) as err: obj.encode() + with self.assertRaises(ObjNotReady) as err: + encode2pass(obj) repr(err.exception) obj = Integer(value) self.assertTrue(obj.ready) @@ -925,6 +936,10 @@ class TestInteger(CommonMixin, TestCase): Integer(values[0]).encode() ) repr(err.exception) + with assertRaisesRegex(self, DecodeError, "bounds") as err: + Integer(bounds=(values[1], values[2])).decode( + encode2pass(Integer(values[0])) + ) with self.assertRaises(BoundsError) as err: Integer(value=values[2], bounds=(values[0], values[1])) repr(err.exception) @@ -933,6 +948,10 @@ class TestInteger(CommonMixin, TestCase): Integer(values[2]).encode() ) repr(err.exception) + with assertRaisesRegex(self, DecodeError, "bounds") as err: + Integer(bounds=(values[0], values[1])).decode( + encode2pass(Integer(values[2])) + ) @given(data_strategy()) def test_call(self, d): @@ -1124,6 +1143,7 @@ class TestInteger(CommonMixin, TestCase): pprint(obj, big_blobs=True, with_decode_path=True) self.assertFalse(obj.expled) obj_encoded = obj.encode() + self.assertEqual(encode2pass(obj), obj_encoded) self.assertSequenceEqual(encode_cer(obj), obj_encoded) obj_expled = obj(value, expl=tag_expl) self.assertTrue(obj_expled.expled) @@ -1362,6 +1382,8 @@ class TestBitString(CommonMixin, TestCase): with self.assertRaises(ObjNotReady) as err: obj.encode() repr(err.exception) + with self.assertRaises(ObjNotReady) as err: + encode2pass(obj) obj = BitString(value) self.assertTrue(obj.ready) repr(obj) @@ -1540,6 +1562,7 @@ class TestBitString(CommonMixin, TestCase): pprint(obj, big_blobs=True, with_decode_path=True) self.assertFalse(obj.expled) obj_encoded = obj.encode() + self.assertEqual(encode2pass(obj), obj_encoded) self.assertSequenceEqual(encode_cer(obj), obj_encoded) obj_expled = obj(value, expl=tag_expl) self.assertTrue(obj_expled.expled) @@ -1965,6 +1988,8 @@ class TestOctetString(CommonMixin, TestCase): with self.assertRaises(ObjNotReady) as err: obj.encode() repr(err.exception) + with self.assertRaises(ObjNotReady) as err: + encode2pass(obj) obj = OctetString(value) self.assertTrue(obj.ready) repr(obj) @@ -2011,6 +2036,10 @@ class TestOctetString(CommonMixin, TestCase): OctetString(value).encode() ) repr(err.exception) + with assertRaisesRegex(self, DecodeError, "bounds") as err: + OctetString(bounds=(bound_min, bound_max)).decode( + encode2pass(OctetString(value)) + ) value = d.draw(binary(min_size=bound_max + 1)) with self.assertRaises(BoundsError) as err: OctetString(value=value, bounds=(bound_min, bound_max)) @@ -2020,6 +2049,10 @@ class TestOctetString(CommonMixin, TestCase): OctetString(value).encode() ) repr(err.exception) + with assertRaisesRegex(self, DecodeError, "bounds") as err: + OctetString(bounds=(bound_min, bound_max)).decode( + encode2pass(OctetString(value)) + ) @given(data_strategy()) def test_call(self, d): @@ -2196,6 +2229,7 @@ class TestOctetString(CommonMixin, TestCase): pprint(obj, big_blobs=True, with_decode_path=True) self.assertFalse(obj.expled) obj_encoded = obj.encode() + self.assertEqual(encode2pass(obj), obj_encoded) self.assertSequenceEqual(encode_cer(obj), obj_encoded) obj_expled = obj(value, expl=tag_expl) self.assertTrue(obj_expled.expled) @@ -2566,6 +2600,7 @@ class TestNull(CommonMixin, TestCase): pprint(obj, big_blobs=True, with_decode_path=True) self.assertFalse(obj.expled) obj_encoded = obj.encode() + self.assertEqual(encode2pass(obj), obj_encoded) self.assertSequenceEqual(encode_cer(obj), obj_encoded) obj_expled = obj(expl=tag_expl) self.assertTrue(obj_expled.expled) @@ -2695,6 +2730,8 @@ class TestObjectIdentifier(CommonMixin, TestCase): with self.assertRaises(ObjNotReady) as err: obj.encode() repr(err.exception) + with self.assertRaises(ObjNotReady) as err: + encode2pass(obj) obj = ObjectIdentifier(value) self.assertTrue(obj.ready) self.assertFalse(obj.ber_encoded) @@ -2932,6 +2969,7 @@ class TestObjectIdentifier(CommonMixin, TestCase): pprint(obj, big_blobs=True, with_decode_path=True) self.assertFalse(obj.expled) obj_encoded = obj.encode() + self.assertEqual(encode2pass(obj), obj_encoded) self.assertSequenceEqual(encode_cer(obj), obj_encoded) obj_expled = obj(value, expl=tag_expl) self.assertTrue(obj_expled.expled) @@ -3327,6 +3365,7 @@ class TestEnumerated(CommonMixin, TestCase): pprint(obj, big_blobs=True, with_decode_path=True) self.assertFalse(obj.expled) obj_encoded = obj.encode() + self.assertEqual(encode2pass(obj), obj_encoded) obj_expled = obj(value, expl=tag_expl) self.assertTrue(obj_expled.expled) repr(obj_expled) @@ -3444,6 +3483,8 @@ class StringMixin(object): with self.assertRaises(ObjNotReady) as err: obj.encode() repr(err.exception) + with self.assertRaises(ObjNotReady) as err: + encode2pass(obj) value = d.draw(text(alphabet=self.text_alphabet())) obj = self.base_klass(value) self.assertTrue(obj.ready) @@ -3493,6 +3534,10 @@ class StringMixin(object): self.base_klass(value).encode() ) repr(err.exception) + with assertRaisesRegex(self, DecodeError, "bounds") as err: + self.base_klass(bounds=(bound_min, bound_max)).decode( + encode2pass(self.base_klass(value)) + ) value = d.draw(text(alphabet=self.text_alphabet(), min_size=bound_max + 1)) with self.assertRaises(BoundsError) as err: self.base_klass(value=value, bounds=(bound_min, bound_max)) @@ -3502,6 +3547,10 @@ class StringMixin(object): self.base_klass(value).encode() ) repr(err.exception) + with assertRaisesRegex(self, DecodeError, "bounds") as err: + self.base_klass(bounds=(bound_min, bound_max)).decode( + encode2pass(self.base_klass(value)) + ) @given(data_strategy()) def test_call(self, d): @@ -3677,6 +3726,7 @@ class StringMixin(object): pprint(obj, big_blobs=True, with_decode_path=True) self.assertFalse(obj.expled) obj_encoded = obj.encode() + self.assertEqual(encode2pass(obj), obj_encoded) obj_expled = obj(value, expl=tag_expl) self.assertTrue(obj_expled.expled) repr(obj_expled) @@ -4033,6 +4083,8 @@ class TimeMixin(object): with self.assertRaises(ObjNotReady) as err: obj.encode() repr(err.exception) + with self.assertRaises(ObjNotReady) as err: + encode2pass(obj) value = d.draw(datetimes( min_value=self.min_datetime, max_value=self.max_datetime, @@ -4191,6 +4243,7 @@ class TimeMixin(object): pprint(obj, big_blobs=True, with_decode_path=True) self.assertFalse(obj.expled) obj_encoded = obj.encode() + self.assertEqual(encode2pass(obj), obj_encoded) self.additional_symmetric_check(value, obj_encoded) obj_expled = obj(value, expl=tag_expl) self.assertTrue(obj_expled.expled) @@ -5000,6 +5053,8 @@ class TestAny(CommonMixin, TestCase): with self.assertRaises(ObjNotReady) as err: obj.encode() repr(err.exception) + with self.assertRaises(ObjNotReady) as err: + encode2pass(obj) obj = Any(value) self.assertTrue(obj.ready) repr(obj) @@ -5148,6 +5203,7 @@ class TestAny(CommonMixin, TestCase): tag_class, _, tag_num = tag_decode(tag_strip(value)[0]) self.assertEqual(obj.tag_order, (tag_class, tag_num)) obj_encoded = obj.encode() + self.assertEqual(encode2pass(obj), obj_encoded) obj_expled = obj(value, expl=tag_expl) self.assertTrue(obj_expled.expled) tag_class, _, tag_num = tag_decode(tag_expl) @@ -5384,6 +5440,8 @@ class TestChoice(CommonMixin, TestCase): with self.assertRaises(ObjNotReady) as err: obj.encode() repr(err.exception) + with self.assertRaises(ObjNotReady) as err: + encode2pass(obj) obj["whatever"] = Boolean() self.assertFalse(obj.ready) repr(obj) @@ -5532,6 +5590,7 @@ class TestChoice(CommonMixin, TestCase): self.assertFalse(obj.expled) self.assertEqual(obj.tag_order, obj.value.tag_order) obj_encoded = obj.encode() + self.assertEqual(encode2pass(obj), obj_encoded) obj_expled = obj(value, expl=tag_expl) self.assertTrue(obj_expled.expled) tag_class, _, tag_num = tag_decode(tag_expl) @@ -5879,6 +5938,8 @@ class SeqMixing(object): with self.assertRaises(ObjNotReady) as err: seq.encode() repr(err.exception) + with self.assertRaises(ObjNotReady) as err: + encode2pass(seq) for name, value in non_ready.items(): seq[name] = Boolean(value) self.assertTrue(seq.ready) @@ -6067,6 +6128,7 @@ class SeqMixing(object): pprint(seq, big_blobs=True, with_decode_path=True) self.assertTrue(seq.ready) seq_encoded = seq.encode() + self.assertEqual(encode2pass(seq), seq_encoded) seq_encoded_cer = encode_cer(seq) self.assertNotEqual(seq_encoded_cer, seq_encoded) self.assertSequenceEqual( @@ -6155,6 +6217,7 @@ class SeqMixing(object): seq, expect_outers = d.draw(sequences_strategy(seq_klass=self.base_klass)) self.assertTrue(seq.ready) seq_encoded = seq.encode() + self.assertEqual(encode2pass(seq), seq_encoded) seq_decoded, tail = seq.decode(seq_encoded) self.assertEqual(tail, b"") self.assertTrue(seq.ready) @@ -6527,6 +6590,8 @@ class SeqOfMixing(object): with self.assertRaises(ObjNotReady) as err: seqof.encode() repr(err.exception) + with self.assertRaises(ObjNotReady) as err: + encode2pass(seqof) for i, value in enumerate(values): self.assertEqual(seqof[i], value) if not seqof[i].ready: @@ -6570,6 +6635,10 @@ class SeqOfMixing(object): SeqOf(value).encode() ) repr(err.exception) + with assertRaisesRegex(self, DecodeError, "bounds") as err: + SeqOf(bounds=(bound_min, bound_max)).decode( + encode2pass(SeqOf(value)) + ) value = [Boolean(True)] * d.draw(integers( min_value=bound_max + 1, max_value=bound_max + 10, @@ -6582,6 +6651,10 @@ class SeqOfMixing(object): SeqOf(value).encode() ) repr(err.exception) + with assertRaisesRegex(self, DecodeError, "bounds") as err: + SeqOf(bounds=(bound_min, bound_max)).decode( + encode2pass(SeqOf(value)) + ) @given(integers(min_value=1, max_value=10)) def test_out_of_bounds(self, bound_max): @@ -6788,6 +6861,7 @@ class SeqOfMixing(object): pprint(obj, big_blobs=True, with_decode_path=True) self.assertFalse(obj.expled) obj_encoded = obj.encode() + self.assertEqual(encode2pass(obj), obj_encoded) obj_encoded_cer = encode_cer(obj) self.assertNotEqual(obj_encoded_cer, obj_encoded) self.assertSequenceEqual( @@ -6973,6 +7047,26 @@ class TestSequenceOf(SeqOfMixing, CommonMixin, TestCase): register_class(SeqOf) pickle_dumps(seqof) + def test_iterator_2pass(self): + class SeqOf(SequenceOf): + schema = Integer() + bounds = (1, float("+inf")) + def gen(): + for i in six_xrange(10): + yield Integer(i) + seqof = SeqOf(gen()) + self.assertTrue(seqof.ready) + _, state = seqof.encode1st() + self.assertFalse(seqof.ready) + seqof = seqof(gen()) + self.assertTrue(seqof.ready) + buf = BytesIO() + seqof.encode2nd(buf.write, iter(state)) + self.assertSequenceEqual( + [int(i) for i in seqof.decod(buf.getvalue())], + list(gen()), + ) + def test_non_ready_bound_min(self): class SeqOf(SequenceOf): schema = Integer() -- 2.44.0