]> Cypherpunks.ru repositories - pyderasn.git/commitdiff
Valid DER SET ordering
authorSergey Matveev <stargrave@stargrave.org>
Wed, 12 Feb 2020 14:10:35 +0000 (17:10 +0300)
committerSergey Matveev <stargrave@stargrave.org>
Sun, 16 Feb 2020 09:20:11 +0000 (12:20 +0300)
doc/news.rst
pyderasn.py
tests/test_pyderasn.py

index 47fcc107d04231fc650d0bc0d96a6806b0d182d7..6f360012a01beee2dd37b60720187ae29fd77062 100644 (file)
@@ -7,6 +7,10 @@ News
 ---
 * Fixed invalid behaviour where SET OF allowed multiple objects with the
   same tag to be successfully decoded
+* Fixed possibly invalid SET DER encoding where objects were not sorted
+  by tag, but by encoded representation
+* Any does not allow empty data value now. Now it checks if it has valid
+  ASN.1 tag
 
 .. _release6.3:
 
index 4cabec726c6a2c14f12a242fd326355f3334ee56..f1be6473130142969088c58038efc402a5995a10 100755 (executable)
@@ -352,7 +352,6 @@ Let's parse that output, human::
  (and its derivatives), ``SET``, ``SET OF``, ``UTCTime``, ``GeneralizedTime``
  could be BERed.
 
-
 .. _definedby:
 
 DEFINED BY
@@ -781,6 +780,7 @@ from copy import copy
 from datetime import datetime
 from datetime import timedelta
 from math import ceil
+from operator import attrgetter
 from string import ascii_letters
 from string import digits
 from sys import version_info
@@ -1190,6 +1190,7 @@ class AutoAddSlots(type):
 BasicState = namedtuple("BasicState", (
     "version",
     "tag",
+    "tag_order",
     "expl",
     "default",
     "optional",
@@ -1211,6 +1212,7 @@ class Obj(object):
     """
     __slots__ = (
         "tag",
+        "_tag_order",
         "_value",
         "_expl",
         "default",
@@ -1235,6 +1237,13 @@ class Obj(object):
         self._expl = getattr(self, "expl", None) if expl is None else expl
         if self.tag != self.tag_default and self._expl is not None:
             raise ValueError("implicit and explicit tags can not be set simultaneously")
+        if self.tag is None:
+            self._tag_order = None
+        else:
+            tag_class, _, tag_num = tag_decode(
+                self.tag if self._expl is None else self._expl
+            )
+            self._tag_order = (tag_class, tag_num)
         if default is not None:
             optional = True
         self.optional = optional
@@ -1275,6 +1284,7 @@ class Obj(object):
         if state.version != __version__:
             raise ValueError("data is pickled by different PyDERASN version")
         self.tag = state.tag
+        self._tag_order = state.tag_order
         self._expl = state.expl
         self.default = state.default
         self.optional = state.optional
@@ -1285,6 +1295,12 @@ class Obj(object):
         self.lenindef = state.lenindef
         self.ber_encoded = state.ber_encoded
 
+    @property
+    def tag_order(self):
+        """Tag's (class, number) used for DER/CER sorting
+        """
+        return self._tag_order
+
     @property
     def tlen(self):
         """See :ref:`decoding`
@@ -1938,6 +1954,7 @@ class Boolean(Obj):
         return BooleanState(
             __version__,
             self.tag,
+            self._tag_order,
             self._expl,
             self.default,
             self.optional,
@@ -2207,6 +2224,7 @@ class Integer(Obj):
         return IntegerState(
             __version__,
             self.tag,
+            self._tag_order,
             self._expl,
             self.default,
             self.optional,
@@ -2607,6 +2625,7 @@ class BitString(Obj):
         return BitStringState(
             __version__,
             self.tag,
+            self._tag_order,
             self._expl,
             self.default,
             self.optional,
@@ -3042,6 +3061,7 @@ class OctetString(Obj):
         return OctetStringState(
             __version__,
             self.tag,
+            self._tag_order,
             self._expl,
             self.default,
             self.optional,
@@ -3347,6 +3367,7 @@ class Null(Obj):
         return NullState(
             __version__,
             self.tag,
+            self._tag_order,
             self._expl,
             self.default,
             self.optional,
@@ -3562,6 +3583,7 @@ class ObjectIdentifier(Obj):
         return ObjectIdentifierState(
             __version__,
             self.tag,
+            self._tag_order,
             self._expl,
             self.default,
             self.optional,
@@ -4663,6 +4685,9 @@ class Choice(Obj):
             self.default = default_obj
             if value is None:
                 self._value = copy(default_obj._value)
+        if self._expl is not None:
+            tag_class, _, tag_num = tag_decode(self._expl)
+            self._tag_order = (tag_class, tag_num)
 
     def _value_sanitize(self, value):
         if (value.__class__ == tuple) and len(value) == 2:
@@ -4692,6 +4717,7 @@ class Choice(Obj):
         return ChoiceState(
             __version__,
             self.tag,
+            self._tag_order,
             self._expl,
             self.default,
             self.optional,
@@ -4749,6 +4775,11 @@ class Choice(Obj):
         self._assert_ready()
         return self._value[1]
 
+    @property
+    def tag_order(self):
+        self._assert_ready()
+        return self._value[1].tag_order if self._tag_order is None else self._tag_order
+
     def __getitem__(self, key):
         if key not in self.specs:
             raise ObjUnknown(key)
@@ -4917,17 +4948,31 @@ class Any(Obj):
         """
         :param value: set the value. Either any kind of pyderasn's
                       **ready** object, or bytes. Pay attention that
-                      **no** validation is performed is raw binary value
-                      is valid TLV
+                      **no** validation is performed if raw binary value
+                      is valid TLV, except just tag decoding
         :param bytes expl: override default tag with ``EXPLICIT`` one
         :param bool optional: is object ``OPTIONAL`` in sequence
         """
         super(Any, self).__init__(None, expl, None, optional, _decoded)
-        self._value = None if value is None else self._value_sanitize(value)
+        if value is None:
+            self._value = None
+        else:
+            value = self._value_sanitize(value)
+            self._value = value
+            if self._expl is None:
+                if value.__class__ == binary_type:
+                    tag_class, _, tag_num = tag_decode(tag_strip(value)[0])
+                else:
+                    tag_class, tag_num = value.tag_order
+            else:
+                tag_class, _, tag_num = tag_decode(self._expl)
+            self._tag_order = (tag_class, tag_num)
         self.defined = None
 
     def _value_sanitize(self, value):
         if value.__class__ == binary_type:
+            if len(value) == 0:
+                raise ValueError("Any value can not be empty")
             return value
         if isinstance(value, self.__class__):
             return value._value
@@ -4939,6 +4984,11 @@ class Any(Obj):
     def ready(self):
         return self._value is not None
 
+    @property
+    def tag_order(self):
+        self._assert_ready()
+        return self._tag_order
+
     @property
     def bered(self):
         if self.expl_lenindef or self.lenindef:
@@ -4951,6 +5001,7 @@ class Any(Obj):
         return AnyState(
             __version__,
             self.tag,
+            self._tag_order,
             self._expl,
             None,
             self.optional,
@@ -5307,6 +5358,7 @@ class Sequence(Obj):
         return SequenceState(
             __version__,
             self.tag,
+            self._tag_order,
             self._expl,
             self.default,
             self.optional,
@@ -5636,9 +5688,10 @@ class Set(Sequence):
     asn1_type_name = "SET"
 
     def _encode(self):
-        raws = [v.encode() for v in self._values_for_encoding()]
-        raws.sort()
-        v = b"".join(raws)
+        v = b"".join(value.encode() for value in sorted(
+            self._values_for_encoding(),
+            key=attrgetter("tag_order"),
+        ))
         return b"".join((self.tag, len_encode(len(v)), v))
 
     def _decode(self, tlv, offset, decode_path, ctx, tag_only):
@@ -5694,7 +5747,7 @@ class Set(Sequence):
         ber_encoded = False
         ctx_allow_default_values = ctx.get("allow_default_values", False)
         ctx_allow_unordered_set = ctx.get("allow_unordered_set", False)
-        value_prev = memoryview(v[:0])
+        tag_order_prev = (0, 0)
         _specs_items = copy(self.specs)
 
         while len(v) > 0:
@@ -5729,8 +5782,9 @@ class Set(Sequence):
                 ctx=ctx,
                 _ctx_immutable=False,
             )
+            value_tag_order = value.tag_order
             value_len = value.fulllen
-            if value_prev.tobytes() > v[:value_len].tobytes():
+            if tag_order_prev >= value_tag_order:
                 if ctx_bered or ctx_allow_unordered_set:
                     ber_encoded = True
                 else:
@@ -5753,7 +5807,7 @@ class Set(Sequence):
                 )
             values[name] = value
             del _specs_items[name]
-            value_prev = v[:value_len]
+            tag_order_prev = value_tag_order
             sub_offset += value_len
             vlen += value_len
             v = v_tail
@@ -5895,6 +5949,7 @@ class SequenceOf(Obj):
         return SequenceOfState(
             __version__,
             self.tag,
+            self._tag_order,
             self._expl,
             self.default,
             self.optional,
@@ -6148,9 +6203,7 @@ class SetOf(SequenceOf):
     asn1_type_name = "SET OF"
 
     def _encode(self):
-        raws = [v.encode() for v in self._values_for_encoding()]
-        raws.sort()
-        v = b"".join(raws)
+        v = b"".join(sorted(v.encode() for v in self._values_for_encoding()))
         return b"".join((self.tag, len_encode(len(v)), v))
 
     def _decode(self, tlv, offset, decode_path, ctx, tag_only):
index 3161102997daf57449b26c20194b1081613f8b65..309bc2290df05440bccc2b63d65a53547100e852 100644 (file)
@@ -345,6 +345,9 @@ class CommonMixin(object):
         obj = Inherited()
         self.assertSequenceEqual(obj.impl, impl_tag)
         self.assertFalse(obj.expled)
+        if obj.ready:
+            tag_class, _, tag_num = tag_decode(impl_tag)
+            self.assertEqual(obj.tag_order, (tag_class, tag_num))
 
     @given(binary(min_size=1))
     def test_expl_inherited(self, expl_tag):
@@ -353,6 +356,9 @@ class CommonMixin(object):
         obj = Inherited()
         self.assertSequenceEqual(obj.expl, expl_tag)
         self.assertTrue(obj.expled)
+        if obj.ready:
+            tag_class, _, tag_num = tag_decode(expl_tag)
+            self.assertEqual(obj.tag_order, (tag_class, tag_num))
 
     def assert_copied_basic_fields(self, obj, obj_copied):
         self.assertEqual(obj, obj_copied)
@@ -363,6 +369,8 @@ class CommonMixin(object):
         self.assertEqual(obj.offset, obj_copied.offset)
         self.assertEqual(obj.llen, obj_copied.llen)
         self.assertEqual(obj.vlen, obj_copied.vlen)
+        if obj.ready:
+            self.assertEqual(obj.tag_order, obj_copied.tag_order)
 
 
 @composite
@@ -4667,9 +4675,16 @@ class TestUTCTime(TimeMixin, CommonMixin, TestCase):
             )
 
 
+@composite
+def tlv_value_strategy(draw):
+    tag_num = draw(integers(min_value=1))
+    data = draw(binary())
+    return b"".join((tag_encode(tag_num), len_encode(len(data)), data))
+
+
 @composite
 def any_values_strategy(draw, do_expl=False):
-    value = draw(one_of(none(), binary()))
+    value = draw(one_of(none(), tlv_value_strategy()))
     expl = None
     if do_expl:
         expl = draw(one_of(none(), integers(min_value=1).map(tag_encode)))
@@ -4699,7 +4714,7 @@ class TestAny(CommonMixin, TestCase):
         obj = Any(optional=optional)
         self.assertEqual(obj.optional, optional)
 
-    @given(binary())
+    @given(tlv_value_strategy())
     def test_ready(self, value):
         obj = Any()
         self.assertFalse(obj.ready)
@@ -4733,7 +4748,7 @@ class TestAny(CommonMixin, TestCase):
             pprint(obj, big_blobs=True, with_decode_path=True)
             self.assertSequenceEqual(obj.encode(), integer_encoded)
 
-    @given(binary(min_size=1), binary(min_size=1))
+    @given(tlv_value_strategy(), tlv_value_strategy())
     def test_comparison(self, value1, value2):
         for klass in (Any, AnyInherited):
             obj1 = klass(value1)
@@ -4797,7 +4812,7 @@ class TestAny(CommonMixin, TestCase):
             obj.decode(obj.encode()[:-1])
 
     @given(
-        binary(),
+        tlv_value_strategy(),
         integers(min_value=1).map(tag_ctxc),
     )
     def test_stripped_expl(self, value, tag_expl):
@@ -4853,9 +4868,13 @@ class TestAny(CommonMixin, TestCase):
             list(obj.pps())
             pprint(obj, big_blobs=True, with_decode_path=True)
             self.assertFalse(obj.expled)
+            tag_class, _, tag_num = tag_decode(tag_strip(value)[0])
+            self.assertEqual(obj.tag_order, (tag_class, tag_num))
             obj_encoded = obj.encode()
             obj_expled = obj(value, expl=tag_expl)
             self.assertTrue(obj_expled.expled)
+            tag_class, _, tag_num = tag_decode(tag_expl)
+            self.assertEqual(obj_expled.tag_order, (tag_class, tag_num))
             repr(obj_expled)
             list(obj_expled.pps())
             pprint(obj_expled, big_blobs=True, with_decode_path=True)
@@ -5219,9 +5238,12 @@ class TestChoice(CommonMixin, TestCase):
         list(obj.pps())
         pprint(obj, big_blobs=True, with_decode_path=True)
         self.assertFalse(obj.expled)
+        self.assertEqual(obj.tag_order, obj.value.tag_order)
         obj_encoded = obj.encode()
         obj_expled = obj(value, expl=tag_expl)
         self.assertTrue(obj_expled.expled)
+        tag_class, _, tag_num = tag_decode(tag_expl)
+        self.assertEqual(obj_expled.tag_order, (tag_class, tag_num))
         repr(obj_expled)
         list(obj_expled.pps())
         pprint(obj_expled, big_blobs=True, with_decode_path=True)
@@ -5652,8 +5674,9 @@ class SeqMixing(object):
         with self.assertRaises(NotEnoughData):
             seq.decode(seq.encode()[:-1])
 
-    @given(binary(min_size=2))
-    def test_non_tag_mismatch_raised(self, junk):
+    @given(integers(min_value=3), binary(min_size=2))
+    def test_non_tag_mismatch_raised(self, junk_tag_num, junk):
+        junk = tag_encode(junk_tag_num) + junk
         try:
             _, _, len_encoded = tag_strip(memoryview(junk))
             len_decode(len_encoded)
@@ -5999,42 +6022,44 @@ class TestSet(SeqMixing, CommonMixin, TestCase):
     @settings(max_examples=LONG_TEST_MAX_EXAMPLES)
     @given(data_strategy())
     def test_sorted(self, d):
-        tags = [
-            tag_encode(tag) for tag in
-            d.draw(sets(integers(min_value=1), min_size=1, max_size=10))
-        ]
+        class DummySeq(Sequence):
+            schema = (("null", Null()),)
+
+        tag_nums = d.draw(sets(integers(min_value=1), min_size=1, max_size=50))
+        _, _, dummy_seq_tag_num = tag_decode(DummySeq.tag_default)
+        assume(any(i > dummy_seq_tag_num for i in tag_nums))
+        tag_nums -= set([dummy_seq_tag_num])
+        _schema = [(str(i), OctetString(impl=tag_encode(i))) for i in tag_nums]
+        _schema.append(("seq", DummySeq()))
 
         class Seq(Set):
-            schema = [(str(i), OctetString(impl=t)) for i, t in enumerate(tags)]
+            schema = d.draw(permutations(_schema))
         seq = Seq()
-        for name, _ in Seq.schema:
-            seq[name] = OctetString(b"")
+        for name, _ in _schema:
+            if name != "seq":
+                seq[name] = OctetString(name.encode("ascii"))
+        seq["seq"] = DummySeq((("null", Null()),))
+
         seq_encoded = seq.encode()
         seq_decoded, _ = seq.decode(seq_encoded)
+        seq_encoded_expected = []
+        for tag_num in sorted(tag_nums | set([dummy_seq_tag_num])):
+            if tag_num == dummy_seq_tag_num:
+                seq_encoded_expected.append(seq["seq"].encode())
+            else:
+                seq_encoded_expected.append(seq[str(tag_num)].encode())
         self.assertSequenceEqual(
             seq_encoded[seq_decoded.tlen + seq_decoded.llen:],
-            b"".join(sorted([seq[name].encode() for name, _ in Seq.schema])),
+            b"".join(seq_encoded_expected),
         )
 
-    @settings(max_examples=LONG_TEST_MAX_EXAMPLES)
-    @given(data_strategy())
-    def test_unsorted(self, d):
-        tags = [
-            tag_encode(tag) for tag in
-            d.draw(sets(integers(min_value=1), min_size=2, max_size=5))
-        ]
-        tags = d.draw(permutations(tags))
-        assume(tags != sorted(tags))
-        encoded = b"".join(OctetString(t, impl=t).encode() for t in tags)
+        encoded = b"".join(seq[str(i)].encode() for i in tag_nums)
+        encoded += seq["seq"].encode()
         seq_encoded = b"".join((
             Set.tag_default,
             len_encode(len(encoded)),
             encoded,
         ))
-
-        class Seq(Set):
-            schema = [(str(i), OctetString(impl=t)) for i, t in enumerate(tags)]
-        seq = Seq()
         with assertRaisesRegex(self, DecodeError, "unordered SET"):
             seq.decode(seq_encoded)
         for ctx in ({"bered": True}, {"allow_unordered_set": True}):
@@ -6044,10 +6069,6 @@ class TestSet(SeqMixing, CommonMixin, TestCase):
             seq_decoded = copy(seq_decoded)
             self.assertTrue(seq_decoded.ber_encoded)
             self.assertTrue(seq_decoded.bered)
-            self.assertSequenceEqual(
-                [bytes(seq_decoded[str(i)]) for i, t in enumerate(tags)],
-                tags,
-            )
 
     def test_same_value_twice(self):
         class Seq(Set):