From: Sergey Matveev Date: Mon, 2 Oct 2017 13:05:55 +0000 (+0300) Subject: Auto add __slots__ to all inherited classes X-Git-Tag: 1.0~3 X-Git-Url: http://www.git.cypherpunks.ru/?p=pyderasn.git;a=commitdiff_plain;h=1d553e0c1bbb7639409b846fe65711ca1fb00105 Auto add __slots__ to all inherited classes --- diff --git a/THANKS b/THANKS index 670e586..a778939 100644 --- a/THANKS +++ b/THANKS @@ -1,2 +1,4 @@ * pyasn1 (http://pyasn1.sourceforge.net/) project * Go encoding/asn1 (https://golang.org/pkg/encoding/asn1/) library +* Nikolay Ivanov for helping with + AutoAddSlots metaclass diff --git a/pyderasn.py b/pyderasn.py index 4ae1eb0..476e4e8 100755 --- a/pyderasn.py +++ b/pyderasn.py @@ -318,6 +318,7 @@ from collections import OrderedDict from datetime import datetime from math import ceil +from six import add_metaclass from six import binary_type from six import byte2int from six import indexbytes @@ -655,6 +656,13 @@ def len_decode(data): # Base class ######################################################################## +class AutoAddSlots(type): + def __new__(cls, name, bases, _dict): + _dict["__slots__"] = _dict.get("__slots__", ()) + return type.__new__(cls, name, bases, _dict) + + +@add_metaclass(AutoAddSlots) class Obj(object): """Common ASN.1 object class diff --git a/tests/test_crts.py b/tests/test_crts.py index 93162e0..5a839eb 100644 --- a/tests/test_crts.py +++ b/tests/test_crts.py @@ -62,7 +62,6 @@ some_oids = { class Version(Integer): - __slots__ = () schema = ( ("v1", 0), ("v2", 1), @@ -71,12 +70,10 @@ class Version(Integer): class CertificateSerialNumber(Integer): - __slots__ = () pass class AlgorithmIdentifier(Sequence): - __slots__ = () schema = ( ("algorithm", ObjectIdentifier()), ("parameters", Any(optional=True)), @@ -84,17 +81,14 @@ class AlgorithmIdentifier(Sequence): class AttributeType(ObjectIdentifier): - __slots__ = () pass class AttributeValue(Any): - __slots__ = () pass class AttributeTypeAndValue(Sequence): - __slots__ = () schema = ( ("type", AttributeType()), ("value", AttributeValue()), @@ -102,25 +96,21 @@ class AttributeTypeAndValue(Sequence): class RelativeDistinguishedName(SetOf): - __slots__ = () schema = AttributeTypeAndValue() bounds = (1, float("+inf")) class RDNSequence(SequenceOf): - __slots__ = () schema = RelativeDistinguishedName() class Name(Choice): - __slots__ = () schema = ( ("rdnSequence", RDNSequence()), ) class Time(Choice): - __slots__ = () schema = ( ("utcTime", UTCTime()), ("generalTime", GeneralizedTime()), @@ -128,7 +118,6 @@ class Time(Choice): class Validity(Sequence): - __slots__ = () schema = ( ("notBefore", Time()), ("notAfter", Time()), @@ -136,7 +125,6 @@ class Validity(Sequence): class SubjectPublicKeyInfo(Sequence): - __slots__ = () schema = ( ("algorithm", AlgorithmIdentifier()), ("subjectPublicKey", BitString()), @@ -144,12 +132,10 @@ class SubjectPublicKeyInfo(Sequence): class UniqueIdentifier(BitString): - __slots__ = () pass class Extension(Sequence): - __slots__ = () schema = ( ("extnID", ObjectIdentifier()), ("critical", Boolean(default=False)), @@ -158,13 +144,11 @@ class Extension(Sequence): class Extensions(SequenceOf): - __slots__ = () schema = Extension() bounds = (1, float("+inf")) class TBSCertificate(Sequence): - __slots__ = () schema = ( ("version", Version(expl=tag_ctxc(0), default="v1")), ("serialNumber", CertificateSerialNumber()), @@ -180,7 +164,6 @@ class TBSCertificate(Sequence): class Certificate(Sequence): - __slots__ = () schema = ( ("tbsCertificate", TBSCertificate()), ("signatureAlgorithm", AlgorithmIdentifier()), diff --git a/tests/test_pyderasn.py b/tests/test_pyderasn.py index 2fa22ab..0be393e 100644 --- a/tests/test_pyderasn.py +++ b/tests/test_pyderasn.py @@ -292,7 +292,6 @@ class CommonMixin(object): @given(binary()) def test_impl_inherited(self, impl_tag): class Inherited(self.base_klass): - __slots__ = () impl = impl_tag obj = Inherited() self.assertSequenceEqual(obj.impl, impl_tag) @@ -301,7 +300,6 @@ class CommonMixin(object): @given(binary()) def test_expl_inherited(self, expl_tag): class Inherited(self.base_klass): - __slots__ = () expl = expl_tag obj = Inherited() self.assertSequenceEqual(obj.expl, expl_tag) @@ -338,7 +336,7 @@ def boolean_values_strat(draw, do_expl=False): class BooleanInherited(Boolean): - __slots__ = () + pass class TestBoolean(CommonMixin, TestCase): @@ -626,7 +624,7 @@ def integer_values_strat(draw, do_expl=False): class IntegerInherited(Integer): - __slots__ = () + pass class TestInteger(CommonMixin, TestCase): @@ -642,7 +640,6 @@ class TestInteger(CommonMixin, TestCase): missing = names_input.pop() class Int(Integer): - __slots__ = () schema = [(n, 123) for n in names_input] with self.assertRaises(ObjUnknown) as err: Int(missing) @@ -651,7 +648,6 @@ class TestInteger(CommonMixin, TestCase): @given(sets(text_letters(), min_size=2)) def test_known_name(self, names_input): class Int(Integer): - __slots__ = () schema = [(n, 123) for n in names_input] Int(names_input.pop()) @@ -705,7 +701,6 @@ class TestInteger(CommonMixin, TestCase): names_input = dict(zip(names_input, values_input)) class Int(Integer): - __slots__ = () schema = names_input _int = Int(chosen_name) self.assertEqual(_int.named, chosen_name) @@ -1029,7 +1024,7 @@ def bit_string_values_strat(draw, schema=None, value_required=False, do_expl=Fal class BitStringInherited(BitString): - __slots__ = () + pass class TestBitString(CommonMixin, TestCase): @@ -1063,7 +1058,6 @@ class TestBitString(CommonMixin, TestCase): self.assertGreater(len(obj.encode()), (leading_zeros + 1 + trailing_zeros) // 8) class BS(BitString): - __slots__ = () schema = (("whatever", 0),) obj = BS("'%s1%s'B" % (("0" * leading_zeros), ("0" * trailing_zeros))) self.assertEqual(obj.bit_len, leading_zeros + 1) @@ -1100,7 +1094,6 @@ class TestBitString(CommonMixin, TestCase): missing = _schema.pop() class BS(BitString): - __slots__ = () schema = [(n, i) for i, n in enumerate(_schema)] with self.assertRaises(ObjUnknown) as err: BS((missing,)) @@ -1155,7 +1148,6 @@ class TestBitString(CommonMixin, TestCase): ) = d.draw(bit_string_values_strat()) class BS(klass): - __slots__ = () schema = schema_initial obj_initial = BS( value=value_initial, @@ -1200,7 +1192,6 @@ class TestBitString(CommonMixin, TestCase): _schema, value, impl, expl, default, optional, _decoded = values class BS(klass): - __slots__ = () schema = _schema obj = BS( value=value, @@ -1283,7 +1274,6 @@ class TestBitString(CommonMixin, TestCase): offset = d.draw(integers(min_value=0)) for klass in (BitString, BitStringInherited): class BS(klass): - __slots__ = () schema = _schema obj = BS( value=value, @@ -1408,7 +1398,7 @@ def octet_string_values_strat(draw, do_expl=False): class OctetStringInherited(OctetString): - __slots__ = () + pass class TestOctetString(CommonMixin, TestCase): @@ -1690,7 +1680,7 @@ def null_values_strat(draw, do_expl=False): class NullInherited(Null): - __slots__ = () + pass class TestNull(CommonMixin, TestCase): @@ -1883,7 +1873,7 @@ def oid_values_strat(draw, do_expl=False): class ObjectIdentifierInherited(ObjectIdentifier): - __slots__ = () + pass class TestObjectIdentifier(CommonMixin, TestCase): @@ -2212,7 +2202,6 @@ def enumerated_values_strat(draw, schema=None, do_expl=False): class TestEnumerated(CommonMixin, TestCase): class EWhatever(Enumerated): - __slots__ = () schema = (("whatever", 0),) base_klass = EWhatever @@ -2231,7 +2220,6 @@ class TestEnumerated(CommonMixin, TestCase): missing = schema_input.pop() class E(Enumerated): - __slots__ = () schema = [(n, 123) for n in schema_input] with self.assertRaises(ObjUnknown) as err: E(missing) @@ -2247,7 +2235,6 @@ class TestEnumerated(CommonMixin, TestCase): _input = list(zip(schema_input, values_input)) class E(Enumerated): - __slots__ = () schema = _input with self.assertRaises(DecodeError) as err: E(missing_value) @@ -2274,14 +2261,13 @@ class TestEnumerated(CommonMixin, TestCase): @given(integers(), integers(), binary(), binary()) def test_comparison(self, value1, value2, tag1, tag2): class E(Enumerated): - __slots__ = () schema = ( ("whatever0", value1), ("whatever1", value2), ) class EInherited(E): - __slots__ = () + pass for klass in (E, EInherited): obj1 = klass(value1) obj2 = klass(value2) @@ -2304,7 +2290,6 @@ class TestEnumerated(CommonMixin, TestCase): ) = d.draw(enumerated_values_strat()) class E(Enumerated): - __slots__ = () schema = schema_initial obj_initial = E( value=value_initial, @@ -2362,7 +2347,6 @@ class TestEnumerated(CommonMixin, TestCase): schema_input, value, impl, expl, default, optional, _decoded = values class E(Enumerated): - __slots__ = () schema = schema_input obj = E( value=value, @@ -2387,7 +2371,6 @@ class TestEnumerated(CommonMixin, TestCase): value = d.draw(sampled_from(sorted([v for _, v in schema_input]))) class E(Enumerated): - __slots__ = () schema = schema_input obj = E( value=value, @@ -3137,7 +3120,7 @@ def any_values_strat(draw, do_expl=False): class AnyInherited(Any): - __slots__ = () + pass class TestAny(CommonMixin, TestCase): @@ -3367,7 +3350,7 @@ def choice_values_strat(draw, value_required=False, schema=None, do_expl=False): class ChoiceInherited(Choice): - __slots__ = () + pass class TestChoice(CommonMixin, TestCase): @@ -3424,7 +3407,7 @@ class TestChoice(CommonMixin, TestCase): @given(booleans(), booleans()) def test_comparison(self, value1, value2): class WahlInherited(self.base_klass): - __slots__ = () + pass for klass in (self.base_klass, WahlInherited): obj1 = klass(("whatever", Boolean(value1))) obj2 = klass(("whatever", Boolean(value2))) @@ -3445,7 +3428,6 @@ class TestChoice(CommonMixin, TestCase): ) = d.draw(choice_values_strat()) class Wahl(klass): - __slots__ = () schema = schema_initial obj_initial = Wahl( value=value_initial, @@ -3497,7 +3479,6 @@ class TestChoice(CommonMixin, TestCase): _schema, value, expl, default, optional, _decoded = values class Wahl(self.base_klass): - __slots__ = () schema = _schema obj = Wahl( value=value, @@ -3541,7 +3522,6 @@ class TestChoice(CommonMixin, TestCase): offset = d.draw(integers(min_value=0)) class Wahl(self.base_klass): - __slots__ = () schema = _schema obj = Wahl( value=value, @@ -3714,7 +3694,7 @@ def sequence_strat(draw, seq_klass): for i, (klass, value, default) in enumerate(inputs): schema.append((names[i], klass(default=default, **inits[i]))) seq_name = draw(text_letters()) - Seq = type(seq_name, (seq_klass,), {"__slots__": (), "schema": tuple(schema)}) + Seq = type(seq_name, (seq_klass,), {"schema": tuple(schema)}) seq = Seq() expects = [] for i, (klass, value, default) in enumerate(inputs): @@ -3777,7 +3757,7 @@ def sequences_strat(draw, seq_klass): seq(default=(seq if i in defaulted else None), **inits[i]), )) seq_name = draw(text_letters()) - Seq = type(seq_name, (seq_klass,), {"__slots__": (), "schema": tuple(schema)}) + Seq = type(seq_name, (seq_klass,), {"schema": tuple(schema)}) seq_outer = Seq() expect_outers = [] for name, (seq_inner, expects_inner) in zip(names, seq_expectses): @@ -3801,7 +3781,6 @@ class SeqMixing(object): def test_invalid_value_type_set(self): class Seq(self.base_klass): - __slots__ = () schema = (("whatever", Boolean()),) seq = Seq() with self.assertRaises(InvalidValueType) as err: @@ -3836,7 +3815,6 @@ class SeqMixing(object): schema_input.append((name, Boolean())) class Seq(self.base_klass): - __slots__ = () schema = tuple(schema_input) seq = Seq() for name in ready.keys(): @@ -3862,7 +3840,7 @@ class SeqMixing(object): @given(data_strategy()) def test_call(self, d): class SeqInherited(self.base_klass): - __slots__ = () + pass for klass in (self.base_klass, SeqInherited): ( value_initial, @@ -3918,7 +3896,7 @@ class SeqMixing(object): @given(data_strategy()) def test_copy(self, d): class SeqInherited(self.base_klass): - __slots__ = () + pass for klass in (self.base_klass, SeqInherited): values = d.draw(seq_values_strat(seq_klass=klass)) obj = klass(*values) @@ -3933,7 +3911,6 @@ class SeqMixing(object): tag_impl = tag_encode(d.draw(integers(min_value=1))) class Seq(self.base_klass): - __slots__ = () impl = tag_impl schema = (("whatever", Integer()),) seq = Seq() @@ -3947,7 +3924,6 @@ class SeqMixing(object): tag_expl = tag_ctxc(d.draw(integers(min_value=1))) class Seq(self.base_klass): - __slots__ = () expl = tag_expl schema = (("whatever", Integer()),) seq = Seq() @@ -3966,7 +3942,6 @@ class SeqMixing(object): assume(False) class Seq(self.base_klass): - __slots__ = () schema = ( ("whatever", Integer()), ("junk", Any()), @@ -4093,7 +4068,6 @@ class SeqMixing(object): )).items()) class Seq(self.base_klass): - __slots__ = () schema = [ (n, Integer(default=d)) for n, (_, d) in _schema @@ -4123,7 +4097,6 @@ class SeqMixing(object): ))] class SeqWithoutDefault(self.base_klass): - __slots__ = () schema = [ (n, Integer(impl=t)) for (n, _), t in zip(_schema, tags) @@ -4134,7 +4107,6 @@ class SeqMixing(object): seq_encoded = seq_without_default.encode() class SeqWithDefault(self.base_klass): - __slots__ = () schema = [ (n, Integer(default=v, impl=t)) for (n, v), t in zip(_schema, tags) @@ -4156,7 +4128,6 @@ class SeqMixing(object): names_tags = [(name, tag) for tag, name in sorted(zip(tags, names))] class SeqFull(self.base_klass): - __slots__ = () schema = [(n, Integer(impl=t)) for n, t in names_tags] seq_full = SeqFull() for i, name in enumerate(names): @@ -4165,7 +4136,6 @@ class SeqMixing(object): altered = names_tags[:-2] + names_tags[-1:] class SeqMissing(self.base_klass): - __slots__ = () schema = [(n, Integer(impl=t)) for n, t in altered] seq_missing = SeqMissing() with self.assertRaises(TagMismatch): @@ -4181,7 +4151,6 @@ class TestSequence(SeqMixing, CommonMixin, TestCase): ) def test_remaining(self, value, junk): class Seq(Sequence): - __slots__ = () schema = ( ("whatever", Integer()), ) @@ -4199,7 +4168,6 @@ class TestSequence(SeqMixing, CommonMixin, TestCase): missing = names.pop() class Seq(Sequence): - __slots__ = () schema = [(n, Boolean()) for n in names] seq = Seq() with self.assertRaises(ObjUnknown) as err: @@ -4222,7 +4190,6 @@ class TestSet(SeqMixing, CommonMixin, TestCase): ] class Seq(Set): - __slots__ = () schema = [(str(i), OctetString(impl=t)) for i, t in enumerate(tags)] seq = Seq() for name, _ in Seq.schema: @@ -4287,7 +4254,6 @@ class SeqOfMixing(object): def test_invalid_values_type(self): class SeqOf(self.base_klass): - __slots__ = () schema = Integer() with self.assertRaises(InvalidValueType) as err: SeqOf([Integer(123), Boolean(False), Integer(234)]) @@ -4300,7 +4266,6 @@ class SeqOfMixing(object): @given(booleans(), booleans(), binary(), binary()) def test_comparison(self, value1, value2, tag1, tag2): class SeqOf(self.base_klass): - __slots__ = () schema = Boolean() obj1 = SeqOf([Boolean(value1)]) obj2 = SeqOf([Boolean(value2)]) @@ -4314,7 +4279,6 @@ class SeqOfMixing(object): @given(lists(booleans())) def test_iter(self, values): class SeqOf(self.base_klass): - __slots__ = () schema = Boolean() obj = SeqOf([Boolean(value) for value in values]) self.assertEqual(len(obj), len(values)) @@ -4334,7 +4298,6 @@ class SeqOfMixing(object): ] class SeqOf(self.base_klass): - __slots__ = () schema = Integer() values = d.draw(permutations(ready + non_ready)) seqof = SeqOf() @@ -4356,7 +4319,6 @@ class SeqOfMixing(object): def test_spec_mismatch(self): class SeqOf(self.base_klass): - __slots__ = () schema = Integer() seqof = SeqOf() seqof.append(Integer(123)) @@ -4368,7 +4330,6 @@ class SeqOfMixing(object): @given(data_strategy()) def test_bounds_satisfied(self, d): class SeqOf(self.base_klass): - __slots__ = () schema = Boolean() bound_min = d.draw(integers(min_value=0, max_value=1 << 7)) bound_max = d.draw(integers(min_value=bound_min, max_value=1 << 7)) @@ -4378,7 +4339,6 @@ class SeqOfMixing(object): @given(data_strategy()) def test_bounds_unsatisfied(self, d): class SeqOf(self.base_klass): - __slots__ = () schema = Boolean() bound_min = d.draw(integers(min_value=1, max_value=1 << 7)) bound_max = d.draw(integers(min_value=bound_min, max_value=1 << 7)) @@ -4397,7 +4357,6 @@ class SeqOfMixing(object): @given(integers(min_value=1, max_value=10)) def test_out_of_bounds(self, bound_max): class SeqOf(self.base_klass): - __slots__ = () schema = Integer() bounds = (0, bound_max) seqof = SeqOf() @@ -4420,7 +4379,6 @@ class SeqOfMixing(object): ) = d.draw(seqof_values_strat()) class SeqOf(self.base_klass): - __slots__ = () schema = schema_initial obj_initial = SeqOf( value=value_initial, @@ -4498,7 +4456,6 @@ class SeqOfMixing(object): _schema, value, bounds, impl, expl, default, optional, _decoded = values class SeqOf(self.base_klass): - __slots__ = () schema = _schema obj = SeqOf( value=value, @@ -4521,7 +4478,6 @@ class SeqOfMixing(object): ) def test_stripped(self, values, tag_impl): class SeqOf(self.base_klass): - __slots__ = () schema = OctetString() obj = SeqOf([OctetString(v) for v in values], impl=tag_impl) with self.assertRaises(NotEnoughData): @@ -4533,7 +4489,6 @@ class SeqOfMixing(object): ) def test_stripped_expl(self, values, tag_expl): class SeqOf(self.base_klass): - __slots__ = () schema = OctetString() obj = SeqOf([OctetString(v) for v in values], expl=tag_expl) with self.assertRaises(NotEnoughData): @@ -4590,7 +4545,6 @@ class SeqOfMixing(object): _, _, _, _, _, default, optional, _decoded = values class SeqOf(self.base_klass): - __slots__ = () schema = Integer() obj = SeqOf( value=value, @@ -4639,7 +4593,6 @@ class SeqOfMixing(object): class TestSequenceOf(SeqOfMixing, CommonMixin, TestCase): class SeqOf(SequenceOf): - __slots__ = () schema = "whatever" base_klass = SeqOf @@ -4650,7 +4603,6 @@ class TestSequenceOf(SeqOfMixing, CommonMixin, TestCase): class TestSetOf(SeqOfMixing, CommonMixin, TestCase): class SeqOf(SetOf): - __slots__ = () schema = "whatever" base_klass = SeqOf @@ -4666,7 +4618,6 @@ class TestSetOf(SeqOfMixing, CommonMixin, TestCase): values = [OctetString(v) for v in d.draw(lists(binary()))] class Seq(SetOf): - __slots__ = () schema = OctetString() seq = Seq(values) seq_encoded = seq.encode() @@ -4686,7 +4637,6 @@ class TestGoMarshalVectors(TestCase): self.assertSequenceEqual(Integer(-129).encode(), hexdec("0202ff7f")) class Seq(Sequence): - __slots__ = () schema = ( ("erste", Integer()), ("zweite", Integer(optional=True)) @@ -4701,7 +4651,6 @@ class TestGoMarshalVectors(TestCase): self.assertSequenceEqual(seq.encode(), hexdec("3006020140020141")) class NestedSeq(Sequence): - __slots__ = () schema = ( ("nest", Seq()), ) @@ -4717,7 +4666,6 @@ class TestGoMarshalVectors(TestCase): ) class Seq(Sequence): - __slots__ = () schema = ( ("erste", Integer(impl=tag_encode(5, klass=TagClassContext))), ) @@ -4726,7 +4674,6 @@ class TestGoMarshalVectors(TestCase): self.assertSequenceEqual(seq.encode(), hexdec("3003850140")) class Seq(Sequence): - __slots__ = () schema = ( ("erste", Integer(expl=tag_ctxc(5))), ) @@ -4735,7 +4682,6 @@ class TestGoMarshalVectors(TestCase): self.assertSequenceEqual(seq.encode(), hexdec("3005a503020140")) class Seq(Sequence): - __slots__ = () schema = ( ("erste", Null( impl=tag_encode(0, klass=TagClassContext), @@ -4762,7 +4708,6 @@ class TestGoMarshalVectors(TestCase): ) class Seq(Sequence): - __slots__ = () schema = ( ("erste", GeneralizedTime()), ) @@ -4810,7 +4755,6 @@ class TestGoMarshalVectors(TestCase): self.assertSequenceEqual(UTF8String("Σ").encode(), hexdec("0c02cea3")) class Seq(Sequence): - __slots__ = () schema = ( ("erste", IA5String()), ) @@ -4819,7 +4763,6 @@ class TestGoMarshalVectors(TestCase): self.assertSequenceEqual(seq.encode(), hexdec("3006160474657374")) class Seq(Sequence): - __slots__ = () schema = ( ("erste", PrintableString()), ) @@ -4830,7 +4773,6 @@ class TestGoMarshalVectors(TestCase): self.assertSequenceEqual(seq.encode(), hexdec("30071305746573742a")) class Seq(Sequence): - __slots__ = () schema = ( ("erste", Any(optional=True)), ("zweite", Integer()), @@ -4840,18 +4782,15 @@ class TestGoMarshalVectors(TestCase): self.assertSequenceEqual(seq.encode(), hexdec("3003020140")) class Seq(SetOf): - __slots__ = () schema = Integer() seq = Seq() seq.append(Integer(10)) self.assertSequenceEqual(seq.encode(), hexdec("310302010a")) class _SeqOf(SequenceOf): - __slots__ = () schema = PrintableString() class SeqOf(SequenceOf): - __slots__ = () schema = _SeqOf() _seqof = _SeqOf() _seqof.append(PrintableString("1")) @@ -4860,7 +4799,6 @@ class TestGoMarshalVectors(TestCase): self.assertSequenceEqual(seqof.encode(), hexdec("30053003130131")) class Seq(Sequence): - __slots__ = () schema = ( ("erste", Integer(default=1)), ) @@ -4885,3 +4823,13 @@ class TestPP(TestCase): pp = _pp(asn1_type_name=ObjectIdentifier.asn1_type_name, value=chosen) self.assertNotIn(chosen_id, pp_console_row(pp)) self.assertIn(chosen_id, pp_console_row(pp, oids=oids)) + + +class TestAutoAddSlots(TestCase): + def runTest(self): + class Inher(Integer): + pass + + with self.assertRaises(AttributeError): + inher = Inher() + inher.unexistent = "whatever"