]> Cypherpunks.ru repositories - gostls13.git/blob - src/cmd/compile/internal/ssa/loopbce.go
cmd/compile: try to rewrite loops to count down
[gostls13.git] / src / cmd / compile / internal / ssa / loopbce.go
1 // Copyright 2018 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package ssa
6
7 import (
8         "cmd/compile/internal/base"
9         "cmd/compile/internal/types"
10         "fmt"
11 )
12
13 type indVarFlags uint8
14
15 const (
16         indVarMinExc    indVarFlags = 1 << iota // minimum value is exclusive (default: inclusive)
17         indVarMaxInc                            // maximum value is inclusive (default: exclusive)
18         indVarCountDown                         // if set the iteration starts at max and count towards min (default: min towards max)
19 )
20
21 type indVar struct {
22         ind   *Value // induction variable
23         nxt   *Value // the incremented variable
24         min   *Value // minimum value, inclusive/exclusive depends on flags
25         max   *Value // maximum value, inclusive/exclusive depends on flags
26         entry *Block // entry block in the loop.
27         flags indVarFlags
28         // Invariant: for all blocks strictly dominated by entry:
29         //      min <= ind <  max    [if flags == 0]
30         //      min <  ind <  max    [if flags == indVarMinExc]
31         //      min <= ind <= max    [if flags == indVarMaxInc]
32         //      min <  ind <= max    [if flags == indVarMinExc|indVarMaxInc]
33 }
34
35 // parseIndVar checks whether the SSA value passed as argument is a valid induction
36 // variable, and, if so, extracts:
37 //   - the minimum bound
38 //   - the increment value
39 //   - the "next" value (SSA value that is Phi'd into the induction variable every loop)
40 //
41 // Currently, we detect induction variables that match (Phi min nxt),
42 // with nxt being (Add inc ind).
43 // If it can't parse the induction variable correctly, it returns (nil, nil, nil).
44 func parseIndVar(ind *Value) (min, inc, nxt *Value) {
45         if ind.Op != OpPhi {
46                 return
47         }
48
49         if n := ind.Args[0]; (n.Op == OpAdd64 || n.Op == OpAdd32 || n.Op == OpAdd16 || n.Op == OpAdd8) && (n.Args[0] == ind || n.Args[1] == ind) {
50                 min, nxt = ind.Args[1], n
51         } else if n := ind.Args[1]; (n.Op == OpAdd64 || n.Op == OpAdd32 || n.Op == OpAdd16 || n.Op == OpAdd8) && (n.Args[0] == ind || n.Args[1] == ind) {
52                 min, nxt = ind.Args[0], n
53         } else {
54                 // Not a recognized induction variable.
55                 return
56         }
57
58         if nxt.Args[0] == ind { // nxt = ind + inc
59                 inc = nxt.Args[1]
60         } else if nxt.Args[1] == ind { // nxt = inc + ind
61                 inc = nxt.Args[0]
62         } else {
63                 panic("unreachable") // one of the cases must be true from the above.
64         }
65
66         return
67 }
68
69 // findIndVar finds induction variables in a function.
70 //
71 // Look for variables and blocks that satisfy the following
72 //
73 //       loop:
74 //         ind = (Phi min nxt),
75 //         if ind < max
76 //           then goto enter_loop
77 //           else goto exit_loop
78 //
79 //         enter_loop:
80 //              do something
81 //            nxt = inc + ind
82 //              goto loop
83 //
84 //       exit_loop:
85 func findIndVar(f *Func) []indVar {
86         var iv []indVar
87         sdom := f.Sdom()
88
89         for _, b := range f.Blocks {
90                 if b.Kind != BlockIf || len(b.Preds) != 2 {
91                         continue
92                 }
93
94                 var ind *Value   // induction variable
95                 var init *Value  // starting value
96                 var limit *Value // ending value
97
98                 // Check that the control if it either ind </<= limit or limit </<= ind.
99                 // TODO: Handle unsigned comparisons?
100                 c := b.Controls[0]
101                 inclusive := false
102                 switch c.Op {
103                 case OpLeq64, OpLeq32, OpLeq16, OpLeq8:
104                         inclusive = true
105                         fallthrough
106                 case OpLess64, OpLess32, OpLess16, OpLess8:
107                         ind, limit = c.Args[0], c.Args[1]
108                 default:
109                         continue
110                 }
111
112                 // See if this is really an induction variable
113                 less := true
114                 init, inc, nxt := parseIndVar(ind)
115                 if init == nil {
116                         // We failed to parse the induction variable. Before punting, we want to check
117                         // whether the control op was written with the induction variable on the RHS
118                         // instead of the LHS. This happens for the downwards case, like:
119                         //     for i := len(n)-1; i >= 0; i--
120                         init, inc, nxt = parseIndVar(limit)
121                         if init == nil {
122                                 // No recognized induction variable on either operand
123                                 continue
124                         }
125
126                         // Ok, the arguments were reversed. Swap them, and remember that we're
127                         // looking at an ind >/>= loop (so the induction must be decrementing).
128                         ind, limit = limit, ind
129                         less = false
130                 }
131
132                 // Expect the increment to be a nonzero constant.
133                 if !inc.isGenericIntConst() {
134                         continue
135                 }
136                 step := inc.AuxInt
137                 if step == 0 {
138                         continue
139                 }
140
141                 // Increment sign must match comparison direction.
142                 // When incrementing, the termination comparison must be ind </<= limit.
143                 // When decrementing, the termination comparison must be ind >/>= limit.
144                 // See issue 26116.
145                 if step > 0 && !less {
146                         continue
147                 }
148                 if step < 0 && less {
149                         continue
150                 }
151
152                 // Up to now we extracted the induction variable (ind),
153                 // the increment delta (inc), the temporary sum (nxt),
154                 // the initial value (init) and the limiting value (limit).
155                 //
156                 // We also know that ind has the form (Phi init nxt) where
157                 // nxt is (Add inc nxt) which means: 1) inc dominates nxt
158                 // and 2) there is a loop starting at inc and containing nxt.
159                 //
160                 // We need to prove that the induction variable is incremented
161                 // only when it's smaller than the limiting value.
162                 // Two conditions must happen listed below to accept ind
163                 // as an induction variable.
164
165                 // First condition: loop entry has a single predecessor, which
166                 // is the header block.  This implies that b.Succs[0] is
167                 // reached iff ind < limit.
168                 if len(b.Succs[0].b.Preds) != 1 {
169                         // b.Succs[1] must exit the loop.
170                         continue
171                 }
172
173                 // Second condition: b.Succs[0] dominates nxt so that
174                 // nxt is computed when inc < limit.
175                 if !sdom.IsAncestorEq(b.Succs[0].b, nxt.Block) {
176                         // inc+ind can only be reached through the branch that enters the loop.
177                         continue
178                 }
179
180                 // Check for overflow/underflow. We need to make sure that inc never causes
181                 // the induction variable to wrap around.
182                 // We use a function wrapper here for easy return true / return false / keep going logic.
183                 // This function returns true if the increment will never overflow/underflow.
184                 ok := func() bool {
185                         if step > 0 {
186                                 if limit.isGenericIntConst() {
187                                         // Figure out the actual largest value.
188                                         v := limit.AuxInt
189                                         if !inclusive {
190                                                 if v == minSignedValue(limit.Type) {
191                                                         return false // < minint is never satisfiable.
192                                                 }
193                                                 v--
194                                         }
195                                         if init.isGenericIntConst() {
196                                                 // Use stride to compute a better lower limit.
197                                                 if init.AuxInt > v {
198                                                         return false
199                                                 }
200                                                 v = addU(init.AuxInt, diff(v, init.AuxInt)/uint64(step)*uint64(step))
201                                         }
202                                         if addWillOverflow(v, step) {
203                                                 return false
204                                         }
205                                         if inclusive && v != limit.AuxInt || !inclusive && v+1 != limit.AuxInt {
206                                                 // We know a better limit than the programmer did. Use our limit instead.
207                                                 limit = f.constVal(limit.Op, limit.Type, v, true)
208                                                 inclusive = true
209                                         }
210                                         return true
211                                 }
212                                 if step == 1 && !inclusive {
213                                         // Can't overflow because maxint is never a possible value.
214                                         return true
215                                 }
216                                 // If the limit is not a constant, check to see if it is a
217                                 // negative offset from a known non-negative value.
218                                 knn, k := findKNN(limit)
219                                 if knn == nil || k < 0 {
220                                         return false
221                                 }
222                                 // limit == (something nonnegative) - k. That subtraction can't underflow, so
223                                 // we can trust it.
224                                 if inclusive {
225                                         // ind <= knn - k cannot overflow if step is at most k
226                                         return step <= k
227                                 }
228                                 // ind < knn - k cannot overflow if step is at most k+1
229                                 return step <= k+1 && k != maxSignedValue(limit.Type)
230                         } else { // step < 0
231                                 if limit.Op == OpConst64 {
232                                         // Figure out the actual smallest value.
233                                         v := limit.AuxInt
234                                         if !inclusive {
235                                                 if v == maxSignedValue(limit.Type) {
236                                                         return false // > maxint is never satisfiable.
237                                                 }
238                                                 v++
239                                         }
240                                         if init.isGenericIntConst() {
241                                                 // Use stride to compute a better lower limit.
242                                                 if init.AuxInt < v {
243                                                         return false
244                                                 }
245                                                 v = subU(init.AuxInt, diff(init.AuxInt, v)/uint64(-step)*uint64(-step))
246                                         }
247                                         if subWillUnderflow(v, -step) {
248                                                 return false
249                                         }
250                                         if inclusive && v != limit.AuxInt || !inclusive && v-1 != limit.AuxInt {
251                                                 // We know a better limit than the programmer did. Use our limit instead.
252                                                 limit = f.constVal(limit.Op, limit.Type, v, true)
253                                                 inclusive = true
254                                         }
255                                         return true
256                                 }
257                                 if step == -1 && !inclusive {
258                                         // Can't underflow because minint is never a possible value.
259                                         return true
260                                 }
261                         }
262                         return false
263
264                 }
265
266                 if ok() {
267                         flags := indVarFlags(0)
268                         var min, max *Value
269                         if step > 0 {
270                                 min = init
271                                 max = limit
272                                 if inclusive {
273                                         flags |= indVarMaxInc
274                                 }
275                         } else {
276                                 min = limit
277                                 max = init
278                                 flags |= indVarMaxInc
279                                 if !inclusive {
280                                         flags |= indVarMinExc
281                                 }
282                                 flags |= indVarCountDown
283                                 step = -step
284                         }
285                         if f.pass.debug >= 1 {
286                                 printIndVar(b, ind, min, max, step, flags)
287                         }
288
289                         iv = append(iv, indVar{
290                                 ind:   ind,
291                                 nxt:   nxt,
292                                 min:   min,
293                                 max:   max,
294                                 entry: b.Succs[0].b,
295                                 flags: flags,
296                         })
297                         b.Logf("found induction variable %v (inc = %v, min = %v, max = %v)\n", ind, inc, min, max)
298                 }
299
300                 // TODO: other unrolling idioms
301                 // for i := 0; i < KNN - KNN % k ; i += k
302                 // for i := 0; i < KNN&^(k-1) ; i += k // k a power of 2
303                 // for i := 0; i < KNN&(-k) ; i += k // k a power of 2
304         }
305
306         return iv
307 }
308
309 // addWillOverflow reports whether x+y would result in a value more than maxint.
310 func addWillOverflow(x, y int64) bool {
311         return x+y < x
312 }
313
314 // subWillUnderflow reports whether x-y would result in a value less than minint.
315 func subWillUnderflow(x, y int64) bool {
316         return x-y > x
317 }
318
319 // diff returns x-y as a uint64. Requires x>=y.
320 func diff(x, y int64) uint64 {
321         if x < y {
322                 base.Fatalf("diff %d - %d underflowed", x, y)
323         }
324         return uint64(x - y)
325 }
326
327 // addU returns x+y. Requires that x+y does not overflow an int64.
328 func addU(x int64, y uint64) int64 {
329         if y >= 1<<63 {
330                 if x >= 0 {
331                         base.Fatalf("addU overflowed %d + %d", x, y)
332                 }
333                 x += 1<<63 - 1
334                 x += 1
335                 y -= 1 << 63
336         }
337         if addWillOverflow(x, int64(y)) {
338                 base.Fatalf("addU overflowed %d + %d", x, y)
339         }
340         return x + int64(y)
341 }
342
343 // subU returns x-y. Requires that x-y does not underflow an int64.
344 func subU(x int64, y uint64) int64 {
345         if y >= 1<<63 {
346                 if x < 0 {
347                         base.Fatalf("subU underflowed %d - %d", x, y)
348                 }
349                 x -= 1<<63 - 1
350                 x -= 1
351                 y -= 1 << 63
352         }
353         if subWillUnderflow(x, int64(y)) {
354                 base.Fatalf("subU underflowed %d - %d", x, y)
355         }
356         return x - int64(y)
357 }
358
359 // if v is known to be x - c, where x is known to be nonnegative and c is a
360 // constant, return x, c. Otherwise return nil, 0.
361 func findKNN(v *Value) (*Value, int64) {
362         var x, y *Value
363         x = v
364         switch v.Op {
365         case OpSub64, OpSub32, OpSub16, OpSub8:
366                 x = v.Args[0]
367                 y = v.Args[1]
368
369         case OpAdd64, OpAdd32, OpAdd16, OpAdd8:
370                 x = v.Args[0]
371                 y = v.Args[1]
372                 if x.isGenericIntConst() {
373                         x, y = y, x
374                 }
375         }
376         switch x.Op {
377         case OpSliceLen, OpStringLen, OpSliceCap:
378         default:
379                 return nil, 0
380         }
381         if y == nil {
382                 return x, 0
383         }
384         if !y.isGenericIntConst() {
385                 return nil, 0
386         }
387         if v.Op == OpAdd64 || v.Op == OpAdd32 || v.Op == OpAdd16 || v.Op == OpAdd8 {
388                 return x, -y.AuxInt
389         }
390         return x, y.AuxInt
391 }
392
393 func printIndVar(b *Block, i, min, max *Value, inc int64, flags indVarFlags) {
394         mb1, mb2 := "[", "]"
395         if flags&indVarMinExc != 0 {
396                 mb1 = "("
397         }
398         if flags&indVarMaxInc == 0 {
399                 mb2 = ")"
400         }
401
402         mlim1, mlim2 := fmt.Sprint(min.AuxInt), fmt.Sprint(max.AuxInt)
403         if !min.isGenericIntConst() {
404                 if b.Func.pass.debug >= 2 {
405                         mlim1 = fmt.Sprint(min)
406                 } else {
407                         mlim1 = "?"
408                 }
409         }
410         if !max.isGenericIntConst() {
411                 if b.Func.pass.debug >= 2 {
412                         mlim2 = fmt.Sprint(max)
413                 } else {
414                         mlim2 = "?"
415                 }
416         }
417         extra := ""
418         if b.Func.pass.debug >= 2 {
419                 extra = fmt.Sprintf(" (%s)", i)
420         }
421         b.Func.Warnl(b.Pos, "Induction variable: limits %v%v,%v%v, increment %d%s", mb1, mlim1, mlim2, mb2, inc, extra)
422 }
423
424 func minSignedValue(t *types.Type) int64 {
425         return -1 << (t.Size()*8 - 1)
426 }
427
428 func maxSignedValue(t *types.Type) int64 {
429         return 1<<((t.Size()*8)-1) - 1
430 }