]> Cypherpunks.ru repositories - gostls13.git/blob - src/cmd/compile/internal/ssa/loopbce.go
cmd/compile: fix findIndVar so it does not match disjointed loop headers
[gostls13.git] / src / cmd / compile / internal / ssa / loopbce.go
1 // Copyright 2018 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package ssa
6
7 import (
8         "cmd/compile/internal/base"
9         "cmd/compile/internal/types"
10         "fmt"
11 )
12
13 type indVarFlags uint8
14
15 const (
16         indVarMinExc    indVarFlags = 1 << iota // minimum value is exclusive (default: inclusive)
17         indVarMaxInc                            // maximum value is inclusive (default: exclusive)
18         indVarCountDown                         // if set the iteration starts at max and count towards min (default: min towards max)
19 )
20
21 type indVar struct {
22         ind   *Value // induction variable
23         nxt   *Value // the incremented variable
24         min   *Value // minimum value, inclusive/exclusive depends on flags
25         max   *Value // maximum value, inclusive/exclusive depends on flags
26         entry *Block // entry block in the loop.
27         flags indVarFlags
28         // Invariant: for all blocks strictly dominated by entry:
29         //      min <= ind <  max    [if flags == 0]
30         //      min <  ind <  max    [if flags == indVarMinExc]
31         //      min <= ind <= max    [if flags == indVarMaxInc]
32         //      min <  ind <= max    [if flags == indVarMinExc|indVarMaxInc]
33 }
34
35 // parseIndVar checks whether the SSA value passed as argument is a valid induction
36 // variable, and, if so, extracts:
37 //   - the minimum bound
38 //   - the increment value
39 //   - the "next" value (SSA value that is Phi'd into the induction variable every loop)
40 //
41 // Currently, we detect induction variables that match (Phi min nxt),
42 // with nxt being (Add inc ind).
43 // If it can't parse the induction variable correctly, it returns (nil, nil, nil).
44 func parseIndVar(ind *Value) (min, inc, nxt *Value) {
45         if ind.Op != OpPhi {
46                 return
47         }
48
49         if n := ind.Args[0]; (n.Op == OpAdd64 || n.Op == OpAdd32 || n.Op == OpAdd16 || n.Op == OpAdd8) && (n.Args[0] == ind || n.Args[1] == ind) {
50                 min, nxt = ind.Args[1], n
51         } else if n := ind.Args[1]; (n.Op == OpAdd64 || n.Op == OpAdd32 || n.Op == OpAdd16 || n.Op == OpAdd8) && (n.Args[0] == ind || n.Args[1] == ind) {
52                 min, nxt = ind.Args[0], n
53         } else {
54                 // Not a recognized induction variable.
55                 return
56         }
57
58         if nxt.Args[0] == ind { // nxt = ind + inc
59                 inc = nxt.Args[1]
60         } else if nxt.Args[1] == ind { // nxt = inc + ind
61                 inc = nxt.Args[0]
62         } else {
63                 panic("unreachable") // one of the cases must be true from the above.
64         }
65
66         return
67 }
68
69 // findIndVar finds induction variables in a function.
70 //
71 // Look for variables and blocks that satisfy the following
72 //
73 //       loop:
74 //         ind = (Phi min nxt),
75 //         if ind < max
76 //           then goto enter_loop
77 //           else goto exit_loop
78 //
79 //         enter_loop:
80 //              do something
81 //            nxt = inc + ind
82 //              goto loop
83 //
84 //       exit_loop:
85 func findIndVar(f *Func) []indVar {
86         var iv []indVar
87         sdom := f.Sdom()
88
89         for _, b := range f.Blocks {
90                 if b.Kind != BlockIf || len(b.Preds) != 2 {
91                         continue
92                 }
93
94                 var ind *Value   // induction variable
95                 var init *Value  // starting value
96                 var limit *Value // ending value
97
98                 // Check that the control if it either ind </<= limit or limit </<= ind.
99                 // TODO: Handle unsigned comparisons?
100                 c := b.Controls[0]
101                 inclusive := false
102                 switch c.Op {
103                 case OpLeq64, OpLeq32, OpLeq16, OpLeq8:
104                         inclusive = true
105                         fallthrough
106                 case OpLess64, OpLess32, OpLess16, OpLess8:
107                         ind, limit = c.Args[0], c.Args[1]
108                 default:
109                         continue
110                 }
111
112                 // See if this is really an induction variable
113                 less := true
114                 init, inc, nxt := parseIndVar(ind)
115                 if init == nil {
116                         // We failed to parse the induction variable. Before punting, we want to check
117                         // whether the control op was written with the induction variable on the RHS
118                         // instead of the LHS. This happens for the downwards case, like:
119                         //     for i := len(n)-1; i >= 0; i--
120                         init, inc, nxt = parseIndVar(limit)
121                         if init == nil {
122                                 // No recognized induction variable on either operand
123                                 continue
124                         }
125
126                         // Ok, the arguments were reversed. Swap them, and remember that we're
127                         // looking at an ind >/>= loop (so the induction must be decrementing).
128                         ind, limit = limit, ind
129                         less = false
130                 }
131
132                 if ind.Block != b {
133                         // TODO: Could be extended to include disjointed loop headers.
134                         // I don't think this is causing missed optimizations in real world code often.
135                         // See https://go.dev/issue/63955
136                         continue
137                 }
138
139                 // Expect the increment to be a nonzero constant.
140                 if !inc.isGenericIntConst() {
141                         continue
142                 }
143                 step := inc.AuxInt
144                 if step == 0 {
145                         continue
146                 }
147
148                 // Increment sign must match comparison direction.
149                 // When incrementing, the termination comparison must be ind </<= limit.
150                 // When decrementing, the termination comparison must be ind >/>= limit.
151                 // See issue 26116.
152                 if step > 0 && !less {
153                         continue
154                 }
155                 if step < 0 && less {
156                         continue
157                 }
158
159                 // Up to now we extracted the induction variable (ind),
160                 // the increment delta (inc), the temporary sum (nxt),
161                 // the initial value (init) and the limiting value (limit).
162                 //
163                 // We also know that ind has the form (Phi init nxt) where
164                 // nxt is (Add inc nxt) which means: 1) inc dominates nxt
165                 // and 2) there is a loop starting at inc and containing nxt.
166                 //
167                 // We need to prove that the induction variable is incremented
168                 // only when it's smaller than the limiting value.
169                 // Two conditions must happen listed below to accept ind
170                 // as an induction variable.
171
172                 // First condition: loop entry has a single predecessor, which
173                 // is the header block.  This implies that b.Succs[0] is
174                 // reached iff ind < limit.
175                 if len(b.Succs[0].b.Preds) != 1 {
176                         // b.Succs[1] must exit the loop.
177                         continue
178                 }
179
180                 // Second condition: b.Succs[0] dominates nxt so that
181                 // nxt is computed when inc < limit.
182                 if !sdom.IsAncestorEq(b.Succs[0].b, nxt.Block) {
183                         // inc+ind can only be reached through the branch that enters the loop.
184                         continue
185                 }
186
187                 // Check for overflow/underflow. We need to make sure that inc never causes
188                 // the induction variable to wrap around.
189                 // We use a function wrapper here for easy return true / return false / keep going logic.
190                 // This function returns true if the increment will never overflow/underflow.
191                 ok := func() bool {
192                         if step > 0 {
193                                 if limit.isGenericIntConst() {
194                                         // Figure out the actual largest value.
195                                         v := limit.AuxInt
196                                         if !inclusive {
197                                                 if v == minSignedValue(limit.Type) {
198                                                         return false // < minint is never satisfiable.
199                                                 }
200                                                 v--
201                                         }
202                                         if init.isGenericIntConst() {
203                                                 // Use stride to compute a better lower limit.
204                                                 if init.AuxInt > v {
205                                                         return false
206                                                 }
207                                                 v = addU(init.AuxInt, diff(v, init.AuxInt)/uint64(step)*uint64(step))
208                                         }
209                                         if addWillOverflow(v, step) {
210                                                 return false
211                                         }
212                                         if inclusive && v != limit.AuxInt || !inclusive && v+1 != limit.AuxInt {
213                                                 // We know a better limit than the programmer did. Use our limit instead.
214                                                 limit = f.constVal(limit.Op, limit.Type, v, true)
215                                                 inclusive = true
216                                         }
217                                         return true
218                                 }
219                                 if step == 1 && !inclusive {
220                                         // Can't overflow because maxint is never a possible value.
221                                         return true
222                                 }
223                                 // If the limit is not a constant, check to see if it is a
224                                 // negative offset from a known non-negative value.
225                                 knn, k := findKNN(limit)
226                                 if knn == nil || k < 0 {
227                                         return false
228                                 }
229                                 // limit == (something nonnegative) - k. That subtraction can't underflow, so
230                                 // we can trust it.
231                                 if inclusive {
232                                         // ind <= knn - k cannot overflow if step is at most k
233                                         return step <= k
234                                 }
235                                 // ind < knn - k cannot overflow if step is at most k+1
236                                 return step <= k+1 && k != maxSignedValue(limit.Type)
237                         } else { // step < 0
238                                 if limit.Op == OpConst64 {
239                                         // Figure out the actual smallest value.
240                                         v := limit.AuxInt
241                                         if !inclusive {
242                                                 if v == maxSignedValue(limit.Type) {
243                                                         return false // > maxint is never satisfiable.
244                                                 }
245                                                 v++
246                                         }
247                                         if init.isGenericIntConst() {
248                                                 // Use stride to compute a better lower limit.
249                                                 if init.AuxInt < v {
250                                                         return false
251                                                 }
252                                                 v = subU(init.AuxInt, diff(init.AuxInt, v)/uint64(-step)*uint64(-step))
253                                         }
254                                         if subWillUnderflow(v, -step) {
255                                                 return false
256                                         }
257                                         if inclusive && v != limit.AuxInt || !inclusive && v-1 != limit.AuxInt {
258                                                 // We know a better limit than the programmer did. Use our limit instead.
259                                                 limit = f.constVal(limit.Op, limit.Type, v, true)
260                                                 inclusive = true
261                                         }
262                                         return true
263                                 }
264                                 if step == -1 && !inclusive {
265                                         // Can't underflow because minint is never a possible value.
266                                         return true
267                                 }
268                         }
269                         return false
270
271                 }
272
273                 if ok() {
274                         flags := indVarFlags(0)
275                         var min, max *Value
276                         if step > 0 {
277                                 min = init
278                                 max = limit
279                                 if inclusive {
280                                         flags |= indVarMaxInc
281                                 }
282                         } else {
283                                 min = limit
284                                 max = init
285                                 flags |= indVarMaxInc
286                                 if !inclusive {
287                                         flags |= indVarMinExc
288                                 }
289                                 flags |= indVarCountDown
290                                 step = -step
291                         }
292                         if f.pass.debug >= 1 {
293                                 printIndVar(b, ind, min, max, step, flags)
294                         }
295
296                         iv = append(iv, indVar{
297                                 ind:   ind,
298                                 nxt:   nxt,
299                                 min:   min,
300                                 max:   max,
301                                 entry: b.Succs[0].b,
302                                 flags: flags,
303                         })
304                         b.Logf("found induction variable %v (inc = %v, min = %v, max = %v)\n", ind, inc, min, max)
305                 }
306
307                 // TODO: other unrolling idioms
308                 // for i := 0; i < KNN - KNN % k ; i += k
309                 // for i := 0; i < KNN&^(k-1) ; i += k // k a power of 2
310                 // for i := 0; i < KNN&(-k) ; i += k // k a power of 2
311         }
312
313         return iv
314 }
315
316 // addWillOverflow reports whether x+y would result in a value more than maxint.
317 func addWillOverflow(x, y int64) bool {
318         return x+y < x
319 }
320
321 // subWillUnderflow reports whether x-y would result in a value less than minint.
322 func subWillUnderflow(x, y int64) bool {
323         return x-y > x
324 }
325
326 // diff returns x-y as a uint64. Requires x>=y.
327 func diff(x, y int64) uint64 {
328         if x < y {
329                 base.Fatalf("diff %d - %d underflowed", x, y)
330         }
331         return uint64(x - y)
332 }
333
334 // addU returns x+y. Requires that x+y does not overflow an int64.
335 func addU(x int64, y uint64) int64 {
336         if y >= 1<<63 {
337                 if x >= 0 {
338                         base.Fatalf("addU overflowed %d + %d", x, y)
339                 }
340                 x += 1<<63 - 1
341                 x += 1
342                 y -= 1 << 63
343         }
344         if addWillOverflow(x, int64(y)) {
345                 base.Fatalf("addU overflowed %d + %d", x, y)
346         }
347         return x + int64(y)
348 }
349
350 // subU returns x-y. Requires that x-y does not underflow an int64.
351 func subU(x int64, y uint64) int64 {
352         if y >= 1<<63 {
353                 if x < 0 {
354                         base.Fatalf("subU underflowed %d - %d", x, y)
355                 }
356                 x -= 1<<63 - 1
357                 x -= 1
358                 y -= 1 << 63
359         }
360         if subWillUnderflow(x, int64(y)) {
361                 base.Fatalf("subU underflowed %d - %d", x, y)
362         }
363         return x - int64(y)
364 }
365
366 // if v is known to be x - c, where x is known to be nonnegative and c is a
367 // constant, return x, c. Otherwise return nil, 0.
368 func findKNN(v *Value) (*Value, int64) {
369         var x, y *Value
370         x = v
371         switch v.Op {
372         case OpSub64, OpSub32, OpSub16, OpSub8:
373                 x = v.Args[0]
374                 y = v.Args[1]
375
376         case OpAdd64, OpAdd32, OpAdd16, OpAdd8:
377                 x = v.Args[0]
378                 y = v.Args[1]
379                 if x.isGenericIntConst() {
380                         x, y = y, x
381                 }
382         }
383         switch x.Op {
384         case OpSliceLen, OpStringLen, OpSliceCap:
385         default:
386                 return nil, 0
387         }
388         if y == nil {
389                 return x, 0
390         }
391         if !y.isGenericIntConst() {
392                 return nil, 0
393         }
394         if v.Op == OpAdd64 || v.Op == OpAdd32 || v.Op == OpAdd16 || v.Op == OpAdd8 {
395                 return x, -y.AuxInt
396         }
397         return x, y.AuxInt
398 }
399
400 func printIndVar(b *Block, i, min, max *Value, inc int64, flags indVarFlags) {
401         mb1, mb2 := "[", "]"
402         if flags&indVarMinExc != 0 {
403                 mb1 = "("
404         }
405         if flags&indVarMaxInc == 0 {
406                 mb2 = ")"
407         }
408
409         mlim1, mlim2 := fmt.Sprint(min.AuxInt), fmt.Sprint(max.AuxInt)
410         if !min.isGenericIntConst() {
411                 if b.Func.pass.debug >= 2 {
412                         mlim1 = fmt.Sprint(min)
413                 } else {
414                         mlim1 = "?"
415                 }
416         }
417         if !max.isGenericIntConst() {
418                 if b.Func.pass.debug >= 2 {
419                         mlim2 = fmt.Sprint(max)
420                 } else {
421                         mlim2 = "?"
422                 }
423         }
424         extra := ""
425         if b.Func.pass.debug >= 2 {
426                 extra = fmt.Sprintf(" (%s)", i)
427         }
428         b.Func.Warnl(b.Pos, "Induction variable: limits %v%v,%v%v, increment %d%s", mb1, mlim1, mlim2, mb2, inc, extra)
429 }
430
431 func minSignedValue(t *types.Type) int64 {
432         return -1 << (t.Size()*8 - 1)
433 }
434
435 func maxSignedValue(t *types.Type) int64 {
436         return 1<<((t.Size()*8)-1) - 1
437 }