]> Cypherpunks.ru repositories - gostls13.git/commitdiff
cmd/compile: use cache in front of type assert runtime call
authorKeith Randall <khr@golang.org>
Mon, 18 Sep 2023 20:31:49 +0000 (13:31 -0700)
committerKeith Randall <khr@golang.org>
Fri, 6 Oct 2023 17:02:53 +0000 (17:02 +0000)
That way we don't need to call into the runtime for every
type assertion (to an interface type).

name           old time/op  new time/op  delta
TypeAssert-24  3.78ns ± 3%  1.00ns ± 1%  -73.53%  (p=0.000 n=10+8)

Change-Id: I0ba308aaf0f24a5495b4e13c814d35af0c58bfde
Reviewed-on: https://go-review.googlesource.com/c/go/+/529316
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Keith Randall <khr@google.com>
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
Reviewed-by: Cuong Manh Le <cuong.manhle.vn@gmail.com>
src/cmd/compile/internal/ssagen/ssa.go
src/cmd/compile/internal/test/iface_test.go
src/cmd/compile/internal/walk/expr.go
src/internal/abi/switch.go
src/runtime/iface.go
test/codegen/ifaces.go
test/codegen/switch.go

index a438cc77938e617ea1598713e2bc7ad06a5f9c19..93643af2943f904773e685a5f4d9d1c513ef0630 100644 (file)
@@ -6559,7 +6559,8 @@ func (s *state) dynamicDottype(n *ir.DynamicTypeAssertExpr, commaok bool) (res,
 // descriptor is a compiler-allocated internal/abi.TypeAssert whose address is passed to runtime.typeAssert when
 // the target type is a compile-time-known non-empty interface. It may be nil.
 func (s *state) dottype1(pos src.XPos, src, dst *types.Type, iface, source, target, targetItab *ssa.Value, commaok bool, descriptor *obj.LSym) (res, resok *ssa.Value) {
-       byteptr := s.f.Config.Types.BytePtr
+       typs := s.f.Config.Types
+       byteptr := typs.BytePtr
        if dst.IsInterface() {
                if dst.IsEmptyInterface() {
                        // Converting to an empty interface.
@@ -6638,16 +6639,10 @@ func (s *state) dottype1(pos src.XPos, src, dst *types.Type, iface, source, targ
                itab := s.newValue1(ssa.OpITab, byteptr, iface)
                data := s.newValue1(ssa.OpIData, types.Types[types.TUNSAFEPTR], iface)
 
-               if commaok {
-                       // Use a variable to hold the resulting itab. This allows us
-                       // to merge a value from the nil and non-nil branches.
-                       // (This assignment will be the nil result.)
-                       s.vars[typVar] = itab
-               }
-
                // First, check for nil.
                bNil := s.f.NewBlock(ssa.BlockPlain)
                bNonNil := s.f.NewBlock(ssa.BlockPlain)
+               bMerge := s.f.NewBlock(ssa.BlockPlain)
                cond := s.newValue2(ssa.OpNeqPtr, types.Types[types.TBOOL], itab, s.constNil(byteptr))
                b := s.endBlock()
                b.Kind = ssa.BlockIf
@@ -6656,9 +6651,13 @@ func (s *state) dottype1(pos src.XPos, src, dst *types.Type, iface, source, targ
                b.AddEdgeTo(bNonNil)
                b.AddEdgeTo(bNil)
 
-               if !commaok {
+               s.startBlock(bNil)
+               if commaok {
+                       s.vars[typVar] = itab // which will be nil
+                       b := s.endBlock()
+                       b.AddEdgeTo(bMerge)
+               } else {
                        // Panic if input is nil.
-                       s.startBlock(bNil)
                        s.rtcall(ir.Syms.Panicnildottype, false, nil, target)
                }
 
@@ -6669,9 +6668,96 @@ func (s *state) dottype1(pos src.XPos, src, dst *types.Type, iface, source, targ
                        typ = s.load(byteptr, s.newValue1I(ssa.OpOffPtr, byteptr, int64(types.PtrSize), itab))
                }
 
+               // Check the cache first.
+               var d *ssa.Value
+               if descriptor != nil {
+                       d = s.newValue1A(ssa.OpAddr, byteptr, descriptor, s.sb)
+                       if base.Flag.N == 0 && rtabi.UseInterfaceSwitchCache(Arch.LinkArch.Name) {
+                               // Note: we can only use the cache if we have the right atomic load instruction.
+                               // Double-check that here.
+                               if _, ok := intrinsics[intrinsicKey{Arch.LinkArch.Arch, "runtime/internal/atomic", "Loadp"}]; !ok {
+                                       s.Fatalf("atomic load not available")
+                               }
+                               // Pick right size ops.
+                               var mul, and, add, zext ssa.Op
+                               if s.config.PtrSize == 4 {
+                                       mul = ssa.OpMul32
+                                       and = ssa.OpAnd32
+                                       add = ssa.OpAdd32
+                                       zext = ssa.OpCopy
+                               } else {
+                                       mul = ssa.OpMul64
+                                       and = ssa.OpAnd64
+                                       add = ssa.OpAdd64
+                                       zext = ssa.OpZeroExt32to64
+                               }
+
+                               loopHead := s.f.NewBlock(ssa.BlockPlain)
+                               loopBody := s.f.NewBlock(ssa.BlockPlain)
+                               cacheHit := s.f.NewBlock(ssa.BlockPlain)
+                               cacheMiss := s.f.NewBlock(ssa.BlockPlain)
+
+                               // Load cache pointer out of descriptor, with an atomic load so
+                               // we ensure that we see a fully written cache.
+                               atomicLoad := s.newValue2(ssa.OpAtomicLoadPtr, types.NewTuple(typs.BytePtr, types.TypeMem), d, s.mem())
+                               cache := s.newValue1(ssa.OpSelect0, typs.BytePtr, atomicLoad)
+                               s.vars[memVar] = s.newValue1(ssa.OpSelect1, types.TypeMem, atomicLoad)
+
+                               // Load hash from type.
+                               hash := s.newValue2(ssa.OpLoad, typs.UInt32, s.newValue1I(ssa.OpOffPtr, typs.UInt32Ptr, 2*s.config.PtrSize, typ), s.mem())
+                               hash = s.newValue1(zext, typs.Uintptr, hash)
+                               s.vars[hashVar] = hash
+                               // Load mask from cache.
+                               mask := s.newValue2(ssa.OpLoad, typs.Uintptr, cache, s.mem())
+                               // Jump to loop head.
+                               b := s.endBlock()
+                               b.AddEdgeTo(loopHead)
+
+                               // At loop head, get pointer to the cache entry.
+                               //   e := &cache.Entries[hash&mask]
+                               s.startBlock(loopHead)
+                               idx := s.newValue2(and, typs.Uintptr, s.variable(hashVar, typs.Uintptr), mask)
+                               idx = s.newValue2(mul, typs.Uintptr, idx, s.uintptrConstant(uint64(2*s.config.PtrSize)))
+                               idx = s.newValue2(add, typs.Uintptr, idx, s.uintptrConstant(uint64(s.config.PtrSize)))
+                               e := s.newValue2(ssa.OpAddPtr, typs.UintptrPtr, cache, idx)
+                               //   hash++
+                               s.vars[hashVar] = s.newValue2(add, typs.Uintptr, s.variable(hashVar, typs.Uintptr), s.uintptrConstant(1))
+
+                               // Look for a cache hit.
+                               //   if e.Typ == typ { goto hit }
+                               eTyp := s.newValue2(ssa.OpLoad, typs.Uintptr, e, s.mem())
+                               cmp1 := s.newValue2(ssa.OpEqPtr, typs.Bool, typ, eTyp)
+                               b = s.endBlock()
+                               b.Kind = ssa.BlockIf
+                               b.SetControl(cmp1)
+                               b.AddEdgeTo(cacheHit)
+                               b.AddEdgeTo(loopBody)
+
+                               // Look for an empty entry, the tombstone for this hash table.
+                               //   if e.Typ == nil { goto miss }
+                               s.startBlock(loopBody)
+                               cmp2 := s.newValue2(ssa.OpEqPtr, typs.Bool, eTyp, s.constNil(typs.BytePtr))
+                               b = s.endBlock()
+                               b.Kind = ssa.BlockIf
+                               b.SetControl(cmp2)
+                               b.AddEdgeTo(cacheMiss)
+                               b.AddEdgeTo(loopHead)
+
+                               // On a hit, load the data fields of the cache entry.
+                               //   Itab = e.Itab
+                               s.startBlock(cacheHit)
+                               eItab := s.newValue2(ssa.OpLoad, typs.BytePtr, s.newValue1I(ssa.OpOffPtr, typs.BytePtrPtr, s.config.PtrSize, e), s.mem())
+                               s.vars[typVar] = eItab
+                               b = s.endBlock()
+                               b.AddEdgeTo(bMerge)
+
+                               // On a miss, call into the runtime to get the answer.
+                               s.startBlock(cacheMiss)
+                       }
+               }
+
                // Call into runtime to get itab for result.
                if descriptor != nil {
-                       d := s.newValue1A(ssa.OpAddr, byteptr, descriptor, s.sb)
                        itab = s.rtcall(ir.Syms.TypeAssert, true, []*types.Type{byteptr}, d, typ)[0]
                } else {
                        var fn *obj.LSym
@@ -6682,18 +6768,18 @@ func (s *state) dottype1(pos src.XPos, src, dst *types.Type, iface, source, targ
                        }
                        itab = s.rtcall(fn, true, []*types.Type{byteptr}, target, typ)[0]
                }
-               // Build result.
+               s.vars[typVar] = itab
+               b = s.endBlock()
+               b.AddEdgeTo(bMerge)
+
+               // Build resulting interface.
+               s.startBlock(bMerge)
+               itab = s.variable(typVar, byteptr)
+               var ok *ssa.Value
                if commaok {
-                       // Merge the nil result and the runtime call result.
-                       s.vars[typVar] = itab
-                       b := s.endBlock()
-                       b.AddEdgeTo(bNil)
-                       s.startBlock(bNil)
-                       itab = s.variable(typVar, byteptr)
-                       ok := s.newValue2(ssa.OpNeqPtr, types.Types[types.TBOOL], itab, s.constNil(byteptr))
-                       return s.newValue2(ssa.OpIMake, dst, itab, data), ok
+                       ok = s.newValue2(ssa.OpNeqPtr, types.Types[types.TBOOL], itab, s.constNil(byteptr))
                }
-               return s.newValue2(ssa.OpIMake, dst, itab, data), nil
+               return s.newValue2(ssa.OpIMake, dst, itab, data), ok
        }
 
        if base.Debug.TypeAssert > 0 {
index ebc4f891c963a17de4c5d668f7248e9cbc8ae47c..db41eb8e55c740c36652ae1c739779bac928168a 100644 (file)
@@ -124,3 +124,15 @@ func BenchmarkEfaceInteger(b *testing.B) {
 func i2int(i interface{}) int {
        return i.(int)
 }
+
+func BenchmarkTypeAssert(b *testing.B) {
+       e := any(Int(0))
+       r := true
+       for i := 0; i < b.N; i++ {
+               _, ok := e.(I)
+               if !ok {
+                       r = false
+               }
+       }
+       sink = r
+}
index a3caa4db36ddf6ecfbe1af25f1bddbf50b7bd2b8..914011d13553a8e850a13c4e41955de083d5039b 100644 (file)
@@ -731,10 +731,15 @@ func walkDotType(n *ir.TypeAssertExpr, init *ir.Nodes) ir.Node {
                lsym := types.LocalPkg.Lookup(fmt.Sprintf(".typeAssert.%d", typeAssertGen)).LinksymABI(obj.ABI0)
                typeAssertGen++
                off := 0
+               off = objw.SymPtr(lsym, off, typecheck.LookupRuntimeVar("emptyTypeAssertCache"), 0)
                off = objw.SymPtr(lsym, off, reflectdata.TypeSym(n.Type()).Linksym(), 0)
                off = objw.Bool(lsym, off, n.Op() == ir.ODOTTYPE2) // CanFail
                off += types.PtrSize - 1
-               objw.Global(lsym, int32(off), obj.LOCAL|obj.NOPTR)
+               objw.Global(lsym, int32(off), obj.LOCAL)
+               // Set the type to be just a single pointer, as the cache pointer is the
+               // only one that GC needs to see.
+               lsym.Gotype = reflectdata.TypeLinksym(types.Types[types.TUINT8].PtrTo())
+
                n.Descriptor = lsym
        }
        return n
index 495580f9df404bef72bc7c7151ae1fb50331078f..9669fe51d5aa75f11129ba5ebf42aa14e0117a5e 100644 (file)
@@ -44,6 +44,18 @@ func UseInterfaceSwitchCache(goarch string) bool {
 }
 
 type TypeAssert struct {
+       Cache   *TypeAssertCache
        Inter   *InterfaceType
        CanFail bool
 }
+type TypeAssertCache struct {
+       Mask    uintptr
+       Entries [1]TypeAssertCacheEntry
+}
+type TypeAssertCacheEntry struct {
+       // type of source value (a *runtime._type)
+       Typ uintptr
+       // itab to use for result (a *runtime.itab)
+       // nil if CanFail is set and conversion would fail.
+       Itab uintptr
+}
index 7a2c257b13176b9b30926d2f699b9f751772fd28..da6346a706c05b704cdd85206e128411802718d9 100644 (file)
@@ -436,15 +436,94 @@ func assertE2I2(inter *interfacetype, t *_type) *itab {
 // interface type s.Inter. If the conversion is not possible it
 // panics if s.CanFail is false and returns nil if s.CanFail is true.
 func typeAssert(s *abi.TypeAssert, t *_type) *itab {
+       var tab *itab
        if t == nil {
-               if s.CanFail {
-                       return nil
+               if !s.CanFail {
+                       panic(&TypeAssertionError{nil, nil, &s.Inter.Type, ""})
+               }
+       } else {
+               tab = getitab(s.Inter, t, s.CanFail)
+       }
+
+       if !abi.UseInterfaceSwitchCache(GOARCH) {
+               return tab
+       }
+
+       // Maybe update the cache, so the next time the generated code
+       // doesn't need to call into the runtime.
+       if fastrand()&1023 != 0 {
+               // Only bother updating the cache ~1 in 1000 times.
+               return tab
+       }
+       // Load the current cache.
+       oldC := (*abi.TypeAssertCache)(atomic.Loadp(unsafe.Pointer(&s.Cache)))
+
+       if fastrand()&uint32(oldC.Mask) != 0 {
+               // As cache gets larger, choose to update it less often
+               // so we can amortize the cost of building a new cache.
+               return tab
+       }
+
+       // Make a new cache.
+       newC := buildTypeAssertCache(oldC, t, tab)
+
+       // Update cache. Use compare-and-swap so if multiple threads
+       // are fighting to update the cache, at least one of their
+       // updates will stick.
+       atomic_casPointer((*unsafe.Pointer)(unsafe.Pointer(&s.Cache)), unsafe.Pointer(oldC), unsafe.Pointer(newC))
+
+       return tab
+}
+
+func buildTypeAssertCache(oldC *abi.TypeAssertCache, typ *_type, tab *itab) *abi.TypeAssertCache {
+       oldEntries := unsafe.Slice(&oldC.Entries[0], oldC.Mask+1)
+
+       // Count the number of entries we need.
+       n := 1
+       for _, e := range oldEntries {
+               if e.Typ != 0 {
+                       n++
                }
-               panic(&TypeAssertionError{nil, nil, &s.Inter.Type, ""})
        }
-       return getitab(s.Inter, t, s.CanFail)
+
+       // Figure out how big a table we need.
+       // We need at least one more slot than the number of entries
+       // so that we are guaranteed an empty slot (for termination).
+       newN := n * 2                         // make it at most 50% full
+       newN = 1 << sys.Len64(uint64(newN-1)) // round up to a power of 2
+
+       // Allocate the new table.
+       newSize := unsafe.Sizeof(abi.TypeAssertCache{}) + uintptr(newN-1)*unsafe.Sizeof(abi.TypeAssertCacheEntry{})
+       newC := (*abi.TypeAssertCache)(mallocgc(newSize, nil, true))
+       newC.Mask = uintptr(newN - 1)
+       newEntries := unsafe.Slice(&newC.Entries[0], newN)
+
+       // Fill the new table.
+       addEntry := func(typ *_type, tab *itab) {
+               h := int(typ.Hash) & (newN - 1)
+               for {
+                       if newEntries[h].Typ == 0 {
+                               newEntries[h].Typ = uintptr(unsafe.Pointer(typ))
+                               newEntries[h].Itab = uintptr(unsafe.Pointer(tab))
+                               return
+                       }
+                       h = (h + 1) & (newN - 1)
+               }
+       }
+       for _, e := range oldEntries {
+               if e.Typ != 0 {
+                       addEntry((*_type)(unsafe.Pointer(e.Typ)), (*itab)(unsafe.Pointer(e.Itab)))
+               }
+       }
+       addEntry(typ, tab)
+
+       return newC
 }
 
+// Empty type assert cache. Contains one entry with a nil Typ (which
+// causes a cache lookup to fail immediately.)
+var emptyTypeAssertCache = abi.TypeAssertCache{Mask: 0}
+
 // interfaceSwitch compares t against the list of cases in s.
 // If t matches case i, interfaceSwitch returns the case index i and
 // an itab for the pair <t, s.Cases[i]>.
index d773845e8ee781f017072bb1b920f7957c307fa8..2be3fa5146a7acfc3ab1b82e9a4641f27f6c955f 100644 (file)
@@ -6,16 +6,22 @@
 
 package codegen
 
-type I interface { M() }
+type I interface{ M() }
 
 func NopConvertIface(x I) I {
-        // amd64:-`.*runtime.convI2I`
+       // amd64:-`.*runtime.convI2I`
        return I(x)
 }
 
 func NopConvertGeneric[T any](x T) T {
-        // amd64:-`.*runtime.convI2I`
-        return T(x)
+       // amd64:-`.*runtime.convI2I`
+       return T(x)
 }
 
 var NopConvertGenericIface = NopConvertGeneric[I]
+
+func ConvToM(x any) I {
+       // amd64:`CALL\truntime.typeAssert`,`MOVL\t16\(.*\)`,`MOVQ\t8\(.*\)(.*\*1)`
+       // arm64:`CALL\truntime.typeAssert`,`LDAR`,`MOVWU`,`MOVD\t\(R.*\)\(R.*\)`
+       return x.(I)
+}
index 6778c65ab3a4fe3998bec8990aefd14bf6dad4bf..b0186ba5b7c1eb0c73ccb26325d031b62f027300 100644 (file)
@@ -139,3 +139,12 @@ func interfaceSwitch(x any) int {
                return 3
        }
 }
+
+func interfaceCast(x any) int {
+       // amd64:`CALL\truntime.typeAssert`,`MOVL\t16\(.*\)`,`MOVQ\t8\(.*\)(.*\*1)`
+       // arm64:`CALL\truntime.typeAssert`,`LDAR`,`MOVWU`,`MOVD\t\(R.*\)\(R.*\)`
+       if _, ok := x.(I); ok {
+               return 3
+       }
+       return 5
+}