]> Cypherpunks.ru repositories - pyderasn.git/blobdiff - tests/test_pyderasn.py
Valid DER SET ordering
[pyderasn.git] / tests / test_pyderasn.py
index 3161102997daf57449b26c20194b1081613f8b65..309bc2290df05440bccc2b63d65a53547100e852 100644 (file)
@@ -345,6 +345,9 @@ class CommonMixin(object):
         obj = Inherited()
         self.assertSequenceEqual(obj.impl, impl_tag)
         self.assertFalse(obj.expled)
+        if obj.ready:
+            tag_class, _, tag_num = tag_decode(impl_tag)
+            self.assertEqual(obj.tag_order, (tag_class, tag_num))
 
     @given(binary(min_size=1))
     def test_expl_inherited(self, expl_tag):
@@ -353,6 +356,9 @@ class CommonMixin(object):
         obj = Inherited()
         self.assertSequenceEqual(obj.expl, expl_tag)
         self.assertTrue(obj.expled)
+        if obj.ready:
+            tag_class, _, tag_num = tag_decode(expl_tag)
+            self.assertEqual(obj.tag_order, (tag_class, tag_num))
 
     def assert_copied_basic_fields(self, obj, obj_copied):
         self.assertEqual(obj, obj_copied)
@@ -363,6 +369,8 @@ class CommonMixin(object):
         self.assertEqual(obj.offset, obj_copied.offset)
         self.assertEqual(obj.llen, obj_copied.llen)
         self.assertEqual(obj.vlen, obj_copied.vlen)
+        if obj.ready:
+            self.assertEqual(obj.tag_order, obj_copied.tag_order)
 
 
 @composite
@@ -4667,9 +4675,16 @@ class TestUTCTime(TimeMixin, CommonMixin, TestCase):
             )
 
 
+@composite
+def tlv_value_strategy(draw):
+    tag_num = draw(integers(min_value=1))
+    data = draw(binary())
+    return b"".join((tag_encode(tag_num), len_encode(len(data)), data))
+
+
 @composite
 def any_values_strategy(draw, do_expl=False):
-    value = draw(one_of(none(), binary()))
+    value = draw(one_of(none(), tlv_value_strategy()))
     expl = None
     if do_expl:
         expl = draw(one_of(none(), integers(min_value=1).map(tag_encode)))
@@ -4699,7 +4714,7 @@ class TestAny(CommonMixin, TestCase):
         obj = Any(optional=optional)
         self.assertEqual(obj.optional, optional)
 
-    @given(binary())
+    @given(tlv_value_strategy())
     def test_ready(self, value):
         obj = Any()
         self.assertFalse(obj.ready)
@@ -4733,7 +4748,7 @@ class TestAny(CommonMixin, TestCase):
             pprint(obj, big_blobs=True, with_decode_path=True)
             self.assertSequenceEqual(obj.encode(), integer_encoded)
 
-    @given(binary(min_size=1), binary(min_size=1))
+    @given(tlv_value_strategy(), tlv_value_strategy())
     def test_comparison(self, value1, value2):
         for klass in (Any, AnyInherited):
             obj1 = klass(value1)
@@ -4797,7 +4812,7 @@ class TestAny(CommonMixin, TestCase):
             obj.decode(obj.encode()[:-1])
 
     @given(
-        binary(),
+        tlv_value_strategy(),
         integers(min_value=1).map(tag_ctxc),
     )
     def test_stripped_expl(self, value, tag_expl):
@@ -4853,9 +4868,13 @@ class TestAny(CommonMixin, TestCase):
             list(obj.pps())
             pprint(obj, big_blobs=True, with_decode_path=True)
             self.assertFalse(obj.expled)
+            tag_class, _, tag_num = tag_decode(tag_strip(value)[0])
+            self.assertEqual(obj.tag_order, (tag_class, tag_num))
             obj_encoded = obj.encode()
             obj_expled = obj(value, expl=tag_expl)
             self.assertTrue(obj_expled.expled)
+            tag_class, _, tag_num = tag_decode(tag_expl)
+            self.assertEqual(obj_expled.tag_order, (tag_class, tag_num))
             repr(obj_expled)
             list(obj_expled.pps())
             pprint(obj_expled, big_blobs=True, with_decode_path=True)
@@ -5219,9 +5238,12 @@ class TestChoice(CommonMixin, TestCase):
         list(obj.pps())
         pprint(obj, big_blobs=True, with_decode_path=True)
         self.assertFalse(obj.expled)
+        self.assertEqual(obj.tag_order, obj.value.tag_order)
         obj_encoded = obj.encode()
         obj_expled = obj(value, expl=tag_expl)
         self.assertTrue(obj_expled.expled)
+        tag_class, _, tag_num = tag_decode(tag_expl)
+        self.assertEqual(obj_expled.tag_order, (tag_class, tag_num))
         repr(obj_expled)
         list(obj_expled.pps())
         pprint(obj_expled, big_blobs=True, with_decode_path=True)
@@ -5652,8 +5674,9 @@ class SeqMixing(object):
         with self.assertRaises(NotEnoughData):
             seq.decode(seq.encode()[:-1])
 
-    @given(binary(min_size=2))
-    def test_non_tag_mismatch_raised(self, junk):
+    @given(integers(min_value=3), binary(min_size=2))
+    def test_non_tag_mismatch_raised(self, junk_tag_num, junk):
+        junk = tag_encode(junk_tag_num) + junk
         try:
             _, _, len_encoded = tag_strip(memoryview(junk))
             len_decode(len_encoded)
@@ -5999,42 +6022,44 @@ class TestSet(SeqMixing, CommonMixin, TestCase):
     @settings(max_examples=LONG_TEST_MAX_EXAMPLES)
     @given(data_strategy())
     def test_sorted(self, d):
-        tags = [
-            tag_encode(tag) for tag in
-            d.draw(sets(integers(min_value=1), min_size=1, max_size=10))
-        ]
+        class DummySeq(Sequence):
+            schema = (("null", Null()),)
+
+        tag_nums = d.draw(sets(integers(min_value=1), min_size=1, max_size=50))
+        _, _, dummy_seq_tag_num = tag_decode(DummySeq.tag_default)
+        assume(any(i > dummy_seq_tag_num for i in tag_nums))
+        tag_nums -= set([dummy_seq_tag_num])
+        _schema = [(str(i), OctetString(impl=tag_encode(i))) for i in tag_nums]
+        _schema.append(("seq", DummySeq()))
 
         class Seq(Set):
-            schema = [(str(i), OctetString(impl=t)) for i, t in enumerate(tags)]
+            schema = d.draw(permutations(_schema))
         seq = Seq()
-        for name, _ in Seq.schema:
-            seq[name] = OctetString(b"")
+        for name, _ in _schema:
+            if name != "seq":
+                seq[name] = OctetString(name.encode("ascii"))
+        seq["seq"] = DummySeq((("null", Null()),))
+
         seq_encoded = seq.encode()
         seq_decoded, _ = seq.decode(seq_encoded)
+        seq_encoded_expected = []
+        for tag_num in sorted(tag_nums | set([dummy_seq_tag_num])):
+            if tag_num == dummy_seq_tag_num:
+                seq_encoded_expected.append(seq["seq"].encode())
+            else:
+                seq_encoded_expected.append(seq[str(tag_num)].encode())
         self.assertSequenceEqual(
             seq_encoded[seq_decoded.tlen + seq_decoded.llen:],
-            b"".join(sorted([seq[name].encode() for name, _ in Seq.schema])),
+            b"".join(seq_encoded_expected),
         )
 
-    @settings(max_examples=LONG_TEST_MAX_EXAMPLES)
-    @given(data_strategy())
-    def test_unsorted(self, d):
-        tags = [
-            tag_encode(tag) for tag in
-            d.draw(sets(integers(min_value=1), min_size=2, max_size=5))
-        ]
-        tags = d.draw(permutations(tags))
-        assume(tags != sorted(tags))
-        encoded = b"".join(OctetString(t, impl=t).encode() for t in tags)
+        encoded = b"".join(seq[str(i)].encode() for i in tag_nums)
+        encoded += seq["seq"].encode()
         seq_encoded = b"".join((
             Set.tag_default,
             len_encode(len(encoded)),
             encoded,
         ))
-
-        class Seq(Set):
-            schema = [(str(i), OctetString(impl=t)) for i, t in enumerate(tags)]
-        seq = Seq()
         with assertRaisesRegex(self, DecodeError, "unordered SET"):
             seq.decode(seq_encoded)
         for ctx in ({"bered": True}, {"allow_unordered_set": True}):
@@ -6044,10 +6069,6 @@ class TestSet(SeqMixing, CommonMixin, TestCase):
             seq_decoded = copy(seq_decoded)
             self.assertTrue(seq_decoded.ber_encoded)
             self.assertTrue(seq_decoded.bered)
-            self.assertSequenceEqual(
-                [bytes(seq_decoded[str(i)]) for i, t in enumerate(tags)],
-                tags,
-            )
 
     def test_same_value_twice(self):
         class Seq(Set):