]> Cypherpunks.ru repositories - gostls13.git/blobdiff - src/cmd/compile/internal/types2/stmt.go
cmd/compile, go/types: typechecking of range over int, func
[gostls13.git] / src / cmd / compile / internal / types2 / stmt.go
index a671002e1214c149cd31e06d9a1de29fed8004bc..e00c72685f186eb7f2a10e9ba81767413c3f59bd 100644 (file)
@@ -9,6 +9,7 @@ package types2
 import (
        "cmd/compile/internal/syntax"
        "go/constant"
+       "internal/buildcfg"
        . "internal/types/errors"
        "sort"
 )
@@ -828,7 +829,10 @@ func (check *Checker) typeSwitchStmt(inner stmtContext, s *syntax.SwitchStmt, gu
 }
 
 func (check *Checker) rangeStmt(inner stmtContext, s *syntax.ForStmt, rclause *syntax.RangeClause) {
-       // determine lhs, if any
+       // Convert syntax form to local variables.
+       type expr = syntax.Expr
+       type identType = syntax.Name
+       identName := func(n *identType) string { return n.Value }
        sKey := rclause.Lhs // possibly nil
        var sValue, sExtra syntax.Expr
        if p, _ := sKey.(*syntax.ListExpr); p != nil {
@@ -844,43 +848,48 @@ func (check *Checker) rangeStmt(inner stmtContext, s *syntax.ForStmt, rclause *s
                        sExtra = p.ElemList[2]
                }
        }
+       isDef := rclause.Def
+       rangeVar := rclause.X
+       noNewVarPos := s
+
+       // Do not use rclause anymore.
+       rclause = nil
+
+       // Everything from here on is shared between cmd/compile/internal/types2 and go/types.
 
        // check expression to iterate over
        var x operand
-       check.expr(nil, &x, rclause.X)
+       check.expr(nil, &x, rangeVar)
 
        // determine key/value types
        var key, val Type
        if x.mode != invalid {
                // Ranging over a type parameter is permitted if it has a core type.
-               var cause string
-               u := coreType(x.typ)
-               if t, _ := u.(*Chan); t != nil {
-                       if sValue != nil {
-                               check.softErrorf(sValue, InvalidIterVar, "range over %s permits only one iteration variable", &x)
-                               // ok to continue
-                       }
-                       if t.dir == SendOnly {
-                               cause = "receive from send-only channel"
-                       }
-               } else {
-                       if sExtra != nil {
-                               check.softErrorf(sExtra, InvalidIterVar, "range clause permits at most two iteration variables")
-                               // ok to continue
-                       }
-                       if u == nil {
-                               cause = check.sprintf("%s has no core type", x.typ)
+               k, v, cause, isFunc, ok := rangeKeyVal(x.typ)
+               switch {
+               case !ok && cause != "":
+                       check.softErrorf(&x, InvalidRangeExpr, "cannot range over %s: %s", &x, cause)
+               case !ok:
+                       check.softErrorf(&x, InvalidRangeExpr, "cannot range over %s", &x)
+               case k == nil && sKey != nil:
+                       check.softErrorf(sKey, InvalidIterVar, "range over %s permits no iteration variables", &x)
+               case v == nil && sValue != nil:
+                       check.softErrorf(sValue, InvalidIterVar, "range over %s permits only one iteration variable", &x)
+               case sExtra != nil:
+                       check.softErrorf(sExtra, InvalidIterVar, "range clause permits at most two iteration variables")
+               case isFunc && ((k == nil) != (sKey == nil) || (v == nil) != (sValue == nil)):
+                       var count string
+                       switch {
+                       case k == nil:
+                               count = "no iteration variables"
+                       case v == nil:
+                               count = "one iteration variable"
+                       default:
+                               count = "two iteration variables"
                        }
+                       check.softErrorf(&x, InvalidIterVar, "range over %s must have %s", &x, count)
                }
-               key, val = rangeKeyVal(u)
-               if key == nil || cause != "" {
-                       if cause == "" {
-                               check.softErrorf(&x, InvalidRangeExpr, "cannot range over %s", &x)
-                       } else {
-                               check.softErrorf(&x, InvalidRangeExpr, "cannot range over %s (%s)", &x, cause)
-                       }
-                       // ok to continue
-               }
+               key, val = k, v
        }
 
        // Open the for-statement block scope now, after the range clause.
@@ -892,10 +901,10 @@ func (check *Checker) rangeStmt(inner stmtContext, s *syntax.ForStmt, rclause *s
        // (irregular assignment, cannot easily map to existing assignment checks)
 
        // lhs expressions and initialization value (rhs) types
-       lhs := [2]syntax.Expr{sKey, sValue}
+       lhs := [2]expr{sKey, sValue}
        rhs := [2]Type{key, val} // key, val may be nil
 
-       if rclause.Def {
+       if isDef {
                // short variable declaration
                var vars []*Var
                for i, lhs := range lhs {
@@ -905,9 +914,9 @@ func (check *Checker) rangeStmt(inner stmtContext, s *syntax.ForStmt, rclause *s
 
                        // determine lhs variable
                        var obj *Var
-                       if ident, _ := lhs.(*syntax.Name); ident != nil {
+                       if ident, _ := lhs.(*identType); ident != nil {
                                // declare new variable
-                               name := ident.Value
+                               name := identName(ident)
                                obj = NewVar(ident.Pos(), check.pkg, name, nil)
                                check.recordDef(ident, obj)
                                // _ variables don't count as new variables
@@ -938,7 +947,7 @@ func (check *Checker) rangeStmt(inner stmtContext, s *syntax.ForStmt, rclause *s
                                check.declare(check.scope, nil /* recordDef already called */, obj, scopePos)
                        }
                } else {
-                       check.error(s, NoNewVar, "no new variables on left side of :=")
+                       check.error(noNewVarPos, NoNewVar, "no new variables on left side of :=")
                }
        } else {
                // ordinary assignment
@@ -959,22 +968,68 @@ func (check *Checker) rangeStmt(inner stmtContext, s *syntax.ForStmt, rclause *s
 }
 
 // rangeKeyVal returns the key and value type produced by a range clause
-// over an expression of type typ. If the range clause is not permitted
-// the results are nil.
-func rangeKeyVal(typ Type) (key, val Type) {
-       switch typ := arrayPtrDeref(typ).(type) {
+// over an expression of type typ. If the range clause is not permitted,
+// rangeKeyVal returns ok = false. When ok = false, rangeKeyVal may also
+// return a reason in cause.
+func rangeKeyVal(typ Type) (key, val Type, cause string, isFunc, ok bool) {
+       bad := func(cause string) (Type, Type, string, bool, bool) {
+               return Typ[Invalid], Typ[Invalid], cause, false, false
+       }
+       toSig := func(t Type) *Signature {
+               sig, _ := coreType(t).(*Signature)
+               return sig
+       }
+
+       orig := typ
+       switch typ := arrayPtrDeref(coreType(typ)).(type) {
+       case nil:
+               return bad("no core type")
        case *Basic:
                if isString(typ) {
-                       return Typ[Int], universeRune // use 'rune' name
+                       return Typ[Int], universeRune, "", false, true // use 'rune' name
+               }
+               if buildcfg.Experiment.Range && isInteger(typ) {
+                       return orig, nil, "", false, true
                }
        case *Array:
-               return Typ[Int], typ.elem
+               return Typ[Int], typ.elem, "", false, true
        case *Slice:
-               return Typ[Int], typ.elem
+               return Typ[Int], typ.elem, "", false, true
        case *Map:
-               return typ.key, typ.elem
+               return typ.key, typ.elem, "", false, true
        case *Chan:
-               return typ.elem, Typ[Invalid]
+               if typ.dir == SendOnly {
+                       return bad("receive from send-only channel")
+               }
+               return typ.elem, nil, "", false, true
+       case *Signature:
+               if !buildcfg.Experiment.Range {
+                       break
+               }
+               assert(typ.Recv() == nil)
+               switch {
+               case typ.Params().Len() != 1:
+                       return bad("func must be func(yield func(...) bool): wrong argument count")
+               case toSig(typ.Params().At(0).Type()) == nil:
+                       return bad("func must be func(yield func(...) bool): argument is not func")
+               case typ.Results().Len() != 0:
+                       return bad("func must be func(yield func(...) bool): unexpected results")
+               }
+               cb := toSig(typ.Params().At(0).Type())
+               assert(cb.Recv() == nil)
+               switch {
+               case cb.Params().Len() > 2:
+                       return bad("func must be func(yield func(...) bool): yield func has too many parameters")
+               case cb.Results().Len() != 1 || !isBoolean(cb.Results().At(0).Type()):
+                       return bad("func must be func(yield func(...) bool): yield func does not return bool")
+               }
+               if cb.Params().Len() >= 1 {
+                       key = cb.Params().At(0).Type()
+               }
+               if cb.Params().Len() >= 2 {
+                       val = cb.Params().At(1).Type()
+               }
+               return key, val, "", true, true
        }
        return
 }