]> Cypherpunks.ru repositories - gostls13.git/commitdiff
cmd/compile: implement range over func
authorRuss Cox <rsc@golang.org>
Wed, 14 Jun 2023 14:56:49 +0000 (10:56 -0400)
committerGopher Robot <gobot@golang.org>
Wed, 20 Sep 2023 14:52:38 +0000 (14:52 +0000)
Add compiler support for range over functions.
See the large comment at the top of
cmd/compile/internal/rangefunc/rewrite.go for details.

This is only reachable if GOEXPERIMENT=range is set,
because otherwise type checking will fail.

For proposal #61405 (but behind a GOEXPERIMENT).
For #61717.

Change-Id: I05717f94e63089c503acc49b28b47edeb4e011b4
Reviewed-on: https://go-review.googlesource.com/c/go/+/510541
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
Auto-Submit: Russ Cox <rsc@golang.org>

17 files changed:
src/cmd/compile/internal/escape/call.go
src/cmd/compile/internal/ir/expr.go
src/cmd/compile/internal/ir/stmt.go
src/cmd/compile/internal/ir/symtab.go
src/cmd/compile/internal/noder/irgen.go
src/cmd/compile/internal/noder/reader.go
src/cmd/compile/internal/noder/writer.go
src/cmd/compile/internal/rangefunc/rewrite.go [new file with mode: 0644]
src/cmd/compile/internal/ssagen/ssa.go
src/cmd/compile/internal/syntax/nodes.go
src/cmd/compile/internal/syntax/tokens.go
src/cmd/compile/internal/typecheck/_builtin/runtime.go
src/cmd/compile/internal/typecheck/builtin.go
src/cmd/compile/internal/walk/expr.go
src/cmd/compile/internal/walk/stmt.go
test/range3.go
test/rangegen.go [new file with mode: 0644]

index 360094ee44f258a32a7549106acba7107dbebc81..2bc87d4f8ee39236ee0a4439ffe6d6805b1930a0 100644 (file)
@@ -185,7 +185,7 @@ func (e *escape) call(ks []hole, call ir.Node) {
 // goDeferStmt analyzes a "go" or "defer" statement.
 func (e *escape) goDeferStmt(n *ir.GoDeferStmt) {
        k := e.heapHole()
-       if n.Op() == ir.ODEFER && e.loopDepth == 1 {
+       if n.Op() == ir.ODEFER && e.loopDepth == 1 && n.DeferAt == nil {
                // Top-level defer arguments don't escape to the heap,
                // but they do need to last until they're invoked.
                k = e.later(e.discardHole())
index 04398112ddc3baa6cd6013c926ef63067ee88b48..1436170a431b46ea7f64247cb1a2b100f4918565 100644 (file)
@@ -186,6 +186,7 @@ type CallExpr struct {
        miniExpr
        X         Node
        Args      Nodes
+       DeferAt   Node
        RType     Node    `mknode:"-"` // see reflectdata/helpers.go
        KeepAlive []*Name // vars to be kept alive until call returns
        IsDDD     bool
index 3e925b9db2e78823ff933a0158d3a6688022b93a..b3d6b0fbd5c4d71dd12ed6db747ced7e8da38f72 100644 (file)
@@ -242,7 +242,8 @@ func NewForStmt(pos src.XPos, init Node, cond, post Node, body []Node, distinctV
 // in a different context (a separate goroutine or a later time).
 type GoDeferStmt struct {
        miniStmt
-       Call Node
+       Call    Node
+       DeferAt Expr
 }
 
 func NewGoDeferStmt(pos src.XPos, op Op, call Node) *GoDeferStmt {
index 1e2fc8d97448a4bd4766f832c8249ab4fa238b01..4021035aa84782b2849fc714cddf1381f4a3d748 100644 (file)
@@ -10,7 +10,9 @@ import (
 )
 
 // Syms holds known symbols.
-var Syms struct {
+var Syms symsStruct
+
+type symsStruct struct {
        AssertE2I         *obj.LSym
        AssertE2I2        *obj.LSym
        AssertI2I         *obj.LSym
@@ -21,6 +23,7 @@ var Syms struct {
        CgoCheckPtrWrite  *obj.LSym
        CheckPtrAlignment *obj.LSym
        Deferproc         *obj.LSym
+       Deferprocat       *obj.LSym
        DeferprocStack    *obj.LSym
        Deferreturn       *obj.LSym
        Duffcopy          *obj.LSym
index f7b6d191fb85223ae7148790b96764506db2028b..c09a79d4f5d36108bbe94ea2aca69dbc0a674642 100644 (file)
@@ -12,6 +12,7 @@ import (
        "sort"
 
        "cmd/compile/internal/base"
+       "cmd/compile/internal/rangefunc"
        "cmd/compile/internal/syntax"
        "cmd/compile/internal/types2"
        "cmd/internal/src"
@@ -70,6 +71,10 @@ func checkFiles(m posMap, noders []*noder) (*types2.Package, *types2.Info) {
        }
 
        pkg, err := conf.Check(base.Ctxt.Pkgpath, files, info)
+       base.ExitIfErrors()
+       if err != nil {
+               base.FatalfAt(src.NoXPos, "conf.Check error: %v", err)
+       }
 
        // Check for anonymous interface cycles (#56103).
        if base.Debug.InterfaceCycles == 0 {
@@ -90,6 +95,7 @@ func checkFiles(m posMap, noders []*noder) (*types2.Package, *types2.Info) {
                        })
                }
        }
+       base.ExitIfErrors()
 
        // Implementation restriction: we don't allow not-in-heap types to
        // be used as type arguments (#54765).
@@ -115,11 +121,16 @@ func checkFiles(m posMap, noders []*noder) (*types2.Package, *types2.Info) {
                        base.ErrorfAt(targ.pos, 0, "cannot use incomplete (or unallocatable) type as a type argument: %v", targ.typ)
                }
        }
-
        base.ExitIfErrors()
-       if err != nil {
-               base.FatalfAt(src.NoXPos, "conf.Check error: %v", err)
-       }
+
+       // Rewrite range over function to explicit function calls
+       // with the loop bodies converted into new implicit closures.
+       // We do this now, before serialization to unified IR, so that if the
+       // implicit closures are inlined, we will have the unified IR form.
+       // If we do the rewrite in the back end, like between typecheck and walk,
+       // then the new implicit closure will not have a unified IR inline body,
+       // and bodyReaderFor will fail.
+       rangefunc.Rewrite(pkg, info, files)
 
        return pkg, info
 }
index 40efce139aef09da29689937f5c5a98b1e80fa09..f25c4afb2d8dc20b9c15bc6a3e1be7dc18e5ed3d 100644 (file)
@@ -677,6 +677,9 @@ func (pr *pkgReader) objIdx(idx pkgbits.Index, implicits, explicits []*types.Typ
                if pri, ok := objReader[sym]; ok {
                        return pri.pr.objIdx(pri.idx, nil, explicits, shaped)
                }
+               if sym.Pkg.Path == "runtime" {
+                       return typecheck.LookupRuntime(sym.Name)
+               }
                base.Fatalf("unresolved stub: %v", sym)
        }
 
@@ -1646,7 +1649,14 @@ func (r *reader) stmt1(tag codeStmt, out *ir.Nodes) ir.Node {
                pos := r.pos()
                op := r.op()
                call := r.expr()
-               return ir.NewGoDeferStmt(pos, op, call)
+               stmt := ir.NewGoDeferStmt(pos, op, call)
+               if op == ir.ODEFER {
+                       x := r.optExpr()
+                       if x != nil {
+                               stmt.DeferAt = x.(ir.Expr)
+                       }
+               }
+               return stmt
 
        case stmtExpr:
                return r.expr()
index f68a3875dfe5b23672ed5eab9cd1e43d22f0895d..6d7bd4c782183e42fbccb7d8903ad8fd47cc8670 100644 (file)
@@ -10,6 +10,7 @@ import (
        "go/token"
        "internal/buildcfg"
        "internal/pkgbits"
+       "os"
 
        "cmd/compile/internal/base"
        "cmd/compile/internal/ir"
@@ -1055,6 +1056,9 @@ func (w *writer) funcExt(obj *types2.Func) {
 
        sig, block := obj.Type().(*types2.Signature), decl.Body
        body, closureVars := w.p.bodyIdx(sig, block, w.dict)
+       if len(closureVars) > 0 {
+               fmt.Fprintln(os.Stderr, "CLOSURE", closureVars)
+       }
        assert(len(closureVars) == 0)
 
        w.Sync(pkgbits.SyncFuncExt)
@@ -1266,6 +1270,9 @@ func (w *writer) stmt1(stmt syntax.Stmt) {
                w.pos(stmt)
                w.op(callOps[stmt.Tok])
                w.expr(stmt.Call)
+               if stmt.Tok == syntax.Defer {
+                       w.optExpr(stmt.DeferAt)
+               }
 
        case *syntax.DeclStmt:
                for _, decl := range stmt.DeclList {
@@ -2300,6 +2307,10 @@ type posVar struct {
        var_ *types2.Var
 }
 
+func (p posVar) String() string {
+       return p.pos.String() + ":" + p.var_.String()
+}
+
 func (w *writer) exprList(expr syntax.Expr) {
        w.Sync(pkgbits.SyncExprList)
        w.exprs(syntax.UnpackListExpr(expr))
diff --git a/src/cmd/compile/internal/rangefunc/rewrite.go b/src/cmd/compile/internal/rangefunc/rewrite.go
new file mode 100644 (file)
index 0000000..ac12c53
--- /dev/null
@@ -0,0 +1,1142 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+/*
+Package rangefunc rewrites range-over-func to code that doesn't use range-over-funcs.
+Rewriting the construct in the front end, before noder, means the functions generated during
+the rewrite are available in a noder-generated representation for inlining by the back end.
+
+# Theory of Operation
+
+The basic idea is to rewrite
+
+       for x := range f {
+               ...
+       }
+
+into
+
+       f(func(x T) bool {
+               ...
+       })
+
+But it's not usually that easy.
+
+# Range variables
+
+For a range not using :=, the assigned variables cannot be function parameters
+in the generated body function. Instead, we allocate fake parameters and
+start the body with an assignment. For example:
+
+       for expr1, expr2 = range f {
+               ...
+       }
+
+becomes
+
+       f(func(#p1 T1, #p2 T2) bool {
+               expr1, expr2 = #p1, #p2
+               ...
+       })
+
+(All the generated variables have a # at the start to signal that they
+are internal variables when looking at the generated code in a
+debugger. Because variables have all been resolved to the specific
+objects they represent, there is no danger of using plain "p1" and
+colliding with a Go variable named "p1"; the # is just nice to have,
+not for correctness.)
+
+It can also happen that there are fewer range variables than function
+arguments, in which case we end up with something like
+
+       f(func(x T1, _ T2) bool {
+               ...
+       })
+
+or
+
+       f(func(#p1 T1, #p2 T2, _ T3) bool {
+               expr1, expr2 = #p1, #p2
+               ...
+       })
+
+# Return
+
+If the body contains a "break", that break turns into "return false",
+to tell f to stop. And if the body contains a "continue", that turns
+into "return true", to tell f to proceed with the next value.
+Those are the easy cases.
+
+If the body contains a return or a break/continue/goto L, then we need
+to rewrite that into code that breaks out of the loop and then
+triggers that control flow. In general we rewrite
+
+       for x := range f {
+               ...
+       }
+
+into
+
+       {
+               var #next int
+               f(func(x T1) bool {
+                       ...
+                       return true
+               })
+               ... check #next ...
+       }
+
+The variable #next is an integer code that says what to do when f
+returns. Each difficult statement sets #next and then returns false to
+stop f.
+
+A plain "return" rewrites to {#next = -1; return false}.
+The return false breaks the loop. Then when f returns, the "check
+#next" section includes
+
+       if #next == -1 { return }
+
+which causes the return we want.
+
+Return with arguments is more involved. We need somewhere to store the
+arguments while we break out of f, so we add them to the var
+declaration, like:
+
+       {
+               var (
+                       #next int
+                       #r1 type1
+                       #r2 type2
+               )
+               f(func(x T1) bool {
+                       ...
+                       {
+                               // return a, b
+                               #r1, #r2 = a, b
+                               #next = -2
+                               return false
+                       }
+                       ...
+                       return true
+               })
+               if #next == -2 { return #r1, #r2 }
+       }
+
+TODO: What about:
+
+       func f() (x bool) {
+               for range g(&x) {
+                       return true
+               }
+       }
+
+       func g(p *bool) func(func() bool) {
+               return func(yield func() bool) {
+                       yield()
+                       // Is *p true or false here?
+               }
+       }
+
+With this rewrite the "return true" is not visible after yield returns,
+but maybe it should be?
+
+# Nested Loops
+
+So far we've only considered a single loop. If a function contains a
+sequence of loops, each can be translated individually. But loops can
+be nested. It would work to translate the innermost loop and then
+translate the loop around it, and so on, except that there'd be a lot
+of rewriting of rewritten code and the overall traversals could end up
+taking time quadratic in the depth of the nesting. To avoid all that,
+we use a single rewriting pass that handles a top-most range-over-func
+loop and all the range-over-func loops it contains at the same time.
+
+If we need to return from inside a doubly-nested loop, the rewrites
+above stay the same, but the check after the inner loop only says
+
+       if #next < 0 { return false }
+
+to stop the outer loop so it can do the actual return. That is,
+
+       for range f {
+               for range g {
+                       ...
+                       return a, b
+                       ...
+               }
+       }
+
+becomes
+
+       {
+               var (
+                       #next int
+                       #r1 type1
+                       #r2 type2
+               )
+               f(func() {
+                       g(func() {
+                               ...
+                               {
+                                       // return a, b
+                                       #r1, #r2 = a, b
+                                       #next = -2
+                                       return false
+                               }
+                               ...
+                               return true
+                       })
+                       if #next < 0 {
+                               return false
+                       }
+                       return true
+               })
+               if #next == -2 {
+                       return #r1, #r2
+               }
+       }
+
+Note that the #next < 0 after the inner loop handles both kinds of
+return with a single check.
+
+# Labeled break/continue of range-over-func loops
+
+For a labeled break or continue of an outer range-over-func, we
+use positive #next values. Any such labeled break or continue
+really means "do N breaks" or "do N breaks and 1 continue".
+We encode that as 2*N or 2*N+1 respectively.
+Loops that might need to propagate a labeled break or continue
+add one or both of these to the #next checks:
+
+       if #next >= 2 {
+               #next -= 2
+               return false
+       }
+
+       if #next == 1 {
+               #next = 0
+               return true
+       }
+
+For example
+
+       F: for range f {
+               for range g {
+                       for range h {
+                               ...
+                               break F
+                               ...
+                               ...
+                               continue F
+                               ...
+                       }
+               }
+               ...
+       }
+
+becomes
+
+       {
+               var #next int
+               f(func() {
+                       g(func() {
+                               h(func() {
+                                       ...
+                                       {
+                                               // break F
+                                               #next = 4
+                                               return false
+                                       }
+                                       ...
+                                       {
+                                               // continue F
+                                               #next = 3
+                                               return false
+                                       }
+                                       ...
+                                       return true
+                               })
+                               if #next >= 2 {
+                                       #next -= 2
+                                       return false
+                               }
+                               return true
+                       })
+                       if #next >= 2 {
+                               #next -= 2
+                               return false
+                       }
+                       if #next == 1 {
+                               #next = 0
+                               return true
+                       }
+                       ...
+                       return true
+               })
+       }
+
+Note that the post-h checks only consider a break,
+since no generated code tries to continue g.
+
+# Gotos and other labeled break/continue
+
+The final control flow translations are goto and break/continue of a
+non-range-over-func statement. In both cases, we may need to break out
+of one or more range-over-func loops before we can do the actual
+control flow statement. Each such break/continue/goto L statement is
+assigned a unique negative #next value (below -2, since -1 and -2 are
+for the two kinds of return). Then the post-checks for a given loop
+test for the specific codes that refer to labels directly targetable
+from that block. Otherwise, the generic
+
+       if #next < 0 { return false }
+
+check handles stopping the next loop to get one step closer to the label.
+
+For example
+
+       Top: print("start\n")
+       for range f {
+               for range g {
+                       for range h {
+                               ...
+                               goto Top
+                               ...
+                       }
+               }
+       }
+
+becomes
+
+       Top: print("start\n")
+       {
+               var #next int
+               f(func() {
+                       g(func() {
+                               h(func() {
+                                       ...
+                                       {
+                                               // goto Top
+                                               #next = -3
+                                               return false
+                                       }
+                                       ...
+                                       return true
+                               })
+                               if #next < 0 {
+                                       return false
+                               }
+                               return true
+                       })
+                       if #next < 0 {
+                               return false
+                       }
+                       return true
+               })
+               if #next == -3 {
+                       #next = 0
+                       goto Top
+               }
+       }
+
+Labeled break/continue to non-range-over-funcs are handled the same
+way as goto.
+
+# Defers
+
+The last wrinkle is handling defer statements. If we have
+
+       for range f {
+               defer print("A")
+       }
+
+we cannot rewrite that into
+
+       f(func() {
+               defer print("A")
+       })
+
+because the deferred code will run at the end of the iteration, not
+the end of the containing function. To fix that, the runtime provides
+a special hook that lets us obtain a defer "token" representing the
+outer function and then use it in a later defer to attach the deferred
+code to that outer function.
+
+Normally,
+
+       defer print("A")
+
+compiles to
+
+       runtime.deferproc(func() { print("A") })
+
+This changes in a range-over-func. For example:
+
+       for range f {
+               defer print("A")
+       }
+
+compiles to
+
+       var #defers = runtime.deferrangefunc()
+       f(func() {
+               runtime.deferprocat(func() { print("A") }, #defers)
+       })
+
+For this rewriting phase, we insert the explicit initialization of
+#defers and then attach the #defers variable to the CallStmt
+representing the defer. That variable will be propagated to the
+backend and will cause the backend to compile the defer using
+deferprocat instead of an ordinary deferproc.
+
+TODO: Could call runtime.deferrangefuncend after f.
+*/
+package rangefunc
+
+import (
+       "cmd/compile/internal/base"
+       "cmd/compile/internal/syntax"
+       "cmd/compile/internal/types2"
+       "fmt"
+       "go/constant"
+       "os"
+)
+
+// nopos is the zero syntax.Pos.
+var nopos syntax.Pos
+
+// A rewriter implements rewriting the range-over-funcs in a given function.
+type rewriter struct {
+       pkg   *types2.Package
+       info  *types2.Info
+       outer *syntax.FuncType
+       body  *syntax.BlockStmt
+
+       // References to important types and values.
+       any   types2.Object
+       bool  types2.Object
+       int   types2.Object
+       true  types2.Object
+       false types2.Object
+
+       // Branch numbering, computed as needed.
+       branchNext map[branch]int             // branch -> #next value
+       labelLoop  map[string]*syntax.ForStmt // label -> innermost rangefunc loop it is declared inside (nil for no loop)
+
+       // Stack of nodes being visited.
+       stack    []syntax.Node // all nodes
+       forStack []*forLoop    // range-over-func loops
+
+       rewritten map[*syntax.ForStmt]syntax.Stmt
+
+       // Declared variables in generated code for outermost loop.
+       declStmt *syntax.DeclStmt
+       nextVar  types2.Object
+       retVars  []types2.Object
+       defers   types2.Object
+}
+
+// A branch is a single labeled branch.
+type branch struct {
+       tok   syntax.Token
+       label string
+}
+
+// A forLoop describes a single range-over-func loop being processed.
+type forLoop struct {
+       nfor *syntax.ForStmt // actual syntax
+
+       checkRet      bool     // add check for "return" after loop
+       checkRetArgs  bool     // add check for "return args" after loop
+       checkBreak    bool     // add check for "break" after loop
+       checkContinue bool     // add check for "continue" after loop
+       checkBranch   []branch // add check for labeled branch after loop
+}
+
+// Rewrite rewrites all the range-over-funcs in the files.
+func Rewrite(pkg *types2.Package, info *types2.Info, files []*syntax.File) {
+       for _, file := range files {
+               syntax.Inspect(file, func(n syntax.Node) bool {
+                       switch n := n.(type) {
+                       case *syntax.FuncDecl:
+                               rewriteFunc(pkg, info, n.Type, n.Body)
+                               return false
+                       case *syntax.FuncLit:
+                               rewriteFunc(pkg, info, n.Type, n.Body)
+                               return false
+                       }
+                       return true
+               })
+       }
+}
+
+// rewriteFunc rewrites all the range-over-funcs in a single function (a top-level func or a func literal).
+// The typ and body are the function's type and body.
+func rewriteFunc(pkg *types2.Package, info *types2.Info, typ *syntax.FuncType, body *syntax.BlockStmt) {
+       if body == nil {
+               return
+       }
+       r := &rewriter{
+               pkg:   pkg,
+               info:  info,
+               outer: typ,
+               body:  body,
+       }
+       syntax.Inspect(body, r.inspect)
+       if (base.Flag.W != 0) && r.forStack != nil {
+               syntax.Fdump(os.Stderr, body)
+       }
+}
+
+// inspect is a callback for syntax.Inspect that drives the actual rewriting.
+// If it sees a func literal, it kicks off a separate rewrite for that literal.
+// Otherwise, it maintains a stack of range-over-func loops and
+// converts each in turn.
+func (r *rewriter) inspect(n syntax.Node) bool {
+       switch n := n.(type) {
+       case *syntax.FuncLit:
+               rewriteFunc(r.pkg, r.info, n.Type, n.Body)
+               return false
+
+       default:
+               // Push n onto stack.
+               r.stack = append(r.stack, n)
+               if nfor, ok := forRangeFunc(n); ok {
+                       loop := &forLoop{nfor: nfor}
+                       r.forStack = append(r.forStack, loop)
+                       r.startLoop(loop)
+               }
+
+       case nil:
+               // n == nil signals that we are done visiting
+               // the top-of-stack node's children. Find it.
+               n = r.stack[len(r.stack)-1]
+
+               // If we are inside a range-over-func,
+               // take this moment to replace any break/continue/goto/return
+               // statements directly contained in this node.
+               // Also replace any converted for statements
+               // with the rewritten block.
+               switch n := n.(type) {
+               case *syntax.BlockStmt:
+                       for i, s := range n.List {
+                               n.List[i] = r.editStmt(s)
+                       }
+               case *syntax.CaseClause:
+                       for i, s := range n.Body {
+                               n.Body[i] = r.editStmt(s)
+                       }
+               case *syntax.CommClause:
+                       for i, s := range n.Body {
+                               n.Body[i] = r.editStmt(s)
+                       }
+               case *syntax.LabeledStmt:
+                       n.Stmt = r.editStmt(n.Stmt)
+               }
+
+               // Pop n.
+               if len(r.forStack) > 0 && r.stack[len(r.stack)-1] == r.forStack[len(r.forStack)-1].nfor {
+                       r.endLoop(r.forStack[len(r.forStack)-1])
+                       r.forStack = r.forStack[:len(r.forStack)-1]
+               }
+               r.stack = r.stack[:len(r.stack)-1]
+       }
+       return true
+}
+
+// startLoop sets up for converting a range-over-func loop.
+func (r *rewriter) startLoop(loop *forLoop) {
+       // For first loop in function, allocate syntax for any, bool, int, true, and false.
+       if r.any == nil {
+               r.any = types2.Universe.Lookup("any")
+               r.bool = types2.Universe.Lookup("bool")
+               r.int = types2.Universe.Lookup("int")
+               r.true = types2.Universe.Lookup("true")
+               r.false = types2.Universe.Lookup("false")
+               r.rewritten = make(map[*syntax.ForStmt]syntax.Stmt)
+       }
+}
+
+// editStmt returns the replacement for the statement x,
+// or x itself if it should be left alone.
+// This includes the for loops we are converting,
+// as left in x.rewritten by r.endLoop.
+func (r *rewriter) editStmt(x syntax.Stmt) syntax.Stmt {
+       if x, ok := x.(*syntax.ForStmt); ok {
+               if s := r.rewritten[x]; s != nil {
+                       return s
+               }
+       }
+
+       if len(r.forStack) > 0 {
+               switch x := x.(type) {
+               case *syntax.BranchStmt:
+                       return r.editBranch(x)
+               case *syntax.CallStmt:
+                       if x.Tok == syntax.Defer {
+                               return r.editDefer(x)
+                       }
+               case *syntax.ReturnStmt:
+                       return r.editReturn(x)
+               }
+       }
+
+       return x
+}
+
+// editDefer returns the replacement for the defer statement x.
+// See the "Defers" section in the package doc comment above for more context.
+func (r *rewriter) editDefer(x *syntax.CallStmt) syntax.Stmt {
+       if r.defers == nil {
+               // Declare and initialize the #defers token.
+               init := &syntax.CallExpr{
+                       Fun: runtimeSym(r.info, "deferrangefunc"),
+               }
+               tv := syntax.TypeAndValue{Type: r.any.Type()}
+               tv.SetIsValue()
+               init.SetTypeInfo(tv)
+               r.defers = r.declVar("#defers", r.any.Type(), init)
+       }
+
+       // Attach the token as an "extra" argument to the defer.
+       x.DeferAt = r.useVar(r.defers)
+       setPos(x.DeferAt, x.Pos())
+       return x
+}
+
+// editReturn returns the replacement for the return statement x.
+// See the "Return" section in the package doc comment above for more context.
+func (r *rewriter) editReturn(x *syntax.ReturnStmt) syntax.Stmt {
+       // #next = -1 is return with no arguments; -2 is return with arguments.
+       var next int
+       if x.Results == nil {
+               next = -1
+               r.forStack[0].checkRet = true
+       } else {
+               next = -2
+               r.forStack[0].checkRetArgs = true
+       }
+
+       // Tell the loops along the way to check for a return.
+       for _, loop := range r.forStack[1:] {
+               loop.checkRet = true
+       }
+
+       // Assign results, set #next, and return false.
+       bl := &syntax.BlockStmt{}
+       if x.Results != nil {
+               if r.retVars == nil {
+                       for i, a := range r.outer.ResultList {
+                               obj := r.declVar(fmt.Sprintf("#r%d", i+1), a.Type.GetTypeInfo().Type, nil)
+                               r.retVars = append(r.retVars, obj)
+                       }
+               }
+               bl.List = append(bl.List, &syntax.AssignStmt{Lhs: r.useList(r.retVars), Rhs: x.Results})
+       }
+       bl.List = append(bl.List, &syntax.AssignStmt{Lhs: r.next(), Rhs: r.intConst(next)})
+       bl.List = append(bl.List, &syntax.ReturnStmt{Results: r.useVar(r.false)})
+       setPos(bl, x.Pos())
+       return bl
+}
+
+// editBranch returns the replacement for the branch statement x,
+// or x itself if it should be left alone.
+// See the package doc comment above for more context.
+func (r *rewriter) editBranch(x *syntax.BranchStmt) syntax.Stmt {
+       if x.Tok == syntax.Fallthrough {
+               // Fallthrough is unaffected by the rewrite.
+               return x
+       }
+
+       // Find target of break/continue/goto in r.forStack.
+       // (The target may not be in r.forStack at all.)
+       targ := x.Target
+       i := len(r.forStack) - 1
+       if x.Label == nil && r.forStack[i].nfor != targ {
+               // Unlabeled break or continue that's not nfor must be inside nfor. Leave alone.
+               return x
+       }
+       for i >= 0 && r.forStack[i].nfor != targ {
+               i--
+       }
+
+       // Compute the value to assign to #next and the specific return to use.
+       var next int
+       var ret *syntax.ReturnStmt
+       if x.Tok == syntax.Goto || i < 0 {
+               // goto Label
+               // or break/continue of labeled non-range-over-func loop.
+               // We may be able to leave it alone, or we may have to break
+               // out of one or more nested loops and then use #next to signal
+               // to complete the break/continue/goto.
+               // Figure out which range-over-func loop contains the label.
+               r.computeBranchNext()
+               nfor := r.forStack[len(r.forStack)-1].nfor
+               label := x.Label.Value
+               targ := r.labelLoop[label]
+               if nfor == targ {
+                       // Label is in the innermost range-over-func loop; use it directly.
+                       return x
+               }
+
+               // Set #next to the code meaning break/continue/goto label.
+               next = r.branchNext[branch{x.Tok, label}]
+
+               // Break out of nested loops up to targ.
+               i := len(r.forStack) - 1
+               for i >= 0 && r.forStack[i].nfor != targ {
+                       i--
+               }
+
+               // Mark loop we exit to get to targ to check for that branch.
+               // When i==-1 that's the outermost func body
+               top := r.forStack[i+1]
+               top.checkBranch = append(top.checkBranch, branch{x.Tok, label})
+
+               // Mark loops along the way to check for a plain return, so they break.
+               for j := i + 2; j < len(r.forStack); j++ {
+                       r.forStack[j].checkRet = true
+               }
+
+               // In the innermost loop, use a plain "return false".
+               ret = &syntax.ReturnStmt{Results: r.useVar(r.false)}
+       } else {
+               // break/continue of labeled range-over-func loop.
+               depth := len(r.forStack) - 1 - i
+
+               // For continue of innermost loop, use "return true".
+               // Otherwise we are breaking the innermost loop, so "return false".
+               retVal := r.false
+               if depth == 0 && x.Tok == syntax.Continue {
+                       retVal = r.true
+               }
+               ret = &syntax.ReturnStmt{Results: r.useVar(retVal)}
+
+               // If we're only operating on the innermost loop, the return is all we need.
+               if depth == 0 {
+                       setPos(ret, x.Pos())
+                       return ret
+               }
+
+               // The loop inside the one we are break/continue-ing
+               // needs to make that happen when we break out of it.
+               if x.Tok == syntax.Continue {
+                       r.forStack[i+1].checkContinue = true
+               } else {
+                       r.forStack[i+1].checkBreak = true
+               }
+
+               // The loops along the way just need to break.
+               for j := i + 2; j < len(r.forStack); j++ {
+                       r.forStack[j].checkBreak = true
+               }
+
+               // Set next to break the appropriate number of times;
+               // the final time may be a continue, not a break.
+               next = 2 * depth
+               if x.Tok == syntax.Continue {
+                       next--
+               }
+       }
+
+       // Assign #next = next and do the return.
+       as := &syntax.AssignStmt{Lhs: r.next(), Rhs: r.intConst(next)}
+       bl := &syntax.BlockStmt{
+               List: []syntax.Stmt{as, ret},
+       }
+       setPos(bl, x.Pos())
+       return bl
+}
+
+// computeBranchNext computes the branchNext numbering
+// and determines which labels end up inside which range-over-func loop bodies.
+func (r *rewriter) computeBranchNext() {
+       if r.labelLoop != nil {
+               return
+       }
+
+       r.labelLoop = make(map[string]*syntax.ForStmt)
+       r.branchNext = make(map[branch]int)
+
+       var labels []string
+       var stack []syntax.Node
+       var forStack []*syntax.ForStmt
+       forStack = append(forStack, nil)
+       syntax.Inspect(r.body, func(n syntax.Node) bool {
+               if n != nil {
+                       stack = append(stack, n)
+                       if nfor, ok := forRangeFunc(n); ok {
+                               forStack = append(forStack, nfor)
+                       }
+                       if n, ok := n.(*syntax.LabeledStmt); ok {
+                               l := n.Label.Value
+                               labels = append(labels, l)
+                               f := forStack[len(forStack)-1]
+                               r.labelLoop[l] = f
+                       }
+               } else {
+                       n := stack[len(stack)-1]
+                       stack = stack[:len(stack)-1]
+                       if n == forStack[len(forStack)-1] {
+                               forStack = forStack[:len(forStack)-1]
+                       }
+               }
+               return true
+       })
+
+       // Assign numbers to all the labels we observed.
+       used := -2
+       for _, l := range labels {
+               used -= 3
+               r.branchNext[branch{syntax.Break, l}] = used
+               r.branchNext[branch{syntax.Continue, l}] = used + 1
+               r.branchNext[branch{syntax.Goto, l}] = used + 2
+       }
+}
+
+// endLoop finishes the conversion of a range-over-func loop.
+// We have inspected and rewritten the body of the loop and can now
+// construct the body function and rewrite the for loop into a call
+// bracketed by any declarations and checks it requires.
+func (r *rewriter) endLoop(loop *forLoop) {
+       // Pick apart for range X { ... }
+       nfor := loop.nfor
+       start, end := nfor.Pos(), nfor.Body.Rbrace // start, end position of for loop
+       rclause := nfor.Init.(*syntax.RangeClause)
+       rfunc := types2.CoreType(rclause.X.GetTypeInfo().Type).(*types2.Signature) // type of X - func(func(...)bool)
+       if rfunc.Params().Len() != 1 {
+               base.Fatalf("invalid typecheck of range func")
+       }
+       ftyp := rfunc.Params().At(0).Type().(*types2.Signature) // func(...) bool
+       if ftyp.Results().Len() != 1 {
+               base.Fatalf("invalid typecheck of range func")
+       }
+
+       // Build X(bodyFunc)
+       call := &syntax.ExprStmt{
+               X: &syntax.CallExpr{
+                       Fun: rclause.X,
+                       ArgList: []syntax.Expr{
+                               r.bodyFunc(nfor.Body.List, syntax.UnpackListExpr(rclause.Lhs), rclause.Def, ftyp, start, end),
+                       },
+               },
+       }
+       setPos(call, start)
+
+       // Build checks based on #next after X(bodyFunc)
+       checks := r.checks(loop, end)
+
+       // Rewrite for vars := range X { ... } to
+       //
+       //      {
+       //              r.declStmt
+       //              call
+       //              checks
+       //      }
+       //
+       // The r.declStmt can be added to by this loop or any inner loop
+       // during the creation of r.bodyFunc; it is only emitted in the outermost
+       // converted range loop.
+       block := &syntax.BlockStmt{Rbrace: end}
+       setPos(block, start)
+       if len(r.forStack) == 1 && r.declStmt != nil {
+               setPos(r.declStmt, start)
+               block.List = append(block.List, r.declStmt)
+       }
+       block.List = append(block.List, call)
+       block.List = append(block.List, checks...)
+
+       if len(r.forStack) == 1 { // ending an outermost loop
+               r.declStmt = nil
+               r.nextVar = nil
+               r.retVars = nil
+               r.defers = nil
+       }
+
+       r.rewritten[nfor] = block
+}
+
+// bodyFunc converts the loop body (control flow has already been updated)
+// to a func literal that can be passed to the range function.
+//
+// vars is the range variables from the range statement.
+// def indicates whether this is a := range statement.
+// ftyp is the type of the function we are creating
+// start and end are the syntax positions to use for new nodes
+// that should be at the start or end of the loop.
+func (r *rewriter) bodyFunc(body []syntax.Stmt, lhs []syntax.Expr, def bool, ftyp *types2.Signature, start, end syntax.Pos) *syntax.FuncLit {
+       // Starting X(bodyFunc); build up bodyFunc first.
+       var params, results []*types2.Var
+       results = append(results, types2.NewVar(start, nil, "", r.bool.Type()))
+       bodyFunc := &syntax.FuncLit{
+               // Note: Type is ignored but needs to be non-nil to avoid panic in syntax.Inspect.
+               Type: &syntax.FuncType{},
+               Body: &syntax.BlockStmt{
+                       List:   []syntax.Stmt{},
+                       Rbrace: end,
+               },
+       }
+       setPos(bodyFunc, start)
+
+       for i := 0; i < ftyp.Params().Len(); i++ {
+               typ := ftyp.Params().At(i).Type()
+               var paramVar *types2.Var
+               if i < len(lhs) && def {
+                       // Reuse range variable as parameter.
+                       x := lhs[i]
+                       paramVar = r.info.Defs[x.(*syntax.Name)].(*types2.Var)
+               } else {
+                       // Declare new parameter and assign it to range expression.
+                       paramVar = types2.NewVar(start, r.pkg, fmt.Sprintf("#p%d", 1+i), typ)
+                       if i < len(lhs) {
+                               x := lhs[i]
+                               as := &syntax.AssignStmt{Lhs: x, Rhs: r.useVar(paramVar)}
+                               as.SetPos(x.Pos())
+                               setPos(as.Rhs, x.Pos())
+                               bodyFunc.Body.List = append(bodyFunc.Body.List, as)
+                       }
+               }
+               params = append(params, paramVar)
+       }
+
+       tv := syntax.TypeAndValue{
+               Type: types2.NewSignatureType(nil, nil, nil,
+                       types2.NewTuple(params...),
+                       types2.NewTuple(results...),
+                       false),
+       }
+       tv.SetIsValue()
+       bodyFunc.SetTypeInfo(tv)
+
+       // Original loop body (already rewritten by editStmt during inspect).
+       bodyFunc.Body.List = append(bodyFunc.Body.List, body...)
+
+       // return true to continue at end of loop body
+       ret := &syntax.ReturnStmt{Results: r.useVar(r.true)}
+       ret.SetPos(end)
+       bodyFunc.Body.List = append(bodyFunc.Body.List, ret)
+
+       return bodyFunc
+}
+
+// checks returns the post-call checks that need to be done for the given loop.
+func (r *rewriter) checks(loop *forLoop, pos syntax.Pos) []syntax.Stmt {
+       var list []syntax.Stmt
+       if len(loop.checkBranch) > 0 {
+               did := make(map[branch]bool)
+               for _, br := range loop.checkBranch {
+                       if did[br] {
+                               continue
+                       }
+                       did[br] = true
+                       doBranch := &syntax.BranchStmt{Tok: br.tok, Label: &syntax.Name{Value: br.label}}
+                       list = append(list, r.ifNext(syntax.Eql, r.branchNext[br], doBranch))
+               }
+       }
+       if len(r.forStack) == 1 {
+               if loop.checkRetArgs {
+                       list = append(list, r.ifNext(syntax.Eql, -2, retStmt(r.useList(r.retVars))))
+               }
+               if loop.checkRet {
+                       list = append(list, r.ifNext(syntax.Eql, -1, retStmt(nil)))
+               }
+       } else {
+               if loop.checkRetArgs || loop.checkRet {
+                       // Note: next < 0 also handles gotos handled by outer loops.
+                       // We set checkRet in that case to trigger this check.
+                       list = append(list, r.ifNext(syntax.Lss, 0, retStmt(r.useVar(r.false))))
+               }
+               if loop.checkBreak {
+                       list = append(list, r.ifNext(syntax.Geq, 2, retStmt(r.useVar(r.false))))
+               }
+               if loop.checkContinue {
+                       list = append(list, r.ifNext(syntax.Eql, 1, retStmt(r.useVar(r.true))))
+               }
+       }
+
+       for _, j := range list {
+               setPos(j, pos)
+       }
+       return list
+}
+
+// retStmt returns a return statement returning the given return values.
+func retStmt(results syntax.Expr) *syntax.ReturnStmt {
+       return &syntax.ReturnStmt{Results: results}
+}
+
+// ifNext returns the statement:
+//
+//     if #next op c { adjust; then }
+//
+// When op is >=, adjust is #next -= c.
+// When op is == and c is not -1 or -2, adjust is #next = 0.
+// Otherwise adjust is omitted.
+func (r *rewriter) ifNext(op syntax.Operator, c int, then syntax.Stmt) syntax.Stmt {
+       nif := &syntax.IfStmt{
+               Cond: &syntax.Operation{Op: op, X: r.next(), Y: r.intConst(c)},
+               Then: &syntax.BlockStmt{
+                       List: []syntax.Stmt{then},
+               },
+       }
+       tv := syntax.TypeAndValue{Type: r.bool.Type()}
+       tv.SetIsValue()
+       nif.Cond.SetTypeInfo(tv)
+
+       if op == syntax.Geq {
+               sub := &syntax.AssignStmt{
+                       Op:  syntax.Sub,
+                       Lhs: r.next(),
+                       Rhs: r.intConst(c),
+               }
+               nif.Then.List = []syntax.Stmt{sub, then}
+       }
+       if op == syntax.Eql && c != -1 && c != -2 {
+               clr := &syntax.AssignStmt{
+                       Lhs: r.next(),
+                       Rhs: r.intConst(0),
+               }
+               nif.Then.List = []syntax.Stmt{clr, then}
+       }
+
+       return nif
+}
+
+// next returns a reference to the #next variable.
+func (r *rewriter) next() *syntax.Name {
+       if r.nextVar == nil {
+               r.nextVar = r.declVar("#next", r.int.Type(), nil)
+       }
+       return r.useVar(r.nextVar)
+}
+
+// forRangeFunc checks whether n is a range-over-func.
+// If so, it returns n.(*syntax.ForStmt), true.
+// Otherwise it returns nil, false.
+func forRangeFunc(n syntax.Node) (*syntax.ForStmt, bool) {
+       nfor, ok := n.(*syntax.ForStmt)
+       if !ok {
+               return nil, false
+       }
+       nrange, ok := nfor.Init.(*syntax.RangeClause)
+       if !ok {
+               return nil, false
+       }
+       _, ok = types2.CoreType(nrange.X.GetTypeInfo().Type).(*types2.Signature)
+       if !ok {
+               return nil, false
+       }
+       return nfor, true
+}
+
+// intConst returns syntax for an integer literal with the given value.
+func (r *rewriter) intConst(c int) *syntax.BasicLit {
+       lit := &syntax.BasicLit{
+               Value: fmt.Sprint(c),
+               Kind:  syntax.IntLit,
+       }
+       tv := syntax.TypeAndValue{Type: r.int.Type(), Value: constant.MakeInt64(int64(c))}
+       tv.SetIsValue()
+       lit.SetTypeInfo(tv)
+       return lit
+}
+
+// useVar returns syntax for a reference to decl, which should be its declaration.
+func (r *rewriter) useVar(obj types2.Object) *syntax.Name {
+       n := syntax.NewName(nopos, obj.Name())
+       tv := syntax.TypeAndValue{Type: obj.Type()}
+       tv.SetIsValue()
+       n.SetTypeInfo(tv)
+       r.info.Uses[n] = obj
+       return n
+}
+
+// useList is useVar for a list of decls.
+func (r *rewriter) useList(vars []types2.Object) syntax.Expr {
+       var new []syntax.Expr
+       for _, obj := range vars {
+               new = append(new, r.useVar(obj))
+       }
+       if len(new) == 1 {
+               return new[0]
+       }
+       return &syntax.ListExpr{ElemList: new}
+}
+
+// declVar declares a variable with a given name type and initializer value.
+func (r *rewriter) declVar(name string, typ types2.Type, init syntax.Expr) *types2.Var {
+       if r.declStmt == nil {
+               r.declStmt = &syntax.DeclStmt{}
+       }
+       stmt := r.declStmt
+       obj := types2.NewVar(stmt.Pos(), r.pkg, name, typ)
+       n := syntax.NewName(stmt.Pos(), name)
+       tv := syntax.TypeAndValue{Type: typ}
+       tv.SetIsValue()
+       n.SetTypeInfo(tv)
+       r.info.Defs[n] = obj
+       stmt.DeclList = append(stmt.DeclList, &syntax.VarDecl{
+               NameList: []*syntax.Name{n},
+               // Note: Type is ignored
+               Values: init,
+       })
+       return obj
+}
+
+// declType declares a type with the given name and type.
+// This is more like "type name = typ" than "type name typ".
+func declType(pos syntax.Pos, name string, typ types2.Type) *syntax.Name {
+       n := syntax.NewName(pos, name)
+       n.SetTypeInfo(syntax.TypeAndValue{Type: typ})
+       return n
+}
+
+// runtimePkg is a fake runtime package that contains what we need to refer to in package runtime.
+var runtimePkg = func() *types2.Package {
+       var nopos syntax.Pos
+       pkg := types2.NewPackage("runtime", "runtime")
+       anyType := types2.Universe.Lookup("any").Type()
+
+       // func deferrangefunc() unsafe.Pointer
+       obj := types2.NewVar(nopos, pkg, "deferrangefunc", types2.NewSignatureType(nil, nil, nil, nil, types2.NewTuple(types2.NewParam(nopos, pkg, "extra", anyType)), false))
+       pkg.Scope().Insert(obj)
+
+       return pkg
+}()
+
+// runtimeSym returns a reference to a symbol in the fake runtime package.
+func runtimeSym(info *types2.Info, name string) *syntax.Name {
+       obj := runtimePkg.Scope().Lookup(name)
+       n := syntax.NewName(nopos, "runtime."+name)
+       tv := syntax.TypeAndValue{Type: obj.Type()}
+       tv.SetIsValue()
+       n.SetTypeInfo(tv)
+       info.Uses[n] = obj
+       return n
+}
+
+// setPos walks the top structure of x that has no position assigned
+// and assigns it all to have position pos.
+// When setPos encounters a syntax node with a position assigned,
+// setPos does not look inside that node.
+// setPos only needs to handle syntax we create in this package;
+// all other syntax should have positions assigned already.
+func setPos(x syntax.Node, pos syntax.Pos) {
+       if x == nil {
+               return
+       }
+       syntax.Inspect(x, func(n syntax.Node) bool {
+               if n == nil || n.Pos() != nopos {
+                       return false
+               }
+               n.SetPos(pos)
+               switch n := n.(type) {
+               case *syntax.BlockStmt:
+                       if n.Rbrace == nopos {
+                               n.Rbrace = pos
+                       }
+               }
+               return true
+       })
+}
index d79c6fdd00b9c3a5a1f24c4eb4c0d5c37100f556..a9d80552cb3d03179f929df82203cd6333acaf93 100644 (file)
@@ -103,6 +103,7 @@ func InitConfig() {
        ir.Syms.CgoCheckPtrWrite = typecheck.LookupRuntimeFunc("cgoCheckPtrWrite")
        ir.Syms.CheckPtrAlignment = typecheck.LookupRuntimeFunc("checkptrAlignment")
        ir.Syms.Deferproc = typecheck.LookupRuntimeFunc("deferproc")
+       ir.Syms.Deferprocat = typecheck.LookupRuntimeFunc("deferprocat")
        ir.Syms.DeferprocStack = typecheck.LookupRuntimeFunc("deferprocStack")
        ir.Syms.Deferreturn = typecheck.LookupRuntimeFunc("deferreturn")
        ir.Syms.Duffcopy = typecheck.LookupRuntimeFunc("duffcopy")
@@ -1491,10 +1492,10 @@ func (s *state) stmt(n ir.Node) {
                        s.openDeferRecord(n.Call.(*ir.CallExpr))
                } else {
                        d := callDefer
-                       if n.Esc() == ir.EscNever {
+                       if n.Esc() == ir.EscNever && n.DeferAt == nil {
                                d = callDeferStack
                        }
-                       s.callResult(n.Call.(*ir.CallExpr), d)
+                       s.call(n.Call.(*ir.CallExpr), d, false, n.DeferAt)
                }
        case ir.OGO:
                n := n.(*ir.GoDeferStmt)
@@ -5182,20 +5183,21 @@ func (s *state) openDeferExit() {
 }
 
 func (s *state) callResult(n *ir.CallExpr, k callKind) *ssa.Value {
-       return s.call(n, k, false)
+       return s.call(n, k, false, nil)
 }
 
 func (s *state) callAddr(n *ir.CallExpr, k callKind) *ssa.Value {
-       return s.call(n, k, true)
+       return s.call(n, k, true, nil)
 }
 
 // Calls the function n using the specified call type.
 // Returns the address of the return value (or nil if none).
-func (s *state) call(n *ir.CallExpr, k callKind, returnResultAddr bool) *ssa.Value {
+func (s *state) call(n *ir.CallExpr, k callKind, returnResultAddr bool, deferExtra ir.Expr) *ssa.Value {
        s.prevCall = nil
        var callee *ir.Name    // target function (if static)
        var closure *ssa.Value // ptr to closure to run (if dynamic)
        var codeptr *ssa.Value // ptr to target code (if dynamic)
+       var dextra *ssa.Value  // defer extra arg
        var rcvr *ssa.Value    // receiver to set
        fn := n.X
        var ACArgs []*types.Type    // AuxCall args
@@ -5251,6 +5253,9 @@ func (s *state) call(n *ir.CallExpr, k callKind, returnResultAddr bool) *ssa.Val
                        closure = iclosure
                }
        }
+       if deferExtra != nil {
+               dextra = s.expr(deferExtra)
+       }
 
        params := callABI.ABIAnalyze(n.X.Type(), false /* Do not set (register) nNames from caller side -- can cause races. */)
        types.CalcSize(fn.Type())
@@ -5293,6 +5298,13 @@ func (s *state) call(n *ir.CallExpr, k callKind, returnResultAddr bool) *ssa.Val
                        callArgs = append(callArgs, closure)
                        stksize += int64(types.PtrSize)
                        argStart += int64(types.PtrSize)
+                       if dextra != nil {
+                               // Extra token of type any for deferproc
+                               ACArgs = append(ACArgs, types.Types[types.TINTER])
+                               callArgs = append(callArgs, dextra)
+                               stksize += 2 * int64(types.PtrSize)
+                               argStart += 2 * int64(types.PtrSize)
+                       }
                }
 
                // Set receiver (for interface calls).
@@ -5328,11 +5340,15 @@ func (s *state) call(n *ir.CallExpr, k callKind, returnResultAddr bool) *ssa.Val
                // call target
                switch {
                case k == callDefer:
-                       aux := ssa.StaticAuxCall(ir.Syms.Deferproc, s.f.ABIDefault.ABIAnalyzeTypes(ACArgs, ACResults)) // TODO paramResultInfo for DeferProc
+                       sym := ir.Syms.Deferproc
+                       if dextra != nil {
+                               sym = ir.Syms.Deferprocat
+                       }
+                       aux := ssa.StaticAuxCall(sym, s.f.ABIDefault.ABIAnalyzeTypes(ACArgs, ACResults)) // TODO paramResultInfo for Deferproc(at)
                        call = s.newValue0A(ssa.OpStaticLECall, aux.LateExpansionResultType(), aux)
                case k == callGo:
                        aux := ssa.StaticAuxCall(ir.Syms.Newproc, s.f.ABIDefault.ABIAnalyzeTypes(ACArgs, ACResults))
-                       call = s.newValue0A(ssa.OpStaticLECall, aux.LateExpansionResultType(), aux) // TODO paramResultInfo for NewProc
+                       call = s.newValue0A(ssa.OpStaticLECall, aux.LateExpansionResultType(), aux) // TODO paramResultInfo for Newproc
                case closure != nil:
                        // rawLoad because loading the code pointer from a
                        // closure is always safe, but IsSanitizerSafeAddr
index 6580f053c778c7f42024b67ab1fc3d05edd32af8..de277fc3d8cdabe845538ee6c6fb5629ee97932a 100644 (file)
@@ -17,6 +17,7 @@ type Node interface {
        //    associated with that production; usually the left-most one
        //    ('[' for IndexExpr, 'if' for IfStmt, etc.)
        Pos() Pos
+       SetPos(Pos)
        aNode()
 }
 
@@ -26,8 +27,9 @@ type node struct {
        pos Pos
 }
 
-func (n *node) Pos() Pos { return n.pos }
-func (*node) aNode()     {}
+func (n *node) Pos() Pos       { return n.pos }
+func (n *node) SetPos(pos Pos) { n.pos = pos }
+func (*node) aNode()           {}
 
 // ----------------------------------------------------------------------------
 // Files
@@ -389,8 +391,9 @@ type (
        }
 
        CallStmt struct {
-               Tok  token // Go or Defer
-               Call Expr
+               Tok     token // Go or Defer
+               Call    Expr
+               DeferAt Expr // argument to runtime.deferprocat
                stmt
        }
 
index 6dece1aa5bacbfe80af0c8f465ada3f89b963612..b08f699582fb6595fcd6046e31ee857f3766e808 100644 (file)
@@ -4,7 +4,9 @@
 
 package syntax
 
-type token uint
+type Token uint
+
+type token = Token
 
 //go:generate stringer -type token -linecomment tokens.go
 
index 850873dfa788c67eef45f5027689b75731a9a950..6bace1e6bbba8ba9eb10360fb4172a4150017431 100644 (file)
@@ -117,6 +117,9 @@ func panicnildottype(want *byte)
 func ifaceeq(tab *uintptr, x, y unsafe.Pointer) (ret bool)
 func efaceeq(typ *uintptr, x, y unsafe.Pointer) (ret bool)
 
+// defer in range over func
+func deferrangefunc() interface{}
+
 func fastrand() uint32
 
 // *byte is really *runtime.Type
index 48c27566e5eee055b99e9951dc18a586491e2604..cbf1a4275210e44b271bbdb369154f47d5328863 100644 (file)
@@ -103,139 +103,140 @@ var runtimeDecls = [...]struct {
        {"panicnildottype", funcTag, 72},
        {"ifaceeq", funcTag, 73},
        {"efaceeq", funcTag, 73},
-       {"fastrand", funcTag, 74},
-       {"makemap64", funcTag, 76},
-       {"makemap", funcTag, 77},
-       {"makemap_small", funcTag, 78},
-       {"mapaccess1", funcTag, 79},
-       {"mapaccess1_fast32", funcTag, 80},
-       {"mapaccess1_fast64", funcTag, 81},
-       {"mapaccess1_faststr", funcTag, 82},
-       {"mapaccess1_fat", funcTag, 83},
-       {"mapaccess2", funcTag, 84},
-       {"mapaccess2_fast32", funcTag, 85},
-       {"mapaccess2_fast64", funcTag, 86},
-       {"mapaccess2_faststr", funcTag, 87},
-       {"mapaccess2_fat", funcTag, 88},
-       {"mapassign", funcTag, 79},
-       {"mapassign_fast32", funcTag, 80},
-       {"mapassign_fast32ptr", funcTag, 89},
-       {"mapassign_fast64", funcTag, 81},
-       {"mapassign_fast64ptr", funcTag, 89},
-       {"mapassign_faststr", funcTag, 82},
-       {"mapiterinit", funcTag, 90},
-       {"mapdelete", funcTag, 90},
-       {"mapdelete_fast32", funcTag, 91},
-       {"mapdelete_fast64", funcTag, 92},
-       {"mapdelete_faststr", funcTag, 93},
-       {"mapiternext", funcTag, 94},
-       {"mapclear", funcTag, 95},
-       {"makechan64", funcTag, 97},
-       {"makechan", funcTag, 98},
-       {"chanrecv1", funcTag, 100},
-       {"chanrecv2", funcTag, 101},
-       {"chansend1", funcTag, 103},
+       {"deferrangefunc", funcTag, 74},
+       {"fastrand", funcTag, 75},
+       {"makemap64", funcTag, 77},
+       {"makemap", funcTag, 78},
+       {"makemap_small", funcTag, 79},
+       {"mapaccess1", funcTag, 80},
+       {"mapaccess1_fast32", funcTag, 81},
+       {"mapaccess1_fast64", funcTag, 82},
+       {"mapaccess1_faststr", funcTag, 83},
+       {"mapaccess1_fat", funcTag, 84},
+       {"mapaccess2", funcTag, 85},
+       {"mapaccess2_fast32", funcTag, 86},
+       {"mapaccess2_fast64", funcTag, 87},
+       {"mapaccess2_faststr", funcTag, 88},
+       {"mapaccess2_fat", funcTag, 89},
+       {"mapassign", funcTag, 80},
+       {"mapassign_fast32", funcTag, 81},
+       {"mapassign_fast32ptr", funcTag, 90},
+       {"mapassign_fast64", funcTag, 82},
+       {"mapassign_fast64ptr", funcTag, 90},
+       {"mapassign_faststr", funcTag, 83},
+       {"mapiterinit", funcTag, 91},
+       {"mapdelete", funcTag, 91},
+       {"mapdelete_fast32", funcTag, 92},
+       {"mapdelete_fast64", funcTag, 93},
+       {"mapdelete_faststr", funcTag, 94},
+       {"mapiternext", funcTag, 95},
+       {"mapclear", funcTag, 96},
+       {"makechan64", funcTag, 98},
+       {"makechan", funcTag, 99},
+       {"chanrecv1", funcTag, 101},
+       {"chanrecv2", funcTag, 102},
+       {"chansend1", funcTag, 104},
        {"closechan", funcTag, 30},
-       {"writeBarrier", varTag, 105},
-       {"typedmemmove", funcTag, 106},
-       {"typedmemclr", funcTag, 107},
-       {"typedslicecopy", funcTag, 108},
-       {"selectnbsend", funcTag, 109},
-       {"selectnbrecv", funcTag, 110},
-       {"selectsetpc", funcTag, 111},
-       {"selectgo", funcTag, 112},
+       {"writeBarrier", varTag, 106},
+       {"typedmemmove", funcTag, 107},
+       {"typedmemclr", funcTag, 108},
+       {"typedslicecopy", funcTag, 109},
+       {"selectnbsend", funcTag, 110},
+       {"selectnbrecv", funcTag, 111},
+       {"selectsetpc", funcTag, 112},
+       {"selectgo", funcTag, 113},
        {"block", funcTag, 9},
-       {"makeslice", funcTag, 113},
-       {"makeslice64", funcTag, 114},
-       {"makeslicecopy", funcTag, 115},
-       {"growslice", funcTag, 117},
-       {"unsafeslicecheckptr", funcTag, 118},
+       {"makeslice", funcTag, 114},
+       {"makeslice64", funcTag, 115},
+       {"makeslicecopy", funcTag, 116},
+       {"growslice", funcTag, 118},
+       {"unsafeslicecheckptr", funcTag, 119},
        {"panicunsafeslicelen", funcTag, 9},
        {"panicunsafeslicenilptr", funcTag, 9},
-       {"unsafestringcheckptr", funcTag, 119},
+       {"unsafestringcheckptr", funcTag, 120},
        {"panicunsafestringlen", funcTag, 9},
        {"panicunsafestringnilptr", funcTag, 9},
-       {"mulUintptr", funcTag, 120},
-       {"memmove", funcTag, 121},
-       {"memclrNoHeapPointers", funcTag, 122},
-       {"memclrHasPointers", funcTag, 122},
-       {"memequal", funcTag, 123},
-       {"memequal0", funcTag, 124},
-       {"memequal8", funcTag, 124},
-       {"memequal16", funcTag, 124},
-       {"memequal32", funcTag, 124},
-       {"memequal64", funcTag, 124},
-       {"memequal128", funcTag, 124},
-       {"f32equal", funcTag, 125},
-       {"f64equal", funcTag, 125},
-       {"c64equal", funcTag, 125},
-       {"c128equal", funcTag, 125},
-       {"strequal", funcTag, 125},
-       {"interequal", funcTag, 125},
-       {"nilinterequal", funcTag, 125},
-       {"memhash", funcTag, 126},
-       {"memhash0", funcTag, 127},
-       {"memhash8", funcTag, 127},
-       {"memhash16", funcTag, 127},
-       {"memhash32", funcTag, 127},
-       {"memhash64", funcTag, 127},
-       {"memhash128", funcTag, 127},
-       {"f32hash", funcTag, 128},
-       {"f64hash", funcTag, 128},
-       {"c64hash", funcTag, 128},
-       {"c128hash", funcTag, 128},
-       {"strhash", funcTag, 128},
-       {"interhash", funcTag, 128},
-       {"nilinterhash", funcTag, 128},
-       {"int64div", funcTag, 129},
-       {"uint64div", funcTag, 130},
-       {"int64mod", funcTag, 129},
-       {"uint64mod", funcTag, 130},
-       {"float64toint64", funcTag, 131},
-       {"float64touint64", funcTag, 132},
-       {"float64touint32", funcTag, 133},
-       {"int64tofloat64", funcTag, 134},
-       {"int64tofloat32", funcTag, 136},
-       {"uint64tofloat64", funcTag, 137},
-       {"uint64tofloat32", funcTag, 138},
-       {"uint32tofloat64", funcTag, 139},
-       {"complex128div", funcTag, 140},
-       {"getcallerpc", funcTag, 141},
-       {"getcallersp", funcTag, 141},
+       {"mulUintptr", funcTag, 121},
+       {"memmove", funcTag, 122},
+       {"memclrNoHeapPointers", funcTag, 123},
+       {"memclrHasPointers", funcTag, 123},
+       {"memequal", funcTag, 124},
+       {"memequal0", funcTag, 125},
+       {"memequal8", funcTag, 125},
+       {"memequal16", funcTag, 125},
+       {"memequal32", funcTag, 125},
+       {"memequal64", funcTag, 125},
+       {"memequal128", funcTag, 125},
+       {"f32equal", funcTag, 126},
+       {"f64equal", funcTag, 126},
+       {"c64equal", funcTag, 126},
+       {"c128equal", funcTag, 126},
+       {"strequal", funcTag, 126},
+       {"interequal", funcTag, 126},
+       {"nilinterequal", funcTag, 126},
+       {"memhash", funcTag, 127},
+       {"memhash0", funcTag, 128},
+       {"memhash8", funcTag, 128},
+       {"memhash16", funcTag, 128},
+       {"memhash32", funcTag, 128},
+       {"memhash64", funcTag, 128},
+       {"memhash128", funcTag, 128},
+       {"f32hash", funcTag, 129},
+       {"f64hash", funcTag, 129},
+       {"c64hash", funcTag, 129},
+       {"c128hash", funcTag, 129},
+       {"strhash", funcTag, 129},
+       {"interhash", funcTag, 129},
+       {"nilinterhash", funcTag, 129},
+       {"int64div", funcTag, 130},
+       {"uint64div", funcTag, 131},
+       {"int64mod", funcTag, 130},
+       {"uint64mod", funcTag, 131},
+       {"float64toint64", funcTag, 132},
+       {"float64touint64", funcTag, 133},
+       {"float64touint32", funcTag, 134},
+       {"int64tofloat64", funcTag, 135},
+       {"int64tofloat32", funcTag, 137},
+       {"uint64tofloat64", funcTag, 138},
+       {"uint64tofloat32", funcTag, 139},
+       {"uint32tofloat64", funcTag, 140},
+       {"complex128div", funcTag, 141},
+       {"getcallerpc", funcTag, 142},
+       {"getcallersp", funcTag, 142},
        {"racefuncenter", funcTag, 31},
        {"racefuncexit", funcTag, 9},
        {"raceread", funcTag, 31},
        {"racewrite", funcTag, 31},
-       {"racereadrange", funcTag, 142},
-       {"racewriterange", funcTag, 142},
-       {"msanread", funcTag, 142},
-       {"msanwrite", funcTag, 142},
-       {"msanmove", funcTag, 143},
-       {"asanread", funcTag, 142},
-       {"asanwrite", funcTag, 142},
-       {"checkptrAlignment", funcTag, 144},
-       {"checkptrArithmetic", funcTag, 146},
-       {"libfuzzerTraceCmp1", funcTag, 147},
-       {"libfuzzerTraceCmp2", funcTag, 148},
-       {"libfuzzerTraceCmp4", funcTag, 149},
-       {"libfuzzerTraceCmp8", funcTag, 150},
-       {"libfuzzerTraceConstCmp1", funcTag, 147},
-       {"libfuzzerTraceConstCmp2", funcTag, 148},
-       {"libfuzzerTraceConstCmp4", funcTag, 149},
-       {"libfuzzerTraceConstCmp8", funcTag, 150},
-       {"libfuzzerHookStrCmp", funcTag, 151},
-       {"libfuzzerHookEqualFold", funcTag, 151},
-       {"addCovMeta", funcTag, 153},
+       {"racereadrange", funcTag, 143},
+       {"racewriterange", funcTag, 143},
+       {"msanread", funcTag, 143},
+       {"msanwrite", funcTag, 143},
+       {"msanmove", funcTag, 144},
+       {"asanread", funcTag, 143},
+       {"asanwrite", funcTag, 143},
+       {"checkptrAlignment", funcTag, 145},
+       {"checkptrArithmetic", funcTag, 147},
+       {"libfuzzerTraceCmp1", funcTag, 148},
+       {"libfuzzerTraceCmp2", funcTag, 149},
+       {"libfuzzerTraceCmp4", funcTag, 150},
+       {"libfuzzerTraceCmp8", funcTag, 151},
+       {"libfuzzerTraceConstCmp1", funcTag, 148},
+       {"libfuzzerTraceConstCmp2", funcTag, 149},
+       {"libfuzzerTraceConstCmp4", funcTag, 150},
+       {"libfuzzerTraceConstCmp8", funcTag, 151},
+       {"libfuzzerHookStrCmp", funcTag, 152},
+       {"libfuzzerHookEqualFold", funcTag, 152},
+       {"addCovMeta", funcTag, 154},
        {"x86HasPOPCNT", varTag, 6},
        {"x86HasSSE41", varTag, 6},
        {"x86HasFMA", varTag, 6},
        {"armHasVFPv4", varTag, 6},
        {"arm64HasATOMICS", varTag, 6},
-       {"asanregisterglobals", funcTag, 122},
+       {"asanregisterglobals", funcTag, 123},
 }
 
 func runtimeTypes() []*types.Type {
-       var typs [154]*types.Type
+       var typs [155]*types.Type
        typs[0] = types.ByteType
        typs[1] = types.NewPtr(typs[0])
        typs[2] = types.Types[types.TANY]
@@ -310,86 +311,87 @@ func runtimeTypes() []*types.Type {
        typs[71] = newSig(params(typs[1], typs[1], typs[1]), nil)
        typs[72] = newSig(params(typs[1]), nil)
        typs[73] = newSig(params(typs[57], typs[7], typs[7]), params(typs[6]))
-       typs[74] = newSig(nil, params(typs[62]))
-       typs[75] = types.NewMap(typs[2], typs[2])
-       typs[76] = newSig(params(typs[1], typs[22], typs[3]), params(typs[75]))
-       typs[77] = newSig(params(typs[1], typs[15], typs[3]), params(typs[75]))
-       typs[78] = newSig(nil, params(typs[75]))
-       typs[79] = newSig(params(typs[1], typs[75], typs[3]), params(typs[3]))
-       typs[80] = newSig(params(typs[1], typs[75], typs[62]), params(typs[3]))
-       typs[81] = newSig(params(typs[1], typs[75], typs[24]), params(typs[3]))
-       typs[82] = newSig(params(typs[1], typs[75], typs[28]), params(typs[3]))
-       typs[83] = newSig(params(typs[1], typs[75], typs[3], typs[1]), params(typs[3]))
-       typs[84] = newSig(params(typs[1], typs[75], typs[3]), params(typs[3], typs[6]))
-       typs[85] = newSig(params(typs[1], typs[75], typs[62]), params(typs[3], typs[6]))
-       typs[86] = newSig(params(typs[1], typs[75], typs[24]), params(typs[3], typs[6]))
-       typs[87] = newSig(params(typs[1], typs[75], typs[28]), params(typs[3], typs[6]))
-       typs[88] = newSig(params(typs[1], typs[75], typs[3], typs[1]), params(typs[3], typs[6]))
-       typs[89] = newSig(params(typs[1], typs[75], typs[7]), params(typs[3]))
-       typs[90] = newSig(params(typs[1], typs[75], typs[3]), nil)
-       typs[91] = newSig(params(typs[1], typs[75], typs[62]), nil)
-       typs[92] = newSig(params(typs[1], typs[75], typs[24]), nil)
-       typs[93] = newSig(params(typs[1], typs[75], typs[28]), nil)
-       typs[94] = newSig(params(typs[3]), nil)
-       typs[95] = newSig(params(typs[1], typs[75]), nil)
-       typs[96] = types.NewChan(typs[2], types.Cboth)
-       typs[97] = newSig(params(typs[1], typs[22]), params(typs[96]))
-       typs[98] = newSig(params(typs[1], typs[15]), params(typs[96]))
-       typs[99] = types.NewChan(typs[2], types.Crecv)
-       typs[100] = newSig(params(typs[99], typs[3]), nil)
-       typs[101] = newSig(params(typs[99], typs[3]), params(typs[6]))
-       typs[102] = types.NewChan(typs[2], types.Csend)
-       typs[103] = newSig(params(typs[102], typs[3]), nil)
-       typs[104] = types.NewArray(typs[0], 3)
-       typs[105] = types.NewStruct([]*types.Field{types.NewField(src.NoXPos, Lookup("enabled"), typs[6]), types.NewField(src.NoXPos, Lookup("pad"), typs[104]), types.NewField(src.NoXPos, Lookup("needed"), typs[6]), types.NewField(src.NoXPos, Lookup("cgo"), typs[6]), types.NewField(src.NoXPos, Lookup("alignme"), typs[24])})
-       typs[106] = newSig(params(typs[1], typs[3], typs[3]), nil)
-       typs[107] = newSig(params(typs[1], typs[3]), nil)
-       typs[108] = newSig(params(typs[1], typs[3], typs[15], typs[3], typs[15]), params(typs[15]))
-       typs[109] = newSig(params(typs[102], typs[3]), params(typs[6]))
-       typs[110] = newSig(params(typs[3], typs[99]), params(typs[6], typs[6]))
-       typs[111] = newSig(params(typs[57]), nil)
-       typs[112] = newSig(params(typs[1], typs[1], typs[57], typs[15], typs[15], typs[6]), params(typs[15], typs[6]))
-       typs[113] = newSig(params(typs[1], typs[15], typs[15]), params(typs[7]))
-       typs[114] = newSig(params(typs[1], typs[22], typs[22]), params(typs[7]))
-       typs[115] = newSig(params(typs[1], typs[15], typs[15], typs[7]), params(typs[7]))
-       typs[116] = types.NewSlice(typs[2])
-       typs[117] = newSig(params(typs[3], typs[15], typs[15], typs[15], typs[1]), params(typs[116]))
-       typs[118] = newSig(params(typs[1], typs[7], typs[22]), nil)
-       typs[119] = newSig(params(typs[7], typs[22]), nil)
-       typs[120] = newSig(params(typs[5], typs[5]), params(typs[5], typs[6]))
-       typs[121] = newSig(params(typs[3], typs[3], typs[5]), nil)
-       typs[122] = newSig(params(typs[7], typs[5]), nil)
-       typs[123] = newSig(params(typs[3], typs[3], typs[5]), params(typs[6]))
-       typs[124] = newSig(params(typs[3], typs[3]), params(typs[6]))
-       typs[125] = newSig(params(typs[7], typs[7]), params(typs[6]))
-       typs[126] = newSig(params(typs[3], typs[5], typs[5]), params(typs[5]))
-       typs[127] = newSig(params(typs[7], typs[5]), params(typs[5]))
-       typs[128] = newSig(params(typs[3], typs[5]), params(typs[5]))
-       typs[129] = newSig(params(typs[22], typs[22]), params(typs[22]))
-       typs[130] = newSig(params(typs[24], typs[24]), params(typs[24]))
-       typs[131] = newSig(params(typs[20]), params(typs[22]))
-       typs[132] = newSig(params(typs[20]), params(typs[24]))
-       typs[133] = newSig(params(typs[20]), params(typs[62]))
-       typs[134] = newSig(params(typs[22]), params(typs[20]))
-       typs[135] = types.Types[types.TFLOAT32]
-       typs[136] = newSig(params(typs[22]), params(typs[135]))
-       typs[137] = newSig(params(typs[24]), params(typs[20]))
-       typs[138] = newSig(params(typs[24]), params(typs[135]))
-       typs[139] = newSig(params(typs[62]), params(typs[20]))
-       typs[140] = newSig(params(typs[26], typs[26]), params(typs[26]))
-       typs[141] = newSig(nil, params(typs[5]))
-       typs[142] = newSig(params(typs[5], typs[5]), nil)
-       typs[143] = newSig(params(typs[5], typs[5], typs[5]), nil)
-       typs[144] = newSig(params(typs[7], typs[1], typs[5]), nil)
-       typs[145] = types.NewSlice(typs[7])
-       typs[146] = newSig(params(typs[7], typs[145]), nil)
-       typs[147] = newSig(params(typs[66], typs[66], typs[17]), nil)
-       typs[148] = newSig(params(typs[60], typs[60], typs[17]), nil)
-       typs[149] = newSig(params(typs[62], typs[62], typs[17]), nil)
-       typs[150] = newSig(params(typs[24], typs[24], typs[17]), nil)
-       typs[151] = newSig(params(typs[28], typs[28], typs[17]), nil)
-       typs[152] = types.NewArray(typs[0], 16)
-       typs[153] = newSig(params(typs[7], typs[62], typs[152], typs[28], typs[15], typs[66], typs[66]), params(typs[62]))
+       typs[74] = newSig(nil, params(typs[10]))
+       typs[75] = newSig(nil, params(typs[62]))
+       typs[76] = types.NewMap(typs[2], typs[2])
+       typs[77] = newSig(params(typs[1], typs[22], typs[3]), params(typs[76]))
+       typs[78] = newSig(params(typs[1], typs[15], typs[3]), params(typs[76]))
+       typs[79] = newSig(nil, params(typs[76]))
+       typs[80] = newSig(params(typs[1], typs[76], typs[3]), params(typs[3]))
+       typs[81] = newSig(params(typs[1], typs[76], typs[62]), params(typs[3]))
+       typs[82] = newSig(params(typs[1], typs[76], typs[24]), params(typs[3]))
+       typs[83] = newSig(params(typs[1], typs[76], typs[28]), params(typs[3]))
+       typs[84] = newSig(params(typs[1], typs[76], typs[3], typs[1]), params(typs[3]))
+       typs[85] = newSig(params(typs[1], typs[76], typs[3]), params(typs[3], typs[6]))
+       typs[86] = newSig(params(typs[1], typs[76], typs[62]), params(typs[3], typs[6]))
+       typs[87] = newSig(params(typs[1], typs[76], typs[24]), params(typs[3], typs[6]))
+       typs[88] = newSig(params(typs[1], typs[76], typs[28]), params(typs[3], typs[6]))
+       typs[89] = newSig(params(typs[1], typs[76], typs[3], typs[1]), params(typs[3], typs[6]))
+       typs[90] = newSig(params(typs[1], typs[76], typs[7]), params(typs[3]))
+       typs[91] = newSig(params(typs[1], typs[76], typs[3]), nil)
+       typs[92] = newSig(params(typs[1], typs[76], typs[62]), nil)
+       typs[93] = newSig(params(typs[1], typs[76], typs[24]), nil)
+       typs[94] = newSig(params(typs[1], typs[76], typs[28]), nil)
+       typs[95] = newSig(params(typs[3]), nil)
+       typs[96] = newSig(params(typs[1], typs[76]), nil)
+       typs[97] = types.NewChan(typs[2], types.Cboth)
+       typs[98] = newSig(params(typs[1], typs[22]), params(typs[97]))
+       typs[99] = newSig(params(typs[1], typs[15]), params(typs[97]))
+       typs[100] = types.NewChan(typs[2], types.Crecv)
+       typs[101] = newSig(params(typs[100], typs[3]), nil)
+       typs[102] = newSig(params(typs[100], typs[3]), params(typs[6]))
+       typs[103] = types.NewChan(typs[2], types.Csend)
+       typs[104] = newSig(params(typs[103], typs[3]), nil)
+       typs[105] = types.NewArray(typs[0], 3)
+       typs[106] = types.NewStruct([]*types.Field{types.NewField(src.NoXPos, Lookup("enabled"), typs[6]), types.NewField(src.NoXPos, Lookup("pad"), typs[105]), types.NewField(src.NoXPos, Lookup("needed"), typs[6]), types.NewField(src.NoXPos, Lookup("cgo"), typs[6]), types.NewField(src.NoXPos, Lookup("alignme"), typs[24])})
+       typs[107] = newSig(params(typs[1], typs[3], typs[3]), nil)
+       typs[108] = newSig(params(typs[1], typs[3]), nil)
+       typs[109] = newSig(params(typs[1], typs[3], typs[15], typs[3], typs[15]), params(typs[15]))
+       typs[110] = newSig(params(typs[103], typs[3]), params(typs[6]))
+       typs[111] = newSig(params(typs[3], typs[100]), params(typs[6], typs[6]))
+       typs[112] = newSig(params(typs[57]), nil)
+       typs[113] = newSig(params(typs[1], typs[1], typs[57], typs[15], typs[15], typs[6]), params(typs[15], typs[6]))
+       typs[114] = newSig(params(typs[1], typs[15], typs[15]), params(typs[7]))
+       typs[115] = newSig(params(typs[1], typs[22], typs[22]), params(typs[7]))
+       typs[116] = newSig(params(typs[1], typs[15], typs[15], typs[7]), params(typs[7]))
+       typs[117] = types.NewSlice(typs[2])
+       typs[118] = newSig(params(typs[3], typs[15], typs[15], typs[15], typs[1]), params(typs[117]))
+       typs[119] = newSig(params(typs[1], typs[7], typs[22]), nil)
+       typs[120] = newSig(params(typs[7], typs[22]), nil)
+       typs[121] = newSig(params(typs[5], typs[5]), params(typs[5], typs[6]))
+       typs[122] = newSig(params(typs[3], typs[3], typs[5]), nil)
+       typs[123] = newSig(params(typs[7], typs[5]), nil)
+       typs[124] = newSig(params(typs[3], typs[3], typs[5]), params(typs[6]))
+       typs[125] = newSig(params(typs[3], typs[3]), params(typs[6]))
+       typs[126] = newSig(params(typs[7], typs[7]), params(typs[6]))
+       typs[127] = newSig(params(typs[3], typs[5], typs[5]), params(typs[5]))
+       typs[128] = newSig(params(typs[7], typs[5]), params(typs[5]))
+       typs[129] = newSig(params(typs[3], typs[5]), params(typs[5]))
+       typs[130] = newSig(params(typs[22], typs[22]), params(typs[22]))
+       typs[131] = newSig(params(typs[24], typs[24]), params(typs[24]))
+       typs[132] = newSig(params(typs[20]), params(typs[22]))
+       typs[133] = newSig(params(typs[20]), params(typs[24]))
+       typs[134] = newSig(params(typs[20]), params(typs[62]))
+       typs[135] = newSig(params(typs[22]), params(typs[20]))
+       typs[136] = types.Types[types.TFLOAT32]
+       typs[137] = newSig(params(typs[22]), params(typs[136]))
+       typs[138] = newSig(params(typs[24]), params(typs[20]))
+       typs[139] = newSig(params(typs[24]), params(typs[136]))
+       typs[140] = newSig(params(typs[62]), params(typs[20]))
+       typs[141] = newSig(params(typs[26], typs[26]), params(typs[26]))
+       typs[142] = newSig(nil, params(typs[5]))
+       typs[143] = newSig(params(typs[5], typs[5]), nil)
+       typs[144] = newSig(params(typs[5], typs[5], typs[5]), nil)
+       typs[145] = newSig(params(typs[7], typs[1], typs[5]), nil)
+       typs[146] = types.NewSlice(typs[7])
+       typs[147] = newSig(params(typs[7], typs[146]), nil)
+       typs[148] = newSig(params(typs[66], typs[66], typs[17]), nil)
+       typs[149] = newSig(params(typs[60], typs[60], typs[17]), nil)
+       typs[150] = newSig(params(typs[62], typs[62], typs[17]), nil)
+       typs[151] = newSig(params(typs[24], typs[24], typs[17]), nil)
+       typs[152] = newSig(params(typs[28], typs[28], typs[17]), nil)
+       typs[153] = types.NewArray(typs[0], 16)
+       typs[154] = newSig(params(typs[7], typs[62], typs[153], typs[28], typs[15], typs[66], typs[66]), params(typs[62]))
        return typs[:]
 }
 
index 9047211879d5323e8ef1799a7f989c42acbccf71..d3557d2f942779e46cec3627f10975c9da9961ac 100644 (file)
@@ -583,6 +583,17 @@ func walkCall(n *ir.CallExpr, init *ir.Nodes) ir.Node {
                return e
        }
 
+       if name, ok := n.X.(*ir.Name); ok {
+               sym := name.Sym()
+               if sym.Pkg.Path == "go.runtime" && sym.Name == "deferrangefunc" {
+                       // Call to runtime.deferrangefunc is being shared with a range-over-func
+                       // body that might add defers to this frame, so we cannot use open-coded defers
+                       // and we need to call deferreturn even if we don't see any other explicit defers.
+                       ir.CurFunc.SetHasDefer(true)
+                       ir.CurFunc.SetOpenCodedDeferDisallowed(true)
+               }
+       }
+
        walkCall1(n, init)
        return n
 }
index 2356f803d3261a1fc2321b307ecb3598b69d701b..6a22bfcb87d9cf0fb2e1253171df4283e53d75a9 100644 (file)
@@ -55,6 +55,7 @@ func walkStmt(n ir.Node) ir.Node {
                if n.Typecheck() == 0 {
                        base.Fatalf("missing typecheck: %+v", n)
                }
+
                init := ir.TakeInit(n)
                n = walkExpr(n, &init)
                if n.Op() == ir.ONAME {
@@ -104,10 +105,11 @@ func walkStmt(n ir.Node) ir.Node {
                n := n.(*ir.GoDeferStmt)
                ir.CurFunc.SetHasDefer(true)
                ir.CurFunc.NumDefers++
-               if ir.CurFunc.NumDefers > maxOpenDefers {
+               if ir.CurFunc.NumDefers > maxOpenDefers || n.DeferAt != nil {
                        // Don't allow open-coded defers if there are more than
                        // 8 defers in the function, since we use a single
                        // byte to record active defers.
+                       // Also don't allow if we need to use deferprocat.
                        ir.CurFunc.SetOpenCodedDeferDisallowed(true)
                }
                if n.Esc() != ir.EscNever {
index 80a4ac84163734ed36e336ad7924edf6ab341886..613d7a53f6cbd1ab73d23fa116ed5f38c22456ad 100644 (file)
@@ -11,43 +11,365 @@ package main
 // test range over integers
 
 func testint1() {
+       bad := false
        j := 0
        for i := range int(4) {
                if i != j {
                        println("range var", i, "want", j)
+                       bad = true
                }
                j++
        }
        if j != 4 {
                println("wrong count ranging over 4:", j)
+               bad = true
+       }
+       if bad {
+               panic("testint1")
        }
 }
 
 func testint2() {
+       bad := false
        j := 0
        for i := range 4 {
                if i != j {
                        println("range var", i, "want", j)
+                       bad = true
                }
                j++
        }
        if j != 4 {
                println("wrong count ranging over 4:", j)
+               bad = true
+       }
+       if bad {
+               panic("testint2")
        }
 }
 
 func testint3() {
+       bad := false
        type MyInt int
-
        j := MyInt(0)
        for i := range MyInt(4) {
                if i != j {
                        println("range var", i, "want", j)
+                       bad = true
                }
                j++
        }
        if j != 4 {
                println("wrong count ranging over 4:", j)
+               bad = true
+       }
+       if bad {
+               panic("testint3")
+       }
+}
+
+// test range over functions
+
+var gj int
+
+func yield4x(yield func() bool) {
+       _ = yield() && yield() && yield() && yield()
+}
+
+func yield4(yield func(int) bool) {
+       _ = yield(1) && yield(2) && yield(3) && yield(4)
+}
+
+func yield3(yield func(int) bool) {
+       _ = yield(1) && yield(2) && yield(3)
+}
+
+func yield2(yield func(int) bool) {
+       _ = yield(1) && yield(2)
+}
+
+func testfunc0() {
+       j := 0
+       for range yield4x {
+               j++
+       }
+       if j != 4 {
+               println("wrong count ranging over yield4x:", j)
+               panic("testfunc0")
+       }
+
+       j = 0
+       for _ = range yield4 {
+               j++
+       }
+       if j != 4 {
+               println("wrong count ranging over yield4:", j)
+               panic("testfunc0")
+       }
+}
+
+func testfunc1() {
+       bad := false
+       j := 1
+       for i := range yield4 {
+               if i != j {
+                       println("range var", i, "want", j)
+                       bad = true
+               }
+               j++
+       }
+       if j != 5 {
+               println("wrong count ranging over f:", j)
+               bad = true
+       }
+       if bad {
+               panic("testfunc1")
+       }
+}
+
+func testfunc2() {
+       bad := false
+       j := 1
+       var i int
+       for i = range yield4 {
+               if i != j {
+                       println("range var", i, "want", j)
+                       bad = true
+               }
+               j++
+       }
+       if j != 5 {
+               println("wrong count ranging over f:", j)
+               bad = true
+       }
+       if i != 4 {
+               println("wrong final i ranging over f:", i)
+               bad = true
+       }
+       if bad {
+               panic("testfunc2")
+       }
+}
+
+func testfunc3() {
+       bad := false
+       j := 1
+       var i int
+       for i = range yield4 {
+               if i != j {
+                       println("range var", i, "want", j)
+                       bad = true
+               }
+               j++
+               if i == 2 {
+                       break
+               }
+               continue
+       }
+       if j != 3 {
+               println("wrong count ranging over f:", j)
+               bad = true
+       }
+       if i != 2 {
+               println("wrong final i ranging over f:", i)
+               bad = true
+       }
+       if bad {
+               panic("testfunc3")
+       }
+}
+
+func testfunc4() {
+       bad := false
+       j := 1
+       var i int
+       func() {
+               for i = range yield4 {
+                       if i != j {
+                               println("range var", i, "want", j)
+                               bad = true
+                       }
+                       j++
+                       if i == 2 {
+                               return
+                       }
+               }
+       }()
+       if j != 3 {
+               println("wrong count ranging over f:", j)
+               bad = true
+       }
+       if i != 2 {
+               println("wrong final i ranging over f:", i)
+               bad = true
+       }
+       if bad {
+               panic("testfunc3")
+       }
+}
+
+func func5() (int, int) {
+       for i := range yield4 {
+               return 10, i
+       }
+       panic("still here")
+}
+
+func testfunc5() {
+       x, y := func5()
+       if x != 10 || y != 1 {
+               println("wrong results", x, y, "want", 10, 1)
+               panic("testfunc5")
+       }
+}
+
+func func6() (z, w int) {
+       for i := range yield4 {
+               z = 10
+               w = i
+               return
+       }
+       panic("still here")
+}
+
+func testfunc6() {
+       x, y := func6()
+       if x != 10 || y != 1 {
+               println("wrong results", x, y, "want", 10, 1)
+               panic("testfunc6")
+       }
+}
+
+var saved []int
+
+func save(x int) {
+       saved = append(saved, x)
+}
+
+func printslice(s []int) {
+       print("[")
+       for i, x := range s {
+               if i > 0 {
+                       print(", ")
+               }
+               print(x)
+       }
+       print("]")
+}
+
+func eqslice(s, t []int) bool {
+       if len(s) != len(t) {
+               return false
+       }
+       for i, x := range s {
+               if x != t[i] {
+                       return false
+               }
+       }
+       return true
+}
+
+func func7() {
+       defer save(-1)
+       for i := range yield4 {
+               defer save(i)
+       }
+       defer save(5)
+}
+
+func checkslice(name string, saved, want []int) {
+       if !eqslice(saved, want) {
+               print("wrong results ")
+               printslice(saved)
+               print(" want ")
+               printslice(want)
+               print("\n")
+               panic(name)
+       }
+}
+
+func testfunc7() {
+       saved = nil
+       func7()
+       want := []int{5, 4, 3, 2, 1, -1}
+       checkslice("testfunc7", saved, want)
+}
+
+func func8() {
+       defer save(-1)
+       for i := range yield2 {
+               for j := range yield3 {
+                       defer save(i*10 + j)
+               }
+               defer save(i)
+       }
+       defer save(-2)
+       for i := range yield4 {
+               defer save(i)
+       }
+       defer save(-3)
+}
+
+func testfunc8() {
+       saved = nil
+       func8()
+       want := []int{-3, 4, 3, 2, 1, -2, 2, 23, 22, 21, 1, 13, 12, 11, -1}
+       checkslice("testfunc8", saved, want)
+}
+
+func func9() {
+       n := 0
+       for _ = range yield2 {
+               for _ = range yield3 {
+                       n++
+                       defer save(n)
+               }
+       }
+}
+
+func testfunc9() {
+       saved = nil
+       func9()
+       want := []int{6, 5, 4, 3, 2, 1}
+       checkslice("testfunc9", saved, want)
+}
+
+// test that range evaluates the index and value expressions
+// exactly once per iteration.
+
+var ncalls = 0
+
+func getvar(p *int) *int {
+       ncalls++
+       return p
+}
+
+func iter2(list ...int) func(func(int, int) bool) {
+       return func(yield func(int, int) bool) {
+               for i, x := range list {
+                       if !yield(i, x) {
+                               return
+                       }
+               }
+       }
+}
+
+func testcalls() {
+       var i, v int
+       ncalls = 0
+       si := 0
+       sv := 0
+       for *getvar(&i), *getvar(&v) = range iter2(1, 2) {
+               si += i
+               sv += v
+       }
+       if ncalls != 4 {
+               println("wrong number of calls:", ncalls, "!= 4")
+               panic("fail")
+       }
+       if si != 1 || sv != 3 {
+               println("wrong sum in testcalls", si, sv)
+               panic("fail")
        }
 }
 
@@ -55,4 +377,15 @@ func main() {
        testint1()
        testint2()
        testint3()
+       testfunc0()
+       testfunc1()
+       testfunc2()
+       testfunc3()
+       testfunc4()
+       testfunc5()
+       testfunc6()
+       testfunc7()
+       testfunc8()
+       testfunc9()
+       testcalls()
 }
diff --git a/test/rangegen.go b/test/rangegen.go
new file mode 100644 (file)
index 0000000..7916ed2
--- /dev/null
@@ -0,0 +1,333 @@
+// runoutput -goexperiment range
+
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Torture test for range-over-func.
+//
+// cmd/internal/testdir runs this like
+//
+//     go run rangegen.go >x.go
+//     go run x.go
+//
+// but a longer version can be run using
+//
+//     go run rangegen.go long
+//
+// In that second form, rangegen takes care of compiling
+// and running the code it generates, in batches.
+// That form takes 10-20 minutes to run.
+
+package main
+
+import (
+       "bytes"
+       "fmt"
+       "log"
+       "os"
+       "os/exec"
+       "strings"
+)
+
+const verbose = false
+
+func main() {
+       long := len(os.Args) > 1 && os.Args[1] == "long"
+       log.SetFlags(0)
+       log.SetPrefix("rangegen: ")
+
+       b := new(bytes.Buffer)
+       tests := ""
+       flush := func(force bool) {
+               if !long || (strings.Count(tests, "\n") < 1000 && !force) {
+                       return
+               }
+               p(b, mainCode, tests)
+               err := os.WriteFile("tmp.go", b.Bytes(), 0666)
+               if err != nil {
+                       log.Fatal(err)
+               }
+               out, err := exec.Command("go", "run", "tmp.go").CombinedOutput()
+               if err != nil {
+                       log.Fatalf("go run tmp.go: %v\n%s", err, out)
+               }
+               print(".")
+               if force {
+                       print("\nPASS\n")
+               }
+               b.Reset()
+               tests = ""
+               p(b, "package main\n\n")
+               p(b, "const verbose = %v\n\n", verbose)
+       }
+
+       p(b, "package main\n\n")
+       p(b, "const verbose = %v\n\n", verbose)
+       max := 2
+       if !long {
+               max = 5
+       }
+       for i := 1; i <= max; i++ {
+               maxDouble := -1
+               if long {
+                       maxDouble = i
+               }
+               for double := -1; double <= maxDouble; double++ {
+                       code := gen(new(bytes.Buffer), "", "", "", i, double, func(c int) bool { return true })
+                       for j := 0; j < code; j++ {
+                               hi := j + 1
+                               if long {
+                                       hi = code
+                               }
+                               for k := j; k < hi && k < code; k++ {
+                                       s := fmt.Sprintf("%d_%d_%d_%d", i, double+1, j, k)
+                                       code0 := gen(b, "testFunc"+s, "", "yield2", i, double, func(c int) bool { return c == j || c == k })
+                                       code1 := gen(b, "testSlice"+s, "_, ", "slice2", i, double, func(c int) bool { return c == j || c == k })
+                                       if code0 != code1 {
+                                               panic("bad generator")
+                                       }
+                                       tests += "test" + s + "()\n"
+                                       p(b, testCode, "test"+s, []int{j, k}, "testFunc"+s, "testSlice"+s)
+                                       flush(false)
+                               }
+                       }
+               }
+       }
+       for i := 1; i <= max; i++ {
+               maxDouble := -1
+               if long {
+                       maxDouble = i
+               }
+               for double := -1; double <= maxDouble; double++ {
+                       s := fmt.Sprintf("%d_%d", i, double+1)
+                       code := gen(b, "testFunc"+s, "", "yield2", i, double, func(c int) bool { return true })
+                       code1 := gen(b, "testSlice"+s, "_, ", "slice2", i, double, func(c int) bool { return true })
+                       if code != code1 {
+                               panic("bad generator")
+                       }
+                       tests += "test" + s + "()\n"
+                       var all []int
+                       for j := 0; j < code; j++ {
+                               all = append(all, j)
+                       }
+                       p(b, testCode, "test"+s, all, "testFunc"+s, "testSlice"+s)
+                       flush(false)
+               }
+       }
+       if long {
+               flush(true)
+               os.Remove("tmp.go")
+               return
+       }
+
+       p(b, mainCode, tests)
+
+       os.Stdout.Write(b.Bytes())
+}
+
+func p(b *bytes.Buffer, format string, args ...any) {
+       fmt.Fprintf(b, format, args...)
+}
+
+func gen(b *bytes.Buffer, name, prefix, rangeExpr string, depth, double int, allowed func(int) bool) int {
+       p(b, "func %s(o *output, code int) int {\n", name)
+       p(b, "  dfr := 0; _ = dfr\n")
+       code := genLoop(b, 0, prefix, rangeExpr, depth, double, 0, "", allowed)
+       p(b, "  return 0\n")
+       p(b, "}\n\n")
+       return code
+}
+
+func genLoop(b *bytes.Buffer, d int, prefix, rangeExpr string, depth, double, code int, labelSuffix string, allowed func(int) bool) int {
+       limit := 1
+       if d == double {
+               limit = 2
+       }
+       for rep := 0; rep < limit; rep++ {
+               if rep == 1 {
+                       labelSuffix = "R"
+               }
+               s := fmt.Sprintf("%d%s", d, labelSuffix)
+               p(b, "  o.log(`top%s`)\n", s)
+               p(b, "  l%sa := 0\n", s)
+               p(b, "goto L%sa; L%sa:  o.log(`L%sa`)\n", s, s, s)
+               p(b, "  if l%sa++; l%sa >= 2 { o.log(`loop L%sa`); return -1 }\n", s, s, s)
+               p(b, "  l%sfor := 0\n", s)
+               p(b, "goto L%sfor; L%sfor: for f := 0; f < 1; f++ { o.log(`L%sfor`)\n", s, s, s)
+               p(b, "  if l%sfor++; l%sfor >= 2 { o.log(`loop L%sfor`); return -1 }\n", s, s, s)
+               p(b, "  l%ssw := 0\n", s)
+               p(b, "goto L%ssw; L%ssw: switch { default: o.log(`L%ssw`)\n", s, s, s)
+               p(b, "  if l%ssw++; l%ssw >= 2 { o.log(`loop L%ssw`); return -1 }\n", s, s, s)
+               p(b, "  l%ssel := 0\n", s)
+               p(b, "goto L%ssel; L%ssel: select { default: o.log(`L%ssel`)\n", s, s, s)
+               p(b, "  if l%ssel++; l%ssel >= 2 { o.log(`loop L%ssel`); return -1 }\n", s, s, s)
+               p(b, "  l%s := 0\n", s)
+               p(b, "goto L%s; L%s:    for %s i%s := range %s {\n", s, s, prefix, s, rangeExpr)
+               p(b, "  o.log1(`L%s top`, i%s)\n", s, s)
+               p(b, "  if l%s++; l%s >= 4 { o.log(`loop L%s`); return -1 }\n", s, s, s)
+               printTests := func() {
+                       if code++; allowed(code) {
+                               p(b, "  if code == %v { break }\n", code)
+                       }
+                       if code++; allowed(code) {
+                               p(b, "  if code == %v { continue }\n", code)
+                       }
+                       if code++; allowed(code) {
+                               p(b, "  switch { case code == %v: continue }\n", code)
+                       }
+                       if code++; allowed(code) {
+                               p(b, "  if code == %v { return %[1]v }\n", code)
+                       }
+                       if code++; allowed(code) {
+                               p(b, "  if code == %v { select { default: break } }\n", code)
+                       }
+                       if code++; allowed(code) {
+                               p(b, "  if code == %v { switch { default: break } }\n", code)
+                       }
+                       if code++; allowed(code) {
+                               p(b, "  if code == %v { dfr++; defer o.log1(`defer %d`, dfr) }\n", code, code)
+                       }
+                       for i := d; i > 0; i-- {
+                               suffix := labelSuffix
+                               if i < double {
+                                       suffix = ""
+                               }
+                               if code++; allowed(code) {
+                                       p(b, "  if code == %v { break L%d%s }\n", code, i, suffix)
+                               }
+                               if code++; allowed(code) {
+                                       p(b, "  if code == %v { select { default: break L%d%s } }\n", code, i, suffix)
+                               }
+                               if code++; allowed(code) {
+                                       p(b, "  if code == %v { break L%d%s }\n", code, i, suffix)
+                               }
+                               if code++; allowed(code) {
+                                       p(b, "  if code == %v { break L%d%ssw }\n", code, i, suffix)
+                               }
+                               if code++; allowed(code) {
+                                       p(b, "  if code == %v { break L%d%ssel }\n", code, i, suffix)
+                               }
+                               if code++; allowed(code) {
+                                       p(b, "  if code == %v { break L%d%sfor }\n", code, i, suffix)
+                               }
+                               if code++; allowed(code) {
+                                       p(b, "  if code == %v { continue L%d%sfor }\n", code, i, suffix)
+                               }
+                               if code++; allowed(code) {
+                                       p(b, "  if code == %v { goto L%d%sa }\n", code, i, suffix)
+                               }
+                               if code++; allowed(code) {
+                                       p(b, "  if code == %v { goto L%d%s }\n", code, i, suffix)
+                               }
+                               if code++; allowed(code) {
+                                       p(b, "  if code == %v { goto L%d%sb }\n", code, i, suffix)
+                               }
+                       }
+               }
+               printTests()
+               if d < depth {
+                       if rep == 1 {
+                               double = d // signal to children to use the rep=1 labels
+                       }
+                       code = genLoop(b, d+1, prefix, rangeExpr, depth, double, code, labelSuffix, allowed)
+                       printTests()
+               }
+               p(b, "  o.log(`L%s bot`)\n", s)
+               p(b, "  }\n")
+               p(b, "  o.log(`L%ssel bot`)\n", s)
+               p(b, "  }\n")
+               p(b, "  o.log(`L%ssw bot`)\n", s)
+               p(b, "  }\n")
+               p(b, "  o.log(`L%sfor bot`)\n", s)
+               p(b, "  }\n")
+               p(b, "  o.log(`done%s`)\n", s)
+               p(b, "goto L%sb; L%sb: o.log(`L%sb`)\n", s, s, s)
+       }
+       return code
+}
+
+var testCode = `
+func %s() {
+       all := %#v
+       for i := 0; i < len(all); i++ {
+               c := all[i]
+               outFunc := run(%s, c)
+               outSlice := run(%s, c)
+               if !outFunc.eq(outSlice) {
+                       println("mismatch", "%[3]s", "%[4]s", c)
+                       println()
+                       println("func:")
+                       outFunc.print()
+                       println()
+                       println("slice:")
+                       outSlice.print()
+                       panic("mismatch")
+               }
+       }
+       if verbose {
+               println("did", "%[3]s", "%[4]s", len(all))
+       }
+}
+`
+
+var mainCode = `
+
+func main() {
+       if verbose {
+               println("main")
+       }
+       %s
+}
+
+func yield2(yield func(int)bool) { _ = yield(1) && yield(2) }
+var slice2 = []int{1,2}
+
+type output struct {
+       ret int
+       trace []any
+}
+
+func (o *output) log(x any) {
+       o.trace = append(o.trace, x)
+}
+
+func (o *output) log1(x, y any) {
+       o.trace = append(o.trace, x, y)
+}
+
+func (o *output) eq(p *output) bool{
+       if o.ret != p.ret  || len(o.trace) != len(p.trace) {
+               return false
+       }
+       for i ,x := range o.trace {
+               if x != p.trace[i] {
+                       return false
+               }
+       }
+       return true
+}
+
+func (o *output) print() {
+       println("ret", o.ret, "trace-len", len(o.trace))
+       for i := 0; i < len(o.trace); i++ {
+               print("#", i, " ")
+               switch x := o.trace[i].(type) {
+               case int:
+                       print(x)
+               case string:
+                       print(x)
+               default:
+                       print(x)
+               }
+               print("\n")
+       }
+}
+
+func run(f func(*output, int)int, i int) *output {
+       o := &output{}
+       o.ret = f(o, i)
+       return o
+}
+
+`