]> Cypherpunks.ru repositories - gostls13.git/blob - src/cmd/compile/internal/inline/inlheur/analyze_func_callsites.go
cmd/compile/internal/inline: add call site flag generation
[gostls13.git] / src / cmd / compile / internal / inline / inlheur / analyze_func_callsites.go
1 // Copyright 2023 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package inlheur
6
7 import (
8         "cmd/compile/internal/ir"
9         "cmd/compile/internal/pgo"
10         "fmt"
11         "os"
12         "strings"
13 )
14
15 type callSiteAnalyzer struct {
16         cstab    CallSiteTab
17         fn       *ir.Func
18         ptab     map[ir.Node]pstate
19         nstack   []ir.Node
20         loopNest int
21         isInit   bool
22 }
23
24 func makeCallSiteAnalyzer(fn *ir.Func, ptab map[ir.Node]pstate) *callSiteAnalyzer {
25         isInit := fn.IsPackageInit() || strings.HasPrefix(fn.Sym().Name, "init.")
26         return &callSiteAnalyzer{
27                 fn:     fn,
28                 cstab:  make(CallSiteTab),
29                 ptab:   ptab,
30                 isInit: isInit,
31         }
32 }
33
34 func computeCallSiteTable(fn *ir.Func, ptab map[ir.Node]pstate) CallSiteTab {
35         if debugTrace != 0 {
36                 fmt.Fprintf(os.Stderr, "=-= making callsite table for func %v:\n",
37                         fn.Sym().Name)
38         }
39         csa := makeCallSiteAnalyzer(fn, ptab)
40         var doNode func(ir.Node) bool
41         doNode = func(n ir.Node) bool {
42                 csa.nodeVisitPre(n)
43                 ir.DoChildren(n, doNode)
44                 csa.nodeVisitPost(n)
45                 return false
46         }
47         doNode(fn)
48         return csa.cstab
49 }
50
51 func (csa *callSiteAnalyzer) flagsForNode(call *ir.CallExpr) CSPropBits {
52         var r CSPropBits
53
54         if debugTrace&debugTraceCalls != 0 {
55                 fmt.Fprintf(os.Stderr, "=-= analyzing call at %s\n",
56                         fmtFullPos(call.Pos()))
57         }
58
59         // Set a bit if this call is within a loop.
60         if csa.loopNest > 0 {
61                 r |= CallSiteInLoop
62         }
63
64         // Set a bit if the call is within an init function (either
65         // compiler-generated or user-written).
66         if csa.isInit {
67                 r |= CallSiteInInitFunc
68         }
69
70         // Decide whether to apply the panic path heuristic. Hack: don't
71         // apply this heuristic in the function "main.main" (mostly just
72         // to avoid annoying users).
73         if !isMainMain(csa.fn) {
74                 r = csa.determinePanicPathBits(call, r)
75         }
76
77         return r
78 }
79
80 // determinePanicPathBits updates the CallSiteOnPanicPath bit within
81 // "r" if we think this call is on an unconditional path to
82 // panic/exit. Do this by walking back up the node stack to see if we
83 // can find either A) an enclosing panic, or B) a statement node that
84 // we've determined leads to a panic/exit.
85 func (csa *callSiteAnalyzer) determinePanicPathBits(call ir.Node, r CSPropBits) CSPropBits {
86         csa.nstack = append(csa.nstack, call)
87         defer func() {
88                 csa.nstack = csa.nstack[:len(csa.nstack)-1]
89         }()
90
91         for ri := range csa.nstack[:len(csa.nstack)-1] {
92                 i := len(csa.nstack) - ri - 1
93                 n := csa.nstack[i]
94                 _, isCallExpr := n.(*ir.CallExpr)
95                 _, isStmt := n.(ir.Stmt)
96                 if isCallExpr {
97                         isStmt = false
98                 }
99
100                 if debugTrace&debugTraceCalls != 0 {
101                         ps, inps := csa.ptab[n]
102                         fmt.Fprintf(os.Stderr, "=-= callpar %d op=%s ps=%s inptab=%v stmt=%v\n", i, n.Op().String(), ps.String(), inps, isStmt)
103                 }
104
105                 if n.Op() == ir.OPANIC {
106                         r |= CallSiteOnPanicPath
107                         break
108                 }
109                 if v, ok := csa.ptab[n]; ok {
110                         if v == psCallsPanic {
111                                 r |= CallSiteOnPanicPath
112                                 break
113                         }
114                         if isStmt {
115                                 break
116                         }
117                 }
118         }
119         return r
120 }
121
122 func (csa *callSiteAnalyzer) addCallSite(callee *ir.Func, call *ir.CallExpr) {
123         // FIXME: maybe bulk-allocate these?
124         cs := &CallSite{
125                 Call:   call,
126                 Callee: callee,
127                 Assign: csa.containingAssignment(call),
128                 Flags:  csa.flagsForNode(call),
129                 Id:     uint(len(csa.cstab)),
130         }
131         if _, ok := csa.cstab[call]; ok {
132                 fmt.Fprintf(os.Stderr, "*** cstab duplicate entry at: %s\n",
133                         fmtFullPos(call.Pos()))
134                 fmt.Fprintf(os.Stderr, "*** call: %+v\n", call)
135                 panic("bad")
136         }
137         if debugTrace&debugTraceCalls != 0 {
138                 fmt.Fprintf(os.Stderr, "=-= added callsite: callee=%s call=%v\n",
139                         callee.Sym().Name, callee)
140         }
141
142         csa.cstab[call] = cs
143 }
144
145 func (csa *callSiteAnalyzer) nodeVisitPre(n ir.Node) {
146         switch n.Op() {
147         case ir.ORANGE, ir.OFOR:
148                 if !hasTopLevelLoopBodyReturnOrBreak(loopBody(n)) {
149                         csa.loopNest++
150                 }
151         case ir.OCALLFUNC:
152                 ce := n.(*ir.CallExpr)
153                 callee := pgo.DirectCallee(ce.X)
154                 if callee != nil && callee.Inl != nil {
155                         csa.addCallSite(callee, ce)
156                 }
157         }
158         csa.nstack = append(csa.nstack, n)
159 }
160
161 func (csa *callSiteAnalyzer) nodeVisitPost(n ir.Node) {
162         csa.nstack = csa.nstack[:len(csa.nstack)-1]
163         switch n.Op() {
164         case ir.ORANGE, ir.OFOR:
165                 if !hasTopLevelLoopBodyReturnOrBreak(loopBody(n)) {
166                         csa.loopNest--
167                 }
168         }
169 }
170
171 func loopBody(n ir.Node) ir.Nodes {
172         if forst, ok := n.(*ir.ForStmt); ok {
173                 return forst.Body
174         }
175         if rst, ok := n.(*ir.RangeStmt); ok {
176                 return rst.Body
177         }
178         return nil
179 }
180
181 // hasTopLevelLoopBodyReturnOrBreak examines the body of a "for" or
182 // "range" loop to try to verify that it is a real loop, as opposed to
183 // a construct that is syntactically loopy but doesn't actually iterate
184 // multiple times, like:
185 //
186 //      for {
187 //        blah()
188 //        return 1
189 //      }
190 //
191 // [Remark: the pattern above crops up quite a bit in the source code
192 // for the compiler itself, e.g. the auto-generated rewrite code]
193 //
194 // Note that we don't look for GOTO statements here, so it's possible
195 // we'll get the wrong result for a loop with complicated control
196 // jumps via gotos.
197 func hasTopLevelLoopBodyReturnOrBreak(loopBody ir.Nodes) bool {
198         for _, n := range loopBody {
199                 if n.Op() == ir.ORETURN || n.Op() == ir.OBREAK {
200                         return true
201                 }
202         }
203         return false
204 }
205
206 // containingAssignment returns the top-level assignment statement
207 // for a statement level function call "n". Examples:
208 //
209 //      x := foo()
210 //      x, y := bar(z, baz())
211 //      if blah() { ...
212 //
213 // Here the top-level assignment statement for the foo() call is the
214 // statement assigning to "x"; the top-level assignment for "bar()"
215 // call is the assignment to x,y.   For the baz() and blah() calls,
216 // there is no top level assignment statement.
217 //
218 // The unstated goal here is that we want to use the containing assignment
219 // to establish a connection between a given call and the variables
220 // to which its results/returns are being assigned.
221 //
222 // Note that for the "bar" command above, the front end sometimes
223 // decomposes this into two assignments, the first one assigning the
224 // call to a pair of auto-temps, then the second one assigning the
225 // auto-temps to the user-visible vars. This helper will return the
226 // second (outer) of these two.
227 func (csa *callSiteAnalyzer) containingAssignment(n ir.Node) ir.Node {
228         parent := csa.nstack[len(csa.nstack)-1]
229
230         // assignsOnlyAutoTemps returns TRUE of the specified OAS2FUNC
231         // node assigns only auto-temps.
232         assignsOnlyAutoTemps := func(x ir.Node) bool {
233                 alst := x.(*ir.AssignListStmt)
234                 oa2init := alst.Init()
235                 if len(oa2init) == 0 {
236                         return false
237                 }
238                 for _, v := range oa2init {
239                         d := v.(*ir.Decl)
240                         if !ir.IsAutoTmp(d.X) {
241                                 return false
242                         }
243                 }
244                 return true
245         }
246
247         // Simple case: x := foo()
248         if parent.Op() == ir.OAS {
249                 return parent
250         }
251
252         // Multi-return case: x, y := bar()
253         if parent.Op() == ir.OAS2FUNC {
254                 // Hack city: if the result vars are auto-temps, try looking
255                 // for an outer assignment in the tree. The code shape we're
256                 // looking for here is:
257                 //
258                 // OAS1({x,y},OCONVNOP(OAS2FUNC({auto1,auto2},OCALLFUNC(bar))))
259                 //
260                 if assignsOnlyAutoTemps(parent) {
261                         par2 := csa.nstack[len(csa.nstack)-2]
262                         if par2.Op() == ir.OAS2 {
263                                 return par2
264                         }
265                         if par2.Op() == ir.OCONVNOP {
266                                 par3 := csa.nstack[len(csa.nstack)-3]
267                                 if par3.Op() == ir.OAS2 {
268                                         return par3
269                                 }
270                         }
271                 }
272         }
273
274         return nil
275 }