# HG changeset patch # User Bernhard Urban # Date 1380713550 -7200 # Node ID 7763a42d1658f85034e803d9202a85d509c5c357 # Parent 4635fdfd8256c2dcc13584f00c335fbc23538c50 NewMemoryAwareScheduling: handle MemoryPhis properly remove earliest hack and some refactoring diff -r 4635fdfd8256 -r 7763a42d1658 graal/com.oracle.graal.phases/src/com/oracle/graal/phases/schedule/SchedulePhase.java --- a/graal/com.oracle.graal.phases/src/com/oracle/graal/phases/schedule/SchedulePhase.java Wed Oct 02 11:16:21 2013 +0200 +++ b/graal/com.oracle.graal.phases/src/com/oracle/graal/phases/schedule/SchedulePhase.java Wed Oct 02 13:32:30 2013 +0200 @@ -33,6 +33,7 @@ import com.oracle.graal.graph.*; import com.oracle.graal.graph.Node.Verbosity; import com.oracle.graal.nodes.*; +import com.oracle.graal.nodes.PhiNode.PhiType; import com.oracle.graal.nodes.calc.*; import com.oracle.graal.nodes.cfg.*; import com.oracle.graal.nodes.extended.*; @@ -99,7 +100,7 @@ @Override protected HashSet processBlock(Block block, HashSet currentState) { - for (Node node : getBlockToNodesMap().get(block)) { + for (Node node : blockToNodesMap.get(block)) { if (node instanceof FloatingReadNode) { currentState.add((FloatingReadNode) node); } else if (node instanceof MemoryCheckpoint.Single) { @@ -183,37 +184,49 @@ @Override protected Map processBlock(Block block, Map currentState) { - Map initKillMap = getBlockToKillMap().get(block); - initKillMap.putAll(currentState); + + if (block.getBeginNode() instanceof MergeNode) { + MergeNode mergeNode = (MergeNode) block.getBeginNode(); + for (PhiNode phi : mergeNode.usages().filter(PhiNode.class)) { + if (phi.type() == PhiType.Memory) { + LocationIdentity identity = (LocationIdentity) phi.getIdentity(); + locationKilledBy(identity, phi, currentState); + } + } + } + currentState.putAll(blockToKillMapInit.get(block)); for (Node node : block.getNodes()) { if (node instanceof MemoryCheckpoint.Single) { LocationIdentity identity = ((MemoryCheckpoint.Single) node).getLocationIdentity(); - initKillMap.put(identity, node); + locationKilledBy(identity, node, currentState); } else if (node instanceof MemoryCheckpoint.Multi) { for (LocationIdentity identity : ((MemoryCheckpoint.Multi) node).getLocationIdentities()) { - initKillMap.put(identity, node); + locationKilledBy(identity, node, currentState); } } assert MemoryCheckpoint.TypeAssertion.correctType(node); } - return cloneState(initKillMap); + blockToKillMap.put(block, currentState); + return cloneState(currentState); + } + + private void locationKilledBy(LocationIdentity identity, Node checkpoint, Map state) { + state.put(identity, checkpoint); + if (identity == ANY_LOCATION) { + for (LocationIdentity locid : state.keySet()) { + state.put(locid, checkpoint); + } + } } @Override protected Map merge(Block merge, List> states) { - return merge(merge, states, false); - } - - protected Map merge(Block merge, List> states, boolean loopbegin) { assert merge.getBeginNode() instanceof MergeNode; MergeNode mergeNode = (MergeNode) merge.getBeginNode(); Map initKillMap = new HashMap<>(); - if (loopbegin) { - initKillMap.putAll(getBlockToKillMap().get(merge)); - } for (Map state : states) { for (LocationIdentity locid : state.keySet()) { if (initKillMap.containsKey(locid)) { @@ -226,10 +239,7 @@ } } - getMergeToKillMap().set(mergeNode, cloneState(initKillMap)); - if (!loopbegin) { - initKillMap.putAll(getBlockToKillMap().get(merge)); - } + mergeToKillMap.set(mergeNode, cloneState(initKillMap)); return initKillMap; } @@ -243,15 +253,26 @@ LoopInfo> info = ReentrantBlockIterator.processLoop(this, loop, new HashMap<>(state)); assert loop.header.getBeginNode() instanceof LoopBeginNode; - Map headerState = merge(loop.header, info.endStates, true); - getBlockToKillMap().put(loop.header, headerState); + Map headerState = merge(loop.header, info.endStates); + // getBlockToKillMap().put(loop.header, headerState); - for (Map exitState : info.exitStates) { + int i = 0; + for (Block exit : loop.exits) { + Map exitState = info.exitStates.get(i++); for (LocationIdentity key : headerState.keySet()) { exitState.put(key, headerState.get(key)); } + + Node begin = exit.getBeginNode(); + assert begin instanceof LoopExitNode; + for (Node usage : begin.usages()) { + if (usage instanceof ProxyNode && ((ProxyNode) usage).type() == PhiType.Memory) { + ProxyNode proxy = (ProxyNode) usage; + LocationIdentity identity = (LocationIdentity) proxy.getIdentity(); + locationKilledBy(identity, proxy, exitState); + } + } } - return info.exitStates; } } @@ -263,6 +284,7 @@ * Map from blocks to the nodes in each block. */ private BlockMap> blockToNodesMap; + private BlockMap> blockToKillMapInit; private BlockMap> blockToKillMap; private NodeMap> mergeToKillMap; private final Map> phantomUsages = new IdentityHashMap<>(); @@ -315,8 +337,10 @@ } else if (memsched == MemoryScheduling.OPTIMAL && selectedStrategy != SchedulingStrategy.EARLIEST && graph.getNodes(FloatingReadNode.class).isNotEmpty()) { mergeToKillMap = graph.createNodeMap(); + blockToKillMapInit = new BlockMap<>(cfg); blockToKillMap = new BlockMap<>(cfg); for (Block b : cfg.getBlocks()) { + blockToKillMapInit.put(b, new HashMap()); blockToKillMap.put(b, new HashMap()); } @@ -328,7 +352,7 @@ Node first = n.lastLocationAccess(); assert first != null; - Map killMap = blockToKillMap.get(forKillLocation(first)); + Map killMap = blockToKillMapInit.get(forKillLocation(first)); killMap.put(n.location().getLocationIdentity(), first); } @@ -357,20 +381,27 @@ private void printSchedule(String desc) { Debug.printf("=== %s / %s / %s (%s) ===\n", getCFG().getStartBlock().getBeginNode().graph(), selectedStrategy, memsched, desc); for (Block b : getCFG().getBlocks()) { - Debug.printf("==== b: %s. ", b); + Debug.printf("==== b: %s (loopDepth: %s). ", b, b.getLoopDepth()); Debug.printf("dom: %s. ", b.getDominator()); Debug.printf("post-dom: %s. ", b.getPostdominator()); Debug.printf("preds: %s. ", b.getPredecessors()); Debug.printf("succs: %s ====\n", b.getSuccessors()); - BlockMap> killMaps = getBlockToKillMap(); + BlockMap> killMaps = blockToKillMap; if (killMaps != null) { + if (b.getBeginNode() instanceof MergeNode) { + MergeNode merge = (MergeNode) b.getBeginNode(); + Debug.printf("M merge kills: \n"); + for (LocationIdentity locId : mergeToKillMap.get(merge).keySet()) { + Debug.printf("M %s killed by %s\n", locId, mergeToKillMap.get(merge).get(locId)); + } + } Debug.printf("X block kills: \n"); for (LocationIdentity locId : killMaps.get(b).keySet()) { Debug.printf("X %s killed by %s\n", locId, killMaps.get(b).get(locId)); } } - if (getBlockToNodesMap().get(b) != null) { + if (blockToNodesMap.get(b) != null) { for (Node n : nodesFor(b)) { printNode(n); } @@ -414,14 +445,6 @@ return blockToNodesMap; } - public BlockMap> getBlockToKillMap() { - return blockToKillMap; - } - - public NodeMap> getMergeToKillMap() { - return mergeToKillMap; - } - /** * Gets the nodes in a given block. */ @@ -541,9 +564,8 @@ // iterate the dominator tree while (true) { iterations++; - assert earliestBlock.dominates(previousBlock) : "iterations: " + iterations; Node lastKill = blockToKillMap.get(currentBlock).get(locid); - boolean isAtEarliest = earliestBlock == previousBlock && previousBlock != currentBlock; + assert lastKill != null : "should be never null, due to init of killMaps: " + currentBlock + ", location: " + locid; if (lastKill.equals(upperBound)) { // assign node to the block which kills the location @@ -553,7 +575,6 @@ // schedule read out of the loop if possible, in terms of killMaps and earliest // schedule if (currentBlock != earliestBlock && previousBlock != earliestBlock) { - assert earliestBlock.dominates(currentBlock); Block t = currentBlock; while (t.getLoop() != null && t.getDominator() != null && earliestBlock.dominates(t)) { Block dom = t.getDominator(); @@ -568,17 +589,12 @@ if (!outOfLoop && previousBlock.getBeginNode() instanceof MergeNode) { // merges kill locations right at the beginning of a block. if a merge is the - // killing node, we assign it to the dominating node. + // killing node, we assign it to the dominating block. MergeNode merge = (MergeNode) previousBlock.getBeginNode(); - Node killer = getMergeToKillMap().get(merge).get(locid); + Node killer = mergeToKillMap.get(merge).get(locid); if (killer != null && killer == merge) { - // check if we violate earliest schedule condition - if (isAtEarliest) { - printIterations(iterations, "earliest bound in merge: " + earliestBlock); - return earliestBlock; - } printIterations(iterations, "kill by merge: " + currentBlock); return currentBlock; } @@ -590,11 +606,6 @@ return previousBlock; } - if (isAtEarliest) { - printIterations(iterations, "earliest bound: " + earliestBlock); - return earliestBlock; - } - if (upperBoundBlock == currentBlock) { printIterations(iterations, "upper bound: " + currentBlock + ", previous: " + previousBlock); return currentBlock;