changeset 5478:4a99bfc329f0

Add posibility to provide a replacement function instead of map for duplication. Also added validity check so that only valid slots (Position) get patched when replacing during duplication
author Gilles Duboscq <duboscq@ssw.jku.at>
date Fri, 01 Jun 2012 17:27:31 +0200
parents a7c79bcf55ac
children af838558e9e5
files graal/com.oracle.graal.graph/src/com/oracle/graal/graph/Graph.java graal/com.oracle.graal.graph/src/com/oracle/graal/graph/NodeClass.java
diffstat 2 files changed, 75 insertions(+), 25 deletions(-) [+]
line wrap: on
line diff
--- a/graal/com.oracle.graal.graph/src/com/oracle/graal/graph/Graph.java	Fri Jun 01 17:22:59 2012 +0200
+++ b/graal/com.oracle.graal.graph/src/com/oracle/graal/graph/Graph.java	Fri Jun 01 17:27:31 2012 +0200
@@ -520,10 +520,44 @@
      * @param replacements the replacement map (can be null if no replacement is to be performed)
      * @return a map which associates the original nodes from {@code nodes} to their duplicates
      */
+    public Map<Node, Node> addDuplicates(Iterable<Node> newNodes, Map<Node, Node> replacementsMap) {
+        DuplicationReplacement replacements;
+        if (replacementsMap == null) {
+            replacements = null;
+        } else {
+            replacements = new MapReplacement(replacementsMap);
+        }
+        return addDuplicates(newNodes, replacements);
+    }
+
+    public interface DuplicationReplacement {
+        Node replacement(Node original);
+    }
+
+    private static final class MapReplacement implements DuplicationReplacement {
+        private final Map<Node, Node> map;
+        public MapReplacement(Map<Node, Node> map) {
+            this.map = map;
+        }
+        @Override
+        public Node replacement(Node original) {
+            Node replacement = map.get(original);
+            return replacement != null ? replacement : original;
+        }
+
+    }
+
+    private static final DuplicationReplacement NO_REPLACEMENT = new DuplicationReplacement() {
+        @Override
+        public Node replacement(Node original) {
+            return original;
+        }
+    };
+
     @SuppressWarnings("all")
-    public Map<Node, Node> addDuplicates(Iterable<Node> newNodes, Map<Node, Node> replacements) {
+    public Map<Node, Node> addDuplicates(Iterable<Node> newNodes, DuplicationReplacement replacements) {
         if (replacements == null) {
-            replacements = Collections.emptyMap();
+            replacements = NO_REPLACEMENT;
         }
         return NodeClass.addGraphDuplicate(this, newNodes, replacements);
     }
--- a/graal/com.oracle.graal.graph/src/com/oracle/graal/graph/NodeClass.java	Fri Jun 01 17:22:59 2012 +0200
+++ b/graal/com.oracle.graal.graph/src/com/oracle/graal/graph/NodeClass.java	Fri Jun 01 17:27:31 2012 +0200
@@ -27,6 +27,8 @@
 import java.util.Map.*;
 import java.util.concurrent.ConcurrentHashMap;
 
+import com.oracle.graal.graph.Graph.DuplicationReplacement;
+
 import sun.misc.Unsafe;
 
 public class NodeClass {
@@ -910,15 +912,21 @@
         return directSuccessorCount;
     }
 
-    static Map<Node, Node> addGraphDuplicate(Graph graph, Iterable<Node> nodes, Map<Node, Node> replacements) {
+    static Map<Node, Node> addGraphDuplicate(Graph graph, Iterable<Node> nodes, DuplicationReplacement replacements) {
         Map<Node, Node> newNodes = new IdentityHashMap<>();
+        Map<Node, Node> replacementsMap = new IdentityHashMap<>();
         // create node duplicates
         for (Node node : nodes) {
-            if (node != null && !replacements.containsKey(node)) {
+            if (node != null) {
                 assert !node.isDeleted() : "trying to duplicate deleted node";
-                Node newNode = node.clone(graph);
-                assert newNode.getClass() == node.getClass();
-                newNodes.put(node, newNode);
+                Node replacement = replacements.replacement(node);
+                if (replacement != node) {
+                    replacementsMap.put(node, replacement);
+                } else {
+                    Node newNode = node.clone(graph);
+                    assert newNode.getClass() == node.getClass();
+                    newNodes.put(node, newNode);
+                }
             }
         }
         // re-wire inputs
@@ -928,24 +936,29 @@
             for (NodeClassIterator iter = oldNode.inputs().iterator(); iter.hasNext();) {
                 Position pos = iter.nextPosition();
                 Node input = oldNode.getNodeClass().get(oldNode, pos);
-                Node target = replacements.get(input);
+                Node target = replacementsMap.get(input);
                 if (target == null) {
-                    target = newNodes.get(input);
+                    Node replacement = replacements.replacement(input);
+                    if (replacement != input) {
+                        replacementsMap.put(input, replacement);
+                        target = replacement;
+                    } else {
+                        target = newNodes.get(input);
+                    }
                 }
                 node.getNodeClass().set(node, pos, target);
             }
         }
-        for (Entry<Node, Node> entry : replacements.entrySet()) {
+        for (Entry<Node, Node> entry : replacementsMap.entrySet()) {
             Node oldNode = entry.getKey();
             Node node = entry.getValue();
-            if (oldNode == node) {
-                continue;
-            }
             for (NodeClassIterator iter = oldNode.inputs().iterator(); iter.hasNext();) {
                 Position pos = iter.nextPosition();
-                Node input = oldNode.getNodeClass().get(oldNode, pos);
-                if (newNodes.containsKey(input)) {
-                    node.getNodeClass().set(node, pos, newNodes.get(input));
+                if (pos.isValidFor(node, oldNode)) {
+                    Node input = oldNode.getNodeClass().get(oldNode, pos);
+                    if (newNodes.containsKey(input)) {
+                        node.getNodeClass().set(node, pos, newNodes.get(input));
+                    }
                 }
             }
         }
@@ -957,24 +970,27 @@
             for (NodeClassIterator iter = oldNode.successors().iterator(); iter.hasNext();) {
                 Position pos = iter.nextPosition();
                 Node succ = oldNode.getNodeClass().get(oldNode, pos);
-                Node target = replacements.get(succ);
-                if (target == null) {
+                Node target = replacementsMap.get(succ);
+                Node replacement = replacements.replacement(succ);
+                if (replacement != succ) {
+                    replacementsMap.put(succ, replacement);
+                    target = replacement;
+                } else {
                     target = newNodes.get(succ);
                 }
                 node.getNodeClass().set(node, pos, target);
             }
         }
-        for (Entry<Node, Node> entry : replacements.entrySet()) {
+        for (Entry<Node, Node> entry : replacementsMap.entrySet()) {
             Node oldNode = entry.getKey();
             Node node = entry.getValue();
-            if (oldNode == node) {
-                continue;
-            }
             for (NodeClassIterator iter = oldNode.successors().iterator(); iter.hasNext();) {
                 Position pos = iter.nextPosition();
-                Node succ = oldNode.getNodeClass().get(oldNode, pos);
-                if (newNodes.containsKey(succ)) {
-                    node.getNodeClass().set(node, pos, newNodes.get(succ));
+                if (pos.isValidFor(node, oldNode)) {
+                    Node succ = oldNode.getNodeClass().get(oldNode, pos);
+                    if (newNodes.containsKey(succ)) {
+                        node.getNodeClass().set(node, pos, newNodes.get(succ));
+                    }
                 }
             }
         }