]> Cypherpunks.ru repositories - pyderasn.git/blobdiff - pyderasn.py
Stricter validation of *Time
[pyderasn.git] / pyderasn.py
index 5adc5b068d17915f0c543144b66557b9fc355f30..78080294ec422a7a636503c4cd96e20892764f54 100755 (executable)
@@ -687,7 +687,7 @@ except ImportError:  # pragma: no cover
     def colored(what, *args, **kwargs):
         return what
 
-__version__ = "6.0"
+__version__ = "6.1"
 
 __all__ = (
     "Any",
@@ -3420,6 +3420,13 @@ ObjectIdentifierState = namedtuple("ObjectIdentifierState", (
 ))
 
 
+def pureint(value):
+    i = int(value)
+    if (value[0] in "+- ") or (value[-1] == " "):
+        raise ValueError("non-pure integer")
+    return i
+
+
 class ObjectIdentifier(Obj):
     """``OBJECT IDENTIFIER`` OID type
 
@@ -3497,7 +3504,7 @@ class ObjectIdentifier(Obj):
             return value._value
         if isinstance(value, string_types):
             try:
-                value = tuple(int(arc) for arc in value.split("."))
+                value = tuple(pureint(arc) for arc in value.split("."))
             except ValueError:
                 raise InvalidOID("unacceptable arcs values")
         if isinstance(value, tuple):
@@ -3511,6 +3518,8 @@ class ObjectIdentifier(Obj):
                 pass
             else:
                 raise InvalidOID("unacceptable first arc value")
+            if not all(arc >= 0 for arc in value):
+                raise InvalidOID("negative arc value")
             return value
         raise InvalidValueType((self.__class__, str, tuple))
 
@@ -4127,7 +4136,14 @@ LEN_YYYYMMDDHHMMSSDMZ = len("YYYYMMDDHHMMSSDMZ")
 LEN_YYYYMMDDHHMMSSZ = len("YYYYMMDDHHMMSSZ")
 
 
-class UTCTime(CommonString):
+class VisibleString(CommonString):
+    __slots__ = ()
+    tag_default = tag_encode(26)
+    encoding = "ascii"
+    asn1_type_name = "VisibleString"
+
+
+class UTCTime(VisibleString):
     """``UTCTime`` datetime type
 
     >>> t = UTCTime(datetime(2017, 9, 30, 22, 7, 50, 123))
@@ -4192,11 +4208,11 @@ class UTCTime(CommonString):
             raise ValueError("non UTC timezone")
         return datetime(
             2000 + int(value[:2]),  # %y
-            int(value[2:4]),  # %m
-            int(value[4:6]),  # %d
-            int(value[6:8]),  # %H
-            int(value[8:10]),  # %M
-            int(value[10:12]),  # %S
+            pureint(value[2:4]),  # %m
+            pureint(value[4:6]),  # %d
+            pureint(value[6:8]),  # %H
+            pureint(value[8:10]),  # %M
+            pureint(value[10:12]),  # %S
         )
 
     def _value_sanitize(self, value):
@@ -4316,12 +4332,12 @@ class GeneralizedTime(UTCTime):
             if value[-1] != "Z":
                 raise ValueError("non UTC timezone")
             return datetime(
-                int(value[:4]),  # %Y
-                int(value[4:6]),  # %m
-                int(value[6:8]),  # %d
-                int(value[8:10]),  # %H
-                int(value[10:12]),  # %M
-                int(value[12:14]),  # %S
+                pureint(value[:4]),  # %Y
+                pureint(value[4:6]),  # %m
+                pureint(value[6:8]),  # %d
+                pureint(value[8:10]),  # %H
+                pureint(value[10:12]),  # %M
+                pureint(value[12:14]),  # %S
             )
         if l >= LEN_YYYYMMDDHHMMSSDMZ:
             # datetime.strptime's format: %Y%m%d%H%M%S.%fZ
@@ -4335,14 +4351,14 @@ class GeneralizedTime(UTCTime):
             us_len = len(us)
             if us_len > 6:
                 raise ValueError("only microsecond fractions are supported")
-            us = int(us + ("0" * (6 - us_len)))
+            us = pureint(us + ("0" * (6 - us_len)))
             decoded = datetime(
-                int(value[:4]),  # %Y
-                int(value[4:6]),  # %m
-                int(value[6:8]),  # %d
-                int(value[8:10]),  # %H
-                int(value[10:12]),  # %M
-                int(value[12:14]),  # %S
+                pureint(value[:4]),  # %Y
+                pureint(value[4:6]),  # %m
+                pureint(value[6:8]),  # %d
+                pureint(value[8:10]),  # %H
+                pureint(value[10:12]),  # %M
+                pureint(value[12:14]),  # %S
                 us,  # %f
             )
             return decoded
@@ -4382,13 +4398,6 @@ class GraphicString(CommonString):
     asn1_type_name = "GraphicString"
 
 
-class VisibleString(CommonString):
-    __slots__ = ()
-    tag_default = tag_encode(26)
-    encoding = "ascii"
-    asn1_type_name = "VisibleString"
-
-
 class ISO646String(VisibleString):
     __slots__ = ()
     asn1_type_name = "ISO646String"