X-Git-Url: http://www.git.cypherpunks.ru/?a=blobdiff_plain;f=pyderasn.py;h=162c1404a3af15ddbae345c63d391f8d34fb45e2;hb=498f47100e12f1bef247b29460ac4621fa0fc9f9;hp=7bb1799863c454bb22bdbdc70c75411c47635f04;hpb=749705c0df79a03dfda53dc0f78044efba8590a6;p=pyderasn.git diff --git a/pyderasn.py b/pyderasn.py index 7bb1799..162c140 100755 --- a/pyderasn.py +++ b/pyderasn.py @@ -1,7 +1,9 @@ #!/usr/bin/env python # coding: utf-8 # cython: language_level=3 -# PyDERASN -- Python ASN.1 DER/BER codec with abstract structures +# pylint: disable=line-too-long,superfluous-parens,protected-access,too-many-lines +# pylint: disable=too-many-return-statements,too-many-branches,too-many-statements +# PyDERASN -- Python ASN.1 DER/CER/BER codec with abstract structures # Copyright (C) 2017-2020 Sergey Matveev # # This program is free software: you can redistribute it and/or modify @@ -663,6 +665,11 @@ parse the CRL above with fully assembled ``RevokedCertificate``:: ): print("serial number:", int(obj["userCertificate"])) +.. note:: + + SEQUENCE/SET values with DEFAULT specified are automatically decoded + without evgen mode. + .. _mmap: mmap-ed file @@ -761,6 +768,8 @@ forcefully encoded in DER during CER encoding, by specifying bounds = (1, 32) der_forced = True +.. _agg_octet_string: + agg_octet_string ________________ @@ -827,6 +836,8 @@ copy the payload (without BER/CER encoding interleaved overhead) in it. Virtually it won't take memory more than for keeping small structures and 1 KB binary chunks. +.. _seqof-iterators: + SEQUENCE OF iterators _____________________ @@ -842,6 +853,54 @@ generator taking necessary data from the database and giving the ``RevokedCertificate`` objects. Only binary representation of that objects will take memory during DER encoding. +2-pass DER encoding +------------------- + +There is ability to do 2-pass encoding to DER, writing results directly +to specified writer (buffer, file, whatever). It could be 1.5+ times +slower than ordinary encoding, but it takes little memory for 1st pass +state storing. For example, 1st pass state for CACert.org's CRL with +~416K of certificate entries takes nearly 3.5 MB of memory. +``SignedData`` with several gigabyte ``EncapsulatedContentInfo`` takes +nearly 0.5 KB of memory. + +If you use :ref:`mmap-ed ` memoryviews, :ref:`SEQUENCE OF +iterators ` and write directly to opened file, then +there is very small memory footprint. + +1st pass traverses through all the objects of the structure and returns +the size of DER encoded structure, together with 1st pass state object. +That state contains precalculated lengths for various objects inside the +structure. + +:: + + fulllen, state = obj.encode1st() + +2nd pass takes the writer and 1st pass state. It traverses through all +the objects again, but writes their encoded representation to the writer. + +:: + + opener = io.open if PY2 else open + with opener("result", "wb") as fd: + obj.encode2nd(fd.write, iter(state)) + +.. warning:: + + You **MUST NOT** use 1st pass state if anything is changed in the + objects. It is intended to be used immediately after 1st pass is + done! + +If you use :ref:`SEQUENCE OF iterators `, then you +have to reinitialize the values after the 1st pass. And you **have to** +be sure that the iterator gives exactly the same values as previously. +Yes, you have to run your iterator twice -- because this is two pass +encoding mode. + +If you want to encode to the memory, then you can use convenient +:py:func:`pyderasn.encode2pass` helper. + Base Obj -------- .. autoclass:: pyderasn.Obj @@ -858,7 +917,7 @@ _______ Integer _______ .. autoclass:: pyderasn.Integer - :members: __init__, named + :members: __init__, named, tohex BitString _________ @@ -953,6 +1012,7 @@ Various .. autofunction:: pyderasn.abs_decode_path .. autofunction:: pyderasn.agg_octet_string .. autofunction:: pyderasn.colonize_hex +.. autofunction:: pyderasn.encode2pass .. autofunction:: pyderasn.encode_cer .. autofunction:: pyderasn.file_mmaped .. autofunction:: pyderasn.hexenc @@ -1099,6 +1159,7 @@ from mmap import PROT_READ from operator import attrgetter from string import ascii_letters from string import digits +from sys import maxsize as sys_maxsize from sys import version_info from unicodedata import category as unicat @@ -1124,7 +1185,7 @@ except ImportError: # pragma: no cover def colored(what, *args, **kwargs): return what -__version__ = "7.0" +__version__ = "7.4" __all__ = ( "agg_octet_string", @@ -1134,8 +1195,10 @@ __all__ = ( "Boolean", "BoundsError", "Choice", + "colonize_hex", "DecodeError", "DecodePathDefBy", + "encode2pass", "encode_cer", "Enumerated", "ExceedingData", @@ -1518,9 +1581,33 @@ def len_decode(data): return l, 1 + octets_num, data[1 + octets_num:] +LEN0 = len_encode(0) +LEN1 = len_encode(1) LEN1K = len_encode(1000) +def len_size(l): + """How many bytes length field will take + """ + if l < 128: + return 1 + if l < 256: # 1 << 8 + return 2 + if l < 65536: # 1 << 16 + return 3 + if l < 16777216: # 1 << 24 + return 4 + if l < 4294967296: # 1 << 32 + return 5 + if l < 1099511627776: # 1 << 40 + return 6 + if l < 281474976710656: # 1 << 48 + return 7 + if l < 72057594037927936: # 1 << 56 + return 8 + raise OverflowError("too big length") + + def write_full(writer, data): """Fully write provided data @@ -1539,6 +1626,17 @@ def write_full(writer, data): written += n +# If it is 64-bit system, then use compact 64-bit array of unsigned +# longs. Use an ordinary list with universal integers otherwise, that +# is slower. +if sys_maxsize > 2 ** 32: + def state_2pass_new(): + return array("L") +else: + def state_2pass_new(): + return [] + + ######################################################################## # Base class ######################################################################## @@ -1669,13 +1767,13 @@ class Obj(object): @property def tlen(self): - """See :ref:`decoding` + """.. seealso:: :ref:`decoding` """ return len(self.tag) @property def tlvlen(self): - """See :ref:`decoding` + """.. seealso:: :ref:`decoding` """ return self.tlen + self.llen + self.vlen @@ -1697,9 +1795,18 @@ class Obj(object): def _encode(self): # pragma: no cover raise NotImplementedError() + def _encode_cer(self, writer): + write_full(writer, self._encode()) + def _decode(self, tlv, offset, decode_path, ctx, tag_only, evgen_mode): # pragma: no cover yield NotImplemented + def _encode1st(self, state): + raise NotImplementedError() + + def _encode2nd(self, writer, state_iter): + raise NotImplementedError() + def encode(self): """DER encode the structure @@ -1710,6 +1817,36 @@ class Obj(object): return raw return b"".join((self._expl, len_encode(len(raw)), raw)) + def encode1st(self, state=None): + """Do the 1st pass of 2-pass encoding + + :rtype: (int, array("L")) + :returns: full length of encoded data and precalculated various + objects lengths + """ + if state is None: + state = state_2pass_new() + if self._expl is None: + return self._encode1st(state) + state.append(0) + idx = len(state) - 1 + vlen, _ = self._encode1st(state) + state[idx] = vlen + fulllen = len(self._expl) + len_size(vlen) + vlen + return fulllen, state + + def encode2nd(self, writer, state_iter): + """Do the 2nd pass of 2-pass encoding + + :param writer: must comply with ``io.RawIOBase.write`` behaviour + :param state_iter: iterator over the 1st pass state (``iter(state)``) + """ + if self._expl is None: + self._encode2nd(writer, state_iter) + else: + write_full(writer, self._expl + len_encode(next(state_iter))) + self._encode2nd(writer, state_iter) + def encode_cer(self, writer): """CER encode the structure to specified writer @@ -1727,9 +1864,6 @@ class Obj(object): if self._expl is not None: write_full(writer, EOC) - def _encode_cer(self, writer): - write_full(writer, self._encode()) - def hexencode(self): """Do hexadecimal encoded :py:meth:`pyderasn.Obj.encode` """ @@ -1795,7 +1929,8 @@ class Obj(object): That method is identical to :py:meth:`pyderasn.Obj.decode`, but it returns the generator producing ``(decode_path, obj, tail)`` - values. See :ref:`evgen mode `. + values. + .. seealso:: :ref:`evgen mode `. """ if ctx is None: ctx = {} @@ -1944,25 +2079,25 @@ class Obj(object): @property def expled(self): - """See :ref:`decoding` + """.. seealso:: :ref:`decoding` """ return self._expl is not None @property def expl_tag(self): - """See :ref:`decoding` + """.. seealso:: :ref:`decoding` """ return self._expl @property def expl_tlen(self): - """See :ref:`decoding` + """.. seealso:: :ref:`decoding` """ return len(self._expl) @property def expl_llen(self): - """See :ref:`decoding` + """.. seealso:: :ref:`decoding` """ if self.expl_lenindef: return 1 @@ -1970,31 +2105,31 @@ class Obj(object): @property def expl_offset(self): - """See :ref:`decoding` + """.. seealso:: :ref:`decoding` """ return self.offset - self.expl_tlen - self.expl_llen @property def expl_vlen(self): - """See :ref:`decoding` + """.. seealso:: :ref:`decoding` """ return self.tlvlen @property def expl_tlvlen(self): - """See :ref:`decoding` + """.. seealso:: :ref:`decoding` """ return self.expl_tlen + self.expl_llen + self.expl_vlen @property def fulloffset(self): - """See :ref:`decoding` + """.. seealso:: :ref:`decoding` """ return self.expl_offset if self.expled else self.offset @property def fulllen(self): - """See :ref:`decoding` + """.. seealso:: :ref:`decoding` """ return self.expl_tlvlen if self.expled else self.tlvlen @@ -2041,6 +2176,17 @@ def encode_cer(obj): return buf.getvalue() +def encode2pass(obj): + """Encode (2-pass mode) to DER in memory buffer + + :returns bytes: memory buffer contents + """ + buf = BytesIO() + _, state = obj.encode1st() + obj.encode2nd(buf.write, iter(state)) + return buf.getvalue() + + class DecodePathDefBy(object): """DEFINED BY representation inside decode path """ @@ -2232,13 +2378,8 @@ def pp_console_row( cols.append(_colourize("(%s)" % oid_name, "green", with_colours)) break if pp.asn1_type_name == Integer.asn1_type_name: - hex_repr = hex(int(pp.obj._value))[2:].upper() - if len(hex_repr) % 2 != 0: - hex_repr = "0" + hex_repr cols.append(_colourize( - "(%s)" % colonize_hex(hex_repr), - "green", - with_colours, + "(%s)" % colonize_hex(pp.obj.tohex()), "green", with_colours, )) if with_blob: if pp.blob.__class__ == binary_type: @@ -2458,11 +2599,14 @@ class Boolean(Obj): def _encode(self): self._assert_ready() - return b"".join(( - self.tag, - len_encode(1), - (b"\xFF" if self._value else b"\x00"), - )) + return b"".join((self.tag, LEN1, (b"\xFF" if self._value else b"\x00"))) + + def _encode1st(self, state): + return len(self.tag) + 2, state + + def _encode2nd(self, writer, state_iter): + self._assert_ready() + write_full(writer, self._encode()) def _decode(self, tlv, offset, decode_path, ctx, tag_only, evgen_mode): try: @@ -2702,13 +2846,23 @@ class Integer(Obj): self._assert_ready() return int(self._value) + def tohex(self): + """Hexadecimal representation + + Use :py:func:`pyderasn.colonize_hex` for colonizing it. + """ + hex_repr = hex(int(self))[2:].upper() + if len(hex_repr) % 2 != 0: + hex_repr = "0" + hex_repr + return hex_repr + def __hash__(self): self._assert_ready() - return hash( - self.tag + - bytes(self._expl or b"") + + return hash(b"".join(( + self.tag, + bytes(self._expl or b""), str(self._value).encode("ascii"), - ) + ))) def __eq__(self, their): if isinstance(their, integer_types): @@ -2755,7 +2909,7 @@ class Integer(Obj): _specs=self.specs, ) - def _encode(self): + def _encode_payload(self): self._assert_ready() value = self._value if PY2: @@ -2792,8 +2946,19 @@ class Integer(Obj): bytes_len += 1 else: break + return octets + + def _encode(self): + octets = self._encode_payload() return b"".join((self.tag, len_encode(len(octets)), octets)) + def _encode1st(self, state): + l = len(self._encode_payload()) + return len(self.tag) + len_size(l) + l, state + + def _encode2nd(self, writer, state_iter): + write_full(writer, self._encode()) + def _decode(self, tlv, offset, decode_path, ctx, tag_only, evgen_mode): try: t, _, lv = tag_strip(tlv) @@ -3178,6 +3343,21 @@ class BitString(Obj): octets, )) + def _encode1st(self, state): + self._assert_ready() + _, octets = self._value + l = len(octets) + 1 + return len(self.tag) + len_size(l) + l, state + + def _encode2nd(self, writer, state_iter): + bit_len, octets = self._value + write_full(writer, b"".join(( + self.tag, + len_encode(len(octets) + 1), + int2byte((8 - bit_len % 8) % 8), + ))) + write_full(writer, octets) + def _encode_cer(self, writer): bit_len, octets = self._value if len(octets) + 1 <= 1000: @@ -3629,6 +3809,16 @@ class OctetString(Obj): self._value, )) + def _encode1st(self, state): + self._assert_ready() + l = len(self._value) + return len(self.tag) + len_size(l) + l, state + + def _encode2nd(self, writer, state_iter): + value = self._value + write_full(writer, self.tag + len_encode(len(value))) + write_full(writer, value) + def _encode_cer(self, writer): octets = self._value if len(octets) <= 1000: @@ -3897,6 +4087,8 @@ def agg_octet_string(evgens, decode_path, raw, writer): :param writer: buffer.write where string is going to be saved :param writer: where string is going to be saved. Must comply with ``io.RawIOBase.write`` behaviour + + .. seealso:: :ref:`agg_octet_string` """ decode_path_len = len(decode_path) for dp, obj, _ in evgens: @@ -3985,7 +4177,13 @@ class Null(Obj): ) def _encode(self): - return self.tag + len_encode(0) + return self.tag + LEN0 + + def _encode1st(self, state): + return len(self.tag) + 1, state + + def _encode2nd(self, writer, state_iter): + write_full(writer, self.tag + LEN0) def _decode(self, tlv, offset, decode_path, ctx, tag_only, evgen_mode): try: @@ -4108,7 +4306,9 @@ class ObjectIdentifier(Obj): tuple element is ``{OID: pyderasn.Obj()}`` dictionary, mapping between current OID value and structure applied to defined field. - :ref:`Read about DEFINED BY ` + + .. seealso:: :ref:`definedby` + :param bytes impl: override default tag with ``IMPLICIT`` one :param bytes expl: override default tag with ``EXPLICIT`` one :param default: set default value. Type same as in ``value`` @@ -4201,11 +4401,11 @@ class ObjectIdentifier(Obj): def __hash__(self): self._assert_ready() - return hash( - self.tag + - bytes(self._expl or b"") + + return hash(b"".join(( + self.tag, + bytes(self._expl or b""), str(self._value).encode("ascii"), - ) + ))) def __eq__(self, their): if their.__class__ == tuple: @@ -4239,7 +4439,7 @@ class ObjectIdentifier(Obj): optional=self.optional if optional is None else optional, ) - def _encode(self): + def _encode_octets(self): self._assert_ready() value = self._value first_value = value[1] @@ -4255,9 +4455,19 @@ class ObjectIdentifier(Obj): octets = [zero_ended_encode(first_value)] for arc in value[2:]: octets.append(zero_ended_encode(arc)) - v = b"".join(octets) + return b"".join(octets) + + def _encode(self): + v = self._encode_octets() return b"".join((self.tag, len_encode(len(v)), v)) + def _encode1st(self, state): + l = len(self._encode_octets()) + return len(self.tag) + len_size(l) + l, state + + def _encode2nd(self, writer, state_iter): + write_full(writer, self._encode()) + def _decode(self, tlv, offset, decode_path, ctx, tag_only, evgen_mode): try: t, _, lv = tag_strip(tlv) @@ -4781,8 +4991,11 @@ class IA5String(CommonString): LEN_YYMMDDHHMMSSZ = len("YYMMDDHHMMSSZ") +LEN_LEN_YYMMDDHHMMSSZ = len_encode(LEN_YYMMDDHHMMSSZ) +LEN_YYMMDDHHMMSSZ_WITH_LEN = len(LEN_LEN_YYMMDDHHMMSSZ) + LEN_YYMMDDHHMMSSZ LEN_YYYYMMDDHHMMSSDMZ = len("YYYYMMDDHHMMSSDMZ") LEN_YYYYMMDDHHMMSSZ = len("YYYYMMDDHHMMSSZ") +LEN_LEN_YYYYMMDDHHMMSSZ = len_encode(LEN_YYYYMMDDHHMMSSZ) class VisibleString(CommonString): @@ -4970,6 +5183,7 @@ class UTCTime(VisibleString): if self.ber_encoded: value += " (%s)" % self.ber_raw return value + return None def __unicode__(self): if self.ready: @@ -5011,8 +5225,14 @@ class UTCTime(VisibleString): def _encode(self): self._assert_ready() - value = self._encode_time() - return b"".join((self.tag, len_encode(len(value)), value)) + return b"".join((self.tag, LEN_LEN_YYMMDDHHMMSSZ, self._encode_time())) + + def _encode1st(self, state): + return len(self.tag) + LEN_YYMMDDHHMMSSZ_WITH_LEN, state + + def _encode2nd(self, writer, state_iter): + self._assert_ready() + write_full(writer, self._encode()) def _encode_cer(self, writer): write_full(writer, self._encode()) @@ -5178,6 +5398,22 @@ class GeneralizedTime(UTCTime): encoded += (".%06d" % value.microsecond).rstrip("0") return (encoded + "Z").encode("ascii") + def _encode(self): + self._assert_ready() + value = self._value + if value.microsecond > 0: + encoded = self._encode_time() + return b"".join((self.tag, len_encode(len(encoded)), encoded)) + return b"".join((self.tag, LEN_LEN_YYYYMMDDHHMMSSZ, self._encode_time())) + + def _encode1st(self, state): + self._assert_ready() + vlen = len(self._encode_time()) + return len(self.tag) + len_size(vlen) + vlen, state + + def _encode2nd(self, writer, state_iter): + write_full(writer, self._encode()) + class GraphicString(CommonString): __slots__ = () @@ -5422,6 +5658,13 @@ class Choice(Obj): self._assert_ready() return self._value[1].encode() + def _encode1st(self, state): + self._assert_ready() + return self._value[1].encode1st(state) + + def _encode2nd(self, writer, state_iter): + self._value[1].encode2nd(writer, state_iter) + def _encode_cer(self, writer): self._assert_ready() self._value[1].encode_cer(writer) @@ -5690,6 +5933,20 @@ class Any(Obj): return value return value.encode() + def _encode1st(self, state): + self._assert_ready() + value = self._value + if value.__class__ == binary_type: + return len(value), state + return value.encode1st(state) + + def _encode2nd(self, writer, state_iter): + value = self._value + if value.__class__ == binary_type: + write_full(writer, value) + else: + value.encode2nd(writer, state_iter) + def _encode_cer(self, writer): self._assert_ready() value = self._value @@ -5816,19 +6073,6 @@ class Any(Obj): # ASN.1 constructed types ######################################################################## -def get_def_by_path(defines_by_path, sub_decode_path): - """Get define by decode path - """ - for path, define in defines_by_path: - if len(path) != len(sub_decode_path): - continue - for p1, p2 in zip(path, sub_decode_path): - if (not p1 is any) and (p1 != p2): - break - else: - return define - - def abs_decode_path(decode_path, rel_path): """Create an absolute decode path from current and relative ones @@ -5861,7 +6105,19 @@ SequenceState = namedtuple( ) -class Sequence(Obj): +class SequenceEncode1stMixing(object): + def _encode1st(self, state): + state.append(0) + idx = len(state) - 1 + vlen = 0 + for v in self._values_for_encoding(): + l, _ = v.encode1st(state) + vlen += l + state[idx] = vlen + return len(self.tag) + len_size(vlen) + vlen, state + + +class Sequence(SequenceEncode1stMixing, Obj): """``SEQUENCE`` structure type You have to make specification of sequence:: @@ -5949,11 +6205,10 @@ class Sequence(Obj): defaulted values existence validation by setting ``"allow_default_values": True`` :ref:`context ` option. - .. warning:: - - Check for default value existence is not performed in - ``evgen_mode``, because previously decoded values are not stored - in memory, to be able to compare them. + All values with DEFAULT specified are decoded atomically in + :ref:`evgen mode `. If DEFAULT value is some kind of + SEQUENCE, then it will be yielded as a single element, not + disassembled. That is required for DEFAULT existence check. Two sequences are equal if they have equal specification (schema), implicit/explicit tagging and the same values. @@ -6108,6 +6363,11 @@ class Sequence(Obj): v = b"".join(v.encode() for v in self._values_for_encoding()) return b"".join((self.tag, len_encode(len(v)), v)) + def _encode2nd(self, writer, state_iter): + write_full(writer, self.tag + len_encode(next(state_iter))) + for v in self._values_for_encoding(): + v.encode2nd(writer, state_iter) + def _encode_cer(self, writer): write_full(writer, self.tag + LENINDEF) for v in self._values_for_encoding(): @@ -6174,9 +6434,10 @@ class Sequence(Obj): len(v) == 0 ): continue + spec_defaulted = spec.default is not None sub_decode_path = decode_path + (name,) try: - if evgen_mode: + if evgen_mode and not spec_defaulted: for _decode_path, value, v_tail in spec.decode_evgen( v, sub_offset, @@ -6254,9 +6515,10 @@ class Sequence(Obj): vlen += value_len sub_offset += value_len v = v_tail - if not evgen_mode: - if spec.default is not None and value == spec.default: - # This will not work in evgen_mode + if spec_defaulted: + if evgen_mode: + yield sub_decode_path, value, v_tail + if value == spec.default: if ctx_bered or ctx_allow_default_values: ber_encoded = True else: @@ -6266,6 +6528,7 @@ class Sequence(Obj): decode_path=sub_decode_path, offset=sub_offset, ) + if not evgen_mode: values[name] = value spec_defines = getattr(spec, "defines", ()) if len(spec_defines) == 0: @@ -6352,7 +6615,7 @@ class Sequence(Obj): yield pp -class Set(Sequence): +class Set(Sequence, SequenceEncode1stMixing): """``SET`` structure type Its usage is identical to :py:class:`pyderasn.Sequence`. @@ -6369,17 +6632,16 @@ class Set(Sequence): tag_default = tag_encode(form=TagFormConstructed, num=17) asn1_type_name = "SET" - def _encode(self): - v = b"".join(value.encode() for value in sorted( - self._values_for_encoding(), + def _values_for_encoding(self): + return sorted( + super(Set, self)._values_for_encoding(), key=attrgetter("tag_order"), - )) - return b"".join((self.tag, len_encode(len(v)), v)) + ) def _encode_cer(self, writer): write_full(writer, self.tag + LENINDEF) for v in sorted( - self._values_for_encoding(), + super(Set, self)._values_for_encoding(), key=attrgetter("tag_order_cer"), ): v.encode_cer(writer) @@ -6466,7 +6728,8 @@ class Set(Sequence): decode_path=decode_path, offset=offset, ) - if evgen_mode: + spec_defaulted = spec.default is not None + if evgen_mode and not spec_defaulted: for _decode_path, value, v_tail in spec.decode_evgen( v, sub_offset, @@ -6498,17 +6761,20 @@ class Set(Sequence): decode_path=sub_decode_path, offset=sub_offset, ) - if spec.default is None or value != spec.default: - pass - elif ctx_bered or ctx_allow_default_values: - ber_encoded = True - else: - raise DecodeError( - "DEFAULT value met", - klass=self.__class__, - decode_path=sub_decode_path, - offset=sub_offset, - ) + if spec_defaulted: + if evgen_mode: + yield sub_decode_path, value, v_tail + if value != spec.default: + pass + elif ctx_bered or ctx_allow_default_values: + ber_encoded = True + else: + raise DecodeError( + "DEFAULT value met", + klass=self.__class__, + decode_path=sub_decode_path, + offset=sub_offset, + ) values[name] = value del _specs_items[name] tag_order_prev = value_tag_order @@ -6555,7 +6821,7 @@ SequenceOfState = namedtuple( ) -class SequenceOf(Obj): +class SequenceOf(SequenceEncode1stMixing, Obj): """``SEQUENCE OF`` sequence type For that kind of type you must specify the object it will carry on @@ -6644,7 +6910,6 @@ class SequenceOf(Obj): value = value._value elif hasattr(value, NEXT_ATTR_NAME): iterator = True - value = value elif hasattr(value, "__iter__"): value = list(value) else: @@ -6785,6 +7050,31 @@ class SequenceOf(Obj): value = b"".join(v.encode() for v in self._values_for_encoding()) return b"".join((self.tag, len_encode(len(value)), value)) + def _encode1st(self, state): + state = super(SequenceOf, self)._encode1st(state) + if hasattr(self._value, NEXT_ATTR_NAME): + self._value = [] + return state + + def _encode2nd(self, writer, state_iter): + write_full(writer, self.tag + len_encode(next(state_iter))) + iterator = hasattr(self._value, NEXT_ATTR_NAME) + if iterator: + values_count = 0 + class_expected = self.spec.__class__ + values_for_encoding = self._values_for_encoding() + self._value = [] + for v in values_for_encoding: + if not isinstance(v, class_expected): + raise InvalidValueType((class_expected,)) + v.encode2nd(writer, state_iter) + values_count += 1 + if not self._bound_min <= values_count <= self._bound_max: + raise BoundsError(self._bound_min, values_count, self._bound_max) + else: + for v in self._values_for_encoding(): + v.encode2nd(writer, state_iter) + def _encode_cer(self, writer): write_full(writer, self.tag + LENINDEF) iterator = hasattr(self._value, NEXT_ATTR_NAME) @@ -7008,6 +7298,17 @@ class SetOf(SequenceOf): v = b"".join(sorted(v.encode() for v in self._values_for_encoding())) return b"".join((self.tag, len_encode(len(v)), v)) + def _encode2nd(self, writer, state_iter): + write_full(writer, self.tag + len_encode(next(state_iter))) + values = [] + for v in self._values_for_encoding(): + buf = BytesIO() + v.encode2nd(buf.write, state_iter) + values.append(buf.getvalue()) + values.sort() + for v in values: + write_full(writer, v) + def _encode_cer(self, writer): write_full(writer, self.tag + LENINDEF) for v in sorted(encode_cer(v) for v in self._values_for_encoding()): @@ -7067,6 +7368,7 @@ def generic_decoder(): # pragma: no cover with_colours=False, with_decode_path=False, decode_path_only=(), + decode_path=(), ): def _pprint_pps(pps): for pp in pps: @@ -7098,13 +7400,13 @@ def generic_decoder(): # pragma: no cover else: for row in _pprint_pps(pp): yield row - return "\n".join(_pprint_pps(obj.pps())) + return "\n".join(_pprint_pps(obj.pps(decode_path))) return SEQUENCEOF(), pprint_any def main(): # pragma: no cover import argparse - parser = argparse.ArgumentParser(description="PyDERASN ASN.1 BER/DER decoder") + parser = argparse.ArgumentParser(description="PyDERASN ASN.1 BER/CER/DER decoder") parser.add_argument( "--skip", type=int, @@ -7163,9 +7465,9 @@ def main(): # pragma: no cover [obj_by_path(_path) for _path in (args.oids or "").split(",")] if args.oids else () ) + from functools import partial if args.schema: schema = obj_by_path(args.schema) - from functools import partial pprinter = partial(pprint, big_blobs=True) else: schema, pprinter = generic_decoder() @@ -7197,4 +7499,5 @@ def main(): # pragma: no cover if __name__ == "__main__": + from pyderasn import * main()