]> Cypherpunks.ru repositories - gostls13.git/blob - src/cmd/compile/internal/inline/inlheur/analyze_func_params.go
cmd/compile/internal/inline: extend flag calculation via export data
[gostls13.git] / src / cmd / compile / internal / inline / inlheur / analyze_func_params.go
1 // Copyright 2023 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package inlheur
6
7 import (
8         "cmd/compile/internal/ir"
9         "fmt"
10         "os"
11 )
12
13 // paramsAnalyzer holds state information for the phase that computes
14 // flags for a Go functions parameters, for use in inline heuristics.
15 // Note that the params slice below includes entries for blanks.
16 type paramsAnalyzer struct {
17         fname  string
18         values []ParamPropBits
19         params []*ir.Name
20         top    []bool
21         *condLevelTracker
22 }
23
24 // dclParams returns a slice containing the non-blank, named params
25 // for the specific function (plus rcvr as well if applicable) in
26 // declaration order.
27 func dclParams(fn *ir.Func) []*ir.Name {
28         params := []*ir.Name{}
29         for _, n := range fn.Dcl {
30                 if n.Op() != ir.ONAME {
31                         continue
32                 }
33                 if n.Class != ir.PPARAM {
34                         continue
35                 }
36                 params = append(params, n)
37         }
38         return params
39 }
40
41 // getParams returns an *ir.Name slice containing all params for the
42 // function (plus rcvr as well if applicable). Note that this slice
43 // includes entries for blanks; entries in the returned slice corresponding
44 // to blanks or unnamed params will be nil.
45 func getParams(fn *ir.Func) []*ir.Name {
46         dclparms := dclParams(fn)
47         dclidx := 0
48         recvrParms := fn.Type().RecvParams()
49         params := make([]*ir.Name, len(recvrParms))
50         for i := range recvrParms {
51                 var v *ir.Name
52                 if recvrParms[i].Sym != nil &&
53                         !recvrParms[i].Sym.IsBlank() {
54                         v = dclparms[dclidx]
55                         dclidx++
56                 }
57                 params[i] = v
58         }
59         return params
60 }
61
62 func makeParamsAnalyzer(fn *ir.Func) *paramsAnalyzer {
63         params := getParams(fn) // includes receiver if applicable
64         vals := make([]ParamPropBits, len(params))
65         top := make([]bool, len(params))
66         for i, pn := range params {
67                 if pn == nil {
68                         continue
69                 }
70                 pt := pn.Type()
71                 if !pt.IsScalar() && !pt.HasNil() {
72                         // existing properties not applicable here (for things
73                         // like structs, arrays, slices, etc).
74                         continue
75                 }
76                 // If param is reassigned, skip it.
77                 if ir.Reassigned(pn) {
78                         continue
79                 }
80                 top[i] = true
81         }
82
83         if debugTrace&debugTraceParams != 0 {
84                 fmt.Fprintf(os.Stderr, "=-= param analysis of func %v:\n",
85                         fn.Sym().Name)
86                 for i := range vals {
87                         n := "_"
88                         if params[i] != nil {
89                                 n = params[i].Sym().String()
90                         }
91                         fmt.Fprintf(os.Stderr, "=-=  %d: %q %s\n",
92                                 i, n, vals[i].String())
93                 }
94         }
95
96         return &paramsAnalyzer{
97                 fname:            fn.Sym().Name,
98                 values:           vals,
99                 params:           params,
100                 top:              top,
101                 condLevelTracker: new(condLevelTracker),
102         }
103 }
104
105 func (pa *paramsAnalyzer) setResults(fp *FuncProps) {
106         fp.ParamFlags = pa.values
107 }
108
109 func (pa *paramsAnalyzer) findParamIdx(n *ir.Name) int {
110         if n == nil {
111                 panic("bad")
112         }
113         for i := range pa.params {
114                 if pa.params[i] == n {
115                         return i
116                 }
117         }
118         return -1
119 }
120
121 type testfType func(x ir.Node, param *ir.Name, idx int) (bool, bool)
122
123 // paramsAnalyzer invokes function 'testf' on the specified expression
124 // 'x' for each parameter, and if the result is TRUE, or's 'flag' into
125 // the flags for that param.
126 func (pa *paramsAnalyzer) checkParams(x ir.Node, flag ParamPropBits, mayflag ParamPropBits, testf testfType) {
127         for idx, p := range pa.params {
128                 if !pa.top[idx] && pa.values[idx] == ParamNoInfo {
129                         continue
130                 }
131                 result, may := testf(x, p, idx)
132                 if debugTrace&debugTraceParams != 0 {
133                         fmt.Fprintf(os.Stderr, "=-= test expr %v param %s result=%v flag=%s\n", x, p.Sym().Name, result, flag.String())
134                 }
135                 if result {
136                         v := flag
137                         if pa.condLevel != 0 || may {
138                                 v = mayflag
139                         }
140                         pa.values[idx] |= v
141                         pa.top[idx] = false
142                 }
143         }
144 }
145
146 // foldCheckParams checks expression 'x' (an 'if' condition or
147 // 'switch' stmt expr) to see if the expr would fold away if a
148 // specific parameter had a constant value.
149 func (pa *paramsAnalyzer) foldCheckParams(x ir.Node) {
150         pa.checkParams(x, ParamFeedsIfOrSwitch, ParamMayFeedIfOrSwitch,
151                 func(x ir.Node, p *ir.Name, idx int) (bool, bool) {
152                         return ShouldFoldIfNameConstant(x, []*ir.Name{p}), false
153                 })
154 }
155
156 // callCheckParams examines the target of call expression 'ce' to see
157 // if it is making a call to the value passed in for some parameter.
158 func (pa *paramsAnalyzer) callCheckParams(ce *ir.CallExpr) {
159         switch ce.Op() {
160         case ir.OCALLINTER:
161                 if ce.Op() != ir.OCALLINTER {
162                         return
163                 }
164                 sel := ce.X.(*ir.SelectorExpr)
165                 r := ir.StaticValue(sel.X)
166                 if r.Op() != ir.ONAME {
167                         return
168                 }
169                 name := r.(*ir.Name)
170                 if name.Class != ir.PPARAM {
171                         return
172                 }
173                 pa.checkParams(r, ParamFeedsInterfaceMethodCall,
174                         ParamMayFeedInterfaceMethodCall,
175                         func(x ir.Node, p *ir.Name, idx int) (bool, bool) {
176                                 name := x.(*ir.Name)
177                                 return name == p, false
178                         })
179         case ir.OCALLFUNC:
180                 if ce.X.Op() != ir.ONAME {
181                         return
182                 }
183                 called := ir.StaticValue(ce.X)
184                 if called.Op() != ir.ONAME {
185                         return
186                 }
187                 name := called.(*ir.Name)
188                 if name.Class == ir.PPARAM {
189                         pa.checkParams(called, ParamFeedsIndirectCall,
190                                 ParamMayFeedIndirectCall,
191                                 func(x ir.Node, p *ir.Name, idx int) (bool, bool) {
192                                         name := x.(*ir.Name)
193                                         return name == p, false
194                                 })
195                 } else {
196                         cname, isFunc, _ := isFuncName(called)
197                         if isFunc {
198                                 pa.deriveFlagsFromCallee(ce, cname.Func)
199                         }
200                 }
201         }
202 }
203
204 // deriveFlagsFromCallee tries to derive flags for the current
205 // function based on a call this function makes to some other
206 // function. Example:
207 //
208 //      /* Simple */                /* Derived from callee */
209 //      func foo(f func(int)) {     func foo(f func(int)) {
210 //        f(2)                        bar(32, f)
211 //      }                           }
212 //                                  func bar(x int, f func()) {
213 //                                    f(x)
214 //                                  }
215 //
216 // Here we can set the "param feeds indirect call" flag for
217 // foo's param 'f' since we know that bar has that flag set for
218 // its second param, and we're passing that param a function.
219 func (pa *paramsAnalyzer) deriveFlagsFromCallee(ce *ir.CallExpr, callee *ir.Func) {
220         calleeProps := propsForFunc(callee)
221         if calleeProps == nil {
222                 return
223         }
224         if debugTrace&debugTraceParams != 0 {
225                 fmt.Fprintf(os.Stderr, "=-= callee props for %v:\n%s",
226                         callee.Sym().Name, calleeProps.String())
227         }
228
229         must := []ParamPropBits{ParamFeedsInterfaceMethodCall, ParamFeedsIndirectCall, ParamFeedsIfOrSwitch}
230         may := []ParamPropBits{ParamMayFeedInterfaceMethodCall, ParamMayFeedIndirectCall, ParamMayFeedIfOrSwitch}
231
232         for pidx, arg := range ce.Args {
233                 // Does the callee param have any interesting properties?
234                 // If not we can skip this one.
235                 pflag := calleeProps.ParamFlags[pidx]
236                 if pflag == 0 {
237                         continue
238                 }
239                 // See if one of the caller's parameters is flowing unmodified
240                 // into this actual expression.
241                 r := ir.StaticValue(arg)
242                 if r.Op() != ir.ONAME {
243                         return
244                 }
245                 name := r.(*ir.Name)
246                 if name.Class != ir.PPARAM {
247                         return
248                 }
249                 callerParamIdx := pa.findParamIdx(name)
250                 if callerParamIdx == -1 || pa.params[callerParamIdx] == nil {
251                         panic("something went wrong")
252                 }
253                 if !pa.top[callerParamIdx] &&
254                         pa.values[callerParamIdx] == ParamNoInfo {
255                         continue
256                 }
257                 if debugTrace&debugTraceParams != 0 {
258                         fmt.Fprintf(os.Stderr, "=-= pflag for arg %d is %s\n",
259                                 pidx, pflag.String())
260                 }
261                 for i := range must {
262                         mayv := may[i]
263                         mustv := must[i]
264                         if pflag&mustv != 0 && pa.condLevel == 0 {
265                                 pa.values[callerParamIdx] |= mustv
266                         } else if pflag&(mustv|mayv) != 0 {
267                                 pa.values[callerParamIdx] |= mayv
268                         }
269                 }
270                 pa.top[callerParamIdx] = false
271         }
272 }
273
274 func (pa *paramsAnalyzer) nodeVisitPost(n ir.Node) {
275         if len(pa.values) == 0 {
276                 return
277         }
278         pa.condLevelTracker.post(n)
279         switch n.Op() {
280         case ir.OCALLFUNC:
281                 ce := n.(*ir.CallExpr)
282                 pa.callCheckParams(ce)
283         case ir.OCALLINTER:
284                 ce := n.(*ir.CallExpr)
285                 pa.callCheckParams(ce)
286         case ir.OIF:
287                 ifst := n.(*ir.IfStmt)
288                 pa.foldCheckParams(ifst.Cond)
289         case ir.OSWITCH:
290                 swst := n.(*ir.SwitchStmt)
291                 if swst.Tag != nil {
292                         pa.foldCheckParams(swst.Tag)
293                 }
294         }
295 }
296
297 func (pa *paramsAnalyzer) nodeVisitPre(n ir.Node) {
298         if len(pa.values) == 0 {
299                 return
300         }
301         pa.condLevelTracker.pre(n)
302 }
303
304 // condLevelTracker helps keeps track very roughly of "level of conditional
305 // nesting", e.g. how many "if" statements you have to go through to
306 // get to the point where a given stmt executes. Example:
307 //
308 //                            cond nesting level
309 //      func foo() {
310 //       G = 1                   0
311 //       if x < 10 {             0
312 //        if y < 10 {            1
313 //         G = 0                 2
314 //        }
315 //       }
316 //      }
317 //
318 // The intent here is to provide some sort of very abstract relative
319 // hotness metric, e.g. "G = 1" above is expected to be executed more
320 // often than "G = 0" (in the aggregate, across large numbers of
321 // functions).
322 type condLevelTracker struct {
323         condLevel int
324 }
325
326 func (c *condLevelTracker) pre(n ir.Node) {
327         // Increment level of "conditional testing" if we see
328         // an "if" or switch statement, and decrement if in
329         // a loop.
330         switch n.Op() {
331         case ir.OIF, ir.OSWITCH:
332                 c.condLevel++
333         case ir.OFOR, ir.ORANGE:
334                 c.condLevel--
335         }
336 }
337
338 func (c *condLevelTracker) post(n ir.Node) {
339         switch n.Op() {
340         case ir.OFOR, ir.ORANGE:
341                 c.condLevel++
342         case ir.OIF:
343                 c.condLevel--
344         case ir.OSWITCH:
345                 c.condLevel--
346         }
347 }