]> Cypherpunks.ru repositories - gostls13.git/commitdiff
cmd/compile: improve FP FMA performance on riscv64
authorMeng Zhuo <mzh@golangcn.org>
Mon, 26 Jun 2023 12:46:49 +0000 (20:46 +0800)
committerM Zhuo <mzh@golangcn.org>
Tue, 22 Aug 2023 08:38:08 +0000 (08:38 +0000)
FMADD/FMSUB/FNSUB are an efficient FP FMA instructions, which can
be used by the compiler to improve FP performance.

Erf               188.0n ± 2%   139.5n ± 2%  -25.82% (p=0.000 n=10)
Erfc              193.6n ± 1%   143.2n ± 1%  -26.01% (p=0.000 n=10)
Erfinv            244.4n ± 2%   172.6n ± 0%  -29.40% (p=0.000 n=10)
Erfcinv           244.7n ± 2%   173.0n ± 1%  -29.31% (p=0.000 n=10)
geomean           216.0n        156.3n       -27.65%

Ref: The RISC-V Instruction Set Manual Volume I: Unprivileged ISA
11.6 Single-Precision Floating-Point Computational Instructions

Change-Id: I89aa3a4df7576fdd47f4a6ee608ac16feafd093c
Reviewed-on: https://go-review.googlesource.com/c/go/+/506036
Reviewed-by: Joel Sing <joel@sing.id.au>
Run-TryBot: M Zhuo <mzh@golangcn.org>
Reviewed-by: David Chase <drchase@google.com>
Reviewed-by: Keith Randall <khr@golang.org>
Reviewed-by: Keith Randall <khr@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>

src/cmd/compile/internal/riscv64/ssa.go
src/cmd/compile/internal/ssa/_gen/RISCV64.rules
src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go
src/cmd/compile/internal/ssa/opGen.go
src/cmd/compile/internal/ssa/rewriteRISCV64.go
test/codegen/floats.go

index 2eb1e7ffa017d3a67f305e9d336456973c45b354..143e7c525a3f09ab7e07decbe1468c6b890224d2 100644 (file)
@@ -694,6 +694,9 @@ func ssaGenValue(s *ssagen.State, v *ssa.Value) {
                p.To.Sym = ir.Syms.Duffcopy
                p.To.Offset = v.AuxInt
 
+       case ssa.OpRISCV64LoweredRound32F, ssa.OpRISCV64LoweredRound64F:
+               // input is already rounded
+
        case ssa.OpClobber, ssa.OpClobberReg:
                // TODO: implement for clobberdead experiment. Nop is ok for now.
 
index 181b46a7ced98db8c96a51c03a8629acfa6dca12..ac68dfed76e2083cfef063687c80ac98eb73085d 100644 (file)
 
 (CvtBoolToUint8 ...) => (Copy ...)
 
-(Round(64|32)F ...) => (Copy ...)
+(Round(32|64)F ...) => (LoweredRound(32|64)F ...)
 
 (Slicemask <t> x) => (SRAI [63] (NEG <t> x))
 
 (Select0 m:(LoweredMuluhilo x y)) && m.Uses == 1 => (MULHU x y)
 (Select1 m:(LoweredMuluhilo x y)) && m.Uses == 1 => (MUL x y)
 
+(FADDD a (FMULD x y)) && a.Block.Func.useFMA(v) => (FMADDD x y a)
+(FSUBD a (FMULD x y)) && a.Block.Func.useFMA(v) => (FNMSUBD x y a)
+(FSUBD (FMULD x y) a) && a.Block.Func.useFMA(v) => (FMSUBD x y a)
 // Merge negation into fused multiply-add and multiply-subtract.
 //
 // Key:
index 52e87cbe72c17d1c5d78d28208a69b14c6177e96..69f2950a88fe0638942c745c3588ee3706b2397a 100644 (file)
@@ -237,6 +237,10 @@ func init() {
                // gets correctly ordered with respect to GC safepoints.
                {name: "MOVconvert", argLength: 2, reg: gp11, asm: "MOV"}, // arg0, but converted to int/ptr as appropriate; arg1=mem
 
+               // Round ops to block fused-multiply-add extraction.
+               {name: "LoweredRound32F", argLength: 1, reg: fp11, resultInArg0: true},
+               {name: "LoweredRound64F", argLength: 1, reg: fp11, resultInArg0: true},
+
                // Calls
                {name: "CALLstatic", argLength: -1, reg: call, aux: "CallOff", call: true},               // call static function aux.(*gc.Sym). last arg=mem, auxint=argsize, returns mem
                {name: "CALLtail", argLength: -1, reg: call, aux: "CallOff", call: true, tailCall: true}, // tail call static function aux.(*gc.Sym). last arg=mem, auxint=argsize, returns mem
index b599df05259cd20f4e2d4ac7cb0087e78b2ec4e0..12d8214ae1dd5c21ea59aa21d2c0bfd6711101e2 100644 (file)
@@ -2400,6 +2400,8 @@ const (
        OpRISCV64SLTU
        OpRISCV64SLTIU
        OpRISCV64MOVconvert
+       OpRISCV64LoweredRound32F
+       OpRISCV64LoweredRound64F
        OpRISCV64CALLstatic
        OpRISCV64CALLtail
        OpRISCV64CALLclosure
@@ -32198,6 +32200,32 @@ var opcodeTable = [...]opInfo{
                        },
                },
        },
+       {
+               name:         "LoweredRound32F",
+               argLen:       1,
+               resultInArg0: true,
+               reg: regInfo{
+                       inputs: []inputInfo{
+                               {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                       },
+                       outputs: []outputInfo{
+                               {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                       },
+               },
+       },
+       {
+               name:         "LoweredRound64F",
+               argLen:       1,
+               resultInArg0: true,
+               reg: regInfo{
+                       inputs: []inputInfo{
+                               {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                       },
+                       outputs: []outputInfo{
+                               {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                       },
+               },
+       },
        {
                name:    "CALLstatic",
                auxType: auxCallOff,
index e8002599efe88e5a15af668c3205487137acb2b7..17af023db34de344bf10588387106086e7989e73 100644 (file)
@@ -440,6 +440,8 @@ func rewriteValueRISCV64(v *Value) bool {
                return rewriteValueRISCV64_OpRISCV64AND(v)
        case OpRISCV64ANDI:
                return rewriteValueRISCV64_OpRISCV64ANDI(v)
+       case OpRISCV64FADDD:
+               return rewriteValueRISCV64_OpRISCV64FADDD(v)
        case OpRISCV64FMADDD:
                return rewriteValueRISCV64_OpRISCV64FMADDD(v)
        case OpRISCV64FMSUBD:
@@ -448,6 +450,8 @@ func rewriteValueRISCV64(v *Value) bool {
                return rewriteValueRISCV64_OpRISCV64FNMADDD(v)
        case OpRISCV64FNMSUBD:
                return rewriteValueRISCV64_OpRISCV64FNMSUBD(v)
+       case OpRISCV64FSUBD:
+               return rewriteValueRISCV64_OpRISCV64FSUBD(v)
        case OpRISCV64MOVBUload:
                return rewriteValueRISCV64_OpRISCV64MOVBUload(v)
        case OpRISCV64MOVBUreg:
@@ -541,10 +545,10 @@ func rewriteValueRISCV64(v *Value) bool {
        case OpRotateLeft8:
                return rewriteValueRISCV64_OpRotateLeft8(v)
        case OpRound32F:
-               v.Op = OpCopy
+               v.Op = OpRISCV64LoweredRound32F
                return true
        case OpRound64F:
-               v.Op = OpCopy
+               v.Op = OpRISCV64LoweredRound64F
                return true
        case OpRsh16Ux16:
                return rewriteValueRISCV64_OpRsh16Ux16(v)
@@ -3335,6 +3339,31 @@ func rewriteValueRISCV64_OpRISCV64ANDI(v *Value) bool {
        }
        return false
 }
+func rewriteValueRISCV64_OpRISCV64FADDD(v *Value) bool {
+       v_1 := v.Args[1]
+       v_0 := v.Args[0]
+       // match: (FADDD a (FMULD x y))
+       // cond: a.Block.Func.useFMA(v)
+       // result: (FMADDD x y a)
+       for {
+               for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+                       a := v_0
+                       if v_1.Op != OpRISCV64FMULD {
+                               continue
+                       }
+                       y := v_1.Args[1]
+                       x := v_1.Args[0]
+                       if !(a.Block.Func.useFMA(v)) {
+                               continue
+                       }
+                       v.reset(OpRISCV64FMADDD)
+                       v.AddArg3(x, y, a)
+                       return true
+               }
+               break
+       }
+       return false
+}
 func rewriteValueRISCV64_OpRISCV64FMADDD(v *Value) bool {
        v_2 := v.Args[2]
        v_1 := v.Args[1]
@@ -3515,6 +3544,45 @@ func rewriteValueRISCV64_OpRISCV64FNMSUBD(v *Value) bool {
        }
        return false
 }
+func rewriteValueRISCV64_OpRISCV64FSUBD(v *Value) bool {
+       v_1 := v.Args[1]
+       v_0 := v.Args[0]
+       // match: (FSUBD a (FMULD x y))
+       // cond: a.Block.Func.useFMA(v)
+       // result: (FNMSUBD x y a)
+       for {
+               a := v_0
+               if v_1.Op != OpRISCV64FMULD {
+                       break
+               }
+               y := v_1.Args[1]
+               x := v_1.Args[0]
+               if !(a.Block.Func.useFMA(v)) {
+                       break
+               }
+               v.reset(OpRISCV64FNMSUBD)
+               v.AddArg3(x, y, a)
+               return true
+       }
+       // match: (FSUBD (FMULD x y) a)
+       // cond: a.Block.Func.useFMA(v)
+       // result: (FMSUBD x y a)
+       for {
+               if v_0.Op != OpRISCV64FMULD {
+                       break
+               }
+               y := v_0.Args[1]
+               x := v_0.Args[0]
+               a := v_1
+               if !(a.Block.Func.useFMA(v)) {
+                       break
+               }
+               v.reset(OpRISCV64FMSUBD)
+               v.AddArg3(x, y, a)
+               return true
+       }
+       return false
+}
 func rewriteValueRISCV64_OpRISCV64MOVBUload(v *Value) bool {
        v_1 := v.Args[1]
        v_0 := v.Args[0]
index 9cb62e031a73f7db67e436e40744c8e47e4f478b..1c5fc8a31a99a3f7d620946e8b09533d20831951 100644 (file)
@@ -88,17 +88,20 @@ func FusedAdd64(x, y, z float64) float64 {
        // s390x:"FMADD\t"
        // ppc64x:"FMADD\t"
        // arm64:"FMADDD"
+       // riscv64:"FMADDD\t"
        return x*y + z
 }
 
 func FusedSub64_a(x, y, z float64) float64 {
        // s390x:"FMSUB\t"
        // ppc64x:"FMSUB\t"
+       // riscv64:"FMSUBD\t"
        return x*y - z
 }
 
 func FusedSub64_b(x, y, z float64) float64 {
        // arm64:"FMSUBD"
+       // riscv64:"FNMSUBD\t"
        return z - x*y
 }