]> Cypherpunks.ru repositories - pyderasn.git/blob - tests/test_cms.py
Fix example SignedAttributes bounds
[pyderasn.git] / tests / test_cms.py
1 # coding: utf-8
2 # PyDERASN -- Python ASN.1 DER/CER/BER codec with abstract structures
3 # Copyright (C) 2017-2020 Sergey Matveev <stargrave@stargrave.org>
4 #
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Lesser General Public License as
7 # published by the Free Software Foundation, version 3 of the License.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU Lesser General Public License for more details.
13 #
14 # You should have received a copy of the GNU Lesser General Public
15 # License along with this program.  If not, see
16 # <http://www.gnu.org/licenses/>.
17
18 from hashlib import sha512
19 from io import BytesIO
20 from io import open as io_open
21 from os import environ
22 from os import remove
23 from os import urandom
24 from subprocess import call
25 from sys import getsizeof
26 from tempfile import NamedTemporaryFile
27 from time import time
28 from unittest import skipIf
29 from unittest import TestCase
30
31 from hypothesis import given
32 from hypothesis import settings
33 from hypothesis.strategies import integers
34 from six import PY2
35 from six.moves import xrange as six_xrange
36
37 from pyderasn import agg_octet_string
38 from pyderasn import Any
39 from pyderasn import Choice
40 from pyderasn import encode_cer
41 from pyderasn import file_mmaped
42 from pyderasn import Integer
43 from pyderasn import ObjectIdentifier
44 from pyderasn import OctetString
45 from pyderasn import Sequence
46 from pyderasn import SetOf
47 from pyderasn import tag_ctxc
48 from pyderasn import tag_ctxp
49 from tests.test_crts import AlgorithmIdentifier
50 from tests.test_crts import Certificate
51 from tests.test_crts import SubjectKeyIdentifier
52
53
54 class CMSVersion(Integer):
55     schema = (
56         ("v0", 0),
57         ("v1", 1),
58         ("v2", 2),
59         ("v3", 3),
60         ("v4", 4),
61         ("v5", 5),
62     )
63
64
65 class AttributeValue(Any):
66     pass
67
68
69 class AttributeValues(SetOf):
70     schema = AttributeValue()
71
72
73 class Attribute(Sequence):
74     schema = (
75         ("attrType", ObjectIdentifier()),
76         ("attrValues", AttributeValues()),
77     )
78
79
80 class SignatureAlgorithmIdentifier(AlgorithmIdentifier):
81     pass
82
83
84 class SignedAttributes(SetOf):
85     schema = Attribute()
86     bounds = (1, float("+inf"))
87     der_forced = True
88
89
90 class SignerIdentifier(Choice):
91     schema = (
92         # ("issuerAndSerialNumber", IssuerAndSerialNumber()),
93         ("subjectKeyIdentifier", SubjectKeyIdentifier(impl=tag_ctxp(0))),
94     )
95
96
97 class DigestAlgorithmIdentifiers(SetOf):
98     schema = AlgorithmIdentifier()
99
100
101 class DigestAlgorithmIdentifier(AlgorithmIdentifier):
102     pass
103
104
105 class SignatureValue(OctetString):
106     pass
107
108
109 class SignerInfo(Sequence):
110     schema = (
111         ("version", CMSVersion()),
112         ("sid", SignerIdentifier()),
113         ("digestAlgorithm", DigestAlgorithmIdentifier()),
114         ("signedAttrs", SignedAttributes(impl=tag_ctxc(0), optional=True)),
115         ("signatureAlgorithm", SignatureAlgorithmIdentifier()),
116         ("signature", SignatureValue()),
117         # ("unsignedAttrs", UnsignedAttributes(impl=tag_ctxc(1), optional=True)),
118     )
119
120
121 class SignerInfos(SetOf):
122     schema = SignerInfo()
123
124
125 class ContentType(ObjectIdentifier):
126     pass
127
128
129 class EncapsulatedContentInfo(Sequence):
130     schema = (
131         ("eContentType", ContentType()),
132         ("eContent", OctetString(expl=tag_ctxc(0), optional=True)),
133     )
134
135
136 class CertificateChoices(Choice):
137     schema = (
138         ('certificate', Certificate()),
139         # ...
140     )
141
142
143 class CertificateSet(SetOf):
144     schema = CertificateChoices()
145
146
147 class SignedData(Sequence):
148     schema = (
149         ("version", CMSVersion()),
150         ("digestAlgorithms", DigestAlgorithmIdentifiers()),
151         ("encapContentInfo", EncapsulatedContentInfo()),
152         ("certificates", CertificateSet(impl=tag_ctxc(0), optional=True)),
153         # ("crls", RevocationInfoChoices(impl=tag_ctxc(1), optional=True)),
154         ("signerInfos", SignerInfos()),
155     )
156
157
158 class ContentInfo(Sequence):
159     schema = (
160         ("contentType", ContentType()),
161         ("content", Any(expl=tag_ctxc(0))),
162     )
163
164
165 id_signedData = ObjectIdentifier("1.2.840.113549.1.7.2")
166 id_sha512 = ObjectIdentifier("2.16.840.1.101.3.4.2.3")
167 id_data = ObjectIdentifier("1.2.840.113549.1.7.1")
168 id_ecdsa_with_SHA512 = ObjectIdentifier("1.2.840.10045.4.3.4")
169 id_pkcs9_at_contentType = ObjectIdentifier("1.2.840.113549.1.9.3")
170 id_pkcs9_at_messageDigest = ObjectIdentifier("1.2.840.113549.1.9.4")
171 id_ce_subjectKeyIdentifier = ObjectIdentifier("2.5.29.14")
172 ai_sha512 = AlgorithmIdentifier((("algorithm", id_sha512),))
173
174 openssl_cms_exists = call("openssl cms -help 2>/dev/null", shell=True) == 0
175
176 @skipIf(not openssl_cms_exists, "openssl cms command not found")
177 class TestSignedDataCERWithOpenSSL(TestCase):
178     def tmpfile(self):
179         tmp = NamedTemporaryFile(delete=False)
180         tmp.close()
181         self.addCleanup(lambda: remove(tmp.name))
182         return tmp.name
183
184     def keypair(self):
185         key_path = self.tmpfile()
186         self.assertEqual(0, call(
187             "openssl ecparam -name secp521r1 -genkey -out " + key_path,
188             shell=True,
189         ))
190         cert_path = self.tmpfile()
191         self.assertEqual(0, call(" ".join((
192             "openssl req -x509 -new",
193             ("-key " + key_path),
194             ("-outform PEM -out " + cert_path),
195             "-nodes -subj /CN=pyderasntest",
196         )), shell=True))
197         cert_der_path = self.tmpfile()
198         self.assertEqual(0, call(" ".join((
199             "openssl x509",
200             "-inform PEM -in " + cert_path,
201             "-outform DER -out " + cert_der_path,
202         )), shell=True))
203         self.assertEqual(0, call("cat %s >> %s" % (key_path, cert_path), shell=True))
204         with open(cert_der_path, "rb") as fd:
205             cert = Certificate().decod(fd.read())
206         for ext in cert["tbsCertificate"]["extensions"]:
207             if ext["extnID"] == id_ce_subjectKeyIdentifier:
208                 skid = SubjectKeyIdentifier().decod(bytes(ext["extnValue"]))
209         return key_path, cert_path, cert, skid
210
211     def sign(self, signed_attrs, key_path):
212         input_path = self.tmpfile()
213         with open(input_path, "wb") as fd:
214             fd.write(encode_cer(signed_attrs))
215         signature_path = self.tmpfile()
216         self.assertEqual(0, call(" ".join((
217             "openssl dgst -sha512",
218             ("-sign " + key_path),
219             "-binary", input_path,
220             ("> " + signature_path),
221         )), shell=True))
222         with open(signature_path, "rb") as fd:
223             signature = fd.read()
224         return signature
225
226     def verify(self, cert_path, cms_path):
227         self.assertEqual(0, call(" ".join((
228             "openssl cms -verify",
229             ("-inform DER -in " + cms_path),
230             "-signer %s -CAfile %s" % (cert_path, cert_path),
231             "-out /dev/null 2>/dev/null",
232         )), shell=True))
233
234     @settings(deadline=None)
235     @given(integers(min_value=1000, max_value=5000))
236     def test_simple(self, data_len):
237         key_path, cert_path, cert, skid = self.keypair()
238         data = urandom(data_len)
239         eci = EncapsulatedContentInfo((
240             ("eContentType", ContentType(id_data)),
241             ("eContent", OctetString(data)),
242         ))
243         signed_attrs = SignedAttributes([
244             Attribute((
245                 ("attrType", id_pkcs9_at_contentType),
246                 ("attrValues", AttributeValues([AttributeValue(id_data)])),
247             )),
248             Attribute((
249                 ("attrType", id_pkcs9_at_messageDigest),
250                 ("attrValues", AttributeValues([
251                     AttributeValue(OctetString(
252                         sha512(bytes(eci["eContent"])).digest()
253                     )),
254                 ])),
255             )),
256         ])
257         signature = self.sign(signed_attrs, key_path)
258         ci = ContentInfo((
259             ("contentType", ContentType(id_signedData)),
260             ("content", Any((SignedData((
261                 ("version", CMSVersion("v3")),
262                 ("digestAlgorithms", DigestAlgorithmIdentifiers([ai_sha512])),
263                 ("encapContentInfo", eci),
264                 ("certificates", CertificateSet([
265                     CertificateChoices(("certificate", cert)),
266                 ])),
267                 ("signerInfos", SignerInfos([SignerInfo((
268                     ("version", CMSVersion("v3")),
269                     ("sid", SignerIdentifier(("subjectKeyIdentifier", skid))),
270                     ("digestAlgorithm", DigestAlgorithmIdentifier(ai_sha512)),
271                     ("signedAttrs", signed_attrs),
272                     ("signatureAlgorithm", SignatureAlgorithmIdentifier((
273                         ("algorithm", id_ecdsa_with_SHA512),
274                     ))),
275                     ("signature", SignatureValue(signature)),
276                 ))])),
277             ))))),
278         ))
279         cms_path = self.tmpfile()
280         _, state = ci.encode1st()
281         with io_open(cms_path, "wb") as fd:
282             ci.encode2nd(fd.write, iter(state))
283         self.verify(cert_path, cms_path)
284         with io_open(cms_path, "wb") as fd:
285             ci.encode_cer(fd.write)
286         self.verify(cert_path, cms_path)
287         fd = open(cms_path, "rb")
288         raw = memoryview(fd.read()) if PY2 else file_mmaped(fd)
289         ctx = {"bered": True}
290         for decode_path, obj, _ in ContentInfo().decode_evgen(raw, ctx=ctx):
291             if decode_path == ("content",):
292                 break
293         evgens = SignedData().decode_evgen(raw[obj.offset:], offset=obj.offset, ctx=ctx)
294         buf = BytesIO()
295         agg_octet_string(evgens, ("encapContentInfo", "eContent"), raw, buf.write)
296         self.assertSequenceEqual(buf.getvalue(), data)
297         fd.close()
298
299     def create_huge_file(self):
300         rnd = urandom(1<<20)
301         data_path = self.tmpfile()
302         start = time()
303         with open(data_path, "wb") as fd:
304             for _ in six_xrange(int(environ.get("PYDERASN_TEST_CMS_HUGE"))):
305                 # dgst.update(rnd)
306                 fd.write(rnd)
307         print("data file written", time() - start)
308         return file_mmaped(open(data_path, "rb"))
309
310     @skipIf(PY2, "no mmaped memoryview support in PY2")
311     @skipIf("PYDERASN_TEST_CMS_HUGE" not in environ, "PYDERASN_TEST_CMS_HUGE is not set")
312     def test_huge_cer(self):
313         """Huge CMS test
314
315         Environment variable PYDERASN_TEST_CMS_HUGE tells how many MiBs
316         data to sign. Pay attention that openssl cms is unable to do
317         stream verification and eats huge amounts (several times more,
318         than CMS itself) of memory.
319         """
320         data_raw = self.create_huge_file()
321         key_path, cert_path, cert, skid = self.keypair()
322         from sys import getallocatedblocks  # PY2 does not have it
323         mem_start = getallocatedblocks()
324         start = time()
325         eci = EncapsulatedContentInfo((
326             ("eContentType", ContentType(id_data)),
327             ("eContent", OctetString(data_raw)),
328         ))
329         eci_path = self.tmpfile()
330         with open(eci_path, "wb") as fd:
331             OctetString(eci["eContent"]).encode_cer(fd.write)
332         print("ECI file written", time() - start)
333         eci_fd = open(eci_path, "rb")
334         eci_raw = file_mmaped(eci_fd)
335
336         start = time()
337         dgst = sha512()
338         def hasher(data):
339             dgst.update(data)
340             return len(data)
341         evgens = OctetString().decode_evgen(eci_raw, ctx={"bered": True})
342         agg_octet_string(evgens, (), eci_raw, hasher)
343         dgst = dgst.digest()
344         print("digest calculated", time() - start)
345
346         signed_attrs = SignedAttributes([
347             Attribute((
348                 ("attrType", id_pkcs9_at_contentType),
349                 ("attrValues", AttributeValues([AttributeValue(id_data)])),
350             )),
351             Attribute((
352                 ("attrType", id_pkcs9_at_messageDigest),
353                 ("attrValues", AttributeValues([AttributeValue(OctetString(dgst))])),
354             )),
355         ])
356         signature = self.sign(signed_attrs, key_path)
357
358         self.assertLess(getallocatedblocks(), mem_start * 2)
359         start = time()
360         ci = ContentInfo((
361             ("contentType", ContentType(id_signedData)),
362             ("content", Any((SignedData((
363                 ("version", CMSVersion("v3")),
364                 ("digestAlgorithms", DigestAlgorithmIdentifiers([ai_sha512])),
365                 ("encapContentInfo", eci),
366                 ("certificates", CertificateSet([
367                     CertificateChoices(("certificate", cert)),
368                 ])),
369                 ("signerInfos", SignerInfos([SignerInfo((
370                     ("version", CMSVersion("v3")),
371                     ("sid", SignerIdentifier(("subjectKeyIdentifier", skid))),
372                     ("digestAlgorithm", DigestAlgorithmIdentifier(ai_sha512)),
373                     ("signedAttrs", signed_attrs),
374                     ("signatureAlgorithm", SignatureAlgorithmIdentifier((
375                         ("algorithm", id_ecdsa_with_SHA512),
376                     ))),
377                     ("signature", SignatureValue(signature)),
378                 ))])),
379             ))))),
380         ))
381         cms_path = self.tmpfile()
382         with io_open(cms_path, "wb") as fd:
383             ci.encode_cer(fd.write)
384         print("CMS written", time() - start)
385         self.verify(cert_path, cms_path)
386         eci_fd.close()
387
388     @skipIf(PY2, "no mmaped memoryview support in PY2")
389     @skipIf("PYDERASN_TEST_CMS_HUGE" not in environ, "PYDERASN_TEST_CMS_HUGE is not set")
390     def test_huge_der_2pass(self):
391         """Same test as above, but 2pass DER encoder and just signature verification
392         """
393         data_raw = self.create_huge_file()
394         key_path, cert_path, cert, skid = self.keypair()
395         from sys import getallocatedblocks
396         mem_start = getallocatedblocks()
397         dgst = sha512(data_raw).digest()
398         start = time()
399         eci = EncapsulatedContentInfo((
400             ("eContentType", ContentType(id_data)),
401             ("eContent", OctetString(data_raw)),
402         ))
403         signed_attrs = SignedAttributes([
404             Attribute((
405                 ("attrType", id_pkcs9_at_contentType),
406                 ("attrValues", AttributeValues([AttributeValue(id_data)])),
407             )),
408             Attribute((
409                 ("attrType", id_pkcs9_at_messageDigest),
410                 ("attrValues", AttributeValues([AttributeValue(OctetString(dgst))])),
411             )),
412         ])
413         signature = self.sign(signed_attrs, key_path)
414         self.assertLess(getallocatedblocks(), mem_start * 2)
415         start = time()
416         ci = ContentInfo((
417             ("contentType", ContentType(id_signedData)),
418             ("content", Any((SignedData((
419                 ("version", CMSVersion("v3")),
420                 ("digestAlgorithms", DigestAlgorithmIdentifiers([ai_sha512])),
421                 ("encapContentInfo", eci),
422                 ("certificates", CertificateSet([
423                     CertificateChoices(("certificate", cert)),
424                 ])),
425                 ("signerInfos", SignerInfos([SignerInfo((
426                     ("version", CMSVersion("v3")),
427                     ("sid", SignerIdentifier(("subjectKeyIdentifier", skid))),
428                     ("digestAlgorithm", DigestAlgorithmIdentifier(ai_sha512)),
429                     ("signedAttrs", signed_attrs),
430                     ("signatureAlgorithm", SignatureAlgorithmIdentifier((
431                         ("algorithm", id_ecdsa_with_SHA512),
432                     ))),
433                     ("signature", SignatureValue(signature)),
434                 ))])),
435             ))))),
436         ))
437         _, state = ci.encode1st()
438         print("2pass state size", getsizeof(state))
439         cms_path = self.tmpfile()
440         with io_open(cms_path, "wb") as fd:
441             ci.encode2nd(fd.write, iter(state))
442         print("CMS written", time() - start)
443         self.assertLess(getallocatedblocks(), mem_start * 2)
444         self.verify(cert_path, cms_path)