]> Cypherpunks.ru repositories - pyderasn.git/blob - tests/test_cms.py
51bdaf1cbf26adb3d2843d9481d8c36766032eb3
[pyderasn.git] / tests / test_cms.py
1 # coding: utf-8
2 # PyDERASN -- Python ASN.1 DER/CER/BER codec with abstract structures
3 # Copyright (C) 2017-2020 Sergey Matveev <stargrave@stargrave.org>
4 #
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Lesser General Public License as
7 # published by the Free Software Foundation, version 3 of the License.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU Lesser General Public License for more details.
13 #
14 # You should have received a copy of the GNU Lesser General Public
15 # License along with this program.  If not, see
16 # <http://www.gnu.org/licenses/>.
17
18 from hashlib import sha512
19 from io import BytesIO
20 from io import open as io_open
21 from os import environ
22 from os import remove
23 from os import urandom
24 from subprocess import call
25 from sys import getsizeof
26 from tempfile import NamedTemporaryFile
27 from time import time
28 from unittest import skipIf
29 from unittest import TestCase
30
31 from hypothesis import given
32 from hypothesis import settings
33 from hypothesis.strategies import integers
34 from six import PY2
35 from six.moves import xrange as six_xrange
36
37 from pyderasn import agg_octet_string
38 from pyderasn import Any
39 from pyderasn import Choice
40 from pyderasn import encode_cer
41 from pyderasn import file_mmaped
42 from pyderasn import Integer
43 from pyderasn import ObjectIdentifier
44 from pyderasn import OctetString
45 from pyderasn import Sequence
46 from pyderasn import SetOf
47 from pyderasn import tag_ctxc
48 from pyderasn import tag_ctxp
49 from tests.test_crts import AlgorithmIdentifier
50 from tests.test_crts import Certificate
51 from tests.test_crts import SubjectKeyIdentifier
52
53
54 class CMSVersion(Integer):
55     schema = (
56         ("v0", 0),
57         ("v1", 1),
58         ("v2", 2),
59         ("v3", 3),
60         ("v4", 4),
61         ("v5", 5),
62     )
63
64
65 class AttributeValue(Any):
66     pass
67
68
69 class AttributeValues(SetOf):
70     schema = AttributeValue()
71
72
73 class Attribute(Sequence):
74     schema = (
75         ("attrType", ObjectIdentifier()),
76         ("attrValues", AttributeValues()),
77     )
78
79
80 class SignatureAlgorithmIdentifier(AlgorithmIdentifier):
81     pass
82
83
84 class SignedAttributes(SetOf):
85     schema = Attribute()
86     bounds = (1, 32)
87     der_forced = True
88
89
90 class SignerIdentifier(Choice):
91     schema = (
92         # ("issuerAndSerialNumber", IssuerAndSerialNumber()),
93         ("subjectKeyIdentifier", SubjectKeyIdentifier(impl=tag_ctxp(0))),
94     )
95
96
97 class DigestAlgorithmIdentifiers(SetOf):
98     schema = AlgorithmIdentifier()
99
100
101 class DigestAlgorithmIdentifier(AlgorithmIdentifier):
102     pass
103
104
105 class SignatureValue(OctetString):
106     pass
107
108
109 class SignerInfo(Sequence):
110     schema = (
111         ("version", CMSVersion()),
112         ("sid", SignerIdentifier()),
113         ("digestAlgorithm", DigestAlgorithmIdentifier()),
114         ("signedAttrs", SignedAttributes(impl=tag_ctxc(0), optional=True)),
115         ("signatureAlgorithm", SignatureAlgorithmIdentifier()),
116         ("signature", SignatureValue()),
117         # ("unsignedAttrs", UnsignedAttributes(impl=tag_ctxc(1), optional=True)),
118     )
119
120
121 class SignerInfos(SetOf):
122     schema = SignerInfo()
123
124
125 class ContentType(ObjectIdentifier):
126     pass
127
128
129 class EncapsulatedContentInfo(Sequence):
130     schema = (
131         ("eContentType", ContentType()),
132         ("eContent", OctetString(expl=tag_ctxc(0), optional=True)),
133     )
134
135
136 class CertificateChoices(Choice):
137     schema = (
138         ('certificate', Certificate()),
139         # ...
140     )
141
142
143 class CertificateSet(SetOf):
144     schema = CertificateChoices()
145
146
147 class SignedData(Sequence):
148     schema = (
149         ("version", CMSVersion()),
150         ("digestAlgorithms", DigestAlgorithmIdentifiers()),
151         ("encapContentInfo", EncapsulatedContentInfo()),
152         ("certificates", CertificateSet(impl=tag_ctxc(0), optional=True)),
153         # ("crls", RevocationInfoChoices(impl=tag_ctxc(1), optional=True)),
154         ("signerInfos", SignerInfos()),
155     )
156
157
158 class ContentInfo(Sequence):
159     schema = (
160         ("contentType", ContentType()),
161         ("content", Any(expl=tag_ctxc(0))),
162     )
163
164
165 id_signedData = ObjectIdentifier("1.2.840.113549.1.7.2")
166 id_sha512 = ObjectIdentifier("2.16.840.1.101.3.4.2.3")
167 id_data = ObjectIdentifier("1.2.840.113549.1.7.1")
168 id_ecdsa_with_SHA512 = ObjectIdentifier("1.2.840.10045.4.3.4")
169 id_pkcs9_at_contentType = ObjectIdentifier("1.2.840.113549.1.9.3")
170 id_pkcs9_at_messageDigest = ObjectIdentifier("1.2.840.113549.1.9.4")
171 id_ce_subjectKeyIdentifier = ObjectIdentifier("2.5.29.14")
172 ai_sha512 = AlgorithmIdentifier((("algorithm", id_sha512),))
173
174 openssl_cms_exists = call("openssl cms -help 2>/dev/null", shell=True) == 0
175
176 @skipIf(not openssl_cms_exists, "openssl cms command not found")
177 class TestSignedDataCERWithOpenSSL(TestCase):
178     def tmpfile(self):
179         tmp = NamedTemporaryFile(delete=False)
180         tmp.close()
181         self.addCleanup(lambda: remove(tmp.name))
182         return tmp.name
183
184     def keypair(self):
185         key_path = self.tmpfile()
186         self.assertEqual(0, call(
187             "openssl ecparam -name secp521r1 -genkey -out " + key_path,
188             shell=True,
189         ))
190         cert_path = self.tmpfile()
191         self.assertEqual(0, call(" ".join((
192             "openssl req -x509 -new",
193             ("-key " + key_path),
194             ("-outform PEM -out " + cert_path),
195             "-nodes -subj /CN=pyderasntest",
196         )), shell=True))
197         cert_der_path = self.tmpfile()
198         self.assertEqual(0, call(" ".join((
199             "openssl x509",
200             "-inform PEM -in " + cert_path,
201             "-outform DER -out " + cert_der_path,
202         )), shell=True))
203         self.assertEqual(0, call("cat %s >> %s" % (key_path, cert_path), shell=True))
204         with open(cert_der_path, "rb") as fd:
205             cert = Certificate().decod(fd.read())
206         for ext in cert["tbsCertificate"]["extensions"]:
207             if ext["extnID"] == id_ce_subjectKeyIdentifier:
208                 skid = SubjectKeyIdentifier().decod(bytes(ext["extnValue"]))
209         return key_path, cert_path, cert, skid
210
211     def sign(self, signed_attrs, key_path):
212         input_path = self.tmpfile()
213         with open(input_path, "wb") as fd:
214             fd.write(encode_cer(signed_attrs))
215         signature_path = self.tmpfile()
216         self.assertEqual(0, call(" ".join((
217             "openssl dgst -sha512",
218             ("-sign " + key_path),
219             "-binary", input_path,
220             ("> " + signature_path),
221         )), shell=True))
222         with open(signature_path, "rb") as fd:
223             signature = fd.read()
224         return signature
225
226     def verify(self, cert_path, cms_path):
227         self.assertEqual(0, call(" ".join((
228             "openssl cms -verify",
229             ("-inform DER -in " + cms_path),
230             "-signer %s -CAfile %s" % (cert_path, cert_path),
231             "-out /dev/null 2>/dev/null",
232         )), shell=True))
233
234     @settings(deadline=None)
235     @given(integers(min_value=1000, max_value=5000))
236     def test_simple(self, data_len):
237         key_path, cert_path, cert, skid = self.keypair()
238         data = urandom(data_len)
239         eci = EncapsulatedContentInfo((
240             ("eContentType", ContentType(id_data)),
241             ("eContent", OctetString(data)),
242         ))
243         signed_attrs = SignedAttributes([
244             Attribute((
245                 ("attrType", id_pkcs9_at_contentType),
246                 ("attrValues", AttributeValues([AttributeValue(id_data)])),
247             )),
248             Attribute((
249                 ("attrType", id_pkcs9_at_messageDigest),
250                 ("attrValues", AttributeValues([
251                     AttributeValue(OctetString(
252                         sha512(bytes(eci["eContent"])).digest()
253                     )),
254                 ])),
255             )),
256         ])
257         signature = self.sign(signed_attrs, key_path)
258         ci = ContentInfo((
259             ("contentType", ContentType(id_signedData)),
260             ("content", Any((SignedData((
261                 ("version", CMSVersion("v3")),
262                 ("digestAlgorithms", DigestAlgorithmIdentifiers([ai_sha512])),
263                 ("encapContentInfo", eci),
264                 ("certificates", CertificateSet([
265                     CertificateChoices(("certificate", cert)),
266                 ])),
267                 ("signerInfos", SignerInfos([SignerInfo((
268                     ("version", CMSVersion("v3")),
269                     ("sid", SignerIdentifier(("subjectKeyIdentifier", skid))),
270                     ("digestAlgorithm", DigestAlgorithmIdentifier(ai_sha512)),
271                     ("signedAttrs", signed_attrs),
272                     ("signatureAlgorithm", SignatureAlgorithmIdentifier((
273                         ("algorithm", id_ecdsa_with_SHA512),
274                     ))),
275                     ("signature", SignatureValue(signature)),
276                 ))])),
277             ))))),
278         ))
279         cms_path = self.tmpfile()
280         _, state = ci.encode1st()
281         with io_open(cms_path, "wb") as fd:
282             ci.encode2nd(fd.write, iter(state))
283         self.verify(cert_path, cms_path)
284         with io_open(cms_path, "wb") as fd:
285             ci.encode_cer(fd.write)
286         self.verify(cert_path, cms_path)
287         fd = open(cms_path, "rb")
288         raw = memoryview(fd.read()) if PY2 else file_mmaped(fd)
289         ctx = {"bered": True}
290         for decode_path, obj, _ in ContentInfo().decode_evgen(raw, ctx=ctx):
291             if decode_path == ("content",):
292                 break
293         evgens = SignedData().decode_evgen(raw[obj.offset:], offset=obj.offset, ctx=ctx)
294         buf = BytesIO()
295         agg_octet_string(evgens, ("encapContentInfo", "eContent"), raw, buf.write)
296         self.assertSequenceEqual(buf.getvalue(), data)
297
298     def create_huge_file(self):
299         rnd = urandom(1<<20)
300         data_path = self.tmpfile()
301         start = time()
302         with open(data_path, "wb") as fd:
303             for _ in six_xrange(int(environ.get("PYDERASN_TEST_CMS_HUGE"))):
304                 # dgst.update(rnd)
305                 fd.write(rnd)
306         print("data file written", time() - start)
307         return file_mmaped(open(data_path, "rb"))
308
309     @skipIf(PY2, "no mmaped memoryview support in PY2")
310     @skipIf("PYDERASN_TEST_CMS_HUGE" not in environ, "PYDERASN_TEST_CMS_HUGE is not set")
311     def test_huge_cer(self):
312         """Huge CMS test
313
314         Environment variable PYDERASN_TEST_CMS_HUGE tells how many MiBs
315         data to sign. Pay attention that openssl cms is unable to do
316         stream verification and eats huge amounts (several times more,
317         than CMS itself) of memory.
318         """
319         data_raw = self.create_huge_file()
320         key_path, cert_path, cert, skid = self.keypair()
321         from sys import getallocatedblocks  # PY2 does not have it
322         mem_start = getallocatedblocks()
323         start = time()
324         eci = EncapsulatedContentInfo((
325             ("eContentType", ContentType(id_data)),
326             ("eContent", OctetString(data_raw)),
327         ))
328         eci_path = self.tmpfile()
329         with open(eci_path, "wb") as fd:
330             OctetString(eci["eContent"]).encode_cer(fd.write)
331         print("ECI file written", time() - start)
332         eci_fd = open(eci_path, "rb")
333         eci_raw = file_mmaped(eci_fd)
334
335         start = time()
336         dgst = sha512()
337         def hasher(data):
338             dgst.update(data)
339             return len(data)
340         evgens = OctetString().decode_evgen(eci_raw, ctx={"bered": True})
341         agg_octet_string(evgens, (), eci_raw, hasher)
342         dgst = dgst.digest()
343         print("digest calculated", time() - start)
344
345         signed_attrs = SignedAttributes([
346             Attribute((
347                 ("attrType", id_pkcs9_at_contentType),
348                 ("attrValues", AttributeValues([AttributeValue(id_data)])),
349             )),
350             Attribute((
351                 ("attrType", id_pkcs9_at_messageDigest),
352                 ("attrValues", AttributeValues([AttributeValue(OctetString(dgst))])),
353             )),
354         ])
355         signature = self.sign(signed_attrs, key_path)
356
357         self.assertLess(getallocatedblocks(), mem_start * 2)
358         start = time()
359         ci = ContentInfo((
360             ("contentType", ContentType(id_signedData)),
361             ("content", Any((SignedData((
362                 ("version", CMSVersion("v3")),
363                 ("digestAlgorithms", DigestAlgorithmIdentifiers([ai_sha512])),
364                 ("encapContentInfo", eci),
365                 ("certificates", CertificateSet([
366                     CertificateChoices(("certificate", cert)),
367                 ])),
368                 ("signerInfos", SignerInfos([SignerInfo((
369                     ("version", CMSVersion("v3")),
370                     ("sid", SignerIdentifier(("subjectKeyIdentifier", skid))),
371                     ("digestAlgorithm", DigestAlgorithmIdentifier(ai_sha512)),
372                     ("signedAttrs", signed_attrs),
373                     ("signatureAlgorithm", SignatureAlgorithmIdentifier((
374                         ("algorithm", id_ecdsa_with_SHA512),
375                     ))),
376                     ("signature", SignatureValue(signature)),
377                 ))])),
378             ))))),
379         ))
380         cms_path = self.tmpfile()
381         with io_open(cms_path, "wb") as fd:
382             ci.encode_cer(fd.write)
383         print("CMS written", time() - start)
384         self.verify(cert_path, cms_path)
385
386     @skipIf(PY2, "no mmaped memoryview support in PY2")
387     @skipIf("PYDERASN_TEST_CMS_HUGE" not in environ, "PYDERASN_TEST_CMS_HUGE is not set")
388     def test_huge_der_2pass(self):
389         """Same test as above, but 2pass DER encoder and just signature verification
390         """
391         data_raw = self.create_huge_file()
392         key_path, cert_path, cert, skid = self.keypair()
393         from sys import getallocatedblocks
394         mem_start = getallocatedblocks()
395         dgst = sha512(data_raw).digest()
396         start = time()
397         eci = EncapsulatedContentInfo((
398             ("eContentType", ContentType(id_data)),
399             ("eContent", OctetString(data_raw)),
400         ))
401         signed_attrs = SignedAttributes([
402             Attribute((
403                 ("attrType", id_pkcs9_at_contentType),
404                 ("attrValues", AttributeValues([AttributeValue(id_data)])),
405             )),
406             Attribute((
407                 ("attrType", id_pkcs9_at_messageDigest),
408                 ("attrValues", AttributeValues([AttributeValue(OctetString(dgst))])),
409             )),
410         ])
411         signature = self.sign(signed_attrs, key_path)
412         self.assertLess(getallocatedblocks(), mem_start * 2)
413         start = time()
414         ci = ContentInfo((
415             ("contentType", ContentType(id_signedData)),
416             ("content", Any((SignedData((
417                 ("version", CMSVersion("v3")),
418                 ("digestAlgorithms", DigestAlgorithmIdentifiers([ai_sha512])),
419                 ("encapContentInfo", eci),
420                 ("certificates", CertificateSet([
421                     CertificateChoices(("certificate", cert)),
422                 ])),
423                 ("signerInfos", SignerInfos([SignerInfo((
424                     ("version", CMSVersion("v3")),
425                     ("sid", SignerIdentifier(("subjectKeyIdentifier", skid))),
426                     ("digestAlgorithm", DigestAlgorithmIdentifier(ai_sha512)),
427                     ("signedAttrs", signed_attrs),
428                     ("signatureAlgorithm", SignatureAlgorithmIdentifier((
429                         ("algorithm", id_ecdsa_with_SHA512),
430                     ))),
431                     ("signature", SignatureValue(signature)),
432                 ))])),
433             ))))),
434         ))
435         _, state = ci.encode1st()
436         print("2pass state size", getsizeof(state))
437         cms_path = self.tmpfile()
438         with io_open(cms_path, "wb") as fd:
439             ci.encode2nd(fd.write, iter(state))
440         print("CMS written", time() - start)
441         self.assertLess(getallocatedblocks(), mem_start * 2)
442         self.verify(cert_path, cms_path)