]> Cypherpunks.ru repositories - gostls13.git/commitdiff
cmd/compile: implement range over integer
authorRuss Cox <rsc@golang.org>
Wed, 14 Jun 2023 14:55:06 +0000 (10:55 -0400)
committerGopher Robot <gobot@golang.org>
Wed, 20 Sep 2023 14:52:33 +0000 (14:52 +0000)
Add compiler implementation of range over integers.
This is only reachable if GOEXPERIMENT=range is set,
because otherwise type checking will fail.

For proposal #61405 (but behind a GOEXPERIMENT).
For #61717.

Change-Id: I4e35a73c5df1ac57f61ffb54033a433967e5be51
Reviewed-on: https://go-review.googlesource.com/c/go/+/510538
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
Auto-Submit: Russ Cox <rsc@golang.org>

src/cmd/compile/internal/noder/writer.go
src/cmd/compile/internal/types2/stmt.go
src/cmd/compile/internal/walk/order.go
src/cmd/compile/internal/walk/range.go
test/range3.go [new file with mode: 0644]

index 79c884c22f45b546b941e61f8b574dc00535bf98..f68a3875dfe5b23672ed5eab9cd1e43d22f0895d 100644 (file)
@@ -1448,7 +1448,7 @@ func (w *writer) forStmt(stmt *syntax.ForStmt) {
                                w.convRTTI(src, dstType)
                        }
 
-                       keyType, valueType := w.p.rangeTypes(rang.X)
+                       keyType, valueType := types2.RangeKeyVal(w.p.typeOf(rang.X))
                        assign(0, keyType)
                        assign(1, valueType)
                }
@@ -1489,30 +1489,6 @@ func (w *writer) distinctVars(stmt *syntax.ForStmt) bool {
        return is122 || lv > 0 && lv != 3
 }
 
-// rangeTypes returns the types of values produced by ranging over
-// expr.
-func (pw *pkgWriter) rangeTypes(expr syntax.Expr) (key, value types2.Type) {
-       typ := pw.typeOf(expr)
-       switch typ := types2.CoreType(typ).(type) {
-       case *types2.Pointer: // must be pointer to array
-               return types2.Typ[types2.Int], types2.CoreType(typ.Elem()).(*types2.Array).Elem()
-       case *types2.Array:
-               return types2.Typ[types2.Int], typ.Elem()
-       case *types2.Slice:
-               return types2.Typ[types2.Int], typ.Elem()
-       case *types2.Basic:
-               if typ.Info()&types2.IsString != 0 {
-                       return types2.Typ[types2.Int], runeTypeName.Type()
-               }
-       case *types2.Map:
-               return typ.Key(), typ.Elem()
-       case *types2.Chan:
-               return typ.Elem(), nil
-       }
-       pw.fatalf(expr, "unexpected range type: %v", typ)
-       panic("unreachable")
-}
-
 func (w *writer) ifStmt(stmt *syntax.IfStmt) {
        cond := w.p.staticBool(&stmt.Cond)
 
index e00c72685f186eb7f2a10e9ba81767413c3f59bd..0797da19d48d20dda5de05a35cc003789b305faa 100644 (file)
@@ -967,6 +967,12 @@ func (check *Checker) rangeStmt(inner stmtContext, s *syntax.ForStmt, rclause *s
        check.stmt(inner, s.Body)
 }
 
+// RangeKeyVal returns the key and value types for a range over typ.
+func RangeKeyVal(typ Type) (Type, Type) {
+       key, val, _, _, _ := rangeKeyVal(typ)
+       return key, val
+}
+
 // rangeKeyVal returns the key and value type produced by a range clause
 // over an expression of type typ. If the range clause is not permitted,
 // rangeKeyVal returns ok = false. When ok = false, rangeKeyVal may also
index 8db9e919c70b905bb78e91e194bee509b3b52cfe..2517023908b68f086c4799c98bb5a9e060fb1cbd 100644 (file)
@@ -830,11 +830,14 @@ func (o *orderState) stmt(n ir.Node) {
 
                orderBody := true
                xt := typecheck.RangeExprType(n.X.Type())
-               switch xt.Kind() {
+               switch k := xt.Kind(); {
                default:
                        base.Fatalf("order.stmt range %v", n.Type())
 
-               case types.TARRAY, types.TSLICE:
+               case types.IsInt[k]:
+                       // Used only once, no need to copy.
+
+               case k == types.TARRAY, k == types.TSLICE:
                        if n.Value == nil || ir.IsBlank(n.Value) {
                                // for i := range x will only use x once, to compute len(x).
                                // No need to copy it.
@@ -842,7 +845,7 @@ func (o *orderState) stmt(n ir.Node) {
                        }
                        fallthrough
 
-               case types.TCHAN, types.TSTRING:
+               case k == types.TCHAN, k == types.TSTRING:
                        // chan, string, slice, array ranges use value multiple times.
                        // make copy.
                        r := n.X
@@ -855,7 +858,7 @@ func (o *orderState) stmt(n ir.Node) {
 
                        n.X = o.copyExpr(r)
 
-               case types.TMAP:
+               case k == types.TMAP:
                        if isMapClear(n) {
                                // Preserve the body of the map clear pattern so it can
                                // be detected during walk. The loop body will not be used
index 4e9908b5d1ae4a3275193d33eb1c2d91a70a13cf..93898b3a66f296a715ddd89e5fdf75a672a86df9 100644 (file)
@@ -74,11 +74,25 @@ func walkRange(nrange *ir.RangeStmt) ir.Node {
 
        var body []ir.Node
        var init []ir.Node
-       switch t.Kind() {
+       switch k := t.Kind(); {
        default:
                base.Fatalf("walkRange")
 
-       case types.TARRAY, types.TSLICE, types.TPTR: // TPTR is pointer-to-array
+       case types.IsInt[k]:
+               hv1 := typecheck.TempAt(base.Pos, ir.CurFunc, t)
+               hn := typecheck.TempAt(base.Pos, ir.CurFunc, t)
+
+               init = append(init, ir.NewAssignStmt(base.Pos, hv1, nil))
+               init = append(init, ir.NewAssignStmt(base.Pos, hn, a))
+
+               nfor.Cond = ir.NewBinaryExpr(base.Pos, ir.OLT, hv1, hn)
+               nfor.Post = ir.NewAssignStmt(base.Pos, hv1, ir.NewBinaryExpr(base.Pos, ir.OADD, hv1, ir.NewInt(base.Pos, 1)))
+
+               if v1 != nil {
+                       body = []ir.Node{rangeAssign(nrange, hv1)}
+               }
+
+       case k == types.TARRAY, k == types.TSLICE, k == types.TPTR: // TPTR is pointer-to-array
                if nn := arrayRangeClear(nrange, v1, v2, a); nn != nil {
                        base.Pos = lno
                        return nn
@@ -219,7 +233,7 @@ func walkRange(nrange *ir.RangeStmt) ir.Node {
                as := ir.NewAssignStmt(base.Pos, hu, ir.NewBinaryExpr(base.Pos, ir.OADD, huVal, ir.NewInt(base.Pos, elem.Size())))
                nfor.Post = ir.NewBlockStmt(base.Pos, []ir.Node{nfor.Post, as})
 
-       case types.TMAP:
+       case k == types.TMAP:
                // order.stmt allocated the iterator for us.
                // we only use a once, so no copy needed.
                ha := a
@@ -248,7 +262,7 @@ func walkRange(nrange *ir.RangeStmt) ir.Node {
                        body = []ir.Node{rangeAssign2(nrange, key, elem)}
                }
 
-       case types.TCHAN:
+       case k == types.TCHAN:
                // order.stmt arranged for a copy of the channel variable.
                ha := a
 
@@ -275,7 +289,7 @@ func walkRange(nrange *ir.RangeStmt) ir.Node {
                // See issue 15281.
                body = append(body, ir.NewAssignStmt(base.Pos, hv1, nil))
 
-       case types.TSTRING:
+       case k == types.TSTRING:
                // Transform string range statements like "for v1, v2 = range a" into
                //
                // ha := a
diff --git a/test/range3.go b/test/range3.go
new file mode 100644 (file)
index 0000000..80a4ac8
--- /dev/null
@@ -0,0 +1,58 @@
+// run -goexperiment range
+
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Test the 'for range' construct.
+
+package main
+
+// test range over integers
+
+func testint1() {
+       j := 0
+       for i := range int(4) {
+               if i != j {
+                       println("range var", i, "want", j)
+               }
+               j++
+       }
+       if j != 4 {
+               println("wrong count ranging over 4:", j)
+       }
+}
+
+func testint2() {
+       j := 0
+       for i := range 4 {
+               if i != j {
+                       println("range var", i, "want", j)
+               }
+               j++
+       }
+       if j != 4 {
+               println("wrong count ranging over 4:", j)
+       }
+}
+
+func testint3() {
+       type MyInt int
+
+       j := MyInt(0)
+       for i := range MyInt(4) {
+               if i != j {
+                       println("range var", i, "want", j)
+               }
+               j++
+       }
+       if j != 4 {
+               println("wrong count ranging over 4:", j)
+       }
+}
+
+func main() {
+       testint1()
+       testint2()
+       testint3()
+}