]> Cypherpunks.ru repositories - gostls13.git/blobdiff - src/cmd/compile/internal/walk/switch.go
cmd/compile: improve interface type switches
[gostls13.git] / src / cmd / compile / internal / walk / switch.go
index 2fc8aefe5fa7e34f8ae90d4ac8633937ddfc9fd4..2f7eb5486c5ab970ea9b4b7c7ee8eb0c012d62db 100644 (file)
@@ -5,6 +5,7 @@
 package walk
 
 import (
+       "fmt"
        "go/constant"
        "go/token"
        "math/bits"
@@ -12,9 +13,12 @@ import (
 
        "cmd/compile/internal/base"
        "cmd/compile/internal/ir"
+       "cmd/compile/internal/objw"
+       "cmd/compile/internal/reflectdata"
        "cmd/compile/internal/ssagen"
        "cmd/compile/internal/typecheck"
        "cmd/compile/internal/types"
+       "cmd/internal/obj"
        "cmd/internal/src"
 )
 
@@ -379,17 +383,19 @@ func endsInFallthrough(stmts []ir.Node) (bool, src.XPos) {
 // type switch.
 func walkSwitchType(sw *ir.SwitchStmt) {
        var s typeSwitch
-       s.facename = sw.Tag.(*ir.TypeSwitchGuard).X
-       sw.Tag = nil
-
-       s.facename = walkExpr(s.facename, sw.PtrInit())
-       s.facename = copyExpr(s.facename, s.facename.Type(), &sw.Compiled)
-       s.okname = typecheck.TempAt(base.Pos, ir.CurFunc, types.Types[types.TBOOL])
+       s.srcName = sw.Tag.(*ir.TypeSwitchGuard).X
+       s.srcName = walkExpr(s.srcName, sw.PtrInit())
+       s.srcName = copyExpr(s.srcName, s.srcName.Type(), &sw.Compiled)
+       s.okName = typecheck.TempAt(base.Pos, ir.CurFunc, types.Types[types.TBOOL])
+       s.itabName = typecheck.TempAt(base.Pos, ir.CurFunc, types.Types[types.TUINT8].PtrTo())
 
        // Get interface descriptor word.
        // For empty interfaces this will be the type.
        // For non-empty interfaces this will be the itab.
-       itab := ir.NewUnaryExpr(base.Pos, ir.OITAB, s.facename)
+       srcItab := ir.NewUnaryExpr(base.Pos, ir.OITAB, s.srcName)
+       srcData := ir.NewUnaryExpr(base.Pos, ir.OIDATA, s.srcName)
+       srcData.SetType(types.Types[types.TUINT8].PtrTo())
+       srcData.SetTypecheck(1)
 
        // For empty interfaces, do:
        //     if e._type == nil {
@@ -398,42 +404,49 @@ func walkSwitchType(sw *ir.SwitchStmt) {
        //     h := e._type.hash
        // Use a similar strategy for non-empty interfaces.
        ifNil := ir.NewIfStmt(base.Pos, nil, nil, nil)
-       ifNil.Cond = ir.NewBinaryExpr(base.Pos, ir.OEQ, itab, typecheck.NodNil())
+       ifNil.Cond = ir.NewBinaryExpr(base.Pos, ir.OEQ, srcItab, typecheck.NodNil())
        base.Pos = base.Pos.WithNotStmt() // disable statement marks after the first check.
        ifNil.Cond = typecheck.Expr(ifNil.Cond)
        ifNil.Cond = typecheck.DefaultLit(ifNil.Cond, nil)
-       // ifNil.Nbody assigned at end.
+       // ifNil.Nbody assigned later.
        sw.Compiled.Append(ifNil)
 
        // Load hash from type or itab.
-       dotHash := typeHashFieldOf(base.Pos, itab)
-       s.hashname = copyExpr(dotHash, dotHash.Type(), &sw.Compiled)
+       dotHash := typeHashFieldOf(base.Pos, srcItab)
+       s.hashName = copyExpr(dotHash, dotHash.Type(), &sw.Compiled)
+
+       // Make a label for each case body.
+       labels := make([]*types.Sym, len(sw.Cases))
+       for i := range sw.Cases {
+               labels[i] = typecheck.AutoLabel(".s")
+       }
 
+       // "jump" to execute if no case matches.
        br := ir.NewBranchStmt(base.Pos, ir.OBREAK, nil)
-       var defaultGoto, nilGoto ir.Node
-       var body ir.Nodes
-       for _, ncase := range sw.Cases {
-               caseVar := ncase.Var
-
-               // For single-type cases with an interface type,
-               // we initialize the case variable as part of the type assertion.
-               // In other cases, we initialize it in the body.
-               var singleType *types.Type
-               if len(ncase.List) == 1 && ncase.List[0].Op() == ir.OTYPE {
-                       singleType = ncase.List[0].Type()
-               }
-               caseVarInitialized := false
 
-               label := typecheck.AutoLabel(".s")
-               jmp := ir.NewBranchStmt(ncase.Pos(), ir.OGOTO, label)
+       // Assemble a list of all the types we're looking for.
+       // This pass flattens the case lists, as well as handles
+       // some unusual cases, like default and nil cases.
+       type oneCase struct {
+               pos src.XPos
+               jmp ir.Node // jump to body of selected case
 
+               // The case we're matching. Normally the type we're looking for
+               // is typ.Type(), but when typ is ODYNAMICTYPE the actual type
+               // we're looking for is not a compile-time constant (typ.Type()
+               // will be its shape).
+               typ ir.Node
+       }
+       var cases []oneCase
+       var defaultGoto, nilGoto ir.Node
+       for i, ncase := range sw.Cases {
+               jmp := ir.NewBranchStmt(ncase.Pos(), ir.OGOTO, labels[i])
                if len(ncase.List) == 0 { // default:
                        if defaultGoto != nil {
                                base.Fatalf("duplicate default case not detected during typechecking")
                        }
                        defaultGoto = jmp
                }
-
                for _, n1 := range ncase.List {
                        if ir.IsNil(n1) { // case nil:
                                if nilGoto != nil {
@@ -442,60 +455,228 @@ func walkSwitchType(sw *ir.SwitchStmt) {
                                nilGoto = jmp
                                continue
                        }
+                       if n1.Op() == ir.ODYNAMICTYPE {
+                               // Convert dynamic to static, if the dynamic is actually static.
+                               // TODO: why isn't this OTYPE to begin with?
+                               dt := n1.(*ir.DynamicType)
+                               if dt.RType != nil && dt.RType.Op() == ir.OADDR {
+                                       addr := dt.RType.(*ir.AddrExpr)
+                                       if addr.X.Op() == ir.OLINKSYMOFFSET {
+                                               n1 = ir.TypeNode(n1.Type())
+                                       }
+                               }
+                               if dt.ITab != nil && dt.ITab.Op() == ir.OADDR {
+                                       addr := dt.ITab.(*ir.AddrExpr)
+                                       if addr.X.Op() == ir.OLINKSYMOFFSET {
+                                               n1 = ir.TypeNode(n1.Type())
+                                       }
+                               }
+                       }
+                       cases = append(cases, oneCase{
+                               pos: ncase.Pos(),
+                               typ: n1,
+                               jmp: jmp,
+                       })
+               }
+       }
+       if defaultGoto == nil {
+               defaultGoto = br
+       }
+       if nilGoto == nil {
+               nilGoto = defaultGoto
+       }
+       ifNil.Body = []ir.Node{nilGoto}
+
+       // Now go through the list of cases, processing groups as we find them.
+       var concreteCases []oneCase
+       var interfaceCases []oneCase
+       flush := func() {
+               // Process all the concrete types first. Because we handle shadowing
+               // below, it is correct to do all the concrete types before all of
+               // the interface types.
+               // The concrete cases can all be handled without a runtime call.
+               if len(concreteCases) > 0 {
+                       var clauses []typeClause
+                       for _, c := range concreteCases {
+                               as := ir.NewAssignListStmt(c.pos, ir.OAS2,
+                                       []ir.Node{ir.BlankNode, s.okName},                               // _, ok =
+                                       []ir.Node{ir.NewTypeAssertExpr(c.pos, s.srcName, c.typ.Type())}) // iface.(type)
+                               nif := ir.NewIfStmt(c.pos, s.okName, []ir.Node{c.jmp}, nil)
+                               clauses = append(clauses, typeClause{
+                                       hash: types.TypeHash(c.typ.Type()),
+                                       body: []ir.Node{typecheck.Stmt(as), typecheck.Stmt(nif)},
+                               })
+                       }
+                       s.flush(clauses, &sw.Compiled)
+                       concreteCases = concreteCases[:0]
+               }
+
+               // The "any" case, if it exists, must be the last interface case, because
+               // it would shadow all subsequent cases. Strip it off here so the runtime
+               // call only needs to handle non-empty interfaces.
+               var anyGoto ir.Node
+               if len(interfaceCases) > 0 && interfaceCases[len(interfaceCases)-1].typ.Type().IsEmptyInterface() {
+                       anyGoto = interfaceCases[len(interfaceCases)-1].jmp
+                       interfaceCases = interfaceCases[:len(interfaceCases)-1]
+               }
 
-                       if singleType != nil && singleType.IsInterface() {
-                               s.Add(ncase.Pos(), n1, caseVar, jmp)
-                               caseVarInitialized = true
+               // Next, process all the interface types with a single call to the runtime.
+               if len(interfaceCases) > 0 {
+
+                       // Build an internal/abi.InterfaceSwitch descriptor to pass to the runtime.
+                       lsym := types.LocalPkg.Lookup(fmt.Sprintf(".interfaceSwitch.%d", interfaceSwitchGen)).LinksymABI(obj.ABI0)
+                       interfaceSwitchGen++
+                       off := 0
+                       off = objw.Uintptr(lsym, off, uint64(len(interfaceCases)))
+                       for _, c := range interfaceCases {
+                               off = objw.SymPtr(lsym, off, reflectdata.TypeSym(c.typ.Type()).Linksym(), 0)
+                       }
+                       // Note: it has pointers, just not ones the GC cares about.
+                       objw.Global(lsym, int32(off), obj.LOCAL|obj.NOPTR)
+
+                       // Call runtime to do switch
+                       // case, itab = runtime.interfaceSwitch(&descriptor, typeof(arg))
+                       var typeArg ir.Node
+                       if s.srcName.Type().IsEmptyInterface() {
+                               typeArg = ir.NewConvExpr(base.Pos, ir.OCONVNOP, types.Types[types.TUINT8].PtrTo(), srcItab)
                        } else {
-                               s.Add(ncase.Pos(), n1, nil, jmp)
+                               typeArg = itabType(srcItab)
+                       }
+                       caseVar := typecheck.TempAt(base.Pos, ir.CurFunc, types.Types[types.TINT])
+                       isw := ir.NewInterfaceSwitchStmt(base.Pos, caseVar, s.itabName, typeArg, lsym)
+                       sw.Compiled.Append(isw)
+
+                       // Switch on the result of the call.
+                       var newCases []*ir.CaseClause
+                       for i, c := range interfaceCases {
+                               newCases = append(newCases, &ir.CaseClause{
+                                       List: []ir.Node{ir.NewInt(base.Pos, int64(i))},
+                                       Body: []ir.Node{c.jmp},
+                               })
                        }
+                       // TODO: add len(newCases) case, mark switch as bounded
+                       sw2 := ir.NewSwitchStmt(base.Pos, caseVar, newCases)
+                       sw.Compiled.Append(typecheck.Stmt(sw2))
+                       interfaceCases = interfaceCases[:0]
                }
 
-               body.Append(ir.NewLabelStmt(ncase.Pos(), label))
-               if caseVar != nil && !caseVarInitialized {
-                       val := s.facename
-                       if singleType != nil {
-                               // We have a single concrete type. Extract the data.
-                               if singleType.IsInterface() {
-                                       base.Fatalf("singleType interface should have been handled in Add")
-                               }
-                               val = ifaceData(ncase.Pos(), s.facename, singleType)
+               if anyGoto != nil {
+                       // We've already handled the nil case, so everything
+                       // that reaches here matches the "any" case.
+                       sw.Compiled.Append(anyGoto)
+               }
+       }
+caseLoop:
+       for _, c := range cases {
+               if c.typ.Op() == ir.ODYNAMICTYPE {
+                       flush() // process all previous cases
+                       dt := c.typ.(*ir.DynamicType)
+                       dot := ir.NewDynamicTypeAssertExpr(c.pos, ir.ODYNAMICDOTTYPE, s.srcName, dt.RType)
+                       dot.ITab = dt.ITab
+                       dot.SetType(c.typ.Type())
+                       dot.SetTypecheck(1)
+
+                       as := ir.NewAssignListStmt(c.pos, ir.OAS2, nil, nil)
+                       as.Lhs = []ir.Node{ir.BlankNode, s.okName} // _, ok =
+                       as.Rhs = []ir.Node{dot}
+                       typecheck.Stmt(as)
+
+                       nif := ir.NewIfStmt(c.pos, s.okName, []ir.Node{c.jmp}, nil)
+                       sw.Compiled.Append(as, nif)
+                       continue
+               }
+
+               // Check for shadowing (a case that will never fire because
+               // a previous case would have always fired first). This check
+               // allows us to reorder concrete and interface cases.
+               // (TODO: these should be vet failures, maybe?)
+               for _, ic := range interfaceCases {
+                       // An interface type case will shadow all
+                       // subsequent types that implement that interface.
+                       if typecheck.Implements(c.typ.Type(), ic.typ.Type()) {
+                               continue caseLoop
                        }
-                       if len(ncase.List) == 1 && ncase.List[0].Op() == ir.ODYNAMICTYPE {
-                               dt := ncase.List[0].(*ir.DynamicType)
-                               x := ir.NewDynamicTypeAssertExpr(ncase.Pos(), ir.ODYNAMICDOTTYPE, val, dt.RType)
-                               x.ITab = dt.ITab
-                               x.SetType(caseVar.Type())
-                               x.SetTypecheck(1)
-                               val = x
+                       // Note that we don't need to worry about:
+                       // 1. Two concrete types shadowing each other. That's
+                       //    disallowed by the spec.
+                       // 2. A concrete type shadowing an interface type.
+                       //    That can never happen, as interface types can
+                       //    be satisfied by an infinite set of concrete types.
+                       // The correctness of this step also depends on handling
+                       // the dynamic type cases separately, as we do above.
+               }
+
+               if c.typ.Type().IsInterface() {
+                       interfaceCases = append(interfaceCases, c)
+               } else {
+                       concreteCases = append(concreteCases, c)
+               }
+       }
+       flush()
+
+       sw.Compiled.Append(defaultGoto) // if none of the cases matched
+
+       // Now generate all the case bodies
+       for i, ncase := range sw.Cases {
+               sw.Compiled.Append(ir.NewLabelStmt(ncase.Pos(), labels[i]))
+               if caseVar := ncase.Var; caseVar != nil {
+                       val := s.srcName
+                       if len(ncase.List) == 1 {
+                               // single type. We have to downcast the input value to the target type.
+                               if ncase.List[0].Op() == ir.OTYPE { // single compile-time known type
+                                       t := ncase.List[0].Type()
+                                       if t.IsInterface() {
+                                               // This case is an interface. Build case value from input interface.
+                                               // The data word will always be the same, but the itab/type changes.
+                                               if t.IsEmptyInterface() {
+                                                       var typ ir.Node
+                                                       if s.srcName.Type().IsEmptyInterface() {
+                                                               // E->E, nothing to do, type is already correct.
+                                                               typ = srcItab
+                                                       } else {
+                                                               // I->E, load type out of itab
+                                                               typ = itabType(srcItab)
+                                                               typ.SetPos(ncase.Pos())
+                                                       }
+                                                       val = ir.NewBinaryExpr(ncase.Pos(), ir.OMAKEFACE, typ, srcData)
+                                               } else {
+                                                       // The itab we need was returned by a runtime.interfaceSwitch call.
+                                                       val = ir.NewBinaryExpr(ncase.Pos(), ir.OMAKEFACE, s.itabName, srcData)
+                                               }
+                                       } else {
+                                               // This case is a concrete type, just read its value out of the interface.
+                                               val = ifaceData(ncase.Pos(), s.srcName, t)
+                                       }
+                               } else if ncase.List[0].Op() == ir.ODYNAMICTYPE { // single runtime known type
+                                       dt := ncase.List[0].(*ir.DynamicType)
+                                       x := ir.NewDynamicTypeAssertExpr(ncase.Pos(), ir.ODYNAMICDOTTYPE, val, dt.RType)
+                                       x.ITab = dt.ITab
+                                       val = x
+                               } else if ir.IsNil(ncase.List[0]) {
+                               } else {
+                                       base.Fatalf("unhandled type switch case %v", ncase.List[0])
+                               }
+                               val.SetType(caseVar.Type())
+                               val.SetTypecheck(1)
                        }
                        l := []ir.Node{
                                ir.NewDecl(ncase.Pos(), ir.ODCL, caseVar),
                                ir.NewAssignStmt(ncase.Pos(), caseVar, val),
                        }
                        typecheck.Stmts(l)
-                       body.Append(l...)
+                       sw.Compiled.Append(l...)
                }
-               body.Append(ncase.Body...)
-               body.Append(br)
-       }
-       sw.Cases = nil
-
-       if defaultGoto == nil {
-               defaultGoto = br
+               sw.Compiled.Append(ncase.Body...)
+               sw.Compiled.Append(br)
        }
-       if nilGoto == nil {
-               nilGoto = defaultGoto
-       }
-       ifNil.Body = []ir.Node{nilGoto}
-
-       s.Emit(&sw.Compiled)
-       sw.Compiled.Append(defaultGoto)
-       sw.Compiled.Append(body.Take()...)
 
        walkStmtList(sw.Compiled)
+       sw.Tag = nil
+       sw.Cases = nil
 }
 
+var interfaceSwitchGen int
+
 // typeHashFieldOf returns an expression to select the type hash field
 // from an interface's descriptor word (whether a *runtime._type or
 // *runtime.itab pointer).
@@ -525,12 +706,10 @@ var rtypeHashField, itabHashField *types.Field
 // A typeSwitch walks a type switch.
 type typeSwitch struct {
        // Temporary variables (i.e., ONAMEs) used by type switch dispatch logic:
-       facename ir.Node // value being type-switched on
-       hashname ir.Node // type hash of the value being type-switched on
-       okname   ir.Node // boolean used for comma-ok type assertions
-
-       done    ir.Nodes
-       clauses []typeClause
+       srcName  ir.Node // value being type-switched on
+       hashName ir.Node // type hash of the value being type-switched on
+       okName   ir.Node // boolean used for comma-ok type assertions
+       itabName ir.Node // itab value to use for first word of non-empty interface
 }
 
 type typeClause struct {
@@ -538,68 +717,7 @@ type typeClause struct {
        body ir.Nodes
 }
 
-func (s *typeSwitch) Add(pos src.XPos, n1 ir.Node, caseVar *ir.Name, jmp ir.Node) {
-       typ := n1.Type()
-       var body ir.Nodes
-       if caseVar != nil {
-               l := []ir.Node{
-                       ir.NewDecl(pos, ir.ODCL, caseVar),
-                       ir.NewAssignStmt(pos, caseVar, nil),
-               }
-               typecheck.Stmts(l)
-               body.Append(l...)
-       } else {
-               caseVar = ir.BlankNode
-       }
-
-       // cv, ok = iface.(type)
-       as := ir.NewAssignListStmt(pos, ir.OAS2, nil, nil)
-       as.Lhs = []ir.Node{caseVar, s.okname} // cv, ok =
-       switch n1.Op() {
-       case ir.OTYPE:
-               // Static type assertion (non-generic)
-               dot := ir.NewTypeAssertExpr(pos, s.facename, typ) // iface.(type)
-               as.Rhs = []ir.Node{dot}
-       case ir.ODYNAMICTYPE:
-               // Dynamic type assertion (generic)
-               dt := n1.(*ir.DynamicType)
-               dot := ir.NewDynamicTypeAssertExpr(pos, ir.ODYNAMICDOTTYPE, s.facename, dt.RType)
-               dot.ITab = dt.ITab
-               dot.SetType(typ)
-               dot.SetTypecheck(1)
-               as.Rhs = []ir.Node{dot}
-       default:
-               base.Fatalf("unhandled type case %s", n1.Op())
-       }
-       appendWalkStmt(&body, as)
-
-       // if ok { goto label }
-       nif := ir.NewIfStmt(pos, nil, nil, nil)
-       nif.Cond = s.okname
-       nif.Body = []ir.Node{jmp}
-       body.Append(nif)
-
-       if n1.Op() == ir.OTYPE && !typ.IsInterface() {
-               // Defer static, noninterface cases so they can be binary searched by hash.
-               s.clauses = append(s.clauses, typeClause{
-                       hash: types.TypeHash(n1.Type()),
-                       body: body,
-               })
-               return
-       }
-
-       s.flush()
-       s.done.Append(body.Take()...)
-}
-
-func (s *typeSwitch) Emit(out *ir.Nodes) {
-       s.flush()
-       out.Append(s.done.Take()...)
-}
-
-func (s *typeSwitch) flush() {
-       cc := s.clauses
-       s.clauses = nil
+func (s *typeSwitch) flush(cc []typeClause, compiled *ir.Nodes) {
        if len(cc) == 0 {
                return
        }
@@ -618,18 +736,18 @@ func (s *typeSwitch) flush() {
        }
        cc = merged
 
-       if s.tryJumpTable(cc, &s.done) {
+       if s.tryJumpTable(cc, compiled) {
                return
        }
-       binarySearch(len(cc), &s.done,
+       binarySearch(len(cc), compiled,
                func(i int) ir.Node {
-                       return ir.NewBinaryExpr(base.Pos, ir.OLE, s.hashname, ir.NewInt(base.Pos, int64(cc[i-1].hash)))
+                       return ir.NewBinaryExpr(base.Pos, ir.OLE, s.hashName, ir.NewInt(base.Pos, int64(cc[i-1].hash)))
                },
                func(i int, nif *ir.IfStmt) {
                        // TODO(mdempsky): Omit hash equality check if
                        // there's only one type.
                        c := cc[i]
-                       nif.Cond = ir.NewBinaryExpr(base.Pos, ir.OEQ, s.hashname, ir.NewInt(base.Pos, int64(c.hash)))
+                       nif.Cond = ir.NewBinaryExpr(base.Pos, ir.OEQ, s.hashName, ir.NewInt(base.Pos, int64(c.hash)))
                        nif.Body.Append(c.body.Take()...)
                },
        )
@@ -670,7 +788,7 @@ func (s *typeSwitch) tryJumpTable(cc []typeClause, out *ir.Nodes) bool {
                        }
 
                        // All hashes are distinct. Use these values of b and i.
-                       h := s.hashname
+                       h := s.hashName
                        if i != 0 {
                                h = ir.NewBinaryExpr(base.Pos, ir.ORSH, h, ir.NewInt(base.Pos, int64(i)))
                        }