]> Cypherpunks.ru repositories - pyderasn.git/blobdiff - tests/test_pyderasn.py
Preserve BER-related attributes during copy()
[pyderasn.git] / tests / test_pyderasn.py
index a0a60cb16adbea1b9cd358f54f307893b80f35e2..13aeff48e1ead5dda07499b252214dc294934527 100644 (file)
@@ -1,6 +1,6 @@
 # coding: utf-8
 # PyDERASN -- Python ASN.1 DER/BER codec with abstract structures
-# Copyright (C) 2017-2018 Sergey Matveev <stargrave@stargrave.org>
+# Copyright (C) 2017-2019 Sergey Matveev <stargrave@stargrave.org>
 #
 # This program is free software: you can redistribute it and/or modify
 # it under the terms of the GNU Lesser General Public License as
@@ -622,6 +622,10 @@ class TestBoolean(CommonMixin, TestCase):
         self.assertTrue(obj.ber_encoded)
         self.assertFalse(obj.lenindef)
         self.assertTrue(obj.bered)
+        obj = obj.copy()
+        self.assertTrue(obj.ber_encoded)
+        self.assertFalse(obj.lenindef)
+        self.assertTrue(obj.bered)
 
     @given(
         integers(min_value=1).map(tag_ctxc),
@@ -641,6 +645,11 @@ class TestBoolean(CommonMixin, TestCase):
         self.assertFalse(obj.lenindef)
         self.assertFalse(obj.ber_encoded)
         self.assertTrue(obj.bered)
+        obj = obj.copy()
+        self.assertTrue(obj.expl_lenindef)
+        self.assertFalse(obj.lenindef)
+        self.assertFalse(obj.ber_encoded)
+        self.assertTrue(obj.bered)
         self.assertSequenceEqual(tail, junk)
         repr(obj)
         list(obj.pps())
@@ -1584,6 +1593,10 @@ class TestBitString(CommonMixin, TestCase):
             self.assertTrue(obj.ber_encoded)
             self.assertEqual(obj.lenindef, lenindef_expected)
             self.assertTrue(obj.bered)
+            obj = obj.copy()
+            self.assertTrue(obj.ber_encoded)
+            self.assertEqual(obj.lenindef, lenindef_expected)
+            self.assertTrue(obj.bered)
             self.assertEqual(len(encoded), obj.tlvlen)
 
     @given(
@@ -1716,6 +1729,10 @@ class TestBitString(CommonMixin, TestCase):
         self.assertTrue(obj.ber_encoded)
         self.assertTrue(obj.lenindef)
         self.assertTrue(obj.bered)
+        obj = obj.copy()
+        self.assertTrue(obj.ber_encoded)
+        self.assertTrue(obj.lenindef)
+        self.assertTrue(obj.bered)
 
 
 @composite
@@ -2103,6 +2120,10 @@ class TestOctetString(CommonMixin, TestCase):
             self.assertTrue(obj.ber_encoded)
             self.assertEqual(obj.lenindef, lenindef_expected)
             self.assertTrue(obj.bered)
+            obj = obj.copy()
+            self.assertTrue(obj.ber_encoded)
+            self.assertEqual(obj.lenindef, lenindef_expected)
+            self.assertTrue(obj.bered)
             self.assertEqual(len(encoded), obj.tlvlen)
 
     @given(
@@ -2410,6 +2431,7 @@ class TestObjectIdentifier(CommonMixin, TestCase):
         repr(err.exception)
         obj = ObjectIdentifier(value)
         self.assertTrue(obj.ready)
+        self.assertFalse(obj.ber_encoded)
         repr(obj)
         list(obj.pps())
         pprint(obj, big_blobs=True, with_decode_path=True)
@@ -2720,6 +2742,56 @@ class TestObjectIdentifier(CommonMixin, TestCase):
             ObjectIdentifier((2, 999, 3)),
         )
 
+    @given(data_strategy())
+    def test_nonnormalized_first_arc(self, d):
+        tampered = (
+            ObjectIdentifier.tag_default +
+            len_encode(2) +
+            b'\x80' +
+            ObjectIdentifier((1, 0)).encode()[-1:]
+        )
+        obj, _ = ObjectIdentifier().decode(tampered, ctx={"bered": True})
+        self.assertTrue(obj.ber_encoded)
+        self.assertTrue(obj.bered)
+        obj = obj.copy()
+        self.assertTrue(obj.ber_encoded)
+        self.assertTrue(obj.bered)
+        with assertRaisesRegex(self, DecodeError, "non normalized arc encoding"):
+            ObjectIdentifier().decode(tampered)
+
+    @given(data_strategy())
+    def test_nonnormalized_arcs(self, d):
+        arcs = d.draw(lists(
+            integers(min_value=0, max_value=100),
+            min_size=1,
+            max_size=5,
+        ))
+        dered = ObjectIdentifier((1, 0) + tuple(arcs)).encode()
+        _, tlen, lv = tag_strip(dered)
+        _, llen, v = len_decode(lv)
+        v_no_first_arc = v[1:]
+        idx_for_tamper = d.draw(integers(
+            min_value=0,
+            max_value=len(v_no_first_arc) - 1,
+        ))
+        tampered = list(bytearray(v_no_first_arc))
+        for _ in range(d.draw(integers(min_value=1, max_value=3))):
+            tampered.insert(idx_for_tamper, 0x80)
+        tampered = bytes(bytearray(tampered))
+        tampered = (
+            ObjectIdentifier.tag_default +
+            len_encode(len(tampered)) +
+            tampered
+        )
+        obj, _ = ObjectIdentifier().decode(tampered, ctx={"bered": True})
+        self.assertTrue(obj.ber_encoded)
+        self.assertTrue(obj.bered)
+        obj = obj.copy()
+        self.assertTrue(obj.ber_encoded)
+        self.assertTrue(obj.bered)
+        with assertRaisesRegex(self, DecodeError, "non normalized arc encoding"):
+            ObjectIdentifier().decode(tampered)
+
 
 @composite
 def enumerated_values_strategy(draw, schema=None, do_expl=False):
@@ -3451,6 +3523,10 @@ class TestVisibleString(
         self.assertTrue(obj.ber_encoded)
         self.assertFalse(obj.lenindef)
         self.assertTrue(obj.bered)
+        obj = obj.copy()
+        self.assertTrue(obj.ber_encoded)
+        self.assertFalse(obj.lenindef)
+        self.assertTrue(obj.bered)
 
         obj, tail = VisibleString().decode(
             hexdec("3A8004034A6F6E040265730000"),
@@ -3461,6 +3537,10 @@ class TestVisibleString(
         self.assertTrue(obj.ber_encoded)
         self.assertTrue(obj.lenindef)
         self.assertTrue(obj.bered)
+        obj = obj.copy()
+        self.assertTrue(obj.ber_encoded)
+        self.assertTrue(obj.lenindef)
+        self.assertTrue(obj.bered)
 
 
 class TestGeneralString(
@@ -4154,6 +4234,10 @@ class TestAny(CommonMixin, TestCase):
         self.assertTrue(obj.lenindef)
         self.assertFalse(obj.ber_encoded)
         self.assertTrue(obj.bered)
+        obj = obj.copy()
+        self.assertTrue(obj.lenindef)
+        self.assertFalse(obj.ber_encoded)
+        self.assertTrue(obj.bered)
         repr(obj)
         list(obj.pps())
         pprint(obj, big_blobs=True, with_decode_path=True)
@@ -4967,6 +5051,9 @@ class SeqMixing(object):
         self.assertDictEqual(ctx_copied, ctx_dummy)
         self.assertTrue(seq_decoded_lenindef.lenindef)
         self.assertTrue(seq_decoded_lenindef.bered)
+        seq_decoded_lenindef = seq_decoded_lenindef.copy()
+        self.assertTrue(seq_decoded_lenindef.lenindef)
+        self.assertTrue(seq_decoded_lenindef.bered)
         with self.assertRaises(DecodeError):
             seq.decode(seq_encoded_lenindef[:-1], ctx={"bered": True})
         with self.assertRaises(DecodeError):
@@ -5084,6 +5171,9 @@ class SeqMixing(object):
             seq_decoded, _ = seq_with_default.decode(seq_encoded, ctx=ctx)
             self.assertTrue(seq_decoded.ber_encoded)
             self.assertTrue(seq_decoded.bered)
+            seq_decoded = seq_decoded.copy()
+            self.assertTrue(seq_decoded.ber_encoded)
+            self.assertTrue(seq_decoded.bered)
             for name, value in _schema:
                 self.assertEqual(seq_decoded[name], seq_with_default[name])
                 self.assertEqual(seq_decoded[name], value)
@@ -5122,6 +5212,10 @@ class SeqMixing(object):
         self.assertFalse(decoded.ber_encoded)
         self.assertFalse(decoded.lenindef)
         self.assertTrue(decoded.bered)
+        decoded = decoded.copy()
+        self.assertFalse(decoded.ber_encoded)
+        self.assertFalse(decoded.lenindef)
+        self.assertTrue(decoded.bered)
 
         class Seq(self.base_klass):
             schema = (("underlying", OctetString()),)
@@ -5138,6 +5232,10 @@ class SeqMixing(object):
         self.assertFalse(decoded.ber_encoded)
         self.assertFalse(decoded.lenindef)
         self.assertTrue(decoded.bered)
+        decoded = decoded.copy()
+        self.assertFalse(decoded.ber_encoded)
+        self.assertFalse(decoded.lenindef)
+        self.assertTrue(decoded.bered)
 
 
 class TestSequence(SeqMixing, CommonMixin, TestCase):
@@ -5234,6 +5332,9 @@ class TestSet(SeqMixing, CommonMixin, TestCase):
             seq_decoded, _ = Seq().decode(seq_encoded, ctx=ctx)
             self.assertTrue(seq_decoded.ber_encoded)
             self.assertTrue(seq_decoded.bered)
+            seq_decoded = seq_decoded.copy()
+            self.assertTrue(seq_decoded.ber_encoded)
+            self.assertTrue(seq_decoded.bered)
             self.assertSequenceEqual(
                 [bytes(seq_decoded[str(i)]) for i, t in enumerate(tags)],
                 [t for t in tags],
@@ -5659,6 +5760,9 @@ class SeqOfMixing(object):
         )
         self.assertTrue(obj_decoded_lenindef.lenindef)
         self.assertTrue(obj_decoded_lenindef.bered)
+        obj_decoded_lenindef = obj_decoded_lenindef.copy()
+        self.assertTrue(obj_decoded_lenindef.lenindef)
+        self.assertTrue(obj_decoded_lenindef.bered)
         repr(obj_decoded_lenindef)
         list(obj_decoded_lenindef.pps())
         pprint(obj_decoded_lenindef, big_blobs=True, with_decode_path=True)
@@ -5681,6 +5785,10 @@ class SeqOfMixing(object):
         self.assertFalse(decoded.ber_encoded)
         self.assertFalse(decoded.lenindef)
         self.assertTrue(decoded.bered)
+        decoded = decoded.copy()
+        self.assertFalse(decoded.ber_encoded)
+        self.assertFalse(decoded.lenindef)
+        self.assertTrue(decoded.bered)
 
         class SeqOf(self.base_klass):
             schema = OctetString()
@@ -5698,6 +5806,10 @@ class SeqOfMixing(object):
         self.assertFalse(decoded.ber_encoded)
         self.assertFalse(decoded.lenindef)
         self.assertTrue(decoded.bered)
+        decoded = decoded.copy()
+        self.assertFalse(decoded.ber_encoded)
+        self.assertFalse(decoded.lenindef)
+        self.assertTrue(decoded.bered)
 
 
 class TestSequenceOf(SeqOfMixing, CommonMixin, TestCase):
@@ -5763,6 +5875,9 @@ class TestSetOf(SeqOfMixing, CommonMixin, TestCase):
             seq_decoded, _ = Seq().decode(seq_encoded, ctx=ctx)
             self.assertTrue(seq_decoded.ber_encoded)
             self.assertTrue(seq_decoded.bered)
+            seq_decoded = seq_decoded.copy()
+            self.assertTrue(seq_decoded.ber_encoded)
+            self.assertTrue(seq_decoded.bered)
             self.assertSequenceEqual(
                 [obj.encode() for obj in seq_decoded],
                 values,
@@ -6268,9 +6383,15 @@ class TestStrictDefaultExistence(TestCase):
             decoded, _ = seq.decode(raw, ctx={"allow_default_values": True})
             self.assertTrue(decoded.ber_encoded)
             self.assertTrue(decoded.bered)
+            decoded = decoded.copy()
+            self.assertTrue(decoded.ber_encoded)
+            self.assertTrue(decoded.bered)
             decoded, _ = seq.decode(raw, ctx={"bered": True})
             self.assertTrue(decoded.ber_encoded)
             self.assertTrue(decoded.bered)
+            decoded = decoded.copy()
+            self.assertTrue(decoded.ber_encoded)
+            self.assertTrue(decoded.bered)
 
 
 class TestX690PrefixedType(TestCase):