]> Cypherpunks.ru repositories - gostls13.git/commitdiff
cmd/compile: use jump tables for large type switches
authorKeith Randall <khr@golang.org>
Thu, 17 Aug 2023 21:20:21 +0000 (14:20 -0700)
committerGopher Robot <gobot@golang.org>
Wed, 23 Aug 2023 00:30:54 +0000 (00:30 +0000)
For large interface -> concrete type switches, we can use a jump
table on some bits of the type hash instead of a binary search on
the type hash.

name                        old time/op  new time/op  delta
SwitchTypePredictable-24    1.99ns ± 2%  1.78ns ± 5%  -10.87%  (p=0.000 n=10+10)
SwitchTypeUnpredictable-24  11.0ns ± 1%   9.1ns ± 2%  -17.55%  (p=0.000 n=7+9)

Change-Id: Ida4768e5d62c3ce1c2701288b72664aaa9e64259
Reviewed-on: https://go-review.googlesource.com/c/go/+/521497
Reviewed-by: Keith Randall <khr@google.com>
Auto-Submit: Keith Randall <khr@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Cherry Mui <cherryyz@google.com>
Run-TryBot: Keith Randall <khr@golang.org>

src/cmd/compile/internal/test/switch_test.go
src/cmd/compile/internal/walk/switch.go
test/codegen/switch.go

index 30dee6257e3804164ff9daa133a423bf2f8e885a..ddb39bfe5fa32179ed63c2db38ec50fbdf1625f4 100644 (file)
@@ -120,6 +120,48 @@ func benchmarkSwitchString(b *testing.B, predictable bool) {
        sink = n
 }
 
+func BenchmarkSwitchTypePredictable(b *testing.B) {
+       benchmarkSwitchType(b, true)
+}
+func BenchmarkSwitchTypeUnpredictable(b *testing.B) {
+       benchmarkSwitchType(b, false)
+}
+func benchmarkSwitchType(b *testing.B, predictable bool) {
+       a := []any{
+               int8(1),
+               int16(2),
+               int32(3),
+               int64(4),
+               uint8(5),
+               uint16(6),
+               uint32(7),
+               uint64(8),
+       }
+       n := 0
+       rng := newRNG()
+       for i := 0; i < b.N; i++ {
+               rng = rng.next(predictable)
+               switch a[rng.value()&7].(type) {
+               case int8:
+                       n += 1
+               case int16:
+                       n += 2
+               case int32:
+                       n += 3
+               case int64:
+                       n += 4
+               case uint8:
+                       n += 5
+               case uint16:
+                       n += 6
+               case uint32:
+                       n += 7
+               case uint64:
+                       n += 8
+               }
+       }
+}
+
 // A simple random number generator used to make switches conditionally predictable.
 type rng uint64
 
index ebd3128251aa3d3d887bc6a457d006d9bae75c21..67ccb2e5d166748ed6e8c24b69f6ee2374fbc470 100644 (file)
@@ -7,6 +7,7 @@ package walk
 import (
        "go/constant"
        "go/token"
+       "math/bits"
        "sort"
 
        "cmd/compile/internal/base"
@@ -617,7 +618,9 @@ func (s *typeSwitch) flush() {
        }
        cc = merged
 
-       // TODO: figure out if we could use a jump table using some low bits of the type hashes.
+       if s.tryJumpTable(cc, &s.done) {
+               return
+       }
        binarySearch(len(cc), &s.done,
                func(i int) ir.Node {
                        return ir.NewBinaryExpr(base.Pos, ir.OLE, s.hashname, ir.NewInt(base.Pos, int64(cc[i-1].hash)))
@@ -632,6 +635,83 @@ func (s *typeSwitch) flush() {
        )
 }
 
+// Try to implement the clauses with a jump table. Returns true if successful.
+func (s *typeSwitch) tryJumpTable(cc []typeClause, out *ir.Nodes) bool {
+       const minCases = 5 // have at least minCases cases in the switch
+       if base.Flag.N != 0 || !ssagen.Arch.LinkArch.CanJumpTable || base.Ctxt.Retpoline {
+               return false
+       }
+       if len(cc) < minCases {
+               return false // not enough cases for it to be worth it
+       }
+       hashes := make([]uint32, len(cc))
+       // b = # of bits to use. Start with the minimum number of
+       // bits possible, but try a few larger sizes if needed.
+       b0 := bits.Len(uint(len(cc) - 1))
+       for b := b0; b < b0+3; b++ {
+       pickI:
+               for i := 0; i <= 32-b; i++ { // starting bit position
+                       // Compute the hash we'd get from all the cases,
+                       // selecting b bits starting at bit i.
+                       hashes = hashes[:0]
+                       for _, c := range cc {
+                               h := c.hash >> i & (1<<b - 1)
+                               hashes = append(hashes, h)
+                       }
+                       // Order by increasing hash.
+                       sort.Slice(hashes, func(j, k int) bool {
+                               return hashes[j] < hashes[k]
+                       })
+                       for j := 1; j < len(hashes); j++ {
+                               if hashes[j] == hashes[j-1] {
+                                       // There is a duplicate hash; try a different b/i pair.
+                                       continue pickI
+                               }
+                       }
+
+                       // All hashes are distinct. Use these values of b and i.
+                       h := s.hashname
+                       if i != 0 {
+                               h = ir.NewBinaryExpr(base.Pos, ir.ORSH, h, ir.NewInt(base.Pos, int64(i)))
+                       }
+                       h = ir.NewBinaryExpr(base.Pos, ir.OAND, h, ir.NewInt(base.Pos, int64(1<<b-1)))
+                       h = typecheck.Expr(h)
+
+                       // Build jump table.
+                       jt := ir.NewJumpTableStmt(base.Pos, h)
+                       jt.Cases = make([]constant.Value, 1<<b)
+                       jt.Targets = make([]*types.Sym, 1<<b)
+                       out.Append(jt)
+
+                       // Start with all hashes going to the didn't-match target.
+                       noMatch := typecheck.AutoLabel(".s")
+                       for j := 0; j < 1<<b; j++ {
+                               jt.Cases[j] = constant.MakeInt64(int64(j))
+                               jt.Targets[j] = noMatch
+                       }
+                       // This statement is not reachable, but it will make it obvious that we don't
+                       // fall through to the first case.
+                       out.Append(ir.NewBranchStmt(base.Pos, ir.OGOTO, noMatch))
+
+                       // Emit each of the actual cases.
+                       for _, c := range cc {
+                               h := c.hash >> i & (1<<b - 1)
+                               label := typecheck.AutoLabel(".s")
+                               jt.Targets[h] = label
+                               out.Append(ir.NewLabelStmt(base.Pos, label))
+                               out.Append(c.body...)
+                               // We reach here if the hash matches but the type equality test fails.
+                               out.Append(ir.NewBranchStmt(base.Pos, ir.OGOTO, noMatch))
+                       }
+                       // Emit point to go to if type doesn't match any case.
+                       out.Append(ir.NewLabelStmt(base.Pos, noMatch))
+                       return true
+               }
+       }
+       // Couldn't find a perfect hash. Fall back to binary search.
+       return false
+}
+
 // binarySearch constructs a binary search tree for handling n cases,
 // and appends it to out. It's used for efficiently implementing
 // switch statements.
index 603e0befbbdc752df4a3ba2aee2c5ab3ef8654ce..556d02a162cd0519133bc0d24c930389f1c223af 100644 (file)
@@ -99,3 +99,22 @@ func mimetype(ext string) string {
                return ""
        }
 }
+
+// use jump tables for type switches to concrete types.
+func typeSwitch(x any) int {
+       // amd64:`JMP\s\(.*\)\(.*\)$`
+       // arm64:`MOVD\s\(R.*\)\(R.*<<3\)`,`JMP\s\(R.*\)$`
+       switch x.(type) {
+       case int:
+               return 0
+       case int8:
+               return 1
+       case int16:
+               return 2
+       case int32:
+               return 3
+       case int64:
+               return 4
+       }
+       return 7
+}