]> Cypherpunks.ru repositories - pyderasn.git/commitdiff
SequenceOf iterator support
authorSergey Matveev <stargrave@stargrave.org>
Sat, 15 Feb 2020 15:52:35 +0000 (18:52 +0300)
committerSergey Matveev <stargrave@stargrave.org>
Sun, 16 Feb 2020 18:25:49 +0000 (21:25 +0300)
doc/news.rst
pyderasn.py
tests/test_pyderasn.py

index 930a3ba0db3a4f408e79e46b9dd50ace77f247b3..658001b8787ab69dce678a5422dd5c173c584184 100644 (file)
@@ -11,6 +11,8 @@ News
   by tag, but by encoded representation
 * ``Any`` does not allow empty data value now. Now it checks if it has
   valid ASN.1 tag
+* ``SetOf`` is not treated as ready, if no value was set and minimum
+  bounds are greater than zero
 * ``Any`` allows an ordinary ``Obj`` storing, without its forceful
   encoded representation storage
 * Initial support for so called ``evgen_mode``: event generation mode,
index 38f9fafdec12617fb68311b182dec1a5bca2eafd..62776b4bbb7570aabdf37ace1ce6d0e52be14551 100755 (executable)
@@ -894,6 +894,7 @@ NAMEDTUPLE_KWARGS = {} if version_info < (3, 6) else {"module": __name__}
 SET01 = frozenset("01")
 DECIMALS = frozenset(digits)
 DECIMAL_SIGNS = ".,"
+NEXT_ATTR_NAME = "next" if PY2 else "__next__"
 
 
 def file_mmaped(fd):
@@ -6226,9 +6227,21 @@ class SequenceOf(Obj):
     >>> ints
     Ints SEQUENCE OF[INTEGER 123, INTEGER 345]
 
-    Also you can initialize sequence with preinitialized values:
+    You can initialize sequence with preinitialized values:
 
     >>> ints = Ints([Integer(123), Integer(234)])
+
+    Also you can use iterator as a value:
+
+    >>> ints = Ints(iter(Integer(i) for i in range(1000000)))
+
+    And it won't be iterated until encoding process. Pay attention that
+    bounds and required schema checks are done only during the encoding
+    process in that case! After encode was called, then value is zeroed
+    back to empty list and you have to set it again. That mode is useful
+    mainly with CER encoding mode, where all objects from the iterable
+    will be streamed to the buffer, without copying all of them to
+    memory first.
     """
     __slots__ = ("spec", "_bound_min", "_bound_max")
     tag_default = tag_encode(form=TagFormConstructed, num=16)
@@ -6272,21 +6285,31 @@ class SequenceOf(Obj):
                 self._value = copy(default_obj._value)
 
     def _value_sanitize(self, value):
+        iterator = False
         if issubclass(value.__class__, SequenceOf):
             value = value._value
+        elif hasattr(value, NEXT_ATTR_NAME):
+            iterator = True
+            value = value
         elif hasattr(value, "__iter__"):
             value = list(value)
         else:
-            raise InvalidValueType((self.__class__, iter))
-        if not self._bound_min <= len(value) <= self._bound_max:
-            raise BoundsError(self._bound_min, len(value), self._bound_max)
-        for v in value:
-            if not isinstance(v, self.spec.__class__):
-                raise InvalidValueType((self.spec.__class__,))
+            raise InvalidValueType((self.__class__, iter, "iterator"))
+        if not iterator:
+            if not self._bound_min <= len(value) <= self._bound_max:
+                raise BoundsError(self._bound_min, len(value), self._bound_max)
+            class_expected = self.spec.__class__
+            for v in value:
+                if not isinstance(v, class_expected):
+                    raise InvalidValueType((class_expected,))
         return value
 
     @property
     def ready(self):
+        if hasattr(self._value, NEXT_ATTR_NAME):
+            return True
+        if self._bound_min > 0 and len(self._value) == 0:
+            return False
         return all(v.ready for v in self._value)
 
     @property
@@ -6296,6 +6319,8 @@ class SequenceOf(Obj):
         return any(v.bered for v in self._value)
 
     def __getstate__(self):
+        if hasattr(self._value, NEXT_ATTR_NAME):
+            raise ValueError("can not pickle SequenceOf with iterator")
         return SequenceOfState(
             __version__,
             self.tag,
@@ -6371,11 +6396,9 @@ class SequenceOf(Obj):
         self._value.append(value)
 
     def __iter__(self):
-        self._assert_ready()
         return iter(self._value)
 
     def __len__(self):
-        self._assert_ready()
         return len(self._value)
 
     def __setitem__(self, key, value):
@@ -6390,13 +6413,42 @@ class SequenceOf(Obj):
         return iter(self._value)
 
     def _encode(self):
-        v = b"".join(v.encode() for v in self._values_for_encoding())
-        return b"".join((self.tag, len_encode(len(v)), v))
+        iterator = hasattr(self._value, NEXT_ATTR_NAME)
+        if iterator:
+            values = []
+            values_append = values.append
+            class_expected = self.spec.__class__
+            values_for_encoding = self._values_for_encoding()
+            self._value = []
+            for v in values_for_encoding:
+                if not isinstance(v, class_expected):
+                    raise InvalidValueType((class_expected,))
+                values_append(v.encode())
+            if not self._bound_min <= len(values) <= self._bound_max:
+                raise BoundsError(self._bound_min, len(values), self._bound_max)
+            value = b"".join(values)
+        else:
+            value = b"".join(v.encode() for v in self._values_for_encoding())
+        return b"".join((self.tag, len_encode(len(value)), value))
 
     def _encode_cer(self, writer):
         write_full(writer, self.tag + LENINDEF)
-        for v in self._values_for_encoding():
-            v.encode_cer(writer)
+        iterator = hasattr(self._value, NEXT_ATTR_NAME)
+        if iterator:
+            class_expected = self.spec.__class__
+            values_count = 0
+            values_for_encoding = self._values_for_encoding()
+            self._value = []
+            for v in values_for_encoding:
+                if not isinstance(v, class_expected):
+                    raise InvalidValueType((class_expected,))
+                v.encode_cer(writer)
+                values_count += 1
+            if not self._bound_min <= values_count <= self._bound_max:
+                raise BoundsError(self._bound_min, values_count, self._bound_max)
+        else:
+            for v in self._values_for_encoding():
+                v.encode_cer(writer)
         write_full(writer, EOC)
 
     def _decode(
@@ -6590,6 +6642,14 @@ class SetOf(SequenceOf):
     tag_default = tag_encode(form=TagFormConstructed, num=17)
     asn1_type_name = "SET OF"
 
+    def _value_sanitize(self, value):
+        value = super(SetOf, self)._value_sanitize(value)
+        if hasattr(value, NEXT_ATTR_NAME):
+            raise ValueError(
+                "SetOf does not support iterator values, as no sense in them"
+            )
+        return value
+
     def _encode(self):
         v = b"".join(sorted(v.encode() for v in self._values_for_encoding()))
         return b"".join((self.tag, len_encode(len(v)), v))
index cd165b857a40b5d628e5839082895b3fdb33a18b..0489afdbbbe9e4538fe980b29f883959f2f8e761 100644 (file)
@@ -60,6 +60,7 @@ from six import iterbytes
 from six import PY2
 from six import text_type
 from six import unichr as six_unichr
+from six.moves import xrange as six_xrange
 from six.moves.cPickle import dumps as pickle_dumps
 from six.moves.cPickle import HIGHEST_PROTOCOL as pickle_proto
 from six.moves.cPickle import loads as pickle_loads
@@ -6899,6 +6900,57 @@ class TestSequenceOf(SeqOfMixing, CommonMixin, TestCase):
         self.assertEqual(obj1, obj2)
         self.assertSequenceEqual(list(obj1), list(obj2))
 
+    def test_iterator_pickling(self):
+        class SeqOf(SequenceOf):
+            schema = Integer()
+        register_class(SeqOf)
+        seqof = SeqOf()
+        pickle_dumps(seqof)
+        seqof = seqof(iter(six_xrange(10)))
+        with assertRaisesRegex(self, ValueError, "iterator"):
+            pickle_dumps(seqof)
+
+    def test_iterator_bounds(self):
+        class SeqOf(SequenceOf):
+            schema = Integer()
+            bounds = (10, 20)
+        seqof = None
+        def gen(n):
+            for i in six_xrange(n):
+                yield Integer(i)
+        for n in (9, 21):
+            seqof = SeqOf(gen(n))
+            self.assertTrue(seqof.ready)
+            with self.assertRaises(BoundsError):
+                seqof.encode()
+            self.assertFalse(seqof.ready)
+            seqof = seqof(gen(n))
+            self.assertTrue(seqof.ready)
+            with self.assertRaises(BoundsError):
+                encode_cer(seqof)
+            self.assertFalse(seqof.ready)
+
+    def test_iterator_twice(self):
+        class SeqOf(SequenceOf):
+            schema = Integer()
+            bounds = (1, float("+inf"))
+        def gen():
+            for i in six_xrange(10):
+                yield Integer(i)
+        seqof = SeqOf(gen())
+        self.assertTrue(seqof.ready)
+        seqof.encode()
+        self.assertFalse(seqof.ready)
+        register_class(SeqOf)
+        pickle_dumps(seqof)
+
+    def test_non_ready_bound_min(self):
+        class SeqOf(SequenceOf):
+            schema = Integer()
+            bounds = (1, float("+inf"))
+        seqof = SeqOf()
+        self.assertFalse(seqof.ready)
+
 
 class TestSetOf(SeqOfMixing, CommonMixin, TestCase):
     class SeqOf(SetOf):