changeset 7283:3964f3d4eb18

Extend loop unswicthing to Switch nodes (integer or type)
author Gilles Duboscq <duboscq@ssw.jku.at>
date Thu, 20 Dec 2012 12:06:58 +0100
parents f368ec89e231
children eea2ffb2efe7
files graal/com.oracle.graal.compiler.test/src/com/oracle/graal/compiler/test/LoopUnswitchTest.java graal/com.oracle.graal.graph/src/com/oracle/graal/graph/NodeClass.java graal/com.oracle.graal.loop/src/com/oracle/graal/loop/LoopPolicies.java graal/com.oracle.graal.loop/src/com/oracle/graal/loop/LoopTransformations.java graal/com.oracle.graal.loop/src/com/oracle/graal/loop/phases/LoopTransformLowPhase.java graal/com.oracle.graal.nodes/src/com/oracle/graal/nodes/extended/SwitchNode.java
diffstat 6 files changed, 130 insertions(+), 59 deletions(-) [+]
line wrap: on
line diff
--- a/graal/com.oracle.graal.compiler.test/src/com/oracle/graal/compiler/test/LoopUnswitchTest.java	Wed Dec 19 15:52:50 2012 +0100
+++ b/graal/com.oracle.graal.compiler.test/src/com/oracle/graal/compiler/test/LoopUnswitchTest.java	Thu Dec 20 12:06:58 2012 +0100
@@ -33,7 +33,6 @@
 
 public class LoopUnswitchTest extends GraalCompilerTest {
 
-    @SuppressWarnings("all")
     public static int referenceSnippet1(int a) {
         int sum = 0;
         if (a > 2) {
@@ -48,7 +47,7 @@
         return sum;
     }
 
-    @SuppressWarnings("all")
+
     public static int test1Snippet(int a) {
         int sum = 0;
         for (int i = 0; i < 1000; i++) {
@@ -61,11 +60,65 @@
         return sum;
     }
 
+    public static int referenceSnippet2(int a) {
+        int sum = 0;
+        switch(a) {
+        case 0:
+            for (int i = 0; i < 1000; i++) {
+                sum += System.currentTimeMillis();
+            }
+            break;
+        case 1:
+            for (int i = 0; i < 1000; i++) {
+                sum += 1;
+                sum += 5;
+            }
+            break;
+        case 55:
+            for (int i = 0; i < 1000; i++) {
+                sum += 5;
+            }
+            break;
+        default:
+            for (int i = 0; i < 1000; i++) {
+                //nothing
+            }
+            break;
+        }
+        return sum;
+    }
+
+    public static int test2Snippet(int a) {
+        int sum = 0;
+        for (int i = 0; i < 1000; i++) {
+            switch(a) {
+                case 0:
+                    sum += System.currentTimeMillis();
+                    break;
+                case 1:
+                    sum += 1;
+                    // fall through
+                case 55:
+                    sum += 5;
+                    break;
+                default:
+                    //nothing
+                    break;
+            }
+        }
+        return sum;
+    }
+
     @Test
     public void test1() {
         test("test1Snippet", "referenceSnippet1");
     }
 
+    @Test
+    public void test2() {
+        test("test2Snippet", "referenceSnippet2");
+    }
+
     private void test(String snippet, String referenceSnippet) {
         final StructuredGraph graph = parse(snippet);
         final StructuredGraph referenceGraph = parse(referenceSnippet);
--- a/graal/com.oracle.graal.graph/src/com/oracle/graal/graph/NodeClass.java	Wed Dec 19 15:52:50 2012 +0100
+++ b/graal/com.oracle.graal.graph/src/com/oracle/graal/graph/NodeClass.java	Thu Dec 20 12:06:58 2012 +0100
@@ -547,7 +547,7 @@
         return fieldNames.get(pos.input ? inputOffsets[pos.index] : successorOffsets[pos.index]);
     }
 
-    private void set(Node node, Position pos, Node x) {
+    public void set(Node node, Position pos, Node x) {
         long offset = pos.input ? inputOffsets[pos.index] : successorOffsets[pos.index];
         if (pos.subIndex == NOT_ITERABLE) {
             Node old = getNode(node,  offset);
--- a/graal/com.oracle.graal.loop/src/com/oracle/graal/loop/LoopPolicies.java	Wed Dec 19 15:52:50 2012 +0100
+++ b/graal/com.oracle.graal.loop/src/com/oracle/graal/loop/LoopPolicies.java	Thu Dec 20 12:06:58 2012 +0100
@@ -23,6 +23,7 @@
 package com.oracle.graal.loop;
 
 import com.oracle.graal.debug.*;
+import com.oracle.graal.graph.*;
 import com.oracle.graal.nodes.*;
 import com.oracle.graal.nodes.cfg.*;
 import com.oracle.graal.phases.*;
@@ -56,16 +57,24 @@
         return loop.loopBegin().unswitches() <= GraalOptions.LoopMaxUnswitch;
     }
 
-    public static boolean shouldUnswitch(LoopEx loop, IfNode ifNode) {
-        Block postDomBlock = loop.loopsData().controlFlowGraph().blockFor(ifNode).getPostdominator();
+    public static boolean shouldUnswitch(LoopEx loop, ControlSplitNode controlSplit) {
+        Block postDomBlock = loop.loopsData().controlFlowGraph().blockFor(controlSplit).getPostdominator();
         BeginNode postDom = postDomBlock != null ? postDomBlock.getBeginNode() : null;
-        int inTrueBranch = loop.nodesInLoopFrom(ifNode.trueSuccessor(), postDom).cardinality();
-        int inFalseBranch = loop.nodesInLoopFrom(ifNode.falseSuccessor(), postDom).cardinality();
         int loopTotal = loop.size();
-        int netDiff = loopTotal - (inTrueBranch + inFalseBranch);
-        double uncertainty = (0.5 - Math.abs(ifNode.probability(ifNode.trueSuccessor()) - 0.5)) * 2;
+        int inBranchTotal = 0;
+        double maxProbability = 0;
+        for (Node successor : controlSplit.successors()) {
+            BeginNode branch = (BeginNode) successor;
+            inBranchTotal += loop.nodesInLoopFrom(branch, postDom).cardinality(); //this may count twice because of fall-through in switches
+            double probability = controlSplit.probability(branch);
+            if (probability > maxProbability) {
+                maxProbability = probability;
+            }
+        }
+        int netDiff = loopTotal - (inBranchTotal);
+        double uncertainty = 1 - maxProbability;
         int maxDiff = GraalOptions.LoopUnswitchMaxIncrease + (int) (GraalOptions.LoopUnswitchUncertaintyBoost * loop.loopBegin().loopFrequency() * uncertainty);
-        Debug.log("shouldUnswitch(%s, %s) : delta=%d, max=%d, %.2f%% inside of if", loop, ifNode, netDiff, maxDiff, (double) (inTrueBranch + inFalseBranch) / loopTotal * 100);
+        Debug.log("shouldUnswitch(%s, %s) : delta=%d, max=%d, %.2f%% inside of branches", loop, controlSplit, netDiff, maxDiff, (double) (inBranchTotal) / loopTotal * 100);
         return netDiff <= maxDiff;
     }
 
--- a/graal/com.oracle.graal.loop/src/com/oracle/graal/loop/LoopTransformations.java	Wed Dec 19 15:52:50 2012 +0100
+++ b/graal/com.oracle.graal.loop/src/com/oracle/graal/loop/LoopTransformations.java	Thu Dec 20 12:06:58 2012 +0100
@@ -25,37 +25,16 @@
 import com.oracle.graal.api.code.*;
 import com.oracle.graal.api.meta.*;
 import com.oracle.graal.graph.*;
+import com.oracle.graal.graph.NodeClass.NodeClassIterator;
+import com.oracle.graal.graph.NodeClass.Position;
 import com.oracle.graal.nodes.*;
-import com.oracle.graal.nodes.spi.*;
-import com.oracle.graal.nodes.util.*;
+import com.oracle.graal.nodes.extended.*;
 import com.oracle.graal.phases.*;
 import com.oracle.graal.phases.common.*;
 
 
 public abstract class LoopTransformations {
     private static final int UNROLL_LIMIT = GraalOptions.FullUnrollMaxNodes * 2;
-    private static final SimplifierTool simplifier = new SimplifierTool() {
-        @Override
-        public TargetDescription target() {
-            return null;
-        }
-        @Override
-        public CodeCacheProvider runtime() {
-            return null;
-        }
-        @Override
-        public Assumptions assumptions() {
-            return null;
-        }
-        @Override
-        public void deleteBranch(FixedNode branch) {
-            branch.predecessor().replaceFirstSuccessor(branch, null);
-            GraphUtil.killCFG(branch);
-        }
-        @Override
-        public void addToWorkList(Node node) {
-        }
-    };
 
     private LoopTransformations() {
         // does not need to be instantiated
@@ -88,23 +67,31 @@
         }
     }
 
-    public static void unswitch(LoopEx loop, IfNode ifNode) {
-        // duplicate will be true case, original will be false case
-        loop.loopBegin().incUnswitches();
+    public static void unswitch(LoopEx loop, ControlSplitNode controlSplitNode) {
         LoopFragmentWhole originalLoop = loop.whole();
-        LoopFragmentWhole duplicateLoop = originalLoop.duplicate();
-        StructuredGraph graph = (StructuredGraph) ifNode.graph();
-        BeginNode tempBegin = graph.add(new BeginNode());
-        originalLoop.entryPoint().replaceAtPredecessor(tempBegin);
-        double takenProbability = ifNode.probability(ifNode.trueSuccessor());
-        IfNode newIf = graph.add(new IfNode(ifNode.condition(), duplicateLoop.entryPoint(), originalLoop.entryPoint(), takenProbability, ifNode.leafGraphId()));
-        tempBegin.setNext(newIf);
-        ifNode.setCondition(graph.unique(ConstantNode.forBoolean(false, graph)));
-        IfNode duplicateIf = duplicateLoop.getDuplicatedNode(ifNode);
-        duplicateIf.setCondition(graph.unique(ConstantNode.forBoolean(true, graph)));
-        ifNode.simplify(simplifier);
-        duplicateIf.simplify(simplifier);
-        // TODO (gd) probabilities need some amount of fixup.. (probably also in other transforms)
+        //create new control split out of loop
+        ControlSplitNode newControlSplit = (ControlSplitNode) controlSplitNode.copyWithInputs();
+        originalLoop.entryPoint().replaceAtPredecessor(newControlSplit);
+
+        NodeClassIterator successors = controlSplitNode.successors().iterator();
+        assert successors.hasNext();
+        //original loop is used as first successor
+        Position firstPosition = successors.nextPosition();
+        NodeClass controlSplitClass = controlSplitNode.getNodeClass();
+        controlSplitClass.set(newControlSplit, firstPosition, BeginNode.begin(originalLoop.entryPoint()));
+
+        StructuredGraph graph = (StructuredGraph) controlSplitNode.graph();
+        while (successors.hasNext()) {
+            Position position = successors.nextPosition();
+            // create a new loop duplicate, connect it and simplify it
+            LoopFragmentWhole duplicateLoop = originalLoop.duplicate();
+            controlSplitClass.set(newControlSplit, position, BeginNode.begin(duplicateLoop.entryPoint()));
+            ControlSplitNode duplicatedControlSplit = duplicateLoop.getDuplicatedNode(controlSplitNode);
+            graph.removeSplitPropagate(duplicatedControlSplit, (BeginNode) controlSplitClass.get(duplicatedControlSplit, position));
+        }
+        // original loop is simplified last to avoid deleting controlSplitNode too early
+        graph.removeSplitPropagate(controlSplitNode, (BeginNode) controlSplitClass.get(controlSplitNode, firstPosition));
+        //TODO (gd) probabilities need some amount of fixup.. (probably also in other transforms)
     }
 
     public static void unroll(LoopEx loop, int factor) {
@@ -128,12 +115,17 @@
         }
     }
 
-    public static IfNode findUnswitchableIf(LoopEx loop) {
+    public static ControlSplitNode findUnswitchable(LoopEx loop) {
         for (IfNode ifNode : loop.whole().nodes().filter(IfNode.class)) {
             if (loop.isOutsideLoop(ifNode.condition())) {
                 return ifNode;
             }
         }
+        for (SwitchNode switchNode : loop.whole().nodes().filter(SwitchNode.class)) {
+            if (switchNode.successors().count() > 1 && loop.isOutsideLoop(switchNode.value())) {
+                return switchNode;
+            }
+        }
         return null;
     }
 }
--- a/graal/com.oracle.graal.loop/src/com/oracle/graal/loop/phases/LoopTransformLowPhase.java	Wed Dec 19 15:52:50 2012 +0100
+++ b/graal/com.oracle.graal.loop/src/com/oracle/graal/loop/phases/LoopTransformLowPhase.java	Thu Dec 20 12:06:58 2012 +0100
@@ -23,6 +23,7 @@
 package com.oracle.graal.loop.phases;
 
 import com.oracle.graal.debug.*;
+import com.oracle.graal.graph.NodeClass.NodeClassIterator;
 import com.oracle.graal.loop.*;
 import com.oracle.graal.nodes.*;
 import com.oracle.graal.phases.*;
@@ -51,10 +52,12 @@
                     final LoopsData dataUnswitch = new LoopsData(graph);
                     for (LoopEx loop : dataUnswitch.loops()) {
                         if (LoopPolicies.shouldTryUnswitch(loop)) {
-                            IfNode ifNode = LoopTransformations.findUnswitchableIf(loop);
-                            if (ifNode != null && LoopPolicies.shouldUnswitch(loop, ifNode)) {
-                                Debug.log("Unswitching %s at %s [%f - %f]", loop, ifNode, ifNode.probability(ifNode.trueSuccessor()), ifNode.probability(ifNode.falseSuccessor()));
-                                LoopTransformations.unswitch(loop, ifNode);
+                            ControlSplitNode controlSplit = LoopTransformations.findUnswitchable(loop);
+                            if (controlSplit != null && LoopPolicies.shouldUnswitch(loop, controlSplit)) {
+                                if (Debug.isLogEnabled()) {
+                                    logUnswitch(loop, controlSplit);
+                                }
+                                LoopTransformations.unswitch(loop, controlSplit);
                                 UNSWITCHED.increment();
                                 Debug.dump(graph, "After unswitch %s", loop);
                                 unswitched = true;
@@ -66,4 +69,18 @@
             }
         }
     }
+
+    private static void logUnswitch(LoopEx loop, ControlSplitNode controlSplit) {
+        StringBuilder sb = new StringBuilder("Unswitching ");
+        sb.append(loop).append(" at ").append(controlSplit).append(" [");
+        NodeClassIterator it = controlSplit.successors().iterator();
+        while (it.hasNext()) {
+            sb.append(controlSplit.probability((BeginNode) it.next()));
+            if (it.hasNext()) {
+                sb.append(", ");
+            }
+        }
+        sb.append("]");
+        Debug.log(sb.toString());
+    }
 }
--- a/graal/com.oracle.graal.nodes/src/com/oracle/graal/nodes/extended/SwitchNode.java	Wed Dec 19 15:52:50 2012 +0100
+++ b/graal/com.oracle.graal.nodes/src/com/oracle/graal/nodes/extended/SwitchNode.java	Thu Dec 20 12:06:58 2012 +0100
@@ -39,10 +39,6 @@
     private double[] keyProbabilities;
     private int[] keySuccessors;
 
-    public ValueNode value() {
-        return value;
-    }
-
     /**
      * Constructs a new Switch.
      * @param value the instruction that provides the value to be switched over
@@ -69,6 +65,10 @@
         return sum;
     }
 
+    public ValueNode value() {
+        return value;
+    }
+
     /**
      * The number of distinct keys in this switch.
      */