]> Cypherpunks.ru repositories - pyderasn.git/blobdiff - tests/test_pyderasn.py
Convenient decod() helper method
[pyderasn.git] / tests / test_pyderasn.py
index edf6a7b440f0a6989836d01c0e2a2e2f1344e01d..9fbb33eb39f2d75231721cee73de95a32d66d717 100644 (file)
@@ -65,6 +65,7 @@ from pyderasn import DecodePathDefBy
 from pyderasn import Enumerated
 from pyderasn import EOC
 from pyderasn import EOC_LEN
+from pyderasn import ExceedingData
 from pyderasn import GeneralizedTime
 from pyderasn import GeneralString
 from pyderasn import GraphicString
@@ -135,6 +136,12 @@ decode_path_strat = lists(integers(), max_size=3).map(
 ctx_dummy = dictionaries(integers(), integers(), min_size=2, max_size=4).example()
 
 
+def assert_exceeding_data(self, call, junk):
+    if len(junk) > 0:
+        with assertRaisesRegex(self, ExceedingData, "%d trailing bytes" % len(junk)):
+            call()
+
+
 class TestHex(TestCase):
     @given(binary())
     def test_symmetric(self, data):
@@ -591,6 +598,11 @@ class TestBoolean(CommonMixin, TestCase):
                 offset + obj_decoded.expl_tlen + obj_decoded.expl_llen,
             )
             self.assertEqual(obj_decoded.expl_offset, offset)
+            assert_exceeding_data(
+                self,
+                lambda: obj_expled.decod(obj_expled_encoded + tail_junk),
+                tail_junk,
+            )
 
     @given(integers(min_value=2))
     def test_invalid_len(self, l):
@@ -1083,6 +1095,11 @@ class TestInteger(CommonMixin, TestCase):
                 offset + obj_decoded.expl_tlen + obj_decoded.expl_llen,
             )
             self.assertEqual(obj_decoded.expl_offset, offset)
+            assert_exceeding_data(
+                self,
+                lambda: obj_expled.decod(obj_expled_encoded + tail_junk),
+                tail_junk,
+            )
 
     def test_go_vectors_valid(self):
         for data, expect in ((
@@ -1473,6 +1490,11 @@ class TestBitString(CommonMixin, TestCase):
                 self.assertSetEqual(set(value), set(obj_decoded.named))
                 for name in value:
                     obj_decoded[name]
+            assert_exceeding_data(
+                self,
+                lambda: obj_expled.decod(obj_expled_encoded + tail_junk),
+                tail_junk,
+            )
 
     @given(integers(min_value=1, max_value=255))
     def test_bad_zero_value(self, pad_size):
@@ -2058,6 +2080,11 @@ class TestOctetString(CommonMixin, TestCase):
                 offset + obj_decoded.expl_tlen + obj_decoded.expl_llen,
             )
             self.assertEqual(obj_decoded.expl_offset, offset)
+            assert_exceeding_data(
+                self,
+                lambda: obj_expled.decod(obj_expled_encoded + tail_junk),
+                tail_junk,
+            )
 
     @given(
         integers(min_value=1, max_value=30),
@@ -2360,6 +2387,11 @@ class TestNull(CommonMixin, TestCase):
                 offset + obj_decoded.expl_tlen + obj_decoded.expl_llen,
             )
             self.assertEqual(obj_decoded.expl_offset, offset)
+            assert_exceeding_data(
+                self,
+                lambda: obj_expled.decod(obj_expled_encoded + tail_junk),
+                tail_junk,
+            )
 
     @given(integers(min_value=1))
     def test_invalid_len(self, l):
@@ -2698,6 +2730,11 @@ class TestObjectIdentifier(CommonMixin, TestCase):
                 offset + obj_decoded.expl_tlen + obj_decoded.expl_llen,
             )
             self.assertEqual(obj_decoded.expl_offset, offset)
+            assert_exceeding_data(
+                self,
+                lambda: obj_expled.decod(obj_expled_encoded + tail_junk),
+                tail_junk,
+            )
 
     @given(
         oid_strategy().map(ObjectIdentifier),
@@ -3041,6 +3078,11 @@ class TestEnumerated(CommonMixin, TestCase):
             offset + obj_decoded.expl_tlen + obj_decoded.expl_llen,
         )
         self.assertEqual(obj_decoded.expl_offset, offset)
+        assert_exceeding_data(
+            self,
+            lambda: obj_expled.decod(obj_expled_encoded + tail_junk),
+            tail_junk,
+        )
 
 
 @composite
@@ -3371,6 +3413,11 @@ class StringMixin(object):
             offset + obj_decoded.expl_tlen + obj_decoded.expl_llen,
         )
         self.assertEqual(obj_decoded.expl_offset, offset)
+        assert_exceeding_data(
+            self,
+            lambda: obj_expled.decod(obj_expled_encoded + tail_junk),
+            tail_junk,
+        )
 
 
 class TestUTF8String(StringMixin, CommonMixin, TestCase):
@@ -3806,6 +3853,11 @@ class TimeMixin(object):
             offset + obj_decoded.expl_tlen + obj_decoded.expl_llen,
         )
         self.assertEqual(obj_decoded.expl_offset, offset)
+        assert_exceeding_data(
+            self,
+            lambda: obj_expled.decod(obj_expled_encoded + tail_junk),
+            tail_junk,
+        )
 
 
 class TestGeneralizedTime(TimeMixin, CommonMixin, TestCase):
@@ -4246,6 +4298,11 @@ class TestAny(CommonMixin, TestCase):
             self.assertEqual(obj_decoded.tlen, 0)
             self.assertEqual(obj_decoded.llen, 0)
             self.assertEqual(obj_decoded.vlen, len(value))
+            assert_exceeding_data(
+                self,
+                lambda: obj_expled.decod(obj_expled_encoded + tail_junk),
+                tail_junk,
+            )
 
     @given(
         integers(min_value=1).map(tag_ctxc),
@@ -4610,6 +4667,11 @@ class TestChoice(CommonMixin, TestCase):
             ],
             obj_encoded,
         )
+        assert_exceeding_data(
+            self,
+            lambda: obj_expled.decod(obj_expled_encoded + tail_junk),
+            tail_junk,
+        )
 
     @given(integers())
     def test_set_get(self, value):
@@ -5128,6 +5190,12 @@ class SeqMixing(object):
                     obj.encode(),
                 )
 
+        assert_exceeding_data(
+            self,
+            lambda: seq.decod(seq_encoded_lenindef + tail_junk, ctx={"bered": True}),
+            tail_junk,
+        )
+
     @settings(max_examples=LONG_TEST_MAX_EXAMPLES)
     @given(data_strategy())
     def test_symmetric_with_seq(self, d):
@@ -5814,6 +5882,12 @@ class SeqOfMixing(object):
         with self.assertRaises(DecodeError):
             obj.decode(obj_encoded_lenindef[:-2], ctx={"bered": True})
 
+        assert_exceeding_data(
+            self,
+            lambda: obj_expled.decod(obj_expled_encoded + tail_junk),
+            tail_junk,
+        )
+
     def test_bered(self):
         class SeqOf(self.base_klass):
             schema = Boolean()