]> Cypherpunks.ru repositories - pyderasn.git/blob - tests/test_cms.py
agg_octet_string
[pyderasn.git] / tests / test_cms.py
1 # coding: utf-8
2 # PyDERASN -- Python ASN.1 DER codec with abstract structures
3 # Copyright (C) 2017-2020 Sergey Matveev <stargrave@stargrave.org>
4 #
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Lesser General Public License as
7 # published by the Free Software Foundation, version 3 of the License.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU Lesser General Public License for more details.
13 #
14 # You should have received a copy of the GNU Lesser General Public
15 # License along with this program.  If not, see
16 # <http://www.gnu.org/licenses/>.
17
18 from hashlib import sha256
19 from io import BytesIO
20 from io import open as io_open
21 from os import remove
22 from os import urandom
23 from subprocess import call
24 from tempfile import NamedTemporaryFile
25 from unittest import skipIf
26 from unittest import TestCase
27
28 from hypothesis import given
29 from hypothesis import settings
30 from hypothesis.strategies import integers
31 from six import PY2
32
33 from pyderasn import agg_octet_string
34 from pyderasn import Any
35 from pyderasn import Choice
36 from pyderasn import encode_cer
37 from pyderasn import file_mmaped
38 from pyderasn import Integer
39 from pyderasn import ObjectIdentifier
40 from pyderasn import OctetString
41 from pyderasn import Sequence
42 from pyderasn import SetOf
43 from pyderasn import tag_ctxc
44 from pyderasn import tag_ctxp
45 from tests.test_crts import AlgorithmIdentifier
46 from tests.test_crts import Certificate
47 from tests.test_crts import SubjectKeyIdentifier
48
49
50 class CMSVersion(Integer):
51     schema = (
52         ("v0", 0),
53         ("v1", 1),
54         ("v2", 2),
55         ("v3", 3),
56         ("v4", 4),
57         ("v5", 5),
58     )
59
60
61 class AttributeValue(Any):
62     pass
63
64
65 class AttributeValues(SetOf):
66     schema = AttributeValue()
67
68
69 class Attribute(Sequence):
70     schema = (
71         ("attrType", ObjectIdentifier()),
72         ("attrValues", AttributeValues()),
73     )
74
75
76 class SignatureAlgorithmIdentifier(AlgorithmIdentifier):
77     pass
78
79
80 class SignedAttributes(SetOf):
81     schema = Attribute()
82     bounds = (1, 32)
83     der_forced = True
84
85
86 class SignerIdentifier(Choice):
87     schema = (
88         # ("issuerAndSerialNumber", IssuerAndSerialNumber()),
89         ("subjectKeyIdentifier", SubjectKeyIdentifier(impl=tag_ctxp(0))),
90     )
91
92
93 class DigestAlgorithmIdentifiers(SetOf):
94     schema = AlgorithmIdentifier()
95
96
97 class DigestAlgorithmIdentifier(AlgorithmIdentifier):
98     pass
99
100
101 class SignatureValue(OctetString):
102     pass
103
104
105 class SignerInfo(Sequence):
106     schema = (
107         ("version", CMSVersion()),
108         ("sid", SignerIdentifier()),
109         ("digestAlgorithm", DigestAlgorithmIdentifier()),
110         ("signedAttrs", SignedAttributes(impl=tag_ctxc(0), optional=True)),
111         ("signatureAlgorithm", SignatureAlgorithmIdentifier()),
112         ("signature", SignatureValue()),
113         # ("unsignedAttrs", UnsignedAttributes(impl=tag_ctxc(1), optional=True)),
114     )
115
116
117 class SignerInfos(SetOf):
118     schema = SignerInfo()
119
120
121 class ContentType(ObjectIdentifier):
122     pass
123
124
125 class EncapsulatedContentInfo(Sequence):
126     schema = (
127         ("eContentType", ContentType()),
128         ("eContent", OctetString(expl=tag_ctxc(0), optional=True)),
129     )
130
131
132 class CertificateChoices(Choice):
133     schema = (
134         ('certificate', Certificate()),
135         # ...
136     )
137
138
139 class CertificateSet(SetOf):
140     schema = CertificateChoices()
141
142
143 class SignedData(Sequence):
144     schema = (
145         ("version", CMSVersion()),
146         ("digestAlgorithms", DigestAlgorithmIdentifiers()),
147         ("encapContentInfo", EncapsulatedContentInfo()),
148         ("certificates", CertificateSet(impl=tag_ctxc(0), optional=True)),
149         # ("crls", RevocationInfoChoices(impl=tag_ctxc(1), optional=True)),
150         ("signerInfos", SignerInfos()),
151     )
152
153
154 class ContentInfo(Sequence):
155     schema = (
156         ("contentType", ContentType()),
157         ("content", Any(expl=tag_ctxc(0))),
158     )
159
160
161 id_signedData = ObjectIdentifier("1.2.840.113549.1.7.2")
162 id_sha256 = ObjectIdentifier("2.16.840.1.101.3.4.2.1")
163 id_data = ObjectIdentifier("1.2.840.113549.1.7.1")
164 id_ecdsa_with_SHA256 = ObjectIdentifier("1.2.840.10045.4.3.2")
165 id_pkcs9_at_contentType = ObjectIdentifier("1.2.840.113549.1.9.3")
166 id_pkcs9_at_messageDigest = ObjectIdentifier("1.2.840.113549.1.9.4")
167 id_ce_subjectKeyIdentifier = ObjectIdentifier("2.5.29.14")
168
169 openssl_cms_exists = call("openssl cms -help 2>/dev/null", shell=True) == 0
170
171 @skipIf(not openssl_cms_exists, "openssl cms command not found")
172 class TestSignedDataCERWithOpenSSL(TestCase):
173     @settings(deadline=None)
174     @given(integers(min_value=1000, max_value=5000))
175     def runTest(self, data_len):
176         def tmpfile():
177             tmp = NamedTemporaryFile(delete=False)
178             tmp.close()
179             tmp = tmp.name
180             self.addCleanup(lambda: remove(tmp))
181             return tmp
182         key_path = tmpfile()
183         self.assertEqual(0, call(
184             "openssl ecparam -name prime256v1 -genkey -out " + key_path,
185             shell=True,
186         ))
187         cert_path = tmpfile()
188         self.assertEqual(0, call(" ".join((
189             "openssl req -x509 -new",
190             ("-key " + key_path),
191             ("-outform PEM -out " + cert_path),
192             "-nodes -subj /CN=pyderasntest",
193         )), shell=True))
194         cert_der_path = tmpfile()
195         self.assertEqual(0, call(" ".join((
196             "openssl x509",
197             "-inform PEM -in " + cert_path,
198             "-outform DER -out " + cert_der_path,
199         )), shell=True))
200         self.assertEqual(0, call("cat %s >> %s" % (key_path, cert_path), shell=True))
201         with open(cert_der_path, "rb") as fd:
202             cert = Certificate().decod(fd.read())
203         for ext in cert["tbsCertificate"]["extensions"]:
204             if ext["extnID"] == id_ce_subjectKeyIdentifier:
205                 skid = SubjectKeyIdentifier().decod(bytes(ext["extnValue"]))
206         ai_sha256 = AlgorithmIdentifier((
207             ("algorithm", id_sha256),
208         ))
209         data = urandom(data_len)
210         eci = EncapsulatedContentInfo((
211             ("eContentType", ContentType(id_data)),
212             ("eContent", OctetString(data)),
213         ))
214         signed_attrs = SignedAttributes([
215             Attribute((
216                 ("attrType", id_pkcs9_at_contentType),
217                 ("attrValues", AttributeValues([
218                     AttributeValue(id_data.encode())
219                 ])),
220             )),
221             Attribute((
222                 ("attrType", id_pkcs9_at_messageDigest),
223                 ("attrValues", AttributeValues([
224                     AttributeValue(OctetString(
225                         sha256(bytes(eci["eContent"])).digest()
226                     ).encode()),
227                 ])),
228             )),
229         ])
230         input_path = tmpfile()
231         with open(input_path, "wb") as fd:
232             fd.write(encode_cer(signed_attrs))
233         signature_path = tmpfile()
234         self.assertEqual(0, call(" ".join((
235             "openssl dgst -sha256",
236             ("-sign " + key_path),
237             "-binary", input_path,
238             ("> " + signature_path),
239         )), shell=True))
240         with open(signature_path, "rb") as fd:
241             signature = fd.read()
242         ci = ContentInfo((
243             ("contentType", ContentType(id_signedData)),
244             ("content", Any((SignedData((
245                 ("version", CMSVersion("v3")),
246                 ("digestAlgorithms", DigestAlgorithmIdentifiers([ai_sha256])),
247                 ("encapContentInfo", eci),
248                 ("certificates", CertificateSet([
249                     CertificateChoices(("certificate", cert)),
250                 ])),
251                 ("signerInfos", SignerInfos([SignerInfo((
252                     ("version", CMSVersion("v3")),
253                     ("sid", SignerIdentifier(("subjectKeyIdentifier", skid))),
254                     ("digestAlgorithm", DigestAlgorithmIdentifier(ai_sha256)),
255                     ("signedAttrs", signed_attrs),
256                     ("signatureAlgorithm", SignatureAlgorithmIdentifier((
257                         ("algorithm", id_ecdsa_with_SHA256),
258                     ))),
259                     ("signature", SignatureValue(signature)),
260                 ))])),
261             ))))),
262         ))
263         output_path = tmpfile()
264         with io_open(output_path, "wb") as fd:
265             ci.encode_cer(writer=fd.write)
266         self.assertEqual(0, call(" ".join((
267             "openssl cms -verify",
268             ("-inform DER -in " + output_path),
269             "-signer %s -CAfile %s" % (cert_path, cert_path),
270             "-out /dev/null 2>/dev/null",
271         )), shell=True))
272         fd = open(output_path, "rb")
273         raw = memoryview(fd.read()) if PY2 else file_mmaped(fd)
274         ctx = {"bered": True}
275         for decode_path, obj, _ in ContentInfo().decode_evgen(raw, ctx=ctx):
276             if decode_path == ("content",):
277                 break
278         evgens = SignedData().decode_evgen(raw[obj.offset:], offset=obj.offset, ctx=ctx)
279         buf = BytesIO()
280         agg_octet_string(evgens, ("encapContentInfo", "eContent"), raw, buf.write)
281         self.assertSequenceEqual(buf.getvalue(), data)