X-Git-Url: http://www.git.cypherpunks.ru/?a=blobdiff_plain;f=pyderasn.py;h=50faa0be870ced74ab9d33afb4ac0f4008846421;hb=97fe8b8eae995499566e3ae26e2e941e633997bc;hp=1245a435792c30b33c733aba5409bdd816173d53;hpb=41b748d4ced632131e9edf607d9e32cd590fa464;p=pyderasn.git diff --git a/pyderasn.py b/pyderasn.py index 1245a43..50faa0b 100755 --- a/pyderasn.py +++ b/pyderasn.py @@ -555,6 +555,8 @@ from six import indexbytes from six import int2byte from six import integer_types from six import iterbytes +from six import iteritems +from six import itervalues from six import PY2 from six import string_types from six import text_type @@ -1549,10 +1551,10 @@ class Boolean(Obj): self._value = default def _value_sanitize(self, value): - if issubclass(value.__class__, Boolean): - return value._value if isinstance(value, bool): return value + if issubclass(value.__class__, Boolean): + return value._value raise InvalidValueType((self.__class__, bool)) @property @@ -1798,10 +1800,10 @@ class Integer(Obj): self._value = default def _value_sanitize(self, value): - if issubclass(value.__class__, Integer): - value = value._value - elif isinstance(value, integer_types): + if isinstance(value, integer_types): pass + elif issubclass(value.__class__, Integer): + value = value._value elif isinstance(value, str): value = self.specs.get(value) if value is None: @@ -1861,7 +1863,7 @@ class Integer(Obj): @property def named(self): - for name, value in self.specs.items(): + for name, value in iteritems(self.specs): if value == self._value: return name @@ -2045,6 +2047,9 @@ class Integer(Obj): yield pp +SET01 = frozenset(("0", "1")) + + class BitString(Obj): """``BIT STRING`` bit string type @@ -2150,8 +2155,6 @@ class BitString(Obj): return bit_len, bytes(octets) def _value_sanitize(self, value): - if issubclass(value.__class__, BitString): - return value._value if isinstance(value, (string_types, binary_type)): if ( isinstance(value, string_types) and @@ -2159,7 +2162,7 @@ class BitString(Obj): ): if value.endswith("'B"): value = value[1:-2] - if not set(value) <= set(("0", "1")): + if not frozenset(value) <= SET01: raise ValueError("B's coding contains unacceptable chars") return self._bits2octets(value) elif value.endswith("'H"): @@ -2187,11 +2190,13 @@ class BitString(Obj): bits.append(bit) if len(bits) == 0: return self._bits2octets("") - bits = set(bits) + bits = frozenset(bits) return self._bits2octets("".join( ("1" if bit in bits else "0") for bit in six_xrange(max(bits) + 1) )) + if issubclass(value.__class__, BitString): + return value._value raise InvalidValueType((self.__class__, binary_type, string_types)) @property @@ -2243,7 +2248,7 @@ class BitString(Obj): @property def named(self): - return [name for name, bit in self.specs.items() if self[bit]] + return [name for name, bit in iteritems(self.specs) if self[bit]] def __call__( self, @@ -2600,10 +2605,10 @@ class OctetString(Obj): ) def _value_sanitize(self, value): - if issubclass(value.__class__, OctetString): - value = value._value - elif isinstance(value, binary_type): + if isinstance(value, binary_type): pass + elif issubclass(value.__class__, OctetString): + value = value._value else: raise InvalidValueType((self.__class__, bytes)) if not self._bound_min <= len(value) <= self._bound_max: @@ -3355,7 +3360,10 @@ class Enumerated(Integer): if isinstance(value, self.__class__): value = value._value elif isinstance(value, integer_types): - if value not in list(self.specs.values()): + for _value in itervalues(self.specs): + if _value == value: + break + else: raise DecodeError( "unknown integer value: %s" % value, klass=self.__class__, @@ -3561,7 +3569,7 @@ class AllowableCharsMixin(object): def allowable_chars(self): if PY2: return self._allowable_chars - return set(six_unichr(c) for c in self._allowable_chars) + return frozenset(six_unichr(c) for c in self._allowable_chars) class NumericString(AllowableCharsMixin, CommonString): @@ -3577,11 +3585,11 @@ class NumericString(AllowableCharsMixin, CommonString): tag_default = tag_encode(18) encoding = "ascii" asn1_type_name = "NumericString" - _allowable_chars = set(digits.encode("ascii") + b" ") + _allowable_chars = frozenset(digits.encode("ascii") + b" ") def _value_sanitize(self, value): value = super(NumericString, self)._value_sanitize(value) - if not set(value) <= self._allowable_chars: + if not frozenset(value) <= self._allowable_chars: raise DecodeError("non-numeric value") return value @@ -3598,13 +3606,13 @@ class PrintableString(AllowableCharsMixin, CommonString): tag_default = tag_encode(19) encoding = "ascii" asn1_type_name = "PrintableString" - _allowable_chars = set( + _allowable_chars = frozenset( (ascii_letters + digits + " '()+,-./:=?").encode("ascii") ) def _value_sanitize(self, value): value = super(PrintableString, self)._value_sanitize(value) - if not set(value) <= self._allowable_chars: + if not frozenset(value) <= self._allowable_chars: raise DecodeError("non-printable value") return value @@ -3700,10 +3708,6 @@ class UTCTime(CommonString): self._value = default def _value_sanitize(self, value): - if isinstance(value, self.__class__): - return value._value - if isinstance(value, datetime): - return value.strftime(self.fmt).encode("ascii") if isinstance(value, binary_type): try: value_decoded = value.decode("ascii") @@ -3717,6 +3721,10 @@ class UTCTime(CommonString): return value else: raise DecodeError("invalid UTCTime length") + if isinstance(value, self.__class__): + return value._value + if isinstance(value, datetime): + return value.strftime(self.fmt).encode("ascii") raise InvalidValueType((self.__class__, datetime)) def __eq__(self, their): @@ -3802,12 +3810,6 @@ class GeneralizedTime(UTCTime): fmt_ms = "%Y%m%d%H%M%S.%fZ" def _value_sanitize(self, value): - if isinstance(value, self.__class__): - return value._value - if isinstance(value, datetime): - return value.strftime( - self.fmt_ms if value.microsecond > 0 else self.fmt - ).encode("ascii") if isinstance(value, binary_type): try: value_decoded = value.decode("ascii") @@ -3834,6 +3836,12 @@ class GeneralizedTime(UTCTime): "invalid GeneralizedTime length", klass=self.__class__, ) + if isinstance(value, self.__class__): + return value._value + if isinstance(value, datetime): + return value.strftime( + self.fmt_ms if value.microsecond > 0 else self.fmt + ).encode("ascii") raise InvalidValueType((self.__class__, datetime)) def todatetime(self): @@ -3959,8 +3967,6 @@ class Choice(Obj): self._value = default_obj.copy()._value def _value_sanitize(self, value): - if isinstance(value, self.__class__): - return value._value if isinstance(value, tuple) and len(value) == 2: choice, obj = value spec = self.specs.get(choice) @@ -3969,6 +3975,8 @@ class Choice(Obj): if not isinstance(obj, spec.__class__): raise InvalidValueType((spec,)) return (choice, spec(obj)) + if isinstance(value, self.__class__): + return value._value raise InvalidValueType((self.__class__, tuple)) @property @@ -4064,7 +4072,7 @@ class Choice(Obj): return self._value[1].encode() def _decode(self, tlv, offset, decode_path, ctx, tag_only): - for choice, spec in self.specs.items(): + for choice, spec in iteritems(self.specs): sub_decode_path = decode_path + (choice,) try: spec.decode( @@ -4204,12 +4212,12 @@ class Any(Obj): self.defined = None def _value_sanitize(self, value): + if isinstance(value, binary_type): + return value if isinstance(value, self.__class__): return value._value if isinstance(value, Obj): return value.encode() - if isinstance(value, binary_type): - return value raise InvalidValueType((self.__class__, Obj, binary_type)) @property @@ -4549,7 +4557,7 @@ class Sequence(Obj): @property def ready(self): - for name, spec in self.specs.items(): + for name, spec in iteritems(self.specs): value = self._value.get(name) if value is None: if spec.optional: @@ -4564,7 +4572,7 @@ class Sequence(Obj): def bered(self): if self.expl_lenindef or self.lenindef or self.ber_encoded: return True - return any(value.bered for value in self._value.values()) + return any(value.bered for value in itervalues(self._value)) def copy(self): obj = self.__class__(schema=self.specs) @@ -4578,7 +4586,7 @@ class Sequence(Obj): obj.expl_lenindef = self.expl_lenindef obj.lenindef = self.lenindef obj.ber_encoded = self.ber_encoded - obj._value = {k: v.copy() for k, v in self._value.items()} + obj._value = {k: v.copy() for k, v in iteritems(self._value)} return obj def __eq__(self, their): @@ -4639,7 +4647,7 @@ class Sequence(Obj): def _encoded_values(self): raws = [] - for name, spec in self.specs.items(): + for name, spec in iteritems(self.specs): value = self._value.get(name) if value is None: if spec.optional: @@ -4705,7 +4713,7 @@ class Sequence(Obj): values = {} ber_encoded = False ctx_allow_default_values = ctx.get("allow_default_values", False) - for name, spec in self.specs.items(): + for name, spec in iteritems(self.specs): if spec.optional and ( (lenindef and v[:EOC_LEN].tobytes() == EOC) or len(v) == 0 @@ -4899,6 +4907,9 @@ class Set(Sequence): v = b"".join(raws) return b"".join((self.tag, len_encode(len(v)), v)) + def _specs_items(self): + return iteritems(self.specs) + def _decode(self, tlv, offset, decode_path, ctx, tag_only): try: t, tlen, lv = tag_strip(tlv) @@ -4953,11 +4964,11 @@ class Set(Sequence): ctx_allow_default_values = ctx.get("allow_default_values", False) ctx_allow_unordered_set = ctx.get("allow_unordered_set", False) value_prev = memoryview(v[:0]) - specs_items = self.specs.items + while len(v) > 0: if lenindef and v[:EOC_LEN].tobytes() == EOC: break - for name, spec in specs_items(): + for name, spec in self._specs_items(): sub_decode_path = decode_path + (name,) try: spec.decode(