]> Cypherpunks.ru repositories - pyderasn.git/blobdiff - pyderasn.py
Ability to set values during Sequence initialization
[pyderasn.git] / pyderasn.py
index bf0e4857a63d40a0940a01983f1f8b4a4b5eea36..68d67d07e91a746886d22bb629d19a8ae6000f43 100755 (executable)
@@ -269,11 +269,11 @@ for AlgorithmIdentifier of X.509's
 ``tbsCertificate.subjectPublicKeyInfo.algorithm.algorithm``::
 
         (
-            (('parameters',), {
+            (("parameters",), {
                 id_ecPublicKey: ECParameters(),
                 id_GostR3410_2001: GostR34102001PublicKeyParameters(),
             }),
-            (('..', 'subjectPublicKey'), {
+            (("..", "subjectPublicKey"), {
                 id_rsaEncryption: RSAPublicKey(),
                 id_GostR3410_2001: OctetString(),
             }),
@@ -289,7 +289,7 @@ Following types can be automatically decoded (DEFINED BY):
 * :py:class:`pyderasn.BitString` (that is multiple of 8 bits)
 * :py:class:`pyderasn.OctetString`
 * :py:class:`pyderasn.SequenceOf`/:py:class:`pyderasn.SetOf`
-  ``Any``/``OctetString``-s
+  ``Any``/``BitString``/``OctetString``-s
 
 When any of those fields is automatically decoded, then ``.defined``
 attribute contains ``(OID, value)`` tuple. ``OID`` tells by which OID it
@@ -473,6 +473,7 @@ from collections import OrderedDict
 from datetime import datetime
 from math import ceil
 from os import environ
+from string import digits
 
 from six import add_metaclass
 from six import binary_type
@@ -911,7 +912,7 @@ class Obj(object):
     def _encode(self):  # pragma: no cover
         raise NotImplementedError()
 
-    def _decode(self, tlv, offset, decode_path, ctx):  # pragma: no cover
+    def _decode(self, tlv, offset, decode_path, ctx, tag_only):  # pragma: no cover
         raise NotImplementedError()
 
     def encode(self):
@@ -920,7 +921,15 @@ class Obj(object):
             return raw
         return b"".join((self._expl, len_encode(len(raw)), raw))
 
-    def decode(self, data, offset=0, leavemm=False, decode_path=(), ctx=None):
+    def decode(
+            self,
+            data,
+            offset=0,
+            leavemm=False,
+            decode_path=(),
+            ctx=None,
+            tag_only=False,
+    ):
         """Decode the data
 
         :param data: either binary or memoryview
@@ -928,18 +937,25 @@ class Obj(object):
         :param bool leavemm: do we need to leave memoryview of remaining
                     data as is, or convert it to bytes otherwise
         :param ctx: optional :ref:`context <ctx>` governing decoding process.
+        :param tag_only: decode only the tag, without length and contents
+                         (used only in Choice and Set structures, trying to
+                         determine if tag satisfies the scheme)
         :returns: (Obj, remaining data)
         """
         if ctx is None:
             ctx = {}
         tlv = memoryview(data)
         if self._expl is None:
-            obj, tail = self._decode(
+            result = self._decode(
                 tlv,
                 offset,
                 decode_path=decode_path,
                 ctx=ctx,
+                tag_only=tag_only,
             )
+            if tag_only:
+                return
+            obj, tail = result
         else:
             try:
                 t, tlen, lv = tag_strip(tlv)
@@ -972,12 +988,16 @@ class Obj(object):
                     decode_path=decode_path,
                     offset=offset,
                 )
-            obj, tail = self._decode(
+            result = self._decode(
                 v,
                 offset=offset + tlen + llen,
                 decode_path=decode_path,
                 ctx=ctx,
+                tag_only=tag_only,
             )
+            if tag_only:
+                return
+            obj, tail = result
         return obj, (tail if leavemm else tail.tobytes())
 
     @property
@@ -1012,11 +1032,14 @@ class Obj(object):
 class DecodePathDefBy(object):
     """DEFINED BY representation inside decode path
     """
-    __slots__ = ('defined_by',)
+    __slots__ = ("defined_by",)
 
     def __init__(self, defined_by):
         self.defined_by = defined_by
 
+    def __ne__(self, their):
+        return not(self == their)
+
     def __eq__(self, their):
         if not isinstance(their, self.__class__):
             return False
@@ -1333,7 +1356,7 @@ class Boolean(Obj):
             (b"\xFF" if self._value else b"\x00"),
         ))
 
-    def _decode(self, tlv, offset, decode_path, ctx):
+    def _decode(self, tlv, offset, decode_path, ctx, tag_only):
         try:
             t, _, lv = tag_strip(tlv)
         except DecodeError as err:
@@ -1349,6 +1372,8 @@ class Boolean(Obj):
                 decode_path=decode_path,
                 offset=offset,
             )
+        if tag_only:
+            return
         try:
             l, _, v = len_decode(lv)
         except DecodeError as err:
@@ -1626,7 +1651,7 @@ class Integer(Obj):
                     break
         return b"".join((self.tag, len_encode(len(octets)), octets))
 
-    def _decode(self, tlv, offset, decode_path, ctx):
+    def _decode(self, tlv, offset, decode_path, ctx, tag_only):
         try:
             t, _, lv = tag_strip(tlv)
         except DecodeError as err:
@@ -1642,6 +1667,8 @@ class Integer(Obj):
                 decode_path=decode_path,
                 offset=offset,
             )
+        if tag_only:
+            return
         try:
             l, llen, v = len_decode(lv)
         except DecodeError as err:
@@ -1765,12 +1792,12 @@ class BitString(Obj):
 
         class KeyUsage(BitString):
             schema = (
-                ('digitalSignature', 0),
-                ('nonRepudiation', 1),
-                ('keyEncipherment', 2),
+                ("digitalSignature", 0),
+                ("nonRepudiation", 1),
+                ("keyEncipherment", 2),
             )
 
-    >>> b = KeyUsage(('keyEncipherment', 'nonRepudiation'))
+    >>> b = KeyUsage(("keyEncipherment", "nonRepudiation"))
     KeyUsage BIT STRING 3 bits nonRepudiation, keyEncipherment
     >>> b.named
     ['nonRepudiation', 'keyEncipherment']
@@ -1960,7 +1987,7 @@ class BitString(Obj):
             octets,
         ))
 
-    def _decode(self, tlv, offset, decode_path, ctx):
+    def _decode(self, tlv, offset, decode_path, ctx, tag_only):
         try:
             t, _, lv = tag_strip(tlv)
         except DecodeError as err:
@@ -1976,6 +2003,8 @@ class BitString(Obj):
                 decode_path=decode_path,
                 offset=offset,
             )
+        if tag_only:
+            return
         try:
             l, llen, v = len_decode(lv)
         except DecodeError as err:
@@ -2014,7 +2043,7 @@ class BitString(Obj):
                 decode_path=decode_path,
                 offset=offset,
             )
-        if byte2int(v[-1:]) & ((1 << pad_size) - 1) != 0:
+        if byte2int(v[l - 1:l]) & ((1 << pad_size) - 1) != 0:
             raise DecodeError(
                 "invalid pad",
                 klass=self.__class__,
@@ -2212,7 +2241,7 @@ class OctetString(Obj):
             self._value,
         ))
 
-    def _decode(self, tlv, offset, decode_path, ctx):
+    def _decode(self, tlv, offset, decode_path, ctx, tag_only):
         try:
             t, _, lv = tag_strip(tlv)
         except DecodeError as err:
@@ -2228,6 +2257,8 @@ class OctetString(Obj):
                 decode_path=decode_path,
                 offset=offset,
             )
+        if tag_only:
+            return
         try:
             l, llen, v = len_decode(lv)
         except DecodeError as err:
@@ -2255,6 +2286,13 @@ class OctetString(Obj):
                 optional=self.optional,
                 _decoded=(offset, llen, l),
             )
+        except DecodeError as err:
+            raise DecodeError(
+                msg=err.msg,
+                klass=self.__class__,
+                decode_path=decode_path,
+                offset=offset,
+            )
         except BoundsError as err:
             raise DecodeError(
                 msg=str(err),
@@ -2361,7 +2399,7 @@ class Null(Obj):
     def _encode(self):
         return self.tag + len_encode(0)
 
-    def _decode(self, tlv, offset, decode_path, ctx):
+    def _decode(self, tlv, offset, decode_path, ctx, tag_only):
         try:
             t, _, lv = tag_strip(tlv)
         except DecodeError as err:
@@ -2377,6 +2415,8 @@ class Null(Obj):
                 decode_path=decode_path,
                 offset=offset,
             )
+        if tag_only:
+            return
         try:
             l, _, v = len_decode(lv)
         except DecodeError as err:
@@ -2606,7 +2646,7 @@ class ObjectIdentifier(Obj):
         v = b"".join(octets)
         return b"".join((self.tag, len_encode(len(v)), v))
 
-    def _decode(self, tlv, offset, decode_path, ctx):
+    def _decode(self, tlv, offset, decode_path, ctx, tag_only):
         try:
             t, _, lv = tag_strip(tlv)
         except DecodeError as err:
@@ -2622,6 +2662,8 @@ class ObjectIdentifier(Obj):
                 decode_path=decode_path,
                 offset=offset,
             )
+        if tag_only:
+            return
         try:
             l, llen, v = len_decode(lv)
         except DecodeError as err:
@@ -2811,7 +2853,7 @@ class CommonString(OctetString):
 
     >>> PrintableString("привет мир")
     Traceback (most recent call last):
-    UnicodeEncodeError: 'ascii' codec can't encode characters in position 0-5: ordinal not in range(128)
+    pyderasn.DecodeError: 'ascii' codec can't encode characters in position 0-5: ordinal not in range(128)
 
     >>> BMPString("ада", bounds=(2, 2))
     Traceback (most recent call last):
@@ -2867,14 +2909,17 @@ class CommonString(OctetString):
             value_raw = value
         else:
             raise InvalidValueType((self.__class__, text_type, binary_type))
-        value_raw = (
-            value_decoded.encode(self.encoding)
-            if value_raw is None else value_raw
-        )
-        value_decoded = (
-            value_raw.decode(self.encoding)
-            if value_decoded is None else value_decoded
-        )
+        try:
+            value_raw = (
+                value_decoded.encode(self.encoding)
+                if value_raw is None else value_raw
+            )
+            value_decoded = (
+                value_raw.decode(self.encoding)
+                if value_decoded is None else value_decoded
+            )
+        except (UnicodeEncodeError, UnicodeDecodeError) as err:
+            raise DecodeError(str(err))
         if not self._bound_min <= len(value_decoded) <= self._bound_max:
             raise BoundsError(
                 self._bound_min,
@@ -2936,6 +2981,13 @@ class NumericString(CommonString):
     tag_default = tag_encode(18)
     encoding = "ascii"
     asn1_type_name = "NumericString"
+    allowable_chars = set(digits.encode("ascii"))
+
+    def _value_sanitize(self, value):
+        value = super(NumericString, self)._value_sanitize(value)
+        if not set(value) <= self.allowable_chars:
+            raise DecodeError("non-numeric value")
+        return value
 
 
 class PrintableString(CommonString):
@@ -3210,8 +3262,8 @@ class Choice(Obj):
 
         class GeneralName(Choice):
             schema = (
-                ('rfc822Name', IA5String(impl=tag_ctxp(1))),
-                ('dNSName', IA5String(impl=tag_ctxp(2))),
+                ("rfc822Name", IA5String(impl=tag_ctxp(1))),
+                ("dNSName", IA5String(impl=tag_ctxp(2))),
             )
 
     >>> gn = GeneralName()
@@ -3373,32 +3425,45 @@ class Choice(Obj):
         self._assert_ready()
         return self._value[1].encode()
 
-    def _decode(self, tlv, offset, decode_path, ctx):
+    def _decode(self, tlv, offset, decode_path, ctx, tag_only):
         for choice, spec in self.specs.items():
+            sub_decode_path = decode_path + (choice,)
             try:
-                value, tail = spec.decode(
+                spec.decode(
                     tlv,
                     offset=offset,
                     leavemm=True,
-                    decode_path=decode_path + (choice,),
+                    decode_path=sub_decode_path,
                     ctx=ctx,
+                    tag_only=True,
                 )
             except TagMismatch:
                 continue
-            obj = self.__class__(
-                schema=self.specs,
-                expl=self._expl,
-                default=self.default,
-                optional=self.optional,
-                _decoded=(offset, 0, value.tlvlen),
+            break
+        else:
+            raise TagMismatch(
+                klass=self.__class__,
+                decode_path=decode_path,
+                offset=offset,
             )
-            obj._value = (choice, value)
-            return obj, tail
-        raise TagMismatch(
-            klass=self.__class__,
-            decode_path=decode_path,
+        if tag_only:
+            return
+        value, tail = spec.decode(
+            tlv,
             offset=offset,
+            leavemm=True,
+            decode_path=sub_decode_path,
+            ctx=ctx,
         )
+        obj = self.__class__(
+            schema=self.specs,
+            expl=self._expl,
+            default=self.default,
+            optional=self.optional,
+            _decoded=(offset, 0, value.tlvlen),
+        )
+        obj._value = (choice, value)
+        return obj, tail
 
     def __repr__(self):
         value = pp_console_row(next(self.pps()))
@@ -3548,7 +3613,7 @@ class Any(Obj):
         self._assert_ready()
         return self._value
 
-    def _decode(self, tlv, offset, decode_path, ctx):
+    def _decode(self, tlv, offset, decode_path, ctx, tag_only):
         try:
             t, tlen, lv = tag_strip(tlv)
             l, llen, v = len_decode(lv)
@@ -3675,7 +3740,7 @@ class Sequence(Obj):
     pyderasn.InvalidValueType: invalid value type, expected: <class 'pyderasn.ObjectIdentifier'>
     >>> ext["extnID"] = ObjectIdentifier("1.2.3")
 
-    You can know if sequence is ready to be encoded:
+    You can determine if sequence is ready to be encoded:
 
     >>> ext.ready
     False
@@ -3701,7 +3766,15 @@ class Sequence(Obj):
 
     Assign ``None`` to remove value from sequence.
 
-    You can know if value exists/set in the sequence and take its value:
+    You can set values in Sequence during its initialization:
+
+    >>> AlgorithmIdentifier((
+        ("algorithm", ObjectIdentifier("1.2.3")),
+        ("parameters", Any(Null()))
+    ))
+    AlgorithmIdentifier SEQUENCE[OBJECT IDENTIFIER 1.2.3, ANY 0500 OPTIONAL]
+
+    You can determine if value exists/set in the sequence and take its value:
 
     >>> "extnID" in ext, "extnValue" in ext, "critical" in ext
     (True, True, False)
@@ -3758,9 +3831,17 @@ class Sequence(Obj):
         )
         self._value = {}
         if value is not None:
-            self._value = self._value_sanitize(value)
+            if issubclass(value.__class__, Sequence):
+                self._value = value._value
+            elif hasattr(value, "__iter__"):
+                for seq_key, seq_value in value:
+                    self[seq_key] = seq_value
+            else:
+                raise InvalidValueType((Sequence,))
         if default is not None:
-            default_value = self._value_sanitize(default)
+            if not issubclass(default.__class__, Sequence):
+                raise InvalidValueType((Sequence,))
+            default_value = default._value
             default_obj = self.__class__(impl=self.tag, expl=self._expl)
             default_obj.specs = self.specs
             default_obj._value = default_value
@@ -3768,11 +3849,6 @@ class Sequence(Obj):
             if value is None:
                 self._value = default_obj.copy()._value
 
-    def _value_sanitize(self, value):
-        if not issubclass(value.__class__, Sequence):
-            raise InvalidValueType((Sequence,))
-        return value._value
-
     @property
     def ready(self):
         for name, spec in self.specs.items():
@@ -3869,7 +3945,7 @@ class Sequence(Obj):
         v = b"".join(self._encoded_values())
         return b"".join((self.tag, len_encode(len(v)), v))
 
-    def _decode(self, tlv, offset, decode_path, ctx):
+    def _decode(self, tlv, offset, decode_path, ctx, tag_only):
         try:
             t, tlen, lv = tag_strip(tlv)
         except DecodeError as err:
@@ -3885,6 +3961,8 @@ class Sequence(Obj):
                 decode_path=decode_path,
                 offset=offset,
             )
+        if tag_only:
+            return
         try:
             l, llen, v = len_decode(lv)
         except DecodeError as err:
@@ -4063,7 +4141,7 @@ class Set(Sequence):
         v = b"".join(raws)
         return b"".join((self.tag, len_encode(len(v)), v))
 
-    def _decode(self, tlv, offset, decode_path, ctx):
+    def _decode(self, tlv, offset, decode_path, ctx, tag_only):
         try:
             t, tlen, lv = tag_strip(tlv)
         except DecodeError as err:
@@ -4079,6 +4157,8 @@ class Set(Sequence):
                 decode_path=decode_path,
                 offset=offset,
             )
+        if tag_only:
+            return
         try:
             l, llen, v = len_decode(lv)
         except DecodeError as err:
@@ -4100,23 +4180,18 @@ class Set(Sequence):
         specs_items = self.specs.items
         while len(v) > 0:
             for name, spec in specs_items():
+                sub_decode_path = decode_path + (name,)
                 try:
-                    value, v_tail = spec.decode(
+                    spec.decode(
                         v,
                         sub_offset,
                         leavemm=True,
-                        decode_path=decode_path + (name,),
+                        decode_path=sub_decode_path,
                         ctx=ctx,
+                        tag_only=True,
                     )
                 except TagMismatch:
                     continue
-                sub_offset += (
-                    value.expl_tlvlen if value.expled else value.tlvlen
-                )
-                v = v_tail
-                if spec.default is None or value != spec.default:  # pragma: no cover
-                    # SeqMixing.test_encoded_default_accepted covers that place
-                    values[name] = value
                 break
             else:
                 raise TagMismatch(
@@ -4124,6 +4199,20 @@ class Set(Sequence):
                     decode_path=decode_path,
                     offset=offset,
                 )
+            value, v_tail = spec.decode(
+                v,
+                sub_offset,
+                leavemm=True,
+                decode_path=sub_decode_path,
+                ctx=ctx,
+            )
+            sub_offset += (
+                value.expl_tlvlen if value.expled else value.tlvlen
+            )
+            v = v_tail
+            if spec.default is None or value != spec.default:  # pragma: no cover
+                # SeqMixing.test_encoded_default_accepted covers that place
+                values[name] = value
         obj = self.__class__(
             schema=self.specs,
             impl=self.tag,
@@ -4316,7 +4405,7 @@ class SequenceOf(Obj):
         v = b"".join(self._encoded_values())
         return b"".join((self.tag, len_encode(len(v)), v))
 
-    def _decode(self, tlv, offset, decode_path, ctx):
+    def _decode(self, tlv, offset, decode_path, ctx, tag_only):
         try:
             t, tlen, lv = tag_strip(tlv)
         except DecodeError as err:
@@ -4332,6 +4421,8 @@ class SequenceOf(Obj):
                 decode_path=decode_path,
                 offset=offset,
             )
+        if tag_only:
+            return
         try:
             l, llen, v = len_decode(lv)
         except DecodeError as err: