]> Cypherpunks.ru repositories - pyderasn.git/blobdiff - tests/test_pyderasn.py
DEFINED BY support
[pyderasn.git] / tests / test_pyderasn.py
index fdc03b58369cb8bcd7607f5a9a5c4d7b56661ea8..23127261feb9175309eea0f8b6e41356f0e7cf56 100644 (file)
@@ -56,6 +56,7 @@ from pyderasn import BMPString
 from pyderasn import Boolean
 from pyderasn import BoundsError
 from pyderasn import Choice
+from pyderasn import decode_path_defby
 from pyderasn import DecodeError
 from pyderasn import Enumerated
 from pyderasn import GeneralizedTime
@@ -85,6 +86,7 @@ from pyderasn import SequenceOf
 from pyderasn import Set
 from pyderasn import SetOf
 from pyderasn import tag_ctxc
+from pyderasn import tag_ctxp
 from pyderasn import tag_decode
 from pyderasn import tag_encode
 from pyderasn import tag_strip
@@ -1953,12 +1955,12 @@ class TestObjectIdentifier(CommonMixin, TestCase):
                 _decoded_initial,
             ) = d.draw(oid_values_strategy())
             obj_initial = klass(
-                value_initial,
-                impl_initial,
-                expl_initial,
-                default_initial,
-                optional_initial or False,
-                _decoded_initial,
+                value=value_initial,
+                impl=impl_initial,
+                expl=expl_initial,
+                default=default_initial,
+                optional=optional_initial or False,
+                _decoded=_decoded_initial,
             )
             (
                 value,
@@ -1968,7 +1970,13 @@ class TestObjectIdentifier(CommonMixin, TestCase):
                 optional,
                 _decoded,
             ) = d.draw(oid_values_strategy(do_expl=impl_initial is None))
-            obj = obj_initial(value, impl, expl, default, optional)
+            obj = obj_initial(
+                value=value,
+                impl=impl,
+                expl=expl,
+                default=default,
+                optional=optional,
+            )
             if obj.ready:
                 value_expected = default if value is None else value
                 value_expected = (
@@ -1992,7 +2000,22 @@ class TestObjectIdentifier(CommonMixin, TestCase):
     @given(oid_values_strategy())
     def test_copy(self, values):
         for klass in (ObjectIdentifier, ObjectIdentifierInherited):
-            obj = klass(*values)
+            (
+                value,
+                impl,
+                expl,
+                default,
+                optional,
+                _decoded,
+            ) = values
+            obj = klass(
+                value=value,
+                impl=impl,
+                expl=expl,
+                default=default,
+                optional=optional,
+                _decoded=_decoded,
+            )
             obj_copied = obj.copy()
             self.assert_copied_basic_fields(obj, obj_copied)
             self.assertEqual(obj._value, obj_copied._value)
@@ -4861,3 +4884,163 @@ class TestAutoAddSlots(TestCase):
         with self.assertRaises(AttributeError):
             inher = Inher()
             inher.unexistent = "whatever"
+
+
+class TestOIDDefines(TestCase):
+    @given(data_strategy())
+    def runTest(self, d):
+        value_names = list(d.draw(sets(text_letters(), min_size=1, max_size=10)))
+        value_name_chosen = d.draw(sampled_from(value_names))
+        oids = [
+            ObjectIdentifier(oid)
+            for oid in d.draw(sets(oid_strategy(), min_size=2, max_size=10))
+        ]
+        oid_chosen = d.draw(sampled_from(oids))
+        values = d.draw(lists(
+            integers(),
+            min_size=len(value_names),
+            max_size=len(value_names),
+        ))
+        _schema = [
+            ("type", ObjectIdentifier(defines=(value_name_chosen, {
+                oid: Integer() for oid in oids[:-1]
+            }))),
+        ]
+        for i, value_name in enumerate(value_names):
+            _schema.append((value_name, Any(expl=tag_ctxp(i))))
+
+        class Seq(Sequence):
+            schema = _schema
+        seq = Seq()
+        for value_name, value in zip(value_names, values):
+            seq[value_name] = Any(Integer(value).encode())
+        seq["type"] = oid_chosen
+        seq, _ = Seq().decode(seq.encode())
+        for value_name in value_names:
+            if value_name == value_name_chosen:
+                continue
+            self.assertIsNone(seq[value_name].defined)
+        if value_name_chosen in oids[:-1]:
+            self.assertIsNotNone(seq[value_name_chosen].defined)
+            self.assertEqual(seq[value_name_chosen].defined[0], oid_chosen)
+            self.assertIsInstance(seq[value_name_chosen].defined[1], Integer)
+
+
+class TestDefinesByPath(TestCase):
+    def runTest(self):
+        class Seq(Sequence):
+            schema = (
+                ("type", ObjectIdentifier()),
+                ("value", OctetString(expl=tag_ctxc(123))),
+            )
+
+        class SeqInner(Sequence):
+            schema = (
+                ("typeInner", ObjectIdentifier()),
+                ("valueInner", Any()),
+            )
+
+        class PairValue(SetOf):
+            schema = Any()
+
+        class Pair(Sequence):
+            schema = (
+                ("type", ObjectIdentifier()),
+                ("value", PairValue()),
+            )
+
+        class Pairs(SequenceOf):
+            schema = Pair()
+
+        (
+            type_integered,
+            type_sequenced,
+            type_innered,
+            type_octet_stringed,
+        ) = [
+            ObjectIdentifier(oid)
+            for oid in sets(oid_strategy(), min_size=4, max_size=4).example()
+        ]
+        seq_integered = Seq()
+        seq_integered["type"] = type_integered
+        seq_integered["value"] = OctetString(Integer(123).encode())
+        seq_integered_raw = seq_integered.encode()
+
+        pairs = Pairs()
+        pairs_input = (
+            (type_octet_stringed, OctetString(b"whatever")),
+            (type_integered, Integer(123)),
+            (type_octet_stringed, OctetString(b"whenever")),
+            (type_integered, Integer(234)),
+        )
+        for t, v in pairs_input:
+            pair = Pair()
+            pair["type"] = t
+            pair["value"] = PairValue((Any(v),))
+            pairs.append(pair)
+        seq_inner = SeqInner()
+        seq_inner["typeInner"] = type_innered
+        seq_inner["valueInner"] = Any(pairs)
+        seq_sequenced = Seq()
+        seq_sequenced["type"] = type_sequenced
+        seq_sequenced["value"] = OctetString(seq_inner.encode())
+        seq_sequenced_raw = seq_sequenced.encode()
+
+        defines_by_path = []
+        seq_integered, _ = Seq().decode(seq_integered_raw)
+        self.assertIsNone(seq_integered["value"].defined)
+        defines_by_path.append(
+            (("type",), ("value", {
+                type_integered: Integer(),
+                type_sequenced: SeqInner(),
+            }))
+        )
+        seq_integered, _ = Seq().decode(seq_integered_raw, defines_by_path=defines_by_path)
+        self.assertIsNotNone(seq_integered["value"].defined)
+        self.assertEqual(seq_integered["value"].defined[0], type_integered)
+        self.assertEqual(seq_integered["value"].defined[1], Integer(123))
+
+        seq_sequenced, _ = Seq().decode(seq_sequenced_raw, defines_by_path=defines_by_path)
+        self.assertIsNotNone(seq_sequenced["value"].defined)
+        self.assertEqual(seq_sequenced["value"].defined[0], type_sequenced)
+        seq_inner = seq_sequenced["value"].defined[1]
+        self.assertIsNone(seq_inner["valueInner"].defined)
+
+        defines_by_path.append((
+            ("value", decode_path_defby(type_sequenced), "typeInner"),
+            ("valueInner", {type_innered: Pairs()}),
+        ))
+        seq_sequenced, _ = Seq().decode(seq_sequenced_raw, defines_by_path=defines_by_path)
+        self.assertIsNotNone(seq_sequenced["value"].defined)
+        self.assertEqual(seq_sequenced["value"].defined[0], type_sequenced)
+        seq_inner = seq_sequenced["value"].defined[1]
+        self.assertIsNotNone(seq_inner["valueInner"].defined)
+        self.assertEqual(seq_inner["valueInner"].defined[0], type_innered)
+        pairs = seq_inner["valueInner"].defined[1]
+        for pair in pairs:
+            self.assertIsNone(pair["value"][0].defined)
+
+        defines_by_path.append((
+            (
+                "value",
+                decode_path_defby(type_sequenced),
+                "valueInner",
+                decode_path_defby(type_innered),
+                any,
+                "type",
+            ),
+            ("value", {
+                type_integered: Integer(),
+                type_octet_stringed: OctetString(),
+            }),
+        ))
+        seq_sequenced, _ = Seq().decode(seq_sequenced_raw, defines_by_path=defines_by_path)
+        self.assertIsNotNone(seq_sequenced["value"].defined)
+        self.assertEqual(seq_sequenced["value"].defined[0], type_sequenced)
+        seq_inner = seq_sequenced["value"].defined[1]
+        self.assertIsNotNone(seq_inner["valueInner"].defined)
+        self.assertEqual(seq_inner["valueInner"].defined[0], type_innered)
+        pairs_got = seq_inner["valueInner"].defined[1]
+        for pair_input, pair_got in zip(pairs_input, pairs_got):
+            self.assertEqual(pair_got["value"][0].defined[0], pair_input[0])
+            self.assertEqual(pair_got["value"][0].defined[1], pair_input[1])