]> Cypherpunks.ru repositories - pyderasn.git/blob - tests/test_cms.py
Streaming of huge data support
[pyderasn.git] / tests / test_cms.py
1 # coding: utf-8
2 # PyDERASN -- Python ASN.1 DER codec with abstract structures
3 # Copyright (C) 2017-2020 Sergey Matveev <stargrave@stargrave.org>
4 #
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Lesser General Public License as
7 # published by the Free Software Foundation, version 3 of the License.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU Lesser General Public License for more details.
13 #
14 # You should have received a copy of the GNU Lesser General Public
15 # License along with this program.  If not, see
16 # <http://www.gnu.org/licenses/>.
17
18 from hashlib import sha512
19 from io import BytesIO
20 from io import open as io_open
21 from os import environ
22 from os import remove
23 from os import urandom
24 from subprocess import call
25 from tempfile import NamedTemporaryFile
26 from time import time
27 from unittest import skipIf
28 from unittest import TestCase
29
30 from hypothesis import given
31 from hypothesis import settings
32 from hypothesis.strategies import integers
33 from six import PY2
34 from six.moves import xrange as six_xrange
35
36 from pyderasn import agg_octet_string
37 from pyderasn import Any
38 from pyderasn import Choice
39 from pyderasn import encode_cer
40 from pyderasn import file_mmaped
41 from pyderasn import Integer
42 from pyderasn import ObjectIdentifier
43 from pyderasn import OctetString
44 from pyderasn import Sequence
45 from pyderasn import SetOf
46 from pyderasn import tag_ctxc
47 from pyderasn import tag_ctxp
48 from tests.test_crts import AlgorithmIdentifier
49 from tests.test_crts import Certificate
50 from tests.test_crts import SubjectKeyIdentifier
51
52
53 class CMSVersion(Integer):
54     schema = (
55         ("v0", 0),
56         ("v1", 1),
57         ("v2", 2),
58         ("v3", 3),
59         ("v4", 4),
60         ("v5", 5),
61     )
62
63
64 class AttributeValue(Any):
65     pass
66
67
68 class AttributeValues(SetOf):
69     schema = AttributeValue()
70
71
72 class Attribute(Sequence):
73     schema = (
74         ("attrType", ObjectIdentifier()),
75         ("attrValues", AttributeValues()),
76     )
77
78
79 class SignatureAlgorithmIdentifier(AlgorithmIdentifier):
80     pass
81
82
83 class SignedAttributes(SetOf):
84     schema = Attribute()
85     bounds = (1, 32)
86     der_forced = True
87
88
89 class SignerIdentifier(Choice):
90     schema = (
91         # ("issuerAndSerialNumber", IssuerAndSerialNumber()),
92         ("subjectKeyIdentifier", SubjectKeyIdentifier(impl=tag_ctxp(0))),
93     )
94
95
96 class DigestAlgorithmIdentifiers(SetOf):
97     schema = AlgorithmIdentifier()
98
99
100 class DigestAlgorithmIdentifier(AlgorithmIdentifier):
101     pass
102
103
104 class SignatureValue(OctetString):
105     pass
106
107
108 class SignerInfo(Sequence):
109     schema = (
110         ("version", CMSVersion()),
111         ("sid", SignerIdentifier()),
112         ("digestAlgorithm", DigestAlgorithmIdentifier()),
113         ("signedAttrs", SignedAttributes(impl=tag_ctxc(0), optional=True)),
114         ("signatureAlgorithm", SignatureAlgorithmIdentifier()),
115         ("signature", SignatureValue()),
116         # ("unsignedAttrs", UnsignedAttributes(impl=tag_ctxc(1), optional=True)),
117     )
118
119
120 class SignerInfos(SetOf):
121     schema = SignerInfo()
122
123
124 class ContentType(ObjectIdentifier):
125     pass
126
127
128 class EncapsulatedContentInfo(Sequence):
129     schema = (
130         ("eContentType", ContentType()),
131         ("eContent", OctetString(expl=tag_ctxc(0), optional=True)),
132     )
133
134
135 class CertificateChoices(Choice):
136     schema = (
137         ('certificate', Certificate()),
138         # ...
139     )
140
141
142 class CertificateSet(SetOf):
143     schema = CertificateChoices()
144
145
146 class SignedData(Sequence):
147     schema = (
148         ("version", CMSVersion()),
149         ("digestAlgorithms", DigestAlgorithmIdentifiers()),
150         ("encapContentInfo", EncapsulatedContentInfo()),
151         ("certificates", CertificateSet(impl=tag_ctxc(0), optional=True)),
152         # ("crls", RevocationInfoChoices(impl=tag_ctxc(1), optional=True)),
153         ("signerInfos", SignerInfos()),
154     )
155
156
157 class ContentInfo(Sequence):
158     schema = (
159         ("contentType", ContentType()),
160         ("content", Any(expl=tag_ctxc(0))),
161     )
162
163
164 id_signedData = ObjectIdentifier("1.2.840.113549.1.7.2")
165 id_sha512 = ObjectIdentifier("2.16.840.1.101.3.4.2.3")
166 id_data = ObjectIdentifier("1.2.840.113549.1.7.1")
167 id_ecdsa_with_SHA512 = ObjectIdentifier("1.2.840.10045.4.3.4")
168 id_pkcs9_at_contentType = ObjectIdentifier("1.2.840.113549.1.9.3")
169 id_pkcs9_at_messageDigest = ObjectIdentifier("1.2.840.113549.1.9.4")
170 id_ce_subjectKeyIdentifier = ObjectIdentifier("2.5.29.14")
171 ai_sha512 = AlgorithmIdentifier((("algorithm", id_sha512),))
172
173 openssl_cms_exists = call("openssl cms -help 2>/dev/null", shell=True) == 0
174
175 @skipIf(not openssl_cms_exists, "openssl cms command not found")
176 class TestSignedDataCERWithOpenSSL(TestCase):
177     def tmpfile(self):
178         tmp = NamedTemporaryFile(delete=False)
179         tmp.close()
180         self.addCleanup(lambda: remove(tmp.name))
181         return tmp.name
182
183     def keypair(self):
184         key_path = self.tmpfile()
185         self.assertEqual(0, call(
186             "openssl ecparam -name secp521r1 -genkey -out " + key_path,
187             shell=True,
188         ))
189         cert_path = self.tmpfile()
190         self.assertEqual(0, call(" ".join((
191             "openssl req -x509 -new",
192             ("-key " + key_path),
193             ("-outform PEM -out " + cert_path),
194             "-nodes -subj /CN=pyderasntest",
195         )), shell=True))
196         cert_der_path = self.tmpfile()
197         self.assertEqual(0, call(" ".join((
198             "openssl x509",
199             "-inform PEM -in " + cert_path,
200             "-outform DER -out " + cert_der_path,
201         )), shell=True))
202         self.assertEqual(0, call("cat %s >> %s" % (key_path, cert_path), shell=True))
203         with open(cert_der_path, "rb") as fd:
204             cert = Certificate().decod(fd.read())
205         for ext in cert["tbsCertificate"]["extensions"]:
206             if ext["extnID"] == id_ce_subjectKeyIdentifier:
207                 skid = SubjectKeyIdentifier().decod(bytes(ext["extnValue"]))
208         return key_path, cert_path, cert, skid
209
210     def sign(self, signed_attrs, key_path):
211         input_path = self.tmpfile()
212         with open(input_path, "wb") as fd:
213             fd.write(encode_cer(signed_attrs))
214         signature_path = self.tmpfile()
215         self.assertEqual(0, call(" ".join((
216             "openssl dgst -sha512",
217             ("-sign " + key_path),
218             "-binary", input_path,
219             ("> " + signature_path),
220         )), shell=True))
221         with open(signature_path, "rb") as fd:
222             signature = fd.read()
223         return signature
224
225     def verify(self, cert_path, cms_path):
226         self.assertEqual(0, call(" ".join((
227             "openssl cms -verify",
228             ("-inform DER -in " + cms_path),
229             "-signer %s -CAfile %s" % (cert_path, cert_path),
230             "-out /dev/null 2>/dev/null",
231         )), shell=True))
232
233     @settings(deadline=None)
234     @given(integers(min_value=1000, max_value=5000))
235     def test_simple(self, data_len):
236         key_path, cert_path, cert, skid = self.keypair()
237         data = urandom(data_len)
238         eci = EncapsulatedContentInfo((
239             ("eContentType", ContentType(id_data)),
240             ("eContent", OctetString(data)),
241         ))
242         signed_attrs = SignedAttributes([
243             Attribute((
244                 ("attrType", id_pkcs9_at_contentType),
245                 ("attrValues", AttributeValues([AttributeValue(id_data)])),
246             )),
247             Attribute((
248                 ("attrType", id_pkcs9_at_messageDigest),
249                 ("attrValues", AttributeValues([
250                     AttributeValue(OctetString(
251                         sha512(bytes(eci["eContent"])).digest()
252                     )),
253                 ])),
254             )),
255         ])
256         signature = self.sign(signed_attrs, key_path)
257         ci = ContentInfo((
258             ("contentType", ContentType(id_signedData)),
259             ("content", Any((SignedData((
260                 ("version", CMSVersion("v3")),
261                 ("digestAlgorithms", DigestAlgorithmIdentifiers([ai_sha512])),
262                 ("encapContentInfo", eci),
263                 ("certificates", CertificateSet([
264                     CertificateChoices(("certificate", cert)),
265                 ])),
266                 ("signerInfos", SignerInfos([SignerInfo((
267                     ("version", CMSVersion("v3")),
268                     ("sid", SignerIdentifier(("subjectKeyIdentifier", skid))),
269                     ("digestAlgorithm", DigestAlgorithmIdentifier(ai_sha512)),
270                     ("signedAttrs", signed_attrs),
271                     ("signatureAlgorithm", SignatureAlgorithmIdentifier((
272                         ("algorithm", id_ecdsa_with_SHA512),
273                     ))),
274                     ("signature", SignatureValue(signature)),
275                 ))])),
276             ))))),
277         ))
278         cms_path = self.tmpfile()
279         with io_open(cms_path, "wb") as fd:
280             ci.encode_cer(writer=fd.write)
281         self.verify(cert_path, cms_path)
282         fd = open(cms_path, "rb")
283         raw = memoryview(fd.read()) if PY2 else file_mmaped(fd)
284         ctx = {"bered": True}
285         for decode_path, obj, _ in ContentInfo().decode_evgen(raw, ctx=ctx):
286             if decode_path == ("content",):
287                 break
288         evgens = SignedData().decode_evgen(raw[obj.offset:], offset=obj.offset, ctx=ctx)
289         buf = BytesIO()
290         agg_octet_string(evgens, ("encapContentInfo", "eContent"), raw, buf.write)
291         self.assertSequenceEqual(buf.getvalue(), data)
292
293     @skipIf(PY2, "no mmaped memoryview support in PY2")
294     @skipIf("PYDERASN_TEST_CMS_HUGE" not in environ, "PYDERASN_TEST_CMS_HUGE is not set")
295     def test_huge(self):
296         """Huge CMS test
297
298         Environment variable PYDERASN_TEST_CMS_HUGE tells how many MiBs
299         data to sign. Pay attention that openssl cms is unable to do
300         stream verification and eats huge amounts (several times more,
301         that CMS itself) of memory.
302         """
303         key_path, cert_path, cert, skid = self.keypair()
304         rnd = urandom(1<<20)
305         data_path = self.tmpfile()
306         start = time()
307         with open(data_path, "wb") as fd:
308             for _ in six_xrange(int(environ.get("PYDERASN_TEST_CMS_HUGE"))):
309                 # dgst.update(rnd)
310                 fd.write(rnd)
311         print("data file written", time() - start)
312         data_fd = open(data_path, "rb")
313         data_raw = file_mmaped(data_fd)
314
315         from sys import getallocatedblocks
316         mem_start = getallocatedblocks()
317         start = time()
318         eci = EncapsulatedContentInfo((
319             ("eContentType", ContentType(id_data)),
320             ("eContent", OctetString(data_raw)),
321         ))
322         eci_path = self.tmpfile()
323         with open(eci_path, "wb") as fd:
324             OctetString(eci["eContent"]).encode_cer(writer=fd.write)
325         print("ECI file written", time() - start)
326         eci_fd = open(eci_path, "rb")
327         eci_raw = file_mmaped(eci_fd)
328
329         start = time()
330         dgst = sha512()
331         def hasher(data):
332             dgst.update(data)
333             return len(data)
334         evgens = OctetString().decode_evgen(eci_raw, ctx={"bered": True})
335         agg_octet_string(evgens, (), eci_raw, hasher)
336         dgst = dgst.digest()
337         print("digest calculated", time() - start)
338
339         signed_attrs = SignedAttributes([
340             Attribute((
341                 ("attrType", id_pkcs9_at_contentType),
342                 ("attrValues", AttributeValues([AttributeValue(id_data)])),
343             )),
344             Attribute((
345                 ("attrType", id_pkcs9_at_messageDigest),
346                 ("attrValues", AttributeValues([AttributeValue(OctetString(dgst))])),
347             )),
348         ])
349         signature = self.sign(signed_attrs, key_path)
350
351         self.assertLess(getallocatedblocks(), mem_start * 2)
352         start = time()
353         ci = ContentInfo((
354             ("contentType", ContentType(id_signedData)),
355             ("content", Any((SignedData((
356                 ("version", CMSVersion("v3")),
357                 ("digestAlgorithms", DigestAlgorithmIdentifiers([ai_sha512])),
358                 ("encapContentInfo", eci),
359                 ("certificates", CertificateSet([
360                     CertificateChoices(("certificate", cert)),
361                 ])),
362                 ("signerInfos", SignerInfos([SignerInfo((
363                     ("version", CMSVersion("v3")),
364                     ("sid", SignerIdentifier(("subjectKeyIdentifier", skid))),
365                     ("digestAlgorithm", DigestAlgorithmIdentifier(ai_sha512)),
366                     ("signedAttrs", signed_attrs),
367                     ("signatureAlgorithm", SignatureAlgorithmIdentifier((
368                         ("algorithm", id_ecdsa_with_SHA512),
369                     ))),
370                     ("signature", SignatureValue(signature)),
371                 ))])),
372             ))))),
373         ))
374         cms_path = self.tmpfile()
375         with io_open(cms_path, "wb") as fd:
376             ci.encode_cer(writer=fd.write)
377         print("CMS written", time() - start)
378         self.verify(cert_path, cms_path)