changeset 20081:8529bfcef6f5

Correctly re-wire memory edges in snippets.
author Roland Schatz <roland.schatz@oracle.com>
date Mon, 30 Mar 2015 16:51:26 +0200
parents 826a51b9c5d1
children 5a42f9b582c6 0b2bd777d933 5ea03a00828a
files graal/com.oracle.graal.replacements/src/com/oracle/graal/replacements/InstanceOfSnippetsTemplates.java graal/com.oracle.graal.replacements/src/com/oracle/graal/replacements/SnippetTemplate.java
diffstat 2 files changed, 118 insertions(+), 60 deletions(-) [+]
line wrap: on
line diff
--- a/graal/com.oracle.graal.replacements/src/com/oracle/graal/replacements/InstanceOfSnippetsTemplates.java	Mon Mar 30 14:28:50 2015 +0200
+++ b/graal/com.oracle.graal.replacements/src/com/oracle/graal/replacements/InstanceOfSnippetsTemplates.java	Mon Mar 30 16:51:26 2015 +0200
@@ -208,7 +208,7 @@
         }
 
         @Override
-        public void replace(ValueNode oldNode, ValueNode newNode, MemoryMap mmap) {
+        public void replace(ValueNode oldNode, ValueNode newNode) {
             assert newNode instanceof PhiNode;
             assert oldNode == instanceOf;
             newNode.inferStamp();
@@ -239,7 +239,7 @@
         }
 
         @Override
-        public void replace(ValueNode oldNode, ValueNode newNode, MemoryMap mmap) {
+        public void replace(ValueNode oldNode, ValueNode newNode) {
             assert newNode instanceof PhiNode;
             assert oldNode == instanceOf;
             newNode.inferStamp();
--- a/graal/com.oracle.graal.replacements/src/com/oracle/graal/replacements/SnippetTemplate.java	Mon Mar 30 14:28:50 2015 +0200
+++ b/graal/com.oracle.graal.replacements/src/com/oracle/graal/replacements/SnippetTemplate.java	Mon Mar 30 16:51:26 2015 +0200
@@ -722,18 +722,20 @@
 
         new FloatingReadPhase(false, true).apply(snippetCopy);
 
-        MemoryAnchorNode memoryAnchor = snippetCopy.add(new MemoryAnchorNode());
-        snippetCopy.start().replaceAtUsages(InputType.Memory, memoryAnchor);
+        MemoryAnchorNode anchor = snippetCopy.add(new MemoryAnchorNode());
+        snippetCopy.start().replaceAtUsages(InputType.Memory, anchor);
 
         this.snippet = snippetCopy;
 
         Debug.dump(snippet, "SnippetTemplate after fixing memory anchoring");
 
         StartNode entryPointNode = snippet.start();
-        if (memoryAnchor.hasNoUsages()) {
-            memoryAnchor.safeDelete();
+        if (anchor.hasNoUsages()) {
+            anchor.safeDelete();
+            this.memoryAnchor = null;
         } else {
-            snippetCopy.addAfterFixed(snippetCopy.start(), memoryAnchor);
+            snippetCopy.addAfterFixed(snippetCopy.start(), anchor);
+            this.memoryAnchor = anchor;
         }
         List<ReturnNode> returnNodes = snippet.getNodes(ReturnNode.TYPE).snapshot();
         if (returnNodes.isEmpty()) {
@@ -832,6 +834,11 @@
     private final ReturnNode returnNode;
 
     /**
+     * The memory anchor (if any) of the snippet.
+     */
+    private final MemoryAnchorNode memoryAnchor;
+
+    /**
      * Nodes that inherit the {@link StateSplit#stateAfter()} from the replacee during
      * instantiation.
      */
@@ -948,7 +955,7 @@
         /**
          * Replaces all usages of {@code oldNode} with direct or indirect usages of {@code newNode}.
          */
-        void replace(ValueNode oldNode, ValueNode newNode, MemoryMap mmap);
+        void replace(ValueNode oldNode, ValueNode newNode);
     }
 
     /**
@@ -957,48 +964,8 @@
      */
     public static final UsageReplacer DEFAULT_REPLACER = new UsageReplacer() {
 
-        private LocationIdentity getLocationIdentity(Node node) {
-            if (node instanceof MemoryAccess) {
-                return ((MemoryAccess) node).getLocationIdentity();
-            } else if (node instanceof MemoryProxy) {
-                return ((MemoryProxy) node).getLocationIdentity();
-            } else if (node instanceof MemoryPhiNode) {
-                return ((MemoryPhiNode) node).getLocationIdentity();
-            } else {
-                return null;
-            }
-        }
-
         @Override
-        public void replace(ValueNode oldNode, ValueNode newNode, MemoryMap mmap) {
-            if (mmap != null) {
-                for (Node usage : oldNode.usages().snapshot()) {
-                    LocationIdentity identity = getLocationIdentity(usage);
-                    boolean usageReplaced = false;
-                    if (identity != null && !identity.isImmutable()) {
-                        // lastLocationAccess points into the snippet graph. find a proper
-                        // MemoryCheckPoint inside the snippet graph
-                        MemoryNode lastAccess = mmap.getLastLocationAccess(identity);
-
-                        assert lastAccess != null : "no mapping found for lowerable node " + oldNode + ". (No node in the snippet kills the same location as the lowerable node?)";
-                        if (usage instanceof MemoryAccess) {
-                            MemoryAccess access = (MemoryAccess) usage;
-                            if (access.getLastLocationAccess() == oldNode) {
-                                assert oldNode.graph().isAfterFloatingReadPhase();
-                                access.setLastLocationAccess(lastAccess);
-                                usageReplaced = true;
-                            }
-                        } else {
-                            assert usage instanceof MemoryProxy || usage instanceof MemoryPhiNode;
-                            usage.replaceFirstInput(oldNode, lastAccess.asNode());
-                            usageReplaced = true;
-                        }
-                    }
-                    if (!usageReplaced) {
-                        assert newNode != null : "this branch is only valid if we have a newNode for replacement";
-                    }
-                }
-            }
+        public void replace(ValueNode oldNode, ValueNode newNode) {
             if (newNode == null) {
                 assert oldNode.hasNoUsages();
             } else {
@@ -1066,14 +1033,48 @@
         return true;
     }
 
-    private class DuplicateMapper implements MemoryMap {
+    private static class MemoryInputMap implements MemoryMap {
+
+        private final LocationIdentity locationIdentity;
+        private final MemoryNode lastLocationAccess;
+
+        public MemoryInputMap(ValueNode replacee) {
+            if (replacee instanceof MemoryAccess) {
+                MemoryAccess access = (MemoryAccess) replacee;
+                locationIdentity = access.getLocationIdentity();
+                lastLocationAccess = access.getLastLocationAccess();
+            } else {
+                locationIdentity = null;
+                lastLocationAccess = null;
+            }
+        }
+
+        @Override
+        public MemoryNode getLastLocationAccess(LocationIdentity location) {
+            if (locationIdentity != null && locationIdentity.equals(location)) {
+                return lastLocationAccess;
+            } else {
+                return null;
+            }
+        }
+
+        @Override
+        public Collection<LocationIdentity> getLocations() {
+            if (locationIdentity == null) {
+                return Collections.emptySet();
+            } else {
+                return Collections.singleton(locationIdentity);
+            }
+        }
+    }
+
+    private class MemoryOutputMap extends MemoryInputMap {
 
         private final Map<Node, Node> duplicates;
-        private StartNode replaceeStart;
 
-        public DuplicateMapper(Map<Node, Node> duplicates, StartNode replaceeStart) {
+        public MemoryOutputMap(ValueNode replacee, Map<Node, Node> duplicates) {
+            super(replacee);
             this.duplicates = duplicates;
-            this.replaceeStart = replaceeStart;
         }
 
         @Override
@@ -1081,9 +1082,9 @@
             MemoryMapNode memoryMap = returnNode.getMemoryMap();
             assert memoryMap != null : "no memory map stored for this snippet graph (snippet doesn't have a ReturnNode?)";
             MemoryNode lastLocationAccess = memoryMap.getLastLocationAccess(locationIdentity);
-            assert lastLocationAccess != null;
-            if (lastLocationAccess instanceof StartNode) {
-                return replaceeStart;
+            assert lastLocationAccess != null : locationIdentity;
+            if (lastLocationAccess == memoryAnchor) {
+                return super.getLastLocationAccess(locationIdentity);
             } else {
                 return (MemoryNode) duplicates.get(ValueNodeUtil.asNode(lastLocationAccess));
             }
@@ -1095,6 +1096,60 @@
         }
     }
 
+    private void rewireMemoryGraph(ValueNode replacee, Map<Node, Node> duplicates) {
+        // rewire outgoing memory edges
+        replaceMemoryUsages(replacee, new MemoryOutputMap(replacee, duplicates));
+
+        ReturnNode ret = (ReturnNode) duplicates.get(returnNode);
+        MemoryMapNode memoryMap = ret.getMemoryMap();
+        ret.setMemoryMap(null);
+        memoryMap.safeDelete();
+
+        if (memoryAnchor != null) {
+            // rewire incoming memory edges
+            MemoryAnchorNode memoryDuplicate = (MemoryAnchorNode) duplicates.get(memoryAnchor);
+            replaceMemoryUsages(memoryDuplicate, new MemoryInputMap(replacee));
+
+            if (memoryDuplicate.hasNoUsages()) {
+                memoryDuplicate.graph().removeFixed(memoryDuplicate);
+            }
+        }
+    }
+
+    private static LocationIdentity getLocationIdentity(Node node) {
+        if (node instanceof MemoryAccess) {
+            return ((MemoryAccess) node).getLocationIdentity();
+        } else if (node instanceof MemoryProxy) {
+            return ((MemoryProxy) node).getLocationIdentity();
+        } else if (node instanceof MemoryPhiNode) {
+            return ((MemoryPhiNode) node).getLocationIdentity();
+        } else {
+            return null;
+        }
+    }
+
+    private static void replaceMemoryUsages(ValueNode node, MemoryMap map) {
+        for (Node usage : node.usages().snapshot()) {
+            if (usage instanceof MemoryMapNode) {
+                continue;
+            }
+
+            LocationIdentity location = getLocationIdentity(usage);
+            if (location != null) {
+                NodePosIterator iter = usage.inputs().iterator();
+                while (iter.hasNext()) {
+                    Position pos = iter.nextPosition();
+                    if (pos.getInputType() == InputType.Memory && pos.get(usage) == node) {
+                        MemoryNode replacement = map.getLastLocationAccess(location);
+                        if (replacement != null) {
+                            pos.set(usage, replacement.asNode());
+                        }
+                    }
+                }
+            }
+        }
+    }
+
     /**
      * Replaces a given fixed node with this specialized snippet.
      *
@@ -1181,17 +1236,18 @@
 
             updateStamps(replacee, duplicates);
 
+            rewireMemoryGraph(replacee, duplicates);
+
             // Replace all usages of the replacee with the value returned by the snippet
             ValueNode returnValue = null;
             if (returnNode != null && !(replacee instanceof ControlSinkNode)) {
                 ReturnNode returnDuplicate = (ReturnNode) duplicates.get(returnNode);
                 returnValue = returnDuplicate.result();
-                MemoryMap mmap = new DuplicateMapper(duplicates, replaceeGraph.start());
                 if (returnValue == null && replacee.usages().isNotEmpty() && replacee instanceof MemoryCheckpoint) {
-                    replacer.replace(replacee, null, mmap);
+                    replacer.replace(replacee, null);
                 } else {
                     assert returnValue != null || replacee.hasNoUsages();
-                    replacer.replace(replacee, returnValue, mmap);
+                    replacer.replace(replacee, returnValue);
                 }
                 if (returnDuplicate.isAlive()) {
                     FixedNode next = null;
@@ -1284,11 +1340,13 @@
             }
             updateStamps(replacee, duplicates);
 
+            rewireMemoryGraph(replacee, duplicates);
+
             // Replace all usages of the replacee with the value returned by the snippet
             ReturnNode returnDuplicate = (ReturnNode) duplicates.get(returnNode);
             ValueNode returnValue = returnDuplicate.result();
             assert returnValue != null || replacee.hasNoUsages();
-            replacer.replace(replacee, returnValue, new DuplicateMapper(duplicates, replaceeGraph.start()));
+            replacer.replace(replacee, returnValue);
 
             if (returnDuplicate.isAlive()) {
                 returnDuplicate.replaceAndDelete(next);