]> Cypherpunks.ru repositories - pyderasn.git/blobdiff - pyderasn.py
Raise copyright years
[pyderasn.git] / pyderasn.py
index 4a27bbe5e1c5222388bf4fc4056b250ac6447e67..54242a3bd537fc9360217174cb79b9ccf832a73a 100755 (executable)
@@ -4,7 +4,7 @@
 # pylint: disable=line-too-long,superfluous-parens,protected-access,too-many-lines
 # pylint: disable=too-many-return-statements,too-many-branches,too-many-statements
 # PyDERASN -- Python ASN.1 DER/CER/BER codec with abstract structures
 # pylint: disable=line-too-long,superfluous-parens,protected-access,too-many-lines
 # pylint: disable=too-many-return-statements,too-many-branches,too-many-statements
 # PyDERASN -- Python ASN.1 DER/CER/BER codec with abstract structures
-# Copyright (C) 2017-2021 Sergey Matveev <stargrave@stargrave.org>
+# Copyright (C) 2017-2024 Sergey Matveev <stargrave@stargrave.org>
 #
 # This program is free software: you can redistribute it and/or modify
 # it under the terms of the GNU Lesser General Public License as
 #
 # This program is free software: you can redistribute it and/or modify
 # it under the terms of the GNU Lesser General Public License as
@@ -235,6 +235,7 @@ Currently available context options:
 * :ref:`bered <bered_ctx>`
 * :ref:`defines_by_path <defines_by_path_ctx>`
 * :ref:`evgen_mode_upto <evgen_mode_upto_ctx>`
 * :ref:`bered <bered_ctx>`
 * :ref:`defines_by_path <defines_by_path_ctx>`
 * :ref:`evgen_mode_upto <evgen_mode_upto_ctx>`
+* :ref:`keep_memoryview <keep_memoryview_ctx>`
 
 .. _pprinting:
 
 
 .. _pprinting:
 
@@ -706,6 +707,15 @@ creates read-only memoryview on the file contents::
    page cache used for mmaps. It can take twice the necessary size in
    the memory: both in page cache and ZFS ARC.
 
    page cache used for mmaps. It can take twice the necessary size in
    the memory: both in page cache and ZFS ARC.
 
+.. _keep_memoryview_ctx:
+
+That read-only memoryview could be safe to be used as a value inside
+decoded :py:class:`pyderasn.OctetString` and :py:class:`pyderasn.Any`
+objects. You can enable that by setting `"keep_memoryview": True` in
+:ref:`decode context <ctx>`. No OCTET STRING and ANY values will be
+copied to memory. Of course that works only in DER encoding, where the
+value is continuously encoded.
+
 CER encoding
 ____________
 
 CER encoding
 ____________
 
@@ -971,12 +981,12 @@ _____________
 UTCTime
 _______
 .. autoclass:: pyderasn.UTCTime
 UTCTime
 _______
 .. autoclass:: pyderasn.UTCTime
-   :members: __init__, todatetime
+   :members: __init__, todatetime, totzdatetime
 
 GeneralizedTime
 _______________
 .. autoclass:: pyderasn.GeneralizedTime
 
 GeneralizedTime
 _______________
 .. autoclass:: pyderasn.GeneralizedTime
-   :members: __init__, todatetime
+   :members: __init__, todatetime, totzdatetime
 
 Special types
 -------------
 
 Special types
 -------------
@@ -1159,8 +1169,6 @@ Now you can print only the specified tree, for example signature algorithm::
 """
 
 from array import array
 """
 
 from array import array
-from codecs import getdecoder
-from codecs import getencoder
 from collections import namedtuple
 from collections import OrderedDict
 from copy import copy
 from collections import namedtuple
 from collections import OrderedDict
 from copy import copy
@@ -1182,7 +1190,13 @@ except ImportError:  # pragma: no cover
     def colored(what, *args, **kwargs):
         return what
 
     def colored(what, *args, **kwargs):
         return what
 
-__version__ = "9.0"
+try:
+    from dateutil.tz import UTC as tzUTC
+except ImportError:  # pragma: no cover
+    tzUTC = "missing"
+
+
+__version__ = "9.3"
 
 __all__ = (
     "agg_octet_string",
 
 __all__ = (
     "agg_octet_string",
@@ -1441,20 +1455,16 @@ class BoundsError(ASN1Error):
 # Basic coders
 ########################################################################
 
 # Basic coders
 ########################################################################
 
-_hexdecoder = getdecoder("hex")
-_hexencoder = getencoder("hex")
-
-
 def hexdec(data):
     """Binary data to hexadecimal string convert
     """
 def hexdec(data):
     """Binary data to hexadecimal string convert
     """
-    return _hexdecoder(data)[0]
+    return bytes.fromhex(data)
 
 
 def hexenc(data):
     """Hexadecimal string to binary data convert
     """
 
 
 def hexenc(data):
     """Hexadecimal string to binary data convert
     """
-    return _hexencoder(data)[0].decode("ascii")
+    return data.hex()
 
 
 def int_bytes_len(num, byte_len=8):
 
 
 def int_bytes_len(num, byte_len=8):
@@ -3630,6 +3640,7 @@ class OctetString(Obj):
     tag_default = tag_encode(4)
     asn1_type_name = "OCTET STRING"
     evgen_mode_skip_value = True
     tag_default = tag_encode(4)
     asn1_type_name = "OCTET STRING"
     evgen_mode_skip_value = True
+    memoryview_safe = True
 
     def __init__(
             self,
 
     def __init__(
             self,
@@ -3726,6 +3737,10 @@ class OctetString(Obj):
         self._assert_ready()
         return bytes(self._value)
 
         self._assert_ready()
         return bytes(self._value)
 
+    def memoryview(self):
+        self._assert_ready()
+        return memoryview(self._value)
+
     def __eq__(self, their):
         if their.__class__ == bytes:
             return self._value == their
     def __eq__(self, their):
         if their.__class__ == bytes:
             return self._value == their
@@ -3839,12 +3854,15 @@ class OctetString(Obj):
                     decode_path=decode_path,
                     offset=offset,
                 )
                     decode_path=decode_path,
                     offset=offset,
                 )
+            if evgen_mode and self.evgen_mode_skip_value:
+                value = None
+            elif self.memoryview_safe and ctx.get("keep_memoryview", False):
+                value = v
+            else:
+                value = v.tobytes()
             try:
                 obj = self.__class__(
             try:
                 obj = self.__class__(
-                    value=(
-                        None if (evgen_mode and self.evgen_mode_skip_value)
-                        else v.tobytes()
-                    ),
+                    value=value,
                     bounds=(self._bound_min, self._bound_max),
                     impl=self.tag,
                     expl=self._expl,
                     bounds=(self._bound_min, self._bound_max),
                     impl=self.tag,
                     expl=self._expl,
@@ -4694,6 +4712,7 @@ class CommonString(OctetString):
          - utf-16-be
     """
     __slots__ = ()
          - utf-16-be
     """
     __slots__ = ()
+    memoryview_safe = False
 
     def _value_sanitize(self, value):
         value_raw = None
 
     def _value_sanitize(self, value):
         value_raw = None
@@ -4743,6 +4762,9 @@ class CommonString(OctetString):
             return self._value.decode(self.encoding)
         return str(self._value)
 
             return self._value.decode(self.encoding)
         return str(self._value)
 
+    def memoryview(self):
+        raise ValueError("CommonString does not support .memoryview()")
+
     def __repr__(self):
         return pp_console_row(next(self.pps()))
 
     def __repr__(self):
         return pp_console_row(next(self.pps()))
 
@@ -4784,6 +4806,8 @@ class UTF8String(CommonString):
 
 
 class AllowableCharsMixin:
 
 
 class AllowableCharsMixin:
+    __slots__ = ()
+
     @property
     def allowable_chars(self):
         return frozenset(chr(c) for c in self._allowable_chars)
     @property
     def allowable_chars(self):
         return frozenset(chr(c) for c in self._allowable_chars)
@@ -4795,6 +4819,9 @@ class AllowableCharsMixin:
         return value
 
 
         return value
 
 
+NUMERIC_ALLOWABLE_CHARS = frozenset(digits.encode("ascii") + b" ")
+
+
 class NumericString(AllowableCharsMixin, CommonString):
     """Numeric string
 
 class NumericString(AllowableCharsMixin, CommonString):
     """Numeric string
 
@@ -4804,11 +4831,14 @@ class NumericString(AllowableCharsMixin, CommonString):
     >>> NumericString().allowable_chars
     frozenset(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ' '])
     """
     >>> NumericString().allowable_chars
     frozenset(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ' '])
     """
-    __slots__ = ()
+    __slots__ = ("_allowable_chars",)
     tag_default = tag_encode(18)
     encoding = "ascii"
     asn1_type_name = "NumericString"
     tag_default = tag_encode(18)
     encoding = "ascii"
     asn1_type_name = "NumericString"
-    _allowable_chars = frozenset(digits.encode("ascii") + b" ")
+
+    def __init__(self, *args, **kwargs):
+        self._allowable_chars = NUMERIC_ALLOWABLE_CHARS
+        super().__init__(*args, **kwargs)
 
 
 PrintableStringState = namedtuple(
 
 
 PrintableStringState = namedtuple(
@@ -4818,6 +4848,11 @@ PrintableStringState = namedtuple(
 )
 
 
 )
 
 
+PRINTABLE_ALLOWABLE_CHARS = frozenset(
+    (ascii_letters + digits + " '()+,-./:=?").encode("ascii")
+)
+
+
 class PrintableString(AllowableCharsMixin, CommonString):
     """Printable string
 
 class PrintableString(AllowableCharsMixin, CommonString):
     """Printable string
 
@@ -4830,13 +4865,10 @@ class PrintableString(AllowableCharsMixin, CommonString):
     >>> obj.allow_asterisk, obj.allow_ampersand
     (True, False)
     """
     >>> obj.allow_asterisk, obj.allow_ampersand
     (True, False)
     """
-    __slots__ = ()
+    __slots__ = ("_allowable_chars",)
     tag_default = tag_encode(19)
     encoding = "ascii"
     asn1_type_name = "PrintableString"
     tag_default = tag_encode(19)
     encoding = "ascii"
     asn1_type_name = "PrintableString"
-    _allowable_chars = frozenset(
-        (ascii_letters + digits + " '()+,-./:=?").encode("ascii")
-    )
     _asterisk = frozenset("*".encode("ascii"))
     _ampersand = frozenset("&".encode("ascii"))
 
     _asterisk = frozenset("*".encode("ascii"))
     _ampersand = frozenset("&".encode("ascii"))
 
@@ -4857,11 +4889,13 @@ class PrintableString(AllowableCharsMixin, CommonString):
         :param allow_asterisk: allow asterisk character
         :param allow_ampersand: allow ampersand character
         """
         :param allow_asterisk: allow asterisk character
         :param allow_ampersand: allow ampersand character
         """
+        allowable_chars = PRINTABLE_ALLOWABLE_CHARS
         if allow_asterisk:
         if allow_asterisk:
-            self._allowable_chars |= self._asterisk
+            allowable_chars |= self._asterisk
         if allow_ampersand:
         if allow_ampersand:
-            self._allowable_chars |= self._ampersand
-        super(PrintableString, self).__init__(
+            allowable_chars |= self._ampersand
+        self._allowable_chars = allowable_chars
+        super().__init__(
             value, bounds, impl, expl, default, optional, _decoded, ctx,
         )
 
             value, bounds, impl, expl, default, optional, _decoded, ctx,
         )
 
@@ -4930,6 +4964,11 @@ class VideotexString(CommonString):
     asn1_type_name = "VideotexString"
 
 
     asn1_type_name = "VideotexString"
 
 
+IA5_ALLOWABLE_CHARS = frozenset(b"".join(
+    chr(c).encode("ascii") for c in range(128)
+))
+
+
 class IA5String(AllowableCharsMixin, CommonString):
     """IA5 string
 
 class IA5String(AllowableCharsMixin, CommonString):
     """IA5 string
 
@@ -4944,13 +4983,14 @@ class IA5String(AllowableCharsMixin, CommonString):
     >>> IA5String().allowable_chars
     frozenset(["NUL", ... "DEL"])
     """
     >>> IA5String().allowable_chars
     frozenset(["NUL", ... "DEL"])
     """
-    __slots__ = ()
+    __slots__ = ("_allowable_chars",)
     tag_default = tag_encode(22)
     encoding = "ascii"
     asn1_type_name = "IA5"
     tag_default = tag_encode(22)
     encoding = "ascii"
     asn1_type_name = "IA5"
-    _allowable_chars = frozenset(b"".join(
-        chr(c).encode("ascii") for c in range(128)
-    ))
+
+    def __init__(self, *args, **kwargs):
+        self._allowable_chars = IA5_ALLOWABLE_CHARS
+        super().__init__(*args, **kwargs)
 
 
 LEN_YYMMDDHHMMSSZ = len("YYMMDDHHMMSSZ")
 
 
 LEN_YYMMDDHHMMSSZ = len("YYMMDDHHMMSSZ")
@@ -4961,6 +5001,11 @@ LEN_YYYYMMDDHHMMSSZ = len("YYYYMMDDHHMMSSZ")
 LEN_LEN_YYYYMMDDHHMMSSZ = len_encode(LEN_YYYYMMDDHHMMSSZ)
 
 
 LEN_LEN_YYYYMMDDHHMMSSZ = len_encode(LEN_YYYYMMDDHHMMSSZ)
 
 
+VISIBLE_ALLOWABLE_CHARS = frozenset(b"".join(
+    chr(c).encode("ascii") for c in range(ord(" "), ord("~") + 1)
+))
+
+
 class VisibleString(AllowableCharsMixin, CommonString):
     """Visible string
 
 class VisibleString(AllowableCharsMixin, CommonString):
     """Visible string
 
@@ -4970,13 +5015,14 @@ class VisibleString(AllowableCharsMixin, CommonString):
     >>> VisibleString().allowable_chars
     frozenset([" ", ... "~"])
     """
     >>> VisibleString().allowable_chars
     frozenset([" ", ... "~"])
     """
-    __slots__ = ()
+    __slots__ = ("_allowable_chars",)
     tag_default = tag_encode(26)
     encoding = "ascii"
     asn1_type_name = "VisibleString"
     tag_default = tag_encode(26)
     encoding = "ascii"
     asn1_type_name = "VisibleString"
-    _allowable_chars = frozenset(b"".join(
-        chr(c).encode("ascii") for c in range(ord(" "), ord("~") + 1)
-    ))
+
+    def __init__(self, *args, **kwargs):
+        self._allowable_chars = VISIBLE_ALLOWABLE_CHARS
+        super().__init__(*args, **kwargs)
 
 
 class ISO646String(VisibleString):
 
 
 class ISO646String(VisibleString):
@@ -5014,6 +5060,8 @@ class UTCTime(VisibleString):
     datetime.datetime(2017, 9, 30, 22, 7, 50)
     >>> UTCTime(datetime(2057, 9, 30, 22, 7, 50)).todatetime()
     datetime.datetime(1957, 9, 30, 22, 7, 50)
     datetime.datetime(2017, 9, 30, 22, 7, 50)
     >>> UTCTime(datetime(2057, 9, 30, 22, 7, 50)).todatetime()
     datetime.datetime(1957, 9, 30, 22, 7, 50)
+    >>> UTCTime(datetime(2057, 9, 30, 22, 7, 50)).totzdatetime()
+    datetime.datetime(1957, 9, 30, 22, 7, 50, tzinfo=tzutc())
 
     If BER encoded value was met, then ``ber_raw`` attribute will hold
     its raw representation.
 
     If BER encoded value was met, then ``ber_raw`` attribute will hold
     its raw representation.
@@ -5223,6 +5271,12 @@ class UTCTime(VisibleString):
     def todatetime(self):
         return self._value
 
     def todatetime(self):
         return self._value
 
+    def totzdatetime(self):
+        try:
+            return self._value.replace(tzinfo=tzUTC)
+        except TypeError as err:
+            raise NotImplementedError("Missing dateutil.tz") from err
+
     def __repr__(self):
         return pp_console_row(next(self.pps()))
 
     def __repr__(self):
         return pp_console_row(next(self.pps()))
 
@@ -5810,7 +5864,7 @@ class Any(Obj):
             value = self._value_sanitize(value)
             self._value = value
             if self._expl is None:
             value = self._value_sanitize(value)
             self._value = value
             if self._expl is None:
-                if value.__class__ == bytes:
+                if value.__class__ == bytes or value.__class__ == memoryview:
                     tag_class, _, tag_num = tag_decode(tag_strip(value)[0])
                 else:
                     tag_class, tag_num = value.tag_order
                     tag_class, _, tag_num = tag_decode(tag_strip(value)[0])
                 else:
                     tag_class, tag_num = value.tag_order
@@ -5820,7 +5874,7 @@ class Any(Obj):
         self.defined = None
 
     def _value_sanitize(self, value):
         self.defined = None
 
     def _value_sanitize(self, value):
-        if value.__class__ == bytes:
+        if value.__class__ == bytes or value.__class__ == memoryview:
             if len(value) == 0:
                 raise ValueError("%s value can not be empty" % self.__class__.__name__)
             return value
             if len(value) == 0:
                 raise ValueError("%s value can not be empty" % self.__class__.__name__)
             return value
@@ -5871,13 +5925,13 @@ class Any(Obj):
         self.defined = state.defined
 
     def __eq__(self, their):
         self.defined = state.defined
 
     def __eq__(self, their):
-        if their.__class__ == bytes:
-            if self._value.__class__ == bytes:
+        if their.__class__ == bytes or their.__class__ == memoryview:
+            if self._value.__class__ == bytes or their.__class__ == memoryview:
                 return self._value == their
             return self._value.encode() == their
         if issubclass(their.__class__, Any):
             if self.ready and their.ready:
                 return self._value == their
             return self._value.encode() == their
         if issubclass(their.__class__, Any):
             if self.ready and their.ready:
-                return bytes(self) == bytes(their)
+                return self.memoryview() == their.memoryview()
             return self.ready == their.ready
         return False
 
             return self.ready == their.ready
         return False
 
@@ -5898,8 +5952,17 @@ class Any(Obj):
         value = self._value
         if value.__class__ == bytes:
             return value
         value = self._value
         if value.__class__ == bytes:
             return value
+        if value.__class__ == memoryview:
+            return bytes(value)
         return self._value.encode()
 
         return self._value.encode()
 
+    def memoryview(self):
+        self._assert_ready()
+        value = self._value
+        if value.__class__ == memoryview:
+            return memoryview(value)
+        return memoryview(bytes(self))
+
     @property
     def tlen(self):
         return 0
     @property
     def tlen(self):
         return 0
@@ -5907,20 +5970,20 @@ class Any(Obj):
     def _encode(self):
         self._assert_ready()
         value = self._value
     def _encode(self):
         self._assert_ready()
         value = self._value
-        if value.__class__ == bytes:
-            return value
+        if value.__class__ == bytes or value.__class__ == memoryview:
+            return bytes(self)
         return value.encode()
 
     def _encode1st(self, state):
         self._assert_ready()
         value = self._value
         return value.encode()
 
     def _encode1st(self, state):
         self._assert_ready()
         value = self._value
-        if value.__class__ == bytes:
+        if value.__class__ == bytes or value.__class__ == memoryview:
             return len(value), state
         return value.encode1st(state)
 
     def _encode2nd(self, writer, state_iter):
         value = self._value
             return len(value), state
         return value.encode1st(state)
 
     def _encode2nd(self, writer, state_iter):
         value = self._value
-        if value.__class__ == bytes:
+        if value.__class__ == bytes or value.__class__ == memoryview:
             write_full(writer, value)
         else:
             value.encode2nd(writer, state_iter)
             write_full(writer, value)
         else:
             value.encode2nd(writer, state_iter)
@@ -5928,7 +5991,7 @@ class Any(Obj):
     def _encode_cer(self, writer):
         self._assert_ready()
         value = self._value
     def _encode_cer(self, writer):
         self._assert_ready()
         value = self._value
-        if value.__class__ == bytes:
+        if value.__class__ == bytes or value.__class__ == memoryview:
             write_full(writer, value)
         else:
             value.encode_cer(writer)
             write_full(writer, value)
         else:
             value.encode_cer(writer)
@@ -5995,8 +6058,14 @@ class Any(Obj):
             )
         tlvlen = tlen + llen + l
         v, tail = tlv[:tlvlen], v[l:]
             )
         tlvlen = tlen + llen + l
         v, tail = tlv[:tlvlen], v[l:]
+        if evgen_mode:
+            value = None
+        elif ctx.get("keep_memoryview", False):
+            value = v
+        else:
+            value = v.tobytes()
         obj = self.__class__(
         obj = self.__class__(
-            value=None if evgen_mode else v.tobytes(),
+            value=value,
             expl=self._expl,
             optional=self.optional,
             _decoded=(offset, 0, tlvlen),
             expl=self._expl,
             optional=self.optional,
             _decoded=(offset, 0, tlvlen),
@@ -6011,7 +6080,7 @@ class Any(Obj):
         value = self._value
         if value is None:
             pass
         value = self._value
         if value is None:
             pass
-        elif value.__class__ == bytes:
+        elif value.__class__ == bytes or value.__class__ == memoryview:
             value = None
         else:
             value = repr(value)
             value = None
         else:
             value = repr(value)
@@ -6021,7 +6090,10 @@ class Any(Obj):
             obj_name=self.__class__.__name__,
             decode_path=decode_path,
             value=value,
             obj_name=self.__class__.__name__,
             decode_path=decode_path,
             value=value,
-            blob=self._value if self._value.__class__ == bytes else None,
+            blob=self._value if (
+                self._value.__class__ == bytes or
+                value.__class__ == memoryview
+            ) else None,
             optional=self.optional,
             default=self == self.default,
             impl=None if self.tag == self.tag_default else tag_decode(self.tag),
             optional=self.optional,
             default=self == self.default,
             impl=None if self.tag == self.tag_default else tag_decode(self.tag),
@@ -6083,7 +6155,9 @@ SequenceState = namedtuple(
 )
 
 
 )
 
 
-class SequenceEncode1stMixing:
+class SequenceEncode1stMixin:
+    __slots__ = ()
+
     def _encode1st(self, state):
         state.append(0)
         idx = len(state) - 1
     def _encode1st(self, state):
         state.append(0)
         idx = len(state) - 1
@@ -6095,7 +6169,7 @@ class SequenceEncode1stMixing:
         return len(self.tag) + len_size(vlen) + vlen, state
 
 
         return len(self.tag) + len_size(vlen) + vlen, state
 
 
-class Sequence(SequenceEncode1stMixing, Obj):
+class Sequence(SequenceEncode1stMixin, Obj):
     """``SEQUENCE`` structure type
 
     You have to make specification of sequence::
     """``SEQUENCE`` structure type
 
     You have to make specification of sequence::
@@ -6593,7 +6667,7 @@ class Sequence(SequenceEncode1stMixing, Obj):
             yield pp
 
 
             yield pp
 
 
-class Set(Sequence, SequenceEncode1stMixing):
+class Set(Sequence, SequenceEncode1stMixin):
     """``SET`` structure type
 
     Its usage is identical to :py:class:`pyderasn.Sequence`.
     """``SET`` structure type
 
     Its usage is identical to :py:class:`pyderasn.Sequence`.
@@ -6796,7 +6870,7 @@ SequenceOfState = namedtuple(
 )
 
 
 )
 
 
-class SequenceOf(SequenceEncode1stMixing, Obj):
+class SequenceOf(SequenceEncode1stMixin, Obj):
     """``SEQUENCE OF`` sequence type
 
     For that kind of type you must specify the object it will carry on
     """``SEQUENCE OF`` sequence type
 
     For that kind of type you must specify the object it will carry on