]> Cypherpunks.ru repositories - gostls13.git/blobdiff - src/cmd/compile/internal/ssa/loopbce.go
cmd/compile: fix findIndVar so it does not match disjointed loop headers
[gostls13.git] / src / cmd / compile / internal / ssa / loopbce.go
index 092e7aa35ba36e7ff9eaef8a83a74f511e439c3f..dd1f39dbef74398f7cf9e338d26d6f9b1a8f4ca7 100644 (file)
@@ -5,19 +5,22 @@
 package ssa
 
 import (
+       "cmd/compile/internal/base"
+       "cmd/compile/internal/types"
        "fmt"
-       "math"
 )
 
 type indVarFlags uint8
 
 const (
-       indVarMinExc indVarFlags = 1 << iota // minimum value is exclusive (default: inclusive)
-       indVarMaxInc                         // maximum value is inclusive (default: exclusive)
+       indVarMinExc    indVarFlags = 1 << iota // minimum value is exclusive (default: inclusive)
+       indVarMaxInc                            // maximum value is inclusive (default: exclusive)
+       indVarCountDown                         // if set the iteration starts at max and count towards min (default: min towards max)
 )
 
 type indVar struct {
        ind   *Value // induction variable
+       nxt   *Value // the incremented variable
        min   *Value // minimum value, inclusive/exclusive depends on flags
        max   *Value // maximum value, inclusive/exclusive depends on flags
        entry *Block // entry block in the loop.
@@ -29,89 +32,112 @@ type indVar struct {
        //      min <  ind <= max    [if flags == indVarMinExc|indVarMaxInc]
 }
 
+// parseIndVar checks whether the SSA value passed as argument is a valid induction
+// variable, and, if so, extracts:
+//   - the minimum bound
+//   - the increment value
+//   - the "next" value (SSA value that is Phi'd into the induction variable every loop)
+//
+// Currently, we detect induction variables that match (Phi min nxt),
+// with nxt being (Add inc ind).
+// If it can't parse the induction variable correctly, it returns (nil, nil, nil).
+func parseIndVar(ind *Value) (min, inc, nxt *Value) {
+       if ind.Op != OpPhi {
+               return
+       }
+
+       if n := ind.Args[0]; (n.Op == OpAdd64 || n.Op == OpAdd32 || n.Op == OpAdd16 || n.Op == OpAdd8) && (n.Args[0] == ind || n.Args[1] == ind) {
+               min, nxt = ind.Args[1], n
+       } else if n := ind.Args[1]; (n.Op == OpAdd64 || n.Op == OpAdd32 || n.Op == OpAdd16 || n.Op == OpAdd8) && (n.Args[0] == ind || n.Args[1] == ind) {
+               min, nxt = ind.Args[0], n
+       } else {
+               // Not a recognized induction variable.
+               return
+       }
+
+       if nxt.Args[0] == ind { // nxt = ind + inc
+               inc = nxt.Args[1]
+       } else if nxt.Args[1] == ind { // nxt = inc + ind
+               inc = nxt.Args[0]
+       } else {
+               panic("unreachable") // one of the cases must be true from the above.
+       }
+
+       return
+}
+
 // findIndVar finds induction variables in a function.
 //
 // Look for variables and blocks that satisfy the following
 //
-// loop:
-//   ind = (Phi min nxt),
-//   if ind < max
-//     then goto enter_loop
-//     else goto exit_loop
+//      loop:
+//        ind = (Phi min nxt),
+//        if ind < max
+//          then goto enter_loop
+//          else goto exit_loop
 //
-//   enter_loop:
-//     do something
-//      nxt = inc + ind
-//     goto loop
+//        enter_loop:
+//             do something
+//           nxt = inc + ind
+//             goto loop
 //
-// exit_loop:
-//
-//
-// TODO: handle 32 bit operations
+//      exit_loop:
 func findIndVar(f *Func) []indVar {
        var iv []indVar
-       sdom := f.sdom()
+       sdom := f.Sdom()
 
        for _, b := range f.Blocks {
                if b.Kind != BlockIf || len(b.Preds) != 2 {
                        continue
                }
 
-               var flags indVarFlags
-               var ind, max *Value // induction, and maximum
+               var ind *Value   // induction variable
+               var init *Value  // starting value
+               var limit *Value // ending value
 
-               // Check thet the control if it either ind </<= max or max >/>= ind.
-               // TODO: Handle 32-bit comparisons.
+               // Check that the control if it either ind </<= limit or limit </<= ind.
                // TODO: Handle unsigned comparisons?
-               switch b.Control.Op {
-               case OpLeq64:
-                       flags |= indVarMaxInc
-                       fallthrough
-               case OpLess64:
-                       ind, max = b.Control.Args[0], b.Control.Args[1]
-               case OpGeq64:
-                       flags |= indVarMaxInc
+               c := b.Controls[0]
+               inclusive := false
+               switch c.Op {
+               case OpLeq64, OpLeq32, OpLeq16, OpLeq8:
+                       inclusive = true
                        fallthrough
-               case OpGreater64:
-                       ind, max = b.Control.Args[1], b.Control.Args[0]
+               case OpLess64, OpLess32, OpLess16, OpLess8:
+                       ind, limit = c.Args[0], c.Args[1]
                default:
                        continue
                }
 
-               // See if the arguments are reversed (i < len() <=> len() > i)
+               // See if this is really an induction variable
                less := true
-               if max.Op == OpPhi {
-                       ind, max = max, ind
-                       less = false
-               }
+               init, inc, nxt := parseIndVar(ind)
+               if init == nil {
+                       // We failed to parse the induction variable. Before punting, we want to check
+                       // whether the control op was written with the induction variable on the RHS
+                       // instead of the LHS. This happens for the downwards case, like:
+                       //     for i := len(n)-1; i >= 0; i--
+                       init, inc, nxt = parseIndVar(limit)
+                       if init == nil {
+                               // No recognized induction variable on either operand
+                               continue
+                       }
 
-               // Check that the induction variable is a phi that depends on itself.
-               if ind.Op != OpPhi {
-                       continue
+                       // Ok, the arguments were reversed. Swap them, and remember that we're
+                       // looking at an ind >/>= loop (so the induction must be decrementing).
+                       ind, limit = limit, ind
+                       less = false
                }
 
-               // Extract min and nxt knowing that nxt is an addition (e.g. Add64).
-               var min, nxt *Value // minimum, and next value
-               if n := ind.Args[0]; n.Op == OpAdd64 && (n.Args[0] == ind || n.Args[1] == ind) {
-                       min, nxt = ind.Args[1], n
-               } else if n := ind.Args[1]; n.Op == OpAdd64 && (n.Args[0] == ind || n.Args[1] == ind) {
-                       min, nxt = ind.Args[0], n
-               } else {
-                       // Not a recognized induction variable.
+               if ind.Block != b {
+                       // TODO: Could be extended to include disjointed loop headers.
+                       // I don't think this is causing missed optimizations in real world code often.
+                       // See https://go.dev/issue/63955
                        continue
                }
 
-               var inc *Value
-               if nxt.Args[0] == ind { // nxt = ind + inc
-                       inc = nxt.Args[1]
-               } else if nxt.Args[1] == ind { // nxt = inc + ind
-                       inc = nxt.Args[0]
-               } else {
-                       panic("unreachable") // one of the cases must be true from the above.
-               }
-
                // Expect the increment to be a nonzero constant.
-               if inc.Op != OpConst64 {
+               if !inc.isGenericIntConst() {
                        continue
                }
                step := inc.AuxInt
@@ -120,8 +146,8 @@ func findIndVar(f *Func) []indVar {
                }
 
                // Increment sign must match comparison direction.
-               // When incrementing, the termination comparison must be ind </<= max.
-               // When decrementing, the termination comparison must be ind >/>= max.
+               // When incrementing, the termination comparison must be ind </<= limit.
+               // When decrementing, the termination comparison must be ind >/>= limit.
                // See issue 26116.
                if step > 0 && !less {
                        continue
@@ -130,170 +156,245 @@ func findIndVar(f *Func) []indVar {
                        continue
                }
 
-               // If the increment is negative, swap min/max and their flags
-               if step < 0 {
-                       min, max = max, min
-                       oldf := flags
-                       flags = indVarMaxInc
-                       if oldf&indVarMaxInc == 0 {
-                               flags |= indVarMinExc
-                       }
-                       step = -step
-               }
-
                // Up to now we extracted the induction variable (ind),
                // the increment delta (inc), the temporary sum (nxt),
-               // the mininum value (min) and the maximum value (max).
+               // the initial value (init) and the limiting value (limit).
                //
-               // We also know that ind has the form (Phi min nxt) where
+               // We also know that ind has the form (Phi init nxt) where
                // nxt is (Add inc nxt) which means: 1) inc dominates nxt
                // and 2) there is a loop starting at inc and containing nxt.
                //
                // We need to prove that the induction variable is incremented
-               // only when it's smaller than the maximum value.
+               // only when it's smaller than the limiting value.
                // Two conditions must happen listed below to accept ind
                // as an induction variable.
 
                // First condition: loop entry has a single predecessor, which
                // is the header block.  This implies that b.Succs[0] is
-               // reached iff ind < max.
+               // reached iff ind < limit.
                if len(b.Succs[0].b.Preds) != 1 {
                        // b.Succs[1] must exit the loop.
                        continue
                }
 
                // Second condition: b.Succs[0] dominates nxt so that
-               // nxt is computed when inc < max, meaning nxt <= max.
-               if !sdom.isAncestorEq(b.Succs[0].b, nxt.Block) {
+               // nxt is computed when inc < limit.
+               if !sdom.IsAncestorEq(b.Succs[0].b, nxt.Block) {
                        // inc+ind can only be reached through the branch that enters the loop.
                        continue
                }
 
-               // We can only guarantee that the loop runs within limits of induction variable
-               // if (one of)
-               // (1) the increment is ±1
-               // (2) the limits are constants
-               // (3) loop is of the form k0 upto Known_not_negative-k inclusive, step <= k
-               // (4) loop is of the form k0 upto Known_not_negative-k exclusive, step <= k+1
-               // (5) loop is of the form Known_not_negative downto k0, minint+step < k0
-               if step > 1 {
-                       ok := false
-                       if min.Op == OpConst64 && max.Op == OpConst64 {
-                               if max.AuxInt > min.AuxInt && max.AuxInt%step == min.AuxInt%step { // handle overflow
-                                       ok = true
-                               }
-                       }
-                       // Handle induction variables of these forms.
-                       // KNN is known-not-negative.
-                       // SIGNED ARITHMETIC ONLY. (see switch on b.Control.Op above)
-                       // Possibilities for KNN are len and cap; perhaps we can infer others.
-                       // for i := 0; i <= KNN-k    ; i += k
-                       // for i := 0; i <  KNN-(k-1); i += k
-                       // Also handle decreasing.
-
-                       // "Proof" copied from https://go-review.googlesource.com/c/go/+/104041/10/src/cmd/compile/internal/ssa/loopbce.go#164
-                       //
-                       //      In the case of
-                       //      // PC is Positive Constant
-                       //      L := len(A)-PC
-                       //      for i := 0; i < L; i = i+PC
-                       //
-                       //      we know:
-                       //
-                       //      0 + PC does not over/underflow.
-                       //      len(A)-PC does not over/underflow
-                       //      maximum value for L is MaxInt-PC
-                       //      i < L <= MaxInt-PC means i + PC < MaxInt hence no overflow.
-
-                       // To match in SSA:
-                       // if  (a) min.Op == OpConst64(k0)
-                       // and (b) k0 >= MININT + step
-                       // and (c) max.Op == OpSubtract(Op{StringLen,SliceLen,SliceCap}, k)
-                       // or  (c) max.Op == OpAdd(Op{StringLen,SliceLen,SliceCap}, -k)
-                       // or  (c) max.Op == Op{StringLen,SliceLen,SliceCap}
-                       // and (d) if upto loop, require indVarMaxInc && step <= k or !indVarMaxInc && step-1 <= k
-
-                       if min.Op == OpConst64 && min.AuxInt >= step+math.MinInt64 {
-                               knn := max
-                               k := int64(0)
-                               var kArg *Value
-
-                               switch max.Op {
-                               case OpSub64:
-                                       knn = max.Args[0]
-                                       kArg = max.Args[1]
-
-                               case OpAdd64:
-                                       knn = max.Args[0]
-                                       kArg = max.Args[1]
-                                       if knn.Op == OpConst64 {
-                                               knn, kArg = kArg, knn
+               // Check for overflow/underflow. We need to make sure that inc never causes
+               // the induction variable to wrap around.
+               // We use a function wrapper here for easy return true / return false / keep going logic.
+               // This function returns true if the increment will never overflow/underflow.
+               ok := func() bool {
+                       if step > 0 {
+                               if limit.isGenericIntConst() {
+                                       // Figure out the actual largest value.
+                                       v := limit.AuxInt
+                                       if !inclusive {
+                                               if v == minSignedValue(limit.Type) {
+                                                       return false // < minint is never satisfiable.
+                                               }
+                                               v--
+                                       }
+                                       if init.isGenericIntConst() {
+                                               // Use stride to compute a better lower limit.
+                                               if init.AuxInt > v {
+                                                       return false
+                                               }
+                                               v = addU(init.AuxInt, diff(v, init.AuxInt)/uint64(step)*uint64(step))
+                                       }
+                                       if addWillOverflow(v, step) {
+                                               return false
+                                       }
+                                       if inclusive && v != limit.AuxInt || !inclusive && v+1 != limit.AuxInt {
+                                               // We know a better limit than the programmer did. Use our limit instead.
+                                               limit = f.constVal(limit.Op, limit.Type, v, true)
+                                               inclusive = true
                                        }
+                                       return true
                                }
-                               switch knn.Op {
-                               case OpSliceLen, OpStringLen, OpSliceCap:
-                               default:
-                                       knn = nil
+                               if step == 1 && !inclusive {
+                                       // Can't overflow because maxint is never a possible value.
+                                       return true
                                }
-
-                               if kArg != nil && kArg.Op == OpConst64 {
-                                       k = kArg.AuxInt
-                                       if max.Op == OpAdd64 {
-                                               k = -k
-                                       }
+                               // If the limit is not a constant, check to see if it is a
+                               // negative offset from a known non-negative value.
+                               knn, k := findKNN(limit)
+                               if knn == nil || k < 0 {
+                                       return false
                                }
-                               if k >= 0 && knn != nil {
-                                       if inc.AuxInt > 0 { // increasing iteration
-                                               // The concern for the relation between step and k is to ensure that iv never exceeds knn
-                                               // i.e., iv < knn-(K-1) ==> iv + K <= knn; iv <= knn-K ==> iv +K < knn
-                                               if step <= k || flags&indVarMaxInc == 0 && step-1 == k {
-                                                       ok = true
+                               // limit == (something nonnegative) - k. That subtraction can't underflow, so
+                               // we can trust it.
+                               if inclusive {
+                                       // ind <= knn - k cannot overflow if step is at most k
+                                       return step <= k
+                               }
+                               // ind < knn - k cannot overflow if step is at most k+1
+                               return step <= k+1 && k != maxSignedValue(limit.Type)
+                       } else { // step < 0
+                               if limit.Op == OpConst64 {
+                                       // Figure out the actual smallest value.
+                                       v := limit.AuxInt
+                                       if !inclusive {
+                                               if v == maxSignedValue(limit.Type) {
+                                                       return false // > maxint is never satisfiable.
                                                }
-                                       } else { // decreasing iteration
-                                               // Will be decrementing from max towards min; max is knn-k; will only attempt decrement if
-                                               // knn-k >[=] min; underflow is only a concern if min-step is not smaller than min.
-                                               // This all assumes signed integer arithmetic
-                                               // This is already assured by the test above: min.AuxInt >= step+math.MinInt64
-                                               ok = true
+                                               v++
                                        }
+                                       if init.isGenericIntConst() {
+                                               // Use stride to compute a better lower limit.
+                                               if init.AuxInt < v {
+                                                       return false
+                                               }
+                                               v = subU(init.AuxInt, diff(init.AuxInt, v)/uint64(-step)*uint64(-step))
+                                       }
+                                       if subWillUnderflow(v, -step) {
+                                               return false
+                                       }
+                                       if inclusive && v != limit.AuxInt || !inclusive && v-1 != limit.AuxInt {
+                                               // We know a better limit than the programmer did. Use our limit instead.
+                                               limit = f.constVal(limit.Op, limit.Type, v, true)
+                                               inclusive = true
+                                       }
+                                       return true
+                               }
+                               if step == -1 && !inclusive {
+                                       // Can't underflow because minint is never a possible value.
+                                       return true
                                }
                        }
+                       return false
 
-                       // TODO: other unrolling idioms
-                       // for i := 0; i < KNN - KNN % k ; i += k
-                       // for i := 0; i < KNN&^(k-1) ; i += k // k a power of 2
-                       // for i := 0; i < KNN&(-k) ; i += k // k a power of 2
+               }
 
-                       if !ok {
-                               continue
+               if ok() {
+                       flags := indVarFlags(0)
+                       var min, max *Value
+                       if step > 0 {
+                               min = init
+                               max = limit
+                               if inclusive {
+                                       flags |= indVarMaxInc
+                               }
+                       } else {
+                               min = limit
+                               max = init
+                               flags |= indVarMaxInc
+                               if !inclusive {
+                                       flags |= indVarMinExc
+                               }
+                               flags |= indVarCountDown
+                               step = -step
+                       }
+                       if f.pass.debug >= 1 {
+                               printIndVar(b, ind, min, max, step, flags)
                        }
-               }
 
-               if f.pass.debug >= 1 {
-                       printIndVar(b, ind, min, max, step, flags)
+                       iv = append(iv, indVar{
+                               ind:   ind,
+                               nxt:   nxt,
+                               min:   min,
+                               max:   max,
+                               entry: b.Succs[0].b,
+                               flags: flags,
+                       })
+                       b.Logf("found induction variable %v (inc = %v, min = %v, max = %v)\n", ind, inc, min, max)
                }
 
-               iv = append(iv, indVar{
-                       ind:   ind,
-                       min:   min,
-                       max:   max,
-                       entry: b.Succs[0].b,
-                       flags: flags,
-               })
-               b.Logf("found induction variable %v (inc = %v, min = %v, max = %v)\n", ind, inc, min, max)
+               // TODO: other unrolling idioms
+               // for i := 0; i < KNN - KNN % k ; i += k
+               // for i := 0; i < KNN&^(k-1) ; i += k // k a power of 2
+               // for i := 0; i < KNN&(-k) ; i += k // k a power of 2
        }
 
        return iv
 }
 
-func dropAdd64(v *Value) (*Value, int64) {
-       if v.Op == OpAdd64 && v.Args[0].Op == OpConst64 {
-               return v.Args[1], v.Args[0].AuxInt
+// addWillOverflow reports whether x+y would result in a value more than maxint.
+func addWillOverflow(x, y int64) bool {
+       return x+y < x
+}
+
+// subWillUnderflow reports whether x-y would result in a value less than minint.
+func subWillUnderflow(x, y int64) bool {
+       return x-y > x
+}
+
+// diff returns x-y as a uint64. Requires x>=y.
+func diff(x, y int64) uint64 {
+       if x < y {
+               base.Fatalf("diff %d - %d underflowed", x, y)
+       }
+       return uint64(x - y)
+}
+
+// addU returns x+y. Requires that x+y does not overflow an int64.
+func addU(x int64, y uint64) int64 {
+       if y >= 1<<63 {
+               if x >= 0 {
+                       base.Fatalf("addU overflowed %d + %d", x, y)
+               }
+               x += 1<<63 - 1
+               x += 1
+               y -= 1 << 63
        }
-       if v.Op == OpAdd64 && v.Args[1].Op == OpConst64 {
-               return v.Args[0], v.Args[1].AuxInt
+       if addWillOverflow(x, int64(y)) {
+               base.Fatalf("addU overflowed %d + %d", x, y)
        }
-       return v, 0
+       return x + int64(y)
+}
+
+// subU returns x-y. Requires that x-y does not underflow an int64.
+func subU(x int64, y uint64) int64 {
+       if y >= 1<<63 {
+               if x < 0 {
+                       base.Fatalf("subU underflowed %d - %d", x, y)
+               }
+               x -= 1<<63 - 1
+               x -= 1
+               y -= 1 << 63
+       }
+       if subWillUnderflow(x, int64(y)) {
+               base.Fatalf("subU underflowed %d - %d", x, y)
+       }
+       return x - int64(y)
+}
+
+// if v is known to be x - c, where x is known to be nonnegative and c is a
+// constant, return x, c. Otherwise return nil, 0.
+func findKNN(v *Value) (*Value, int64) {
+       var x, y *Value
+       x = v
+       switch v.Op {
+       case OpSub64, OpSub32, OpSub16, OpSub8:
+               x = v.Args[0]
+               y = v.Args[1]
+
+       case OpAdd64, OpAdd32, OpAdd16, OpAdd8:
+               x = v.Args[0]
+               y = v.Args[1]
+               if x.isGenericIntConst() {
+                       x, y = y, x
+               }
+       }
+       switch x.Op {
+       case OpSliceLen, OpStringLen, OpSliceCap:
+       default:
+               return nil, 0
+       }
+       if y == nil {
+               return x, 0
+       }
+       if !y.isGenericIntConst() {
+               return nil, 0
+       }
+       if v.Op == OpAdd64 || v.Op == OpAdd32 || v.Op == OpAdd16 || v.Op == OpAdd8 {
+               return x, -y.AuxInt
+       }
+       return x, y.AuxInt
 }
 
 func printIndVar(b *Block, i, min, max *Value, inc int64, flags indVarFlags) {
@@ -326,3 +427,11 @@ func printIndVar(b *Block, i, min, max *Value, inc int64, flags indVarFlags) {
        }
        b.Func.Warnl(b.Pos, "Induction variable: limits %v%v,%v%v, increment %d%s", mb1, mlim1, mlim2, mb2, inc, extra)
 }
+
+func minSignedValue(t *types.Type) int64 {
+       return -1 << (t.Size()*8 - 1)
+}
+
+func maxSignedValue(t *types.Type) int64 {
+       return 1<<((t.Size()*8)-1) - 1
+}