]> Cypherpunks.ru repositories - pyderasn.git/blobdiff - pyderasn.py
Correct permissions in tarball
[pyderasn.git] / pyderasn.py
index 68dd21339f62e6801e88788215eb19fcda613fb7..8208efeb024148e54e8c6e71ca369383b393f273 100755 (executable)
@@ -1,7 +1,9 @@
 #!/usr/bin/env python
 # coding: utf-8
 # cython: language_level=3
-# PyDERASN -- Python ASN.1 DER/BER codec with abstract structures
+# pylint: disable=line-too-long,superfluous-parens,protected-access,too-many-lines
+# pylint: disable=too-many-return-statements,too-many-branches,too-many-statements
+# PyDERASN -- Python ASN.1 DER/CER/BER codec with abstract structures
 # Copyright (C) 2017-2020 Sergey Matveev <stargrave@stargrave.org>
 #
 # This program is free software: you can redistribute it and/or modify
@@ -1084,6 +1086,7 @@ Now you can print only the specified tree, for example signature algorithm::
                          . . 05:00
 """
 
+from array import array
 from codecs import getdecoder
 from codecs import getencoder
 from collections import namedtuple
@@ -1123,7 +1126,7 @@ except ImportError:  # pragma: no cover
     def colored(what, *args, **kwargs):
         return what
 
-__version__ = "7.0"
+__version__ = "7.2"
 
 __all__ = (
     "agg_octet_string",
@@ -1517,6 +1520,8 @@ def len_decode(data):
     return l, 1 + octets_num, data[1 + octets_num:]
 
 
+LEN0 = len_encode(0)
+LEN1 = len_encode(1)
 LEN1K = len_encode(1000)
 
 
@@ -2457,11 +2462,7 @@ class Boolean(Obj):
 
     def _encode(self):
         self._assert_ready()
-        return b"".join((
-            self.tag,
-            len_encode(1),
-            (b"\xFF" if self._value else b"\x00"),
-        ))
+        return b"".join((self.tag, LEN1, (b"\xFF" if self._value else b"\x00")))
 
     def _decode(self, tlv, offset, decode_path, ctx, tag_only, evgen_mode):
         try:
@@ -2703,11 +2704,11 @@ class Integer(Obj):
 
     def __hash__(self):
         self._assert_ready()
-        return hash(
-            self.tag +
-            bytes(self._expl or b"") +
+        return hash(b"".join((
+            self.tag,
+            bytes(self._expl or b""),
             str(self._value).encode("ascii"),
-        )
+        )))
 
     def __eq__(self, their):
         if isinstance(their, integer_types):
@@ -3984,7 +3985,7 @@ class Null(Obj):
         )
 
     def _encode(self):
-        return self.tag + len_encode(0)
+        return self.tag + LEN0
 
     def _decode(self, tlv, offset, decode_path, ctx, tag_only, evgen_mode):
         try:
@@ -4130,7 +4131,7 @@ class ObjectIdentifier(Obj):
 
     def __add__(self, their):
         if their.__class__ == tuple:
-            return self.__class__(self._value + their)
+            return self.__class__(self._value + array("L", their))
         if isinstance(their, self.__class__):
             return self.__class__(self._value + their._value)
         raise InvalidValueType((self.__class__, tuple))
@@ -4140,10 +4141,15 @@ class ObjectIdentifier(Obj):
             return value._value
         if isinstance(value, string_types):
             try:
-                value = tuple(pureint(arc) for arc in value.split("."))
+                value = array("L", (pureint(arc) for arc in value.split(".")))
             except ValueError:
                 raise InvalidOID("unacceptable arcs values")
         if value.__class__ == tuple:
+            try:
+                value = array("L", value)
+            except OverflowError as err:
+                raise InvalidOID(repr(err))
+        if value.__class__ is array:
             if len(value) < 2:
                 raise InvalidOID("less than 2 arcs")
             first_arc = value[0]
@@ -4195,15 +4201,15 @@ class ObjectIdentifier(Obj):
 
     def __hash__(self):
         self._assert_ready()
-        return hash(
-            self.tag +
-            bytes(self._expl or b"") +
+        return hash(b"".join((
+            self.tag,
+            bytes(self._expl or b""),
             str(self._value).encode("ascii"),
-        )
+        )))
 
     def __eq__(self, their):
         if their.__class__ == tuple:
-            return self._value == their
+            return self._value == array("L", their)
         if not issubclass(their.__class__, ObjectIdentifier):
             return False
         return (
@@ -4295,7 +4301,7 @@ class ObjectIdentifier(Obj):
                 offset=offset,
             )
         v, tail = v[:l], v[l:]
-        arcs = []
+        arcs = array("L")
         ber_encoded = False
         while len(v) > 0:
             i = 0
@@ -4306,10 +4312,23 @@ class ObjectIdentifier(Obj):
                     if ctx.get("bered", False):
                         ber_encoded = True
                     else:
-                        raise DecodeError("non normalized arc encoding")
+                        raise DecodeError(
+                            "non normalized arc encoding",
+                            klass=self.__class__,
+                            decode_path=decode_path,
+                            offset=offset,
+                        )
                 arc = (arc << 7) | (octet & 0x7F)
                 if octet & 0x80 == 0:
-                    arcs.append(arc)
+                    try:
+                        arcs.append(arc)
+                    except OverflowError:
+                        raise DecodeError(
+                            "too huge value for local unsigned long",
+                            klass=self.__class__,
+                            decode_path=decode_path,
+                            offset=offset,
+                        )
                     v = v[i + 1:]
                     break
                 i += 1
@@ -4331,7 +4350,7 @@ class ObjectIdentifier(Obj):
             first_arc = 2
             second_arc -= 80
         obj = self.__class__(
-            value=tuple([first_arc, second_arc] + arcs[1:]),
+            value=array("L", (first_arc, second_arc)) + arcs[1:],
             impl=self.tag,
             expl=self._expl,
             default=self.default,
@@ -4762,8 +4781,11 @@ class IA5String(CommonString):
 
 
 LEN_YYMMDDHHMMSSZ = len("YYMMDDHHMMSSZ")
+LEN_LEN_YYMMDDHHMMSSZ = len_encode(LEN_YYMMDDHHMMSSZ)
+LEN_YYMMDDHHMMSSZ_WITH_LEN = len(LEN_LEN_YYMMDDHHMMSSZ) + LEN_YYMMDDHHMMSSZ
 LEN_YYYYMMDDHHMMSSDMZ = len("YYYYMMDDHHMMSSDMZ")
 LEN_YYYYMMDDHHMMSSZ = len("YYYYMMDDHHMMSSZ")
+LEN_LEN_YYYYMMDDHHMMSSZ = len_encode(LEN_YYYYMMDDHHMMSSZ)
 
 
 class VisibleString(CommonString):
@@ -4951,6 +4973,7 @@ class UTCTime(VisibleString):
             if self.ber_encoded:
                 value += " (%s)" % self.ber_raw
             return value
+        return None
 
     def __unicode__(self):
         if self.ready:
@@ -4992,8 +5015,7 @@ class UTCTime(VisibleString):
 
     def _encode(self):
         self._assert_ready()
-        value = self._encode_time()
-        return b"".join((self.tag, len_encode(len(value)), value))
+        return b"".join((self.tag, LEN_LEN_YYMMDDHHMMSSZ, self._encode_time()))
 
     def _encode_cer(self, writer):
         write_full(writer, self._encode())
@@ -5159,6 +5181,14 @@ class GeneralizedTime(UTCTime):
             encoded += (".%06d" % value.microsecond).rstrip("0")
         return (encoded + "Z").encode("ascii")
 
+    def _encode(self):
+        self._assert_ready()
+        value = self._value
+        if value.microsecond > 0:
+            encoded = self._encode_time()
+            return b"".join((self.tag, len_encode(len(encoded)), encoded))
+        return b"".join((self.tag, LEN_LEN_YYYYMMDDHHMMSSZ, self._encode_time()))
+
 
 class GraphicString(CommonString):
     __slots__ = ()
@@ -5797,19 +5827,6 @@ class Any(Obj):
 # ASN.1 constructed types
 ########################################################################
 
-def get_def_by_path(defines_by_path, sub_decode_path):
-    """Get define by decode path
-    """
-    for path, define in defines_by_path:
-        if len(path) != len(sub_decode_path):
-            continue
-        for p1, p2 in zip(path, sub_decode_path):
-            if (not p1 is any) and (p1 != p2):
-                break
-        else:
-            return define
-
-
 def abs_decode_path(decode_path, rel_path):
     """Create an absolute decode path from current and relative ones
 
@@ -6350,17 +6367,16 @@ class Set(Sequence):
     tag_default = tag_encode(form=TagFormConstructed, num=17)
     asn1_type_name = "SET"
 
-    def _encode(self):
-        v = b"".join(value.encode() for value in sorted(
-            self._values_for_encoding(),
+    def _values_for_encoding(self):
+        return sorted(
+            super(Set, self)._values_for_encoding(),
             key=attrgetter("tag_order"),
-        ))
-        return b"".join((self.tag, len_encode(len(v)), v))
+        )
 
     def _encode_cer(self, writer):
         write_full(writer, self.tag + LENINDEF)
         for v in sorted(
-                self._values_for_encoding(),
+                super(Set, self)._values_for_encoding(),
                 key=attrgetter("tag_order_cer"),
         ):
             v.encode_cer(writer)
@@ -6625,7 +6641,6 @@ class SequenceOf(Obj):
             value = value._value
         elif hasattr(value, NEXT_ATTR_NAME):
             iterator = True
-            value = value
         elif hasattr(value, "__iter__"):
             value = list(value)
         else:
@@ -7048,6 +7063,7 @@ def generic_decoder():  # pragma: no cover
             with_colours=False,
             with_decode_path=False,
             decode_path_only=(),
+            decode_path=(),
     ):
         def _pprint_pps(pps):
             for pp in pps:
@@ -7079,13 +7095,13 @@ def generic_decoder():  # pragma: no cover
                 else:
                     for row in _pprint_pps(pp):
                         yield row
-        return "\n".join(_pprint_pps(obj.pps()))
+        return "\n".join(_pprint_pps(obj.pps(decode_path)))
     return SEQUENCEOF(), pprint_any
 
 
 def main():  # pragma: no cover
     import argparse
-    parser = argparse.ArgumentParser(description="PyDERASN ASN.1 BER/DER decoder")
+    parser = argparse.ArgumentParser(description="PyDERASN ASN.1 BER/CER/DER decoder")
     parser.add_argument(
         "--skip",
         type=int,
@@ -7144,9 +7160,9 @@ def main():  # pragma: no cover
         [obj_by_path(_path) for _path in (args.oids or "").split(",")]
         if args.oids else ()
     )
+    from functools import partial
     if args.schema:
         schema = obj_by_path(args.schema)
-        from functools import partial
         pprinter = partial(pprint, big_blobs=True)
     else:
         schema, pprinter = generic_decoder()