]> Cypherpunks.ru repositories - pyderasn.git/blobdiff - tests/test_pyderasn.py
DRY decode_path strategy
[pyderasn.git] / tests / test_pyderasn.py
index faa85699fd2d628058d3c9cf5d2ce83d8fbcc218..ef5f79e4d97d88cb7fbf8c308404eb4a20499be7 100644 (file)
@@ -1,5 +1,5 @@
 # coding: utf-8
-# PyDERASN -- Python ASN.1 DER codec with abstract structures
+# PyDERASN -- Python ASN.1 DER/BER codec with abstract structures
 # Copyright (C) 2017-2018 Sergey Matveev <stargrave@stargrave.org>
 #
 # This program is free software: you can redistribute it and/or modify
@@ -125,6 +125,9 @@ tag_classes = sampled_from((
     TagClassUniversal,
 ))
 tag_forms = sampled_from((TagFormConstructed, TagFormPrimitive))
+decode_path_strat = lists(integers(), max_size=3).map(
+    lambda decode_path: tuple(str(dp) for dp in decode_path)
+)
 
 
 class TestHex(TestCase):
@@ -463,10 +466,9 @@ class TestBoolean(CommonMixin, TestCase):
     @given(
         integers(min_value=31),
         integers(min_value=0),
-        lists(integers()),
+        decode_path_strat,
     )
     def test_bad_tag(self, tag, offset, decode_path):
-        decode_path = tuple(str(i) for i in decode_path)
         with self.assertRaises(DecodeError) as err:
             Boolean().decode(
                 tag_encode(tag)[:-1],
@@ -480,10 +482,9 @@ class TestBoolean(CommonMixin, TestCase):
     @given(
         integers(min_value=31),
         integers(min_value=0),
-        lists(integers()),
+        decode_path_strat,
     )
     def test_bad_expl_tag(self, tag, offset, decode_path):
-        decode_path = tuple(str(i) for i in decode_path)
         with self.assertRaises(DecodeError) as err:
             Boolean(expl=Boolean.tag_default).decode(
                 tag_encode(tag)[:-1],
@@ -497,10 +498,9 @@ class TestBoolean(CommonMixin, TestCase):
     @given(
         integers(min_value=128),
         integers(min_value=0),
-        lists(integers()),
+        decode_path_strat,
     )
     def test_bad_len(self, l, offset, decode_path):
-        decode_path = tuple(str(i) for i in decode_path)
         with self.assertRaises(DecodeError) as err:
             Boolean().decode(
                 Boolean.tag_default + len_encode(l)[:-1],
@@ -514,10 +514,9 @@ class TestBoolean(CommonMixin, TestCase):
     @given(
         integers(min_value=128),
         integers(min_value=0),
-        lists(integers()),
+        decode_path_strat,
     )
     def test_bad_expl_len(self, l, offset, decode_path):
-        decode_path = tuple(str(i) for i in decode_path)
         with self.assertRaises(DecodeError) as err:
             Boolean(expl=Boolean.tag_default).decode(
                 Boolean.tag_default + len_encode(l)[:-1],
@@ -607,6 +606,7 @@ class TestBoolean(CommonMixin, TestCase):
         )
         self.assertTrue(bool(obj))
         self.assertTrue(obj.bered)
+        self.assertFalse(obj.lenindef)
 
     @given(
         integers(min_value=1).map(tag_ctxc),
@@ -633,8 +633,26 @@ class TestBoolean(CommonMixin, TestCase):
         self.assertSequenceEqual(tail, b"")
         self.assertSequenceEqual([bool(v) for v in seqof], values)
         self.assertSetEqual(
-            set((v.tlvlen, v.expl_tlvlen, v.expl_tlen, v.expl_llen) for v in seqof),
-            set(((3 + EOC_LEN, len(expl) + 1 + 3 + EOC_LEN, len(expl), 1),)),
+            set(
+                (
+                    v.tlvlen,
+                    v.expl_tlvlen,
+                    v.expl_tlen,
+                    v.expl_llen,
+                    v.bered,
+                    v.lenindef,
+                    v.expl_lenindef,
+                ) for v in seqof
+            ),
+            set(((
+                3 + EOC_LEN,
+                len(expl) + 1 + 3 + EOC_LEN,
+                len(expl),
+                1,
+                False,
+                False,
+                True,
+            ),)),
         )
 
 
@@ -894,10 +912,9 @@ class TestInteger(CommonMixin, TestCase):
     @given(
         integers(min_value=31),
         integers(min_value=0),
-        lists(integers()),
+        decode_path_strat,
     )
     def test_bad_tag(self, tag, offset, decode_path):
-        decode_path = tuple(str(i) for i in decode_path)
         with self.assertRaises(DecodeError) as err:
             Integer().decode(
                 tag_encode(tag)[:-1],
@@ -911,10 +928,9 @@ class TestInteger(CommonMixin, TestCase):
     @given(
         integers(min_value=128),
         integers(min_value=0),
-        lists(integers()),
+        decode_path_strat,
     )
     def test_bad_len(self, l, offset, decode_path):
-        decode_path = tuple(str(i) for i in decode_path)
         with self.assertRaises(DecodeError) as err:
             Integer().decode(
                 Integer.tag_default + len_encode(l)[:-1],
@@ -928,10 +944,9 @@ class TestInteger(CommonMixin, TestCase):
     @given(
         sets(integers(), min_size=2, max_size=2),
         integers(min_value=0),
-        lists(integers()),
+        decode_path_strat,
     )
     def test_invalid_bounds_while_decoding(self, ints, offset, decode_path):
-        decode_path = tuple(str(i) for i in decode_path)
         value, bound_min = list(sorted(ints))
 
         class Int(Integer):
@@ -1288,10 +1303,9 @@ class TestBitString(CommonMixin, TestCase):
     @given(
         integers(min_value=31),
         integers(min_value=0),
-        lists(integers()),
+        decode_path_strat,
     )
     def test_bad_tag(self, tag, offset, decode_path):
-        decode_path = tuple(str(i) for i in decode_path)
         with self.assertRaises(DecodeError) as err:
             BitString().decode(
                 tag_encode(tag)[:-1],
@@ -1305,10 +1319,9 @@ class TestBitString(CommonMixin, TestCase):
     @given(
         integers(min_value=128),
         integers(min_value=0),
-        lists(integers()),
+        decode_path_strat,
     )
     def test_bad_len(self, l, offset, decode_path):
-        decode_path = tuple(str(i) for i in decode_path)
         with self.assertRaises(DecodeError) as err:
             BitString().decode(
                 BitString.tag_default + len_encode(l)[:-1],
@@ -1485,7 +1498,10 @@ class TestBitString(CommonMixin, TestCase):
         )
         with assertRaisesRegex(self, DecodeError, "unallowed BER"):
             BitString(impl=tag_encode(impl)).decode(encoded_indefinite)
-        for encoded in (encoded_indefinite, encoded_definite):
+        for lenindef_expected, encoded in (
+                (True, encoded_indefinite),
+                (False, encoded_definite),
+        ):
             obj, tail = BitString(impl=tag_encode(impl)).decode(
                 encoded, ctx={"bered": True}
             )
@@ -1493,11 +1509,11 @@ class TestBitString(CommonMixin, TestCase):
             self.assertEqual(obj.bit_len, bit_len_expected)
             self.assertSequenceEqual(bytes(obj), payload_expected)
             self.assertTrue(obj.bered)
+            self.assertEqual(obj.lenindef, lenindef_expected)
             self.assertEqual(len(encoded), obj.tlvlen)
 
     def test_x690_vector(self):
-        vector_payload = hexdec("0A3B5F291CD0")
-        vector = BitString((len(vector_payload) * 8 - 4, vector_payload))
+        vector = BitString("'0A3B5F291CD'H")
         obj, tail = BitString().decode(hexdec("0307040A3B5F291CD0"))
         self.assertSequenceEqual(tail, b"")
         self.assertEqual(obj, vector)
@@ -1507,6 +1523,8 @@ class TestBitString(CommonMixin, TestCase):
         )
         self.assertSequenceEqual(tail, b"")
         self.assertEqual(obj, vector)
+        self.assertTrue(obj.bered)
+        self.assertTrue(obj.lenindef)
 
 
 @composite
@@ -1715,10 +1733,9 @@ class TestOctetString(CommonMixin, TestCase):
     @given(
         integers(min_value=31),
         integers(min_value=0),
-        lists(integers()),
+        decode_path_strat,
     )
     def test_bad_tag(self, tag, offset, decode_path):
-        decode_path = tuple(str(i) for i in decode_path)
         with self.assertRaises(DecodeError) as err:
             OctetString().decode(
                 tag_encode(tag)[:-1],
@@ -1732,10 +1749,9 @@ class TestOctetString(CommonMixin, TestCase):
     @given(
         integers(min_value=128),
         integers(min_value=0),
-        lists(integers()),
+        decode_path_strat,
     )
     def test_bad_len(self, l, offset, decode_path):
-        decode_path = tuple(str(i) for i in decode_path)
         with self.assertRaises(DecodeError) as err:
             OctetString().decode(
                 OctetString.tag_default + len_encode(l)[:-1],
@@ -1749,10 +1765,9 @@ class TestOctetString(CommonMixin, TestCase):
     @given(
         sets(integers(min_value=0, max_value=10), min_size=2, max_size=2),
         integers(min_value=0),
-        lists(integers()),
+        decode_path_strat,
     )
     def test_invalid_bounds_while_decoding(self, ints, offset, decode_path):
-        decode_path = tuple(str(i) for i in decode_path)
         value, bound_min = list(sorted(ints))
 
         class String(OctetString):
@@ -1865,13 +1880,17 @@ class TestOctetString(CommonMixin, TestCase):
         )
         with assertRaisesRegex(self, DecodeError, "unallowed BER"):
             OctetString(impl=tag_encode(impl)).decode(encoded_indefinite)
-        for encoded in (encoded_indefinite, encoded_definite):
+        for lenindef_expected, encoded in (
+                (True, encoded_indefinite),
+                (False, encoded_definite),
+        ):
             obj, tail = OctetString(impl=tag_encode(impl)).decode(
                 encoded, ctx={"bered": True}
             )
             self.assertSequenceEqual(tail, b"")
             self.assertSequenceEqual(bytes(obj), payload_expected)
             self.assertTrue(obj.bered)
+            self.assertEqual(obj.lenindef, lenindef_expected)
             self.assertEqual(len(encoded), obj.tlvlen)
 
 
@@ -1970,10 +1989,9 @@ class TestNull(CommonMixin, TestCase):
     @given(
         integers(min_value=31),
         integers(min_value=0),
-        lists(integers()),
+        decode_path_strat,
     )
     def test_bad_tag(self, tag, offset, decode_path):
-        decode_path = tuple(str(i) for i in decode_path)
         with self.assertRaises(DecodeError) as err:
             Null().decode(
                 tag_encode(tag)[:-1],
@@ -1987,10 +2005,9 @@ class TestNull(CommonMixin, TestCase):
     @given(
         integers(min_value=128),
         integers(min_value=0),
-        lists(integers()),
+        decode_path_strat,
     )
     def test_bad_len(self, l, offset, decode_path):
-        decode_path = tuple(str(i) for i in decode_path)
         with self.assertRaises(DecodeError) as err:
             Null().decode(
                 Null.tag_default + len_encode(l)[:-1],
@@ -2242,10 +2259,9 @@ class TestObjectIdentifier(CommonMixin, TestCase):
     @given(
         integers(min_value=31),
         integers(min_value=0),
-        lists(integers()),
+        decode_path_strat,
     )
     def test_bad_tag(self, tag, offset, decode_path):
-        decode_path = tuple(str(i) for i in decode_path)
         with self.assertRaises(DecodeError) as err:
             ObjectIdentifier().decode(
                 tag_encode(tag)[:-1],
@@ -2259,10 +2275,9 @@ class TestObjectIdentifier(CommonMixin, TestCase):
     @given(
         integers(min_value=128),
         integers(min_value=0),
-        lists(integers()),
+        decode_path_strat,
     )
     def test_bad_len(self, l, offset, decode_path):
-        decode_path = tuple(str(i) for i in decode_path)
         with self.assertRaises(DecodeError) as err:
             ObjectIdentifier().decode(
                 ObjectIdentifier.tag_default + len_encode(l)[:-1],
@@ -2877,10 +2892,9 @@ class StringMixin(object):
     @given(
         integers(min_value=31),
         integers(min_value=0),
-        lists(integers()),
+        decode_path_strat,
     )
     def test_bad_tag(self, tag, offset, decode_path):
-        decode_path = tuple(str(i) for i in decode_path)
         with self.assertRaises(DecodeError) as err:
             self.base_klass().decode(
                 tag_encode(tag)[:-1],
@@ -2894,10 +2908,9 @@ class StringMixin(object):
     @given(
         integers(min_value=128),
         integers(min_value=0),
-        lists(integers()),
+        decode_path_strat,
     )
     def test_bad_len(self, l, offset, decode_path):
-        decode_path = tuple(str(i) for i in decode_path)
         with self.assertRaises(DecodeError) as err:
             self.base_klass().decode(
                 self.base_klass.tag_default + len_encode(l)[:-1],
@@ -2911,10 +2924,9 @@ class StringMixin(object):
     @given(
         sets(integers(min_value=0, max_value=10), min_size=2, max_size=2),
         integers(min_value=0),
-        lists(integers()),
+        decode_path_strat,
     )
     def test_invalid_bounds_while_decoding(self, ints, offset, decode_path):
-        decode_path = tuple(str(i) for i in decode_path)
         value, bound_min = list(sorted(ints))
 
         class String(self.base_klass):
@@ -3012,10 +3024,9 @@ class TestNumericString(StringMixin, CommonMixin, TestCase):
     @given(
         sets(integers(min_value=0, max_value=10), min_size=2, max_size=2),
         integers(min_value=0),
-        lists(integers()),
+        decode_path_strat,
     )
     def test_invalid_bounds_while_decoding(self, ints, offset, decode_path):
-        decode_path = tuple(str(i) for i in decode_path)
         value, bound_min = list(sorted(ints))
 
         class String(self.base_klass):
@@ -3085,24 +3096,29 @@ class TestVisibleString(
     base_klass = VisibleString
 
     def test_x690_vector(self):
-        self.assertEqual(
-            str(VisibleString().decode(hexdec("1A054A6F6E6573"))[0]),
-            "Jones",
-        )
-        self.assertEqual(
-            str(VisibleString().decode(
-                hexdec("3A0904034A6F6E04026573"),
-                ctx={"bered": True},
-            )[0]),
-            "Jones",
+        obj, tail = VisibleString().decode(hexdec("1A054A6F6E6573"))
+        self.assertSequenceEqual(tail, b"")
+        self.assertEqual(str(obj), "Jones")
+        self.assertFalse(obj.bered)
+        self.assertFalse(obj.lenindef)
+
+        obj, tail = VisibleString().decode(
+            hexdec("3A0904034A6F6E04026573"),
+            ctx={"bered": True},
         )
-        self.assertEqual(
-            str(VisibleString().decode(
-                hexdec("3A8004034A6F6E040265730000"),
-                ctx={"bered": True},
-            )[0]),
-            "Jones",
+        self.assertSequenceEqual(tail, b"")
+        self.assertEqual(str(obj), "Jones")
+        self.assertTrue(obj.bered)
+        self.assertFalse(obj.lenindef)
+
+        obj, tail = VisibleString().decode(
+            hexdec("3A8004034A6F6E040265730000"),
+            ctx={"bered": True},
         )
+        self.assertSequenceEqual(tail, b"")
+        self.assertEqual(str(obj), "Jones")
+        self.assertTrue(obj.bered)
+        self.assertTrue(obj.lenindef)
 
 
 class TestGeneralString(
@@ -3605,10 +3621,9 @@ class TestAny(CommonMixin, TestCase):
     @given(
         integers(min_value=31),
         integers(min_value=0),
-        lists(integers()),
+        decode_path_strat,
     )
     def test_bad_tag(self, tag, offset, decode_path):
-        decode_path = tuple(str(i) for i in decode_path)
         with self.assertRaises(DecodeError) as err:
             Any().decode(
                 tag_encode(tag)[:-1],
@@ -3622,10 +3637,9 @@ class TestAny(CommonMixin, TestCase):
     @given(
         integers(min_value=128),
         integers(min_value=0),
-        lists(integers()),
+        decode_path_strat,
     )
     def test_bad_len(self, l, offset, decode_path):
-        decode_path = tuple(str(i) for i in decode_path)
         with self.assertRaises(DecodeError) as err:
             Any().decode(
                 Any.tag_default + len_encode(l)[:-1],
@@ -4351,10 +4365,9 @@ class SeqMixing(object):
     @given(
         integers(min_value=31),
         integers(min_value=0),
-        lists(integers()),
+        decode_path_strat,
     )
     def test_bad_tag(self, tag, offset, decode_path):
-        decode_path = tuple(str(i) for i in decode_path)
         with self.assertRaises(DecodeError) as err:
             self.base_klass().decode(
                 tag_encode(tag)[:-1],
@@ -4368,10 +4381,9 @@ class SeqMixing(object):
     @given(
         integers(min_value=128),
         integers(min_value=0),
-        lists(integers()),
+        decode_path_strat,
     )
     def test_bad_len(self, l, offset, decode_path):
-        decode_path = tuple(str(i) for i in decode_path)
         with self.assertRaises(DecodeError) as err:
             self.base_klass().decode(
                 self.base_klass.tag_default + len_encode(l)[:-1],
@@ -4904,10 +4916,9 @@ class SeqOfMixing(object):
     @given(
         integers(min_value=31),
         integers(min_value=0),
-        lists(integers()),
+        decode_path_strat,
     )
     def test_bad_tag(self, tag, offset, decode_path):
-        decode_path = tuple(str(i) for i in decode_path)
         with self.assertRaises(DecodeError) as err:
             self.base_klass().decode(
                 tag_encode(tag)[:-1],
@@ -4921,10 +4932,9 @@ class SeqOfMixing(object):
     @given(
         integers(min_value=128),
         integers(min_value=0),
-        lists(integers()),
+        decode_path_strat,
     )
     def test_bad_len(self, l, offset, decode_path):
-        decode_path = tuple(str(i) for i in decode_path)
         with self.assertRaises(DecodeError) as err:
             self.base_klass().decode(
                 self.base_klass.tag_default + len_encode(l)[:-1],