]> Cypherpunks.ru repositories - pyderasn.git/blobdiff - pyderasn.py
Indexed comparison slightly faster than .startswith()
[pyderasn.git] / pyderasn.py
index 5adc5b068d17915f0c543144b66557b9fc355f30..39f2d885062a49f4e59b630e1305ffa21fede80b 100755 (executable)
@@ -1,5 +1,6 @@
 #!/usr/bin/env python
 # coding: utf-8
+# cython: language_level=3
 # PyDERASN -- Python ASN.1 DER/BER codec with abstract structures
 # Copyright (C) 2017-2020 Sergey Matveev <stargrave@stargrave.org>
 #
@@ -340,7 +341,8 @@ Let's parse that output, human::
  Only applicable to BER encoded data. If object has BER-specific
  encoding, then ``BER`` will be shown. It does not depend on indefinite
  length encoding. ``EOC``, ``BOOLEAN``, ``BIT STRING``, ``OCTET STRING``
- (and its derivatives), ``SET``, ``SET OF`` could be BERed.
+ (and its derivatives), ``SET``, ``SET OF``, ``UTCTime``, ``GeneralizedTime``
+ could be BERed.
 
 
 .. _definedby:
@@ -490,8 +492,8 @@ constructed primitive types should be parsed successfully.
 
 * If object is encoded in BER form (not the DER one), then ``ber_encoded``
   attribute is set to True. Only ``BOOLEAN``, ``BIT STRING``, ``OCTET
-  STRING``, ``OBJECT IDENTIFIER``, ``SEQUENCE``, ``SET``, ``SET OF``
-  can contain it.
+  STRING``, ``OBJECT IDENTIFIER``, ``SEQUENCE``, ``SET``, ``SET OF``,
+  ``UTCTime``, ``GeneralizedTime`` can contain it.
 * If object has an indefinite length encoding, then its ``lenindef``
   attribute is set to True. Only ``BIT STRING``, ``OCTET STRING``,
   ``SEQUENCE``, ``SET``, ``SEQUENCE OF``, ``SET OF``, ``ANY`` can
@@ -659,6 +661,7 @@ from collections import namedtuple
 from collections import OrderedDict
 from copy import copy
 from datetime import datetime
+from datetime import timedelta
 from math import ceil
 from os import environ
 from string import ascii_letters
@@ -687,7 +690,7 @@ except ImportError:  # pragma: no cover
     def colored(what, *args, **kwargs):
         return what
 
-__version__ = "6.0"
+__version__ = "6.1"
 
 __all__ = (
     "Any",
@@ -761,6 +764,7 @@ EOC = b"\x00\x00"
 EOC_LEN = len(EOC)
 LENINDEF = b"\x80"  # length indefinite mark
 LENINDEF_PP_CHAR = "I" if PY2 else "∞"
+NAMEDTUPLE_KWARGS = {} if PY2 else {"module": __name__}
 
 
 ########################################################################
@@ -1477,7 +1481,7 @@ PP = namedtuple("PP", (
     "lenindef",
     "ber_encoded",
     "bered",
-))
+), **NAMEDTUPLE_KWARGS)
 
 
 def _pp(
@@ -1737,7 +1741,7 @@ BooleanState = namedtuple("BooleanState", (
     "expl_lenindef",
     "lenindef",
     "ber_encoded",
-))
+), **NAMEDTUPLE_KWARGS)
 
 
 class Boolean(Obj):
@@ -1981,7 +1985,7 @@ IntegerState = namedtuple("IntegerState", (
     "expl_lenindef",
     "lenindef",
     "ber_encoded",
-))
+), **NAMEDTUPLE_KWARGS)
 
 
 class Integer(Obj):
@@ -2351,7 +2355,7 @@ BitStringState = namedtuple("BitStringState", (
     "ber_encoded",
     "tag_constructed",
     "defined",
-))
+), **NAMEDTUPLE_KWARGS)
 
 
 class BitString(Obj):
@@ -2862,7 +2866,7 @@ OctetStringState = namedtuple("OctetStringState", (
     "ber_encoded",
     "tag_constructed",
     "defined",
-))
+), **NAMEDTUPLE_KWARGS)
 
 
 class OctetString(Obj):
@@ -2902,6 +2906,7 @@ class OctetString(Obj):
             default=None,
             optional=False,
             _decoded=(0, 0, 0),
+            ctx=None,
     ):
         """
         :param value: set the value. Either binary type, or
@@ -3039,7 +3044,7 @@ class OctetString(Obj):
             self._value,
         ))
 
-    def _decode_chunk(self, lv, offset, decode_path):
+    def _decode_chunk(self, lv, offset, decode_path, ctx):
         try:
             l, llen, v = len_decode(lv)
         except DecodeError as err:
@@ -3066,6 +3071,7 @@ class OctetString(Obj):
                 default=self.default,
                 optional=self.optional,
                 _decoded=(offset, llen, l),
+                ctx=ctx,
             )
         except DecodeError as err:
             raise DecodeError(
@@ -3096,7 +3102,7 @@ class OctetString(Obj):
         if t == self.tag:
             if tag_only:
                 return None
-            return self._decode_chunk(lv, offset, decode_path)
+            return self._decode_chunk(lv, offset, decode_path, ctx)
         if t == self.tag_constructed:
             if not ctx.get("bered", False):
                 raise DecodeError(
@@ -3174,6 +3180,7 @@ class OctetString(Obj):
                     default=self.default,
                     optional=self.optional,
                     _decoded=(offset, llen, vlen + (EOC_LEN if lenindef else 0)),
+                    ctx=ctx,
                 )
             except DecodeError as err:
                 raise DecodeError(
@@ -3247,7 +3254,7 @@ NullState = namedtuple("NullState", (
     "expl_lenindef",
     "lenindef",
     "ber_encoded",
-))
+), **NAMEDTUPLE_KWARGS)
 
 
 class Null(Obj):
@@ -3417,7 +3424,14 @@ ObjectIdentifierState = namedtuple("ObjectIdentifierState", (
     "lenindef",
     "ber_encoded",
     "defines",
-))
+), **NAMEDTUPLE_KWARGS)
+
+
+def pureint(value):
+    i = int(value)
+    if (value[0] in "+- ") or (value[-1] == " "):
+        raise ValueError("non-pure integer")
+    return i
 
 
 class ObjectIdentifier(Obj):
@@ -3497,7 +3511,7 @@ class ObjectIdentifier(Obj):
             return value._value
         if isinstance(value, string_types):
             try:
-                value = tuple(int(arc) for arc in value.split("."))
+                value = tuple(pureint(arc) for arc in value.split("."))
             except ValueError:
                 raise InvalidOID("unacceptable arcs values")
         if isinstance(value, tuple):
@@ -3511,6 +3525,8 @@ class ObjectIdentifier(Obj):
                 pass
             else:
                 raise InvalidOID("unacceptable first arc value")
+            if not all(arc >= 0 for arc in value):
+                raise InvalidOID("negative arc value")
             return value
         raise InvalidValueType((self.__class__, str, tuple))
 
@@ -3802,7 +3818,7 @@ class Enumerated(Integer):
 
 
 def escape_control_unicode(c):
-    if unicat(c).startswith("C"):
+    if unicat(c)[0] == "C":
         c = repr(c).lstrip("u").strip("'")
     return c
 
@@ -3995,6 +4011,7 @@ class NumericString(AllowableCharsMixin, CommonString):
 PrintableStringState = namedtuple(
     "PrintableStringState",
     OctetStringState._fields + ("allowable_chars",),
+    **NAMEDTUPLE_KWARGS
 )
 
 
@@ -4029,6 +4046,7 @@ class PrintableString(AllowableCharsMixin, CommonString):
             default=None,
             optional=False,
             _decoded=(0, 0, 0),
+            ctx=None,
             allow_asterisk=False,
             allow_ampersand=False,
     ):
@@ -4041,7 +4059,7 @@ class PrintableString(AllowableCharsMixin, CommonString):
         if allow_ampersand:
             self._allowable_chars |= self._ampersand
         super(PrintableString, self).__init__(
-            value, bounds, impl, expl, default, optional, _decoded,
+            value, bounds, impl, expl, default, optional, _decoded, ctx,
         )
 
     @property
@@ -4127,7 +4145,26 @@ LEN_YYYYMMDDHHMMSSDMZ = len("YYYYMMDDHHMMSSDMZ")
 LEN_YYYYMMDDHHMMSSZ = len("YYYYMMDDHHMMSSZ")
 
 
-class UTCTime(CommonString):
+def fractions2float(fractions_raw):
+    pureint(fractions_raw)
+    return float("0." + fractions_raw)
+
+
+class VisibleString(CommonString):
+    __slots__ = ()
+    tag_default = tag_encode(26)
+    encoding = "ascii"
+    asn1_type_name = "VisibleString"
+
+
+UTCTimeState = namedtuple(
+    "UTCTimeState",
+    OctetStringState._fields + ("ber_raw",),
+    **NAMEDTUPLE_KWARGS
+)
+
+
+class UTCTime(VisibleString):
     """``UTCTime`` datetime type
 
     >>> t = UTCTime(datetime(2017, 9, 30, 22, 7, 50, 123))
@@ -4143,9 +4180,18 @@ class UTCTime(CommonString):
 
     .. warning::
 
-       BER encoding is unsupported.
+       Pay attention that UTCTime can not hold full year, so all years
+       having < 50 years are treated as 20xx, 19xx otherwise, according
+       to X.509 recommendation.
+
+    .. warning::
+
+       No strict validation of UTC offsets are made, but very crude:
+
+       * minutes are not exceeding 60
+       * offset value is not exceeding 14 hours
     """
-    __slots__ = ()
+    __slots__ = ("_ber_raw",)
     tag_default = tag_encode(23)
     encoding = "ascii"
     asn1_type_name = "UTCTime"
@@ -4159,6 +4205,7 @@ class UTCTime(CommonString):
             optional=False,
             _decoded=(0, 0, 0),
             bounds=None,  # dummy argument, workability for OctetString.decode
+            ctx=None,
     ):
         """
         :param value: set the value. Either datetime type, or
@@ -4169,13 +4216,15 @@ class UTCTime(CommonString):
         :param bool optional: is object ``OPTIONAL`` in sequence
         """
         super(UTCTime, self).__init__(
-            None, None, impl, expl, default, optional, _decoded,
+            None, None, impl, expl, None, optional, _decoded, ctx,
         )
         self._value = value
+        self._ber_raw = None
         if value is not None:
-            self._value = self._value_sanitize(value)
+            self._value, self._ber_raw = self._value_sanitize(value, ctx)
+            self.ber_encoded = self._ber_raw is not None
         if default is not None:
-            default = self._value_sanitize(default)
+            default, _ = self._value_sanitize(default)
             self.default = self.__class__(
                 value=default,
                 impl=self.tag,
@@ -4183,6 +4232,51 @@ class UTCTime(CommonString):
             )
             if self._value is None:
                 self._value = default
+            optional = True
+        self.optional = optional
+
+    def _strptime_bered(self, value):
+        year = pureint(value[:2])
+        year += 2000 if year < 50 else 1900
+        decoded = datetime(
+            year,  # %Y
+            pureint(value[2:4]),  # %m
+            pureint(value[4:6]),  # %d
+            pureint(value[6:8]),  # %H
+            pureint(value[8:10]),  # %M
+        )
+        value = value[10:]
+        if len(value) == 0:
+            raise ValueError("no timezone")
+        offset = 0
+        if value[-1] == "Z":
+            value = value[:-1]
+        else:
+            if len(value) < 5:
+                raise ValueError("invalid UTC offset")
+            if value[-5] == "-":
+                sign = -1
+            elif value[-5] == "+":
+                sign = 1
+            else:
+                raise ValueError("invalid UTC offset")
+            offset = 60 * pureint(value[-2:])
+            if offset >= 3600:
+                raise ValueError("invalid UTC offset minutes")
+            offset += 3600 * pureint(value[-4:-2])
+            if offset > 14 * 3600:
+                raise ValueError("too big UTC offset")
+            offset *= sign
+            value = value[:-5]
+        if len(value) == 0:
+            return offset, decoded
+        if len(value) != 2:
+            raise ValueError("invalid UTC offset seconds")
+        seconds = pureint(value)
+        if seconds >= 60:
+            raise ValueError("invalid seconds value")
+        decoded += timedelta(seconds=seconds)
+        return offset, decoded
 
     def _strptime(self, value):
         # datetime.strptime's format: %y%m%d%H%M%SZ
@@ -4190,35 +4284,82 @@ class UTCTime(CommonString):
             raise ValueError("invalid UTCTime length")
         if value[-1] != "Z":
             raise ValueError("non UTC timezone")
+        year = pureint(value[:2])
+        year += 2000 if year < 50 else 1900
         return datetime(
-            2000 + int(value[:2]),  # %y
-            int(value[2:4]),  # %m
-            int(value[4:6]),  # %d
-            int(value[6:8]),  # %H
-            int(value[8:10]),  # %M
-            int(value[10:12]),  # %S
+            year,  # %y
+            pureint(value[2:4]),  # %m
+            pureint(value[4:6]),  # %d
+            pureint(value[6:8]),  # %H
+            pureint(value[8:10]),  # %M
+            pureint(value[10:12]),  # %S
         )
 
-    def _value_sanitize(self, value):
+    def _dt_sanitize(self, value):
+        if value.year < 1950 or value.year > 2049:
+            raise ValueError("UTCTime can hold only 1950-2049 years")
+        return value.replace(microsecond=0)
+
+    def _value_sanitize(self, value, ctx=None):
         if isinstance(value, binary_type):
             try:
                 value_decoded = value.decode("ascii")
             except (UnicodeEncodeError, UnicodeDecodeError) as err:
                 raise DecodeError("invalid UTCTime encoding: %r" % err)
+            err = None
             try:
-                self._strptime(value_decoded)
-            except (TypeError, ValueError) as err:
-                raise DecodeError("invalid UTCTime format: %r" % err)
-            return value
+                return self._strptime(value_decoded), None
+            except (TypeError, ValueError) as _err:
+                err = _err
+                if (ctx is not None) and ctx.get("bered", False):
+                    try:
+                        offset, _value = self._strptime_bered(value_decoded)
+                        _value = _value - timedelta(seconds=offset)
+                        return self._dt_sanitize(_value), value_decoded
+                    except (TypeError, ValueError, OverflowError) as _err:
+                        err = _err
+            raise DecodeError(
+                "invalid %s format: %r" % (self.asn1_type_name, err),
+                klass=self.__class__,
+            )
         if isinstance(value, self.__class__):
-            return value._value
+            return value._value, None
         if isinstance(value, datetime):
-            return value.strftime("%y%m%d%H%M%SZ").encode("ascii")
+            return self._dt_sanitize(value), None
         raise InvalidValueType((self.__class__, datetime))
 
+    def _pp_value(self):
+        if self.ready:
+            value = self._value.isoformat()
+            if self.ber_encoded:
+                value += " (%s)" % self._ber_raw
+            return value
+
+    def __unicode__(self):
+        if self.ready:
+            value = self._value.isoformat()
+            if self.ber_encoded:
+                value += " (%s)" % self._ber_raw
+            return value
+        return text_type(self._pp_value())
+
+    def __getstate__(self):
+        return UTCTimeState(
+            *super(UTCTime, self).__getstate__(),
+            **{"ber_raw": self._ber_raw}
+        )
+
+    def __setstate__(self, state):
+        super(UTCTime, self).__setstate__(state)
+        self._ber_raw = state.ber_raw
+
+    def __bytes__(self):
+        self._assert_ready()
+        return self._encode_time()
+
     def __eq__(self, their):
         if isinstance(their, binary_type):
-            return self._value == their
+            return self._encode_time() == their
         if isinstance(their, datetime):
             return self.todatetime() == their
         if not isinstance(their, self.__class__):
@@ -4229,25 +4370,16 @@ class UTCTime(CommonString):
             self._expl == their._expl
         )
 
-    def todatetime(self):
-        """Convert to datetime
+    def _encode_time(self):
+        return self._value.strftime("%y%m%d%H%M%SZ").encode("ascii")
 
-        :returns: datetime
+    def _encode(self):
+        self._assert_ready()
+        value = self._encode_time()
+        return b"".join((self.tag, len_encode(len(value)), value))
 
-        Pay attention that UTCTime can not hold full year, so all years
-        having < 50 years are treated as 20xx, 19xx otherwise, according
-        to X.509 recomendation.
-        """
-        value = self._strptime(self._value.decode("ascii"))
-        year = value.year % 100
-        return datetime(
-            year=(2000 + year) if year < 50 else (1900 + year),
-            month=value.month,
-            day=value.day,
-            hour=value.hour,
-            minute=value.minute,
-            second=value.second,
-        )
+    def todatetime(self):
+        return self._value
 
     def __repr__(self):
         return pp_console_row(next(self.pps()))
@@ -4258,7 +4390,7 @@ class UTCTime(CommonString):
             asn1_type_name=self.asn1_type_name,
             obj_name=self.__class__.__name__,
             decode_path=decode_path,
-            value=self.todatetime().isoformat() if self.ready else None,
+            value=self._pp_value(),
             optional=self.optional,
             default=self == self.default,
             impl=None if self.tag == self.tag_default else tag_decode(self.tag),
@@ -4293,13 +4425,19 @@ class GeneralizedTime(UTCTime):
 
     .. warning::
 
-       BER encoding is unsupported.
+       Only microsecond fractions are supported in DER encoding.
+       :py:exc:`pyderasn.DecodeError` will be raised during decoding of
+       higher precision values.
 
     .. warning::
 
-       Only microsecond fractions are supported.
-       :py:exc:`pyderasn.DecodeError` will be raised during decoding of
-       higher precision values.
+       BER encoded data can loss information (accuracy) during decoding
+       because of float transformations.
+
+    .. warning::
+
+       Local times (without explicit timezone specification) are treated
+       as UTC one, no transformations are made.
 
     .. warning::
 
@@ -4309,6 +4447,70 @@ class GeneralizedTime(UTCTime):
     tag_default = tag_encode(24)
     asn1_type_name = "GeneralizedTime"
 
+    def _dt_sanitize(self, value):
+        return value
+
+    def _strptime_bered(self, value):
+        if len(value) < 4 + 3 * 2:
+            raise ValueError("invalid GeneralizedTime")
+        decoded = datetime(
+            pureint(value[:4]),  # %Y
+            pureint(value[4:6]),  # %m
+            pureint(value[6:8]),  # %d
+            pureint(value[8:10]),  # %H
+        )
+        value = value[10:]
+        offset = 0
+        if len(value) == 0:
+            return offset, decoded
+        if value[-1] == "Z":
+            value = value[:-1]
+        else:
+            for char, sign in (("-", -1), ("+", 1)):
+                idx = value.rfind(char)
+                if idx == -1:
+                    continue
+                offset_raw = value[idx + 1:].replace(":", "")
+                if len(offset_raw) not in (2, 4):
+                    raise ValueError("invalid UTC offset")
+                value = value[:idx]
+                offset = 60 * pureint(offset_raw[2:] or "0")
+                if offset >= 3600:
+                    raise ValueError("invalid UTC offset minutes")
+                offset += 3600 * pureint(offset_raw[:2])
+                if offset > 14 * 3600:
+                    raise ValueError("too big UTC offset")
+                offset *= sign
+                break
+        if len(value) == 0:
+            return offset, decoded
+        decimal_signs = ".,"
+        if value[0] in decimal_signs:
+            return offset, (
+                decoded + timedelta(seconds=3600 * fractions2float(value[1:]))
+            )
+        if len(value) < 2:
+            raise ValueError("stripped minutes")
+        decoded += timedelta(seconds=60 * pureint(value[:2]))
+        value = value[2:]
+        if len(value) == 0:
+            return offset, decoded
+        if value[0] in decimal_signs:
+            return offset, (
+                decoded + timedelta(seconds=60 * fractions2float(value[1:]))
+            )
+        if len(value) < 2:
+            raise ValueError("stripped seconds")
+        decoded += timedelta(seconds=pureint(value[:2]))
+        value = value[2:]
+        if len(value) == 0:
+            return offset, decoded
+        if value[0] not in decimal_signs:
+            raise ValueError("invalid format after seconds")
+        return offset, (
+            decoded + timedelta(microseconds=10**6 * fractions2float(value[1:]))
+        )
+
     def _strptime(self, value):
         l = len(value)
         if l == LEN_YYYYMMDDHHMMSSZ:
@@ -4316,12 +4518,12 @@ class GeneralizedTime(UTCTime):
             if value[-1] != "Z":
                 raise ValueError("non UTC timezone")
             return datetime(
-                int(value[:4]),  # %Y
-                int(value[4:6]),  # %m
-                int(value[6:8]),  # %d
-                int(value[8:10]),  # %H
-                int(value[10:12]),  # %M
-                int(value[12:14]),  # %S
+                pureint(value[:4]),  # %Y
+                pureint(value[4:6]),  # %m
+                pureint(value[6:8]),  # %d
+                pureint(value[8:10]),  # %H
+                pureint(value[10:12]),  # %M
+                pureint(value[12:14]),  # %S
             )
         if l >= LEN_YYYYMMDDHHMMSSDMZ:
             # datetime.strptime's format: %Y%m%d%H%M%S.%fZ
@@ -4335,44 +4537,25 @@ class GeneralizedTime(UTCTime):
             us_len = len(us)
             if us_len > 6:
                 raise ValueError("only microsecond fractions are supported")
-            us = int(us + ("0" * (6 - us_len)))
+            us = pureint(us + ("0" * (6 - us_len)))
             decoded = datetime(
-                int(value[:4]),  # %Y
-                int(value[4:6]),  # %m
-                int(value[6:8]),  # %d
-                int(value[8:10]),  # %H
-                int(value[10:12]),  # %M
-                int(value[12:14]),  # %S
+                pureint(value[:4]),  # %Y
+                pureint(value[4:6]),  # %m
+                pureint(value[6:8]),  # %d
+                pureint(value[8:10]),  # %H
+                pureint(value[10:12]),  # %M
+                pureint(value[12:14]),  # %S
                 us,  # %f
             )
             return decoded
         raise ValueError("invalid GeneralizedTime length")
 
-    def _value_sanitize(self, value):
-        if isinstance(value, binary_type):
-            try:
-                value_decoded = value.decode("ascii")
-            except (UnicodeEncodeError, UnicodeDecodeError) as err:
-                raise DecodeError("invalid GeneralizedTime encoding: %r" % err)
-            try:
-                self._strptime(value_decoded)
-            except (TypeError, ValueError) as err:
-                raise DecodeError(
-                    "invalid GeneralizedTime format: %r" % err,
-                    klass=self.__class__,
-                )
-            return value
-        if isinstance(value, self.__class__):
-            return value._value
-        if isinstance(value, datetime):
-            encoded = value.strftime("%Y%m%d%H%M%S")
-            if value.microsecond > 0:
-                encoded = encoded + (".%06d" % value.microsecond).rstrip("0")
-            return (encoded + "Z").encode("ascii")
-        raise InvalidValueType((self.__class__, datetime))
-
-    def todatetime(self):
-        return self._strptime(self._value.decode("ascii"))
+    def _encode_time(self):
+        value = self._value
+        encoded = value.strftime("%Y%m%d%H%M%S")
+        if value.microsecond > 0:
+            encoded += (".%06d" % value.microsecond).rstrip("0")
+        return (encoded + "Z").encode("ascii")
 
 
 class GraphicString(CommonString):
@@ -4382,13 +4565,6 @@ class GraphicString(CommonString):
     asn1_type_name = "GraphicString"
 
 
-class VisibleString(CommonString):
-    __slots__ = ()
-    tag_default = tag_encode(26)
-    encoding = "ascii"
-    asn1_type_name = "VisibleString"
-
-
 class ISO646String(VisibleString):
     __slots__ = ()
     asn1_type_name = "ISO646String"
@@ -4429,7 +4605,7 @@ ChoiceState = namedtuple("ChoiceState", (
     "expl_lenindef",
     "lenindef",
     "ber_encoded",
-))
+), **NAMEDTUPLE_KWARGS)
 
 
 class Choice(Obj):
@@ -4747,7 +4923,7 @@ AnyState = namedtuple("AnyState", (
     "lenindef",
     "ber_encoded",
     "defined",
-))
+), **NAMEDTUPLE_KWARGS)
 
 
 class Any(Obj):
@@ -5026,7 +5202,7 @@ SequenceState = namedtuple("SequenceState", (
     "expl_lenindef",
     "lenindef",
     "ber_encoded",
-))
+), **NAMEDTUPLE_KWARGS)
 
 
 class Sequence(Obj):
@@ -5691,7 +5867,7 @@ SequenceOfState = namedtuple("SequenceOfState", (
     "expl_lenindef",
     "lenindef",
     "ber_encoded",
-))
+), **NAMEDTUPLE_KWARGS)
 
 
 class SequenceOf(Obj):