X-Git-Url: http://www.git.cypherpunks.ru/?a=blobdiff_plain;f=pyderasn.py;h=099e8a436c21bf18412651ab6563958856feed16;hb=5b15d6d7bdfe67388f01fd2ab1638f8c02c1294b;hp=885f816b2097c93589d9cc2e44af0cc2c5ac5961;hpb=eb67733960022e82168120c03b5c0e81272ddb2b;p=pyderasn.git diff --git a/pyderasn.py b/pyderasn.py index 885f816..099e8a4 100755 --- a/pyderasn.py +++ b/pyderasn.py @@ -269,11 +269,11 @@ for AlgorithmIdentifier of X.509's ``tbsCertificate.subjectPublicKeyInfo.algorithm.algorithm``:: ( - (('parameters',), { + (("parameters",), { id_ecPublicKey: ECParameters(), id_GostR3410_2001: GostR34102001PublicKeyParameters(), }), - (('..', 'subjectPublicKey'), { + (("..", "subjectPublicKey"), { id_rsaEncryption: RSAPublicKey(), id_GostR3410_2001: OctetString(), }), @@ -473,6 +473,7 @@ from collections import OrderedDict from datetime import datetime from math import ceil from os import environ +from string import digits from six import add_metaclass from six import binary_type @@ -560,6 +561,8 @@ TagClassReprs = { TagClassPrivate: "PRIVATE ", TagClassUniversal: "UNIV ", } +EOC = b"\x00\x00" +EOC_LEN = len(EOC) ######################################################################## @@ -605,6 +608,10 @@ class NotEnoughData(DecodeError): pass +class LenIndefiniteForm(DecodeError): + pass + + class TagMismatch(DecodeError): pass @@ -805,7 +812,7 @@ def len_decode(data): if octets_num + 1 > len(data): raise NotEnoughData("encoded length is longer than data") if octets_num == 0: - raise DecodeError("long form instead of short one") + raise LenIndefiniteForm() if byte2int(data[1:]) == 0: raise DecodeError("leading zeros") l = 0 @@ -842,6 +849,9 @@ class Obj(object): "offset", "llen", "vlen", + "lenindef", + "expl_lenindef", + "bered", ) def __init__( @@ -863,6 +873,9 @@ class Obj(object): self.optional = optional self.offset, self.llen, self.vlen = _decoded self.default = None + self.lenindef = False + self.expl_lenindef = False + self.bered = False @property def ready(self): # pragma: no cover @@ -911,7 +924,7 @@ class Obj(object): def _encode(self): # pragma: no cover raise NotImplementedError() - def _decode(self, tlv, offset, decode_path, ctx): # pragma: no cover + def _decode(self, tlv, offset, decode_path, ctx, tag_only): # pragma: no cover raise NotImplementedError() def encode(self): @@ -920,7 +933,15 @@ class Obj(object): return raw return b"".join((self._expl, len_encode(len(raw)), raw)) - def decode(self, data, offset=0, leavemm=False, decode_path=(), ctx=None): + def decode( + self, + data, + offset=0, + leavemm=False, + decode_path=(), + ctx=None, + tag_only=False, + ): """Decode the data :param data: either binary or memoryview @@ -928,18 +949,25 @@ class Obj(object): :param bool leavemm: do we need to leave memoryview of remaining data as is, or convert it to bytes otherwise :param ctx: optional :ref:`context ` governing decoding process. + :param tag_only: decode only the tag, without length and contents + (used only in Choice and Set structures, trying to + determine if tag satisfies the scheme) :returns: (Obj, remaining data) """ if ctx is None: ctx = {} tlv = memoryview(data) if self._expl is None: - obj, tail = self._decode( + result = self._decode( tlv, offset, decode_path=decode_path, ctx=ctx, + tag_only=tag_only, ) + if tag_only: + return + obj, tail = result else: try: t, tlen, lv = tag_strip(tlv) @@ -958,6 +986,35 @@ class Obj(object): ) try: l, llen, v = len_decode(lv) + except LenIndefiniteForm as err: + if not ctx.get("bered", False): + raise err.__class__( + msg=err.msg, + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + llen, v = 1, lv[1:] + offset += tlen + llen + result = self._decode( + v, + offset=offset, + decode_path=decode_path, + ctx=ctx, + tag_only=tag_only, + ) + if tag_only: + return + obj, tail = result + eoc_expected, tail = tail[:EOC_LEN], tail[EOC_LEN:] + if eoc_expected.tobytes() != EOC: + raise DecodeError( + msg="no EOC", + decode_path=decode_path, + offset=offset, + ) + obj.vlen += EOC_LEN + obj.expl_lenindef = True except DecodeError as err: raise err.__class__( msg=err.msg, @@ -965,19 +1022,24 @@ class Obj(object): decode_path=decode_path, offset=offset, ) - if l > len(v): - raise NotEnoughData( - "encoded length is longer than data", - klass=self.__class__, + else: + if l > len(v): + raise NotEnoughData( + "encoded length is longer than data", + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + result = self._decode( + v, + offset=offset + tlen + llen, decode_path=decode_path, - offset=offset, + ctx=ctx, + tag_only=tag_only, ) - obj, tail = self._decode( - v, - offset=offset + tlen + llen, - decode_path=decode_path, - ctx=ctx, - ) + if tag_only: + return + obj, tail = result return obj, (tail if leavemm else tail.tobytes()) @property @@ -994,6 +1056,8 @@ class Obj(object): @property def expl_llen(self): + if self.expl_lenindef: + return 1 return len(len_encode(self.tlvlen)) @property @@ -1012,11 +1076,14 @@ class Obj(object): class DecodePathDefBy(object): """DEFINED BY representation inside decode path """ - __slots__ = ('defined_by',) + __slots__ = ("defined_by",) def __init__(self, defined_by): self.defined_by = defined_by + def __ne__(self, their): + return not(self == their) + def __eq__(self, their): if not isinstance(their, self.__class__): return False @@ -1333,7 +1400,7 @@ class Boolean(Obj): (b"\xFF" if self._value else b"\x00"), )) - def _decode(self, tlv, offset, decode_path, ctx): + def _decode(self, tlv, offset, decode_path, ctx, tag_only): try: t, _, lv = tag_strip(tlv) except DecodeError as err: @@ -1349,6 +1416,8 @@ class Boolean(Obj): decode_path=decode_path, offset=offset, ) + if tag_only: + return try: l, _, v = len_decode(lv) except DecodeError as err: @@ -1373,10 +1442,14 @@ class Boolean(Obj): offset=offset, ) first_octet = byte2int(v) + bered = False if first_octet == 0: value = False elif first_octet == 0xFF: value = True + elif ctx.get("bered", False): + value = True + bered = True else: raise DecodeError( "unacceptable Boolean value", @@ -1392,6 +1465,7 @@ class Boolean(Obj): optional=self.optional, _decoded=(offset, 1, 1), ) + obj.bered = bered return obj, v[1:] def __repr__(self): @@ -1626,7 +1700,7 @@ class Integer(Obj): break return b"".join((self.tag, len_encode(len(octets)), octets)) - def _decode(self, tlv, offset, decode_path, ctx): + def _decode(self, tlv, offset, decode_path, ctx, tag_only): try: t, _, lv = tag_strip(tlv) except DecodeError as err: @@ -1642,6 +1716,8 @@ class Integer(Obj): decode_path=decode_path, offset=offset, ) + if tag_only: + return try: l, llen, v = len_decode(lv) except DecodeError as err: @@ -1765,19 +1841,19 @@ class BitString(Obj): class KeyUsage(BitString): schema = ( - ('digitalSignature', 0), - ('nonRepudiation', 1), - ('keyEncipherment', 2), + ("digitalSignature", 0), + ("nonRepudiation", 1), + ("keyEncipherment", 2), ) - >>> b = KeyUsage(('keyEncipherment', 'nonRepudiation')) + >>> b = KeyUsage(("keyEncipherment", "nonRepudiation")) KeyUsage BIT STRING 3 bits nonRepudiation, keyEncipherment >>> b.named ['nonRepudiation', 'keyEncipherment'] >>> b.specs {'nonRepudiation': 1, 'digitalSignature': 0, 'keyEncipherment': 2} """ - __slots__ = ("specs", "defined") + __slots__ = ("tag_constructed", "specs", "defined") tag_default = tag_encode(3) asn1_type_name = "BIT STRING" @@ -1815,6 +1891,12 @@ class BitString(Obj): if value is None: self._value = default self.defined = None + tag_klass, _, tag_num = tag_decode(self.tag) + self.tag_constructed = tag_encode( + klass=tag_klass, + form=TagFormConstructed, + num=tag_num, + ) def _bits2octets(self, bits): if len(self.specs) > 0: @@ -1960,22 +2042,7 @@ class BitString(Obj): octets, )) - def _decode(self, tlv, offset, decode_path, ctx): - try: - t, _, lv = tag_strip(tlv) - except DecodeError as err: - raise err.__class__( - msg=err.msg, - klass=self.__class__, - decode_path=decode_path, - offset=offset, - ) - if t != self.tag: - raise TagMismatch( - klass=self.__class__, - decode_path=decode_path, - offset=offset, - ) + def _decode_chunk(self, lv, offset, decode_path, ctx): try: l, llen, v = len_decode(lv) except DecodeError as err: @@ -2033,6 +2100,129 @@ class BitString(Obj): ) return obj, tail + def _decode(self, tlv, offset, decode_path, ctx, tag_only): + try: + t, tlen, lv = tag_strip(tlv) + except DecodeError as err: + raise err.__class__( + msg=err.msg, + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + if t == self.tag: + if tag_only: + return + return self._decode_chunk(lv, offset, decode_path, ctx) + if t == self.tag_constructed: + if not ctx.get("bered", False): + raise DecodeError( + msg="unallowed BER constructed encoding", + decode_path=decode_path, + offset=offset, + ) + if tag_only: + return + lenindef = False + try: + l, llen, v = len_decode(lv) + except LenIndefiniteForm: + llen, l, v = 1, 0, lv[1:] + lenindef = True + except DecodeError as err: + raise err.__class__( + msg=err.msg, + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + if l > 0 and l > len(v): + raise NotEnoughData( + "encoded length is longer than data", + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + if not lenindef and l == 0: + raise NotEnoughData( + "zero length", + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + chunks = [] + sub_offset = offset + tlen + llen + vlen = 0 + while True: + if lenindef: + if v[:EOC_LEN].tobytes() == EOC: + break + else: + if vlen == l: + break + if vlen > l: + raise DecodeError( + msg="chunk out of bounds", + decode_path=len(chunks) - 1, + offset=chunks[-1].offset, + ) + sub_decode_path = decode_path + (str(len(chunks)),) + try: + chunk, v_tail = BitString().decode( + v, + offset=sub_offset, + decode_path=sub_decode_path, + leavemm=True, + ctx=ctx, + ) + except TagMismatch: + raise DecodeError( + msg="expected BitString encoded chunk", + decode_path=sub_decode_path, + offset=sub_offset, + ) + chunks.append(chunk) + sub_offset += chunk.tlvlen + vlen += chunk.tlvlen + v = v_tail + if len(chunks) == 0: + raise DecodeError( + msg="no chunks", + decode_path=decode_path, + offset=offset, + ) + values = [] + bit_len = 0 + for chunk_i, chunk in enumerate(chunks[:-1]): + if chunk.bit_len % 8 != 0: + raise DecodeError( + msg="BitString chunk is not multiple of 8 bit", + decode_path=decode_path + (str(chunk_i),), + offset=chunk.offset, + ) + values.append(bytes(chunk)) + bit_len += chunk.bit_len + chunk_last = chunks[-1] + values.append(bytes(chunk_last)) + bit_len += chunk_last.bit_len + obj = self.__class__( + value=(bit_len, b"".join(values)), + impl=self.tag, + expl=self._expl, + default=self.default, + optional=self.optional, + _specs=self.specs, + _decoded=(offset, llen, vlen + (EOC_LEN if lenindef else 0)), + ) + obj.lenindef = lenindef + obj.bered = True + return obj, (v[EOC_LEN:] if lenindef else v) + raise TagMismatch( + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + def __repr__(self): return pp_console_row(next(self.pps())) @@ -2086,7 +2276,7 @@ class OctetString(Obj): >>> OctetString(b"hell", bounds=(4, 4)) OCTET STRING 4 bytes 68656c6c """ - __slots__ = ("_bound_min", "_bound_max", "defined") + __slots__ = ("tag_constructed", "_bound_min", "_bound_max", "defined") tag_default = tag_encode(4) asn1_type_name = "OCTET STRING" @@ -2135,6 +2325,12 @@ class OctetString(Obj): if self._value is None: self._value = default self.defined = None + tag_klass, _, tag_num = tag_decode(self.tag) + self.tag_constructed = tag_encode( + klass=tag_klass, + form=TagFormConstructed, + num=tag_num, + ) def _value_sanitize(self, value): if issubclass(value.__class__, OctetString): @@ -2212,22 +2408,7 @@ class OctetString(Obj): self._value, )) - def _decode(self, tlv, offset, decode_path, ctx): - try: - t, _, lv = tag_strip(tlv) - except DecodeError as err: - raise err.__class__( - msg=err.msg, - klass=self.__class__, - decode_path=decode_path, - offset=offset, - ) - if t != self.tag: - raise TagMismatch( - klass=self.__class__, - decode_path=decode_path, - offset=offset, - ) + def _decode_chunk(self, lv, offset, decode_path, ctx): try: l, llen, v = len_decode(lv) except DecodeError as err: @@ -2255,6 +2436,13 @@ class OctetString(Obj): optional=self.optional, _decoded=(offset, llen, l), ) + except DecodeError as err: + raise DecodeError( + msg=err.msg, + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) except BoundsError as err: raise DecodeError( msg=str(err), @@ -2264,6 +2452,130 @@ class OctetString(Obj): ) return obj, tail + def _decode(self, tlv, offset, decode_path, ctx, tag_only): + try: + t, tlen, lv = tag_strip(tlv) + except DecodeError as err: + raise err.__class__( + msg=err.msg, + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + if t == self.tag: + if tag_only: + return + return self._decode_chunk(lv, offset, decode_path, ctx) + if t == self.tag_constructed: + if not ctx.get("bered", False): + raise DecodeError( + msg="unallowed BER constructed encoding", + decode_path=decode_path, + offset=offset, + ) + if tag_only: + return + lenindef = False + try: + l, llen, v = len_decode(lv) + except LenIndefiniteForm: + llen, l, v = 1, 0, lv[1:] + lenindef = True + except DecodeError as err: + raise err.__class__( + msg=err.msg, + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + if l > 0 and l > len(v): + raise NotEnoughData( + "encoded length is longer than data", + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + if not lenindef and l == 0: + raise NotEnoughData( + "zero length", + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + chunks = [] + sub_offset = offset + tlen + llen + vlen = 0 + while True: + if lenindef: + if v[:EOC_LEN].tobytes() == EOC: + break + else: + if vlen == l: + break + if vlen > l: + raise DecodeError( + msg="chunk out of bounds", + decode_path=len(chunks) - 1, + offset=chunks[-1].offset, + ) + sub_decode_path = decode_path + (str(len(chunks)),) + try: + chunk, v_tail = OctetString().decode( + v, + offset=sub_offset, + decode_path=sub_decode_path, + leavemm=True, + ctx=ctx, + ) + except TagMismatch: + raise DecodeError( + msg="expected OctetString encoded chunk", + decode_path=sub_decode_path, + offset=sub_offset, + ) + chunks.append(chunk) + sub_offset += chunk.tlvlen + vlen += chunk.tlvlen + v = v_tail + if len(chunks) == 0: + raise DecodeError( + msg="no chunks", + decode_path=decode_path, + offset=offset, + ) + try: + obj = self.__class__( + value=b"".join(bytes(chunk) for chunk in chunks), + bounds=(self._bound_min, self._bound_max), + impl=self.tag, + expl=self._expl, + default=self.default, + optional=self.optional, + _decoded=(offset, llen, vlen + (EOC_LEN if lenindef else 0)), + ) + except DecodeError as err: + raise DecodeError( + msg=err.msg, + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + except BoundsError as err: + raise DecodeError( + msg=str(err), + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + obj.lenindef = lenindef + obj.bered = True + return obj, (v[EOC_LEN:] if lenindef else v) + raise TagMismatch( + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + def __repr__(self): return pp_console_row(next(self.pps())) @@ -2361,7 +2673,7 @@ class Null(Obj): def _encode(self): return self.tag + len_encode(0) - def _decode(self, tlv, offset, decode_path, ctx): + def _decode(self, tlv, offset, decode_path, ctx, tag_only): try: t, _, lv = tag_strip(tlv) except DecodeError as err: @@ -2377,6 +2689,8 @@ class Null(Obj): decode_path=decode_path, offset=offset, ) + if tag_only: + return try: l, _, v = len_decode(lv) except DecodeError as err: @@ -2606,7 +2920,7 @@ class ObjectIdentifier(Obj): v = b"".join(octets) return b"".join((self.tag, len_encode(len(v)), v)) - def _decode(self, tlv, offset, decode_path, ctx): + def _decode(self, tlv, offset, decode_path, ctx, tag_only): try: t, _, lv = tag_strip(tlv) except DecodeError as err: @@ -2622,6 +2936,8 @@ class ObjectIdentifier(Obj): decode_path=decode_path, offset=offset, ) + if tag_only: + return try: l, llen, v = len_decode(lv) except DecodeError as err: @@ -2811,7 +3127,7 @@ class CommonString(OctetString): >>> PrintableString("привет мир") Traceback (most recent call last): - UnicodeEncodeError: 'ascii' codec can't encode characters in position 0-5: ordinal not in range(128) + pyderasn.DecodeError: 'ascii' codec can't encode characters in position 0-5: ordinal not in range(128) >>> BMPString("ада", bounds=(2, 2)) Traceback (most recent call last): @@ -2867,14 +3183,17 @@ class CommonString(OctetString): value_raw = value else: raise InvalidValueType((self.__class__, text_type, binary_type)) - value_raw = ( - value_decoded.encode(self.encoding) - if value_raw is None else value_raw - ) - value_decoded = ( - value_raw.decode(self.encoding) - if value_decoded is None else value_decoded - ) + try: + value_raw = ( + value_decoded.encode(self.encoding) + if value_raw is None else value_raw + ) + value_decoded = ( + value_raw.decode(self.encoding) + if value_decoded is None else value_decoded + ) + except (UnicodeEncodeError, UnicodeDecodeError) as err: + raise DecodeError(str(err)) if not self._bound_min <= len(value_decoded) <= self._bound_max: raise BoundsError( self._bound_min, @@ -2936,6 +3255,13 @@ class NumericString(CommonString): tag_default = tag_encode(18) encoding = "ascii" asn1_type_name = "NumericString" + allowable_chars = set(digits.encode("ascii")) + + def _value_sanitize(self, value): + value = super(NumericString, self)._value_sanitize(value) + if not set(value) <= self.allowable_chars: + raise DecodeError("non-numeric value") + return value class PrintableString(CommonString): @@ -3210,8 +3536,8 @@ class Choice(Obj): class GeneralName(Choice): schema = ( - ('rfc822Name', IA5String(impl=tag_ctxp(1))), - ('dNSName', IA5String(impl=tag_ctxp(2))), + ("rfc822Name", IA5String(impl=tag_ctxp(1))), + ("dNSName", IA5String(impl=tag_ctxp(2))), ) >>> gn = GeneralName() @@ -3373,32 +3699,45 @@ class Choice(Obj): self._assert_ready() return self._value[1].encode() - def _decode(self, tlv, offset, decode_path, ctx): + def _decode(self, tlv, offset, decode_path, ctx, tag_only): for choice, spec in self.specs.items(): + sub_decode_path = decode_path + (choice,) try: - value, tail = spec.decode( + spec.decode( tlv, offset=offset, leavemm=True, - decode_path=decode_path + (choice,), + decode_path=sub_decode_path, ctx=ctx, + tag_only=True, ) except TagMismatch: continue - obj = self.__class__( - schema=self.specs, - expl=self._expl, - default=self.default, - optional=self.optional, - _decoded=(offset, 0, value.tlvlen), + break + else: + raise TagMismatch( + klass=self.__class__, + decode_path=decode_path, + offset=offset, ) - obj._value = (choice, value) - return obj, tail - raise TagMismatch( - klass=self.__class__, - decode_path=decode_path, + if tag_only: + return + value, tail = spec.decode( + tlv, offset=offset, + leavemm=True, + decode_path=sub_decode_path, + ctx=ctx, + ) + obj = self.__class__( + schema=self.specs, + expl=self._expl, + default=self.default, + optional=self.optional, + _decoded=(offset, 0, value.tlvlen), ) + obj._value = (choice, value) + return obj, tail def __repr__(self): value = pp_console_row(next(self.pps())) @@ -3548,10 +3887,52 @@ class Any(Obj): self._assert_ready() return self._value - def _decode(self, tlv, offset, decode_path, ctx): + def _decode(self, tlv, offset, decode_path, ctx, tag_only): try: t, tlen, lv = tag_strip(tlv) + except DecodeError as err: + raise err.__class__( + msg=err.msg, + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + try: l, llen, v = len_decode(lv) + except LenIndefiniteForm as err: + if not ctx.get("bered", False): + raise err.__class__( + msg=err.msg, + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + llen, vlen, v = 1, 0, lv[1:] + sub_offset = offset + tlen + llen + chunk_i = 0 + while True: + if v[:EOC_LEN].tobytes() == EOC: + tlvlen = tlen + llen + vlen + EOC_LEN + obj = self.__class__( + value=tlv[:tlvlen].tobytes(), + expl=self._expl, + optional=self.optional, + _decoded=(offset, 0, tlvlen), + ) + obj.lenindef = True + obj.tag = t + return obj, v[EOC_LEN:] + else: + chunk, v = Any().decode( + v, + offset=sub_offset, + decode_path=decode_path + (str(chunk_i),), + leavemm=True, + ctx=ctx, + ) + vlen += chunk.tlvlen + sub_offset += chunk.tlvlen + chunk_i += 1 except DecodeError as err: raise err.__class__( msg=err.msg, @@ -3675,7 +4056,7 @@ class Sequence(Obj): pyderasn.InvalidValueType: invalid value type, expected: >>> ext["extnID"] = ObjectIdentifier("1.2.3") - You can know if sequence is ready to be encoded: + You can determine if sequence is ready to be encoded: >>> ext.ready False @@ -3701,7 +4082,15 @@ class Sequence(Obj): Assign ``None`` to remove value from sequence. - You can know if value exists/set in the sequence and take its value: + You can set values in Sequence during its initialization: + + >>> AlgorithmIdentifier(( + ("algorithm", ObjectIdentifier("1.2.3")), + ("parameters", Any(Null())) + )) + AlgorithmIdentifier SEQUENCE[OBJECT IDENTIFIER 1.2.3, ANY 0500 OPTIONAL] + + You can determine if value exists/set in the sequence and take its value: >>> "extnID" in ext, "extnValue" in ext, "critical" in ext (True, True, False) @@ -3758,9 +4147,17 @@ class Sequence(Obj): ) self._value = {} if value is not None: - self._value = self._value_sanitize(value) + if issubclass(value.__class__, Sequence): + self._value = value._value + elif hasattr(value, "__iter__"): + for seq_key, seq_value in value: + self[seq_key] = seq_value + else: + raise InvalidValueType((Sequence,)) if default is not None: - default_value = self._value_sanitize(default) + if not issubclass(default.__class__, Sequence): + raise InvalidValueType((Sequence,)) + default_value = default._value default_obj = self.__class__(impl=self.tag, expl=self._expl) default_obj.specs = self.specs default_obj._value = default_value @@ -3768,11 +4165,6 @@ class Sequence(Obj): if value is None: self._value = default_obj.copy()._value - def _value_sanitize(self, value): - if not issubclass(value.__class__, Sequence): - raise InvalidValueType((Sequence,)) - return value._value - @property def ready(self): for name, spec in self.specs.items(): @@ -3869,7 +4261,7 @@ class Sequence(Obj): v = b"".join(self._encoded_values()) return b"".join((self.tag, len_encode(len(v)), v)) - def _decode(self, tlv, offset, decode_path, ctx): + def _decode(self, tlv, offset, decode_path, ctx, tag_only): try: t, tlen, lv = tag_strip(tlv) except DecodeError as err: @@ -3885,8 +4277,21 @@ class Sequence(Obj): decode_path=decode_path, offset=offset, ) + if tag_only: + return + lenindef = False try: l, llen, v = len_decode(lv) + except LenIndefiniteForm as err: + if not ctx.get("bered", False): + raise err.__class__( + msg=err.msg, + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + l, llen, v = 0, 1, lv[1:] + lenindef = True except DecodeError as err: raise err.__class__( msg=err.msg, @@ -3901,11 +4306,16 @@ class Sequence(Obj): decode_path=decode_path, offset=offset, ) - v, tail = v[:l], v[l:] + if not lenindef: + v, tail = v[:l], v[l:] + vlen = 0 sub_offset = offset + tlen + llen values = {} for name, spec in self.specs.items(): - if len(v) == 0 and spec.optional: + if spec.optional and ( + (lenindef and v[:EOC_LEN].tobytes() == EOC) or + len(v) == 0 + ): continue sub_decode_path = decode_path + (name,) try: @@ -3968,7 +4378,9 @@ class Sequence(Obj): ) value.defined = (defined_by, defined_value) - sub_offset += (value.expl_tlvlen if value.expled else value.tlvlen) + value_len = value.expl_tlvlen if value.expled else value.tlvlen + vlen += value_len + sub_offset += value_len v = v_tail if spec.default is not None and value == spec.default: if ctx.get("strict_default_existence", False): @@ -3995,7 +4407,17 @@ class Sequence(Obj): abs_decode_path(sub_decode_path[:-1], rel_path), (value, defined), )) - if len(v) > 0: + if lenindef: + if v[:EOC_LEN].tobytes() != EOC: + raise DecodeError( + "no EOC", + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + tail = v[EOC_LEN:] + vlen += EOC_LEN + elif len(v) > 0: raise DecodeError( "remaining data", klass=self.__class__, @@ -4008,9 +4430,10 @@ class Sequence(Obj): expl=self._expl, default=self.default, optional=self.optional, - _decoded=(offset, llen, l), + _decoded=(offset, llen, vlen), ) obj._value = values + obj.lenindef = lenindef return obj, tail def __repr__(self): @@ -4063,7 +4486,7 @@ class Set(Sequence): v = b"".join(raws) return b"".join((self.tag, len_encode(len(v)), v)) - def _decode(self, tlv, offset, decode_path, ctx): + def _decode(self, tlv, offset, decode_path, ctx, tag_only): try: t, tlen, lv = tag_strip(tlv) except DecodeError as err: @@ -4079,8 +4502,21 @@ class Set(Sequence): decode_path=decode_path, offset=offset, ) + if tag_only: + return + lenindef = False try: l, llen, v = len_decode(lv) + except LenIndefiniteForm as err: + if not ctx.get("bered", False): + raise err.__class__( + msg=err.msg, + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + l, llen, v = 0, 1, lv[1:] + lenindef = True except DecodeError as err: raise err.__class__( msg=err.msg, @@ -4094,29 +4530,28 @@ class Set(Sequence): klass=self.__class__, offset=offset, ) - v, tail = v[:l], v[l:] + if not lenindef: + v, tail = v[:l], v[l:] + vlen = 0 sub_offset = offset + tlen + llen values = {} specs_items = self.specs.items while len(v) > 0: + if lenindef and v[:EOC_LEN].tobytes() == EOC: + break for name, spec in specs_items(): + sub_decode_path = decode_path + (name,) try: - value, v_tail = spec.decode( + spec.decode( v, sub_offset, leavemm=True, - decode_path=decode_path + (name,), + decode_path=sub_decode_path, ctx=ctx, + tag_only=True, ) except TagMismatch: continue - sub_offset += ( - value.expl_tlvlen if value.expled else value.tlvlen - ) - v = v_tail - if spec.default is None or value != spec.default: # pragma: no cover - # SeqMixing.test_encoded_default_accepted covers that place - values[name] = value break else: raise TagMismatch( @@ -4124,16 +4559,38 @@ class Set(Sequence): decode_path=decode_path, offset=offset, ) + value, v_tail = spec.decode( + v, + sub_offset, + leavemm=True, + decode_path=sub_decode_path, + ctx=ctx, + ) + value_len = value.expl_tlvlen if value.expled else value.tlvlen + sub_offset += value_len + vlen += value_len + v = v_tail + if spec.default is None or value != spec.default: # pragma: no cover + # SeqMixing.test_encoded_default_accepted covers that place + values[name] = value obj = self.__class__( schema=self.specs, impl=self.tag, expl=self._expl, default=self.default, optional=self.optional, - _decoded=(offset, llen, l), + _decoded=(offset, llen, vlen + (EOC_LEN if lenindef else 0)), ) obj._value = values - return obj, tail + if not obj.ready: + raise DecodeError( + msg="not all values are ready", + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + obj.lenindef = lenindef + return obj, (v[EOC_LEN:] if lenindef else tail) class SequenceOf(Obj): @@ -4316,7 +4773,7 @@ class SequenceOf(Obj): v = b"".join(self._encoded_values()) return b"".join((self.tag, len_encode(len(v)), v)) - def _decode(self, tlv, offset, decode_path, ctx): + def _decode(self, tlv, offset, decode_path, ctx, tag_only): try: t, tlen, lv = tag_strip(tlv) except DecodeError as err: @@ -4332,8 +4789,21 @@ class SequenceOf(Obj): decode_path=decode_path, offset=offset, ) + if tag_only: + return + lenindef = False try: l, llen, v = len_decode(lv) + except LenIndefiniteForm as err: + if not ctx.get("bered", False): + raise err.__class__( + msg=err.msg, + klass=self.__class__, + decode_path=decode_path, + offset=offset, + ) + l, llen, v = 0, 1, lv[1:] + lenindef = True except DecodeError as err: raise err.__class__( msg=err.msg, @@ -4348,11 +4818,15 @@ class SequenceOf(Obj): decode_path=decode_path, offset=offset, ) - v, tail = v[:l], v[l:] + if not lenindef: + v, tail = v[:l], v[l:] + vlen = 0 sub_offset = offset + tlen + llen _value = [] spec = self.spec while len(v) > 0: + if lenindef and v[:EOC_LEN].tobytes() == EOC: + break value, v_tail = spec.decode( v, sub_offset, @@ -4360,7 +4834,9 @@ class SequenceOf(Obj): decode_path=decode_path + (str(len(_value)),), ctx=ctx, ) - sub_offset += (value.expl_tlvlen if value.expled else value.tlvlen) + value_len = value.expl_tlvlen if value.expled else value.tlvlen + sub_offset += value_len + vlen += value_len v = v_tail _value.append(value) obj = self.__class__( @@ -4371,9 +4847,10 @@ class SequenceOf(Obj): expl=self._expl, default=self.default, optional=self.optional, - _decoded=(offset, llen, l), + _decoded=(offset, llen, vlen), ) - return obj, tail + obj.lenindef = lenindef + return obj, (v[EOC_LEN:] if lenindef else tail) def __repr__(self): return "%s[%s]" % ( @@ -4516,13 +4993,10 @@ def main(): # pragma: no cover pprinter = partial(pprint, big_blobs=True) else: schema, pprinter = generic_decoder() - obj, tail = schema().decode( - der, - ctx=( - None if args.defines_by_path is None else - {"defines_by_path": obj_by_path(args.defines_by_path)} - ), - ) + ctx = {"bered": True} + if args.defines_by_path is not None: + ctx["defines_by_path"] = obj_by_path(args.defines_by_path) + obj, tail = schema().decode(der, ctx=ctx) print(pprinter( obj, oids=oids,