changeset 23275:a886e9dc3a47

GraphPE: when exploding nested loops, keep exploding exits of inner loops separately so that they merge separately with the outermost loop; fix bugs in loop detection and make test cases more challenging
author Christian Wimmer <christian.wimmer@oracle.com>
date Fri, 08 Jan 2016 13:12:43 -0800
parents 50079bd51344
children d5a1109b239b 4ba737504681
files graal/com.oracle.graal.nodes/src/com/oracle/graal/nodes/GraphDecoder.java graal/com.oracle.graal.truffle.test/src/com/oracle/graal/truffle/test/BytecodeInterpreterPartialEvaluationTest.java
diffstat 2 files changed, 97 insertions(+), 117 deletions(-) [+]
line wrap: on
line diff
--- a/graal/com.oracle.graal.nodes/src/com/oracle/graal/nodes/GraphDecoder.java	Fri Jan 08 13:10:52 2016 -0800
+++ b/graal/com.oracle.graal.nodes/src/com/oracle/graal/nodes/GraphDecoder.java	Fri Jan 08 13:12:43 2016 -0800
@@ -30,10 +30,11 @@
 import java.util.BitSet;
 import java.util.Deque;
 import java.util.HashMap;
-import java.util.IdentityHashMap;
+import java.util.HashSet;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 
 import jdk.vm.ci.code.Architecture;
 import jdk.vm.ci.meta.ResolvedJavaType;
@@ -168,7 +169,7 @@
             this.createdNodes = new Node[nodeCount];
         }
 
-        protected LoopScope(LoopScope outer, int loopDepth, int loopIteration, int loopBeginOrderId, Node[] initialCreatedNodes, Deque<LoopScope> nextIterations,
+        protected LoopScope(LoopScope outer, int loopDepth, int loopIteration, int loopBeginOrderId, Node[] initialCreatedNodes, Node[] createdNodes, Deque<LoopScope> nextIterations,
                         Map<LoopExplosionState, LoopExplosionState> iterationStates) {
             this.outer = outer;
             this.loopDepth = loopDepth;
@@ -178,7 +179,7 @@
             this.loopBeginOrderId = loopBeginOrderId;
             this.nodesToProcess = new BitSet(initialCreatedNodes.length);
             this.initialCreatedNodes = initialCreatedNodes;
-            this.createdNodes = Arrays.copyOf(initialCreatedNodes, initialCreatedNodes.length);
+            this.createdNodes = Arrays.copyOf(createdNodes, createdNodes.length);
         }
 
         @Override
@@ -365,7 +366,22 @@
         LoopScope successorAddScope = loopScope;
         boolean updatePredecessors = true;
         if (node instanceof LoopExitNode) {
-            successorAddScope = loopScope.outer;
+            if (methodScope.loopExplosion == LoopExplosionKind.MERGE_EXPLODE && loopScope.loopDepth > 1) {
+                /*
+                 * We do not want to merge loop exits of inner loops. Instead, we want to keep
+                 * exploding the outer loop separately for every loop exit and then merge the outer
+                 * loop. Therefore, we create a new LoopScope of the outer loop for every loop exit
+                 * of the inner loop.
+                 */
+                LoopScope outerScope = loopScope.outer;
+                int nextIterationNumber = outerScope.nextIterations.isEmpty() ? outerScope.loopIteration + 1 : outerScope.nextIterations.getLast().loopIteration + 1;
+                successorAddScope = new LoopScope(outerScope.outer, outerScope.loopDepth, nextIterationNumber, outerScope.loopBeginOrderId, outerScope.initialCreatedNodes,
+                                loopScope.initialCreatedNodes, outerScope.nextIterations, outerScope.iterationStates);
+                checkLoopExplosionIteration(methodScope, successorAddScope);
+                outerScope.nextIterations.addLast(successorAddScope);
+            } else {
+                successorAddScope = loopScope.outer;
+            }
             updatePredecessors = methodScope.loopExplosion == LoopExplosionKind.NONE;
         }
 
@@ -384,7 +400,7 @@
 
         } else if (node instanceof LoopExitNode) {
             if (methodScope.loopExplosion != LoopExplosionKind.NONE) {
-                handleLoopExplosionProxyNodes(methodScope, loopScope, (LoopExitNode) node, nodeOrderId);
+                handleLoopExplosionProxyNodes(methodScope, loopScope, successorAddScope, (LoopExitNode) node, nodeOrderId);
             } else {
                 handleProxyNodes(methodScope, loopScope, (LoopExitNode) node);
             }
@@ -405,7 +421,7 @@
 
                 if (merge instanceof LoopBeginNode) {
                     assert phiNodeScope == phiInputScope && phiNodeScope == loopScope;
-                    resultScope = new LoopScope(loopScope, loopScope.loopDepth + 1, 0, mergeOrderId, Arrays.copyOf(loopScope.createdNodes, loopScope.createdNodes.length), //
+                    resultScope = new LoopScope(loopScope, loopScope.loopDepth + 1, 0, mergeOrderId, Arrays.copyOf(loopScope.createdNodes, loopScope.createdNodes.length), loopScope.createdNodes, //
                                     methodScope.loopExplosion != LoopExplosionKind.NONE ? new ArrayDeque<>() : null, //
                                     methodScope.loopExplosion == LoopExplosionKind.MERGE_EXPLODE ? new HashMap<>() : null);
                     phiInputScope = resultScope;
@@ -530,7 +546,7 @@
         if (methodScope.loopExplosion != LoopExplosionKind.FULL_UNROLL || loopScope.nextIterations.isEmpty()) {
             int nextIterationNumber = loopScope.nextIterations.isEmpty() ? loopScope.loopIteration + 1 : loopScope.nextIterations.getLast().loopIteration + 1;
             LoopScope nextIterationScope = new LoopScope(loopScope.outer, loopScope.loopDepth, nextIterationNumber, loopScope.loopBeginOrderId, loopScope.initialCreatedNodes,
-                            loopScope.nextIterations, loopScope.iterationStates);
+                            loopScope.initialCreatedNodes, loopScope.nextIterations, loopScope.iterationStates);
             checkLoopExplosionIteration(methodScope, nextIterationScope);
             loopScope.nextIterations.addLast(nextIterationScope);
             registerNode(nextIterationScope, loopScope.loopBeginOrderId, null, true, true);
@@ -567,7 +583,7 @@
         }
     }
 
-    protected void handleLoopExplosionProxyNodes(MethodScope methodScope, LoopScope loopScope, LoopExitNode loopExit, int loopExitOrderId) {
+    protected void handleLoopExplosionProxyNodes(MethodScope methodScope, LoopScope loopScope, LoopScope outerScope, LoopExitNode loopExit, int loopExitOrderId) {
         assert loopExit.stateAfter() == null;
         int stateAfterOrderId = readOrderId(methodScope);
 
@@ -582,16 +598,16 @@
          * handle the merge.
          */
         MergeNode merge = null;
-        Node existingExit = lookupNode(loopScope.outer, loopExitOrderId);
+        Node existingExit = lookupNode(outerScope, loopExitOrderId);
         if (existingExit == null) {
             /* First loop iteration that exits. No merge necessary yet. */
-            registerNode(loopScope.outer, loopExitOrderId, begin, false, false);
+            registerNode(outerScope, loopExitOrderId, begin, false, false);
             begin.setNext(loopExitSuccessor);
 
         } else if (existingExit instanceof BeginNode) {
             /* Second loop iteration that exits. Create the merge. */
             merge = methodScope.graph.add(new MergeNode());
-            registerNode(loopScope.outer, loopExitOrderId, merge, true, false);
+            registerNode(outerScope, loopExitOrderId, merge, true, false);
             /* Add the first iteration. */
             EndNode firstEnd = methodScope.graph.add(new EndNode());
             ((BeginNode) existingExit).setNext(firstEnd);
@@ -622,13 +638,13 @@
             ValueNode phiInput = proxy.value();
             ValueNode replacement;
 
-            ValueNode existing = (ValueNode) loopScope.outer.createdNodes[proxyOrderId];
+            ValueNode existing = (ValueNode) outerScope.createdNodes[proxyOrderId];
             if (existing == null || existing == phiInput) {
                 /*
                  * We are at the first loop exit, or the proxy carries the same value for all exits.
                  * We do not need a phi node yet.
                  */
-                registerNode(loopScope.outer, proxyOrderId, phiInput, true, false);
+                registerNode(outerScope, proxyOrderId, phiInput, true, false);
                 replacement = phiInput;
 
             } else if (!merge.isPhiAtMerge(existing)) {
@@ -640,7 +656,7 @@
                 }
                 /* Add the input from this exit. */
                 phi.addInput(phiInput);
-                registerNode(loopScope.outer, proxyOrderId, phi, true, false);
+                registerNode(outerScope, proxyOrderId, phi, true, false);
                 replacement = phi;
                 phiCreated = true;
 
@@ -656,8 +672,8 @@
 
         if (merge != null && (merge.stateAfter() == null || phiCreated)) {
             FrameState oldStateAfter = merge.stateAfter();
-            registerNode(loopScope.outer, stateAfterOrderId, null, true, true);
-            merge.setStateAfter((FrameState) ensureNodeCreated(methodScope, loopScope.outer, stateAfterOrderId));
+            registerNode(outerScope, stateAfterOrderId, null, true, true);
+            merge.setStateAfter((FrameState) ensureNodeCreated(methodScope, outerScope, stateAfterOrderId));
             if (oldStateAfter != null) {
                 oldStateAfter.safeDelete();
             }
@@ -1015,11 +1031,13 @@
      */
 
     protected void detectLoops(StructuredGraph currentGraph, FixedNode startInstruction) {
+        Debug.dump(currentGraph, "Before detectLoops");
         NodeBitMap visited = currentGraph.createNodeBitMap();
         NodeBitMap active = currentGraph.createNodeBitMap();
         Deque<Node> stack = new ArrayDeque<>();
         stack.add(startInstruction);
         visited.mark(startInstruction);
+        Set<LoopBeginNode> newLoopBegins = new HashSet<>();
         while (!stack.isEmpty()) {
             Node next = stack.peek();
             assert next.isDeleted() || visited.isMarked(next);
@@ -1046,6 +1064,7 @@
                             LoopBeginNode newLoopBegin = appendLoopBegin(currentGraph, merge);
                             newLoopBegin.setNext(afterMerge);
                             newLoopBegin.setStateAfter(stateAfter);
+                            newLoopBegins.add(newLoopBegin);
                         }
                         LoopBeginNode loopBegin = (LoopBeginNode) ((EndNode) merge.next()).merge();
                         LoopEndNode loopEnd = currentGraph.add(new LoopEndNode(loopBegin));
@@ -1060,7 +1079,9 @@
             }
         }
 
-        insertLoopEnds(currentGraph, startInstruction);
+        Debug.dump(currentGraph, "Before insertLoopEnds");
+        insertLoopEnds(currentGraph, startInstruction, newLoopBegins);
+        Debug.dump(currentGraph, "After detectLoops");
     }
 
     private static LoopBeginNode appendLoopBegin(StructuredGraph currentGraph, FixedWithNextNode fixedWithNext) {
@@ -1072,7 +1093,7 @@
         return loopBegin;
     }
 
-    private static void insertLoopEnds(StructuredGraph currentGraph, FixedNode startInstruction) {
+    private static void insertLoopEnds(StructuredGraph currentGraph, FixedNode startInstruction, Set<LoopBeginNode> newLoopBegins) {
         NodeBitMap visited = currentGraph.createNodeBitMap();
         Deque<Node> stack = new ArrayDeque<>();
         stack.add(startInstruction);
@@ -1081,7 +1102,7 @@
         while (!stack.isEmpty()) {
             Node next = stack.pop();
             assert visited.isMarked(next);
-            if (next instanceof LoopBeginNode) {
+            if (next instanceof LoopBeginNode && newLoopBegins.contains(next)) {
                 loopBegins.add((LoopBeginNode) next);
             }
             for (Node n : next.cfgSuccessors()) {
@@ -1094,10 +1115,9 @@
             }
         }
 
-        IdentityHashMap<LoopBeginNode, List<LoopBeginNode>> innerLoopsMap = new IdentityHashMap<>();
         for (int i = loopBegins.size() - 1; i >= 0; --i) {
             LoopBeginNode loopBegin = loopBegins.get(i);
-            insertLoopExits(currentGraph, loopBegin, innerLoopsMap);
+            insertLoopExits(currentGraph, loopBegin);
         }
 
         // Remove degenerated merges with only one predecessor.
@@ -1109,7 +1129,7 @@
         }
     }
 
-    private static void insertLoopExits(StructuredGraph currentGraph, LoopBeginNode loopBegin, IdentityHashMap<LoopBeginNode, List<LoopBeginNode>> innerLoopsMap) {
+    private static void insertLoopExits(StructuredGraph currentGraph, LoopBeginNode loopBegin) {
         NodeBitMap visited = currentGraph.createNodeBitMap();
         Deque<Node> stack = new ArrayDeque<>();
         for (LoopEndNode loopEnd : loopBegin.loopEnds()) {
@@ -1118,7 +1138,6 @@
         }
 
         List<ControlSplitNode> controlSplits = new ArrayList<>();
-        List<LoopBeginNode> innerLoopBegins = new ArrayList<>();
 
         while (!stack.isEmpty()) {
             Node current = stack.pop();
@@ -1135,7 +1154,6 @@
                         if (!visited.isMarked(innerLoopBegin)) {
                             stack.push(innerLoopBegin);
                             visited.mark(innerLoopBegin);
-                            innerLoopBegins.add(innerLoopBegin);
                         }
                     } else {
                         if (pred instanceof ControlSplitNode) {
@@ -1158,27 +1176,6 @@
                 }
             }
         }
-
-        for (LoopBeginNode inner : innerLoopBegins) {
-            addLoopExits(currentGraph, loopBegin, inner, innerLoopsMap, visited);
-        }
-
-        innerLoopsMap.put(loopBegin, innerLoopBegins);
-    }
-
-    private static void addLoopExits(StructuredGraph currentGraph, LoopBeginNode loopBegin, LoopBeginNode inner, IdentityHashMap<LoopBeginNode, List<LoopBeginNode>> innerLoopsMap, NodeBitMap visited) {
-        for (LoopExitNode exit : inner.loopExits()) {
-            if (!visited.isMarked(exit)) {
-                LoopExitNode newLoopExit = currentGraph.add(new LoopExitNode(loopBegin));
-                FixedNode next = exit.next();
-                next.replaceAtPredecessor(newLoopExit);
-                newLoopExit.setNext(next);
-            }
-        }
-
-        for (LoopBeginNode innerInner : innerLoopsMap.get(inner)) {
-            addLoopExits(currentGraph, loopBegin, innerInner, innerLoopsMap, visited);
-        }
     }
 
     /**
--- a/graal/com.oracle.graal.truffle.test/src/com/oracle/graal/truffle/test/BytecodeInterpreterPartialEvaluationTest.java	Fri Jan 08 13:10:52 2016 -0800
+++ b/graal/com.oracle.graal.truffle.test/src/com/oracle/graal/truffle/test/BytecodeInterpreterPartialEvaluationTest.java	Fri Jan 08 13:12:43 2016 -0800
@@ -22,17 +22,11 @@
  */
 package com.oracle.graal.truffle.test;
 
-import java.util.Random;
-
 import org.junit.After;
 import org.junit.Assert;
 import org.junit.Before;
-import org.junit.Ignore;
 import org.junit.Test;
 
-import com.oracle.graal.options.OptionValue;
-import com.oracle.graal.options.OptionValue.OverrideScope;
-import com.oracle.graal.truffle.PartialEvaluator;
 import com.oracle.truffle.api.CompilerAsserts;
 import com.oracle.truffle.api.CompilerDirectives;
 import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
@@ -73,6 +67,28 @@
 
     public static boolean TRACE = false;
 
+    /*
+     * A method with a non-exploded loop, which goes away after loop unrolling as long as the
+     * parameter is a compilation constant. The method is called from multiple places to inject
+     * additional loops into the test cases, i.e., to stress the partial evaluator and compiler
+     * optimizations.
+     */
+    static int nonExplodedLoop(int x) {
+        if (x >= 0 && x < 50) {
+            int result = 0;
+            for (int i = 0; i < x; i++) {
+                result++;
+                if (result > 100) {
+                    /* Dead branch because result < 50, just to complicate the loop structure. */
+                    CompilerDirectives.transferToInterpreter();
+                }
+            }
+            return result;
+        } else {
+            return x;
+        }
+    }
+
     public static class Program extends RootNode {
         private final String name;
         @CompilationFinal private final byte[] bytecodes;
@@ -124,7 +140,9 @@
             trace("Start program");
             int topOfStack = -1;
             int bci = 0;
-            outer: while (true) {
+            int result = 0;
+            boolean running = true;
+            outer: while (running) {
                 CompilerAsserts.partialEvaluationConstant(bci);
                 byte bc = bytecodes[bci];
                 byte value = 0;
@@ -132,13 +150,15 @@
                     case Bytecode.CONST:
                         value = bytecodes[bci + 1];
                         trace("%d (%d): CONST %s", bci, topOfStack, value);
-                        setInt(frame, ++topOfStack, value);
+                        setInt(frame, ++topOfStack, nonExplodedLoop(value));
                         bci = bci + 2;
                         continue;
 
                     case Bytecode.RETURN:
                         trace("%d (%d): RETURN", bci, topOfStack);
-                        return getInt(frame, topOfStack);
+                        result = nonExplodedLoop(getInt(frame, topOfStack));
+                        running = false;
+                        continue;
 
                     case Bytecode.ADD: {
                         int left = getInt(frame, topOfStack);
@@ -167,18 +187,11 @@
                         for (int i = 0; i < value; ++i) {
                             if (switchValue == i) {
                                 bci = bytecodes[bci + i + 2];
-                                // Need this seemingly useless condition here for two reasons:
-                                // 1. Bytecode analysis will consider the current block as
-                                // being within the inner loop.
-                                // 2. The if body will be an empty block that directly
-                                // jumps to the begin of the outer loop.
-                                if (i != -1) {
-                                    continue outer;
-                                }
+                                continue outer;
                             }
                         }
                         // Continue with the code after the switch.
-                        bci += value + 1;
+                        bci += value + 2;
                         continue;
 
                     case Bytecode.POP:
@@ -200,6 +213,7 @@
                         continue;
                 }
             }
+            return nonExplodedLoop(result);
         }
     }
 
@@ -391,7 +405,7 @@
 
             @Override
             public boolean execute(VirtualFrame frame) {
-                frame.setInt(slot, value);
+                frame.setInt(slot, nonExplodedLoop(value));
                 return true;
             }
 
@@ -512,7 +526,7 @@
                     ip = inst[ip].getFalseSucc();
                 }
             }
-            return FrameUtil.getIntSafe(frame, returnSlot);
+            return nonExplodedLoop(FrameUtil.getIntSafe(frame, returnSlot));
         }
     }
 
@@ -531,55 +545,6 @@
         assertPartialEvalEqualsAndRunsCorrect(new InstArrayProgram("instArraySimpleIfProgram", inst, returnSlot, fd));
     }
 
-    /**
-     * Slightly modified version to expose a partial evaluation bug with ExplodeLoop(merge=true).
-     */
-    public static class InstArrayProgram2 extends InstArrayProgram {
-        public InstArrayProgram2(String name, Inst[] inst, FrameSlot returnSlot, FrameDescriptor fd) {
-            super(name, inst, returnSlot, fd);
-        }
-
-        @Override
-        @ExplodeLoop(merge = true)
-        public Object execute(VirtualFrame frame) {
-            int ip = 0;
-            while (ip != -1) {
-                CompilerAsserts.partialEvaluationConstant(ip);
-                if (inst[ip].execute(frame)) {
-                    ip = inst[ip].getTrueSucc();
-                } else {
-                    ip = inst[ip].getFalseSucc();
-                }
-            }
-            if (frame.getArguments().length > 0) {
-                return new Random();
-            } else {
-                return FrameUtil.getIntSafe(frame, returnSlot);
-            }
-        }
-    }
-
-    @Ignore("produces a bad graph")
-    @Test
-    public void instArraySimpleIfProgram2() {
-        FrameDescriptor fd = new FrameDescriptor();
-        FrameSlot value1Slot = fd.addFrameSlot("value1", FrameSlotKind.Int);
-        FrameSlot value2Slot = fd.addFrameSlot("value2", FrameSlotKind.Int);
-        FrameSlot returnSlot = fd.addFrameSlot("return", FrameSlotKind.Int);
-        Inst[] inst = new Inst[]{
-        /* 0: */new Inst.Const(value1Slot, 100, 1),
-        /* 1: */new Inst.Const(value2Slot, 100, 2),
-        /* 2: */new Inst.IfLt(value1Slot, value2Slot, 3, 5),
-        /* 3: */new Inst.Const(returnSlot, 41, 4),
-        /* 4: */new Inst.Return(),
-        /* 5: */new Inst.Const(returnSlot, 42, 6),
-        /* 6: */new Inst.Return()};
-        InstArrayProgram program = new InstArrayProgram2("instArraySimpleIfProgram2", inst, returnSlot, fd);
-        program.execute(Truffle.getRuntime().createVirtualFrame(new Object[0], fd));
-        program.execute(Truffle.getRuntime().createVirtualFrame(new Object[1], fd));
-        assertPartialEvalEqualsAndRunsCorrect(program);
-    }
-
     @Test
     @SuppressWarnings("try")
     public void simpleSwitchProgram() {
@@ -600,8 +565,26 @@
         /* 13: */42,
         /* 14: */Bytecode.RETURN};
         Program program = new Program("simpleSwitchProgram", bytecodes, 0, 3);
-        try (OverrideScope s = OptionValue.override(PartialEvaluator.GraphPE, false)) {
-            assertPartialEvalEqualsAndRunsCorrect(program);
-        }
+        assertPartialEvalEqualsAndRunsCorrect(program);
+    }
+
+    @Test
+    @SuppressWarnings("try")
+    public void loopSwitchProgram() {
+        byte[] bytecodes = new byte[]{
+        /* 0: */Bytecode.CONST,
+        /* 1: */1,
+        /* 2: */Bytecode.SWITCH,
+        /* 3: */2,
+        /* 4: */0,
+        /* 5: */9,
+        /* 6: */Bytecode.CONST,
+        /* 7: */40,
+        /* 8: */Bytecode.RETURN,
+        /* 9: */Bytecode.CONST,
+        /* 10: */42,
+        /* 11: */Bytecode.RETURN};
+        Program program = new Program("loopSwitchProgram", bytecodes, 0, 3);
+        assertPartialEvalEqualsAndRunsCorrect(program);
     }
 }