]> Cypherpunks.ru repositories - pyderasn.git/blobdiff - pyderasn.py
Unify quotes
[pyderasn.git] / pyderasn.py
index cc44e24c91293ca4ffe85ba7684279e87718925c..3b5c4d824a47eda126121f843281c0050b2eed6a 100755 (executable)
@@ -269,11 +269,11 @@ for AlgorithmIdentifier of X.509's
 ``tbsCertificate.subjectPublicKeyInfo.algorithm.algorithm``::
 
         (
-            (('parameters',), {
+            (("parameters",), {
                 id_ecPublicKey: ECParameters(),
                 id_GostR3410_2001: GostR34102001PublicKeyParameters(),
             }),
-            (('..', 'subjectPublicKey'), {
+            (("..", "subjectPublicKey"), {
                 id_rsaEncryption: RSAPublicKey(),
                 id_GostR3410_2001: OctetString(),
             }),
@@ -289,7 +289,7 @@ Following types can be automatically decoded (DEFINED BY):
 * :py:class:`pyderasn.BitString` (that is multiple of 8 bits)
 * :py:class:`pyderasn.OctetString`
 * :py:class:`pyderasn.SequenceOf`/:py:class:`pyderasn.SetOf`
-  ``Any``/``OctetString``-s
+  ``Any``/``BitString``/``OctetString``-s
 
 When any of those fields is automatically decoded, then ``.defined``
 attribute contains ``(OID, value)`` tuple. ``OID`` tells by which OID it
@@ -329,7 +329,7 @@ of ``PKIResponse``::
         (
             (
                 "content",
-                decode_path_defby(id_signedData),
+                DecodePathDefBy(id_signedData),
                 "encapContentInfo",
                 "eContentType",
             ),
@@ -341,10 +341,10 @@ of ``PKIResponse``::
         (
             (
                 "content",
-                decode_path_defby(id_signedData),
+                DecodePathDefBy(id_signedData),
                 "encapContentInfo",
                 "eContent",
-                decode_path_defby(id_cct_PKIResponse),
+                DecodePathDefBy(id_cct_PKIResponse),
                 "controlSequence",
                 any,
                 "attrType",
@@ -358,7 +358,7 @@ of ``PKIResponse``::
         ),
     ))
 
-Pay attention for :py:func:`pyderasn.decode_path_defby` and ``any``.
+Pay attention for :py:class:`pyderasn.DecodePathDefBy` and ``any``.
 First function is useful for path construction when some automatic
 decoding is already done. ``any`` means literally any value it meet --
 useful for SEQUENCE/SET OF-s.
@@ -472,6 +472,8 @@ from collections import namedtuple
 from collections import OrderedDict
 from datetime import datetime
 from math import ceil
+from os import environ
+from string import digits
 
 from six import add_metaclass
 from six import binary_type
@@ -486,6 +488,13 @@ from six import text_type
 from six.moves import xrange as six_xrange
 
 
+try:
+    from termcolor import colored
+except ImportError:
+    def colored(what, *args):
+        return what
+
+
 __all__ = (
     "Any",
     "BitString",
@@ -493,8 +502,8 @@ __all__ = (
     "Boolean",
     "BoundsError",
     "Choice",
-    "decode_path_defby",
     "DecodeError",
+    "DecodePathDefBy",
     "Enumerated",
     "GeneralizedTime",
     "GeneralString",
@@ -581,7 +590,7 @@ class DecodeError(Exception):
             c for c in (
                 "" if self.klass is None else self.klass.__name__,
                 (
-                    ("(%s)" % ".".join(self.decode_path))
+                    ("(%s)" % ".".join(str(dp) for dp in self.decode_path))
                     if len(self.decode_path) > 0 else ""
                 ),
                 ("(at %d)" % self.offset) if self.offset > 0 else "",
@@ -1001,10 +1010,24 @@ class Obj(object):
         return self.expl_tlen + self.expl_llen + self.expl_vlen
 
 
-def decode_path_defby(defined_by):
+class DecodePathDefBy(object):
     """DEFINED BY representation inside decode path
     """
-    return "DEFINED BY (%s)" % defined_by
+    __slots__ = ("defined_by",)
+
+    def __init__(self, defined_by):
+        self.defined_by = defined_by
+
+    def __eq__(self, their):
+        if not isinstance(their, self.__class__):
+            return False
+        return self.defined_by == their.defined_by
+
+    def __str__(self):
+        return "DEFINED BY " + str(self.defined_by)
+
+    def __repr__(self):
+        return "<%s: %s>" % (self.__class__.__name__, self.defined_by)
 
 
 ########################################################################
@@ -1072,49 +1095,75 @@ def _pp(
     )
 
 
-def pp_console_row(pp, oids=None, with_offsets=False, with_blob=True):
+def _colorize(what, colour, with_colours, attrs=("bold",)):
+    return colored(what, colour, attrs=attrs) if with_colours else what
+
+
+def pp_console_row(
+        pp,
+        oids=None,
+        with_offsets=False,
+        with_blob=True,
+        with_colours=False,
+):
     cols = []
     if with_offsets:
-        cols.append("%5d%s [%d,%d,%4d]" % (
+        col = "%5d%s" % (
             pp.offset,
             (
                 "  " if pp.expl_offset is None else
                 ("-%d" % (pp.offset - pp.expl_offset))
             ),
-            pp.tlen,
-            pp.llen,
-            pp.vlen,
-        ))
+        )
+        cols.append(_colorize(col, "red", with_colours, ()))
+        col = "[%d,%d,%4d]" % (pp.tlen, pp.llen, pp.vlen)
+        cols.append(_colorize(col, "green", with_colours, ()))
     if len(pp.decode_path) > 0:
         cols.append(" ." * (len(pp.decode_path)))
-        cols.append("%s:" % pp.decode_path[-1])
+        ent = pp.decode_path[-1]
+        if isinstance(ent, DecodePathDefBy):
+            cols.append(_colorize("DEFINED BY", "red", with_colours, ("reverse",)))
+            value = str(ent.defined_by)
+            if (
+                    oids is not None and
+                    ent.defined_by.asn1_type_name ==
+                    ObjectIdentifier.asn1_type_name and
+                    value in oids
+            ):
+                cols.append(_colorize("%s:" % oids[value], "green", with_colours))
+            else:
+                cols.append(_colorize("%s:" % value, "white", with_colours, ("reverse",)))
+        else:
+            cols.append(_colorize("%s:" % ent, "yellow", with_colours, ("reverse",)))
     if pp.expl is not None:
         klass, _, num = pp.expl
-        cols.append("[%s%d] EXPLICIT" % (TagClassReprs[klass], num))
+        col = "[%s%d] EXPLICIT" % (TagClassReprs[klass], num)
+        cols.append(_colorize(col, "blue", with_colours))
     if pp.impl is not None:
         klass, _, num = pp.impl
-        cols.append("[%s%d]" % (TagClassReprs[klass], num))
+        col = "[%s%d]" % (TagClassReprs[klass], num)
+        cols.append(_colorize(col, "blue", with_colours))
     if pp.asn1_type_name.replace(" ", "") != pp.obj_name.upper():
-        cols.append(pp.obj_name)
-    cols.append(pp.asn1_type_name)
+        cols.append(_colorize(pp.obj_name, "magenta", with_colours))
+    cols.append(_colorize(pp.asn1_type_name, "cyan", with_colours))
     if pp.value is not None:
         value = pp.value
+        cols.append(_colorize(value, "white", with_colours, ("reverse",)))
         if (
                 oids is not None and
                 pp.asn1_type_name == ObjectIdentifier.asn1_type_name and
                 value in oids
         ):
-            value = "%s (%s)" % (oids[value], pp.value)
-        cols.append(value)
+            cols.append(_colorize("(%s)" % oids[value], "green", with_colours))
     if with_blob:
         if isinstance(pp.blob, binary_type):
             cols.append(hexenc(pp.blob))
         elif isinstance(pp.blob, tuple):
             cols.append(", ".join(pp.blob))
     if pp.optional:
-        cols.append("OPTIONAL")
+        cols.append(_colorize("OPTIONAL", "red", with_colours))
     if pp.default:
-        cols.append("DEFAULT")
+        cols.append(_colorize("DEFAULT", "red", with_colours))
     return " ".join(cols)
 
 
@@ -1133,7 +1182,7 @@ def pp_console_blob(pp):
         yield " ".join(cols + [", ".join(pp.blob)])
 
 
-def pprint(obj, oids=None, big_blobs=False):
+def pprint(obj, oids=None, big_blobs=False, with_colours=False):
     """Pretty print object
 
     :param Obj obj: object you want to pretty print
@@ -1142,6 +1191,8 @@ def pprint(obj, oids=None, big_blobs=False):
     :param big_blobs: if large binary objects are met (like OctetString
                       values), do we need to print them too, on separate
                       lines
+    :param with_colours: colourize output, if ``termcolor`` library
+                         is available
     """
     def _pprint_pps(pps):
         for pp in pps:
@@ -1152,11 +1203,18 @@ def pprint(obj, oids=None, big_blobs=False):
                         oids=oids,
                         with_offsets=True,
                         with_blob=False,
+                        with_colours=with_colours,
                     )
                     for row in pp_console_blob(pp):
                         yield row
                 else:
-                    yield pp_console_row(pp, oids=oids, with_offsets=True)
+                    yield pp_console_row(
+                        pp,
+                        oids=oids,
+                        with_offsets=True,
+                        with_blob=True,
+                        with_colours=with_colours,
+                    )
             else:
                 for row in _pprint_pps(pp):
                     yield row
@@ -1708,12 +1766,12 @@ class BitString(Obj):
 
         class KeyUsage(BitString):
             schema = (
-                ('digitalSignature', 0),
-                ('nonRepudiation', 1),
-                ('keyEncipherment', 2),
+                ("digitalSignature", 0),
+                ("nonRepudiation", 1),
+                ("keyEncipherment", 2),
             )
 
-    >>> b = KeyUsage(('keyEncipherment', 'nonRepudiation'))
+    >>> b = KeyUsage(("keyEncipherment", "nonRepudiation"))
     KeyUsage BIT STRING 3 bits nonRepudiation, keyEncipherment
     >>> b.named
     ['nonRepudiation', 'keyEncipherment']
@@ -1957,7 +2015,7 @@ class BitString(Obj):
                 decode_path=decode_path,
                 offset=offset,
             )
-        if byte2int(v[-1:]) & ((1 << pad_size) - 1) != 0:
+        if byte2int(v[l - 1:l]) & ((1 << pad_size) - 1) != 0:
             raise DecodeError(
                 "invalid pad",
                 klass=self.__class__,
@@ -2009,7 +2067,7 @@ class BitString(Obj):
         defined_by, defined = self.defined or (None, None)
         if defined_by is not None:
             yield defined.pps(
-                decode_path=decode_path + (decode_path_defby(defined_by),)
+                decode_path=decode_path + (DecodePathDefBy(defined_by),)
             )
 
 
@@ -2198,6 +2256,13 @@ class OctetString(Obj):
                 optional=self.optional,
                 _decoded=(offset, llen, l),
             )
+        except DecodeError as err:
+            raise DecodeError(
+                msg=err.msg,
+                klass=self.__class__,
+                decode_path=decode_path,
+                offset=offset,
+            )
         except BoundsError as err:
             raise DecodeError(
                 msg=str(err),
@@ -2233,7 +2298,7 @@ class OctetString(Obj):
         defined_by, defined = self.defined or (None, None)
         if defined_by is not None:
             yield defined.pps(
-                decode_path=decode_path + (decode_path_defby(defined_by),)
+                decode_path=decode_path + (DecodePathDefBy(defined_by),)
             )
 
 
@@ -2754,7 +2819,7 @@ class CommonString(OctetString):
 
     >>> PrintableString("привет мир")
     Traceback (most recent call last):
-    UnicodeEncodeError: 'ascii' codec can't encode characters in position 0-5: ordinal not in range(128)
+    pyderasn.DecodeError: 'ascii' codec can't encode characters in position 0-5: ordinal not in range(128)
 
     >>> BMPString("ада", bounds=(2, 2))
     Traceback (most recent call last):
@@ -2810,14 +2875,17 @@ class CommonString(OctetString):
             value_raw = value
         else:
             raise InvalidValueType((self.__class__, text_type, binary_type))
-        value_raw = (
-            value_decoded.encode(self.encoding)
-            if value_raw is None else value_raw
-        )
-        value_decoded = (
-            value_raw.decode(self.encoding)
-            if value_decoded is None else value_decoded
-        )
+        try:
+            value_raw = (
+                value_decoded.encode(self.encoding)
+                if value_raw is None else value_raw
+            )
+            value_decoded = (
+                value_raw.decode(self.encoding)
+                if value_decoded is None else value_decoded
+            )
+        except (UnicodeEncodeError, UnicodeDecodeError) as err:
+            raise DecodeError(str(err))
         if not self._bound_min <= len(value_decoded) <= self._bound_max:
             raise BoundsError(
                 self._bound_min,
@@ -2879,6 +2947,13 @@ class NumericString(CommonString):
     tag_default = tag_encode(18)
     encoding = "ascii"
     asn1_type_name = "NumericString"
+    allowable_chars = set(digits.encode("ascii"))
+
+    def _value_sanitize(self, value):
+        value = super(NumericString, self)._value_sanitize(value)
+        if not set(value) <= self.allowable_chars:
+            raise DecodeError("non-numeric value")
+        return value
 
 
 class PrintableString(CommonString):
@@ -3153,8 +3228,8 @@ class Choice(Obj):
 
         class GeneralName(Choice):
             schema = (
-                ('rfc822Name', IA5String(impl=tag_ctxp(1))),
-                ('dNSName', IA5String(impl=tag_ctxp(2))),
+                ("rfc822Name", IA5String(impl=tag_ctxp(1))),
+                ("dNSName", IA5String(impl=tag_ctxp(2))),
             )
 
     >>> gn = GeneralName()
@@ -3545,7 +3620,7 @@ class Any(Obj):
         defined_by, defined = self.defined or (None, None)
         if defined_by is not None:
             yield defined.pps(
-                decode_path=decode_path + (decode_path_defby(defined_by),)
+                decode_path=decode_path + (DecodePathDefBy(defined_by),)
             )
 
 
@@ -3642,6 +3717,8 @@ class Sequence(Obj):
     >>> tbs = TBSCertificate()
     >>> tbs["version"] = Version("v2") # no need to explicitly add ``expl``
 
+    Assign ``None`` to remove value from sequence.
+
     You can know if value exists/set in the sequence and take its value:
 
     >>> "extnID" in ext, "extnValue" in ext, "critical" in ext
@@ -3869,11 +3946,14 @@ class Sequence(Obj):
                     for i, _value in enumerate(value):
                         sub_sub_decode_path = sub_decode_path + (
                             str(i),
-                            decode_path_defby(defined_by),
+                            DecodePathDefBy(defined_by),
                         )
                         defined_value, defined_tail = defined_spec.decode(
                             memoryview(bytes(_value)),
-                            sub_offset + value.tlen + value.llen,
+                            sub_offset + (
+                                (value.tlen + value.llen + value.expl_tlen + value.expl_llen)
+                                if value.expled else (value.tlen + value.llen)
+                            ),
                             leavemm=True,
                             decode_path=sub_sub_decode_path,
                             ctx=ctx,
@@ -3889,16 +3969,19 @@ class Sequence(Obj):
                 else:
                     defined_value, defined_tail = defined_spec.decode(
                         memoryview(bytes(value)),
-                        sub_offset + value.tlen + value.llen,
+                        sub_offset + (
+                            (value.tlen + value.llen + value.expl_tlen + value.expl_llen)
+                            if value.expled else (value.tlen + value.llen)
+                        ),
                         leavemm=True,
-                        decode_path=sub_decode_path + (decode_path_defby(defined_by),),
+                        decode_path=sub_decode_path + (DecodePathDefBy(defined_by),),
                         ctx=ctx,
                     )
                     if len(defined_tail) > 0:
                         raise DecodeError(
                             "remaining data",
                             klass=self.__class__,
-                            decode_path=sub_decode_path + (decode_path_defby(defined_by),),
+                            decode_path=sub_decode_path + (DecodePathDefBy(defined_by),),
                             offset=offset,
                         )
                     value.defined = (defined_by, defined_value)
@@ -4389,7 +4472,7 @@ def generic_decoder():  # pragma: no cover
         __slots__ = ()
         schema = choice
 
-    def pprint_any(obj, oids=None):
+    def pprint_any(obj, oids=None, with_colours=False):
         def _pprint_pps(pps):
             for pp in pps:
                 if hasattr(pp, "_fields"):
@@ -4403,6 +4486,7 @@ def generic_decoder():  # pragma: no cover
                         oids=oids,
                         with_offsets=True,
                         with_blob=False,
+                        with_colours=with_colours,
                     )
                     for row in pp_console_blob(pp):
                         yield row
@@ -4457,7 +4541,11 @@ def main():  # pragma: no cover
             {"defines_by_path": obj_by_path(args.defines_by_path)}
         ),
     )
-    print(pprinter(obj, oids=oids))
+    print(pprinter(
+        obj,
+        oids=oids,
+        with_colours=True if environ.get("NO_COLOR") is None else False,
+    ))
     if tail != b"":
         print("\nTrailing data: %s" % hexenc(tail))