]> Cypherpunks.ru repositories - gostls13.git/blobdiff - src/cmd/compile/internal/inline/inlheur/analyze_func_callsites.go
cmd/compile/internal/inline: rework call scoring for non-inlinable funcs
[gostls13.git] / src / cmd / compile / internal / inline / inlheur / analyze_func_callsites.go
index e494d03e0ac7f82f8aee7214f8939ac6d2c598ce..e59ee265312b1f2112caa655b98d057ce4a94af2 100644 (file)
@@ -9,25 +9,37 @@ import (
        "cmd/compile/internal/pgo"
        "fmt"
        "os"
+       "strings"
 )
 
 type callSiteAnalyzer struct {
-       cstab  CallSiteTab
-       nstack []ir.Node
+       cstab    CallSiteTab
+       fn       *ir.Func
+       ptab     map[ir.Node]pstate
+       nstack   []ir.Node
+       loopNest int
+       isInit   bool
 }
 
-func makeCallSiteAnalyzer(fn *ir.Func) *callSiteAnalyzer {
+func makeCallSiteAnalyzer(fn *ir.Func, cstab CallSiteTab, ptab map[ir.Node]pstate, loopNestingLevel int) *callSiteAnalyzer {
+       isInit := fn.IsPackageInit() || strings.HasPrefix(fn.Sym().Name, "init.")
        return &callSiteAnalyzer{
-               cstab: make(CallSiteTab),
+               fn:       fn,
+               cstab:    cstab,
+               ptab:     ptab,
+               isInit:   isInit,
+               loopNest: loopNestingLevel,
+               nstack:   []ir.Node{fn},
        }
 }
 
-func computeCallSiteTable(fn *ir.Func) CallSiteTab {
-       if debugTrace&debugTraceCalls != 0 {
-               fmt.Fprintf(os.Stderr, "=-= making callsite table for func %v:\n",
-                       fn.Sym().Name)
-       }
-       csa := makeCallSiteAnalyzer(fn)
+// computeCallSiteTable builds and returns a table of call sites for
+// the specified region in function fn. A region here corresponds to a
+// specific subtree within the AST for a function. The main intended
+// use cases are for 'region' to be either A) an entire function body,
+// or B) an inlined call expression.
+func computeCallSiteTable(fn *ir.Func, region ir.Nodes, cstab CallSiteTab, ptab map[ir.Node]pstate, loopNestingLevel int) CallSiteTab {
+       csa := makeCallSiteAnalyzer(fn, cstab, ptab, loopNestingLevel)
        var doNode func(ir.Node) bool
        doNode = func(n ir.Node) bool {
                csa.nodeVisitPre(n)
@@ -35,22 +47,92 @@ func computeCallSiteTable(fn *ir.Func) CallSiteTab {
                csa.nodeVisitPost(n)
                return false
        }
-       doNode(fn)
+       for _, n := range region {
+               doNode(n)
+       }
        return csa.cstab
 }
 
 func (csa *callSiteAnalyzer) flagsForNode(call *ir.CallExpr) CSPropBits {
-       return 0
+       var r CSPropBits
+
+       if debugTrace&debugTraceCalls != 0 {
+               fmt.Fprintf(os.Stderr, "=-= analyzing call at %s\n",
+                       fmtFullPos(call.Pos()))
+       }
+
+       // Set a bit if this call is within a loop.
+       if csa.loopNest > 0 {
+               r |= CallSiteInLoop
+       }
+
+       // Set a bit if the call is within an init function (either
+       // compiler-generated or user-written).
+       if csa.isInit {
+               r |= CallSiteInInitFunc
+       }
+
+       // Decide whether to apply the panic path heuristic. Hack: don't
+       // apply this heuristic in the function "main.main" (mostly just
+       // to avoid annoying users).
+       if !isMainMain(csa.fn) {
+               r = csa.determinePanicPathBits(call, r)
+       }
+
+       return r
+}
+
+// determinePanicPathBits updates the CallSiteOnPanicPath bit within
+// "r" if we think this call is on an unconditional path to
+// panic/exit. Do this by walking back up the node stack to see if we
+// can find either A) an enclosing panic, or B) a statement node that
+// we've determined leads to a panic/exit.
+func (csa *callSiteAnalyzer) determinePanicPathBits(call ir.Node, r CSPropBits) CSPropBits {
+       csa.nstack = append(csa.nstack, call)
+       defer func() {
+               csa.nstack = csa.nstack[:len(csa.nstack)-1]
+       }()
+
+       for ri := range csa.nstack[:len(csa.nstack)-1] {
+               i := len(csa.nstack) - ri - 1
+               n := csa.nstack[i]
+               _, isCallExpr := n.(*ir.CallExpr)
+               _, isStmt := n.(ir.Stmt)
+               if isCallExpr {
+                       isStmt = false
+               }
+
+               if debugTrace&debugTraceCalls != 0 {
+                       ps, inps := csa.ptab[n]
+                       fmt.Fprintf(os.Stderr, "=-= callpar %d op=%s ps=%s inptab=%v stmt=%v\n", i, n.Op().String(), ps.String(), inps, isStmt)
+               }
+
+               if n.Op() == ir.OPANIC {
+                       r |= CallSiteOnPanicPath
+                       break
+               }
+               if v, ok := csa.ptab[n]; ok {
+                       if v == psCallsPanic {
+                               r |= CallSiteOnPanicPath
+                               break
+                       }
+                       if isStmt {
+                               break
+                       }
+               }
+       }
+       return r
 }
 
 func (csa *callSiteAnalyzer) addCallSite(callee *ir.Func, call *ir.CallExpr) {
+       flags := csa.flagsForNode(call)
        // FIXME: maybe bulk-allocate these?
        cs := &CallSite{
                Call:   call,
                Callee: callee,
                Assign: csa.containingAssignment(call),
-               Flags:  csa.flagsForNode(call),
-               Id:     uint(len(csa.cstab)),
+               Flags:  flags,
+               ID:     uint(len(csa.cstab)),
        }
        if _, ok := csa.cstab[call]; ok {
                fmt.Fprintf(os.Stderr, "*** cstab duplicate entry at: %s\n",
@@ -58,19 +140,31 @@ func (csa *callSiteAnalyzer) addCallSite(callee *ir.Func, call *ir.CallExpr) {
                fmt.Fprintf(os.Stderr, "*** call: %+v\n", call)
                panic("bad")
        }
-       if debugTrace&debugTraceCalls != 0 {
-               fmt.Fprintf(os.Stderr, "=-= added callsite: callee=%s call=%v\n",
-                       callee.Sym().Name, callee)
+       if callee.Inl != nil {
+               // Set initial score for callsite to the cost computed
+               // by CanInline; this score will be refined later based
+               // on heuristics.
+               cs.Score = int(callee.Inl.Cost)
        }
 
+       if csa.cstab == nil {
+               csa.cstab = make(CallSiteTab)
+       }
        csa.cstab[call] = cs
+       if debugTrace&debugTraceCalls != 0 {
+               fmt.Fprintf(os.Stderr, "=-= added callsite at %s: callee=%s call[%p]=%v\n", fmtFullPos(call.Pos()), callee.Sym().Name, call, call)
+       }
 }
 
 func (csa *callSiteAnalyzer) nodeVisitPre(n ir.Node) {
        switch n.Op() {
+       case ir.ORANGE, ir.OFOR:
+               if !hasTopLevelLoopBodyReturnOrBreak(loopBody(n)) {
+                       csa.loopNest++
+               }
        case ir.OCALLFUNC:
                ce := n.(*ir.CallExpr)
-               callee := pgo.DirectCallee(ce.X)
+               callee := pgo.DirectCallee(ce.Fun)
                if callee != nil && callee.Inl != nil {
                        csa.addCallSite(callee, ce)
                }
@@ -80,6 +174,47 @@ func (csa *callSiteAnalyzer) nodeVisitPre(n ir.Node) {
 
 func (csa *callSiteAnalyzer) nodeVisitPost(n ir.Node) {
        csa.nstack = csa.nstack[:len(csa.nstack)-1]
+       switch n.Op() {
+       case ir.ORANGE, ir.OFOR:
+               if !hasTopLevelLoopBodyReturnOrBreak(loopBody(n)) {
+                       csa.loopNest--
+               }
+       }
+}
+
+func loopBody(n ir.Node) ir.Nodes {
+       if forst, ok := n.(*ir.ForStmt); ok {
+               return forst.Body
+       }
+       if rst, ok := n.(*ir.RangeStmt); ok {
+               return rst.Body
+       }
+       return nil
+}
+
+// hasTopLevelLoopBodyReturnOrBreak examines the body of a "for" or
+// "range" loop to try to verify that it is a real loop, as opposed to
+// a construct that is syntactically loopy but doesn't actually iterate
+// multiple times, like:
+//
+//     for {
+//       blah()
+//       return 1
+//     }
+//
+// [Remark: the pattern above crops up quite a bit in the source code
+// for the compiler itself, e.g. the auto-generated rewrite code]
+//
+// Note that we don't look for GOTO statements here, so it's possible
+// we'll get the wrong result for a loop with complicated control
+// jumps via gotos.
+func hasTopLevelLoopBodyReturnOrBreak(loopBody ir.Nodes) bool {
+       for _, n := range loopBody {
+               if n.Op() == ir.ORETURN || n.Op() == ir.OBREAK {
+                       return true
+               }
+       }
+       return false
 }
 
 // containingAssignment returns the top-level assignment statement
@@ -91,12 +226,12 @@ func (csa *callSiteAnalyzer) nodeVisitPost(n ir.Node) {
 //
 // Here the top-level assignment statement for the foo() call is the
 // statement assigning to "x"; the top-level assignment for "bar()"
-// call is the assignment to x,y.   For the baz() and blah() calls,
+// call is the assignment to x,y. For the baz() and blah() calls,
 // there is no top level assignment statement.
 //
-// The unstated goal here is that we want to use the containing assignment
-// to establish a connection between a given call and the variables
-// to which its results/returns are being assigned.
+// The unstated goal here is that we want to use the containing
+// assignment to establish a connection between a given call and the
+// variables to which its results/returns are being assigned.
 //
 // Note that for the "bar" command above, the front end sometimes
 // decomposes this into two assignments, the first one assigning the