]> Cypherpunks.ru repositories - pyderasn.git/blobdiff - pyderasn.py
Ability to set values during Sequence initialization
[pyderasn.git] / pyderasn.py
index e69b19056ce010fa0ea09e4a111fd68964c34c95..68d67d07e91a746886d22bb629d19a8ae6000f43 100755 (executable)
@@ -269,11 +269,11 @@ for AlgorithmIdentifier of X.509's
 ``tbsCertificate.subjectPublicKeyInfo.algorithm.algorithm``::
 
         (
-            (('parameters',), {
+            (("parameters",), {
                 id_ecPublicKey: ECParameters(),
                 id_GostR3410_2001: GostR34102001PublicKeyParameters(),
             }),
-            (('..', 'subjectPublicKey'), {
+            (("..", "subjectPublicKey"), {
                 id_rsaEncryption: RSAPublicKey(),
                 id_GostR3410_2001: OctetString(),
             }),
@@ -289,7 +289,7 @@ Following types can be automatically decoded (DEFINED BY):
 * :py:class:`pyderasn.BitString` (that is multiple of 8 bits)
 * :py:class:`pyderasn.OctetString`
 * :py:class:`pyderasn.SequenceOf`/:py:class:`pyderasn.SetOf`
-  ``Any``/``OctetString``-s
+  ``Any``/``BitString``/``OctetString``-s
 
 When any of those fields is automatically decoded, then ``.defined``
 attribute contains ``(OID, value)`` tuple. ``OID`` tells by which OID it
@@ -472,6 +472,8 @@ from collections import namedtuple
 from collections import OrderedDict
 from datetime import datetime
 from math import ceil
+from os import environ
+from string import digits
 
 from six import add_metaclass
 from six import binary_type
@@ -588,7 +590,7 @@ class DecodeError(Exception):
             c for c in (
                 "" if self.klass is None else self.klass.__name__,
                 (
-                    ("(%s)" % ".".join(self.decode_path))
+                    ("(%s)" % ".".join(str(dp) for dp in self.decode_path))
                     if len(self.decode_path) > 0 else ""
                 ),
                 ("(at %d)" % self.offset) if self.offset > 0 else "",
@@ -910,7 +912,7 @@ class Obj(object):
     def _encode(self):  # pragma: no cover
         raise NotImplementedError()
 
-    def _decode(self, tlv, offset, decode_path, ctx):  # pragma: no cover
+    def _decode(self, tlv, offset, decode_path, ctx, tag_only):  # pragma: no cover
         raise NotImplementedError()
 
     def encode(self):
@@ -919,7 +921,15 @@ class Obj(object):
             return raw
         return b"".join((self._expl, len_encode(len(raw)), raw))
 
-    def decode(self, data, offset=0, leavemm=False, decode_path=(), ctx=None):
+    def decode(
+            self,
+            data,
+            offset=0,
+            leavemm=False,
+            decode_path=(),
+            ctx=None,
+            tag_only=False,
+    ):
         """Decode the data
 
         :param data: either binary or memoryview
@@ -927,18 +937,25 @@ class Obj(object):
         :param bool leavemm: do we need to leave memoryview of remaining
                     data as is, or convert it to bytes otherwise
         :param ctx: optional :ref:`context <ctx>` governing decoding process.
+        :param tag_only: decode only the tag, without length and contents
+                         (used only in Choice and Set structures, trying to
+                         determine if tag satisfies the scheme)
         :returns: (Obj, remaining data)
         """
         if ctx is None:
             ctx = {}
         tlv = memoryview(data)
         if self._expl is None:
-            obj, tail = self._decode(
+            result = self._decode(
                 tlv,
                 offset,
                 decode_path=decode_path,
                 ctx=ctx,
+                tag_only=tag_only,
             )
+            if tag_only:
+                return
+            obj, tail = result
         else:
             try:
                 t, tlen, lv = tag_strip(tlv)
@@ -971,12 +988,16 @@ class Obj(object):
                     decode_path=decode_path,
                     offset=offset,
                 )
-            obj, tail = self._decode(
+            result = self._decode(
                 v,
                 offset=offset + tlen + llen,
                 decode_path=decode_path,
                 ctx=ctx,
+                tag_only=tag_only,
             )
+            if tag_only:
+                return
+            obj, tail = result
         return obj, (tail if leavemm else tail.tobytes())
 
     @property
@@ -1011,11 +1032,14 @@ class Obj(object):
 class DecodePathDefBy(object):
     """DEFINED BY representation inside decode path
     """
-    __slots__ = ('defined_by',)
+    __slots__ = ("defined_by",)
 
     def __init__(self, defined_by):
         self.defined_by = defined_by
 
+    def __ne__(self, their):
+        return not(self == their)
+
     def __eq__(self, their):
         if not isinstance(their, self.__class__):
             return False
@@ -1130,9 +1154,9 @@ def pp_console_row(
             ):
                 cols.append(_colorize("%s:" % oids[value], "green", with_colours))
             else:
-                cols.append(_colorize("%s:" % value, "white", with_colours))
+                cols.append(_colorize("%s:" % value, "white", with_colours, ("reverse",)))
         else:
-            cols.append(_colorize("%s:" % ent, "yellow", with_colours))
+            cols.append(_colorize("%s:" % ent, "yellow", with_colours, ("reverse",)))
     if pp.expl is not None:
         klass, _, num = pp.expl
         col = "[%s%d] EXPLICIT" % (TagClassReprs[klass], num)
@@ -1146,7 +1170,7 @@ def pp_console_row(
     cols.append(_colorize(pp.asn1_type_name, "cyan", with_colours))
     if pp.value is not None:
         value = pp.value
-        cols.append(_colorize(value, "white", with_colours))
+        cols.append(_colorize(value, "white", with_colours, ("reverse",)))
         if (
                 oids is not None and
                 pp.asn1_type_name == ObjectIdentifier.asn1_type_name and
@@ -1332,7 +1356,7 @@ class Boolean(Obj):
             (b"\xFF" if self._value else b"\x00"),
         ))
 
-    def _decode(self, tlv, offset, decode_path, ctx):
+    def _decode(self, tlv, offset, decode_path, ctx, tag_only):
         try:
             t, _, lv = tag_strip(tlv)
         except DecodeError as err:
@@ -1348,6 +1372,8 @@ class Boolean(Obj):
                 decode_path=decode_path,
                 offset=offset,
             )
+        if tag_only:
+            return
         try:
             l, _, v = len_decode(lv)
         except DecodeError as err:
@@ -1625,7 +1651,7 @@ class Integer(Obj):
                     break
         return b"".join((self.tag, len_encode(len(octets)), octets))
 
-    def _decode(self, tlv, offset, decode_path, ctx):
+    def _decode(self, tlv, offset, decode_path, ctx, tag_only):
         try:
             t, _, lv = tag_strip(tlv)
         except DecodeError as err:
@@ -1641,6 +1667,8 @@ class Integer(Obj):
                 decode_path=decode_path,
                 offset=offset,
             )
+        if tag_only:
+            return
         try:
             l, llen, v = len_decode(lv)
         except DecodeError as err:
@@ -1764,12 +1792,12 @@ class BitString(Obj):
 
         class KeyUsage(BitString):
             schema = (
-                ('digitalSignature', 0),
-                ('nonRepudiation', 1),
-                ('keyEncipherment', 2),
+                ("digitalSignature", 0),
+                ("nonRepudiation", 1),
+                ("keyEncipherment", 2),
             )
 
-    >>> b = KeyUsage(('keyEncipherment', 'nonRepudiation'))
+    >>> b = KeyUsage(("keyEncipherment", "nonRepudiation"))
     KeyUsage BIT STRING 3 bits nonRepudiation, keyEncipherment
     >>> b.named
     ['nonRepudiation', 'keyEncipherment']
@@ -1959,7 +1987,7 @@ class BitString(Obj):
             octets,
         ))
 
-    def _decode(self, tlv, offset, decode_path, ctx):
+    def _decode(self, tlv, offset, decode_path, ctx, tag_only):
         try:
             t, _, lv = tag_strip(tlv)
         except DecodeError as err:
@@ -1975,6 +2003,8 @@ class BitString(Obj):
                 decode_path=decode_path,
                 offset=offset,
             )
+        if tag_only:
+            return
         try:
             l, llen, v = len_decode(lv)
         except DecodeError as err:
@@ -2013,7 +2043,7 @@ class BitString(Obj):
                 decode_path=decode_path,
                 offset=offset,
             )
-        if byte2int(v[-1:]) & ((1 << pad_size) - 1) != 0:
+        if byte2int(v[l - 1:l]) & ((1 << pad_size) - 1) != 0:
             raise DecodeError(
                 "invalid pad",
                 klass=self.__class__,
@@ -2211,7 +2241,7 @@ class OctetString(Obj):
             self._value,
         ))
 
-    def _decode(self, tlv, offset, decode_path, ctx):
+    def _decode(self, tlv, offset, decode_path, ctx, tag_only):
         try:
             t, _, lv = tag_strip(tlv)
         except DecodeError as err:
@@ -2227,6 +2257,8 @@ class OctetString(Obj):
                 decode_path=decode_path,
                 offset=offset,
             )
+        if tag_only:
+            return
         try:
             l, llen, v = len_decode(lv)
         except DecodeError as err:
@@ -2254,6 +2286,13 @@ class OctetString(Obj):
                 optional=self.optional,
                 _decoded=(offset, llen, l),
             )
+        except DecodeError as err:
+            raise DecodeError(
+                msg=err.msg,
+                klass=self.__class__,
+                decode_path=decode_path,
+                offset=offset,
+            )
         except BoundsError as err:
             raise DecodeError(
                 msg=str(err),
@@ -2360,7 +2399,7 @@ class Null(Obj):
     def _encode(self):
         return self.tag + len_encode(0)
 
-    def _decode(self, tlv, offset, decode_path, ctx):
+    def _decode(self, tlv, offset, decode_path, ctx, tag_only):
         try:
             t, _, lv = tag_strip(tlv)
         except DecodeError as err:
@@ -2376,6 +2415,8 @@ class Null(Obj):
                 decode_path=decode_path,
                 offset=offset,
             )
+        if tag_only:
+            return
         try:
             l, _, v = len_decode(lv)
         except DecodeError as err:
@@ -2605,7 +2646,7 @@ class ObjectIdentifier(Obj):
         v = b"".join(octets)
         return b"".join((self.tag, len_encode(len(v)), v))
 
-    def _decode(self, tlv, offset, decode_path, ctx):
+    def _decode(self, tlv, offset, decode_path, ctx, tag_only):
         try:
             t, _, lv = tag_strip(tlv)
         except DecodeError as err:
@@ -2621,6 +2662,8 @@ class ObjectIdentifier(Obj):
                 decode_path=decode_path,
                 offset=offset,
             )
+        if tag_only:
+            return
         try:
             l, llen, v = len_decode(lv)
         except DecodeError as err:
@@ -2810,7 +2853,7 @@ class CommonString(OctetString):
 
     >>> PrintableString("привет мир")
     Traceback (most recent call last):
-    UnicodeEncodeError: 'ascii' codec can't encode characters in position 0-5: ordinal not in range(128)
+    pyderasn.DecodeError: 'ascii' codec can't encode characters in position 0-5: ordinal not in range(128)
 
     >>> BMPString("ада", bounds=(2, 2))
     Traceback (most recent call last):
@@ -2866,14 +2909,17 @@ class CommonString(OctetString):
             value_raw = value
         else:
             raise InvalidValueType((self.__class__, text_type, binary_type))
-        value_raw = (
-            value_decoded.encode(self.encoding)
-            if value_raw is None else value_raw
-        )
-        value_decoded = (
-            value_raw.decode(self.encoding)
-            if value_decoded is None else value_decoded
-        )
+        try:
+            value_raw = (
+                value_decoded.encode(self.encoding)
+                if value_raw is None else value_raw
+            )
+            value_decoded = (
+                value_raw.decode(self.encoding)
+                if value_decoded is None else value_decoded
+            )
+        except (UnicodeEncodeError, UnicodeDecodeError) as err:
+            raise DecodeError(str(err))
         if not self._bound_min <= len(value_decoded) <= self._bound_max:
             raise BoundsError(
                 self._bound_min,
@@ -2935,6 +2981,13 @@ class NumericString(CommonString):
     tag_default = tag_encode(18)
     encoding = "ascii"
     asn1_type_name = "NumericString"
+    allowable_chars = set(digits.encode("ascii"))
+
+    def _value_sanitize(self, value):
+        value = super(NumericString, self)._value_sanitize(value)
+        if not set(value) <= self.allowable_chars:
+            raise DecodeError("non-numeric value")
+        return value
 
 
 class PrintableString(CommonString):
@@ -3209,8 +3262,8 @@ class Choice(Obj):
 
         class GeneralName(Choice):
             schema = (
-                ('rfc822Name', IA5String(impl=tag_ctxp(1))),
-                ('dNSName', IA5String(impl=tag_ctxp(2))),
+                ("rfc822Name", IA5String(impl=tag_ctxp(1))),
+                ("dNSName", IA5String(impl=tag_ctxp(2))),
             )
 
     >>> gn = GeneralName()
@@ -3372,32 +3425,45 @@ class Choice(Obj):
         self._assert_ready()
         return self._value[1].encode()
 
-    def _decode(self, tlv, offset, decode_path, ctx):
+    def _decode(self, tlv, offset, decode_path, ctx, tag_only):
         for choice, spec in self.specs.items():
+            sub_decode_path = decode_path + (choice,)
             try:
-                value, tail = spec.decode(
+                spec.decode(
                     tlv,
                     offset=offset,
                     leavemm=True,
-                    decode_path=decode_path + (choice,),
+                    decode_path=sub_decode_path,
                     ctx=ctx,
+                    tag_only=True,
                 )
             except TagMismatch:
                 continue
-            obj = self.__class__(
-                schema=self.specs,
-                expl=self._expl,
-                default=self.default,
-                optional=self.optional,
-                _decoded=(offset, 0, value.tlvlen),
+            break
+        else:
+            raise TagMismatch(
+                klass=self.__class__,
+                decode_path=decode_path,
+                offset=offset,
             )
-            obj._value = (choice, value)
-            return obj, tail
-        raise TagMismatch(
-            klass=self.__class__,
-            decode_path=decode_path,
+        if tag_only:
+            return
+        value, tail = spec.decode(
+            tlv,
             offset=offset,
+            leavemm=True,
+            decode_path=sub_decode_path,
+            ctx=ctx,
         )
+        obj = self.__class__(
+            schema=self.specs,
+            expl=self._expl,
+            default=self.default,
+            optional=self.optional,
+            _decoded=(offset, 0, value.tlvlen),
+        )
+        obj._value = (choice, value)
+        return obj, tail
 
     def __repr__(self):
         value = pp_console_row(next(self.pps()))
@@ -3547,7 +3613,7 @@ class Any(Obj):
         self._assert_ready()
         return self._value
 
-    def _decode(self, tlv, offset, decode_path, ctx):
+    def _decode(self, tlv, offset, decode_path, ctx, tag_only):
         try:
             t, tlen, lv = tag_strip(tlv)
             l, llen, v = len_decode(lv)
@@ -3674,7 +3740,7 @@ class Sequence(Obj):
     pyderasn.InvalidValueType: invalid value type, expected: <class 'pyderasn.ObjectIdentifier'>
     >>> ext["extnID"] = ObjectIdentifier("1.2.3")
 
-    You can know if sequence is ready to be encoded:
+    You can determine if sequence is ready to be encoded:
 
     >>> ext.ready
     False
@@ -3700,7 +3766,15 @@ class Sequence(Obj):
 
     Assign ``None`` to remove value from sequence.
 
-    You can know if value exists/set in the sequence and take its value:
+    You can set values in Sequence during its initialization:
+
+    >>> AlgorithmIdentifier((
+        ("algorithm", ObjectIdentifier("1.2.3")),
+        ("parameters", Any(Null()))
+    ))
+    AlgorithmIdentifier SEQUENCE[OBJECT IDENTIFIER 1.2.3, ANY 0500 OPTIONAL]
+
+    You can determine if value exists/set in the sequence and take its value:
 
     >>> "extnID" in ext, "extnValue" in ext, "critical" in ext
     (True, True, False)
@@ -3757,9 +3831,17 @@ class Sequence(Obj):
         )
         self._value = {}
         if value is not None:
-            self._value = self._value_sanitize(value)
+            if issubclass(value.__class__, Sequence):
+                self._value = value._value
+            elif hasattr(value, "__iter__"):
+                for seq_key, seq_value in value:
+                    self[seq_key] = seq_value
+            else:
+                raise InvalidValueType((Sequence,))
         if default is not None:
-            default_value = self._value_sanitize(default)
+            if not issubclass(default.__class__, Sequence):
+                raise InvalidValueType((Sequence,))
+            default_value = default._value
             default_obj = self.__class__(impl=self.tag, expl=self._expl)
             default_obj.specs = self.specs
             default_obj._value = default_value
@@ -3767,11 +3849,6 @@ class Sequence(Obj):
             if value is None:
                 self._value = default_obj.copy()._value
 
-    def _value_sanitize(self, value):
-        if not issubclass(value.__class__, Sequence):
-            raise InvalidValueType((Sequence,))
-        return value._value
-
     @property
     def ready(self):
         for name, spec in self.specs.items():
@@ -3868,7 +3945,7 @@ class Sequence(Obj):
         v = b"".join(self._encoded_values())
         return b"".join((self.tag, len_encode(len(v)), v))
 
-    def _decode(self, tlv, offset, decode_path, ctx):
+    def _decode(self, tlv, offset, decode_path, ctx, tag_only):
         try:
             t, tlen, lv = tag_strip(tlv)
         except DecodeError as err:
@@ -3884,6 +3961,8 @@ class Sequence(Obj):
                 decode_path=decode_path,
                 offset=offset,
             )
+        if tag_only:
+            return
         try:
             l, llen, v = len_decode(lv)
         except DecodeError as err:
@@ -4062,7 +4141,7 @@ class Set(Sequence):
         v = b"".join(raws)
         return b"".join((self.tag, len_encode(len(v)), v))
 
-    def _decode(self, tlv, offset, decode_path, ctx):
+    def _decode(self, tlv, offset, decode_path, ctx, tag_only):
         try:
             t, tlen, lv = tag_strip(tlv)
         except DecodeError as err:
@@ -4078,6 +4157,8 @@ class Set(Sequence):
                 decode_path=decode_path,
                 offset=offset,
             )
+        if tag_only:
+            return
         try:
             l, llen, v = len_decode(lv)
         except DecodeError as err:
@@ -4099,23 +4180,18 @@ class Set(Sequence):
         specs_items = self.specs.items
         while len(v) > 0:
             for name, spec in specs_items():
+                sub_decode_path = decode_path + (name,)
                 try:
-                    value, v_tail = spec.decode(
+                    spec.decode(
                         v,
                         sub_offset,
                         leavemm=True,
-                        decode_path=decode_path + (name,),
+                        decode_path=sub_decode_path,
                         ctx=ctx,
+                        tag_only=True,
                     )
                 except TagMismatch:
                     continue
-                sub_offset += (
-                    value.expl_tlvlen if value.expled else value.tlvlen
-                )
-                v = v_tail
-                if spec.default is None or value != spec.default:  # pragma: no cover
-                    # SeqMixing.test_encoded_default_accepted covers that place
-                    values[name] = value
                 break
             else:
                 raise TagMismatch(
@@ -4123,6 +4199,20 @@ class Set(Sequence):
                     decode_path=decode_path,
                     offset=offset,
                 )
+            value, v_tail = spec.decode(
+                v,
+                sub_offset,
+                leavemm=True,
+                decode_path=sub_decode_path,
+                ctx=ctx,
+            )
+            sub_offset += (
+                value.expl_tlvlen if value.expled else value.tlvlen
+            )
+            v = v_tail
+            if spec.default is None or value != spec.default:  # pragma: no cover
+                # SeqMixing.test_encoded_default_accepted covers that place
+                values[name] = value
         obj = self.__class__(
             schema=self.specs,
             impl=self.tag,
@@ -4315,7 +4405,7 @@ class SequenceOf(Obj):
         v = b"".join(self._encoded_values())
         return b"".join((self.tag, len_encode(len(v)), v))
 
-    def _decode(self, tlv, offset, decode_path, ctx):
+    def _decode(self, tlv, offset, decode_path, ctx, tag_only):
         try:
             t, tlen, lv = tag_strip(tlv)
         except DecodeError as err:
@@ -4331,6 +4421,8 @@ class SequenceOf(Obj):
                 decode_path=decode_path,
                 offset=offset,
             )
+        if tag_only:
+            return
         try:
             l, llen, v = len_decode(lv)
         except DecodeError as err:
@@ -4499,11 +4591,6 @@ def main():  # pragma: no cover
         "--defines-by-path",
         help="Python path to decoder's defines_by_path",
     )
-    parser.add_argument(
-        "--with-colours",
-        action='store_true',
-        help="Enable coloured output",
-    )
     parser.add_argument(
         "DERFile",
         type=argparse.FileType("rb"),
@@ -4530,7 +4617,7 @@ def main():  # pragma: no cover
     print(pprinter(
         obj,
         oids=oids,
-        with_colours=True if args.with_colours else False,
+        with_colours=True if environ.get("NO_COLOR") is None else False,
     ))
     if tail != b"":
         print("\nTrailing data: %s" % hexenc(tail))