]> Cypherpunks.ru repositories - pyderasn.git/blobdiff - pyderasn.py
*Time BER decoding support
[pyderasn.git] / pyderasn.py
index 78080294ec422a7a636503c4cd96e20892764f54..0d3a6637840115c15fdcf90bc898f095f0c640f7 100755 (executable)
@@ -340,7 +340,8 @@ Let's parse that output, human::
  Only applicable to BER encoded data. If object has BER-specific
  encoding, then ``BER`` will be shown. It does not depend on indefinite
  length encoding. ``EOC``, ``BOOLEAN``, ``BIT STRING``, ``OCTET STRING``
- (and its derivatives), ``SET``, ``SET OF`` could be BERed.
+ (and its derivatives), ``SET``, ``SET OF``, ``UTCTime``, ``GeneralizedTime``
+ could be BERed.
 
 
 .. _definedby:
@@ -490,8 +491,8 @@ constructed primitive types should be parsed successfully.
 
 * If object is encoded in BER form (not the DER one), then ``ber_encoded``
   attribute is set to True. Only ``BOOLEAN``, ``BIT STRING``, ``OCTET
-  STRING``, ``OBJECT IDENTIFIER``, ``SEQUENCE``, ``SET``, ``SET OF``
-  can contain it.
+  STRING``, ``OBJECT IDENTIFIER``, ``SEQUENCE``, ``SET``, ``SET OF``,
+  ``UTCTime``, ``GeneralizedTime`` can contain it.
 * If object has an indefinite length encoding, then its ``lenindef``
   attribute is set to True. Only ``BIT STRING``, ``OCTET STRING``,
   ``SEQUENCE``, ``SET``, ``SEQUENCE OF``, ``SET OF``, ``ANY`` can
@@ -659,6 +660,7 @@ from collections import namedtuple
 from collections import OrderedDict
 from copy import copy
 from datetime import datetime
+from datetime import timedelta
 from math import ceil
 from os import environ
 from string import ascii_letters
@@ -2902,6 +2904,7 @@ class OctetString(Obj):
             default=None,
             optional=False,
             _decoded=(0, 0, 0),
+            ctx=None,
     ):
         """
         :param value: set the value. Either binary type, or
@@ -3039,7 +3042,7 @@ class OctetString(Obj):
             self._value,
         ))
 
-    def _decode_chunk(self, lv, offset, decode_path):
+    def _decode_chunk(self, lv, offset, decode_path, ctx):
         try:
             l, llen, v = len_decode(lv)
         except DecodeError as err:
@@ -3066,6 +3069,7 @@ class OctetString(Obj):
                 default=self.default,
                 optional=self.optional,
                 _decoded=(offset, llen, l),
+                ctx=ctx,
             )
         except DecodeError as err:
             raise DecodeError(
@@ -3096,7 +3100,7 @@ class OctetString(Obj):
         if t == self.tag:
             if tag_only:
                 return None
-            return self._decode_chunk(lv, offset, decode_path)
+            return self._decode_chunk(lv, offset, decode_path, ctx)
         if t == self.tag_constructed:
             if not ctx.get("bered", False):
                 raise DecodeError(
@@ -3174,6 +3178,7 @@ class OctetString(Obj):
                     default=self.default,
                     optional=self.optional,
                     _decoded=(offset, llen, vlen + (EOC_LEN if lenindef else 0)),
+                    ctx=ctx,
                 )
             except DecodeError as err:
                 raise DecodeError(
@@ -4038,6 +4043,7 @@ class PrintableString(AllowableCharsMixin, CommonString):
             default=None,
             optional=False,
             _decoded=(0, 0, 0),
+            ctx=None,
             allow_asterisk=False,
             allow_ampersand=False,
     ):
@@ -4050,7 +4056,7 @@ class PrintableString(AllowableCharsMixin, CommonString):
         if allow_ampersand:
             self._allowable_chars |= self._ampersand
         super(PrintableString, self).__init__(
-            value, bounds, impl, expl, default, optional, _decoded,
+            value, bounds, impl, expl, default, optional, _decoded, ctx,
         )
 
     @property
@@ -4136,6 +4142,11 @@ LEN_YYYYMMDDHHMMSSDMZ = len("YYYYMMDDHHMMSSDMZ")
 LEN_YYYYMMDDHHMMSSZ = len("YYYYMMDDHHMMSSZ")
 
 
+def fractions2float(fractions_raw):
+    pureint(fractions_raw)
+    return float("0." + fractions_raw)
+
+
 class VisibleString(CommonString):
     __slots__ = ()
     tag_default = tag_encode(26)
@@ -4143,6 +4154,9 @@ class VisibleString(CommonString):
     asn1_type_name = "VisibleString"
 
 
+UTCTimeState = namedtuple("UTCTimeState", OctetStringState._fields + ("ber_raw",))
+
+
 class UTCTime(VisibleString):
     """``UTCTime`` datetime type
 
@@ -4159,9 +4173,18 @@ class UTCTime(VisibleString):
 
     .. warning::
 
-       BER encoding is unsupported.
+       Pay attention that UTCTime can not hold full year, so all years
+       having < 50 years are treated as 20xx, 19xx otherwise, according
+       to X.509 recommendation.
+
+    .. warning::
+
+       No strict validation of UTC offsets are made, but very crude:
+
+       * minutes are not exceeding 60
+       * offset value is not exceeding 14 hours
     """
-    __slots__ = ()
+    __slots__ = ("_ber_raw",)
     tag_default = tag_encode(23)
     encoding = "ascii"
     asn1_type_name = "UTCTime"
@@ -4175,6 +4198,7 @@ class UTCTime(VisibleString):
             optional=False,
             _decoded=(0, 0, 0),
             bounds=None,  # dummy argument, workability for OctetString.decode
+            ctx=None,
     ):
         """
         :param value: set the value. Either datetime type, or
@@ -4185,13 +4209,15 @@ class UTCTime(VisibleString):
         :param bool optional: is object ``OPTIONAL`` in sequence
         """
         super(UTCTime, self).__init__(
-            None, None, impl, expl, default, optional, _decoded,
+            None, None, impl, expl, None, optional, _decoded, ctx,
         )
         self._value = value
+        self._ber_raw = None
         if value is not None:
-            self._value = self._value_sanitize(value)
+            self._value, self._ber_raw = self._value_sanitize(value, ctx)
+            self.ber_encoded = self._ber_raw is not None
         if default is not None:
-            default = self._value_sanitize(default)
+            default, _ = self._value_sanitize(default)
             self.default = self.__class__(
                 value=default,
                 impl=self.tag,
@@ -4199,6 +4225,51 @@ class UTCTime(VisibleString):
             )
             if self._value is None:
                 self._value = default
+            optional = True
+        self.optional = optional
+
+    def _strptime_bered(self, value):
+        year = pureint(value[:2])
+        year += 2000 if year < 50 else 1900
+        decoded = datetime(
+            year,  # %Y
+            pureint(value[2:4]),  # %m
+            pureint(value[4:6]),  # %d
+            pureint(value[6:8]),  # %H
+            pureint(value[8:10]),  # %M
+        )
+        value = value[10:]
+        if len(value) == 0:
+            raise ValueError("no timezone")
+        offset = 0
+        if value[-1] == "Z":
+            value = value[:-1]
+        else:
+            if len(value) < 5:
+                raise ValueError("invalid UTC offset")
+            if value[-5] == "-":
+                sign = -1
+            elif value[-5] == "+":
+                sign = 1
+            else:
+                raise ValueError("invalid UTC offset")
+            offset = 60 * pureint(value[-2:])
+            if offset >= 3600:
+                raise ValueError("invalid UTC offset minutes")
+            offset += 3600 * pureint(value[-4:-2])
+            if offset > 14 * 3600:
+                raise ValueError("too big UTC offset")
+            offset *= sign
+            value = value[:-5]
+        if len(value) == 0:
+            return offset, decoded
+        if len(value) != 2:
+            raise ValueError("invalid UTC offset seconds")
+        seconds = pureint(value)
+        if seconds >= 60:
+            raise ValueError("invalid seconds value")
+        decoded += timedelta(seconds=seconds)
+        return offset, decoded
 
     def _strptime(self, value):
         # datetime.strptime's format: %y%m%d%H%M%SZ
@@ -4206,8 +4277,10 @@ class UTCTime(VisibleString):
             raise ValueError("invalid UTCTime length")
         if value[-1] != "Z":
             raise ValueError("non UTC timezone")
+        year = pureint(value[:2])
+        year += 2000 if year < 50 else 1900
         return datetime(
-            2000 + int(value[:2]),  # %y
+            year,  # %y
             pureint(value[2:4]),  # %m
             pureint(value[4:6]),  # %d
             pureint(value[6:8]),  # %H
@@ -4215,26 +4288,71 @@ class UTCTime(VisibleString):
             pureint(value[10:12]),  # %S
         )
 
-    def _value_sanitize(self, value):
+    def _dt_sanitize(self, value):
+        if value.year < 1950 or value.year > 2049:
+            raise ValueError("UTCTime can hold only 1950-2049 years")
+        return value.replace(microsecond=0)
+
+    def _value_sanitize(self, value, ctx=None):
         if isinstance(value, binary_type):
             try:
                 value_decoded = value.decode("ascii")
             except (UnicodeEncodeError, UnicodeDecodeError) as err:
                 raise DecodeError("invalid UTCTime encoding: %r" % err)
+            err = None
             try:
-                self._strptime(value_decoded)
-            except (TypeError, ValueError) as err:
-                raise DecodeError("invalid UTCTime format: %r" % err)
-            return value
+                return self._strptime(value_decoded), None
+            except (TypeError, ValueError) as _err:
+                err = _err
+                if (ctx is not None) and ctx.get("bered", False):
+                    try:
+                        offset, _value = self._strptime_bered(value_decoded)
+                        _value = _value - timedelta(seconds=offset)
+                        return self._dt_sanitize(_value), value_decoded
+                    except (TypeError, ValueError, OverflowError) as _err:
+                        err = _err
+            raise DecodeError(
+                "invalid %s format: %r" % (self.asn1_type_name, err),
+                klass=self.__class__,
+            )
         if isinstance(value, self.__class__):
-            return value._value
+            return value._value, None
         if isinstance(value, datetime):
-            return value.strftime("%y%m%d%H%M%SZ").encode("ascii")
+            return self._dt_sanitize(value), None
         raise InvalidValueType((self.__class__, datetime))
 
+    def _pp_value(self):
+        if self.ready:
+            value = self._value.isoformat()
+            if self.ber_encoded:
+                value += " (%s)" % self._ber_raw
+            return value
+
+    def __unicode__(self):
+        if self.ready:
+            value = self._value.isoformat()
+            if self.ber_encoded:
+                value += " (%s)" % self._ber_raw
+            return value
+        return text_type(self._pp_value())
+
+    def __getstate__(self):
+        return UTCTimeState(
+            *super(UTCTime, self).__getstate__(),
+            **{"ber_raw": self._ber_raw}
+        )
+
+    def __setstate__(self, state):
+        super(UTCTime, self).__setstate__(state)
+        self._ber_raw = state.ber_raw
+
+    def __bytes__(self):
+        self._assert_ready()
+        return self._encode_time()
+
     def __eq__(self, their):
         if isinstance(their, binary_type):
-            return self._value == their
+            return self._encode_time() == their
         if isinstance(their, datetime):
             return self.todatetime() == their
         if not isinstance(their, self.__class__):
@@ -4245,25 +4363,16 @@ class UTCTime(VisibleString):
             self._expl == their._expl
         )
 
-    def todatetime(self):
-        """Convert to datetime
+    def _encode_time(self):
+        return self._value.strftime("%y%m%d%H%M%SZ").encode("ascii")
 
-        :returns: datetime
+    def _encode(self):
+        self._assert_ready()
+        value = self._encode_time()
+        return b"".join((self.tag, len_encode(len(value)), value))
 
-        Pay attention that UTCTime can not hold full year, so all years
-        having < 50 years are treated as 20xx, 19xx otherwise, according
-        to X.509 recomendation.
-        """
-        value = self._strptime(self._value.decode("ascii"))
-        year = value.year % 100
-        return datetime(
-            year=(2000 + year) if year < 50 else (1900 + year),
-            month=value.month,
-            day=value.day,
-            hour=value.hour,
-            minute=value.minute,
-            second=value.second,
-        )
+    def todatetime(self):
+        return self._value
 
     def __repr__(self):
         return pp_console_row(next(self.pps()))
@@ -4274,7 +4383,7 @@ class UTCTime(VisibleString):
             asn1_type_name=self.asn1_type_name,
             obj_name=self.__class__.__name__,
             decode_path=decode_path,
-            value=self.todatetime().isoformat() if self.ready else None,
+            value=self._pp_value(),
             optional=self.optional,
             default=self == self.default,
             impl=None if self.tag == self.tag_default else tag_decode(self.tag),
@@ -4309,13 +4418,19 @@ class GeneralizedTime(UTCTime):
 
     .. warning::
 
-       BER encoding is unsupported.
+       Only microsecond fractions are supported in DER encoding.
+       :py:exc:`pyderasn.DecodeError` will be raised during decoding of
+       higher precision values.
 
     .. warning::
 
-       Only microsecond fractions are supported.
-       :py:exc:`pyderasn.DecodeError` will be raised during decoding of
-       higher precision values.
+       BER encoded data can loss information (accuracy) during decoding
+       because of float transformations.
+
+    .. warning::
+
+       Local times (without explicit timezone specification) are treated
+       as UTC one, no transformations are made.
 
     .. warning::
 
@@ -4325,6 +4440,70 @@ class GeneralizedTime(UTCTime):
     tag_default = tag_encode(24)
     asn1_type_name = "GeneralizedTime"
 
+    def _dt_sanitize(self, value):
+        return value
+
+    def _strptime_bered(self, value):
+        if len(value) < 4 + 3 * 2:
+            raise ValueError("invalid GeneralizedTime")
+        decoded = datetime(
+            pureint(value[:4]),  # %Y
+            pureint(value[4:6]),  # %m
+            pureint(value[6:8]),  # %d
+            pureint(value[8:10]),  # %H
+        )
+        value = value[10:]
+        offset = 0
+        if len(value) == 0:
+            return offset, decoded
+        if value[-1] == "Z":
+            value = value[:-1]
+        else:
+            for char, sign in (("-", -1), ("+", 1)):
+                idx = value.rfind(char)
+                if idx == -1:
+                    continue
+                offset_raw = value[idx + 1:].replace(":", "")
+                if len(offset_raw) not in (2, 4):
+                    raise ValueError("invalid UTC offset")
+                value = value[:idx]
+                offset = 60 * pureint(offset_raw[2:] or "0")
+                if offset >= 3600:
+                    raise ValueError("invalid UTC offset minutes")
+                offset += 3600 * pureint(offset_raw[:2])
+                if offset > 14 * 3600:
+                    raise ValueError("too big UTC offset")
+                offset *= sign
+                break
+        if len(value) == 0:
+            return offset, decoded
+        decimal_signs = ".,"
+        if value[0] in decimal_signs:
+            return offset, (
+                decoded + timedelta(seconds=3600 * fractions2float(value[1:]))
+            )
+        if len(value) < 2:
+            raise ValueError("stripped minutes")
+        decoded += timedelta(seconds=60 * pureint(value[:2]))
+        value = value[2:]
+        if len(value) == 0:
+            return offset, decoded
+        if value[0] in decimal_signs:
+            return offset, (
+                decoded + timedelta(seconds=60 * fractions2float(value[1:]))
+            )
+        if len(value) < 2:
+            raise ValueError("stripped seconds")
+        decoded += timedelta(seconds=pureint(value[:2]))
+        value = value[2:]
+        if len(value) == 0:
+            return offset, decoded
+        if value[0] not in decimal_signs:
+            raise ValueError("invalid format after seconds")
+        return offset, (
+            decoded + timedelta(microseconds=10**6 * fractions2float(value[1:]))
+        )
+
     def _strptime(self, value):
         l = len(value)
         if l == LEN_YYYYMMDDHHMMSSZ:
@@ -4364,31 +4543,12 @@ class GeneralizedTime(UTCTime):
             return decoded
         raise ValueError("invalid GeneralizedTime length")
 
-    def _value_sanitize(self, value):
-        if isinstance(value, binary_type):
-            try:
-                value_decoded = value.decode("ascii")
-            except (UnicodeEncodeError, UnicodeDecodeError) as err:
-                raise DecodeError("invalid GeneralizedTime encoding: %r" % err)
-            try:
-                self._strptime(value_decoded)
-            except (TypeError, ValueError) as err:
-                raise DecodeError(
-                    "invalid GeneralizedTime format: %r" % err,
-                    klass=self.__class__,
-                )
-            return value
-        if isinstance(value, self.__class__):
-            return value._value
-        if isinstance(value, datetime):
-            encoded = value.strftime("%Y%m%d%H%M%S")
-            if value.microsecond > 0:
-                encoded = encoded + (".%06d" % value.microsecond).rstrip("0")
-            return (encoded + "Z").encode("ascii")
-        raise InvalidValueType((self.__class__, datetime))
-
-    def todatetime(self):
-        return self._strptime(self._value.decode("ascii"))
+    def _encode_time(self):
+        value = self._value
+        encoded = value.strftime("%Y%m%d%H%M%S")
+        if value.microsecond > 0:
+            encoded += (".%06d" % value.microsecond).rstrip("0")
+        return (encoded + "Z").encode("ascii")
 
 
 class GraphicString(CommonString):