]> Cypherpunks.ru repositories - pyderasn.git/blobdiff - pyderasn.py
SequenceOf iterator support
[pyderasn.git] / pyderasn.py
index 24204ce8780ca8813723e4d09d715186fe1d1201..62776b4bbb7570aabdf37ace1ce6d0e52be14551 100755 (executable)
@@ -643,6 +643,8 @@ Various
 
 .. autofunction:: pyderasn.abs_decode_path
 .. autofunction:: pyderasn.colonize_hex
+.. autofunction:: pyderasn.encode_cer
+.. autofunction:: pyderasn.file_mmaped
 .. autofunction:: pyderasn.hexenc
 .. autofunction:: pyderasn.hexdec
 .. autofunction:: pyderasn.tag_encode
@@ -779,7 +781,10 @@ from collections import OrderedDict
 from copy import copy
 from datetime import datetime
 from datetime import timedelta
+from io import BytesIO
 from math import ceil
+from mmap import mmap
+from mmap import PROT_READ
 from operator import attrgetter
 from string import ascii_letters
 from string import digits
@@ -811,6 +816,7 @@ except ImportError:  # pragma: no cover
 __version__ = "7.0"
 
 __all__ = (
+    "agg_octet_string",
     "Any",
     "BitString",
     "BMPString",
@@ -819,8 +825,10 @@ __all__ = (
     "Choice",
     "DecodeError",
     "DecodePathDefBy",
+    "encode_cer",
     "Enumerated",
     "ExceedingData",
+    "file_mmaped",
     "GeneralizedTime",
     "GeneralString",
     "GraphicString",
@@ -886,8 +894,17 @@ NAMEDTUPLE_KWARGS = {} if version_info < (3, 6) else {"module": __name__}
 SET01 = frozenset("01")
 DECIMALS = frozenset(digits)
 DECIMAL_SIGNS = ".,"
+NEXT_ATTR_NAME = "next" if PY2 else "__next__"
 
 
+def file_mmaped(fd):
+    """Make mmap-ed memoryview for reading from file
+
+    :param fd: file object
+    :returns: memoryview over read-only mmap-ing of the whole file
+    """
+    return memoryview(mmap(fd.fileno(), 0, prot=PROT_READ))
+
 def pureint(value):
     if not set(value) <= DECIMALS:
         raise ValueError("non-pure integer")
@@ -1190,6 +1207,23 @@ def len_decode(data):
     return l, 1 + octets_num, data[1 + octets_num:]
 
 
+LEN1K = len_encode(1000)
+
+
+def write_full(writer, data):
+    """Fully write provided data
+
+    BytesIO does not guarantee that the whole data will be written at once.
+    """
+    data = memoryview(data)
+    written = 0
+    while written != len(data):
+        n = writer(data[written:])
+        if n is None:
+            raise ValueError("can not write to buf")
+        written += n
+
+
 ########################################################################
 # Base class
 ########################################################################
@@ -1314,6 +1348,10 @@ class Obj(object):
         """
         return self._tag_order
 
+    @property
+    def tag_order_cer(self):
+        return self.tag_order
+
     @property
     def tlen(self):
         """See :ref:`decoding`
@@ -1357,6 +1395,19 @@ class Obj(object):
             return raw
         return b"".join((self._expl, len_encode(len(raw)), raw))
 
+    def encode_cer(self, writer):
+        if self._expl is not None:
+            write_full(writer, self._expl + LENINDEF)
+        if getattr(self, "der_forced", False):
+            write_full(writer, self._encode())
+        else:
+            self._encode_cer(writer)
+        if self._expl is not None:
+            write_full(writer, EOC)
+
+    def _encode_cer(self, writer):
+        write_full(writer, self._encode())
+
     def hexencode(self):
         """Do hexadecimal encoded :py:meth:`pyderasn.Obj.encode`
         """
@@ -1648,6 +1699,14 @@ class Obj(object):
             )
 
 
+def encode_cer(obj):
+    """Encode to CER in memory
+    """
+    buf = BytesIO()
+    obj.encode_cer(buf.write)
+    return buf.getvalue()
+
+
 class DecodePathDefBy(object):
     """DEFINED BY representation inside decode path
     """
@@ -1886,6 +1945,7 @@ def pprint(
         with_colours=False,
         with_decode_path=False,
         decode_path_only=(),
+        decode_path=(),
 ):
     """Pretty print object
 
@@ -1938,7 +1998,7 @@ def pprint(
             else:
                 for row in _pprint_pps(pp):
                     yield row
-    return "\n".join(_pprint_pps(obj.pps()))
+    return "\n".join(_pprint_pps(obj.pps(decode_path)))
 
 
 ########################################################################
@@ -2784,6 +2844,30 @@ class BitString(Obj):
             octets,
         ))
 
+    def _encode_cer(self, writer):
+        bit_len, octets = self._value
+        if len(octets) + 1 <= 1000:
+            write_full(writer, self._encode())
+            return
+        write_full(writer, self.tag_constructed)
+        write_full(writer, LENINDEF)
+        for offset in six_xrange(0, (len(octets) // 999) * 999, 999):
+            write_full(writer, b"".join((
+                BitString.tag_default,
+                LEN1K,
+                int2byte(0),
+                octets[offset:offset + 999],
+            )))
+        tail = octets[offset+999:]
+        if len(tail) > 0:
+            tail = int2byte((8 - bit_len % 8) % 8) + tail
+            write_full(writer, b"".join((
+                BitString.tag_default,
+                len_encode(len(tail)),
+                tail,
+            )))
+        write_full(writer, EOC)
+
     def _decode(self, tlv, offset, decode_path, ctx, tag_only, evgen_mode):
         try:
             t, tlen, lv = tag_strip(tlv)
@@ -3063,13 +3147,10 @@ class OctetString(Obj):
     >>> OctetString(b"hell", bounds=(4, 4))
     OCTET STRING 4 bytes 68656c6c
 
-    .. note::
-
-       Pay attention that OCTET STRING can be encoded both in primitive
-       and constructed forms. Decoder always checks constructed form tag
-       additionally to specified primitive one. If BER decoding is
-       :ref:`not enabled <bered_ctx>`, then decoder will fail, because
-       of DER restrictions.
+    Memoryviews can be used as a values. If memoryview is made on
+    mmap-ed file, then it does not take storage inside OctetString
+    itself. In CER encoding mode it will be streamed to the specified
+    writer, copying 1 KB chunks.
     """
     __slots__ = ("tag_constructed", "_bound_min", "_bound_max", "defined")
     tag_default = tag_encode(4)
@@ -3124,12 +3205,12 @@ class OctetString(Obj):
         )
 
     def _value_sanitize(self, value):
-        if value.__class__ == binary_type:
+        if value.__class__ == binary_type or value.__class__ == memoryview:
             pass
         elif issubclass(value.__class__, OctetString):
             value = value._value
         else:
-            raise InvalidValueType((self.__class__, bytes))
+            raise InvalidValueType((self.__class__, bytes, memoryview))
         if not self._bound_min <= len(value) <= self._bound_max:
             raise BoundsError(self._bound_min, len(value), self._bound_max)
         return value
@@ -3169,7 +3250,7 @@ class OctetString(Obj):
 
     def __bytes__(self):
         self._assert_ready()
-        return self._value
+        return bytes(self._value)
 
     def __eq__(self, their):
         if their.__class__ == binary_type:
@@ -3214,6 +3295,28 @@ class OctetString(Obj):
             self._value,
         ))
 
+    def _encode_cer(self, writer):
+        octets = self._value
+        if len(octets) <= 1000:
+            write_full(writer, self._encode())
+            return
+        write_full(writer, self.tag_constructed)
+        write_full(writer, LENINDEF)
+        for offset in six_xrange(0, (len(octets) // 1000) * 1000, 1000):
+            write_full(writer, b"".join((
+                OctetString.tag_default,
+                LEN1K,
+                octets[offset:offset + 1000],
+            )))
+        tail = octets[offset+1000:]
+        if len(tail) > 0:
+            write_full(writer, b"".join((
+                OctetString.tag_default,
+                len_encode(len(tail)),
+                tail,
+            )))
+        write_full(writer, EOC)
+
     def _decode(self, tlv, offset, decode_path, ctx, tag_only, evgen_mode):
         try:
             t, tlen, lv = tag_strip(tlv)
@@ -3450,6 +3553,29 @@ class OctetString(Obj):
             yield pp
 
 
+def agg_octet_string(evgens, decode_path, raw, writer):
+    """Aggregate constructed string (OctetString and its derivatives)
+
+    :param evgens: iterator of generated events
+    :param decode_path: points to the string we want to decode
+    :param raw: slicebable (memoryview, bytearray, etc) with
+                the data evgens are generated one
+    :param writer: buffer.write where string is going to be saved
+    """
+    decode_path_len = len(decode_path)
+    for dp, obj, _ in evgens:
+        if dp[:decode_path_len] != decode_path:
+            continue
+        if not obj.ber_encoded:
+            write_full(writer, raw[
+                obj.offset + obj.tlen + obj.llen:
+                obj.offset + obj.tlen + obj.llen + obj.vlen -
+                (EOC_LEN if obj.expl_lenindef else 0)
+            ])
+        if len(dp) == decode_path_len:
+            break
+
+
 NullState = namedtuple("NullState", BasicState._fields, **NAMEDTUPLE_KWARGS)
 
 
@@ -4534,6 +4660,9 @@ class UTCTime(VisibleString):
         value = self._encode_time()
         return b"".join((self.tag, len_encode(len(value)), value))
 
+    def _encode_cer(self, writer):
+        write_full(writer, self._encode())
+
     def todatetime(self):
         return self._value
 
@@ -4905,6 +5034,10 @@ class Choice(Obj):
         self._assert_ready()
         return self._value[1].tag_order if self._tag_order is None else self._tag_order
 
+    @property
+    def tag_order_cer(self):
+        return min(v.tag_order_cer for v in itervalues(self.specs))
+
     def __getitem__(self, key):
         if key not in self.specs:
             raise ObjUnknown(key)
@@ -4935,6 +5068,10 @@ class Choice(Obj):
         self._assert_ready()
         return self._value[1].encode()
 
+    def _encode_cer(self, writer):
+        self._assert_ready()
+        self._value[1].encode_cer(writer)
+
     def _decode(self, tlv, offset, decode_path, ctx, tag_only, evgen_mode):
         for choice, spec in iteritems(self.specs):
             sub_decode_path = decode_path + (choice,)
@@ -5066,7 +5203,7 @@ class Any(Obj):
     """``ANY`` special type
 
     >>> Any(Integer(-123))
-    ANY 020185
+    ANY INTEGER -123 (0X:7B)
     >>> a = Any(OctetString(b"hello world").encode())
     ANY 040b68656c6c6f20776f726c64
     >>> hexenc(bytes(a))
@@ -5114,9 +5251,9 @@ class Any(Obj):
             return value
         if isinstance(value, self.__class__):
             return value._value
-        if isinstance(value, Obj):
-            return value.encode()
-        raise InvalidValueType((self.__class__, Obj, binary_type))
+        if not isinstance(value, Obj):
+            raise InvalidValueType((self.__class__, Obj, binary_type))
+        return value
 
     @property
     def ready(self):
@@ -5160,9 +5297,13 @@ class Any(Obj):
 
     def __eq__(self, their):
         if their.__class__ == binary_type:
-            return self._value == their
+            if self._value.__class__ == binary_type:
+                return self._value == their
+            return self._value.encode() == their
         if issubclass(their.__class__, Any):
-            return self._value == their._value
+            if self.ready and their.ready:
+                return bytes(self) == bytes(their)
+            return self.ready == their.ready
         return False
 
     def __call__(
@@ -5179,7 +5320,10 @@ class Any(Obj):
 
     def __bytes__(self):
         self._assert_ready()
-        return self._value
+        value = self._value
+        if value.__class__ == binary_type:
+            return value
+        return self._value.encode()
 
     @property
     def tlen(self):
@@ -5187,7 +5331,18 @@ class Any(Obj):
 
     def _encode(self):
         self._assert_ready()
-        return self._value
+        value = self._value
+        if value.__class__ == binary_type:
+            return value
+        return value.encode()
+
+    def _encode_cer(self, writer):
+        self._assert_ready()
+        value = self._value
+        if value.__class__ == binary_type:
+            write_full(writer, value)
+        else:
+            value.encode_cer(writer)
 
     def _decode(self, tlv, offset, decode_path, ctx, tag_only, evgen_mode):
         try:
@@ -5264,12 +5419,20 @@ class Any(Obj):
         return pp_console_row(next(self.pps()))
 
     def pps(self, decode_path=()):
+        value = self._value
+        if value is None:
+            pass
+        elif value.__class__ == binary_type:
+            value = None
+        else:
+            value = repr(value)
         yield _pp(
             obj=self,
             asn1_type_name=self.asn1_type_name,
             obj_name=self.__class__.__name__,
             decode_path=decode_path,
-            blob=self._value if self.ready else None,
+            value=value,
+            blob=self._value if self._value.__class__ == binary_type else None,
             optional=self.optional,
             default=self == self.default,
             impl=None if self.tag == self.tag_default else tag_decode(self.tag),
@@ -5591,6 +5754,12 @@ class Sequence(Obj):
         v = b"".join(v.encode() for v in self._values_for_encoding())
         return b"".join((self.tag, len_encode(len(v)), v))
 
+    def _encode_cer(self, writer):
+        write_full(writer, self.tag + LENINDEF)
+        for v in self._values_for_encoding():
+            v.encode_cer(writer)
+        write_full(writer, EOC)
+
     def _decode(self, tlv, offset, decode_path, ctx, tag_only, evgen_mode):
         try:
             t, tlen, lv = tag_strip(tlv)
@@ -5853,6 +6022,15 @@ class Set(Sequence):
         ))
         return b"".join((self.tag, len_encode(len(v)), v))
 
+    def _encode_cer(self, writer):
+        write_full(writer, self.tag + LENINDEF)
+        for v in sorted(
+                self._values_for_encoding(),
+                key=attrgetter("tag_order_cer"),
+        ):
+            v.encode_cer(writer)
+        write_full(writer, EOC)
+
     def _decode(self, tlv, offset, decode_path, ctx, tag_only, evgen_mode):
         try:
             t, tlen, lv = tag_strip(tlv)
@@ -6049,9 +6227,21 @@ class SequenceOf(Obj):
     >>> ints
     Ints SEQUENCE OF[INTEGER 123, INTEGER 345]
 
-    Also you can initialize sequence with preinitialized values:
+    You can initialize sequence with preinitialized values:
 
     >>> ints = Ints([Integer(123), Integer(234)])
+
+    Also you can use iterator as a value:
+
+    >>> ints = Ints(iter(Integer(i) for i in range(1000000)))
+
+    And it won't be iterated until encoding process. Pay attention that
+    bounds and required schema checks are done only during the encoding
+    process in that case! After encode was called, then value is zeroed
+    back to empty list and you have to set it again. That mode is useful
+    mainly with CER encoding mode, where all objects from the iterable
+    will be streamed to the buffer, without copying all of them to
+    memory first.
     """
     __slots__ = ("spec", "_bound_min", "_bound_max")
     tag_default = tag_encode(form=TagFormConstructed, num=16)
@@ -6095,21 +6285,31 @@ class SequenceOf(Obj):
                 self._value = copy(default_obj._value)
 
     def _value_sanitize(self, value):
+        iterator = False
         if issubclass(value.__class__, SequenceOf):
             value = value._value
+        elif hasattr(value, NEXT_ATTR_NAME):
+            iterator = True
+            value = value
         elif hasattr(value, "__iter__"):
             value = list(value)
         else:
-            raise InvalidValueType((self.__class__, iter))
-        if not self._bound_min <= len(value) <= self._bound_max:
-            raise BoundsError(self._bound_min, len(value), self._bound_max)
-        for v in value:
-            if not isinstance(v, self.spec.__class__):
-                raise InvalidValueType((self.spec.__class__,))
+            raise InvalidValueType((self.__class__, iter, "iterator"))
+        if not iterator:
+            if not self._bound_min <= len(value) <= self._bound_max:
+                raise BoundsError(self._bound_min, len(value), self._bound_max)
+            class_expected = self.spec.__class__
+            for v in value:
+                if not isinstance(v, class_expected):
+                    raise InvalidValueType((class_expected,))
         return value
 
     @property
     def ready(self):
+        if hasattr(self._value, NEXT_ATTR_NAME):
+            return True
+        if self._bound_min > 0 and len(self._value) == 0:
+            return False
         return all(v.ready for v in self._value)
 
     @property
@@ -6119,6 +6319,8 @@ class SequenceOf(Obj):
         return any(v.bered for v in self._value)
 
     def __getstate__(self):
+        if hasattr(self._value, NEXT_ATTR_NAME):
+            raise ValueError("can not pickle SequenceOf with iterator")
         return SequenceOfState(
             __version__,
             self.tag,
@@ -6194,11 +6396,9 @@ class SequenceOf(Obj):
         self._value.append(value)
 
     def __iter__(self):
-        self._assert_ready()
         return iter(self._value)
 
     def __len__(self):
-        self._assert_ready()
         return len(self._value)
 
     def __setitem__(self, key, value):
@@ -6213,8 +6413,43 @@ class SequenceOf(Obj):
         return iter(self._value)
 
     def _encode(self):
-        v = b"".join(v.encode() for v in self._values_for_encoding())
-        return b"".join((self.tag, len_encode(len(v)), v))
+        iterator = hasattr(self._value, NEXT_ATTR_NAME)
+        if iterator:
+            values = []
+            values_append = values.append
+            class_expected = self.spec.__class__
+            values_for_encoding = self._values_for_encoding()
+            self._value = []
+            for v in values_for_encoding:
+                if not isinstance(v, class_expected):
+                    raise InvalidValueType((class_expected,))
+                values_append(v.encode())
+            if not self._bound_min <= len(values) <= self._bound_max:
+                raise BoundsError(self._bound_min, len(values), self._bound_max)
+            value = b"".join(values)
+        else:
+            value = b"".join(v.encode() for v in self._values_for_encoding())
+        return b"".join((self.tag, len_encode(len(value)), value))
+
+    def _encode_cer(self, writer):
+        write_full(writer, self.tag + LENINDEF)
+        iterator = hasattr(self._value, NEXT_ATTR_NAME)
+        if iterator:
+            class_expected = self.spec.__class__
+            values_count = 0
+            values_for_encoding = self._values_for_encoding()
+            self._value = []
+            for v in values_for_encoding:
+                if not isinstance(v, class_expected):
+                    raise InvalidValueType((class_expected,))
+                v.encode_cer(writer)
+                values_count += 1
+            if not self._bound_min <= values_count <= self._bound_max:
+                raise BoundsError(self._bound_min, values_count, self._bound_max)
+        else:
+            for v in self._values_for_encoding():
+                v.encode_cer(writer)
+        write_full(writer, EOC)
 
     def _decode(
             self,
@@ -6407,10 +6642,24 @@ class SetOf(SequenceOf):
     tag_default = tag_encode(form=TagFormConstructed, num=17)
     asn1_type_name = "SET OF"
 
+    def _value_sanitize(self, value):
+        value = super(SetOf, self)._value_sanitize(value)
+        if hasattr(value, NEXT_ATTR_NAME):
+            raise ValueError(
+                "SetOf does not support iterator values, as no sense in them"
+            )
+        return value
+
     def _encode(self):
         v = b"".join(sorted(v.encode() for v in self._values_for_encoding()))
         return b"".join((self.tag, len_encode(len(v)), v))
 
+    def _encode_cer(self, writer):
+        write_full(writer, self.tag + LENINDEF)
+        for v in sorted(encode_cer(v) for v in self._values_for_encoding()):
+            write_full(writer, v)
+        write_full(writer, EOC)
+
     def _decode(self, tlv, offset, decode_path, ctx, tag_only, evgen_mode):
         return super(SetOf, self)._decode(
             tlv,
@@ -6540,14 +6789,22 @@ def main():  # pragma: no cover
         help="Allow explicit tag out-of-bound",
     )
     parser.add_argument(
-        "DERFile",
+        "--evgen",
+        action="store_true",
+        help="Turn on event generation mode",
+    )
+    parser.add_argument(
+        "RAWFile",
         type=argparse.FileType("rb"),
-        help="Path to DER file you want to decode",
+        help="Path to BER/CER/DER file you want to decode",
     )
     args = parser.parse_args()
-    args.DERFile.seek(args.skip)
-    der = memoryview(args.DERFile.read())
-    args.DERFile.close()
+    if PY2:
+        args.RAWFile.seek(args.skip)
+        raw = memoryview(args.RAWFile.read())
+        args.RAWFile.close()
+    else:
+        raw = file_mmaped(args.RAWFile)[args.skip:]
     oid_maps = (
         [obj_by_path(_path) for _path in (args.oids or "").split(",")]
         if args.oids else ()
@@ -6564,10 +6821,9 @@ def main():  # pragma: no cover
     }
     if args.defines_by_path is not None:
         ctx["defines_by_path"] = obj_by_path(args.defines_by_path)
-    obj, tail = schema().decode(der, ctx=ctx)
     from os import environ
-    print(pprinter(
-        obj,
+    pprinter = partial(
+        pprinter,
         oid_maps=oid_maps,
         with_colours=environ.get("NO_COLOR") is None,
         with_decode_path=args.print_decode_path,
@@ -6575,7 +6831,13 @@ def main():  # pragma: no cover
             () if args.decode_path_only is None else
             tuple(args.decode_path_only.split(":"))
         ),
-    ))
+    )
+    if args.evgen:
+        for decode_path, obj, tail in schema().decode_evgen(raw, ctx=ctx):
+            print(pprinter(obj, decode_path=decode_path))
+    else:
+        obj, tail = schema().decode(raw, ctx=ctx)
+        print(pprinter(obj))
     if tail != b"":
         print("\nTrailing data: %s" % hexenc(tail))