]> Cypherpunks.ru repositories - gostls13.git/blob - src/cmd/compile/internal/inline/inlheur/analyze_func_callsites.go
f0e07d29fca387550072a149c5f07fb4868bab25
[gostls13.git] / src / cmd / compile / internal / inline / inlheur / analyze_func_callsites.go
1 // Copyright 2023 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package inlheur
6
7 import (
8         "cmd/compile/internal/base"
9         "cmd/compile/internal/ir"
10         "cmd/compile/internal/pgo"
11         "fmt"
12         "os"
13         "sort"
14         "strings"
15 )
16
17 type callSiteAnalyzer struct {
18         cstab    CallSiteTab
19         fn       *ir.Func
20         ptab     map[ir.Node]pstate
21         nstack   []ir.Node
22         loopNest int
23         isInit   bool
24 }
25
26 func makeCallSiteAnalyzer(fn *ir.Func, ptab map[ir.Node]pstate) *callSiteAnalyzer {
27         isInit := fn.IsPackageInit() || strings.HasPrefix(fn.Sym().Name, "init.")
28         return &callSiteAnalyzer{
29                 fn:     fn,
30                 cstab:  make(CallSiteTab),
31                 ptab:   ptab,
32                 isInit: isInit,
33         }
34 }
35
36 func computeCallSiteTable(fn *ir.Func, ptab map[ir.Node]pstate) CallSiteTab {
37         if debugTrace != 0 {
38                 fmt.Fprintf(os.Stderr, "=-= making callsite table for func %v:\n",
39                         fn.Sym().Name)
40         }
41         csa := makeCallSiteAnalyzer(fn, ptab)
42         var doNode func(ir.Node) bool
43         doNode = func(n ir.Node) bool {
44                 csa.nodeVisitPre(n)
45                 ir.DoChildren(n, doNode)
46                 csa.nodeVisitPost(n)
47                 return false
48         }
49         doNode(fn)
50         return csa.cstab
51 }
52
53 func (csa *callSiteAnalyzer) flagsForNode(call *ir.CallExpr) CSPropBits {
54         var r CSPropBits
55
56         if debugTrace&debugTraceCalls != 0 {
57                 fmt.Fprintf(os.Stderr, "=-= analyzing call at %s\n",
58                         fmtFullPos(call.Pos()))
59         }
60
61         // Set a bit if this call is within a loop.
62         if csa.loopNest > 0 {
63                 r |= CallSiteInLoop
64         }
65
66         // Set a bit if the call is within an init function (either
67         // compiler-generated or user-written).
68         if csa.isInit {
69                 r |= CallSiteInInitFunc
70         }
71
72         // Decide whether to apply the panic path heuristic. Hack: don't
73         // apply this heuristic in the function "main.main" (mostly just
74         // to avoid annoying users).
75         if !isMainMain(csa.fn) {
76                 r = csa.determinePanicPathBits(call, r)
77         }
78
79         return r
80 }
81
82 // determinePanicPathBits updates the CallSiteOnPanicPath bit within
83 // "r" if we think this call is on an unconditional path to
84 // panic/exit. Do this by walking back up the node stack to see if we
85 // can find either A) an enclosing panic, or B) a statement node that
86 // we've determined leads to a panic/exit.
87 func (csa *callSiteAnalyzer) determinePanicPathBits(call ir.Node, r CSPropBits) CSPropBits {
88         csa.nstack = append(csa.nstack, call)
89         defer func() {
90                 csa.nstack = csa.nstack[:len(csa.nstack)-1]
91         }()
92
93         for ri := range csa.nstack[:len(csa.nstack)-1] {
94                 i := len(csa.nstack) - ri - 1
95                 n := csa.nstack[i]
96                 _, isCallExpr := n.(*ir.CallExpr)
97                 _, isStmt := n.(ir.Stmt)
98                 if isCallExpr {
99                         isStmt = false
100                 }
101
102                 if debugTrace&debugTraceCalls != 0 {
103                         ps, inps := csa.ptab[n]
104                         fmt.Fprintf(os.Stderr, "=-= callpar %d op=%s ps=%s inptab=%v stmt=%v\n", i, n.Op().String(), ps.String(), inps, isStmt)
105                 }
106
107                 if n.Op() == ir.OPANIC {
108                         r |= CallSiteOnPanicPath
109                         break
110                 }
111                 if v, ok := csa.ptab[n]; ok {
112                         if v == psCallsPanic {
113                                 r |= CallSiteOnPanicPath
114                                 break
115                         }
116                         if isStmt {
117                                 break
118                         }
119                 }
120         }
121         return r
122 }
123
124 func (csa *callSiteAnalyzer) addCallSite(callee *ir.Func, call *ir.CallExpr) {
125         flags := csa.flagsForNode(call)
126         // FIXME: maybe bulk-allocate these?
127         cs := &CallSite{
128                 Call:   call,
129                 Callee: callee,
130                 Assign: csa.containingAssignment(call),
131                 Flags:  flags,
132                 ID:     uint(len(csa.cstab)),
133         }
134         if _, ok := csa.cstab[call]; ok {
135                 fmt.Fprintf(os.Stderr, "*** cstab duplicate entry at: %s\n",
136                         fmtFullPos(call.Pos()))
137                 fmt.Fprintf(os.Stderr, "*** call: %+v\n", call)
138                 panic("bad")
139         }
140         if callee.Inl != nil {
141                 // Set initial score for callsite to the cost computed
142                 // by CanInline; this score will be refined later based
143                 // on heuristics.
144                 cs.Score = int(callee.Inl.Cost)
145         }
146
147         csa.cstab[call] = cs
148         if debugTrace&debugTraceCalls != 0 {
149                 fmt.Fprintf(os.Stderr, "=-= added callsite: callee=%s call=%v\n",
150                         callee.Sym().Name, callee)
151         }
152 }
153
154 // ScoreCalls assigns numeric scores to each of the callsites in
155 // function 'fn'; the lower the score, the more helpful we think it
156 // will be to inline.
157 //
158 // Unlike a lot of the other inline heuristics machinery, callsite
159 // scoring can't be done as part of the CanInline call for a function,
160 // due to fact that we may be working on a non-trivial SCC. So for
161 // example with this SCC:
162 //
163 //      func foo(x int) {           func bar(x int, f func()) {
164 //        if x != 0 {                  f()
165 //          bar(x, func(){})           foo(x-1)
166 //        }                         }
167 //      }
168 //
169 // We don't want to perform scoring for the 'foo' call in "bar" until
170 // after foo has been analyzed, but it's conceivable that CanInline
171 // might visit bar before foo for this SCC.
172 func ScoreCalls(fn *ir.Func) {
173         enableDebugTraceIfEnv()
174         defer disableDebugTrace()
175         if debugTrace&debugTraceScoring != 0 {
176                 fmt.Fprintf(os.Stderr, "=-= ScoreCalls(%v)\n", ir.FuncName(fn))
177         }
178
179         funcInlHeur, ok := fpmap[fn]
180         if !ok {
181                 // TODO: add an assert/panic here.
182                 return
183         }
184
185         resultNameTab := make(map[*ir.Name]resultPropAndCS)
186
187         // Sort callsites to avoid any surprises with non deterministic
188         // map iteration order (this is probably not needed, but here just
189         // in case).
190         csl := make([]*CallSite, 0, len(funcInlHeur.cstab))
191         for _, cs := range funcInlHeur.cstab {
192                 csl = append(csl, cs)
193         }
194         sort.Slice(csl, func(i, j int) bool {
195                 return csl[i].ID < csl[j].ID
196         })
197
198         // Score each call site.
199         for _, cs := range csl {
200                 var cprops *FuncProps
201                 fihcprops := false
202                 desercprops := false
203                 if funcInlHeur, ok := fpmap[cs.Callee]; ok {
204                         cprops = funcInlHeur.props
205                         fihcprops = true
206                 } else if cs.Callee.Inl != nil {
207                         cprops = DeserializeFromString(cs.Callee.Inl.Properties)
208                         desercprops = true
209                 } else {
210                         if base.Debug.DumpInlFuncProps != "" {
211                                 fmt.Fprintf(os.Stderr, "=-= *** unable to score call to %s from %s\n", cs.Callee.Sym().Name, fmtFullPos(cs.Call.Pos()))
212                                 panic("should never happen")
213                         } else {
214                                 continue
215                         }
216                 }
217                 cs.Score, cs.ScoreMask = computeCallSiteScore(cs.Callee, cprops, cs.Call, cs.Flags)
218
219                 examineCallResults(cs, resultNameTab)
220
221                 if debugTrace&debugTraceScoring != 0 {
222                         fmt.Fprintf(os.Stderr, "=-= scoring call at %s: flags=%d score=%d funcInlHeur=%v deser=%v\n", fmtFullPos(cs.Call.Pos()), cs.Flags, cs.Score, fihcprops, desercprops)
223                 }
224         }
225
226         rescoreBasedOnCallResultUses(fn, resultNameTab, funcInlHeur.cstab)
227 }
228
229 func (csa *callSiteAnalyzer) nodeVisitPre(n ir.Node) {
230         switch n.Op() {
231         case ir.ORANGE, ir.OFOR:
232                 if !hasTopLevelLoopBodyReturnOrBreak(loopBody(n)) {
233                         csa.loopNest++
234                 }
235         case ir.OCALLFUNC:
236                 ce := n.(*ir.CallExpr)
237                 callee := pgo.DirectCallee(ce.Fun)
238                 if callee != nil && callee.Inl != nil {
239                         csa.addCallSite(callee, ce)
240                 }
241         }
242         csa.nstack = append(csa.nstack, n)
243 }
244
245 func (csa *callSiteAnalyzer) nodeVisitPost(n ir.Node) {
246         csa.nstack = csa.nstack[:len(csa.nstack)-1]
247         switch n.Op() {
248         case ir.ORANGE, ir.OFOR:
249                 if !hasTopLevelLoopBodyReturnOrBreak(loopBody(n)) {
250                         csa.loopNest--
251                 }
252         }
253 }
254
255 func loopBody(n ir.Node) ir.Nodes {
256         if forst, ok := n.(*ir.ForStmt); ok {
257                 return forst.Body
258         }
259         if rst, ok := n.(*ir.RangeStmt); ok {
260                 return rst.Body
261         }
262         return nil
263 }
264
265 // hasTopLevelLoopBodyReturnOrBreak examines the body of a "for" or
266 // "range" loop to try to verify that it is a real loop, as opposed to
267 // a construct that is syntactically loopy but doesn't actually iterate
268 // multiple times, like:
269 //
270 //      for {
271 //        blah()
272 //        return 1
273 //      }
274 //
275 // [Remark: the pattern above crops up quite a bit in the source code
276 // for the compiler itself, e.g. the auto-generated rewrite code]
277 //
278 // Note that we don't look for GOTO statements here, so it's possible
279 // we'll get the wrong result for a loop with complicated control
280 // jumps via gotos.
281 func hasTopLevelLoopBodyReturnOrBreak(loopBody ir.Nodes) bool {
282         for _, n := range loopBody {
283                 if n.Op() == ir.ORETURN || n.Op() == ir.OBREAK {
284                         return true
285                 }
286         }
287         return false
288 }
289
290 // containingAssignment returns the top-level assignment statement
291 // for a statement level function call "n". Examples:
292 //
293 //      x := foo()
294 //      x, y := bar(z, baz())
295 //      if blah() { ...
296 //
297 // Here the top-level assignment statement for the foo() call is the
298 // statement assigning to "x"; the top-level assignment for "bar()"
299 // call is the assignment to x,y. For the baz() and blah() calls,
300 // there is no top level assignment statement.
301 //
302 // The unstated goal here is that we want to use the containing
303 // assignment to establish a connection between a given call and the
304 // variables to which its results/returns are being assigned.
305 //
306 // Note that for the "bar" command above, the front end sometimes
307 // decomposes this into two assignments, the first one assigning the
308 // call to a pair of auto-temps, then the second one assigning the
309 // auto-temps to the user-visible vars. This helper will return the
310 // second (outer) of these two.
311 func (csa *callSiteAnalyzer) containingAssignment(n ir.Node) ir.Node {
312         parent := csa.nstack[len(csa.nstack)-1]
313
314         // assignsOnlyAutoTemps returns TRUE of the specified OAS2FUNC
315         // node assigns only auto-temps.
316         assignsOnlyAutoTemps := func(x ir.Node) bool {
317                 alst := x.(*ir.AssignListStmt)
318                 oa2init := alst.Init()
319                 if len(oa2init) == 0 {
320                         return false
321                 }
322                 for _, v := range oa2init {
323                         d := v.(*ir.Decl)
324                         if !ir.IsAutoTmp(d.X) {
325                                 return false
326                         }
327                 }
328                 return true
329         }
330
331         // Simple case: x := foo()
332         if parent.Op() == ir.OAS {
333                 return parent
334         }
335
336         // Multi-return case: x, y := bar()
337         if parent.Op() == ir.OAS2FUNC {
338                 // Hack city: if the result vars are auto-temps, try looking
339                 // for an outer assignment in the tree. The code shape we're
340                 // looking for here is:
341                 //
342                 // OAS1({x,y},OCONVNOP(OAS2FUNC({auto1,auto2},OCALLFUNC(bar))))
343                 //
344                 if assignsOnlyAutoTemps(parent) {
345                         par2 := csa.nstack[len(csa.nstack)-2]
346                         if par2.Op() == ir.OAS2 {
347                                 return par2
348                         }
349                         if par2.Op() == ir.OCONVNOP {
350                                 par3 := csa.nstack[len(csa.nstack)-3]
351                                 if par3.Op() == ir.OAS2 {
352                                         return par3
353                                 }
354                         }
355                 }
356         }
357
358         return nil
359 }