]> Cypherpunks.ru repositories - gostls13.git/commitdiff
cmd/compile: try to rewrite loops to count down
authorJorropo <jorropo.pgm@gmail.com>
Tue, 25 Jul 2023 14:19:10 +0000 (16:19 +0200)
committerGopher Robot <gobot@golang.org>
Mon, 31 Jul 2023 18:33:29 +0000 (18:33 +0000)
Fixes #61629

This reduce the pressure on regalloc because then the loop only keep alive
one value (the iterator) instead of the iterator and the upper bound since
the comparison now acts against an immediate, often zero which can be skipped.

This optimize things like:
  for i := 0; i < n; i++ {
Or a range over a slice where the index is not used:
  for _, v := range someSlice {
Or the new range over int from #61405:
  for range n {

It is hit in 975 unique places while doing ./make.bash.

Change-Id: I5facff8b267a0b60ea3c1b9a58c4d74cdb38f03f
Reviewed-on: https://go-review.googlesource.com/c/go/+/512935
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Jorropo <jorropo.pgm@gmail.com>
Reviewed-by: Keith Randall <khr@google.com>
Reviewed-by: David Chase <drchase@google.com>
Reviewed-by: Keith Randall <khr@golang.org>
Auto-Submit: Keith Randall <khr@golang.org>

src/cmd/compile/internal/ssa/loopbce.go
src/cmd/compile/internal/ssa/prove.go
test/codegen/compare_and_branch.go
test/prove_invert_loop_with_unused_iterators.go [new file with mode: 0644]

index b7dfaa33e3bf46523e57d52b98474dc30d719837..3dbd7350ae9da37fe92319bb8ed3e671b83cb161 100644 (file)
@@ -13,12 +13,14 @@ import (
 type indVarFlags uint8
 
 const (
-       indVarMinExc indVarFlags = 1 << iota // minimum value is exclusive (default: inclusive)
-       indVarMaxInc                         // maximum value is inclusive (default: exclusive)
+       indVarMinExc    indVarFlags = 1 << iota // minimum value is exclusive (default: inclusive)
+       indVarMaxInc                            // maximum value is inclusive (default: exclusive)
+       indVarCountDown                         // if set the iteration starts at max and count towards min (default: min towards max)
 )
 
 type indVar struct {
        ind   *Value // induction variable
+       nxt   *Value // the incremented variable
        min   *Value // minimum value, inclusive/exclusive depends on flags
        max   *Value // maximum value, inclusive/exclusive depends on flags
        entry *Block // entry block in the loop.
@@ -277,6 +279,7 @@ func findIndVar(f *Func) []indVar {
                                if !inclusive {
                                        flags |= indVarMinExc
                                }
+                               flags |= indVarCountDown
                                step = -step
                        }
                        if f.pass.debug >= 1 {
@@ -285,6 +288,7 @@ func findIndVar(f *Func) []indVar {
 
                        iv = append(iv, indVar{
                                ind:   ind,
+                               nxt:   nxt,
                                min:   min,
                                max:   max,
                                entry: b.Succs[0].b,
index 38758c3361bf2cf86be3a75cc25e688db9fe1f0e..91f5fbe7652a6db56d3a1d785ad3ed329a821ba7 100644 (file)
@@ -798,6 +798,166 @@ func (ft *factsTable) cleanup(f *Func) {
 // its negation. If either leads to a contradiction, it can trim that
 // successor.
 func prove(f *Func) {
+       // Find induction variables. Currently, findIndVars
+       // is limited to one induction variable per block.
+       var indVars map[*Block]indVar
+       for _, v := range findIndVar(f) {
+               ind := v.ind
+               if len(ind.Args) != 2 {
+                       // the rewrite code assumes there is only ever two parents to loops
+                       panic("unexpected induction with too many parents")
+               }
+
+               nxt := v.nxt
+               if !(ind.Uses == 2 && // 2 used by comparison and next
+                       nxt.Uses == 1) { // 1 used by induction
+                       // ind or nxt is used inside the loop, add it for the facts table
+                       if indVars == nil {
+                               indVars = make(map[*Block]indVar)
+                       }
+                       indVars[v.entry] = v
+                       continue
+               } else {
+                       // Since this induction variable is not used for anything but counting the iterations,
+                       // no point in putting it into the facts table.
+               }
+
+               // try to rewrite to a downward counting loop checking against start if the
+               // loop body does not depends on ind or nxt and end is known before the loop.
+               // This reduce pressure on the register allocator because this do not need
+               // to use end on each iteration anymore. We compare against the start constant instead.
+               // That means this code:
+               //
+               //      loop:
+               //              ind = (Phi (Const [x]) nxt),
+               //              if ind < end
+               //              then goto enter_loop
+               //              else goto exit_loop
+               //
+               //      enter_loop:
+               //              do something without using ind nor nxt
+               //              nxt = inc + ind
+               //              goto loop
+               //
+               //      exit_loop:
+               //
+               // is rewritten to:
+               //
+               //      loop:
+               //              ind = (Phi end nxt)
+               //              if (Const [x]) < ind
+               //              then goto enter_loop
+               //              else goto exit_loop
+               //
+               //      enter_loop:
+               //              do something without using ind nor nxt
+               //              nxt = ind - inc
+               //              goto loop
+               //
+               //      exit_loop:
+               //
+               // this is better because it only require to keep ind then nxt alive while looping,
+               // while the original form keeps ind then nxt and end alive
+               start, end := v.min, v.max
+               if v.flags&indVarCountDown != 0 {
+                       start, end = end, start
+               }
+
+               if !(start.Op == OpConst8 || start.Op == OpConst16 || start.Op == OpConst32 || start.Op == OpConst64) {
+                       // if start is not a constant we would be winning nothing from inverting the loop
+                       continue
+               }
+               if end.Op == OpConst8 || end.Op == OpConst16 || end.Op == OpConst32 || end.Op == OpConst64 {
+                       // TODO: if both start and end are constants we should rewrite such that the comparison
+                       // is against zero and nxt is ++ or -- operation
+                       // That means:
+                       //      for i := 2; i < 11; i += 2 {
+                       // should be rewritten to:
+                       //      for i := 5; 0 < i; i-- {
+                       continue
+               }
+
+               header := ind.Block
+               check := header.Controls[0]
+               if check == nil {
+                       // we don't know how to rewrite a loop that not simple comparison
+                       continue
+               }
+               switch check.Op {
+               case OpLeq64, OpLeq32, OpLeq16, OpLeq8,
+                       OpLess64, OpLess32, OpLess16, OpLess8:
+               default:
+                       // we don't know how to rewrite a loop that not simple comparison
+                       continue
+               }
+               if !((check.Args[0] == ind && check.Args[1] == end) ||
+                       (check.Args[1] == ind && check.Args[0] == end)) {
+                       // we don't know how to rewrite a loop that not simple comparison
+                       continue
+               }
+               if end.Block == ind.Block {
+                       // we can't rewrite loops where the condition depends on the loop body
+                       // this simple check is forced to work because if this is true a Phi in ind.Block must exists
+                       continue
+               }
+
+               // invert the check
+               check.Args[0], check.Args[1] = check.Args[1], check.Args[0]
+
+               // invert start and end in the loop
+               for i, v := range check.Args {
+                       if v != end {
+                               continue
+                       }
+
+                       check.SetArg(i, start)
+                       goto replacedEnd
+               }
+               panic(fmt.Sprintf("unreachable, ind: %v, start: %v, end: %v", ind, start, end))
+       replacedEnd:
+
+               for i, v := range ind.Args {
+                       if v != start {
+                               continue
+                       }
+
+                       ind.SetArg(i, end)
+                       goto replacedStart
+               }
+               panic(fmt.Sprintf("unreachable, ind: %v, start: %v, end: %v", ind, start, end))
+       replacedStart:
+
+               if nxt.Args[0] != ind {
+                       // unlike additions subtractions are not commutative so be sure we get it right
+                       nxt.Args[0], nxt.Args[1] = nxt.Args[1], nxt.Args[0]
+               }
+
+               switch nxt.Op {
+               case OpAdd8:
+                       nxt.Op = OpSub8
+               case OpAdd16:
+                       nxt.Op = OpSub16
+               case OpAdd32:
+                       nxt.Op = OpSub32
+               case OpAdd64:
+                       nxt.Op = OpSub64
+               case OpSub8:
+                       nxt.Op = OpAdd8
+               case OpSub16:
+                       nxt.Op = OpAdd16
+               case OpSub32:
+                       nxt.Op = OpAdd32
+               case OpSub64:
+                       nxt.Op = OpAdd64
+               default:
+                       panic("unreachable")
+               }
+
+               if f.pass.debug > 0 {
+                       f.Warnl(ind.Pos, "Inverted loop iteration")
+               }
+       }
+
        ft := newFactsTable(f)
        ft.checkpoint()
 
@@ -933,15 +1093,6 @@ func prove(f *Func) {
                        }
                }
        }
-       // Find induction variables. Currently, findIndVars
-       // is limited to one induction variable per block.
-       var indVars map[*Block]indVar
-       for _, v := range findIndVar(f) {
-               if indVars == nil {
-                       indVars = make(map[*Block]indVar)
-               }
-               indVars[v.entry] = v
-       }
 
        // current node state
        type walkState int
index b3feef0eb75f3a6697a1f706589fd78096bf899f..3502a0302238b1095c762415e75dd8b028121390 100644 (file)
@@ -23,24 +23,25 @@ func si64(x, y chan int64) {
 }
 
 // Signed 64-bit compare-and-branch with 8-bit immediate.
-func si64x8() {
+func si64x8(doNotOptimize int64) {
+       // take in doNotOptimize as an argument to avoid the loops being rewritten to count down
        // s390x:"CGIJ\t[$]12, R[0-9]+, [$]127, "
-       for i := int64(0); i < 128; i++ {
+       for i := doNotOptimize; i < 128; i++ {
                dummy()
        }
 
        // s390x:"CGIJ\t[$]10, R[0-9]+, [$]-128, "
-       for i := int64(0); i > -129; i-- {
+       for i := doNotOptimize; i > -129; i-- {
                dummy()
        }
 
        // s390x:"CGIJ\t[$]2, R[0-9]+, [$]127, "
-       for i := int64(0); i >= 128; i++ {
+       for i := doNotOptimize; i >= 128; i++ {
                dummy()
        }
 
        // s390x:"CGIJ\t[$]4, R[0-9]+, [$]-128, "
-       for i := int64(0); i <= -129; i-- {
+       for i := doNotOptimize; i <= -129; i-- {
                dummy()
        }
 }
@@ -95,24 +96,25 @@ func si32(x, y chan int32) {
 }
 
 // Signed 32-bit compare-and-branch with 8-bit immediate.
-func si32x8() {
+func si32x8(doNotOptimize int32) {
+       // take in doNotOptimize as an argument to avoid the loops being rewritten to count down
        // s390x:"CIJ\t[$]12, R[0-9]+, [$]127, "
-       for i := int32(0); i < 128; i++ {
+       for i := doNotOptimize; i < 128; i++ {
                dummy()
        }
 
        // s390x:"CIJ\t[$]10, R[0-9]+, [$]-128, "
-       for i := int32(0); i > -129; i-- {
+       for i := doNotOptimize; i > -129; i-- {
                dummy()
        }
 
        // s390x:"CIJ\t[$]2, R[0-9]+, [$]127, "
-       for i := int32(0); i >= 128; i++ {
+       for i := doNotOptimize; i >= 128; i++ {
                dummy()
        }
 
        // s390x:"CIJ\t[$]4, R[0-9]+, [$]-128, "
-       for i := int32(0); i <= -129; i-- {
+       for i := doNotOptimize; i <= -129; i-- {
                dummy()
        }
 }
diff --git a/test/prove_invert_loop_with_unused_iterators.go b/test/prove_invert_loop_with_unused_iterators.go
new file mode 100644 (file)
index 0000000..f278e5a
--- /dev/null
@@ -0,0 +1,10 @@
+// +build amd64
+// errorcheck -0 -d=ssa/prove/debug=1
+
+package main
+
+func invert(b func(), n int) {
+       for i := 0; i < n; i++ { // ERROR "(Inverted loop iteration|Induction variable: limits \[0,\?\), increment 1)"
+               b()
+       }
+}