]> Cypherpunks.ru repositories - pyderasn.git/blob - tests/test_cms.py
Raise copyright years
[pyderasn.git] / tests / test_cms.py
1 # coding: utf-8
2 # PyDERASN -- Python ASN.1 DER/CER/BER codec with abstract structures
3 # Copyright (C) 2017-2022 Sergey Matveev <stargrave@stargrave.org>
4 #
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Lesser General Public License as
7 # published by the Free Software Foundation, version 3 of the License.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU Lesser General Public License for more details.
13 #
14 # You should have received a copy of the GNU Lesser General Public
15 # License along with this program.  If not, see
16 # <http://www.gnu.org/licenses/>.
17
18 from hashlib import sha512
19 from io import BytesIO
20 from io import open as io_open
21 from os import environ
22 from os import remove
23 from os import urandom
24 from subprocess import call
25 from sys import getsizeof
26 from tempfile import NamedTemporaryFile
27 from time import time
28 from unittest import skipIf
29 from unittest import TestCase
30
31 from hypothesis import given
32 from hypothesis import settings
33 from hypothesis.strategies import integers
34
35 from pyderasn import agg_octet_string
36 from pyderasn import Any
37 from pyderasn import Choice
38 from pyderasn import encode_cer
39 from pyderasn import file_mmaped
40 from pyderasn import Integer
41 from pyderasn import ObjectIdentifier
42 from pyderasn import OctetString
43 from pyderasn import Sequence
44 from pyderasn import SetOf
45 from pyderasn import tag_ctxc
46 from pyderasn import tag_ctxp
47 from tests.test_crts import AlgorithmIdentifier
48 from tests.test_crts import Certificate
49 from tests.test_crts import SubjectKeyIdentifier
50
51
52 class CMSVersion(Integer):
53     schema = (
54         ("v0", 0),
55         ("v1", 1),
56         ("v2", 2),
57         ("v3", 3),
58         ("v4", 4),
59         ("v5", 5),
60     )
61
62
63 class AttributeValue(Any):
64     pass
65
66
67 class AttributeValues(SetOf):
68     schema = AttributeValue()
69
70
71 class Attribute(Sequence):
72     schema = (
73         ("attrType", ObjectIdentifier()),
74         ("attrValues", AttributeValues()),
75     )
76
77
78 class SignatureAlgorithmIdentifier(AlgorithmIdentifier):
79     pass
80
81
82 class SignedAttributes(SetOf):
83     schema = Attribute()
84     bounds = (1, float("+inf"))
85     der_forced = True
86
87
88 class SignerIdentifier(Choice):
89     schema = (
90         # ("issuerAndSerialNumber", IssuerAndSerialNumber()),
91         ("subjectKeyIdentifier", SubjectKeyIdentifier(impl=tag_ctxp(0))),
92     )
93
94
95 class DigestAlgorithmIdentifiers(SetOf):
96     schema = AlgorithmIdentifier()
97
98
99 class DigestAlgorithmIdentifier(AlgorithmIdentifier):
100     pass
101
102
103 class SignatureValue(OctetString):
104     pass
105
106
107 class SignerInfo(Sequence):
108     schema = (
109         ("version", CMSVersion()),
110         ("sid", SignerIdentifier()),
111         ("digestAlgorithm", DigestAlgorithmIdentifier()),
112         ("signedAttrs", SignedAttributes(impl=tag_ctxc(0), optional=True)),
113         ("signatureAlgorithm", SignatureAlgorithmIdentifier()),
114         ("signature", SignatureValue()),
115         # ("unsignedAttrs", UnsignedAttributes(impl=tag_ctxc(1), optional=True)),
116     )
117
118
119 class SignerInfos(SetOf):
120     schema = SignerInfo()
121
122
123 class ContentType(ObjectIdentifier):
124     pass
125
126
127 class EncapsulatedContentInfo(Sequence):
128     schema = (
129         ("eContentType", ContentType()),
130         ("eContent", OctetString(expl=tag_ctxc(0), optional=True)),
131     )
132
133
134 class CertificateChoices(Choice):
135     schema = (
136         ('certificate', Certificate()),
137         # ...
138     )
139
140
141 class CertificateSet(SetOf):
142     schema = CertificateChoices()
143
144
145 class SignedData(Sequence):
146     schema = (
147         ("version", CMSVersion()),
148         ("digestAlgorithms", DigestAlgorithmIdentifiers()),
149         ("encapContentInfo", EncapsulatedContentInfo()),
150         ("certificates", CertificateSet(impl=tag_ctxc(0), optional=True)),
151         # ("crls", RevocationInfoChoices(impl=tag_ctxc(1), optional=True)),
152         ("signerInfos", SignerInfos()),
153     )
154
155
156 class ContentInfo(Sequence):
157     schema = (
158         ("contentType", ContentType()),
159         ("content", Any(expl=tag_ctxc(0))),
160     )
161
162
163 id_signedData = ObjectIdentifier("1.2.840.113549.1.7.2")
164 id_sha512 = ObjectIdentifier("2.16.840.1.101.3.4.2.3")
165 id_data = ObjectIdentifier("1.2.840.113549.1.7.1")
166 id_ecdsa_with_SHA512 = ObjectIdentifier("1.2.840.10045.4.3.4")
167 id_pkcs9_at_contentType = ObjectIdentifier("1.2.840.113549.1.9.3")
168 id_pkcs9_at_messageDigest = ObjectIdentifier("1.2.840.113549.1.9.4")
169 id_ce_subjectKeyIdentifier = ObjectIdentifier("2.5.29.14")
170 ai_sha512 = AlgorithmIdentifier((("algorithm", id_sha512),))
171
172 openssl_cms_exists = call("openssl cms -help 2>/dev/null", shell=True) == 0
173
174 @skipIf(not openssl_cms_exists, "openssl cms command not found")
175 class TestSignedDataCERWithOpenSSL(TestCase):
176     def tmpfile(self):
177         tmp = NamedTemporaryFile(delete=False)
178         tmp.close()
179         self.addCleanup(lambda: remove(tmp.name))
180         return tmp.name
181
182     def keypair(self):
183         key_path = self.tmpfile()
184         self.assertEqual(0, call(
185             "openssl ecparam -name secp521r1 -genkey -out " + key_path,
186             shell=True,
187         ))
188         cert_path = self.tmpfile()
189         self.assertEqual(0, call(" ".join((
190             "openssl req -x509 -new",
191             ("-key " + key_path),
192             ("-outform PEM -out " + cert_path),
193             "-nodes -subj /CN=pyderasntest",
194         )), shell=True))
195         cert_der_path = self.tmpfile()
196         self.assertEqual(0, call(" ".join((
197             "openssl x509",
198             "-inform PEM -in " + cert_path,
199             "-outform DER -out " + cert_der_path,
200         )), shell=True))
201         self.assertEqual(0, call("cat %s >> %s" % (key_path, cert_path), shell=True))
202         with open(cert_der_path, "rb") as fd:
203             cert = Certificate().decod(fd.read())
204         for ext in cert["tbsCertificate"]["extensions"]:
205             if ext["extnID"] == id_ce_subjectKeyIdentifier:
206                 skid = SubjectKeyIdentifier().decod(bytes(ext["extnValue"]))
207         return key_path, cert_path, cert, skid
208
209     def sign(self, signed_attrs, key_path):
210         input_path = self.tmpfile()
211         with open(input_path, "wb") as fd:
212             fd.write(encode_cer(signed_attrs))
213         signature_path = self.tmpfile()
214         self.assertEqual(0, call(" ".join((
215             "openssl dgst -sha512",
216             ("-sign " + key_path),
217             "-binary", input_path,
218             ("> " + signature_path),
219         )), shell=True))
220         with open(signature_path, "rb") as fd:
221             signature = fd.read()
222         return signature
223
224     def verify(self, cert_path, cms_path):
225         self.assertEqual(0, call(" ".join((
226             "openssl cms -verify",
227             ("-inform DER -in " + cms_path),
228             "-signer %s -CAfile %s" % (cert_path, cert_path),
229             "-out /dev/null 2>/dev/null",
230         )), shell=True))
231
232     @settings(deadline=None)
233     @given(integers(min_value=1000, max_value=5000))
234     def test_simple(self, data_len):
235         key_path, cert_path, cert, skid = self.keypair()
236         data = urandom(data_len)
237         eci = EncapsulatedContentInfo((
238             ("eContentType", ContentType(id_data)),
239             ("eContent", OctetString(data)),
240         ))
241         signed_attrs = SignedAttributes([
242             Attribute((
243                 ("attrType", id_pkcs9_at_contentType),
244                 ("attrValues", AttributeValues([AttributeValue(id_data)])),
245             )),
246             Attribute((
247                 ("attrType", id_pkcs9_at_messageDigest),
248                 ("attrValues", AttributeValues([
249                     AttributeValue(OctetString(
250                         sha512(bytes(eci["eContent"])).digest()
251                     )),
252                 ])),
253             )),
254         ])
255         signature = self.sign(signed_attrs, key_path)
256         ci = ContentInfo((
257             ("contentType", ContentType(id_signedData)),
258             ("content", Any((SignedData((
259                 ("version", CMSVersion("v3")),
260                 ("digestAlgorithms", DigestAlgorithmIdentifiers([ai_sha512])),
261                 ("encapContentInfo", eci),
262                 ("certificates", CertificateSet([
263                     CertificateChoices(("certificate", cert)),
264                 ])),
265                 ("signerInfos", SignerInfos([SignerInfo((
266                     ("version", CMSVersion("v3")),
267                     ("sid", SignerIdentifier(("subjectKeyIdentifier", skid))),
268                     ("digestAlgorithm", DigestAlgorithmIdentifier(ai_sha512)),
269                     ("signedAttrs", signed_attrs),
270                     ("signatureAlgorithm", SignatureAlgorithmIdentifier((
271                         ("algorithm", id_ecdsa_with_SHA512),
272                     ))),
273                     ("signature", SignatureValue(signature)),
274                 ))])),
275             ))))),
276         ))
277         cms_path = self.tmpfile()
278         _, state = ci.encode1st()
279         with io_open(cms_path, "wb") as fd:
280             ci.encode2nd(fd.write, iter(state))
281         self.verify(cert_path, cms_path)
282         with io_open(cms_path, "wb") as fd:
283             ci.encode_cer(fd.write)
284         self.verify(cert_path, cms_path)
285         fd = open(cms_path, "rb")
286         raw = file_mmaped(fd)
287         ctx = {"bered": True}
288         for decode_path, obj, _ in ContentInfo().decode_evgen(raw, ctx=ctx):
289             if decode_path == ("content",):
290                 break
291         evgens = SignedData().decode_evgen(raw[obj.offset:], offset=obj.offset, ctx=ctx)
292         buf = BytesIO()
293         agg_octet_string(evgens, ("encapContentInfo", "eContent"), raw, buf.write)
294         self.assertSequenceEqual(buf.getvalue(), data)
295         fd.close()
296
297     def create_huge_file(self):
298         rnd = urandom(1<<20)
299         data_path = self.tmpfile()
300         start = time()
301         with open(data_path, "wb") as fd:
302             for _ in range(int(environ.get("PYDERASN_TEST_CMS_HUGE"))):
303                 # dgst.update(rnd)
304                 fd.write(rnd)
305         print("data file written", time() - start)
306         return file_mmaped(open(data_path, "rb"))
307
308     @skipIf("PYDERASN_TEST_CMS_HUGE" not in environ, "PYDERASN_TEST_CMS_HUGE is not set")
309     def test_huge_cer(self):
310         """Huge CMS test
311
312         Environment variable PYDERASN_TEST_CMS_HUGE tells how many MiBs
313         data to sign. Pay attention that openssl cms is unable to do
314         stream verification and eats huge amounts (several times more,
315         than CMS itself) of memory.
316         """
317         data_raw = self.create_huge_file()
318         key_path, cert_path, cert, skid = self.keypair()
319         from sys import getallocatedblocks
320         mem_start = getallocatedblocks()
321         start = time()
322         eci = EncapsulatedContentInfo((
323             ("eContentType", ContentType(id_data)),
324             ("eContent", OctetString(data_raw)),
325         ))
326         eci_path = self.tmpfile()
327         with open(eci_path, "wb") as fd:
328             OctetString(eci["eContent"]).encode_cer(fd.write)
329         print("ECI file written", time() - start)
330         eci_fd = open(eci_path, "rb")
331         eci_raw = file_mmaped(eci_fd)
332
333         start = time()
334         dgst = sha512()
335         def hasher(data):
336             dgst.update(data)
337             return len(data)
338         evgens = OctetString().decode_evgen(eci_raw, ctx={"bered": True})
339         agg_octet_string(evgens, (), eci_raw, hasher)
340         dgst = dgst.digest()
341         print("digest calculated", time() - start)
342
343         signed_attrs = SignedAttributes([
344             Attribute((
345                 ("attrType", id_pkcs9_at_contentType),
346                 ("attrValues", AttributeValues([AttributeValue(id_data)])),
347             )),
348             Attribute((
349                 ("attrType", id_pkcs9_at_messageDigest),
350                 ("attrValues", AttributeValues([AttributeValue(OctetString(dgst))])),
351             )),
352         ])
353         signature = self.sign(signed_attrs, key_path)
354
355         self.assertLess(getallocatedblocks(), mem_start * 2)
356         start = time()
357         ci = ContentInfo((
358             ("contentType", ContentType(id_signedData)),
359             ("content", Any((SignedData((
360                 ("version", CMSVersion("v3")),
361                 ("digestAlgorithms", DigestAlgorithmIdentifiers([ai_sha512])),
362                 ("encapContentInfo", eci),
363                 ("certificates", CertificateSet([
364                     CertificateChoices(("certificate", cert)),
365                 ])),
366                 ("signerInfos", SignerInfos([SignerInfo((
367                     ("version", CMSVersion("v3")),
368                     ("sid", SignerIdentifier(("subjectKeyIdentifier", skid))),
369                     ("digestAlgorithm", DigestAlgorithmIdentifier(ai_sha512)),
370                     ("signedAttrs", signed_attrs),
371                     ("signatureAlgorithm", SignatureAlgorithmIdentifier((
372                         ("algorithm", id_ecdsa_with_SHA512),
373                     ))),
374                     ("signature", SignatureValue(signature)),
375                 ))])),
376             ))))),
377         ))
378         cms_path = self.tmpfile()
379         with io_open(cms_path, "wb") as fd:
380             ci.encode_cer(fd.write)
381         print("CMS written", time() - start)
382         self.verify(cert_path, cms_path)
383         eci_fd.close()
384
385     @skipIf("PYDERASN_TEST_CMS_HUGE" not in environ, "PYDERASN_TEST_CMS_HUGE is not set")
386     def test_huge_der_2pass(self):
387         """Same test as above, but 2pass DER encoder and just signature verification
388         """
389         data_raw = self.create_huge_file()
390         key_path, cert_path, cert, skid = self.keypair()
391         from sys import getallocatedblocks
392         mem_start = getallocatedblocks()
393         dgst = sha512(data_raw).digest()
394         start = time()
395         eci = EncapsulatedContentInfo((
396             ("eContentType", ContentType(id_data)),
397             ("eContent", OctetString(data_raw)),
398         ))
399         signed_attrs = SignedAttributes([
400             Attribute((
401                 ("attrType", id_pkcs9_at_contentType),
402                 ("attrValues", AttributeValues([AttributeValue(id_data)])),
403             )),
404             Attribute((
405                 ("attrType", id_pkcs9_at_messageDigest),
406                 ("attrValues", AttributeValues([AttributeValue(OctetString(dgst))])),
407             )),
408         ])
409         signature = self.sign(signed_attrs, key_path)
410         self.assertLess(getallocatedblocks(), mem_start * 2)
411         start = time()
412         ci = ContentInfo((
413             ("contentType", ContentType(id_signedData)),
414             ("content", Any((SignedData((
415                 ("version", CMSVersion("v3")),
416                 ("digestAlgorithms", DigestAlgorithmIdentifiers([ai_sha512])),
417                 ("encapContentInfo", eci),
418                 ("certificates", CertificateSet([
419                     CertificateChoices(("certificate", cert)),
420                 ])),
421                 ("signerInfos", SignerInfos([SignerInfo((
422                     ("version", CMSVersion("v3")),
423                     ("sid", SignerIdentifier(("subjectKeyIdentifier", skid))),
424                     ("digestAlgorithm", DigestAlgorithmIdentifier(ai_sha512)),
425                     ("signedAttrs", signed_attrs),
426                     ("signatureAlgorithm", SignatureAlgorithmIdentifier((
427                         ("algorithm", id_ecdsa_with_SHA512),
428                     ))),
429                     ("signature", SignatureValue(signature)),
430                 ))])),
431             ))))),
432         ))
433         _, state = ci.encode1st()
434         print("2pass state size", getsizeof(state))
435         cms_path = self.tmpfile()
436         with io_open(cms_path, "wb") as fd:
437             ci.encode2nd(fd.write, iter(state))
438         print("CMS written", time() - start)
439         self.assertLess(getallocatedblocks(), mem_start * 2)
440         self.verify(cert_path, cms_path)