]> Cypherpunks.ru repositories - pyderasn.git/commitdiff
ctx is safe to use as immutable
authorSergey Matveev <stargrave@stargrave.org>
Sat, 8 Dec 2018 16:05:10 +0000 (19:05 +0300)
committerSergey Matveev <stargrave@stargrave.org>
Sat, 8 Dec 2018 16:24:59 +0000 (19:24 +0300)
VERSION
doc/news.rst
pyderasn.py
tests/test_pyderasn.py

diff --git a/VERSION b/VERSION
index 515be8f918de9d7addeeccda132a1db7b29afc14..4caecc733e6bc437e8afacd9a27cd84edd18a5bc 100644 (file)
--- a/VERSION
+++ b/VERSION
@@ -1 +1 @@
-4.4
+4.5
index 369b8234b1f665239b01aa98ce595863f4c110ab..87863f5f5f5273b23173bd1cae92f75973f78505 100644 (file)
@@ -1,6 +1,13 @@
 News
 ====
 
+.. _release4.5:
+
+4.5
+---
+* ``ctx`` parameter can be safely used in .decode() and won't be muted
+
+
 .. _release4.4:
 
 4.4
index aa91b04052c100a8ba6dfc102cd840dcfce64c4a..4142af690d616b316ac46494dfde8857fcb25a02 100755 (executable)
@@ -539,6 +539,7 @@ from codecs import getdecoder
 from codecs import getencoder
 from collections import namedtuple
 from collections import OrderedDict
+from copy import copy
 from datetime import datetime
 from math import ceil
 from os import environ
@@ -646,6 +647,7 @@ LENINDEF_PP_CHAR = "I" if PY2 else "∞"
 class ASN1Error(ValueError):
     pass
 
+
 class DecodeError(ASN1Error):
     def __init__(self, msg="", klass=None, decode_path=(), offset=0):
         """
@@ -1027,6 +1029,7 @@ class Obj(object):
             decode_path=(),
             ctx=None,
             tag_only=False,
+            _ctx_immutable=True,
     ):
         """Decode the data
 
@@ -1038,10 +1041,13 @@ class Obj(object):
         :param tag_only: decode only the tag, without length and contents
                          (used only in Choice and Set structures, trying to
                          determine if tag satisfies the scheme)
+        :param _ctx_immutable: do we need to copy ``ctx`` before using it
         :returns: (Obj, remaining data)
         """
         if ctx is None:
             ctx = {}
+        elif _ctx_immutable:
+            ctx = copy(ctx)
         tlv = memoryview(data)
         if self._expl is None:
             result = self._decode(
@@ -2387,6 +2393,7 @@ class BitString(Obj):
                         decode_path=sub_decode_path,
                         leavemm=True,
                         ctx=ctx,
+                        _ctx_immutable=False,
                     )
                 except TagMismatch:
                     raise DecodeError(
@@ -2751,6 +2758,7 @@ class OctetString(Obj):
                         decode_path=sub_decode_path,
                         leavemm=True,
                         ctx=ctx,
+                        _ctx_immutable=False,
                     )
                 except TagMismatch:
                     raise DecodeError(
@@ -4007,6 +4015,7 @@ class Choice(Obj):
                     decode_path=sub_decode_path,
                     ctx=ctx,
                     tag_only=True,
+                    _ctx_immutable=False,
                 )
             except TagMismatch:
                 continue
@@ -4025,6 +4034,7 @@ class Choice(Obj):
             leavemm=True,
             decode_path=sub_decode_path,
             ctx=ctx,
+            _ctx_immutable=False,
         )
         obj = self.__class__(
             schema=self.specs,
@@ -4226,6 +4236,7 @@ class Any(Obj):
                     decode_path=decode_path + (str(chunk_i),),
                     leavemm=True,
                     ctx=ctx,
+                    _ctx_immutable=False,
                 )
                 vlen += chunk.tlvlen
                 sub_offset += chunk.tlvlen
@@ -4641,6 +4652,7 @@ class Sequence(Obj):
                     leavemm=True,
                     decode_path=sub_decode_path,
                     ctx=ctx,
+                    _ctx_immutable=False,
                 )
             except TagMismatch:
                 if spec.optional:
@@ -4665,6 +4677,7 @@ class Sequence(Obj):
                             leavemm=True,
                             decode_path=sub_sub_decode_path,
                             ctx=ctx,
+                            _ctx_immutable=False,
                         )
                         if len(defined_tail) > 0:
                             raise DecodeError(
@@ -4684,6 +4697,7 @@ class Sequence(Obj):
                         leavemm=True,
                         decode_path=sub_decode_path + (DecodePathDefBy(defined_by),),
                         ctx=ctx,
+                        _ctx_immutable=False,
                     )
                     if len(defined_tail) > 0:
                         raise DecodeError(
@@ -4885,6 +4899,7 @@ class Set(Sequence):
                         decode_path=sub_decode_path,
                         ctx=ctx,
                         tag_only=True,
+                        _ctx_immutable=False,
                     )
                 except TagMismatch:
                     continue
@@ -4901,6 +4916,7 @@ class Set(Sequence):
                 leavemm=True,
                 decode_path=sub_decode_path,
                 ctx=ctx,
+                _ctx_immutable=False,
             )
             value_len = value.fulllen
             if value_prev.tobytes() > v[:value_len].tobytes():
@@ -5210,6 +5226,7 @@ class SequenceOf(Obj):
                 leavemm=True,
                 decode_path=sub_decode_path,
                 ctx=ctx,
+                _ctx_immutable=False,
             )
             value_len = value.fulllen
             if ordering_check:
index 26addf001e5f164b8187fa19250a96902fa06e58..a0a60cb16adbea1b9cd358f54f307893b80f35e2 100644 (file)
@@ -16,6 +16,7 @@
 # License along with this program.  If not, see
 # <http://www.gnu.org/licenses/>.
 
+from copy import deepcopy
 from datetime import datetime
 from string import ascii_letters
 from string import digits
@@ -132,6 +133,7 @@ tag_forms = sampled_from((TagFormConstructed, TagFormPrimitive))
 decode_path_strat = lists(integers(), max_size=3).map(
     lambda decode_path: tuple(str(dp) for dp in decode_path)
 )
+ctx_dummy = dictionaries(integers(), integers(), min_size=2, max_size=4).example()
 
 
 class TestHex(TestCase):
@@ -561,10 +563,13 @@ class TestBoolean(CommonMixin, TestCase):
             list(obj_expled.pps())
             pprint(obj_expled, big_blobs=True, with_decode_path=True)
             obj_expled_encoded = obj_expled.encode()
+            ctx_copied = deepcopy(ctx_dummy)
             obj_decoded, tail = obj_expled.decode(
                 obj_expled_encoded + tail_junk,
                 offset=offset,
+                ctx=ctx_copied,
             )
+            self.assertDictEqual(ctx_copied, ctx_dummy)
             repr(obj_decoded)
             list(obj_decoded.pps())
             pprint(obj_decoded, big_blobs=True, with_decode_path=True)
@@ -1041,10 +1046,13 @@ class TestInteger(CommonMixin, TestCase):
             list(obj_expled.pps())
             pprint(obj_expled, big_blobs=True, with_decode_path=True)
             obj_expled_encoded = obj_expled.encode()
+            ctx_copied = deepcopy(ctx_dummy)
             obj_decoded, tail = obj_expled.decode(
                 obj_expled_encoded + tail_junk,
                 offset=offset,
+                ctx=ctx_copied,
             )
+            self.assertDictEqual(ctx_copied, ctx_dummy)
             repr(obj_decoded)
             list(obj_decoded.pps())
             pprint(obj_decoded, big_blobs=True, with_decode_path=True)
@@ -1424,10 +1432,13 @@ class TestBitString(CommonMixin, TestCase):
             list(obj_expled.pps())
             pprint(obj_expled, big_blobs=True, with_decode_path=True)
             obj_expled_encoded = obj_expled.encode()
+            ctx_copied = deepcopy(ctx_dummy)
             obj_decoded, tail = obj_expled.decode(
                 obj_expled_encoded + tail_junk,
                 offset=offset,
+                ctx=ctx_copied,
             )
+            self.assertDictEqual(ctx_copied, ctx_dummy)
             repr(obj_decoded)
             list(obj_decoded.pps())
             pprint(obj_decoded, big_blobs=True, with_decode_path=True)
@@ -2002,10 +2013,13 @@ class TestOctetString(CommonMixin, TestCase):
             list(obj_expled.pps())
             pprint(obj_expled, big_blobs=True, with_decode_path=True)
             obj_expled_encoded = obj_expled.encode()
+            ctx_copied = deepcopy(ctx_dummy)
             obj_decoded, tail = obj_expled.decode(
                 obj_expled_encoded + tail_junk,
                 offset=offset,
+                ctx=ctx_copied,
             )
+            self.assertDictEqual(ctx_copied, ctx_dummy)
             repr(obj_decoded)
             list(obj_decoded.pps())
             pprint(obj_decoded, big_blobs=True, with_decode_path=True)
@@ -2299,10 +2313,13 @@ class TestNull(CommonMixin, TestCase):
             list(obj_expled.pps())
             pprint(obj_expled, big_blobs=True, with_decode_path=True)
             obj_expled_encoded = obj_expled.encode()
+            ctx_copied = deepcopy(ctx_dummy)
             obj_decoded, tail = obj_expled.decode(
                 obj_expled_encoded + tail_junk,
                 offset=offset,
+                ctx=ctx_copied,
             )
+            self.assertDictEqual(ctx_copied, ctx_dummy)
             repr(obj_decoded)
             list(obj_decoded.pps())
             pprint(obj_decoded, big_blobs=True, with_decode_path=True)
@@ -2631,10 +2648,13 @@ class TestObjectIdentifier(CommonMixin, TestCase):
             list(obj_expled.pps())
             pprint(obj_expled, big_blobs=True, with_decode_path=True)
             obj_expled_encoded = obj_expled.encode()
+            ctx_copied = deepcopy(ctx_dummy)
             obj_decoded, tail = obj_expled.decode(
                 obj_expled_encoded + tail_junk,
                 offset=offset,
+                ctx=ctx_copied,
             )
+            self.assertDictEqual(ctx_copied, ctx_dummy)
             repr(obj_decoded)
             list(obj_decoded.pps())
             pprint(obj_decoded, big_blobs=True, with_decode_path=True)
@@ -2922,10 +2942,13 @@ class TestEnumerated(CommonMixin, TestCase):
         list(obj_expled.pps())
         pprint(obj_expled, big_blobs=True, with_decode_path=True)
         obj_expled_encoded = obj_expled.encode()
+        ctx_copied = deepcopy(ctx_dummy)
         obj_decoded, tail = obj_expled.decode(
             obj_expled_encoded + tail_junk,
             offset=offset,
+            ctx=ctx_copied,
         )
+        self.assertDictEqual(ctx_copied, ctx_dummy)
         repr(obj_decoded)
         list(obj_decoded.pps())
         pprint(obj_decoded, big_blobs=True, with_decode_path=True)
@@ -3247,10 +3270,13 @@ class StringMixin(object):
         list(obj_expled.pps())
         pprint(obj_expled, big_blobs=True, with_decode_path=True)
         obj_expled_encoded = obj_expled.encode()
+        ctx_copied = deepcopy(ctx_dummy)
         obj_decoded, tail = obj_expled.decode(
             obj_expled_encoded + tail_junk,
             offset=offset,
+            ctx=ctx_copied,
         )
+        self.assertDictEqual(ctx_copied, ctx_dummy)
         repr(obj_decoded)
         list(obj_decoded.pps())
         pprint(obj_decoded, big_blobs=True, with_decode_path=True)
@@ -3673,10 +3699,13 @@ class TimeMixin(object):
         list(obj_expled.pps())
         pprint(obj_expled, big_blobs=True, with_decode_path=True)
         obj_expled_encoded = obj_expled.encode()
+        ctx_copied = deepcopy(ctx_dummy)
         obj_decoded, tail = obj_expled.decode(
             obj_expled_encoded + tail_junk,
             offset=offset,
+            ctx=ctx_copied,
         )
+        self.assertDictEqual(ctx_copied, ctx_dummy)
         repr(obj_decoded)
         list(obj_decoded.pps())
         pprint(obj_decoded, big_blobs=True, with_decode_path=True)
@@ -4060,10 +4089,13 @@ class TestAny(CommonMixin, TestCase):
             list(obj_expled.pps())
             pprint(obj_expled, big_blobs=True, with_decode_path=True)
             obj_expled_encoded = obj_expled.encode()
+            ctx_copied = deepcopy(ctx_dummy)
             obj_decoded, tail = obj_expled.decode(
                 obj_expled_encoded + tail_junk,
                 offset=offset,
+                ctx=ctx_copied,
             )
+            self.assertDictEqual(ctx_copied, ctx_dummy)
             repr(obj_decoded)
             list(obj_decoded.pps())
             pprint(obj_decoded, big_blobs=True, with_decode_path=True)
@@ -4411,10 +4443,13 @@ class TestChoice(CommonMixin, TestCase):
         list(obj_expled.pps())
         pprint(obj_expled, big_blobs=True, with_decode_path=True)
         obj_expled_encoded = obj_expled.encode()
+        ctx_copied = deepcopy(ctx_dummy)
         obj_decoded, tail = obj_expled.decode(
             obj_expled_encoded + tail_junk,
             offset=offset,
+            ctx=ctx_copied,
         )
+        self.assertDictEqual(ctx_copied, ctx_dummy)
         repr(obj_decoded)
         list(obj_decoded.pps())
         pprint(obj_decoded, big_blobs=True, with_decode_path=True)
@@ -4922,10 +4957,14 @@ class SeqMixing(object):
         t, _, lv = tag_strip(seq_encoded)
         _, _, v = len_decode(lv)
         seq_encoded_lenindef = t + LENINDEF + v + EOC
+        ctx_copied = deepcopy(ctx_dummy)
+        ctx_copied["bered"] = True
         seq_decoded_lenindef, tail_lenindef = seq.decode(
             seq_encoded_lenindef + tail_junk,
-            ctx={"bered": True},
+            ctx=ctx_copied,
         )
+        del ctx_copied["bered"]
+        self.assertDictEqual(ctx_copied, ctx_dummy)
         self.assertTrue(seq_decoded_lenindef.lenindef)
         self.assertTrue(seq_decoded_lenindef.bered)
         with self.assertRaises(DecodeError):
@@ -5575,10 +5614,13 @@ class SeqOfMixing(object):
         list(obj_expled.pps())
         pprint(obj_expled, big_blobs=True, with_decode_path=True)
         obj_expled_encoded = obj_expled.encode()
+        ctx_copied = deepcopy(ctx_dummy)
         obj_decoded, tail = obj_expled.decode(
             obj_expled_encoded + tail_junk,
             offset=offset,
+            ctx=ctx_copied,
         )
+        self.assertDictEqual(ctx_copied, ctx_dummy)
         repr(obj_decoded)
         list(obj_decoded.pps())
         pprint(obj_decoded, big_blobs=True, with_decode_path=True)
@@ -6044,7 +6086,12 @@ class TestDefinesByPath(TestCase):
         pprint(seq_sequenced, big_blobs=True, with_decode_path=True)
 
         defines_by_path = []
-        seq_integered, _ = Seq().decode(seq_integered_raw)
+        ctx_copied = deepcopy(ctx_dummy)
+        seq_integered, _ = Seq().decode(
+            seq_integered_raw,
+            ctx=ctx_copied,
+        )
+        self.assertDictEqual(ctx_copied, ctx_dummy)
         self.assertIsNone(seq_integered["value"].defined)
         defines_by_path.append(
             (("type",), ((("value",), {
@@ -6052,10 +6099,13 @@ class TestDefinesByPath(TestCase):
                 type_sequenced: SeqInner(),
             }),))
         )
+        ctx_copied["defines_by_path"] = defines_by_path
         seq_integered, _ = Seq().decode(
             seq_integered_raw,
-            ctx={"defines_by_path": defines_by_path},
+            ctx=ctx_copied,
         )
+        del ctx_copied["defines_by_path"]
+        self.assertDictEqual(ctx_copied, ctx_dummy)
         self.assertIsNotNone(seq_integered["value"].defined)
         self.assertEqual(seq_integered["value"].defined[0], type_integered)
         self.assertEqual(seq_integered["value"].defined[1], Integer(123))
@@ -6066,10 +6116,13 @@ class TestDefinesByPath(TestCase):
         list(seq_integered.pps())
         pprint(seq_integered, big_blobs=True, with_decode_path=True)
 
+        ctx_copied["defines_by_path"] = defines_by_path
         seq_sequenced, _ = Seq().decode(
             seq_sequenced_raw,
-            ctx={"defines_by_path": defines_by_path},
+            ctx=ctx_copied,
         )
+        del ctx_copied["defines_by_path"]
+        self.assertDictEqual(ctx_copied, ctx_dummy)
         self.assertIsNotNone(seq_sequenced["value"].defined)
         self.assertEqual(seq_sequenced["value"].defined[0], type_sequenced)
         seq_inner = seq_sequenced["value"].defined[1]
@@ -6082,10 +6135,13 @@ class TestDefinesByPath(TestCase):
             ("value", DecodePathDefBy(type_sequenced), "typeInner"),
             ((("valueInner",), {type_innered: Pairs()}),),
         ))
+        ctx_copied["defines_by_path"] = defines_by_path
         seq_sequenced, _ = Seq().decode(
             seq_sequenced_raw,
-            ctx={"defines_by_path": defines_by_path},
+            ctx=ctx_copied,
         )
+        del ctx_copied["defines_by_path"]
+        self.assertDictEqual(ctx_copied, ctx_dummy)
         self.assertIsNotNone(seq_sequenced["value"].defined)
         self.assertEqual(seq_sequenced["value"].defined[0], type_sequenced)
         seq_inner = seq_sequenced["value"].defined[1]
@@ -6112,10 +6168,13 @@ class TestDefinesByPath(TestCase):
                 type_octet_stringed: OctetString(),
             }),),
         ))
+        ctx_copied["defines_by_path"] = defines_by_path
         seq_sequenced, _ = Seq().decode(
             seq_sequenced_raw,
-            ctx={"defines_by_path": defines_by_path},
+            ctx=ctx_copied,
         )
+        del ctx_copied["defines_by_path"]
+        self.assertDictEqual(ctx_copied, ctx_dummy)
         self.assertIsNotNone(seq_sequenced["value"].defined)
         self.assertEqual(seq_sequenced["value"].defined[0], type_sequenced)
         seq_inner = seq_sequenced["value"].defined[1]