]> Cypherpunks.ru repositories - gostls13.git/commitdiff
go/types, types2: implement reverse type inference for function arguments
authorRobert Griesemer <gri@golang.org>
Wed, 12 Apr 2023 01:02:26 +0000 (18:02 -0700)
committerGopher Robot <gobot@golang.org>
Wed, 3 May 2023 18:47:48 +0000 (18:47 +0000)
Allow function-typed function arguments to be generic and collect
their type parameters together with the callee's type parameters
(if any). Use a single inference step to infer the type arguments
for all type parameters simultaneously.

Requires Go 1.21 and that Config.EnableReverseTypeInference is set.
Does not yet support partially instantiated generic function arguments.
Not yet enabled in the compiler.

Known bug: inference may produce an incorrect result is the same
           generic function is passed twice in the same function
           call.

For #59338.

Change-Id: Ia1faa27a28c6353f0bbfd7f81feafc21bd36652c
Reviewed-on: https://go-review.googlesource.com/c/go/+/483935
Auto-Submit: Robert Griesemer <gri@google.com>
Reviewed-by: Robert Findley <rfindley@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Robert Griesemer <gri@google.com>
Run-TryBot: Robert Griesemer <gri@google.com>

14 files changed:
src/cmd/compile/internal/types2/api.go
src/cmd/compile/internal/types2/api_test.go
src/cmd/compile/internal/types2/call.go
src/cmd/compile/internal/types2/expr.go
src/cmd/compile/internal/types2/infer.go
src/go/types/api.go
src/go/types/api_test.go
src/go/types/call.go
src/go/types/check_test.go
src/go/types/expr.go
src/go/types/infer.go
src/internal/types/testdata/examples/inference2.go
src/internal/types/testdata/fixedbugs/issue59338a.go [new file with mode: 0644]
src/internal/types/testdata/fixedbugs/issue59338b.go [new file with mode: 0644]

index bd87945295be6777c0d958379f1e0504d5d321a6..0ee9a4bd06c62e0c398b616df8ff62d66c757296 100644 (file)
@@ -174,7 +174,7 @@ type Config struct {
        // partially instantiated generic functions may be assigned
        // (incl. returned) to variables of function type and type
        // inference will attempt to infer the missing type arguments.
-       // Experimental. Needs a proposal.
+       // See proposal go.dev/issue/59338.
        EnableReverseTypeInference bool
 }
 
index a13f43111c4c0edbc3ac81aa5176ee318abb432b..ae253623e6436ada2c9533129ff8d29ccdd8d5f4 100644 (file)
@@ -550,18 +550,55 @@ type T[P any] []P
                {`package issue51803; func foo[T any](T) {}; func _() { foo[int]( /* leave arg away on purpose */ ) }`,
                        []testInst{{`foo`, []string{`int`}, `func(int)`}},
                },
+
+               // reverse type parameter inference
+               {`package reverse1a; var f func(int) = g; func g[P any](P) {}`,
+                       []testInst{{`g`, []string{`int`}, `func(int)`}},
+               },
+               {`package reverse1b; func f(func(int)) {}; func g[P any](P) {}; func _() { f(g) }`,
+                       []testInst{{`g`, []string{`int`}, `func(int)`}},
+               },
+               {`package reverse2a; var f func(int) string = g; func g[P, Q any](P) Q { var q Q; return q }`,
+                       []testInst{{`g`, []string{`int`, `string`}, `func(int) string`}},
+               },
+               {`package reverse2b; func f(func(int) string) {}; func g[P, Q any](P) Q { var q Q; return q }; func _() { f(g) }`,
+                       []testInst{{`g`, []string{`int`, `string`}, `func(int) string`}},
+               },
+               // reverse3a not possible (cannot assign to generic function outside of argument passing)
+               {`package reverse3b; func f[R any](func(int) R) {}; func g[P any](P) string { return "" }; func _() { f(g) }`,
+                       []testInst{
+                               {`f`, []string{`string`}, `func(func(int) string)`},
+                               {`g`, []string{`int`}, `func(int) string`},
+                       },
+               },
+               {`package reverse4a; var _, _ func([]int, *float32) = g, h; func g[P, Q any]([]P, *Q) {}; func h[R any]([]R, *float32) {}`,
+                       []testInst{
+                               {`g`, []string{`int`, `float32`}, `func([]int, *float32)`},
+                               {`h`, []string{`int`}, `func([]int, *float32)`},
+                       },
+               },
+               {`package reverse4b; func f(_, _ func([]int, *float32)) {}; func g[P, Q any]([]P, *Q) {}; func h[R any]([]R, *float32) {}; func _() { f(g, h) }`,
+                       []testInst{
+                               {`g`, []string{`int`, `float32`}, `func([]int, *float32)`},
+                               {`h`, []string{`int`}, `func([]int, *float32)`},
+                       },
+               },
        }
 
        for _, test := range tests {
                imports := make(testImporter)
                conf := Config{
-                       Importer: imports,
-                       Error:    func(error) {}, // ignore errors
+                       Importer:                   imports,
+                       EnableReverseTypeInference: true,
                }
                instMap := make(map[*syntax.Name]Instance)
                useMap := make(map[*syntax.Name]Object)
                makePkg := func(src string) *Package {
-                       pkg, _ := typecheck(src, &conf, &Info{Instances: instMap, Uses: useMap})
+                       pkg, err := typecheck(src, &conf, &Info{Instances: instMap, Uses: useMap})
+                       // allow error for issue51803
+                       if err != nil && (pkg == nil || pkg.Name() != "issue51803") {
+                               t.Fatal(err)
+                       }
                        imports[pkg.Name()] = pkg
                        return pkg
                }
index 20cde9f44ed3a2d5b5df4cc606ada37e065e0b59..8ad7744ab4afdb916c5e66c1e6bf91428fc59932 100644 (file)
@@ -23,14 +23,16 @@ import (
 func (check *Checker) funcInst(tsig *Signature, pos syntax.Pos, x *operand, inst *syntax.IndexExpr) {
        assert(tsig != nil || inst != nil)
 
+       var versionErr bool  // set if version error was reported
+       var instErrPos poser // position for instantion error
+       if inst != nil {
+               instErrPos = inst.Pos()
+       } else {
+               instErrPos = pos
+       }
        if !check.allowVersion(check.pkg, pos, 1, 18) {
-               var posn poser
-               if inst != nil {
-                       posn = inst.Pos()
-               } else {
-                       posn = pos
-               }
-               check.versionErrorf(posn, "go1.18", "function instantiation")
+               check.versionErrorf(instErrPos, "go1.18", "function instantiation")
+               versionErr = true
        }
 
        // targs and xlist are the type arguments and corresponding type expressions, or nil.
@@ -72,6 +74,13 @@ func (check *Checker) funcInst(tsig *Signature, pos syntax.Pos, x *operand, inst
                        // of a synthetic function f where f's parameters are the parameters and results
                        // of x and where the arguments to the call of f are values of the parameter and
                        // result types of x.
+                       if !versionErr && !check.allowVersion(check.pkg, pos, 1, 21) {
+                               if inst != nil {
+                                       check.versionErrorf(instErrPos, "go1.21", "partially instantiated function in assignment")
+                               } else {
+                                       check.versionErrorf(instErrPos, "go1.21", "implicitly instantiated function in assignment")
+                               }
+                       }
                        n := tsig.params.Len()
                        m := tsig.results.Len()
                        args = make([]*operand, n+m)
@@ -303,7 +312,7 @@ func (check *Checker) callExpr(x *operand, call *syntax.CallExpr) exprKind {
        }
 
        // evaluate arguments
-       args := check.exprList(call.ArgList)
+       args := check.genericExprList(call.ArgList)
        sig = check.arguments(call, sig, targs, args, xlist)
 
        if wasGeneric && sig.TypeParams().Len() == 0 {
@@ -338,6 +347,8 @@ func (check *Checker) callExpr(x *operand, call *syntax.CallExpr) exprKind {
        return statement
 }
 
+// exprList evaluates a list of expressions and returns the corresponding operands.
+// A single-element expression list may evaluate to multiple operands.
 func (check *Checker) exprList(elist []syntax.Expr) (xlist []*operand) {
        switch len(elist) {
        case 0:
@@ -356,6 +367,25 @@ func (check *Checker) exprList(elist []syntax.Expr) (xlist []*operand) {
        return
 }
 
+// genericExprList is like exprList but result operands may be generic (not fully instantiated).
+func (check *Checker) genericExprList(elist []syntax.Expr) (xlist []*operand) {
+       switch len(elist) {
+       case 0:
+               // nothing to do
+       case 1:
+               xlist = check.genericMultiExpr(elist[0])
+       default:
+               // multiple (possibly invalid) values
+               xlist = make([]*operand, len(elist))
+               for i, e := range elist {
+                       var x operand
+                       check.genericExpr(&x, e)
+                       xlist[i] = &x
+               }
+       }
+       return
+}
+
 // xlist is the list of type argument expressions supplied in the source code.
 func (check *Checker) arguments(call *syntax.CallExpr, sig *Signature, targs []Type, args []*operand, xlist []syntax.Expr) (rsig *Signature) {
        rsig = sig
@@ -386,7 +416,7 @@ func (check *Checker) arguments(call *syntax.CallExpr, sig *Signature, targs []T
 
        // set up parameters
        sigParams := sig.params // adjusted for variadic functions (may be nil for empty parameter lists!)
-       adjusted := false       // indicates if sigParams is different from t.params
+       adjusted := false       // indicates if sigParams is different from sig.params
        if sig.variadic {
                if ddd {
                        // variadic_func(a, b, c...)
@@ -451,8 +481,12 @@ func (check *Checker) arguments(call *syntax.CallExpr, sig *Signature, targs []T
                return
        }
 
-       // infer type arguments and instantiate signature if necessary
-       if sig.TypeParams().Len() > 0 {
+       // collect type parameters of callee and generic function arguments
+       var tparams []*TypeParam
+
+       // collect type parameters of callee
+       n := sig.TypeParams().Len()
+       if n > 0 {
                if !check.allowVersion(check.pkg, call.Pos(), 1, 18) {
                        if iexpr, _ := call.Fun.(*syntax.IndexExpr); iexpr != nil {
                                check.versionErrorf(iexpr.Pos(), "go1.18", "function instantiation")
@@ -460,29 +494,72 @@ func (check *Checker) arguments(call *syntax.CallExpr, sig *Signature, targs []T
                                check.versionErrorf(call.Pos(), "go1.18", "implicit function instantiation")
                        }
                }
-
-               // Rename type parameters to avoid problems with recursive calls.
-               var tparams []*TypeParam
+               // rename type parameters to avoid problems with recursive calls
                tparams, sigParams = check.renameTParams(call.Pos(), sig.TypeParams().list(), sigParams)
+       }
 
-               targs := check.infer(call.Pos(), tparams, targs, sigParams, args)
+       // collect type parameters from generic function arguments
+       var genericArgs []int // indices of generic function arguments
+       if check.conf.EnableReverseTypeInference {
+               for i, arg := range args {
+                       // generic arguments cannot have a defined (*Named) type - no need for underlying type below
+                       if asig, _ := arg.typ.(*Signature); asig != nil && asig.TypeParams().Len() > 0 {
+                               // TODO(gri) need to also rename type parameters for cases like f(g, g)
+                               tparams = append(tparams, asig.TypeParams().list()...)
+                               genericArgs = append(genericArgs, i)
+                       }
+               }
+       }
+       if len(genericArgs) > 0 && !check.allowVersion(check.pkg, call.Pos(), 1, 21) {
+               // at the moment we only support implicit instantiations of argument functions
+               check.versionErrorf(args[genericArgs[0]].Pos(), "go1.21", "implicitly instantiated function as argument")
+       }
+
+       // tparams holds the type parameters of the callee and generic function arguments, if any:
+       // the first n type parameters belong to the callee, followed by mi type parameters for each
+       // of the generic function arguments, where mi = args[i].typ.(*Signature).TypeParams().Len().
+
+       // infer missing type arguments of callee and function arguments
+       if len(tparams) > 0 {
+               targs = check.infer(call.Pos(), tparams, targs, sigParams, args)
                if targs == nil {
+                       // TODO(gri) If infer inferred the first targs[:n], consider instantiating
+                       //           the call signature for better error messages/gopls behavior.
+                       //           Perhaps instantiate as much as we can, also for arguments.
+                       //           This will require changes to how infer returns its results.
                        return // error already reported
                }
 
-               // compute result signature
-               rsig = check.instantiateSignature(call.Pos(), sig, targs, xlist)
-               assert(rsig.TypeParams().Len() == 0) // signature is not generic anymore
-               check.recordInstance(call.Fun, targs, rsig)
+               // compute result signature: instantiate if needed
+               rsig = sig
+               if n > 0 {
+                       rsig = check.instantiateSignature(call.Pos(), sig, targs[:n], xlist)
+                       assert(rsig.TypeParams().Len() == 0) // signature is not generic anymore
+                       check.recordInstance(call.Fun, targs[:n], rsig)
+               }
 
-               // Optimization: Only if the parameter list was adjusted do we
-               // need to compute it from the adjusted list; otherwise we can
-               // simply use the result signature's parameter list.
-               if adjusted {
-                       sigParams = check.subst(call.Pos(), sigParams, makeSubstMap(tparams, targs), nil, check.context()).(*Tuple)
+               // Optimization: Only if the callee's parameter list was adjusted do we need to
+               // compute it from the adjusted list; otherwise we can simply use the result
+               // signature's parameter list. We only need the n type parameters and arguments
+               // of the callee.
+               if n > 0 && adjusted {
+                       sigParams = check.subst(call.Pos(), sigParams, makeSubstMap(tparams[:n], targs[:n]), nil, check.context()).(*Tuple)
                } else {
                        sigParams = rsig.params
                }
+
+               // compute argument signatures: instantiate if needed
+               j := n
+               for _, i := range genericArgs {
+                       asig := args[i].typ.(*Signature)
+                       k := j + asig.TypeParams().Len()
+                       // targs[j:k] are the inferred type arguments for asig
+                       asig = check.instantiateSignature(call.Pos(), asig, targs[j:k], nil) // TODO(gri) provide xlist if possible (partial instantiations)
+                       assert(asig.TypeParams().Len() == 0)                                 // signature is not generic anymore
+                       args[i].typ = asig
+                       check.recordInstance(args[i].expr, targs[j:k], asig)
+                       j = k
+               }
        }
 
        // check arguments
index 7c3b40f086811e2e7cd5ef02769fa186234240a6..51b944eeade7d51beeeb2eb197d8fb2cb632782d 100644 (file)
@@ -1882,6 +1882,27 @@ func (check *Checker) multiExpr(e syntax.Expr, allowCommaOk bool) (list []*opera
        return
 }
 
+// genericMultiExpr is like multiExpr but a one-element result may also be generic
+// and potential comma-ok expressions are returned as single values.
+func (check *Checker) genericMultiExpr(e syntax.Expr) (list []*operand) {
+       var x operand
+       check.rawExpr(nil, &x, e, nil, true)
+       check.exclude(&x, 1<<novalue|1<<builtin|1<<typexpr)
+
+       if t, ok := x.typ.(*Tuple); ok && x.mode != invalid {
+               // multiple values - cannot be generic
+               list = make([]*operand, t.Len())
+               for i, v := range t.vars {
+                       list[i] = &operand{mode: value, expr: e, typ: v.typ}
+               }
+               return
+       }
+
+       // exactly one (possible invalid or generic) value
+       list = []*operand{&x}
+       return
+}
+
 // exprWithHint typechecks expression e and initializes x with the expression value;
 // hint is the type of a composite literal element.
 // If an error occurred, x.mode is set to invalid.
index dbe621cded56189c2cc2035c756f8118bcaad659..9c1022c46fef541d8717d64f8d0e2789ee2360d4 100644 (file)
@@ -140,17 +140,17 @@ func (check *Checker) infer(pos syntax.Pos, tparams []*TypeParam, targs []Type,
        }
 
        for i, arg := range args {
+               if arg.mode == invalid {
+                       // An error was reported earlier. Ignore this arg
+                       // and continue, we may still be able to infer all
+                       // targs resulting in fewer follow-on errors.
+                       // TODO(gri) determine if we still need this check
+                       continue
+               }
                par := params.At(i)
-               // If we permit bidirectional unification, this conditional code needs to be
-               // executed even if par.typ is not parameterized since the argument may be a
-               // generic function (for which we want to infer its type arguments).
-               if isParameterized(tparams, par.typ) {
-                       if arg.mode == invalid {
-                               // An error was reported earlier. Ignore this targ
-                               // and continue, we may still be able to infer all
-                               // targs resulting in fewer follow-on errors.
-                               continue
-                       }
+               if isParameterized(tparams, par.typ) || isParameterized(tparams, arg.typ) {
+                       // Function parameters are always typed. Arguments may be untyped.
+                       // Collect the indices of untyped arguments and handle them later.
                        if isTyped(arg.typ) {
                                if !u.unify(par.typ, arg.typ) {
                                        errorf("type", par.typ, arg.typ, arg)
@@ -263,7 +263,7 @@ func (check *Checker) infer(pos syntax.Pos, tparams []*TypeParam, targs []Type,
        }
 
        // --- 3 ---
-       // use information from untyped contants
+       // use information from untyped constants
 
        if traceInference {
                u.tracef("== untyped arguments: %v", untyped)
@@ -541,7 +541,6 @@ func (w *tpWalker) isParameterized(typ Type) (res bool) {
                }
 
        case *TypeParam:
-               // t must be one of w.tparams
                return tparamIndex(w.tparams, t) >= 0
 
        default:
index 7af84fd244a0cea115d145486ae17e7a4ab3c8f0..e202d6dea8eea26780f97c7b49ca99cbcb41620f 100644 (file)
@@ -175,7 +175,7 @@ type Config struct {
        // partially instantiated generic functions may be assigned
        // (incl. returned) to variables of function type and type
        // inference will attempt to infer the missing type arguments.
-       // Experimental. Needs a proposal.
+       // See proposal go.dev/issue/59338.
        _EnableReverseTypeInference bool
 }
 
index ae1a7e50a7b29d8b271e67a977cfa77d20e6204f..02e26c3f025d1283bf31f8f9013f53eaafa1613f 100644 (file)
@@ -550,18 +550,57 @@ type T[P any] []P
                {`package issue51803; func foo[T any](T) {}; func _() { foo[int]( /* leave arg away on purpose */ ) }`,
                        []testInst{{`foo`, []string{`int`}, `func(int)`}},
                },
+
+               // reverse type parameter inference
+               {`package reverse1a; var f func(int) = g; func g[P any](P) {}`,
+                       []testInst{{`g`, []string{`int`}, `func(int)`}},
+               },
+               {`package reverse1b; func f(func(int)) {}; func g[P any](P) {}; func _() { f(g) }`,
+                       []testInst{{`g`, []string{`int`}, `func(int)`}},
+               },
+               {`package reverse2a; var f func(int) string = g; func g[P, Q any](P) Q { var q Q; return q }`,
+                       []testInst{{`g`, []string{`int`, `string`}, `func(int) string`}},
+               },
+               {`package reverse2b; func f(func(int) string) {}; func g[P, Q any](P) Q { var q Q; return q }; func _() { f(g) }`,
+                       []testInst{{`g`, []string{`int`, `string`}, `func(int) string`}},
+               },
+               // reverse3a not possible (cannot assign to generic function outside of argument passing)
+               {`package reverse3b; func f[R any](func(int) R) {}; func g[P any](P) string { return "" }; func _() { f(g) }`,
+                       []testInst{
+                               {`f`, []string{`string`}, `func(func(int) string)`},
+                               {`g`, []string{`int`}, `func(int) string`},
+                       },
+               },
+               {`package reverse4a; var _, _ func([]int, *float32) = g, h; func g[P, Q any]([]P, *Q) {}; func h[R any]([]R, *float32) {}`,
+                       []testInst{
+                               {`g`, []string{`int`, `float32`}, `func([]int, *float32)`},
+                               {`h`, []string{`int`}, `func([]int, *float32)`},
+                       },
+               },
+               {`package reverse4b; func f(_, _ func([]int, *float32)) {}; func g[P, Q any]([]P, *Q) {}; func h[R any]([]R, *float32) {}; func _() { f(g, h) }`,
+                       []testInst{
+                               {`g`, []string{`int`, `float32`}, `func([]int, *float32)`},
+                               {`h`, []string{`int`}, `func([]int, *float32)`},
+                       },
+               },
        }
 
        for _, test := range tests {
                imports := make(testImporter)
                conf := Config{
                        Importer: imports,
-                       Error:    func(error) {}, // ignore errors
+                       // Unexported field: set below with boolFieldAddr
+                       // _EnableReverseTypeInference: true,
                }
+               *boolFieldAddr(&conf, "_EnableReverseTypeInference") = true
                instMap := make(map[*ast.Ident]Instance)
                useMap := make(map[*ast.Ident]Object)
                makePkg := func(src string) *Package {
-                       pkg, _ := typecheck(src, &conf, &Info{Instances: instMap, Uses: useMap})
+                       pkg, err := typecheck(src, &conf, &Info{Instances: instMap, Uses: useMap})
+                       // allow error for issue51803
+                       if err != nil && (pkg == nil || pkg.Name() != "issue51803") {
+                               t.Fatal(err)
+                       }
                        imports[pkg.Name()] = pkg
                        return pkg
                }
index 979de2338f27293b30863dfeed9c13827e51d1e3..02b6038cccf23260d4a8104575abac2dea49fec0 100644 (file)
@@ -25,14 +25,16 @@ import (
 func (check *Checker) funcInst(tsig *Signature, pos token.Pos, x *operand, ix *typeparams.IndexExpr) {
        assert(tsig != nil || ix != nil)
 
+       var versionErr bool       // set if version error was reported
+       var instErrPos positioner // position for instantion error
+       if ix != nil {
+               instErrPos = inNode(ix.Orig, ix.Lbrack)
+       } else {
+               instErrPos = atPos(pos)
+       }
        if !check.allowVersion(check.pkg, pos, 1, 18) {
-               var posn positioner
-               if ix != nil {
-                       posn = inNode(ix.Orig, ix.Lbrack)
-               } else {
-                       posn = atPos(pos)
-               }
-               check.softErrorf(posn, UnsupportedFeature, "function instantiation requires go1.18 or later")
+               check.softErrorf(instErrPos, UnsupportedFeature, "function instantiation requires go1.18 or later")
+               versionErr = true
        }
 
        // targs and xlist are the type arguments and corresponding type expressions, or nil.
@@ -74,6 +76,13 @@ func (check *Checker) funcInst(tsig *Signature, pos token.Pos, x *operand, ix *t
                        // of a synthetic function f where f's parameters are the parameters and results
                        // of x and where the arguments to the call of f are values of the parameter and
                        // result types of x.
+                       if !versionErr && !check.allowVersion(check.pkg, pos, 1, 21) {
+                               if ix != nil {
+                                       check.softErrorf(instErrPos, UnsupportedFeature, "partially instantiated function in assignment requires go1.21 or later")
+                               } else {
+                                       check.softErrorf(instErrPos, UnsupportedFeature, "implicitly instantiated function in assignment requires go1.21 or later")
+                               }
+                       }
                        n := tsig.params.Len()
                        m := tsig.results.Len()
                        args = make([]*operand, n+m)
@@ -308,7 +317,7 @@ func (check *Checker) callExpr(x *operand, call *ast.CallExpr) exprKind {
        }
 
        // evaluate arguments
-       args := check.exprList(call.Args)
+       args := check.genericExprList(call.Args)
        sig = check.arguments(call, sig, targs, args, xlist)
 
        if wasGeneric && sig.TypeParams().Len() == 0 {
@@ -343,6 +352,8 @@ func (check *Checker) callExpr(x *operand, call *ast.CallExpr) exprKind {
        return statement
 }
 
+// exprList evaluates a list of expressions and returns the corresponding operands.
+// A single-element expression list may evaluate to multiple operands.
 func (check *Checker) exprList(elist []ast.Expr) (xlist []*operand) {
        switch len(elist) {
        case 0:
@@ -361,6 +372,25 @@ func (check *Checker) exprList(elist []ast.Expr) (xlist []*operand) {
        return
 }
 
+// genericExprList is like exprList but result operands may be generic (not fully instantiated).
+func (check *Checker) genericExprList(elist []ast.Expr) (xlist []*operand) {
+       switch len(elist) {
+       case 0:
+               // nothing to do
+       case 1:
+               xlist = check.genericMultiExpr(elist[0])
+       default:
+               // multiple (possibly invalid) values
+               xlist = make([]*operand, len(elist))
+               for i, e := range elist {
+                       var x operand
+                       check.genericExpr(&x, e)
+                       xlist[i] = &x
+               }
+       }
+       return
+}
+
 // xlist is the list of type argument expressions supplied in the source code.
 func (check *Checker) arguments(call *ast.CallExpr, sig *Signature, targs []Type, args []*operand, xlist []ast.Expr) (rsig *Signature) {
        rsig = sig
@@ -391,7 +421,7 @@ func (check *Checker) arguments(call *ast.CallExpr, sig *Signature, targs []Type
 
        // set up parameters
        sigParams := sig.params // adjusted for variadic functions (may be nil for empty parameter lists!)
-       adjusted := false       // indicates if sigParams is different from t.params
+       adjusted := false       // indicates if sigParams is different from sig.params
        if sig.variadic {
                if ddd {
                        // variadic_func(a, b, c...)
@@ -452,8 +482,12 @@ func (check *Checker) arguments(call *ast.CallExpr, sig *Signature, targs []Type
                return
        }
 
-       // infer type arguments and instantiate signature if necessary
-       if sig.TypeParams().Len() > 0 {
+       // collect type parameters of callee and generic function arguments
+       var tparams []*TypeParam
+
+       // collect type parameters of callee
+       n := sig.TypeParams().Len()
+       if n > 0 {
                if !check.allowVersion(check.pkg, call.Pos(), 1, 18) {
                        switch call.Fun.(type) {
                        case *ast.IndexExpr, *ast.IndexListExpr:
@@ -463,29 +497,72 @@ func (check *Checker) arguments(call *ast.CallExpr, sig *Signature, targs []Type
                                check.softErrorf(inNode(call, call.Lparen), UnsupportedFeature, "implicit function instantiation requires go1.18 or later")
                        }
                }
-
-               // Rename type parameters to avoid problems with recursive calls.
-               var tparams []*TypeParam
+               // rename type parameters to avoid problems with recursive calls
                tparams, sigParams = check.renameTParams(call.Pos(), sig.TypeParams().list(), sigParams)
+       }
 
-               targs := check.infer(call, tparams, targs, sigParams, args)
+       // collect type parameters from generic function arguments
+       var genericArgs []int // indices of generic function arguments
+       if check.conf._EnableReverseTypeInference {
+               for i, arg := range args {
+                       // generic arguments cannot have a defined (*Named) type - no need for underlying type below
+                       if asig, _ := arg.typ.(*Signature); asig != nil && asig.TypeParams().Len() > 0 {
+                               // TODO(gri) need to also rename type parameters for cases like f(g, g)
+                               tparams = append(tparams, asig.TypeParams().list()...)
+                               genericArgs = append(genericArgs, i)
+                       }
+               }
+       }
+       if len(genericArgs) > 0 && !check.allowVersion(check.pkg, call.Pos(), 1, 21) {
+               // at the moment we only support implicit instantiations of argument functions
+               check.softErrorf(inNode(call, call.Lparen), UnsupportedFeature, "implicitly instantiated function as argument requires go1.21 or later")
+       }
+
+       // tparams holds the type parameters of the callee and generic function arguments, if any:
+       // the first n type parameters belong to the callee, followed by mi type parameters for each
+       // of the generic function arguments, where mi = args[i].typ.(*Signature).TypeParams().Len().
+
+       // infer missing type arguments of callee and function arguments
+       if len(tparams) > 0 {
+               targs = check.infer(call, tparams, targs, sigParams, args)
                if targs == nil {
+                       // TODO(gri) If infer inferred the first targs[:n], consider instantiating
+                       //           the call signature for better error messages/gopls behavior.
+                       //           Perhaps instantiate as much as we can, also for arguments.
+                       //           This will require changes to how infer returns its results.
                        return // error already reported
                }
 
-               // compute result signature
-               rsig = check.instantiateSignature(call.Pos(), sig, targs, xlist)
-               assert(rsig.TypeParams().Len() == 0) // signature is not generic anymore
-               check.recordInstance(call.Fun, targs, rsig)
+               // compute result signature: instantiate if needed
+               rsig = sig
+               if n > 0 {
+                       rsig = check.instantiateSignature(call.Pos(), sig, targs[:n], xlist)
+                       assert(rsig.TypeParams().Len() == 0) // signature is not generic anymore
+                       check.recordInstance(call.Fun, targs[:n], rsig)
+               }
 
-               // Optimization: Only if the parameter list was adjusted do we
-               // need to compute it from the adjusted list; otherwise we can
-               // simply use the result signature's parameter list.
-               if adjusted {
-                       sigParams = check.subst(call.Pos(), sigParams, makeSubstMap(tparams, targs), nil, check.context()).(*Tuple)
+               // Optimization: Only if the callee's parameter list was adjusted do we need to
+               // compute it from the adjusted list; otherwise we can simply use the result
+               // signature's parameter list. We only need the n type parameters and arguments
+               // of the callee.
+               if n > 0 && adjusted {
+                       sigParams = check.subst(call.Pos(), sigParams, makeSubstMap(tparams[:n], targs[:n]), nil, check.context()).(*Tuple)
                } else {
                        sigParams = rsig.params
                }
+
+               // compute argument signatures: instantiate if needed
+               j := n
+               for _, i := range genericArgs {
+                       asig := args[i].typ.(*Signature)
+                       k := j + asig.TypeParams().Len()
+                       // targs[j:k] are the inferred type arguments for asig
+                       asig = check.instantiateSignature(call.Pos(), asig, targs[j:k], nil) // TODO(gri) provide xlist if possible (partial instantiations)
+                       assert(asig.TypeParams().Len() == 0)                                 // signature is not generic anymore
+                       args[i].typ = asig
+                       check.recordInstance(args[i].expr, targs[j:k], asig)
+                       j = k
+               }
        }
 
        // check arguments
index 0f4c320a47bc4d5369441014bfa61a2cb3f57e8f..cda052f4d39de929a0432d167ffa0835f211a7d7 100644 (file)
@@ -39,6 +39,7 @@ import (
        "go/scanner"
        "go/token"
        "internal/testenv"
+       "internal/types/errors"
        "os"
        "path/filepath"
        "reflect"
@@ -295,9 +296,9 @@ func testFiles(t *testing.T, sizes Sizes, filenames []string, srcs [][]byte, man
        }
 }
 
-func readCode(err Error) int {
+func readCode(err Error) errors.Code {
        v := reflect.ValueOf(err)
-       return int(v.FieldByName("go116code").Int())
+       return errors.Code(v.FieldByName("go116code").Int())
 }
 
 // boolFieldAddr(conf, name) returns the address of the boolean field conf.<name>.
index 0db80ca44b31554c4b74f83ad0dba8d2e7ce8d98..891153ba8d1a672a20c5b92a4ecb2b2f238de943 100644 (file)
@@ -1829,6 +1829,27 @@ func (check *Checker) multiExpr(e ast.Expr, allowCommaOk bool) (list []*operand,
        return
 }
 
+// genericMultiExpr is like multiExpr but a one-element result may also be generic
+// and potential comma-ok expressions are returned as single values.
+func (check *Checker) genericMultiExpr(e ast.Expr) (list []*operand) {
+       var x operand
+       check.rawExpr(nil, &x, e, nil, true)
+       check.exclude(&x, 1<<novalue|1<<builtin|1<<typexpr)
+
+       if t, ok := x.typ.(*Tuple); ok && x.mode != invalid {
+               // multiple values - cannot be generic
+               list = make([]*operand, t.Len())
+               for i, v := range t.vars {
+                       list[i] = &operand{mode: value, expr: e, typ: v.typ}
+               }
+               return
+       }
+
+       // exactly one (possible invalid or generic) value
+       list = []*operand{&x}
+       return
+}
+
 // exprWithHint typechecks expression e and initializes x with the expression value;
 // hint is the type of a composite literal element.
 // If an error occurred, x.mode is set to invalid.
index 3aa66105c409ec835d21953ed94ac6152ae7d5df..39bf4a14f73422c0afc9ec0c8d362904f65aefba 100644 (file)
@@ -142,17 +142,17 @@ func (check *Checker) infer(posn positioner, tparams []*TypeParam, targs []Type,
        }
 
        for i, arg := range args {
+               if arg.mode == invalid {
+                       // An error was reported earlier. Ignore this arg
+                       // and continue, we may still be able to infer all
+                       // targs resulting in fewer follow-on errors.
+                       // TODO(gri) determine if we still need this check
+                       continue
+               }
                par := params.At(i)
-               // If we permit bidirectional unification, this conditional code needs to be
-               // executed even if par.typ is not parameterized since the argument may be a
-               // generic function (for which we want to infer its type arguments).
-               if isParameterized(tparams, par.typ) {
-                       if arg.mode == invalid {
-                               // An error was reported earlier. Ignore this targ
-                               // and continue, we may still be able to infer all
-                               // targs resulting in fewer follow-on errors.
-                               continue
-                       }
+               if isParameterized(tparams, par.typ) || isParameterized(tparams, arg.typ) {
+                       // Function parameters are always typed. Arguments may be untyped.
+                       // Collect the indices of untyped arguments and handle them later.
                        if isTyped(arg.typ) {
                                if !u.unify(par.typ, arg.typ) {
                                        errorf("type", par.typ, arg.typ, arg)
@@ -265,7 +265,7 @@ func (check *Checker) infer(posn positioner, tparams []*TypeParam, targs []Type,
        }
 
        // --- 3 ---
-       // use information from untyped contants
+       // use information from untyped constants
 
        if traceInference {
                u.tracef("== untyped arguments: %v", untyped)
@@ -543,7 +543,6 @@ func (w *tpWalker) isParameterized(typ Type) (res bool) {
                }
 
        case *TypeParam:
-               // t must be one of w.tparams
                return tparamIndex(w.tparams, t) >= 0
 
        default:
index d309a00c0c2cb79e34c5616dc88dfa49e77cbf95..80acc828ddd6c43426f959dab6e9c5f26e769be7 100644 (file)
 
 package p
 
-func f1[P any](P)      {}
-func f2[P any]() P     { var x P; return x }
-func f3[P, Q any](P) Q { var x Q; return x }
-func f4[P any](P, P)   {}
-func f5[P any](P) []P  { return nil }
+func f1[P any](P)        {}
+func f2[P any]() P       { var x P; return x }
+func f3[P, Q any](P) Q   { var x Q; return x }
+func f4[P any](P, P)     {}
+func f5[P any](P) []P    { return nil }
+func f6[P any](int) P    { var x P; return x }
+func f7[P any](P) string { return "" }
 
 // initialization expressions
 var (
@@ -71,3 +73,22 @@ func _() func(string) []int {
 
 func _() (_, _ func(int)) { return f1, f1 }
 func _() (_, _ func(int)) { return f1, f2 /* ERROR "cannot infer P" */ }
+
+// Argument passing
+func g1(func(int))                           {}
+func g2(func(int, int))                      {}
+func g3(func(int) string)                    {}
+func g4[P any](func(P) string)               {}
+func g5[P, Q any](func(P) string, func(P) Q) {}
+func g6(func(int), func(string))             {}
+
+func _() {
+       g1(f1)
+       g1(f2 /* ERROR "cannot infer P" */)
+       g2(f4)
+       g4(f6)
+       g5(f6, f7)
+
+       // TODO(gri) this should work (requires type parameter renaming for f1)
+       g6(f1, f1 /* ERROR "type func[P any](P) of f1 does not match func(string)" */)
+}
diff --git a/src/internal/types/testdata/fixedbugs/issue59338a.go b/src/internal/types/testdata/fixedbugs/issue59338a.go
new file mode 100644 (file)
index 0000000..fd37586
--- /dev/null
@@ -0,0 +1,21 @@
+// -reverseTypeInference -lang=go1.20
+
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package p
+
+func g[P any](P)      {}
+func h[P, Q any](P) Q { panic(0) }
+
+var _ func(int) = g /* ERROR "implicitly instantiated function in assignment requires go1.21 or later" */
+var _ func(int) string = h[ /* ERROR "partially instantiated function in assignment requires go1.21 or later" */ int]
+
+func f1(func(int))      {}
+func f2(int, func(int)) {}
+
+func _() {
+       f1( /* ERROR "implicitly instantiated function as argument requires go1.21 or later" */ g)
+       f2( /* ERROR "implicitly instantiated function as argument requires go1.21 or later" */ 0, g)
+}
diff --git a/src/internal/types/testdata/fixedbugs/issue59338b.go b/src/internal/types/testdata/fixedbugs/issue59338b.go
new file mode 100644 (file)
index 0000000..ea321bc
--- /dev/null
@@ -0,0 +1,21 @@
+// -reverseTypeInference
+
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package p
+
+func g[P any](P)      {}
+func h[P, Q any](P) Q { panic(0) }
+
+var _ func(int) = g
+var _ func(int) string = h[int]
+
+func f1(func(int))      {}
+func f2(int, func(int)) {}
+
+func _() {
+       f1(g)
+       f2(0, g)
+}