]> Cypherpunks.ru repositories - pyderasn.git/blobdiff - pyderasn.py
DER 2pass encoding
[pyderasn.git] / pyderasn.py
index 8208efeb024148e54e8c6e71ca369383b393f273..c9d6a116044b5515b0f75b6b5271327345106f6d 100755 (executable)
@@ -829,6 +829,8 @@ copy the payload (without BER/CER encoding interleaved overhead) in it.
 Virtually it won't take memory more than for keeping small structures
 and 1 KB binary chunks.
 
+.. _seqof-iterators:
+
 SEQUENCE OF iterators
 _____________________
 
@@ -844,6 +846,54 @@ generator taking necessary data from the database and giving the
 ``RevokedCertificate`` objects. Only binary representation of that
 objects will take memory during DER encoding.
 
+2-pass DER encoding
+-------------------
+
+There is ability to do 2-pass encoding to DER, writing results directly
+to specified writer (buffer, file, whatever). It could be 1.5+ times
+slower than ordinary encoding, but it takes little memory for 1st pass
+state storing. For example, 1st pass state for CACert.org's CRL with
+~416K of certificate entries takes nearly 3.5 MB of memory.
+``SignedData`` with several gigabyte ``EncapsulatedContentInfo`` takes
+nearly 0.5 KB of memory.
+
+If you use :ref:`mmap-ed <mmap>` memoryviews, :ref:`SEQUENCE OF
+iterators <seqof-iterators>` and write directly to opened file, then
+there is very small memory footprint.
+
+1st pass traverses through all the objects of the structure and returns
+the size of DER encoded structure, together with 1st pass state object.
+That state contains precalculated lengths for various objects inside the
+structure.
+
+::
+
+    fulllen, state = obj.encode1st()
+
+2nd pass takes the writer and 1st pass state. It traverses through all
+the objects again, but writes their encoded representation to the writer.
+
+::
+
+    opener = io.open if PY2 else open
+    with opener("result", "wb") as fd:
+        obj.encode2nd(fd.write, iter(state))
+
+.. warning::
+
+   You **MUST NOT** use 1st pass state if anything is changed in the
+   objects. It is intended to be used immediately after 1st pass is
+   done!
+
+If you use :ref:`SEQUENCE OF iterators <seqof-iterators>`, then you
+have to reinitialize the values after the 1st pass. And you **have to**
+be sure that the iterator gives exactly the same values as previously.
+Yes, you have to run your iterator twice -- because this is two pass
+encoding mode.
+
+If you want to encode to the memory, then you can use convenient
+:py:func:`pyderasn.encode2pass` helper.
+
 Base Obj
 --------
 .. autoclass:: pyderasn.Obj
@@ -955,6 +1005,7 @@ Various
 .. autofunction:: pyderasn.abs_decode_path
 .. autofunction:: pyderasn.agg_octet_string
 .. autofunction:: pyderasn.colonize_hex
+.. autofunction:: pyderasn.encode2pass
 .. autofunction:: pyderasn.encode_cer
 .. autofunction:: pyderasn.file_mmaped
 .. autofunction:: pyderasn.hexenc
@@ -1101,6 +1152,7 @@ from mmap import PROT_READ
 from operator import attrgetter
 from string import ascii_letters
 from string import digits
+from sys import maxsize as sys_maxsize
 from sys import version_info
 from unicodedata import category as unicat
 
@@ -1138,6 +1190,7 @@ __all__ = (
     "Choice",
     "DecodeError",
     "DecodePathDefBy",
+    "encode2pass",
     "encode_cer",
     "Enumerated",
     "ExceedingData",
@@ -1525,6 +1578,28 @@ LEN1 = len_encode(1)
 LEN1K = len_encode(1000)
 
 
+def len_size(l):
+    """How many bytes length field will take
+    """
+    if l < 128:
+        return 1
+    if l < 256:  # 1 << 8
+        return 2
+    if l < 65536:  # 1 << 16
+        return 3
+    if l < 16777216:  # 1 << 24
+        return 4
+    if l < 4294967296:  # 1 << 32
+        return 5
+    if l < 1099511627776:  # 1 << 40
+        return 6
+    if l < 281474976710656:  # 1 << 48
+        return 7
+    if l < 72057594037927936:  # 1 << 56
+        return 8
+    raise OverflowError("too big length")
+
+
 def write_full(writer, data):
     """Fully write provided data
 
@@ -1543,6 +1618,17 @@ def write_full(writer, data):
         written += n
 
 
+# If it is 64-bit system, then use compact 64-bit array of unsigned
+# longs. Use an ordinary list with universal integers otherwise, that
+# is slower.
+if sys_maxsize > 2 ** 32:
+    def state_2pass_new():
+        return array("L")
+else:
+    def state_2pass_new():
+        return []
+
+
 ########################################################################
 # Base class
 ########################################################################
@@ -1701,9 +1787,18 @@ class Obj(object):
     def _encode(self):  # pragma: no cover
         raise NotImplementedError()
 
+    def _encode_cer(self, writer):
+        write_full(writer, self._encode())
+
     def _decode(self, tlv, offset, decode_path, ctx, tag_only, evgen_mode):  # pragma: no cover
         yield NotImplemented
 
+    def _encode1st(self, state):
+        raise NotImplementedError()
+
+    def _encode2nd(self, writer, state_iter):
+        raise NotImplementedError()
+
     def encode(self):
         """DER encode the structure
 
@@ -1714,6 +1809,36 @@ class Obj(object):
             return raw
         return b"".join((self._expl, len_encode(len(raw)), raw))
 
+    def encode1st(self, state=None):
+        """Do the 1st pass of 2-pass encoding
+
+        :rtype: (int, array("L"))
+        :returns: full length of encoded data and precalculated various
+                  objects lengths
+        """
+        if state is None:
+            state = state_2pass_new()
+        if self._expl is None:
+            return self._encode1st(state)
+        state.append(0)
+        idx = len(state) - 1
+        vlen, _ = self._encode1st(state)
+        state[idx] = vlen
+        fulllen = len(self._expl) + len_size(vlen) + vlen
+        return fulllen, state
+
+    def encode2nd(self, writer, state_iter):
+        """Do the 2nd pass of 2-pass encoding
+
+        :param writer: must comply with ``io.RawIOBase.write`` behaviour
+        :param state_iter: iterator over the 1st pass state (``iter(state)``)
+        """
+        if self._expl is None:
+            self._encode2nd(writer, state_iter)
+        else:
+            write_full(writer, self._expl + len_encode(next(state_iter)))
+            self._encode2nd(writer, state_iter)
+
     def encode_cer(self, writer):
         """CER encode the structure to specified writer
 
@@ -1731,9 +1856,6 @@ class Obj(object):
         if self._expl is not None:
             write_full(writer, EOC)
 
-    def _encode_cer(self, writer):
-        write_full(writer, self._encode())
-
     def hexencode(self):
         """Do hexadecimal encoded :py:meth:`pyderasn.Obj.encode`
         """
@@ -2045,6 +2167,17 @@ def encode_cer(obj):
     return buf.getvalue()
 
 
+def encode2pass(obj):
+    """Encode (2-pass mode) to DER in memory buffer
+
+    :returns bytes: memory buffer contents
+    """
+    buf = BytesIO()
+    _, state = obj.encode1st()
+    obj.encode2nd(buf.write, iter(state))
+    return buf.getvalue()
+
+
 class DecodePathDefBy(object):
     """DEFINED BY representation inside decode path
     """
@@ -2464,6 +2597,13 @@ class Boolean(Obj):
         self._assert_ready()
         return b"".join((self.tag, LEN1, (b"\xFF" if self._value else b"\x00")))
 
+    def _encode1st(self, state):
+        return len(self.tag) + 2, state
+
+    def _encode2nd(self, writer, state_iter):
+        self._assert_ready()
+        write_full(writer, self._encode())
+
     def _decode(self, tlv, offset, decode_path, ctx, tag_only, evgen_mode):
         try:
             t, _, lv = tag_strip(tlv)
@@ -2755,7 +2895,7 @@ class Integer(Obj):
             _specs=self.specs,
         )
 
-    def _encode(self):
+    def _encode_payload(self):
         self._assert_ready()
         value = self._value
         if PY2:
@@ -2792,8 +2932,20 @@ class Integer(Obj):
                     bytes_len += 1
                 else:
                     break
+        return octets
+        return b"".join((self.tag, len_encode(len(octets)), octets))
+
+    def _encode(self):
+        octets = self._encode_payload()
         return b"".join((self.tag, len_encode(len(octets)), octets))
 
+    def _encode1st(self, state):
+        l = len(self._encode_payload())
+        return len(self.tag) + len_size(l) + l, state
+
+    def _encode2nd(self, writer, state_iter):
+        write_full(writer, self._encode())
+
     def _decode(self, tlv, offset, decode_path, ctx, tag_only, evgen_mode):
         try:
             t, _, lv = tag_strip(tlv)
@@ -3178,6 +3330,21 @@ class BitString(Obj):
             octets,
         ))
 
+    def _encode1st(self, state):
+        self._assert_ready()
+        _, octets = self._value
+        l = len(octets) + 1
+        return len(self.tag) + len_size(l) + l, state
+
+    def _encode2nd(self, writer, state_iter):
+        bit_len, octets = self._value
+        write_full(writer, b"".join((
+            self.tag,
+            len_encode(len(octets) + 1),
+            int2byte((8 - bit_len % 8) % 8),
+        )))
+        write_full(writer, octets)
+
     def _encode_cer(self, writer):
         bit_len, octets = self._value
         if len(octets) + 1 <= 1000:
@@ -3629,6 +3796,16 @@ class OctetString(Obj):
             self._value,
         ))
 
+    def _encode1st(self, state):
+        self._assert_ready()
+        l = len(self._value)
+        return len(self.tag) + len_size(l) + l, state
+
+    def _encode2nd(self, writer, state_iter):
+        value = self._value
+        write_full(writer, self.tag + len_encode(len(value)))
+        write_full(writer, value)
+
     def _encode_cer(self, writer):
         octets = self._value
         if len(octets) <= 1000:
@@ -3987,6 +4164,12 @@ class Null(Obj):
     def _encode(self):
         return self.tag + LEN0
 
+    def _encode1st(self, state):
+        return len(self.tag) + 1, state
+
+    def _encode2nd(self, writer, state_iter):
+        write_full(writer, self.tag + LEN0)
+
     def _decode(self, tlv, offset, decode_path, ctx, tag_only, evgen_mode):
         try:
             t, _, lv = tag_strip(tlv)
@@ -4239,7 +4422,7 @@ class ObjectIdentifier(Obj):
             optional=self.optional if optional is None else optional,
         )
 
-    def _encode(self):
+    def _encode_octets(self):
         self._assert_ready()
         value = self._value
         first_value = value[1]
@@ -4255,9 +4438,19 @@ class ObjectIdentifier(Obj):
         octets = [zero_ended_encode(first_value)]
         for arc in value[2:]:
             octets.append(zero_ended_encode(arc))
-        v = b"".join(octets)
+        return b"".join(octets)
+
+    def _encode(self):
+        v = self._encode_octets()
         return b"".join((self.tag, len_encode(len(v)), v))
 
+    def _encode1st(self, state):
+        l = len(self._encode_octets())
+        return len(self.tag) + len_size(l) + l, state
+
+    def _encode2nd(self, writer, state_iter):
+        write_full(writer, self._encode())
+
     def _decode(self, tlv, offset, decode_path, ctx, tag_only, evgen_mode):
         try:
             t, _, lv = tag_strip(tlv)
@@ -5017,6 +5210,13 @@ class UTCTime(VisibleString):
         self._assert_ready()
         return b"".join((self.tag, LEN_LEN_YYMMDDHHMMSSZ, self._encode_time()))
 
+    def _encode1st(self, state):
+        return len(self.tag) + LEN_YYMMDDHHMMSSZ_WITH_LEN, state
+
+    def _encode2nd(self, writer, state_iter):
+        self._assert_ready()
+        write_full(writer, self._encode())
+
     def _encode_cer(self, writer):
         write_full(writer, self._encode())
 
@@ -5189,6 +5389,14 @@ class GeneralizedTime(UTCTime):
             return b"".join((self.tag, len_encode(len(encoded)), encoded))
         return b"".join((self.tag, LEN_LEN_YYYYMMDDHHMMSSZ, self._encode_time()))
 
+    def _encode1st(self, state):
+        self._assert_ready()
+        vlen = len(self._encode_time())
+        return len(self.tag) + len_size(vlen) + vlen, state
+
+    def _encode2nd(self, writer, state_iter):
+        write_full(writer, self._encode())
+
 
 class GraphicString(CommonString):
     __slots__ = ()
@@ -5433,6 +5641,13 @@ class Choice(Obj):
         self._assert_ready()
         return self._value[1].encode()
 
+    def _encode1st(self, state):
+        self._assert_ready()
+        return self._value[1].encode1st(state)
+
+    def _encode2nd(self, writer, state_iter):
+        self._value[1].encode2nd(writer, state_iter)
+
     def _encode_cer(self, writer):
         self._assert_ready()
         self._value[1].encode_cer(writer)
@@ -5701,6 +5916,20 @@ class Any(Obj):
             return value
         return value.encode()
 
+    def _encode1st(self, state):
+        self._assert_ready()
+        value = self._value
+        if value.__class__ == binary_type:
+            return len(value), state
+        return value.encode1st(state)
+
+    def _encode2nd(self, writer, state_iter):
+        value = self._value
+        if value.__class__ == binary_type:
+            write_full(writer, value)
+        else:
+            value.encode2nd(writer, state_iter)
+
     def _encode_cer(self, writer):
         self._assert_ready()
         value = self._value
@@ -5859,7 +6088,19 @@ SequenceState = namedtuple(
 )
 
 
-class Sequence(Obj):
+class SequenceEncode1stMixing(object):
+    def _encode1st(self, state):
+        state.append(0)
+        idx = len(state) - 1
+        vlen = 0
+        for v in self._values_for_encoding():
+            l, _ = v.encode1st(state)
+            vlen += l
+        state[idx] = vlen
+        return len(self.tag) + len_size(vlen) + vlen, state
+
+
+class Sequence(SequenceEncode1stMixing, Obj):
     """``SEQUENCE`` structure type
 
     You have to make specification of sequence::
@@ -6106,6 +6347,11 @@ class Sequence(Obj):
         v = b"".join(v.encode() for v in self._values_for_encoding())
         return b"".join((self.tag, len_encode(len(v)), v))
 
+    def _encode2nd(self, writer, state_iter):
+        write_full(writer, self.tag + len_encode(next(state_iter)))
+        for v in self._values_for_encoding():
+            v.encode2nd(writer, state_iter)
+
     def _encode_cer(self, writer):
         write_full(writer, self.tag + LENINDEF)
         for v in self._values_for_encoding():
@@ -6350,7 +6596,7 @@ class Sequence(Obj):
             yield pp
 
 
-class Set(Sequence):
+class Set(Sequence, SequenceEncode1stMixing):
     """``SET`` structure type
 
     Its usage is identical to :py:class:`pyderasn.Sequence`.
@@ -6552,7 +6798,7 @@ SequenceOfState = namedtuple(
 )
 
 
-class SequenceOf(Obj):
+class SequenceOf(SequenceEncode1stMixing, Obj):
     """``SEQUENCE OF`` sequence type
 
     For that kind of type you must specify the object it will carry on
@@ -6781,6 +7027,31 @@ class SequenceOf(Obj):
             value = b"".join(v.encode() for v in self._values_for_encoding())
         return b"".join((self.tag, len_encode(len(value)), value))
 
+    def _encode1st(self, state):
+        state = super(SequenceOf, self)._encode1st(state)
+        if hasattr(self._value, NEXT_ATTR_NAME):
+            self._value = []
+        return state
+
+    def _encode2nd(self, writer, state_iter):
+        write_full(writer, self.tag + len_encode(next(state_iter)))
+        iterator = hasattr(self._value, NEXT_ATTR_NAME)
+        if iterator:
+            values_count = 0
+            class_expected = self.spec.__class__
+            values_for_encoding = self._values_for_encoding()
+            self._value = []
+            for v in values_for_encoding:
+                if not isinstance(v, class_expected):
+                    raise InvalidValueType((class_expected,))
+                v.encode2nd(writer, state_iter)
+                values_count += 1
+            if not self._bound_min <= values_count <= self._bound_max:
+                raise BoundsError(self._bound_min, values_count, self._bound_max)
+        else:
+            for v in self._values_for_encoding():
+                v.encode2nd(writer, state_iter)
+
     def _encode_cer(self, writer):
         write_full(writer, self.tag + LENINDEF)
         iterator = hasattr(self._value, NEXT_ATTR_NAME)
@@ -7004,6 +7275,17 @@ class SetOf(SequenceOf):
         v = b"".join(sorted(v.encode() for v in self._values_for_encoding()))
         return b"".join((self.tag, len_encode(len(v)), v))
 
+    def _encode2nd(self, writer, state_iter):
+        write_full(writer, self.tag + len_encode(next(state_iter)))
+        values = []
+        for v in self._values_for_encoding():
+            buf = BytesIO()
+            v.encode2nd(buf.write, state_iter)
+            values.append(buf.getvalue())
+        values.sort()
+        for v in values:
+            write_full(writer, v)
+
     def _encode_cer(self, writer):
         write_full(writer, self.tag + LENINDEF)
         for v in sorted(encode_cer(v) for v in self._values_for_encoding()):