Mercurial > hg > truffle
changeset 12773:b6e04d6fe3a7
NewMemoryAwareScheduling: rewrite to set based approach
author | Bernhard Urban <bernhard.urban@jku.at> |
---|---|
date | Mon, 18 Nov 2013 17:22:37 +0100 |
parents | 6a7b6dcb7f67 |
children | 1729072a893a |
files | graal/com.oracle.graal.phases/src/com/oracle/graal/phases/graph/ReentrantBlockIterator.java graal/com.oracle.graal.phases/src/com/oracle/graal/phases/schedule/SchedulePhase.java |
diffstat | 2 files changed, 210 insertions(+), 194 deletions(-) [+] |
line wrap: on
line diff
--- a/graal/com.oracle.graal.phases/src/com/oracle/graal/phases/graph/ReentrantBlockIterator.java Mon Nov 18 17:22:30 2013 +0100 +++ b/graal/com.oracle.graal.phases/src/com/oracle/graal/phases/graph/ReentrantBlockIterator.java Mon Nov 18 17:22:37 2013 +0100 @@ -76,7 +76,7 @@ apply(closure, start, closure.getInitialState(), null); } - private static <StateT> IdentityHashMap<FixedNode, StateT> apply(BlockIteratorClosure<StateT> closure, Block start, StateT initialState, Set<Block> boundary) { + public static <StateT> IdentityHashMap<FixedNode, StateT> apply(BlockIteratorClosure<StateT> closure, Block start, StateT initialState, Set<Block> boundary) { Deque<Block> blockQueue = new ArrayDeque<>(); /* * States are stored on EndNodes before merges, and on BeginNodes after ControlSplitNodes. @@ -173,6 +173,7 @@ mergedStates.add(states.get(end)); } state = closure.merge(current, mergedStates); + states.put(merge, state); } } }
--- a/graal/com.oracle.graal.phases/src/com/oracle/graal/phases/schedule/SchedulePhase.java Mon Nov 18 17:22:30 2013 +0100 +++ b/graal/com.oracle.graal.phases/src/com/oracle/graal/phases/schedule/SchedulePhase.java Mon Nov 18 17:22:37 2013 +0100 @@ -175,104 +175,87 @@ } } - private class NewMemoryScheduleClosure extends BlockIteratorClosure<Map<LocationIdentity, Node>> { + private class NewMemoryScheduleClosure extends BlockIteratorClosure<Set<LocationIdentity>> { + @Override + protected Set<LocationIdentity> getInitialState() { + return cloneState(blockToKillSet.get(getCFG().getStartBlock())); + } + + @Override + protected Set<LocationIdentity> processBlock(Block block, Set<LocationIdentity> currentState) { + currentState.addAll(computeKillSet(block)); + return currentState; + } @Override - protected Map<LocationIdentity, Node> getInitialState() { - return cloneState(blockToKillMap.get(getCFG().getStartBlock())); + protected Set<LocationIdentity> merge(Block merge, List<Set<LocationIdentity>> states) { + assert merge.getBeginNode() instanceof MergeNode; + + Set<LocationIdentity> initKillSet = new HashSet<>(); + for (Set<LocationIdentity> state : states) { + initKillSet.addAll(state); + } + + return initKillSet; + } + + @Override + protected Set<LocationIdentity> cloneState(Set<LocationIdentity> state) { + return new HashSet<>(state); } @Override - protected Map<LocationIdentity, Node> processBlock(Block block, Map<LocationIdentity, Node> currentState) { + protected List<Set<LocationIdentity>> processLoop(Loop loop, Set<LocationIdentity> state) { + LoopInfo<Set<LocationIdentity>> info = ReentrantBlockIterator.processLoop(this, loop, cloneState(state)); + + assert loop.header.getBeginNode() instanceof LoopBeginNode; + Set<LocationIdentity> headerState = merge(loop.header, info.endStates); - if (block.getBeginNode() instanceof MergeNode) { - MergeNode mergeNode = (MergeNode) block.getBeginNode(); - for (PhiNode phi : mergeNode.usages().filter(PhiNode.class)) { - if (phi.type() == PhiType.Memory) { - LocationIdentity identity = phi.getIdentity(); - locationKilledBy(identity, phi, currentState); - } - } - } - currentState.putAll(blockToKillMapInit.get(block)); + // second iteration, for propagating information to loop exits + info = ReentrantBlockIterator.processLoop(this, loop, cloneState(headerState)); + + return info.exitStates; + } + } - for (Node node : block.getNodes()) { - if (node instanceof MemoryCheckpoint.Single) { - LocationIdentity identity = ((MemoryCheckpoint.Single) node).getLocationIdentity(); - locationKilledBy(identity, node, currentState); - } else if (node instanceof MemoryCheckpoint.Multi) { - for (LocationIdentity identity : ((MemoryCheckpoint.Multi) node).getLocationIdentities()) { - locationKilledBy(identity, node, currentState); - } - } - assert MemoryCheckpoint.TypeAssertion.correctType(node); - } + /** + * gather all kill locations by iterating trough the nodes assigned to a block. + * + * assumptions: {@link MemoryCheckpoint MemoryCheckPoints} are {@link FixedNode FixedNodes}. + * + * @param block block to analyze + * @return all killed locations + */ + private Set<LocationIdentity> computeKillSet(Block block) { + Set<LocationIdentity> cachedSet = blockToKillSet.get(block); + if (cachedSet != null) { + return cachedSet; + } + HashSet<LocationIdentity> set = new HashSet<>(); + blockToKillSet.put(block, set); - blockToKillMap.put(block, currentState); - return cloneState(currentState); - } - - private void locationKilledBy(LocationIdentity identity, Node checkpoint, Map<LocationIdentity, Node> state) { - state.put(identity, checkpoint); - if (identity == ANY_LOCATION) { - for (LocationIdentity locid : state.keySet()) { - state.put(locid, checkpoint); + if (block.getBeginNode() instanceof MergeNode) { + MergeNode mergeNode = (MergeNode) block.getBeginNode(); + for (PhiNode phi : mergeNode.usages().filter(PhiNode.class)) { + if (phi.type() == PhiType.Memory) { + set.add(phi.getIdentity()); } } } - @Override - protected Map<LocationIdentity, Node> merge(Block merge, List<Map<LocationIdentity, Node>> states) { - assert merge.getBeginNode() instanceof MergeNode; - MergeNode mergeNode = (MergeNode) merge.getBeginNode(); - - Map<LocationIdentity, Node> initKillMap = new HashMap<>(); - for (Map<LocationIdentity, Node> state : states) { - for (LocationIdentity locid : state.keySet()) { - if (initKillMap.containsKey(locid)) { - if (!initKillMap.get(locid).equals(state.get(locid))) { - initKillMap.put(locid, mergeNode); - } - } else { - initKillMap.put(locid, state.get(locid)); - } + for (Node node : block.getNodes()) { + if (node instanceof MemoryCheckpoint.Single) { + LocationIdentity identity = ((MemoryCheckpoint.Single) node).getLocationIdentity(); + set.add(identity); + } else if (node instanceof MemoryCheckpoint.Multi) { + for (LocationIdentity identity : ((MemoryCheckpoint.Multi) node).getLocationIdentities()) { + set.add(identity); } } - - mergeToKillMap.set(mergeNode, cloneState(initKillMap)); - return initKillMap; + assert MemoryCheckpoint.TypeAssertion.correctType(node); } - @Override - protected Map<LocationIdentity, Node> cloneState(Map<LocationIdentity, Node> state) { - return new HashMap<>(state); - } - - @Override - protected List<Map<LocationIdentity, Node>> processLoop(Loop loop, Map<LocationIdentity, Node> state) { - LoopInfo<Map<LocationIdentity, Node>> info = ReentrantBlockIterator.processLoop(this, loop, cloneState(state)); - - assert loop.header.getBeginNode() instanceof LoopBeginNode; - Map<LocationIdentity, Node> headerState = merge(loop.header, info.endStates); - // second iteration, for computing information at loop exits - info = ReentrantBlockIterator.processLoop(this, loop, cloneState(headerState)); - - int i = 0; - for (Block exit : loop.exits) { - Map<LocationIdentity, Node> exitState = info.exitStates.get(i++); - - Node begin = exit.getBeginNode(); - assert begin instanceof LoopExitNode; - for (Node usage : begin.usages()) { - if (usage instanceof ProxyNode && ((ProxyNode) usage).type() == PhiType.Memory) { - ProxyNode proxy = (ProxyNode) usage; - LocationIdentity identity = proxy.getIdentity(); - locationKilledBy(identity, proxy, exitState); - } - } - } - return info.exitStates; - } + return set; } private ControlFlowGraph cfg; @@ -282,13 +265,12 @@ * Map from blocks to the nodes in each block. */ private BlockMap<List<ScheduledNode>> blockToNodesMap; - private BlockMap<Map<LocationIdentity, Node>> blockToKillMapInit; - private BlockMap<Map<LocationIdentity, Node>> blockToKillMap; - private NodeMap<Map<LocationIdentity, Node>> mergeToKillMap; + private BlockMap<Set<LocationIdentity>> blockToKillSet; private final Map<FloatingNode, List<FixedNode>> phantomUsages = new IdentityHashMap<>(); private final Map<FixedNode, List<FloatingNode>> phantomInputs = new IdentityHashMap<>(); private final SchedulingStrategy selectedStrategy; private final MemoryScheduling memsched; + private NewMemoryScheduleClosure maschedClosure; public SchedulePhase() { this(OptScheduleOutOfLoops.getValue() ? SchedulingStrategy.LATEST_OUT_OF_LOOPS : SchedulingStrategy.LATEST); @@ -323,8 +305,7 @@ assignBlockToNodes(graph, SchedulingStrategy.EARLIEST); sortNodesWithinBlocks(graph, SchedulingStrategy.EARLIEST); - MemoryScheduleClosure closure = new MemoryScheduleClosure(); - ReentrantBlockIterator.apply(closure, getCFG().getStartBlock()); + ReentrantBlockIterator.apply(new MemoryScheduleClosure(), getCFG().getStartBlock()); cfg.clearNodeToBlock(); blockToNodesMap = new BlockMap<>(cfg); @@ -333,31 +314,8 @@ sortNodesWithinBlocks(graph, selectedStrategy); printSchedule("after sorting nodes within blocks"); } else if (memsched == MemoryScheduling.OPTIMAL && selectedStrategy != SchedulingStrategy.EARLIEST && graph.getNodes(FloatingReadNode.class).isNotEmpty()) { - mergeToKillMap = graph.createNodeMap(); - - blockToKillMapInit = new BlockMap<>(cfg); - blockToKillMap = new BlockMap<>(cfg); - for (Block b : cfg.getBlocks()) { - blockToKillMapInit.put(b, new HashMap<LocationIdentity, Node>()); - blockToKillMap.put(b, new HashMap<LocationIdentity, Node>()); - } - - // initialize killMaps with lastLocationAccess - for (FloatingReadNode n : graph.getNodes(FloatingReadNode.class)) { - if (n.location().getLocationIdentity() == FINAL_LOCATION) { - continue; - } - Node first = n.getLastLocationAccess(); - assert first != null; - - Map<LocationIdentity, Node> killMap = blockToKillMapInit.get(forKillLocation(first)); - killMap.put(n.location().getLocationIdentity(), first); - } - - // distribute and compute killMaps for all blocks - NewMemoryScheduleClosure closure = new NewMemoryScheduleClosure(); - ReentrantBlockIterator.apply(closure, getCFG().getStartBlock()); - printSchedule("after computing killMaps"); + blockToKillSet = new BlockMap<>(cfg); + maschedClosure = new NewMemoryScheduleClosure(); assignBlockToNodes(graph, selectedStrategy); printSchedule("after assign nodes to blocks"); @@ -370,46 +328,43 @@ } } - private Block forKillLocation(Node n) { + private Block blockForFixedNode(Node n) { Block b = cfg.getNodeToBlock().get(n); assert b != null : "all lastAccess locations should have a block assignment from CFG"; return b; } private void printSchedule(String desc) { - Debug.printf("=== %s / %s / %s (%s) ===\n", getCFG().getStartBlock().getBeginNode().graph(), selectedStrategy, memsched, desc); - for (Block b : getCFG().getBlocks()) { - Debug.printf("==== b: %s (loopDepth: %s). ", b, b.getLoopDepth()); - Debug.printf("dom: %s. ", b.getDominator()); - Debug.printf("post-dom: %s. ", b.getPostdominator()); - Debug.printf("preds: %s. ", b.getPredecessors()); - Debug.printf("succs: %s ====\n", b.getSuccessors()); - BlockMap<Map<LocationIdentity, Node>> killMaps = blockToKillMap; - if (killMaps != null) { - if (b.getBeginNode() instanceof MergeNode) { - MergeNode merge = (MergeNode) b.getBeginNode(); - Debug.printf("M merge kills: \n"); - for (LocationIdentity locId : mergeToKillMap.get(merge).keySet()) { - Debug.printf("M %s killed by %s\n", locId, mergeToKillMap.get(merge).get(locId)); + if (Debug.isEnabled()) { + Debug.printf("=== %s / %s / %s (%s) ===\n", getCFG().getStartBlock().getBeginNode().graph(), selectedStrategy, memsched, desc); + for (Block b : getCFG().getBlocks()) { + Debug.printf("==== b: %s (loopDepth: %s). ", b, b.getLoopDepth()); + Debug.printf("dom: %s. ", b.getDominator()); + Debug.printf("post-dom: %s. ", b.getPostdominator()); + Debug.printf("preds: %s. ", b.getPredecessors()); + Debug.printf("succs: %s ====\n", b.getSuccessors()); + BlockMap<Set<LocationIdentity>> killMaps = blockToKillSet; + if (killMaps != null) { + Debug.printf("X block kills: \n"); + if (killMaps.get(b) != null) { + for (LocationIdentity locId : killMaps.get(b)) { + Debug.printf("X %s killed by %s\n", locId, "dunno anymore"); + } } } - Debug.printf("X block kills: \n"); - for (LocationIdentity locId : killMaps.get(b).keySet()) { - Debug.printf("X %s killed by %s\n", locId, killMaps.get(b).get(locId)); + + if (blockToNodesMap.get(b) != null) { + for (Node n : nodesFor(b)) { + printNode(n); + } + } else { + for (Node n : b.getNodes()) { + printNode(n); + } } } - - if (blockToNodesMap.get(b) != null) { - for (Node n : nodesFor(b)) { - printNode(n); - } - } else { - for (Node n : b.getNodes()) { - printNode(n); - } - } + Debug.printf("\n\n"); } - Debug.printf("\n\n"); } private static void printNode(Node n) { @@ -487,7 +442,8 @@ } Block earliestBlock = earliestBlock(node); - Block block; + Block block = null; + Block latest = null; switch (strategy) { case EARLIEST: block = earliestBlock; @@ -496,7 +452,10 @@ case LATEST_OUT_OF_LOOPS: boolean scheduleRead = memsched == MemoryScheduling.OPTIMAL && node instanceof FloatingReadNode && ((FloatingReadNode) node).location().getLocationIdentity() != FINAL_LOCATION; if (scheduleRead) { - block = optimalBlock((FloatingReadNode) node, strategy); + FloatingReadNode read = (FloatingReadNode) node; + block = optimalBlock(read, strategy); + assert earliestBlock.dominates(block) : String.format("%s (%s) cannot be scheduled before earliest schedule (%s). location: %s", read, block, earliestBlock, + read.getLocationIdentity()); } else { block = latestBlock(node, strategy); } @@ -504,11 +463,22 @@ block = earliestBlock; } else if (strategy == SchedulingStrategy.LATEST_OUT_OF_LOOPS && !(node instanceof VirtualObjectNode)) { // schedule at the latest position possible in the outermost loop possible + latest = block; block = scheduleOutOfLoops(node, block, earliestBlock); } - assert !scheduleRead || forKillLocation(((FloatingReadNode) node).getLastLocationAccess()).dominates(block) : "out of loop violated memory semantics for " + node + ". moved to " + - block + ", but upper bound is " + forKillLocation(((FloatingReadNode) node).getLastLocationAccess()); + if (assertionEnabled()) { + if (scheduleRead) { + FloatingReadNode read = (FloatingReadNode) node; + Node lastLocationAccess = read.getLastLocationAccess(); + Block upperBound = blockForFixedNode(lastLocationAccess); + if (!blockForFixedNode(lastLocationAccess).dominates(block)) { + assert false : String.format("out of loop movement voilated memory semantics for %s (location %s). moved to %s but upper bound is %s (earliest: %s, latest: %s)", read, + read.getLocationIdentity(), block, upperBound, earliestBlock, latest); + } + + } + } break; default: throw new GraalInternalError("unknown scheduling strategy"); @@ -520,11 +490,19 @@ blockToNodesMap.get(block).add(node); } + @SuppressWarnings("all") + private static boolean assertionEnabled() { + boolean enabled = false; + assert enabled = true; + return enabled; + } + /** - * this method tries to find the latest position for a read, by taking the information gathered - * by {@link NewMemoryScheduleClosure} into account. + * this method tries to find the "optimal" schedule for a read, by pushing it down towards its + * latest schedule starting by the earliest schedule. By doing this, it takes care of memory + * dependencies using kill sets. * - * The idea is to iterate the dominator tree starting with the latest schedule of the read. + * In terms of domination relation, it looks like this: * * <pre> * U upperbound block, defined by last access location of the floating read @@ -536,10 +514,7 @@ * L latest block * </pre> * - * i.e. <code>upperbound `dom` earliest `dom` optimal `dom` latest</code>. However, there're - * cases where <code>earliest `dom` optimal</code> is not true, because the position is - * (impliclitly) bounded by an anchor of the read's guard. In such cases, the earliest schedule - * is taken. + * i.e. <code>upperbound `dom` earliest `dom` optimal `dom` latest</code>. * */ private Block optimalBlock(FloatingReadNode n, SchedulingStrategy strategy) { @@ -548,59 +523,99 @@ LocationIdentity locid = n.location().getLocationIdentity(); assert locid != FINAL_LOCATION; - Node upperBound = n.getLastLocationAccess(); - Block upperBoundBlock = forKillLocation(upperBound); + Block upperBoundBlock = blockForFixedNode(n.getLastLocationAccess()); Block earliestBlock = earliestBlock(n); assert upperBoundBlock.dominates(earliestBlock) : "upper bound (" + upperBoundBlock + ") should dominate earliest (" + earliestBlock + ")"; - Block currentBlock = latestBlock(n, strategy); - assert currentBlock != null && earliestBlock.dominates(currentBlock) : "earliest (" + earliestBlock + ") should dominate latest block (" + currentBlock + ")"; - Block previousBlock = currentBlock; + Block latestBlock = latestBlock(n, strategy); + assert latestBlock != null && earliestBlock.dominates(latestBlock) : "earliest (" + earliestBlock + ") should dominate latest block (" + latestBlock + ")"; - Debug.printf("processing %s (accessing %s): latest %s, earliest %s, upper bound %s (%s)\n", n, locid, currentBlock, earliestBlock, upperBoundBlock, upperBound); + Debug.printf("processing %s (accessing %s): latest %s, earliest %s, upper bound %s (%s)\n", n, locid, latestBlock, earliestBlock, upperBoundBlock, n.getLastLocationAccess()); + if (earliestBlock == latestBlock) { + // read is fixed to this block, nothing to schedule + return latestBlock; + } + + Stack<Block> path = computePathInDominatorTree(earliestBlock, latestBlock); + Debug.printf("|path| is %d: %s\n", path.size(), path); - int iterations = 0; - // iterate the dominator tree - while (true) { - iterations++; - Node lastKill = blockToKillMap.get(currentBlock).get(locid); - assert lastKill != null : "should be never null, due to init of killMaps: " + currentBlock + ", location: " + locid; + Set<LocationIdentity> killSet = new HashSet<>(); + // follow path, start at earliest schedule + while (path.size() > 0) { + Block currentBlock = path.pop(); + Block dominatedBlock = path.size() == 0 ? null : path.peek(); + if (dominatedBlock != null && !currentBlock.getSuccessors().contains(dominatedBlock)) { + // the dominated block is not a successor -> we have a split + assert dominatedBlock.getBeginNode() instanceof MergeNode; + + HashSet<Block> region = computeRegion(currentBlock, dominatedBlock); + Debug.printf("%s: region for %s -> %s: %s\n", n, currentBlock, dominatedBlock, region); + + Map<FixedNode, Set<LocationIdentity>> states; + states = ReentrantBlockIterator.apply(maschedClosure, currentBlock, new HashSet<>(killSet), region); - if (lastKill.equals(upperBound)) { - // assign node to the block which kills the location - - if (previousBlock.getBeginNode() instanceof MergeNode) { - // merges kill locations right at the beginning of a block. if a merge is the - // killing node, we assign it to the dominating block. + Set<LocationIdentity> mergeState = states.get(dominatedBlock.getBeginNode()); + if (mergeState.contains(locid)) { + // location got killed somewhere in the branches, + // thus we've to move the read above it + return currentBlock; + } + killSet.addAll(mergeState); + } else { + // trivial case + if (dominatedBlock == null) { + return currentBlock; + } + Set<LocationIdentity> blockKills = computeKillSet(currentBlock); + if (blockKills.contains(locid)) { + return currentBlock; + } + killSet.addAll(blockKills); + } + } + assert false : "should have found a block for " + n; + return null; + } - MergeNode merge = (MergeNode) previousBlock.getBeginNode(); - Node killer = mergeToKillMap.get(merge).get(locid); + /** + * compute path in dominator tree from earliest schedule to latest schedule. + * + * @return the order of the stack is such as the first element is the earliest schedule. + */ + private static Stack<Block> computePathInDominatorTree(Block earliestBlock, Block latestBlock) { + Stack<Block> path = new Stack<>(); + Block currentBlock = latestBlock; + while (currentBlock != null && earliestBlock.dominates(currentBlock)) { + path.push(currentBlock); + currentBlock = currentBlock.getDominator(); + } + assert path.peek() == earliestBlock; + return path; + } - if (killer != null && killer == merge) { - printIterations(iterations, "kill by merge: " + currentBlock); - return currentBlock; + /** + * compute a set that contains all blocks in a region spanned by dominatorBlock and + * dominatedBlock (exclusive the dominatedBlock). + */ + private static HashSet<Block> computeRegion(Block dominatorBlock, Block dominatedBlock) { + HashSet<Block> region = new HashSet<>(); + Stack<Block> workList = new Stack<>(); + + region.add(dominatorBlock); + workList.addAll(0, dominatorBlock.getSuccessors()); + while (workList.size() > 0) { + Block current = workList.pop(); + if (current != dominatedBlock) { + region.add(current); + for (Block b : current.getSuccessors()) { + if (!region.contains(b) && !workList.contains(b)) { + workList.add(b); } } - - // current block matches last access, that means the previous (dominated) block - // kills the location, therefore schedule read to previous block. - printIterations(iterations, "regular kill: " + previousBlock); - return previousBlock; } - - if (upperBoundBlock == currentBlock) { - printIterations(iterations, "upper bound: " + currentBlock + ", previous: " + previousBlock); - return currentBlock; - } - - previousBlock = currentBlock; - currentBlock = currentBlock.getDominator(); - assert currentBlock != null; } - } - - private static void printIterations(int iterations, String desc) { - Debug.printf("iterations: %d, %s\n", iterations, desc); + assert !region.contains(dominatedBlock) && region.containsAll(dominatedBlock.getPredecessors()); + return region; } /**