]> Cypherpunks.ru repositories - pyderasn.git/blobdiff - tests/test_pyderasn.py
OID test vector from Go
[pyderasn.git] / tests / test_pyderasn.py
index b16c6f6cc7fd318c6534599d3887ef1415688dbc..73eea21c6151a8ac3979f103eb42c3b7632cf0b5 100644 (file)
@@ -33,6 +33,7 @@ from time import mktime
 from time import time
 from unittest import TestCase
 
+from dateutil.tz import UTC
 from hypothesis import assume
 from hypothesis import given
 from hypothesis import settings
@@ -276,6 +277,19 @@ class TestTagCoder(TestCase):
         with self.assertRaises(DecodeError):
             len_decode(octets)
 
+    @given(tag_classes, tag_forms, integers(min_value=31))
+    def test_leading_zero_byte(self, klass, form, num):
+        raw = tag_encode(klass=klass, form=form, num=num)
+        raw = b"".join((raw[:1], b"\x80", raw[1:]))
+        with assertRaisesRegex(self, DecodeError, "leading zero byte"):
+            tag_strip(raw)
+
+    @given(tag_classes, tag_forms, integers(max_value=30, min_value=0))
+    def test_unexpected_long_form(self, klass, form, num):
+        raw = int2byte(klass | form | 31) + int2byte(num)
+        with assertRaisesRegex(self, DecodeError, "unexpected long form"):
+            tag_strip(raw)
+
 
 class TestLenCoder(TestCase):
     @settings(max_examples=LONG_TEST_MAX_EXAMPLES)
@@ -2662,7 +2676,6 @@ class TestNull(CommonMixin, TestCase):
             repr(obj)
             list(obj.pps())
 
-
     @given(integers(min_value=1))
     def test_invalid_len(self, l):
         with self.assertRaises(InvalidLength):
@@ -3069,6 +3082,10 @@ class TestObjectIdentifier(CommonMixin, TestCase):
                 data,
             )))
 
+    def test_go_non_minimal_encoding(self):
+        with self.assertRaises(DecodeError):
+            ObjectIdentifier().decode(hexdec("060a2a80864886f70d01010b"))
+
     def test_x690_vector(self):
         self.assertEqual(
             ObjectIdentifier().decode(hexdec("0603883703"))[0],
@@ -3463,9 +3480,7 @@ class StringMixin(object):
         repr(err.exception)
 
     def text_alphabet(self):
-        if self.base_klass.encoding in ("ascii", "iso-8859-1"):
-            return printable + whitespace
-        return None
+        return "".join(six_unichr(c) for c in six_xrange(256))
 
     @given(booleans())
     def test_optional(self, optional):
@@ -3829,7 +3844,7 @@ class TestNumericString(StringMixin, CommonMixin, TestCase):
 
     @given(text(alphabet=ascii_letters, min_size=1, max_size=5))
     def test_non_numeric(self, non_numeric_text):
-        with assertRaisesRegex(self, DecodeError, "non-numeric"):
+        with assertRaisesRegex(self, DecodeError, "alphabet value"):
             self.base_klass(non_numeric_text)
 
     @given(
@@ -3879,7 +3894,7 @@ class TestPrintableString(
 
     @given(text(alphabet=sorted(set(whitespace) - set(" ")), min_size=1, max_size=5))
     def test_non_printable(self, non_printable_text):
-        with assertRaisesRegex(self, DecodeError, "non-printable"):
+        with assertRaisesRegex(self, DecodeError, "alphabet value"):
             self.base_klass(non_printable_text)
 
     @given(
@@ -3913,7 +3928,7 @@ class TestPrintableString(
             for prop in kwargs.keys():
                 self.assertFalse(getattr(obj, prop))
             s += c
-            with assertRaisesRegex(self, DecodeError, "non-printable"):
+            with assertRaisesRegex(self, DecodeError, "alphabet value"):
                 self.base_klass(s)
             self.base_klass(s, **kwargs)
             klass = self.base_klass(**kwargs)
@@ -3952,6 +3967,18 @@ class TestIA5String(
 ):
     base_klass = IA5String
 
+    def text_alphabet(self):
+        return "".join(six_unichr(c) for c in six_xrange(128))
+
+    @given(integers(min_value=128, max_value=255))
+    def test_alphabet_bad(self, code):
+        with self.assertRaises(DecodeError):
+            self.base_klass().decod(
+                self.base_klass.tag_default +
+                len_encode(1) +
+                bytes(bytearray([code])),
+            )
+
 
 class TestGraphicString(
         UnicodeDecodeErrorMixin,
@@ -3970,6 +3997,9 @@ class TestVisibleString(
 ):
     base_klass = VisibleString
 
+    def text_alphabet(self):
+        return " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~"
+
     def test_x690_vector(self):
         obj, tail = VisibleString().decode(hexdec("1A054A6F6E6573"))
         self.assertSequenceEqual(tail, b"")
@@ -4006,6 +4036,38 @@ class TestVisibleString(
         self.assertTrue(obj.lenindef)
         self.assertTrue(obj.bered)
 
+    @given(one_of((
+        integers(min_value=0, max_value=ord(" ") - 1),
+        integers(min_value=ord("~") + 1, max_value=255),
+    )))
+    def test_alphabet_bad(self, code):
+        with self.assertRaises(DecodeError):
+            self.base_klass().decod(
+                self.base_klass.tag_default +
+                len_encode(1) +
+                bytes(bytearray([code])),
+            )
+
+    @given(
+        sets(integers(min_value=0, max_value=10), min_size=2, max_size=2),
+        integers(min_value=0),
+        decode_path_strat,
+    )
+    def test_invalid_bounds_while_decoding(self, ints, offset, decode_path):
+        value, bound_min = list(sorted(ints))
+
+        class String(self.base_klass):
+            bounds = (bound_min, bound_min)
+        with self.assertRaises(DecodeError) as err:
+            String().decode(
+                self.base_klass(b"1" * value).encode(),
+                offset=offset,
+                decode_path=decode_path,
+            )
+        repr(err.exception)
+        self.assertEqual(err.exception.offset, offset)
+        self.assertEqual(err.exception.decode_path, decode_path)
+
 
 class TestGeneralString(
         UnicodeDecodeErrorMixin,
@@ -4437,8 +4499,13 @@ class TestGeneralizedTime(TimeMixin, CommonMixin, TestCase):
                 mktime(obj.todatetime().timetuple()),
                 mktime(dt.timetuple()),
             )
-        elif not PY2:
-            self.assertEqual(obj.todatetime().timestamp(), dt.timestamp())
+        else:
+            try:
+                obj.todatetime().timestamp()
+            except:
+                pass
+            else:
+                self.assertEqual(obj.todatetime().timestamp(), dt.timestamp())
         self.assertEqual(obj.ber_encoded, not dered)
         self.assertEqual(obj.bered, not dered)
         self.assertEqual(obj.ber_raw, None if dered else data)
@@ -4670,6 +4737,10 @@ class TestGeneralizedTime(TimeMixin, CommonMixin, TestCase):
             with self.assertRaises(DecodeError):
                 GeneralizedTime(data)
 
+    def test_aware(self):
+        with assertRaisesRegex(self, ValueError, "only naive"):
+            GeneralizedTime(datetime(2000, 1, 1, 1, tzinfo=UTC))
+
 
 class TestUTCTime(TimeMixin, CommonMixin, TestCase):
     base_klass = UTCTime
@@ -5003,6 +5074,10 @@ class TestUTCTime(TimeMixin, CommonMixin, TestCase):
                 junk
             )
 
+    def test_aware(self):
+        with assertRaisesRegex(self, ValueError, "only naive"):
+            UTCTime(datetime(2000, 1, 1, 1, tzinfo=UTC))
+
 
 @composite
 def tlv_value_strategy(draw):
@@ -6274,6 +6349,7 @@ class SeqMixing(object):
             min_size=len(_schema),
             max_size=len(_schema),
         ))]
+
         class Wahl(Choice):
             schema = (("int", Integer()),)
 
@@ -7046,6 +7122,7 @@ class TestSequenceOf(SeqOfMixing, CommonMixin, TestCase):
             schema = Integer()
             bounds = (10, 20)
         seqof = None
+
         def gen(n):
             for i in six_xrange(n):
                 yield Integer(i)
@@ -7065,6 +7142,7 @@ class TestSequenceOf(SeqOfMixing, CommonMixin, TestCase):
         class SeqOf(SequenceOf):
             schema = Integer()
             bounds = (1, float("+inf"))
+
         def gen():
             for i in six_xrange(10):
                 yield Integer(i)
@@ -7079,6 +7157,7 @@ class TestSequenceOf(SeqOfMixing, CommonMixin, TestCase):
         class SeqOf(SequenceOf):
             schema = Integer()
             bounds = (1, float("+inf"))
+
         def gen():
             for i in six_xrange(10):
                 yield Integer(i)
@@ -7613,6 +7692,7 @@ class TestDefinesByPath(TestCase):
 
     def test_remaining_data(self):
         oid = ObjectIdentifier("1.2.3")
+
         class Seq(Sequence):
             schema = (
                 ("oid", ObjectIdentifier(defines=((("tgt",), {
@@ -7630,6 +7710,7 @@ class TestDefinesByPath(TestCase):
 
     def test_remaining_data_seqof(self):
         oid = ObjectIdentifier("1.2.3")
+
         class SeqOf(SetOf):
             schema = OctetString()