-// Copyright 2022 The Go Authors. All rights reserved.
+// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
+ "strconv"
+
. "github.com/mmcloughlin/avo/build"
. "github.com/mmcloughlin/avo/operand"
. "github.com/mmcloughlin/avo/reg"
)
-//go:generate go run . -out ../nat_amd64.s -stubs ../nat_amd64.go -pkg bigmod
+//go:generate go run . -out ../nat_amd64.s -pkg bigmod
func main() {
Package("crypto/internal/bigmod")
- ConstraintExpr("amd64,gc,!purego")
-
- Implement("montgomeryLoop")
- Pragma("noescape")
-
- size := Load(Param("d").Len(), GP64())
- d := Mem{Base: Load(Param("d").Base(), GP64())}
- b := Mem{Base: Load(Param("b").Base(), GP64())}
- m := Mem{Base: Load(Param("m").Base(), GP64())}
- m0inv := Load(Param("m0inv"), GP64())
-
- overflow := zero()
- i := zero()
- Label("outerLoop")
-
- ai := Load(Param("a").Base(), GP64())
- MOVQ(Mem{Base: ai}.Idx(i, 8), ai)
-
- z := uint128{GP64(), GP64()}
- mul64(z, b, ai)
- add64(z, d)
- f := GP64()
- MOVQ(m0inv, f)
- IMULQ(z.lo, f)
- _MASK(f)
- addMul64(z, m, f)
- carry := shiftBy63(z)
-
- j := zero()
- INCQ(j)
- JMP(LabelRef("innerLoopCondition"))
- Label("innerLoop")
-
- // z = d[j] + a[i] * b[j] + f * m[j] + carry
- z = uint128{GP64(), GP64()}
- mul64(z, b.Idx(j, 8), ai)
- addMul64(z, m.Idx(j, 8), f)
- add64(z, d.Idx(j, 8))
- add64(z, carry)
- // d[j-1] = z_lo & _MASK
- storeMasked(z.lo, d.Idx(j, 8).Offset(-8))
- // carry = z_hi<<1 | z_lo>>_W
- MOVQ(shiftBy63(z), carry)
-
- INCQ(j)
- Label("innerLoopCondition")
- CMPQ(size, j)
- JGT(LabelRef("innerLoop"))
-
- ADDQ(carry, overflow)
- storeMasked(overflow, d.Idx(size, 8).Offset(-8))
- SHRQ(Imm(63), overflow)
-
- INCQ(i)
- CMPQ(size, i)
- JGT(LabelRef("outerLoop"))
-
- Store(overflow, ReturnIndex(0))
- RET()
- Generate()
-}
+ ConstraintExpr("!purego")
-// zero zeroes a new register and returns it.
-func zero() Register {
- r := GP64()
- XORQ(r, r)
- return r
-}
-
-// _MASK masks out the top bit of r.
-func _MASK(r Register) {
- BTRQ(Imm(63), r)
-}
-
-type uint128 struct {
- hi, lo GPVirtual
-}
+ addMulVVW(1024)
+ addMulVVW(1536)
+ addMulVVW(2048)
-// storeMasked stores _MASK(src) in dst. It doesn't modify src.
-func storeMasked(src, dst Op) {
- out := GP64()
- MOVQ(src, out)
- _MASK(out)
- MOVQ(out, dst)
-}
-
-// shiftBy63 returns z >> 63. It reuses z.lo.
-func shiftBy63(z uint128) Register {
- SHRQ(Imm(63), z.hi, z.lo)
- result := z.lo
- z.hi, z.lo = nil, nil
- return result
-}
-
-// add64 sets r to r + a.
-func add64(r uint128, a Op) {
- ADDQ(a, r.lo)
- ADCQ(Imm(0), r.hi)
+ Generate()
}
-// mul64 sets r to a * b.
-func mul64(r uint128, a, b Op) {
- MOVQ(a, RAX)
- MULQ(b) // RDX, RAX = RAX * b
- MOVQ(RAX, r.lo)
- MOVQ(RDX, r.hi)
-}
+func addMulVVW(bits int) {
+ if bits%64 != 0 {
+ panic("bit size unsupported")
+ }
+
+ Implement("addMulVVW" + strconv.Itoa(bits))
+
+ CMPB(Mem{Symbol: Symbol{Name: "·supportADX"}, Base: StaticBase}, Imm(1))
+ JEQ(LabelRef("adx"))
+
+ z := Mem{Base: Load(Param("z"), GP64())}
+ x := Mem{Base: Load(Param("x"), GP64())}
+ y := Load(Param("y"), GP64())
+
+ carry := GP64()
+ XORQ(carry, carry) // zero out carry
+
+ for i := 0; i < bits/64; i++ {
+ Comment("Iteration " + strconv.Itoa(i))
+ hi, lo := RDX, RAX // implicit MULQ inputs and outputs
+ MOVQ(x.Offset(i*8), lo)
+ MULQ(y)
+ ADDQ(z.Offset(i*8), lo)
+ ADCQ(Imm(0), hi)
+ ADDQ(carry, lo)
+ ADCQ(Imm(0), hi)
+ MOVQ(hi, carry)
+ MOVQ(lo, z.Offset(i*8))
+ }
+
+ Store(carry, ReturnIndex(0))
+ RET()
-// addMul64 sets r to r + a * b.
-func addMul64(r uint128, a, b Op) {
- MOVQ(a, RAX)
- MULQ(b) // RDX, RAX = RAX * b
- ADDQ(RAX, r.lo)
- ADCQ(RDX, r.hi)
+ Label("adx")
+
+ // The ADX strategy implements the following function, where c1 and c2 are
+ // the overflow and the carry flag respectively.
+ //
+ // func addMulVVW(z, x []uint, y uint) (carry uint) {
+ // var c1, c2 uint
+ // for i := range z {
+ // hi, lo := bits.Mul(x[i], y)
+ // lo, c1 = bits.Add(lo, z[i], c1)
+ // z[i], c2 = bits.Add(lo, carry, c2)
+ // carry = hi
+ // }
+ // return carry + c1 + c2
+ // }
+ //
+ // The loop is fully unrolled and the hi / carry registers are alternated
+ // instead of introducing a MOV.
+
+ z = Mem{Base: Load(Param("z"), GP64())}
+ x = Mem{Base: Load(Param("x"), GP64())}
+ Load(Param("y"), RDX) // implicit source of MULXQ
+
+ carry = GP64()
+ XORQ(carry, carry) // zero out carry
+ z0 := GP64()
+ XORQ(z0, z0) // unset flags and zero out z0
+
+ for i := 0; i < bits/64; i++ {
+ hi, lo := GP64(), GP64()
+
+ Comment("Iteration " + strconv.Itoa(i))
+ MULXQ(x.Offset(i*8), lo, hi)
+ ADCXQ(carry, lo)
+ ADOXQ(z.Offset(i*8), lo)
+ MOVQ(lo, z.Offset(i*8))
+
+ i++
+
+ Comment("Iteration " + strconv.Itoa(i))
+ MULXQ(x.Offset(i*8), lo, carry)
+ ADCXQ(hi, lo)
+ ADOXQ(z.Offset(i*8), lo)
+ MOVQ(lo, z.Offset(i*8))
+ }
+
+ Comment("Add back carry flags and return")
+ ADCXQ(z0, carry)
+ ADOXQ(z0, carry)
+
+ Store(carry, ReturnIndex(0))
+ RET()
}
package bigmod
import (
+ "encoding/binary"
"errors"
"math/big"
"math/bits"
)
const (
- // _W is the number of bits we use for our limbs.
- _W = bits.UintSize - 1
- // _MASK selects _W bits from a full machine word.
- _MASK = (1 << _W) - 1
+ // _W is the size in bits of our limbs.
+ _W = bits.UintSize
+ // _S is the size in bytes of our limbs.
+ _S = _W / 8
)
// choice represents a constant-time boolean. The value of choice is always
const yes = choice(1)
const no = choice(0)
-// ctSelect returns x if on == 1, and y if on == 0. The execution time of this
-// function does not depend on its inputs. If on is any value besides 1 or 0,
-// the result is undefined.
-func ctSelect(on choice, x, y uint) uint {
- // When on == 1, mask is 0b111..., otherwise mask is 0b000...
- mask := -uint(on)
- // When mask is all zeros, we just have y, otherwise, y cancels with itself.
- return y ^ (mask & (y ^ x))
-}
+// ctMask is all 1s if on is yes, and all 0s otherwise.
+func ctMask(on choice) uint { return -uint(on) }
// ctEq returns 1 if x == y, and 0 otherwise. The execution time of this
// function does not depend on its inputs.
// Operations on this number are allowed to leak this length, but will not leak
// any information about the values contained in those limbs.
type Nat struct {
- // limbs is a little-endian representation in base 2^W with
- // W = bits.UintSize - 1. The top bit is always unset between operations.
- //
- // The top bit is left unset to optimize Montgomery multiplication, in the
- // inner loop of exponentiation. Using fully saturated limbs would leave us
- // working with 129-bit numbers on 64-bit platforms, wasting a lot of space,
- // and thus time.
+ // limbs is little-endian in base 2^W with W = bits.UintSize.
limbs []uint
}
// The announced length of x is set based on the actual bit size of the input,
// ignoring leading zeroes.
func (x *Nat) setBig(n *big.Int) *Nat {
- requiredLimbs := (n.BitLen() + _W - 1) / _W
- x.reset(requiredLimbs)
-
- outI := 0
- shift := 0
limbs := n.Bits()
+ x.reset(len(limbs))
for i := range limbs {
- xi := uint(limbs[i])
- x.limbs[outI] |= (xi << shift) & _MASK
- outI++
- if outI == requiredLimbs {
- return x
- }
- x.limbs[outI] = xi >> (_W - shift)
- shift++ // this assumes bits.UintSize - _W = 1
- if shift == _W {
- shift = 0
- outI++
- }
+ x.limbs[i] = uint(limbs[i])
}
return x
}
//
// x must have the same size as m and it must be reduced modulo m.
func (x *Nat) Bytes(m *Modulus) []byte {
- bytes := make([]byte, m.Size())
- shift := 0
- outI := len(bytes) - 1
+ i := m.Size()
+ bytes := make([]byte, i)
for _, limb := range x.limbs {
- remainingBits := _W
- for remainingBits >= 8 {
- bytes[outI] |= byte(limb) << shift
- consumed := 8 - shift
- limb >>= consumed
- remainingBits -= consumed
- shift = 0
- outI--
- if outI < 0 {
- return bytes
+ for j := 0; j < _S; j++ {
+ i--
+ if i < 0 {
+ if limb == 0 {
+ break
+ }
+ panic("bigmod: modulus is smaller than nat")
}
+ bytes[i] = byte(limb)
+ limb >>= 8
}
- bytes[outI] = byte(limb)
- shift = remainingBits
}
return bytes
}
return x, nil
}
-// SetOverflowingBytes assigns x = b, where b is a slice of big-endian bytes. SetOverflowingBytes
-// returns an error if b has a longer bit length than m, but reduces overflowing
-// values up to 2^⌈log2(m)⌉ - 1.
+// SetOverflowingBytes assigns x = b, where b is a slice of big-endian bytes.
+// SetOverflowingBytes returns an error if b has a longer bit length than m, but
+// reduces overflowing values up to 2^⌈log2(m)⌉ - 1.
//
// The output will be resized to the size of m and overwritten.
func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) {
}
leading := _W - bitLen(x.limbs[len(x.limbs)-1])
if leading < m.leading {
- return nil, errors.New("input overflows the modulus")
+ return nil, errors.New("input overflows the modulus size")
}
- x.sub(x.cmpGeq(m.nat), m.nat)
+ x.maybeSubtractModulus(no, m)
return x, nil
}
+// bigEndianUint returns the contents of buf interpreted as a
+// big-endian encoded uint value.
+func bigEndianUint(buf []byte) uint {
+ if _W == 64 {
+ return uint(binary.BigEndian.Uint64(buf))
+ }
+ return uint(binary.BigEndian.Uint32(buf))
+}
+
func (x *Nat) setBytes(b []byte, m *Modulus) error {
- outI := 0
- shift := 0
x.resetFor(m)
- for i := len(b) - 1; i >= 0; i-- {
- bi := b[i]
- x.limbs[outI] |= uint(bi) << shift
- shift += 8
- if shift >= _W {
- shift -= _W
- x.limbs[outI] &= _MASK
- overflow := bi >> (8 - shift)
- outI++
- if outI >= len(x.limbs) {
- if overflow > 0 || i > 0 {
- return errors.New("input overflows the modulus")
- }
- break
- }
- x.limbs[outI] = uint(overflow)
- }
+ i, k := len(b), 0
+ for k < len(x.limbs) && i >= _S {
+ x.limbs[k] = bigEndianUint(b[i-_S : i])
+ i -= _S
+ k++
+ }
+ for s := 0; s < _W && k < len(x.limbs) && i > 0; s += 8 {
+ x.limbs[k] |= uint(b[i-1]) << s
+ i--
+ }
+ if i > 0 {
+ return errors.New("input overflows the modulus size")
}
return nil
}
var c uint
for i := 0; i < size; i++ {
- c = (xLimbs[i] - yLimbs[i] - c) >> _W
+ _, c = bits.Sub(xLimbs[i], yLimbs[i], c)
}
// If there was a carry, then subtracting y underflowed, so
// x is not greater than or equal to y.
xLimbs := x.limbs[:size]
yLimbs := y.limbs[:size]
+ mask := ctMask(on)
for i := 0; i < size; i++ {
- xLimbs[i] = ctSelect(on, yLimbs[i], xLimbs[i])
+ xLimbs[i] ^= mask & (xLimbs[i] ^ yLimbs[i])
}
return x
}
-// add computes x += y if on == 1, and does nothing otherwise. It returns the
-// carry of the addition regardless of on.
+// add computes x += y and returns the carry.
//
// Both operands must have the same announced length.
-func (x *Nat) add(on choice, y *Nat) (c uint) {
+func (x *Nat) add(y *Nat) (c uint) {
// Eliminate bounds checks in the loop.
size := len(x.limbs)
xLimbs := x.limbs[:size]
yLimbs := y.limbs[:size]
for i := 0; i < size; i++ {
- res := xLimbs[i] + yLimbs[i] + c
- xLimbs[i] = ctSelect(on, res&_MASK, xLimbs[i])
- c = res >> _W
+ xLimbs[i], c = bits.Add(xLimbs[i], yLimbs[i], c)
}
return
}
-// sub computes x -= y if on == 1, and does nothing otherwise. It returns the
-// borrow of the subtraction regardless of on.
+// sub computes x -= y. It returns the borrow of the subtraction.
//
// Both operands must have the same announced length.
-func (x *Nat) sub(on choice, y *Nat) (c uint) {
+func (x *Nat) sub(y *Nat) (c uint) {
// Eliminate bounds checks in the loop.
size := len(x.limbs)
xLimbs := x.limbs[:size]
yLimbs := y.limbs[:size]
for i := 0; i < size; i++ {
- res := xLimbs[i] - yLimbs[i] - c
- xLimbs[i] = ctSelect(on, res&_MASK, xLimbs[i])
- c = res >> _W
+ xLimbs[i], c = bits.Sub(xLimbs[i], yLimbs[i], c)
}
return
}
// Every iteration of this loop doubles the least-significant bits of
// correct inverse in y. The first three bits are already correct (1⁻¹ = 1,
// 3⁻¹ = 3, 5⁻¹ = 5, and 7⁻¹ = 7 mod 8), so doubling five times is enough
- // for 61 bits (and wastes only one iteration for 31 bits).
+ // for 64 bits (and wastes only one iteration for 32 bits).
//
// See https://crypto.stackexchange.com/a/47496.
y := x
for i := 0; i < 5; i++ {
y = y * (2 - x*y)
}
- return (1 << _W) - (y & _MASK)
+ return -y
}
// NewModulusFromBig creates a new Modulus from a [big.Int].
//
-// The Int must be odd. The number of significant bits must be leakable.
+// The Int must be odd. The number of significant bits (and nothing else) is
+// leaked through timing side-channels.
func NewModulusFromBig(n *big.Int) *Modulus {
m := &Modulus{}
m.nat = NewNat().setBig(n)
// shiftIn calculates x = x << _W + y mod m.
//
-// This assumes that x is already reduced mod m, and that y < 2^_W.
+// This assumes that x is already reduced mod m.
func (x *Nat) shiftIn(y uint, m *Modulus) *Nat {
d := NewNat().resetFor(m)
//
// To do the reduction, each iteration computes both 2x + b and 2x + b - m.
// The next iteration (and finally the return line) will use either result
- // based on whether the subtraction underflowed.
+ // based on whether 2x + b overflows m.
needSubtraction := no
for i := _W - 1; i >= 0; i-- {
carry := (y >> i) & 1
var borrow uint
+ mask := ctMask(needSubtraction)
for i := 0; i < size; i++ {
- l := ctSelect(needSubtraction, dLimbs[i], xLimbs[i])
-
- res := l<<1 + carry
- xLimbs[i] = res & _MASK
- carry = res >> _W
-
- res = xLimbs[i] - mLimbs[i] - borrow
- dLimbs[i] = res & _MASK
- borrow = res >> _W
+ l := xLimbs[i] ^ (mask & (xLimbs[i] ^ dLimbs[i]))
+ xLimbs[i], carry = bits.Add(l, l, carry)
+ dLimbs[i], borrow = bits.Sub(xLimbs[i], mLimbs[i], borrow)
}
- // See Add for how carry (aka overflow), borrow (aka underflow), and
- // needSubtraction relate.
- needSubtraction = ctEq(carry, borrow)
+ // Like in maybeSubtractModulus, we need the subtraction if either it
+ // didn't underflow (meaning 2x + b > m) or if computing 2x + b
+ // overflowed (meaning 2x + b > 2^_W*n > m).
+ needSubtraction = not(choice(borrow)) | choice(carry)
}
return x.assign(needSubtraction, d)
}
return out
}
-// ExpandFor ensures out has the right size to work with operations modulo m.
+// ExpandFor ensures x has the right size to work with operations modulo m.
//
-// The announced size of out must be smaller than or equal to that of m.
-func (out *Nat) ExpandFor(m *Modulus) *Nat {
- return out.expand(len(m.nat.limbs))
+// The announced size of x must be smaller than or equal to that of m.
+func (x *Nat) ExpandFor(m *Modulus) *Nat {
+ return x.expand(len(m.nat.limbs))
}
// resetFor ensures out has the right size to work with operations modulo m.
return out.reset(len(m.nat.limbs))
}
+// maybeSubtractModulus computes x -= m if and only if x >= m or if "always" is yes.
+//
+// It can be used to reduce modulo m a value up to 2m - 1, which is a common
+// range for results computed by higher level operations.
+//
+// always is usually a carry that indicates that the operation that produced x
+// overflowed its size, meaning abstractly x > 2^_W*n > m even if x < m.
+//
+// x and m operands must have the same announced length.
+func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) {
+ t := NewNat().set(x)
+ underflow := t.sub(m.nat)
+ // We keep the result if x - m didn't underflow (meaning x >= m)
+ // or if always was set.
+ keep := not(choice(underflow)) | choice(always)
+ x.assign(keep, t)
+}
+
// Sub computes x = x - y mod m.
//
// The length of both operands must be the same as the modulus. Both operands
// must already be reduced modulo m.
func (x *Nat) Sub(y *Nat, m *Modulus) *Nat {
- underflow := x.sub(yes, y)
+ underflow := x.sub(y)
// If the subtraction underflowed, add m.
- x.add(choice(underflow), m.nat)
+ t := NewNat().set(x)
+ t.add(m.nat)
+ x.assign(choice(underflow), t)
return x
}
// The length of both operands must be the same as the modulus. Both operands
// must already be reduced modulo m.
func (x *Nat) Add(y *Nat, m *Modulus) *Nat {
- overflow := x.add(yes, y)
- underflow := not(x.cmpGeq(m.nat)) // x < m
-
- // Three cases are possible:
- //
- // - overflow = 0, underflow = 0
- //
- // In this case, addition fits in our limbs, but we can still subtract away
- // m without an underflow, so we need to perform the subtraction to reduce
- // our result.
- //
- // - overflow = 0, underflow = 1
- //
- // The addition fits in our limbs, but we can't subtract m without
- // underflowing. The result is already reduced.
- //
- // - overflow = 1, underflow = 1
- //
- // The addition does not fit in our limbs, and the subtraction's borrow
- // would cancel out with the addition's carry. We need to subtract m to
- // reduce our result.
- //
- // The overflow = 1, underflow = 0 case is not possible, because y is at
- // most m - 1, and if adding m - 1 overflows, then subtracting m must
- // necessarily underflow.
- needSubtraction := ctEq(overflow, uint(underflow))
-
- x.sub(needSubtraction, m.nat)
+ overflow := x.add(y)
+ x.maybeSubtractModulus(choice(overflow), m)
return x
}
return x.montgomeryMul(t0, t1, m)
}
-// montgomeryMul calculates d = a * b / R mod m, with R = 2^(_W * n) and
-// n = len(m.nat.limbs), using the Montgomery Multiplication technique.
+// montgomeryMul calculates x = a * b / R mod m, with R = 2^(_W * n) and
+// n = len(m.nat.limbs), also known as a Montgomery multiplication.
//
-// All inputs should be the same length, not aliasing d, and already
-// reduced modulo m. d will be resized to the size of m and overwritten.
-func (d *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat {
- d.resetFor(m)
- if len(a.limbs) != len(m.nat.limbs) || len(b.limbs) != len(m.nat.limbs) {
- panic("bigmod: invalid montgomeryMul input")
- }
+// All inputs should be the same length and already reduced modulo m.
+// x will be resized to the size of m and overwritten.
+func (x *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat {
+ n := len(m.nat.limbs)
+ mLimbs := m.nat.limbs[:n]
+ aLimbs := a.limbs[:n]
+ bLimbs := b.limbs[:n]
+
+ switch n {
+ default:
+ // Attempt to use a stack-allocated backing array.
+ T := make([]uint, 0, preallocLimbs*2)
+ if cap(T) < n*2 {
+ T = make([]uint, 0, n*2)
+ }
+ T = T[:n*2]
+
+ // This loop implements Word-by-Word Montgomery Multiplication, as
+ // described in Algorithm 4 (Fig. 3) of "Efficient Software
+ // Implementations of Modular Exponentiation" by Shay Gueron
+ // [https://eprint.iacr.org/2011/239.pdf].
+ var c uint
+ for i := 0; i < n; i++ {
+ _ = T[n+i] // bounds check elimination hint
+
+ // Step 1 (T = a × b) is computed as a large pen-and-paper column
+ // multiplication of two numbers with n base-2^_W digits. If we just
+ // wanted to produce 2n-wide T, we would do
+ //
+ // for i := 0; i < n; i++ {
+ // d := bLimbs[i]
+ // T[n+i] = addMulVVW(T[i:n+i], aLimbs, d)
+ // }
+ //
+ // where d is a digit of the multiplier, T[i:n+i] is the shifted
+ // position of the product of that digit, and T[n+i] is the final carry.
+ // Note that T[i] isn't modified after processing the i-th digit.
+ //
+ // Instead of running two loops, one for Step 1 and one for Steps 2–6,
+ // the result of Step 1 is computed during the next loop. This is
+ // possible because each iteration only uses T[i] in Step 2 and then
+ // discards it in Step 6.
+ d := bLimbs[i]
+ c1 := addMulVVW(T[i:n+i], aLimbs, d)
+
+ // Step 6 is replaced by shifting the virtual window we operate
+ // over: T of the algorithm is T[i:] for us. That means that T1 in
+ // Step 2 (T mod 2^_W) is simply T[i]. k0 in Step 3 is our m0inv.
+ Y := T[i] * m.m0inv
+
+ // Step 4 and 5 add Y × m to T, which as mentioned above is stored
+ // at T[i:]. The two carries (from a × d and Y × m) are added up in
+ // the next word T[n+i], and the carry bit from that addition is
+ // brought forward to the next iteration.
+ c2 := addMulVVW(T[i:n+i], mLimbs, Y)
+ T[n+i], c = bits.Add(c1, c2, c)
+ }
- // See https://bearssl.org/bigint.html#montgomery-reduction-and-multiplication
- // for a description of the algorithm implemented mostly in montgomeryLoop.
- // See Add for how overflow, underflow, and needSubtraction relate.
- overflow := montgomeryLoop(d.limbs, a.limbs, b.limbs, m.nat.limbs, m.m0inv)
- underflow := not(d.cmpGeq(m.nat)) // d < m
- needSubtraction := ctEq(overflow, uint(underflow))
- d.sub(needSubtraction, m.nat)
+ // Finally for Step 7 we copy the final T window into x, and subtract m
+ // if necessary (which as explained in maybeSubtractModulus can be the
+ // case both if x >= m, or if x overflowed).
+ //
+ // The paper suggests in Section 4 that we can do an "Almost Montgomery
+ // Multiplication" by subtracting only in the overflow case, but the
+ // cost is very similar since the constant time subtraction tells us if
+ // x >= m as a side effect, and taking care of the broken invariant is
+ // highly undesirable (see https://go.dev/issue/13907).
+ copy(x.reset(n).limbs, T[n:])
+ x.maybeSubtractModulus(choice(c), m)
+
+ // The following specialized cases follow the exact same algorithm, but
+ // optimized for the sizes most used in RSA. addMulVVW is implemented in
+ // assembly with loop unrolling depending on the architecture and bounds
+ // checks are removed by the compiler thanks to the constant size.
+ case 1024 / _W:
+ const n = 1024 / _W // compiler hint
+ T := make([]uint, n*2)
+ var c uint
+ for i := 0; i < n; i++ {
+ d := bLimbs[i]
+ c1 := addMulVVW1024(&T[i], &aLimbs[0], d)
+ Y := T[i] * m.m0inv
+ c2 := addMulVVW1024(&T[i], &mLimbs[0], Y)
+ T[n+i], c = bits.Add(c1, c2, c)
+ }
+ copy(x.reset(n).limbs, T[n:])
+ x.maybeSubtractModulus(choice(c), m)
+
+ case 1536 / _W:
+ const n = 1536 / _W // compiler hint
+ T := make([]uint, n*2)
+ var c uint
+ for i := 0; i < n; i++ {
+ d := bLimbs[i]
+ c1 := addMulVVW1536(&T[i], &aLimbs[0], d)
+ Y := T[i] * m.m0inv
+ c2 := addMulVVW1536(&T[i], &mLimbs[0], Y)
+ T[n+i], c = bits.Add(c1, c2, c)
+ }
+ copy(x.reset(n).limbs, T[n:])
+ x.maybeSubtractModulus(choice(c), m)
+
+ case 2048 / _W:
+ const n = 2048 / _W // compiler hint
+ T := make([]uint, n*2)
+ var c uint
+ for i := 0; i < n; i++ {
+ d := bLimbs[i]
+ c1 := addMulVVW2048(&T[i], &aLimbs[0], d)
+ Y := T[i] * m.m0inv
+ c2 := addMulVVW2048(&T[i], &mLimbs[0], Y)
+ T[n+i], c = bits.Add(c1, c2, c)
+ }
+ copy(x.reset(n).limbs, T[n:])
+ x.maybeSubtractModulus(choice(c), m)
+ }
- return d
+ return x
}
-func montgomeryLoopGeneric(d, a, b, m []uint, m0inv uint) (overflow uint) {
- // Eliminate bounds checks in the loop.
- size := len(d)
- a = a[:size]
- b = b[:size]
- m = m[:size]
-
- for _, ai := range a {
- // This is an unrolled iteration of the loop below with j = 0.
- hi, lo := bits.Mul(ai, b[0])
- z_lo, c := bits.Add(d[0], lo, 0)
- f := (z_lo * m0inv) & _MASK // (d[0] + a[i] * b[0]) * m0inv
- z_hi, _ := bits.Add(0, hi, c)
- hi, lo = bits.Mul(f, m[0])
- z_lo, c = bits.Add(z_lo, lo, 0)
- z_hi, _ = bits.Add(z_hi, hi, c)
- carry := z_hi<<1 | z_lo>>_W
-
- for j := 1; j < size; j++ {
- // z = d[j] + a[i] * b[j] + f * m[j] + carry <= 2^(2W+1) - 2^(W+1) + 2^W
- hi, lo := bits.Mul(ai, b[j])
- z_lo, c := bits.Add(d[j], lo, 0)
- z_hi, _ := bits.Add(0, hi, c)
- hi, lo = bits.Mul(f, m[j])
- z_lo, c = bits.Add(z_lo, lo, 0)
- z_hi, _ = bits.Add(z_hi, hi, c)
- z_lo, c = bits.Add(z_lo, carry, 0)
- z_hi, _ = bits.Add(z_hi, 0, c)
- d[j-1] = z_lo & _MASK
- carry = z_hi<<1 | z_lo>>_W // carry <= 2^(W+1) - 2
- }
-
- z := overflow + carry // z <= 2^(W+1) - 1
- d[size-1] = z & _MASK
- overflow = z >> _W // overflow <= 1
+// addMulVVW multiplies the multi-word value x by the single-word value y,
+// adding the result to the multi-word value z and returning the final carry.
+// It can be thought of as one row of a pen-and-paper column multiplication.
+func addMulVVW(z, x []uint, y uint) (carry uint) {
+ _ = x[len(z)-1] // bounds check elimination hint
+ for i := range z {
+ hi, lo := bits.Mul(x[i], y)
+ lo, c := bits.Add(lo, z[i], 0)
+ // We use bits.Add with zero to get an add-with-carry instruction that
+ // absorbs the carry from the previous bits.Add.
+ hi, _ = bits.Add(hi, 0, c)
+ lo, c = bits.Add(lo, carry, 0)
+ hi, _ = bits.Add(hi, 0, c)
+ carry = hi
+ z[i] = lo
}
- return
+ return carry
}
// Mul calculates x *= y mod m.
func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
// We use a 4 bit window. For our RSA workload, 4 bit windows are faster
// than 2 bit windows, but use an extra 12 nats worth of scratch space.
- // Using bit sizes that don't divide 8 are more complex to implement.
+ // Using bit sizes that don't divide 8 are more complex to implement, but
+ // are likely to be more efficient if necessary.
table := [(1 << 4) - 1]*Nat{ // table[i] = x ^ (i+1)
// newNat calls are unrolled so they are allocated on the stack.
t1 := NewNat().ExpandFor(m)
for _, b := range e {
for _, j := range []int{4, 0} {
- // Square four times.
+ // Square four times. Optimization note: this can be implemented
+ // more efficiently than with generic Montgomery multiplication.
t1.montgomeryMul(out, out, m)
out.montgomeryMul(t1, t1, m)
t1.montgomeryMul(out, out, m)
--- /dev/null
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !purego
+
+#include "textflag.h"
+
+// func addMulVVW1024(z, x *uint, y uint) (c uint)
+TEXT ·addMulVVW1024(SB), $0-16
+ MOVL $32, BX
+ JMP addMulVVWx(SB)
+
+// func addMulVVW1536(z, x *uint, y uint) (c uint)
+TEXT ·addMulVVW1536(SB), $0-16
+ MOVL $48, BX
+ JMP addMulVVWx(SB)
+
+// func addMulVVW2048(z, x *uint, y uint) (c uint)
+TEXT ·addMulVVW2048(SB), $0-16
+ MOVL $64, BX
+ JMP addMulVVWx(SB)
+
+TEXT addMulVVWx(SB), NOFRAME|NOSPLIT, $0
+ MOVL z+0(FP), DI
+ MOVL x+4(FP), SI
+ MOVL y+8(FP), BP
+ LEAL (DI)(BX*4), DI
+ LEAL (SI)(BX*4), SI
+ NEGL BX // i = -n
+ MOVL $0, CX // c = 0
+ JMP E6
+
+L6: MOVL (SI)(BX*4), AX
+ MULL BP
+ ADDL CX, AX
+ ADCL $0, DX
+ ADDL AX, (DI)(BX*4)
+ ADCL $0, DX
+ MOVL DX, CX
+ ADDL $1, BX // i++
+
+E6: CMPL BX, $0 // i < 0
+ JL L6
+
+ MOVL CX, c+12(FP)
+ RET
+++ /dev/null
-// Code generated by command: go run nat_amd64_asm.go -out ../nat_amd64.s -stubs ../nat_amd64.go -pkg bigmod. DO NOT EDIT.
-
-//go:build amd64 && gc && !purego
-
-package bigmod
-
-//go:noescape
-func montgomeryLoop(d []uint, a []uint, b []uint, m []uint, m0inv uint) uint
-// Code generated by command: go run nat_amd64_asm.go -out ../nat_amd64.s -stubs ../nat_amd64.go -pkg bigmod. DO NOT EDIT.
-
-//go:build amd64 && gc && !purego
-
-// func montgomeryLoop(d []uint, a []uint, b []uint, m []uint, m0inv uint) uint
-TEXT ·montgomeryLoop(SB), $8-112
- MOVQ d_len+8(FP), CX
- MOVQ d_base+0(FP), BX
- MOVQ b_base+48(FP), SI
- MOVQ m_base+72(FP), DI
- MOVQ m0inv+96(FP), R8
- XORQ R9, R9
- XORQ R10, R10
-
-outerLoop:
- MOVQ a_base+24(FP), R11
- MOVQ (R11)(R10*8), R11
- MOVQ (SI), AX
- MULQ R11
- MOVQ AX, R13
- MOVQ DX, R12
- ADDQ (BX), R13
- ADCQ $0x00, R12
- MOVQ R8, R14
- IMULQ R13, R14
- BTRQ $0x3f, R14
- MOVQ (DI), AX
- MULQ R14
- ADDQ AX, R13
- ADCQ DX, R12
- SHRQ $0x3f, R12, R13
- XORQ R12, R12
- INCQ R12
- JMP innerLoopCondition
-
-innerLoop:
- MOVQ (SI)(R12*8), AX
- MULQ R11
- MOVQ AX, BP
- MOVQ DX, R15
- MOVQ (DI)(R12*8), AX
- MULQ R14
- ADDQ AX, BP
- ADCQ DX, R15
- ADDQ (BX)(R12*8), BP
- ADCQ $0x00, R15
- ADDQ R13, BP
- ADCQ $0x00, R15
- MOVQ BP, AX
- BTRQ $0x3f, AX
- MOVQ AX, -8(BX)(R12*8)
- SHRQ $0x3f, R15, BP
- MOVQ BP, R13
- INCQ R12
-
-innerLoopCondition:
- CMPQ CX, R12
- JGT innerLoop
- ADDQ R13, R9
- MOVQ R9, AX
- BTRQ $0x3f, AX
- MOVQ AX, -8(BX)(CX*8)
- SHRQ $0x3f, R9
- INCQ R10
- CMPQ CX, R10
- JGT outerLoop
- MOVQ R9, ret+104(FP)
+// Code generated by command: go run nat_amd64_asm.go -out ../nat_amd64.s -pkg bigmod. DO NOT EDIT.
+
+//go:build !purego
+
+// func addMulVVW1024(z *uint, x *uint, y uint) (c uint)
+// Requires: ADX, BMI2
+TEXT ·addMulVVW1024(SB), $0-32
+ CMPB ·supportADX+0(SB), $0x01
+ JEQ adx
+ MOVQ z+0(FP), CX
+ MOVQ x+8(FP), BX
+ MOVQ y+16(FP), SI
+ XORQ DI, DI
+
+ // Iteration 0
+ MOVQ (BX), AX
+ MULQ SI
+ ADDQ (CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, (CX)
+
+ // Iteration 1
+ MOVQ 8(BX), AX
+ MULQ SI
+ ADDQ 8(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 8(CX)
+
+ // Iteration 2
+ MOVQ 16(BX), AX
+ MULQ SI
+ ADDQ 16(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 16(CX)
+
+ // Iteration 3
+ MOVQ 24(BX), AX
+ MULQ SI
+ ADDQ 24(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 24(CX)
+
+ // Iteration 4
+ MOVQ 32(BX), AX
+ MULQ SI
+ ADDQ 32(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 32(CX)
+
+ // Iteration 5
+ MOVQ 40(BX), AX
+ MULQ SI
+ ADDQ 40(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 40(CX)
+
+ // Iteration 6
+ MOVQ 48(BX), AX
+ MULQ SI
+ ADDQ 48(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 48(CX)
+
+ // Iteration 7
+ MOVQ 56(BX), AX
+ MULQ SI
+ ADDQ 56(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 56(CX)
+
+ // Iteration 8
+ MOVQ 64(BX), AX
+ MULQ SI
+ ADDQ 64(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 64(CX)
+
+ // Iteration 9
+ MOVQ 72(BX), AX
+ MULQ SI
+ ADDQ 72(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 72(CX)
+
+ // Iteration 10
+ MOVQ 80(BX), AX
+ MULQ SI
+ ADDQ 80(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 80(CX)
+
+ // Iteration 11
+ MOVQ 88(BX), AX
+ MULQ SI
+ ADDQ 88(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 88(CX)
+
+ // Iteration 12
+ MOVQ 96(BX), AX
+ MULQ SI
+ ADDQ 96(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 96(CX)
+
+ // Iteration 13
+ MOVQ 104(BX), AX
+ MULQ SI
+ ADDQ 104(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 104(CX)
+
+ // Iteration 14
+ MOVQ 112(BX), AX
+ MULQ SI
+ ADDQ 112(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 112(CX)
+
+ // Iteration 15
+ MOVQ 120(BX), AX
+ MULQ SI
+ ADDQ 120(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 120(CX)
+ MOVQ DI, c+24(FP)
+ RET
+
+adx:
+ MOVQ z+0(FP), AX
+ MOVQ x+8(FP), CX
+ MOVQ y+16(FP), DX
+ XORQ BX, BX
+ XORQ SI, SI
+
+ // Iteration 0
+ MULXQ (CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ (AX), R8
+ MOVQ R8, (AX)
+
+ // Iteration 1
+ MULXQ 8(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 8(AX), R8
+ MOVQ R8, 8(AX)
+
+ // Iteration 2
+ MULXQ 16(CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ 16(AX), R8
+ MOVQ R8, 16(AX)
+
+ // Iteration 3
+ MULXQ 24(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 24(AX), R8
+ MOVQ R8, 24(AX)
+
+ // Iteration 4
+ MULXQ 32(CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ 32(AX), R8
+ MOVQ R8, 32(AX)
+
+ // Iteration 5
+ MULXQ 40(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 40(AX), R8
+ MOVQ R8, 40(AX)
+
+ // Iteration 6
+ MULXQ 48(CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ 48(AX), R8
+ MOVQ R8, 48(AX)
+
+ // Iteration 7
+ MULXQ 56(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 56(AX), R8
+ MOVQ R8, 56(AX)
+
+ // Iteration 8
+ MULXQ 64(CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ 64(AX), R8
+ MOVQ R8, 64(AX)
+
+ // Iteration 9
+ MULXQ 72(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 72(AX), R8
+ MOVQ R8, 72(AX)
+
+ // Iteration 10
+ MULXQ 80(CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ 80(AX), R8
+ MOVQ R8, 80(AX)
+
+ // Iteration 11
+ MULXQ 88(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 88(AX), R8
+ MOVQ R8, 88(AX)
+
+ // Iteration 12
+ MULXQ 96(CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ 96(AX), R8
+ MOVQ R8, 96(AX)
+
+ // Iteration 13
+ MULXQ 104(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 104(AX), R8
+ MOVQ R8, 104(AX)
+
+ // Iteration 14
+ MULXQ 112(CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ 112(AX), R8
+ MOVQ R8, 112(AX)
+
+ // Iteration 15
+ MULXQ 120(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 120(AX), R8
+ MOVQ R8, 120(AX)
+
+ // Add back carry flags and return
+ ADCXQ SI, BX
+ ADOXQ SI, BX
+ MOVQ BX, c+24(FP)
+ RET
+
+// func addMulVVW1536(z *uint, x *uint, y uint) (c uint)
+// Requires: ADX, BMI2
+TEXT ·addMulVVW1536(SB), $0-32
+ CMPB ·supportADX+0(SB), $0x01
+ JEQ adx
+ MOVQ z+0(FP), CX
+ MOVQ x+8(FP), BX
+ MOVQ y+16(FP), SI
+ XORQ DI, DI
+
+ // Iteration 0
+ MOVQ (BX), AX
+ MULQ SI
+ ADDQ (CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, (CX)
+
+ // Iteration 1
+ MOVQ 8(BX), AX
+ MULQ SI
+ ADDQ 8(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 8(CX)
+
+ // Iteration 2
+ MOVQ 16(BX), AX
+ MULQ SI
+ ADDQ 16(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 16(CX)
+
+ // Iteration 3
+ MOVQ 24(BX), AX
+ MULQ SI
+ ADDQ 24(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 24(CX)
+
+ // Iteration 4
+ MOVQ 32(BX), AX
+ MULQ SI
+ ADDQ 32(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 32(CX)
+
+ // Iteration 5
+ MOVQ 40(BX), AX
+ MULQ SI
+ ADDQ 40(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 40(CX)
+
+ // Iteration 6
+ MOVQ 48(BX), AX
+ MULQ SI
+ ADDQ 48(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 48(CX)
+
+ // Iteration 7
+ MOVQ 56(BX), AX
+ MULQ SI
+ ADDQ 56(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 56(CX)
+
+ // Iteration 8
+ MOVQ 64(BX), AX
+ MULQ SI
+ ADDQ 64(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 64(CX)
+
+ // Iteration 9
+ MOVQ 72(BX), AX
+ MULQ SI
+ ADDQ 72(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 72(CX)
+
+ // Iteration 10
+ MOVQ 80(BX), AX
+ MULQ SI
+ ADDQ 80(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 80(CX)
+
+ // Iteration 11
+ MOVQ 88(BX), AX
+ MULQ SI
+ ADDQ 88(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 88(CX)
+
+ // Iteration 12
+ MOVQ 96(BX), AX
+ MULQ SI
+ ADDQ 96(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 96(CX)
+
+ // Iteration 13
+ MOVQ 104(BX), AX
+ MULQ SI
+ ADDQ 104(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 104(CX)
+
+ // Iteration 14
+ MOVQ 112(BX), AX
+ MULQ SI
+ ADDQ 112(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 112(CX)
+
+ // Iteration 15
+ MOVQ 120(BX), AX
+ MULQ SI
+ ADDQ 120(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 120(CX)
+
+ // Iteration 16
+ MOVQ 128(BX), AX
+ MULQ SI
+ ADDQ 128(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 128(CX)
+
+ // Iteration 17
+ MOVQ 136(BX), AX
+ MULQ SI
+ ADDQ 136(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 136(CX)
+
+ // Iteration 18
+ MOVQ 144(BX), AX
+ MULQ SI
+ ADDQ 144(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 144(CX)
+
+ // Iteration 19
+ MOVQ 152(BX), AX
+ MULQ SI
+ ADDQ 152(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 152(CX)
+
+ // Iteration 20
+ MOVQ 160(BX), AX
+ MULQ SI
+ ADDQ 160(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 160(CX)
+
+ // Iteration 21
+ MOVQ 168(BX), AX
+ MULQ SI
+ ADDQ 168(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 168(CX)
+
+ // Iteration 22
+ MOVQ 176(BX), AX
+ MULQ SI
+ ADDQ 176(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 176(CX)
+
+ // Iteration 23
+ MOVQ 184(BX), AX
+ MULQ SI
+ ADDQ 184(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 184(CX)
+ MOVQ DI, c+24(FP)
+ RET
+
+adx:
+ MOVQ z+0(FP), AX
+ MOVQ x+8(FP), CX
+ MOVQ y+16(FP), DX
+ XORQ BX, BX
+ XORQ SI, SI
+
+ // Iteration 0
+ MULXQ (CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ (AX), R8
+ MOVQ R8, (AX)
+
+ // Iteration 1
+ MULXQ 8(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 8(AX), R8
+ MOVQ R8, 8(AX)
+
+ // Iteration 2
+ MULXQ 16(CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ 16(AX), R8
+ MOVQ R8, 16(AX)
+
+ // Iteration 3
+ MULXQ 24(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 24(AX), R8
+ MOVQ R8, 24(AX)
+
+ // Iteration 4
+ MULXQ 32(CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ 32(AX), R8
+ MOVQ R8, 32(AX)
+
+ // Iteration 5
+ MULXQ 40(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 40(AX), R8
+ MOVQ R8, 40(AX)
+
+ // Iteration 6
+ MULXQ 48(CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ 48(AX), R8
+ MOVQ R8, 48(AX)
+
+ // Iteration 7
+ MULXQ 56(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 56(AX), R8
+ MOVQ R8, 56(AX)
+
+ // Iteration 8
+ MULXQ 64(CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ 64(AX), R8
+ MOVQ R8, 64(AX)
+
+ // Iteration 9
+ MULXQ 72(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 72(AX), R8
+ MOVQ R8, 72(AX)
+
+ // Iteration 10
+ MULXQ 80(CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ 80(AX), R8
+ MOVQ R8, 80(AX)
+
+ // Iteration 11
+ MULXQ 88(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 88(AX), R8
+ MOVQ R8, 88(AX)
+
+ // Iteration 12
+ MULXQ 96(CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ 96(AX), R8
+ MOVQ R8, 96(AX)
+
+ // Iteration 13
+ MULXQ 104(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 104(AX), R8
+ MOVQ R8, 104(AX)
+
+ // Iteration 14
+ MULXQ 112(CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ 112(AX), R8
+ MOVQ R8, 112(AX)
+
+ // Iteration 15
+ MULXQ 120(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 120(AX), R8
+ MOVQ R8, 120(AX)
+
+ // Iteration 16
+ MULXQ 128(CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ 128(AX), R8
+ MOVQ R8, 128(AX)
+
+ // Iteration 17
+ MULXQ 136(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 136(AX), R8
+ MOVQ R8, 136(AX)
+
+ // Iteration 18
+ MULXQ 144(CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ 144(AX), R8
+ MOVQ R8, 144(AX)
+
+ // Iteration 19
+ MULXQ 152(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 152(AX), R8
+ MOVQ R8, 152(AX)
+
+ // Iteration 20
+ MULXQ 160(CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ 160(AX), R8
+ MOVQ R8, 160(AX)
+
+ // Iteration 21
+ MULXQ 168(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 168(AX), R8
+ MOVQ R8, 168(AX)
+
+ // Iteration 22
+ MULXQ 176(CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ 176(AX), R8
+ MOVQ R8, 176(AX)
+
+ // Iteration 23
+ MULXQ 184(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 184(AX), R8
+ MOVQ R8, 184(AX)
+
+ // Add back carry flags and return
+ ADCXQ SI, BX
+ ADOXQ SI, BX
+ MOVQ BX, c+24(FP)
+ RET
+
+// func addMulVVW2048(z *uint, x *uint, y uint) (c uint)
+// Requires: ADX, BMI2
+TEXT ·addMulVVW2048(SB), $0-32
+ CMPB ·supportADX+0(SB), $0x01
+ JEQ adx
+ MOVQ z+0(FP), CX
+ MOVQ x+8(FP), BX
+ MOVQ y+16(FP), SI
+ XORQ DI, DI
+
+ // Iteration 0
+ MOVQ (BX), AX
+ MULQ SI
+ ADDQ (CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, (CX)
+
+ // Iteration 1
+ MOVQ 8(BX), AX
+ MULQ SI
+ ADDQ 8(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 8(CX)
+
+ // Iteration 2
+ MOVQ 16(BX), AX
+ MULQ SI
+ ADDQ 16(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 16(CX)
+
+ // Iteration 3
+ MOVQ 24(BX), AX
+ MULQ SI
+ ADDQ 24(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 24(CX)
+
+ // Iteration 4
+ MOVQ 32(BX), AX
+ MULQ SI
+ ADDQ 32(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 32(CX)
+
+ // Iteration 5
+ MOVQ 40(BX), AX
+ MULQ SI
+ ADDQ 40(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 40(CX)
+
+ // Iteration 6
+ MOVQ 48(BX), AX
+ MULQ SI
+ ADDQ 48(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 48(CX)
+
+ // Iteration 7
+ MOVQ 56(BX), AX
+ MULQ SI
+ ADDQ 56(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 56(CX)
+
+ // Iteration 8
+ MOVQ 64(BX), AX
+ MULQ SI
+ ADDQ 64(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 64(CX)
+
+ // Iteration 9
+ MOVQ 72(BX), AX
+ MULQ SI
+ ADDQ 72(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 72(CX)
+
+ // Iteration 10
+ MOVQ 80(BX), AX
+ MULQ SI
+ ADDQ 80(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 80(CX)
+
+ // Iteration 11
+ MOVQ 88(BX), AX
+ MULQ SI
+ ADDQ 88(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 88(CX)
+
+ // Iteration 12
+ MOVQ 96(BX), AX
+ MULQ SI
+ ADDQ 96(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 96(CX)
+
+ // Iteration 13
+ MOVQ 104(BX), AX
+ MULQ SI
+ ADDQ 104(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 104(CX)
+
+ // Iteration 14
+ MOVQ 112(BX), AX
+ MULQ SI
+ ADDQ 112(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 112(CX)
+
+ // Iteration 15
+ MOVQ 120(BX), AX
+ MULQ SI
+ ADDQ 120(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 120(CX)
+
+ // Iteration 16
+ MOVQ 128(BX), AX
+ MULQ SI
+ ADDQ 128(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 128(CX)
+
+ // Iteration 17
+ MOVQ 136(BX), AX
+ MULQ SI
+ ADDQ 136(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 136(CX)
+
+ // Iteration 18
+ MOVQ 144(BX), AX
+ MULQ SI
+ ADDQ 144(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 144(CX)
+
+ // Iteration 19
+ MOVQ 152(BX), AX
+ MULQ SI
+ ADDQ 152(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 152(CX)
+
+ // Iteration 20
+ MOVQ 160(BX), AX
+ MULQ SI
+ ADDQ 160(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 160(CX)
+
+ // Iteration 21
+ MOVQ 168(BX), AX
+ MULQ SI
+ ADDQ 168(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 168(CX)
+
+ // Iteration 22
+ MOVQ 176(BX), AX
+ MULQ SI
+ ADDQ 176(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 176(CX)
+
+ // Iteration 23
+ MOVQ 184(BX), AX
+ MULQ SI
+ ADDQ 184(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 184(CX)
+
+ // Iteration 24
+ MOVQ 192(BX), AX
+ MULQ SI
+ ADDQ 192(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 192(CX)
+
+ // Iteration 25
+ MOVQ 200(BX), AX
+ MULQ SI
+ ADDQ 200(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 200(CX)
+
+ // Iteration 26
+ MOVQ 208(BX), AX
+ MULQ SI
+ ADDQ 208(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 208(CX)
+
+ // Iteration 27
+ MOVQ 216(BX), AX
+ MULQ SI
+ ADDQ 216(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 216(CX)
+
+ // Iteration 28
+ MOVQ 224(BX), AX
+ MULQ SI
+ ADDQ 224(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 224(CX)
+
+ // Iteration 29
+ MOVQ 232(BX), AX
+ MULQ SI
+ ADDQ 232(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 232(CX)
+
+ // Iteration 30
+ MOVQ 240(BX), AX
+ MULQ SI
+ ADDQ 240(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 240(CX)
+
+ // Iteration 31
+ MOVQ 248(BX), AX
+ MULQ SI
+ ADDQ 248(CX), AX
+ ADCQ $0x00, DX
+ ADDQ DI, AX
+ ADCQ $0x00, DX
+ MOVQ DX, DI
+ MOVQ AX, 248(CX)
+ MOVQ DI, c+24(FP)
+ RET
+
+adx:
+ MOVQ z+0(FP), AX
+ MOVQ x+8(FP), CX
+ MOVQ y+16(FP), DX
+ XORQ BX, BX
+ XORQ SI, SI
+
+ // Iteration 0
+ MULXQ (CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ (AX), R8
+ MOVQ R8, (AX)
+
+ // Iteration 1
+ MULXQ 8(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 8(AX), R8
+ MOVQ R8, 8(AX)
+
+ // Iteration 2
+ MULXQ 16(CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ 16(AX), R8
+ MOVQ R8, 16(AX)
+
+ // Iteration 3
+ MULXQ 24(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 24(AX), R8
+ MOVQ R8, 24(AX)
+
+ // Iteration 4
+ MULXQ 32(CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ 32(AX), R8
+ MOVQ R8, 32(AX)
+
+ // Iteration 5
+ MULXQ 40(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 40(AX), R8
+ MOVQ R8, 40(AX)
+
+ // Iteration 6
+ MULXQ 48(CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ 48(AX), R8
+ MOVQ R8, 48(AX)
+
+ // Iteration 7
+ MULXQ 56(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 56(AX), R8
+ MOVQ R8, 56(AX)
+
+ // Iteration 8
+ MULXQ 64(CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ 64(AX), R8
+ MOVQ R8, 64(AX)
+
+ // Iteration 9
+ MULXQ 72(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 72(AX), R8
+ MOVQ R8, 72(AX)
+
+ // Iteration 10
+ MULXQ 80(CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ 80(AX), R8
+ MOVQ R8, 80(AX)
+
+ // Iteration 11
+ MULXQ 88(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 88(AX), R8
+ MOVQ R8, 88(AX)
+
+ // Iteration 12
+ MULXQ 96(CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ 96(AX), R8
+ MOVQ R8, 96(AX)
+
+ // Iteration 13
+ MULXQ 104(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 104(AX), R8
+ MOVQ R8, 104(AX)
+
+ // Iteration 14
+ MULXQ 112(CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ 112(AX), R8
+ MOVQ R8, 112(AX)
+
+ // Iteration 15
+ MULXQ 120(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 120(AX), R8
+ MOVQ R8, 120(AX)
+
+ // Iteration 16
+ MULXQ 128(CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ 128(AX), R8
+ MOVQ R8, 128(AX)
+
+ // Iteration 17
+ MULXQ 136(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 136(AX), R8
+ MOVQ R8, 136(AX)
+
+ // Iteration 18
+ MULXQ 144(CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ 144(AX), R8
+ MOVQ R8, 144(AX)
+
+ // Iteration 19
+ MULXQ 152(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 152(AX), R8
+ MOVQ R8, 152(AX)
+
+ // Iteration 20
+ MULXQ 160(CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ 160(AX), R8
+ MOVQ R8, 160(AX)
+
+ // Iteration 21
+ MULXQ 168(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 168(AX), R8
+ MOVQ R8, 168(AX)
+
+ // Iteration 22
+ MULXQ 176(CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ 176(AX), R8
+ MOVQ R8, 176(AX)
+
+ // Iteration 23
+ MULXQ 184(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 184(AX), R8
+ MOVQ R8, 184(AX)
+
+ // Iteration 24
+ MULXQ 192(CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ 192(AX), R8
+ MOVQ R8, 192(AX)
+
+ // Iteration 25
+ MULXQ 200(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 200(AX), R8
+ MOVQ R8, 200(AX)
+
+ // Iteration 26
+ MULXQ 208(CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ 208(AX), R8
+ MOVQ R8, 208(AX)
+
+ // Iteration 27
+ MULXQ 216(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 216(AX), R8
+ MOVQ R8, 216(AX)
+
+ // Iteration 28
+ MULXQ 224(CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ 224(AX), R8
+ MOVQ R8, 224(AX)
+
+ // Iteration 29
+ MULXQ 232(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 232(AX), R8
+ MOVQ R8, 232(AX)
+
+ // Iteration 30
+ MULXQ 240(CX), R8, DI
+ ADCXQ BX, R8
+ ADOXQ 240(AX), R8
+ MOVQ R8, 240(AX)
+
+ // Iteration 31
+ MULXQ 248(CX), R8, BX
+ ADCXQ DI, R8
+ ADOXQ 248(AX), R8
+ MOVQ R8, 248(AX)
+
+ // Add back carry flags and return
+ ADCXQ SI, BX
+ ADOXQ SI, BX
+ MOVQ BX, c+24(FP)
RET
--- /dev/null
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !purego
+
+#include "textflag.h"
+
+// func addMulVVW1024(z, x *uint, y uint) (c uint)
+TEXT ·addMulVVW1024(SB), $0-16
+ MOVW $32, R5
+ JMP addMulVVWx(SB)
+
+// func addMulVVW1536(z, x *uint, y uint) (c uint)
+TEXT ·addMulVVW1536(SB), $0-16
+ MOVW $48, R5
+ JMP addMulVVWx(SB)
+
+// func addMulVVW2048(z, x *uint, y uint) (c uint)
+TEXT ·addMulVVW2048(SB), $0-16
+ MOVW $64, R5
+ JMP addMulVVWx(SB)
+
+TEXT addMulVVWx(SB), NOFRAME|NOSPLIT, $0
+ MOVW $0, R0
+ MOVW z+0(FP), R1
+ MOVW x+4(FP), R2
+ MOVW y+8(FP), R3
+ ADD R5<<2, R1, R5
+ MOVW $0, R4
+ B E9
+
+L9: MOVW.P 4(R2), R6
+ MULLU R6, R3, (R7, R6)
+ ADD.S R4, R6
+ ADC R0, R7
+ MOVW 0(R1), R4
+ ADD.S R4, R6
+ ADC R0, R7
+ MOVW.P R6, 4(R1)
+ MOVW R7, R4
+
+E9: TEQ R1, R5
+ BNE L9
+
+ MOVW R4, c+12(FP)
+ RET
--- /dev/null
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !purego
+
+#include "textflag.h"
+
+// func addMulVVW1024(z, x *uint, y uint) (c uint)
+TEXT ·addMulVVW1024(SB), $0-32
+ MOVD $16, R0
+ JMP addMulVVWx(SB)
+
+// func addMulVVW1536(z, x *uint, y uint) (c uint)
+TEXT ·addMulVVW1536(SB), $0-32
+ MOVD $24, R0
+ JMP addMulVVWx(SB)
+
+// func addMulVVW2048(z, x *uint, y uint) (c uint)
+TEXT ·addMulVVW2048(SB), $0-32
+ MOVD $32, R0
+ JMP addMulVVWx(SB)
+
+TEXT addMulVVWx(SB), NOFRAME|NOSPLIT, $0
+ MOVD z+0(FP), R1
+ MOVD x+8(FP), R2
+ MOVD y+16(FP), R3
+ MOVD $0, R4
+
+// The main loop of this code operates on a block of 4 words every iteration
+// performing [R4:R12:R11:R10:R9] = R4 + R3 * [R8:R7:R6:R5] + [R12:R11:R10:R9]
+// where R4 is carried from the previous iteration, R8:R7:R6:R5 hold the next
+// 4 words of x, R3 is y and R12:R11:R10:R9 are part of the result z.
+loop:
+ CBZ R0, done
+
+ LDP.P 16(R2), (R5, R6)
+ LDP.P 16(R2), (R7, R8)
+
+ LDP (R1), (R9, R10)
+ ADDS R4, R9
+ MUL R6, R3, R14
+ ADCS R14, R10
+ MUL R7, R3, R15
+ LDP 16(R1), (R11, R12)
+ ADCS R15, R11
+ MUL R8, R3, R16
+ ADCS R16, R12
+ UMULH R8, R3, R20
+ ADC $0, R20
+
+ MUL R5, R3, R13
+ ADDS R13, R9
+ UMULH R5, R3, R17
+ ADCS R17, R10
+ UMULH R6, R3, R21
+ STP.P (R9, R10), 16(R1)
+ ADCS R21, R11
+ UMULH R7, R3, R19
+ ADCS R19, R12
+ STP.P (R11, R12), 16(R1)
+ ADC $0, R20, R4
+
+ SUB $4, R0
+ B loop
+
+done:
+ MOVD R4, c+24(FP)
+ RET
--- /dev/null
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !purego && (386 || amd64 || arm || arm64 || ppc64 || ppc64le || s390x)
+
+package bigmod
+
+import "internal/cpu"
+
+// amd64 assembly uses ADCX/ADOX/MULX if ADX is available to run two carry
+// chains in the flags in parallel across the whole operation, and aggressively
+// unrolls loops. arm64 processes four words at a time.
+//
+// It's unclear why the assembly for all other architectures, as well as for
+// amd64 without ADX, perform better than the compiler output.
+// TODO(filippo): file cmd/compile performance issue.
+
+var supportADX = cpu.X86.HasADX && cpu.X86.HasBMI2
+
+//go:noescape
+func addMulVVW1024(z, x *uint, y uint) (c uint)
+
+//go:noescape
+func addMulVVW1536(z, x *uint, y uint) (c uint)
+
+//go:noescape
+func addMulVVW2048(z, x *uint, y uint) (c uint)
-// Copyright 2022 The Go Authors. All rights reserved.
+// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !amd64 || !gc || purego
+//go:build purego || !(386 || amd64 || arm || arm64 || ppc64 || ppc64le || s390x)
package bigmod
-func montgomeryLoop(d, a, b, m []uint, m0inv uint) uint {
- return montgomeryLoopGeneric(d, a, b, m, m0inv)
+import "unsafe"
+
+func addMulVVW1024(z, x *uint, y uint) (c uint) {
+ return addMulVVW(unsafe.Slice(z, 1024/_W), unsafe.Slice(x, 1024/_W), y)
+}
+
+func addMulVVW1536(z, x *uint, y uint) (c uint) {
+ return addMulVVW(unsafe.Slice(z, 1536/_W), unsafe.Slice(x, 1536/_W), y)
+}
+
+func addMulVVW2048(z, x *uint, y uint) (c uint) {
+ return addMulVVW(unsafe.Slice(z, 2048/_W), unsafe.Slice(x, 2048/_W), y)
}
--- /dev/null
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !purego && (ppc64 || ppc64le)
+
+#include "textflag.h"
+
+// func addMulVVW1024(z, x *uint, y uint) (c uint)
+TEXT ·addMulVVW1024(SB), $0-32
+ MOVD $16, R22 // R22 = z_len
+ JMP addMulVVWx(SB)
+
+// func addMulVVW1536(z, x *uint, y uint) (c uint)
+TEXT ·addMulVVW1536(SB), $0-32
+ MOVD $24, R22 // R22 = z_len
+ JMP addMulVVWx(SB)
+
+// func addMulVVW2048(z, x *uint, y uint) (c uint)
+TEXT ·addMulVVW2048(SB), $0-32
+ MOVD $32, R22 // R22 = z_len
+ JMP addMulVVWx(SB)
+
+TEXT addMulVVWx(SB), NOFRAME|NOSPLIT, $0
+ MOVD z+0(FP), R10 // R10 = z[]
+ MOVD x+8(FP), R8 // R8 = x[]
+ MOVD y+16(FP), R9 // R9 = y
+
+ MOVD R0, R3 // R3 will be the index register
+ CMP R0, R22
+ MOVD R0, R4 // R4 = c = 0
+ MOVD R22, CTR // Initialize loop counter
+ BEQ done
+ PCALIGN $16
+
+loop:
+ MOVD (R8)(R3), R20 // Load x[i]
+ MOVD (R10)(R3), R21 // Load z[i]
+ MULLD R9, R20, R6 // R6 = Low-order(x[i]*y)
+ MULHDU R9, R20, R7 // R7 = High-order(x[i]*y)
+ ADDC R21, R6 // R6 = z0
+ ADDZE R7 // R7 = z1
+ ADDC R4, R6 // R6 = z0 + c + 0
+ ADDZE R7, R4 // c += z1
+ MOVD R6, (R10)(R3) // Store z[i]
+ ADD $8, R3
+ BC 16, 0, loop // bdnz
+
+done:
+ MOVD R4, c+24(FP)
+ RET
--- /dev/null
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !purego
+
+#include "textflag.h"
+
+// func addMulVVW1024(z, x *uint, y uint) (c uint)
+TEXT ·addMulVVW1024(SB), $0-32
+ MOVD $16, R5
+ JMP addMulVVWx(SB)
+
+// func addMulVVW1536(z, x *uint, y uint) (c uint)
+TEXT ·addMulVVW1536(SB), $0-32
+ MOVD $24, R5
+ JMP addMulVVWx(SB)
+
+// func addMulVVW2048(z, x *uint, y uint) (c uint)
+TEXT ·addMulVVW2048(SB), $0-32
+ MOVD $32, R5
+ JMP addMulVVWx(SB)
+
+TEXT addMulVVWx(SB), NOFRAME|NOSPLIT, $0
+ MOVD z+0(FP), R2
+ MOVD x+8(FP), R8
+ MOVD y+16(FP), R9
+
+ MOVD $0, R1 // i*8 = 0
+ MOVD $0, R7 // i = 0
+ MOVD $0, R0 // make sure it's zero
+ MOVD $0, R4 // c = 0
+
+ MOVD R5, R12
+ AND $-2, R12
+ CMPBGE R5, $2, A6
+ BR E6
+
+A6:
+ MOVD (R8)(R1*1), R6
+ MULHDU R9, R6
+ MOVD (R2)(R1*1), R10
+ ADDC R10, R11 // add to low order bits
+ ADDE R0, R6
+ ADDC R4, R11
+ ADDE R0, R6
+ MOVD R6, R4
+ MOVD R11, (R2)(R1*1)
+
+ MOVD (8)(R8)(R1*1), R6
+ MULHDU R9, R6
+ MOVD (8)(R2)(R1*1), R10
+ ADDC R10, R11 // add to low order bits
+ ADDE R0, R6
+ ADDC R4, R11
+ ADDE R0, R6
+ MOVD R6, R4
+ MOVD R11, (8)(R2)(R1*1)
+
+ ADD $16, R1 // i*8 + 8
+ ADD $2, R7 // i++
+
+ CMPBLT R7, R12, A6
+ BR E6
+
+L6:
+ // TODO: drop unused single-step loop.
+ MOVD (R8)(R1*1), R6
+ MULHDU R9, R6
+ MOVD (R2)(R1*1), R10
+ ADDC R10, R11 // add to low order bits
+ ADDE R0, R6
+ ADDC R4, R11
+ ADDE R0, R6
+ MOVD R6, R4
+ MOVD R11, (R2)(R1*1)
+
+ ADD $8, R1 // i*8 + 8
+ ADD $1, R7 // i++
+
+E6:
+ CMPBLT R7, R5, L6 // i < n
+
+ MOVD R4, c+24(FP)
+ RET
package bigmod
import (
+ "fmt"
"math/big"
"math/bits"
"math/rand"
"reflect"
+ "strings"
"testing"
"testing/quick"
)
+func (n *Nat) String() string {
+ var limbs []string
+ for i := range n.limbs {
+ limbs = append(limbs, fmt.Sprintf("%016X", n.limbs[len(n.limbs)-1-i]))
+ }
+ return "{" + strings.Join(limbs, " ") + "}"
+}
+
// Generate generates an even nat. It's used by testing/quick to produce random
// *nat values for quick.Check invocations.
func (*Nat) Generate(r *rand.Rand, size int) reflect.Value {
}
}
-func testMontgomeryRoundtrip(a *Nat) bool {
- one := &Nat{make([]uint, len(a.limbs))}
- one.limbs[0] = 1
- aPlusOne := new(big.Int).SetBytes(natBytes(a))
- aPlusOne.Add(aPlusOne, big.NewInt(1))
- m := NewModulusFromBig(aPlusOne)
- monty := new(Nat).set(a)
- monty.montgomeryRepresentation(m)
- aAgain := new(Nat).set(monty)
- aAgain.montgomeryMul(monty, one, m)
- return a.Equal(aAgain) == 1
-}
-
func TestMontgomeryRoundtrip(t *testing.T) {
- err := quick.Check(testMontgomeryRoundtrip, &quick.Config{})
+ err := quick.Check(func(a *Nat) bool {
+ one := &Nat{make([]uint, len(a.limbs))}
+ one.limbs[0] = 1
+ aPlusOne := new(big.Int).SetBytes(natBytes(a))
+ aPlusOne.Add(aPlusOne, big.NewInt(1))
+ m := NewModulusFromBig(aPlusOne)
+ monty := new(Nat).set(a)
+ monty.montgomeryRepresentation(m)
+ aAgain := new(Nat).set(monty)
+ aAgain.montgomeryMul(monty, one, m)
+ if a.Equal(aAgain) != 1 {
+ t.Errorf("%v != %v", a, aAgain)
+ return false
+ }
+ return true
+ }, &quick.Config{})
if err != nil {
t.Error(err)
}
}{{
m: []byte{13},
x: []byte{0},
- y: 0x7FFF_FFFF_FFFF_FFFF,
- expected: []byte{7},
+ y: 0xFFFF_FFFF_FFFF_FFFF,
+ expected: []byte{2},
}, {
m: []byte{13},
x: []byte{7},
- y: 0x7FFF_FFFF_FFFF_FFFF,
- expected: []byte{11},
+ y: 0xFFFF_FFFF_FFFF_FFFF,
+ expected: []byte{10},
}, {
m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
x: make([]byte, 9),
- y: 0x7FFF_FFFF_FFFF_FFFF,
- expected: []byte{0x00, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ y: 0xFFFF_FFFF_FFFF_FFFF,
+ expected: []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
}, {
m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
- x: []byte{0x00, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ x: []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
y: 0,
- expected: []byte{0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08},
+ expected: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06},
}}
for i, tt := range examples {
m := modulusFromBytes(tt.m)
got := natFromBytes(tt.x).ExpandFor(m).shiftIn(uint(tt.y), m)
- if got.Equal(natFromBytes(tt.expected).ExpandFor(m)) != 1 {
- t.Errorf("%d: got %x, expected %x", i, got, tt.expected)
+ if exp := natFromBytes(tt.expected).ExpandFor(m); got.Equal(exp) != 1 {
+ t.Errorf("%d: got %v, expected %v", i, got, exp)
}
}
}
continue
}
if expected := natFromBytes(tt.b).ExpandFor(m); got.Equal(expected) != yes {
- t.Errorf("%d: got %x, expected %x", i, got, expected)
+ t.Errorf("%d: got %v, expected %v", i, got, expected)
}
}
for i, tt := range examples {
got := (&Nat{tt.in}).expand(tt.n)
if len(got.limbs) != len(tt.out) || got.Equal(&Nat{tt.out}) != 1 {
- t.Errorf("%d: got %x, expected %x", i, got, tt.out)
+ t.Errorf("%d: got %v, expected %v", i, got, tt.out)
}
}
}
}
}
+// TestMulReductions tests that Mul reduces results equal or slightly greater
+// than the modulus. Some Montgomery algorithms don't and need extra care to
+// return correct results. See https://go.dev/issue/13907.
+func TestMulReductions(t *testing.T) {
+ // Two short but multi-limb primes.
+ a, _ := new(big.Int).SetString("773608962677651230850240281261679752031633236267106044359907", 10)
+ b, _ := new(big.Int).SetString("180692823610368451951102211649591374573781973061758082626801", 10)
+ n := new(big.Int).Mul(a, b)
+
+ N := NewModulusFromBig(n)
+ A := NewNat().setBig(a).ExpandFor(N)
+ B := NewNat().setBig(b).ExpandFor(N)
+
+ if A.Mul(B, N).IsZero() != 1 {
+ t.Error("a * b mod (a * b) != 0")
+ }
+
+ i := new(big.Int).ModInverse(a, b)
+ N = NewModulusFromBig(b)
+ A = NewNat().setBig(a).ExpandFor(N)
+ I := NewNat().setBig(i).ExpandFor(N)
+ one := NewNat().setBig(big.NewInt(1)).ExpandFor(N)
+
+ if A.Mul(I, N).Equal(one) != 1 {
+ t.Error("a * inv(a) mod b != 1")
+ }
+}
+
func natBytes(n *Nat) []byte {
return n.Bytes(maxModulus(uint(len(n.limbs))))
}
func natFromBytes(b []byte) *Nat {
+ // Must not use Nat.SetBytes as it's used in TestSetBytes.
bb := new(big.Int).SetBytes(b)
return NewNat().setBig(bb)
}
func makeBenchmarkValue() *Nat {
x := make([]uint, 32)
for i := 0; i < 32; i++ {
- x[i] = _MASK - 1
+ x[i]--
}
return &Nat{limbs: x}
}