X-Git-Url: http://www.git.cypherpunks.ru/?p=pyderasn.git;a=blobdiff_plain;f=tests%2Ftest_pyderasn.py;h=309bc2290df05440bccc2b63d65a53547100e852;hp=3161102997daf57449b26c20194b1081613f8b65;hb=2e6117387cfb10eca87e9846498a9a045f05dba3;hpb=25f29a82ccf3a411032a89cee111edd87f07ad3e diff --git a/tests/test_pyderasn.py b/tests/test_pyderasn.py index 3161102..309bc22 100644 --- a/tests/test_pyderasn.py +++ b/tests/test_pyderasn.py @@ -345,6 +345,9 @@ class CommonMixin(object): obj = Inherited() self.assertSequenceEqual(obj.impl, impl_tag) self.assertFalse(obj.expled) + if obj.ready: + tag_class, _, tag_num = tag_decode(impl_tag) + self.assertEqual(obj.tag_order, (tag_class, tag_num)) @given(binary(min_size=1)) def test_expl_inherited(self, expl_tag): @@ -353,6 +356,9 @@ class CommonMixin(object): obj = Inherited() self.assertSequenceEqual(obj.expl, expl_tag) self.assertTrue(obj.expled) + if obj.ready: + tag_class, _, tag_num = tag_decode(expl_tag) + self.assertEqual(obj.tag_order, (tag_class, tag_num)) def assert_copied_basic_fields(self, obj, obj_copied): self.assertEqual(obj, obj_copied) @@ -363,6 +369,8 @@ class CommonMixin(object): self.assertEqual(obj.offset, obj_copied.offset) self.assertEqual(obj.llen, obj_copied.llen) self.assertEqual(obj.vlen, obj_copied.vlen) + if obj.ready: + self.assertEqual(obj.tag_order, obj_copied.tag_order) @composite @@ -4667,9 +4675,16 @@ class TestUTCTime(TimeMixin, CommonMixin, TestCase): ) +@composite +def tlv_value_strategy(draw): + tag_num = draw(integers(min_value=1)) + data = draw(binary()) + return b"".join((tag_encode(tag_num), len_encode(len(data)), data)) + + @composite def any_values_strategy(draw, do_expl=False): - value = draw(one_of(none(), binary())) + value = draw(one_of(none(), tlv_value_strategy())) expl = None if do_expl: expl = draw(one_of(none(), integers(min_value=1).map(tag_encode))) @@ -4699,7 +4714,7 @@ class TestAny(CommonMixin, TestCase): obj = Any(optional=optional) self.assertEqual(obj.optional, optional) - @given(binary()) + @given(tlv_value_strategy()) def test_ready(self, value): obj = Any() self.assertFalse(obj.ready) @@ -4733,7 +4748,7 @@ class TestAny(CommonMixin, TestCase): pprint(obj, big_blobs=True, with_decode_path=True) self.assertSequenceEqual(obj.encode(), integer_encoded) - @given(binary(min_size=1), binary(min_size=1)) + @given(tlv_value_strategy(), tlv_value_strategy()) def test_comparison(self, value1, value2): for klass in (Any, AnyInherited): obj1 = klass(value1) @@ -4797,7 +4812,7 @@ class TestAny(CommonMixin, TestCase): obj.decode(obj.encode()[:-1]) @given( - binary(), + tlv_value_strategy(), integers(min_value=1).map(tag_ctxc), ) def test_stripped_expl(self, value, tag_expl): @@ -4853,9 +4868,13 @@ class TestAny(CommonMixin, TestCase): list(obj.pps()) pprint(obj, big_blobs=True, with_decode_path=True) self.assertFalse(obj.expled) + tag_class, _, tag_num = tag_decode(tag_strip(value)[0]) + self.assertEqual(obj.tag_order, (tag_class, tag_num)) obj_encoded = obj.encode() obj_expled = obj(value, expl=tag_expl) self.assertTrue(obj_expled.expled) + tag_class, _, tag_num = tag_decode(tag_expl) + self.assertEqual(obj_expled.tag_order, (tag_class, tag_num)) repr(obj_expled) list(obj_expled.pps()) pprint(obj_expled, big_blobs=True, with_decode_path=True) @@ -5219,9 +5238,12 @@ class TestChoice(CommonMixin, TestCase): list(obj.pps()) pprint(obj, big_blobs=True, with_decode_path=True) self.assertFalse(obj.expled) + self.assertEqual(obj.tag_order, obj.value.tag_order) obj_encoded = obj.encode() obj_expled = obj(value, expl=tag_expl) self.assertTrue(obj_expled.expled) + tag_class, _, tag_num = tag_decode(tag_expl) + self.assertEqual(obj_expled.tag_order, (tag_class, tag_num)) repr(obj_expled) list(obj_expled.pps()) pprint(obj_expled, big_blobs=True, with_decode_path=True) @@ -5652,8 +5674,9 @@ class SeqMixing(object): with self.assertRaises(NotEnoughData): seq.decode(seq.encode()[:-1]) - @given(binary(min_size=2)) - def test_non_tag_mismatch_raised(self, junk): + @given(integers(min_value=3), binary(min_size=2)) + def test_non_tag_mismatch_raised(self, junk_tag_num, junk): + junk = tag_encode(junk_tag_num) + junk try: _, _, len_encoded = tag_strip(memoryview(junk)) len_decode(len_encoded) @@ -5999,42 +6022,44 @@ class TestSet(SeqMixing, CommonMixin, TestCase): @settings(max_examples=LONG_TEST_MAX_EXAMPLES) @given(data_strategy()) def test_sorted(self, d): - tags = [ - tag_encode(tag) for tag in - d.draw(sets(integers(min_value=1), min_size=1, max_size=10)) - ] + class DummySeq(Sequence): + schema = (("null", Null()),) + + tag_nums = d.draw(sets(integers(min_value=1), min_size=1, max_size=50)) + _, _, dummy_seq_tag_num = tag_decode(DummySeq.tag_default) + assume(any(i > dummy_seq_tag_num for i in tag_nums)) + tag_nums -= set([dummy_seq_tag_num]) + _schema = [(str(i), OctetString(impl=tag_encode(i))) for i in tag_nums] + _schema.append(("seq", DummySeq())) class Seq(Set): - schema = [(str(i), OctetString(impl=t)) for i, t in enumerate(tags)] + schema = d.draw(permutations(_schema)) seq = Seq() - for name, _ in Seq.schema: - seq[name] = OctetString(b"") + for name, _ in _schema: + if name != "seq": + seq[name] = OctetString(name.encode("ascii")) + seq["seq"] = DummySeq((("null", Null()),)) + seq_encoded = seq.encode() seq_decoded, _ = seq.decode(seq_encoded) + seq_encoded_expected = [] + for tag_num in sorted(tag_nums | set([dummy_seq_tag_num])): + if tag_num == dummy_seq_tag_num: + seq_encoded_expected.append(seq["seq"].encode()) + else: + seq_encoded_expected.append(seq[str(tag_num)].encode()) self.assertSequenceEqual( seq_encoded[seq_decoded.tlen + seq_decoded.llen:], - b"".join(sorted([seq[name].encode() for name, _ in Seq.schema])), + b"".join(seq_encoded_expected), ) - @settings(max_examples=LONG_TEST_MAX_EXAMPLES) - @given(data_strategy()) - def test_unsorted(self, d): - tags = [ - tag_encode(tag) for tag in - d.draw(sets(integers(min_value=1), min_size=2, max_size=5)) - ] - tags = d.draw(permutations(tags)) - assume(tags != sorted(tags)) - encoded = b"".join(OctetString(t, impl=t).encode() for t in tags) + encoded = b"".join(seq[str(i)].encode() for i in tag_nums) + encoded += seq["seq"].encode() seq_encoded = b"".join(( Set.tag_default, len_encode(len(encoded)), encoded, )) - - class Seq(Set): - schema = [(str(i), OctetString(impl=t)) for i, t in enumerate(tags)] - seq = Seq() with assertRaisesRegex(self, DecodeError, "unordered SET"): seq.decode(seq_encoded) for ctx in ({"bered": True}, {"allow_unordered_set": True}): @@ -6044,10 +6069,6 @@ class TestSet(SeqMixing, CommonMixin, TestCase): seq_decoded = copy(seq_decoded) self.assertTrue(seq_decoded.ber_encoded) self.assertTrue(seq_decoded.bered) - self.assertSequenceEqual( - [bytes(seq_decoded[str(i)]) for i, t in enumerate(tags)], - tags, - ) def test_same_value_twice(self): class Seq(Set):