From: Jorropo Date: Tue, 25 Jul 2023 14:19:10 +0000 (+0200) Subject: cmd/compile: try to rewrite loops to count down X-Git-Tag: go1.22rc1~1472 X-Git-Url: http://www.git.cypherpunks.ru/?a=commitdiff_plain;h=bac4e2f241ca8df3d5be6ddf83214b9a681f4086;p=gostls13.git cmd/compile: try to rewrite loops to count down Fixes #61629 This reduce the pressure on regalloc because then the loop only keep alive one value (the iterator) instead of the iterator and the upper bound since the comparison now acts against an immediate, often zero which can be skipped. This optimize things like: for i := 0; i < n; i++ { Or a range over a slice where the index is not used: for _, v := range someSlice { Or the new range over int from #61405: for range n { It is hit in 975 unique places while doing ./make.bash. Change-Id: I5facff8b267a0b60ea3c1b9a58c4d74cdb38f03f Reviewed-on: https://go-review.googlesource.com/c/go/+/512935 TryBot-Result: Gopher Robot Run-TryBot: Jorropo Reviewed-by: Keith Randall Reviewed-by: David Chase Reviewed-by: Keith Randall Auto-Submit: Keith Randall --- diff --git a/src/cmd/compile/internal/ssa/loopbce.go b/src/cmd/compile/internal/ssa/loopbce.go index b7dfaa33e3..3dbd7350ae 100644 --- a/src/cmd/compile/internal/ssa/loopbce.go +++ b/src/cmd/compile/internal/ssa/loopbce.go @@ -13,12 +13,14 @@ import ( type indVarFlags uint8 const ( - indVarMinExc indVarFlags = 1 << iota // minimum value is exclusive (default: inclusive) - indVarMaxInc // maximum value is inclusive (default: exclusive) + indVarMinExc indVarFlags = 1 << iota // minimum value is exclusive (default: inclusive) + indVarMaxInc // maximum value is inclusive (default: exclusive) + indVarCountDown // if set the iteration starts at max and count towards min (default: min towards max) ) type indVar struct { ind *Value // induction variable + nxt *Value // the incremented variable min *Value // minimum value, inclusive/exclusive depends on flags max *Value // maximum value, inclusive/exclusive depends on flags entry *Block // entry block in the loop. @@ -277,6 +279,7 @@ func findIndVar(f *Func) []indVar { if !inclusive { flags |= indVarMinExc } + flags |= indVarCountDown step = -step } if f.pass.debug >= 1 { @@ -285,6 +288,7 @@ func findIndVar(f *Func) []indVar { iv = append(iv, indVar{ ind: ind, + nxt: nxt, min: min, max: max, entry: b.Succs[0].b, diff --git a/src/cmd/compile/internal/ssa/prove.go b/src/cmd/compile/internal/ssa/prove.go index 38758c3361..91f5fbe765 100644 --- a/src/cmd/compile/internal/ssa/prove.go +++ b/src/cmd/compile/internal/ssa/prove.go @@ -798,6 +798,166 @@ func (ft *factsTable) cleanup(f *Func) { // its negation. If either leads to a contradiction, it can trim that // successor. func prove(f *Func) { + // Find induction variables. Currently, findIndVars + // is limited to one induction variable per block. + var indVars map[*Block]indVar + for _, v := range findIndVar(f) { + ind := v.ind + if len(ind.Args) != 2 { + // the rewrite code assumes there is only ever two parents to loops + panic("unexpected induction with too many parents") + } + + nxt := v.nxt + if !(ind.Uses == 2 && // 2 used by comparison and next + nxt.Uses == 1) { // 1 used by induction + // ind or nxt is used inside the loop, add it for the facts table + if indVars == nil { + indVars = make(map[*Block]indVar) + } + indVars[v.entry] = v + continue + } else { + // Since this induction variable is not used for anything but counting the iterations, + // no point in putting it into the facts table. + } + + // try to rewrite to a downward counting loop checking against start if the + // loop body does not depends on ind or nxt and end is known before the loop. + // This reduce pressure on the register allocator because this do not need + // to use end on each iteration anymore. We compare against the start constant instead. + // That means this code: + // + // loop: + // ind = (Phi (Const [x]) nxt), + // if ind < end + // then goto enter_loop + // else goto exit_loop + // + // enter_loop: + // do something without using ind nor nxt + // nxt = inc + ind + // goto loop + // + // exit_loop: + // + // is rewritten to: + // + // loop: + // ind = (Phi end nxt) + // if (Const [x]) < ind + // then goto enter_loop + // else goto exit_loop + // + // enter_loop: + // do something without using ind nor nxt + // nxt = ind - inc + // goto loop + // + // exit_loop: + // + // this is better because it only require to keep ind then nxt alive while looping, + // while the original form keeps ind then nxt and end alive + start, end := v.min, v.max + if v.flags&indVarCountDown != 0 { + start, end = end, start + } + + if !(start.Op == OpConst8 || start.Op == OpConst16 || start.Op == OpConst32 || start.Op == OpConst64) { + // if start is not a constant we would be winning nothing from inverting the loop + continue + } + if end.Op == OpConst8 || end.Op == OpConst16 || end.Op == OpConst32 || end.Op == OpConst64 { + // TODO: if both start and end are constants we should rewrite such that the comparison + // is against zero and nxt is ++ or -- operation + // That means: + // for i := 2; i < 11; i += 2 { + // should be rewritten to: + // for i := 5; 0 < i; i-- { + continue + } + + header := ind.Block + check := header.Controls[0] + if check == nil { + // we don't know how to rewrite a loop that not simple comparison + continue + } + switch check.Op { + case OpLeq64, OpLeq32, OpLeq16, OpLeq8, + OpLess64, OpLess32, OpLess16, OpLess8: + default: + // we don't know how to rewrite a loop that not simple comparison + continue + } + if !((check.Args[0] == ind && check.Args[1] == end) || + (check.Args[1] == ind && check.Args[0] == end)) { + // we don't know how to rewrite a loop that not simple comparison + continue + } + if end.Block == ind.Block { + // we can't rewrite loops where the condition depends on the loop body + // this simple check is forced to work because if this is true a Phi in ind.Block must exists + continue + } + + // invert the check + check.Args[0], check.Args[1] = check.Args[1], check.Args[0] + + // invert start and end in the loop + for i, v := range check.Args { + if v != end { + continue + } + + check.SetArg(i, start) + goto replacedEnd + } + panic(fmt.Sprintf("unreachable, ind: %v, start: %v, end: %v", ind, start, end)) + replacedEnd: + + for i, v := range ind.Args { + if v != start { + continue + } + + ind.SetArg(i, end) + goto replacedStart + } + panic(fmt.Sprintf("unreachable, ind: %v, start: %v, end: %v", ind, start, end)) + replacedStart: + + if nxt.Args[0] != ind { + // unlike additions subtractions are not commutative so be sure we get it right + nxt.Args[0], nxt.Args[1] = nxt.Args[1], nxt.Args[0] + } + + switch nxt.Op { + case OpAdd8: + nxt.Op = OpSub8 + case OpAdd16: + nxt.Op = OpSub16 + case OpAdd32: + nxt.Op = OpSub32 + case OpAdd64: + nxt.Op = OpSub64 + case OpSub8: + nxt.Op = OpAdd8 + case OpSub16: + nxt.Op = OpAdd16 + case OpSub32: + nxt.Op = OpAdd32 + case OpSub64: + nxt.Op = OpAdd64 + default: + panic("unreachable") + } + + if f.pass.debug > 0 { + f.Warnl(ind.Pos, "Inverted loop iteration") + } + } + ft := newFactsTable(f) ft.checkpoint() @@ -933,15 +1093,6 @@ func prove(f *Func) { } } } - // Find induction variables. Currently, findIndVars - // is limited to one induction variable per block. - var indVars map[*Block]indVar - for _, v := range findIndVar(f) { - if indVars == nil { - indVars = make(map[*Block]indVar) - } - indVars[v.entry] = v - } // current node state type walkState int diff --git a/test/codegen/compare_and_branch.go b/test/codegen/compare_and_branch.go index b3feef0eb7..3502a03022 100644 --- a/test/codegen/compare_and_branch.go +++ b/test/codegen/compare_and_branch.go @@ -23,24 +23,25 @@ func si64(x, y chan int64) { } // Signed 64-bit compare-and-branch with 8-bit immediate. -func si64x8() { +func si64x8(doNotOptimize int64) { + // take in doNotOptimize as an argument to avoid the loops being rewritten to count down // s390x:"CGIJ\t[$]12, R[0-9]+, [$]127, " - for i := int64(0); i < 128; i++ { + for i := doNotOptimize; i < 128; i++ { dummy() } // s390x:"CGIJ\t[$]10, R[0-9]+, [$]-128, " - for i := int64(0); i > -129; i-- { + for i := doNotOptimize; i > -129; i-- { dummy() } // s390x:"CGIJ\t[$]2, R[0-9]+, [$]127, " - for i := int64(0); i >= 128; i++ { + for i := doNotOptimize; i >= 128; i++ { dummy() } // s390x:"CGIJ\t[$]4, R[0-9]+, [$]-128, " - for i := int64(0); i <= -129; i-- { + for i := doNotOptimize; i <= -129; i-- { dummy() } } @@ -95,24 +96,25 @@ func si32(x, y chan int32) { } // Signed 32-bit compare-and-branch with 8-bit immediate. -func si32x8() { +func si32x8(doNotOptimize int32) { + // take in doNotOptimize as an argument to avoid the loops being rewritten to count down // s390x:"CIJ\t[$]12, R[0-9]+, [$]127, " - for i := int32(0); i < 128; i++ { + for i := doNotOptimize; i < 128; i++ { dummy() } // s390x:"CIJ\t[$]10, R[0-9]+, [$]-128, " - for i := int32(0); i > -129; i-- { + for i := doNotOptimize; i > -129; i-- { dummy() } // s390x:"CIJ\t[$]2, R[0-9]+, [$]127, " - for i := int32(0); i >= 128; i++ { + for i := doNotOptimize; i >= 128; i++ { dummy() } // s390x:"CIJ\t[$]4, R[0-9]+, [$]-128, " - for i := int32(0); i <= -129; i-- { + for i := doNotOptimize; i <= -129; i-- { dummy() } } diff --git a/test/prove_invert_loop_with_unused_iterators.go b/test/prove_invert_loop_with_unused_iterators.go new file mode 100644 index 0000000000..f278e5aee0 --- /dev/null +++ b/test/prove_invert_loop_with_unused_iterators.go @@ -0,0 +1,10 @@ +// +build amd64 +// errorcheck -0 -d=ssa/prove/debug=1 + +package main + +func invert(b func(), n int) { + for i := 0; i < n; i++ { // ERROR "(Inverted loop iteration|Induction variable: limits \[0,\?\), increment 1)" + b() + } +}