]> Cypherpunks.ru repositories - gostls13.git/commitdiff
crypto/internal/bigmod: switch to saturated limbs
authorFilippo Valsorda <filippo@golang.org>
Sat, 25 Feb 2023 18:09:11 +0000 (19:09 +0100)
committerGopher Robot <gobot@golang.org>
Wed, 24 May 2023 22:37:58 +0000 (22:37 +0000)
Turns out that unsaturated limbs being more performant for Montgomery
multiplication was true in portable C89, but is now a misconception.
With add-with-carry instructions, it's possible to run the carry chain
across the limbs, instead of needing the limb-by-limb product to fit in
two words.

Switch to saturated limbs, and import the same Montgomery loop as
math/big, along with its assembly for some architectures. Since here we
know the sizes we care about, we can drop most of the assembly
scaffolding. For amd64, ported to avo, too.

We recover all the Go 1.20 performance loss on private key operations on
both Intel Xeon and AMD EPYC, with even a 10% improvement over Go 1.19
(which used variable-time math/big) for some operations.

goos: linux
goarch: amd64
pkg: crypto/rsa
cpu: Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz
                       │ go1.19.txt  │       go1.20.txt         │         new.txt          │
                       │   sec/op    │    sec/op      vs base   │    sec/op      vs base   │
DecryptPKCS1v15/2048-4   1.175m ± 0%     1.515m ± 0%    +28.95%     1.132m ± 0%     -3.59%
DecryptPKCS1v15/3072-4   3.428m ± 1%     4.516m ± 0%    +31.75%     3.198m ± 0%     -6.69%
DecryptPKCS1v15/4096-4   7.405m ± 0%    10.092m ± 0%    +36.29%     6.446m ± 0%    -12.95%
EncryptPKCS1v15/2048-4   7.426µ ± 0%   170.829µ ± 0%  +2200.57%   131.874µ ± 0%  +1675.97%
DecryptOAEP/2048-4       1.175m ± 0%     1.524m ± 0%    +29.68%     1.137m ± 0%     -3.26%
EncryptOAEP/2048-4       9.609µ ± 0%   173.008µ ± 0%  +1700.48%   132.344µ ± 0%  +1277.29%
SignPKCS1v15/2048-4      1.181m ± 0%     1.563m ± 0%    +32.34%     1.177m ± 0%     -0.37%
VerifyPKCS1v15/2048-4    6.452µ ± 0%   170.092µ ± 0%  +2536.06%   131.225µ ± 0%  +1933.70%
SignPSS/2048-4           1.184m ± 0%     1.574m ± 0%    +32.88%     1.175m ± 0%     -0.84%
VerifyPSS/2048-4         9.151µ ± 1%   172.909µ ± 0%  +1789.50%   132.391µ ± 0%  +1346.74%

                       │  go1.19.txt   │      go1.20.txt       │       new.txt         │
                       │     B/op      │     B/op      vs base │     B/op      vs base │
DecryptPKCS1v15/2048-4    24266.5 ± 0%     640.0 ± 0%  -97.36%     640.0 ± 0%  -97.36%
DecryptPKCS1v15/3072-4   45.465Ki ± 0%   3.375Ki ± 0%  -92.58%   4.688Ki ± 0%  -89.69%
DecryptPKCS1v15/4096-4   61.080Ki ± 0%   4.625Ki ± 0%  -92.43%   6.250Ki ± 0%  -89.77%
EncryptPKCS1v15/2048-4    3.138Ki ± 0%   1.146Ki ± 0%  -63.49%   1.082Ki ± 0%  -65.52%
DecryptOAEP/2048-4        24500.0 ± 0%     872.0 ± 0%  -96.44%     872.0 ± 0%  -96.44%
EncryptOAEP/2048-4        3.610Ki ± 0%   1.371Ki ± 0%  -62.02%   1.308Ki ± 0%  -63.78%
SignPKCS1v15/2048-4       26933.0 ± 0%     896.0 ± 0%  -96.67%     896.0 ± 0%  -96.67%
VerifyPKCS1v15/2048-4      3209.0 ± 0%     912.0 ± 0%  -71.58%     848.0 ± 0%  -73.57%
SignPSS/2048-4           26.940Ki ± 0%   1.266Ki ± 0%  -95.30%   1.266Ki ± 0%  -95.30%
VerifyPSS/2048-4          3.337Ki ± 0%   1.094Ki ± 0%  -67.22%   1.031Ki ± 0%  -69.10%

                       │  go1.19.txt  │     go1.20.txt      │      new.txt          │
                       │  allocs/op   │ allocs/op   vs base │ allocs/op   vs base   │
DecryptPKCS1v15/2048-4    97.000 ± 0%   4.000 ± 0%  -95.88%     4.000 ± 0%  -95.88%
DecryptPKCS1v15/3072-4    107.00 ± 0%   10.00 ± 0%  -90.65%     12.00 ± 0%  -88.79%
DecryptPKCS1v15/4096-4    113.00 ± 0%   10.00 ± 0%  -91.15%     12.00 ± 0%  -89.38%
EncryptPKCS1v15/2048-4     7.000 ± 0%   7.000 ± 0%        ~     7.000 ± 0%        ~
DecryptOAEP/2048-4        103.00 ± 0%   10.00 ± 0%  -90.29%     10.00 ± 0%  -90.29%
EncryptOAEP/2048-4         14.00 ± 0%   13.00 ± 0%   -7.14%     13.00 ± 0%   -7.14%
SignPKCS1v15/2048-4      102.000 ± 0%   5.000 ± 0%  -95.10%     5.000 ± 0%  -95.10%
VerifyPKCS1v15/2048-4      7.000 ± 0%   6.000 ± 0%  -14.29%     6.000 ± 0%  -14.29%
SignPSS/2048-4            108.00 ± 0%   10.00 ± 0%  -90.74%     10.00 ± 0%  -90.74%
VerifyPSS/2048-4           12.00 ± 0%   11.00 ± 0%   -8.33%     11.00 ± 0%   -8.33%

goos: linux
goarch: amd64
pkg: crypto/rsa
cpu: AMD EPYC 7R13 Processor
                       │ go1.19a.txt │       go1.20a.txt        │        newa.txt          │
                       │   sec/op    │    sec/op      vs base   │    sec/op      vs base   │
DecryptPKCS1v15/2048-4   970.0µ ± 0%    1667.6µ ± 0%    +71.92%     951.6µ ± 0%     -1.90%
DecryptPKCS1v15/3072-4   2.949m ± 0%     5.124m ± 0%    +73.75%     2.675m ± 0%     -9.29%
DecryptPKCS1v15/4096-4   6.350m ± 0%    11.660m ± 0%    +83.62%     5.746m ± 0%     -9.51%
EncryptPKCS1v15/2048-4   6.605µ ± 1%   183.807µ ± 0%  +2683.05%   123.720µ ± 0%  +1773.27%
DecryptOAEP/2048-4       973.8µ ± 0%    1670.8µ ± 0%    +71.57%     951.8µ ± 0%     -2.27%
EncryptOAEP/2048-4       8.444µ ± 1%   185.889µ ± 0%  +2101.56%   124.142µ ± 0%  +1370.27%
SignPKCS1v15/2048-4      976.8µ ± 0%    1725.5µ ± 0%    +76.65%     979.6µ ± 0%     +0.28%
VerifyPKCS1v15/2048-4    5.713µ ± 0%   182.983µ ± 0%  +3103.19%   122.737µ ± 0%  +2048.56%
SignPSS/2048-4           980.3µ ± 0%    1729.5µ ± 0%    +76.42%     985.7µ ± 3%     +0.55%
VerifyPSS/2048-4         8.168µ ± 1%   185.312µ ± 0%  +2168.76%   123.772µ ± 0%  +1415.33%

Fixes #59463
Fixes #59442
Updates #57752

Change-Id: I311a9c1f4f5288e47e53ca14f615a443f3132734
Reviewed-on: https://go-review.googlesource.com/c/go/+/471259
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
Run-TryBot: Filippo Valsorda <filippo@golang.org>
Auto-Submit: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>

12 files changed:
src/crypto/internal/bigmod/_asm/nat_amd64_asm.go
src/crypto/internal/bigmod/nat.go
src/crypto/internal/bigmod/nat_386.s [new file with mode: 0644]
src/crypto/internal/bigmod/nat_amd64.go [deleted file]
src/crypto/internal/bigmod/nat_amd64.s
src/crypto/internal/bigmod/nat_arm.s [new file with mode: 0644]
src/crypto/internal/bigmod/nat_arm64.s [new file with mode: 0644]
src/crypto/internal/bigmod/nat_asm.go [new file with mode: 0644]
src/crypto/internal/bigmod/nat_noasm.go
src/crypto/internal/bigmod/nat_ppc64x.s [new file with mode: 0644]
src/crypto/internal/bigmod/nat_s390x.s [new file with mode: 0644]
src/crypto/internal/bigmod/nat_test.go

index 5690f04d1ee2683febc0063d81b45e73c56870df..bf64565d5c9fc84a7a09fb12fe30f6a47e9bb7ba 100644 (file)
-// Copyright 2022 The Go Authors. All rights reserved.
+// Copyright 2023 The Go Authors. All rights reserved.
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
 package main
 
 import (
+       "strconv"
+
        . "github.com/mmcloughlin/avo/build"
        . "github.com/mmcloughlin/avo/operand"
        . "github.com/mmcloughlin/avo/reg"
 )
 
-//go:generate go run . -out ../nat_amd64.s -stubs ../nat_amd64.go -pkg bigmod
+//go:generate go run . -out ../nat_amd64.s -pkg bigmod
 
 func main() {
        Package("crypto/internal/bigmod")
-       ConstraintExpr("amd64,gc,!purego")
-
-       Implement("montgomeryLoop")
-       Pragma("noescape")
-
-       size := Load(Param("d").Len(), GP64())
-       d := Mem{Base: Load(Param("d").Base(), GP64())}
-       b := Mem{Base: Load(Param("b").Base(), GP64())}
-       m := Mem{Base: Load(Param("m").Base(), GP64())}
-       m0inv := Load(Param("m0inv"), GP64())
-
-       overflow := zero()
-       i := zero()
-       Label("outerLoop")
-
-       ai := Load(Param("a").Base(), GP64())
-       MOVQ(Mem{Base: ai}.Idx(i, 8), ai)
-
-       z := uint128{GP64(), GP64()}
-       mul64(z, b, ai)
-       add64(z, d)
-       f := GP64()
-       MOVQ(m0inv, f)
-       IMULQ(z.lo, f)
-       _MASK(f)
-       addMul64(z, m, f)
-       carry := shiftBy63(z)
-
-       j := zero()
-       INCQ(j)
-       JMP(LabelRef("innerLoopCondition"))
-       Label("innerLoop")
-
-       // z = d[j] + a[i] * b[j] + f * m[j] + carry
-       z = uint128{GP64(), GP64()}
-       mul64(z, b.Idx(j, 8), ai)
-       addMul64(z, m.Idx(j, 8), f)
-       add64(z, d.Idx(j, 8))
-       add64(z, carry)
-       // d[j-1] = z_lo & _MASK
-       storeMasked(z.lo, d.Idx(j, 8).Offset(-8))
-       // carry = z_hi<<1 | z_lo>>_W
-       MOVQ(shiftBy63(z), carry)
-
-       INCQ(j)
-       Label("innerLoopCondition")
-       CMPQ(size, j)
-       JGT(LabelRef("innerLoop"))
-
-       ADDQ(carry, overflow)
-       storeMasked(overflow, d.Idx(size, 8).Offset(-8))
-       SHRQ(Imm(63), overflow)
-
-       INCQ(i)
-       CMPQ(size, i)
-       JGT(LabelRef("outerLoop"))
-
-       Store(overflow, ReturnIndex(0))
-       RET()
-       Generate()
-}
+       ConstraintExpr("!purego")
 
-// zero zeroes a new register and returns it.
-func zero() Register {
-       r := GP64()
-       XORQ(r, r)
-       return r
-}
-
-// _MASK masks out the top bit of r.
-func _MASK(r Register) {
-       BTRQ(Imm(63), r)
-}
-
-type uint128 struct {
-       hi, lo GPVirtual
-}
+       addMulVVW(1024)
+       addMulVVW(1536)
+       addMulVVW(2048)
 
-// storeMasked stores _MASK(src) in dst. It doesn't modify src.
-func storeMasked(src, dst Op) {
-       out := GP64()
-       MOVQ(src, out)
-       _MASK(out)
-       MOVQ(out, dst)
-}
-
-// shiftBy63 returns z >> 63. It reuses z.lo.
-func shiftBy63(z uint128) Register {
-       SHRQ(Imm(63), z.hi, z.lo)
-       result := z.lo
-       z.hi, z.lo = nil, nil
-       return result
-}
-
-// add64 sets r to r + a.
-func add64(r uint128, a Op) {
-       ADDQ(a, r.lo)
-       ADCQ(Imm(0), r.hi)
+       Generate()
 }
 
-// mul64 sets r to a * b.
-func mul64(r uint128, a, b Op) {
-       MOVQ(a, RAX)
-       MULQ(b) // RDX, RAX = RAX * b
-       MOVQ(RAX, r.lo)
-       MOVQ(RDX, r.hi)
-}
+func addMulVVW(bits int) {
+       if bits%64 != 0 {
+               panic("bit size unsupported")
+       }
+
+       Implement("addMulVVW" + strconv.Itoa(bits))
+
+       CMPB(Mem{Symbol: Symbol{Name: "·supportADX"}, Base: StaticBase}, Imm(1))
+       JEQ(LabelRef("adx"))
+
+       z := Mem{Base: Load(Param("z"), GP64())}
+       x := Mem{Base: Load(Param("x"), GP64())}
+       y := Load(Param("y"), GP64())
+
+       carry := GP64()
+       XORQ(carry, carry) // zero out carry
+
+       for i := 0; i < bits/64; i++ {
+               Comment("Iteration " + strconv.Itoa(i))
+               hi, lo := RDX, RAX // implicit MULQ inputs and outputs
+               MOVQ(x.Offset(i*8), lo)
+               MULQ(y)
+               ADDQ(z.Offset(i*8), lo)
+               ADCQ(Imm(0), hi)
+               ADDQ(carry, lo)
+               ADCQ(Imm(0), hi)
+               MOVQ(hi, carry)
+               MOVQ(lo, z.Offset(i*8))
+       }
+
+       Store(carry, ReturnIndex(0))
+       RET()
 
-// addMul64 sets r to r + a * b.
-func addMul64(r uint128, a, b Op) {
-       MOVQ(a, RAX)
-       MULQ(b) // RDX, RAX = RAX * b
-       ADDQ(RAX, r.lo)
-       ADCQ(RDX, r.hi)
+       Label("adx")
+
+       // The ADX strategy implements the following function, where c1 and c2 are
+       // the overflow and the carry flag respectively.
+       //
+       //    func addMulVVW(z, x []uint, y uint) (carry uint) {
+       //        var c1, c2 uint
+       //        for i := range z {
+       //            hi, lo := bits.Mul(x[i], y)
+       //            lo, c1 = bits.Add(lo, z[i], c1)
+       //            z[i], c2 = bits.Add(lo, carry, c2)
+       //            carry = hi
+       //        }
+       //        return carry + c1 + c2
+       //    }
+       //
+       // The loop is fully unrolled and the hi / carry registers are alternated
+       // instead of introducing a MOV.
+
+       z = Mem{Base: Load(Param("z"), GP64())}
+       x = Mem{Base: Load(Param("x"), GP64())}
+       Load(Param("y"), RDX) // implicit source of MULXQ
+
+       carry = GP64()
+       XORQ(carry, carry) // zero out carry
+       z0 := GP64()
+       XORQ(z0, z0) // unset flags and zero out z0
+
+       for i := 0; i < bits/64; i++ {
+               hi, lo := GP64(), GP64()
+
+               Comment("Iteration " + strconv.Itoa(i))
+               MULXQ(x.Offset(i*8), lo, hi)
+               ADCXQ(carry, lo)
+               ADOXQ(z.Offset(i*8), lo)
+               MOVQ(lo, z.Offset(i*8))
+
+               i++
+
+               Comment("Iteration " + strconv.Itoa(i))
+               MULXQ(x.Offset(i*8), lo, carry)
+               ADCXQ(hi, lo)
+               ADOXQ(z.Offset(i*8), lo)
+               MOVQ(lo, z.Offset(i*8))
+       }
+
+       Comment("Add back carry flags and return")
+       ADCXQ(z0, carry)
+       ADOXQ(z0, carry)
+
+       Store(carry, ReturnIndex(0))
+       RET()
 }
index 804316f50489a377c93ddae5fb0cbcb4a0efb722..3cad382b53f0186b25e131a949b0ca4c9eafd654 100644 (file)
@@ -5,16 +5,17 @@
 package bigmod
 
 import (
+       "encoding/binary"
        "errors"
        "math/big"
        "math/bits"
 )
 
 const (
-       // _W is the number of bits we use for our limbs.
-       _W = bits.UintSize - 1
-       // _MASK selects _W bits from a full machine word.
-       _MASK = (1 << _W) - 1
+       // _W is the size in bits of our limbs.
+       _W = bits.UintSize
+       // _S is the size in bytes of our limbs.
+       _S = _W / 8
 )
 
 // choice represents a constant-time boolean. The value of choice is always
@@ -27,15 +28,8 @@ func not(c choice) choice { return 1 ^ c }
 const yes = choice(1)
 const no = choice(0)
 
-// ctSelect returns x if on == 1, and y if on == 0. The execution time of this
-// function does not depend on its inputs. If on is any value besides 1 or 0,
-// the result is undefined.
-func ctSelect(on choice, x, y uint) uint {
-       // When on == 1, mask is 0b111..., otherwise mask is 0b000...
-       mask := -uint(on)
-       // When mask is all zeros, we just have y, otherwise, y cancels with itself.
-       return y ^ (mask & (y ^ x))
-}
+// ctMask is all 1s if on is yes, and all 0s otherwise.
+func ctMask(on choice) uint { return -uint(on) }
 
 // ctEq returns 1 if x == y, and 0 otherwise. The execution time of this
 // function does not depend on its inputs.
@@ -60,13 +54,7 @@ func ctGeq(x, y uint) choice {
 // Operations on this number are allowed to leak this length, but will not leak
 // any information about the values contained in those limbs.
 type Nat struct {
-       // limbs is a little-endian representation in base 2^W with
-       // W = bits.UintSize - 1. The top bit is always unset between operations.
-       //
-       // The top bit is left unset to optimize Montgomery multiplication, in the
-       // inner loop of exponentiation. Using fully saturated limbs would leave us
-       // working with 129-bit numbers on 64-bit platforms, wasting a lot of space,
-       // and thus time.
+       // limbs is little-endian in base 2^W with W = bits.UintSize.
        limbs []uint
 }
 
@@ -128,25 +116,10 @@ func (x *Nat) set(y *Nat) *Nat {
 // The announced length of x is set based on the actual bit size of the input,
 // ignoring leading zeroes.
 func (x *Nat) setBig(n *big.Int) *Nat {
-       requiredLimbs := (n.BitLen() + _W - 1) / _W
-       x.reset(requiredLimbs)
-
-       outI := 0
-       shift := 0
        limbs := n.Bits()
+       x.reset(len(limbs))
        for i := range limbs {
-               xi := uint(limbs[i])
-               x.limbs[outI] |= (xi << shift) & _MASK
-               outI++
-               if outI == requiredLimbs {
-                       return x
-               }
-               x.limbs[outI] = xi >> (_W - shift)
-               shift++ // this assumes bits.UintSize - _W = 1
-               if shift == _W {
-                       shift = 0
-                       outI++
-               }
+               x.limbs[i] = uint(limbs[i])
        }
        return x
 }
@@ -156,24 +129,20 @@ func (x *Nat) setBig(n *big.Int) *Nat {
 //
 // x must have the same size as m and it must be reduced modulo m.
 func (x *Nat) Bytes(m *Modulus) []byte {
-       bytes := make([]byte, m.Size())
-       shift := 0
-       outI := len(bytes) - 1
+       i := m.Size()
+       bytes := make([]byte, i)
        for _, limb := range x.limbs {
-               remainingBits := _W
-               for remainingBits >= 8 {
-                       bytes[outI] |= byte(limb) << shift
-                       consumed := 8 - shift
-                       limb >>= consumed
-                       remainingBits -= consumed
-                       shift = 0
-                       outI--
-                       if outI < 0 {
-                               return bytes
+               for j := 0; j < _S; j++ {
+                       i--
+                       if i < 0 {
+                               if limb == 0 {
+                                       break
+                               }
+                               panic("bigmod: modulus is smaller than nat")
                        }
+                       bytes[i] = byte(limb)
+                       limb >>= 8
                }
-               bytes[outI] = byte(limb)
-               shift = remainingBits
        }
        return bytes
 }
@@ -192,9 +161,9 @@ func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) {
        return x, nil
 }
 
-// SetOverflowingBytes assigns x = b, where b is a slice of big-endian bytes. SetOverflowingBytes
-// returns an error if b has a longer bit length than m, but reduces overflowing
-// values up to 2^⌈log2(m)⌉ - 1.
+// SetOverflowingBytes assigns x = b, where b is a slice of big-endian bytes.
+// SetOverflowingBytes returns an error if b has a longer bit length than m, but
+// reduces overflowing values up to 2^⌈log2(m)⌉ - 1.
 //
 // The output will be resized to the size of m and overwritten.
 func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) {
@@ -203,33 +172,35 @@ func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) {
        }
        leading := _W - bitLen(x.limbs[len(x.limbs)-1])
        if leading < m.leading {
-               return nil, errors.New("input overflows the modulus")
+               return nil, errors.New("input overflows the modulus size")
        }
-       x.sub(x.cmpGeq(m.nat), m.nat)
+       x.maybeSubtractModulus(no, m)
        return x, nil
 }
 
+// bigEndianUint returns the contents of buf interpreted as a
+// big-endian encoded uint value.
+func bigEndianUint(buf []byte) uint {
+       if _W == 64 {
+               return uint(binary.BigEndian.Uint64(buf))
+       }
+       return uint(binary.BigEndian.Uint32(buf))
+}
+
 func (x *Nat) setBytes(b []byte, m *Modulus) error {
-       outI := 0
-       shift := 0
        x.resetFor(m)
-       for i := len(b) - 1; i >= 0; i-- {
-               bi := b[i]
-               x.limbs[outI] |= uint(bi) << shift
-               shift += 8
-               if shift >= _W {
-                       shift -= _W
-                       x.limbs[outI] &= _MASK
-                       overflow := bi >> (8 - shift)
-                       outI++
-                       if outI >= len(x.limbs) {
-                               if overflow > 0 || i > 0 {
-                                       return errors.New("input overflows the modulus")
-                               }
-                               break
-                       }
-                       x.limbs[outI] = uint(overflow)
-               }
+       i, k := len(b), 0
+       for k < len(x.limbs) && i >= _S {
+               x.limbs[k] = bigEndianUint(b[i-_S : i])
+               i -= _S
+               k++
+       }
+       for s := 0; s < _W && k < len(x.limbs) && i > 0; s += 8 {
+               x.limbs[k] |= uint(b[i-1]) << s
+               i--
+       }
+       if i > 0 {
+               return errors.New("input overflows the modulus size")
        }
        return nil
 }
@@ -274,7 +245,7 @@ func (x *Nat) cmpGeq(y *Nat) choice {
 
        var c uint
        for i := 0; i < size; i++ {
-               c = (xLimbs[i] - yLimbs[i] - c) >> _W
+               _, c = bits.Sub(xLimbs[i], yLimbs[i], c)
        }
        // If there was a carry, then subtracting y underflowed, so
        // x is not greater than or equal to y.
@@ -290,44 +261,39 @@ func (x *Nat) assign(on choice, y *Nat) *Nat {
        xLimbs := x.limbs[:size]
        yLimbs := y.limbs[:size]
 
+       mask := ctMask(on)
        for i := 0; i < size; i++ {
-               xLimbs[i] = ctSelect(on, yLimbs[i], xLimbs[i])
+               xLimbs[i] ^= mask & (xLimbs[i] ^ yLimbs[i])
        }
        return x
 }
 
-// add computes x += y if on == 1, and does nothing otherwise. It returns the
-// carry of the addition regardless of on.
+// add computes x += y and returns the carry.
 //
 // Both operands must have the same announced length.
-func (x *Nat) add(on choice, y *Nat) (c uint) {
+func (x *Nat) add(y *Nat) (c uint) {
        // Eliminate bounds checks in the loop.
        size := len(x.limbs)
        xLimbs := x.limbs[:size]
        yLimbs := y.limbs[:size]
 
        for i := 0; i < size; i++ {
-               res := xLimbs[i] + yLimbs[i] + c
-               xLimbs[i] = ctSelect(on, res&_MASK, xLimbs[i])
-               c = res >> _W
+               xLimbs[i], c = bits.Add(xLimbs[i], yLimbs[i], c)
        }
        return
 }
 
-// sub computes x -= y if on == 1, and does nothing otherwise. It returns the
-// borrow of the subtraction regardless of on.
+// sub computes x -= y. It returns the borrow of the subtraction.
 //
 // Both operands must have the same announced length.
-func (x *Nat) sub(on choice, y *Nat) (c uint) {
+func (x *Nat) sub(y *Nat) (c uint) {
        // Eliminate bounds checks in the loop.
        size := len(x.limbs)
        xLimbs := x.limbs[:size]
        yLimbs := y.limbs[:size]
 
        for i := 0; i < size; i++ {
-               res := xLimbs[i] - yLimbs[i] - c
-               xLimbs[i] = ctSelect(on, res&_MASK, xLimbs[i])
-               c = res >> _W
+               xLimbs[i], c = bits.Sub(xLimbs[i], yLimbs[i], c)
        }
        return
 }
@@ -371,19 +337,20 @@ func minusInverseModW(x uint) uint {
        // Every iteration of this loop doubles the least-significant bits of
        // correct inverse in y. The first three bits are already correct (1⁻¹ = 1,
        // 3⁻¹ = 3, 5⁻¹ = 5, and 7⁻¹ = 7 mod 8), so doubling five times is enough
-       // for 61 bits (and wastes only one iteration for 31 bits).
+       // for 64 bits (and wastes only one iteration for 32 bits).
        //
        // See https://crypto.stackexchange.com/a/47496.
        y := x
        for i := 0; i < 5; i++ {
                y = y * (2 - x*y)
        }
-       return (1 << _W) - (y & _MASK)
+       return -y
 }
 
 // NewModulusFromBig creates a new Modulus from a [big.Int].
 //
-// The Int must be odd. The number of significant bits must be leakable.
+// The Int must be odd. The number of significant bits (and nothing else) is
+// leaked through timing side-channels.
 func NewModulusFromBig(n *big.Int) *Modulus {
        m := &Modulus{}
        m.nat = NewNat().setBig(n)
@@ -424,7 +391,7 @@ func (m *Modulus) Nat() *Nat {
 
 // shiftIn calculates x = x << _W + y mod m.
 //
-// This assumes that x is already reduced mod m, and that y < 2^_W.
+// This assumes that x is already reduced mod m.
 func (x *Nat) shiftIn(y uint, m *Modulus) *Nat {
        d := NewNat().resetFor(m)
 
@@ -440,25 +407,21 @@ func (x *Nat) shiftIn(y uint, m *Modulus) *Nat {
        //
        // To do the reduction, each iteration computes both 2x + b and 2x + b - m.
        // The next iteration (and finally the return line) will use either result
-       // based on whether the subtraction underflowed.
+       // based on whether 2x + b overflows m.
        needSubtraction := no
        for i := _W - 1; i >= 0; i-- {
                carry := (y >> i) & 1
                var borrow uint
+               mask := ctMask(needSubtraction)
                for i := 0; i < size; i++ {
-                       l := ctSelect(needSubtraction, dLimbs[i], xLimbs[i])
-
-                       res := l<<1 + carry
-                       xLimbs[i] = res & _MASK
-                       carry = res >> _W
-
-                       res = xLimbs[i] - mLimbs[i] - borrow
-                       dLimbs[i] = res & _MASK
-                       borrow = res >> _W
+                       l := xLimbs[i] ^ (mask & (xLimbs[i] ^ dLimbs[i]))
+                       xLimbs[i], carry = bits.Add(l, l, carry)
+                       dLimbs[i], borrow = bits.Sub(xLimbs[i], mLimbs[i], borrow)
                }
-               // See Add for how carry (aka overflow), borrow (aka underflow), and
-               // needSubtraction relate.
-               needSubtraction = ctEq(carry, borrow)
+               // Like in maybeSubtractModulus, we need the subtraction if either it
+               // didn't underflow (meaning 2x + b > m) or if computing 2x + b
+               // overflowed (meaning 2x + b > 2^_W*n > m).
+               needSubtraction = not(choice(borrow)) | choice(carry)
        }
        return x.assign(needSubtraction, d)
 }
@@ -494,11 +457,11 @@ func (out *Nat) Mod(x *Nat, m *Modulus) *Nat {
        return out
 }
 
-// ExpandFor ensures out has the right size to work with operations modulo m.
+// ExpandFor ensures x has the right size to work with operations modulo m.
 //
-// The announced size of out must be smaller than or equal to that of m.
-func (out *Nat) ExpandFor(m *Modulus) *Nat {
-       return out.expand(len(m.nat.limbs))
+// The announced size of x must be smaller than or equal to that of m.
+func (x *Nat) ExpandFor(m *Modulus) *Nat {
+       return x.expand(len(m.nat.limbs))
 }
 
 // resetFor ensures out has the right size to work with operations modulo m.
@@ -508,14 +471,34 @@ func (out *Nat) resetFor(m *Modulus) *Nat {
        return out.reset(len(m.nat.limbs))
 }
 
+// maybeSubtractModulus computes x -= m if and only if x >= m or if "always" is yes.
+//
+// It can be used to reduce modulo m a value up to 2m - 1, which is a common
+// range for results computed by higher level operations.
+//
+// always is usually a carry that indicates that the operation that produced x
+// overflowed its size, meaning abstractly x > 2^_W*n > m even if x < m.
+//
+// x and m operands must have the same announced length.
+func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) {
+       t := NewNat().set(x)
+       underflow := t.sub(m.nat)
+       // We keep the result if x - m didn't underflow (meaning x >= m)
+       // or if always was set.
+       keep := not(choice(underflow)) | choice(always)
+       x.assign(keep, t)
+}
+
 // Sub computes x = x - y mod m.
 //
 // The length of both operands must be the same as the modulus. Both operands
 // must already be reduced modulo m.
 func (x *Nat) Sub(y *Nat, m *Modulus) *Nat {
-       underflow := x.sub(yes, y)
+       underflow := x.sub(y)
        // If the subtraction underflowed, add m.
-       x.add(choice(underflow), m.nat)
+       t := NewNat().set(x)
+       t.add(m.nat)
+       x.assign(choice(underflow), t)
        return x
 }
 
@@ -524,34 +507,8 @@ func (x *Nat) Sub(y *Nat, m *Modulus) *Nat {
 // The length of both operands must be the same as the modulus. Both operands
 // must already be reduced modulo m.
 func (x *Nat) Add(y *Nat, m *Modulus) *Nat {
-       overflow := x.add(yes, y)
-       underflow := not(x.cmpGeq(m.nat)) // x < m
-
-       // Three cases are possible:
-       //
-       //   - overflow = 0, underflow = 0
-       //
-       // In this case, addition fits in our limbs, but we can still subtract away
-       // m without an underflow, so we need to perform the subtraction to reduce
-       // our result.
-       //
-       //   - overflow = 0, underflow = 1
-       //
-       // The addition fits in our limbs, but we can't subtract m without
-       // underflowing. The result is already reduced.
-       //
-       //   - overflow = 1, underflow = 1
-       //
-       // The addition does not fit in our limbs, and the subtraction's borrow
-       // would cancel out with the addition's carry. We need to subtract m to
-       // reduce our result.
-       //
-       // The overflow = 1, underflow = 0 case is not possible, because y is at
-       // most m - 1, and if adding m - 1 overflows, then subtracting m must
-       // necessarily underflow.
-       needSubtraction := ctEq(overflow, uint(underflow))
-
-       x.sub(needSubtraction, m.nat)
+       overflow := x.add(y)
+       x.maybeSubtractModulus(choice(overflow), m)
        return x
 }
 
@@ -582,65 +539,146 @@ func (x *Nat) montgomeryReduction(m *Modulus) *Nat {
        return x.montgomeryMul(t0, t1, m)
 }
 
-// montgomeryMul calculates d = a * b / R mod m, with R = 2^(_W * n) and
-// n = len(m.nat.limbs), using the Montgomery Multiplication technique.
+// montgomeryMul calculates x = a * b / R mod m, with R = 2^(_W * n) and
+// n = len(m.nat.limbs), also known as a Montgomery multiplication.
 //
-// All inputs should be the same length, not aliasing d, and already
-// reduced modulo m. d will be resized to the size of m and overwritten.
-func (d *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat {
-       d.resetFor(m)
-       if len(a.limbs) != len(m.nat.limbs) || len(b.limbs) != len(m.nat.limbs) {
-               panic("bigmod: invalid montgomeryMul input")
-       }
+// All inputs should be the same length and already reduced modulo m.
+// x will be resized to the size of m and overwritten.
+func (x *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat {
+       n := len(m.nat.limbs)
+       mLimbs := m.nat.limbs[:n]
+       aLimbs := a.limbs[:n]
+       bLimbs := b.limbs[:n]
+
+       switch n {
+       default:
+               // Attempt to use a stack-allocated backing array.
+               T := make([]uint, 0, preallocLimbs*2)
+               if cap(T) < n*2 {
+                       T = make([]uint, 0, n*2)
+               }
+               T = T[:n*2]
+
+               // This loop implements Word-by-Word Montgomery Multiplication, as
+               // described in Algorithm 4 (Fig. 3) of "Efficient Software
+               // Implementations of Modular Exponentiation" by Shay Gueron
+               // [https://eprint.iacr.org/2011/239.pdf].
+               var c uint
+               for i := 0; i < n; i++ {
+                       _ = T[n+i] // bounds check elimination hint
+
+                       // Step 1 (T = a × b) is computed as a large pen-and-paper column
+                       // multiplication of two numbers with n base-2^_W digits. If we just
+                       // wanted to produce 2n-wide T, we would do
+                       //
+                       //   for i := 0; i < n; i++ {
+                       //       d := bLimbs[i]
+                       //       T[n+i] = addMulVVW(T[i:n+i], aLimbs, d)
+                       //   }
+                       //
+                       // where d is a digit of the multiplier, T[i:n+i] is the shifted
+                       // position of the product of that digit, and T[n+i] is the final carry.
+                       // Note that T[i] isn't modified after processing the i-th digit.
+                       //
+                       // Instead of running two loops, one for Step 1 and one for Steps 2–6,
+                       // the result of Step 1 is computed during the next loop. This is
+                       // possible because each iteration only uses T[i] in Step 2 and then
+                       // discards it in Step 6.
+                       d := bLimbs[i]
+                       c1 := addMulVVW(T[i:n+i], aLimbs, d)
+
+                       // Step 6 is replaced by shifting the virtual window we operate
+                       // over: T of the algorithm is T[i:] for us. That means that T1 in
+                       // Step 2 (T mod 2^_W) is simply T[i]. k0 in Step 3 is our m0inv.
+                       Y := T[i] * m.m0inv
+
+                       // Step 4 and 5 add Y × m to T, which as mentioned above is stored
+                       // at T[i:]. The two carries (from a × d and Y × m) are added up in
+                       // the next word T[n+i], and the carry bit from that addition is
+                       // brought forward to the next iteration.
+                       c2 := addMulVVW(T[i:n+i], mLimbs, Y)
+                       T[n+i], c = bits.Add(c1, c2, c)
+               }
 
-       // See https://bearssl.org/bigint.html#montgomery-reduction-and-multiplication
-       // for a description of the algorithm implemented mostly in montgomeryLoop.
-       // See Add for how overflow, underflow, and needSubtraction relate.
-       overflow := montgomeryLoop(d.limbs, a.limbs, b.limbs, m.nat.limbs, m.m0inv)
-       underflow := not(d.cmpGeq(m.nat)) // d < m
-       needSubtraction := ctEq(overflow, uint(underflow))
-       d.sub(needSubtraction, m.nat)
+               // Finally for Step 7 we copy the final T window into x, and subtract m
+               // if necessary (which as explained in maybeSubtractModulus can be the
+               // case both if x >= m, or if x overflowed).
+               //
+               // The paper suggests in Section 4 that we can do an "Almost Montgomery
+               // Multiplication" by subtracting only in the overflow case, but the
+               // cost is very similar since the constant time subtraction tells us if
+               // x >= m as a side effect, and taking care of the broken invariant is
+               // highly undesirable (see https://go.dev/issue/13907).
+               copy(x.reset(n).limbs, T[n:])
+               x.maybeSubtractModulus(choice(c), m)
+
+       // The following specialized cases follow the exact same algorithm, but
+       // optimized for the sizes most used in RSA. addMulVVW is implemented in
+       // assembly with loop unrolling depending on the architecture and bounds
+       // checks are removed by the compiler thanks to the constant size.
+       case 1024 / _W:
+               const n = 1024 / _W // compiler hint
+               T := make([]uint, n*2)
+               var c uint
+               for i := 0; i < n; i++ {
+                       d := bLimbs[i]
+                       c1 := addMulVVW1024(&T[i], &aLimbs[0], d)
+                       Y := T[i] * m.m0inv
+                       c2 := addMulVVW1024(&T[i], &mLimbs[0], Y)
+                       T[n+i], c = bits.Add(c1, c2, c)
+               }
+               copy(x.reset(n).limbs, T[n:])
+               x.maybeSubtractModulus(choice(c), m)
+
+       case 1536 / _W:
+               const n = 1536 / _W // compiler hint
+               T := make([]uint, n*2)
+               var c uint
+               for i := 0; i < n; i++ {
+                       d := bLimbs[i]
+                       c1 := addMulVVW1536(&T[i], &aLimbs[0], d)
+                       Y := T[i] * m.m0inv
+                       c2 := addMulVVW1536(&T[i], &mLimbs[0], Y)
+                       T[n+i], c = bits.Add(c1, c2, c)
+               }
+               copy(x.reset(n).limbs, T[n:])
+               x.maybeSubtractModulus(choice(c), m)
+
+       case 2048 / _W:
+               const n = 2048 / _W // compiler hint
+               T := make([]uint, n*2)
+               var c uint
+               for i := 0; i < n; i++ {
+                       d := bLimbs[i]
+                       c1 := addMulVVW2048(&T[i], &aLimbs[0], d)
+                       Y := T[i] * m.m0inv
+                       c2 := addMulVVW2048(&T[i], &mLimbs[0], Y)
+                       T[n+i], c = bits.Add(c1, c2, c)
+               }
+               copy(x.reset(n).limbs, T[n:])
+               x.maybeSubtractModulus(choice(c), m)
+       }
 
-       return d
+       return x
 }
 
-func montgomeryLoopGeneric(d, a, b, m []uint, m0inv uint) (overflow uint) {
-       // Eliminate bounds checks in the loop.
-       size := len(d)
-       a = a[:size]
-       b = b[:size]
-       m = m[:size]
-
-       for _, ai := range a {
-               // This is an unrolled iteration of the loop below with j = 0.
-               hi, lo := bits.Mul(ai, b[0])
-               z_lo, c := bits.Add(d[0], lo, 0)
-               f := (z_lo * m0inv) & _MASK // (d[0] + a[i] * b[0]) * m0inv
-               z_hi, _ := bits.Add(0, hi, c)
-               hi, lo = bits.Mul(f, m[0])
-               z_lo, c = bits.Add(z_lo, lo, 0)
-               z_hi, _ = bits.Add(z_hi, hi, c)
-               carry := z_hi<<1 | z_lo>>_W
-
-               for j := 1; j < size; j++ {
-                       // z = d[j] + a[i] * b[j] + f * m[j] + carry <= 2^(2W+1) - 2^(W+1) + 2^W
-                       hi, lo := bits.Mul(ai, b[j])
-                       z_lo, c := bits.Add(d[j], lo, 0)
-                       z_hi, _ := bits.Add(0, hi, c)
-                       hi, lo = bits.Mul(f, m[j])
-                       z_lo, c = bits.Add(z_lo, lo, 0)
-                       z_hi, _ = bits.Add(z_hi, hi, c)
-                       z_lo, c = bits.Add(z_lo, carry, 0)
-                       z_hi, _ = bits.Add(z_hi, 0, c)
-                       d[j-1] = z_lo & _MASK
-                       carry = z_hi<<1 | z_lo>>_W // carry <= 2^(W+1) - 2
-               }
-
-               z := overflow + carry // z <= 2^(W+1) - 1
-               d[size-1] = z & _MASK
-               overflow = z >> _W // overflow <= 1
+// addMulVVW multiplies the multi-word value x by the single-word value y,
+// adding the result to the multi-word value z and returning the final carry.
+// It can be thought of as one row of a pen-and-paper column multiplication.
+func addMulVVW(z, x []uint, y uint) (carry uint) {
+       _ = x[len(z)-1] // bounds check elimination hint
+       for i := range z {
+               hi, lo := bits.Mul(x[i], y)
+               lo, c := bits.Add(lo, z[i], 0)
+               // We use bits.Add with zero to get an add-with-carry instruction that
+               // absorbs the carry from the previous bits.Add.
+               hi, _ = bits.Add(hi, 0, c)
+               lo, c = bits.Add(lo, carry, 0)
+               hi, _ = bits.Add(hi, 0, c)
+               carry = hi
+               z[i] = lo
        }
-       return
+       return carry
 }
 
 // Mul calculates x *= y mod m.
@@ -661,7 +699,8 @@ func (x *Nat) Mul(y *Nat, m *Modulus) *Nat {
 func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
        // We use a 4 bit window. For our RSA workload, 4 bit windows are faster
        // than 2 bit windows, but use an extra 12 nats worth of scratch space.
-       // Using bit sizes that don't divide 8 are more complex to implement.
+       // Using bit sizes that don't divide 8 are more complex to implement, but
+       // are likely to be more efficient if necessary.
 
        table := [(1 << 4) - 1]*Nat{ // table[i] = x ^ (i+1)
                // newNat calls are unrolled so they are allocated on the stack.
@@ -681,7 +720,8 @@ func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
        t1 := NewNat().ExpandFor(m)
        for _, b := range e {
                for _, j := range []int{4, 0} {
-                       // Square four times.
+                       // Square four times. Optimization note: this can be implemented
+                       // more efficiently than with generic Montgomery multiplication.
                        t1.montgomeryMul(out, out, m)
                        out.montgomeryMul(t1, t1, m)
                        t1.montgomeryMul(out, out, m)
diff --git a/src/crypto/internal/bigmod/nat_386.s b/src/crypto/internal/bigmod/nat_386.s
new file mode 100644 (file)
index 0000000..0637d27
--- /dev/null
@@ -0,0 +1,47 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !purego
+
+#include "textflag.h"
+
+// func addMulVVW1024(z, x *uint, y uint) (c uint)
+TEXT ·addMulVVW1024(SB), $0-16
+       MOVL    $32, BX
+       JMP             addMulVVWx(SB)
+
+// func addMulVVW1536(z, x *uint, y uint) (c uint)
+TEXT ·addMulVVW1536(SB), $0-16
+       MOVL    $48, BX
+       JMP             addMulVVWx(SB)
+
+// func addMulVVW2048(z, x *uint, y uint) (c uint)
+TEXT ·addMulVVW2048(SB), $0-16
+       MOVL    $64, BX
+       JMP             addMulVVWx(SB)
+
+TEXT addMulVVWx(SB), NOFRAME|NOSPLIT, $0
+       MOVL z+0(FP), DI
+       MOVL x+4(FP), SI
+       MOVL y+8(FP), BP
+       LEAL (DI)(BX*4), DI
+       LEAL (SI)(BX*4), SI
+       NEGL BX                 // i = -n
+       MOVL $0, CX             // c = 0
+       JMP E6
+
+L6:    MOVL (SI)(BX*4), AX
+       MULL BP
+       ADDL CX, AX
+       ADCL $0, DX
+       ADDL AX, (DI)(BX*4)
+       ADCL $0, DX
+       MOVL DX, CX
+       ADDL $1, BX             // i++
+
+E6:    CMPL BX, $0             // i < 0
+       JL L6
+
+       MOVL CX, c+12(FP)
+       RET
diff --git a/src/crypto/internal/bigmod/nat_amd64.go b/src/crypto/internal/bigmod/nat_amd64.go
deleted file mode 100644 (file)
index e947782..0000000
+++ /dev/null
@@ -1,8 +0,0 @@
-// Code generated by command: go run nat_amd64_asm.go -out ../nat_amd64.s -stubs ../nat_amd64.go -pkg bigmod. DO NOT EDIT.
-
-//go:build amd64 && gc && !purego
-
-package bigmod
-
-//go:noescape
-func montgomeryLoop(d []uint, a []uint, b []uint, m []uint, m0inv uint) uint
index 12b7629984e304f2f62c61e677f2bb175d200a6e..ab94344e10a44268cd8cdeafb4d09c07a228f0d8 100644 (file)
-// Code generated by command: go run nat_amd64_asm.go -out ../nat_amd64.s -stubs ../nat_amd64.go -pkg bigmod. DO NOT EDIT.
-
-//go:build amd64 && gc && !purego
-
-// func montgomeryLoop(d []uint, a []uint, b []uint, m []uint, m0inv uint) uint
-TEXT ·montgomeryLoop(SB), $8-112
-       MOVQ d_len+8(FP), CX
-       MOVQ d_base+0(FP), BX
-       MOVQ b_base+48(FP), SI
-       MOVQ m_base+72(FP), DI
-       MOVQ m0inv+96(FP), R8
-       XORQ R9, R9
-       XORQ R10, R10
-
-outerLoop:
-       MOVQ  a_base+24(FP), R11
-       MOVQ  (R11)(R10*8), R11
-       MOVQ  (SI), AX
-       MULQ  R11
-       MOVQ  AX, R13
-       MOVQ  DX, R12
-       ADDQ  (BX), R13
-       ADCQ  $0x00, R12
-       MOVQ  R8, R14
-       IMULQ R13, R14
-       BTRQ  $0x3f, R14
-       MOVQ  (DI), AX
-       MULQ  R14
-       ADDQ  AX, R13
-       ADCQ  DX, R12
-       SHRQ  $0x3f, R12, R13
-       XORQ  R12, R12
-       INCQ  R12
-       JMP   innerLoopCondition
-
-innerLoop:
-       MOVQ (SI)(R12*8), AX
-       MULQ R11
-       MOVQ AX, BP
-       MOVQ DX, R15
-       MOVQ (DI)(R12*8), AX
-       MULQ R14
-       ADDQ AX, BP
-       ADCQ DX, R15
-       ADDQ (BX)(R12*8), BP
-       ADCQ $0x00, R15
-       ADDQ R13, BP
-       ADCQ $0x00, R15
-       MOVQ BP, AX
-       BTRQ $0x3f, AX
-       MOVQ AX, -8(BX)(R12*8)
-       SHRQ $0x3f, R15, BP
-       MOVQ BP, R13
-       INCQ R12
-
-innerLoopCondition:
-       CMPQ CX, R12
-       JGT  innerLoop
-       ADDQ R13, R9
-       MOVQ R9, AX
-       BTRQ $0x3f, AX
-       MOVQ AX, -8(BX)(CX*8)
-       SHRQ $0x3f, R9
-       INCQ R10
-       CMPQ CX, R10
-       JGT  outerLoop
-       MOVQ R9, ret+104(FP)
+// Code generated by command: go run nat_amd64_asm.go -out ../nat_amd64.s -pkg bigmod. DO NOT EDIT.
+
+//go:build !purego
+
+// func addMulVVW1024(z *uint, x *uint, y uint) (c uint)
+// Requires: ADX, BMI2
+TEXT ·addMulVVW1024(SB), $0-32
+       CMPB ·supportADX+0(SB), $0x01
+       JEQ  adx
+       MOVQ z+0(FP), CX
+       MOVQ x+8(FP), BX
+       MOVQ y+16(FP), SI
+       XORQ DI, DI
+
+       // Iteration 0
+       MOVQ (BX), AX
+       MULQ SI
+       ADDQ (CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, (CX)
+
+       // Iteration 1
+       MOVQ 8(BX), AX
+       MULQ SI
+       ADDQ 8(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 8(CX)
+
+       // Iteration 2
+       MOVQ 16(BX), AX
+       MULQ SI
+       ADDQ 16(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 16(CX)
+
+       // Iteration 3
+       MOVQ 24(BX), AX
+       MULQ SI
+       ADDQ 24(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 24(CX)
+
+       // Iteration 4
+       MOVQ 32(BX), AX
+       MULQ SI
+       ADDQ 32(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 32(CX)
+
+       // Iteration 5
+       MOVQ 40(BX), AX
+       MULQ SI
+       ADDQ 40(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 40(CX)
+
+       // Iteration 6
+       MOVQ 48(BX), AX
+       MULQ SI
+       ADDQ 48(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 48(CX)
+
+       // Iteration 7
+       MOVQ 56(BX), AX
+       MULQ SI
+       ADDQ 56(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 56(CX)
+
+       // Iteration 8
+       MOVQ 64(BX), AX
+       MULQ SI
+       ADDQ 64(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 64(CX)
+
+       // Iteration 9
+       MOVQ 72(BX), AX
+       MULQ SI
+       ADDQ 72(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 72(CX)
+
+       // Iteration 10
+       MOVQ 80(BX), AX
+       MULQ SI
+       ADDQ 80(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 80(CX)
+
+       // Iteration 11
+       MOVQ 88(BX), AX
+       MULQ SI
+       ADDQ 88(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 88(CX)
+
+       // Iteration 12
+       MOVQ 96(BX), AX
+       MULQ SI
+       ADDQ 96(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 96(CX)
+
+       // Iteration 13
+       MOVQ 104(BX), AX
+       MULQ SI
+       ADDQ 104(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 104(CX)
+
+       // Iteration 14
+       MOVQ 112(BX), AX
+       MULQ SI
+       ADDQ 112(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 112(CX)
+
+       // Iteration 15
+       MOVQ 120(BX), AX
+       MULQ SI
+       ADDQ 120(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 120(CX)
+       MOVQ DI, c+24(FP)
+       RET
+
+adx:
+       MOVQ z+0(FP), AX
+       MOVQ x+8(FP), CX
+       MOVQ y+16(FP), DX
+       XORQ BX, BX
+       XORQ SI, SI
+
+       // Iteration 0
+       MULXQ (CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ (AX), R8
+       MOVQ  R8, (AX)
+
+       // Iteration 1
+       MULXQ 8(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 8(AX), R8
+       MOVQ  R8, 8(AX)
+
+       // Iteration 2
+       MULXQ 16(CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ 16(AX), R8
+       MOVQ  R8, 16(AX)
+
+       // Iteration 3
+       MULXQ 24(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 24(AX), R8
+       MOVQ  R8, 24(AX)
+
+       // Iteration 4
+       MULXQ 32(CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ 32(AX), R8
+       MOVQ  R8, 32(AX)
+
+       // Iteration 5
+       MULXQ 40(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 40(AX), R8
+       MOVQ  R8, 40(AX)
+
+       // Iteration 6
+       MULXQ 48(CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ 48(AX), R8
+       MOVQ  R8, 48(AX)
+
+       // Iteration 7
+       MULXQ 56(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 56(AX), R8
+       MOVQ  R8, 56(AX)
+
+       // Iteration 8
+       MULXQ 64(CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ 64(AX), R8
+       MOVQ  R8, 64(AX)
+
+       // Iteration 9
+       MULXQ 72(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 72(AX), R8
+       MOVQ  R8, 72(AX)
+
+       // Iteration 10
+       MULXQ 80(CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ 80(AX), R8
+       MOVQ  R8, 80(AX)
+
+       // Iteration 11
+       MULXQ 88(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 88(AX), R8
+       MOVQ  R8, 88(AX)
+
+       // Iteration 12
+       MULXQ 96(CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ 96(AX), R8
+       MOVQ  R8, 96(AX)
+
+       // Iteration 13
+       MULXQ 104(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 104(AX), R8
+       MOVQ  R8, 104(AX)
+
+       // Iteration 14
+       MULXQ 112(CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ 112(AX), R8
+       MOVQ  R8, 112(AX)
+
+       // Iteration 15
+       MULXQ 120(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 120(AX), R8
+       MOVQ  R8, 120(AX)
+
+       // Add back carry flags and return
+       ADCXQ SI, BX
+       ADOXQ SI, BX
+       MOVQ  BX, c+24(FP)
+       RET
+
+// func addMulVVW1536(z *uint, x *uint, y uint) (c uint)
+// Requires: ADX, BMI2
+TEXT ·addMulVVW1536(SB), $0-32
+       CMPB ·supportADX+0(SB), $0x01
+       JEQ  adx
+       MOVQ z+0(FP), CX
+       MOVQ x+8(FP), BX
+       MOVQ y+16(FP), SI
+       XORQ DI, DI
+
+       // Iteration 0
+       MOVQ (BX), AX
+       MULQ SI
+       ADDQ (CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, (CX)
+
+       // Iteration 1
+       MOVQ 8(BX), AX
+       MULQ SI
+       ADDQ 8(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 8(CX)
+
+       // Iteration 2
+       MOVQ 16(BX), AX
+       MULQ SI
+       ADDQ 16(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 16(CX)
+
+       // Iteration 3
+       MOVQ 24(BX), AX
+       MULQ SI
+       ADDQ 24(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 24(CX)
+
+       // Iteration 4
+       MOVQ 32(BX), AX
+       MULQ SI
+       ADDQ 32(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 32(CX)
+
+       // Iteration 5
+       MOVQ 40(BX), AX
+       MULQ SI
+       ADDQ 40(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 40(CX)
+
+       // Iteration 6
+       MOVQ 48(BX), AX
+       MULQ SI
+       ADDQ 48(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 48(CX)
+
+       // Iteration 7
+       MOVQ 56(BX), AX
+       MULQ SI
+       ADDQ 56(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 56(CX)
+
+       // Iteration 8
+       MOVQ 64(BX), AX
+       MULQ SI
+       ADDQ 64(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 64(CX)
+
+       // Iteration 9
+       MOVQ 72(BX), AX
+       MULQ SI
+       ADDQ 72(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 72(CX)
+
+       // Iteration 10
+       MOVQ 80(BX), AX
+       MULQ SI
+       ADDQ 80(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 80(CX)
+
+       // Iteration 11
+       MOVQ 88(BX), AX
+       MULQ SI
+       ADDQ 88(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 88(CX)
+
+       // Iteration 12
+       MOVQ 96(BX), AX
+       MULQ SI
+       ADDQ 96(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 96(CX)
+
+       // Iteration 13
+       MOVQ 104(BX), AX
+       MULQ SI
+       ADDQ 104(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 104(CX)
+
+       // Iteration 14
+       MOVQ 112(BX), AX
+       MULQ SI
+       ADDQ 112(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 112(CX)
+
+       // Iteration 15
+       MOVQ 120(BX), AX
+       MULQ SI
+       ADDQ 120(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 120(CX)
+
+       // Iteration 16
+       MOVQ 128(BX), AX
+       MULQ SI
+       ADDQ 128(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 128(CX)
+
+       // Iteration 17
+       MOVQ 136(BX), AX
+       MULQ SI
+       ADDQ 136(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 136(CX)
+
+       // Iteration 18
+       MOVQ 144(BX), AX
+       MULQ SI
+       ADDQ 144(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 144(CX)
+
+       // Iteration 19
+       MOVQ 152(BX), AX
+       MULQ SI
+       ADDQ 152(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 152(CX)
+
+       // Iteration 20
+       MOVQ 160(BX), AX
+       MULQ SI
+       ADDQ 160(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 160(CX)
+
+       // Iteration 21
+       MOVQ 168(BX), AX
+       MULQ SI
+       ADDQ 168(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 168(CX)
+
+       // Iteration 22
+       MOVQ 176(BX), AX
+       MULQ SI
+       ADDQ 176(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 176(CX)
+
+       // Iteration 23
+       MOVQ 184(BX), AX
+       MULQ SI
+       ADDQ 184(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 184(CX)
+       MOVQ DI, c+24(FP)
+       RET
+
+adx:
+       MOVQ z+0(FP), AX
+       MOVQ x+8(FP), CX
+       MOVQ y+16(FP), DX
+       XORQ BX, BX
+       XORQ SI, SI
+
+       // Iteration 0
+       MULXQ (CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ (AX), R8
+       MOVQ  R8, (AX)
+
+       // Iteration 1
+       MULXQ 8(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 8(AX), R8
+       MOVQ  R8, 8(AX)
+
+       // Iteration 2
+       MULXQ 16(CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ 16(AX), R8
+       MOVQ  R8, 16(AX)
+
+       // Iteration 3
+       MULXQ 24(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 24(AX), R8
+       MOVQ  R8, 24(AX)
+
+       // Iteration 4
+       MULXQ 32(CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ 32(AX), R8
+       MOVQ  R8, 32(AX)
+
+       // Iteration 5
+       MULXQ 40(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 40(AX), R8
+       MOVQ  R8, 40(AX)
+
+       // Iteration 6
+       MULXQ 48(CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ 48(AX), R8
+       MOVQ  R8, 48(AX)
+
+       // Iteration 7
+       MULXQ 56(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 56(AX), R8
+       MOVQ  R8, 56(AX)
+
+       // Iteration 8
+       MULXQ 64(CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ 64(AX), R8
+       MOVQ  R8, 64(AX)
+
+       // Iteration 9
+       MULXQ 72(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 72(AX), R8
+       MOVQ  R8, 72(AX)
+
+       // Iteration 10
+       MULXQ 80(CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ 80(AX), R8
+       MOVQ  R8, 80(AX)
+
+       // Iteration 11
+       MULXQ 88(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 88(AX), R8
+       MOVQ  R8, 88(AX)
+
+       // Iteration 12
+       MULXQ 96(CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ 96(AX), R8
+       MOVQ  R8, 96(AX)
+
+       // Iteration 13
+       MULXQ 104(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 104(AX), R8
+       MOVQ  R8, 104(AX)
+
+       // Iteration 14
+       MULXQ 112(CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ 112(AX), R8
+       MOVQ  R8, 112(AX)
+
+       // Iteration 15
+       MULXQ 120(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 120(AX), R8
+       MOVQ  R8, 120(AX)
+
+       // Iteration 16
+       MULXQ 128(CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ 128(AX), R8
+       MOVQ  R8, 128(AX)
+
+       // Iteration 17
+       MULXQ 136(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 136(AX), R8
+       MOVQ  R8, 136(AX)
+
+       // Iteration 18
+       MULXQ 144(CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ 144(AX), R8
+       MOVQ  R8, 144(AX)
+
+       // Iteration 19
+       MULXQ 152(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 152(AX), R8
+       MOVQ  R8, 152(AX)
+
+       // Iteration 20
+       MULXQ 160(CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ 160(AX), R8
+       MOVQ  R8, 160(AX)
+
+       // Iteration 21
+       MULXQ 168(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 168(AX), R8
+       MOVQ  R8, 168(AX)
+
+       // Iteration 22
+       MULXQ 176(CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ 176(AX), R8
+       MOVQ  R8, 176(AX)
+
+       // Iteration 23
+       MULXQ 184(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 184(AX), R8
+       MOVQ  R8, 184(AX)
+
+       // Add back carry flags and return
+       ADCXQ SI, BX
+       ADOXQ SI, BX
+       MOVQ  BX, c+24(FP)
+       RET
+
+// func addMulVVW2048(z *uint, x *uint, y uint) (c uint)
+// Requires: ADX, BMI2
+TEXT ·addMulVVW2048(SB), $0-32
+       CMPB ·supportADX+0(SB), $0x01
+       JEQ  adx
+       MOVQ z+0(FP), CX
+       MOVQ x+8(FP), BX
+       MOVQ y+16(FP), SI
+       XORQ DI, DI
+
+       // Iteration 0
+       MOVQ (BX), AX
+       MULQ SI
+       ADDQ (CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, (CX)
+
+       // Iteration 1
+       MOVQ 8(BX), AX
+       MULQ SI
+       ADDQ 8(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 8(CX)
+
+       // Iteration 2
+       MOVQ 16(BX), AX
+       MULQ SI
+       ADDQ 16(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 16(CX)
+
+       // Iteration 3
+       MOVQ 24(BX), AX
+       MULQ SI
+       ADDQ 24(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 24(CX)
+
+       // Iteration 4
+       MOVQ 32(BX), AX
+       MULQ SI
+       ADDQ 32(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 32(CX)
+
+       // Iteration 5
+       MOVQ 40(BX), AX
+       MULQ SI
+       ADDQ 40(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 40(CX)
+
+       // Iteration 6
+       MOVQ 48(BX), AX
+       MULQ SI
+       ADDQ 48(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 48(CX)
+
+       // Iteration 7
+       MOVQ 56(BX), AX
+       MULQ SI
+       ADDQ 56(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 56(CX)
+
+       // Iteration 8
+       MOVQ 64(BX), AX
+       MULQ SI
+       ADDQ 64(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 64(CX)
+
+       // Iteration 9
+       MOVQ 72(BX), AX
+       MULQ SI
+       ADDQ 72(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 72(CX)
+
+       // Iteration 10
+       MOVQ 80(BX), AX
+       MULQ SI
+       ADDQ 80(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 80(CX)
+
+       // Iteration 11
+       MOVQ 88(BX), AX
+       MULQ SI
+       ADDQ 88(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 88(CX)
+
+       // Iteration 12
+       MOVQ 96(BX), AX
+       MULQ SI
+       ADDQ 96(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 96(CX)
+
+       // Iteration 13
+       MOVQ 104(BX), AX
+       MULQ SI
+       ADDQ 104(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 104(CX)
+
+       // Iteration 14
+       MOVQ 112(BX), AX
+       MULQ SI
+       ADDQ 112(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 112(CX)
+
+       // Iteration 15
+       MOVQ 120(BX), AX
+       MULQ SI
+       ADDQ 120(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 120(CX)
+
+       // Iteration 16
+       MOVQ 128(BX), AX
+       MULQ SI
+       ADDQ 128(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 128(CX)
+
+       // Iteration 17
+       MOVQ 136(BX), AX
+       MULQ SI
+       ADDQ 136(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 136(CX)
+
+       // Iteration 18
+       MOVQ 144(BX), AX
+       MULQ SI
+       ADDQ 144(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 144(CX)
+
+       // Iteration 19
+       MOVQ 152(BX), AX
+       MULQ SI
+       ADDQ 152(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 152(CX)
+
+       // Iteration 20
+       MOVQ 160(BX), AX
+       MULQ SI
+       ADDQ 160(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 160(CX)
+
+       // Iteration 21
+       MOVQ 168(BX), AX
+       MULQ SI
+       ADDQ 168(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 168(CX)
+
+       // Iteration 22
+       MOVQ 176(BX), AX
+       MULQ SI
+       ADDQ 176(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 176(CX)
+
+       // Iteration 23
+       MOVQ 184(BX), AX
+       MULQ SI
+       ADDQ 184(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 184(CX)
+
+       // Iteration 24
+       MOVQ 192(BX), AX
+       MULQ SI
+       ADDQ 192(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 192(CX)
+
+       // Iteration 25
+       MOVQ 200(BX), AX
+       MULQ SI
+       ADDQ 200(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 200(CX)
+
+       // Iteration 26
+       MOVQ 208(BX), AX
+       MULQ SI
+       ADDQ 208(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 208(CX)
+
+       // Iteration 27
+       MOVQ 216(BX), AX
+       MULQ SI
+       ADDQ 216(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 216(CX)
+
+       // Iteration 28
+       MOVQ 224(BX), AX
+       MULQ SI
+       ADDQ 224(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 224(CX)
+
+       // Iteration 29
+       MOVQ 232(BX), AX
+       MULQ SI
+       ADDQ 232(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 232(CX)
+
+       // Iteration 30
+       MOVQ 240(BX), AX
+       MULQ SI
+       ADDQ 240(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 240(CX)
+
+       // Iteration 31
+       MOVQ 248(BX), AX
+       MULQ SI
+       ADDQ 248(CX), AX
+       ADCQ $0x00, DX
+       ADDQ DI, AX
+       ADCQ $0x00, DX
+       MOVQ DX, DI
+       MOVQ AX, 248(CX)
+       MOVQ DI, c+24(FP)
+       RET
+
+adx:
+       MOVQ z+0(FP), AX
+       MOVQ x+8(FP), CX
+       MOVQ y+16(FP), DX
+       XORQ BX, BX
+       XORQ SI, SI
+
+       // Iteration 0
+       MULXQ (CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ (AX), R8
+       MOVQ  R8, (AX)
+
+       // Iteration 1
+       MULXQ 8(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 8(AX), R8
+       MOVQ  R8, 8(AX)
+
+       // Iteration 2
+       MULXQ 16(CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ 16(AX), R8
+       MOVQ  R8, 16(AX)
+
+       // Iteration 3
+       MULXQ 24(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 24(AX), R8
+       MOVQ  R8, 24(AX)
+
+       // Iteration 4
+       MULXQ 32(CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ 32(AX), R8
+       MOVQ  R8, 32(AX)
+
+       // Iteration 5
+       MULXQ 40(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 40(AX), R8
+       MOVQ  R8, 40(AX)
+
+       // Iteration 6
+       MULXQ 48(CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ 48(AX), R8
+       MOVQ  R8, 48(AX)
+
+       // Iteration 7
+       MULXQ 56(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 56(AX), R8
+       MOVQ  R8, 56(AX)
+
+       // Iteration 8
+       MULXQ 64(CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ 64(AX), R8
+       MOVQ  R8, 64(AX)
+
+       // Iteration 9
+       MULXQ 72(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 72(AX), R8
+       MOVQ  R8, 72(AX)
+
+       // Iteration 10
+       MULXQ 80(CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ 80(AX), R8
+       MOVQ  R8, 80(AX)
+
+       // Iteration 11
+       MULXQ 88(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 88(AX), R8
+       MOVQ  R8, 88(AX)
+
+       // Iteration 12
+       MULXQ 96(CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ 96(AX), R8
+       MOVQ  R8, 96(AX)
+
+       // Iteration 13
+       MULXQ 104(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 104(AX), R8
+       MOVQ  R8, 104(AX)
+
+       // Iteration 14
+       MULXQ 112(CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ 112(AX), R8
+       MOVQ  R8, 112(AX)
+
+       // Iteration 15
+       MULXQ 120(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 120(AX), R8
+       MOVQ  R8, 120(AX)
+
+       // Iteration 16
+       MULXQ 128(CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ 128(AX), R8
+       MOVQ  R8, 128(AX)
+
+       // Iteration 17
+       MULXQ 136(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 136(AX), R8
+       MOVQ  R8, 136(AX)
+
+       // Iteration 18
+       MULXQ 144(CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ 144(AX), R8
+       MOVQ  R8, 144(AX)
+
+       // Iteration 19
+       MULXQ 152(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 152(AX), R8
+       MOVQ  R8, 152(AX)
+
+       // Iteration 20
+       MULXQ 160(CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ 160(AX), R8
+       MOVQ  R8, 160(AX)
+
+       // Iteration 21
+       MULXQ 168(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 168(AX), R8
+       MOVQ  R8, 168(AX)
+
+       // Iteration 22
+       MULXQ 176(CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ 176(AX), R8
+       MOVQ  R8, 176(AX)
+
+       // Iteration 23
+       MULXQ 184(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 184(AX), R8
+       MOVQ  R8, 184(AX)
+
+       // Iteration 24
+       MULXQ 192(CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ 192(AX), R8
+       MOVQ  R8, 192(AX)
+
+       // Iteration 25
+       MULXQ 200(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 200(AX), R8
+       MOVQ  R8, 200(AX)
+
+       // Iteration 26
+       MULXQ 208(CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ 208(AX), R8
+       MOVQ  R8, 208(AX)
+
+       // Iteration 27
+       MULXQ 216(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 216(AX), R8
+       MOVQ  R8, 216(AX)
+
+       // Iteration 28
+       MULXQ 224(CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ 224(AX), R8
+       MOVQ  R8, 224(AX)
+
+       // Iteration 29
+       MULXQ 232(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 232(AX), R8
+       MOVQ  R8, 232(AX)
+
+       // Iteration 30
+       MULXQ 240(CX), R8, DI
+       ADCXQ BX, R8
+       ADOXQ 240(AX), R8
+       MOVQ  R8, 240(AX)
+
+       // Iteration 31
+       MULXQ 248(CX), R8, BX
+       ADCXQ DI, R8
+       ADOXQ 248(AX), R8
+       MOVQ  R8, 248(AX)
+
+       // Add back carry flags and return
+       ADCXQ SI, BX
+       ADOXQ SI, BX
+       MOVQ  BX, c+24(FP)
        RET
diff --git a/src/crypto/internal/bigmod/nat_arm.s b/src/crypto/internal/bigmod/nat_arm.s
new file mode 100644 (file)
index 0000000..c7397b8
--- /dev/null
@@ -0,0 +1,47 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !purego
+
+#include "textflag.h"
+
+// func addMulVVW1024(z, x *uint, y uint) (c uint)
+TEXT ·addMulVVW1024(SB), $0-16
+       MOVW    $32, R5
+       JMP             addMulVVWx(SB)
+
+// func addMulVVW1536(z, x *uint, y uint) (c uint)
+TEXT ·addMulVVW1536(SB), $0-16
+       MOVW    $48, R5
+       JMP             addMulVVWx(SB)
+
+// func addMulVVW2048(z, x *uint, y uint) (c uint)
+TEXT ·addMulVVW2048(SB), $0-16
+       MOVW    $64, R5
+       JMP             addMulVVWx(SB)
+
+TEXT addMulVVWx(SB), NOFRAME|NOSPLIT, $0
+       MOVW    $0, R0
+       MOVW    z+0(FP), R1
+       MOVW    x+4(FP), R2
+       MOVW    y+8(FP), R3
+       ADD     R5<<2, R1, R5
+       MOVW    $0, R4
+       B E9
+
+L9:    MOVW.P  4(R2), R6
+       MULLU   R6, R3, (R7, R6)
+       ADD.S   R4, R6
+       ADC     R0, R7
+       MOVW    0(R1), R4
+       ADD.S   R4, R6
+       ADC     R0, R7
+       MOVW.P  R6, 4(R1)
+       MOVW    R7, R4
+
+E9:    TEQ     R1, R5
+       BNE     L9
+
+       MOVW    R4, c+12(FP)
+       RET
diff --git a/src/crypto/internal/bigmod/nat_arm64.s b/src/crypto/internal/bigmod/nat_arm64.s
new file mode 100644 (file)
index 0000000..ba1e611
--- /dev/null
@@ -0,0 +1,69 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !purego
+
+#include "textflag.h"
+
+// func addMulVVW1024(z, x *uint, y uint) (c uint)
+TEXT ·addMulVVW1024(SB), $0-32
+       MOVD    $16, R0
+       JMP             addMulVVWx(SB)
+
+// func addMulVVW1536(z, x *uint, y uint) (c uint)
+TEXT ·addMulVVW1536(SB), $0-32
+       MOVD    $24, R0
+       JMP             addMulVVWx(SB)
+
+// func addMulVVW2048(z, x *uint, y uint) (c uint)
+TEXT ·addMulVVW2048(SB), $0-32
+       MOVD    $32, R0
+       JMP             addMulVVWx(SB)
+
+TEXT addMulVVWx(SB), NOFRAME|NOSPLIT, $0
+       MOVD    z+0(FP), R1
+       MOVD    x+8(FP), R2
+       MOVD    y+16(FP), R3
+       MOVD    $0, R4
+
+// The main loop of this code operates on a block of 4 words every iteration
+// performing [R4:R12:R11:R10:R9] = R4 + R3 * [R8:R7:R6:R5] + [R12:R11:R10:R9]
+// where R4 is carried from the previous iteration, R8:R7:R6:R5 hold the next
+// 4 words of x, R3 is y and R12:R11:R10:R9 are part of the result z.
+loop:
+       CBZ     R0, done
+
+       LDP.P   16(R2), (R5, R6)
+       LDP.P   16(R2), (R7, R8)
+
+       LDP     (R1), (R9, R10)
+       ADDS    R4, R9
+       MUL     R6, R3, R14
+       ADCS    R14, R10
+       MUL     R7, R3, R15
+       LDP     16(R1), (R11, R12)
+       ADCS    R15, R11
+       MUL     R8, R3, R16
+       ADCS    R16, R12
+       UMULH   R8, R3, R20
+       ADC     $0, R20
+
+       MUL     R5, R3, R13
+       ADDS    R13, R9
+       UMULH   R5, R3, R17
+       ADCS    R17, R10
+       UMULH   R6, R3, R21
+       STP.P   (R9, R10), 16(R1)
+       ADCS    R21, R11
+       UMULH   R7, R3, R19
+       ADCS    R19, R12
+       STP.P   (R11, R12), 16(R1)
+       ADC     $0, R20, R4
+
+       SUB     $4, R0
+       B       loop
+
+done:
+       MOVD    R4, c+24(FP)
+       RET
diff --git a/src/crypto/internal/bigmod/nat_asm.go b/src/crypto/internal/bigmod/nat_asm.go
new file mode 100644 (file)
index 0000000..5eb91e1
--- /dev/null
@@ -0,0 +1,28 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !purego && (386 || amd64 || arm || arm64 || ppc64 || ppc64le || s390x)
+
+package bigmod
+
+import "internal/cpu"
+
+// amd64 assembly uses ADCX/ADOX/MULX if ADX is available to run two carry
+// chains in the flags in parallel across the whole operation, and aggressively
+// unrolls loops. arm64 processes four words at a time.
+//
+// It's unclear why the assembly for all other architectures, as well as for
+// amd64 without ADX, perform better than the compiler output.
+// TODO(filippo): file cmd/compile performance issue.
+
+var supportADX = cpu.X86.HasADX && cpu.X86.HasBMI2
+
+//go:noescape
+func addMulVVW1024(z, x *uint, y uint) (c uint)
+
+//go:noescape
+func addMulVVW1536(z, x *uint, y uint) (c uint)
+
+//go:noescape
+func addMulVVW2048(z, x *uint, y uint) (c uint)
index 870b44519d9b67fd377c1d3a7868bf8ec7c0ead4..eff12536f910c689b04a1165c026eadc4e3180d8 100644 (file)
@@ -1,11 +1,21 @@
-// Copyright 2022 The Go Authors. All rights reserved.
+// Copyright 2023 The Go Authors. All rights reserved.
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-//go:build !amd64 || !gc || purego
+//go:build purego || !(386 || amd64 || arm || arm64 || ppc64 || ppc64le || s390x)
 
 package bigmod
 
-func montgomeryLoop(d, a, b, m []uint, m0inv uint) uint {
-       return montgomeryLoopGeneric(d, a, b, m, m0inv)
+import "unsafe"
+
+func addMulVVW1024(z, x *uint, y uint) (c uint) {
+       return addMulVVW(unsafe.Slice(z, 1024/_W), unsafe.Slice(x, 1024/_W), y)
+}
+
+func addMulVVW1536(z, x *uint, y uint) (c uint) {
+       return addMulVVW(unsafe.Slice(z, 1536/_W), unsafe.Slice(x, 1536/_W), y)
+}
+
+func addMulVVW2048(z, x *uint, y uint) (c uint) {
+       return addMulVVW(unsafe.Slice(z, 2048/_W), unsafe.Slice(x, 2048/_W), y)
 }
diff --git a/src/crypto/internal/bigmod/nat_ppc64x.s b/src/crypto/internal/bigmod/nat_ppc64x.s
new file mode 100644 (file)
index 0000000..974f4f9
--- /dev/null
@@ -0,0 +1,51 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !purego && (ppc64 || ppc64le)
+
+#include "textflag.h"
+
+// func addMulVVW1024(z, x *uint, y uint) (c uint)
+TEXT ·addMulVVW1024(SB), $0-32
+       MOVD    $16, R22 // R22 = z_len
+       JMP             addMulVVWx(SB)
+
+// func addMulVVW1536(z, x *uint, y uint) (c uint)
+TEXT ·addMulVVW1536(SB), $0-32
+       MOVD    $24, R22 // R22 = z_len
+       JMP             addMulVVWx(SB)
+
+// func addMulVVW2048(z, x *uint, y uint) (c uint)
+TEXT ·addMulVVW2048(SB), $0-32
+       MOVD    $32, R22 // R22 = z_len
+       JMP             addMulVVWx(SB)
+
+TEXT addMulVVWx(SB), NOFRAME|NOSPLIT, $0
+       MOVD z+0(FP), R10       // R10 = z[]
+       MOVD x+8(FP), R8        // R8 = x[]
+       MOVD y+16(FP), R9       // R9 = y
+
+       MOVD R0, R3             // R3 will be the index register
+       CMP  R0, R22
+       MOVD R0, R4             // R4 = c = 0
+       MOVD R22, CTR           // Initialize loop counter
+       BEQ  done
+       PCALIGN $16
+
+loop:
+       MOVD  (R8)(R3), R20     // Load x[i]
+       MOVD  (R10)(R3), R21    // Load z[i]
+       MULLD  R9, R20, R6      // R6 = Low-order(x[i]*y)
+       MULHDU R9, R20, R7      // R7 = High-order(x[i]*y)
+       ADDC   R21, R6          // R6 = z0
+       ADDZE  R7               // R7 = z1
+       ADDC   R4, R6           // R6 = z0 + c + 0
+       ADDZE  R7, R4           // c += z1
+       MOVD   R6, (R10)(R3)    // Store z[i]
+       ADD    $8, R3
+       BC  16, 0, loop         // bdnz
+
+done:
+       MOVD R4, c+24(FP)
+       RET
diff --git a/src/crypto/internal/bigmod/nat_s390x.s b/src/crypto/internal/bigmod/nat_s390x.s
new file mode 100644 (file)
index 0000000..0c07a0c
--- /dev/null
@@ -0,0 +1,85 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !purego
+
+#include "textflag.h"
+
+// func addMulVVW1024(z, x *uint, y uint) (c uint)
+TEXT ·addMulVVW1024(SB), $0-32
+       MOVD    $16, R5
+       JMP             addMulVVWx(SB)
+
+// func addMulVVW1536(z, x *uint, y uint) (c uint)
+TEXT ·addMulVVW1536(SB), $0-32
+       MOVD    $24, R5
+       JMP             addMulVVWx(SB)
+
+// func addMulVVW2048(z, x *uint, y uint) (c uint)
+TEXT ·addMulVVW2048(SB), $0-32
+       MOVD    $32, R5
+       JMP             addMulVVWx(SB)
+
+TEXT addMulVVWx(SB), NOFRAME|NOSPLIT, $0
+       MOVD z+0(FP), R2
+       MOVD x+8(FP), R8
+       MOVD y+16(FP), R9
+
+       MOVD $0, R1 // i*8 = 0
+       MOVD $0, R7 // i = 0
+       MOVD $0, R0 // make sure it's zero
+       MOVD $0, R4 // c = 0
+
+       MOVD   R5, R12
+       AND    $-2, R12
+       CMPBGE R5, $2, A6
+       BR     E6
+
+A6:
+       MOVD   (R8)(R1*1), R6
+       MULHDU R9, R6
+       MOVD   (R2)(R1*1), R10
+       ADDC   R10, R11        // add to low order bits
+       ADDE   R0, R6
+       ADDC   R4, R11
+       ADDE   R0, R6
+       MOVD   R6, R4
+       MOVD   R11, (R2)(R1*1)
+
+       MOVD   (8)(R8)(R1*1), R6
+       MULHDU R9, R6
+       MOVD   (8)(R2)(R1*1), R10
+       ADDC   R10, R11           // add to low order bits
+       ADDE   R0, R6
+       ADDC   R4, R11
+       ADDE   R0, R6
+       MOVD   R6, R4
+       MOVD   R11, (8)(R2)(R1*1)
+
+       ADD $16, R1 // i*8 + 8
+       ADD $2, R7  // i++
+
+       CMPBLT R7, R12, A6
+       BR     E6
+
+L6:
+       // TODO: drop unused single-step loop.
+       MOVD   (R8)(R1*1), R6
+       MULHDU R9, R6
+       MOVD   (R2)(R1*1), R10
+       ADDC   R10, R11        // add to low order bits
+       ADDE   R0, R6
+       ADDC   R4, R11
+       ADDE   R0, R6
+       MOVD   R6, R4
+       MOVD   R11, (R2)(R1*1)
+
+       ADD $8, R1 // i*8 + 8
+       ADD $1, R7 // i++
+
+E6:
+       CMPBLT R7, R5, L6 // i < n
+
+       MOVD R4, c+24(FP)
+       RET
index 4593a2e4932795aee0b58c91e7e53c5bbafb3813..cc5ffe7bb7f5e4e52b82ef281f15cade551d6cc3 100644 (file)
@@ -5,14 +5,24 @@
 package bigmod
 
 import (
+       "fmt"
        "math/big"
        "math/bits"
        "math/rand"
        "reflect"
+       "strings"
        "testing"
        "testing/quick"
 )
 
+func (n *Nat) String() string {
+       var limbs []string
+       for i := range n.limbs {
+               limbs = append(limbs, fmt.Sprintf("%016X", n.limbs[len(n.limbs)-1-i]))
+       }
+       return "{" + strings.Join(limbs, " ") + "}"
+}
+
 // Generate generates an even nat. It's used by testing/quick to produce random
 // *nat values for quick.Check invocations.
 func (*Nat) Generate(r *rand.Rand, size int) reflect.Value {
@@ -54,21 +64,23 @@ func TestModSubThenAddIdentity(t *testing.T) {
        }
 }
 
-func testMontgomeryRoundtrip(a *Nat) bool {
-       one := &Nat{make([]uint, len(a.limbs))}
-       one.limbs[0] = 1
-       aPlusOne := new(big.Int).SetBytes(natBytes(a))
-       aPlusOne.Add(aPlusOne, big.NewInt(1))
-       m := NewModulusFromBig(aPlusOne)
-       monty := new(Nat).set(a)
-       monty.montgomeryRepresentation(m)
-       aAgain := new(Nat).set(monty)
-       aAgain.montgomeryMul(monty, one, m)
-       return a.Equal(aAgain) == 1
-}
-
 func TestMontgomeryRoundtrip(t *testing.T) {
-       err := quick.Check(testMontgomeryRoundtrip, &quick.Config{})
+       err := quick.Check(func(a *Nat) bool {
+               one := &Nat{make([]uint, len(a.limbs))}
+               one.limbs[0] = 1
+               aPlusOne := new(big.Int).SetBytes(natBytes(a))
+               aPlusOne.Add(aPlusOne, big.NewInt(1))
+               m := NewModulusFromBig(aPlusOne)
+               monty := new(Nat).set(a)
+               monty.montgomeryRepresentation(m)
+               aAgain := new(Nat).set(monty)
+               aAgain.montgomeryMul(monty, one, m)
+               if a.Equal(aAgain) != 1 {
+                       t.Errorf("%v != %v", a, aAgain)
+                       return false
+               }
+               return true
+       }, &quick.Config{})
        if err != nil {
                t.Error(err)
        }
@@ -84,30 +96,30 @@ func TestShiftIn(t *testing.T) {
        }{{
                m:        []byte{13},
                x:        []byte{0},
-               y:        0x7FFF_FFFF_FFFF_FFFF,
-               expected: []byte{7},
+               y:        0xFFFF_FFFF_FFFF_FFFF,
+               expected: []byte{2},
        }, {
                m:        []byte{13},
                x:        []byte{7},
-               y:        0x7FFF_FFFF_FFFF_FFFF,
-               expected: []byte{11},
+               y:        0xFFFF_FFFF_FFFF_FFFF,
+               expected: []byte{10},
        }, {
                m:        []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
                x:        make([]byte, 9),
-               y:        0x7FFF_FFFF_FFFF_FFFF,
-               expected: []byte{0x00, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+               y:        0xFFFF_FFFF_FFFF_FFFF,
+               expected: []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
        }, {
                m:        []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
-               x:        []byte{0x00, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+               x:        []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
                y:        0,
-               expected: []byte{0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08},
+               expected: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06},
        }}
 
        for i, tt := range examples {
                m := modulusFromBytes(tt.m)
                got := natFromBytes(tt.x).ExpandFor(m).shiftIn(uint(tt.y), m)
-               if got.Equal(natFromBytes(tt.expected).ExpandFor(m)) != 1 {
-                       t.Errorf("%d: got %x, expected %x", i, got, tt.expected)
+               if exp := natFromBytes(tt.expected).ExpandFor(m); got.Equal(exp) != 1 {
+                       t.Errorf("%d: got %v, expected %v", i, got, exp)
                }
        }
 }
@@ -186,7 +198,7 @@ func TestSetBytes(t *testing.T) {
                        continue
                }
                if expected := natFromBytes(tt.b).ExpandFor(m); got.Equal(expected) != yes {
-                       t.Errorf("%d: got %x, expected %x", i, got, expected)
+                       t.Errorf("%d: got %v, expected %v", i, got, expected)
                }
        }
 
@@ -228,7 +240,7 @@ func TestExpand(t *testing.T) {
        for i, tt := range examples {
                got := (&Nat{tt.in}).expand(tt.n)
                if len(got.limbs) != len(tt.out) || got.Equal(&Nat{tt.out}) != 1 {
-                       t.Errorf("%d: got %x, expected %x", i, got, tt.out)
+                       t.Errorf("%d: got %v, expected %v", i, got, tt.out)
                }
        }
 }
@@ -287,11 +299,40 @@ func TestExp(t *testing.T) {
        }
 }
 
+// TestMulReductions tests that Mul reduces results equal or slightly greater
+// than the modulus. Some Montgomery algorithms don't and need extra care to
+// return correct results. See https://go.dev/issue/13907.
+func TestMulReductions(t *testing.T) {
+       // Two short but multi-limb primes.
+       a, _ := new(big.Int).SetString("773608962677651230850240281261679752031633236267106044359907", 10)
+       b, _ := new(big.Int).SetString("180692823610368451951102211649591374573781973061758082626801", 10)
+       n := new(big.Int).Mul(a, b)
+
+       N := NewModulusFromBig(n)
+       A := NewNat().setBig(a).ExpandFor(N)
+       B := NewNat().setBig(b).ExpandFor(N)
+
+       if A.Mul(B, N).IsZero() != 1 {
+               t.Error("a * b mod (a * b) != 0")
+       }
+
+       i := new(big.Int).ModInverse(a, b)
+       N = NewModulusFromBig(b)
+       A = NewNat().setBig(a).ExpandFor(N)
+       I := NewNat().setBig(i).ExpandFor(N)
+       one := NewNat().setBig(big.NewInt(1)).ExpandFor(N)
+
+       if A.Mul(I, N).Equal(one) != 1 {
+               t.Error("a * inv(a) mod b != 1")
+       }
+}
+
 func natBytes(n *Nat) []byte {
        return n.Bytes(maxModulus(uint(len(n.limbs))))
 }
 
 func natFromBytes(b []byte) *Nat {
+       // Must not use Nat.SetBytes as it's used in TestSetBytes.
        bb := new(big.Int).SetBytes(b)
        return NewNat().setBig(bb)
 }
@@ -316,7 +357,7 @@ func makeBenchmarkModulus() *Modulus {
 func makeBenchmarkValue() *Nat {
        x := make([]uint, 32)
        for i := 0; i < 32; i++ {
-               x[i] = _MASK - 1
+               x[i]--
        }
        return &Nat{limbs: x}
 }