changeset 16460:1da834bdfda2

let FloatingReadPhase deal with existing MemoryPhiNodes
author Lukas Stadler <lukas.stadler@oracle.com>
date Thu, 10 Jul 2014 17:07:35 +0200
parents a039ae7e0e50
children be0ad9b9aefe
files graal/com.oracle.graal.phases.common/src/com/oracle/graal/phases/common/FloatingReadPhase.java graal/com.oracle.graal.replacements/src/com/oracle/graal/replacements/SnippetTemplate.java
diffstat 2 files changed, 70 insertions(+), 24 deletions(-) [+]
line wrap: on
line diff
--- a/graal/com.oracle.graal.phases.common/src/com/oracle/graal/phases/common/FloatingReadPhase.java	Thu Jul 10 16:46:19 2014 +0200
+++ b/graal/com.oracle.graal.phases.common/src/com/oracle/graal/phases/common/FloatingReadPhase.java	Thu Jul 10 17:07:35 2014 +0200
@@ -43,10 +43,9 @@
 
 public class FloatingReadPhase extends Phase {
 
-    public enum ExecutionMode {
-        ANALYSIS_ONLY,
-        CREATE_FLOATING_READS
-    }
+    private boolean createFloatingReads;
+    private boolean createMemoryMapNodes;
+    private boolean updateExistingPhis;
 
     public static class MemoryMapImpl implements MemoryMap {
 
@@ -90,14 +89,23 @@
         }
     }
 
-    private final ExecutionMode execmode;
-
     public FloatingReadPhase() {
-        this(ExecutionMode.CREATE_FLOATING_READS);
+        this(true, false, false);
     }
 
-    public FloatingReadPhase(ExecutionMode execmode) {
-        this.execmode = execmode;
+    /**
+     * @param createFloatingReads specifies whether {@link FloatableAccessNode}s like
+     *            {@link ReadNode} should be converted into floating nodes (e.g.,
+     *            {@link FloatingReadNode}s) where possible
+     * @param createMemoryMapNodes a {@link MemoryMapNode} will be created for each return if this
+     *            is true
+     * @param updateExistingPhis if true, then existing {@link MemoryPhiNode}s in the graph will be
+     *            updated
+     */
+    public FloatingReadPhase(boolean createFloatingReads, boolean createMemoryMapNodes, boolean updateExistingPhis) {
+        this.createFloatingReads = createFloatingReads;
+        this.createMemoryMapNodes = createMemoryMapNodes;
+        this.updateExistingPhis = updateExistingPhis;
     }
 
     /**
@@ -127,7 +135,7 @@
         ReentrantNodeIterator.apply(new CollectMemoryCheckpointsClosure(modifiedInLoops), graph.start(), new HashSet<LocationIdentity>());
         HashSetNodeEventListener listener = new HashSetNodeEventListener(EnumSet.of(NODE_ADDED, ZERO_USAGES));
         try (NodeEventScope nes = graph.trackNodeEvents(listener)) {
-            ReentrantNodeIterator.apply(new FloatingReadClosure(modifiedInLoops, execmode), graph.start(), new MemoryMapImpl(graph.start()));
+            ReentrantNodeIterator.apply(new FloatingReadClosure(modifiedInLoops, createFloatingReads, createMemoryMapNodes, updateExistingPhis), graph.start(), new MemoryMapImpl(graph.start()));
         }
 
         for (Node n : removeExternallyUsedNodes(listener.getNodes())) {
@@ -136,13 +144,13 @@
                 GraphUtil.killWithUnusedFloatingInputs(n);
             }
         }
-        if (execmode == ExecutionMode.CREATE_FLOATING_READS) {
+        if (createFloatingReads) {
             assert !graph.isAfterFloatingReadPhase();
             graph.setAfterFloatingReadPhase(true);
         }
     }
 
-    public static MemoryMapImpl mergeMemoryMaps(MergeNode merge, List<? extends MemoryMap> states) {
+    public static MemoryMapImpl mergeMemoryMaps(MergeNode merge, List<? extends MemoryMap> states, boolean updateExistingPhis) {
         MemoryMapImpl newState = new MemoryMapImpl();
 
         Set<LocationIdentity> keys = new HashSet<>();
@@ -151,6 +159,17 @@
         }
         assert !keys.contains(FINAL_LOCATION);
 
+        Map<LocationIdentity, MemoryPhiNode> existingPhis = null;
+        if (updateExistingPhis) {
+            for (MemoryPhiNode phi : merge.phis().filter(MemoryPhiNode.class)) {
+                if (existingPhis == null) {
+                    existingPhis = newIdentityMap();
+                }
+                phi.values().clear();
+                existingPhis.put(phi.getLocationIdentity(), phi);
+            }
+        }
+
         for (LocationIdentity key : keys) {
             int mergedStatesCount = 0;
             boolean isPhi = false;
@@ -165,7 +184,10 @@
                     } else if (merged == null) {
                         merged = last;
                     } else {
-                        MemoryPhiNode phi = merge.graph().addWithoutUnique(new MemoryPhiNode(merge, key));
+                        MemoryPhiNode phi = null;
+                        if (existingPhis == null || (phi = existingPhis.remove(key)) == null) {
+                            phi = merge.graph().addWithoutUnique(new MemoryPhiNode(merge, key));
+                        }
                         for (int j = 0; j < mergedStatesCount; j++) {
                             phi.addInput(ValueNodeUtil.asNode(merged));
                         }
@@ -178,6 +200,11 @@
             }
             newState.lastMemorySnapshot.put(key, merged);
         }
+        if (existingPhis != null) {
+            for (Map.Entry<LocationIdentity, MemoryPhiNode> entry : existingPhis.entrySet()) {
+                entry.getValue().replaceAndDelete(newState.getLastLocationAccess(entry.getKey()).asNode());
+            }
+        }
         return newState;
 
     }
@@ -237,11 +264,15 @@
     private static class FloatingReadClosure extends NodeIteratorClosure<MemoryMapImpl> {
 
         private final Map<LoopBeginNode, Set<LocationIdentity>> modifiedInLoops;
-        private final ExecutionMode execmode;
+        private boolean createFloatingReads;
+        private boolean createMemoryMapNodes;
+        private boolean updateExistingPhis;
 
-        public FloatingReadClosure(Map<LoopBeginNode, Set<LocationIdentity>> modifiedInLoops, ExecutionMode execmode) {
+        public FloatingReadClosure(Map<LoopBeginNode, Set<LocationIdentity>> modifiedInLoops, boolean createFloatingReads, boolean createMemoryMapNodes, boolean updateExistingPhis) {
             this.modifiedInLoops = modifiedInLoops;
-            this.execmode = execmode;
+            this.createFloatingReads = createFloatingReads;
+            this.createMemoryMapNodes = createMemoryMapNodes;
+            this.updateExistingPhis = updateExistingPhis;
         }
 
         @Override
@@ -250,7 +281,7 @@
                 processAccess((MemoryAccess) node, state);
             }
 
-            if (node instanceof FloatableAccessNode && execmode == ExecutionMode.CREATE_FLOATING_READS) {
+            if (createFloatingReads & node instanceof FloatableAccessNode) {
                 processFloatable((FloatableAccessNode) node, state);
             } else if (node instanceof MemoryCheckpoint.Single) {
                 processCheckpoint((MemoryCheckpoint.Single) node, state);
@@ -259,7 +290,7 @@
             }
             assert MemoryCheckpoint.TypeAssertion.correctType(node) : node;
 
-            if (execmode == ExecutionMode.ANALYSIS_ONLY && node instanceof ReturnNode) {
+            if (createMemoryMapNodes && node instanceof ReturnNode) {
                 ((ReturnNode) node).setMemoryMap(node.graph().unique(new MemoryMapNode(state.lastMemorySnapshot)));
             }
             return state;
@@ -309,7 +340,7 @@
 
         @Override
         protected MemoryMapImpl merge(MergeNode merge, List<MemoryMapImpl> states) {
-            return mergeMemoryMaps(merge, states);
+            return mergeMemoryMaps(merge, states, updateExistingPhis);
         }
 
         @Override
@@ -339,10 +370,25 @@
             }
 
             Map<LocationIdentity, MemoryPhiNode> phis = new HashMap<>();
+
+            if (updateExistingPhis) {
+                for (MemoryPhiNode phi : loop.phis().filter(MemoryPhiNode.class)) {
+                    if (modifiedLocations.contains(phi.getLocationIdentity())) {
+                        phi.values().clear();
+                        phi.addInput(ValueNodeUtil.asNode(initialState.getLastLocationAccess(phi.getLocationIdentity())));
+                        phis.put(phi.getLocationIdentity(), phi);
+                    } else {
+                        phi.replaceAndDelete(initialState.getLastLocationAccess(phi.getLocationIdentity()).asNode());
+                    }
+                }
+            }
+
             for (LocationIdentity location : modifiedLocations) {
-                MemoryPhiNode phi = loop.graph().addWithoutUnique(new MemoryPhiNode(loop, location));
-                phi.addInput(ValueNodeUtil.asNode(initialState.getLastLocationAccess(location)));
-                phis.put(location, phi);
+                if (!updateExistingPhis || !phis.containsKey(location)) {
+                    MemoryPhiNode phi = loop.graph().addWithoutUnique(new MemoryPhiNode(loop, location));
+                    phi.addInput(ValueNodeUtil.asNode(initialState.getLastLocationAccess(location)));
+                    phis.put(location, phi);
+                }
             }
             for (Map.Entry<LocationIdentity, MemoryPhiNode> entry : phis.entrySet()) {
                 initialState.lastMemorySnapshot.put(entry.getKey(), entry.getValue());
--- a/graal/com.oracle.graal.replacements/src/com/oracle/graal/replacements/SnippetTemplate.java	Thu Jul 10 16:46:19 2014 +0200
+++ b/graal/com.oracle.graal.replacements/src/com/oracle/graal/replacements/SnippetTemplate.java	Thu Jul 10 17:07:35 2014 +0200
@@ -669,7 +669,7 @@
 
         assert checkAllVarargPlaceholdersAreDeleted(parameterCount, placeholders);
 
-        new FloatingReadPhase(FloatingReadPhase.ExecutionMode.ANALYSIS_ONLY).apply(snippetCopy);
+        new FloatingReadPhase(false, true, false).apply(snippetCopy);
 
         MemoryAnchorNode memoryAnchor = snippetCopy.add(new MemoryAnchorNode());
         snippetCopy.start().replaceAtUsages(InputType.Memory, memoryAnchor);
@@ -694,7 +694,7 @@
             List<MemoryMapNode> memMaps = returnNodes.stream().map(n -> n.getMemoryMap()).collect(Collectors.toList());
             ValueNode returnValue = InliningUtil.mergeReturns(merge, returnNodes, null);
             this.returnNode = snippet.add(new ReturnNode(returnValue));
-            MemoryMapImpl mmap = FloatingReadPhase.mergeMemoryMaps(merge, memMaps);
+            MemoryMapImpl mmap = FloatingReadPhase.mergeMemoryMaps(merge, memMaps, false);
             MemoryMapNode memoryMap = snippet.unique(new MemoryMapNode(mmap.getMap()));
             this.returnNode.setMemoryMap(memoryMap);
             for (MemoryMapNode mm : memMaps) {