This will let us reuse it in crypto/ecdsa for the NIST scalar fields.
The main change in API is around encoding and decoding. The SetBytes +
ExpandFor sequence was hacky: SetBytes could produce a bigger size than
the modulus if leading zeroes in the top byte overflowed the limb
boundary, so ExpandFor had to check for and tolerate that. Also, the
caller was responsible for checking that the overflow was actually all
zeroes (which we weren't doing, exposing a crasher in decryption and
signature verification) and then for checking that the result was less
than the modulus. Instead, make SetBytes take a modulus and return an
error if the value overflows. Same with Bytes: we were always allocating
based on Size before FillBytes anyway, so now Bytes takes a modulus.
Finally, SetBig was almost only used for moduli, so replaced
NewModulusFromNat and SetBig with NewModulusFromBig.
Moved the constant-time bitLen to math/big.Int.BitLen. It's slower, but
BitLen is primarily used in cryptographic code, so it's safer this way.
Change-Id: Ibaf7f36d80695578cb80484167d82ce1aa83832f
Reviewed-on: https://go-review.googlesource.com/c/go/+/450055
Auto-Submit: Filippo Valsorda <filippo@golang.org>
Run-TryBot: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Cherry Mui <cherryyz@google.com>
Reviewed-by: Roland Shoemaker <roland@golang.org>
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-package rsa
+package bigmod
import (
+ "errors"
"math/big"
"math/bits"
)
return not(choice(carry))
}
-// nat represents an arbitrary natural number
+// Nat represents an arbitrary natural number
//
-// Each nat has an announced length, which is the number of limbs it has stored.
+// Each Nat has an announced length, which is the number of limbs it has stored.
// Operations on this number are allowed to leak this length, but will not leak
// any information about the values contained in those limbs.
-type nat struct {
+type Nat struct {
// limbs is a little-endian representation in base 2^W with
// W = bits.UintSize - 1. The top bit is always unset between operations.
//
const preallocTarget = 2048
const preallocLimbs = (preallocTarget + _W) / _W
-// newNat returns a new nat with a size of zero, just like new(nat), but with
+// NewNat returns a new nat with a size of zero, just like new(Nat), but with
// the preallocated capacity to hold a number of up to preallocTarget bits.
-// newNat inlines, so the allocation can live on the stack.
-func newNat() *nat {
+// NewNat inlines, so the allocation can live on the stack.
+func NewNat() *Nat {
limbs := make([]uint, 0, preallocLimbs)
- return &nat{limbs}
+ return &Nat{limbs}
}
// expand expands x to n limbs, leaving its value unchanged.
-func (x *nat) expand(n int) *nat {
- for len(x.limbs) > n {
- if x.limbs[len(x.limbs)-1] != 0 {
- panic("rsa: internal error: shrinking nat")
- }
- x.limbs = x.limbs[:len(x.limbs)-1]
+func (x *Nat) expand(n int) *Nat {
+ if len(x.limbs) > n {
+ panic("bigmod: internal error: shrinking nat")
}
if cap(x.limbs) < n {
newLimbs := make([]uint, n)
}
// reset returns a zero nat of n limbs, reusing x's storage if n <= cap(x.limbs).
-func (x *nat) reset(n int) *nat {
+func (x *Nat) reset(n int) *Nat {
if cap(x.limbs) < n {
x.limbs = make([]uint, n)
return x
}
// set assigns x = y, optionally resizing x to the appropriate size.
-func (x *nat) set(y *nat) *nat {
+func (x *Nat) set(y *Nat) *Nat {
x.reset(len(y.limbs))
copy(x.limbs, y.limbs)
return x
}
-// set assigns x = n, optionally resizing n to the appropriate size.
+// setBig assigns x = n, optionally resizing n to the appropriate size.
//
// The announced length of x is set based on the actual bit size of the input,
// ignoring leading zeroes.
-func (x *nat) setBig(n *big.Int) *nat {
- bitSize := bigBitLen(n)
- requiredLimbs := (bitSize + _W - 1) / _W
+func (x *Nat) setBig(n *big.Int) *Nat {
+ requiredLimbs := (n.BitLen() + _W - 1) / _W
x.reset(requiredLimbs)
outI := 0
return x
}
-// fillBytes sets bytes to x as a zero-extended big-endian byte slice.
+// Bytes returns x as a zero-extended big-endian byte slice. The size of the
+// slice will match the size of m.
//
-// If bytes is not long enough to contain the number or at least len(x.limbs)-1
-// limbs, or has zero length, fillBytes will panic.
-func (x *nat) fillBytes(bytes []byte) []byte {
- if len(bytes) == 0 {
- panic("nat: fillBytes invoked with too small buffer")
- }
- for i := range bytes {
- bytes[i] = 0
- }
+// x must have the same size as m and it must be reduced modulo m.
+func (x *Nat) Bytes(m *Modulus) []byte {
+ bytes := make([]byte, m.Size())
shift := 0
outI := len(bytes) - 1
- for i, limb := range x.limbs {
+ for _, limb := range x.limbs {
remainingBits := _W
for remainingBits >= 8 {
bytes[outI] |= byte(limb) << shift
shift = 0
outI--
if outI < 0 {
- if limb != 0 || i < len(x.limbs)-1 {
- panic("nat: fillBytes invoked with too small buffer")
- }
return bytes
}
}
return bytes
}
-// setBytes assigns x = b, where b is a slice of big-endian bytes, optionally
-// resizing n to the appropriate size.
+// SetBytes assigns x = b, where b is a slice of big-endian bytes.
+// SetBytes returns an error if b > m.
//
-// The announced length of the output depends only on the length of b. Unlike
-// big.Int, creating a nat will not remove leading zeros.
-func (x *nat) setBytes(b []byte) *nat {
- bitSize := len(b) * 8
- requiredLimbs := (bitSize + _W - 1) / _W
- x.reset(requiredLimbs)
-
+// The output will be resized to the size of m and overwritten.
+func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) {
outI := 0
shift := 0
+ x.resetFor(m)
for i := len(b) - 1; i >= 0; i-- {
bi := b[i]
x.limbs[outI] |= uint(bi) << shift
if shift >= _W {
shift -= _W
x.limbs[outI] &= _MASK
+ overflow := bi >> (8 - shift)
outI++
- if shift > 0 {
- x.limbs[outI] = uint(bi) >> (8 - shift)
+ if outI >= len(x.limbs) {
+ if overflow > 0 || i > 0 {
+ return nil, errors.New("input overflows the modulus")
+ }
+ break
}
+ x.limbs[outI] = uint(overflow)
}
}
- return x
+ if x.cmpGeq(m.nat) == yes {
+ return nil, errors.New("input overflows the modulus")
+ }
+ return x, nil
}
-// cmpEq returns 1 if x == y, and 0 otherwise.
+// Equal returns 1 if x == y, and 0 otherwise.
//
// Both operands must have the same announced length.
-func (x *nat) cmpEq(y *nat) choice {
+func (x *Nat) Equal(y *Nat) choice {
// Eliminate bounds checks in the loop.
size := len(x.limbs)
xLimbs := x.limbs[:size]
// cmpGeq returns 1 if x >= y, and 0 otherwise.
//
// Both operands must have the same announced length.
-func (x *nat) cmpGeq(y *nat) choice {
+func (x *Nat) cmpGeq(y *Nat) choice {
// Eliminate bounds checks in the loop.
size := len(x.limbs)
xLimbs := x.limbs[:size]
// assign sets x <- y if on == 1, and does nothing otherwise.
//
// Both operands must have the same announced length.
-func (x *nat) assign(on choice, y *nat) *nat {
+func (x *Nat) assign(on choice, y *Nat) *Nat {
// Eliminate bounds checks in the loop.
size := len(x.limbs)
xLimbs := x.limbs[:size]
// carry of the addition regardless of on.
//
// Both operands must have the same announced length.
-func (x *nat) add(on choice, y *nat) (c uint) {
+func (x *Nat) add(on choice, y *Nat) (c uint) {
// Eliminate bounds checks in the loop.
size := len(x.limbs)
xLimbs := x.limbs[:size]
// borrow of the subtraction regardless of on.
//
// Both operands must have the same announced length.
-func (x *nat) sub(on choice, y *nat) (c uint) {
+func (x *Nat) sub(on choice, y *Nat) (c uint) {
// Eliminate bounds checks in the loop.
size := len(x.limbs)
xLimbs := x.limbs[:size]
return
}
-// modulus is used for modular arithmetic, precomputing relevant constants.
+// Modulus is used for modular arithmetic, precomputing relevant constants.
//
// Moduli are assumed to be odd numbers. Moduli can also leak the exact
// number of bits needed to store their value, and are stored without padding.
//
// Their actual value is still kept secret.
-type modulus struct {
+type Modulus struct {
// The underlying natural number for this modulus.
//
// This will be stored without any padding, and shouldn't alias with any
// other natural number being used.
- nat *nat
+ nat *Nat
leading int // number of leading zeros in the modulus
m0inv uint // -nat.limbs[0]⁻¹ mod _W
- RR *nat // R*R for montgomeryRepresentation
+ rr *Nat // R*R for montgomeryRepresentation
}
// rr returns R*R with R = 2^(_W * n) and n = len(m.nat.limbs).
-func rr(m *modulus) *nat {
- rr := newNat().expandFor(m)
+func rr(m *Modulus) *Nat {
+ rr := NewNat().ExpandFor(m)
// R*R is 2^(2 * _W * n). We can safely get 2^(_W * (n - 1)) by setting the
// most significant limb to 1. We then get to R*R by shifting left by _W
// n + 1 times.
return (1 << _W) - (y & _MASK)
}
-// modulusFromNat creates a new modulus from a nat.
+// NewModulusFromBig creates a new Modulus from a [big.Int].
//
-// The nat should be odd, nonzero, and the number of significant bits in the
-// number should be leakable. The nat shouldn't be reused.
-func modulusFromNat(nat *nat) *modulus {
- m := &modulus{}
- m.nat = nat
- size := len(m.nat.limbs)
- for m.nat.limbs[size-1] == 0 {
- size--
- }
- m.nat.limbs = m.nat.limbs[:size]
- m.leading = _W - bitLen(m.nat.limbs[size-1])
+// The Int must be odd. The number of significant bits must be leakable.
+func NewModulusFromBig(n *big.Int) *Modulus {
+ m := &Modulus{}
+ m.nat = NewNat().setBig(n)
+ m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1])
m.m0inv = minusInverseModW(m.nat.limbs[0])
- m.RR = rr(m)
+ m.rr = rr(m)
return m
}
return len
}
-// bigBitLen is a version of big.Int.BitLen that only leaks the bit length of x,
-// but not its value. big.Int.BitLen uses bits.Len.
-func bigBitLen(x *big.Int) int {
- xLimbs := x.Bits()
- fullLimbs := len(xLimbs) - 1
- topLimb := uint(xLimbs[len(xLimbs)-1])
- return fullLimbs*bits.UintSize + bitLen(topLimb)
-}
-
-// modulusSize returns the size of m in bytes.
-func modulusSize(m *modulus) int {
+// Size returns the size of m in bytes.
+func (m *Modulus) Size() int {
bits := len(m.nat.limbs)*_W - int(m.leading)
return (bits + 7) / 8
}
+// Nat returns m as a Nat. The return value must not be written to.
+func (m *Modulus) Nat() *Nat {
+ return m.nat
+}
+
// shiftIn calculates x = x << _W + y mod m.
//
// This assumes that x is already reduced mod m, and that y < 2^_W.
-func (x *nat) shiftIn(y uint, m *modulus) *nat {
- d := newNat().resetFor(m)
+func (x *Nat) shiftIn(y uint, m *Modulus) *Nat {
+ d := NewNat().resetFor(m)
// Eliminate bounds checks in the loop.
size := len(m.nat.limbs)
dLimbs[i] = res & _MASK
borrow = res >> _W
}
- // See modAdd for how carry (aka overflow), borrow (aka underflow), and
+ // See Add for how carry (aka overflow), borrow (aka underflow), and
// needSubtraction relate.
needSubtraction = ctEq(carry, borrow)
}
return x.assign(needSubtraction, d)
}
-// mod calculates out = x mod m.
+// Mod calculates out = x mod m.
//
// This works regardless how large the value of x is.
//
// The output will be resized to the size of m and overwritten.
-func (out *nat) mod(x *nat, m *modulus) *nat {
+func (out *Nat) Mod(x *Nat, m *Modulus) *Nat {
out.resetFor(m)
// Working our way from the most significant to the least significant limb,
// we can insert each limb at the least significant position, shifting all
return out
}
-// expandFor ensures out has the right size to work with operations modulo m.
+// ExpandFor ensures out has the right size to work with operations modulo m.
//
-// This assumes that out has as many or fewer limbs than m, or that the extra
-// limbs are all zero (which may happen when decoding a value that has leading
-// zeroes in its bytes representation that spill over the limb threshold).
-func (out *nat) expandFor(m *modulus) *nat {
+// The announced size of out must be smaller than or equal to that of m.
+func (out *Nat) ExpandFor(m *Modulus) *Nat {
return out.expand(len(m.nat.limbs))
}
// resetFor ensures out has the right size to work with operations modulo m.
//
// out is zeroed and may start at any size.
-func (out *nat) resetFor(m *modulus) *nat {
+func (out *Nat) resetFor(m *Modulus) *Nat {
return out.reset(len(m.nat.limbs))
}
-// modSub computes x = x - y mod m.
+// Sub computes x = x - y mod m.
//
// The length of both operands must be the same as the modulus. Both operands
// must already be reduced modulo m.
-func (x *nat) modSub(y *nat, m *modulus) *nat {
+func (x *Nat) Sub(y *Nat, m *Modulus) *Nat {
underflow := x.sub(yes, y)
// If the subtraction underflowed, add m.
x.add(choice(underflow), m.nat)
return x
}
-// modAdd computes x = x + y mod m.
+// Add computes x = x + y mod m.
//
// The length of both operands must be the same as the modulus. Both operands
// must already be reduced modulo m.
-func (x *nat) modAdd(y *nat, m *modulus) *nat {
+func (x *Nat) Add(y *Nat, m *Modulus) *Nat {
overflow := x.add(yes, y)
underflow := not(x.cmpGeq(m.nat)) // x < m
// numbers in this representation.
//
// This assumes that x is already reduced mod m.
-func (x *nat) montgomeryRepresentation(m *modulus) *nat {
+func (x *Nat) montgomeryRepresentation(m *Modulus) *Nat {
// A Montgomery multiplication (which computes a * b / R) by R * R works out
// to a multiplication by R, which takes the value out of the Montgomery domain.
- return x.montgomeryMul(newNat().set(x), m.RR, m)
+ return x.montgomeryMul(NewNat().set(x), m.rr, m)
}
// montgomeryReduction calculates x = x / R mod m, with R = 2^(_W * n) and
// n = len(m.nat.limbs).
//
// This assumes that x is already reduced mod m.
-func (x *nat) montgomeryReduction(m *modulus) *nat {
+func (x *Nat) montgomeryReduction(m *Modulus) *Nat {
// By Montgomery multiplying with 1 not in Montgomery representation, we
// convert out back from Montgomery representation, because it works out to
// dividing by R.
- t0 := newNat().set(x)
- t1 := newNat().expandFor(m)
+ t0 := NewNat().set(x)
+ t1 := NewNat().ExpandFor(m)
t1.limbs[0] = 1
return x.montgomeryMul(t0, t1, m)
}
//
// All inputs should be the same length, not aliasing d, and already
// reduced modulo m. d will be resized to the size of m and overwritten.
-func (d *nat) montgomeryMul(a *nat, b *nat, m *modulus) *nat {
+func (d *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat {
// See https://bearssl.org/bigint.html#montgomery-reduction-and-multiplication
// for a description of the algorithm.
dLimbs[size-1] = z & _MASK
overflow = z >> _W // overflow <= 1
}
- // See modAdd for how overflow, underflow, and needSubtraction relate.
+ // See Add for how overflow, underflow, and needSubtraction relate.
underflow := not(d.cmpGeq(m.nat)) // d < m
needSubtraction := ctEq(overflow, uint(underflow))
d.sub(needSubtraction, m.nat)
return d
}
-// modMul calculates x *= y mod m.
+// Mul calculates x *= y mod m.
//
// x and y must already be reduced modulo m, they must share its announced
// length, and they may not alias.
-func (x *nat) modMul(y *nat, m *modulus) *nat {
+func (x *Nat) Mul(y *Nat, m *Modulus) *Nat {
// A Montgomery multiplication by a value out of the Montgomery domain
// takes the result out of Montgomery representation.
- xR := newNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m
+ xR := NewNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m
return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m
}
-// exp calculates out = x^e mod m.
+// Exp calculates out = x^e mod m.
//
// The exponent e is represented in big-endian order. The output will be resized
// to the size of m and overwritten. x must already be reduced modulo m.
-func (out *nat) exp(x *nat, e []byte, m *modulus) *nat {
+func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
// We use a 4 bit window. For our RSA workload, 4 bit windows are faster
// than 2 bit windows, but use an extra 12 nats worth of scratch space.
// Using bit sizes that don't divide 8 are more complex to implement.
- table := [(1 << 4) - 1]*nat{ // table[i] = x ^ (i+1)
+ table := [(1 << 4) - 1]*Nat{ // table[i] = x ^ (i+1)
// newNat calls are unrolled so they are allocated on the stack.
- newNat(), newNat(), newNat(), newNat(), newNat(),
- newNat(), newNat(), newNat(), newNat(), newNat(),
- newNat(), newNat(), newNat(), newNat(), newNat(),
+ NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
+ NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
+ NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
}
table[0].set(x).montgomeryRepresentation(m)
for i := 1; i < len(table); i++ {
out.resetFor(m)
out.limbs[0] = 1
out.montgomeryRepresentation(m)
- t0 := newNat().expandFor(m)
- t1 := newNat().expandFor(m)
+ t0 := NewNat().ExpandFor(m)
+ t1 := NewNat().ExpandFor(m)
for _, b := range e {
for _, j := range []int{4, 0} {
// Square four times.
--- /dev/null
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package bigmod
+
+import (
+ "math/big"
+ "math/bits"
+ "math/rand"
+ "reflect"
+ "testing"
+ "testing/quick"
+)
+
+// Generate generates an even nat. It's used by testing/quick to produce random
+// *nat values for quick.Check invocations.
+func (*Nat) Generate(r *rand.Rand, size int) reflect.Value {
+ limbs := make([]uint, size)
+ for i := 0; i < size; i++ {
+ limbs[i] = uint(r.Uint64()) & ((1 << _W) - 2)
+ }
+ return reflect.ValueOf(&Nat{limbs})
+}
+
+func testModAddCommutative(a *Nat, b *Nat) bool {
+ m := maxModulus(uint(len(a.limbs)))
+ aPlusB := new(Nat).set(a)
+ aPlusB.Add(b, m)
+ bPlusA := new(Nat).set(b)
+ bPlusA.Add(a, m)
+ return aPlusB.Equal(bPlusA) == 1
+}
+
+func TestModAddCommutative(t *testing.T) {
+ err := quick.Check(testModAddCommutative, &quick.Config{})
+ if err != nil {
+ t.Error(err)
+ }
+}
+
+func testModSubThenAddIdentity(a *Nat, b *Nat) bool {
+ m := maxModulus(uint(len(a.limbs)))
+ original := new(Nat).set(a)
+ a.Sub(b, m)
+ a.Add(b, m)
+ return a.Equal(original) == 1
+}
+
+func TestModSubThenAddIdentity(t *testing.T) {
+ err := quick.Check(testModSubThenAddIdentity, &quick.Config{})
+ if err != nil {
+ t.Error(err)
+ }
+}
+
+func testMontgomeryRoundtrip(a *Nat) bool {
+ one := &Nat{make([]uint, len(a.limbs))}
+ one.limbs[0] = 1
+ aPlusOne := new(big.Int).SetBytes(natBytes(a))
+ aPlusOne.Add(aPlusOne, big.NewInt(1))
+ m := NewModulusFromBig(aPlusOne)
+ monty := new(Nat).set(a)
+ monty.montgomeryRepresentation(m)
+ aAgain := new(Nat).set(monty)
+ aAgain.montgomeryMul(monty, one, m)
+ return a.Equal(aAgain) == 1
+}
+
+func TestMontgomeryRoundtrip(t *testing.T) {
+ err := quick.Check(testMontgomeryRoundtrip, &quick.Config{})
+ if err != nil {
+ t.Error(err)
+ }
+}
+
+func TestShiftIn(t *testing.T) {
+ if bits.UintSize != 64 {
+ t.Skip("examples are only valid in 64 bit")
+ }
+ examples := []struct {
+ m, x, expected []byte
+ y uint64
+ }{{
+ m: []byte{13},
+ x: []byte{0},
+ y: 0x7FFF_FFFF_FFFF_FFFF,
+ expected: []byte{7},
+ }, {
+ m: []byte{13},
+ x: []byte{7},
+ y: 0x7FFF_FFFF_FFFF_FFFF,
+ expected: []byte{11},
+ }, {
+ m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
+ x: make([]byte, 9),
+ y: 0x7FFF_FFFF_FFFF_FFFF,
+ expected: []byte{0x00, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ }, {
+ m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
+ x: []byte{0x00, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ y: 0,
+ expected: []byte{0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08},
+ }}
+
+ for i, tt := range examples {
+ m := modulusFromBytes(tt.m)
+ got := natFromBytes(tt.x).ExpandFor(m).shiftIn(uint(tt.y), m)
+ if got.Equal(natFromBytes(tt.expected).ExpandFor(m)) != 1 {
+ t.Errorf("%d: got %x, expected %x", i, got, tt.expected)
+ }
+ }
+}
+
+func TestModulusAndNatSizes(t *testing.T) {
+ // These are 126 bit (2 * _W on 64-bit architectures) values, serialized as
+ // 128 bits worth of bytes. If leading zeroes are stripped, they fit in two
+ // limbs, if they are not, they fit in three. This can be a problem because
+ // modulus strips leading zeroes and nat does not.
+ m := modulusFromBytes([]byte{
+ 0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
+ xb := []byte{0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}
+ natFromBytes(xb).ExpandFor(m) // must not panic for shrinking
+ NewNat().SetBytes(xb, m)
+}
+
+func TestSetBytes(t *testing.T) {
+ tests := []struct {
+ m, b []byte
+ fail bool
+ }{{
+ m: []byte{0xff, 0xff},
+ b: []byte{0x00, 0x01},
+ }, {
+ m: []byte{0xff, 0xff},
+ b: []byte{0xff, 0xff},
+ fail: true,
+ }, {
+ m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ b: []byte{0x00, 0x01},
+ }, {
+ m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
+ }, {
+ m: []byte{0xff, 0xff},
+ b: []byte{0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ fail: true,
+ }, {
+ m: []byte{0xff, 0xff},
+ b: []byte{0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ fail: true,
+ }, {
+ m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ b: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
+ }, {
+ m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
+ fail: true,
+ }, {
+ m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ b: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ fail: true,
+ }, {
+ m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
+ fail: true,
+ }, {
+ m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfd},
+ b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ fail: true,
+ }}
+
+ for i, tt := range tests {
+ m := modulusFromBytes(tt.m)
+ got, err := NewNat().SetBytes(tt.b, m)
+ if err != nil {
+ if !tt.fail {
+ t.Errorf("%d: unexpected error: %v", i, err)
+ }
+ continue
+ }
+ if err == nil && tt.fail {
+ t.Errorf("%d: unexpected success", i)
+ continue
+ }
+ if expected := natFromBytes(tt.b).ExpandFor(m); got.Equal(expected) != yes {
+ t.Errorf("%d: got %x, expected %x", i, got, expected)
+ }
+ }
+
+ f := func(xBytes []byte) bool {
+ m := maxModulus(uint(len(xBytes)*8/_W + 1))
+ got, err := NewNat().SetBytes(xBytes, m)
+ if err != nil {
+ return false
+ }
+ return got.Equal(natFromBytes(xBytes).ExpandFor(m)) == yes
+ }
+
+ err := quick.Check(f, &quick.Config{})
+ if err != nil {
+ t.Error(err)
+ }
+}
+
+func TestExpand(t *testing.T) {
+ sliced := []uint{1, 2, 3, 4}
+ examples := []struct {
+ in []uint
+ n int
+ out []uint
+ }{{
+ []uint{1, 2},
+ 4,
+ []uint{1, 2, 0, 0},
+ }, {
+ sliced[:2],
+ 4,
+ []uint{1, 2, 0, 0},
+ }, {
+ []uint{1, 2},
+ 2,
+ []uint{1, 2},
+ }}
+
+ for i, tt := range examples {
+ got := (&Nat{tt.in}).expand(tt.n)
+ if len(got.limbs) != len(tt.out) || got.Equal(&Nat{tt.out}) != 1 {
+ t.Errorf("%d: got %x, expected %x", i, got, tt.out)
+ }
+ }
+}
+
+func TestMod(t *testing.T) {
+ m := modulusFromBytes([]byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d})
+ x := natFromBytes([]byte{0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01})
+ out := new(Nat)
+ out.Mod(x, m)
+ expected := natFromBytes([]byte{0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09})
+ if out.Equal(expected) != 1 {
+ t.Errorf("%+v != %+v", out, expected)
+ }
+}
+
+func TestModSub(t *testing.T) {
+ m := modulusFromBytes([]byte{13})
+ x := &Nat{[]uint{6}}
+ y := &Nat{[]uint{7}}
+ x.Sub(y, m)
+ expected := &Nat{[]uint{12}}
+ if x.Equal(expected) != 1 {
+ t.Errorf("%+v != %+v", x, expected)
+ }
+ x.Sub(y, m)
+ expected = &Nat{[]uint{5}}
+ if x.Equal(expected) != 1 {
+ t.Errorf("%+v != %+v", x, expected)
+ }
+}
+
+func TestModAdd(t *testing.T) {
+ m := modulusFromBytes([]byte{13})
+ x := &Nat{[]uint{6}}
+ y := &Nat{[]uint{7}}
+ x.Add(y, m)
+ expected := &Nat{[]uint{0}}
+ if x.Equal(expected) != 1 {
+ t.Errorf("%+v != %+v", x, expected)
+ }
+ x.Add(y, m)
+ expected = &Nat{[]uint{7}}
+ if x.Equal(expected) != 1 {
+ t.Errorf("%+v != %+v", x, expected)
+ }
+}
+
+func TestExp(t *testing.T) {
+ m := modulusFromBytes([]byte{13})
+ x := &Nat{[]uint{3}}
+ out := &Nat{[]uint{0}}
+ out.Exp(x, []byte{12}, m)
+ expected := &Nat{[]uint{1}}
+ if out.Equal(expected) != 1 {
+ t.Errorf("%+v != %+v", out, expected)
+ }
+}
+
+func natBytes(n *Nat) []byte {
+ return n.Bytes(maxModulus(uint(len(n.limbs))))
+}
+
+func natFromBytes(b []byte) *Nat {
+ bb := new(big.Int).SetBytes(b)
+ return NewNat().setBig(bb)
+}
+
+func modulusFromBytes(b []byte) *Modulus {
+ bb := new(big.Int).SetBytes(b)
+ return NewModulusFromBig(bb)
+}
+
+// maxModulus returns the biggest modulus that can fit in n limbs.
+func maxModulus(n uint) *Modulus {
+ m := big.NewInt(1)
+ m.Lsh(m, n*_W)
+ m.Sub(m, big.NewInt(1))
+ return NewModulusFromBig(m)
+}
+
+func makeBenchmarkModulus() *Modulus {
+ return maxModulus(32)
+}
+
+func makeBenchmarkValue() *Nat {
+ x := make([]uint, 32)
+ for i := 0; i < 32; i++ {
+ x[i] = _MASK - 1
+ }
+ return &Nat{limbs: x}
+}
+
+func makeBenchmarkExponent() []byte {
+ e := make([]byte, 256)
+ for i := 0; i < 32; i++ {
+ e[i] = 0xFF
+ }
+ return e
+}
+
+func BenchmarkModAdd(b *testing.B) {
+ x := makeBenchmarkValue()
+ y := makeBenchmarkValue()
+ m := makeBenchmarkModulus()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ x.Add(y, m)
+ }
+}
+
+func BenchmarkModSub(b *testing.B) {
+ x := makeBenchmarkValue()
+ y := makeBenchmarkValue()
+ m := makeBenchmarkModulus()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ x.Sub(y, m)
+ }
+}
+
+func BenchmarkMontgomeryRepr(b *testing.B) {
+ x := makeBenchmarkValue()
+ m := makeBenchmarkModulus()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ x.montgomeryRepresentation(m)
+ }
+}
+
+func BenchmarkMontgomeryMul(b *testing.B) {
+ x := makeBenchmarkValue()
+ y := makeBenchmarkValue()
+ out := makeBenchmarkValue()
+ m := makeBenchmarkModulus()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ out.montgomeryMul(x, y, m)
+ }
+}
+
+func BenchmarkModMul(b *testing.B) {
+ x := makeBenchmarkValue()
+ y := makeBenchmarkValue()
+ m := makeBenchmarkModulus()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ x.Mul(y, m)
+ }
+}
+
+func BenchmarkExpBig(b *testing.B) {
+ out := new(big.Int)
+ exponentBytes := makeBenchmarkExponent()
+ x := new(big.Int).SetBytes(exponentBytes)
+ e := new(big.Int).SetBytes(exponentBytes)
+ n := new(big.Int).SetBytes(exponentBytes)
+ one := new(big.Int).SetUint64(1)
+ n.Add(n, one)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ out.Exp(x, e, n)
+ }
+}
+
+func BenchmarkExp(b *testing.B) {
+ x := makeBenchmarkValue()
+ e := makeBenchmarkExponent()
+ out := makeBenchmarkValue()
+ m := makeBenchmarkModulus()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ out.Exp(x, e, m)
+ }
+}
+++ /dev/null
-// Copyright 2021 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package rsa
-
-import (
- "bytes"
- "math/big"
- "math/bits"
- "math/rand"
- "reflect"
- "testing"
- "testing/quick"
-)
-
-// Generate generates an even nat. It's used by testing/quick to produce random
-// *nat values for quick.Check invocations.
-func (*nat) Generate(r *rand.Rand, size int) reflect.Value {
- limbs := make([]uint, size)
- for i := 0; i < size; i++ {
- limbs[i] = uint(r.Uint64()) & ((1 << _W) - 2)
- }
- return reflect.ValueOf(&nat{limbs})
-}
-
-func testModAddCommutative(a *nat, b *nat) bool {
- mLimbs := make([]uint, len(a.limbs))
- for i := 0; i < len(mLimbs); i++ {
- mLimbs[i] = _MASK
- }
- m := modulusFromNat(&nat{mLimbs})
- aPlusB := new(nat).set(a)
- aPlusB.modAdd(b, m)
- bPlusA := new(nat).set(b)
- bPlusA.modAdd(a, m)
- return aPlusB.cmpEq(bPlusA) == 1
-}
-
-func TestModAddCommutative(t *testing.T) {
- err := quick.Check(testModAddCommutative, &quick.Config{})
- if err != nil {
- t.Error(err)
- }
-}
-
-func testModSubThenAddIdentity(a *nat, b *nat) bool {
- mLimbs := make([]uint, len(a.limbs))
- for i := 0; i < len(mLimbs); i++ {
- mLimbs[i] = _MASK
- }
- m := modulusFromNat(&nat{mLimbs})
- original := new(nat).set(a)
- a.modSub(b, m)
- a.modAdd(b, m)
- return a.cmpEq(original) == 1
-}
-
-func TestModSubThenAddIdentity(t *testing.T) {
- err := quick.Check(testModSubThenAddIdentity, &quick.Config{})
- if err != nil {
- t.Error(err)
- }
-}
-
-func testMontgomeryRoundtrip(a *nat) bool {
- one := &nat{make([]uint, len(a.limbs))}
- one.limbs[0] = 1
- aPlusOne := new(nat).set(a)
- aPlusOne.add(1, one)
- m := modulusFromNat(aPlusOne)
- monty := new(nat).set(a)
- monty.montgomeryRepresentation(m)
- aAgain := new(nat).set(monty)
- aAgain.montgomeryMul(monty, one, m)
- return a.cmpEq(aAgain) == 1
-}
-
-func TestMontgomeryRoundtrip(t *testing.T) {
- err := quick.Check(testMontgomeryRoundtrip, &quick.Config{})
- if err != nil {
- t.Error(err)
- }
-}
-
-func TestFromBig(t *testing.T) {
- expected := []byte{0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
- theBig := new(big.Int).SetBytes(expected)
- actual := new(nat).setBig(theBig).fillBytes(make([]byte, len(expected)))
- if !bytes.Equal(actual, expected) {
- t.Errorf("%+x != %+x", actual, expected)
- }
-}
-
-func TestFillBytes(t *testing.T) {
- xBytes := []byte{0xAA, 0xFF, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}
- x := new(nat).setBytes(xBytes)
- for l := 20; l >= len(xBytes); l-- {
- buf := make([]byte, l)
- rand.Read(buf)
- actual := x.fillBytes(buf)
- expected := make([]byte, l)
- copy(expected[l-len(xBytes):], xBytes)
- if !bytes.Equal(actual, expected) {
- t.Errorf("%d: %+v != %+v", l, actual, expected)
- }
- }
- for l := len(xBytes) - 1; l >= 0; l-- {
- (func() {
- defer func() {
- if recover() == nil {
- t.Errorf("%d: expected panic", l)
- }
- }()
- x.fillBytes(make([]byte, l))
- })()
- }
-}
-
-func TestFromBytes(t *testing.T) {
- f := func(xBytes []byte) bool {
- if len(xBytes) == 0 {
- return true
- }
- actual := new(nat).setBytes(xBytes).fillBytes(make([]byte, len(xBytes)))
- if !bytes.Equal(actual, xBytes) {
- t.Errorf("%+x != %+x", actual, xBytes)
- return false
- }
- return true
- }
-
- err := quick.Check(f, &quick.Config{})
- if err != nil {
- t.Error(err)
- }
-
- f([]byte{0xFF, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88})
- f(bytes.Repeat([]byte{0xFF}, _W))
-}
-
-func TestShiftIn(t *testing.T) {
- if bits.UintSize != 64 {
- t.Skip("examples are only valid in 64 bit")
- }
- examples := []struct {
- m, x, expected []byte
- y uint64
- }{{
- m: []byte{13},
- x: []byte{0},
- y: 0x7FFF_FFFF_FFFF_FFFF,
- expected: []byte{7},
- }, {
- m: []byte{13},
- x: []byte{7},
- y: 0x7FFF_FFFF_FFFF_FFFF,
- expected: []byte{11},
- }, {
- m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
- x: make([]byte, 9),
- y: 0x7FFF_FFFF_FFFF_FFFF,
- expected: []byte{0x00, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
- }, {
- m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
- x: []byte{0x00, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
- y: 0,
- expected: []byte{0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08},
- }}
-
- for i, tt := range examples {
- m := modulusFromNat(new(nat).setBytes(tt.m))
- got := new(nat).setBytes(tt.x).expandFor(m).shiftIn(uint(tt.y), m)
- if got.cmpEq(new(nat).setBytes(tt.expected).expandFor(m)) != 1 {
- t.Errorf("%d: got %x, expected %x", i, got, tt.expected)
- }
- }
-}
-
-func TestModulusAndNatSizes(t *testing.T) {
- // These are 126 bit (2 * _W on 64-bit architectures) values, serialized as
- // 128 bits worth of bytes. If leading zeroes are stripped, they fit in two
- // limbs, if they are not, they fit in three. This can be a problem because
- // modulus strips leading zeroes and nat does not.
- m := modulusFromNat(new(nat).setBytes([]byte{
- 0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
- 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}))
- x := new(nat).setBytes([]byte{
- 0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
- 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe})
- x.expandFor(m) // must not panic for shrinking
-}
-
-func TestExpand(t *testing.T) {
- sliced := []uint{1, 2, 3, 4}
- examples := []struct {
- in []uint
- n int
- out []uint
- }{{
- []uint{1, 2},
- 4,
- []uint{1, 2, 0, 0},
- }, {
- sliced[:2],
- 4,
- []uint{1, 2, 0, 0},
- }, {
- []uint{1, 2},
- 2,
- []uint{1, 2},
- }, {
- []uint{1, 2, 0},
- 2,
- []uint{1, 2},
- }}
-
- for i, tt := range examples {
- got := (&nat{tt.in}).expand(tt.n)
- if len(got.limbs) != len(tt.out) || got.cmpEq(&nat{tt.out}) != 1 {
- t.Errorf("%d: got %x, expected %x", i, got, tt.out)
- }
- }
-}
-
-func TestMod(t *testing.T) {
- m := modulusFromNat(new(nat).setBytes([]byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d}))
- x := new(nat).setBytes([]byte{0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01})
- out := new(nat)
- out.mod(x, m)
- expected := new(nat).setBytes([]byte{0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09})
- if out.cmpEq(expected) != 1 {
- t.Errorf("%+v != %+v", out, expected)
- }
-}
-
-func TestModSub(t *testing.T) {
- m := modulusFromNat(&nat{[]uint{13}})
- x := &nat{[]uint{6}}
- y := &nat{[]uint{7}}
- x.modSub(y, m)
- expected := &nat{[]uint{12}}
- if x.cmpEq(expected) != 1 {
- t.Errorf("%+v != %+v", x, expected)
- }
- x.modSub(y, m)
- expected = &nat{[]uint{5}}
- if x.cmpEq(expected) != 1 {
- t.Errorf("%+v != %+v", x, expected)
- }
-}
-
-func TestModAdd(t *testing.T) {
- m := modulusFromNat(&nat{[]uint{13}})
- x := &nat{[]uint{6}}
- y := &nat{[]uint{7}}
- x.modAdd(y, m)
- expected := &nat{[]uint{0}}
- if x.cmpEq(expected) != 1 {
- t.Errorf("%+v != %+v", x, expected)
- }
- x.modAdd(y, m)
- expected = &nat{[]uint{7}}
- if x.cmpEq(expected) != 1 {
- t.Errorf("%+v != %+v", x, expected)
- }
-}
-
-func TestExp(t *testing.T) {
- m := modulusFromNat(&nat{[]uint{13}})
- x := &nat{[]uint{3}}
- out := &nat{[]uint{0}}
- out.exp(x, []byte{12}, m)
- expected := &nat{[]uint{1}}
- if out.cmpEq(expected) != 1 {
- t.Errorf("%+v != %+v", out, expected)
- }
-}
-
-func makeBenchmarkModulus() *modulus {
- m := make([]uint, 32)
- for i := 0; i < 32; i++ {
- m[i] = _MASK
- }
- return modulusFromNat(&nat{limbs: m})
-}
-
-func makeBenchmarkValue() *nat {
- x := make([]uint, 32)
- for i := 0; i < 32; i++ {
- x[i] = _MASK - 1
- }
- return &nat{limbs: x}
-}
-
-func makeBenchmarkExponent() []byte {
- e := make([]byte, 256)
- for i := 0; i < 32; i++ {
- e[i] = 0xFF
- }
- return e
-}
-
-func BenchmarkModAdd(b *testing.B) {
- x := makeBenchmarkValue()
- y := makeBenchmarkValue()
- m := makeBenchmarkModulus()
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- x.modAdd(y, m)
- }
-}
-
-func BenchmarkModSub(b *testing.B) {
- x := makeBenchmarkValue()
- y := makeBenchmarkValue()
- m := makeBenchmarkModulus()
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- x.modSub(y, m)
- }
-}
-
-func BenchmarkMontgomeryRepr(b *testing.B) {
- x := makeBenchmarkValue()
- m := makeBenchmarkModulus()
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- x.montgomeryRepresentation(m)
- }
-}
-
-func BenchmarkMontgomeryMul(b *testing.B) {
- x := makeBenchmarkValue()
- y := makeBenchmarkValue()
- out := makeBenchmarkValue()
- m := makeBenchmarkModulus()
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- out.montgomeryMul(x, y, m)
- }
-}
-
-func BenchmarkModMul(b *testing.B) {
- x := makeBenchmarkValue()
- y := makeBenchmarkValue()
- m := makeBenchmarkModulus()
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- x.modMul(y, m)
- }
-}
-
-func BenchmarkExpBig(b *testing.B) {
- out := new(big.Int)
- exponentBytes := makeBenchmarkExponent()
- x := new(big.Int).SetBytes(exponentBytes)
- e := new(big.Int).SetBytes(exponentBytes)
- n := new(big.Int).SetBytes(exponentBytes)
- one := new(big.Int).SetUint64(1)
- n.Add(n, one)
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- out.Exp(x, e, n)
- }
-}
-
-func BenchmarkExp(b *testing.B) {
- x := makeBenchmarkValue()
- e := makeBenchmarkExponent()
- out := makeBenchmarkValue()
- m := makeBenchmarkModulus()
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- out.exp(x, e, m)
- }
-}
return boring.EncryptRSANoPadding(bkey, em)
}
- return encrypt(pub, em), nil
+ return encrypt(pub, em)
}
// DecryptPKCS1v15 decrypts a plaintext using RSA and the padding scheme from PKCS #1 v1.5.
return ErrVerification
}
- em := encrypt(pub, sig)
+ em, err := encrypt(pub, sig)
+ if err != nil {
+ return ErrVerification
+ }
// EM = 0x00 || 0x01 || PS || 0x00 || T
ok := subtle.ConstantTimeByteEq(em[0], 0)
// given hash function. salt is a random sequence of bytes whose length will be
// later used to verify the signature.
func signPSSWithSalt(priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) ([]byte, error) {
- emBits := bigBitLen(priv.N) - 1
+ emBits := priv.N.BitLen() - 1
em, err := emsaPSSEncode(hashed, emBits, salt, hash.New())
if err != nil {
return nil, err
saltLength := opts.saltLength()
switch saltLength {
case PSSSaltLengthAuto:
- saltLength = (bigBitLen(priv.N)-1+7)/8 - 2 - hash.Size()
+ saltLength = (priv.N.BitLen()-1+7)/8 - 2 - hash.Size()
if saltLength < 0 {
return nil, ErrMessageTooLong
}
return invalidSaltLenErr
}
- emBits := bigBitLen(pub.N) - 1
+ emBits := pub.N.BitLen() - 1
emLen := (emBits + 7) / 8
- em := encrypt(pub, sig)
+ em, err := encrypt(pub, sig)
+ if err != nil {
+ return ErrVerification
+ }
// Like in signPSSWithSalt, deal with mismatches between emLen and the size
// of the modulus. The spec would have us wire emLen into the encoding
import (
"crypto"
+ "crypto/internal/bigmod"
"crypto/internal/boring"
"crypto/internal/boring/bbig"
"crypto/internal/randutil"
// Size returns the modulus size in bytes. Raw signatures and ciphertexts
// for or by this public key will have the same size.
func (pub *PublicKey) Size() int {
- return (bigBitLen(pub.N) + 7) / 8
+ return (pub.N.BitLen() + 7) / 8
}
// Equal reports whether pub and x have the same value.
// complexity.
CRTValues []CRTValue
- n, p, q *modulus // moduli for CRT with Montgomery precomputed constants
+ n, p, q *bigmod.Modulus // moduli for CRT with Montgomery precomputed constants
}
// CRTValue contains the precomputed Chinese remainder theorem values.
Dq: Dq,
Qinv: Qinv,
CRTValues: make([]CRTValue, 0), // non-nil, to match Precompute
- n: modulusFromNat(newNat().setBig(N)),
- p: modulusFromNat(newNat().setBig(P)),
- q: modulusFromNat(newNat().setBig(Q)),
+ n: bigmod.NewModulusFromBig(N),
+ p: bigmod.NewModulusFromBig(P),
+ q: bigmod.NewModulusFromBig(Q),
},
}
return key, nil
// be returned if the size of the salt is too large.
var ErrMessageTooLong = errors.New("crypto/rsa: message too long for RSA key size")
-func encrypt(pub *PublicKey, plaintext []byte) []byte {
+func encrypt(pub *PublicKey, plaintext []byte) ([]byte, error) {
boring.Unreachable()
- N := modulusFromNat(newNat().setBig(pub.N))
- m := newNat().setBytes(plaintext).expandFor(N)
+ N := bigmod.NewModulusFromBig(pub.N)
+ m, err := bigmod.NewNat().SetBytes(plaintext, N)
+ if err != nil {
+ return nil, err
+ }
e := intToBytes(pub.E)
- out := make([]byte, modulusSize(N))
- return newNat().exp(m, e, N).fillBytes(out)
+ return bigmod.NewNat().Exp(m, e, N).Bytes(N), nil
}
// intToBytes returns i as a big-endian slice of bytes with no leading zeroes,
return boring.EncryptRSANoPadding(bkey, em)
}
- return encrypt(pub, em), nil
+ return encrypt(pub, em)
}
// ErrDecryption represents a failure to decrypt a message.
// in the future.
func (priv *PrivateKey) Precompute() {
if priv.Precomputed.n == nil && len(priv.Primes) == 2 {
- priv.Precomputed.n = modulusFromNat(newNat().setBig(priv.N))
- priv.Precomputed.p = modulusFromNat(newNat().setBig(priv.Primes[0]))
- priv.Precomputed.q = modulusFromNat(newNat().setBig(priv.Primes[1]))
+ priv.Precomputed.n = bigmod.NewModulusFromBig(priv.N)
+ priv.Precomputed.p = bigmod.NewModulusFromBig(priv.Primes[0])
+ priv.Precomputed.q = bigmod.NewModulusFromBig(priv.Primes[1])
}
// Fill in the backwards-compatibility *big.Int values.
boring.Unreachable()
}
- N := priv.Precomputed.n
- if N == nil {
- N = modulusFromNat(newNat().setBig(priv.N))
- }
- c := newNat().setBytes(ciphertext).expandFor(N)
- if c.cmpGeq(N.nat) == 1 {
- return nil, ErrDecryption
- }
- if priv.N.Sign() == 0 {
- return nil, ErrDecryption
- }
-
- var m *nat
+ var (
+ err error
+ m, c *bigmod.Nat
+ N *bigmod.Modulus
+ t0 = bigmod.NewNat()
+ )
if priv.Precomputed.n == nil {
- m = newNat().exp(c, priv.D.Bytes(), N)
+ N = bigmod.NewModulusFromBig(priv.N)
+ c, err = bigmod.NewNat().SetBytes(ciphertext, N)
+ if err != nil {
+ return nil, ErrDecryption
+ }
+ m = bigmod.NewNat().Exp(c, priv.D.Bytes(), N)
} else {
- t0 := newNat()
+ N = priv.Precomputed.n
P, Q := priv.Precomputed.p, priv.Precomputed.q
+ Qinv, err := bigmod.NewNat().SetBytes(priv.Precomputed.Qinv.Bytes(), P)
+ if err != nil {
+ return nil, ErrDecryption
+ }
+ c, err = bigmod.NewNat().SetBytes(ciphertext, N)
+ if err != nil {
+ return nil, ErrDecryption
+ }
+
// m = c ^ Dp mod p
- m = newNat().exp(t0.mod(c, P), priv.Precomputed.Dp.Bytes(), P)
+ m = bigmod.NewNat().Exp(t0.Mod(c, P), priv.Precomputed.Dp.Bytes(), P)
// m2 = c ^ Dq mod q
- m2 := newNat().exp(t0.mod(c, Q), priv.Precomputed.Dq.Bytes(), Q)
+ m2 := bigmod.NewNat().Exp(t0.Mod(c, Q), priv.Precomputed.Dq.Bytes(), Q)
// m = m - m2 mod p
- m.modSub(t0.mod(m2, P), P)
+ m.Sub(t0.Mod(m2, P), P)
// m = m * Qinv mod p
- m.modMul(newNat().setBig(priv.Precomputed.Qinv).expandFor(P), P)
+ m.Mul(Qinv, P)
// m = m * q mod N
- m.expandFor(N).modMul(t0.mod(Q.nat, N), N)
+ m.ExpandFor(N).Mul(t0.Mod(Q.Nat(), N), N)
// m = m + m2 mod N
- m.modAdd(m2.expandFor(N), N)
+ m.Add(m2.ExpandFor(N), N)
}
if check {
- c1 := newNat().exp(m, intToBytes(priv.E), N)
- if c1.cmpEq(c) != 1 {
+ c1 := bigmod.NewNat().Exp(m, intToBytes(priv.E), N)
+ if c1.Equal(c) != 1 {
return nil, ErrDecryption
}
}
- out := make([]byte, modulusSize(N))
- return m.fillBytes(out), nil
+ return m.Bytes(N), nil
}
// DecryptOAEP decrypts ciphertext using RSA-OAEP.
}
hash[1] ^= 0x80
}
+
+ // Check that an input bigger than the modulus is handled correctly,
+ // whether it is longer than the byte size of the modulus or not.
+ c := bytes.Repeat([]byte{0xff}, priv.Size())
+ err = VerifyPSS(&priv.PublicKey, crypto.SHA256, hash[:], c, opts)
+ if err == nil {
+ t.Errorf("VerifyPSS accepted a large signature")
+ }
+ _, err = DecryptPKCS1v15(nil, priv, c)
+ if err == nil {
+ t.Errorf("DecryptPKCS1v15 accepted a large ciphertext")
+ }
+ c = append(c, 0xff)
+ err = VerifyPSS(&priv.PublicKey, crypto.SHA256, hash[:], c, opts)
+ if err == nil {
+ t.Errorf("VerifyPSS accepted a long signature")
+ }
+ _, err = DecryptPKCS1v15(nil, priv, c)
+ if err == nil {
+ t.Errorf("DecryptPKCS1v15 accepted a long ciphertext")
+ }
}
func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") }
< encoding/asn1
< golang.org/x/crypto/cryptobyte/asn1
< golang.org/x/crypto/cryptobyte
+ < crypto/internal/bigmod
< crypto/dsa, crypto/elliptic, crypto/rsa
< crypto/ecdsa
< CRYPTO-MATH;
// 0 if x == 0
// +1 if x > 0
func (x *Int) Sign() int {
+ // This function is used in cryptographic operations. It must not leak
+ // anything but the Int's sign and bit size through side-channels. Any
+ // changes must be reviewed by a security expert.
if len(x.abs) == 0 {
return 0
}
// Bits is intended to support implementation of missing low-level Int
// functionality outside this package; it should be avoided otherwise.
func (x *Int) Bits() []Word {
+ // This function is used in cryptographic operations. It must not leak
+ // anything but the Int's sign and bit size through side-channels. Any
+ // changes must be reviewed by a security expert.
return x.abs
}
//
// To use a fixed length slice, or a preallocated one, use FillBytes.
func (x *Int) Bytes() []byte {
+ // This function is used in cryptographic operations. It must not leak
+ // anything but the Int's sign and bit size through side-channels. Any
+ // changes must be reviewed by a security expert.
buf := make([]byte, len(x.abs)*_S)
return buf[x.abs.bytes(buf):]
}
// BitLen returns the length of the absolute value of x in bits.
// The bit length of 0 is 0.
func (x *Int) BitLen() int {
+ // This function is used in cryptographic operations. It must not leak
+ // anything but the Int's sign and bit size through side-channels. Any
+ // changes must be reviewed by a security expert.
return x.abs.bitLen()
}
// bitLen returns the length of x in bits.
// Unlike most methods, it works even if x is not normalized.
func (x nat) bitLen() int {
+ // This function is used in cryptographic operations. It must not leak
+ // anything but the Int's sign and bit size through side-channels. Any
+ // changes must be reviewed by a security expert.
+ //
+ // In particular, bits.Len and bits.LeadingZeros use a lookup table for the
+ // low-order bits on some architectures.
if i := len(x) - 1; i >= 0 {
- return i*_W + bits.Len(uint(x[i]))
+ l := i * _W
+ for top := x[i]; top != 0; top >>= 1 {
+ l++
+ }
+ return l
}
return 0
}
// cannot be represented in buf, bytes panics. The number i of unused
// bytes at the beginning of buf is returned as result.
func (z nat) bytes(buf []byte) (i int) {
+ // This function is used in cryptographic operations. It must not leak
+ // anything but the Int's sign and bit size through side-channels. Any
+ // changes must be reviewed by a security expert.
i = len(buf)
for _, d := range z {
for j := 0; j < _S; j++ {