]> Cypherpunks.ru repositories - gostls13.git/blobdiff - src/go/types/stmt.go
cmd/compile, go/types: typechecking of range over int, func
[gostls13.git] / src / go / types / stmt.go
index 3e56d415b62a75e2e4236a779fccb9af35dd6821..203205e19fe2b1c541eb40f64f0338df47240f1d 100644 (file)
@@ -10,6 +10,7 @@ import (
        "go/ast"
        "go/constant"
        "go/token"
+       "internal/buildcfg"
        . "internal/types/errors"
        "sort"
 )
@@ -827,136 +828,199 @@ func (check *Checker) stmt(ctxt stmtContext, s ast.Stmt) {
 
        case *ast.RangeStmt:
                inner |= breakOk | continueOk
+               check.rangeStmt(inner, s)
 
-               // check expression to iterate over
-               var x operand
-               check.expr(nil, &x, s.X)
+       default:
+               check.error(s, InvalidSyntaxTree, "invalid statement")
+       }
+}
 
-               // determine key/value types
-               var key, val Type
-               if x.mode != invalid {
-                       // Ranging over a type parameter is permitted if it has a core type.
-                       var cause string
-                       u := coreType(x.typ)
-                       switch t := u.(type) {
-                       case nil:
-                               cause = check.sprintf("%s has no core type", x.typ)
-                       case *Chan:
-                               if s.Value != nil {
-                                       check.softErrorf(s.Value, InvalidIterVar, "range over %s permits only one iteration variable", &x)
-                                       // ok to continue
-                               }
-                               if t.dir == SendOnly {
-                                       cause = "receive from send-only channel"
-                               }
-                       }
-                       key, val = rangeKeyVal(u)
-                       if key == nil || cause != "" {
-                               if cause == "" {
-                                       check.softErrorf(&x, InvalidRangeExpr, "cannot range over %s", &x)
-                               } else {
-                                       check.softErrorf(&x, InvalidRangeExpr, "cannot range over %s (%s)", &x, cause)
-                               }
-                               // ok to continue
+func (check *Checker) rangeStmt(inner stmtContext, s *ast.RangeStmt) {
+       // Convert go/ast form to local variables.
+       type expr = ast.Expr
+       type identType = ast.Ident
+       identName := func(n *identType) string { return n.Name }
+       sKey, sValue := s.Key, s.Value
+       var sExtra ast.Expr = nil
+       isDef := s.Tok == token.DEFINE
+       rangeVar := s.X
+       noNewVarPos := inNode(s, s.TokPos)
+
+       // Everything from here on is shared between cmd/compile/internal/types2 and go/types.
+
+       // check expression to iterate over
+       var x operand
+       check.expr(nil, &x, rangeVar)
+
+       // determine key/value types
+       var key, val Type
+       if x.mode != invalid {
+               // Ranging over a type parameter is permitted if it has a core type.
+               k, v, cause, isFunc, ok := rangeKeyVal(x.typ)
+               switch {
+               case !ok && cause != "":
+                       check.softErrorf(&x, InvalidRangeExpr, "cannot range over %s: %s", &x, cause)
+               case !ok:
+                       check.softErrorf(&x, InvalidRangeExpr, "cannot range over %s", &x)
+               case k == nil && sKey != nil:
+                       check.softErrorf(sKey, InvalidIterVar, "range over %s permits no iteration variables", &x)
+               case v == nil && sValue != nil:
+                       check.softErrorf(sValue, InvalidIterVar, "range over %s permits only one iteration variable", &x)
+               case sExtra != nil:
+                       check.softErrorf(sExtra, InvalidIterVar, "range clause permits at most two iteration variables")
+               case isFunc && ((k == nil) != (sKey == nil) || (v == nil) != (sValue == nil)):
+                       var count string
+                       switch {
+                       case k == nil:
+                               count = "no iteration variables"
+                       case v == nil:
+                               count = "one iteration variable"
+                       default:
+                               count = "two iteration variables"
                        }
+                       check.softErrorf(&x, InvalidIterVar, "range over %s must have %s", &x, count)
                }
+               key, val = k, v
+       }
 
-               // Open the for-statement block scope now, after the range clause.
-               // Iteration variables declared with := need to go in this scope (was go.dev/issue/51437).
-               check.openScope(s, "range")
-               defer check.closeScope()
-
-               // check assignment to/declaration of iteration variables
-               // (irregular assignment, cannot easily map to existing assignment checks)
+       // Open the for-statement block scope now, after the range clause.
+       // Iteration variables declared with := need to go in this scope (was go.dev/issue/51437).
+       check.openScope(s, "range")
+       defer check.closeScope()
 
-               // lhs expressions and initialization value (rhs) types
-               lhs := [2]ast.Expr{s.Key, s.Value}
-               rhs := [2]Type{key, val} // key, val may be nil
+       // check assignment to/declaration of iteration variables
+       // (irregular assignment, cannot easily map to existing assignment checks)
 
-               if s.Tok == token.DEFINE {
-                       // short variable declaration
-                       var vars []*Var
-                       for i, lhs := range lhs {
-                               if lhs == nil {
-                                       continue
-                               }
+       // lhs expressions and initialization value (rhs) types
+       lhs := [2]expr{sKey, sValue}
+       rhs := [2]Type{key, val} // key, val may be nil
 
-                               // determine lhs variable
-                               var obj *Var
-                               if ident, _ := lhs.(*ast.Ident); ident != nil {
-                                       // declare new variable
-                                       name := ident.Name
-                                       obj = NewVar(ident.Pos(), check.pkg, name, nil)
-                                       check.recordDef(ident, obj)
-                                       // _ variables don't count as new variables
-                                       if name != "_" {
-                                               vars = append(vars, obj)
-                                       }
-                               } else {
-                                       check.errorf(lhs, InvalidSyntaxTree, "cannot declare %s", lhs)
-                                       obj = NewVar(lhs.Pos(), check.pkg, "_", nil) // dummy variable
-                               }
+       if isDef {
+               // short variable declaration
+               var vars []*Var
+               for i, lhs := range lhs {
+                       if lhs == nil {
+                               continue
+                       }
 
-                               // initialize lhs variable
-                               if typ := rhs[i]; typ != nil {
-                                       x.mode = value
-                                       x.expr = lhs // we don't have a better rhs expression to use here
-                                       x.typ = typ
-                                       check.initVar(obj, &x, "range clause")
-                               } else {
-                                       obj.typ = Typ[Invalid]
-                                       obj.used = true // don't complain about unused variable
+                       // determine lhs variable
+                       var obj *Var
+                       if ident, _ := lhs.(*identType); ident != nil {
+                               // declare new variable
+                               name := identName(ident)
+                               obj = NewVar(ident.Pos(), check.pkg, name, nil)
+                               check.recordDef(ident, obj)
+                               // _ variables don't count as new variables
+                               if name != "_" {
+                                       vars = append(vars, obj)
                                }
+                       } else {
+                               check.errorf(lhs, InvalidSyntaxTree, "cannot declare %s", lhs)
+                               obj = NewVar(lhs.Pos(), check.pkg, "_", nil) // dummy variable
                        }
 
-                       // declare variables
-                       if len(vars) > 0 {
-                               scopePos := s.Body.Pos()
-                               for _, obj := range vars {
-                                       check.declare(check.scope, nil /* recordDef already called */, obj, scopePos)
-                               }
+                       // initialize lhs variable
+                       if typ := rhs[i]; typ != nil {
+                               x.mode = value
+                               x.expr = lhs // we don't have a better rhs expression to use here
+                               x.typ = typ
+                               check.initVar(obj, &x, "range clause")
                        } else {
-                               check.error(inNode(s, s.TokPos), NoNewVar, "no new variables on left side of :=")
+                               obj.typ = Typ[Invalid]
+                               obj.used = true // don't complain about unused variable
+                       }
+               }
+
+               // declare variables
+               if len(vars) > 0 {
+                       scopePos := s.Body.Pos()
+                       for _, obj := range vars {
+                               check.declare(check.scope, nil /* recordDef already called */, obj, scopePos)
                        }
                } else {
-                       // ordinary assignment
-                       for i, lhs := range lhs {
-                               if lhs == nil {
-                                       continue
-                               }
-                               if typ := rhs[i]; typ != nil {
-                                       x.mode = value
-                                       x.expr = lhs // we don't have a better rhs expression to use here
-                                       x.typ = typ
-                                       check.assignVar(lhs, nil, &x)
-                               }
+                       check.error(noNewVarPos, NoNewVar, "no new variables on left side of :=")
+               }
+       } else {
+               // ordinary assignment
+               for i, lhs := range lhs {
+                       if lhs == nil {
+                               continue
+                       }
+                       if typ := rhs[i]; typ != nil {
+                               x.mode = value
+                               x.expr = lhs // we don't have a better rhs expression to use here
+                               x.typ = typ
+                               check.assignVar(lhs, nil, &x)
                        }
                }
-
-               check.stmt(inner, s.Body)
-
-       default:
-               check.error(s, InvalidSyntaxTree, "invalid statement")
        }
+
+       check.stmt(inner, s.Body)
 }
 
 // rangeKeyVal returns the key and value type produced by a range clause
-// over an expression of type typ. If the range clause is not permitted
-// the results are nil.
-func rangeKeyVal(typ Type) (key, val Type) {
-       switch typ := arrayPtrDeref(typ).(type) {
+// over an expression of type typ. If the range clause is not permitted,
+// rangeKeyVal returns ok = false. When ok = false, rangeKeyVal may also
+// return a reason in cause.
+func rangeKeyVal(typ Type) (key, val Type, cause string, isFunc, ok bool) {
+       bad := func(cause string) (Type, Type, string, bool, bool) {
+               return Typ[Invalid], Typ[Invalid], cause, false, false
+       }
+       toSig := func(t Type) *Signature {
+               sig, _ := coreType(t).(*Signature)
+               return sig
+       }
+
+       orig := typ
+       switch typ := arrayPtrDeref(coreType(typ)).(type) {
+       case nil:
+               return bad("no core type")
        case *Basic:
                if isString(typ) {
-                       return Typ[Int], universeRune // use 'rune' name
+                       return Typ[Int], universeRune, "", false, true // use 'rune' name
+               }
+               if buildcfg.Experiment.Range && isInteger(typ) {
+                       return orig, nil, "", false, true
                }
        case *Array:
-               return Typ[Int], typ.elem
+               return Typ[Int], typ.elem, "", false, true
        case *Slice:
-               return Typ[Int], typ.elem
+               return Typ[Int], typ.elem, "", false, true
        case *Map:
-               return typ.key, typ.elem
+               return typ.key, typ.elem, "", false, true
        case *Chan:
-               return typ.elem, Typ[Invalid]
+               if typ.dir == SendOnly {
+                       return bad("receive from send-only channel")
+               }
+               return typ.elem, nil, "", false, true
+       case *Signature:
+               if !buildcfg.Experiment.Range {
+                       break
+               }
+               assert(typ.Recv() == nil)
+               switch {
+               case typ.Params().Len() != 1:
+                       return bad("func must be func(yield func(...) bool): wrong argument count")
+               case toSig(typ.Params().At(0).Type()) == nil:
+                       return bad("func must be func(yield func(...) bool): argument is not func")
+               case typ.Results().Len() != 0:
+                       return bad("func must be func(yield func(...) bool): unexpected results")
+               }
+               cb := toSig(typ.Params().At(0).Type())
+               assert(cb.Recv() == nil)
+               switch {
+               case cb.Params().Len() > 2:
+                       return bad("func must be func(yield func(...) bool): yield func has too many parameters")
+               case cb.Results().Len() != 1 || !isBoolean(cb.Results().At(0).Type()):
+                       return bad("func must be func(yield func(...) bool): yield func does not return bool")
+               }
+               if cb.Params().Len() >= 1 {
+                       key = cb.Params().At(0).Type()
+               }
+               if cb.Params().Len() >= 2 {
+                       val = cb.Params().At(1).Type()
+               }
+               return key, val, "", true, true
        }
        return
 }