]> Cypherpunks.ru repositories - gostls13.git/commitdiff
[release-branch.go1.22] cmd/compile: fail noder.LookupFunc gracefully if function...
authorMichael Pratt <mpratt@google.com>
Thu, 8 Feb 2024 18:07:33 +0000 (13:07 -0500)
committerThan McIntosh <thanm@google.com>
Fri, 16 Feb 2024 15:54:58 +0000 (15:54 +0000)
PGO uses noder.LookupFunc to look for devirtualization targets in
export data.  LookupFunc does not support type-parameterized
functions, and will currently fail the build when attempting to lookup
a type-parameterized function because objIdx is passed the wrong
number of type arguments.

This doesn't usually come up, as a PGO profile will report a generic
function with a symbol name like Func[.go.shape.foo]. In export data,
this is just Func, so when we do LookupFunc("Func[.go.shape.foo]")
lookup simply fails because the name doesn't exist.

However, if Func is not generic when the profile is collected, but the
source has since changed to make Func generic, then LookupFunc("Func")
will find the object successfully, only to fail the build because we
failed to provide type arguments.

Handle this with a objIdxMayFail, which allows graceful failure if the
object requires type arguments.

Bumping the language version to 1.21 in pgo_devirtualize_test.go is
required for type inference of the uses of mult.MultFn in
cmd/compile/internal/test/testdata/pgo/devirtualize/devirt_test.go.

For #65615.
Fixes #65618.

Change-Id: I84d9344840b851182f5321b8f7a29a591221b29f
Reviewed-on: https://go-review.googlesource.com/c/go/+/562737
Reviewed-by: Cherry Mui <cherryyz@google.com>
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
(cherry picked from commit 532c6f1c8d724975f578c8ec519f1f5b07d949da)
Reviewed-on: https://go-review.googlesource.com/c/go/+/563016

src/cmd/compile/internal/noder/reader.go
src/cmd/compile/internal/noder/unified.go
src/cmd/compile/internal/test/pgo_devirtualize_test.go

index f5d1fce50cf039d9e7c7a23076689a6eab871d9e..2dddd201659024f9aa45d928e84df3d2a15996d8 100644 (file)
@@ -663,9 +663,24 @@ func (pr *pkgReader) objInstIdx(info objInfo, dict *readerDict, shaped bool) ir.
 }
 
 // objIdx returns the specified object, instantiated with the given
-// type arguments, if any. If shaped is true, then the shaped variant
-// of the object is returned instead.
+// type arguments, if any.
+// If shaped is true, then the shaped variant of the object is returned
+// instead.
 func (pr *pkgReader) objIdx(idx pkgbits.Index, implicits, explicits []*types.Type, shaped bool) ir.Node {
+       n, err := pr.objIdxMayFail(idx, implicits, explicits, shaped)
+       if err != nil {
+               base.Fatalf("%v", err)
+       }
+       return n
+}
+
+// objIdxMayFail is equivalent to objIdx, but returns an error rather than
+// failing the build if this object requires type arguments and the incorrect
+// number of type arguments were passed.
+//
+// Other sources of internal failure (such as duplicate definitions) still fail
+// the build.
+func (pr *pkgReader) objIdxMayFail(idx pkgbits.Index, implicits, explicits []*types.Type, shaped bool) (ir.Node, error) {
        rname := pr.newReader(pkgbits.RelocName, idx, pkgbits.SyncObject1)
        _, sym := rname.qualifiedIdent()
        tag := pkgbits.CodeObj(rname.Code(pkgbits.SyncCodeObj))
@@ -674,22 +689,25 @@ func (pr *pkgReader) objIdx(idx pkgbits.Index, implicits, explicits []*types.Typ
                assert(!sym.IsBlank())
                switch sym.Pkg {
                case types.BuiltinPkg, types.UnsafePkg:
-                       return sym.Def.(ir.Node)
+                       return sym.Def.(ir.Node), nil
                }
                if pri, ok := objReader[sym]; ok {
-                       return pri.pr.objIdx(pri.idx, nil, explicits, shaped)
+                       return pri.pr.objIdxMayFail(pri.idx, nil, explicits, shaped)
                }
                if sym.Pkg.Path == "runtime" {
-                       return typecheck.LookupRuntime(sym.Name)
+                       return typecheck.LookupRuntime(sym.Name), nil
                }
                base.Fatalf("unresolved stub: %v", sym)
        }
 
-       dict := pr.objDictIdx(sym, idx, implicits, explicits, shaped)
+       dict, err := pr.objDictIdx(sym, idx, implicits, explicits, shaped)
+       if err != nil {
+               return nil, err
+       }
 
        sym = dict.baseSym
        if !sym.IsBlank() && sym.Def != nil {
-               return sym.Def.(*ir.Name)
+               return sym.Def.(*ir.Name), nil
        }
 
        r := pr.newReader(pkgbits.RelocObj, idx, pkgbits.SyncObject1)
@@ -725,7 +743,7 @@ func (pr *pkgReader) objIdx(idx pkgbits.Index, implicits, explicits []*types.Typ
                name := do(ir.OTYPE, false)
                setType(name, r.typ())
                name.SetAlias(true)
-               return name
+               return name, nil
 
        case pkgbits.ObjConst:
                name := do(ir.OLITERAL, false)
@@ -733,7 +751,7 @@ func (pr *pkgReader) objIdx(idx pkgbits.Index, implicits, explicits []*types.Typ
                val := FixValue(typ, r.Value())
                setType(name, typ)
                setValue(name, val)
-               return name
+               return name, nil
 
        case pkgbits.ObjFunc:
                if sym.Name == "init" {
@@ -768,7 +786,7 @@ func (pr *pkgReader) objIdx(idx pkgbits.Index, implicits, explicits []*types.Typ
                }
 
                rext.funcExt(name, nil)
-               return name
+               return name, nil
 
        case pkgbits.ObjType:
                name := do(ir.OTYPE, true)
@@ -805,13 +823,13 @@ func (pr *pkgReader) objIdx(idx pkgbits.Index, implicits, explicits []*types.Typ
                        r.needWrapper(typ)
                }
 
-               return name
+               return name, nil
 
        case pkgbits.ObjVar:
                name := do(ir.ONAME, false)
                setType(name, r.typ())
                rext.varExt(name)
-               return name
+               return name, nil
        }
 }
 
@@ -908,7 +926,7 @@ func shapify(targ *types.Type, basic bool) *types.Type {
 }
 
 // objDictIdx reads and returns the specified object dictionary.
-func (pr *pkgReader) objDictIdx(sym *types.Sym, idx pkgbits.Index, implicits, explicits []*types.Type, shaped bool) *readerDict {
+func (pr *pkgReader) objDictIdx(sym *types.Sym, idx pkgbits.Index, implicits, explicits []*types.Type, shaped bool) (*readerDict, error) {
        r := pr.newReader(pkgbits.RelocObjDict, idx, pkgbits.SyncObject1)
 
        dict := readerDict{
@@ -919,7 +937,7 @@ func (pr *pkgReader) objDictIdx(sym *types.Sym, idx pkgbits.Index, implicits, ex
        nexplicits := r.Len()
 
        if nimplicits > len(implicits) || nexplicits != len(explicits) {
-               base.Fatalf("%v has %v+%v params, but instantiated with %v+%v args", sym, nimplicits, nexplicits, len(implicits), len(explicits))
+               return nil, fmt.Errorf("%v has %v+%v params, but instantiated with %v+%v args", sym, nimplicits, nexplicits, len(implicits), len(explicits))
        }
 
        dict.targs = append(implicits[:nimplicits:nimplicits], explicits...)
@@ -984,7 +1002,7 @@ func (pr *pkgReader) objDictIdx(sym *types.Sym, idx pkgbits.Index, implicits, ex
                dict.itabs[i] = itabInfo{typ: r.typInfo(), iface: r.typInfo()}
        }
 
-       return &dict
+       return &dict, nil
 }
 
 func (r *reader) typeParamNames() {
@@ -2529,7 +2547,10 @@ func (pr *pkgReader) objDictName(idx pkgbits.Index, implicits, explicits []*type
                base.Fatalf("unresolved stub: %v", sym)
        }
 
-       dict := pr.objDictIdx(sym, idx, implicits, explicits, false)
+       dict, err := pr.objDictIdx(sym, idx, implicits, explicits, false)
+       if err != nil {
+               base.Fatalf("%v", err)
+       }
 
        return pr.dictNameOf(dict)
 }
index d2ca1f37a9f10ab5d713d90a2b4895ada32bd58d..492b00d256610da8ee0ae045fda8da574236a929 100644 (file)
@@ -80,7 +80,11 @@ func lookupFunction(pkg *types.Pkg, symName string) (*ir.Func, error) {
                return nil, fmt.Errorf("func sym %v missing objReader", sym)
        }
 
-       name := pri.pr.objIdx(pri.idx, nil, nil, false).(*ir.Name)
+       node, err := pri.pr.objIdxMayFail(pri.idx, nil, nil, false)
+       if err != nil {
+               return nil, fmt.Errorf("func sym %v lookup error: %w", sym, err)
+       }
+       name := node.(*ir.Name)
        if name.Op() != ir.ONAME || name.Class != ir.PFUNC {
                return nil, fmt.Errorf("func sym %v refers to non-function name: %v", sym, name)
        }
@@ -105,7 +109,11 @@ func lookupMethod(pkg *types.Pkg, symName string) (*ir.Func, error) {
                return nil, fmt.Errorf("type sym %v missing objReader", typ)
        }
 
-       name := pri.pr.objIdx(pri.idx, nil, nil, false).(*ir.Name)
+       node, err := pri.pr.objIdxMayFail(pri.idx, nil, nil, false)
+       if err != nil {
+               return nil, fmt.Errorf("func sym %v lookup error: %w", typ, err)
+       }
+       name := node.(*ir.Name)
        if name.Op() != ir.OTYPE {
                return nil, fmt.Errorf("type sym %v refers to non-type name: %v", typ, name)
        }
index c457478a1fb1611f1f99cd190835904b535b8cf0..f4512436834fc56ef5de13ad97a297942534a0b0 100644 (file)
@@ -14,8 +14,13 @@ import (
        "testing"
 )
 
+type devirtualization struct {
+       pos    string
+       callee string
+}
+
 // testPGODevirtualize tests that specific PGO devirtualize rewrites are performed.
-func testPGODevirtualize(t *testing.T, dir string) {
+func testPGODevirtualize(t *testing.T, dir string, want []devirtualization) {
        testenv.MustHaveGoRun(t)
        t.Parallel()
 
@@ -23,7 +28,7 @@ func testPGODevirtualize(t *testing.T, dir string) {
 
        // Add a go.mod so we have a consistent symbol names in this temp dir.
        goMod := fmt.Sprintf(`module %s
-go 1.19
+go 1.21
 `, pkg)
        if err := os.WriteFile(filepath.Join(dir, "go.mod"), []byte(goMod), 0644); err != nil {
                t.Fatalf("error writing go.mod: %v", err)
@@ -60,51 +65,6 @@ go 1.19
                t.Fatalf("error starting go test: %v", err)
        }
 
-       type devirtualization struct {
-               pos    string
-               callee string
-       }
-
-       want := []devirtualization{
-               // ExerciseIface
-               {
-                       pos:    "./devirt.go:101:20",
-                       callee: "mult.Mult.Multiply",
-               },
-               {
-                       pos:    "./devirt.go:101:39",
-                       callee: "Add.Add",
-               },
-               // ExerciseFuncConcrete
-               {
-                       pos:    "./devirt.go:173:36",
-                       callee: "AddFn",
-               },
-               {
-                       pos:    "./devirt.go:173:15",
-                       callee: "mult.MultFn",
-               },
-               // ExerciseFuncField
-               {
-                       pos:    "./devirt.go:207:35",
-                       callee: "AddFn",
-               },
-               {
-                       pos:    "./devirt.go:207:19",
-                       callee: "mult.MultFn",
-               },
-               // ExerciseFuncClosure
-               // TODO(prattmic): Closure callees not implemented.
-               //{
-               //      pos:    "./devirt.go:249:27",
-               //      callee: "AddClosure.func1",
-               //},
-               //{
-               //      pos:    "./devirt.go:249:15",
-               //      callee: "mult.MultClosure.func1",
-               //},
-       }
-
        got := make(map[devirtualization]struct{})
 
        devirtualizedLine := regexp.MustCompile(`(.*): PGO devirtualizing \w+ call .* to (.*)`)
@@ -172,5 +132,130 @@ func TestPGODevirtualize(t *testing.T) {
                }
        }
 
-       testPGODevirtualize(t, dir)
+       want := []devirtualization{
+               // ExerciseIface
+               {
+                       pos:    "./devirt.go:101:20",
+                       callee: "mult.Mult.Multiply",
+               },
+               {
+                       pos:    "./devirt.go:101:39",
+                       callee: "Add.Add",
+               },
+               // ExerciseFuncConcrete
+               {
+                       pos:    "./devirt.go:173:36",
+                       callee: "AddFn",
+               },
+               {
+                       pos:    "./devirt.go:173:15",
+                       callee: "mult.MultFn",
+               },
+               // ExerciseFuncField
+               {
+                       pos:    "./devirt.go:207:35",
+                       callee: "AddFn",
+               },
+               {
+                       pos:    "./devirt.go:207:19",
+                       callee: "mult.MultFn",
+               },
+               // ExerciseFuncClosure
+               // TODO(prattmic): Closure callees not implemented.
+               //{
+               //      pos:    "./devirt.go:249:27",
+               //      callee: "AddClosure.func1",
+               //},
+               //{
+               //      pos:    "./devirt.go:249:15",
+               //      callee: "mult.MultClosure.func1",
+               //},
+       }
+
+       testPGODevirtualize(t, dir, want)
+}
+
+// Regression test for https://go.dev/issue/65615. If a target function changes
+// from non-generic to generic we can't devirtualize it (don't know the type
+// parameters), but the compiler should not crash.
+func TestLookupFuncGeneric(t *testing.T) {
+       wd, err := os.Getwd()
+       if err != nil {
+               t.Fatalf("error getting wd: %v", err)
+       }
+       srcDir := filepath.Join(wd, "testdata", "pgo", "devirtualize")
+
+       // Copy the module to a scratch location so we can add a go.mod.
+       dir := t.TempDir()
+       if err := os.Mkdir(filepath.Join(dir, "mult.pkg"), 0755); err != nil {
+               t.Fatalf("error creating dir: %v", err)
+       }
+       for _, file := range []string{"devirt.go", "devirt_test.go", "devirt.pprof", filepath.Join("mult.pkg", "mult.go")} {
+               if err := copyFile(filepath.Join(dir, file), filepath.Join(srcDir, file)); err != nil {
+                       t.Fatalf("error copying %s: %v", file, err)
+               }
+       }
+
+       // Change MultFn from a concrete function to a parameterized function.
+       if err := convertMultToGeneric(filepath.Join(dir, "mult.pkg", "mult.go")); err != nil {
+               t.Fatalf("error editing mult.go: %v", err)
+       }
+
+       // Same as TestPGODevirtualize except for MultFn, which we cannot
+       // devirtualize to because it has become generic.
+       //
+       // Note that the important part of this test is that the build is
+       // successful, not the specific devirtualizations.
+       want := []devirtualization{
+               // ExerciseIface
+               {
+                       pos:    "./devirt.go:101:20",
+                       callee: "mult.Mult.Multiply",
+               },
+               {
+                       pos:    "./devirt.go:101:39",
+                       callee: "Add.Add",
+               },
+               // ExerciseFuncConcrete
+               {
+                       pos:    "./devirt.go:173:36",
+                       callee: "AddFn",
+               },
+               // ExerciseFuncField
+               {
+                       pos:    "./devirt.go:207:35",
+                       callee: "AddFn",
+               },
+               // ExerciseFuncClosure
+               // TODO(prattmic): Closure callees not implemented.
+               //{
+               //      pos:    "./devirt.go:249:27",
+               //      callee: "AddClosure.func1",
+               //},
+               //{
+               //      pos:    "./devirt.go:249:15",
+               //      callee: "mult.MultClosure.func1",
+               //},
+       }
+
+       testPGODevirtualize(t, dir, want)
+}
+
+var multFnRe = regexp.MustCompile(`func MultFn\(a, b int64\) int64`)
+
+func convertMultToGeneric(path string) error {
+       content, err := os.ReadFile(path)
+       if err != nil {
+               return fmt.Errorf("error opening: %w", err)
+       }
+
+       if !multFnRe.Match(content) {
+               return fmt.Errorf("MultFn not found; update regexp?")
+       }
+
+       // Users of MultFn shouldn't need adjustment, type inference should
+       // work OK.
+       content = multFnRe.ReplaceAll(content, []byte(`func MultFn[T int32|int64](a, b T) T`))
+
+       return os.WriteFile(path, content, 0644)
 }