changeset 13092:b334ca53f077

NewMemoryAwareScheduling: don't consider lastAccessLocation of a read as a kill also: don't accumulate KillSet along the dominator path. we aren't interested in the information from previous blocks.
author Bernhard Urban <bernhard.urban@jku.at>
date Wed, 20 Nov 2013 20:33:22 +0100
parents d1c929751642
children ec0f6ecc0b7a
files graal/com.oracle.graal.phases/src/com/oracle/graal/phases/schedule/SchedulePhase.java
diffstat 1 files changed, 76 insertions(+), 23 deletions(-) [+]
line wrap: on
line diff
--- a/graal/com.oracle.graal.phases/src/com/oracle/graal/phases/schedule/SchedulePhase.java	Wed Nov 20 16:30:06 2013 +0100
+++ b/graal/com.oracle.graal.phases/src/com/oracle/graal/phases/schedule/SchedulePhase.java	Wed Nov 20 20:33:22 2013 +0100
@@ -205,6 +205,18 @@
     }
 
     private class NewMemoryScheduleClosure extends BlockIteratorClosure<KillSet> {
+        private Node excludeNode;
+        private Block upperBoundBlock;
+
+        public NewMemoryScheduleClosure(Node excludeNode, Block upperBoundBlock) {
+            this.excludeNode = excludeNode;
+            this.upperBoundBlock = upperBoundBlock;
+        }
+
+        public NewMemoryScheduleClosure() {
+            this(null, null);
+        }
+
         @Override
         protected KillSet getInitialState() {
             return cloneState(blockToKillSet.get(getCFG().getStartBlock()));
@@ -212,7 +224,8 @@
 
         @Override
         protected KillSet processBlock(Block block, KillSet currentState) {
-            currentState.addAll(computeKillSet(block));
+            assert block != null;
+            currentState.addAll(computeKillSet(block, block == upperBoundBlock ? excludeNode : null));
             return currentState;
         }
 
@@ -253,40 +266,75 @@
      * assumptions: {@link MemoryCheckpoint MemoryCheckPoints} are {@link FixedNode FixedNodes}.
      * 
      * @param block block to analyze
+     * @param excludeNode if null, compute normal set of kill locations. if != null, don't add kills
+     *            until we reach excludeNode.
      * @return all killed locations
      */
-    private KillSet computeKillSet(Block block) {
-        KillSet cachedSet = blockToKillSet.get(block);
-        if (cachedSet != null) {
-            return cachedSet;
+    private KillSet computeKillSet(Block block, Node excludeNode) {
+        // cache is only valid if we don't potentially exclude kills from the set
+        if (excludeNode == null) {
+            KillSet cachedSet = blockToKillSet.get(block);
+            if (cachedSet != null) {
+                return cachedSet;
+            }
         }
+
+        // add locations to excludedLocations until we reach the excluded node
+        boolean foundExcludeNode = excludeNode == null;
+
         KillSet set = new KillSet();
-        blockToKillSet.put(block, set);
-
+        KillSet excludedLocations = new KillSet();
         if (block.getBeginNode() instanceof MergeNode) {
             MergeNode mergeNode = (MergeNode) block.getBeginNode();
             for (PhiNode phi : mergeNode.usages().filter(PhiNode.class)) {
                 if (phi.type() == PhiType.Memory) {
-                    set.add(phi.getIdentity());
+                    if (foundExcludeNode) {
+                        set.add(phi.getIdentity());
+                    } else {
+                        excludedLocations.add(phi.getIdentity());
+                        foundExcludeNode = phi == excludeNode;
+                    }
                 }
             }
         }
 
+        AbstractBeginNode startNode = cfg.getStartBlock().getBeginNode();
+        assert startNode instanceof StartNode;
+
+        KillSet accm = foundExcludeNode ? set : excludedLocations;
         for (Node node : block.getNodes()) {
+            if (!foundExcludeNode && node == excludeNode) {
+                foundExcludeNode = true;
+            }
+            if (node == startNode) {
+                continue;
+            }
             if (node instanceof MemoryCheckpoint.Single) {
                 LocationIdentity identity = ((MemoryCheckpoint.Single) node).getLocationIdentity();
-                set.add(identity);
+                accm.add(identity);
             } else if (node instanceof MemoryCheckpoint.Multi) {
                 for (LocationIdentity identity : ((MemoryCheckpoint.Multi) node).getLocationIdentities()) {
-                    set.add(identity);
+                    accm.add(identity);
                 }
             }
             assert MemoryCheckpoint.TypeAssertion.correctType(node);
+
+            if (foundExcludeNode) {
+                accm = set;
+            }
         }
 
+        // merge it for the cache entry
+        excludedLocations.addAll(set);
+        blockToKillSet.put(block, excludedLocations);
+
         return set;
     }
 
+    private KillSet computeKillSet(Block block) {
+        return computeKillSet(block, null);
+    }
+
     private ControlFlowGraph cfg;
     private NodeMap<Block> earliestCache;
 
@@ -299,7 +347,6 @@
     private final Map<FixedNode, List<FloatingNode>> phantomInputs = new IdentityHashMap<>();
     private final SchedulingStrategy selectedStrategy;
     private final MemoryScheduling memsched;
-    private NewMemoryScheduleClosure maschedClosure;
 
     public SchedulePhase() {
         this(OptScheduleOutOfLoops.getValue() ? SchedulingStrategy.LATEST_OUT_OF_LOOPS : SchedulingStrategy.LATEST);
@@ -344,7 +391,6 @@
             printSchedule("after sorting nodes within blocks");
         } else if (memsched == MemoryScheduling.OPTIMAL && selectedStrategy != SchedulingStrategy.EARLIEST && graph.getNodes(FloatingReadNode.class).isNotEmpty()) {
             blockToKillSet = new BlockMap<>(cfg);
-            maschedClosure = new NewMemoryScheduleClosure();
 
             assignBlockToNodes(graph, selectedStrategy);
             printSchedule("after assign nodes to blocks");
@@ -483,6 +529,7 @@
                 if (scheduleRead) {
                     FloatingReadNode read = (FloatingReadNode) node;
                     block = optimalBlock(read, strategy);
+                    Debug.printf("schedule for %s: %s\n", read, block);
                     assert earliestBlock.dominates(block) : String.format("%s (%s) cannot be scheduled before earliest schedule (%s). location: %s", read, block, earliestBlock,
                                     read.getLocationIdentity());
                 } else {
@@ -568,7 +615,6 @@
         Stack<Block> path = computePathInDominatorTree(earliestBlock, latestBlock);
         Debug.printf("|path| is %d: %s\n", path.size(), path);
 
-        KillSet killSet = new KillSet();
         // follow path, start at earliest schedule
         while (path.size() > 0) {
             Block currentBlock = path.pop();
@@ -578,10 +624,18 @@
                 assert dominatedBlock.getBeginNode() instanceof MergeNode;
 
                 HashSet<Block> region = computeRegion(currentBlock, dominatedBlock);
-                Debug.printf("%s: region for %s -> %s: %s\n", n, currentBlock, dominatedBlock, region);
+                Debug.printf("> merge.  %s: region for %s -> %s: %s\n", n, currentBlock, dominatedBlock, region);
 
+                NewMemoryScheduleClosure closure = null;
+                if (currentBlock == upperBoundBlock) {
+                    assert earliestBlock == upperBoundBlock;
+                    // don't treat lastLocationAccess node as a kill for this read.
+                    closure = new NewMemoryScheduleClosure(n.getLastLocationAccess(), upperBoundBlock);
+                } else {
+                    closure = new NewMemoryScheduleClosure();
+                }
                 Map<FixedNode, KillSet> states;
-                states = ReentrantBlockIterator.apply(maschedClosure, currentBlock, new KillSet(killSet), region);
+                states = ReentrantBlockIterator.apply(closure, currentBlock, new KillSet(), region);
 
                 KillSet mergeState = states.get(dominatedBlock.getBeginNode());
                 if (mergeState.isKilled(locid)) {
@@ -589,17 +643,16 @@
                     // thus we've to move the read above it
                     return currentBlock;
                 }
-                killSet.addAll(mergeState);
             } else {
-                // trivial case
-                if (dominatedBlock == null) {
+                if (currentBlock == upperBoundBlock) {
+                    assert earliestBlock == upperBoundBlock;
+                    KillSet ks = computeKillSet(upperBoundBlock, n.getLastLocationAccess());
+                    if (ks.isKilled(locid)) {
+                        return upperBoundBlock;
+                    }
+                } else if (dominatedBlock == null || computeKillSet(currentBlock).isKilled(locid)) {
                     return currentBlock;
                 }
-                KillSet blockKills = computeKillSet(currentBlock);
-                if (blockKills.isKilled(locid)) {
-                    return currentBlock;
-                }
-                killSet.addAll(blockKills);
             }
         }
         assert false : "should have found a block for " + n;