]> Cypherpunks.ru repositories - gostls13.git/commitdiff
cmd/compile: sparse conditional constant propagation
authorYi Yang <qingfeng.yy@alibaba-inc.com>
Thu, 31 Aug 2023 02:48:34 +0000 (02:48 +0000)
committerKeith Randall <khr@golang.org>
Tue, 12 Sep 2023 21:01:50 +0000 (21:01 +0000)
sparse conditional constant propagation can discover optimization
opportunities that cannot be found by just combining constant folding
and constant propagation and dead code elimination separately.

This is a re-submit of PR#59575, which fix a broken dominance relationship caught by ssacheck

Updates https://github.com/golang/go/issues/59399

Change-Id: I57482dee38f8e80a610aed4f64295e60c38b7a47
GitHub-Last-Rev: 830016f24e3a5320c6c127a48ab7c84e2fc672eb
GitHub-Pull-Request: golang/go#60469
Reviewed-on: https://go-review.googlesource.com/c/go/+/498795
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Keith Randall <khr@google.com>
Reviewed-by: Heschi Kreinick <heschi@google.com>
Reviewed-by: Keith Randall <khr@golang.org>
src/cmd/compile/internal/ssa/_gen/genericOps.go
src/cmd/compile/internal/ssa/block.go
src/cmd/compile/internal/ssa/compile.go
src/cmd/compile/internal/ssa/sccp.go [new file with mode: 0644]
src/cmd/compile/internal/ssa/sccp_test.go [new file with mode: 0644]
test/checkbce.go
test/codegen/compare_and_branch.go
test/loopbce.go

index fb183192634258442460a9211e9fcb3361446bd2..a182afbaa83e4f887f293d189f8727deb07dd423 100644 (file)
@@ -649,6 +649,8 @@ var genericOps = []opData{
 //    Plain                []            [next]
 //       If   [boolean Value]      [then, else]
 //    First                []   [always, never]
+//    Defer             [mem]  [nopanic, panic]                  (control opcode should be OpStaticCall to runtime.deferproc)
+//JumpTable   [integer Value]  [succ1,succ2,..]
 
 var genericBlocks = []blockData{
        {name: "Plain"},                  // a single successor
index e7776b231643c6c2d222b99a6b936ae1b0612b66..6d391ab0110da2ded47595a1e201e453c53af167 100644 (file)
@@ -112,13 +112,6 @@ func (e Edge) String() string {
 }
 
 // BlockKind is the kind of SSA block.
-//
-//       kind          controls        successors
-//     ------------------------------------------
-//       Exit      [return mem]                []
-//      Plain                []            [next]
-//         If   [boolean Value]      [then, else]
-//      Defer             [mem]  [nopanic, panic]  (control opcode should be OpStaticCall to runtime.deferproc)
 type BlockKind int16
 
 // short form print
@@ -275,8 +268,7 @@ func (b *Block) truncateValues(i int) {
        b.Values = b.Values[:i]
 }
 
-// AddEdgeTo adds an edge from block b to block c. Used during building of the
-// SSA graph; do not use on an already-completed SSA graph.
+// AddEdgeTo adds an edge from block b to block c.
 func (b *Block) AddEdgeTo(c *Block) {
        i := len(b.Succs)
        j := len(c.Preds)
index 10984d508be4f30218cf400804a85bb8ea4f6453..625c98bb1f8fa58bdf02ac6503529fdd776ec2c2 100644 (file)
@@ -477,6 +477,7 @@ var passes = [...]pass{
        {name: "softfloat", fn: softfloat, required: true},
        {name: "late opt", fn: opt, required: true}, // TODO: split required rules and optimizing rules
        {name: "dead auto elim", fn: elimDeadAutosGeneric},
+       {name: "sccp", fn: sccp},
        {name: "generic deadcode", fn: deadcode, required: true}, // remove dead stores, which otherwise mess up store chain
        {name: "check bce", fn: checkbce},
        {name: "branchelim", fn: branchelim},
diff --git a/src/cmd/compile/internal/ssa/sccp.go b/src/cmd/compile/internal/ssa/sccp.go
new file mode 100644 (file)
index 0000000..3c10954
--- /dev/null
@@ -0,0 +1,578 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssa
+
+import (
+       "fmt"
+)
+
+// ----------------------------------------------------------------------------
+// Sparse Conditional Constant Propagation
+//
+// Described in
+// Mark N. Wegman, F. Kenneth Zadeck: Constant Propagation with Conditional Branches.
+// TOPLAS 1991.
+//
+// This algorithm uses three level lattice for SSA value
+//
+//      Top        undefined
+//     / | \
+// .. 1  2  3 ..   constant
+//     \ | /
+//     Bottom      not constant
+//
+// It starts with optimistically assuming that all SSA values are initially Top
+// and then propagates constant facts only along reachable control flow paths.
+// Since some basic blocks are not visited yet, corresponding inputs of phi become
+// Top, we use the meet(phi) to compute its lattice.
+//
+//       Top ∩ any = any
+//       Bottom ∩ any = Bottom
+//       ConstantA ∩ ConstantA = ConstantA
+//       ConstantA ∩ ConstantB = Bottom
+//
+// Each lattice value is lowered most twice(Top to Constant, Constant to Bottom)
+// due to lattice depth, resulting in a fast convergence speed of the algorithm.
+// In this way, sccp can discover optimization opportunities that cannot be found
+// by just combining constant folding and constant propagation and dead code
+// elimination separately.
+
+// Three level lattice holds compile time knowledge about SSA value
+const (
+       top      int8 = iota // undefined
+       constant             // constant
+       bottom               // not a constant
+)
+
+type lattice struct {
+       tag int8   // lattice type
+       val *Value // constant value
+}
+
+type worklist struct {
+       f            *Func               // the target function to be optimized out
+       edges        []Edge              // propagate constant facts through edges
+       uses         []*Value            // re-visiting set
+       visited      map[Edge]bool       // visited edges
+       latticeCells map[*Value]lattice  // constant lattices
+       defUse       map[*Value][]*Value // def-use chains for some values
+       defBlock     map[*Value][]*Block // use blocks of def
+       visitedBlock []bool              // visited block
+}
+
+// sccp stands for sparse conditional constant propagation, it propagates constants
+// through CFG conditionally and applies constant folding, constant replacement and
+// dead code elimination all together.
+func sccp(f *Func) {
+       var t worklist
+       t.f = f
+       t.edges = make([]Edge, 0)
+       t.visited = make(map[Edge]bool)
+       t.edges = append(t.edges, Edge{f.Entry, 0})
+       t.defUse = make(map[*Value][]*Value)
+       t.defBlock = make(map[*Value][]*Block)
+       t.latticeCells = make(map[*Value]lattice)
+       t.visitedBlock = f.Cache.allocBoolSlice(f.NumBlocks())
+       defer f.Cache.freeBoolSlice(t.visitedBlock)
+
+       // build it early since we rely heavily on the def-use chain later
+       t.buildDefUses()
+
+       // pick up either an edge or SSA value from worklilst, process it
+       for {
+               if len(t.edges) > 0 {
+                       edge := t.edges[0]
+                       t.edges = t.edges[1:]
+                       if _, exist := t.visited[edge]; !exist {
+                               dest := edge.b
+                               destVisited := t.visitedBlock[dest.ID]
+
+                               // mark edge as visited
+                               t.visited[edge] = true
+                               t.visitedBlock[dest.ID] = true
+                               for _, val := range dest.Values {
+                                       if val.Op == OpPhi || !destVisited {
+                                               t.visitValue(val)
+                                       }
+                               }
+                               // propagates constants facts through CFG, taking condition test
+                               // into account
+                               if !destVisited {
+                                       t.propagate(dest)
+                               }
+                       }
+                       continue
+               }
+               if len(t.uses) > 0 {
+                       use := t.uses[0]
+                       t.uses = t.uses[1:]
+                       t.visitValue(use)
+                       continue
+               }
+               break
+       }
+
+       // apply optimizations based on discovered constants
+       constCnt, rewireCnt := t.replaceConst()
+       if f.pass.debug > 0 {
+               if constCnt > 0 || rewireCnt > 0 {
+                       fmt.Printf("Phase SCCP for %v : %v constants, %v dce\n", f.Name, constCnt, rewireCnt)
+               }
+       }
+}
+
+func equals(a, b lattice) bool {
+       if a == b {
+               // fast path
+               return true
+       }
+       if a.tag != b.tag {
+               return false
+       }
+       if a.tag == constant {
+               // The same content of const value may be different, we should
+               // compare with auxInt instead
+               v1 := a.val
+               v2 := b.val
+               if v1.Op == v2.Op && v1.AuxInt == v2.AuxInt {
+                       return true
+               } else {
+                       return false
+               }
+       }
+       return true
+}
+
+// possibleConst checks if Value can be fold to const. For those Values that can
+// never become constants(e.g. StaticCall), we don't make futile efforts.
+func possibleConst(val *Value) bool {
+       if isConst(val) {
+               return true
+       }
+       switch val.Op {
+       case OpCopy:
+               return true
+       case OpPhi:
+               return true
+       case
+               // negate
+               OpNeg8, OpNeg16, OpNeg32, OpNeg64, OpNeg32F, OpNeg64F,
+               OpCom8, OpCom16, OpCom32, OpCom64,
+               // math
+               OpFloor, OpCeil, OpTrunc, OpRoundToEven, OpSqrt,
+               // conversion
+               OpTrunc16to8, OpTrunc32to8, OpTrunc32to16, OpTrunc64to8,
+               OpTrunc64to16, OpTrunc64to32, OpCvt32to32F, OpCvt32to64F,
+               OpCvt64to32F, OpCvt64to64F, OpCvt32Fto32, OpCvt32Fto64,
+               OpCvt64Fto32, OpCvt64Fto64, OpCvt32Fto64F, OpCvt64Fto32F,
+               OpCvtBoolToUint8,
+               OpZeroExt8to16, OpZeroExt8to32, OpZeroExt8to64, OpZeroExt16to32,
+               OpZeroExt16to64, OpZeroExt32to64, OpSignExt8to16, OpSignExt8to32,
+               OpSignExt8to64, OpSignExt16to32, OpSignExt16to64, OpSignExt32to64,
+               // bit
+               OpCtz8, OpCtz16, OpCtz32, OpCtz64,
+               // mask
+               OpSlicemask,
+               // safety check
+               OpIsNonNil,
+               // not
+               OpNot:
+               return true
+       case
+               // add
+               OpAdd64, OpAdd32, OpAdd16, OpAdd8,
+               OpAdd32F, OpAdd64F,
+               // sub
+               OpSub64, OpSub32, OpSub16, OpSub8,
+               OpSub32F, OpSub64F,
+               // mul
+               OpMul64, OpMul32, OpMul16, OpMul8,
+               OpMul32F, OpMul64F,
+               // div
+               OpDiv32F, OpDiv64F,
+               OpDiv8, OpDiv16, OpDiv32, OpDiv64,
+               OpDiv8u, OpDiv16u, OpDiv32u, OpDiv64u,
+               OpMod8, OpMod16, OpMod32, OpMod64,
+               OpMod8u, OpMod16u, OpMod32u, OpMod64u,
+               // compare
+               OpEq64, OpEq32, OpEq16, OpEq8,
+               OpEq32F, OpEq64F,
+               OpLess64, OpLess32, OpLess16, OpLess8,
+               OpLess64U, OpLess32U, OpLess16U, OpLess8U,
+               OpLess32F, OpLess64F,
+               OpLeq64, OpLeq32, OpLeq16, OpLeq8,
+               OpLeq64U, OpLeq32U, OpLeq16U, OpLeq8U,
+               OpLeq32F, OpLeq64F,
+               OpEqB, OpNeqB,
+               // shift
+               OpLsh64x64, OpRsh64x64, OpRsh64Ux64, OpLsh32x64,
+               OpRsh32x64, OpRsh32Ux64, OpLsh16x64, OpRsh16x64,
+               OpRsh16Ux64, OpLsh8x64, OpRsh8x64, OpRsh8Ux64,
+               // safety check
+               OpIsInBounds, OpIsSliceInBounds,
+               // bit
+               OpAnd8, OpAnd16, OpAnd32, OpAnd64,
+               OpOr8, OpOr16, OpOr32, OpOr64,
+               OpXor8, OpXor16, OpXor32, OpXor64:
+               return true
+       default:
+               return false
+       }
+}
+
+func (t *worklist) getLatticeCell(val *Value) lattice {
+       if !possibleConst(val) {
+               // they are always worst
+               return lattice{bottom, nil}
+       }
+       lt, exist := t.latticeCells[val]
+       if !exist {
+               return lattice{top, nil} // optimistically for un-visited value
+       }
+       return lt
+}
+
+func isConst(val *Value) bool {
+       switch val.Op {
+       case OpConst64, OpConst32, OpConst16, OpConst8,
+               OpConstBool, OpConst32F, OpConst64F:
+               return true
+       default:
+               return false
+       }
+}
+
+// buildDefUses builds def-use chain for some values early, because once the
+// lattice of a value is changed, we need to update lattices of use. But we don't
+// need all uses of it, only uses that can become constants would be added into
+// re-visit worklist since no matter how many times they are revisited, uses which
+// can't become constants lattice remains unchanged, i.e. Bottom.
+func (t *worklist) buildDefUses() {
+       for _, block := range t.f.Blocks {
+               for _, val := range block.Values {
+                       for _, arg := range val.Args {
+                               // find its uses, only uses that can become constants take into account
+                               if possibleConst(arg) && possibleConst(val) {
+                                       if _, exist := t.defUse[arg]; !exist {
+                                               t.defUse[arg] = make([]*Value, 0, arg.Uses)
+                                       }
+                                       t.defUse[arg] = append(t.defUse[arg], val)
+                               }
+                       }
+               }
+               for _, ctl := range block.ControlValues() {
+                       // for control values that can become constants, find their use blocks
+                       if possibleConst(ctl) {
+                               t.defBlock[ctl] = append(t.defBlock[ctl], block)
+                       }
+               }
+       }
+}
+
+// addUses finds all uses of value and appends them into work list for further process
+func (t *worklist) addUses(val *Value) {
+       for _, use := range t.defUse[val] {
+               if val == use {
+                       // Phi may refer to itself as uses, ignore them to avoid re-visiting phi
+                       // for performance reason
+                       continue
+               }
+               t.uses = append(t.uses, use)
+       }
+       for _, block := range t.defBlock[val] {
+               if t.visitedBlock[block.ID] {
+                       t.propagate(block)
+               }
+       }
+}
+
+// meet meets all of phi arguments and computes result lattice
+func (t *worklist) meet(val *Value) lattice {
+       optimisticLt := lattice{top, nil}
+       for i := 0; i < len(val.Args); i++ {
+               edge := Edge{val.Block, i}
+               // If incoming edge for phi is not visited, assume top optimistically.
+               // According to rules of meet:
+               //              Top ∩ any = any
+               // Top participates in meet() but does not affect the result, so here
+               // we will ignore Top and only take other lattices into consideration.
+               if _, exist := t.visited[edge]; exist {
+                       lt := t.getLatticeCell(val.Args[i])
+                       if lt.tag == constant {
+                               if optimisticLt.tag == top {
+                                       optimisticLt = lt
+                               } else {
+                                       if !equals(optimisticLt, lt) {
+                                               // ConstantA ∩ ConstantB = Bottom
+                                               return lattice{bottom, nil}
+                                       }
+                               }
+                       } else if lt.tag == bottom {
+                               // Bottom ∩ any = Bottom
+                               return lattice{bottom, nil}
+                       } else {
+                               // Top ∩ any = any
+                       }
+               } else {
+                       // Top ∩ any = any
+               }
+       }
+
+       // ConstantA ∩ ConstantA = ConstantA or Top ∩ any = any
+       return optimisticLt
+}
+
+func computeLattice(f *Func, val *Value, args ...*Value) lattice {
+       // In general, we need to perform constant evaluation based on constant args:
+       //
+       //  res := lattice{constant, nil}
+       //      switch op {
+       //      case OpAdd16:
+       //              res.val = newConst(argLt1.val.AuxInt16() + argLt2.val.AuxInt16())
+       //      case OpAdd32:
+       //              res.val = newConst(argLt1.val.AuxInt32() + argLt2.val.AuxInt32())
+       //      case OpDiv8:
+       //              if !isDivideByZero(argLt2.val.AuxInt8()) {
+       //                      res.val = newConst(argLt1.val.AuxInt8() / argLt2.val.AuxInt8())
+       //              }
+       //  ...
+       //      }
+       //
+       // However, this would create a huge switch for all opcodes that can be
+       // evaluated during compile time. Moreover, some operations can be evaluated
+       // only if its arguments satisfy additional conditions(e.g. divide by zero).
+       // It's fragile and error prone. We did a trick by reusing the existing rules
+       // in generic rules for compile-time evaluation. But generic rules rewrite
+       // original value, this behavior is undesired, because the lattice of values
+       // may change multiple times, once it was rewritten, we lose the opportunity
+       // to change it permanently, which can lead to errors. For example, We cannot
+       // change its value immediately after visiting Phi, because some of its input
+       // edges may still not be visited at this moment.
+       constValue := f.newValue(val.Op, val.Type, f.Entry, val.Pos)
+       constValue.AddArgs(args...)
+       matched := rewriteValuegeneric(constValue)
+       if matched {
+               if isConst(constValue) {
+                       return lattice{constant, constValue}
+               }
+       }
+       // Either we can not match generic rules for given value or it does not
+       // satisfy additional constraints(e.g. divide by zero), in these cases, clean
+       // up temporary value immediately in case they are not dominated by their args.
+       constValue.reset(OpInvalid)
+       return lattice{bottom, nil}
+}
+
+func (t *worklist) visitValue(val *Value) {
+       if !possibleConst(val) {
+               // fast fail for always worst Values, i.e. there is no lowering happen
+               // on them, their lattices must be initially worse Bottom.
+               return
+       }
+
+       oldLt := t.getLatticeCell(val)
+       defer func() {
+               // re-visit all uses of value if its lattice is changed
+               newLt := t.getLatticeCell(val)
+               if !equals(newLt, oldLt) {
+                       if int8(oldLt.tag) > int8(newLt.tag) {
+                               t.f.Fatalf("Must lower lattice\n")
+                       }
+                       t.addUses(val)
+               }
+       }()
+
+       switch val.Op {
+       // they are constant values, aren't they?
+       case OpConst64, OpConst32, OpConst16, OpConst8,
+               OpConstBool, OpConst32F, OpConst64F: //TODO: support ConstNil ConstString etc
+               t.latticeCells[val] = lattice{constant, val}
+       // lattice value of copy(x) actually means lattice value of (x)
+       case OpCopy:
+               t.latticeCells[val] = t.getLatticeCell(val.Args[0])
+       // phi should be processed specially
+       case OpPhi:
+               t.latticeCells[val] = t.meet(val)
+       // fold 1-input operations:
+       case
+               // negate
+               OpNeg8, OpNeg16, OpNeg32, OpNeg64, OpNeg32F, OpNeg64F,
+               OpCom8, OpCom16, OpCom32, OpCom64,
+               // math
+               OpFloor, OpCeil, OpTrunc, OpRoundToEven, OpSqrt,
+               // conversion
+               OpTrunc16to8, OpTrunc32to8, OpTrunc32to16, OpTrunc64to8,
+               OpTrunc64to16, OpTrunc64to32, OpCvt32to32F, OpCvt32to64F,
+               OpCvt64to32F, OpCvt64to64F, OpCvt32Fto32, OpCvt32Fto64,
+               OpCvt64Fto32, OpCvt64Fto64, OpCvt32Fto64F, OpCvt64Fto32F,
+               OpCvtBoolToUint8,
+               OpZeroExt8to16, OpZeroExt8to32, OpZeroExt8to64, OpZeroExt16to32,
+               OpZeroExt16to64, OpZeroExt32to64, OpSignExt8to16, OpSignExt8to32,
+               OpSignExt8to64, OpSignExt16to32, OpSignExt16to64, OpSignExt32to64,
+               // bit
+               OpCtz8, OpCtz16, OpCtz32, OpCtz64,
+               // mask
+               OpSlicemask,
+               // safety check
+               OpIsNonNil,
+               // not
+               OpNot:
+               lt1 := t.getLatticeCell(val.Args[0])
+
+               if lt1.tag == constant {
+                       // here we take a shortcut by reusing generic rules to fold constants
+                       t.latticeCells[val] = computeLattice(t.f, val, lt1.val)
+               } else {
+                       t.latticeCells[val] = lattice{lt1.tag, nil}
+               }
+       // fold 2-input operations
+       case
+               // add
+               OpAdd64, OpAdd32, OpAdd16, OpAdd8,
+               OpAdd32F, OpAdd64F,
+               // sub
+               OpSub64, OpSub32, OpSub16, OpSub8,
+               OpSub32F, OpSub64F,
+               // mul
+               OpMul64, OpMul32, OpMul16, OpMul8,
+               OpMul32F, OpMul64F,
+               // div
+               OpDiv32F, OpDiv64F,
+               OpDiv8, OpDiv16, OpDiv32, OpDiv64,
+               OpDiv8u, OpDiv16u, OpDiv32u, OpDiv64u, //TODO: support div128u
+               // mod
+               OpMod8, OpMod16, OpMod32, OpMod64,
+               OpMod8u, OpMod16u, OpMod32u, OpMod64u,
+               // compare
+               OpEq64, OpEq32, OpEq16, OpEq8,
+               OpEq32F, OpEq64F,
+               OpLess64, OpLess32, OpLess16, OpLess8,
+               OpLess64U, OpLess32U, OpLess16U, OpLess8U,
+               OpLess32F, OpLess64F,
+               OpLeq64, OpLeq32, OpLeq16, OpLeq8,
+               OpLeq64U, OpLeq32U, OpLeq16U, OpLeq8U,
+               OpLeq32F, OpLeq64F,
+               OpEqB, OpNeqB,
+               // shift
+               OpLsh64x64, OpRsh64x64, OpRsh64Ux64, OpLsh32x64,
+               OpRsh32x64, OpRsh32Ux64, OpLsh16x64, OpRsh16x64,
+               OpRsh16Ux64, OpLsh8x64, OpRsh8x64, OpRsh8Ux64,
+               // safety check
+               OpIsInBounds, OpIsSliceInBounds,
+               // bit
+               OpAnd8, OpAnd16, OpAnd32, OpAnd64,
+               OpOr8, OpOr16, OpOr32, OpOr64,
+               OpXor8, OpXor16, OpXor32, OpXor64:
+               lt1 := t.getLatticeCell(val.Args[0])
+               lt2 := t.getLatticeCell(val.Args[1])
+
+               if lt1.tag == constant && lt2.tag == constant {
+                       // here we take a shortcut by reusing generic rules to fold constants
+                       t.latticeCells[val] = computeLattice(t.f, val, lt1.val, lt2.val)
+               } else {
+                       if lt1.tag == bottom || lt2.tag == bottom {
+                               t.latticeCells[val] = lattice{bottom, nil}
+                       } else {
+                               t.latticeCells[val] = lattice{top, nil}
+                       }
+               }
+       default:
+               // Any other type of value cannot be a constant, they are always worst(Bottom)
+       }
+}
+
+// propagate propagates constants facts through CFG. If the block has single successor,
+// add the successor anyway. If the block has multiple successors, only add the
+// branch destination corresponding to lattice value of condition value.
+func (t *worklist) propagate(block *Block) {
+       switch block.Kind {
+       case BlockExit, BlockRet, BlockRetJmp, BlockInvalid:
+               // control flow ends, do nothing then
+               break
+       case BlockDefer:
+               // we know nothing about control flow, add all branch destinations
+               t.edges = append(t.edges, block.Succs...)
+       case BlockFirst:
+               fallthrough // always takes the first branch
+       case BlockPlain:
+               t.edges = append(t.edges, block.Succs[0])
+       case BlockIf, BlockJumpTable:
+               cond := block.ControlValues()[0]
+               condLattice := t.getLatticeCell(cond)
+               if condLattice.tag == bottom {
+                       // we know nothing about control flow, add all branch destinations
+                       t.edges = append(t.edges, block.Succs...)
+               } else if condLattice.tag == constant {
+                       // add branchIdx destinations depends on its condition
+                       var branchIdx int64
+                       if block.Kind == BlockIf {
+                               branchIdx = 1 - condLattice.val.AuxInt
+                       } else {
+                               branchIdx = condLattice.val.AuxInt
+                       }
+                       t.edges = append(t.edges, block.Succs[branchIdx])
+               } else {
+                       // condition value is not visited yet, don't propagate it now
+               }
+       default:
+               t.f.Fatalf("All kind of block should be processed above.")
+       }
+}
+
+// rewireSuccessor rewires corresponding successors according to constant value
+// discovered by previous analysis. As the result, some successors become unreachable
+// and thus can be removed in further deadcode phase
+func rewireSuccessor(block *Block, constVal *Value) bool {
+       switch block.Kind {
+       case BlockIf:
+               block.removeEdge(int(constVal.AuxInt))
+               block.Kind = BlockPlain
+               block.Likely = BranchUnknown
+               block.ResetControls()
+               return true
+       case BlockJumpTable:
+               idx := int(constVal.AuxInt)
+               targetBlock := block.Succs[idx].b
+               for len(block.Succs) > 0 {
+                       block.removeEdge(0)
+               }
+               block.AddEdgeTo(targetBlock)
+               block.Kind = BlockPlain
+               block.Likely = BranchUnknown
+               block.ResetControls()
+               return true
+       default:
+               return false
+       }
+}
+
+// replaceConst will replace non-constant values that have been proven by sccp
+// to be constants.
+func (t *worklist) replaceConst() (int, int) {
+       constCnt, rewireCnt := 0, 0
+       for val, lt := range t.latticeCells {
+               if lt.tag == constant {
+                       if !isConst(val) {
+                               if t.f.pass.debug > 0 {
+                                       fmt.Printf("Replace %v with %v\n", val.LongString(), lt.val.LongString())
+                               }
+                               val.reset(lt.val.Op)
+                               val.AuxInt = lt.val.AuxInt
+                               constCnt++
+                       }
+                       // If const value controls this block, rewires successors according to its value
+                       ctrlBlock := t.defBlock[val]
+                       for _, block := range ctrlBlock {
+                               if rewireSuccessor(block, lt.val) {
+                                       rewireCnt++
+                                       if t.f.pass.debug > 0 {
+                                               fmt.Printf("Rewire %v %v successors\n", block.Kind, block)
+                                       }
+                               }
+                       }
+               }
+       }
+       return constCnt, rewireCnt
+}
diff --git a/src/cmd/compile/internal/ssa/sccp_test.go b/src/cmd/compile/internal/ssa/sccp_test.go
new file mode 100644 (file)
index 0000000..70c23e7
--- /dev/null
@@ -0,0 +1,95 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssa
+
+import (
+       "cmd/compile/internal/types"
+       "strings"
+       "testing"
+)
+
+func TestSCCPBasic(t *testing.T) {
+       c := testConfig(t)
+       fun := c.Fun("b1",
+               Bloc("b1",
+                       Valu("mem", OpInitMem, types.TypeMem, 0, nil),
+                       Valu("v1", OpConst64, c.config.Types.Int64, 20, nil),
+                       Valu("v2", OpConst64, c.config.Types.Int64, 21, nil),
+                       Valu("v3", OpConst64F, c.config.Types.Float64, 21.0, nil),
+                       Valu("v4", OpConstBool, c.config.Types.Bool, 1, nil),
+                       Valu("t1", OpAdd64, c.config.Types.Int64, 0, nil, "v1", "v2"),
+                       Valu("t2", OpDiv64, c.config.Types.Int64, 0, nil, "t1", "v1"),
+                       Valu("t3", OpAdd64, c.config.Types.Int64, 0, nil, "t1", "t2"),
+                       Valu("t4", OpSub64, c.config.Types.Int64, 0, nil, "t3", "v2"),
+                       Valu("t5", OpMul64, c.config.Types.Int64, 0, nil, "t4", "v2"),
+                       Valu("t6", OpMod64, c.config.Types.Int64, 0, nil, "t5", "v2"),
+                       Valu("t7", OpAnd64, c.config.Types.Int64, 0, nil, "t6", "v2"),
+                       Valu("t8", OpOr64, c.config.Types.Int64, 0, nil, "t7", "v2"),
+                       Valu("t9", OpXor64, c.config.Types.Int64, 0, nil, "t8", "v2"),
+                       Valu("t10", OpNeg64, c.config.Types.Int64, 0, nil, "t9"),
+                       Valu("t11", OpCom64, c.config.Types.Int64, 0, nil, "t10"),
+                       Valu("t12", OpNeg64, c.config.Types.Int64, 0, nil, "t11"),
+                       Valu("t13", OpFloor, c.config.Types.Float64, 0, nil, "v3"),
+                       Valu("t14", OpSqrt, c.config.Types.Float64, 0, nil, "t13"),
+                       Valu("t15", OpCeil, c.config.Types.Float64, 0, nil, "t14"),
+                       Valu("t16", OpTrunc, c.config.Types.Float64, 0, nil, "t15"),
+                       Valu("t17", OpRoundToEven, c.config.Types.Float64, 0, nil, "t16"),
+                       Valu("t18", OpTrunc64to32, c.config.Types.Int64, 0, nil, "t12"),
+                       Valu("t19", OpCvt64Fto64, c.config.Types.Float64, 0, nil, "t17"),
+                       Valu("t20", OpCtz64, c.config.Types.Int64, 0, nil, "v2"),
+                       Valu("t21", OpSlicemask, c.config.Types.Int64, 0, nil, "t20"),
+                       Valu("t22", OpIsNonNil, c.config.Types.Int64, 0, nil, "v2"),
+                       Valu("t23", OpNot, c.config.Types.Bool, 0, nil, "v4"),
+                       Valu("t24", OpEq64, c.config.Types.Bool, 0, nil, "v1", "v2"),
+                       Valu("t25", OpLess64, c.config.Types.Bool, 0, nil, "v1", "v2"),
+                       Valu("t26", OpLeq64, c.config.Types.Bool, 0, nil, "v1", "v2"),
+                       Valu("t27", OpEqB, c.config.Types.Bool, 0, nil, "v4", "v4"),
+                       Valu("t28", OpLsh64x64, c.config.Types.Int64, 0, nil, "v2", "v1"),
+                       Valu("t29", OpIsInBounds, c.config.Types.Int64, 0, nil, "v2", "v1"),
+                       Valu("t30", OpIsSliceInBounds, c.config.Types.Int64, 0, nil, "v2", "v1"),
+                       Goto("b2")),
+               Bloc("b2",
+                       Exit("mem")))
+       sccp(fun.f)
+       CheckFunc(fun.f)
+       for name, value := range fun.values {
+               if strings.HasPrefix(name, "t") {
+                       if !isConst(value) {
+                               t.Errorf("Must be constant: %v", value.LongString())
+                       }
+               }
+       }
+}
+
+func TestSCCPIf(t *testing.T) {
+       c := testConfig(t)
+       fun := c.Fun("b1",
+               Bloc("b1",
+                       Valu("mem", OpInitMem, types.TypeMem, 0, nil),
+                       Valu("v1", OpConst64, c.config.Types.Int64, 0, nil),
+                       Valu("v2", OpConst64, c.config.Types.Int64, 1, nil),
+                       Valu("cmp", OpLess64, c.config.Types.Bool, 0, nil, "v1", "v2"),
+                       If("cmp", "b2", "b3")),
+               Bloc("b2",
+                       Valu("v3", OpConst64, c.config.Types.Int64, 3, nil),
+                       Goto("b4")),
+               Bloc("b3",
+                       Valu("v4", OpConst64, c.config.Types.Int64, 4, nil),
+                       Goto("b4")),
+               Bloc("b4",
+                       Valu("merge", OpPhi, c.config.Types.Int64, 0, nil, "v3", "v4"),
+                       Exit("mem")))
+       sccp(fun.f)
+       CheckFunc(fun.f)
+       for _, b := range fun.blocks {
+               for _, v := range b.Values {
+                       if v == fun.values["merge"] {
+                               if !isConst(v) {
+                                       t.Errorf("Must be constant: %v", v.LongString())
+                               }
+                       }
+               }
+       }
+}
index 6a126099bc89accd5f2d3519664648188c8bfdf1..ab31d9528386d0f0d142c1f31163ac83243d4f10 100644 (file)
@@ -137,6 +137,10 @@ func g4(a [100]int) {
                useInt(a[i+50])
 
                // The following are out of bounds.
+               if a[0] == 0xdeadbeef {
+                       // This is a trick to prohibit sccp to optimize out the following out of bound check
+                       continue
+               }
                useInt(a[i-11]) // ERROR "Found IsInBounds$"
                useInt(a[i+51]) // ERROR "Found IsInBounds$"
        }
index 3502a0302238b1095c762415e75dd8b028121390..c121f1d2cc235a4b2b7eb9fb9387b5760fab74b3 100644 (file)
@@ -72,7 +72,7 @@ func ui64x8() {
        }
 
        // s390x:"CLGIJ\t[$]2, R[0-9]+, [$]255, "
-       for i := uint64(0); i >= 256; i-- {
+       for i := uint64(257); i >= 256; i-- {
                dummy()
        }
 
@@ -145,7 +145,7 @@ func ui32x8() {
        }
 
        // s390x:"CLIJ\t[$]2, R[0-9]+, [$]255, "
-       for i := uint32(0); i >= 256; i-- {
+       for i := uint32(257); i >= 256; i-- {
                dummy()
        }
 
index fcf0d8d90dbdc15c1b21d960864b8fd6dcf34065..1119aaa65aa3c8dc3a8479d8eb9644300ea91073 100644 (file)
@@ -58,7 +58,7 @@ func f4(a [10]int) int {
 func f5(a [10]int) int {
        x := 0
        for i := -10; i < len(a); i += 2 { // ERROR "Induction variable: limits \[-10,8\], increment 2$"
-               x += a[i]
+               x += a[i+10]
        }
        return x
 }
@@ -66,7 +66,7 @@ func f5(a [10]int) int {
 func f5_int32(a [10]int) int {
        x := 0
        for i := int32(-10); i < int32(len(a)); i += 2 { // ERROR "Induction variable: limits \[-10,8\], increment 2$"
-               x += a[i]
+               x += a[i+10]
        }
        return x
 }
@@ -74,7 +74,7 @@ func f5_int32(a [10]int) int {
 func f5_int16(a [10]int) int {
        x := 0
        for i := int16(-10); i < int16(len(a)); i += 2 { // ERROR "Induction variable: limits \[-10,8\], increment 2$"
-               x += a[i]
+               x += a[i+10]
        }
        return x
 }
@@ -82,7 +82,7 @@ func f5_int16(a [10]int) int {
 func f5_int8(a [10]int) int {
        x := 0
        for i := int8(-10); i < int8(len(a)); i += 2 { // ERROR "Induction variable: limits \[-10,8\], increment 2$"
-               x += a[i]
+               x += a[i+10]
        }
        return x
 }
@@ -201,6 +201,10 @@ func h2(a []byte) {
 
 func k0(a [100]int) [100]int {
        for i := 10; i < 90; i++ { // ERROR "Induction variable: limits \[10,90\), increment 1$"
+               if a[0] == 0xdeadbeef {
+                       // This is a trick to prohibit sccp to optimize out the following out of bound check
+                       continue
+               }
                a[i-11] = i
                a[i-10] = i // ERROR "(\([0-9]+\) )?Proved IsInBounds$"
                a[i-5] = i  // ERROR "(\([0-9]+\) )?Proved IsInBounds$"
@@ -214,6 +218,10 @@ func k0(a [100]int) [100]int {
 
 func k1(a [100]int) [100]int {
        for i := 10; i < 90; i++ { // ERROR "Induction variable: limits \[10,90\), increment 1$"
+               if a[0] == 0xdeadbeef {
+                       // This is a trick to prohibit sccp to optimize out the following out of bound check
+                       continue
+               }
                useSlice(a[:i-11])
                useSlice(a[:i-10]) // ERROR "(\([0-9]+\) )?Proved IsSliceInBounds$"
                useSlice(a[:i-5])  // ERROR "(\([0-9]+\) )?Proved IsSliceInBounds$"
@@ -229,6 +237,10 @@ func k1(a [100]int) [100]int {
 
 func k2(a [100]int) [100]int {
        for i := 10; i < 90; i++ { // ERROR "Induction variable: limits \[10,90\), increment 1$"
+               if a[0] == 0xdeadbeef {
+                       // This is a trick to prohibit sccp to optimize out the following out of bound check
+                       continue
+               }
                useSlice(a[i-11:])
                useSlice(a[i-10:]) // ERROR "(\([0-9]+\) )?Proved IsSliceInBounds$"
                useSlice(a[i-5:])  // ERROR "(\([0-9]+\) )?Proved IsSliceInBounds$"
@@ -243,6 +255,10 @@ func k2(a [100]int) [100]int {
 
 func k3(a [100]int) [100]int {
        for i := -10; i < 90; i++ { // ERROR "Induction variable: limits \[-10,90\), increment 1$"
+               if a[0] == 0xdeadbeef {
+                       // This is a trick to prohibit sccp to optimize out the following out of bound check
+                       continue
+               }
                a[i+9] = i
                a[i+10] = i // ERROR "(\([0-9]+\) )?Proved IsInBounds$"
                a[i+11] = i
@@ -252,6 +268,10 @@ func k3(a [100]int) [100]int {
 
 func k3neg(a [100]int) [100]int {
        for i := 89; i > -11; i-- { // ERROR "Induction variable: limits \(-11,89\], increment 1$"
+               if a[0] == 0xdeadbeef {
+                       // This is a trick to prohibit sccp to optimize out the following out of bound check
+                       continue
+               }
                a[i+9] = i
                a[i+10] = i // ERROR "(\([0-9]+\) )?Proved IsInBounds$"
                a[i+11] = i
@@ -261,6 +281,10 @@ func k3neg(a [100]int) [100]int {
 
 func k3neg2(a [100]int) [100]int {
        for i := 89; i >= -10; i-- { // ERROR "Induction variable: limits \[-10,89\], increment 1$"
+               if a[0] == 0xdeadbeef {
+                       // This is a trick to prohibit sccp to optimize out the following out of bound check
+                       continue
+               }
                a[i+9] = i
                a[i+10] = i // ERROR "(\([0-9]+\) )?Proved IsInBounds$"
                a[i+11] = i
@@ -411,7 +435,6 @@ func nobce3(a [100]int64) [100]int64 {
        min := int64((-1) << 63)
        max := int64((1 << 63) - 1)
        for i := min; i < max; i++ { // ERROR "Induction variable: limits \[-9223372036854775808,9223372036854775807\), increment 1$"
-               a[i] = i
        }
        return a
 }