changeset 11863:7763a42d1658

NewMemoryAwareScheduling: handle MemoryPhis properly remove earliest hack and some refactoring
author Bernhard Urban <bernhard.urban@jku.at>
date Wed, 02 Oct 2013 13:32:30 +0200
parents 4635fdfd8256
children fbcde297f87a
files graal/com.oracle.graal.phases/src/com/oracle/graal/phases/schedule/SchedulePhase.java
diffstat 1 files changed, 59 insertions(+), 48 deletions(-) [+]
line wrap: on
line diff
--- a/graal/com.oracle.graal.phases/src/com/oracle/graal/phases/schedule/SchedulePhase.java	Wed Oct 02 11:16:21 2013 +0200
+++ b/graal/com.oracle.graal.phases/src/com/oracle/graal/phases/schedule/SchedulePhase.java	Wed Oct 02 13:32:30 2013 +0200
@@ -33,6 +33,7 @@
 import com.oracle.graal.graph.*;
 import com.oracle.graal.graph.Node.Verbosity;
 import com.oracle.graal.nodes.*;
+import com.oracle.graal.nodes.PhiNode.PhiType;
 import com.oracle.graal.nodes.calc.*;
 import com.oracle.graal.nodes.cfg.*;
 import com.oracle.graal.nodes.extended.*;
@@ -99,7 +100,7 @@
 
         @Override
         protected HashSet<FloatingReadNode> processBlock(Block block, HashSet<FloatingReadNode> currentState) {
-            for (Node node : getBlockToNodesMap().get(block)) {
+            for (Node node : blockToNodesMap.get(block)) {
                 if (node instanceof FloatingReadNode) {
                     currentState.add((FloatingReadNode) node);
                 } else if (node instanceof MemoryCheckpoint.Single) {
@@ -183,37 +184,49 @@
 
         @Override
         protected Map<LocationIdentity, Node> processBlock(Block block, Map<LocationIdentity, Node> currentState) {
-            Map<LocationIdentity, Node> initKillMap = getBlockToKillMap().get(block);
-            initKillMap.putAll(currentState);
+
+            if (block.getBeginNode() instanceof MergeNode) {
+                MergeNode mergeNode = (MergeNode) block.getBeginNode();
+                for (PhiNode phi : mergeNode.usages().filter(PhiNode.class)) {
+                    if (phi.type() == PhiType.Memory) {
+                        LocationIdentity identity = (LocationIdentity) phi.getIdentity();
+                        locationKilledBy(identity, phi, currentState);
+                    }
+                }
+            }
+            currentState.putAll(blockToKillMapInit.get(block));
 
             for (Node node : block.getNodes()) {
                 if (node instanceof MemoryCheckpoint.Single) {
                     LocationIdentity identity = ((MemoryCheckpoint.Single) node).getLocationIdentity();
-                    initKillMap.put(identity, node);
+                    locationKilledBy(identity, node, currentState);
                 } else if (node instanceof MemoryCheckpoint.Multi) {
                     for (LocationIdentity identity : ((MemoryCheckpoint.Multi) node).getLocationIdentities()) {
-                        initKillMap.put(identity, node);
+                        locationKilledBy(identity, node, currentState);
                     }
                 }
                 assert MemoryCheckpoint.TypeAssertion.correctType(node);
             }
 
-            return cloneState(initKillMap);
+            blockToKillMap.put(block, currentState);
+            return cloneState(currentState);
+        }
+
+        private void locationKilledBy(LocationIdentity identity, Node checkpoint, Map<LocationIdentity, Node> state) {
+            state.put(identity, checkpoint);
+            if (identity == ANY_LOCATION) {
+                for (LocationIdentity locid : state.keySet()) {
+                    state.put(locid, checkpoint);
+                }
+            }
         }
 
         @Override
         protected Map<LocationIdentity, Node> merge(Block merge, List<Map<LocationIdentity, Node>> states) {
-            return merge(merge, states, false);
-        }
-
-        protected Map<LocationIdentity, Node> merge(Block merge, List<Map<LocationIdentity, Node>> states, boolean loopbegin) {
             assert merge.getBeginNode() instanceof MergeNode;
             MergeNode mergeNode = (MergeNode) merge.getBeginNode();
 
             Map<LocationIdentity, Node> initKillMap = new HashMap<>();
-            if (loopbegin) {
-                initKillMap.putAll(getBlockToKillMap().get(merge));
-            }
             for (Map<LocationIdentity, Node> state : states) {
                 for (LocationIdentity locid : state.keySet()) {
                     if (initKillMap.containsKey(locid)) {
@@ -226,10 +239,7 @@
                 }
             }
 
-            getMergeToKillMap().set(mergeNode, cloneState(initKillMap));
-            if (!loopbegin) {
-                initKillMap.putAll(getBlockToKillMap().get(merge));
-            }
+            mergeToKillMap.set(mergeNode, cloneState(initKillMap));
             return initKillMap;
         }
 
@@ -243,15 +253,26 @@
             LoopInfo<Map<LocationIdentity, Node>> info = ReentrantBlockIterator.processLoop(this, loop, new HashMap<>(state));
 
             assert loop.header.getBeginNode() instanceof LoopBeginNode;
-            Map<LocationIdentity, Node> headerState = merge(loop.header, info.endStates, true);
-            getBlockToKillMap().put(loop.header, headerState);
+            Map<LocationIdentity, Node> headerState = merge(loop.header, info.endStates);
+            // getBlockToKillMap().put(loop.header, headerState);
 
-            for (Map<LocationIdentity, Node> exitState : info.exitStates) {
+            int i = 0;
+            for (Block exit : loop.exits) {
+                Map<LocationIdentity, Node> exitState = info.exitStates.get(i++);
                 for (LocationIdentity key : headerState.keySet()) {
                     exitState.put(key, headerState.get(key));
                 }
+
+                Node begin = exit.getBeginNode();
+                assert begin instanceof LoopExitNode;
+                for (Node usage : begin.usages()) {
+                    if (usage instanceof ProxyNode && ((ProxyNode) usage).type() == PhiType.Memory) {
+                        ProxyNode proxy = (ProxyNode) usage;
+                        LocationIdentity identity = (LocationIdentity) proxy.getIdentity();
+                        locationKilledBy(identity, proxy, exitState);
+                    }
+                }
             }
-
             return info.exitStates;
         }
     }
@@ -263,6 +284,7 @@
      * Map from blocks to the nodes in each block.
      */
     private BlockMap<List<ScheduledNode>> blockToNodesMap;
+    private BlockMap<Map<LocationIdentity, Node>> blockToKillMapInit;
     private BlockMap<Map<LocationIdentity, Node>> blockToKillMap;
     private NodeMap<Map<LocationIdentity, Node>> mergeToKillMap;
     private final Map<FloatingNode, List<FixedNode>> phantomUsages = new IdentityHashMap<>();
@@ -315,8 +337,10 @@
         } else if (memsched == MemoryScheduling.OPTIMAL && selectedStrategy != SchedulingStrategy.EARLIEST && graph.getNodes(FloatingReadNode.class).isNotEmpty()) {
             mergeToKillMap = graph.createNodeMap();
 
+            blockToKillMapInit = new BlockMap<>(cfg);
             blockToKillMap = new BlockMap<>(cfg);
             for (Block b : cfg.getBlocks()) {
+                blockToKillMapInit.put(b, new HashMap<LocationIdentity, Node>());
                 blockToKillMap.put(b, new HashMap<LocationIdentity, Node>());
             }
 
@@ -328,7 +352,7 @@
                 Node first = n.lastLocationAccess();
                 assert first != null;
 
-                Map<LocationIdentity, Node> killMap = blockToKillMap.get(forKillLocation(first));
+                Map<LocationIdentity, Node> killMap = blockToKillMapInit.get(forKillLocation(first));
                 killMap.put(n.location().getLocationIdentity(), first);
             }
 
@@ -357,20 +381,27 @@
     private void printSchedule(String desc) {
         Debug.printf("=== %s / %s / %s (%s) ===\n", getCFG().getStartBlock().getBeginNode().graph(), selectedStrategy, memsched, desc);
         for (Block b : getCFG().getBlocks()) {
-            Debug.printf("==== b: %s. ", b);
+            Debug.printf("==== b: %s (loopDepth: %s). ", b, b.getLoopDepth());
             Debug.printf("dom: %s. ", b.getDominator());
             Debug.printf("post-dom: %s. ", b.getPostdominator());
             Debug.printf("preds: %s. ", b.getPredecessors());
             Debug.printf("succs: %s ====\n", b.getSuccessors());
-            BlockMap<Map<LocationIdentity, Node>> killMaps = getBlockToKillMap();
+            BlockMap<Map<LocationIdentity, Node>> killMaps = blockToKillMap;
             if (killMaps != null) {
+                if (b.getBeginNode() instanceof MergeNode) {
+                    MergeNode merge = (MergeNode) b.getBeginNode();
+                    Debug.printf("M merge kills: \n");
+                    for (LocationIdentity locId : mergeToKillMap.get(merge).keySet()) {
+                        Debug.printf("M %s killed by %s\n", locId, mergeToKillMap.get(merge).get(locId));
+                    }
+                }
                 Debug.printf("X block kills: \n");
                 for (LocationIdentity locId : killMaps.get(b).keySet()) {
                     Debug.printf("X %s killed by %s\n", locId, killMaps.get(b).get(locId));
                 }
             }
 
-            if (getBlockToNodesMap().get(b) != null) {
+            if (blockToNodesMap.get(b) != null) {
                 for (Node n : nodesFor(b)) {
                     printNode(n);
                 }
@@ -414,14 +445,6 @@
         return blockToNodesMap;
     }
 
-    public BlockMap<Map<LocationIdentity, Node>> getBlockToKillMap() {
-        return blockToKillMap;
-    }
-
-    public NodeMap<Map<LocationIdentity, Node>> getMergeToKillMap() {
-        return mergeToKillMap;
-    }
-
     /**
      * Gets the nodes in a given block.
      */
@@ -541,9 +564,8 @@
         // iterate the dominator tree
         while (true) {
             iterations++;
-            assert earliestBlock.dominates(previousBlock) : "iterations: " + iterations;
             Node lastKill = blockToKillMap.get(currentBlock).get(locid);
-            boolean isAtEarliest = earliestBlock == previousBlock && previousBlock != currentBlock;
+            assert lastKill != null : "should be never null, due to init of killMaps: " + currentBlock + ", location: " + locid;
 
             if (lastKill.equals(upperBound)) {
                 // assign node to the block which kills the location
@@ -553,7 +575,6 @@
                 // schedule read out of the loop if possible, in terms of killMaps and earliest
                 // schedule
                 if (currentBlock != earliestBlock && previousBlock != earliestBlock) {
-                    assert earliestBlock.dominates(currentBlock);
                     Block t = currentBlock;
                     while (t.getLoop() != null && t.getDominator() != null && earliestBlock.dominates(t)) {
                         Block dom = t.getDominator();
@@ -568,17 +589,12 @@
 
                 if (!outOfLoop && previousBlock.getBeginNode() instanceof MergeNode) {
                     // merges kill locations right at the beginning of a block. if a merge is the
-                    // killing node, we assign it to the dominating node.
+                    // killing node, we assign it to the dominating block.
 
                     MergeNode merge = (MergeNode) previousBlock.getBeginNode();
-                    Node killer = getMergeToKillMap().get(merge).get(locid);
+                    Node killer = mergeToKillMap.get(merge).get(locid);
 
                     if (killer != null && killer == merge) {
-                        // check if we violate earliest schedule condition
-                        if (isAtEarliest) {
-                            printIterations(iterations, "earliest bound in merge: " + earliestBlock);
-                            return earliestBlock;
-                        }
                         printIterations(iterations, "kill by merge: " + currentBlock);
                         return currentBlock;
                     }
@@ -590,11 +606,6 @@
                 return previousBlock;
             }
 
-            if (isAtEarliest) {
-                printIterations(iterations, "earliest bound: " + earliestBlock);
-                return earliestBlock;
-            }
-
             if (upperBoundBlock == currentBlock) {
                 printIterations(iterations, "upper bound: " + currentBlock + ", previous: " + previousBlock);
                 return currentBlock;