]> Cypherpunks.ru repositories - gostls13.git/blob - src/cmd/compile/internal/inline/inlheur/analyze_func_params.go
cmd/compile/internal/inline/inlheur: remove pkg-level call site table
[gostls13.git] / src / cmd / compile / internal / inline / inlheur / analyze_func_params.go
1 // Copyright 2023 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package inlheur
6
7 import (
8         "cmd/compile/internal/ir"
9         "fmt"
10         "os"
11 )
12
13 // paramsAnalyzer holds state information for the phase that computes
14 // flags for a Go functions parameters, for use in inline heuristics.
15 // Note that the params slice below includes entries for blanks.
16 type paramsAnalyzer struct {
17         fname  string
18         values []ParamPropBits
19         params []*ir.Name
20         top    []bool
21         *condLevelTracker
22 }
23
24 // getParams returns an *ir.Name slice containing all params for the
25 // function (plus rcvr as well if applicable).
26 func getParams(fn *ir.Func) []*ir.Name {
27         sig := fn.Type()
28         numParams := sig.NumRecvs() + sig.NumParams()
29         return fn.Dcl[:numParams]
30 }
31
32 func makeParamsAnalyzer(fn *ir.Func) *paramsAnalyzer {
33         params := getParams(fn) // includes receiver if applicable
34         vals := make([]ParamPropBits, len(params))
35         top := make([]bool, len(params))
36         for i, pn := range params {
37                 if pn == nil {
38                         continue
39                 }
40                 pt := pn.Type()
41                 if !pt.IsScalar() && !pt.HasNil() {
42                         // existing properties not applicable here (for things
43                         // like structs, arrays, slices, etc).
44                         continue
45                 }
46                 // If param is reassigned, skip it.
47                 if ir.Reassigned(pn) {
48                         continue
49                 }
50                 top[i] = true
51         }
52
53         if debugTrace&debugTraceParams != 0 {
54                 fmt.Fprintf(os.Stderr, "=-= param analysis of func %v:\n",
55                         fn.Sym().Name)
56                 for i := range vals {
57                         n := "_"
58                         if params[i] != nil {
59                                 n = params[i].Sym().String()
60                         }
61                         fmt.Fprintf(os.Stderr, "=-=  %d: %q %s\n",
62                                 i, n, vals[i].String())
63                 }
64         }
65
66         return &paramsAnalyzer{
67                 fname:            fn.Sym().Name,
68                 values:           vals,
69                 params:           params,
70                 top:              top,
71                 condLevelTracker: new(condLevelTracker),
72         }
73 }
74
75 func (pa *paramsAnalyzer) setResults(funcProps *FuncProps) {
76         funcProps.ParamFlags = pa.values
77 }
78
79 func (pa *paramsAnalyzer) findParamIdx(n *ir.Name) int {
80         if n == nil {
81                 panic("bad")
82         }
83         for i := range pa.params {
84                 if pa.params[i] == n {
85                         return i
86                 }
87         }
88         return -1
89 }
90
91 type testfType func(x ir.Node, param *ir.Name, idx int) (bool, bool)
92
93 // paramsAnalyzer invokes function 'testf' on the specified expression
94 // 'x' for each parameter, and if the result is TRUE, or's 'flag' into
95 // the flags for that param.
96 func (pa *paramsAnalyzer) checkParams(x ir.Node, flag ParamPropBits, mayflag ParamPropBits, testf testfType) {
97         for idx, p := range pa.params {
98                 if !pa.top[idx] && pa.values[idx] == ParamNoInfo {
99                         continue
100                 }
101                 result, may := testf(x, p, idx)
102                 if debugTrace&debugTraceParams != 0 {
103                         fmt.Fprintf(os.Stderr, "=-= test expr %v param %s result=%v flag=%s\n", x, p.Sym().Name, result, flag.String())
104                 }
105                 if result {
106                         v := flag
107                         if pa.condLevel != 0 || may {
108                                 v = mayflag
109                         }
110                         pa.values[idx] |= v
111                         pa.top[idx] = false
112                 }
113         }
114 }
115
116 // foldCheckParams checks expression 'x' (an 'if' condition or
117 // 'switch' stmt expr) to see if the expr would fold away if a
118 // specific parameter had a constant value.
119 func (pa *paramsAnalyzer) foldCheckParams(x ir.Node) {
120         pa.checkParams(x, ParamFeedsIfOrSwitch, ParamMayFeedIfOrSwitch,
121                 func(x ir.Node, p *ir.Name, idx int) (bool, bool) {
122                         return ShouldFoldIfNameConstant(x, []*ir.Name{p}), false
123                 })
124 }
125
126 // callCheckParams examines the target of call expression 'ce' to see
127 // if it is making a call to the value passed in for some parameter.
128 func (pa *paramsAnalyzer) callCheckParams(ce *ir.CallExpr) {
129         switch ce.Op() {
130         case ir.OCALLINTER:
131                 if ce.Op() != ir.OCALLINTER {
132                         return
133                 }
134                 sel := ce.Fun.(*ir.SelectorExpr)
135                 r := ir.StaticValue(sel.X)
136                 if r.Op() != ir.ONAME {
137                         return
138                 }
139                 name := r.(*ir.Name)
140                 if name.Class != ir.PPARAM {
141                         return
142                 }
143                 pa.checkParams(r, ParamFeedsInterfaceMethodCall,
144                         ParamMayFeedInterfaceMethodCall,
145                         func(x ir.Node, p *ir.Name, idx int) (bool, bool) {
146                                 name := x.(*ir.Name)
147                                 return name == p, false
148                         })
149         case ir.OCALLFUNC:
150                 if ce.Fun.Op() != ir.ONAME {
151                         return
152                 }
153                 called := ir.StaticValue(ce.Fun)
154                 if called.Op() != ir.ONAME {
155                         return
156                 }
157                 name := called.(*ir.Name)
158                 if name.Class == ir.PPARAM {
159                         pa.checkParams(called, ParamFeedsIndirectCall,
160                                 ParamMayFeedIndirectCall,
161                                 func(x ir.Node, p *ir.Name, idx int) (bool, bool) {
162                                         name := x.(*ir.Name)
163                                         return name == p, false
164                                 })
165                 } else {
166                         cname, isFunc, _ := isFuncName(called)
167                         if isFunc {
168                                 pa.deriveFlagsFromCallee(ce, cname.Func)
169                         }
170                 }
171         }
172 }
173
174 // deriveFlagsFromCallee tries to derive flags for the current
175 // function based on a call this function makes to some other
176 // function. Example:
177 //
178 //      /* Simple */                /* Derived from callee */
179 //      func foo(f func(int)) {     func foo(f func(int)) {
180 //        f(2)                        bar(32, f)
181 //      }                           }
182 //                                  func bar(x int, f func()) {
183 //                                    f(x)
184 //                                  }
185 //
186 // Here we can set the "param feeds indirect call" flag for
187 // foo's param 'f' since we know that bar has that flag set for
188 // its second param, and we're passing that param a function.
189 func (pa *paramsAnalyzer) deriveFlagsFromCallee(ce *ir.CallExpr, callee *ir.Func) {
190         calleeProps := propsForFunc(callee)
191         if calleeProps == nil {
192                 return
193         }
194         if debugTrace&debugTraceParams != 0 {
195                 fmt.Fprintf(os.Stderr, "=-= callee props for %v:\n%s",
196                         callee.Sym().Name, calleeProps.String())
197         }
198
199         must := []ParamPropBits{ParamFeedsInterfaceMethodCall, ParamFeedsIndirectCall, ParamFeedsIfOrSwitch}
200         may := []ParamPropBits{ParamMayFeedInterfaceMethodCall, ParamMayFeedIndirectCall, ParamMayFeedIfOrSwitch}
201
202         for pidx, arg := range ce.Args {
203                 // Does the callee param have any interesting properties?
204                 // If not we can skip this one.
205                 pflag := calleeProps.ParamFlags[pidx]
206                 if pflag == 0 {
207                         continue
208                 }
209                 // See if one of the caller's parameters is flowing unmodified
210                 // into this actual expression.
211                 r := ir.StaticValue(arg)
212                 if r.Op() != ir.ONAME {
213                         return
214                 }
215                 name := r.(*ir.Name)
216                 if name.Class != ir.PPARAM {
217                         return
218                 }
219                 callerParamIdx := pa.findParamIdx(name)
220                 if callerParamIdx == -1 || pa.params[callerParamIdx] == nil {
221                         panic("something went wrong")
222                 }
223                 if !pa.top[callerParamIdx] &&
224                         pa.values[callerParamIdx] == ParamNoInfo {
225                         continue
226                 }
227                 if debugTrace&debugTraceParams != 0 {
228                         fmt.Fprintf(os.Stderr, "=-= pflag for arg %d is %s\n",
229                                 pidx, pflag.String())
230                 }
231                 for i := range must {
232                         mayv := may[i]
233                         mustv := must[i]
234                         if pflag&mustv != 0 && pa.condLevel == 0 {
235                                 pa.values[callerParamIdx] |= mustv
236                         } else if pflag&(mustv|mayv) != 0 {
237                                 pa.values[callerParamIdx] |= mayv
238                         }
239                 }
240                 pa.top[callerParamIdx] = false
241         }
242 }
243
244 func (pa *paramsAnalyzer) nodeVisitPost(n ir.Node) {
245         if len(pa.values) == 0 {
246                 return
247         }
248         pa.condLevelTracker.post(n)
249         switch n.Op() {
250         case ir.OCALLFUNC:
251                 ce := n.(*ir.CallExpr)
252                 pa.callCheckParams(ce)
253         case ir.OCALLINTER:
254                 ce := n.(*ir.CallExpr)
255                 pa.callCheckParams(ce)
256         case ir.OIF:
257                 ifst := n.(*ir.IfStmt)
258                 pa.foldCheckParams(ifst.Cond)
259         case ir.OSWITCH:
260                 swst := n.(*ir.SwitchStmt)
261                 if swst.Tag != nil {
262                         pa.foldCheckParams(swst.Tag)
263                 }
264         }
265 }
266
267 func (pa *paramsAnalyzer) nodeVisitPre(n ir.Node) {
268         if len(pa.values) == 0 {
269                 return
270         }
271         pa.condLevelTracker.pre(n)
272 }
273
274 // condLevelTracker helps keeps track very roughly of "level of conditional
275 // nesting", e.g. how many "if" statements you have to go through to
276 // get to the point where a given stmt executes. Example:
277 //
278 //                            cond nesting level
279 //      func foo() {
280 //       G = 1                   0
281 //       if x < 10 {             0
282 //        if y < 10 {            1
283 //         G = 0                 2
284 //        }
285 //       }
286 //      }
287 //
288 // The intent here is to provide some sort of very abstract relative
289 // hotness metric, e.g. "G = 1" above is expected to be executed more
290 // often than "G = 0" (in the aggregate, across large numbers of
291 // functions).
292 type condLevelTracker struct {
293         condLevel int
294 }
295
296 func (c *condLevelTracker) pre(n ir.Node) {
297         // Increment level of "conditional testing" if we see
298         // an "if" or switch statement, and decrement if in
299         // a loop.
300         switch n.Op() {
301         case ir.OIF, ir.OSWITCH:
302                 c.condLevel++
303         case ir.OFOR, ir.ORANGE:
304                 c.condLevel--
305         }
306 }
307
308 func (c *condLevelTracker) post(n ir.Node) {
309         switch n.Op() {
310         case ir.OFOR, ir.ORANGE:
311                 c.condLevel++
312         case ir.OIF:
313                 c.condLevel--
314         case ir.OSWITCH:
315                 c.condLevel--
316         }
317 }