diff graal/com.oracle.truffle.api.dsl/src/com/oracle/truffle/api/dsl/internal/SpecializationNode.java @ 19757:e8d2f3f95dcd

Truffle-DSL: implemented duplication check for specializations with @Cached to avoid duplicates for multithreaded AST execution.
author Christian Humer <christian.humer@gmail.com>
date Tue, 10 Mar 2015 19:28:26 +0100
parents f4792a544170
children f682b9e6ca07
line wrap: on
line diff
--- a/graal/com.oracle.truffle.api.dsl/src/com/oracle/truffle/api/dsl/internal/SpecializationNode.java	Tue Mar 10 13:47:46 2015 +0100
+++ b/graal/com.oracle.truffle.api.dsl/src/com/oracle/truffle/api/dsl/internal/SpecializationNode.java	Tue Mar 10 19:28:26 2015 +0100
@@ -30,12 +30,13 @@
 
 import com.oracle.truffle.api.*;
 import com.oracle.truffle.api.dsl.*;
-import com.oracle.truffle.api.dsl.internal.RewriteEvent.RewriteEvent0;
-import com.oracle.truffle.api.dsl.internal.RewriteEvent.RewriteEvent1;
-import com.oracle.truffle.api.dsl.internal.RewriteEvent.RewriteEvent2;
-import com.oracle.truffle.api.dsl.internal.RewriteEvent.RewriteEvent3;
-import com.oracle.truffle.api.dsl.internal.RewriteEvent.RewriteEvent4;
-import com.oracle.truffle.api.dsl.internal.RewriteEvent.RewriteEventN;
+import com.oracle.truffle.api.dsl.internal.SlowPathEvent.SlowPathEvent0;
+import com.oracle.truffle.api.dsl.internal.SlowPathEvent.SlowPathEvent1;
+import com.oracle.truffle.api.dsl.internal.SlowPathEvent.SlowPathEvent2;
+import com.oracle.truffle.api.dsl.internal.SlowPathEvent.SlowPathEvent3;
+import com.oracle.truffle.api.dsl.internal.SlowPathEvent.SlowPathEvent4;
+import com.oracle.truffle.api.dsl.internal.SlowPathEvent.SlowPathEvent5;
+import com.oracle.truffle.api.dsl.internal.SlowPathEvent.SlowPathEventN;
 import com.oracle.truffle.api.frame.*;
 import com.oracle.truffle.api.nodes.*;
 import com.oracle.truffle.api.nodes.NodeUtil.NodeClass;
@@ -44,13 +45,13 @@
 /**
  * Internal implementation dependent base class for generated specialized nodes.
  */
+@NodeInfo(cost = NodeCost.NONE)
 @SuppressWarnings("unused")
-@NodeInfo(cost = NodeCost.NONE)
 public abstract class SpecializationNode extends Node {
 
     @Child protected SpecializationNode next;
 
-    private final int index;
+    final int index;
 
     public SpecializationNode() {
         this(-1);
@@ -92,10 +93,9 @@
         }
     }
 
-    protected final SpecializationNode polymorphicMerge(SpecializationNode newNode) {
-        SpecializationNode merged = next.merge(newNode);
-        if (merged == newNode && !isSame(newNode) && count() <= 2) {
-            return removeSame(new RewriteEvent0(findRoot(), "merged polymorphic to monomorphic"));
+    protected final SpecializationNode polymorphicMerge(SpecializationNode newNode, SpecializationNode merged) {
+        if (merged == newNode && count() <= 2) {
+            return removeSame(new SlowPathEvent0(this, "merged polymorphic to monomorphic", null));
         }
         return merged;
     }
@@ -114,15 +114,85 @@
 
     protected abstract Node[] getSuppliedChildren();
 
-    protected SpecializationNode merge(SpecializationNode newNode) {
-        if (this.isSame(newNode)) {
+    protected SpecializationNode merge(SpecializationNode newNode, Frame frame) {
+        if (isIdentical(newNode, frame)) {
+            return this;
+        }
+        return next != null ? next.merge(newNode, frame) : newNode;
+    }
+
+    protected SpecializationNode merge(SpecializationNode newNode, Frame frame, Object o1) {
+        if (isIdentical(newNode, frame, o1)) {
+            return this;
+        }
+        return next != null ? next.merge(newNode, frame, o1) : newNode;
+    }
+
+    protected SpecializationNode merge(SpecializationNode newNode, Frame frame, Object o1, Object o2) {
+        if (isIdentical(newNode, frame, o1, o2)) {
+            return this;
+        }
+        return next != null ? next.merge(newNode, frame, o1, o2) : newNode;
+    }
+
+    protected SpecializationNode merge(SpecializationNode newNode, Frame frame, Object o1, Object o2, Object o3) {
+        if (isIdentical(newNode, frame, o1, o2, o3)) {
+            return this;
+        }
+        return next != null ? next.merge(newNode, frame, o1, o2, o3) : newNode;
+    }
+
+    protected SpecializationNode merge(SpecializationNode newNode, Frame frame, Object o1, Object o2, Object o3, Object o4) {
+        if (isIdentical(newNode, frame, o1, o2, o3, o4)) {
+            return this;
+        }
+        return next != null ? next.merge(newNode, frame, o1, o2, o3, o4) : newNode;
+    }
+
+    protected SpecializationNode merge(SpecializationNode newNode, Frame frame, Object o1, Object o2, Object o3, Object o4, Object o5) {
+        if (isIdentical(newNode, frame, o1, o2, o3, o4, o5)) {
             return this;
         }
-        return next != null ? next.merge(newNode) : newNode;
+        return next != null ? next.merge(newNode, frame, o1, o2, o3, o4, o5) : newNode;
+    }
+
+    protected SpecializationNode merge(SpecializationNode newNode, Frame frame, Object... args) {
+        if (isIdentical(newNode, frame, args)) {
+            return this;
+        }
+        return next != null ? next.merge(newNode, frame, args) : newNode;
+    }
+
+    protected boolean isSame(SpecializationNode other) {
+        return getClass() == other.getClass();
+    }
+
+    protected boolean isIdentical(SpecializationNode newNode, Frame frame) {
+        return isSame(newNode);
     }
 
-    protected SpecializationNode mergeNoSame(SpecializationNode newNode) {
-        return next != null ? next.merge(newNode) : newNode;
+    protected boolean isIdentical(SpecializationNode newNode, Frame frame, Object o1) {
+        return isSame(newNode);
+    }
+
+    protected boolean isIdentical(SpecializationNode newNode, Frame frame, Object o1, Object o2) {
+        return isSame(newNode);
+    }
+
+    protected boolean isIdentical(SpecializationNode newNode, Frame frame, Object o1, Object o2, Object o3) {
+        return isSame(newNode);
+    }
+
+    protected boolean isIdentical(SpecializationNode newNode, Frame frame, Object o1, Object o2, Object o3, Object o4) {
+        return isSame(newNode);
+    }
+
+    protected boolean isIdentical(SpecializationNode newNode, Frame frame, Object o1, Object o2, Object o3, Object o4, Object o5) {
+        return isSame(newNode);
+    }
+
+    protected boolean isIdentical(SpecializationNode newNode, Frame frame, Object... args) {
+        return isSame(newNode);
     }
 
     protected final int countSame(SpecializationNode node) {
@@ -150,10 +220,6 @@
         return index;
     }
 
-    protected boolean isSame(SpecializationNode other) {
-        return getClass() == other.getClass();
-    }
-
     private int count() {
         return next != null ? next.count() + 1 : 1;
     }
@@ -226,8 +292,12 @@
         return findStart().getParent();
     }
 
-    private SpecializationNode removeSameImpl(SpecializationNode toRemove, CharSequence reason) {
-        SpecializationNode start = findStart();
+    private SpecializedNode findSpecializedNode() {
+        return (SpecializedNode) findEnd().findStart().getParent();
+    }
+
+    private static SpecializationNode removeSameImpl(SpecializationNode toRemove, CharSequence reason) {
+        SpecializationNode start = toRemove.findStart();
         SpecializationNode current = start;
         while (current != null) {
             if (current.isSame(toRemove)) {
@@ -238,7 +308,7 @@
             }
             current = current.next;
         }
-        return findEnd().findStart();
+        return toRemove.findEnd().findStart();
     }
 
     public Object acceptAndExecute(Frame frame) {
@@ -314,7 +384,7 @@
         if (nextSpecialization == null) {
             return unsupported(frame);
         }
-        return insertSpecialization(nextSpecialization, new RewriteEvent0(findRoot(), "inserted new specialization")).acceptAndExecute(frame);
+        return atomic(new InsertionEvent0(this, "insert new specialization", frame, nextSpecialization)).acceptAndExecute(frame);
     }
 
     protected final Object uninitialized(Frame frame, Object o1) {
@@ -326,7 +396,7 @@
         if (nextSpecialization == null) {
             return unsupported(frame, o1);
         }
-        return insertSpecialization(nextSpecialization, new RewriteEvent1(findRoot(), "inserted new specialization", o1)).acceptAndExecute(frame, o1);
+        return atomic(new InsertionEvent1(this, "insert new specialization", frame, o1, nextSpecialization)).acceptAndExecute(frame, o1);
     }
 
     protected final Object uninitialized(Frame frame, Object o1, Object o2) {
@@ -338,7 +408,7 @@
         if (nextSpecialization == null) {
             return unsupported(frame, o1, o2);
         }
-        return insertSpecialization(nextSpecialization, new RewriteEvent2(findRoot(), "inserted new specialization", o1, o2)).acceptAndExecute(frame, o1, o2);
+        return atomic(new InsertionEvent2(this, "insert new specialization", frame, o1, o2, nextSpecialization)).acceptAndExecute(frame, o1, o2);
     }
 
     protected final Object uninitialized(Frame frame, Object o1, Object o2, Object o3) {
@@ -350,7 +420,7 @@
         if (nextSpecialization == null) {
             return unsupported(frame, o1, o2, o3);
         }
-        return insertSpecialization(nextSpecialization, new RewriteEvent3(findRoot(), "inserted new specialization", o1, o2, o3)).acceptAndExecute(frame, o1, o2, o3);
+        return atomic(new InsertionEvent3(this, "insert new specialization", frame, o1, o2, o3, nextSpecialization)).acceptAndExecute(frame, o1, o2, o3);
     }
 
     protected final Object uninitialized(Frame frame, Object o1, Object o2, Object o3, Object o4) {
@@ -362,7 +432,7 @@
         if (nextSpecialization == null) {
             return unsupported(frame, o1, o2, o3, o4);
         }
-        return insertSpecialization(nextSpecialization, new RewriteEvent4(findRoot(), "inserts new specialization", o1, o2, o3, o4)).acceptAndExecute(frame, o1, o2, o3, o4);
+        return atomic(new InsertionEvent4(this, "insert new specialization", frame, o1, o2, o3, o4, nextSpecialization)).acceptAndExecute(frame, o1, o2, o3, o4);
     }
 
     protected final Object uninitialized(Frame frame, Object o1, Object o2, Object o3, Object o4, Object o5) {
@@ -374,7 +444,7 @@
         if (nextSpecialization == null) {
             unsupported(frame, o1, o2, o3, o4, o5);
         }
-        return insertSpecialization(nextSpecialization, new RewriteEventN(findRoot(), "inserts new specialization", o1, o2, o3, o4, o5)).acceptAndExecute(frame, o1, o2, o3, o4, o5);
+        return atomic(new InsertionEvent5(this, "insert new specialization", frame, o1, o2, o3, o4, o5, nextSpecialization)).acceptAndExecute(frame, o1, o2, o3, o4, o5);
     }
 
     protected final Object uninitialized(Frame frame, Object... args) {
@@ -386,39 +456,35 @@
         if (nextSpecialization == null) {
             unsupported(frame, args);
         }
-        return insertSpecialization(nextSpecialization, new RewriteEventN(findRoot(), "inserts new specialization", args)).acceptAndExecute(frame, args);
-    }
-
-    private boolean needsPolymorphic() {
-        return findStart().count() == 2;
+        return atomic(new InsertionEventN(this, "insert new specialization", frame, args, nextSpecialization)).acceptAndExecute(frame, args);
     }
 
     protected final Object remove(String reason, Frame frame) {
-        return removeSame(new RewriteEvent0(findRoot(), reason)).acceptAndExecute(frame);
+        return atomic(new RemoveEvent0(this, reason, frame)).acceptAndExecute(frame);
     }
 
     protected final Object remove(String reason, Frame frame, Object o1) {
-        return removeSame(new RewriteEvent1(findRoot(), reason, o1)).acceptAndExecute(frame, o1);
+        return atomic(new RemoveEvent1(this, reason, frame, o1)).acceptAndExecute(frame, o1);
     }
 
     protected final Object remove(String reason, Frame frame, Object o1, Object o2) {
-        return removeSame(new RewriteEvent2(findRoot(), reason, o1, o2)).acceptAndExecute(frame, o1, o2);
+        return atomic(new RemoveEvent2(this, reason, frame, o1, o2)).acceptAndExecute(frame, o1, o2);
     }
 
     protected final Object remove(String reason, Frame frame, Object o1, Object o2, Object o3) {
-        return removeSame(new RewriteEvent3(findRoot(), reason, o1, o2, o3)).acceptAndExecute(frame, o1, o2, o3);
+        return atomic(new RemoveEvent3(this, reason, frame, o1, o2, o3)).acceptAndExecute(frame, o1, o2, o3);
     }
 
     protected final Object remove(String reason, Frame frame, Object o1, Object o2, Object o3, Object o4) {
-        return removeSame(new RewriteEvent4(findRoot(), reason, o1, o2, o3, o4)).acceptAndExecute(frame, o1, o2, o3, o4);
+        return atomic(new RemoveEvent4(this, reason, frame, o1, o2, o3, o4)).acceptAndExecute(frame, o1, o2, o3, o4);
     }
 
     protected final Object remove(String reason, Frame frame, Object o1, Object o2, Object o3, Object o4, Object o5) {
-        return removeSame(new RewriteEventN(findRoot(), reason, o1, o2, o3, o4, o5)).acceptAndExecute(frame, o1, o2, o3, o4, o5);
+        return atomic(new RemoveEvent5(this, reason, frame, o1, o2, o3, o4, o5)).acceptAndExecute(frame, o1, o2, o3, o4, o5);
     }
 
     protected final Object remove(String reason, Frame frame, Object... args) {
-        return removeSame(new RewriteEventN(findRoot(), reason, args)).acceptAndExecute(frame, args);
+        return atomic(new RemoveEventN(this, reason, frame, args)).acceptAndExecute(frame, args);
     }
 
     protected Object unsupported(Frame frame) {
@@ -449,49 +515,24 @@
         throw new UnsupportedSpecializationException(findRoot(), getSuppliedChildren(), args);
     }
 
-    private SpecializationNode insertSpecialization(final SpecializationNode generated, final CharSequence message) {
-        return atomic(new Callable<SpecializationNode>() {
-            public SpecializationNode call() {
-                return insert(generated, message);
-            }
-        });
-    }
-
-    private SpecializationNode insert(final SpecializationNode generated, CharSequence message) {
-        SpecializationNode start = findStart();
-        if (start == this) {
-            // fast path for first insert
-            return insertBefore(this, generated, message);
-        } else {
-            return slowSortedInsert(start, generated, message);
-        }
-    }
-
-    private static <T> SpecializationNode slowSortedInsert(SpecializationNode start, final SpecializationNode generated, final CharSequence message) {
-        final SpecializationNode merged = start.merge(generated);
+    static SpecializationNode insertSorted(SpecializationNode start, final SpecializationNode generated, final CharSequence message, final SpecializationNode merged) {
         if (merged == generated) {
             // new node
             if (start.count() == 2) {
-                insertBefore(start, start.createPolymorphic(), "insert polymorphic");
+                insertAt(start, start.createPolymorphic(), "insert polymorphic");
             }
-            SpecializationNode insertBefore = findInsertBeforeNode(generated.index, start);
-            return insertBefore(insertBefore, generated, message);
+            SpecializationNode current = start;
+            while (current != null && current.index < generated.index) {
+                current = current.next;
+            }
+            return insertAt(current, generated, message);
         } else {
             // existing node
-            merged.replace(merged, new RewriteEvent0(merged.findRoot(), "merged specialization"));
             return start;
         }
     }
 
-    private static SpecializationNode findInsertBeforeNode(int generatedIndex, SpecializationNode start) {
-        SpecializationNode current = start;
-        while (current != null && current.index < generatedIndex) {
-            current = current.next;
-        }
-        return current;
-    }
-
-    private static <T> SpecializationNode insertBefore(SpecializationNode node, SpecializationNode insertBefore, CharSequence message) {
+    static <T> SpecializationNode insertAt(SpecializationNode node, SpecializationNode insertBefore, CharSequence message) {
         insertBefore.next = node;
         return node.replace(insertBefore, message);
     }
@@ -545,7 +586,6 @@
         b.append(")");
     }
 
-    // utilities for generated code
     protected static void check(Assumption assumption) throws InvalidAssumptionException {
         if (assumption != null) {
             assumption.check();
@@ -580,4 +620,226 @@
         return true;
     }
 
+    private static final class InsertionEvent0 extends SlowPathEvent0 implements Callable<SpecializationNode> {
+
+        private final SpecializationNode next;
+
+        public InsertionEvent0(SpecializationNode source, String reason, Frame frame, SpecializationNode next) {
+            super(source, reason, frame);
+            this.next = next;
+        }
+
+        public SpecializationNode call() throws Exception {
+            SpecializationNode start = source.findStart();
+            if (start.index == Integer.MAX_VALUE) {
+                return insertAt(start, next, this);
+            } else {
+                return insertSorted(start, next, this, start.merge(next, frame));
+            }
+        }
+
+    }
+
+    private static final class InsertionEvent1 extends SlowPathEvent1 implements Callable<SpecializationNode> {
+
+        private final SpecializationNode next;
+
+        public InsertionEvent1(SpecializationNode source, String reason, Frame frame, Object o1, SpecializationNode next) {
+            super(source, reason, frame, o1);
+            this.next = next;
+        }
+
+        public SpecializationNode call() throws Exception {
+            SpecializationNode start = source.findStart();
+            if (start.index == Integer.MAX_VALUE) {
+                return insertAt(start, next, this);
+            } else {
+                return insertSorted(start, next, this, start.merge(next, frame, o1));
+            }
+        }
+
+    }
+
+    private static final class InsertionEvent2 extends SlowPathEvent2 implements Callable<SpecializationNode> {
+
+        private final SpecializationNode next;
+
+        public InsertionEvent2(SpecializationNode source, String reason, Frame frame, Object o1, Object o2, SpecializationNode next) {
+            super(source, reason, frame, o1, o2);
+            this.next = next;
+        }
+
+        public SpecializationNode call() throws Exception {
+            SpecializationNode start = source.findStart();
+            if (start.index == Integer.MAX_VALUE) {
+                return insertAt(start, next, this);
+            } else {
+                return insertSorted(start, next, this, start.merge(next, frame, o1, o2));
+            }
+        }
+
+    }
+
+    private static final class InsertionEvent3 extends SlowPathEvent3 implements Callable<SpecializationNode> {
+
+        private final SpecializationNode next;
+
+        public InsertionEvent3(SpecializationNode source, String reason, Frame frame, Object o1, Object o2, Object o3, SpecializationNode next) {
+            super(source, reason, frame, o1, o2, o3);
+            this.next = next;
+        }
+
+        public SpecializationNode call() throws Exception {
+            SpecializationNode start = source.findStart();
+            if (start.index == Integer.MAX_VALUE) {
+                return insertAt(start, next, this);
+            } else {
+                return insertSorted(start, next, this, start.merge(next, frame, o1, o2, o3));
+            }
+        }
+
+    }
+
+    private static final class InsertionEvent4 extends SlowPathEvent4 implements Callable<SpecializationNode> {
+
+        private final SpecializationNode next;
+
+        public InsertionEvent4(SpecializationNode source, String reason, Frame frame, Object o1, Object o2, Object o3, Object o4, SpecializationNode next) {
+            super(source, reason, frame, o1, o2, o3, o4);
+            this.next = next;
+        }
+
+        public SpecializationNode call() throws Exception {
+            SpecializationNode start = source.findStart();
+            if (start.index == Integer.MAX_VALUE) {
+                return insertAt(start, next, this);
+            } else {
+                return insertSorted(start, next, this, start.merge(next, frame, o1, o2, o3, o4));
+            }
+        }
+
+    }
+
+    private static final class InsertionEvent5 extends SlowPathEvent5 implements Callable<SpecializationNode> {
+
+        private final SpecializationNode next;
+
+        public InsertionEvent5(SpecializationNode source, String reason, Frame frame, Object o1, Object o2, Object o3, Object o4, Object o5, SpecializationNode next) {
+            super(source, reason, frame, o1, o2, o3, o4, o5);
+            this.next = next;
+        }
+
+        public SpecializationNode call() throws Exception {
+            SpecializationNode start = source.findStart();
+            if (start.index == Integer.MAX_VALUE) {
+                return insertAt(start, next, this);
+            } else {
+                return insertSorted(start, next, this, start.merge(next, frame, o1, o2, o3, o4, o5));
+            }
+        }
+
+    }
+
+    private static final class InsertionEventN extends SlowPathEventN implements Callable<SpecializationNode> {
+
+        private final SpecializationNode next;
+
+        public InsertionEventN(SpecializationNode source, String reason, Frame frame, Object[] args, SpecializationNode next) {
+            super(source, reason, frame, args);
+            this.next = next;
+        }
+
+        public SpecializationNode call() throws Exception {
+            SpecializationNode start = source.findStart();
+            if (start.index == Integer.MAX_VALUE) {
+                return insertAt(start, next, this);
+            } else {
+                return insertSorted(start, next, this, start.merge(next, frame, args));
+            }
+        }
+    }
+
+    private static final class RemoveEvent0 extends SlowPathEvent0 implements Callable<SpecializationNode> {
+
+        public RemoveEvent0(SpecializationNode source, String reason, Frame frame) {
+            super(source, reason, frame);
+        }
+
+        public SpecializationNode call() throws Exception {
+            return removeSameImpl(source, this);
+        }
+
+    }
+
+    private static final class RemoveEvent1 extends SlowPathEvent1 implements Callable<SpecializationNode> {
+
+        public RemoveEvent1(SpecializationNode source, String reason, Frame frame, Object o1) {
+            super(source, reason, frame, o1);
+        }
+
+        public SpecializationNode call() throws Exception {
+            return removeSameImpl(source, this);
+        }
+
+    }
+
+    private static final class RemoveEvent2 extends SlowPathEvent2 implements Callable<SpecializationNode> {
+
+        public RemoveEvent2(SpecializationNode source, String reason, Frame frame, Object o1, Object o2) {
+            super(source, reason, frame, o1, o2);
+        }
+
+        public SpecializationNode call() throws Exception {
+            return removeSameImpl(source, this);
+        }
+
+    }
+
+    private static final class RemoveEvent3 extends SlowPathEvent3 implements Callable<SpecializationNode> {
+
+        public RemoveEvent3(SpecializationNode source, String reason, Frame frame, Object o1, Object o2, Object o3) {
+            super(source, reason, frame, o1, o2, o3);
+        }
+
+        public SpecializationNode call() throws Exception {
+            return removeSameImpl(source, this);
+        }
+
+    }
+
+    private static final class RemoveEvent4 extends SlowPathEvent4 implements Callable<SpecializationNode> {
+
+        public RemoveEvent4(SpecializationNode source, String reason, Frame frame, Object o1, Object o2, Object o3, Object o4) {
+            super(source, reason, frame, o1, o2, o3, o4);
+        }
+
+        public SpecializationNode call() throws Exception {
+            return removeSameImpl(source, this);
+        }
+
+    }
+
+    private static final class RemoveEvent5 extends SlowPathEvent5 implements Callable<SpecializationNode> {
+
+        public RemoveEvent5(SpecializationNode source, String reason, Frame frame, Object o1, Object o2, Object o3, Object o4, Object o5) {
+            super(source, reason, frame, o1, o2, o3, o4, o5);
+        }
+
+        public SpecializationNode call() throws Exception {
+            return removeSameImpl(source, this);
+        }
+
+    }
+
+    private static final class RemoveEventN extends SlowPathEventN implements Callable<SpecializationNode> {
+
+        public RemoveEventN(SpecializationNode source, String reason, Frame frame, Object[] args) {
+            super(source, reason, frame, args);
+        }
+
+        public SpecializationNode call() throws Exception {
+            return removeSameImpl(source, this);
+        }
+    }
+
 }