]> Cypherpunks.ru repositories - pyderasn.git/blobdiff - pyderasn.py
Change order of value sanitizers
[pyderasn.git] / pyderasn.py
index f9689286fdad36c4c537b49a75dde39fb45fcd8b..50faa0be870ced74ab9d33afb4ac0f4008846421 100755 (executable)
@@ -1551,10 +1551,10 @@ class Boolean(Obj):
                 self._value = default
 
     def _value_sanitize(self, value):
-        if issubclass(value.__class__, Boolean):
-            return value._value
         if isinstance(value, bool):
             return value
+        if issubclass(value.__class__, Boolean):
+            return value._value
         raise InvalidValueType((self.__class__, bool))
 
     @property
@@ -1800,10 +1800,10 @@ class Integer(Obj):
                 self._value = default
 
     def _value_sanitize(self, value):
-        if issubclass(value.__class__, Integer):
-            value = value._value
-        elif isinstance(value, integer_types):
+        if isinstance(value, integer_types):
             pass
+        elif issubclass(value.__class__, Integer):
+            value = value._value
         elif isinstance(value, str):
             value = self.specs.get(value)
             if value is None:
@@ -2047,6 +2047,9 @@ class Integer(Obj):
             yield pp
 
 
+SET01 = frozenset(("0", "1"))
+
+
 class BitString(Obj):
     """``BIT STRING`` bit string type
 
@@ -2152,8 +2155,6 @@ class BitString(Obj):
         return bit_len, bytes(octets)
 
     def _value_sanitize(self, value):
-        if issubclass(value.__class__, BitString):
-            return value._value
         if isinstance(value, (string_types, binary_type)):
             if (
                     isinstance(value, string_types) and
@@ -2161,7 +2162,7 @@ class BitString(Obj):
             ):
                 if value.endswith("'B"):
                     value = value[1:-2]
-                    if not set(value) <= set(("0", "1")):
+                    if not frozenset(value) <= SET01:
                         raise ValueError("B's coding contains unacceptable chars")
                     return self._bits2octets(value)
                 elif value.endswith("'H"):
@@ -2189,11 +2190,13 @@ class BitString(Obj):
                 bits.append(bit)
             if len(bits) == 0:
                 return self._bits2octets("")
-            bits = set(bits)
+            bits = frozenset(bits)
             return self._bits2octets("".join(
                 ("1" if bit in bits else "0")
                 for bit in six_xrange(max(bits) + 1)
             ))
+        if issubclass(value.__class__, BitString):
+            return value._value
         raise InvalidValueType((self.__class__, binary_type, string_types))
 
     @property
@@ -2602,10 +2605,10 @@ class OctetString(Obj):
         )
 
     def _value_sanitize(self, value):
-        if issubclass(value.__class__, OctetString):
-            value = value._value
-        elif isinstance(value, binary_type):
+        if isinstance(value, binary_type):
             pass
+        elif issubclass(value.__class__, OctetString):
+            value = value._value
         else:
             raise InvalidValueType((self.__class__, bytes))
         if not self._bound_min <= len(value) <= self._bound_max:
@@ -3566,7 +3569,7 @@ class AllowableCharsMixin(object):
     def allowable_chars(self):
         if PY2:
             return self._allowable_chars
-        return set(six_unichr(c) for c in self._allowable_chars)
+        return frozenset(six_unichr(c) for c in self._allowable_chars)
 
 
 class NumericString(AllowableCharsMixin, CommonString):
@@ -3582,11 +3585,11 @@ class NumericString(AllowableCharsMixin, CommonString):
     tag_default = tag_encode(18)
     encoding = "ascii"
     asn1_type_name = "NumericString"
-    _allowable_chars = set(digits.encode("ascii") + b" ")
+    _allowable_chars = frozenset(digits.encode("ascii") + b" ")
 
     def _value_sanitize(self, value):
         value = super(NumericString, self)._value_sanitize(value)
-        if not set(value) <= self._allowable_chars:
+        if not frozenset(value) <= self._allowable_chars:
             raise DecodeError("non-numeric value")
         return value
 
@@ -3603,13 +3606,13 @@ class PrintableString(AllowableCharsMixin, CommonString):
     tag_default = tag_encode(19)
     encoding = "ascii"
     asn1_type_name = "PrintableString"
-    _allowable_chars = set(
+    _allowable_chars = frozenset(
         (ascii_letters + digits + " '()+,-./:=?").encode("ascii")
     )
 
     def _value_sanitize(self, value):
         value = super(PrintableString, self)._value_sanitize(value)
-        if not set(value) <= self._allowable_chars:
+        if not frozenset(value) <= self._allowable_chars:
             raise DecodeError("non-printable value")
         return value
 
@@ -3705,10 +3708,6 @@ class UTCTime(CommonString):
                 self._value = default
 
     def _value_sanitize(self, value):
-        if isinstance(value, self.__class__):
-            return value._value
-        if isinstance(value, datetime):
-            return value.strftime(self.fmt).encode("ascii")
         if isinstance(value, binary_type):
             try:
                 value_decoded = value.decode("ascii")
@@ -3722,6 +3721,10 @@ class UTCTime(CommonString):
                 return value
             else:
                 raise DecodeError("invalid UTCTime length")
+        if isinstance(value, self.__class__):
+            return value._value
+        if isinstance(value, datetime):
+            return value.strftime(self.fmt).encode("ascii")
         raise InvalidValueType((self.__class__, datetime))
 
     def __eq__(self, their):
@@ -3807,12 +3810,6 @@ class GeneralizedTime(UTCTime):
     fmt_ms = "%Y%m%d%H%M%S.%fZ"
 
     def _value_sanitize(self, value):
-        if isinstance(value, self.__class__):
-            return value._value
-        if isinstance(value, datetime):
-            return value.strftime(
-                self.fmt_ms if value.microsecond > 0 else self.fmt
-            ).encode("ascii")
         if isinstance(value, binary_type):
             try:
                 value_decoded = value.decode("ascii")
@@ -3839,6 +3836,12 @@ class GeneralizedTime(UTCTime):
                     "invalid GeneralizedTime length",
                     klass=self.__class__,
                 )
+        if isinstance(value, self.__class__):
+            return value._value
+        if isinstance(value, datetime):
+            return value.strftime(
+                self.fmt_ms if value.microsecond > 0 else self.fmt
+            ).encode("ascii")
         raise InvalidValueType((self.__class__, datetime))
 
     def todatetime(self):
@@ -3964,8 +3967,6 @@ class Choice(Obj):
                 self._value = default_obj.copy()._value
 
     def _value_sanitize(self, value):
-        if isinstance(value, self.__class__):
-            return value._value
         if isinstance(value, tuple) and len(value) == 2:
             choice, obj = value
             spec = self.specs.get(choice)
@@ -3974,6 +3975,8 @@ class Choice(Obj):
             if not isinstance(obj, spec.__class__):
                 raise InvalidValueType((spec,))
             return (choice, spec(obj))
+        if isinstance(value, self.__class__):
+            return value._value
         raise InvalidValueType((self.__class__, tuple))
 
     @property
@@ -4209,12 +4212,12 @@ class Any(Obj):
         self.defined = None
 
     def _value_sanitize(self, value):
+        if isinstance(value, binary_type):
+            return value
         if isinstance(value, self.__class__):
             return value._value
         if isinstance(value, Obj):
             return value.encode()
-        if isinstance(value, binary_type):
-            return value
         raise InvalidValueType((self.__class__, Obj, binary_type))
 
     @property