X-Git-Url: http://www.git.cypherpunks.ru/?a=blobdiff_plain;f=tests%2Ftest_pyderasn.py;h=29a3cb2da95a79a7a1f5d861d2fbc3c4bcac1697;hb=333d098f0af80eae5481c99b86419d25ea927d22;hp=13aeff48e1ead5dda07499b252214dc294934527;hpb=afc0f9f65430bed928619c783373ae3c6a82be1b;p=pyderasn.git diff --git a/tests/test_pyderasn.py b/tests/test_pyderasn.py index 13aeff4..29a3cb2 100644 --- a/tests/test_pyderasn.py +++ b/tests/test_pyderasn.py @@ -1,11 +1,10 @@ # coding: utf-8 # PyDERASN -- Python ASN.1 DER/BER codec with abstract structures -# Copyright (C) 2017-2019 Sergey Matveev +# Copyright (C) 2017-2020 Sergey Matveev # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. +# published by the Free Software Foundation, version 3 of the License. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of @@ -16,12 +15,15 @@ # License along with this program. If not, see # . +from copy import copy from copy import deepcopy from datetime import datetime +from importlib import import_module from string import ascii_letters from string import digits from string import printable from string import whitespace +from time import time from unittest import TestCase from hypothesis import assume @@ -52,6 +54,9 @@ from six import iterbytes from six import PY2 from six import text_type from six import unichr as six_unichr +from six.moves.cPickle import dumps as pickle_dumps +from six.moves.cPickle import HIGHEST_PROTOCOL as pickle_proto +from six.moves.cPickle import loads as pickle_loads from pyderasn import _pp from pyderasn import abs_decode_path @@ -66,6 +71,7 @@ from pyderasn import DecodePathDefBy from pyderasn import Enumerated from pyderasn import EOC from pyderasn import EOC_LEN +from pyderasn import ExceedingData from pyderasn import GeneralizedTime from pyderasn import GeneralString from pyderasn import GraphicString @@ -134,6 +140,24 @@ decode_path_strat = lists(integers(), max_size=3).map( lambda decode_path: tuple(str(dp) for dp in decode_path) ) ctx_dummy = dictionaries(integers(), integers(), min_size=2, max_size=4).example() +copy_funcs = ( + copy, + lambda obj: pickle_loads(pickle_dumps(obj, pickle_proto)), +) +self_module = import_module(__name__) + + +def register_class(klass): + klassname = klass.__name__ + str(time()).replace(".", "") + klass.__name__ = klassname + klass.__qualname__ = klassname + setattr(self_module, klassname, klass) + + +def assert_exceeding_data(self, call, junk): + if len(junk) > 0: + with assertRaisesRegex(self, ExceedingData, "%d trailing bytes" % len(junk)): + call() class TestHex(TestCase): @@ -450,8 +474,9 @@ class TestBoolean(CommonMixin, TestCase): def test_copy(self, values): for klass in (Boolean, BooleanInherited): obj = klass(*values) - obj_copied = obj.copy() - self.assert_copied_basic_fields(obj, obj_copied) + for copy_func in copy_funcs: + obj_copied = copy_func(obj) + self.assert_copied_basic_fields(obj, obj_copied) @given( booleans(), @@ -562,10 +587,10 @@ class TestBoolean(CommonMixin, TestCase): repr(obj_expled) list(obj_expled.pps()) pprint(obj_expled, big_blobs=True, with_decode_path=True) - obj_expled_encoded = obj_expled.encode() + obj_expled_hex_encoded = obj_expled.hexencode() ctx_copied = deepcopy(ctx_dummy) - obj_decoded, tail = obj_expled.decode( - obj_expled_encoded + tail_junk, + obj_decoded, tail = obj_expled.hexdecode( + obj_expled_hex_encoded + hexenc(tail_junk), offset=offset, ctx=ctx_copied, ) @@ -578,7 +603,7 @@ class TestBoolean(CommonMixin, TestCase): self.assertNotEqual(obj_decoded, obj) self.assertEqual(bool(obj_decoded), bool(obj_expled)) self.assertEqual(bool(obj_decoded), bool(obj)) - self.assertSequenceEqual(obj_decoded.encode(), obj_expled_encoded) + self.assertSequenceEqual(obj_decoded.hexencode(), obj_expled_hex_encoded) self.assertSequenceEqual(obj_decoded.expl_tag, tag_expl) self.assertEqual(obj_decoded.expl_tlen, len(tag_expl)) self.assertEqual( @@ -592,6 +617,11 @@ class TestBoolean(CommonMixin, TestCase): offset + obj_decoded.expl_tlen + obj_decoded.expl_llen, ) self.assertEqual(obj_decoded.expl_offset, offset) + assert_exceeding_data( + self, + lambda: obj_expled.hexdecod(obj_expled_hex_encoded + hexenc(tail_junk)), + tail_junk, + ) @given(integers(min_value=2)) def test_invalid_len(self, l): @@ -622,7 +652,7 @@ class TestBoolean(CommonMixin, TestCase): self.assertTrue(obj.ber_encoded) self.assertFalse(obj.lenindef) self.assertTrue(obj.bered) - obj = obj.copy() + obj = copy(obj) self.assertTrue(obj.ber_encoded) self.assertFalse(obj.lenindef) self.assertTrue(obj.bered) @@ -645,7 +675,7 @@ class TestBoolean(CommonMixin, TestCase): self.assertFalse(obj.lenindef) self.assertFalse(obj.ber_encoded) self.assertTrue(obj.bered) - obj = obj.copy() + obj = copy(obj) self.assertTrue(obj.expl_lenindef) self.assertFalse(obj.lenindef) self.assertFalse(obj.ber_encoded) @@ -943,12 +973,13 @@ class TestInteger(CommonMixin, TestCase): def test_copy(self, values): for klass in (Integer, IntegerInherited): obj = klass(*values) - obj_copied = obj.copy() - self.assert_copied_basic_fields(obj, obj_copied) - self.assertEqual(obj.specs, obj_copied.specs) - self.assertEqual(obj._bound_min, obj_copied._bound_min) - self.assertEqual(obj._bound_max, obj_copied._bound_max) - self.assertEqual(obj._value, obj_copied._value) + for copy_func in copy_funcs: + obj_copied = copy_func(obj) + self.assert_copied_basic_fields(obj, obj_copied) + self.assertEqual(obj.specs, obj_copied.specs) + self.assertEqual(obj._bound_min, obj_copied._bound_min) + self.assertEqual(obj._bound_max, obj_copied._bound_max) + self.assertEqual(obj._value, obj_copied._value) @given( integers(), @@ -1084,6 +1115,11 @@ class TestInteger(CommonMixin, TestCase): offset + obj_decoded.expl_tlen + obj_decoded.expl_llen, ) self.assertEqual(obj_decoded.expl_offset, offset) + assert_exceeding_data( + self, + lambda: obj_expled.decod(obj_expled_encoded + tail_junk), + tail_junk, + ) def test_go_vectors_valid(self): for data, expect in (( @@ -1138,7 +1174,7 @@ def bit_string_values_strategy(draw, schema=None, value_required=False, do_expl= def _value(value_required): if not value_required and draw(booleans()): - return + return None generation_choice = 0 if value_required: generation_choice = draw(sampled_from((1, 2, 3))) @@ -1147,9 +1183,9 @@ def bit_string_values_strategy(draw, schema=None, value_required=False, do_expl= sampled_from(("0", "1")), max_size=len(schema), ))) - elif generation_choice == 2 or draw(booleans()): + if generation_choice == 2 or draw(booleans()): return draw(binary(max_size=len(schema) // 8)) - elif generation_choice == 3 or draw(booleans()): + if generation_choice == 3 or draw(booleans()): return tuple(draw(lists(sampled_from([name for name, _ in schema])))) return None value = _value(value_required) @@ -1343,6 +1379,7 @@ class TestBitString(CommonMixin, TestCase): class BS(klass): schema = _schema + register_class(BS) obj = BS( value=value, impl=impl, @@ -1351,10 +1388,11 @@ class TestBitString(CommonMixin, TestCase): optional=optional or False, _decoded=_decoded, ) - obj_copied = obj.copy() - self.assert_copied_basic_fields(obj, obj_copied) - self.assertEqual(obj.specs, obj_copied.specs) - self.assertEqual(obj._value, obj_copied._value) + for copy_func in copy_funcs: + obj_copied = copy_func(obj) + self.assert_copied_basic_fields(obj, obj_copied) + self.assertEqual(obj.specs, obj_copied.specs) + self.assertEqual(obj._value, obj_copied._value) @given( binary(), @@ -1474,6 +1512,11 @@ class TestBitString(CommonMixin, TestCase): self.assertSetEqual(set(value), set(obj_decoded.named)) for name in value: obj_decoded[name] + assert_exceeding_data( + self, + lambda: obj_expled.decod(obj_expled_encoded + tail_junk), + tail_junk, + ) @given(integers(min_value=1, max_value=255)) def test_bad_zero_value(self, pad_size): @@ -1593,7 +1636,7 @@ class TestBitString(CommonMixin, TestCase): self.assertTrue(obj.ber_encoded) self.assertEqual(obj.lenindef, lenindef_expected) self.assertTrue(obj.bered) - obj = obj.copy() + obj = copy(obj) self.assertTrue(obj.ber_encoded) self.assertEqual(obj.lenindef, lenindef_expected) self.assertTrue(obj.bered) @@ -1729,7 +1772,7 @@ class TestBitString(CommonMixin, TestCase): self.assertTrue(obj.ber_encoded) self.assertTrue(obj.lenindef) self.assertTrue(obj.bered) - obj = obj.copy() + obj = copy(obj) self.assertTrue(obj.ber_encoded) self.assertTrue(obj.lenindef) self.assertTrue(obj.bered) @@ -1926,11 +1969,12 @@ class TestOctetString(CommonMixin, TestCase): def test_copy(self, values): for klass in (OctetString, OctetStringInherited): obj = klass(*values) - obj_copied = obj.copy() - self.assert_copied_basic_fields(obj, obj_copied) - self.assertEqual(obj._bound_min, obj_copied._bound_min) - self.assertEqual(obj._bound_max, obj_copied._bound_max) - self.assertEqual(obj._value, obj_copied._value) + for copy_func in copy_funcs: + obj_copied = copy_func(obj) + self.assert_copied_basic_fields(obj, obj_copied) + self.assertEqual(obj._bound_min, obj_copied._bound_min) + self.assertEqual(obj._bound_max, obj_copied._bound_max) + self.assertEqual(obj._value, obj_copied._value) @given( binary(), @@ -2059,6 +2103,11 @@ class TestOctetString(CommonMixin, TestCase): offset + obj_decoded.expl_tlen + obj_decoded.expl_llen, ) self.assertEqual(obj_decoded.expl_offset, offset) + assert_exceeding_data( + self, + lambda: obj_expled.decod(obj_expled_encoded + tail_junk), + tail_junk, + ) @given( integers(min_value=1, max_value=30), @@ -2120,7 +2169,7 @@ class TestOctetString(CommonMixin, TestCase): self.assertTrue(obj.ber_encoded) self.assertEqual(obj.lenindef, lenindef_expected) self.assertTrue(obj.bered) - obj = obj.copy() + obj = copy(obj) self.assertTrue(obj.ber_encoded) self.assertEqual(obj.lenindef, lenindef_expected) self.assertTrue(obj.bered) @@ -2260,8 +2309,9 @@ class TestNull(CommonMixin, TestCase): optional=optional or False, _decoded=_decoded, ) - obj_copied = obj.copy() - self.assert_copied_basic_fields(obj, obj_copied) + for copy_func in copy_funcs: + obj_copied = copy_func(obj) + self.assert_copied_basic_fields(obj, obj_copied) @given(integers(min_value=1).map(tag_encode)) def test_stripped(self, tag_impl): @@ -2361,6 +2411,11 @@ class TestNull(CommonMixin, TestCase): offset + obj_decoded.expl_tlen + obj_decoded.expl_llen, ) self.assertEqual(obj_decoded.expl_offset, offset) + assert_exceeding_data( + self, + lambda: obj_expled.decod(obj_expled_encoded + tail_junk), + tail_junk, + ) @given(integers(min_value=1)) def test_invalid_len(self, l): @@ -2531,9 +2586,10 @@ class TestObjectIdentifier(CommonMixin, TestCase): optional=optional, _decoded=_decoded, ) - obj_copied = obj.copy() - self.assert_copied_basic_fields(obj, obj_copied) - self.assertEqual(obj._value, obj_copied._value) + for copy_func in copy_funcs: + obj_copied = copy_func(obj) + self.assert_copied_basic_fields(obj, obj_copied) + self.assertEqual(obj._value, obj_copied._value) @settings(max_examples=LONG_TEST_MAX_EXAMPLES) @given( @@ -2699,6 +2755,11 @@ class TestObjectIdentifier(CommonMixin, TestCase): offset + obj_decoded.expl_tlen + obj_decoded.expl_llen, ) self.assertEqual(obj_decoded.expl_offset, offset) + assert_exceeding_data( + self, + lambda: obj_expled.decod(obj_expled_encoded + tail_junk), + tail_junk, + ) @given( oid_strategy().map(ObjectIdentifier), @@ -2742,8 +2803,7 @@ class TestObjectIdentifier(CommonMixin, TestCase): ObjectIdentifier((2, 999, 3)), ) - @given(data_strategy()) - def test_nonnormalized_first_arc(self, d): + def test_nonnormalized_first_arc(self): tampered = ( ObjectIdentifier.tag_default + len_encode(2) + @@ -2753,12 +2813,34 @@ class TestObjectIdentifier(CommonMixin, TestCase): obj, _ = ObjectIdentifier().decode(tampered, ctx={"bered": True}) self.assertTrue(obj.ber_encoded) self.assertTrue(obj.bered) - obj = obj.copy() + obj = copy(obj) self.assertTrue(obj.ber_encoded) self.assertTrue(obj.bered) with assertRaisesRegex(self, DecodeError, "non normalized arc encoding"): ObjectIdentifier().decode(tampered) + @given(data_strategy()) + def test_negative_arcs(self, d): + oid = list(d.draw(oid_strategy())) + if len(oid) == 2: + return + idx = d.draw(integers(min_value=3, max_value=len(oid))) + oid[idx - 1] *= -1 + if oid[idx - 1] == 0: + oid[idx - 1] = -1 + with self.assertRaises(InvalidOID): + ObjectIdentifier(tuple(oid)) + with self.assertRaises(InvalidOID): + ObjectIdentifier(".".join(str(i) for i in oid)) + + @given(data_strategy()) + def test_plused_arcs(self, d): + oid = [str(arc) for arc in d.draw(oid_strategy())] + idx = d.draw(integers(min_value=0, max_value=len(oid))) + oid[idx - 1] = "+" + oid[idx - 1] + with self.assertRaises(InvalidOID): + ObjectIdentifier(".".join(str(i) for i in oid)) + @given(data_strategy()) def test_nonnormalized_arcs(self, d): arcs = d.draw(lists( @@ -2767,8 +2849,8 @@ class TestObjectIdentifier(CommonMixin, TestCase): max_size=5, )) dered = ObjectIdentifier((1, 0) + tuple(arcs)).encode() - _, tlen, lv = tag_strip(dered) - _, llen, v = len_decode(lv) + _, _, lv = tag_strip(dered) + _, _, v = len_decode(lv) v_no_first_arc = v[1:] idx_for_tamper = d.draw(integers( min_value=0, @@ -2786,7 +2868,7 @@ class TestObjectIdentifier(CommonMixin, TestCase): obj, _ = ObjectIdentifier().decode(tampered, ctx={"bered": True}) self.assertTrue(obj.ber_encoded) self.assertTrue(obj.bered) - obj = obj.copy() + obj = copy(obj) self.assertTrue(obj.ber_encoded) self.assertTrue(obj.bered) with assertRaisesRegex(self, DecodeError, "non normalized arc encoding"): @@ -2972,6 +3054,7 @@ class TestEnumerated(CommonMixin, TestCase): class E(Enumerated): schema = schema_input + register_class(E) obj = E( value=value, impl=impl, @@ -2980,9 +3063,10 @@ class TestEnumerated(CommonMixin, TestCase): optional=optional, _decoded=_decoded, ) - obj_copied = obj.copy() - self.assert_copied_basic_fields(obj, obj_copied) - self.assertEqual(obj.specs, obj_copied.specs) + for copy_func in copy_funcs: + obj_copied = copy_func(obj) + self.assert_copied_basic_fields(obj, obj_copied) + self.assertEqual(obj.specs, obj_copied.specs) @settings(max_examples=LONG_TEST_MAX_EXAMPLES) @given(data_strategy()) @@ -3043,6 +3127,11 @@ class TestEnumerated(CommonMixin, TestCase): offset + obj_decoded.expl_tlen + obj_decoded.expl_llen, ) self.assertEqual(obj_decoded.expl_offset, offset) + assert_exceeding_data( + self, + lambda: obj_expled.decod(obj_expled_encoded + tail_junk), + tail_junk, + ) @composite @@ -3241,11 +3330,12 @@ class StringMixin(object): def test_copy(self, d): values = d.draw(string_values_strategy(self.text_alphabet())) obj = self.base_klass(*values) - obj_copied = obj.copy() - self.assert_copied_basic_fields(obj, obj_copied) - self.assertEqual(obj._bound_min, obj_copied._bound_min) - self.assertEqual(obj._bound_max, obj_copied._bound_max) - self.assertEqual(obj._value, obj_copied._value) + for copy_func in copy_funcs: + obj_copied = copy_func(obj) + self.assert_copied_basic_fields(obj, obj_copied) + self.assertEqual(obj._bound_min, obj_copied._bound_min) + self.assertEqual(obj._bound_max, obj_copied._bound_max) + self.assertEqual(obj._value, obj_copied._value) @given(data_strategy()) def test_stripped(self, d): @@ -3373,6 +3463,11 @@ class StringMixin(object): offset + obj_decoded.expl_tlen + obj_decoded.expl_llen, ) self.assertEqual(obj_decoded.expl_offset, offset) + assert_exceeding_data( + self, + lambda: obj_expled.decod(obj_expled_encoded + tail_junk), + tail_junk, + ) class TestUTF8String(StringMixin, CommonMixin, TestCase): @@ -3461,6 +3556,29 @@ class TestPrintableString( self.assertEqual(err.exception.offset, offset) self.assertEqual(err.exception.decode_path, decode_path) + def test_allowable_invalid_chars(self): + for c, kwargs in ( + ("*", {"allow_asterisk": True}), + ("&", {"allow_ampersand": True}), + ("&*", {"allow_asterisk": True, "allow_ampersand": True}), + ): + s = "hello invalid" + obj = self.base_klass(s) + for prop in kwargs.keys(): + self.assertFalse(getattr(obj, prop)) + s += c + with assertRaisesRegex(self, DecodeError, "non-printable"): + self.base_klass(s) + self.base_klass(s, **kwargs) + klass = self.base_klass(**kwargs) + obj = klass(s) + for prop in kwargs.keys(): + self.assertTrue(getattr(obj, prop)) + obj = copy(obj) + obj(s) + for prop in kwargs.keys(): + self.assertTrue(getattr(obj, prop)) + class TestTeletexString( UnicodeDecodeErrorMixin, @@ -3523,7 +3641,7 @@ class TestVisibleString( self.assertTrue(obj.ber_encoded) self.assertFalse(obj.lenindef) self.assertTrue(obj.bered) - obj = obj.copy() + obj = copy(obj) self.assertTrue(obj.ber_encoded) self.assertFalse(obj.lenindef) self.assertTrue(obj.bered) @@ -3537,7 +3655,7 @@ class TestVisibleString( self.assertTrue(obj.ber_encoded) self.assertTrue(obj.lenindef) self.assertTrue(obj.bered) - obj = obj.copy() + obj = copy(obj) self.assertTrue(obj.ber_encoded) self.assertTrue(obj.lenindef) self.assertTrue(obj.bered) @@ -3619,7 +3737,10 @@ class TimeMixin(object): with self.assertRaises(ObjNotReady) as err: obj.encode() repr(err.exception) - value = d.draw(datetimes(min_value=self.min_datetime)) + value = d.draw(datetimes( + min_value=self.min_datetime, + max_value=self.max_datetime, + )) obj = self.base_klass(value) self.assertTrue(obj.ready) repr(obj) @@ -3721,9 +3842,10 @@ class TimeMixin(object): max_datetime=self.max_datetime, )) obj = self.base_klass(*values) - obj_copied = obj.copy() - self.assert_copied_basic_fields(obj, obj_copied) - self.assertEqual(obj._value, obj_copied._value) + for copy_func in copy_funcs: + obj_copied = copy_func(obj) + self.assert_copied_basic_fields(obj, obj_copied) + self.assertEqual(obj._value, obj_copied._value) @given(data_strategy()) def test_stripped(self, d): @@ -3773,6 +3895,7 @@ class TimeMixin(object): pprint(obj, big_blobs=True, with_decode_path=True) self.assertFalse(obj.expled) obj_encoded = obj.encode() + self.additional_symmetric_check(value, obj_encoded) obj_expled = obj(value, expl=tag_expl) self.assertTrue(obj_expled.expled) repr(obj_expled) @@ -3807,6 +3930,11 @@ class TimeMixin(object): offset + obj_decoded.expl_tlen + obj_decoded.expl_llen, ) self.assertEqual(obj_decoded.expl_offset, offset) + assert_exceeding_data( + self, + lambda: obj_expled.decod(obj_expled_encoded + tail_junk), + tail_junk, + ) class TestGeneralizedTime(TimeMixin, CommonMixin, TestCase): @@ -3815,6 +3943,28 @@ class TestGeneralizedTime(TimeMixin, CommonMixin, TestCase): min_datetime = datetime(1900, 1, 1) max_datetime = datetime(9999, 12, 31) + def additional_symmetric_check(self, value, obj_encoded): + if value.microsecond > 0: + self.assertFalse(obj_encoded.endswith(b"0Z")) + + def test_x690_vector_valid(self): + for data in (( + b"19920521000000Z", + b"19920622123421Z", + b"19920722132100.3Z", + )): + GeneralizedTime(data) + + def test_x690_vector_invalid(self): + for data in (( + b"19920520240000Z", + b"19920622123421.0Z", + b"19920722132100.30Z", + )): + with self.assertRaises(DecodeError) as err: + GeneralizedTime(data) + repr(err.exception) + def test_go_vectors_invalid(self): for data in (( b"20100102030405", @@ -3891,6 +4041,11 @@ class TestGeneralizedTime(TimeMixin, CommonMixin, TestCase): junk ) + def test_ns_fractions(self): + GeneralizedTime(b"20010101000000.000001Z") + with assertRaisesRegex(self, DecodeError, "only microsecond fractions"): + GeneralizedTime(b"20010101000000.0000001Z") + class TestUTCTime(TimeMixin, CommonMixin, TestCase): base_klass = UTCTime @@ -3898,6 +4053,26 @@ class TestUTCTime(TimeMixin, CommonMixin, TestCase): min_datetime = datetime(2000, 1, 1) max_datetime = datetime(2049, 12, 31) + def additional_symmetric_check(self, value, obj_encoded): + pass + + def test_x690_vector_valid(self): + for data in (( + b"920521000000Z", + b"920622123421Z", + b"920722132100Z", + )): + UTCTime(data) + + def test_x690_vector_invalid(self): + for data in (( + b"920520240000Z", + b"9207221321Z", + )): + with self.assertRaises(DecodeError) as err: + UTCTime(data) + repr(err.exception) + def test_go_vectors_invalid(self): for data in (( b"a10506234540Z", @@ -4095,9 +4270,10 @@ class TestAny(CommonMixin, TestCase): def test_copy(self, values): for klass in (Any, AnyInherited): obj = klass(*values) - obj_copied = obj.copy() - self.assert_copied_basic_fields(obj, obj_copied) - self.assertEqual(obj._value, obj_copied._value) + for copy_func in copy_funcs: + obj_copied = copy_func(obj) + self.assert_copied_basic_fields(obj, obj_copied) + self.assertEqual(obj._value, obj_copied._value) @given(binary().map(OctetString)) def test_stripped(self, value): @@ -4200,6 +4376,11 @@ class TestAny(CommonMixin, TestCase): self.assertEqual(obj_decoded.tlen, 0) self.assertEqual(obj_decoded.llen, 0) self.assertEqual(obj_decoded.vlen, len(value)) + assert_exceeding_data( + self, + lambda: obj_expled.decod(obj_expled_encoded + tail_junk), + tail_junk, + ) @given( integers(min_value=1).map(tag_ctxc), @@ -4234,7 +4415,7 @@ class TestAny(CommonMixin, TestCase): self.assertTrue(obj.lenindef) self.assertFalse(obj.ber_encoded) self.assertTrue(obj.bered) - obj = obj.copy() + obj = copy(obj) self.assertTrue(obj.lenindef) self.assertFalse(obj.ber_encoded) self.assertTrue(obj.bered) @@ -4466,6 +4647,7 @@ class TestChoice(CommonMixin, TestCase): class Wahl(self.base_klass): schema = _schema + register_class(Wahl) obj = Wahl( value=value, expl=expl, @@ -4473,15 +4655,17 @@ class TestChoice(CommonMixin, TestCase): optional=optional or False, _decoded=_decoded, ) - obj_copied = obj.copy() - self.assertIsNone(obj.tag) - self.assertIsNone(obj_copied.tag) - # hack for assert_copied_basic_fields - obj.tag = "whatever" - obj_copied.tag = "whatever" - self.assert_copied_basic_fields(obj, obj_copied) - self.assertEqual(obj._value, obj_copied._value) - self.assertEqual(obj.specs, obj_copied.specs) + for copy_func in copy_funcs: + obj_copied = copy_func(obj) + self.assertIsNone(obj.tag) + self.assertIsNone(obj_copied.tag) + # hack for assert_copied_basic_fields + obj.tag = "whatever" + obj_copied.tag = "whatever" + self.assert_copied_basic_fields(obj, obj_copied) + obj.tag = None + self.assertEqual(obj._value, obj_copied._value) + self.assertEqual(obj.specs, obj_copied.specs) @given(booleans()) def test_stripped(self, value): @@ -4564,6 +4748,11 @@ class TestChoice(CommonMixin, TestCase): ], obj_encoded, ) + assert_exceeding_data( + self, + lambda: obj_expled.decod(obj_expled_encoded + tail_junk), + tail_junk, + ) @given(integers()) def test_set_get(self, value): @@ -4623,15 +4812,13 @@ def seq_values_strategy(draw, seq_klass, do_expl=False): value = None if draw(booleans()): value = seq_klass() - value._value = { - k: v for k, v in draw(dictionaries( - integers(), - one_of( - booleans().map(Boolean), - integers().map(Integer), - ), - )).items() - } + value._value = draw(dictionaries( + integers(), + one_of( + booleans().map(Boolean), + integers().map(Integer), + ), + )) schema = None if draw(booleans()): schema = list(draw(dictionaries( @@ -4650,15 +4837,13 @@ def seq_values_strategy(draw, seq_klass, do_expl=False): default = None if draw(booleans()): default = seq_klass() - default._value = { - k: v for k, v in draw(dictionaries( - integers(), - one_of( - booleans().map(Boolean), - integers().map(Integer), - ), - )).items() - } + default._value = draw(dictionaries( + integers(), + one_of( + booleans().map(Boolean), + integers().map(Integer), + ), + )) optional = draw(one_of(none(), booleans())) _decoded = ( draw(integers(min_value=0)), @@ -4916,13 +5101,15 @@ class SeqMixing(object): def test_copy(self, d): class SeqInherited(self.base_klass): pass + register_class(SeqInherited) for klass in (self.base_klass, SeqInherited): values = d.draw(seq_values_strategy(seq_klass=klass)) obj = klass(*values) - obj_copied = obj.copy() - self.assert_copied_basic_fields(obj, obj_copied) - self.assertEqual(obj.specs, obj_copied.specs) - self.assertEqual(obj._value, obj_copied._value) + for copy_func in copy_funcs: + obj_copied = copy_func(obj) + self.assert_copied_basic_fields(obj, obj_copied) + self.assertEqual(obj.specs, obj_copied.specs) + self.assertEqual(obj._value, obj_copied._value) @given(data_strategy()) def test_stripped(self, d): @@ -5051,7 +5238,7 @@ class SeqMixing(object): self.assertDictEqual(ctx_copied, ctx_dummy) self.assertTrue(seq_decoded_lenindef.lenindef) self.assertTrue(seq_decoded_lenindef.bered) - seq_decoded_lenindef = seq_decoded_lenindef.copy() + seq_decoded_lenindef = copy(seq_decoded_lenindef) self.assertTrue(seq_decoded_lenindef.lenindef) self.assertTrue(seq_decoded_lenindef.bered) with self.assertRaises(DecodeError): @@ -5086,6 +5273,12 @@ class SeqMixing(object): obj.encode(), ) + assert_exceeding_data( + self, + lambda: seq.decod(seq_encoded_lenindef + tail_junk, ctx={"bered": True}), + tail_junk, + ) + @settings(max_examples=LONG_TEST_MAX_EXAMPLES) @given(data_strategy()) def test_symmetric_with_seq(self, d): @@ -5171,7 +5364,7 @@ class SeqMixing(object): seq_decoded, _ = seq_with_default.decode(seq_encoded, ctx=ctx) self.assertTrue(seq_decoded.ber_encoded) self.assertTrue(seq_decoded.bered) - seq_decoded = seq_decoded.copy() + seq_decoded = copy(seq_decoded) self.assertTrue(seq_decoded.ber_encoded) self.assertTrue(seq_decoded.bered) for name, value in _schema: @@ -5202,8 +5395,7 @@ class SeqMixing(object): with self.assertRaises(TagMismatch): seq_missing.decode(seq_encoded) - @given(data_strategy()) - def test_bered(self, d): + def test_bered(self): class Seq(self.base_klass): schema = (("underlying", Boolean()),) encoded = Boolean.tag_default + len_encode(1) + b"\x01" @@ -5212,7 +5404,7 @@ class SeqMixing(object): self.assertFalse(decoded.ber_encoded) self.assertFalse(decoded.lenindef) self.assertTrue(decoded.bered) - decoded = decoded.copy() + decoded = copy(decoded) self.assertFalse(decoded.ber_encoded) self.assertFalse(decoded.lenindef) self.assertTrue(decoded.bered) @@ -5232,7 +5424,7 @@ class SeqMixing(object): self.assertFalse(decoded.ber_encoded) self.assertFalse(decoded.lenindef) self.assertTrue(decoded.bered) - decoded = decoded.copy() + decoded = copy(decoded) self.assertFalse(decoded.ber_encoded) self.assertFalse(decoded.lenindef) self.assertTrue(decoded.bered) @@ -5332,12 +5524,12 @@ class TestSet(SeqMixing, CommonMixin, TestCase): seq_decoded, _ = Seq().decode(seq_encoded, ctx=ctx) self.assertTrue(seq_decoded.ber_encoded) self.assertTrue(seq_decoded.bered) - seq_decoded = seq_decoded.copy() + seq_decoded = copy(seq_decoded) self.assertTrue(seq_decoded.ber_encoded) self.assertTrue(seq_decoded.bered) self.assertSequenceEqual( [bytes(seq_decoded[str(i)]) for i, t in enumerate(tags)], - [t for t in tags], + tags, ) @@ -5610,6 +5802,7 @@ class SeqOfMixing(object): class SeqOf(self.base_klass): schema = _schema + register_class(SeqOf) obj = SeqOf( value=value, bounds=bounds, @@ -5619,11 +5812,12 @@ class SeqOfMixing(object): optional=optional or False, _decoded=_decoded, ) - obj_copied = obj.copy() - self.assert_copied_basic_fields(obj, obj_copied) - self.assertEqual(obj._bound_min, obj_copied._bound_min) - self.assertEqual(obj._bound_max, obj_copied._bound_max) - self.assertEqual(obj._value, obj_copied._value) + for copy_func in copy_funcs: + obj_copied = copy_func(obj) + self.assert_copied_basic_fields(obj, obj_copied) + self.assertEqual(obj._bound_min, obj_copied._bound_min) + self.assertEqual(obj._bound_max, obj_copied._bound_max) + self.assertEqual(obj._value, obj_copied._value) @given( lists(binary()), @@ -5760,20 +5954,26 @@ class SeqOfMixing(object): ) self.assertTrue(obj_decoded_lenindef.lenindef) self.assertTrue(obj_decoded_lenindef.bered) - obj_decoded_lenindef = obj_decoded_lenindef.copy() + obj_decoded_lenindef = copy(obj_decoded_lenindef) self.assertTrue(obj_decoded_lenindef.lenindef) self.assertTrue(obj_decoded_lenindef.bered) repr(obj_decoded_lenindef) list(obj_decoded_lenindef.pps()) pprint(obj_decoded_lenindef, big_blobs=True, with_decode_path=True) + self.assertEqual(tail_lenindef, tail_junk) self.assertEqual(obj_decoded_lenindef.tlvlen, len(obj_encoded_lenindef)) with self.assertRaises(DecodeError): obj.decode(obj_encoded_lenindef[:-1], ctx={"bered": True}) with self.assertRaises(DecodeError): obj.decode(obj_encoded_lenindef[:-2], ctx={"bered": True}) - @given(data_strategy()) - def test_bered(self, d): + assert_exceeding_data( + self, + lambda: obj_expled.decod(obj_expled_encoded + tail_junk), + tail_junk, + ) + + def test_bered(self): class SeqOf(self.base_klass): schema = Boolean() encoded = Boolean(False).encode() @@ -5785,7 +5985,7 @@ class SeqOfMixing(object): self.assertFalse(decoded.ber_encoded) self.assertFalse(decoded.lenindef) self.assertTrue(decoded.bered) - decoded = decoded.copy() + decoded = copy(decoded) self.assertFalse(decoded.ber_encoded) self.assertFalse(decoded.lenindef) self.assertTrue(decoded.bered) @@ -5806,7 +6006,7 @@ class SeqOfMixing(object): self.assertFalse(decoded.ber_encoded) self.assertFalse(decoded.lenindef) self.assertTrue(decoded.bered) - decoded = decoded.copy() + decoded = copy(decoded) self.assertFalse(decoded.ber_encoded) self.assertFalse(decoded.lenindef) self.assertTrue(decoded.bered) @@ -5875,7 +6075,7 @@ class TestSetOf(SeqOfMixing, CommonMixin, TestCase): seq_decoded, _ = Seq().decode(seq_encoded, ctx=ctx) self.assertTrue(seq_decoded.ber_encoded) self.assertTrue(seq_decoded.bered) - seq_decoded = seq_decoded.copy() + seq_decoded = copy(seq_decoded) self.assertTrue(seq_decoded.ber_encoded) self.assertTrue(seq_decoded.bered) self.assertSequenceEqual( @@ -6081,7 +6281,10 @@ class TestPP(TestCase): chosen_id = oids[chosen] pp = _pp(asn1_type_name=ObjectIdentifier.asn1_type_name, value=chosen) self.assertNotIn(chosen_id, pp_console_row(pp)) - self.assertIn(chosen_id, pp_console_row(pp, oids=oids)) + self.assertIn( + chosen_id, + pp_console_row(pp, oid_maps=[{'whatever': 'whenever'}, oids]), + ) class TestAutoAddSlots(TestCase): @@ -6383,13 +6586,13 @@ class TestStrictDefaultExistence(TestCase): decoded, _ = seq.decode(raw, ctx={"allow_default_values": True}) self.assertTrue(decoded.ber_encoded) self.assertTrue(decoded.bered) - decoded = decoded.copy() + decoded = copy(decoded) self.assertTrue(decoded.ber_encoded) self.assertTrue(decoded.bered) decoded, _ = seq.decode(raw, ctx={"bered": True}) self.assertTrue(decoded.ber_encoded) self.assertTrue(decoded.bered) - decoded = decoded.copy() + decoded = copy(decoded) self.assertTrue(decoded.ber_encoded) self.assertTrue(decoded.bered) @@ -6441,3 +6644,15 @@ class TestExplOOB(TestCase): with assertRaisesRegex(self, DecodeError, "explicit tag out-of-bound"): Integer(expl=expl).decode(raw) Integer(expl=expl).decode(raw, ctx={"allow_expl_oob": True}) + + +class TestPickleDifferentVersion(TestCase): + def runTest(self): + pickled = pickle_dumps(Integer(123), pickle_proto) + import pyderasn + version_orig = pyderasn.__version__ + pyderasn.__version__ += "different" + with assertRaisesRegex(self, ValueError, "different PyDERASN version"): + pickle_loads(pickled) + pyderasn.__version__ = version_orig + pickle_loads(pickled)