From 521a4868199657f49e0b20973dab53730b93fd54 Mon Sep 17 00:00:00 2001 From: Sergey Matveev Date: Thu, 6 Feb 2020 13:51:42 +0300 Subject: [PATCH] copy/pickle friendly Obj --- VERSION | 2 +- doc/examples.rst | 2 +- doc/install.rst | 12 +- doc/news.rst | 10 + pyderasn.py | 712 ++++++++++++++++++++++++++++++----------- tests/test_crts.py | 9 +- tests/test_pyderasn.py | 201 +++++++----- 7 files changed, 685 insertions(+), 263 deletions(-) diff --git a/VERSION b/VERSION index 2df33d7..e0ea36f 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -5.6 +6.0 diff --git a/doc/examples.rst b/doc/examples.rst index 5dace60..6f16af6 100644 --- a/doc/examples.rst +++ b/doc/examples.rst @@ -362,7 +362,7 @@ Let's create some simple self-signed X.509 certificate from the ground:: tbs["validity"] = validity spki = SubjectPublicKeyInfo() - spki_algo_id = sign_algo_id.copy() + spki_algo_id = copy(sign_algo_id) spki_algo_id["algorithm"] = ObjectIdentifier("1.2.840.113549.1.1.1") spki["algorithm"] = spki_algo_id spki["subjectPublicKey"] = BitString(hexdec("".join(( diff --git a/doc/install.rst b/doc/install.rst index 72a196c..074a774 100644 --- a/doc/install.rst +++ b/doc/install.rst @@ -4,11 +4,11 @@ Install Preferable way is to :ref:`download ` tarball with the signature from `official website `__:: - $ [fetch|wget] http://pyderasn.cypherpunks.ru/pyderasn-5.6.tar.xz - $ [fetch|wget] http://pyderasn.cypherpunks.ru/pyderasn-5.6.tar.xz.sig - $ gpg --verify pyderasn-5.6.tar.xz.sig pyderasn-5.6.tar.xz - $ xz --decompress --stdout pyderasn-5.6.tar.xz | tar xf - - $ cd pyderasn-5.6 + $ [fetch|wget] http://pyderasn.cypherpunks.ru/pyderasn-6.0.tar.xz + $ [fetch|wget] http://pyderasn.cypherpunks.ru/pyderasn-6.0.tar.xz.sig + $ gpg --verify pyderasn-6.0.tar.xz.sig pyderasn-6.0.tar.xz + $ xz --decompress --stdout pyderasn-6.0.tar.xz | tar xf - + $ cd pyderasn-6.0 $ python setup.py install # or copy pyderasn.py (+six.py, possibly termcolor.py) to your PYTHONPATH @@ -19,7 +19,7 @@ You can also find it mirrored on :ref:`download ` page. You could use pip (**no** OpenPGP authentication is performed!) with PyPI:: $ cat > requirements.txt < 0 - def copy(self): # pragma: no cover - """Make a copy of object, safe to be mutated + def __getstate__(self): # pragma: no cover + """Used for making safe to be mutable pickleable copies """ raise NotImplementedError() + def __setstate__(self, state): + if state.version != __version__: + raise ValueError("data is pickled by different PyDERASN version") + self.tag = self.tag_default + self._value = None + self._expl = None + self.default = None + self.optional = False + self.offset = 0 + self.llen = 0 + self.vlen = 0 + self.expl_lenindef = False + self.lenindef = False + self.ber_encoded = False + @property def tlen(self): """See :ref:`decoding` @@ -1181,7 +1199,8 @@ class Obj(object): :param tag_only: decode only the tag, without length and contents (used only in Choice and Set structures, trying to determine if tag satisfies the scheme) - :param _ctx_immutable: do we need to copy ``ctx`` before using it + :param _ctx_immutable: do we need to ``copy.copy()`` ``ctx`` + before using it? :returns: (Obj, remaining data) .. seealso:: :ref:`decoding` @@ -1691,6 +1710,22 @@ def pprint( # ASN.1 primitive types ######################################################################## +BooleanState = namedtuple("BooleanState", ( + "version", + "value", + "tag", + "expl", + "default", + "optional", + "offset", + "llen", + "vlen", + "expl_lenindef", + "lenindef", + "ber_encoded", +)) + + class Boolean(Obj): """``BOOLEAN`` boolean type @@ -1745,20 +1780,35 @@ class Boolean(Obj): def ready(self): return self._value is not None - def copy(self): - obj = self.__class__() - obj._value = self._value - obj.tag = self.tag - obj._expl = self._expl - obj.default = self.default - obj.optional = self.optional - obj.offset = self.offset - obj.llen = self.llen - obj.vlen = self.vlen - obj.expl_lenindef = self.expl_lenindef - obj.lenindef = self.lenindef - obj.ber_encoded = self.ber_encoded - return obj + def __getstate__(self): + return BooleanState( + __version__, + self._value, + self.tag, + self._expl, + self.default, + self.optional, + self.offset, + self.llen, + self.vlen, + self.expl_lenindef, + self.lenindef, + self.ber_encoded, + ) + + def __setstate__(self, state): + super(Boolean, self).__setstate__(state) + self._value = state.value + self.tag = state.tag + self._expl = state.expl + self.default = state.default + self.optional = state.optional + self.offset = state.offset + self.llen = state.llen + self.vlen = state.vlen + self.expl_lenindef = state.expl_lenindef + self.lenindef = state.lenindef + self.ber_encoded = state.ber_encoded def __nonzero__(self): self._assert_ready() @@ -1901,6 +1951,25 @@ class Boolean(Obj): yield pp +IntegerState = namedtuple("IntegerState", ( + "version", + "specs", + "value", + "bound_min", + "bound_max", + "tag", + "expl", + "default", + "optional", + "offset", + "llen", + "vlen", + "expl_lenindef", + "lenindef", + "ber_encoded", +)) + + class Integer(Obj): """``INTEGER`` integer type @@ -2002,22 +2071,41 @@ class Integer(Obj): def ready(self): return self._value is not None - def copy(self): - obj = self.__class__(_specs=self.specs) - obj._value = self._value - obj._bound_min = self._bound_min - obj._bound_max = self._bound_max - obj.tag = self.tag - obj._expl = self._expl - obj.default = self.default - obj.optional = self.optional - obj.offset = self.offset - obj.llen = self.llen - obj.vlen = self.vlen - obj.expl_lenindef = self.expl_lenindef - obj.lenindef = self.lenindef - obj.ber_encoded = self.ber_encoded - return obj + def __getstate__(self): + return IntegerState( + __version__, + self.specs, + self._value, + self._bound_min, + self._bound_max, + self.tag, + self._expl, + self.default, + self.optional, + self.offset, + self.llen, + self.vlen, + self.expl_lenindef, + self.lenindef, + self.ber_encoded, + ) + + def __setstate__(self, state): + super(Integer, self).__setstate__(state) + self.specs = state.specs + self._value = state.value + self._bound_min = state.bound_min + self._bound_max = state.bound_max + self.tag = state.tag + self._expl = state.expl + self.default = state.default + self.optional = state.optional + self.offset = state.offset + self.llen = state.llen + self.vlen = state.vlen + self.expl_lenindef = state.expl_lenindef + self.lenindef = state.lenindef + self.ber_encoded = state.ber_encoded def __int__(self): self._assert_ready() @@ -2233,6 +2321,23 @@ class Integer(Obj): SET01 = frozenset(("0", "1")) +BitStringState = namedtuple("BitStringState", ( + "version", + "specs", + "value", + "tag", + "expl", + "default", + "optional", + "offset", + "llen", + "vlen", + "expl_lenindef", + "lenindef", + "ber_encoded", + "tag_constructed", + "defined", +)) class BitString(Obj): @@ -2387,23 +2492,41 @@ class BitString(Obj): def ready(self): return self._value is not None - def copy(self): - obj = self.__class__(_specs=self.specs) - value = self._value - if value is not None: - value = (value[0], value[1]) - obj._value = value - obj.tag = self.tag - obj._expl = self._expl - obj.default = self.default - obj.optional = self.optional - obj.offset = self.offset - obj.llen = self.llen - obj.vlen = self.vlen - obj.expl_lenindef = self.expl_lenindef - obj.lenindef = self.lenindef - obj.ber_encoded = self.ber_encoded - return obj + def __getstate__(self): + return BitStringState( + __version__, + self.specs, + self._value, + self.tag, + self._expl, + self.default, + self.optional, + self.offset, + self.llen, + self.vlen, + self.expl_lenindef, + self.lenindef, + self.ber_encoded, + self.tag_constructed, + self.defined, + ) + + def __setstate__(self, state): + super(BitString, self).__setstate__(state) + self.specs = state.specs + self._value = state.value + self.tag = state.tag + self._expl = state.expl + self.default = state.default + self.optional = state.optional + self.offset = state.offset + self.llen = state.llen + self.vlen = state.vlen + self.expl_lenindef = state.expl_lenindef + self.lenindef = state.lenindef + self.ber_encoded = state.ber_encoded + self.tag_constructed = state.tag_constructed + self.defined = state.defined def __iter__(self): self._assert_ready() @@ -2708,6 +2831,26 @@ class BitString(Obj): yield pp +OctetStringState = namedtuple("OctetStringState", ( + "version", + "value", + "bound_min", + "bound_max", + "tag", + "expl", + "default", + "optional", + "offset", + "llen", + "vlen", + "expl_lenindef", + "lenindef", + "ber_encoded", + "tag_constructed", + "defined", +)) + + class OctetString(Obj): """``OCTET STRING`` binary string type @@ -2797,22 +2940,43 @@ class OctetString(Obj): def ready(self): return self._value is not None - def copy(self): - obj = self.__class__() - obj._value = self._value - obj._bound_min = self._bound_min - obj._bound_max = self._bound_max - obj.tag = self.tag - obj._expl = self._expl - obj.default = self.default - obj.optional = self.optional - obj.offset = self.offset - obj.llen = self.llen - obj.vlen = self.vlen - obj.expl_lenindef = self.expl_lenindef - obj.lenindef = self.lenindef - obj.ber_encoded = self.ber_encoded - return obj + def __getstate__(self): + return OctetStringState( + __version__, + self._value, + self._bound_min, + self._bound_max, + self.tag, + self._expl, + self.default, + self.optional, + self.offset, + self.llen, + self.vlen, + self.expl_lenindef, + self.lenindef, + self.ber_encoded, + self.tag_constructed, + self.defined, + ) + + def __setstate__(self, state): + super(OctetString, self).__setstate__(state) + self._value = state.value + self._bound_min = state.bound_min + self._bound_max = state.bound_max + self.tag = state.tag + self._expl = state.expl + self.default = state.default + self.optional = state.optional + self.offset = state.offset + self.llen = state.llen + self.vlen = state.vlen + self.expl_lenindef = state.expl_lenindef + self.lenindef = state.lenindef + self.ber_encoded = state.ber_encoded + self.tag_constructed = state.tag_constructed + self.defined = state.defined def __bytes__(self): self._assert_ready() @@ -3057,6 +3221,21 @@ class OctetString(Obj): yield pp +NullState = namedtuple("NullState", ( + "version", + "tag", + "expl", + "default", + "optional", + "offset", + "llen", + "vlen", + "expl_lenindef", + "lenindef", + "ber_encoded", +)) + + class Null(Obj): """``NULL`` null object @@ -3089,19 +3268,33 @@ class Null(Obj): def ready(self): return True - def copy(self): - obj = self.__class__() - obj.tag = self.tag - obj._expl = self._expl - obj.default = self.default - obj.optional = self.optional - obj.offset = self.offset - obj.llen = self.llen - obj.vlen = self.vlen - obj.expl_lenindef = self.expl_lenindef - obj.lenindef = self.lenindef - obj.ber_encoded = self.ber_encoded - return obj + def __getstate__(self): + return NullState( + __version__, + self.tag, + self._expl, + self.default, + self.optional, + self.offset, + self.llen, + self.vlen, + self.expl_lenindef, + self.lenindef, + self.ber_encoded, + ) + + def __setstate__(self, state): + super(Null, self).__setstate__(state) + self.tag = state.tag + self._expl = state.expl + self.default = state.default + self.optional = state.optional + self.offset = state.offset + self.llen = state.llen + self.vlen = state.vlen + self.expl_lenindef = state.expl_lenindef + self.lenindef = state.lenindef + self.ber_encoded = state.ber_encoded def __eq__(self, their): if not issubclass(their.__class__, Null): @@ -3196,6 +3389,23 @@ class Null(Obj): yield pp +ObjectIdentifierState = namedtuple("ObjectIdentifierState", ( + "version", + "value", + "tag", + "expl", + "default", + "optional", + "offset", + "llen", + "vlen", + "expl_lenindef", + "lenindef", + "ber_encoded", + "defines", +)) + + class ObjectIdentifier(Obj): """``OBJECT IDENTIFIER`` OID type @@ -3294,21 +3504,37 @@ class ObjectIdentifier(Obj): def ready(self): return self._value is not None - def copy(self): - obj = self.__class__() - obj._value = self._value - obj.defines = self.defines - obj.tag = self.tag - obj._expl = self._expl - obj.default = self.default - obj.optional = self.optional - obj.offset = self.offset - obj.llen = self.llen - obj.vlen = self.vlen - obj.expl_lenindef = self.expl_lenindef - obj.lenindef = self.lenindef - obj.ber_encoded = self.ber_encoded - return obj + def __getstate__(self): + return ObjectIdentifierState( + __version__, + self._value, + self.tag, + self._expl, + self.default, + self.optional, + self.offset, + self.llen, + self.vlen, + self.expl_lenindef, + self.lenindef, + self.ber_encoded, + self.defines, + ) + + def __setstate__(self, state): + super(ObjectIdentifier, self).__setstate__(state) + self._value = state.value + self.tag = state.tag + self._expl = state.expl + self.default = state.default + self.optional = state.optional + self.offset = state.offset + self.llen = state.llen + self.vlen = state.vlen + self.expl_lenindef = state.expl_lenindef + self.lenindef = state.lenindef + self.ber_encoded = state.ber_encoded + self.defines = state.defines def __iter__(self): self._assert_ready() @@ -3542,23 +3768,6 @@ class Enumerated(Integer): raise InvalidValueType((self.__class__, int, str)) return value - def copy(self): - obj = self.__class__(_specs=self.specs) - obj._value = self._value - obj._bound_min = self._bound_min - obj._bound_max = self._bound_max - obj.tag = self.tag - obj._expl = self._expl - obj.default = self.default - obj.optional = self.optional - obj.offset = self.offset - obj.llen = self.llen - obj.vlen = self.vlen - obj.expl_lenindef = self.expl_lenindef - obj.lenindef = self.lenindef - obj.ber_encoded = self.ber_encoded - return obj - def __call__( self, value=None, @@ -3769,6 +3978,12 @@ class NumericString(AllowableCharsMixin, CommonString): return value +PrintableStringState = namedtuple( + "PrintableStringState", + OctetStringState._fields + ("allowable_chars",), +) + + class PrintableString(AllowableCharsMixin, CommonString): """Printable string @@ -3817,10 +4032,15 @@ class PrintableString(AllowableCharsMixin, CommonString): raise DecodeError("non-printable value") return value - def copy(self): - obj = super(PrintableString, self).copy() - obj._allowable_chars = self._allowable_chars - return obj + def __getstate__(self): + return PrintableStringState( + *super(PrintableString, self).__getstate__(), + **{"allowable_chars": self._allowable_chars} + ) + + def __setstate__(self, state): + super(PrintableString, self).__setstate__(state) + self._allowable_chars = state.allowable_chars def __call__( self, @@ -4161,6 +4381,23 @@ class BMPString(CommonString): asn1_type_name = "BMPString" +ChoiceState = namedtuple("ChoiceState", ( + "version", + "specs", + "value", + "tag", + "expl", + "default", + "optional", + "offset", + "llen", + "vlen", + "expl_lenindef", + "lenindef", + "ber_encoded", +)) + + class Choice(Obj): """``CHOICE`` special type @@ -4234,7 +4471,7 @@ class Choice(Obj): default_obj._value = default_value self.default = default_obj if value is None: - self._value = default_obj.copy()._value + self._value = copy(default_obj._value) def _value_sanitize(self, value): if isinstance(value, tuple) and len(value) == 2: @@ -4260,21 +4497,36 @@ class Choice(Obj): self._value[1].bered ) - def copy(self): - obj = self.__class__(schema=self.specs) - obj._expl = self._expl - obj.default = self.default - obj.optional = self.optional - obj.offset = self.offset - obj.llen = self.llen - obj.vlen = self.vlen - obj.expl_lenindef = self.expl_lenindef - obj.lenindef = self.lenindef - obj.ber_encoded = self.ber_encoded - value = self._value - if value is not None: - obj._value = (value[0], value[1].copy()) - return obj + def __getstate__(self): + return ChoiceState( + __version__, + self.specs, + copy(self._value), + self.tag, + self._expl, + self.default, + self.optional, + self.offset, + self.llen, + self.vlen, + self.expl_lenindef, + self.lenindef, + self.ber_encoded, + ) + + def __setstate__(self, state): + super(Choice, self).__setstate__(state) + self.specs = state.specs + self._value = state.value + self._expl = state.expl + self.default = state.default + self.optional = state.optional + self.offset = state.offset + self.llen = state.llen + self.vlen = state.vlen + self.expl_lenindef = state.expl_lenindef + self.lenindef = state.lenindef + self.ber_encoded = state.ber_encoded def __eq__(self, their): if isinstance(their, tuple) and len(their) == 2: @@ -4448,6 +4700,22 @@ class PrimitiveTypes(Choice): )) +AnyState = namedtuple("AnyState", ( + "version", + "value", + "tag", + "expl", + "optional", + "offset", + "llen", + "vlen", + "expl_lenindef", + "lenindef", + "ber_encoded", + "defined", +)) + + class Any(Obj): """``ANY`` special type @@ -4502,19 +4770,35 @@ class Any(Obj): return False return self.defined[1].bered - def copy(self): - obj = self.__class__() - obj._value = self._value - obj.tag = self.tag - obj._expl = self._expl - obj.optional = self.optional - obj.offset = self.offset - obj.llen = self.llen - obj.vlen = self.vlen - obj.expl_lenindef = self.expl_lenindef - obj.lenindef = self.lenindef - obj.ber_encoded = self.ber_encoded - return obj + def __getstate__(self): + return AnyState( + __version__, + self._value, + self.tag, + self._expl, + self.optional, + self.offset, + self.llen, + self.vlen, + self.expl_lenindef, + self.lenindef, + self.ber_encoded, + self.defined, + ) + + def __setstate__(self, state): + super(Any, self).__setstate__(state) + self._value = state.value + self.tag = state.tag + self._expl = state.expl + self.optional = state.optional + self.offset = state.offset + self.llen = state.llen + self.vlen = state.vlen + self.expl_lenindef = state.expl_lenindef + self.lenindef = state.lenindef + self.ber_encoded = state.ber_encoded + self.defined = state.defined def __eq__(self, their): if isinstance(their, binary_type): @@ -4694,6 +4978,23 @@ def abs_decode_path(decode_path, rel_path): return decode_path + rel_path +SequenceState = namedtuple("SequenceState", ( + "version", + "specs", + "value", + "tag", + "expl", + "default", + "optional", + "offset", + "llen", + "vlen", + "expl_lenindef", + "lenindef", + "ber_encoded", +)) + + class Sequence(Obj): """``SEQUENCE`` structure type @@ -4823,7 +5124,7 @@ class Sequence(Obj): default_obj._value = default_value self.default = default_obj if value is None: - self._value = default_obj.copy()._value + self._value = copy(default_obj._value) @property def ready(self): @@ -4843,20 +5144,37 @@ class Sequence(Obj): return True return any(value.bered for value in itervalues(self._value)) - def copy(self): - obj = self.__class__(schema=self.specs) - obj.tag = self.tag - obj._expl = self._expl - obj.default = self.default - obj.optional = self.optional - obj.offset = self.offset - obj.llen = self.llen - obj.vlen = self.vlen - obj.expl_lenindef = self.expl_lenindef - obj.lenindef = self.lenindef - obj.ber_encoded = self.ber_encoded - obj._value = {k: v.copy() for k, v in iteritems(self._value)} - return obj + def __getstate__(self): + return SequenceState( + __version__, + self.specs, + {k: copy(v) for k, v in iteritems(self._value)}, + self.tag, + self._expl, + self.default, + self.optional, + self.offset, + self.llen, + self.vlen, + self.expl_lenindef, + self.lenindef, + self.ber_encoded, + ) + + def __setstate__(self, state): + super(Sequence, self).__setstate__(state) + self.specs = state.specs + self._value = state.value + self.tag = state.tag + self._expl = state.expl + self.default = state.default + self.optional = state.optional + self.offset = state.offset + self.llen = state.llen + self.vlen = state.vlen + self.expl_lenindef = state.expl_lenindef + self.lenindef = state.lenindef + self.ber_encoded = state.ber_encoded def __eq__(self, their): if not isinstance(their, self.__class__): @@ -5323,6 +5641,25 @@ class Set(Sequence): return obj, tail +SequenceOfState = namedtuple("SequenceOfState", ( + "version", + "spec", + "value", + "bound_min", + "bound_max", + "tag", + "expl", + "default", + "optional", + "offset", + "llen", + "vlen", + "expl_lenindef", + "lenindef", + "ber_encoded", +)) + + class SequenceOf(Obj): """``SEQUENCE OF`` sequence type @@ -5392,7 +5729,7 @@ class SequenceOf(Obj): default_obj._value = default_value self.default = default_obj if value is None: - self._value = default_obj.copy()._value + self._value = copy(default_obj._value) def _value_sanitize(self, value): if issubclass(value.__class__, SequenceOf): @@ -5418,22 +5755,41 @@ class SequenceOf(Obj): return True return any(v.bered for v in self._value) - def copy(self): - obj = self.__class__(schema=self.spec) - obj._bound_min = self._bound_min - obj._bound_max = self._bound_max - obj.tag = self.tag - obj._expl = self._expl - obj.default = self.default - obj.optional = self.optional - obj.offset = self.offset - obj.llen = self.llen - obj.vlen = self.vlen - obj.expl_lenindef = self.expl_lenindef - obj.lenindef = self.lenindef - obj.ber_encoded = self.ber_encoded - obj._value = [v.copy() for v in self._value] - return obj + def __getstate__(self): + return SequenceOfState( + __version__, + self.spec, + [copy(v) for v in self._value], + self._bound_min, + self._bound_max, + self.tag, + self._expl, + self.default, + self.optional, + self.offset, + self.llen, + self.vlen, + self.expl_lenindef, + self.lenindef, + self.ber_encoded, + ) + + def __setstate__(self, state): + super(SequenceOf, self).__setstate__(state) + self.spec = state.spec + self._value = state.value + self._bound_min = state.bound_min + self._bound_max = state.bound_max + self.tag = state.tag + self._expl = state.expl + self.default = state.default + self.optional = state.optional + self.offset = state.offset + self.llen = state.llen + self.vlen = state.vlen + self.expl_lenindef = state.expl_lenindef + self.lenindef = state.lenindef + self.ber_encoded = state.ber_encoded def __eq__(self, their): if isinstance(their, self.__class__): diff --git a/tests/test_crts.py b/tests/test_crts.py index 4b58fac..d878529 100644 --- a/tests/test_crts.py +++ b/tests/test_crts.py @@ -15,9 +15,14 @@ # License along with this program. If not, see # . +from copy import copy from datetime import datetime from unittest import TestCase +from six.moves.cPickle import dumps as pickle_dumps +from six.moves.cPickle import HIGHEST_PROTOCOL as pickle_proto +from six.moves.cPickle import loads as pickle_loads + from pyderasn import Any from pyderasn import BitString from pyderasn import Boolean @@ -279,6 +284,7 @@ class TestGoSelfSignedVector(TestCase): self.assertSequenceEqual(crt.encode(), raw) pprint(crt) repr(crt) + pickle_loads(pickle_dumps(crt, pickle_proto)) tbs = TBSCertificate() tbs["serialNumber"] = CertificateSerialNumber(10143011886257155224) @@ -322,7 +328,7 @@ class TestGoSelfSignedVector(TestCase): tbs["validity"] = validity spki = SubjectPublicKeyInfo() - spki_algo_id = sign_algo_id.copy() + spki_algo_id = copy(sign_algo_id) spki_algo_id["algorithm"] = ObjectIdentifier("1.2.840.113549.1.1.1") spki["algorithm"] = spki_algo_id spki["subjectPublicKey"] = BitString(hexdec("".join(( @@ -400,3 +406,4 @@ class TestGoPayPalVector(TestCase): self.assertSequenceEqual(crt.encode(), raw) pprint(crt) repr(crt) + pickle_loads(pickle_dumps(crt, pickle_proto)) diff --git a/tests/test_pyderasn.py b/tests/test_pyderasn.py index aaac6cd..a39d125 100644 --- a/tests/test_pyderasn.py +++ b/tests/test_pyderasn.py @@ -15,12 +15,15 @@ # License along with this program. If not, see # . +from copy import copy from copy import deepcopy from datetime import datetime +from importlib import import_module from string import ascii_letters from string import digits from string import printable from string import whitespace +from time import time from unittest import TestCase from hypothesis import assume @@ -51,6 +54,9 @@ from six import iterbytes from six import PY2 from six import text_type from six import unichr as six_unichr +from six.moves.cPickle import dumps as pickle_dumps +from six.moves.cPickle import HIGHEST_PROTOCOL as pickle_proto +from six.moves.cPickle import loads as pickle_loads from pyderasn import _pp from pyderasn import abs_decode_path @@ -134,6 +140,18 @@ decode_path_strat = lists(integers(), max_size=3).map( lambda decode_path: tuple(str(dp) for dp in decode_path) ) ctx_dummy = dictionaries(integers(), integers(), min_size=2, max_size=4).example() +copy_funcs = ( + copy, + lambda obj: pickle_loads(pickle_dumps(obj, pickle_proto)), +) +self_module = import_module(__name__) + + +def register_class(klass): + klassname = klass.__name__ + str(time()).replace(".", "") + klass.__name__ = klassname + klass.__qualname__ = klassname + setattr(self_module, klassname, klass) def assert_exceeding_data(self, call, junk): @@ -456,8 +474,9 @@ class TestBoolean(CommonMixin, TestCase): def test_copy(self, values): for klass in (Boolean, BooleanInherited): obj = klass(*values) - obj_copied = obj.copy() - self.assert_copied_basic_fields(obj, obj_copied) + for copy_func in copy_funcs: + obj_copied = copy_func(obj) + self.assert_copied_basic_fields(obj, obj_copied) @given( booleans(), @@ -633,7 +652,7 @@ class TestBoolean(CommonMixin, TestCase): self.assertTrue(obj.ber_encoded) self.assertFalse(obj.lenindef) self.assertTrue(obj.bered) - obj = obj.copy() + obj = copy(obj) self.assertTrue(obj.ber_encoded) self.assertFalse(obj.lenindef) self.assertTrue(obj.bered) @@ -656,7 +675,7 @@ class TestBoolean(CommonMixin, TestCase): self.assertFalse(obj.lenindef) self.assertFalse(obj.ber_encoded) self.assertTrue(obj.bered) - obj = obj.copy() + obj = copy(obj) self.assertTrue(obj.expl_lenindef) self.assertFalse(obj.lenindef) self.assertFalse(obj.ber_encoded) @@ -954,12 +973,13 @@ class TestInteger(CommonMixin, TestCase): def test_copy(self, values): for klass in (Integer, IntegerInherited): obj = klass(*values) - obj_copied = obj.copy() - self.assert_copied_basic_fields(obj, obj_copied) - self.assertEqual(obj.specs, obj_copied.specs) - self.assertEqual(obj._bound_min, obj_copied._bound_min) - self.assertEqual(obj._bound_max, obj_copied._bound_max) - self.assertEqual(obj._value, obj_copied._value) + for copy_func in copy_funcs: + obj_copied = copy_func(obj) + self.assert_copied_basic_fields(obj, obj_copied) + self.assertEqual(obj.specs, obj_copied.specs) + self.assertEqual(obj._bound_min, obj_copied._bound_min) + self.assertEqual(obj._bound_max, obj_copied._bound_max) + self.assertEqual(obj._value, obj_copied._value) @given( integers(), @@ -1359,6 +1379,7 @@ class TestBitString(CommonMixin, TestCase): class BS(klass): schema = _schema + register_class(BS) obj = BS( value=value, impl=impl, @@ -1367,10 +1388,11 @@ class TestBitString(CommonMixin, TestCase): optional=optional or False, _decoded=_decoded, ) - obj_copied = obj.copy() - self.assert_copied_basic_fields(obj, obj_copied) - self.assertEqual(obj.specs, obj_copied.specs) - self.assertEqual(obj._value, obj_copied._value) + for copy_func in copy_funcs: + obj_copied = copy_func(obj) + self.assert_copied_basic_fields(obj, obj_copied) + self.assertEqual(obj.specs, obj_copied.specs) + self.assertEqual(obj._value, obj_copied._value) @given( binary(), @@ -1614,7 +1636,7 @@ class TestBitString(CommonMixin, TestCase): self.assertTrue(obj.ber_encoded) self.assertEqual(obj.lenindef, lenindef_expected) self.assertTrue(obj.bered) - obj = obj.copy() + obj = copy(obj) self.assertTrue(obj.ber_encoded) self.assertEqual(obj.lenindef, lenindef_expected) self.assertTrue(obj.bered) @@ -1750,7 +1772,7 @@ class TestBitString(CommonMixin, TestCase): self.assertTrue(obj.ber_encoded) self.assertTrue(obj.lenindef) self.assertTrue(obj.bered) - obj = obj.copy() + obj = copy(obj) self.assertTrue(obj.ber_encoded) self.assertTrue(obj.lenindef) self.assertTrue(obj.bered) @@ -1947,11 +1969,12 @@ class TestOctetString(CommonMixin, TestCase): def test_copy(self, values): for klass in (OctetString, OctetStringInherited): obj = klass(*values) - obj_copied = obj.copy() - self.assert_copied_basic_fields(obj, obj_copied) - self.assertEqual(obj._bound_min, obj_copied._bound_min) - self.assertEqual(obj._bound_max, obj_copied._bound_max) - self.assertEqual(obj._value, obj_copied._value) + for copy_func in copy_funcs: + obj_copied = copy_func(obj) + self.assert_copied_basic_fields(obj, obj_copied) + self.assertEqual(obj._bound_min, obj_copied._bound_min) + self.assertEqual(obj._bound_max, obj_copied._bound_max) + self.assertEqual(obj._value, obj_copied._value) @given( binary(), @@ -2146,7 +2169,7 @@ class TestOctetString(CommonMixin, TestCase): self.assertTrue(obj.ber_encoded) self.assertEqual(obj.lenindef, lenindef_expected) self.assertTrue(obj.bered) - obj = obj.copy() + obj = copy(obj) self.assertTrue(obj.ber_encoded) self.assertEqual(obj.lenindef, lenindef_expected) self.assertTrue(obj.bered) @@ -2286,8 +2309,9 @@ class TestNull(CommonMixin, TestCase): optional=optional or False, _decoded=_decoded, ) - obj_copied = obj.copy() - self.assert_copied_basic_fields(obj, obj_copied) + for copy_func in copy_funcs: + obj_copied = copy_func(obj) + self.assert_copied_basic_fields(obj, obj_copied) @given(integers(min_value=1).map(tag_encode)) def test_stripped(self, tag_impl): @@ -2562,9 +2586,10 @@ class TestObjectIdentifier(CommonMixin, TestCase): optional=optional, _decoded=_decoded, ) - obj_copied = obj.copy() - self.assert_copied_basic_fields(obj, obj_copied) - self.assertEqual(obj._value, obj_copied._value) + for copy_func in copy_funcs: + obj_copied = copy_func(obj) + self.assert_copied_basic_fields(obj, obj_copied) + self.assertEqual(obj._value, obj_copied._value) @settings(max_examples=LONG_TEST_MAX_EXAMPLES) @given( @@ -2788,7 +2813,7 @@ class TestObjectIdentifier(CommonMixin, TestCase): obj, _ = ObjectIdentifier().decode(tampered, ctx={"bered": True}) self.assertTrue(obj.ber_encoded) self.assertTrue(obj.bered) - obj = obj.copy() + obj = copy(obj) self.assertTrue(obj.ber_encoded) self.assertTrue(obj.bered) with assertRaisesRegex(self, DecodeError, "non normalized arc encoding"): @@ -2821,7 +2846,7 @@ class TestObjectIdentifier(CommonMixin, TestCase): obj, _ = ObjectIdentifier().decode(tampered, ctx={"bered": True}) self.assertTrue(obj.ber_encoded) self.assertTrue(obj.bered) - obj = obj.copy() + obj = copy(obj) self.assertTrue(obj.ber_encoded) self.assertTrue(obj.bered) with assertRaisesRegex(self, DecodeError, "non normalized arc encoding"): @@ -3007,6 +3032,7 @@ class TestEnumerated(CommonMixin, TestCase): class E(Enumerated): schema = schema_input + register_class(E) obj = E( value=value, impl=impl, @@ -3015,9 +3041,10 @@ class TestEnumerated(CommonMixin, TestCase): optional=optional, _decoded=_decoded, ) - obj_copied = obj.copy() - self.assert_copied_basic_fields(obj, obj_copied) - self.assertEqual(obj.specs, obj_copied.specs) + for copy_func in copy_funcs: + obj_copied = copy_func(obj) + self.assert_copied_basic_fields(obj, obj_copied) + self.assertEqual(obj.specs, obj_copied.specs) @settings(max_examples=LONG_TEST_MAX_EXAMPLES) @given(data_strategy()) @@ -3281,11 +3308,12 @@ class StringMixin(object): def test_copy(self, d): values = d.draw(string_values_strategy(self.text_alphabet())) obj = self.base_klass(*values) - obj_copied = obj.copy() - self.assert_copied_basic_fields(obj, obj_copied) - self.assertEqual(obj._bound_min, obj_copied._bound_min) - self.assertEqual(obj._bound_max, obj_copied._bound_max) - self.assertEqual(obj._value, obj_copied._value) + for copy_func in copy_funcs: + obj_copied = copy_func(obj) + self.assert_copied_basic_fields(obj, obj_copied) + self.assertEqual(obj._bound_min, obj_copied._bound_min) + self.assertEqual(obj._bound_max, obj_copied._bound_max) + self.assertEqual(obj._value, obj_copied._value) @given(data_strategy()) def test_stripped(self, d): @@ -3518,7 +3546,7 @@ class TestPrintableString( self.base_klass(s, **kwargs) klass = self.base_klass(**kwargs) obj = klass(s) - obj = obj.copy() + obj = copy(obj) obj(s) @@ -3583,7 +3611,7 @@ class TestVisibleString( self.assertTrue(obj.ber_encoded) self.assertFalse(obj.lenindef) self.assertTrue(obj.bered) - obj = obj.copy() + obj = copy(obj) self.assertTrue(obj.ber_encoded) self.assertFalse(obj.lenindef) self.assertTrue(obj.bered) @@ -3597,7 +3625,7 @@ class TestVisibleString( self.assertTrue(obj.ber_encoded) self.assertTrue(obj.lenindef) self.assertTrue(obj.bered) - obj = obj.copy() + obj = copy(obj) self.assertTrue(obj.ber_encoded) self.assertTrue(obj.lenindef) self.assertTrue(obj.bered) @@ -3781,9 +3809,10 @@ class TimeMixin(object): max_datetime=self.max_datetime, )) obj = self.base_klass(*values) - obj_copied = obj.copy() - self.assert_copied_basic_fields(obj, obj_copied) - self.assertEqual(obj._value, obj_copied._value) + for copy_func in copy_funcs: + obj_copied = copy_func(obj) + self.assert_copied_basic_fields(obj, obj_copied) + self.assertEqual(obj._value, obj_copied._value) @given(data_strategy()) def test_stripped(self, d): @@ -4208,9 +4237,10 @@ class TestAny(CommonMixin, TestCase): def test_copy(self, values): for klass in (Any, AnyInherited): obj = klass(*values) - obj_copied = obj.copy() - self.assert_copied_basic_fields(obj, obj_copied) - self.assertEqual(obj._value, obj_copied._value) + for copy_func in copy_funcs: + obj_copied = copy_func(obj) + self.assert_copied_basic_fields(obj, obj_copied) + self.assertEqual(obj._value, obj_copied._value) @given(binary().map(OctetString)) def test_stripped(self, value): @@ -4352,7 +4382,7 @@ class TestAny(CommonMixin, TestCase): self.assertTrue(obj.lenindef) self.assertFalse(obj.ber_encoded) self.assertTrue(obj.bered) - obj = obj.copy() + obj = copy(obj) self.assertTrue(obj.lenindef) self.assertFalse(obj.ber_encoded) self.assertTrue(obj.bered) @@ -4584,6 +4614,7 @@ class TestChoice(CommonMixin, TestCase): class Wahl(self.base_klass): schema = _schema + register_class(Wahl) obj = Wahl( value=value, expl=expl, @@ -4591,15 +4622,17 @@ class TestChoice(CommonMixin, TestCase): optional=optional or False, _decoded=_decoded, ) - obj_copied = obj.copy() - self.assertIsNone(obj.tag) - self.assertIsNone(obj_copied.tag) - # hack for assert_copied_basic_fields - obj.tag = "whatever" - obj_copied.tag = "whatever" - self.assert_copied_basic_fields(obj, obj_copied) - self.assertEqual(obj._value, obj_copied._value) - self.assertEqual(obj.specs, obj_copied.specs) + for copy_func in copy_funcs: + obj_copied = copy_func(obj) + self.assertIsNone(obj.tag) + self.assertIsNone(obj_copied.tag) + # hack for assert_copied_basic_fields + obj.tag = "whatever" + obj_copied.tag = "whatever" + self.assert_copied_basic_fields(obj, obj_copied) + obj.tag = None + self.assertEqual(obj._value, obj_copied._value) + self.assertEqual(obj.specs, obj_copied.specs) @given(booleans()) def test_stripped(self, value): @@ -5035,13 +5068,15 @@ class SeqMixing(object): def test_copy(self, d): class SeqInherited(self.base_klass): pass + register_class(SeqInherited) for klass in (self.base_klass, SeqInherited): values = d.draw(seq_values_strategy(seq_klass=klass)) obj = klass(*values) - obj_copied = obj.copy() - self.assert_copied_basic_fields(obj, obj_copied) - self.assertEqual(obj.specs, obj_copied.specs) - self.assertEqual(obj._value, obj_copied._value) + for copy_func in copy_funcs: + obj_copied = copy_func(obj) + self.assert_copied_basic_fields(obj, obj_copied) + self.assertEqual(obj.specs, obj_copied.specs) + self.assertEqual(obj._value, obj_copied._value) @given(data_strategy()) def test_stripped(self, d): @@ -5170,7 +5205,7 @@ class SeqMixing(object): self.assertDictEqual(ctx_copied, ctx_dummy) self.assertTrue(seq_decoded_lenindef.lenindef) self.assertTrue(seq_decoded_lenindef.bered) - seq_decoded_lenindef = seq_decoded_lenindef.copy() + seq_decoded_lenindef = copy(seq_decoded_lenindef) self.assertTrue(seq_decoded_lenindef.lenindef) self.assertTrue(seq_decoded_lenindef.bered) with self.assertRaises(DecodeError): @@ -5296,7 +5331,7 @@ class SeqMixing(object): seq_decoded, _ = seq_with_default.decode(seq_encoded, ctx=ctx) self.assertTrue(seq_decoded.ber_encoded) self.assertTrue(seq_decoded.bered) - seq_decoded = seq_decoded.copy() + seq_decoded = copy(seq_decoded) self.assertTrue(seq_decoded.ber_encoded) self.assertTrue(seq_decoded.bered) for name, value in _schema: @@ -5336,7 +5371,7 @@ class SeqMixing(object): self.assertFalse(decoded.ber_encoded) self.assertFalse(decoded.lenindef) self.assertTrue(decoded.bered) - decoded = decoded.copy() + decoded = copy(decoded) self.assertFalse(decoded.ber_encoded) self.assertFalse(decoded.lenindef) self.assertTrue(decoded.bered) @@ -5356,7 +5391,7 @@ class SeqMixing(object): self.assertFalse(decoded.ber_encoded) self.assertFalse(decoded.lenindef) self.assertTrue(decoded.bered) - decoded = decoded.copy() + decoded = copy(decoded) self.assertFalse(decoded.ber_encoded) self.assertFalse(decoded.lenindef) self.assertTrue(decoded.bered) @@ -5456,7 +5491,7 @@ class TestSet(SeqMixing, CommonMixin, TestCase): seq_decoded, _ = Seq().decode(seq_encoded, ctx=ctx) self.assertTrue(seq_decoded.ber_encoded) self.assertTrue(seq_decoded.bered) - seq_decoded = seq_decoded.copy() + seq_decoded = copy(seq_decoded) self.assertTrue(seq_decoded.ber_encoded) self.assertTrue(seq_decoded.bered) self.assertSequenceEqual( @@ -5734,6 +5769,7 @@ class SeqOfMixing(object): class SeqOf(self.base_klass): schema = _schema + register_class(SeqOf) obj = SeqOf( value=value, bounds=bounds, @@ -5743,11 +5779,12 @@ class SeqOfMixing(object): optional=optional or False, _decoded=_decoded, ) - obj_copied = obj.copy() - self.assert_copied_basic_fields(obj, obj_copied) - self.assertEqual(obj._bound_min, obj_copied._bound_min) - self.assertEqual(obj._bound_max, obj_copied._bound_max) - self.assertEqual(obj._value, obj_copied._value) + for copy_func in copy_funcs: + obj_copied = copy_func(obj) + self.assert_copied_basic_fields(obj, obj_copied) + self.assertEqual(obj._bound_min, obj_copied._bound_min) + self.assertEqual(obj._bound_max, obj_copied._bound_max) + self.assertEqual(obj._value, obj_copied._value) @given( lists(binary()), @@ -5884,7 +5921,7 @@ class SeqOfMixing(object): ) self.assertTrue(obj_decoded_lenindef.lenindef) self.assertTrue(obj_decoded_lenindef.bered) - obj_decoded_lenindef = obj_decoded_lenindef.copy() + obj_decoded_lenindef = copy(obj_decoded_lenindef) self.assertTrue(obj_decoded_lenindef.lenindef) self.assertTrue(obj_decoded_lenindef.bered) repr(obj_decoded_lenindef) @@ -5915,7 +5952,7 @@ class SeqOfMixing(object): self.assertFalse(decoded.ber_encoded) self.assertFalse(decoded.lenindef) self.assertTrue(decoded.bered) - decoded = decoded.copy() + decoded = copy(decoded) self.assertFalse(decoded.ber_encoded) self.assertFalse(decoded.lenindef) self.assertTrue(decoded.bered) @@ -5936,7 +5973,7 @@ class SeqOfMixing(object): self.assertFalse(decoded.ber_encoded) self.assertFalse(decoded.lenindef) self.assertTrue(decoded.bered) - decoded = decoded.copy() + decoded = copy(decoded) self.assertFalse(decoded.ber_encoded) self.assertFalse(decoded.lenindef) self.assertTrue(decoded.bered) @@ -6005,7 +6042,7 @@ class TestSetOf(SeqOfMixing, CommonMixin, TestCase): seq_decoded, _ = Seq().decode(seq_encoded, ctx=ctx) self.assertTrue(seq_decoded.ber_encoded) self.assertTrue(seq_decoded.bered) - seq_decoded = seq_decoded.copy() + seq_decoded = copy(seq_decoded) self.assertTrue(seq_decoded.ber_encoded) self.assertTrue(seq_decoded.bered) self.assertSequenceEqual( @@ -6516,13 +6553,13 @@ class TestStrictDefaultExistence(TestCase): decoded, _ = seq.decode(raw, ctx={"allow_default_values": True}) self.assertTrue(decoded.ber_encoded) self.assertTrue(decoded.bered) - decoded = decoded.copy() + decoded = copy(decoded) self.assertTrue(decoded.ber_encoded) self.assertTrue(decoded.bered) decoded, _ = seq.decode(raw, ctx={"bered": True}) self.assertTrue(decoded.ber_encoded) self.assertTrue(decoded.bered) - decoded = decoded.copy() + decoded = copy(decoded) self.assertTrue(decoded.ber_encoded) self.assertTrue(decoded.bered) @@ -6574,3 +6611,15 @@ class TestExplOOB(TestCase): with assertRaisesRegex(self, DecodeError, "explicit tag out-of-bound"): Integer(expl=expl).decode(raw) Integer(expl=expl).decode(raw, ctx={"allow_expl_oob": True}) + + +class TestPickleDifferentVersion(TestCase): + def runTest(self): + pickled = pickle_dumps(Integer(123), pickle_proto) + import pyderasn + version_orig = pyderasn.__version__ + pyderasn.__version__ += "different" + with assertRaisesRegex(self, ValueError, "different PyDERASN version"): + pickle_loads(pickled) + pyderasn.__version__ = version_orig + pickle_loads(pickled) -- 2.44.0