]> Cypherpunks.ru repositories - gostls13.git/blob - src/cmd/compile/internal/inline/inlheur/analyze_func_callsites.go
cmd/compile/internal/inlheur: regionalize call site analysis
[gostls13.git] / src / cmd / compile / internal / inline / inlheur / analyze_func_callsites.go
1 // Copyright 2023 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package inlheur
6
7 import (
8         "cmd/compile/internal/base"
9         "cmd/compile/internal/ir"
10         "cmd/compile/internal/pgo"
11         "fmt"
12         "os"
13         "sort"
14         "strings"
15 )
16
17 type callSiteAnalyzer struct {
18         cstab    CallSiteTab
19         fn       *ir.Func
20         ptab     map[ir.Node]pstate
21         nstack   []ir.Node
22         loopNest int
23         isInit   bool
24 }
25
26 func makeCallSiteAnalyzer(fn *ir.Func, ptab map[ir.Node]pstate, loopNestingLevel int) *callSiteAnalyzer {
27         isInit := fn.IsPackageInit() || strings.HasPrefix(fn.Sym().Name, "init.")
28         return &callSiteAnalyzer{
29                 fn:       fn,
30                 cstab:    make(CallSiteTab),
31                 ptab:     ptab,
32                 isInit:   isInit,
33                 loopNest: loopNestingLevel,
34                 nstack:   []ir.Node{fn},
35         }
36 }
37
38 // computeCallSiteTable builds and returns a table of call sites for
39 // the specified region in function fn. A region here corresponds to a
40 // specific subtree within the AST for a function. The main intended
41 // use cases are for 'region' to be either A) an entire function body,
42 // or B) an inlined call expression.
43 func computeCallSiteTable(fn *ir.Func, region ir.Nodes, ptab map[ir.Node]pstate, loopNestingLevel int) CallSiteTab {
44         csa := makeCallSiteAnalyzer(fn, ptab, loopNestingLevel)
45         var doNode func(ir.Node) bool
46         doNode = func(n ir.Node) bool {
47                 csa.nodeVisitPre(n)
48                 ir.DoChildren(n, doNode)
49                 csa.nodeVisitPost(n)
50                 return false
51         }
52         for _, n := range region {
53                 doNode(n)
54         }
55         return csa.cstab
56 }
57
58 func (csa *callSiteAnalyzer) flagsForNode(call *ir.CallExpr) CSPropBits {
59         var r CSPropBits
60
61         if debugTrace&debugTraceCalls != 0 {
62                 fmt.Fprintf(os.Stderr, "=-= analyzing call at %s\n",
63                         fmtFullPos(call.Pos()))
64         }
65
66         // Set a bit if this call is within a loop.
67         if csa.loopNest > 0 {
68                 r |= CallSiteInLoop
69         }
70
71         // Set a bit if the call is within an init function (either
72         // compiler-generated or user-written).
73         if csa.isInit {
74                 r |= CallSiteInInitFunc
75         }
76
77         // Decide whether to apply the panic path heuristic. Hack: don't
78         // apply this heuristic in the function "main.main" (mostly just
79         // to avoid annoying users).
80         if !isMainMain(csa.fn) {
81                 r = csa.determinePanicPathBits(call, r)
82         }
83
84         return r
85 }
86
87 // determinePanicPathBits updates the CallSiteOnPanicPath bit within
88 // "r" if we think this call is on an unconditional path to
89 // panic/exit. Do this by walking back up the node stack to see if we
90 // can find either A) an enclosing panic, or B) a statement node that
91 // we've determined leads to a panic/exit.
92 func (csa *callSiteAnalyzer) determinePanicPathBits(call ir.Node, r CSPropBits) CSPropBits {
93         csa.nstack = append(csa.nstack, call)
94         defer func() {
95                 csa.nstack = csa.nstack[:len(csa.nstack)-1]
96         }()
97
98         for ri := range csa.nstack[:len(csa.nstack)-1] {
99                 i := len(csa.nstack) - ri - 1
100                 n := csa.nstack[i]
101                 _, isCallExpr := n.(*ir.CallExpr)
102                 _, isStmt := n.(ir.Stmt)
103                 if isCallExpr {
104                         isStmt = false
105                 }
106
107                 if debugTrace&debugTraceCalls != 0 {
108                         ps, inps := csa.ptab[n]
109                         fmt.Fprintf(os.Stderr, "=-= callpar %d op=%s ps=%s inptab=%v stmt=%v\n", i, n.Op().String(), ps.String(), inps, isStmt)
110                 }
111
112                 if n.Op() == ir.OPANIC {
113                         r |= CallSiteOnPanicPath
114                         break
115                 }
116                 if v, ok := csa.ptab[n]; ok {
117                         if v == psCallsPanic {
118                                 r |= CallSiteOnPanicPath
119                                 break
120                         }
121                         if isStmt {
122                                 break
123                         }
124                 }
125         }
126         return r
127 }
128
129 func (csa *callSiteAnalyzer) addCallSite(callee *ir.Func, call *ir.CallExpr) {
130         flags := csa.flagsForNode(call)
131         // FIXME: maybe bulk-allocate these?
132         cs := &CallSite{
133                 Call:   call,
134                 Callee: callee,
135                 Assign: csa.containingAssignment(call),
136                 Flags:  flags,
137                 ID:     uint(len(csa.cstab)),
138         }
139         if _, ok := csa.cstab[call]; ok {
140                 fmt.Fprintf(os.Stderr, "*** cstab duplicate entry at: %s\n",
141                         fmtFullPos(call.Pos()))
142                 fmt.Fprintf(os.Stderr, "*** call: %+v\n", call)
143                 panic("bad")
144         }
145         if callee.Inl != nil {
146                 // Set initial score for callsite to the cost computed
147                 // by CanInline; this score will be refined later based
148                 // on heuristics.
149                 cs.Score = int(callee.Inl.Cost)
150         }
151
152         csa.cstab[call] = cs
153         if debugTrace&debugTraceCalls != 0 {
154                 fmt.Fprintf(os.Stderr, "=-= added callsite: callee=%s call=%v\n",
155                         callee.Sym().Name, callee)
156         }
157 }
158
159 // ScoreCalls assigns numeric scores to each of the callsites in
160 // function fn; the lower the score, the more helpful we think it will
161 // be to inline.
162 //
163 // Unlike a lot of the other inline heuristics machinery, callsite
164 // scoring can't be done as part of the CanInline call for a function,
165 // due to fact that we may be working on a non-trivial SCC. So for
166 // example with this SCC:
167 //
168 //      func foo(x int) {           func bar(x int, f func()) {
169 //        if x != 0 {                  f()
170 //          bar(x, func(){})           foo(x-1)
171 //        }                         }
172 //      }
173 //
174 // We don't want to perform scoring for the 'foo' call in "bar" until
175 // after foo has been analyzed, but it's conceivable that CanInline
176 // might visit bar before foo for this SCC.
177 func ScoreCalls(fn *ir.Func) {
178         enableDebugTraceIfEnv()
179         defer disableDebugTrace()
180         if debugTrace&debugTraceScoring != 0 {
181                 fmt.Fprintf(os.Stderr, "=-= ScoreCalls(%v)\n", ir.FuncName(fn))
182         }
183
184         funcInlHeur, ok := fpmap[fn]
185         if !ok {
186                 // TODO: add an assert/panic here.
187                 return
188         }
189         scoreCallsRegion(fn, fn.Body, funcInlHeur.cstab)
190 }
191
192 // scoreCallsRegion assigns numeric scores to each of the callsites in
193 // region 'region' within function 'fn'. This can be called on
194 // an entire function, or with 'region' set to a chunk of
195 // code corresponding to an inlined call.
196 func scoreCallsRegion(fn *ir.Func, region ir.Nodes, cstab CallSiteTab) {
197         if debugTrace&debugTraceScoring != 0 {
198                 fmt.Fprintf(os.Stderr, "=-= scoreCallsRegion(%v, %s)\n",
199                         ir.FuncName(fn), region[0].Op().String())
200         }
201
202         resultNameTab := make(map[*ir.Name]resultPropAndCS)
203
204         // Sort callsites to avoid any surprises with non deterministic
205         // map iteration order (this is probably not needed, but here just
206         // in case).
207         csl := make([]*CallSite, 0, len(cstab))
208         for _, cs := range cstab {
209                 csl = append(csl, cs)
210         }
211         sort.Slice(csl, func(i, j int) bool {
212                 return csl[i].ID < csl[j].ID
213         })
214
215         // Score each call site.
216         for _, cs := range csl {
217                 var cprops *FuncProps
218                 fihcprops := false
219                 desercprops := false
220                 if funcInlHeur, ok := fpmap[cs.Callee]; ok {
221                         cprops = funcInlHeur.props
222                         fihcprops = true
223                 } else if cs.Callee.Inl != nil {
224                         cprops = DeserializeFromString(cs.Callee.Inl.Properties)
225                         desercprops = true
226                 } else {
227                         if base.Debug.DumpInlFuncProps != "" {
228                                 fmt.Fprintf(os.Stderr, "=-= *** unable to score call to %s from %s\n", cs.Callee.Sym().Name, fmtFullPos(cs.Call.Pos()))
229                                 panic("should never happen")
230                         } else {
231                                 continue
232                         }
233                 }
234                 cs.Score, cs.ScoreMask = computeCallSiteScore(cs.Callee, cprops, cs.Call, cs.Flags)
235
236                 examineCallResults(cs, resultNameTab)
237
238                 if debugTrace&debugTraceScoring != 0 {
239                         fmt.Fprintf(os.Stderr, "=-= examineCallResults at %s: flags=%d score=%d funcInlHeur=%v deser=%v\n", fmtFullPos(cs.Call.Pos()), cs.Flags, cs.Score, fihcprops, desercprops)
240                 }
241         }
242
243         rescoreBasedOnCallResultUses(fn, resultNameTab, cstab)
244 }
245
246 func (csa *callSiteAnalyzer) nodeVisitPre(n ir.Node) {
247         switch n.Op() {
248         case ir.ORANGE, ir.OFOR:
249                 if !hasTopLevelLoopBodyReturnOrBreak(loopBody(n)) {
250                         csa.loopNest++
251                 }
252         case ir.OCALLFUNC:
253                 ce := n.(*ir.CallExpr)
254                 callee := pgo.DirectCallee(ce.Fun)
255                 if callee != nil && callee.Inl != nil {
256                         csa.addCallSite(callee, ce)
257                 }
258         }
259         csa.nstack = append(csa.nstack, n)
260 }
261
262 func (csa *callSiteAnalyzer) nodeVisitPost(n ir.Node) {
263         csa.nstack = csa.nstack[:len(csa.nstack)-1]
264         switch n.Op() {
265         case ir.ORANGE, ir.OFOR:
266                 if !hasTopLevelLoopBodyReturnOrBreak(loopBody(n)) {
267                         csa.loopNest--
268                 }
269         }
270 }
271
272 func loopBody(n ir.Node) ir.Nodes {
273         if forst, ok := n.(*ir.ForStmt); ok {
274                 return forst.Body
275         }
276         if rst, ok := n.(*ir.RangeStmt); ok {
277                 return rst.Body
278         }
279         return nil
280 }
281
282 // hasTopLevelLoopBodyReturnOrBreak examines the body of a "for" or
283 // "range" loop to try to verify that it is a real loop, as opposed to
284 // a construct that is syntactically loopy but doesn't actually iterate
285 // multiple times, like:
286 //
287 //      for {
288 //        blah()
289 //        return 1
290 //      }
291 //
292 // [Remark: the pattern above crops up quite a bit in the source code
293 // for the compiler itself, e.g. the auto-generated rewrite code]
294 //
295 // Note that we don't look for GOTO statements here, so it's possible
296 // we'll get the wrong result for a loop with complicated control
297 // jumps via gotos.
298 func hasTopLevelLoopBodyReturnOrBreak(loopBody ir.Nodes) bool {
299         for _, n := range loopBody {
300                 if n.Op() == ir.ORETURN || n.Op() == ir.OBREAK {
301                         return true
302                 }
303         }
304         return false
305 }
306
307 // containingAssignment returns the top-level assignment statement
308 // for a statement level function call "n". Examples:
309 //
310 //      x := foo()
311 //      x, y := bar(z, baz())
312 //      if blah() { ...
313 //
314 // Here the top-level assignment statement for the foo() call is the
315 // statement assigning to "x"; the top-level assignment for "bar()"
316 // call is the assignment to x,y. For the baz() and blah() calls,
317 // there is no top level assignment statement.
318 //
319 // The unstated goal here is that we want to use the containing
320 // assignment to establish a connection between a given call and the
321 // variables to which its results/returns are being assigned.
322 //
323 // Note that for the "bar" command above, the front end sometimes
324 // decomposes this into two assignments, the first one assigning the
325 // call to a pair of auto-temps, then the second one assigning the
326 // auto-temps to the user-visible vars. This helper will return the
327 // second (outer) of these two.
328 func (csa *callSiteAnalyzer) containingAssignment(n ir.Node) ir.Node {
329         parent := csa.nstack[len(csa.nstack)-1]
330
331         // assignsOnlyAutoTemps returns TRUE of the specified OAS2FUNC
332         // node assigns only auto-temps.
333         assignsOnlyAutoTemps := func(x ir.Node) bool {
334                 alst := x.(*ir.AssignListStmt)
335                 oa2init := alst.Init()
336                 if len(oa2init) == 0 {
337                         return false
338                 }
339                 for _, v := range oa2init {
340                         d := v.(*ir.Decl)
341                         if !ir.IsAutoTmp(d.X) {
342                                 return false
343                         }
344                 }
345                 return true
346         }
347
348         // Simple case: x := foo()
349         if parent.Op() == ir.OAS {
350                 return parent
351         }
352
353         // Multi-return case: x, y := bar()
354         if parent.Op() == ir.OAS2FUNC {
355                 // Hack city: if the result vars are auto-temps, try looking
356                 // for an outer assignment in the tree. The code shape we're
357                 // looking for here is:
358                 //
359                 // OAS1({x,y},OCONVNOP(OAS2FUNC({auto1,auto2},OCALLFUNC(bar))))
360                 //
361                 if assignsOnlyAutoTemps(parent) {
362                         par2 := csa.nstack[len(csa.nstack)-2]
363                         if par2.Op() == ir.OAS2 {
364                                 return par2
365                         }
366                         if par2.Op() == ir.OCONVNOP {
367                                 par3 := csa.nstack[len(csa.nstack)-3]
368                                 if par3.Op() == ir.OAS2 {
369                                         return par3
370                                 }
371                         }
372                 }
373         }
374
375         return nil
376 }