]> Cypherpunks.ru repositories - gostls13.git/commitdiff
crypto/internal/bigmod: move nat implementation out of crypto/rsa
authorFilippo Valsorda <filippo@golang.org>
Sat, 12 Nov 2022 13:01:15 +0000 (14:01 +0100)
committerGopher Robot <gobot@golang.org>
Mon, 21 Nov 2022 16:19:15 +0000 (16:19 +0000)
This will let us reuse it in crypto/ecdsa for the NIST scalar fields.

The main change in API is around encoding and decoding. The SetBytes +
ExpandFor sequence was hacky: SetBytes could produce a bigger size than
the modulus if leading zeroes in the top byte overflowed the limb
boundary, so ExpandFor had to check for and tolerate that. Also, the
caller was responsible for checking that the overflow was actually all
zeroes (which we weren't doing, exposing a crasher in decryption and
signature verification) and then for checking that the result was less
than the modulus. Instead, make SetBytes take a modulus and return an
error if the value overflows. Same with Bytes: we were always allocating
based on Size before FillBytes anyway, so now Bytes takes a modulus.
Finally, SetBig was almost only used for moduli, so replaced
NewModulusFromNat and SetBig with NewModulusFromBig.

Moved the constant-time bitLen to math/big.Int.BitLen. It's slower, but
BitLen is primarily used in cryptographic code, so it's safer this way.

Change-Id: Ibaf7f36d80695578cb80484167d82ce1aa83832f
Reviewed-on: https://go-review.googlesource.com/c/go/+/450055
Auto-Submit: Filippo Valsorda <filippo@golang.org>
Run-TryBot: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Cherry Mui <cherryyz@google.com>
Reviewed-by: Roland Shoemaker <roland@golang.org>
src/crypto/internal/bigmod/nat.go [moved from src/crypto/rsa/nat.go with 77% similarity]
src/crypto/internal/bigmod/nat_test.go [new file with mode: 0644]
src/crypto/rsa/nat_test.go [deleted file]
src/crypto/rsa/pkcs1v15.go
src/crypto/rsa/pss.go
src/crypto/rsa/rsa.go
src/crypto/rsa/rsa_test.go
src/go/build/deps_test.go
src/math/big/int.go
src/math/big/nat.go

similarity index 77%
rename from src/crypto/rsa/nat.go
rename to src/crypto/internal/bigmod/nat.go
index 5398d10606760102fab38a06762cf1dbef6da7a5..679eb34b1f9578e4b275ab7b9101b1768106b59a 100644 (file)
@@ -2,9 +2,10 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-package rsa
+package bigmod
 
 import (
+       "errors"
        "math/big"
        "math/bits"
 )
@@ -53,12 +54,12 @@ func ctGeq(x, y uint) choice {
        return not(choice(carry))
 }
 
-// nat represents an arbitrary natural number
+// Nat represents an arbitrary natural number
 //
-// Each nat has an announced length, which is the number of limbs it has stored.
+// Each Nat has an announced length, which is the number of limbs it has stored.
 // Operations on this number are allowed to leak this length, but will not leak
 // any information about the values contained in those limbs.
-type nat struct {
+type Nat struct {
        // limbs is a little-endian representation in base 2^W with
        // W = bits.UintSize - 1. The top bit is always unset between operations.
        //
@@ -75,21 +76,18 @@ type nat struct {
 const preallocTarget = 2048
 const preallocLimbs = (preallocTarget + _W) / _W
 
-// newNat returns a new nat with a size of zero, just like new(nat), but with
+// NewNat returns a new nat with a size of zero, just like new(Nat), but with
 // the preallocated capacity to hold a number of up to preallocTarget bits.
-// newNat inlines, so the allocation can live on the stack.
-func newNat() *nat {
+// NewNat inlines, so the allocation can live on the stack.
+func NewNat() *Nat {
        limbs := make([]uint, 0, preallocLimbs)
-       return &nat{limbs}
+       return &Nat{limbs}
 }
 
 // expand expands x to n limbs, leaving its value unchanged.
-func (x *nat) expand(n int) *nat {
-       for len(x.limbs) > n {
-               if x.limbs[len(x.limbs)-1] != 0 {
-                       panic("rsa: internal error: shrinking nat")
-               }
-               x.limbs = x.limbs[:len(x.limbs)-1]
+func (x *Nat) expand(n int) *Nat {
+       if len(x.limbs) > n {
+               panic("bigmod: internal error: shrinking nat")
        }
        if cap(x.limbs) < n {
                newLimbs := make([]uint, n)
@@ -106,7 +104,7 @@ func (x *nat) expand(n int) *nat {
 }
 
 // reset returns a zero nat of n limbs, reusing x's storage if n <= cap(x.limbs).
-func (x *nat) reset(n int) *nat {
+func (x *Nat) reset(n int) *Nat {
        if cap(x.limbs) < n {
                x.limbs = make([]uint, n)
                return x
@@ -119,19 +117,18 @@ func (x *nat) reset(n int) *nat {
 }
 
 // set assigns x = y, optionally resizing x to the appropriate size.
-func (x *nat) set(y *nat) *nat {
+func (x *Nat) set(y *Nat) *Nat {
        x.reset(len(y.limbs))
        copy(x.limbs, y.limbs)
        return x
 }
 
-// set assigns x = n, optionally resizing n to the appropriate size.
+// setBig assigns x = n, optionally resizing n to the appropriate size.
 //
 // The announced length of x is set based on the actual bit size of the input,
 // ignoring leading zeroes.
-func (x *nat) setBig(n *big.Int) *nat {
-       bitSize := bigBitLen(n)
-       requiredLimbs := (bitSize + _W - 1) / _W
+func (x *Nat) setBig(n *big.Int) *Nat {
+       requiredLimbs := (n.BitLen() + _W - 1) / _W
        x.reset(requiredLimbs)
 
        outI := 0
@@ -154,20 +151,15 @@ func (x *nat) setBig(n *big.Int) *nat {
        return x
 }
 
-// fillBytes sets bytes to x as a zero-extended big-endian byte slice.
+// Bytes returns x as a zero-extended big-endian byte slice. The size of the
+// slice will match the size of m.
 //
-// If bytes is not long enough to contain the number or at least len(x.limbs)-1
-// limbs, or has zero length, fillBytes will panic.
-func (x *nat) fillBytes(bytes []byte) []byte {
-       if len(bytes) == 0 {
-               panic("nat: fillBytes invoked with too small buffer")
-       }
-       for i := range bytes {
-               bytes[i] = 0
-       }
+// x must have the same size as m and it must be reduced modulo m.
+func (x *Nat) Bytes(m *Modulus) []byte {
+       bytes := make([]byte, m.Size())
        shift := 0
        outI := len(bytes) - 1
-       for i, limb := range x.limbs {
+       for _, limb := range x.limbs {
                remainingBits := _W
                for remainingBits >= 8 {
                        bytes[outI] |= byte(limb) << shift
@@ -177,9 +169,6 @@ func (x *nat) fillBytes(bytes []byte) []byte {
                        shift = 0
                        outI--
                        if outI < 0 {
-                               if limb != 0 || i < len(x.limbs)-1 {
-                                       panic("nat: fillBytes invoked with too small buffer")
-                               }
                                return bytes
                        }
                }
@@ -189,18 +178,14 @@ func (x *nat) fillBytes(bytes []byte) []byte {
        return bytes
 }
 
-// setBytes assigns x = b, where b is a slice of big-endian bytes, optionally
-// resizing n to the appropriate size.
+// SetBytes assigns x = b, where b is a slice of big-endian bytes.
+// SetBytes returns an error if b > m.
 //
-// The announced length of the output depends only on the length of b. Unlike
-// big.Int, creating a nat will not remove leading zeros.
-func (x *nat) setBytes(b []byte) *nat {
-       bitSize := len(b) * 8
-       requiredLimbs := (bitSize + _W - 1) / _W
-       x.reset(requiredLimbs)
-
+// The output will be resized to the size of m and overwritten.
+func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) {
        outI := 0
        shift := 0
+       x.resetFor(m)
        for i := len(b) - 1; i >= 0; i-- {
                bi := b[i]
                x.limbs[outI] |= uint(bi) << shift
@@ -208,19 +193,27 @@ func (x *nat) setBytes(b []byte) *nat {
                if shift >= _W {
                        shift -= _W
                        x.limbs[outI] &= _MASK
+                       overflow := bi >> (8 - shift)
                        outI++
-                       if shift > 0 {
-                               x.limbs[outI] = uint(bi) >> (8 - shift)
+                       if outI >= len(x.limbs) {
+                               if overflow > 0 || i > 0 {
+                                       return nil, errors.New("input overflows the modulus")
+                               }
+                               break
                        }
+                       x.limbs[outI] = uint(overflow)
                }
        }
-       return x
+       if x.cmpGeq(m.nat) == yes {
+               return nil, errors.New("input overflows the modulus")
+       }
+       return x, nil
 }
 
-// cmpEq returns 1 if x == y, and 0 otherwise.
+// Equal returns 1 if x == y, and 0 otherwise.
 //
 // Both operands must have the same announced length.
-func (x *nat) cmpEq(y *nat) choice {
+func (x *Nat) Equal(y *Nat) choice {
        // Eliminate bounds checks in the loop.
        size := len(x.limbs)
        xLimbs := x.limbs[:size]
@@ -236,7 +229,7 @@ func (x *nat) cmpEq(y *nat) choice {
 // cmpGeq returns 1 if x >= y, and 0 otherwise.
 //
 // Both operands must have the same announced length.
-func (x *nat) cmpGeq(y *nat) choice {
+func (x *Nat) cmpGeq(y *Nat) choice {
        // Eliminate bounds checks in the loop.
        size := len(x.limbs)
        xLimbs := x.limbs[:size]
@@ -254,7 +247,7 @@ func (x *nat) cmpGeq(y *nat) choice {
 // assign sets x <- y if on == 1, and does nothing otherwise.
 //
 // Both operands must have the same announced length.
-func (x *nat) assign(on choice, y *nat) *nat {
+func (x *Nat) assign(on choice, y *Nat) *Nat {
        // Eliminate bounds checks in the loop.
        size := len(x.limbs)
        xLimbs := x.limbs[:size]
@@ -270,7 +263,7 @@ func (x *nat) assign(on choice, y *nat) *nat {
 // carry of the addition regardless of on.
 //
 // Both operands must have the same announced length.
-func (x *nat) add(on choice, y *nat) (c uint) {
+func (x *Nat) add(on choice, y *Nat) (c uint) {
        // Eliminate bounds checks in the loop.
        size := len(x.limbs)
        xLimbs := x.limbs[:size]
@@ -288,7 +281,7 @@ func (x *nat) add(on choice, y *nat) (c uint) {
 // borrow of the subtraction regardless of on.
 //
 // Both operands must have the same announced length.
-func (x *nat) sub(on choice, y *nat) (c uint) {
+func (x *Nat) sub(on choice, y *Nat) (c uint) {
        // Eliminate bounds checks in the loop.
        size := len(x.limbs)
        xLimbs := x.limbs[:size]
@@ -302,26 +295,26 @@ func (x *nat) sub(on choice, y *nat) (c uint) {
        return
 }
 
-// modulus is used for modular arithmetic, precomputing relevant constants.
+// Modulus is used for modular arithmetic, precomputing relevant constants.
 //
 // Moduli are assumed to be odd numbers. Moduli can also leak the exact
 // number of bits needed to store their value, and are stored without padding.
 //
 // Their actual value is still kept secret.
-type modulus struct {
+type Modulus struct {
        // The underlying natural number for this modulus.
        //
        // This will be stored without any padding, and shouldn't alias with any
        // other natural number being used.
-       nat     *nat
+       nat     *Nat
        leading int  // number of leading zeros in the modulus
        m0inv   uint // -nat.limbs[0]⁻¹ mod _W
-       RR      *nat // R*R for montgomeryRepresentation
+       rr      *Nat // R*R for montgomeryRepresentation
 }
 
 // rr returns R*R with R = 2^(_W * n) and n = len(m.nat.limbs).
-func rr(m *modulus) *nat {
-       rr := newNat().expandFor(m)
+func rr(m *Modulus) *Nat {
+       rr := NewNat().ExpandFor(m)
        // R*R is 2^(2 * _W * n). We can safely get 2^(_W * (n - 1)) by setting the
        // most significant limb to 1. We then get to R*R by shifting left by _W
        // n + 1 times.
@@ -351,21 +344,15 @@ func minusInverseModW(x uint) uint {
        return (1 << _W) - (y & _MASK)
 }
 
-// modulusFromNat creates a new modulus from a nat.
+// NewModulusFromBig creates a new Modulus from a [big.Int].
 //
-// The nat should be odd, nonzero, and the number of significant bits in the
-// number should be leakable. The nat shouldn't be reused.
-func modulusFromNat(nat *nat) *modulus {
-       m := &modulus{}
-       m.nat = nat
-       size := len(m.nat.limbs)
-       for m.nat.limbs[size-1] == 0 {
-               size--
-       }
-       m.nat.limbs = m.nat.limbs[:size]
-       m.leading = _W - bitLen(m.nat.limbs[size-1])
+// The Int must be odd. The number of significant bits must be leakable.
+func NewModulusFromBig(n *big.Int) *Modulus {
+       m := &Modulus{}
+       m.nat = NewNat().setBig(n)
+       m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1])
        m.m0inv = minusInverseModW(m.nat.limbs[0])
-       m.RR = rr(m)
+       m.rr = rr(m)
        return m
 }
 
@@ -383,26 +370,22 @@ func bitLen(n uint) int {
        return len
 }
 
-// bigBitLen is a version of big.Int.BitLen that only leaks the bit length of x,
-// but not its value. big.Int.BitLen uses bits.Len.
-func bigBitLen(x *big.Int) int {
-       xLimbs := x.Bits()
-       fullLimbs := len(xLimbs) - 1
-       topLimb := uint(xLimbs[len(xLimbs)-1])
-       return fullLimbs*bits.UintSize + bitLen(topLimb)
-}
-
-// modulusSize returns the size of m in bytes.
-func modulusSize(m *modulus) int {
+// Size returns the size of m in bytes.
+func (m *Modulus) Size() int {
        bits := len(m.nat.limbs)*_W - int(m.leading)
        return (bits + 7) / 8
 }
 
+// Nat returns m as a Nat. The return value must not be written to.
+func (m *Modulus) Nat() *Nat {
+       return m.nat
+}
+
 // shiftIn calculates x = x << _W + y mod m.
 //
 // This assumes that x is already reduced mod m, and that y < 2^_W.
-func (x *nat) shiftIn(y uint, m *modulus) *nat {
-       d := newNat().resetFor(m)
+func (x *Nat) shiftIn(y uint, m *Modulus) *Nat {
+       d := NewNat().resetFor(m)
 
        // Eliminate bounds checks in the loop.
        size := len(m.nat.limbs)
@@ -432,19 +415,19 @@ func (x *nat) shiftIn(y uint, m *modulus) *nat {
                        dLimbs[i] = res & _MASK
                        borrow = res >> _W
                }
-               // See modAdd for how carry (aka overflow), borrow (aka underflow), and
+               // See Add for how carry (aka overflow), borrow (aka underflow), and
                // needSubtraction relate.
                needSubtraction = ctEq(carry, borrow)
        }
        return x.assign(needSubtraction, d)
 }
 
-// mod calculates out = x mod m.
+// Mod calculates out = x mod m.
 //
 // This works regardless how large the value of x is.
 //
 // The output will be resized to the size of m and overwritten.
-func (out *nat) mod(x *nat, m *modulus) *nat {
+func (out *Nat) Mod(x *Nat, m *Modulus) *Nat {
        out.resetFor(m)
        // Working our way from the most significant to the least significant limb,
        // we can insert each limb at the least significant position, shifting all
@@ -470,38 +453,36 @@ func (out *nat) mod(x *nat, m *modulus) *nat {
        return out
 }
 
-// expandFor ensures out has the right size to work with operations modulo m.
+// ExpandFor ensures out has the right size to work with operations modulo m.
 //
-// This assumes that out has as many or fewer limbs than m, or that the extra
-// limbs are all zero (which may happen when decoding a value that has leading
-// zeroes in its bytes representation that spill over the limb threshold).
-func (out *nat) expandFor(m *modulus) *nat {
+// The announced size of out must be smaller than or equal to that of m.
+func (out *Nat) ExpandFor(m *Modulus) *Nat {
        return out.expand(len(m.nat.limbs))
 }
 
 // resetFor ensures out has the right size to work with operations modulo m.
 //
 // out is zeroed and may start at any size.
-func (out *nat) resetFor(m *modulus) *nat {
+func (out *Nat) resetFor(m *Modulus) *Nat {
        return out.reset(len(m.nat.limbs))
 }
 
-// modSub computes x = x - y mod m.
+// Sub computes x = x - y mod m.
 //
 // The length of both operands must be the same as the modulus. Both operands
 // must already be reduced modulo m.
-func (x *nat) modSub(y *nat, m *modulus) *nat {
+func (x *Nat) Sub(y *Nat, m *Modulus) *Nat {
        underflow := x.sub(yes, y)
        // If the subtraction underflowed, add m.
        x.add(choice(underflow), m.nat)
        return x
 }
 
-// modAdd computes x = x + y mod m.
+// Add computes x = x + y mod m.
 //
 // The length of both operands must be the same as the modulus. Both operands
 // must already be reduced modulo m.
-func (x *nat) modAdd(y *nat, m *modulus) *nat {
+func (x *Nat) Add(y *Nat, m *Modulus) *Nat {
        overflow := x.add(yes, y)
        underflow := not(x.cmpGeq(m.nat)) // x < m
 
@@ -540,22 +521,22 @@ func (x *nat) modAdd(y *nat, m *modulus) *nat {
 // numbers in this representation.
 //
 // This assumes that x is already reduced mod m.
-func (x *nat) montgomeryRepresentation(m *modulus) *nat {
+func (x *Nat) montgomeryRepresentation(m *Modulus) *Nat {
        // A Montgomery multiplication (which computes a * b / R) by R * R works out
        // to a multiplication by R, which takes the value out of the Montgomery domain.
-       return x.montgomeryMul(newNat().set(x), m.RR, m)
+       return x.montgomeryMul(NewNat().set(x), m.rr, m)
 }
 
 // montgomeryReduction calculates x = x / R mod m, with R = 2^(_W * n) and
 // n = len(m.nat.limbs).
 //
 // This assumes that x is already reduced mod m.
-func (x *nat) montgomeryReduction(m *modulus) *nat {
+func (x *Nat) montgomeryReduction(m *Modulus) *Nat {
        // By Montgomery multiplying with 1 not in Montgomery representation, we
        // convert out back from Montgomery representation, because it works out to
        // dividing by R.
-       t0 := newNat().set(x)
-       t1 := newNat().expandFor(m)
+       t0 := NewNat().set(x)
+       t1 := NewNat().ExpandFor(m)
        t1.limbs[0] = 1
        return x.montgomeryMul(t0, t1, m)
 }
@@ -565,7 +546,7 @@ func (x *nat) montgomeryReduction(m *modulus) *nat {
 //
 // All inputs should be the same length, not aliasing d, and already
 // reduced modulo m. d will be resized to the size of m and overwritten.
-func (d *nat) montgomeryMul(a *nat, b *nat, m *modulus) *nat {
+func (d *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat {
        // See https://bearssl.org/bigint.html#montgomery-reduction-and-multiplication
        // for a description of the algorithm.
 
@@ -599,7 +580,7 @@ func (d *nat) montgomeryMul(a *nat, b *nat, m *modulus) *nat {
                dLimbs[size-1] = z & _MASK
                overflow = z >> _W // overflow <= 1
        }
-       // See modAdd for how overflow, underflow, and needSubtraction relate.
+       // See Add for how overflow, underflow, and needSubtraction relate.
        underflow := not(d.cmpGeq(m.nat)) // d < m
        needSubtraction := ctEq(overflow, uint(underflow))
        d.sub(needSubtraction, m.nat)
@@ -607,31 +588,31 @@ func (d *nat) montgomeryMul(a *nat, b *nat, m *modulus) *nat {
        return d
 }
 
-// modMul calculates x *= y mod m.
+// Mul calculates x *= y mod m.
 //
 // x and y must already be reduced modulo m, they must share its announced
 // length, and they may not alias.
-func (x *nat) modMul(y *nat, m *modulus) *nat {
+func (x *Nat) Mul(y *Nat, m *Modulus) *Nat {
        // A Montgomery multiplication by a value out of the Montgomery domain
        // takes the result out of Montgomery representation.
-       xR := newNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m
+       xR := NewNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m
        return x.montgomeryMul(xR, y, m)                  // x = xR * y / R mod m
 }
 
-// exp calculates out = x^e mod m.
+// Exp calculates out = x^e mod m.
 //
 // The exponent e is represented in big-endian order. The output will be resized
 // to the size of m and overwritten. x must already be reduced modulo m.
-func (out *nat) exp(x *nat, e []byte, m *modulus) *nat {
+func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
        // We use a 4 bit window. For our RSA workload, 4 bit windows are faster
        // than 2 bit windows, but use an extra 12 nats worth of scratch space.
        // Using bit sizes that don't divide 8 are more complex to implement.
 
-       table := [(1 << 4) - 1]*nat{ // table[i] = x ^ (i+1)
+       table := [(1 << 4) - 1]*Nat{ // table[i] = x ^ (i+1)
                // newNat calls are unrolled so they are allocated on the stack.
-               newNat(), newNat(), newNat(), newNat(), newNat(),
-               newNat(), newNat(), newNat(), newNat(), newNat(),
-               newNat(), newNat(), newNat(), newNat(), newNat(),
+               NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
+               NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
+               NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
        }
        table[0].set(x).montgomeryRepresentation(m)
        for i := 1; i < len(table); i++ {
@@ -641,8 +622,8 @@ func (out *nat) exp(x *nat, e []byte, m *modulus) *nat {
        out.resetFor(m)
        out.limbs[0] = 1
        out.montgomeryRepresentation(m)
-       t0 := newNat().expandFor(m)
-       t1 := newNat().expandFor(m)
+       t0 := NewNat().ExpandFor(m)
+       t1 := NewNat().ExpandFor(m)
        for _, b := range e {
                for _, j := range []int{4, 0} {
                        // Square four times.
diff --git a/src/crypto/internal/bigmod/nat_test.go b/src/crypto/internal/bigmod/nat_test.go
new file mode 100644 (file)
index 0000000..6431d25
--- /dev/null
@@ -0,0 +1,412 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package bigmod
+
+import (
+       "math/big"
+       "math/bits"
+       "math/rand"
+       "reflect"
+       "testing"
+       "testing/quick"
+)
+
+// Generate generates an even nat. It's used by testing/quick to produce random
+// *nat values for quick.Check invocations.
+func (*Nat) Generate(r *rand.Rand, size int) reflect.Value {
+       limbs := make([]uint, size)
+       for i := 0; i < size; i++ {
+               limbs[i] = uint(r.Uint64()) & ((1 << _W) - 2)
+       }
+       return reflect.ValueOf(&Nat{limbs})
+}
+
+func testModAddCommutative(a *Nat, b *Nat) bool {
+       m := maxModulus(uint(len(a.limbs)))
+       aPlusB := new(Nat).set(a)
+       aPlusB.Add(b, m)
+       bPlusA := new(Nat).set(b)
+       bPlusA.Add(a, m)
+       return aPlusB.Equal(bPlusA) == 1
+}
+
+func TestModAddCommutative(t *testing.T) {
+       err := quick.Check(testModAddCommutative, &quick.Config{})
+       if err != nil {
+               t.Error(err)
+       }
+}
+
+func testModSubThenAddIdentity(a *Nat, b *Nat) bool {
+       m := maxModulus(uint(len(a.limbs)))
+       original := new(Nat).set(a)
+       a.Sub(b, m)
+       a.Add(b, m)
+       return a.Equal(original) == 1
+}
+
+func TestModSubThenAddIdentity(t *testing.T) {
+       err := quick.Check(testModSubThenAddIdentity, &quick.Config{})
+       if err != nil {
+               t.Error(err)
+       }
+}
+
+func testMontgomeryRoundtrip(a *Nat) bool {
+       one := &Nat{make([]uint, len(a.limbs))}
+       one.limbs[0] = 1
+       aPlusOne := new(big.Int).SetBytes(natBytes(a))
+       aPlusOne.Add(aPlusOne, big.NewInt(1))
+       m := NewModulusFromBig(aPlusOne)
+       monty := new(Nat).set(a)
+       monty.montgomeryRepresentation(m)
+       aAgain := new(Nat).set(monty)
+       aAgain.montgomeryMul(monty, one, m)
+       return a.Equal(aAgain) == 1
+}
+
+func TestMontgomeryRoundtrip(t *testing.T) {
+       err := quick.Check(testMontgomeryRoundtrip, &quick.Config{})
+       if err != nil {
+               t.Error(err)
+       }
+}
+
+func TestShiftIn(t *testing.T) {
+       if bits.UintSize != 64 {
+               t.Skip("examples are only valid in 64 bit")
+       }
+       examples := []struct {
+               m, x, expected []byte
+               y              uint64
+       }{{
+               m:        []byte{13},
+               x:        []byte{0},
+               y:        0x7FFF_FFFF_FFFF_FFFF,
+               expected: []byte{7},
+       }, {
+               m:        []byte{13},
+               x:        []byte{7},
+               y:        0x7FFF_FFFF_FFFF_FFFF,
+               expected: []byte{11},
+       }, {
+               m:        []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
+               x:        make([]byte, 9),
+               y:        0x7FFF_FFFF_FFFF_FFFF,
+               expected: []byte{0x00, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+       }, {
+               m:        []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
+               x:        []byte{0x00, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+               y:        0,
+               expected: []byte{0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08},
+       }}
+
+       for i, tt := range examples {
+               m := modulusFromBytes(tt.m)
+               got := natFromBytes(tt.x).ExpandFor(m).shiftIn(uint(tt.y), m)
+               if got.Equal(natFromBytes(tt.expected).ExpandFor(m)) != 1 {
+                       t.Errorf("%d: got %x, expected %x", i, got, tt.expected)
+               }
+       }
+}
+
+func TestModulusAndNatSizes(t *testing.T) {
+       // These are 126 bit (2 * _W on 64-bit architectures) values, serialized as
+       // 128 bits worth of bytes. If leading zeroes are stripped, they fit in two
+       // limbs, if they are not, they fit in three. This can be a problem because
+       // modulus strips leading zeroes and nat does not.
+       m := modulusFromBytes([]byte{
+               0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+               0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
+       xb := []byte{0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+               0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}
+       natFromBytes(xb).ExpandFor(m) // must not panic for shrinking
+       NewNat().SetBytes(xb, m)
+}
+
+func TestSetBytes(t *testing.T) {
+       tests := []struct {
+               m, b []byte
+               fail bool
+       }{{
+               m: []byte{0xff, 0xff},
+               b: []byte{0x00, 0x01},
+       }, {
+               m:    []byte{0xff, 0xff},
+               b:    []byte{0xff, 0xff},
+               fail: true,
+       }, {
+               m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+               b: []byte{0x00, 0x01},
+       }, {
+               m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+               b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
+       }, {
+               m:    []byte{0xff, 0xff},
+               b:    []byte{0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+               fail: true,
+       }, {
+               m:    []byte{0xff, 0xff},
+               b:    []byte{0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+               fail: true,
+       }, {
+               m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+               b: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
+       }, {
+               m:    []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+               b:    []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
+               fail: true,
+       }, {
+               m:    []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+               b:    []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+               fail: true,
+       }, {
+               m:    []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+               b:    []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
+               fail: true,
+       }, {
+               m:    []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfd},
+               b:    []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+               fail: true,
+       }}
+
+       for i, tt := range tests {
+               m := modulusFromBytes(tt.m)
+               got, err := NewNat().SetBytes(tt.b, m)
+               if err != nil {
+                       if !tt.fail {
+                               t.Errorf("%d: unexpected error: %v", i, err)
+                       }
+                       continue
+               }
+               if err == nil && tt.fail {
+                       t.Errorf("%d: unexpected success", i)
+                       continue
+               }
+               if expected := natFromBytes(tt.b).ExpandFor(m); got.Equal(expected) != yes {
+                       t.Errorf("%d: got %x, expected %x", i, got, expected)
+               }
+       }
+
+       f := func(xBytes []byte) bool {
+               m := maxModulus(uint(len(xBytes)*8/_W + 1))
+               got, err := NewNat().SetBytes(xBytes, m)
+               if err != nil {
+                       return false
+               }
+               return got.Equal(natFromBytes(xBytes).ExpandFor(m)) == yes
+       }
+
+       err := quick.Check(f, &quick.Config{})
+       if err != nil {
+               t.Error(err)
+       }
+}
+
+func TestExpand(t *testing.T) {
+       sliced := []uint{1, 2, 3, 4}
+       examples := []struct {
+               in  []uint
+               n   int
+               out []uint
+       }{{
+               []uint{1, 2},
+               4,
+               []uint{1, 2, 0, 0},
+       }, {
+               sliced[:2],
+               4,
+               []uint{1, 2, 0, 0},
+       }, {
+               []uint{1, 2},
+               2,
+               []uint{1, 2},
+       }}
+
+       for i, tt := range examples {
+               got := (&Nat{tt.in}).expand(tt.n)
+               if len(got.limbs) != len(tt.out) || got.Equal(&Nat{tt.out}) != 1 {
+                       t.Errorf("%d: got %x, expected %x", i, got, tt.out)
+               }
+       }
+}
+
+func TestMod(t *testing.T) {
+       m := modulusFromBytes([]byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d})
+       x := natFromBytes([]byte{0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01})
+       out := new(Nat)
+       out.Mod(x, m)
+       expected := natFromBytes([]byte{0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09})
+       if out.Equal(expected) != 1 {
+               t.Errorf("%+v != %+v", out, expected)
+       }
+}
+
+func TestModSub(t *testing.T) {
+       m := modulusFromBytes([]byte{13})
+       x := &Nat{[]uint{6}}
+       y := &Nat{[]uint{7}}
+       x.Sub(y, m)
+       expected := &Nat{[]uint{12}}
+       if x.Equal(expected) != 1 {
+               t.Errorf("%+v != %+v", x, expected)
+       }
+       x.Sub(y, m)
+       expected = &Nat{[]uint{5}}
+       if x.Equal(expected) != 1 {
+               t.Errorf("%+v != %+v", x, expected)
+       }
+}
+
+func TestModAdd(t *testing.T) {
+       m := modulusFromBytes([]byte{13})
+       x := &Nat{[]uint{6}}
+       y := &Nat{[]uint{7}}
+       x.Add(y, m)
+       expected := &Nat{[]uint{0}}
+       if x.Equal(expected) != 1 {
+               t.Errorf("%+v != %+v", x, expected)
+       }
+       x.Add(y, m)
+       expected = &Nat{[]uint{7}}
+       if x.Equal(expected) != 1 {
+               t.Errorf("%+v != %+v", x, expected)
+       }
+}
+
+func TestExp(t *testing.T) {
+       m := modulusFromBytes([]byte{13})
+       x := &Nat{[]uint{3}}
+       out := &Nat{[]uint{0}}
+       out.Exp(x, []byte{12}, m)
+       expected := &Nat{[]uint{1}}
+       if out.Equal(expected) != 1 {
+               t.Errorf("%+v != %+v", out, expected)
+       }
+}
+
+func natBytes(n *Nat) []byte {
+       return n.Bytes(maxModulus(uint(len(n.limbs))))
+}
+
+func natFromBytes(b []byte) *Nat {
+       bb := new(big.Int).SetBytes(b)
+       return NewNat().setBig(bb)
+}
+
+func modulusFromBytes(b []byte) *Modulus {
+       bb := new(big.Int).SetBytes(b)
+       return NewModulusFromBig(bb)
+}
+
+// maxModulus returns the biggest modulus that can fit in n limbs.
+func maxModulus(n uint) *Modulus {
+       m := big.NewInt(1)
+       m.Lsh(m, n*_W)
+       m.Sub(m, big.NewInt(1))
+       return NewModulusFromBig(m)
+}
+
+func makeBenchmarkModulus() *Modulus {
+       return maxModulus(32)
+}
+
+func makeBenchmarkValue() *Nat {
+       x := make([]uint, 32)
+       for i := 0; i < 32; i++ {
+               x[i] = _MASK - 1
+       }
+       return &Nat{limbs: x}
+}
+
+func makeBenchmarkExponent() []byte {
+       e := make([]byte, 256)
+       for i := 0; i < 32; i++ {
+               e[i] = 0xFF
+       }
+       return e
+}
+
+func BenchmarkModAdd(b *testing.B) {
+       x := makeBenchmarkValue()
+       y := makeBenchmarkValue()
+       m := makeBenchmarkModulus()
+
+       b.ResetTimer()
+       for i := 0; i < b.N; i++ {
+               x.Add(y, m)
+       }
+}
+
+func BenchmarkModSub(b *testing.B) {
+       x := makeBenchmarkValue()
+       y := makeBenchmarkValue()
+       m := makeBenchmarkModulus()
+
+       b.ResetTimer()
+       for i := 0; i < b.N; i++ {
+               x.Sub(y, m)
+       }
+}
+
+func BenchmarkMontgomeryRepr(b *testing.B) {
+       x := makeBenchmarkValue()
+       m := makeBenchmarkModulus()
+
+       b.ResetTimer()
+       for i := 0; i < b.N; i++ {
+               x.montgomeryRepresentation(m)
+       }
+}
+
+func BenchmarkMontgomeryMul(b *testing.B) {
+       x := makeBenchmarkValue()
+       y := makeBenchmarkValue()
+       out := makeBenchmarkValue()
+       m := makeBenchmarkModulus()
+
+       b.ResetTimer()
+       for i := 0; i < b.N; i++ {
+               out.montgomeryMul(x, y, m)
+       }
+}
+
+func BenchmarkModMul(b *testing.B) {
+       x := makeBenchmarkValue()
+       y := makeBenchmarkValue()
+       m := makeBenchmarkModulus()
+
+       b.ResetTimer()
+       for i := 0; i < b.N; i++ {
+               x.Mul(y, m)
+       }
+}
+
+func BenchmarkExpBig(b *testing.B) {
+       out := new(big.Int)
+       exponentBytes := makeBenchmarkExponent()
+       x := new(big.Int).SetBytes(exponentBytes)
+       e := new(big.Int).SetBytes(exponentBytes)
+       n := new(big.Int).SetBytes(exponentBytes)
+       one := new(big.Int).SetUint64(1)
+       n.Add(n, one)
+
+       b.ResetTimer()
+       for i := 0; i < b.N; i++ {
+               out.Exp(x, e, n)
+       }
+}
+
+func BenchmarkExp(b *testing.B) {
+       x := makeBenchmarkValue()
+       e := makeBenchmarkExponent()
+       out := makeBenchmarkValue()
+       m := makeBenchmarkModulus()
+
+       b.ResetTimer()
+       for i := 0; i < b.N; i++ {
+               out.Exp(x, e, m)
+       }
+}
diff --git a/src/crypto/rsa/nat_test.go b/src/crypto/rsa/nat_test.go
deleted file mode 100644 (file)
index d72ba11..0000000
+++ /dev/null
@@ -1,384 +0,0 @@
-// Copyright 2021 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package rsa
-
-import (
-       "bytes"
-       "math/big"
-       "math/bits"
-       "math/rand"
-       "reflect"
-       "testing"
-       "testing/quick"
-)
-
-// Generate generates an even nat. It's used by testing/quick to produce random
-// *nat values for quick.Check invocations.
-func (*nat) Generate(r *rand.Rand, size int) reflect.Value {
-       limbs := make([]uint, size)
-       for i := 0; i < size; i++ {
-               limbs[i] = uint(r.Uint64()) & ((1 << _W) - 2)
-       }
-       return reflect.ValueOf(&nat{limbs})
-}
-
-func testModAddCommutative(a *nat, b *nat) bool {
-       mLimbs := make([]uint, len(a.limbs))
-       for i := 0; i < len(mLimbs); i++ {
-               mLimbs[i] = _MASK
-       }
-       m := modulusFromNat(&nat{mLimbs})
-       aPlusB := new(nat).set(a)
-       aPlusB.modAdd(b, m)
-       bPlusA := new(nat).set(b)
-       bPlusA.modAdd(a, m)
-       return aPlusB.cmpEq(bPlusA) == 1
-}
-
-func TestModAddCommutative(t *testing.T) {
-       err := quick.Check(testModAddCommutative, &quick.Config{})
-       if err != nil {
-               t.Error(err)
-       }
-}
-
-func testModSubThenAddIdentity(a *nat, b *nat) bool {
-       mLimbs := make([]uint, len(a.limbs))
-       for i := 0; i < len(mLimbs); i++ {
-               mLimbs[i] = _MASK
-       }
-       m := modulusFromNat(&nat{mLimbs})
-       original := new(nat).set(a)
-       a.modSub(b, m)
-       a.modAdd(b, m)
-       return a.cmpEq(original) == 1
-}
-
-func TestModSubThenAddIdentity(t *testing.T) {
-       err := quick.Check(testModSubThenAddIdentity, &quick.Config{})
-       if err != nil {
-               t.Error(err)
-       }
-}
-
-func testMontgomeryRoundtrip(a *nat) bool {
-       one := &nat{make([]uint, len(a.limbs))}
-       one.limbs[0] = 1
-       aPlusOne := new(nat).set(a)
-       aPlusOne.add(1, one)
-       m := modulusFromNat(aPlusOne)
-       monty := new(nat).set(a)
-       monty.montgomeryRepresentation(m)
-       aAgain := new(nat).set(monty)
-       aAgain.montgomeryMul(monty, one, m)
-       return a.cmpEq(aAgain) == 1
-}
-
-func TestMontgomeryRoundtrip(t *testing.T) {
-       err := quick.Check(testMontgomeryRoundtrip, &quick.Config{})
-       if err != nil {
-               t.Error(err)
-       }
-}
-
-func TestFromBig(t *testing.T) {
-       expected := []byte{0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
-       theBig := new(big.Int).SetBytes(expected)
-       actual := new(nat).setBig(theBig).fillBytes(make([]byte, len(expected)))
-       if !bytes.Equal(actual, expected) {
-               t.Errorf("%+x != %+x", actual, expected)
-       }
-}
-
-func TestFillBytes(t *testing.T) {
-       xBytes := []byte{0xAA, 0xFF, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}
-       x := new(nat).setBytes(xBytes)
-       for l := 20; l >= len(xBytes); l-- {
-               buf := make([]byte, l)
-               rand.Read(buf)
-               actual := x.fillBytes(buf)
-               expected := make([]byte, l)
-               copy(expected[l-len(xBytes):], xBytes)
-               if !bytes.Equal(actual, expected) {
-                       t.Errorf("%d: %+v != %+v", l, actual, expected)
-               }
-       }
-       for l := len(xBytes) - 1; l >= 0; l-- {
-               (func() {
-                       defer func() {
-                               if recover() == nil {
-                                       t.Errorf("%d: expected panic", l)
-                               }
-                       }()
-                       x.fillBytes(make([]byte, l))
-               })()
-       }
-}
-
-func TestFromBytes(t *testing.T) {
-       f := func(xBytes []byte) bool {
-               if len(xBytes) == 0 {
-                       return true
-               }
-               actual := new(nat).setBytes(xBytes).fillBytes(make([]byte, len(xBytes)))
-               if !bytes.Equal(actual, xBytes) {
-                       t.Errorf("%+x != %+x", actual, xBytes)
-                       return false
-               }
-               return true
-       }
-
-       err := quick.Check(f, &quick.Config{})
-       if err != nil {
-               t.Error(err)
-       }
-
-       f([]byte{0xFF, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88})
-       f(bytes.Repeat([]byte{0xFF}, _W))
-}
-
-func TestShiftIn(t *testing.T) {
-       if bits.UintSize != 64 {
-               t.Skip("examples are only valid in 64 bit")
-       }
-       examples := []struct {
-               m, x, expected []byte
-               y              uint64
-       }{{
-               m:        []byte{13},
-               x:        []byte{0},
-               y:        0x7FFF_FFFF_FFFF_FFFF,
-               expected: []byte{7},
-       }, {
-               m:        []byte{13},
-               x:        []byte{7},
-               y:        0x7FFF_FFFF_FFFF_FFFF,
-               expected: []byte{11},
-       }, {
-               m:        []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
-               x:        make([]byte, 9),
-               y:        0x7FFF_FFFF_FFFF_FFFF,
-               expected: []byte{0x00, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
-       }, {
-               m:        []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
-               x:        []byte{0x00, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
-               y:        0,
-               expected: []byte{0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08},
-       }}
-
-       for i, tt := range examples {
-               m := modulusFromNat(new(nat).setBytes(tt.m))
-               got := new(nat).setBytes(tt.x).expandFor(m).shiftIn(uint(tt.y), m)
-               if got.cmpEq(new(nat).setBytes(tt.expected).expandFor(m)) != 1 {
-                       t.Errorf("%d: got %x, expected %x", i, got, tt.expected)
-               }
-       }
-}
-
-func TestModulusAndNatSizes(t *testing.T) {
-       // These are 126 bit (2 * _W on 64-bit architectures) values, serialized as
-       // 128 bits worth of bytes. If leading zeroes are stripped, they fit in two
-       // limbs, if they are not, they fit in three. This can be a problem because
-       // modulus strips leading zeroes and nat does not.
-       m := modulusFromNat(new(nat).setBytes([]byte{
-               0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
-               0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}))
-       x := new(nat).setBytes([]byte{
-               0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
-               0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe})
-       x.expandFor(m) // must not panic for shrinking
-}
-
-func TestExpand(t *testing.T) {
-       sliced := []uint{1, 2, 3, 4}
-       examples := []struct {
-               in  []uint
-               n   int
-               out []uint
-       }{{
-               []uint{1, 2},
-               4,
-               []uint{1, 2, 0, 0},
-       }, {
-               sliced[:2],
-               4,
-               []uint{1, 2, 0, 0},
-       }, {
-               []uint{1, 2},
-               2,
-               []uint{1, 2},
-       }, {
-               []uint{1, 2, 0},
-               2,
-               []uint{1, 2},
-       }}
-
-       for i, tt := range examples {
-               got := (&nat{tt.in}).expand(tt.n)
-               if len(got.limbs) != len(tt.out) || got.cmpEq(&nat{tt.out}) != 1 {
-                       t.Errorf("%d: got %x, expected %x", i, got, tt.out)
-               }
-       }
-}
-
-func TestMod(t *testing.T) {
-       m := modulusFromNat(new(nat).setBytes([]byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d}))
-       x := new(nat).setBytes([]byte{0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01})
-       out := new(nat)
-       out.mod(x, m)
-       expected := new(nat).setBytes([]byte{0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09})
-       if out.cmpEq(expected) != 1 {
-               t.Errorf("%+v != %+v", out, expected)
-       }
-}
-
-func TestModSub(t *testing.T) {
-       m := modulusFromNat(&nat{[]uint{13}})
-       x := &nat{[]uint{6}}
-       y := &nat{[]uint{7}}
-       x.modSub(y, m)
-       expected := &nat{[]uint{12}}
-       if x.cmpEq(expected) != 1 {
-               t.Errorf("%+v != %+v", x, expected)
-       }
-       x.modSub(y, m)
-       expected = &nat{[]uint{5}}
-       if x.cmpEq(expected) != 1 {
-               t.Errorf("%+v != %+v", x, expected)
-       }
-}
-
-func TestModAdd(t *testing.T) {
-       m := modulusFromNat(&nat{[]uint{13}})
-       x := &nat{[]uint{6}}
-       y := &nat{[]uint{7}}
-       x.modAdd(y, m)
-       expected := &nat{[]uint{0}}
-       if x.cmpEq(expected) != 1 {
-               t.Errorf("%+v != %+v", x, expected)
-       }
-       x.modAdd(y, m)
-       expected = &nat{[]uint{7}}
-       if x.cmpEq(expected) != 1 {
-               t.Errorf("%+v != %+v", x, expected)
-       }
-}
-
-func TestExp(t *testing.T) {
-       m := modulusFromNat(&nat{[]uint{13}})
-       x := &nat{[]uint{3}}
-       out := &nat{[]uint{0}}
-       out.exp(x, []byte{12}, m)
-       expected := &nat{[]uint{1}}
-       if out.cmpEq(expected) != 1 {
-               t.Errorf("%+v != %+v", out, expected)
-       }
-}
-
-func makeBenchmarkModulus() *modulus {
-       m := make([]uint, 32)
-       for i := 0; i < 32; i++ {
-               m[i] = _MASK
-       }
-       return modulusFromNat(&nat{limbs: m})
-}
-
-func makeBenchmarkValue() *nat {
-       x := make([]uint, 32)
-       for i := 0; i < 32; i++ {
-               x[i] = _MASK - 1
-       }
-       return &nat{limbs: x}
-}
-
-func makeBenchmarkExponent() []byte {
-       e := make([]byte, 256)
-       for i := 0; i < 32; i++ {
-               e[i] = 0xFF
-       }
-       return e
-}
-
-func BenchmarkModAdd(b *testing.B) {
-       x := makeBenchmarkValue()
-       y := makeBenchmarkValue()
-       m := makeBenchmarkModulus()
-
-       b.ResetTimer()
-       for i := 0; i < b.N; i++ {
-               x.modAdd(y, m)
-       }
-}
-
-func BenchmarkModSub(b *testing.B) {
-       x := makeBenchmarkValue()
-       y := makeBenchmarkValue()
-       m := makeBenchmarkModulus()
-
-       b.ResetTimer()
-       for i := 0; i < b.N; i++ {
-               x.modSub(y, m)
-       }
-}
-
-func BenchmarkMontgomeryRepr(b *testing.B) {
-       x := makeBenchmarkValue()
-       m := makeBenchmarkModulus()
-
-       b.ResetTimer()
-       for i := 0; i < b.N; i++ {
-               x.montgomeryRepresentation(m)
-       }
-}
-
-func BenchmarkMontgomeryMul(b *testing.B) {
-       x := makeBenchmarkValue()
-       y := makeBenchmarkValue()
-       out := makeBenchmarkValue()
-       m := makeBenchmarkModulus()
-
-       b.ResetTimer()
-       for i := 0; i < b.N; i++ {
-               out.montgomeryMul(x, y, m)
-       }
-}
-
-func BenchmarkModMul(b *testing.B) {
-       x := makeBenchmarkValue()
-       y := makeBenchmarkValue()
-       m := makeBenchmarkModulus()
-
-       b.ResetTimer()
-       for i := 0; i < b.N; i++ {
-               x.modMul(y, m)
-       }
-}
-
-func BenchmarkExpBig(b *testing.B) {
-       out := new(big.Int)
-       exponentBytes := makeBenchmarkExponent()
-       x := new(big.Int).SetBytes(exponentBytes)
-       e := new(big.Int).SetBytes(exponentBytes)
-       n := new(big.Int).SetBytes(exponentBytes)
-       one := new(big.Int).SetUint64(1)
-       n.Add(n, one)
-
-       b.ResetTimer()
-       for i := 0; i < b.N; i++ {
-               out.Exp(x, e, n)
-       }
-}
-
-func BenchmarkExp(b *testing.B) {
-       x := makeBenchmarkValue()
-       e := makeBenchmarkExponent()
-       out := makeBenchmarkValue()
-       m := makeBenchmarkModulus()
-
-       b.ResetTimer()
-       for i := 0; i < b.N; i++ {
-               out.exp(x, e, m)
-       }
-}
index 971aee6a6d555e895406ac87a29ff2552b4b67c4..e51b9d2ca78780fb8e80e8550e5bbdbff7ee0b23 100644 (file)
@@ -75,7 +75,7 @@ func EncryptPKCS1v15(random io.Reader, pub *PublicKey, msg []byte) ([]byte, erro
                return boring.EncryptRSANoPadding(bkey, em)
        }
 
-       return encrypt(pub, em), nil
+       return encrypt(pub, em)
 }
 
 // DecryptPKCS1v15 decrypts a plaintext using RSA and the padding scheme from PKCS #1 v1.5.
@@ -333,7 +333,10 @@ func VerifyPKCS1v15(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte)
                return ErrVerification
        }
 
-       em := encrypt(pub, sig)
+       em, err := encrypt(pub, sig)
+       if err != nil {
+               return ErrVerification
+       }
        // EM = 0x00 || 0x01 || PS || 0x00 || T
 
        ok := subtle.ConstantTimeByteEq(em[0], 0)
index 6f1a0c12a5b9700e27c3a59f52208693494858fc..f7d23b55ef811a58881ade88e5b82d3494fedd14 100644 (file)
@@ -208,7 +208,7 @@ func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error {
 // given hash function. salt is a random sequence of bytes whose length will be
 // later used to verify the signature.
 func signPSSWithSalt(priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) ([]byte, error) {
-       emBits := bigBitLen(priv.N) - 1
+       emBits := priv.N.BitLen() - 1
        em, err := emsaPSSEncode(hashed, emBits, salt, hash.New())
        if err != nil {
                return nil, err
@@ -302,7 +302,7 @@ func SignPSS(rand io.Reader, priv *PrivateKey, hash crypto.Hash, digest []byte,
        saltLength := opts.saltLength()
        switch saltLength {
        case PSSSaltLengthAuto:
-               saltLength = (bigBitLen(priv.N)-1+7)/8 - 2 - hash.Size()
+               saltLength = (priv.N.BitLen()-1+7)/8 - 2 - hash.Size()
                if saltLength < 0 {
                        return nil, ErrMessageTooLong
                }
@@ -349,9 +349,12 @@ func VerifyPSS(pub *PublicKey, hash crypto.Hash, digest []byte, sig []byte, opts
                return invalidSaltLenErr
        }
 
-       emBits := bigBitLen(pub.N) - 1
+       emBits := pub.N.BitLen() - 1
        emLen := (emBits + 7) / 8
-       em := encrypt(pub, sig)
+       em, err := encrypt(pub, sig)
+       if err != nil {
+               return ErrVerification
+       }
 
        // Like in signPSSWithSalt, deal with mismatches between emLen and the size
        // of the modulus. The spec would have us wire emLen into the encoding
index 55b233152705c56ca19055e0c011eac43832a388..88a5e28e626028c2efece660b2f1820309afee7e 100644 (file)
@@ -27,6 +27,7 @@ package rsa
 
 import (
        "crypto"
+       "crypto/internal/bigmod"
        "crypto/internal/boring"
        "crypto/internal/boring/bbig"
        "crypto/internal/randutil"
@@ -54,7 +55,7 @@ type PublicKey struct {
 // Size returns the modulus size in bytes. Raw signatures and ciphertexts
 // for or by this public key will have the same size.
 func (pub *PublicKey) Size() int {
-       return (bigBitLen(pub.N) + 7) / 8
+       return (pub.N.BitLen() + 7) / 8
 }
 
 // Equal reports whether pub and x have the same value.
@@ -209,7 +210,7 @@ type PrecomputedValues struct {
        // complexity.
        CRTValues []CRTValue
 
-       n, p, q *modulus // moduli for CRT with Montgomery precomputed constants
+       n, p, q *bigmod.Modulus // moduli for CRT with Montgomery precomputed constants
 }
 
 // CRTValue contains the precomputed Chinese remainder theorem values.
@@ -314,9 +315,9 @@ func GenerateMultiPrimeKey(random io.Reader, nprimes int, bits int) (*PrivateKey
                                Dq:        Dq,
                                Qinv:      Qinv,
                                CRTValues: make([]CRTValue, 0), // non-nil, to match Precompute
-                               n:         modulusFromNat(newNat().setBig(N)),
-                               p:         modulusFromNat(newNat().setBig(P)),
-                               q:         modulusFromNat(newNat().setBig(Q)),
+                               n:         bigmod.NewModulusFromBig(N),
+                               p:         bigmod.NewModulusFromBig(P),
+                               q:         bigmod.NewModulusFromBig(Q),
                        },
                }
                return key, nil
@@ -451,15 +452,17 @@ func mgf1XOR(out []byte, hash hash.Hash, seed []byte) {
 // be returned if the size of the salt is too large.
 var ErrMessageTooLong = errors.New("crypto/rsa: message too long for RSA key size")
 
-func encrypt(pub *PublicKey, plaintext []byte) []byte {
+func encrypt(pub *PublicKey, plaintext []byte) ([]byte, error) {
        boring.Unreachable()
 
-       N := modulusFromNat(newNat().setBig(pub.N))
-       m := newNat().setBytes(plaintext).expandFor(N)
+       N := bigmod.NewModulusFromBig(pub.N)
+       m, err := bigmod.NewNat().SetBytes(plaintext, N)
+       if err != nil {
+               return nil, err
+       }
        e := intToBytes(pub.E)
 
-       out := make([]byte, modulusSize(N))
-       return newNat().exp(m, e, N).fillBytes(out)
+       return bigmod.NewNat().Exp(m, e, N).Bytes(N), nil
 }
 
 // intToBytes returns i as a big-endian slice of bytes with no leading zeroes,
@@ -538,7 +541,7 @@ func EncryptOAEP(hash hash.Hash, random io.Reader, pub *PublicKey, msg []byte, l
                return boring.EncryptRSANoPadding(bkey, em)
        }
 
-       return encrypt(pub, em), nil
+       return encrypt(pub, em)
 }
 
 // ErrDecryption represents a failure to decrypt a message.
@@ -553,9 +556,9 @@ var ErrVerification = errors.New("crypto/rsa: verification error")
 // in the future.
 func (priv *PrivateKey) Precompute() {
        if priv.Precomputed.n == nil && len(priv.Primes) == 2 {
-               priv.Precomputed.n = modulusFromNat(newNat().setBig(priv.N))
-               priv.Precomputed.p = modulusFromNat(newNat().setBig(priv.Primes[0]))
-               priv.Precomputed.q = modulusFromNat(newNat().setBig(priv.Primes[1]))
+               priv.Precomputed.n = bigmod.NewModulusFromBig(priv.N)
+               priv.Precomputed.p = bigmod.NewModulusFromBig(priv.Primes[0])
+               priv.Precomputed.q = bigmod.NewModulusFromBig(priv.Primes[1])
        }
 
        // Fill in the backwards-compatibility *big.Int values.
@@ -598,47 +601,53 @@ func decrypt(priv *PrivateKey, ciphertext []byte, check bool) ([]byte, error) {
                boring.Unreachable()
        }
 
-       N := priv.Precomputed.n
-       if N == nil {
-               N = modulusFromNat(newNat().setBig(priv.N))
-       }
-       c := newNat().setBytes(ciphertext).expandFor(N)
-       if c.cmpGeq(N.nat) == 1 {
-               return nil, ErrDecryption
-       }
-       if priv.N.Sign() == 0 {
-               return nil, ErrDecryption
-       }
-
-       var m *nat
+       var (
+               err  error
+               m, c *bigmod.Nat
+               N    *bigmod.Modulus
+               t0   = bigmod.NewNat()
+       )
        if priv.Precomputed.n == nil {
-               m = newNat().exp(c, priv.D.Bytes(), N)
+               N = bigmod.NewModulusFromBig(priv.N)
+               c, err = bigmod.NewNat().SetBytes(ciphertext, N)
+               if err != nil {
+                       return nil, ErrDecryption
+               }
+               m = bigmod.NewNat().Exp(c, priv.D.Bytes(), N)
        } else {
-               t0 := newNat()
+               N = priv.Precomputed.n
                P, Q := priv.Precomputed.p, priv.Precomputed.q
+               Qinv, err := bigmod.NewNat().SetBytes(priv.Precomputed.Qinv.Bytes(), P)
+               if err != nil {
+                       return nil, ErrDecryption
+               }
+               c, err = bigmod.NewNat().SetBytes(ciphertext, N)
+               if err != nil {
+                       return nil, ErrDecryption
+               }
+
                // m = c ^ Dp mod p
-               m = newNat().exp(t0.mod(c, P), priv.Precomputed.Dp.Bytes(), P)
+               m = bigmod.NewNat().Exp(t0.Mod(c, P), priv.Precomputed.Dp.Bytes(), P)
                // m2 = c ^ Dq mod q
-               m2 := newNat().exp(t0.mod(c, Q), priv.Precomputed.Dq.Bytes(), Q)
+               m2 := bigmod.NewNat().Exp(t0.Mod(c, Q), priv.Precomputed.Dq.Bytes(), Q)
                // m = m - m2 mod p
-               m.modSub(t0.mod(m2, P), P)
+               m.Sub(t0.Mod(m2, P), P)
                // m = m * Qinv mod p
-               m.modMul(newNat().setBig(priv.Precomputed.Qinv).expandFor(P), P)
+               m.Mul(Qinv, P)
                // m = m * q mod N
-               m.expandFor(N).modMul(t0.mod(Q.nat, N), N)
+               m.ExpandFor(N).Mul(t0.Mod(Q.Nat(), N), N)
                // m = m + m2 mod N
-               m.modAdd(m2.expandFor(N), N)
+               m.Add(m2.ExpandFor(N), N)
        }
 
        if check {
-               c1 := newNat().exp(m, intToBytes(priv.E), N)
-               if c1.cmpEq(c) != 1 {
+               c1 := bigmod.NewNat().Exp(m, intToBytes(priv.E), N)
+               if c1.Equal(c) != 1 {
                        return nil, ErrDecryption
                }
        }
 
-       out := make([]byte, modulusSize(N))
-       return m.fillBytes(out), nil
+       return m.Bytes(N), nil
 }
 
 // DecryptOAEP decrypts ciphertext using RSA-OAEP.
index 1699ab6fb14b55d25522bf8d88275b7368baa7bf..16101f043a770f82f82328e894fd5950882ddc16 100644 (file)
@@ -276,6 +276,27 @@ func testEverything(t *testing.T, priv *PrivateKey) {
                }
                hash[1] ^= 0x80
        }
+
+       // Check that an input bigger than the modulus is handled correctly,
+       // whether it is longer than the byte size of the modulus or not.
+       c := bytes.Repeat([]byte{0xff}, priv.Size())
+       err = VerifyPSS(&priv.PublicKey, crypto.SHA256, hash[:], c, opts)
+       if err == nil {
+               t.Errorf("VerifyPSS accepted a large signature")
+       }
+       _, err = DecryptPKCS1v15(nil, priv, c)
+       if err == nil {
+               t.Errorf("DecryptPKCS1v15 accepted a large ciphertext")
+       }
+       c = append(c, 0xff)
+       err = VerifyPSS(&priv.PublicKey, crypto.SHA256, hash[:], c, opts)
+       if err == nil {
+               t.Errorf("VerifyPSS accepted a long signature")
+       }
+       _, err = DecryptPKCS1v15(nil, priv, c)
+       if err == nil {
+               t.Errorf("DecryptPKCS1v15 accepted a long ciphertext")
+       }
 }
 
 func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") }
index d275822ce5399c20f459ae3c8762d546a8b79ac3..12dd0e1e2cc8abd3d7d638c6459648ce2aea3e53 100644 (file)
@@ -436,6 +436,7 @@ var depsRules = `
        < encoding/asn1
        < golang.org/x/crypto/cryptobyte/asn1
        < golang.org/x/crypto/cryptobyte
+       < crypto/internal/bigmod
        < crypto/dsa, crypto/elliptic, crypto/rsa
        < crypto/ecdsa
        < CRYPTO-MATH;
index d747078e23c5cfdde06f2ccf3981e2f109981c48..76d6eb9caed127c03bb19fe4713b9a3ffb8ef59e 100644 (file)
@@ -35,6 +35,9 @@ var intOne = &Int{false, natOne}
 //      0 if x == 0
 //     +1 if x >  0
 func (x *Int) Sign() int {
+       // This function is used in cryptographic operations. It must not leak
+       // anything but the Int's sign and bit size through side-channels. Any
+       // changes must be reviewed by a security expert.
        if len(x.abs) == 0 {
                return 0
        }
@@ -96,6 +99,9 @@ func (z *Int) Set(x *Int) *Int {
 // Bits is intended to support implementation of missing low-level Int
 // functionality outside this package; it should be avoided otherwise.
 func (x *Int) Bits() []Word {
+       // This function is used in cryptographic operations. It must not leak
+       // anything but the Int's sign and bit size through side-channels. Any
+       // changes must be reviewed by a security expert.
        return x.abs
 }
 
@@ -487,6 +493,9 @@ func (z *Int) SetBytes(buf []byte) *Int {
 //
 // To use a fixed length slice, or a preallocated one, use FillBytes.
 func (x *Int) Bytes() []byte {
+       // This function is used in cryptographic operations. It must not leak
+       // anything but the Int's sign and bit size through side-channels. Any
+       // changes must be reviewed by a security expert.
        buf := make([]byte, len(x.abs)*_S)
        return buf[x.abs.bytes(buf):]
 }
@@ -507,6 +516,9 @@ func (x *Int) FillBytes(buf []byte) []byte {
 // BitLen returns the length of the absolute value of x in bits.
 // The bit length of 0 is 0.
 func (x *Int) BitLen() int {
+       // This function is used in cryptographic operations. It must not leak
+       // anything but the Int's sign and bit size through side-channels. Any
+       // changes must be reviewed by a security expert.
        return x.abs.bitLen()
 }
 
index a7f4dc6999addccbf70b5152238c7e925137a5d4..4166a90ac0da9477d7fa63e430e47af8337c34d2 100644 (file)
@@ -661,8 +661,18 @@ var natPool sync.Pool
 // bitLen returns the length of x in bits.
 // Unlike most methods, it works even if x is not normalized.
 func (x nat) bitLen() int {
+       // This function is used in cryptographic operations. It must not leak
+       // anything but the Int's sign and bit size through side-channels. Any
+       // changes must be reviewed by a security expert.
+       //
+       // In particular, bits.Len and bits.LeadingZeros use a lookup table for the
+       // low-order bits on some architectures.
        if i := len(x) - 1; i >= 0 {
-               return i*_W + bits.Len(uint(x[i]))
+               l := i * _W
+               for top := x[i]; top != 0; top >>= 1 {
+                       l++
+               }
+               return l
        }
        return 0
 }
@@ -1292,6 +1302,9 @@ func (z nat) expNNMontgomery(x, y, m nat) nat {
 // cannot be represented in buf, bytes panics. The number i of unused
 // bytes at the beginning of buf is returned as result.
 func (z nat) bytes(buf []byte) (i int) {
+       // This function is used in cryptographic operations. It must not leak
+       // anything but the Int's sign and bit size through side-channels. Any
+       // changes must be reviewed by a security expert.
        i = len(buf)
        for _, d := range z {
                for j := 0; j < _S; j++ {