X-Git-Url: http://www.git.cypherpunks.ru/?a=blobdiff_plain;f=pyderasn.py;h=39d2f2f3fe4260193b5f96959b9c47325b828318;hb=605094b65aa961926cf8a192cf79cd7f90e059ff;hp=6c14e4619972a8a31e68fdafd8cefbd48f344659;hpb=7af6ff46e2927030706fc4c0df0ddeadfa55dd79;p=pyderasn.git diff --git a/pyderasn.py b/pyderasn.py index 6c14e46..39d2f2f 100755 --- a/pyderasn.py +++ b/pyderasn.py @@ -75,9 +75,17 @@ tags simultaneously. There are :py:func:`pyderasn.tag_ctxp` and :py:func:`pyderasn.tag_ctxc` functions, allowing you to easily create ``CONTEXT`` ``PRIMITIVE``/``CONSTRUCTED`` tags, by specifying only the required tag -number. Pay attention that explicit tags always have *constructed* tag -(``tag_ctxc``), but implicit tags for primitive types are primitive -(``tag_ctxp``). +number. + +.. note:: + + EXPLICIT tags always have **constructed** tag. PyDERASN does not + explicitly check correctness of schema input here. + +.. note:: + + Implicit tags have **primitive** (``tag_ctxp``) encoding for + primitive values. :: @@ -420,7 +428,7 @@ defines_by_path context option ______________________________ Sometimes you either can not or do not want to explicitly set *defines* -in the scheme. You can dynamically apply those definitions when calling +in the schema. You can dynamically apply those definitions when calling ``.decode()`` method. Specify ``defines_by_path`` key in the :ref:`decode context `. Its @@ -663,9 +671,9 @@ from copy import copy from datetime import datetime from datetime import timedelta from math import ceil -from os import environ from string import ascii_letters from string import digits +from sys import version_info from unicodedata import category as unicat from six import add_metaclass @@ -690,7 +698,7 @@ except ImportError: # pragma: no cover def colored(what, *args, **kwargs): return what -__version__ = "6.1" +__version__ = "6.3" __all__ = ( "Any", @@ -764,7 +772,20 @@ EOC = b"\x00\x00" EOC_LEN = len(EOC) LENINDEF = b"\x80" # length indefinite mark LENINDEF_PP_CHAR = "I" if PY2 else "∞" -NAMEDTUPLE_KWARGS = {} if PY2 else {"module": __name__} +NAMEDTUPLE_KWARGS = {} if version_info < (3, 6) else {"module": __name__} +SET01 = frozenset("01") +DECIMALS = frozenset(digits) +DECIMAL_SIGNS = ".," + + +def pureint(value): + if not set(value) <= DECIMALS: + raise ValueError("non-pure integer") + return int(value) + +def fractions2float(fractions_raw): + pureint(fractions_raw) + return float("0." + fractions_raw) ######################################################################## @@ -1056,6 +1077,21 @@ class AutoAddSlots(type): return type.__new__(cls, name, bases, _dict) +BasicState = namedtuple("BasicState", ( + "version", + "tag", + "expl", + "default", + "optional", + "offset", + "llen", + "vlen", + "expl_lenindef", + "lenindef", + "ber_encoded", +), **NAMEDTUPLE_KWARGS) + + @add_metaclass(AutoAddSlots) class Obj(object): """Common ASN.1 object class @@ -1128,17 +1164,16 @@ class Obj(object): def __setstate__(self, state): if state.version != __version__: raise ValueError("data is pickled by different PyDERASN version") - self.tag = self.tag_default - self._value = None - self._expl = None - self.default = None - self.optional = False - self.offset = 0 - self.llen = 0 - self.vlen = 0 - self.expl_lenindef = False - self.lenindef = False - self.ber_encoded = False + self.tag = state.tag + self._expl = state.expl + self.default = state.default + self.optional = state.optional + self.offset = state.offset + self.llen = state.llen + self.vlen = state.vlen + self.expl_lenindef = state.expl_lenindef + self.lenindef = state.lenindef + self.ber_encoded = state.ber_encoded @property def tlen(self): @@ -1207,7 +1242,7 @@ class Obj(object): :param ctx: optional :ref:`context ` governing decoding process :param tag_only: decode only the tag, without length and contents (used only in Choice and Set structures, trying to - determine if tag satisfies the scheme) + determine if tag satisfies the schema) :param _ctx_immutable: do we need to ``copy.copy()`` ``ctx`` before using it? :returns: (Obj, remaining data) @@ -1631,9 +1666,9 @@ def pp_console_row( with_colours, )) if with_blob: - if isinstance(pp.blob, binary_type): + if pp.blob.__class__ == binary_type: cols.append(hexenc(pp.blob)) - elif isinstance(pp.blob, tuple): + elif pp.blob.__class__ == tuple: cols.append(", ".join(pp.blob)) if pp.optional: cols.append(_colourize("OPTIONAL", "red", with_colours)) @@ -1653,12 +1688,12 @@ def pp_console_blob(pp, decode_path_len_decrease=0): decode_path_len = len(pp.decode_path) - decode_path_len_decrease if decode_path_len > 0: cols.append(" ." * (decode_path_len + 1)) - if isinstance(pp.blob, binary_type): + if pp.blob.__class__ == binary_type: blob = hexenc(pp.blob).upper() for i in six_xrange(0, len(blob), 32): chunk = blob[i:i + 32] yield " ".join(cols + [colonize_hex(chunk)]) - elif isinstance(pp.blob, tuple): + elif pp.blob.__class__ == tuple: yield " ".join(cols + [", ".join(pp.blob)]) @@ -1728,20 +1763,11 @@ def pprint( # ASN.1 primitive types ######################################################################## -BooleanState = namedtuple("BooleanState", ( - "version", - "value", - "tag", - "expl", - "default", - "optional", - "offset", - "llen", - "vlen", - "expl_lenindef", - "lenindef", - "ber_encoded", -), **NAMEDTUPLE_KWARGS) +BooleanState = namedtuple( + "BooleanState", + BasicState._fields + ("value",), + **NAMEDTUPLE_KWARGS +) class Boolean(Obj): @@ -1788,7 +1814,7 @@ class Boolean(Obj): self._value = default def _value_sanitize(self, value): - if isinstance(value, bool): + if value.__class__ == bool: return value if issubclass(value.__class__, Boolean): return value._value @@ -1801,7 +1827,6 @@ class Boolean(Obj): def __getstate__(self): return BooleanState( __version__, - self._value, self.tag, self._expl, self.default, @@ -1812,21 +1837,12 @@ class Boolean(Obj): self.expl_lenindef, self.lenindef, self.ber_encoded, + self._value, ) def __setstate__(self, state): super(Boolean, self).__setstate__(state) self._value = state.value - self.tag = state.tag - self._expl = state.expl - self.default = state.default - self.optional = state.optional - self.offset = state.offset - self.llen = state.llen - self.vlen = state.vlen - self.expl_lenindef = state.expl_lenindef - self.lenindef = state.lenindef - self.ber_encoded = state.ber_encoded def __nonzero__(self): self._assert_ready() @@ -1837,7 +1853,7 @@ class Boolean(Obj): return self._value def __eq__(self, their): - if isinstance(their, bool): + if their.__class__ == bool: return self._value == their if not issubclass(their.__class__, Boolean): return False @@ -1969,23 +1985,11 @@ class Boolean(Obj): yield pp -IntegerState = namedtuple("IntegerState", ( - "version", - "specs", - "value", - "bound_min", - "bound_max", - "tag", - "expl", - "default", - "optional", - "offset", - "llen", - "vlen", - "expl_lenindef", - "lenindef", - "ber_encoded", -), **NAMEDTUPLE_KWARGS) +IntegerState = namedtuple( + "IntegerState", + BasicState._fields + ("specs", "value", "bound_min", "bound_max"), + **NAMEDTUPLE_KWARGS +) class Integer(Obj): @@ -2051,7 +2055,7 @@ class Integer(Obj): super(Integer, self).__init__(impl, expl, default, optional, _decoded) self._value = value specs = getattr(self, "schema", {}) if _specs is None else _specs - self.specs = specs if isinstance(specs, dict) else dict(specs) + self.specs = specs if specs.__class__ == dict else dict(specs) self._bound_min, self._bound_max = getattr( self, "bounds", @@ -2075,7 +2079,7 @@ class Integer(Obj): pass elif issubclass(value.__class__, Integer): value = value._value - elif isinstance(value, str): + elif value.__class__ == str: value = self.specs.get(value) if value is None: raise ObjUnknown("integer value: %s" % value) @@ -2092,10 +2096,6 @@ class Integer(Obj): def __getstate__(self): return IntegerState( __version__, - self.specs, - self._value, - self._bound_min, - self._bound_max, self.tag, self._expl, self.default, @@ -2106,6 +2106,10 @@ class Integer(Obj): self.expl_lenindef, self.lenindef, self.ber_encoded, + self.specs, + self._value, + self._bound_min, + self._bound_max, ) def __setstate__(self, state): @@ -2114,16 +2118,6 @@ class Integer(Obj): self._value = state.value self._bound_min = state.bound_min self._bound_max = state.bound_max - self.tag = state.tag - self._expl = state.expl - self.default = state.default - self.optional = state.optional - self.offset = state.offset - self.llen = state.llen - self.vlen = state.vlen - self.expl_lenindef = state.expl_lenindef - self.lenindef = state.lenindef - self.ber_encoded = state.ber_encoded def __int__(self): self._assert_ready() @@ -2338,24 +2332,11 @@ class Integer(Obj): yield pp -SET01 = frozenset(("0", "1")) -BitStringState = namedtuple("BitStringState", ( - "version", - "specs", - "value", - "tag", - "expl", - "default", - "optional", - "offset", - "llen", - "vlen", - "expl_lenindef", - "lenindef", - "ber_encoded", - "tag_constructed", - "defined", -), **NAMEDTUPLE_KWARGS) +BitStringState = namedtuple( + "BitStringState", + BasicState._fields + ("specs", "value", "tag_constructed", "defined"), + **NAMEDTUPLE_KWARGS +) class BitString(Obj): @@ -2433,7 +2414,7 @@ class BitString(Obj): """ super(BitString, self).__init__(impl, expl, default, optional, _decoded) specs = getattr(self, "schema", {}) if _specs is None else _specs - self.specs = specs if isinstance(specs, dict) else dict(specs) + self.specs = specs if specs.__class__ == dict else dict(specs) self._value = None if value is None else self._value_sanitize(value) if default is not None: default = self._value_sanitize(default) @@ -2479,14 +2460,14 @@ class BitString(Obj): len(value) * 4, hexdec(value + ("" if len(value) % 2 == 0 else "0")), ) - if isinstance(value, binary_type): + if value.__class__ == binary_type: return (len(value) * 8, value) raise InvalidValueType((self.__class__, string_types, binary_type)) - if isinstance(value, tuple): + if value.__class__ == tuple: if ( len(value) == 2 and isinstance(value[0], integer_types) and - isinstance(value[1], binary_type) + value[1].__class__ == binary_type ): return value bits = [] @@ -2513,8 +2494,6 @@ class BitString(Obj): def __getstate__(self): return BitStringState( __version__, - self.specs, - self._value, self.tag, self._expl, self.default, @@ -2525,6 +2504,8 @@ class BitString(Obj): self.expl_lenindef, self.lenindef, self.ber_encoded, + self.specs, + self._value, self.tag_constructed, self.defined, ) @@ -2533,16 +2514,6 @@ class BitString(Obj): super(BitString, self).__setstate__(state) self.specs = state.specs self._value = state.value - self.tag = state.tag - self._expl = state.expl - self.default = state.default - self.optional = state.optional - self.offset = state.offset - self.llen = state.llen - self.vlen = state.vlen - self.expl_lenindef = state.expl_lenindef - self.lenindef = state.lenindef - self.ber_encoded = state.ber_encoded self.tag_constructed = state.tag_constructed self.defined = state.defined @@ -2561,7 +2532,7 @@ class BitString(Obj): return self._value[1] def __eq__(self, their): - if isinstance(their, bytes): + if their.__class__ == bytes: return self._value[1] == their if not issubclass(their.__class__, BitString): return False @@ -2593,7 +2564,7 @@ class BitString(Obj): ) def __getitem__(self, key): - if isinstance(key, int): + if key.__class__ == int: bit_len, octets = self._value if key >= bit_len: return False @@ -2618,64 +2589,6 @@ class BitString(Obj): octets, )) - def _decode_chunk(self, lv, offset, decode_path): - try: - l, llen, v = len_decode(lv) - except DecodeError as err: - raise err.__class__( - msg=err.msg, - klass=self.__class__, - decode_path=decode_path, - offset=offset, - ) - if l > len(v): - raise NotEnoughData( - "encoded length is longer than data", - klass=self.__class__, - decode_path=decode_path, - offset=offset, - ) - if l == 0: - raise NotEnoughData( - "zero length", - klass=self.__class__, - decode_path=decode_path, - offset=offset, - ) - pad_size = byte2int(v) - if l == 1 and pad_size != 0: - raise DecodeError( - "invalid empty value", - klass=self.__class__, - decode_path=decode_path, - offset=offset, - ) - if pad_size > 7: - raise DecodeError( - "too big pad", - klass=self.__class__, - decode_path=decode_path, - offset=offset, - ) - if byte2int(v[l - 1:l]) & ((1 << pad_size) - 1) != 0: - raise DecodeError( - "invalid pad", - klass=self.__class__, - decode_path=decode_path, - offset=offset, - ) - v, tail = v[:l], v[l:] - obj = self.__class__( - value=((len(v) - 1) * 8 - pad_size, v[1:].tobytes()), - impl=self.tag, - expl=self._expl, - default=self.default, - optional=self.optional, - _specs=self.specs, - _decoded=(offset, llen, l), - ) - return obj, tail - def _decode(self, tlv, offset, decode_path, ctx, tag_only): try: t, tlen, lv = tag_strip(tlv) @@ -2689,23 +2602,8 @@ class BitString(Obj): if t == self.tag: if tag_only: # pragma: no cover return None - return self._decode_chunk(lv, offset, decode_path) - if t == self.tag_constructed: - if not ctx.get("bered", False): - raise DecodeError( - "unallowed BER constructed encoding", - klass=self.__class__, - decode_path=decode_path, - offset=offset, - ) - if tag_only: # pragma: no cover - return None - lenindef = False try: l, llen, v = len_decode(lv) - except LenIndefForm: - llen, l, v = 1, 0, lv[1:] - lenindef = True except DecodeError as err: raise err.__class__( msg=err.msg, @@ -2720,90 +2618,160 @@ class BitString(Obj): decode_path=decode_path, offset=offset, ) - if not lenindef and l == 0: + if l == 0: raise NotEnoughData( "zero length", klass=self.__class__, decode_path=decode_path, offset=offset, ) - chunks = [] - sub_offset = offset + tlen + llen - vlen = 0 - while True: - if lenindef: - if v[:EOC_LEN].tobytes() == EOC: - break - else: - if vlen == l: - break - if vlen > l: - raise DecodeError( - "chunk out of bounds", - klass=self.__class__, - decode_path=decode_path + (str(len(chunks) - 1),), - offset=chunks[-1].offset, - ) - sub_decode_path = decode_path + (str(len(chunks)),) - try: - chunk, v_tail = BitString().decode( - v, - offset=sub_offset, - decode_path=sub_decode_path, - leavemm=True, - ctx=ctx, - _ctx_immutable=False, - ) - except TagMismatch: - raise DecodeError( - "expected BitString encoded chunk", - klass=self.__class__, - decode_path=sub_decode_path, - offset=sub_offset, - ) - chunks.append(chunk) - sub_offset += chunk.tlvlen - vlen += chunk.tlvlen - v = v_tail - if len(chunks) == 0: + pad_size = byte2int(v) + if l == 1 and pad_size != 0: raise DecodeError( - "no chunks", + "invalid empty value", klass=self.__class__, decode_path=decode_path, offset=offset, ) - values = [] - bit_len = 0 - for chunk_i, chunk in enumerate(chunks[:-1]): - if chunk.bit_len % 8 != 0: - raise DecodeError( - "BitString chunk is not multiple of 8 bits", - klass=self.__class__, - decode_path=decode_path + (str(chunk_i),), - offset=chunk.offset, - ) - values.append(bytes(chunk)) - bit_len += chunk.bit_len - chunk_last = chunks[-1] - values.append(bytes(chunk_last)) - bit_len += chunk_last.bit_len + if pad_size > 7: + raise DecodeError( + "too big pad", + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + if byte2int(v[l - 1:l]) & ((1 << pad_size) - 1) != 0: + raise DecodeError( + "invalid pad", + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + v, tail = v[:l], v[l:] obj = self.__class__( - value=(bit_len, b"".join(values)), + value=((len(v) - 1) * 8 - pad_size, v[1:].tobytes()), impl=self.tag, expl=self._expl, default=self.default, optional=self.optional, _specs=self.specs, - _decoded=(offset, llen, vlen + (EOC_LEN if lenindef else 0)), + _decoded=(offset, llen, l), ) - obj.lenindef = lenindef - obj.ber_encoded = True - return obj, (v[EOC_LEN:] if lenindef else v) - raise TagMismatch( - klass=self.__class__, - decode_path=decode_path, - offset=offset, + return obj, tail + if t != self.tag_constructed: + raise TagMismatch( + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + if not ctx.get("bered", False): + raise DecodeError( + "unallowed BER constructed encoding", + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + if tag_only: # pragma: no cover + return None + lenindef = False + try: + l, llen, v = len_decode(lv) + except LenIndefForm: + llen, l, v = 1, 0, lv[1:] + lenindef = True + except DecodeError as err: + raise err.__class__( + msg=err.msg, + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + if l > len(v): + raise NotEnoughData( + "encoded length is longer than data", + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + if not lenindef and l == 0: + raise NotEnoughData( + "zero length", + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + chunks = [] + sub_offset = offset + tlen + llen + vlen = 0 + while True: + if lenindef: + if v[:EOC_LEN].tobytes() == EOC: + break + else: + if vlen == l: + break + if vlen > l: + raise DecodeError( + "chunk out of bounds", + klass=self.__class__, + decode_path=decode_path + (str(len(chunks) - 1),), + offset=chunks[-1].offset, + ) + sub_decode_path = decode_path + (str(len(chunks)),) + try: + chunk, v_tail = BitString().decode( + v, + offset=sub_offset, + decode_path=sub_decode_path, + leavemm=True, + ctx=ctx, + _ctx_immutable=False, + ) + except TagMismatch: + raise DecodeError( + "expected BitString encoded chunk", + klass=self.__class__, + decode_path=sub_decode_path, + offset=sub_offset, + ) + chunks.append(chunk) + sub_offset += chunk.tlvlen + vlen += chunk.tlvlen + v = v_tail + if len(chunks) == 0: + raise DecodeError( + "no chunks", + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + values = [] + bit_len = 0 + for chunk_i, chunk in enumerate(chunks[:-1]): + if chunk.bit_len % 8 != 0: + raise DecodeError( + "BitString chunk is not multiple of 8 bits", + klass=self.__class__, + decode_path=decode_path + (str(chunk_i),), + offset=chunk.offset, + ) + values.append(bytes(chunk)) + bit_len += chunk.bit_len + chunk_last = chunks[-1] + values.append(bytes(chunk_last)) + bit_len += chunk_last.bit_len + obj = self.__class__( + value=(bit_len, b"".join(values)), + impl=self.tag, + expl=self._expl, + default=self.default, + optional=self.optional, + _specs=self.specs, + _decoded=(offset, llen, vlen + (EOC_LEN if lenindef else 0)), ) + obj.lenindef = lenindef + obj.ber_encoded = True + return obj, (v[EOC_LEN:] if lenindef else v) def __repr__(self): return pp_console_row(next(self.pps())) @@ -2849,24 +2817,17 @@ class BitString(Obj): yield pp -OctetStringState = namedtuple("OctetStringState", ( - "version", - "value", - "bound_min", - "bound_max", - "tag", - "expl", - "default", - "optional", - "offset", - "llen", - "vlen", - "expl_lenindef", - "lenindef", - "ber_encoded", - "tag_constructed", - "defined", -), **NAMEDTUPLE_KWARGS) +OctetStringState = namedtuple( + "OctetStringState", + BasicState._fields + ( + "value", + "bound_min", + "bound_max", + "tag_constructed", + "defined", + ), + **NAMEDTUPLE_KWARGS +) class OctetString(Obj): @@ -2945,7 +2906,7 @@ class OctetString(Obj): ) def _value_sanitize(self, value): - if isinstance(value, binary_type): + if value.__class__ == binary_type: pass elif issubclass(value.__class__, OctetString): value = value._value @@ -2962,9 +2923,6 @@ class OctetString(Obj): def __getstate__(self): return OctetStringState( __version__, - self._value, - self._bound_min, - self._bound_max, self.tag, self._expl, self.default, @@ -2975,6 +2933,9 @@ class OctetString(Obj): self.expl_lenindef, self.lenindef, self.ber_encoded, + self._value, + self._bound_min, + self._bound_max, self.tag_constructed, self.defined, ) @@ -2984,16 +2945,6 @@ class OctetString(Obj): self._value = state.value self._bound_min = state.bound_min self._bound_max = state.bound_max - self.tag = state.tag - self._expl = state.expl - self.default = state.default - self.optional = state.optional - self.offset = state.offset - self.llen = state.llen - self.vlen = state.vlen - self.expl_lenindef = state.expl_lenindef - self.lenindef = state.lenindef - self.ber_encoded = state.ber_encoded self.tag_constructed = state.tag_constructed self.defined = state.defined @@ -3002,7 +2953,7 @@ class OctetString(Obj): return self._value def __eq__(self, their): - if isinstance(their, binary_type): + if their.__class__ == binary_type: return self._value == their if not issubclass(their.__class__, OctetString): return False @@ -3044,51 +2995,6 @@ class OctetString(Obj): self._value, )) - def _decode_chunk(self, lv, offset, decode_path, ctx): - try: - l, llen, v = len_decode(lv) - except DecodeError as err: - raise err.__class__( - msg=err.msg, - klass=self.__class__, - decode_path=decode_path, - offset=offset, - ) - if l > len(v): - raise NotEnoughData( - "encoded length is longer than data", - klass=self.__class__, - decode_path=decode_path, - offset=offset, - ) - v, tail = v[:l], v[l:] - try: - obj = self.__class__( - value=v.tobytes(), - bounds=(self._bound_min, self._bound_max), - impl=self.tag, - expl=self._expl, - default=self.default, - optional=self.optional, - _decoded=(offset, llen, l), - ctx=ctx, - ) - except DecodeError as err: - raise DecodeError( - msg=err.msg, - klass=self.__class__, - decode_path=decode_path, - offset=offset, - ) - except BoundsError as err: - raise DecodeError( - msg=str(err), - klass=self.__class__, - decode_path=decode_path, - offset=offset, - ) - return obj, tail - def _decode(self, tlv, offset, decode_path, ctx, tag_only): try: t, tlen, lv = tag_strip(tlv) @@ -3102,23 +3008,8 @@ class OctetString(Obj): if t == self.tag: if tag_only: return None - return self._decode_chunk(lv, offset, decode_path, ctx) - if t == self.tag_constructed: - if not ctx.get("bered", False): - raise DecodeError( - "unallowed BER constructed encoding", - klass=self.__class__, - decode_path=decode_path, - offset=offset, - ) - if tag_only: - return None - lenindef = False try: l, llen, v = len_decode(lv) - except LenIndefForm: - llen, l, v = 1, 0, lv[1:] - lenindef = True except DecodeError as err: raise err.__class__( msg=err.msg, @@ -3133,77 +3024,134 @@ class OctetString(Obj): decode_path=decode_path, offset=offset, ) - chunks = [] - sub_offset = offset + tlen + llen - vlen = 0 - while True: - if lenindef: - if v[:EOC_LEN].tobytes() == EOC: - break - else: - if vlen == l: - break - if vlen > l: - raise DecodeError( - "chunk out of bounds", - klass=self.__class__, - decode_path=decode_path + (str(len(chunks) - 1),), - offset=chunks[-1].offset, - ) - sub_decode_path = decode_path + (str(len(chunks)),) - try: - chunk, v_tail = OctetString().decode( - v, - offset=sub_offset, - decode_path=sub_decode_path, - leavemm=True, - ctx=ctx, - _ctx_immutable=False, - ) - except TagMismatch: - raise DecodeError( - "expected OctetString encoded chunk", - klass=self.__class__, - decode_path=sub_decode_path, - offset=sub_offset, - ) - chunks.append(chunk) - sub_offset += chunk.tlvlen - vlen += chunk.tlvlen - v = v_tail + v, tail = v[:l], v[l:] try: obj = self.__class__( - value=b"".join(bytes(chunk) for chunk in chunks), + value=v.tobytes(), bounds=(self._bound_min, self._bound_max), impl=self.tag, expl=self._expl, default=self.default, optional=self.optional, - _decoded=(offset, llen, vlen + (EOC_LEN if lenindef else 0)), + _decoded=(offset, llen, l), + ctx=ctx, + ) + except DecodeError as err: + raise DecodeError( + msg=err.msg, + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + except BoundsError as err: + raise DecodeError( + msg=str(err), + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + return obj, tail + if t != self.tag_constructed: + raise TagMismatch( + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + if not ctx.get("bered", False): + raise DecodeError( + "unallowed BER constructed encoding", + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + if tag_only: + return None + lenindef = False + try: + l, llen, v = len_decode(lv) + except LenIndefForm: + llen, l, v = 1, 0, lv[1:] + lenindef = True + except DecodeError as err: + raise err.__class__( + msg=err.msg, + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + if l > len(v): + raise NotEnoughData( + "encoded length is longer than data", + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + chunks = [] + sub_offset = offset + tlen + llen + vlen = 0 + while True: + if lenindef: + if v[:EOC_LEN].tobytes() == EOC: + break + else: + if vlen == l: + break + if vlen > l: + raise DecodeError( + "chunk out of bounds", + klass=self.__class__, + decode_path=decode_path + (str(len(chunks) - 1),), + offset=chunks[-1].offset, + ) + sub_decode_path = decode_path + (str(len(chunks)),) + try: + chunk, v_tail = OctetString().decode( + v, + offset=sub_offset, + decode_path=sub_decode_path, + leavemm=True, ctx=ctx, + _ctx_immutable=False, ) - except DecodeError as err: - raise DecodeError( - msg=err.msg, - klass=self.__class__, - decode_path=decode_path, - offset=offset, - ) - except BoundsError as err: + except TagMismatch: raise DecodeError( - msg=str(err), + "expected OctetString encoded chunk", klass=self.__class__, - decode_path=decode_path, - offset=offset, + decode_path=sub_decode_path, + offset=sub_offset, ) - obj.lenindef = lenindef - obj.ber_encoded = True - return obj, (v[EOC_LEN:] if lenindef else v) - raise TagMismatch( - klass=self.__class__, - decode_path=decode_path, - offset=offset, - ) + chunks.append(chunk) + sub_offset += chunk.tlvlen + vlen += chunk.tlvlen + v = v_tail + try: + obj = self.__class__( + value=b"".join(bytes(chunk) for chunk in chunks), + bounds=(self._bound_min, self._bound_max), + impl=self.tag, + expl=self._expl, + default=self.default, + optional=self.optional, + _decoded=(offset, llen, vlen + (EOC_LEN if lenindef else 0)), + ctx=ctx, + ) + except DecodeError as err: + raise DecodeError( + msg=err.msg, + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + except BoundsError as err: + raise DecodeError( + msg=str(err), + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + obj.lenindef = lenindef + obj.ber_encoded = True + return obj, (v[EOC_LEN:] if lenindef else v) def __repr__(self): return pp_console_row(next(self.pps())) @@ -3242,19 +3190,7 @@ class OctetString(Obj): yield pp -NullState = namedtuple("NullState", ( - "version", - "tag", - "expl", - "default", - "optional", - "offset", - "llen", - "vlen", - "expl_lenindef", - "lenindef", - "ber_encoded", -), **NAMEDTUPLE_KWARGS) +NullState = namedtuple("NullState", BasicState._fields, **NAMEDTUPLE_KWARGS) class Null(Obj): @@ -3304,19 +3240,6 @@ class Null(Obj): self.ber_encoded, ) - def __setstate__(self, state): - super(Null, self).__setstate__(state) - self.tag = state.tag - self._expl = state.expl - self.default = state.default - self.optional = state.optional - self.offset = state.offset - self.llen = state.llen - self.vlen = state.vlen - self.expl_lenindef = state.expl_lenindef - self.lenindef = state.lenindef - self.ber_encoded = state.ber_encoded - def __eq__(self, their): if not issubclass(their.__class__, Null): return False @@ -3410,28 +3333,11 @@ class Null(Obj): yield pp -ObjectIdentifierState = namedtuple("ObjectIdentifierState", ( - "version", - "value", - "tag", - "expl", - "default", - "optional", - "offset", - "llen", - "vlen", - "expl_lenindef", - "lenindef", - "ber_encoded", - "defines", -), **NAMEDTUPLE_KWARGS) - - -def pureint(value): - i = int(value) - if (value[0] in "+- ") or (value[-1] == " "): - raise ValueError("non-pure integer") - return i +ObjectIdentifierState = namedtuple( + "ObjectIdentifierState", + BasicState._fields + ("value", "defines"), + **NAMEDTUPLE_KWARGS +) class ObjectIdentifier(Obj): @@ -3500,10 +3406,10 @@ class ObjectIdentifier(Obj): self.defines = defines def __add__(self, their): + if their.__class__ == tuple: + return self.__class__(self._value + their) if isinstance(their, self.__class__): return self.__class__(self._value + their._value) - if isinstance(their, tuple): - return self.__class__(self._value + their) raise InvalidValueType((self.__class__, tuple)) def _value_sanitize(self, value): @@ -3514,7 +3420,7 @@ class ObjectIdentifier(Obj): value = tuple(pureint(arc) for arc in value.split(".")) except ValueError: raise InvalidOID("unacceptable arcs values") - if isinstance(value, tuple): + if value.__class__ == tuple: if len(value) < 2: raise InvalidOID("less than 2 arcs") first_arc = value[0] @@ -3537,7 +3443,6 @@ class ObjectIdentifier(Obj): def __getstate__(self): return ObjectIdentifierState( __version__, - self._value, self.tag, self._expl, self.default, @@ -3548,22 +3453,13 @@ class ObjectIdentifier(Obj): self.expl_lenindef, self.lenindef, self.ber_encoded, + self._value, self.defines, ) def __setstate__(self, state): super(ObjectIdentifier, self).__setstate__(state) self._value = state.value - self.tag = state.tag - self._expl = state.expl - self.default = state.default - self.optional = state.optional - self.offset = state.offset - self.llen = state.llen - self.vlen = state.vlen - self.expl_lenindef = state.expl_lenindef - self.lenindef = state.lenindef - self.ber_encoded = state.ber_encoded self.defines = state.defines def __iter__(self): @@ -3582,7 +3478,7 @@ class ObjectIdentifier(Obj): ) def __eq__(self, their): - if isinstance(their, tuple): + if their.__class__ == tuple: return self._value == their if not issubclass(their.__class__, ObjectIdentifier): return False @@ -3818,7 +3714,7 @@ class Enumerated(Integer): def escape_control_unicode(c): - if unicat(c).startswith("C"): + if unicat(c)[0] == "C": c = repr(c).lstrip("u").strip("'") return c @@ -3892,9 +3788,9 @@ class CommonString(OctetString): value_decoded = None if isinstance(value, self.__class__): value_raw = value._value - elif isinstance(value, text_type): + elif value.__class__ == text_type: value_decoded = value - elif isinstance(value, binary_type): + elif value.__class__ == binary_type: value_raw = value else: raise InvalidValueType((self.__class__, text_type, binary_type)) @@ -3918,9 +3814,9 @@ class CommonString(OctetString): return value_raw def __eq__(self, their): - if isinstance(their, binary_type): + if their.__class__ == binary_type: return self._value == their - if isinstance(their, text_type): + if their.__class__ == text_type: return self._value == their.encode(self.encoding) if not isinstance(their, self.__class__): return False @@ -4145,11 +4041,6 @@ LEN_YYYYMMDDHHMMSSDMZ = len("YYYYMMDDHHMMSSDMZ") LEN_YYYYMMDDHHMMSSZ = len("YYYYMMDDHHMMSSZ") -def fractions2float(fractions_raw): - pureint(fractions_raw) - return float("0." + fractions_raw) - - class VisibleString(CommonString): __slots__ = () tag_default = tag_encode(26) @@ -4164,6 +4055,16 @@ UTCTimeState = namedtuple( ) +def str_to_time_fractions(value): + v = pureint(value) + year, v = (v // 10**10), (v % 10**10) + month, v = (v // 10**8), (v % 10**8) + day, v = (v // 10**6), (v % 10**6) + hour, v = (v // 10**4), (v % 10**4) + minute, second = (v // 100), (v % 100) + return year, month, day, hour, minute, second + + class UTCTime(VisibleString): """``UTCTime`` datetime type @@ -4178,6 +4079,9 @@ class UTCTime(VisibleString): >>> UTCTime(datetime(2057, 9, 30, 22, 7, 50)).todatetime() datetime.datetime(1957, 9, 30, 22, 7, 50) + If BER encoded value was met, then ``ber_raw`` attribute will hold + its raw representation. + .. warning:: Pay attention that UTCTime can not hold full year, so all years @@ -4191,7 +4095,7 @@ class UTCTime(VisibleString): * minutes are not exceeding 60 * offset value is not exceeding 14 hours """ - __slots__ = ("_ber_raw",) + __slots__ = ("ber_raw",) tag_default = tag_encode(23) encoding = "ascii" asn1_type_name = "UTCTime" @@ -4219,10 +4123,10 @@ class UTCTime(VisibleString): None, None, impl, expl, None, optional, _decoded, ctx, ) self._value = value - self._ber_raw = None + self.ber_raw = None if value is not None: - self._value, self._ber_raw = self._value_sanitize(value, ctx) - self.ber_encoded = self._ber_raw is not None + self._value, self.ber_raw = self._value_sanitize(value, ctx) + self.ber_encoded = self.ber_raw is not None if default is not None: default, _ = self._value_sanitize(default) self.default = self.__class__( @@ -4236,18 +4140,12 @@ class UTCTime(VisibleString): self.optional = optional def _strptime_bered(self, value): - year = pureint(value[:2]) - year += 2000 if year < 50 else 1900 - decoded = datetime( - year, # %Y - pureint(value[2:4]), # %m - pureint(value[4:6]), # %d - pureint(value[6:8]), # %H - pureint(value[8:10]), # %M - ) + year, month, day, hour, minute, _ = str_to_time_fractions(value[:10] + "00") value = value[10:] if len(value) == 0: raise ValueError("no timezone") + year += 2000 if year < 50 else 1900 + decoded = datetime(year, month, day, hour, minute) offset = 0 if value[-1] == "Z": value = value[:-1] @@ -4260,10 +4158,11 @@ class UTCTime(VisibleString): sign = 1 else: raise ValueError("invalid UTC offset") - offset = 60 * pureint(value[-2:]) + v = pureint(value[-4:]) + offset, v = (60 * (v % 100)), v // 100 if offset >= 3600: raise ValueError("invalid UTC offset minutes") - offset += 3600 * pureint(value[-4:-2]) + offset += 3600 * v if offset > 14 * 3600: raise ValueError("too big UTC offset") offset *= sign @@ -4275,8 +4174,7 @@ class UTCTime(VisibleString): seconds = pureint(value) if seconds >= 60: raise ValueError("invalid seconds value") - decoded += timedelta(seconds=seconds) - return offset, decoded + return offset, decoded + timedelta(seconds=seconds) def _strptime(self, value): # datetime.strptime's format: %y%m%d%H%M%SZ @@ -4284,16 +4182,9 @@ class UTCTime(VisibleString): raise ValueError("invalid UTCTime length") if value[-1] != "Z": raise ValueError("non UTC timezone") - year = pureint(value[:2]) + year, month, day, hour, minute, second = str_to_time_fractions(value[:-1]) year += 2000 if year < 50 else 1900 - return datetime( - year, # %y - pureint(value[2:4]), # %m - pureint(value[4:6]), # %d - pureint(value[6:8]), # %H - pureint(value[8:10]), # %M - pureint(value[10:12]), # %S - ) + return datetime(year, month, day, hour, minute, second) def _dt_sanitize(self, value): if value.year < 1950 or value.year > 2049: @@ -4301,7 +4192,7 @@ class UTCTime(VisibleString): return value.replace(microsecond=0) def _value_sanitize(self, value, ctx=None): - if isinstance(value, binary_type): + if value.__class__ == binary_type: try: value_decoded = value.decode("ascii") except (UnicodeEncodeError, UnicodeDecodeError) as err: @@ -4315,7 +4206,7 @@ class UTCTime(VisibleString): try: offset, _value = self._strptime_bered(value_decoded) _value = _value - timedelta(seconds=offset) - return self._dt_sanitize(_value), value_decoded + return self._dt_sanitize(_value), value except (TypeError, ValueError, OverflowError) as _err: err = _err raise DecodeError( @@ -4324,7 +4215,7 @@ class UTCTime(VisibleString): ) if isinstance(value, self.__class__): return value._value, None - if isinstance(value, datetime): + if value.__class__ == datetime: return self._dt_sanitize(value), None raise InvalidValueType((self.__class__, datetime)) @@ -4332,35 +4223,35 @@ class UTCTime(VisibleString): if self.ready: value = self._value.isoformat() if self.ber_encoded: - value += " (%s)" % self._ber_raw + value += " (%s)" % self.ber_raw return value def __unicode__(self): if self.ready: value = self._value.isoformat() if self.ber_encoded: - value += " (%s)" % self._ber_raw + value += " (%s)" % self.ber_raw return value return text_type(self._pp_value()) def __getstate__(self): return UTCTimeState( *super(UTCTime, self).__getstate__(), - **{"ber_raw": self._ber_raw} + **{"ber_raw": self.ber_raw} ) def __setstate__(self, state): super(UTCTime, self).__setstate__(state) - self._ber_raw = state.ber_raw + self.ber_raw = state.ber_raw def __bytes__(self): self._assert_ready() return self._encode_time() def __eq__(self, their): - if isinstance(their, binary_type): + if their.__class__ == binary_type: return self._encode_time() == their - if isinstance(their, datetime): + if their.__class__ == datetime: return self.todatetime() == their if not isinstance(their, self.__class__): return False @@ -4453,14 +4344,9 @@ class GeneralizedTime(UTCTime): def _strptime_bered(self, value): if len(value) < 4 + 3 * 2: raise ValueError("invalid GeneralizedTime") - decoded = datetime( - pureint(value[:4]), # %Y - pureint(value[4:6]), # %m - pureint(value[6:8]), # %d - pureint(value[8:10]), # %H - ) - value = value[10:] - offset = 0 + year, month, day, hour, _, _ = str_to_time_fractions(value[:10] + "0000") + decoded = datetime(year, month, day, hour) + offset, value = 0, value[10:] if len(value) == 0: return offset, decoded if value[-1] == "Z": @@ -4470,22 +4356,24 @@ class GeneralizedTime(UTCTime): idx = value.rfind(char) if idx == -1: continue - offset_raw = value[idx + 1:].replace(":", "") - if len(offset_raw) not in (2, 4): + offset_raw, value = value[idx + 1:].replace(":", ""), value[:idx] + v = pureint(offset_raw) + if len(offset_raw) == 4: + offset, v = (60 * (v % 100)), v // 100 + if offset >= 3600: + raise ValueError("invalid UTC offset minutes") + elif len(offset_raw) == 2: + pass + else: raise ValueError("invalid UTC offset") - value = value[:idx] - offset = 60 * pureint(offset_raw[2:] or "0") - if offset >= 3600: - raise ValueError("invalid UTC offset minutes") - offset += 3600 * pureint(offset_raw[:2]) + offset += 3600 * v if offset > 14 * 3600: raise ValueError("too big UTC offset") offset *= sign break if len(value) == 0: return offset, decoded - decimal_signs = ".," - if value[0] in decimal_signs: + if value[0] in DECIMAL_SIGNS: return offset, ( decoded + timedelta(seconds=3600 * fractions2float(value[1:])) ) @@ -4495,7 +4383,7 @@ class GeneralizedTime(UTCTime): value = value[2:] if len(value) == 0: return offset, decoded - if value[0] in decimal_signs: + if value[0] in DECIMAL_SIGNS: return offset, ( decoded + timedelta(seconds=60 * fractions2float(value[1:])) ) @@ -4505,7 +4393,7 @@ class GeneralizedTime(UTCTime): value = value[2:] if len(value) == 0: return offset, decoded - if value[0] not in decimal_signs: + if value[0] not in DECIMAL_SIGNS: raise ValueError("invalid format after seconds") return offset, ( decoded + timedelta(microseconds=10**6 * fractions2float(value[1:])) @@ -4517,14 +4405,7 @@ class GeneralizedTime(UTCTime): # datetime.strptime's format: %Y%m%d%H%M%SZ if value[-1] != "Z": raise ValueError("non UTC timezone") - return datetime( - pureint(value[:4]), # %Y - pureint(value[4:6]), # %m - pureint(value[6:8]), # %d - pureint(value[8:10]), # %H - pureint(value[10:12]), # %M - pureint(value[12:14]), # %S - ) + return datetime(*str_to_time_fractions(value[:-1])) if l >= LEN_YYYYMMDDHHMMSSDMZ: # datetime.strptime's format: %Y%m%d%H%M%S.%fZ if value[-1] != "Z": @@ -4538,16 +4419,8 @@ class GeneralizedTime(UTCTime): if us_len > 6: raise ValueError("only microsecond fractions are supported") us = pureint(us + ("0" * (6 - us_len))) - decoded = datetime( - pureint(value[:4]), # %Y - pureint(value[4:6]), # %m - pureint(value[6:8]), # %d - pureint(value[8:10]), # %H - pureint(value[10:12]), # %M - pureint(value[12:14]), # %S - us, # %f - ) - return decoded + year, month, day, hour, minute, second = str_to_time_fractions(value[:14]) + return datetime(year, month, day, hour, minute, second, us) raise ValueError("invalid GeneralizedTime length") def _encode_time(self): @@ -4591,21 +4464,11 @@ class BMPString(CommonString): asn1_type_name = "BMPString" -ChoiceState = namedtuple("ChoiceState", ( - "version", - "specs", - "value", - "tag", - "expl", - "default", - "optional", - "offset", - "llen", - "vlen", - "expl_lenindef", - "lenindef", - "ber_encoded", -), **NAMEDTUPLE_KWARGS) +ChoiceState = namedtuple( + "ChoiceState", + BasicState._fields + ("specs", "value",), + **NAMEDTUPLE_KWARGS +) class Choice(Obj): @@ -4669,7 +4532,7 @@ class Choice(Obj): if len(schema) == 0: raise ValueError("schema must be specified") self.specs = ( - schema if isinstance(schema, OrderedDict) else OrderedDict(schema) + schema if schema.__class__ == OrderedDict else OrderedDict(schema) ) self._value = None if value is not None: @@ -4684,7 +4547,7 @@ class Choice(Obj): self._value = copy(default_obj._value) def _value_sanitize(self, value): - if isinstance(value, tuple) and len(value) == 2: + if (value.__class__ == tuple) and len(value) == 2: choice, obj = value spec = self.specs.get(choice) if spec is None: @@ -4710,8 +4573,6 @@ class Choice(Obj): def __getstate__(self): return ChoiceState( __version__, - self.specs, - copy(self._value), self.tag, self._expl, self.default, @@ -4722,24 +4583,17 @@ class Choice(Obj): self.expl_lenindef, self.lenindef, self.ber_encoded, + self.specs, + copy(self._value), ) def __setstate__(self, state): super(Choice, self).__setstate__(state) self.specs = state.specs self._value = state.value - self._expl = state.expl - self.default = state.default - self.optional = state.optional - self.offset = state.offset - self.llen = state.llen - self.vlen = state.vlen - self.expl_lenindef = state.expl_lenindef - self.lenindef = state.lenindef - self.ber_encoded = state.ber_encoded def __eq__(self, their): - if isinstance(their, tuple) and len(their) == 2: + if (their.__class__ == tuple) and len(their) == 2: return self._value == their if not isinstance(their, self.__class__): return False @@ -4910,20 +4764,11 @@ class PrimitiveTypes(Choice): )) -AnyState = namedtuple("AnyState", ( - "version", - "value", - "tag", - "expl", - "optional", - "offset", - "llen", - "vlen", - "expl_lenindef", - "lenindef", - "ber_encoded", - "defined", -), **NAMEDTUPLE_KWARGS) +AnyState = namedtuple( + "AnyState", + BasicState._fields + ("value", "defined"), + **NAMEDTUPLE_KWARGS +) class Any(Obj): @@ -4960,7 +4805,7 @@ class Any(Obj): self.defined = None def _value_sanitize(self, value): - if isinstance(value, binary_type): + if value.__class__ == binary_type: return value if isinstance(value, self.__class__): return value._value @@ -4983,9 +4828,9 @@ class Any(Obj): def __getstate__(self): return AnyState( __version__, - self._value, self.tag, self._expl, + None, self.optional, self.offset, self.llen, @@ -4993,25 +4838,17 @@ class Any(Obj): self.expl_lenindef, self.lenindef, self.ber_encoded, + self._value, self.defined, ) def __setstate__(self, state): super(Any, self).__setstate__(state) self._value = state.value - self.tag = state.tag - self._expl = state.expl - self.optional = state.optional - self.offset = state.offset - self.llen = state.llen - self.vlen = state.vlen - self.expl_lenindef = state.expl_lenindef - self.lenindef = state.lenindef - self.ber_encoded = state.ber_encoded self.defined = state.defined def __eq__(self, their): - if isinstance(their, binary_type): + if their.__class__ == binary_type: return self._value == their if issubclass(their.__class__, Any): return self._value == their._value @@ -5157,7 +4994,7 @@ def get_def_by_path(defines_by_path, sub_decode_path): if len(path) != len(sub_decode_path): continue for p1, p2 in zip(path, sub_decode_path): - if (p1 != any) and (p1 != p2): + if (not p1 is any) and (p1 != p2): break else: return define @@ -5188,21 +5025,11 @@ def abs_decode_path(decode_path, rel_path): return decode_path + rel_path -SequenceState = namedtuple("SequenceState", ( - "version", - "specs", - "value", - "tag", - "expl", - "default", - "optional", - "offset", - "llen", - "vlen", - "expl_lenindef", - "lenindef", - "ber_encoded", -), **NAMEDTUPLE_KWARGS) +SequenceState = namedtuple( + "SequenceState", + BasicState._fields + ("specs", "value",), + **NAMEDTUPLE_KWARGS +) class Sequence(Obj): @@ -5314,7 +5141,7 @@ class Sequence(Obj): if schema is None: schema = getattr(self, "schema", ()) self.specs = ( - schema if isinstance(schema, OrderedDict) else OrderedDict(schema) + schema if schema.__class__ == OrderedDict else OrderedDict(schema) ) self._value = {} if value is not None: @@ -5357,8 +5184,6 @@ class Sequence(Obj): def __getstate__(self): return SequenceState( __version__, - self.specs, - {k: copy(v) for k, v in iteritems(self._value)}, self.tag, self._expl, self.default, @@ -5369,22 +5194,14 @@ class Sequence(Obj): self.expl_lenindef, self.lenindef, self.ber_encoded, + self.specs, + {k: copy(v) for k, v in iteritems(self._value)}, ) def __setstate__(self, state): super(Sequence, self).__setstate__(state) self.specs = state.specs self._value = state.value - self.tag = state.tag - self._expl = state.expl - self.default = state.default - self.optional = state.optional - self.offset = state.offset - self.llen = state.llen - self.vlen = state.vlen - self.expl_lenindef = state.expl_lenindef - self.lenindef = state.lenindef - self.ber_encoded = state.ber_encoded def __eq__(self, their): if not isinstance(their, self.__class__): @@ -5442,19 +5259,17 @@ class Sequence(Obj): return spec.default return None - def _encoded_values(self): - raws = [] + def _values_for_encoding(self): for name, spec in iteritems(self.specs): value = self._value.get(name) if value is None: if spec.optional: continue raise ObjNotReady(name) - raws.append(value.encode()) - return raws + yield value def _encode(self): - v = b"".join(self._encoded_values()) + v = b"".join(v.encode() for v in self._values_for_encoding()) return b"".join((self.tag, len_encode(len(v)), v)) def _decode(self, tlv, offset, decode_path, ctx, tag_only): @@ -5689,8 +5504,8 @@ class Set(Sequence): .. _allow_unordered_set_ctx: DER prohibits unordered values encoding and will raise an error - during decode. If If :ref:`bered ` context option is set, - then no error will occure. Also you can disable strict values + during decode. If :ref:`bered ` context option is set, + then no error will occur. Also you can disable strict values ordering check by setting ``"allow_unordered_set": True`` :ref:`context ` option. """ @@ -5699,7 +5514,7 @@ class Set(Sequence): asn1_type_name = "SET" def _encode(self): - raws = self._encoded_values() + raws = [v.encode() for v in self._values_for_encoding()] raws.sort() v = b"".join(raws) return b"".join((self.tag, len_encode(len(v)), v)) @@ -5840,34 +5655,23 @@ class Set(Sequence): tail = v[EOC_LEN:] obj.lenindef = True obj._value = values - if not obj.ready: - raise DecodeError( - "not all values are ready", - klass=self.__class__, - decode_path=decode_path, - offset=offset, - ) + for name, spec in iteritems(self.specs): + if name not in values and not spec.optional: + raise DecodeError( + "%s value is not ready" % name, + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) obj.ber_encoded = ber_encoded return obj, tail -SequenceOfState = namedtuple("SequenceOfState", ( - "version", - "spec", - "value", - "bound_min", - "bound_max", - "tag", - "expl", - "default", - "optional", - "offset", - "llen", - "vlen", - "expl_lenindef", - "lenindef", - "ber_encoded", -), **NAMEDTUPLE_KWARGS) +SequenceOfState = namedtuple( + "SequenceOfState", + BasicState._fields + ("spec", "value", "bound_min", "bound_max"), + **NAMEDTUPLE_KWARGS +) class SequenceOf(Obj): @@ -5968,10 +5772,6 @@ class SequenceOf(Obj): def __getstate__(self): return SequenceOfState( __version__, - self.spec, - [copy(v) for v in self._value], - self._bound_min, - self._bound_max, self.tag, self._expl, self.default, @@ -5982,6 +5782,10 @@ class SequenceOf(Obj): self.expl_lenindef, self.lenindef, self.ber_encoded, + self.spec, + [copy(v) for v in self._value], + self._bound_min, + self._bound_max, ) def __setstate__(self, state): @@ -5990,16 +5794,6 @@ class SequenceOf(Obj): self._value = state.value self._bound_min = state.bound_min self._bound_max = state.bound_max - self.tag = state.tag - self._expl = state.expl - self.default = state.default - self.optional = state.optional - self.offset = state.offset - self.llen = state.llen - self.vlen = state.vlen - self.expl_lenindef = state.expl_lenindef - self.lenindef = state.lenindef - self.ber_encoded = state.ber_encoded def __eq__(self, their): if isinstance(their, self.__class__): @@ -6065,11 +5859,11 @@ class SequenceOf(Obj): def __getitem__(self, key): return self._value[key] - def _encoded_values(self): - return [v.encode() for v in self._value] + def _values_for_encoding(self): + return iter(self._value) def _encode(self): - v = b"".join(self._encoded_values()) + v = b"".join(v.encode() for v in self._values_for_encoding()) return b"".join((self.tag, len_encode(len(v)), v)) def _decode(self, tlv, offset, decode_path, ctx, tag_only, ordering_check=False): @@ -6232,7 +6026,7 @@ class SetOf(SequenceOf): asn1_type_name = "SET OF" def _encode(self): - raws = self._encoded_values() + raws = [v.encode() for v in self._values_for_encoding()] raws.sort() v = b"".join(raws) return b"".join((self.tag, len_encode(len(v)), v)) @@ -6390,6 +6184,7 @@ def main(): # pragma: no cover if args.defines_by_path is not None: ctx["defines_by_path"] = obj_by_path(args.defines_by_path) obj, tail = schema().decode(der, ctx=ctx) + from os import environ print(pprinter( obj, oid_maps=oid_maps,