changeset 3210:27ae76ed33ca

Finish implementation of loop inversion
author Gilles Duboscq <gilles.duboscq@oracle.com>
date Tue, 12 Jul 2011 13:10:33 +0200
parents ccd6318f294d
children 76507b87dd25 1e5ca59c8769
files graal/com.oracle.max.graal.compiler/src/com/oracle/max/graal/compiler/GraalMetrics.java graal/com.oracle.max.graal.compiler/src/com/oracle/max/graal/compiler/phases/LoopPhase.java graal/com.oracle.max.graal.compiler/src/com/oracle/max/graal/compiler/util/LoopUtil.java
diffstat 3 files changed, 178 insertions(+), 130 deletions(-) [+]
line wrap: on
line diff
--- a/graal/com.oracle.max.graal.compiler/src/com/oracle/max/graal/compiler/GraalMetrics.java	Tue Jul 12 13:10:11 2011 +0200
+++ b/graal/com.oracle.max.graal.compiler/src/com/oracle/max/graal/compiler/GraalMetrics.java	Tue Jul 12 13:10:33 2011 +0200
@@ -63,6 +63,7 @@
     public static int FrameStateValuesCreated;
     public static int NodesCanonicalized;
     public static int LoopsPeeled;
+    public static int LoopsInverted;
 
     public static void print() {
         for (Entry<String, GraalMetrics> m : map.entrySet()) {
--- a/graal/com.oracle.max.graal.compiler/src/com/oracle/max/graal/compiler/phases/LoopPhase.java	Tue Jul 12 13:10:11 2011 +0200
+++ b/graal/com.oracle.max.graal.compiler/src/com/oracle/max/graal/compiler/phases/LoopPhase.java	Tue Jul 12 13:10:33 2011 +0200
@@ -61,7 +61,18 @@
                         continue peeling;
                     }
                 }
-                LoopUtil.peelLoop(loop);
+                boolean canInvert = false;
+                if (loop.loopBegin().next() instanceof If) {
+                    If ifNode = (If) loop.loopBegin().next();
+                    if (loop.exits().isMarked(ifNode.trueSuccessor()) || loop.exits().isMarked(ifNode.falseSuccessor())) {
+                        canInvert = true;
+                    }
+                }
+                if (canInvert) {
+                    LoopUtil.inverseLoop(loop, (If) loop.loopBegin().next());
+                } else {
+                    LoopUtil.peelLoop(loop);
+                }
             }
         } else {
 //            loops = LoopUtil.computeLoops(graph); // TODO (gd) avoid recomputing loops
@@ -161,7 +172,7 @@
                             useCounterAfterAdd = true;
                         }
                     }
-                    if (stride != null && !loopNodes.isNew(stride) &&  !loopNodes.isMarked(stride)) {
+                    if (stride != null && !loopNodes.isNotNewNotMarked(stride)) {
                         Graph graph = loopBegin.graph();
                         LoopCounter counter = new LoopCounter(init.kind, init, stride, loopBegin, graph);
                         counters.add(counter);
--- a/graal/com.oracle.max.graal.compiler/src/com/oracle/max/graal/compiler/util/LoopUtil.java	Tue Jul 12 13:10:11 2011 +0200
+++ b/graal/com.oracle.max.graal.compiler/src/com/oracle/max/graal/compiler/util/LoopUtil.java	Tue Jul 12 13:10:33 2011 +0200
@@ -120,7 +120,8 @@
         public final NodeMap<Node> phiInits;
         public final NodeMap<Node> dataOut;
         public final NodeBitMap exitFrameStates;
-        public PeelingResult(FixedNode begin, FixedNode end, NodeMap<StateSplit> exits, NodeMap<Placeholder> phis, NodeMap<Node> phiInits, NodeMap<Node> dataOut, NodeBitMap unaffectedExits, NodeBitMap exitFrameStates) {
+        public final NodeBitMap peeledNodes;
+        public PeelingResult(FixedNode begin, FixedNode end, NodeMap<StateSplit> exits, NodeMap<Placeholder> phis, NodeMap<Node> phiInits, NodeMap<Node> dataOut, NodeBitMap unaffectedExits, NodeBitMap exitFrameStates, NodeBitMap peeledNodes) {
             this.begin = begin;
             this.end = end;
             this.exits = exits;
@@ -129,6 +130,7 @@
             this.dataOut = dataOut;
             this.unaffectedExits = unaffectedExits;
             this.exitFrameStates = exitFrameStates;
+            this.peeledNodes = peeledNodes;
         }
     }
 
@@ -167,82 +169,13 @@
         return exits;
     }
 
-    public static NodeBitMap computeLoopNodes(LoopBegin loopBegin) {
-        return computeLoopNodesFrom(loopBegin, loopBegin.loopEnd());
-    }
     private static boolean recurse = false;
-    public static NodeBitMap computeLoopNodesFrom(LoopBegin loopBegin, FixedNode from) {
-        NodeFlood workData1 = loopBegin.graph().createNodeFlood();
-        NodeFlood workData2 = loopBegin.graph().createNodeFlood();
+    public static NodeBitMap computeLoopNodesFrom(Loop loop, FixedNode from) {
+        LoopBegin loopBegin = loop.loopBegin();
         NodeBitMap loopNodes = markUpCFG(loopBegin, from);
         loopNodes.mark(loopBegin);
-        for (Node n : loopNodes) {
-            workData1.add(n);
-            workData2.add(n);
-        }
-        NodeBitMap inOrAfter = loopBegin.graph().createNodeBitMap();
-        for (Node n : workData1) {
-            markWithState(n, inOrAfter);
-            for (Node usage : n.dataUsages()) {
-                if (usage instanceof Phi) { // filter out data graph cycles
-                    Phi phi = (Phi) usage;
-                    if (phi.type() == PhiType.Value) {
-                        Merge merge = phi.merge();
-                        if (merge instanceof LoopBegin) {
-                            LoopBegin phiLoop = (LoopBegin) merge;
-                            int backIndex = phiLoop.phiPredecessorIndex(phiLoop.loopEnd());
-                            if (phi.valueAt(backIndex) == n) {
-                                continue;
-                            }
-                        }
-                    }
-                }
-                workData1.add(usage);
-            }
-        }
-        NodeBitMap inOrBefore = loopBegin.graph().createNodeBitMap();
-        for (Node n : workData2) {
-            //markWithState(n, inOrBefore);
-            inOrBefore.mark(n);
-            if (n instanceof Phi) { // filter out data graph cycles
-                Phi phi = (Phi) n;
-                if (phi.type() == PhiType.Value) {
-                    int backIndex = -1;
-                    Merge merge = phi.merge();
-                    if (!loopNodes.isMarked(merge) && merge instanceof LoopBegin) {
-                        LoopBegin phiLoop = (LoopBegin) merge;
-                        backIndex = phiLoop.phiPredecessorIndex(phiLoop.loopEnd());
-                    }
-                    for (int i = 0; i < phi.valueCount(); i++) {
-                        if (i != backIndex) {
-                            workData2.add(phi.valueAt(i));
-                        }
-                    }
-                }
-            } else {
-                for (Node input : n.dataInputs()) {
-                    workData2.add(input);
-                }
-            }
-            if (n instanceof Merge) { //add phis & counters
-                for (Node usage : n.dataUsages()) {
-                    if (!(usage instanceof LoopEnd)) {
-                        workData2.add(usage);
-                    }
-                }
-            }
-            if (n instanceof AbstractVectorNode) {
-                for (Node usage : n.usages()) {
-                    workData2.add(usage);
-                }
-            }
-            if (n instanceof StateSplit) {
-                FrameState stateAfter = ((StateSplit) n).stateAfter();
-                if (stateAfter != null) {
-                    workData2.add(stateAfter);
-                }
-            }
-        }
+        NodeBitMap inOrAfter = inOrAfter(loop, loopNodes, false);
+        NodeBitMap inOrBefore = inOrBefore(loop, inOrAfter, loopNodes, false);
         /*if (!recurse) {
             recurse = true;
             GraalCompilation compilation = GraalCompilation.compilation();
@@ -257,6 +190,9 @@
         }*/
         inOrAfter.setIntersect(inOrBefore);
         loopNodes.setUnion(inOrAfter);
+        if (from == loopBegin.loopEnd()) { // fill the Loop cache value for loop nodes this is correct even if after/before were partial
+            loop.nodes = loopNodes;
+        }
         return loopNodes;
     }
 
@@ -292,8 +228,64 @@
         assert loop.cfgNodes().isMarked(noExit);
 
         PeelingResult peeling = preparePeeling(loop, split);
-        rewirePeeling(peeling, loop, split);
-        // TODO (gd) move peeled part to the end, rewire dataOut
+        rewirePeeling(peeling, loop, split, true);
+
+        // move peeled part to the end
+        LoopBegin loopBegin = loop.loopBegin();
+        LoopEnd loopEnd = loopBegin.loopEnd();
+        FixedNode lastNode = (FixedNode) loopEnd.singlePredecessor();
+        Graph graph = loopBegin.graph();
+        if (loopBegin.next() != lastNode) {
+            lastNode.successors().replace(loopEnd, loopBegin.next());
+            loopBegin.setNext(noExit);
+            split.successors().replace(noExit, loopEnd);
+        }
+
+        // rewire dataOut
+        NodeBitMap exitMergesPhis = graph.createNodeBitMap();
+        for (Entry<Node, StateSplit> entry : peeling.exits.entries()) {
+            StateSplit newExit = entry.getValue();
+            Merge merge = ((EndNode) newExit.next()).merge();
+            exitMergesPhis.markAll(merge.phis());
+        }
+        for (Entry<Node, Node> entry : peeling.dataOut.entries()) {
+            Value originalValue = (Value) entry.getKey();
+            if (originalValue instanceof Phi && ((Phi) originalValue).merge() == loopBegin) {
+                continue;
+            }
+            Value newValue = (Value) entry.getValue();
+            Phi phi = null;
+            List<Node> usages = new ArrayList<Node>(originalValue.usages());
+            for (Node usage : usages) {
+                if (exitMergesPhis.isMarked(usage) || (
+                                loop.nodes().isNotNewMarked(usage)
+                                && peeling.peeledNodes.isNotNewNotMarked(usage)
+                                && !(usage instanceof Phi && ((Phi) usage).merge() == loopBegin))
+                                && !(usage instanceof FrameState && ((FrameState) usage).block() == loopBegin)) {
+                    if (phi == null) {
+                        phi = new Phi(originalValue.kind, loopBegin, PhiType.Value, graph);
+                        phi.addInput(newValue);
+                        phi.addInput(originalValue);
+                    }
+                    usage.inputs().replace(originalValue, phi);
+                }
+            }
+        }
+        //rewire phi usage in peeled part
+        int backIndex = loopBegin.phiPredecessorIndex(loopBegin.loopEnd());
+        for (Phi phi : loopBegin.phis()) {
+            Value backValue = phi.valueAt(backIndex);
+            if (loop.nodes().isMarked(backValue) && peeling.peeledNodes.isNotNewNotMarked(backValue)) {
+                List<Node> usages = new ArrayList<Node>(phi.usages());
+                for (Node usage : usages) {
+                    if (peeling.peeledNodes.isNotNewMarked(usage)) {
+                        usage.inputs().replace(phi, backValue);
+                    }
+                }
+            }
+        }
+
+        GraalMetrics.LoopsInverted++;
     }
 
     public static void peelLoop(Loop loop) {
@@ -322,14 +314,14 @@
         for (Entry<Node, Node> entry : peeling.dataOut.entries()) {
             System.out.println("  - " + entry.getKey() + " -> " + entry.getValue());
         }*/
-        rewirePeeling(peeling, loop, loopEnd);
+        rewirePeeling(peeling, loop, loopEnd, false);
         /*if (compilation.compiler.isObserved()) {
             compilation.compiler.fireCompilationEvent(new CompilationEvent(compilation, "After rewirePeeling", loopEnd.graph(), true, false));
         }*/
         // update parents
         Loop parent = loop.parent();
         while (parent != null) {
-            parent.cfgNodes = computeLoopNodes(parent.loopBegin);
+            parent.cfgNodes = markUpCFG(parent.loopBegin, parent.loopBegin.loopEnd());
             parent.invalidateCached();
             parent.exits = computeLoopExits(parent.loopBegin, parent.cfgNodes);
             parent = parent.parent;
@@ -337,7 +329,7 @@
         GraalMetrics.LoopsPeeled++;
     }
 
-    private static void rewirePeeling(PeelingResult peeling, Loop loop, FixedNode from) {
+    private static void rewirePeeling(PeelingResult peeling, Loop loop, FixedNode from, boolean inversion) {
         LoopBegin loopBegin = loop.loopBegin();
         Graph graph = loopBegin.graph();
         Node loopPred = loopBegin.singlePredecessor();
@@ -407,6 +399,7 @@
             exitPoints.add(newExit);
         }
 
+        int phiBackIndex = loopBegin.phiPredecessorIndex(loopBegin.loopEnd());
         for (Entry<Node, StateSplit> entry : peeling.exits.entries()) {
             StateSplit original = (StateSplit) entry.getKey();
             EndNode oEnd = (EndNode) original.next();
@@ -421,7 +414,18 @@
                     phiMap = graph.createNodeMap();
                     newExitValues.set(originalValue, phiMap);
                 }
-                phiMap.set(original, (Value) originalValue);
+                Value backValue = null;
+                if (inversion && originalValue instanceof Phi && ((Phi) originalValue).merge() == loopBegin) {
+                    backValue = ((Phi) originalValue).valueAt(phiBackIndex);
+                    if (!loop.nodes().isMarked(backValue) || peeling.peeledNodes.isNotNewMarked(backValue)) {
+                        backValue = null;
+                    }
+                }
+                if (backValue != null) {
+                    phiMap.set(original, backValue);
+                } else {
+                    phiMap.set(original, (Value) originalValue);
+                }
                 phiMap.set(newExit, (Value) newValue);
             }
         }
@@ -551,8 +555,8 @@
                 System.out.println(" - !inOrBefore : " + (!inOrBefore.isNew(n) && !inOrBefore.isMarked(n)));
                 System.out.println(" - inputs > 0 : " + (n.inputs().size() > 0));
                 System.out.println(" - !danglingMergeFrameState : " + (!danglingMergeFrameState(n)));*/
-                return (!exitFrameStates.isNew(n) && exitFrameStates.isMarked(n))
-                || (!inOrBefore.isNew(n) && !inOrBefore.isMarked(n) && n.inputs().size() > 0 && !afterColoringFramestate(n)); //TODO (gd) hum
+                return exitFrameStates.isNotNewMarked(n)
+                || (inOrBefore.isNotNewNotMarked(n) && n.inputs().size() > 0 && !afterColoringFramestate(n)); //TODO (gd) hum
             }
             public boolean afterColoringFramestate(Node n) {
                 if (!(n instanceof FrameState)) {
@@ -634,7 +638,7 @@
     private static PeelingResult preparePeeling(Loop loop, FixedNode from) {
         LoopBegin loopBegin = loop.loopBegin();
         Graph graph = loopBegin.graph();
-        NodeBitMap marked = computeLoopNodesFrom(loopBegin, from);
+        NodeBitMap marked = computeLoopNodesFrom(loop, from);
         GraalCompilation compilation = GraalCompilation.compilation();
         /*if (compilation.compiler.isObserved()) {
             Map<String, Object> debug = new HashMap<String, Object>();
@@ -653,7 +657,7 @@
         NodeBitMap exitFrameStates = graph.createNodeBitMap();
         for (Node exit : loop.exits()) {
             if (marked.isMarked(exit.singlePredecessor())) {
-                StateSplit pExit = findNearestMergableExitPoint(exit, marked);
+                StateSplit pExit = findNearestMergableExitPoint((FixedNode) exit, marked);
                 markWithState(pExit, marked);
                 clonedExits.mark(pExit);
                 FrameState stateAfter = pExit.stateAfter();
@@ -722,44 +726,61 @@
         }
 
         NodeMap<Node> phiInits = graph.createNodeMap();
-        if (from == loopBegin.loopEnd()) {
-            int backIndex = loopBegin.phiPredecessorIndex(loopBegin.loopEnd());
-            int fowardIndex = loopBegin.phiPredecessorIndex(loopBegin.forwardEdge());
-            for (Phi phi : loopBegin.phis()) {
-                Value backValue = phi.valueAt(backIndex);
-                if (marked.isMarked(backValue)) {
-                    phiInits.set(phi, duplicates.get(backValue));
-                } else if (backValue instanceof Phi && ((Phi) backValue).merge() == loopBegin) {
-                    Phi backPhi = (Phi) backValue;
-                    phiInits.set(phi, backPhi.valueAt(fowardIndex));
-                } else {
-                    phiInits.set(phi, backValue);
-                }
+        int backIndex = loopBegin.phiPredecessorIndex(loopBegin.loopEnd());
+        int fowardIndex = loopBegin.phiPredecessorIndex(loopBegin.forwardEdge());
+        for (Phi phi : loopBegin.phis()) {
+            Value backValue = phi.valueAt(backIndex);
+            if (marked.isMarked(backValue)) {
+                phiInits.set(phi, duplicates.get(backValue));
+            } else if (from == loopBegin.loopEnd() && backValue instanceof Phi && ((Phi) backValue).merge() == loopBegin) {
+                Phi backPhi = (Phi) backValue;
+                phiInits.set(phi, backPhi.valueAt(fowardIndex));
             }
         }
 
         FixedNode newBegin = (FixedNode) duplicates.get(loopBegin.next());
         FixedNode newFrom = (FixedNode) duplicates.get(from == loopBegin.loopEnd() ? from.singlePredecessor() : from);
-        return new PeelingResult(newBegin, newFrom, exits, phis, phiInits, dataOutMapping, unaffectedExits, exitFrameStates);
+        return new PeelingResult(newBegin, newFrom, exits, phis, phiInits, dataOutMapping, unaffectedExits, exitFrameStates, marked);
     }
 
-    private static StateSplit findNearestMergableExitPoint(Node exit, NodeBitMap marked) {
-        // TODO (gd) find appropriate point : will be useful if a loop exit goes "up" as a result of making a branch dead in the loop body
-        return (StateSplit) exit;
+    private static StateSplit findNearestMergableExitPoint(FixedNode exit, NodeBitMap marked) {
+
+        LinkedList<FixedNode> branches = new LinkedList<FixedNode>();
+        branches.add(exit);
+        while (true) {
+            if (branches.size() == 1) {
+                final FixedNode fixedNode = branches.get(0);
+                if (fixedNode instanceof StateSplit && ((StateSplit) fixedNode).stateAfter() != null) {
+                    return (StateSplit) fixedNode;
+                }
+            } else {
+                FixedNode current = branches.poll();
+                // TODO (gd) find appropriate point : will be useful if a loop exit goes "up" as a result of making a branch dead in the loop body
+            }
+        }
     }
 
     private static NodeBitMap inOrAfter(Loop loop) {
+        return inOrAfter(loop, loop.cfgNodes());
+    }
+
+    private static NodeBitMap inOrAfter(Loop loop, NodeBitMap cfgNodes) {
+        return inOrAfter(loop, cfgNodes, true);
+    }
+
+    private static NodeBitMap inOrAfter(Loop loop, NodeBitMap cfgNodes, boolean full) {
         Graph graph = loop.loopBegin().graph();
         NodeBitMap inOrAfter = graph.createNodeBitMap();
         NodeFlood work = graph.createNodeFlood();
-        NodeBitMap loopNodes = loop.cfgNodes();
-        work.addAll(loopNodes);
+        work.addAll(cfgNodes);
         for (Node n : work) {
             //inOrAfter.mark(n);
             markWithState(n, inOrAfter);
-            for (Node sux : n.successors()) {
-                if (sux != null) {
-                    work.add(sux);
+            if (full) {
+                for (Node sux : n.successors()) {
+                    if (sux != null) {
+                        work.add(sux);
+                    }
                 }
             }
             for (Node usage : n.usages()) {
@@ -780,23 +801,36 @@
         return inOrAfter;
     }
 
+    private static NodeBitMap inOrBefore(Loop loop) {
+        return inOrBefore(loop, inOrAfter(loop));
+    }
+
     private static NodeBitMap inOrBefore(Loop loop, NodeBitMap inOrAfter) {
+        return inOrBefore(loop, inOrAfter, loop.cfgNodes());
+    }
+
+    private static NodeBitMap inOrBefore(Loop loop, NodeBitMap inOrAfter, NodeBitMap cfgNodes) {
+        return inOrBefore(loop, inOrAfter, cfgNodes, true);
+    }
+
+    private static NodeBitMap inOrBefore(Loop loop, NodeBitMap inOrAfter, NodeBitMap cfgNodes, boolean full) {
         Graph graph = loop.loopBegin().graph();
         NodeBitMap inOrBefore = graph.createNodeBitMap();
         NodeFlood work = graph.createNodeFlood();
-        NodeBitMap loopNodes = loop.cfgNodes();
-        work.addAll(loopNodes);
+        work.addAll(cfgNodes);
         for (Node n : work) {
             inOrBefore.mark(n);
-            for (Node pred : n.predecessors()) {
-                work.add(pred);
+            if (full) {
+                for (Node pred : n.predecessors()) {
+                    work.add(pred);
+                }
             }
             if (n instanceof Phi) { // filter out data graph cycles
                 Phi phi = (Phi) n;
                 if (phi.type() == PhiType.Value) {
                     int backIndex = -1;
                     Merge merge = phi.merge();
-                    if (!loopNodes.isNew(merge) && !loopNodes.isMarked(merge) && merge instanceof LoopBegin) {
+                    if (merge instanceof LoopBegin && cfgNodes.isNotNewNotMarked(((LoopBegin) merge).loopEnd())) {
                         LoopBegin phiLoop = (LoopBegin) merge;
                         backIndex = phiLoop.phiPredecessorIndex(phiLoop.loopEnd());
                     }
@@ -812,9 +846,25 @@
                         work.add(in);
                     }
                 }
-                for (Node sux : n.cfgSuccessors()) { // go down into branches that are not 'inOfAfter'
-                    if (sux != null && !inOrAfter.isMarked(sux)) {
-                        work.add(sux);
+                if (full) {
+                    for (Node sux : n.cfgSuccessors()) { // go down into branches that are not 'inOfAfter'
+                        if (sux != null && !inOrAfter.isMarked(sux)) {
+                            work.add(sux);
+                        }
+                    }
+                    if (n instanceof LoopBegin && n != loop.loopBegin()) {
+                        Loop p = loop.parent;
+                        boolean isParent = false;
+                        while (p != null) {
+                            if (p.loopBegin() == n) {
+                                isParent = true;
+                                break;
+                            }
+                            p = p.parent;
+                        }
+                        if (!isParent) {
+                            work.add(((LoopBegin) n).loopEnd());
+                        }
                     }
                 }
                 if (n instanceof Merge) { //add phis & counters
@@ -829,20 +879,6 @@
                         work.add(usage);
                     }
                 }
-                if (n instanceof LoopBegin && n != loop.loopBegin()) {
-                    Loop p = loop.parent;
-                    boolean isParent = false;
-                    while (p != null) {
-                        if (p.loopBegin() == n) {
-                            isParent = true;
-                            break;
-                        }
-                        p = p.parent;
-                    }
-                    if (!isParent) {
-                        work.add(((LoopBegin) n).loopEnd());
-                    }
-                }
             }
         }
         return inOrBefore;