Mercurial > hg > graal-compiler
changeset 23275:a886e9dc3a47
GraphPE: when exploding nested loops, keep exploding exits of inner loops separately so that they merge separately with the outermost loop; fix bugs in loop detection and make test cases more challenging
author | Christian Wimmer <christian.wimmer@oracle.com> |
---|---|
date | Fri, 08 Jan 2016 13:12:43 -0800 |
parents | 50079bd51344 |
children | d5a1109b239b 4ba737504681 |
files | graal/com.oracle.graal.nodes/src/com/oracle/graal/nodes/GraphDecoder.java graal/com.oracle.graal.truffle.test/src/com/oracle/graal/truffle/test/BytecodeInterpreterPartialEvaluationTest.java |
diffstat | 2 files changed, 97 insertions(+), 117 deletions(-) [+] |
line wrap: on
line diff
--- a/graal/com.oracle.graal.nodes/src/com/oracle/graal/nodes/GraphDecoder.java Fri Jan 08 13:10:52 2016 -0800 +++ b/graal/com.oracle.graal.nodes/src/com/oracle/graal/nodes/GraphDecoder.java Fri Jan 08 13:12:43 2016 -0800 @@ -30,10 +30,11 @@ import java.util.BitSet; import java.util.Deque; import java.util.HashMap; -import java.util.IdentityHashMap; +import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Set; import jdk.vm.ci.code.Architecture; import jdk.vm.ci.meta.ResolvedJavaType; @@ -168,7 +169,7 @@ this.createdNodes = new Node[nodeCount]; } - protected LoopScope(LoopScope outer, int loopDepth, int loopIteration, int loopBeginOrderId, Node[] initialCreatedNodes, Deque<LoopScope> nextIterations, + protected LoopScope(LoopScope outer, int loopDepth, int loopIteration, int loopBeginOrderId, Node[] initialCreatedNodes, Node[] createdNodes, Deque<LoopScope> nextIterations, Map<LoopExplosionState, LoopExplosionState> iterationStates) { this.outer = outer; this.loopDepth = loopDepth; @@ -178,7 +179,7 @@ this.loopBeginOrderId = loopBeginOrderId; this.nodesToProcess = new BitSet(initialCreatedNodes.length); this.initialCreatedNodes = initialCreatedNodes; - this.createdNodes = Arrays.copyOf(initialCreatedNodes, initialCreatedNodes.length); + this.createdNodes = Arrays.copyOf(createdNodes, createdNodes.length); } @Override @@ -365,7 +366,22 @@ LoopScope successorAddScope = loopScope; boolean updatePredecessors = true; if (node instanceof LoopExitNode) { - successorAddScope = loopScope.outer; + if (methodScope.loopExplosion == LoopExplosionKind.MERGE_EXPLODE && loopScope.loopDepth > 1) { + /* + * We do not want to merge loop exits of inner loops. Instead, we want to keep + * exploding the outer loop separately for every loop exit and then merge the outer + * loop. Therefore, we create a new LoopScope of the outer loop for every loop exit + * of the inner loop. + */ + LoopScope outerScope = loopScope.outer; + int nextIterationNumber = outerScope.nextIterations.isEmpty() ? outerScope.loopIteration + 1 : outerScope.nextIterations.getLast().loopIteration + 1; + successorAddScope = new LoopScope(outerScope.outer, outerScope.loopDepth, nextIterationNumber, outerScope.loopBeginOrderId, outerScope.initialCreatedNodes, + loopScope.initialCreatedNodes, outerScope.nextIterations, outerScope.iterationStates); + checkLoopExplosionIteration(methodScope, successorAddScope); + outerScope.nextIterations.addLast(successorAddScope); + } else { + successorAddScope = loopScope.outer; + } updatePredecessors = methodScope.loopExplosion == LoopExplosionKind.NONE; } @@ -384,7 +400,7 @@ } else if (node instanceof LoopExitNode) { if (methodScope.loopExplosion != LoopExplosionKind.NONE) { - handleLoopExplosionProxyNodes(methodScope, loopScope, (LoopExitNode) node, nodeOrderId); + handleLoopExplosionProxyNodes(methodScope, loopScope, successorAddScope, (LoopExitNode) node, nodeOrderId); } else { handleProxyNodes(methodScope, loopScope, (LoopExitNode) node); } @@ -405,7 +421,7 @@ if (merge instanceof LoopBeginNode) { assert phiNodeScope == phiInputScope && phiNodeScope == loopScope; - resultScope = new LoopScope(loopScope, loopScope.loopDepth + 1, 0, mergeOrderId, Arrays.copyOf(loopScope.createdNodes, loopScope.createdNodes.length), // + resultScope = new LoopScope(loopScope, loopScope.loopDepth + 1, 0, mergeOrderId, Arrays.copyOf(loopScope.createdNodes, loopScope.createdNodes.length), loopScope.createdNodes, // methodScope.loopExplosion != LoopExplosionKind.NONE ? new ArrayDeque<>() : null, // methodScope.loopExplosion == LoopExplosionKind.MERGE_EXPLODE ? new HashMap<>() : null); phiInputScope = resultScope; @@ -530,7 +546,7 @@ if (methodScope.loopExplosion != LoopExplosionKind.FULL_UNROLL || loopScope.nextIterations.isEmpty()) { int nextIterationNumber = loopScope.nextIterations.isEmpty() ? loopScope.loopIteration + 1 : loopScope.nextIterations.getLast().loopIteration + 1; LoopScope nextIterationScope = new LoopScope(loopScope.outer, loopScope.loopDepth, nextIterationNumber, loopScope.loopBeginOrderId, loopScope.initialCreatedNodes, - loopScope.nextIterations, loopScope.iterationStates); + loopScope.initialCreatedNodes, loopScope.nextIterations, loopScope.iterationStates); checkLoopExplosionIteration(methodScope, nextIterationScope); loopScope.nextIterations.addLast(nextIterationScope); registerNode(nextIterationScope, loopScope.loopBeginOrderId, null, true, true); @@ -567,7 +583,7 @@ } } - protected void handleLoopExplosionProxyNodes(MethodScope methodScope, LoopScope loopScope, LoopExitNode loopExit, int loopExitOrderId) { + protected void handleLoopExplosionProxyNodes(MethodScope methodScope, LoopScope loopScope, LoopScope outerScope, LoopExitNode loopExit, int loopExitOrderId) { assert loopExit.stateAfter() == null; int stateAfterOrderId = readOrderId(methodScope); @@ -582,16 +598,16 @@ * handle the merge. */ MergeNode merge = null; - Node existingExit = lookupNode(loopScope.outer, loopExitOrderId); + Node existingExit = lookupNode(outerScope, loopExitOrderId); if (existingExit == null) { /* First loop iteration that exits. No merge necessary yet. */ - registerNode(loopScope.outer, loopExitOrderId, begin, false, false); + registerNode(outerScope, loopExitOrderId, begin, false, false); begin.setNext(loopExitSuccessor); } else if (existingExit instanceof BeginNode) { /* Second loop iteration that exits. Create the merge. */ merge = methodScope.graph.add(new MergeNode()); - registerNode(loopScope.outer, loopExitOrderId, merge, true, false); + registerNode(outerScope, loopExitOrderId, merge, true, false); /* Add the first iteration. */ EndNode firstEnd = methodScope.graph.add(new EndNode()); ((BeginNode) existingExit).setNext(firstEnd); @@ -622,13 +638,13 @@ ValueNode phiInput = proxy.value(); ValueNode replacement; - ValueNode existing = (ValueNode) loopScope.outer.createdNodes[proxyOrderId]; + ValueNode existing = (ValueNode) outerScope.createdNodes[proxyOrderId]; if (existing == null || existing == phiInput) { /* * We are at the first loop exit, or the proxy carries the same value for all exits. * We do not need a phi node yet. */ - registerNode(loopScope.outer, proxyOrderId, phiInput, true, false); + registerNode(outerScope, proxyOrderId, phiInput, true, false); replacement = phiInput; } else if (!merge.isPhiAtMerge(existing)) { @@ -640,7 +656,7 @@ } /* Add the input from this exit. */ phi.addInput(phiInput); - registerNode(loopScope.outer, proxyOrderId, phi, true, false); + registerNode(outerScope, proxyOrderId, phi, true, false); replacement = phi; phiCreated = true; @@ -656,8 +672,8 @@ if (merge != null && (merge.stateAfter() == null || phiCreated)) { FrameState oldStateAfter = merge.stateAfter(); - registerNode(loopScope.outer, stateAfterOrderId, null, true, true); - merge.setStateAfter((FrameState) ensureNodeCreated(methodScope, loopScope.outer, stateAfterOrderId)); + registerNode(outerScope, stateAfterOrderId, null, true, true); + merge.setStateAfter((FrameState) ensureNodeCreated(methodScope, outerScope, stateAfterOrderId)); if (oldStateAfter != null) { oldStateAfter.safeDelete(); } @@ -1015,11 +1031,13 @@ */ protected void detectLoops(StructuredGraph currentGraph, FixedNode startInstruction) { + Debug.dump(currentGraph, "Before detectLoops"); NodeBitMap visited = currentGraph.createNodeBitMap(); NodeBitMap active = currentGraph.createNodeBitMap(); Deque<Node> stack = new ArrayDeque<>(); stack.add(startInstruction); visited.mark(startInstruction); + Set<LoopBeginNode> newLoopBegins = new HashSet<>(); while (!stack.isEmpty()) { Node next = stack.peek(); assert next.isDeleted() || visited.isMarked(next); @@ -1046,6 +1064,7 @@ LoopBeginNode newLoopBegin = appendLoopBegin(currentGraph, merge); newLoopBegin.setNext(afterMerge); newLoopBegin.setStateAfter(stateAfter); + newLoopBegins.add(newLoopBegin); } LoopBeginNode loopBegin = (LoopBeginNode) ((EndNode) merge.next()).merge(); LoopEndNode loopEnd = currentGraph.add(new LoopEndNode(loopBegin)); @@ -1060,7 +1079,9 @@ } } - insertLoopEnds(currentGraph, startInstruction); + Debug.dump(currentGraph, "Before insertLoopEnds"); + insertLoopEnds(currentGraph, startInstruction, newLoopBegins); + Debug.dump(currentGraph, "After detectLoops"); } private static LoopBeginNode appendLoopBegin(StructuredGraph currentGraph, FixedWithNextNode fixedWithNext) { @@ -1072,7 +1093,7 @@ return loopBegin; } - private static void insertLoopEnds(StructuredGraph currentGraph, FixedNode startInstruction) { + private static void insertLoopEnds(StructuredGraph currentGraph, FixedNode startInstruction, Set<LoopBeginNode> newLoopBegins) { NodeBitMap visited = currentGraph.createNodeBitMap(); Deque<Node> stack = new ArrayDeque<>(); stack.add(startInstruction); @@ -1081,7 +1102,7 @@ while (!stack.isEmpty()) { Node next = stack.pop(); assert visited.isMarked(next); - if (next instanceof LoopBeginNode) { + if (next instanceof LoopBeginNode && newLoopBegins.contains(next)) { loopBegins.add((LoopBeginNode) next); } for (Node n : next.cfgSuccessors()) { @@ -1094,10 +1115,9 @@ } } - IdentityHashMap<LoopBeginNode, List<LoopBeginNode>> innerLoopsMap = new IdentityHashMap<>(); for (int i = loopBegins.size() - 1; i >= 0; --i) { LoopBeginNode loopBegin = loopBegins.get(i); - insertLoopExits(currentGraph, loopBegin, innerLoopsMap); + insertLoopExits(currentGraph, loopBegin); } // Remove degenerated merges with only one predecessor. @@ -1109,7 +1129,7 @@ } } - private static void insertLoopExits(StructuredGraph currentGraph, LoopBeginNode loopBegin, IdentityHashMap<LoopBeginNode, List<LoopBeginNode>> innerLoopsMap) { + private static void insertLoopExits(StructuredGraph currentGraph, LoopBeginNode loopBegin) { NodeBitMap visited = currentGraph.createNodeBitMap(); Deque<Node> stack = new ArrayDeque<>(); for (LoopEndNode loopEnd : loopBegin.loopEnds()) { @@ -1118,7 +1138,6 @@ } List<ControlSplitNode> controlSplits = new ArrayList<>(); - List<LoopBeginNode> innerLoopBegins = new ArrayList<>(); while (!stack.isEmpty()) { Node current = stack.pop(); @@ -1135,7 +1154,6 @@ if (!visited.isMarked(innerLoopBegin)) { stack.push(innerLoopBegin); visited.mark(innerLoopBegin); - innerLoopBegins.add(innerLoopBegin); } } else { if (pred instanceof ControlSplitNode) { @@ -1158,27 +1176,6 @@ } } } - - for (LoopBeginNode inner : innerLoopBegins) { - addLoopExits(currentGraph, loopBegin, inner, innerLoopsMap, visited); - } - - innerLoopsMap.put(loopBegin, innerLoopBegins); - } - - private static void addLoopExits(StructuredGraph currentGraph, LoopBeginNode loopBegin, LoopBeginNode inner, IdentityHashMap<LoopBeginNode, List<LoopBeginNode>> innerLoopsMap, NodeBitMap visited) { - for (LoopExitNode exit : inner.loopExits()) { - if (!visited.isMarked(exit)) { - LoopExitNode newLoopExit = currentGraph.add(new LoopExitNode(loopBegin)); - FixedNode next = exit.next(); - next.replaceAtPredecessor(newLoopExit); - newLoopExit.setNext(next); - } - } - - for (LoopBeginNode innerInner : innerLoopsMap.get(inner)) { - addLoopExits(currentGraph, loopBegin, innerInner, innerLoopsMap, visited); - } } /**
--- a/graal/com.oracle.graal.truffle.test/src/com/oracle/graal/truffle/test/BytecodeInterpreterPartialEvaluationTest.java Fri Jan 08 13:10:52 2016 -0800 +++ b/graal/com.oracle.graal.truffle.test/src/com/oracle/graal/truffle/test/BytecodeInterpreterPartialEvaluationTest.java Fri Jan 08 13:12:43 2016 -0800 @@ -22,17 +22,11 @@ */ package com.oracle.graal.truffle.test; -import java.util.Random; - import org.junit.After; import org.junit.Assert; import org.junit.Before; -import org.junit.Ignore; import org.junit.Test; -import com.oracle.graal.options.OptionValue; -import com.oracle.graal.options.OptionValue.OverrideScope; -import com.oracle.graal.truffle.PartialEvaluator; import com.oracle.truffle.api.CompilerAsserts; import com.oracle.truffle.api.CompilerDirectives; import com.oracle.truffle.api.CompilerDirectives.CompilationFinal; @@ -73,6 +67,28 @@ public static boolean TRACE = false; + /* + * A method with a non-exploded loop, which goes away after loop unrolling as long as the + * parameter is a compilation constant. The method is called from multiple places to inject + * additional loops into the test cases, i.e., to stress the partial evaluator and compiler + * optimizations. + */ + static int nonExplodedLoop(int x) { + if (x >= 0 && x < 50) { + int result = 0; + for (int i = 0; i < x; i++) { + result++; + if (result > 100) { + /* Dead branch because result < 50, just to complicate the loop structure. */ + CompilerDirectives.transferToInterpreter(); + } + } + return result; + } else { + return x; + } + } + public static class Program extends RootNode { private final String name; @CompilationFinal private final byte[] bytecodes; @@ -124,7 +140,9 @@ trace("Start program"); int topOfStack = -1; int bci = 0; - outer: while (true) { + int result = 0; + boolean running = true; + outer: while (running) { CompilerAsserts.partialEvaluationConstant(bci); byte bc = bytecodes[bci]; byte value = 0; @@ -132,13 +150,15 @@ case Bytecode.CONST: value = bytecodes[bci + 1]; trace("%d (%d): CONST %s", bci, topOfStack, value); - setInt(frame, ++topOfStack, value); + setInt(frame, ++topOfStack, nonExplodedLoop(value)); bci = bci + 2; continue; case Bytecode.RETURN: trace("%d (%d): RETURN", bci, topOfStack); - return getInt(frame, topOfStack); + result = nonExplodedLoop(getInt(frame, topOfStack)); + running = false; + continue; case Bytecode.ADD: { int left = getInt(frame, topOfStack); @@ -167,18 +187,11 @@ for (int i = 0; i < value; ++i) { if (switchValue == i) { bci = bytecodes[bci + i + 2]; - // Need this seemingly useless condition here for two reasons: - // 1. Bytecode analysis will consider the current block as - // being within the inner loop. - // 2. The if body will be an empty block that directly - // jumps to the begin of the outer loop. - if (i != -1) { - continue outer; - } + continue outer; } } // Continue with the code after the switch. - bci += value + 1; + bci += value + 2; continue; case Bytecode.POP: @@ -200,6 +213,7 @@ continue; } } + return nonExplodedLoop(result); } } @@ -391,7 +405,7 @@ @Override public boolean execute(VirtualFrame frame) { - frame.setInt(slot, value); + frame.setInt(slot, nonExplodedLoop(value)); return true; } @@ -512,7 +526,7 @@ ip = inst[ip].getFalseSucc(); } } - return FrameUtil.getIntSafe(frame, returnSlot); + return nonExplodedLoop(FrameUtil.getIntSafe(frame, returnSlot)); } } @@ -531,55 +545,6 @@ assertPartialEvalEqualsAndRunsCorrect(new InstArrayProgram("instArraySimpleIfProgram", inst, returnSlot, fd)); } - /** - * Slightly modified version to expose a partial evaluation bug with ExplodeLoop(merge=true). - */ - public static class InstArrayProgram2 extends InstArrayProgram { - public InstArrayProgram2(String name, Inst[] inst, FrameSlot returnSlot, FrameDescriptor fd) { - super(name, inst, returnSlot, fd); - } - - @Override - @ExplodeLoop(merge = true) - public Object execute(VirtualFrame frame) { - int ip = 0; - while (ip != -1) { - CompilerAsserts.partialEvaluationConstant(ip); - if (inst[ip].execute(frame)) { - ip = inst[ip].getTrueSucc(); - } else { - ip = inst[ip].getFalseSucc(); - } - } - if (frame.getArguments().length > 0) { - return new Random(); - } else { - return FrameUtil.getIntSafe(frame, returnSlot); - } - } - } - - @Ignore("produces a bad graph") - @Test - public void instArraySimpleIfProgram2() { - FrameDescriptor fd = new FrameDescriptor(); - FrameSlot value1Slot = fd.addFrameSlot("value1", FrameSlotKind.Int); - FrameSlot value2Slot = fd.addFrameSlot("value2", FrameSlotKind.Int); - FrameSlot returnSlot = fd.addFrameSlot("return", FrameSlotKind.Int); - Inst[] inst = new Inst[]{ - /* 0: */new Inst.Const(value1Slot, 100, 1), - /* 1: */new Inst.Const(value2Slot, 100, 2), - /* 2: */new Inst.IfLt(value1Slot, value2Slot, 3, 5), - /* 3: */new Inst.Const(returnSlot, 41, 4), - /* 4: */new Inst.Return(), - /* 5: */new Inst.Const(returnSlot, 42, 6), - /* 6: */new Inst.Return()}; - InstArrayProgram program = new InstArrayProgram2("instArraySimpleIfProgram2", inst, returnSlot, fd); - program.execute(Truffle.getRuntime().createVirtualFrame(new Object[0], fd)); - program.execute(Truffle.getRuntime().createVirtualFrame(new Object[1], fd)); - assertPartialEvalEqualsAndRunsCorrect(program); - } - @Test @SuppressWarnings("try") public void simpleSwitchProgram() { @@ -600,8 +565,26 @@ /* 13: */42, /* 14: */Bytecode.RETURN}; Program program = new Program("simpleSwitchProgram", bytecodes, 0, 3); - try (OverrideScope s = OptionValue.override(PartialEvaluator.GraphPE, false)) { - assertPartialEvalEqualsAndRunsCorrect(program); - } + assertPartialEvalEqualsAndRunsCorrect(program); + } + + @Test + @SuppressWarnings("try") + public void loopSwitchProgram() { + byte[] bytecodes = new byte[]{ + /* 0: */Bytecode.CONST, + /* 1: */1, + /* 2: */Bytecode.SWITCH, + /* 3: */2, + /* 4: */0, + /* 5: */9, + /* 6: */Bytecode.CONST, + /* 7: */40, + /* 8: */Bytecode.RETURN, + /* 9: */Bytecode.CONST, + /* 10: */42, + /* 11: */Bytecode.RETURN}; + Program program = new Program("loopSwitchProgram", bytecodes, 0, 3); + assertPartialEvalEqualsAndRunsCorrect(program); } }