]> Cypherpunks.ru repositories - gostls13.git/blob - src/cmd/compile/internal/inline/inlheur/analyze_func_callsites.go
cmd/compile/internal/inline: score call sites exposed by inlines
[gostls13.git] / src / cmd / compile / internal / inline / inlheur / analyze_func_callsites.go
1 // Copyright 2023 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package inlheur
6
7 import (
8         "cmd/compile/internal/ir"
9         "cmd/compile/internal/pgo"
10         "fmt"
11         "os"
12         "strings"
13 )
14
15 type callSiteAnalyzer struct {
16         cstab    CallSiteTab
17         fn       *ir.Func
18         ptab     map[ir.Node]pstate
19         nstack   []ir.Node
20         loopNest int
21         isInit   bool
22 }
23
24 func makeCallSiteAnalyzer(fn *ir.Func, cstab CallSiteTab, ptab map[ir.Node]pstate, loopNestingLevel int) *callSiteAnalyzer {
25         isInit := fn.IsPackageInit() || strings.HasPrefix(fn.Sym().Name, "init.")
26         return &callSiteAnalyzer{
27                 fn:       fn,
28                 cstab:    cstab,
29                 ptab:     ptab,
30                 isInit:   isInit,
31                 loopNest: loopNestingLevel,
32                 nstack:   []ir.Node{fn},
33         }
34 }
35
36 // computeCallSiteTable builds and returns a table of call sites for
37 // the specified region in function fn. A region here corresponds to a
38 // specific subtree within the AST for a function. The main intended
39 // use cases are for 'region' to be either A) an entire function body,
40 // or B) an inlined call expression.
41 func computeCallSiteTable(fn *ir.Func, region ir.Nodes, cstab CallSiteTab, ptab map[ir.Node]pstate, loopNestingLevel int) CallSiteTab {
42         csa := makeCallSiteAnalyzer(fn, cstab, ptab, loopNestingLevel)
43         var doNode func(ir.Node) bool
44         doNode = func(n ir.Node) bool {
45                 csa.nodeVisitPre(n)
46                 ir.DoChildren(n, doNode)
47                 csa.nodeVisitPost(n)
48                 return false
49         }
50         for _, n := range region {
51                 doNode(n)
52         }
53         return csa.cstab
54 }
55
56 func (csa *callSiteAnalyzer) flagsForNode(call *ir.CallExpr) CSPropBits {
57         var r CSPropBits
58
59         if debugTrace&debugTraceCalls != 0 {
60                 fmt.Fprintf(os.Stderr, "=-= analyzing call at %s\n",
61                         fmtFullPos(call.Pos()))
62         }
63
64         // Set a bit if this call is within a loop.
65         if csa.loopNest > 0 {
66                 r |= CallSiteInLoop
67         }
68
69         // Set a bit if the call is within an init function (either
70         // compiler-generated or user-written).
71         if csa.isInit {
72                 r |= CallSiteInInitFunc
73         }
74
75         // Decide whether to apply the panic path heuristic. Hack: don't
76         // apply this heuristic in the function "main.main" (mostly just
77         // to avoid annoying users).
78         if !isMainMain(csa.fn) {
79                 r = csa.determinePanicPathBits(call, r)
80         }
81
82         return r
83 }
84
85 // determinePanicPathBits updates the CallSiteOnPanicPath bit within
86 // "r" if we think this call is on an unconditional path to
87 // panic/exit. Do this by walking back up the node stack to see if we
88 // can find either A) an enclosing panic, or B) a statement node that
89 // we've determined leads to a panic/exit.
90 func (csa *callSiteAnalyzer) determinePanicPathBits(call ir.Node, r CSPropBits) CSPropBits {
91         csa.nstack = append(csa.nstack, call)
92         defer func() {
93                 csa.nstack = csa.nstack[:len(csa.nstack)-1]
94         }()
95
96         for ri := range csa.nstack[:len(csa.nstack)-1] {
97                 i := len(csa.nstack) - ri - 1
98                 n := csa.nstack[i]
99                 _, isCallExpr := n.(*ir.CallExpr)
100                 _, isStmt := n.(ir.Stmt)
101                 if isCallExpr {
102                         isStmt = false
103                 }
104
105                 if debugTrace&debugTraceCalls != 0 {
106                         ps, inps := csa.ptab[n]
107                         fmt.Fprintf(os.Stderr, "=-= callpar %d op=%s ps=%s inptab=%v stmt=%v\n", i, n.Op().String(), ps.String(), inps, isStmt)
108                 }
109
110                 if n.Op() == ir.OPANIC {
111                         r |= CallSiteOnPanicPath
112                         break
113                 }
114                 if v, ok := csa.ptab[n]; ok {
115                         if v == psCallsPanic {
116                                 r |= CallSiteOnPanicPath
117                                 break
118                         }
119                         if isStmt {
120                                 break
121                         }
122                 }
123         }
124         return r
125 }
126
127 func (csa *callSiteAnalyzer) addCallSite(callee *ir.Func, call *ir.CallExpr) {
128         flags := csa.flagsForNode(call)
129         // FIXME: maybe bulk-allocate these?
130         cs := &CallSite{
131                 Call:   call,
132                 Callee: callee,
133                 Assign: csa.containingAssignment(call),
134                 Flags:  flags,
135                 ID:     uint(len(csa.cstab)),
136         }
137         if _, ok := csa.cstab[call]; ok {
138                 fmt.Fprintf(os.Stderr, "*** cstab duplicate entry at: %s\n",
139                         fmtFullPos(call.Pos()))
140                 fmt.Fprintf(os.Stderr, "*** call: %+v\n", call)
141                 panic("bad")
142         }
143         if callee.Inl != nil {
144                 // Set initial score for callsite to the cost computed
145                 // by CanInline; this score will be refined later based
146                 // on heuristics.
147                 cs.Score = int(callee.Inl.Cost)
148         }
149
150         if csa.cstab == nil {
151                 csa.cstab = make(CallSiteTab)
152         }
153         csa.cstab[call] = cs
154         if debugTrace&debugTraceCalls != 0 {
155                 fmt.Fprintf(os.Stderr, "=-= added callsite: caller=%v callee=%v n=%s\n",
156                         csa.fn, callee, fmtFullPos(call.Pos()))
157         }
158 }
159
160 func (csa *callSiteAnalyzer) nodeVisitPre(n ir.Node) {
161         switch n.Op() {
162         case ir.ORANGE, ir.OFOR:
163                 if !hasTopLevelLoopBodyReturnOrBreak(loopBody(n)) {
164                         csa.loopNest++
165                 }
166         case ir.OCALLFUNC:
167                 ce := n.(*ir.CallExpr)
168                 callee := pgo.DirectCallee(ce.Fun)
169                 if callee != nil && callee.Inl != nil {
170                         csa.addCallSite(callee, ce)
171                 }
172         }
173         csa.nstack = append(csa.nstack, n)
174 }
175
176 func (csa *callSiteAnalyzer) nodeVisitPost(n ir.Node) {
177         csa.nstack = csa.nstack[:len(csa.nstack)-1]
178         switch n.Op() {
179         case ir.ORANGE, ir.OFOR:
180                 if !hasTopLevelLoopBodyReturnOrBreak(loopBody(n)) {
181                         csa.loopNest--
182                 }
183         }
184 }
185
186 func loopBody(n ir.Node) ir.Nodes {
187         if forst, ok := n.(*ir.ForStmt); ok {
188                 return forst.Body
189         }
190         if rst, ok := n.(*ir.RangeStmt); ok {
191                 return rst.Body
192         }
193         return nil
194 }
195
196 // hasTopLevelLoopBodyReturnOrBreak examines the body of a "for" or
197 // "range" loop to try to verify that it is a real loop, as opposed to
198 // a construct that is syntactically loopy but doesn't actually iterate
199 // multiple times, like:
200 //
201 //      for {
202 //        blah()
203 //        return 1
204 //      }
205 //
206 // [Remark: the pattern above crops up quite a bit in the source code
207 // for the compiler itself, e.g. the auto-generated rewrite code]
208 //
209 // Note that we don't look for GOTO statements here, so it's possible
210 // we'll get the wrong result for a loop with complicated control
211 // jumps via gotos.
212 func hasTopLevelLoopBodyReturnOrBreak(loopBody ir.Nodes) bool {
213         for _, n := range loopBody {
214                 if n.Op() == ir.ORETURN || n.Op() == ir.OBREAK {
215                         return true
216                 }
217         }
218         return false
219 }
220
221 // containingAssignment returns the top-level assignment statement
222 // for a statement level function call "n". Examples:
223 //
224 //      x := foo()
225 //      x, y := bar(z, baz())
226 //      if blah() { ...
227 //
228 // Here the top-level assignment statement for the foo() call is the
229 // statement assigning to "x"; the top-level assignment for "bar()"
230 // call is the assignment to x,y. For the baz() and blah() calls,
231 // there is no top level assignment statement.
232 //
233 // The unstated goal here is that we want to use the containing
234 // assignment to establish a connection between a given call and the
235 // variables to which its results/returns are being assigned.
236 //
237 // Note that for the "bar" command above, the front end sometimes
238 // decomposes this into two assignments, the first one assigning the
239 // call to a pair of auto-temps, then the second one assigning the
240 // auto-temps to the user-visible vars. This helper will return the
241 // second (outer) of these two.
242 func (csa *callSiteAnalyzer) containingAssignment(n ir.Node) ir.Node {
243         parent := csa.nstack[len(csa.nstack)-1]
244
245         // assignsOnlyAutoTemps returns TRUE of the specified OAS2FUNC
246         // node assigns only auto-temps.
247         assignsOnlyAutoTemps := func(x ir.Node) bool {
248                 alst := x.(*ir.AssignListStmt)
249                 oa2init := alst.Init()
250                 if len(oa2init) == 0 {
251                         return false
252                 }
253                 for _, v := range oa2init {
254                         d := v.(*ir.Decl)
255                         if !ir.IsAutoTmp(d.X) {
256                                 return false
257                         }
258                 }
259                 return true
260         }
261
262         // Simple case: x := foo()
263         if parent.Op() == ir.OAS {
264                 return parent
265         }
266
267         // Multi-return case: x, y := bar()
268         if parent.Op() == ir.OAS2FUNC {
269                 // Hack city: if the result vars are auto-temps, try looking
270                 // for an outer assignment in the tree. The code shape we're
271                 // looking for here is:
272                 //
273                 // OAS1({x,y},OCONVNOP(OAS2FUNC({auto1,auto2},OCALLFUNC(bar))))
274                 //
275                 if assignsOnlyAutoTemps(parent) {
276                         par2 := csa.nstack[len(csa.nstack)-2]
277                         if par2.Op() == ir.OAS2 {
278                                 return par2
279                         }
280                         if par2.Op() == ir.OCONVNOP {
281                                 par3 := csa.nstack[len(csa.nstack)-3]
282                                 if par3.Op() == ir.OAS2 {
283                                         return par3
284                                 }
285                         }
286                 }
287         }
288
289         return nil
290 }
291
292 // UpdateCallsiteTable handles updating of callerfn's call site table
293 // after an inlined has been carried out, e.g. the call at 'n' as been
294 // turned into the inlined call expression 'ic' within function
295 // callerfn. The chief thing of interest here is to make sure that any
296 // call nodes within 'ic' are added to the call site table for
297 // 'callerfn' and scored appropriately.
298 func UpdateCallsiteTable(callerfn *ir.Func, n *ir.CallExpr, ic *ir.InlinedCallExpr) {
299         enableDebugTraceIfEnv()
300         defer disableDebugTrace()
301
302         funcInlHeur, ok := fpmap[callerfn]
303         if !ok {
304                 // This can happen for compiler-generated wrappers.
305                 if debugTrace&debugTraceCalls != 0 {
306                         fmt.Fprintf(os.Stderr, "=-= early exit, no entry for caller fn %v\n", callerfn)
307                 }
308                 return
309         }
310
311         if debugTrace&debugTraceCalls != 0 {
312                 fmt.Fprintf(os.Stderr, "=-= UpdateCallsiteTable(caller=%v, cs=%s)\n",
313                         callerfn, fmtFullPos(n.Pos()))
314         }
315
316         // Mark the call in question as inlined.
317         oldcs, ok := funcInlHeur.cstab[n]
318         if !ok {
319                 // This can happen for compiler-generated wrappers.
320                 return
321         }
322         oldcs.aux |= csAuxInlined
323
324         if debugTrace&debugTraceCalls != 0 {
325                 fmt.Fprintf(os.Stderr, "=-= marked as inlined: callee=%v %s\n",
326                         oldcs.Callee, EncodeCallSiteKey(oldcs))
327         }
328
329         // Walk the inlined call region to collect new callsites.
330         var icp pstate
331         if oldcs.Flags&CallSiteOnPanicPath != 0 {
332                 icp = psCallsPanic
333         }
334         var loopNestLevel int
335         if oldcs.Flags&CallSiteInLoop != 0 {
336                 loopNestLevel = 1
337         }
338         ptab := map[ir.Node]pstate{ic: icp}
339         icstab := computeCallSiteTable(callerfn, ic.Body, nil, ptab, loopNestLevel)
340
341         // Record parent callsite. This is primarily for debug output.
342         for _, cs := range icstab {
343                 cs.parent = oldcs
344         }
345
346         // Score the calls in the inlined body. Note the setting of "doCallResults"
347         // to false here: at the moment there isn't any easy way to localize
348         // or region-ize the work done by "rescoreBasedOnCallResultUses", which
349         // currently does a walk over the entire function to look for uses
350         // of a given set of results.
351         const doCallResults = false
352         scoreCallsRegion(callerfn, ic.Body, icstab, doCallResults, ic)
353 }