]> Cypherpunks.ru repositories - pyderasn.git/blobdiff - tests/test_pyderasn.py
Stricter validation of *Time
[pyderasn.git] / tests / test_pyderasn.py
index 1a7ed2c488f713e9af4b79a174816546547fc957..f2e73b4202a0a626f762e71b94b54f94f9548aff 100644 (file)
@@ -155,9 +155,11 @@ def register_class(klass):
 
 
 def assert_exceeding_data(self, call, junk):
-    if len(junk) > 0:
-        with assertRaisesRegex(self, ExceedingData, "%d trailing bytes" % len(junk)):
-            call()
+    if len(junk) <= 0:
+        return
+    with assertRaisesRegex(self, ExceedingData, "%d trailing bytes" % len(junk)) as err:
+        call()
+    repr(err)
 
 
 class TestHex(TestCase):
@@ -1641,6 +1643,9 @@ class TestBitString(CommonMixin, TestCase):
             self.assertEqual(obj.lenindef, lenindef_expected)
             self.assertTrue(obj.bered)
             self.assertEqual(len(encoded), obj.tlvlen)
+            repr(obj)
+            list(obj.pps())
+            pprint(obj, big_blobs=True, with_decode_path=True)
 
     @given(
         integers(min_value=0),
@@ -2174,6 +2179,9 @@ class TestOctetString(CommonMixin, TestCase):
             self.assertEqual(obj.lenindef, lenindef_expected)
             self.assertTrue(obj.bered)
             self.assertEqual(len(encoded), obj.tlvlen)
+            repr(obj)
+            list(obj.pps())
+            pprint(obj, big_blobs=True, with_decode_path=True)
 
     @given(
         integers(min_value=0),
@@ -2819,6 +2827,28 @@ class TestObjectIdentifier(CommonMixin, TestCase):
         with assertRaisesRegex(self, DecodeError, "non normalized arc encoding"):
             ObjectIdentifier().decode(tampered)
 
+    @given(data_strategy())
+    def test_negative_arcs(self, d):
+        oid = list(d.draw(oid_strategy()))
+        if len(oid) == 2:
+            return
+        idx = d.draw(integers(min_value=3, max_value=len(oid)))
+        oid[idx - 1] *= -1
+        if oid[idx - 1] == 0:
+            oid[idx - 1] = -1
+        with self.assertRaises(InvalidOID):
+            ObjectIdentifier(tuple(oid))
+        with self.assertRaises(InvalidOID):
+            ObjectIdentifier(".".join(str(i) for i in oid))
+
+    @given(data_strategy())
+    def test_plused_arcs(self, d):
+        oid = [str(arc) for arc in d.draw(oid_strategy())]
+        idx = d.draw(integers(min_value=0, max_value=len(oid)))
+        oid[idx - 1] = "+" + oid[idx - 1]
+        with self.assertRaises(InvalidOID):
+            ObjectIdentifier(".".join(str(i) for i in oid))
+
     @given(data_strategy())
     def test_nonnormalized_arcs(self, d):
         arcs = d.draw(lists(
@@ -3715,7 +3745,10 @@ class TimeMixin(object):
         with self.assertRaises(ObjNotReady) as err:
             obj.encode()
         repr(err.exception)
-        value = d.draw(datetimes(min_value=self.min_datetime))
+        value = d.draw(datetimes(
+            min_value=self.min_datetime,
+            max_value=self.max_datetime,
+        ))
         obj = self.base_klass(value)
         self.assertTrue(obj.ready)
         repr(obj)
@@ -4021,6 +4054,35 @@ class TestGeneralizedTime(TimeMixin, CommonMixin, TestCase):
         with assertRaisesRegex(self, DecodeError, "only microsecond fractions"):
             GeneralizedTime(b"20010101000000.0000001Z")
 
+    def test_non_pure_integers(self):
+        for data in ((
+                # b"20000102030405Z,
+                b"+2000102030405Z",
+                b"2000+102030405Z",
+                b"200001+2030405Z",
+                b"20000102+30405Z",
+                b"2000010203+405Z",
+                b"200001020304+5Z",
+                b"20000102030405.+6Z",
+                b"20000102030405.-6Z",
+                b" 2000102030405Z",
+                b"2000 102030405Z",
+                b"200001 2030405Z",
+                b"20000102 30405Z",
+                b"2000010203 405Z",
+                b"200001020304 5Z",
+                b"20000102030405. 6Z",
+                b"200 0102030405Z",
+                b"20001 02030405Z",
+                b"2000012 030405Z",
+                b"200001023 0405Z",
+                b"20000102034 05Z",
+                b"2000010203045 Z",
+                b"20000102030405.6 Z",
+        )):
+            with self.assertRaises(DecodeError):
+                GeneralizedTime(data)
+
 
 class TestUTCTime(TimeMixin, CommonMixin, TestCase):
     base_klass = UTCTime
@@ -4091,6 +4153,31 @@ class TestUTCTime(TimeMixin, CommonMixin, TestCase):
             datetime(1991, 5, 6, 23, 45, 40, 0),
         )
 
+    def test_non_pure_integers(self):
+        for data in ((
+                # b"000102030405Z",
+                b"+10102030405Z",
+                b"00+102030405Z",
+                b"0001+2030405Z",
+                b"000102+30405Z",
+                b"00010203+405Z",
+                b"0001020304+5Z",
+                b" 10102030405Z",
+                b"00 102030405Z",
+                b"0001 2030405Z",
+                b"000102 30405Z",
+                b"00010203 405Z",
+                b"0001020304 5Z",
+                b"1 0102030405Z",
+                b"001 02030405Z",
+                b"00012 030405Z",
+                b"0001023 0405Z",
+                b"000102034 05Z",
+                b"00010203045 Z",
+        )):
+            with self.assertRaises(DecodeError):
+                UTCTime(data)
+
     @given(integers(min_value=0, max_value=49))
     def test_pre50(self, year):
         self.assertEqual(
@@ -5203,6 +5290,8 @@ class SeqMixing(object):
         t, _, lv = tag_strip(seq_encoded)
         _, _, v = len_decode(lv)
         seq_encoded_lenindef = t + LENINDEF + v + EOC
+        with self.assertRaises(DecodeError):
+            seq.decode(seq_encoded_lenindef)
         ctx_copied = deepcopy(ctx_dummy)
         ctx_copied["bered"] = True
         seq_decoded_lenindef, tail_lenindef = seq.decode(
@@ -5923,6 +6012,8 @@ class SeqOfMixing(object):
         t, _, lv = tag_strip(obj_encoded)
         _, _, v = len_decode(lv)
         obj_encoded_lenindef = t + LENINDEF + v + EOC
+        with self.assertRaises(DecodeError):
+            obj.decode(obj_encoded_lenindef)
         obj_decoded_lenindef, tail_lenindef = obj.decode(
             obj_encoded_lenindef + tail_junk,
             ctx={"bered": True},
@@ -6287,32 +6378,33 @@ class TestOIDDefines(TestCase):
             min_size=len(value_names),
             max_size=len(value_names),
         ))
-        _schema = [
-            ("type", ObjectIdentifier(defines=(((value_name_chosen,), {
-                oid: Integer() for oid in oids[:-1]
-            }),))),
-        ]
-        for i, value_name in enumerate(value_names):
-            _schema.append((value_name, Any(expl=tag_ctxp(i))))
+        for definable_class in (Any, OctetString, BitString):
+            _schema = [
+                ("type", ObjectIdentifier(defines=(((value_name_chosen,), {
+                    oid: Integer() for oid in oids[:-1]
+                }),))),
+            ]
+            for i, value_name in enumerate(value_names):
+                _schema.append((value_name, definable_class(expl=tag_ctxp(i))))
 
-        class Seq(Sequence):
-            schema = _schema
-        seq = Seq()
-        for value_name, value in zip(value_names, values):
-            seq[value_name] = Any(Integer(value).encode())
-        seq["type"] = oid_chosen
-        seq, _ = Seq().decode(seq.encode())
-        for value_name in value_names:
-            if value_name == value_name_chosen:
-                continue
-            self.assertIsNone(seq[value_name].defined)
-        if value_name_chosen in oids[:-1]:
-            self.assertIsNotNone(seq[value_name_chosen].defined)
-            self.assertEqual(seq[value_name_chosen].defined[0], oid_chosen)
-            self.assertIsInstance(seq[value_name_chosen].defined[1], Integer)
-        repr(seq)
-        list(seq.pps())
-        pprint(seq, big_blobs=True, with_decode_path=True)
+            class Seq(Sequence):
+                schema = _schema
+            seq = Seq()
+            for value_name, value in zip(value_names, values):
+                seq[value_name] = definable_class(Integer(value).encode())
+            seq["type"] = oid_chosen
+            seq, _ = Seq().decode(seq.encode())
+            for value_name in value_names:
+                if value_name == value_name_chosen:
+                    continue
+                self.assertIsNone(seq[value_name].defined)
+            if value_name_chosen in oids[:-1]:
+                self.assertIsNotNone(seq[value_name_chosen].defined)
+                self.assertEqual(seq[value_name_chosen].defined[0], oid_chosen)
+                self.assertIsInstance(seq[value_name_chosen].defined[1], Integer)
+            repr(seq)
+            list(seq.pps())
+            pprint(seq, big_blobs=True, with_decode_path=True)
 
 
 class TestDefinesByPath(TestCase):
@@ -6363,10 +6455,10 @@ class TestDefinesByPath(TestCase):
             (type_integered, Integer(234)),
         )
         for t, v in pairs_input:
-            pair = Pair()
-            pair["type"] = t
-            pair["value"] = PairValue((Any(v),))
-            pairs.append(pair)
+            pairs.append(Pair((
+                ("type", t),
+                ("value", PairValue((Any(v),))),
+            )))
         seq_inner = SeqInner()
         seq_inner["typeInner"] = type_innered
         seq_inner["valueInner"] = Any(pairs)
@@ -6504,6 +6596,43 @@ class TestDefinesByPath(TestCase):
         decoded, _ = Outer().decode(outer.encode())
         self.assertEqual(decoded["tgt"].defined[1], Integer(tgt))
 
+    def test_remaining_data(self):
+        oid = ObjectIdentifier("1.2.3")
+        class Seq(Sequence):
+            schema = (
+                ("oid", ObjectIdentifier(defines=((("tgt",), {
+                    oid: Integer(),
+                }),))),
+                ("tgt", OctetString()),
+            )
+
+        seq = Seq((
+            ("oid", oid),
+            ("tgt", OctetString(Integer(123).encode() + b"junk")),
+        ))
+        with assertRaisesRegex(self, DecodeError, "remaining data"):
+            Seq().decode(seq.encode())
+
+    def test_remaining_data_seqof(self):
+        oid = ObjectIdentifier("1.2.3")
+        class SeqOf(SetOf):
+            schema = OctetString()
+
+        class Seq(Sequence):
+            schema = (
+                ("oid", ObjectIdentifier(defines=((("tgt",), {
+                    oid: Integer(),
+                }),))),
+                ("tgt", SeqOf()),
+            )
+
+        seq = Seq((
+            ("oid", oid),
+            ("tgt", SeqOf([OctetString(Integer(123).encode() + b"junk")])),
+        ))
+        with assertRaisesRegex(self, DecodeError, "remaining data"):
+            Seq().decode(seq.encode())
+
 
 class TestAbsDecodePath(TestCase):
     @given(
@@ -6511,10 +6640,9 @@ class TestAbsDecodePath(TestCase):
         lists(text(alphabet=ascii_letters, min_size=1), min_size=1).map(tuple),
     )
     def test_concat(self, decode_path, rel_path):
-        self.assertSequenceEqual(
-            abs_decode_path(decode_path, rel_path),
-            decode_path + rel_path,
-        )
+        dp = abs_decode_path(decode_path, rel_path)
+        self.assertSequenceEqual(dp, decode_path + rel_path)
+        repr(dp)
 
     @given(
         lists(text(alphabet=ascii_letters, min_size=1)).map(tuple),