]> Cypherpunks.ru repositories - gostls13.git/blobdiff - src/cmd/compile/internal/ssa/loopbce.go
cmd/compile: fix findIndVar so it does not match disjointed loop headers
[gostls13.git] / src / cmd / compile / internal / ssa / loopbce.go
index 9bd2d3f0de97d5a8dc5afe8e60f0fa957be0ef7c..dd1f39dbef74398f7cf9e338d26d6f9b1a8f4ca7 100644 (file)
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
 package ssa
 
+import (
+       "cmd/compile/internal/base"
+       "cmd/compile/internal/types"
+       "fmt"
+)
+
+type indVarFlags uint8
+
+const (
+       indVarMinExc    indVarFlags = 1 << iota // minimum value is exclusive (default: inclusive)
+       indVarMaxInc                            // maximum value is inclusive (default: exclusive)
+       indVarCountDown                         // if set the iteration starts at max and count towards min (default: min towards max)
+)
+
 type indVar struct {
        ind   *Value // induction variable
-       inc   *Value // increment, a constant
-       nxt   *Value // ind+inc variable
-       min   *Value // minimum value. inclusive,
-       max   *Value // maximum value. exclusive.
+       nxt   *Value // the incremented variable
+       min   *Value // minimum value, inclusive/exclusive depends on flags
+       max   *Value // maximum value, inclusive/exclusive depends on flags
        entry *Block // entry block in the loop.
-       // Invariants: for all blocks dominated by entry:
-       //      min <= ind < max
-       //      min <= nxt <= max
+       flags indVarFlags
+       // Invariant: for all blocks strictly dominated by entry:
+       //      min <= ind <  max    [if flags == 0]
+       //      min <  ind <  max    [if flags == indVarMinExc]
+       //      min <= ind <= max    [if flags == indVarMaxInc]
+       //      min <  ind <= max    [if flags == indVarMinExc|indVarMaxInc]
+}
+
+// parseIndVar checks whether the SSA value passed as argument is a valid induction
+// variable, and, if so, extracts:
+//   - the minimum bound
+//   - the increment value
+//   - the "next" value (SSA value that is Phi'd into the induction variable every loop)
+//
+// Currently, we detect induction variables that match (Phi min nxt),
+// with nxt being (Add inc ind).
+// If it can't parse the induction variable correctly, it returns (nil, nil, nil).
+func parseIndVar(ind *Value) (min, inc, nxt *Value) {
+       if ind.Op != OpPhi {
+               return
+       }
+
+       if n := ind.Args[0]; (n.Op == OpAdd64 || n.Op == OpAdd32 || n.Op == OpAdd16 || n.Op == OpAdd8) && (n.Args[0] == ind || n.Args[1] == ind) {
+               min, nxt = ind.Args[1], n
+       } else if n := ind.Args[1]; (n.Op == OpAdd64 || n.Op == OpAdd32 || n.Op == OpAdd16 || n.Op == OpAdd8) && (n.Args[0] == ind || n.Args[1] == ind) {
+               min, nxt = ind.Args[0], n
+       } else {
+               // Not a recognized induction variable.
+               return
+       }
+
+       if nxt.Args[0] == ind { // nxt = ind + inc
+               inc = nxt.Args[1]
+       } else if nxt.Args[1] == ind { // nxt = inc + ind
+               inc = nxt.Args[0]
+       } else {
+               panic("unreachable") // one of the cases must be true from the above.
+       }
+
+       return
 }
 
 // findIndVar finds induction variables in a function.
 //
 // Look for variables and blocks that satisfy the following
 //
-// loop:
-//   ind = (Phi min nxt),
-//   if ind < max
-//     then goto enter_loop
-//     else goto exit_loop
+//      loop:
+//        ind = (Phi min nxt),
+//        if ind < max
+//          then goto enter_loop
+//          else goto exit_loop
 //
-//   enter_loop:
-//     do something
-//      nxt = inc + ind
-//     goto loop
+//        enter_loop:
+//             do something
+//           nxt = inc + ind
+//             goto loop
 //
-// exit_loop:
-//
-//
-// TODO: handle 32 bit operations
+//      exit_loop:
 func findIndVar(f *Func) []indVar {
        var iv []indVar
+       sdom := f.Sdom()
 
-nextb:
        for _, b := range f.Blocks {
                if b.Kind != BlockIf || len(b.Preds) != 2 {
                        continue
                }
 
-               var ind, max *Value // induction, and maximum
-               entry := -1         // which successor of b enters the loop
-
-               // Check thet the control if it either ind < max or max > ind.
-               // TODO: Handle Leq64, Geq64.
-               switch b.Control.Op {
-               case OpLess64:
-                       entry = 0
-                       ind, max = b.Control.Args[0], b.Control.Args[1]
-               case OpGreater64:
-                       entry = 0
-                       ind, max = b.Control.Args[1], b.Control.Args[0]
+               var ind *Value   // induction variable
+               var init *Value  // starting value
+               var limit *Value // ending value
+
+               // Check that the control if it either ind </<= limit or limit </<= ind.
+               // TODO: Handle unsigned comparisons?
+               c := b.Controls[0]
+               inclusive := false
+               switch c.Op {
+               case OpLeq64, OpLeq32, OpLeq16, OpLeq8:
+                       inclusive = true
+                       fallthrough
+               case OpLess64, OpLess32, OpLess16, OpLess8:
+                       ind, limit = c.Args[0], c.Args[1]
                default:
-                       continue nextb
+                       continue
                }
 
-               // Check that the induction variable is a phi that depends on itself.
-               if ind.Op != OpPhi {
-                       continue
+               // See if this is really an induction variable
+               less := true
+               init, inc, nxt := parseIndVar(ind)
+               if init == nil {
+                       // We failed to parse the induction variable. Before punting, we want to check
+                       // whether the control op was written with the induction variable on the RHS
+                       // instead of the LHS. This happens for the downwards case, like:
+                       //     for i := len(n)-1; i >= 0; i--
+                       init, inc, nxt = parseIndVar(limit)
+                       if init == nil {
+                               // No recognized induction variable on either operand
+                               continue
+                       }
+
+                       // Ok, the arguments were reversed. Swap them, and remember that we're
+                       // looking at an ind >/>= loop (so the induction must be decrementing).
+                       ind, limit = limit, ind
+                       less = false
                }
 
-               // Extract min and nxt knowing that nxt is an addition (e.g. Add64).
-               var min, nxt *Value // minimum, and next value
-               if n := ind.Args[0]; n.Op == OpAdd64 && (n.Args[0] == ind || n.Args[1] == ind) {
-                       min, nxt = ind.Args[1], n
-               } else if n := ind.Args[1]; n.Op == OpAdd64 && (n.Args[0] == ind || n.Args[1] == ind) {
-                       min, nxt = ind.Args[0], n
-               } else {
-                       // Not a recognized induction variable.
+               if ind.Block != b {
+                       // TODO: Could be extended to include disjointed loop headers.
+                       // I don't think this is causing missed optimizations in real world code often.
+                       // See https://go.dev/issue/63955
                        continue
                }
 
-               var inc *Value
-               if nxt.Args[0] == ind { // nxt = ind + inc
-                       inc = nxt.Args[1]
-               } else if nxt.Args[1] == ind { // nxt = inc + ind
-                       inc = nxt.Args[0]
-               } else {
-                       panic("unreachable") // one of the cases must be true from the above.
+               // Expect the increment to be a nonzero constant.
+               if !inc.isGenericIntConst() {
+                       continue
+               }
+               step := inc.AuxInt
+               if step == 0 {
+                       continue
                }
 
-               // Expect the increment to be a positive constant.
-               // TODO: handle negative increment.
-               if inc.Op != OpConst64 || inc.AuxInt <= 0 {
+               // Increment sign must match comparison direction.
+               // When incrementing, the termination comparison must be ind </<= limit.
+               // When decrementing, the termination comparison must be ind >/>= limit.
+               // See issue 26116.
+               if step > 0 && !less {
+                       continue
+               }
+               if step < 0 && less {
                        continue
                }
 
                // Up to now we extracted the induction variable (ind),
                // the increment delta (inc), the temporary sum (nxt),
-               // the mininum value (min) and the maximum value (max).
+               // the initial value (init) and the limiting value (limit).
                //
-               // We also know that ind has the form (Phi min nxt) where
+               // We also know that ind has the form (Phi init nxt) where
                // nxt is (Add inc nxt) which means: 1) inc dominates nxt
                // and 2) there is a loop starting at inc and containing nxt.
                //
                // We need to prove that the induction variable is incremented
-               // only when it's smaller than the maximum value.
+               // only when it's smaller than the limiting value.
                // Two conditions must happen listed below to accept ind
                // as an induction variable.
 
                // First condition: loop entry has a single predecessor, which
-               // is the header block.  This implies that b.Succs[entry] is
-               // reached iff ind < max.
-               if len(b.Succs[entry].Preds) != 1 {
-                       // b.Succs[1-entry] must exit the loop.
+               // is the header block.  This implies that b.Succs[0] is
+               // reached iff ind < limit.
+               if len(b.Succs[0].b.Preds) != 1 {
+                       // b.Succs[1] must exit the loop.
                        continue
                }
 
-               // Second condition: b.Succs[entry] dominates nxt so that
-               // nxt is computed when inc < max, meaning nxt <= max.
-               if !f.sdom.isAncestorEq(b.Succs[entry], nxt.Block) {
+               // Second condition: b.Succs[0] dominates nxt so that
+               // nxt is computed when inc < limit.
+               if !sdom.IsAncestorEq(b.Succs[0].b, nxt.Block) {
                        // inc+ind can only be reached through the branch that enters the loop.
                        continue
                }
 
-               // If max is c + SliceLen with c <= 0 then we drop c.
-               // Makes sure c + SliceLen doesn't overflow when SliceLen == 0.
-               // TODO: save c as an offset from max.
-               if w, c := dropAdd64(max); (w.Op == OpStringLen || w.Op == OpSliceLen) && 0 >= c && -c >= 0 {
-                       max = w
-               }
-
-               // We can only guarantee that the loops runs within limits of induction variable
-               // if the increment is 1 or when the limits are constants.
-               if inc.AuxInt != 1 {
-                       ok := false
-                       if min.Op == OpConst64 && max.Op == OpConst64 {
-                               if max.AuxInt > min.AuxInt && max.AuxInt%inc.AuxInt == min.AuxInt%inc.AuxInt { // handle overflow
-                                       ok = true
+               // Check for overflow/underflow. We need to make sure that inc never causes
+               // the induction variable to wrap around.
+               // We use a function wrapper here for easy return true / return false / keep going logic.
+               // This function returns true if the increment will never overflow/underflow.
+               ok := func() bool {
+                       if step > 0 {
+                               if limit.isGenericIntConst() {
+                                       // Figure out the actual largest value.
+                                       v := limit.AuxInt
+                                       if !inclusive {
+                                               if v == minSignedValue(limit.Type) {
+                                                       return false // < minint is never satisfiable.
+                                               }
+                                               v--
+                                       }
+                                       if init.isGenericIntConst() {
+                                               // Use stride to compute a better lower limit.
+                                               if init.AuxInt > v {
+                                                       return false
+                                               }
+                                               v = addU(init.AuxInt, diff(v, init.AuxInt)/uint64(step)*uint64(step))
+                                       }
+                                       if addWillOverflow(v, step) {
+                                               return false
+                                       }
+                                       if inclusive && v != limit.AuxInt || !inclusive && v+1 != limit.AuxInt {
+                                               // We know a better limit than the programmer did. Use our limit instead.
+                                               limit = f.constVal(limit.Op, limit.Type, v, true)
+                                               inclusive = true
+                                       }
+                                       return true
+                               }
+                               if step == 1 && !inclusive {
+                                       // Can't overflow because maxint is never a possible value.
+                                       return true
+                               }
+                               // If the limit is not a constant, check to see if it is a
+                               // negative offset from a known non-negative value.
+                               knn, k := findKNN(limit)
+                               if knn == nil || k < 0 {
+                                       return false
+                               }
+                               // limit == (something nonnegative) - k. That subtraction can't underflow, so
+                               // we can trust it.
+                               if inclusive {
+                                       // ind <= knn - k cannot overflow if step is at most k
+                                       return step <= k
+                               }
+                               // ind < knn - k cannot overflow if step is at most k+1
+                               return step <= k+1 && k != maxSignedValue(limit.Type)
+                       } else { // step < 0
+                               if limit.Op == OpConst64 {
+                                       // Figure out the actual smallest value.
+                                       v := limit.AuxInt
+                                       if !inclusive {
+                                               if v == maxSignedValue(limit.Type) {
+                                                       return false // > maxint is never satisfiable.
+                                               }
+                                               v++
+                                       }
+                                       if init.isGenericIntConst() {
+                                               // Use stride to compute a better lower limit.
+                                               if init.AuxInt < v {
+                                                       return false
+                                               }
+                                               v = subU(init.AuxInt, diff(init.AuxInt, v)/uint64(-step)*uint64(-step))
+                                       }
+                                       if subWillUnderflow(v, -step) {
+                                               return false
+                                       }
+                                       if inclusive && v != limit.AuxInt || !inclusive && v-1 != limit.AuxInt {
+                                               // We know a better limit than the programmer did. Use our limit instead.
+                                               limit = f.constVal(limit.Op, limit.Type, v, true)
+                                               inclusive = true
+                                       }
+                                       return true
+                               }
+                               if step == -1 && !inclusive {
+                                       // Can't underflow because minint is never a possible value.
+                                       return true
                                }
                        }
-                       if !ok {
-                               continue
-                       }
+                       return false
+
                }
 
-               if f.pass.debug > 1 {
-                       if min.Op == OpConst64 {
-                               b.Func.Config.Warnl(b.Line, "Induction variable with minimum %d and increment %d", min.AuxInt, inc.AuxInt)
+               if ok() {
+                       flags := indVarFlags(0)
+                       var min, max *Value
+                       if step > 0 {
+                               min = init
+                               max = limit
+                               if inclusive {
+                                       flags |= indVarMaxInc
+                               }
                        } else {
-                               b.Func.Config.Warnl(b.Line, "Induction variable with non-const minimum and increment %d", inc.AuxInt)
+                               min = limit
+                               max = init
+                               flags |= indVarMaxInc
+                               if !inclusive {
+                                       flags |= indVarMinExc
+                               }
+                               flags |= indVarCountDown
+                               step = -step
                        }
+                       if f.pass.debug >= 1 {
+                               printIndVar(b, ind, min, max, step, flags)
+                       }
+
+                       iv = append(iv, indVar{
+                               ind:   ind,
+                               nxt:   nxt,
+                               min:   min,
+                               max:   max,
+                               entry: b.Succs[0].b,
+                               flags: flags,
+                       })
+                       b.Logf("found induction variable %v (inc = %v, min = %v, max = %v)\n", ind, inc, min, max)
                }
 
-               iv = append(iv, indVar{
-                       ind:   ind,
-                       inc:   inc,
-                       nxt:   nxt,
-                       min:   min,
-                       max:   max,
-                       entry: b.Succs[entry],
-               })
-               b.Logf("found induction variable %v (inc = %v, min = %v, max = %v)\n", ind, inc, min, max)
+               // TODO: other unrolling idioms
+               // for i := 0; i < KNN - KNN % k ; i += k
+               // for i := 0; i < KNN&^(k-1) ; i += k // k a power of 2
+               // for i := 0; i < KNN&(-k) ; i += k // k a power of 2
        }
 
        return iv
 }
 
-// loopbce performs loop based bounds check elimination.
-func loopbce(f *Func) {
-       ivList := findIndVar(f)
-
-       m := make(map[*Value]indVar)
-       for _, iv := range ivList {
-               m[iv.ind] = iv
-       }
-
-       removeBoundsChecks(f, m)
+// addWillOverflow reports whether x+y would result in a value more than maxint.
+func addWillOverflow(x, y int64) bool {
+       return x+y < x
 }
 
-// removesBoundsChecks remove IsInBounds and IsSliceInBounds based on the induction variables.
-func removeBoundsChecks(f *Func, m map[*Value]indVar) {
-       for _, b := range f.Blocks {
-               if b.Kind != BlockIf {
-                       continue
-               }
+// subWillUnderflow reports whether x-y would result in a value less than minint.
+func subWillUnderflow(x, y int64) bool {
+       return x-y > x
+}
 
-               v := b.Control
-
-               // Simplify:
-               // (IsInBounds ind max) where 0 <= const == min <= ind < max.
-               // (IsSliceInBounds ind max) where 0 <= const == min <= ind < max.
-               // Found in:
-               //      for i := range a {
-               //              use a[i]
-               //              use a[i:]
-               //              use a[:i]
-               //      }
-               if v.Op == OpIsInBounds || v.Op == OpIsSliceInBounds {
-                       ind, add := dropAdd64(v.Args[0])
-                       if ind.Op != OpPhi {
-                               goto skip1
-                       }
-                       if v.Op == OpIsInBounds && add != 0 {
-                               goto skip1
-                       }
-                       if v.Op == OpIsSliceInBounds && (0 > add || add > 1) {
-                               goto skip1
-                       }
+// diff returns x-y as a uint64. Requires x>=y.
+func diff(x, y int64) uint64 {
+       if x < y {
+               base.Fatalf("diff %d - %d underflowed", x, y)
+       }
+       return uint64(x - y)
+}
 
-                       if iv, has := m[ind]; has && f.sdom.isAncestorEq(iv.entry, b) && isNonNegative(iv.min) {
-                               if v.Args[1] == iv.max {
-                                       if f.pass.debug > 0 {
-                                               f.Config.Warnl(b.Line, "Found redundant %s", v.Op)
-                                       }
-                                       goto simplify
-                               }
-                       }
+// addU returns x+y. Requires that x+y does not overflow an int64.
+func addU(x int64, y uint64) int64 {
+       if y >= 1<<63 {
+               if x >= 0 {
+                       base.Fatalf("addU overflowed %d + %d", x, y)
                }
-       skip1:
-
-               // Simplify:
-               // (IsSliceInBounds ind (SliceCap a)) where 0 <= min <= ind < max == (SliceLen a)
-               // Found in:
-               //      for i := range a {
-               //              use a[:i]
-               //              use a[:i+1]
-               //      }
-               if v.Op == OpIsSliceInBounds {
-                       ind, add := dropAdd64(v.Args[0])
-                       if ind.Op != OpPhi {
-                               goto skip2
-                       }
-                       if 0 > add || add > 1 {
-                               goto skip2
-                       }
+               x += 1<<63 - 1
+               x += 1
+               y -= 1 << 63
+       }
+       if addWillOverflow(x, int64(y)) {
+               base.Fatalf("addU overflowed %d + %d", x, y)
+       }
+       return x + int64(y)
+}
 
-                       if iv, has := m[ind]; has && f.sdom.isAncestorEq(iv.entry, b) && isNonNegative(iv.min) {
-                               if v.Args[1].Op == OpSliceCap && iv.max.Op == OpSliceLen && v.Args[1].Args[0] == iv.max.Args[0] {
-                                       if f.pass.debug > 0 {
-                                               f.Config.Warnl(b.Line, "Found redundant %s (len promoted to cap)", v.Op)
-                                       }
-                                       goto simplify
-                               }
-                       }
+// subU returns x-y. Requires that x-y does not underflow an int64.
+func subU(x int64, y uint64) int64 {
+       if y >= 1<<63 {
+               if x < 0 {
+                       base.Fatalf("subU underflowed %d - %d", x, y)
                }
-       skip2:
-
-               // Simplify
-               // (IsInBounds (Add64 ind) (Const64 [c])) where 0 <= min <= ind < max <= (Const64 [c])
-               // (IsSliceInBounds ind (Const64 [c])) where 0 <= min <= ind < max <= (Const64 [c])
-               if v.Op == OpIsInBounds || v.Op == OpIsSliceInBounds {
-                       ind, add := dropAdd64(v.Args[0])
-                       if ind.Op != OpPhi {
-                               goto skip3
-                       }
-
-                       // ind + add >= 0 <-> min + add >= 0 <-> min >= -add
-                       if iv, has := m[ind]; has && f.sdom.isAncestorEq(iv.entry, b) && isGreaterOrEqualThan(iv.min, -add) {
-                               if !v.Args[1].isGenericIntConst() || !iv.max.isGenericIntConst() {
-                                       goto skip3
-                               }
-
-                               limit := v.Args[1].AuxInt
-                               if v.Op == OpIsSliceInBounds {
-                                       // If limit++ overflows signed integer then 0 <= max && max <= limit will be false.
-                                       limit++
-                               }
+               x -= 1<<63 - 1
+               x -= 1
+               y -= 1 << 63
+       }
+       if subWillUnderflow(x, int64(y)) {
+               base.Fatalf("subU underflowed %d - %d", x, y)
+       }
+       return x - int64(y)
+}
 
-                               if max := iv.max.AuxInt + add; 0 <= max && max <= limit { // handle overflow
-                                       if f.pass.debug > 0 {
-                                               f.Config.Warnl(b.Line, "Found redundant (%s ind %d), ind < %d", v.Op, v.Args[1].AuxInt, iv.max.AuxInt+add)
-                                       }
-                                       goto simplify
-                               }
-                       }
+// if v is known to be x - c, where x is known to be nonnegative and c is a
+// constant, return x, c. Otherwise return nil, 0.
+func findKNN(v *Value) (*Value, int64) {
+       var x, y *Value
+       x = v
+       switch v.Op {
+       case OpSub64, OpSub32, OpSub16, OpSub8:
+               x = v.Args[0]
+               y = v.Args[1]
+
+       case OpAdd64, OpAdd32, OpAdd16, OpAdd8:
+               x = v.Args[0]
+               y = v.Args[1]
+               if x.isGenericIntConst() {
+                       x, y = y, x
                }
-       skip3:
-
-               continue
-
-       simplify:
-               f.Logf("removing bounds check %v at %v in %s\n", b.Control, b, f.Name)
-               b.Kind = BlockFirst
-               b.SetControl(nil)
        }
+       switch x.Op {
+       case OpSliceLen, OpStringLen, OpSliceCap:
+       default:
+               return nil, 0
+       }
+       if y == nil {
+               return x, 0
+       }
+       if !y.isGenericIntConst() {
+               return nil, 0
+       }
+       if v.Op == OpAdd64 || v.Op == OpAdd32 || v.Op == OpAdd16 || v.Op == OpAdd8 {
+               return x, -y.AuxInt
+       }
+       return x, y.AuxInt
 }
 
-func dropAdd64(v *Value) (*Value, int64) {
-       if v.Op == OpAdd64 && v.Args[0].Op == OpConst64 {
-               return v.Args[1], v.Args[0].AuxInt
+func printIndVar(b *Block, i, min, max *Value, inc int64, flags indVarFlags) {
+       mb1, mb2 := "[", "]"
+       if flags&indVarMinExc != 0 {
+               mb1 = "("
        }
-       if v.Op == OpAdd64 && v.Args[1].Op == OpConst64 {
-               return v.Args[0], v.Args[1].AuxInt
+       if flags&indVarMaxInc == 0 {
+               mb2 = ")"
        }
-       return v, 0
-}
 
-func isGreaterOrEqualThan(v *Value, c int64) bool {
-       if c == 0 {
-               return isNonNegative(v)
+       mlim1, mlim2 := fmt.Sprint(min.AuxInt), fmt.Sprint(max.AuxInt)
+       if !min.isGenericIntConst() {
+               if b.Func.pass.debug >= 2 {
+                       mlim1 = fmt.Sprint(min)
+               } else {
+                       mlim1 = "?"
+               }
+       }
+       if !max.isGenericIntConst() {
+               if b.Func.pass.debug >= 2 {
+                       mlim2 = fmt.Sprint(max)
+               } else {
+                       mlim2 = "?"
+               }
        }
-       if v.isGenericIntConst() && v.AuxInt >= c {
-               return true
+       extra := ""
+       if b.Func.pass.debug >= 2 {
+               extra = fmt.Sprintf(" (%s)", i)
        }
-       return false
+       b.Func.Warnl(b.Pos, "Induction variable: limits %v%v,%v%v, increment %d%s", mb1, mlim1, mlim2, mb2, inc, extra)
+}
+
+func minSignedValue(t *types.Type) int64 {
+       return -1 << (t.Size()*8 - 1)
+}
+
+func maxSignedValue(t *types.Type) int64 {
+       return 1<<((t.Size()*8)-1) - 1
 }