1 // Copyright 2009 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
12 "cmd/compile/internal/base"
13 "cmd/compile/internal/ir"
14 "cmd/compile/internal/ssagen"
15 "cmd/compile/internal/typecheck"
16 "cmd/compile/internal/types"
20 // walkSwitch walks a switch statement.
21 func walkSwitch(sw *ir.SwitchStmt) {
22 // Guard against double walk, see #25776.
24 return // Was fatal, but eliminating every possible source of double-walking is hard
28 if sw.Tag != nil && sw.Tag.Op() == ir.OTYPESW {
35 // walkSwitchExpr generates an AST implementing sw. sw is an
37 func walkSwitchExpr(sw *ir.SwitchStmt) {
43 // convert switch {...} to switch true {...}
45 cond = ir.NewBool(base.Pos, true)
46 cond = typecheck.Expr(cond)
47 cond = typecheck.DefaultLit(cond, nil)
50 // Given "switch string(byteslice)",
51 // with all cases being side-effect free,
52 // use a zero-cost alias of the byte slice.
53 // Do this before calling walkExpr on cond,
54 // because walkExpr will lower the string
55 // conversion into a runtime call.
56 // See issue 24937 for more discussion.
57 if cond.Op() == ir.OBYTES2STR && allCaseExprsAreSideEffectFree(sw) {
58 cond := cond.(*ir.ConvExpr)
59 cond.SetOp(ir.OBYTES2STRTMP)
62 cond = walkExpr(cond, sw.PtrInit())
63 if cond.Op() != ir.OLITERAL && cond.Op() != ir.ONIL {
64 cond = copyExpr(cond, cond.Type(), &sw.Compiled)
74 var defaultGoto ir.Node
76 for _, ncase := range sw.Cases {
77 label := typecheck.AutoLabel(".s")
78 jmp := ir.NewBranchStmt(ncase.Pos(), ir.OGOTO, label)
80 // Process case dispatch.
81 if len(ncase.List) == 0 {
82 if defaultGoto != nil {
83 base.Fatalf("duplicate default case not detected during typechecking")
88 for i, n1 := range ncase.List {
90 if i < len(ncase.RTypes) {
91 rtype = ncase.RTypes[i]
93 s.Add(ncase.Pos(), n1, rtype, jmp)
97 body.Append(ir.NewLabelStmt(ncase.Pos(), label))
98 body.Append(ncase.Body...)
99 if fall, pos := endsInFallthrough(ncase.Body); !fall {
100 br := ir.NewBranchStmt(base.Pos, ir.OBREAK, nil)
107 if defaultGoto == nil {
108 br := ir.NewBranchStmt(base.Pos, ir.OBREAK, nil)
109 br.SetPos(br.Pos().WithNotStmt())
114 sw.Compiled.Append(defaultGoto)
115 sw.Compiled.Append(body.Take()...)
116 walkStmtList(sw.Compiled)
119 // An exprSwitch walks an expression switch.
120 type exprSwitch struct {
122 exprname ir.Node // value being switched on
128 type exprClause struct {
131 rtype ir.Node // *runtime._type for OEQ node
135 func (s *exprSwitch) Add(pos src.XPos, expr, rtype, jmp ir.Node) {
136 c := exprClause{pos: pos, lo: expr, hi: expr, rtype: rtype, jmp: jmp}
137 if types.IsOrdered[s.exprname.Type().Kind()] && expr.Op() == ir.OLITERAL {
138 s.clauses = append(s.clauses, c)
143 s.clauses = append(s.clauses, c)
147 func (s *exprSwitch) Emit(out *ir.Nodes) {
149 out.Append(s.done.Take()...)
152 func (s *exprSwitch) flush() {
159 // Caution: If len(cc) == 1, then cc[0] might not an OLITERAL.
160 // The code below is structured to implicitly handle this case
161 // (e.g., sort.Slice doesn't need to invoke the less function
162 // when there's only a single slice element).
164 if s.exprname.Type().IsString() && len(cc) >= 2 {
165 // Sort strings by length and then by value. It is
166 // much cheaper to compare lengths than values, and
167 // all we need here is consistency. We respect this
169 sort.Slice(cc, func(i, j int) bool {
170 si := ir.StringVal(cc[i].lo)
171 sj := ir.StringVal(cc[j].lo)
172 if len(si) != len(sj) {
173 return len(si) < len(sj)
178 // runLen returns the string length associated with a
179 // particular run of exprClauses.
180 runLen := func(run []exprClause) int64 { return int64(len(ir.StringVal(run[0].lo))) }
182 // Collapse runs of consecutive strings with the same length.
183 var runs [][]exprClause
185 for i := 1; i < len(cc); i++ {
186 if runLen(cc[start:]) != runLen(cc[i:]) {
187 runs = append(runs, cc[start:i])
191 runs = append(runs, cc[start:])
193 // We have strings of more than one length. Generate an
194 // outer switch which switches on the length of the string
195 // and an inner switch in each case which resolves all the
196 // strings of the same length. The code looks something like this:
200 // ... search among length 5 strings ...
203 // ... search among length 8 strings ...
205 // ... other lengths ...
210 // ... other lengths ...
214 outerLabel := typecheck.AutoLabel(".s")
215 endLabel := typecheck.AutoLabel(".s")
217 // Jump around all the individual switches for each length.
218 s.done.Append(ir.NewBranchStmt(s.pos, ir.OGOTO, outerLabel))
221 outer.exprname = ir.NewUnaryExpr(s.pos, ir.OLEN, s.exprname)
222 outer.exprname.SetType(types.Types[types.TINT])
224 for _, run := range runs {
225 // Target label to jump to when we match this length.
226 label := typecheck.AutoLabel(".s")
228 // Search within this run of same-length strings.
230 s.done.Append(ir.NewLabelStmt(pos, label))
231 stringSearch(s.exprname, run, &s.done)
232 s.done.Append(ir.NewBranchStmt(pos, ir.OGOTO, endLabel))
234 // Add length case to outer switch.
235 cas := ir.NewBasicLit(pos, constant.MakeInt64(runLen(run)))
236 jmp := ir.NewBranchStmt(pos, ir.OGOTO, label)
237 outer.Add(pos, cas, nil, jmp)
239 s.done.Append(ir.NewLabelStmt(s.pos, outerLabel))
241 s.done.Append(ir.NewLabelStmt(s.pos, endLabel))
245 sort.Slice(cc, func(i, j int) bool {
246 return constant.Compare(cc[i].lo.Val(), token.LSS, cc[j].lo.Val())
249 // Merge consecutive integer cases.
250 if s.exprname.Type().IsInteger() {
251 consecutive := func(last, next constant.Value) bool {
252 delta := constant.BinaryOp(next, token.SUB, last)
253 return constant.Compare(delta, token.EQL, constant.MakeInt64(1))
257 for _, c := range cc[1:] {
258 last := &merged[len(merged)-1]
259 if last.jmp == c.jmp && consecutive(last.hi.Val(), c.lo.Val()) {
262 merged = append(merged, c)
268 s.search(cc, &s.done)
271 func (s *exprSwitch) search(cc []exprClause, out *ir.Nodes) {
272 if s.tryJumpTable(cc, out) {
275 binarySearch(len(cc), out,
276 func(i int) ir.Node {
277 return ir.NewBinaryExpr(base.Pos, ir.OLE, s.exprname, cc[i-1].hi)
279 func(i int, nif *ir.IfStmt) {
281 nif.Cond = c.test(s.exprname)
282 nif.Body = []ir.Node{c.jmp}
287 // Try to implement the clauses with a jump table. Returns true if successful.
288 func (s *exprSwitch) tryJumpTable(cc []exprClause, out *ir.Nodes) bool {
289 const minCases = 8 // have at least minCases cases in the switch
290 const minDensity = 4 // use at least 1 out of every minDensity entries
292 if base.Flag.N != 0 || !ssagen.Arch.LinkArch.CanJumpTable || base.Ctxt.Retpoline {
295 if len(cc) < minCases {
296 return false // not enough cases for it to be worth it
298 if cc[0].lo.Val().Kind() != constant.Int {
299 return false // e.g. float
301 if s.exprname.Type().Size() > int64(types.PtrSize) {
302 return false // 64-bit switches on 32-bit archs
304 min := cc[0].lo.Val()
305 max := cc[len(cc)-1].hi.Val()
306 width := constant.BinaryOp(constant.BinaryOp(max, token.SUB, min), token.ADD, constant.MakeInt64(1))
307 limit := constant.MakeInt64(int64(len(cc)) * minDensity)
308 if constant.Compare(width, token.GTR, limit) {
309 // We disable jump tables if we use less than a minimum fraction of the entries.
310 // i.e. for switch x {case 0: case 1000: case 2000:} we don't want to use a jump table.
313 jt := ir.NewJumpTableStmt(base.Pos, s.exprname)
314 for _, c := range cc {
315 jmp := c.jmp.(*ir.BranchStmt)
316 if jmp.Op() != ir.OGOTO || jmp.Label == nil {
317 panic("bad switch case body")
319 for i := c.lo.Val(); constant.Compare(i, token.LEQ, c.hi.Val()); i = constant.BinaryOp(i, token.ADD, constant.MakeInt64(1)) {
320 jt.Cases = append(jt.Cases, i)
321 jt.Targets = append(jt.Targets, jmp.Label)
328 func (c *exprClause) test(exprname ir.Node) ir.Node {
331 low := ir.NewBinaryExpr(c.pos, ir.OGE, exprname, c.lo)
332 high := ir.NewBinaryExpr(c.pos, ir.OLE, exprname, c.hi)
333 return ir.NewLogicalExpr(c.pos, ir.OANDAND, low, high)
336 // Optimize "switch true { ...}" and "switch false { ... }".
337 if ir.IsConst(exprname, constant.Bool) && !c.lo.Type().IsInterface() {
338 if ir.BoolVal(exprname) {
341 return ir.NewUnaryExpr(c.pos, ir.ONOT, c.lo)
345 n := ir.NewBinaryExpr(c.pos, ir.OEQ, exprname, c.lo)
350 func allCaseExprsAreSideEffectFree(sw *ir.SwitchStmt) bool {
351 // In theory, we could be more aggressive, allowing any
352 // side-effect-free expressions in cases, but it's a bit
353 // tricky because some of that information is unavailable due
354 // to the introduction of temporaries during order.
355 // Restricting to constants is simple and probably powerful
358 for _, ncase := range sw.Cases {
359 for _, v := range ncase.List {
360 if v.Op() != ir.OLITERAL {
368 // endsInFallthrough reports whether stmts ends with a "fallthrough" statement.
369 func endsInFallthrough(stmts []ir.Node) (bool, src.XPos) {
371 return false, src.NoXPos
374 return stmts[i].Op() == ir.OFALL, stmts[i].Pos()
377 // walkSwitchType generates an AST that implements sw, where sw is a
379 func walkSwitchType(sw *ir.SwitchStmt) {
381 s.facename = sw.Tag.(*ir.TypeSwitchGuard).X
384 s.facename = walkExpr(s.facename, sw.PtrInit())
385 s.facename = copyExpr(s.facename, s.facename.Type(), &sw.Compiled)
386 s.okname = typecheck.Temp(types.Types[types.TBOOL])
388 // Get interface descriptor word.
389 // For empty interfaces this will be the type.
390 // For non-empty interfaces this will be the itab.
391 itab := ir.NewUnaryExpr(base.Pos, ir.OITAB, s.facename)
393 // For empty interfaces, do:
394 // if e._type == nil {
395 // do nil case if it exists, otherwise default
398 // Use a similar strategy for non-empty interfaces.
399 ifNil := ir.NewIfStmt(base.Pos, nil, nil, nil)
400 ifNil.Cond = ir.NewBinaryExpr(base.Pos, ir.OEQ, itab, typecheck.NodNil())
401 base.Pos = base.Pos.WithNotStmt() // disable statement marks after the first check.
402 ifNil.Cond = typecheck.Expr(ifNil.Cond)
403 ifNil.Cond = typecheck.DefaultLit(ifNil.Cond, nil)
404 // ifNil.Nbody assigned at end.
405 sw.Compiled.Append(ifNil)
407 // Load hash from type or itab.
408 dotHash := typeHashFieldOf(base.Pos, itab)
409 s.hashname = copyExpr(dotHash, dotHash.Type(), &sw.Compiled)
411 br := ir.NewBranchStmt(base.Pos, ir.OBREAK, nil)
412 var defaultGoto, nilGoto ir.Node
414 for _, ncase := range sw.Cases {
417 // For single-type cases with an interface type,
418 // we initialize the case variable as part of the type assertion.
419 // In other cases, we initialize it in the body.
420 var singleType *types.Type
421 if len(ncase.List) == 1 && ncase.List[0].Op() == ir.OTYPE {
422 singleType = ncase.List[0].Type()
424 caseVarInitialized := false
426 label := typecheck.AutoLabel(".s")
427 jmp := ir.NewBranchStmt(ncase.Pos(), ir.OGOTO, label)
429 if len(ncase.List) == 0 { // default:
430 if defaultGoto != nil {
431 base.Fatalf("duplicate default case not detected during typechecking")
436 for _, n1 := range ncase.List {
437 if ir.IsNil(n1) { // case nil:
439 base.Fatalf("duplicate nil case not detected during typechecking")
445 if singleType != nil && singleType.IsInterface() {
446 s.Add(ncase.Pos(), n1, caseVar, jmp)
447 caseVarInitialized = true
449 s.Add(ncase.Pos(), n1, nil, jmp)
453 body.Append(ir.NewLabelStmt(ncase.Pos(), label))
454 if caseVar != nil && !caseVarInitialized {
456 if singleType != nil {
457 // We have a single concrete type. Extract the data.
458 if singleType.IsInterface() {
459 base.Fatalf("singleType interface should have been handled in Add")
461 val = ifaceData(ncase.Pos(), s.facename, singleType)
463 if len(ncase.List) == 1 && ncase.List[0].Op() == ir.ODYNAMICTYPE {
464 dt := ncase.List[0].(*ir.DynamicType)
465 x := ir.NewDynamicTypeAssertExpr(ncase.Pos(), ir.ODYNAMICDOTTYPE, val, dt.RType)
467 x.SetType(caseVar.Type())
472 ir.NewDecl(ncase.Pos(), ir.ODCL, caseVar),
473 ir.NewAssignStmt(ncase.Pos(), caseVar, val),
478 body.Append(ncase.Body...)
483 if defaultGoto == nil {
487 nilGoto = defaultGoto
489 ifNil.Body = []ir.Node{nilGoto}
492 sw.Compiled.Append(defaultGoto)
493 sw.Compiled.Append(body.Take()...)
495 walkStmtList(sw.Compiled)
498 // typeHashFieldOf returns an expression to select the type hash field
499 // from an interface's descriptor word (whether a *runtime._type or
500 // *runtime.itab pointer).
501 func typeHashFieldOf(pos src.XPos, itab *ir.UnaryExpr) *ir.SelectorExpr {
502 if itab.Op() != ir.OITAB {
503 base.Fatalf("expected OITAB, got %v", itab.Op())
505 var hashField *types.Field
506 if itab.X.Type().IsEmptyInterface() {
507 // runtime._type's hash field
508 if rtypeHashField == nil {
509 rtypeHashField = runtimeField("hash", int64(2*types.PtrSize), types.Types[types.TUINT32])
511 hashField = rtypeHashField
513 // runtime.itab's hash field
514 if itabHashField == nil {
515 itabHashField = runtimeField("hash", int64(2*types.PtrSize), types.Types[types.TUINT32])
517 hashField = itabHashField
519 return boundedDotPtr(pos, itab, hashField)
522 var rtypeHashField, itabHashField *types.Field
524 // A typeSwitch walks a type switch.
525 type typeSwitch struct {
526 // Temporary variables (i.e., ONAMEs) used by type switch dispatch logic:
527 facename ir.Node // value being type-switched on
528 hashname ir.Node // type hash of the value being type-switched on
529 okname ir.Node // boolean used for comma-ok type assertions
535 type typeClause struct {
540 func (s *typeSwitch) Add(pos src.XPos, n1 ir.Node, caseVar *ir.Name, jmp ir.Node) {
545 ir.NewDecl(pos, ir.ODCL, caseVar),
546 ir.NewAssignStmt(pos, caseVar, nil),
551 caseVar = ir.BlankNode
554 // cv, ok = iface.(type)
555 as := ir.NewAssignListStmt(pos, ir.OAS2, nil, nil)
556 as.Lhs = []ir.Node{caseVar, s.okname} // cv, ok =
559 // Static type assertion (non-generic)
560 dot := ir.NewTypeAssertExpr(pos, s.facename, typ) // iface.(type)
561 as.Rhs = []ir.Node{dot}
562 case ir.ODYNAMICTYPE:
563 // Dynamic type assertion (generic)
564 dt := n1.(*ir.DynamicType)
565 dot := ir.NewDynamicTypeAssertExpr(pos, ir.ODYNAMICDOTTYPE, s.facename, dt.RType)
569 as.Rhs = []ir.Node{dot}
571 base.Fatalf("unhandled type case %s", n1.Op())
573 appendWalkStmt(&body, as)
575 // if ok { goto label }
576 nif := ir.NewIfStmt(pos, nil, nil, nil)
578 nif.Body = []ir.Node{jmp}
581 if n1.Op() == ir.OTYPE && !typ.IsInterface() {
582 // Defer static, noninterface cases so they can be binary searched by hash.
583 s.clauses = append(s.clauses, typeClause{
584 hash: types.TypeHash(n1.Type()),
591 s.done.Append(body.Take()...)
594 func (s *typeSwitch) Emit(out *ir.Nodes) {
596 out.Append(s.done.Take()...)
599 func (s *typeSwitch) flush() {
606 sort.Slice(cc, func(i, j int) bool { return cc[i].hash < cc[j].hash })
608 // Combine adjacent cases with the same hash.
610 for _, c := range cc[1:] {
611 last := &merged[len(merged)-1]
612 if last.hash == c.hash {
613 last.body.Append(c.body.Take()...)
615 merged = append(merged, c)
620 // TODO: figure out if we could use a jump table using some low bits of the type hashes.
621 binarySearch(len(cc), &s.done,
622 func(i int) ir.Node {
623 return ir.NewBinaryExpr(base.Pos, ir.OLE, s.hashname, ir.NewInt(base.Pos, int64(cc[i-1].hash)))
625 func(i int, nif *ir.IfStmt) {
626 // TODO(mdempsky): Omit hash equality check if
627 // there's only one type.
629 nif.Cond = ir.NewBinaryExpr(base.Pos, ir.OEQ, s.hashname, ir.NewInt(base.Pos, int64(c.hash)))
630 nif.Body.Append(c.body.Take()...)
635 // binarySearch constructs a binary search tree for handling n cases,
636 // and appends it to out. It's used for efficiently implementing
637 // switch statements.
639 // less(i) should return a boolean expression. If it evaluates true,
640 // then cases before i will be tested; otherwise, cases i and later.
642 // leaf(i, nif) should setup nif (an OIF node) to test case i. In
643 // particular, it should set nif.Cond and nif.Body.
644 func binarySearch(n int, out *ir.Nodes, less func(i int) ir.Node, leaf func(i int, nif *ir.IfStmt)) {
645 const binarySearchMin = 4 // minimum number of cases for binary search
647 var do func(lo, hi int, out *ir.Nodes)
648 do = func(lo, hi int, out *ir.Nodes) {
650 if n < binarySearchMin {
651 for i := lo; i < hi; i++ {
652 nif := ir.NewIfStmt(base.Pos, nil, nil, nil)
654 base.Pos = base.Pos.WithNotStmt()
655 nif.Cond = typecheck.Expr(nif.Cond)
656 nif.Cond = typecheck.DefaultLit(nif.Cond, nil)
664 nif := ir.NewIfStmt(base.Pos, nil, nil, nil)
665 nif.Cond = less(half)
666 base.Pos = base.Pos.WithNotStmt()
667 nif.Cond = typecheck.Expr(nif.Cond)
668 nif.Cond = typecheck.DefaultLit(nif.Cond, nil)
669 do(lo, half, &nif.Body)
670 do(half, hi, &nif.Else)
677 func stringSearch(expr ir.Node, cc []exprClause, out *ir.Nodes) {
679 // Short list, just do brute force equality checks.
680 for _, c := range cc {
681 nif := ir.NewIfStmt(base.Pos.WithNotStmt(), typecheck.DefaultLit(typecheck.Expr(c.test(expr)), nil), []ir.Node{c.jmp}, nil)
688 // The strategy here is to find a simple test to divide the set of possible strings
689 // that might match expr approximately in half.
690 // The test we're going to use is to do an ordered comparison of a single byte
691 // of expr to a constant. We will pick the index of that byte and the value we're
692 // comparing against to make the split as even as possible.
693 // if expr[3] <= 'd' { ... search strings with expr[3] at 'd' or lower ... }
694 // else { ... search strings with expr[3] at 'e' or higher ... }
696 // To add complication, we will do the ordered comparison in the signed domain.
697 // The reason for this is to prevent CSE from merging the load used for the
698 // ordered comparison with the load used for the later equality check.
699 // if expr[3] <= 'd' { ... if expr[0] == 'f' && expr[1] == 'o' && expr[2] == 'o' && expr[3] == 'd' { ... } }
700 // If we did both expr[3] loads in the unsigned domain, they would be CSEd, and that
701 // would in turn defeat the combining of expr[0]...expr[3] into a single 4-byte load.
703 // By using signed loads for the ordered comparison and unsigned loads for the
704 // equality comparison, they don't get CSEd and the equality comparisons will be
705 // done using wider loads.
707 n := len(ir.StringVal(cc[0].lo)) // Length of the constant strings.
708 bestScore := int64(0) // measure of how good the split is.
709 bestIdx := 0 // split using expr[bestIdx]
710 bestByte := int8(0) // compare expr[bestIdx] against bestByte
711 for idx := 0; idx < n; idx++ {
712 for b := int8(-128); b < 127; b++ {
714 for _, c := range cc {
715 s := ir.StringVal(c.lo)
716 if int8(s[idx]) <= b {
720 score := int64(le) * int64(len(cc)-le)
721 if score > bestScore {
729 // The split must be at least 1:n-1 because we have at least 2 distinct strings; they
730 // have to be different somewhere.
731 // TODO: what if the best split is still pretty bad?
733 base.Fatalf("unable to split string set")
736 // Convert expr to a []int8
737 slice := ir.NewConvExpr(base.Pos, ir.OSTR2BYTESTMP, types.NewSlice(types.Types[types.TINT8]), expr)
738 slice.SetTypecheck(1) // legacy typechecker doesn't handle this op
739 // Load the byte we're splitting on.
740 load := ir.NewIndexExpr(base.Pos, slice, ir.NewInt(base.Pos, int64(bestIdx)))
741 // Compare with the value we're splitting on.
742 cmp := ir.Node(ir.NewBinaryExpr(base.Pos, ir.OLE, load, ir.NewInt(base.Pos, int64(bestByte))))
743 cmp = typecheck.DefaultLit(typecheck.Expr(cmp), nil)
744 nif := ir.NewIfStmt(base.Pos, cmp, nil, nil)
748 for _, c := range cc {
749 s := ir.StringVal(c.lo)
750 if int8(s[bestIdx]) <= bestByte {
756 stringSearch(expr, le, &nif.Body)
757 stringSearch(expr, gt, &nif.Else)
760 // TODO: if expr[bestIdx] has enough different possible values, use a jump table.