]> Cypherpunks.ru repositories - gostls13.git/blob - src/crypto/rsa/pss.go
[dev.boringcrypto] all: merge master into dev.boringcrypto
[gostls13.git] / src / crypto / rsa / pss.go
1 // Copyright 2013 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package rsa
6
7 // This file implements the PSS signature scheme [1].
8 //
9 // [1] https://www.emc.com/collateral/white-papers/h11300-pkcs-1v2-2-rsa-cryptography-standard-wp.pdf
10
11 import (
12         "bytes"
13         "crypto"
14         "crypto/internal/boring"
15         "errors"
16         "hash"
17         "io"
18         "math/big"
19 )
20
21 func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byte, error) {
22         // See [1], section 9.1.1
23         hLen := hash.Size()
24         sLen := len(salt)
25         emLen := (emBits + 7) / 8
26
27         // 1.  If the length of M is greater than the input limitation for the
28         //     hash function (2^61 - 1 octets for SHA-1), output "message too
29         //     long" and stop.
30         //
31         // 2.  Let mHash = Hash(M), an octet string of length hLen.
32
33         if len(mHash) != hLen {
34                 return nil, errors.New("crypto/rsa: input must be hashed message")
35         }
36
37         // 3.  If emLen < hLen + sLen + 2, output "encoding error" and stop.
38
39         if emLen < hLen+sLen+2 {
40                 return nil, errors.New("crypto/rsa: key size too small for PSS signature")
41         }
42
43         em := make([]byte, emLen)
44         db := em[:emLen-sLen-hLen-2+1+sLen]
45         h := em[emLen-sLen-hLen-2+1+sLen : emLen-1]
46
47         // 4.  Generate a random octet string salt of length sLen; if sLen = 0,
48         //     then salt is the empty string.
49         //
50         // 5.  Let
51         //       M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt;
52         //
53         //     M' is an octet string of length 8 + hLen + sLen with eight
54         //     initial zero octets.
55         //
56         // 6.  Let H = Hash(M'), an octet string of length hLen.
57
58         var prefix [8]byte
59
60         hash.Write(prefix[:])
61         hash.Write(mHash)
62         hash.Write(salt)
63
64         h = hash.Sum(h[:0])
65         hash.Reset()
66
67         // 7.  Generate an octet string PS consisting of emLen - sLen - hLen - 2
68         //     zero octets. The length of PS may be 0.
69         //
70         // 8.  Let DB = PS || 0x01 || salt; DB is an octet string of length
71         //     emLen - hLen - 1.
72
73         db[emLen-sLen-hLen-2] = 0x01
74         copy(db[emLen-sLen-hLen-1:], salt)
75
76         // 9.  Let dbMask = MGF(H, emLen - hLen - 1).
77         //
78         // 10. Let maskedDB = DB \xor dbMask.
79
80         mgf1XOR(db, hash, h)
81
82         // 11. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in
83         //     maskedDB to zero.
84
85         db[0] &= (0xFF >> uint(8*emLen-emBits))
86
87         // 12. Let EM = maskedDB || H || 0xbc.
88         em[emLen-1] = 0xBC
89
90         // 13. Output EM.
91         return em, nil
92 }
93
94 func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error {
95         // 1.  If the length of M is greater than the input limitation for the
96         //     hash function (2^61 - 1 octets for SHA-1), output "inconsistent"
97         //     and stop.
98         //
99         // 2.  Let mHash = Hash(M), an octet string of length hLen.
100         hLen := hash.Size()
101         if hLen != len(mHash) {
102                 return ErrVerification
103         }
104
105         // 3.  If emLen < hLen + sLen + 2, output "inconsistent" and stop.
106         emLen := (emBits + 7) / 8
107         if emLen < hLen+sLen+2 {
108                 return ErrVerification
109         }
110
111         // 4.  If the rightmost octet of EM does not have hexadecimal value
112         //     0xbc, output "inconsistent" and stop.
113         if em[len(em)-1] != 0xBC {
114                 return ErrVerification
115         }
116
117         // 5.  Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and
118         //     let H be the next hLen octets.
119         db := em[:emLen-hLen-1]
120         h := em[emLen-hLen-1 : len(em)-1]
121
122         // 6.  If the leftmost 8 * emLen - emBits bits of the leftmost octet in
123         //     maskedDB are not all equal to zero, output "inconsistent" and
124         //     stop.
125         if em[0]&(0xFF<<uint(8-(8*emLen-emBits))) != 0 {
126                 return ErrVerification
127         }
128
129         // 7.  Let dbMask = MGF(H, emLen - hLen - 1).
130         //
131         // 8.  Let DB = maskedDB \xor dbMask.
132         mgf1XOR(db, hash, h)
133
134         // 9.  Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB
135         //     to zero.
136         db[0] &= (0xFF >> uint(8*emLen-emBits))
137
138         if sLen == PSSSaltLengthAuto {
139         FindSaltLength:
140                 for sLen = emLen - (hLen + 2); sLen >= 0; sLen-- {
141                         switch db[emLen-hLen-sLen-2] {
142                         case 1:
143                                 break FindSaltLength
144                         case 0:
145                                 continue
146                         default:
147                                 return ErrVerification
148                         }
149                 }
150                 if sLen < 0 {
151                         return ErrVerification
152                 }
153         } else {
154                 // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero
155                 //     or if the octet at position emLen - hLen - sLen - 1 (the leftmost
156                 //     position is "position 1") does not have hexadecimal value 0x01,
157                 //     output "inconsistent" and stop.
158                 for _, e := range db[:emLen-hLen-sLen-2] {
159                         if e != 0x00 {
160                                 return ErrVerification
161                         }
162                 }
163                 if db[emLen-hLen-sLen-2] != 0x01 {
164                         return ErrVerification
165                 }
166         }
167
168         // 11.  Let salt be the last sLen octets of DB.
169         salt := db[len(db)-sLen:]
170
171         // 12.  Let
172         //          M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ;
173         //     M' is an octet string of length 8 + hLen + sLen with eight
174         //     initial zero octets.
175         //
176         // 13. Let H' = Hash(M'), an octet string of length hLen.
177         var prefix [8]byte
178         hash.Write(prefix[:])
179         hash.Write(mHash)
180         hash.Write(salt)
181
182         h0 := hash.Sum(nil)
183
184         // 14. If H = H', output "consistent." Otherwise, output "inconsistent."
185         if !bytes.Equal(h0, h) {
186                 return ErrVerification
187         }
188         return nil
189 }
190
191 // signPSSWithSalt calculates the signature of hashed using PSS [1] with specified salt.
192 // Note that hashed must be the result of hashing the input message using the
193 // given hash function. salt is a random sequence of bytes whose length will be
194 // later used to verify the signature.
195 func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) (s []byte, err error) {
196         nBits := priv.N.BitLen()
197         em, err := emsaPSSEncode(hashed, nBits-1, salt, hash.New())
198         if err != nil {
199                 return
200         }
201
202         if boring.Enabled {
203                 boringFakeRandomBlind(rand, priv)
204                 bkey, err := boringPrivateKey(priv)
205                 if err != nil {
206                         return nil, err
207                 }
208                 // Note: BoringCrypto takes care of the "AndCheck" part of "decryptAndCheck".
209                 // (It's not just decrypt.)
210                 s, err := boring.DecryptRSANoPadding(bkey, em)
211                 if err != nil {
212                         return nil, err
213                 }
214                 return s, nil
215         }
216
217         m := new(big.Int).SetBytes(em)
218         c, err := decryptAndCheck(rand, priv, m)
219         if err != nil {
220                 return
221         }
222         s = make([]byte, (nBits+7)/8)
223         copyWithLeftPad(s, c.Bytes())
224         return
225 }
226
227 const (
228         // PSSSaltLengthAuto causes the salt in a PSS signature to be as large
229         // as possible when signing, and to be auto-detected when verifying.
230         PSSSaltLengthAuto = 0
231         // PSSSaltLengthEqualsHash causes the salt length to equal the length
232         // of the hash used in the signature.
233         PSSSaltLengthEqualsHash = -1
234 )
235
236 // PSSOptions contains options for creating and verifying PSS signatures.
237 type PSSOptions struct {
238         // SaltLength controls the length of the salt used in the PSS
239         // signature. It can either be a number of bytes, or one of the special
240         // PSSSaltLength constants.
241         SaltLength int
242
243         // Hash, if not zero, overrides the hash function passed to SignPSS.
244         // This is the only way to specify the hash function when using the
245         // crypto.Signer interface.
246         Hash crypto.Hash
247 }
248
249 // HashFunc returns pssOpts.Hash so that PSSOptions implements
250 // crypto.SignerOpts.
251 func (pssOpts *PSSOptions) HashFunc() crypto.Hash {
252         return pssOpts.Hash
253 }
254
255 func (opts *PSSOptions) saltLength() int {
256         if opts == nil {
257                 return PSSSaltLengthAuto
258         }
259         return opts.SaltLength
260 }
261
262 // SignPSS calculates the signature of hashed using RSASSA-PSS [1].
263 // Note that hashed must be the result of hashing the input message using the
264 // given hash function. The opts argument may be nil, in which case sensible
265 // defaults are used.
266 func SignPSS(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed []byte, opts *PSSOptions) ([]byte, error) {
267         saltLength := opts.saltLength()
268         switch saltLength {
269         case PSSSaltLengthAuto:
270                 saltLength = (priv.N.BitLen()+7)/8 - 2 - hash.Size()
271         case PSSSaltLengthEqualsHash:
272                 saltLength = hash.Size()
273         }
274
275         if opts != nil && opts.Hash != 0 {
276                 hash = opts.Hash
277         }
278
279         if boring.Enabled && rand == boring.RandReader {
280                 bkey, err := boringPrivateKey(priv)
281                 if err != nil {
282                         return nil, err
283                 }
284                 return boring.SignRSAPSS(bkey, hash, hashed, saltLength)
285         }
286
287         salt := make([]byte, saltLength)
288         if _, err := io.ReadFull(rand, salt); err != nil {
289                 return nil, err
290         }
291         return signPSSWithSalt(rand, priv, hash, hashed, salt)
292 }
293
294 // VerifyPSS verifies a PSS signature.
295 // hashed is the result of hashing the input message using the given hash
296 // function and sig is the signature. A valid signature is indicated by
297 // returning a nil error. The opts argument may be nil, in which case sensible
298 // defaults are used.
299 func VerifyPSS(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte, opts *PSSOptions) error {
300         return verifyPSS(pub, hash, hashed, sig, opts.saltLength())
301 }
302
303 // verifyPSS verifies a PSS signature with the given salt length.
304 func verifyPSS(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte, saltLen int) error {
305         if boring.Enabled {
306                 bkey, err := boringPublicKey(pub)
307                 if err != nil {
308                         return err
309                 }
310                 if err := boring.VerifyRSAPSS(bkey, hash, hashed, sig, saltLen); err != nil {
311                         return ErrVerification
312                 }
313                 return nil
314         }
315         nBits := pub.N.BitLen()
316         if len(sig) != (nBits+7)/8 {
317                 return ErrVerification
318         }
319         s := new(big.Int).SetBytes(sig)
320         m := encrypt(new(big.Int), pub, s)
321         emBits := nBits - 1
322         emLen := (emBits + 7) / 8
323         if emLen < len(m.Bytes()) {
324                 return ErrVerification
325         }
326         em := make([]byte, emLen)
327         copyWithLeftPad(em, m.Bytes())
328         if saltLen == PSSSaltLengthEqualsHash {
329                 saltLen = hash.Size()
330         }
331         return emsaPSSVerify(hashed, em, emBits, saltLen, hash.New())
332 }