]> Cypherpunks.ru repositories - gostls13.git/commitdiff
cmd/compile: add single-precision FMA code generation for riscv64
authorMeng Zhuo <mzh@golangcn.org>
Wed, 28 Jun 2023 08:45:07 +0000 (16:45 +0800)
committerM Zhuo <mzh@golangcn.org>
Tue, 22 Aug 2023 12:05:36 +0000 (12:05 +0000)
This CL adds FMADDS,FMSUBS,FNMADDS,FNMSUBS SSA support for riscv

Change-Id: I1e7dd322b46b9e0f4923dbba256303d69ed12066
Reviewed-on: https://go-review.googlesource.com/c/go/+/506616
Reviewed-by: Joel Sing <joel@sing.id.au>
Reviewed-by: David Chase <drchase@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Keith Randall <khr@google.com>
Run-TryBot: M Zhuo <mzh@golangcn.org>

src/cmd/compile/internal/riscv64/ssa.go
src/cmd/compile/internal/ssa/_gen/RISCV64.rules
src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go
src/cmd/compile/internal/ssa/opGen.go
src/cmd/compile/internal/ssa/rewriteRISCV64.go
test/codegen/floats.go

index 143e7c525a3f09ab7e07decbe1468c6b890224d2..f8cf786920474e1def1a5c648fee0df835306c10 100644 (file)
@@ -332,7 +332,8 @@ func ssaGenValue(s *ssagen.State, v *ssa.Value) {
                p2.From.Reg = v.Reg1()
                p2.To.Type = obj.TYPE_REG
                p2.To.Reg = v.Reg1()
-       case ssa.OpRISCV64FMADDD, ssa.OpRISCV64FMSUBD, ssa.OpRISCV64FNMADDD, ssa.OpRISCV64FNMSUBD:
+       case ssa.OpRISCV64FMADDD, ssa.OpRISCV64FMSUBD, ssa.OpRISCV64FNMADDD, ssa.OpRISCV64FNMSUBD,
+               ssa.OpRISCV64FMADDS, ssa.OpRISCV64FMSUBS, ssa.OpRISCV64FNMADDS, ssa.OpRISCV64FNMSUBS:
                r := v.Reg()
                r1 := v.Args[0].Reg()
                r2 := v.Args[1].Reg()
index ac68dfed76e2083cfef063687c80ac98eb73085d..e0bf00d45d3a2984e7efbaa19f26e8b92e46e21a 100644 (file)
 (Select0 m:(LoweredMuluhilo x y)) && m.Uses == 1 => (MULHU x y)
 (Select1 m:(LoweredMuluhilo x y)) && m.Uses == 1 => (MUL x y)
 
-(FADDD a (FMULD x y)) && a.Block.Func.useFMA(v) => (FMADDD x y a)
-(FSUBD a (FMULD x y)) && a.Block.Func.useFMA(v) => (FNMSUBD x y a)
-(FSUBD (FMULD x y) a) && a.Block.Func.useFMA(v) => (FMSUBD x y a)
+(FADD(S|D) a (FMUL(S|D) x y)) && a.Block.Func.useFMA(v) => (FMADD(S|D) x y a)
+(FSUB(S|D) a (FMUL(S|D) x y)) && a.Block.Func.useFMA(v) => (FNMSUB(S|D) x y a)
+(FSUB(S|D) (FMUL(S|D) x y) a) && a.Block.Func.useFMA(v) => (FMSUB(S|D) x y a)
+
 // Merge negation into fused multiply-add and multiply-subtract.
 //
 // Key:
 //                D B
 //
 // Note: multiplication commutativity handled by rule generator.
+(F(MADD|NMADD|MSUB|NMSUB)S neg:(FNEGS x) y z) && neg.Uses == 1 => (F(NMSUB|MSUB|NMADD|MADD)S x y z)
+(F(MADD|NMADD|MSUB|NMSUB)S x y neg:(FNEGS z)) && neg.Uses == 1 => (F(MSUB|NMSUB|MADD|NMADD)S x y z)
 (F(MADD|NMADD|MSUB|NMSUB)D neg:(FNEGD x) y z) && neg.Uses == 1 => (F(NMSUB|MSUB|NMADD|MADD)D x y z)
 (F(MADD|NMADD|MSUB|NMSUB)D x y neg:(FNEGD z)) && neg.Uses == 1 => (F(MSUB|NMSUB|MADD|NMADD)D x y z)
index 69f2950a88fe0638942c745c3588ee3706b2397a..317e9150c9e756fc8de83d34b40633d17eea1362 100644 (file)
@@ -411,6 +411,10 @@ func init() {
                {name: "FSUBS", argLength: 2, reg: fp21, asm: "FSUBS", commutative: false, typ: "Float32"},                                          // arg0 - arg1
                {name: "FMULS", argLength: 2, reg: fp21, asm: "FMULS", commutative: true, typ: "Float32"},                                           // arg0 * arg1
                {name: "FDIVS", argLength: 2, reg: fp21, asm: "FDIVS", commutative: false, typ: "Float32"},                                          // arg0 / arg1
+               {name: "FMADDS", argLength: 3, reg: fp31, asm: "FMADDS", commutative: true, typ: "Float32"},                                         // (arg0 * arg1) + arg2
+               {name: "FMSUBS", argLength: 3, reg: fp31, asm: "FMSUBS", commutative: true, typ: "Float32"},                                         // (arg0 * arg1) - arg2
+               {name: "FNMADDS", argLength: 3, reg: fp31, asm: "FNMADDS", commutative: true, typ: "Float32"},                                       // -(arg0 * arg1) + arg2
+               {name: "FNMSUBS", argLength: 3, reg: fp31, asm: "FNMSUBS", commutative: true, typ: "Float32"},                                       // -(arg0 * arg1) - arg2
                {name: "FSQRTS", argLength: 1, reg: fp11, asm: "FSQRTS", typ: "Float32"},                                                            // sqrt(arg0)
                {name: "FNEGS", argLength: 1, reg: fp11, asm: "FNEGS", typ: "Float32"},                                                              // -arg0
                {name: "FMVSX", argLength: 1, reg: gpfp, asm: "FMVSX", typ: "Float32"},                                                              // reinterpret arg0 as float
index 12d8214ae1dd5c21ea59aa21d2c0bfd6711101e2..11a61383570e9945acabf99c9d83948f40485ff3 100644 (file)
@@ -2436,6 +2436,10 @@ const (
        OpRISCV64FSUBS
        OpRISCV64FMULS
        OpRISCV64FDIVS
+       OpRISCV64FMADDS
+       OpRISCV64FMSUBS
+       OpRISCV64FNMADDS
+       OpRISCV64FNMSUBS
        OpRISCV64FSQRTS
        OpRISCV64FNEGS
        OpRISCV64FMVSX
@@ -32673,6 +32677,70 @@ var opcodeTable = [...]opInfo{
                        },
                },
        },
+       {
+               name:        "FMADDS",
+               argLen:      3,
+               commutative: true,
+               asm:         riscv.AFMADDS,
+               reg: regInfo{
+                       inputs: []inputInfo{
+                               {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                               {1, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                               {2, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                       },
+                       outputs: []outputInfo{
+                               {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                       },
+               },
+       },
+       {
+               name:        "FMSUBS",
+               argLen:      3,
+               commutative: true,
+               asm:         riscv.AFMSUBS,
+               reg: regInfo{
+                       inputs: []inputInfo{
+                               {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                               {1, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                               {2, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                       },
+                       outputs: []outputInfo{
+                               {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                       },
+               },
+       },
+       {
+               name:        "FNMADDS",
+               argLen:      3,
+               commutative: true,
+               asm:         riscv.AFNMADDS,
+               reg: regInfo{
+                       inputs: []inputInfo{
+                               {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                               {1, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                               {2, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                       },
+                       outputs: []outputInfo{
+                               {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                       },
+               },
+       },
+       {
+               name:        "FNMSUBS",
+               argLen:      3,
+               commutative: true,
+               asm:         riscv.AFNMSUBS,
+               reg: regInfo{
+                       inputs: []inputInfo{
+                               {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                               {1, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                               {2, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                       },
+                       outputs: []outputInfo{
+                               {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                       },
+               },
+       },
        {
                name:   "FSQRTS",
                argLen: 1,
index 17af023db34de344bf10588387106086e7989e73..0ad6433bf4c05216862e1bfe33defa5c1aef117e 100644 (file)
@@ -442,16 +442,28 @@ func rewriteValueRISCV64(v *Value) bool {
                return rewriteValueRISCV64_OpRISCV64ANDI(v)
        case OpRISCV64FADDD:
                return rewriteValueRISCV64_OpRISCV64FADDD(v)
+       case OpRISCV64FADDS:
+               return rewriteValueRISCV64_OpRISCV64FADDS(v)
        case OpRISCV64FMADDD:
                return rewriteValueRISCV64_OpRISCV64FMADDD(v)
+       case OpRISCV64FMADDS:
+               return rewriteValueRISCV64_OpRISCV64FMADDS(v)
        case OpRISCV64FMSUBD:
                return rewriteValueRISCV64_OpRISCV64FMSUBD(v)
+       case OpRISCV64FMSUBS:
+               return rewriteValueRISCV64_OpRISCV64FMSUBS(v)
        case OpRISCV64FNMADDD:
                return rewriteValueRISCV64_OpRISCV64FNMADDD(v)
+       case OpRISCV64FNMADDS:
+               return rewriteValueRISCV64_OpRISCV64FNMADDS(v)
        case OpRISCV64FNMSUBD:
                return rewriteValueRISCV64_OpRISCV64FNMSUBD(v)
+       case OpRISCV64FNMSUBS:
+               return rewriteValueRISCV64_OpRISCV64FNMSUBS(v)
        case OpRISCV64FSUBD:
                return rewriteValueRISCV64_OpRISCV64FSUBD(v)
+       case OpRISCV64FSUBS:
+               return rewriteValueRISCV64_OpRISCV64FSUBS(v)
        case OpRISCV64MOVBUload:
                return rewriteValueRISCV64_OpRISCV64MOVBUload(v)
        case OpRISCV64MOVBUreg:
@@ -3364,6 +3376,31 @@ func rewriteValueRISCV64_OpRISCV64FADDD(v *Value) bool {
        }
        return false
 }
+func rewriteValueRISCV64_OpRISCV64FADDS(v *Value) bool {
+       v_1 := v.Args[1]
+       v_0 := v.Args[0]
+       // match: (FADDS a (FMULS x y))
+       // cond: a.Block.Func.useFMA(v)
+       // result: (FMADDS x y a)
+       for {
+               for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+                       a := v_0
+                       if v_1.Op != OpRISCV64FMULS {
+                               continue
+                       }
+                       y := v_1.Args[1]
+                       x := v_1.Args[0]
+                       if !(a.Block.Func.useFMA(v)) {
+                               continue
+                       }
+                       v.reset(OpRISCV64FMADDS)
+                       v.AddArg3(x, y, a)
+                       return true
+               }
+               break
+       }
+       return false
+}
 func rewriteValueRISCV64_OpRISCV64FMADDD(v *Value) bool {
        v_2 := v.Args[2]
        v_1 := v.Args[1]
@@ -3409,6 +3446,51 @@ func rewriteValueRISCV64_OpRISCV64FMADDD(v *Value) bool {
        }
        return false
 }
+func rewriteValueRISCV64_OpRISCV64FMADDS(v *Value) bool {
+       v_2 := v.Args[2]
+       v_1 := v.Args[1]
+       v_0 := v.Args[0]
+       // match: (FMADDS neg:(FNEGS x) y z)
+       // cond: neg.Uses == 1
+       // result: (FNMSUBS x y z)
+       for {
+               for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+                       neg := v_0
+                       if neg.Op != OpRISCV64FNEGS {
+                               continue
+                       }
+                       x := neg.Args[0]
+                       y := v_1
+                       z := v_2
+                       if !(neg.Uses == 1) {
+                               continue
+                       }
+                       v.reset(OpRISCV64FNMSUBS)
+                       v.AddArg3(x, y, z)
+                       return true
+               }
+               break
+       }
+       // match: (FMADDS x y neg:(FNEGS z))
+       // cond: neg.Uses == 1
+       // result: (FMSUBS x y z)
+       for {
+               x := v_0
+               y := v_1
+               neg := v_2
+               if neg.Op != OpRISCV64FNEGS {
+                       break
+               }
+               z := neg.Args[0]
+               if !(neg.Uses == 1) {
+                       break
+               }
+               v.reset(OpRISCV64FMSUBS)
+               v.AddArg3(x, y, z)
+               return true
+       }
+       return false
+}
 func rewriteValueRISCV64_OpRISCV64FMSUBD(v *Value) bool {
        v_2 := v.Args[2]
        v_1 := v.Args[1]
@@ -3454,6 +3536,51 @@ func rewriteValueRISCV64_OpRISCV64FMSUBD(v *Value) bool {
        }
        return false
 }
+func rewriteValueRISCV64_OpRISCV64FMSUBS(v *Value) bool {
+       v_2 := v.Args[2]
+       v_1 := v.Args[1]
+       v_0 := v.Args[0]
+       // match: (FMSUBS neg:(FNEGS x) y z)
+       // cond: neg.Uses == 1
+       // result: (FNMADDS x y z)
+       for {
+               for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+                       neg := v_0
+                       if neg.Op != OpRISCV64FNEGS {
+                               continue
+                       }
+                       x := neg.Args[0]
+                       y := v_1
+                       z := v_2
+                       if !(neg.Uses == 1) {
+                               continue
+                       }
+                       v.reset(OpRISCV64FNMADDS)
+                       v.AddArg3(x, y, z)
+                       return true
+               }
+               break
+       }
+       // match: (FMSUBS x y neg:(FNEGS z))
+       // cond: neg.Uses == 1
+       // result: (FMADDS x y z)
+       for {
+               x := v_0
+               y := v_1
+               neg := v_2
+               if neg.Op != OpRISCV64FNEGS {
+                       break
+               }
+               z := neg.Args[0]
+               if !(neg.Uses == 1) {
+                       break
+               }
+               v.reset(OpRISCV64FMADDS)
+               v.AddArg3(x, y, z)
+               return true
+       }
+       return false
+}
 func rewriteValueRISCV64_OpRISCV64FNMADDD(v *Value) bool {
        v_2 := v.Args[2]
        v_1 := v.Args[1]
@@ -3499,6 +3626,51 @@ func rewriteValueRISCV64_OpRISCV64FNMADDD(v *Value) bool {
        }
        return false
 }
+func rewriteValueRISCV64_OpRISCV64FNMADDS(v *Value) bool {
+       v_2 := v.Args[2]
+       v_1 := v.Args[1]
+       v_0 := v.Args[0]
+       // match: (FNMADDS neg:(FNEGS x) y z)
+       // cond: neg.Uses == 1
+       // result: (FMSUBS x y z)
+       for {
+               for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+                       neg := v_0
+                       if neg.Op != OpRISCV64FNEGS {
+                               continue
+                       }
+                       x := neg.Args[0]
+                       y := v_1
+                       z := v_2
+                       if !(neg.Uses == 1) {
+                               continue
+                       }
+                       v.reset(OpRISCV64FMSUBS)
+                       v.AddArg3(x, y, z)
+                       return true
+               }
+               break
+       }
+       // match: (FNMADDS x y neg:(FNEGS z))
+       // cond: neg.Uses == 1
+       // result: (FNMSUBS x y z)
+       for {
+               x := v_0
+               y := v_1
+               neg := v_2
+               if neg.Op != OpRISCV64FNEGS {
+                       break
+               }
+               z := neg.Args[0]
+               if !(neg.Uses == 1) {
+                       break
+               }
+               v.reset(OpRISCV64FNMSUBS)
+               v.AddArg3(x, y, z)
+               return true
+       }
+       return false
+}
 func rewriteValueRISCV64_OpRISCV64FNMSUBD(v *Value) bool {
        v_2 := v.Args[2]
        v_1 := v.Args[1]
@@ -3544,6 +3716,51 @@ func rewriteValueRISCV64_OpRISCV64FNMSUBD(v *Value) bool {
        }
        return false
 }
+func rewriteValueRISCV64_OpRISCV64FNMSUBS(v *Value) bool {
+       v_2 := v.Args[2]
+       v_1 := v.Args[1]
+       v_0 := v.Args[0]
+       // match: (FNMSUBS neg:(FNEGS x) y z)
+       // cond: neg.Uses == 1
+       // result: (FMADDS x y z)
+       for {
+               for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+                       neg := v_0
+                       if neg.Op != OpRISCV64FNEGS {
+                               continue
+                       }
+                       x := neg.Args[0]
+                       y := v_1
+                       z := v_2
+                       if !(neg.Uses == 1) {
+                               continue
+                       }
+                       v.reset(OpRISCV64FMADDS)
+                       v.AddArg3(x, y, z)
+                       return true
+               }
+               break
+       }
+       // match: (FNMSUBS x y neg:(FNEGS z))
+       // cond: neg.Uses == 1
+       // result: (FNMADDS x y z)
+       for {
+               x := v_0
+               y := v_1
+               neg := v_2
+               if neg.Op != OpRISCV64FNEGS {
+                       break
+               }
+               z := neg.Args[0]
+               if !(neg.Uses == 1) {
+                       break
+               }
+               v.reset(OpRISCV64FNMADDS)
+               v.AddArg3(x, y, z)
+               return true
+       }
+       return false
+}
 func rewriteValueRISCV64_OpRISCV64FSUBD(v *Value) bool {
        v_1 := v.Args[1]
        v_0 := v.Args[0]
@@ -3583,6 +3800,45 @@ func rewriteValueRISCV64_OpRISCV64FSUBD(v *Value) bool {
        }
        return false
 }
+func rewriteValueRISCV64_OpRISCV64FSUBS(v *Value) bool {
+       v_1 := v.Args[1]
+       v_0 := v.Args[0]
+       // match: (FSUBS a (FMULS x y))
+       // cond: a.Block.Func.useFMA(v)
+       // result: (FNMSUBS x y a)
+       for {
+               a := v_0
+               if v_1.Op != OpRISCV64FMULS {
+                       break
+               }
+               y := v_1.Args[1]
+               x := v_1.Args[0]
+               if !(a.Block.Func.useFMA(v)) {
+                       break
+               }
+               v.reset(OpRISCV64FNMSUBS)
+               v.AddArg3(x, y, a)
+               return true
+       }
+       // match: (FSUBS (FMULS x y) a)
+       // cond: a.Block.Func.useFMA(v)
+       // result: (FMSUBS x y a)
+       for {
+               if v_0.Op != OpRISCV64FMULS {
+                       break
+               }
+               y := v_0.Args[1]
+               x := v_0.Args[0]
+               a := v_1
+               if !(a.Block.Func.useFMA(v)) {
+                       break
+               }
+               v.reset(OpRISCV64FMSUBS)
+               v.AddArg3(x, y, a)
+               return true
+       }
+       return false
+}
 func rewriteValueRISCV64_OpRISCV64MOVBUload(v *Value) bool {
        v_1 := v.Args[1]
        v_0 := v.Args[0]
index 1c5fc8a31a99a3f7d620946e8b09533d20831951..7991174b66769b926f49367fe47eef7a171fd888 100644 (file)
@@ -70,17 +70,20 @@ func FusedAdd32(x, y, z float32) float32 {
        // s390x:"FMADDS\t"
        // ppc64x:"FMADDS\t"
        // arm64:"FMADDS"
+       // riscv64:"FMADDS\t"
        return x*y + z
 }
 
 func FusedSub32_a(x, y, z float32) float32 {
        // s390x:"FMSUBS\t"
        // ppc64x:"FMSUBS\t"
+       // riscv64:"FMSUBS\t"
        return x*y - z
 }
 
 func FusedSub32_b(x, y, z float32) float32 {
        // arm64:"FMSUBS"
+       // riscv64:"FNMSUBS\t"
        return z - x*y
 }