Mercurial > hg > truffle
diff graal/com.oracle.truffle.api.dsl/src/com/oracle/truffle/api/dsl/internal/SpecializationNode.java @ 19757:e8d2f3f95dcd
Truffle-DSL: implemented duplication check for specializations with @Cached to avoid duplicates for multithreaded AST execution.
author | Christian Humer <christian.humer@gmail.com> |
---|---|
date | Tue, 10 Mar 2015 19:28:26 +0100 |
parents | f4792a544170 |
children | f682b9e6ca07 |
line wrap: on
line diff
--- a/graal/com.oracle.truffle.api.dsl/src/com/oracle/truffle/api/dsl/internal/SpecializationNode.java Tue Mar 10 13:47:46 2015 +0100 +++ b/graal/com.oracle.truffle.api.dsl/src/com/oracle/truffle/api/dsl/internal/SpecializationNode.java Tue Mar 10 19:28:26 2015 +0100 @@ -30,12 +30,13 @@ import com.oracle.truffle.api.*; import com.oracle.truffle.api.dsl.*; -import com.oracle.truffle.api.dsl.internal.RewriteEvent.RewriteEvent0; -import com.oracle.truffle.api.dsl.internal.RewriteEvent.RewriteEvent1; -import com.oracle.truffle.api.dsl.internal.RewriteEvent.RewriteEvent2; -import com.oracle.truffle.api.dsl.internal.RewriteEvent.RewriteEvent3; -import com.oracle.truffle.api.dsl.internal.RewriteEvent.RewriteEvent4; -import com.oracle.truffle.api.dsl.internal.RewriteEvent.RewriteEventN; +import com.oracle.truffle.api.dsl.internal.SlowPathEvent.SlowPathEvent0; +import com.oracle.truffle.api.dsl.internal.SlowPathEvent.SlowPathEvent1; +import com.oracle.truffle.api.dsl.internal.SlowPathEvent.SlowPathEvent2; +import com.oracle.truffle.api.dsl.internal.SlowPathEvent.SlowPathEvent3; +import com.oracle.truffle.api.dsl.internal.SlowPathEvent.SlowPathEvent4; +import com.oracle.truffle.api.dsl.internal.SlowPathEvent.SlowPathEvent5; +import com.oracle.truffle.api.dsl.internal.SlowPathEvent.SlowPathEventN; import com.oracle.truffle.api.frame.*; import com.oracle.truffle.api.nodes.*; import com.oracle.truffle.api.nodes.NodeUtil.NodeClass; @@ -44,13 +45,13 @@ /** * Internal implementation dependent base class for generated specialized nodes. */ +@NodeInfo(cost = NodeCost.NONE) @SuppressWarnings("unused") -@NodeInfo(cost = NodeCost.NONE) public abstract class SpecializationNode extends Node { @Child protected SpecializationNode next; - private final int index; + final int index; public SpecializationNode() { this(-1); @@ -92,10 +93,9 @@ } } - protected final SpecializationNode polymorphicMerge(SpecializationNode newNode) { - SpecializationNode merged = next.merge(newNode); - if (merged == newNode && !isSame(newNode) && count() <= 2) { - return removeSame(new RewriteEvent0(findRoot(), "merged polymorphic to monomorphic")); + protected final SpecializationNode polymorphicMerge(SpecializationNode newNode, SpecializationNode merged) { + if (merged == newNode && count() <= 2) { + return removeSame(new SlowPathEvent0(this, "merged polymorphic to monomorphic", null)); } return merged; } @@ -114,15 +114,85 @@ protected abstract Node[] getSuppliedChildren(); - protected SpecializationNode merge(SpecializationNode newNode) { - if (this.isSame(newNode)) { + protected SpecializationNode merge(SpecializationNode newNode, Frame frame) { + if (isIdentical(newNode, frame)) { + return this; + } + return next != null ? next.merge(newNode, frame) : newNode; + } + + protected SpecializationNode merge(SpecializationNode newNode, Frame frame, Object o1) { + if (isIdentical(newNode, frame, o1)) { + return this; + } + return next != null ? next.merge(newNode, frame, o1) : newNode; + } + + protected SpecializationNode merge(SpecializationNode newNode, Frame frame, Object o1, Object o2) { + if (isIdentical(newNode, frame, o1, o2)) { + return this; + } + return next != null ? next.merge(newNode, frame, o1, o2) : newNode; + } + + protected SpecializationNode merge(SpecializationNode newNode, Frame frame, Object o1, Object o2, Object o3) { + if (isIdentical(newNode, frame, o1, o2, o3)) { + return this; + } + return next != null ? next.merge(newNode, frame, o1, o2, o3) : newNode; + } + + protected SpecializationNode merge(SpecializationNode newNode, Frame frame, Object o1, Object o2, Object o3, Object o4) { + if (isIdentical(newNode, frame, o1, o2, o3, o4)) { + return this; + } + return next != null ? next.merge(newNode, frame, o1, o2, o3, o4) : newNode; + } + + protected SpecializationNode merge(SpecializationNode newNode, Frame frame, Object o1, Object o2, Object o3, Object o4, Object o5) { + if (isIdentical(newNode, frame, o1, o2, o3, o4, o5)) { return this; } - return next != null ? next.merge(newNode) : newNode; + return next != null ? next.merge(newNode, frame, o1, o2, o3, o4, o5) : newNode; + } + + protected SpecializationNode merge(SpecializationNode newNode, Frame frame, Object... args) { + if (isIdentical(newNode, frame, args)) { + return this; + } + return next != null ? next.merge(newNode, frame, args) : newNode; + } + + protected boolean isSame(SpecializationNode other) { + return getClass() == other.getClass(); + } + + protected boolean isIdentical(SpecializationNode newNode, Frame frame) { + return isSame(newNode); } - protected SpecializationNode mergeNoSame(SpecializationNode newNode) { - return next != null ? next.merge(newNode) : newNode; + protected boolean isIdentical(SpecializationNode newNode, Frame frame, Object o1) { + return isSame(newNode); + } + + protected boolean isIdentical(SpecializationNode newNode, Frame frame, Object o1, Object o2) { + return isSame(newNode); + } + + protected boolean isIdentical(SpecializationNode newNode, Frame frame, Object o1, Object o2, Object o3) { + return isSame(newNode); + } + + protected boolean isIdentical(SpecializationNode newNode, Frame frame, Object o1, Object o2, Object o3, Object o4) { + return isSame(newNode); + } + + protected boolean isIdentical(SpecializationNode newNode, Frame frame, Object o1, Object o2, Object o3, Object o4, Object o5) { + return isSame(newNode); + } + + protected boolean isIdentical(SpecializationNode newNode, Frame frame, Object... args) { + return isSame(newNode); } protected final int countSame(SpecializationNode node) { @@ -150,10 +220,6 @@ return index; } - protected boolean isSame(SpecializationNode other) { - return getClass() == other.getClass(); - } - private int count() { return next != null ? next.count() + 1 : 1; } @@ -226,8 +292,12 @@ return findStart().getParent(); } - private SpecializationNode removeSameImpl(SpecializationNode toRemove, CharSequence reason) { - SpecializationNode start = findStart(); + private SpecializedNode findSpecializedNode() { + return (SpecializedNode) findEnd().findStart().getParent(); + } + + private static SpecializationNode removeSameImpl(SpecializationNode toRemove, CharSequence reason) { + SpecializationNode start = toRemove.findStart(); SpecializationNode current = start; while (current != null) { if (current.isSame(toRemove)) { @@ -238,7 +308,7 @@ } current = current.next; } - return findEnd().findStart(); + return toRemove.findEnd().findStart(); } public Object acceptAndExecute(Frame frame) { @@ -314,7 +384,7 @@ if (nextSpecialization == null) { return unsupported(frame); } - return insertSpecialization(nextSpecialization, new RewriteEvent0(findRoot(), "inserted new specialization")).acceptAndExecute(frame); + return atomic(new InsertionEvent0(this, "insert new specialization", frame, nextSpecialization)).acceptAndExecute(frame); } protected final Object uninitialized(Frame frame, Object o1) { @@ -326,7 +396,7 @@ if (nextSpecialization == null) { return unsupported(frame, o1); } - return insertSpecialization(nextSpecialization, new RewriteEvent1(findRoot(), "inserted new specialization", o1)).acceptAndExecute(frame, o1); + return atomic(new InsertionEvent1(this, "insert new specialization", frame, o1, nextSpecialization)).acceptAndExecute(frame, o1); } protected final Object uninitialized(Frame frame, Object o1, Object o2) { @@ -338,7 +408,7 @@ if (nextSpecialization == null) { return unsupported(frame, o1, o2); } - return insertSpecialization(nextSpecialization, new RewriteEvent2(findRoot(), "inserted new specialization", o1, o2)).acceptAndExecute(frame, o1, o2); + return atomic(new InsertionEvent2(this, "insert new specialization", frame, o1, o2, nextSpecialization)).acceptAndExecute(frame, o1, o2); } protected final Object uninitialized(Frame frame, Object o1, Object o2, Object o3) { @@ -350,7 +420,7 @@ if (nextSpecialization == null) { return unsupported(frame, o1, o2, o3); } - return insertSpecialization(nextSpecialization, new RewriteEvent3(findRoot(), "inserted new specialization", o1, o2, o3)).acceptAndExecute(frame, o1, o2, o3); + return atomic(new InsertionEvent3(this, "insert new specialization", frame, o1, o2, o3, nextSpecialization)).acceptAndExecute(frame, o1, o2, o3); } protected final Object uninitialized(Frame frame, Object o1, Object o2, Object o3, Object o4) { @@ -362,7 +432,7 @@ if (nextSpecialization == null) { return unsupported(frame, o1, o2, o3, o4); } - return insertSpecialization(nextSpecialization, new RewriteEvent4(findRoot(), "inserts new specialization", o1, o2, o3, o4)).acceptAndExecute(frame, o1, o2, o3, o4); + return atomic(new InsertionEvent4(this, "insert new specialization", frame, o1, o2, o3, o4, nextSpecialization)).acceptAndExecute(frame, o1, o2, o3, o4); } protected final Object uninitialized(Frame frame, Object o1, Object o2, Object o3, Object o4, Object o5) { @@ -374,7 +444,7 @@ if (nextSpecialization == null) { unsupported(frame, o1, o2, o3, o4, o5); } - return insertSpecialization(nextSpecialization, new RewriteEventN(findRoot(), "inserts new specialization", o1, o2, o3, o4, o5)).acceptAndExecute(frame, o1, o2, o3, o4, o5); + return atomic(new InsertionEvent5(this, "insert new specialization", frame, o1, o2, o3, o4, o5, nextSpecialization)).acceptAndExecute(frame, o1, o2, o3, o4, o5); } protected final Object uninitialized(Frame frame, Object... args) { @@ -386,39 +456,35 @@ if (nextSpecialization == null) { unsupported(frame, args); } - return insertSpecialization(nextSpecialization, new RewriteEventN(findRoot(), "inserts new specialization", args)).acceptAndExecute(frame, args); - } - - private boolean needsPolymorphic() { - return findStart().count() == 2; + return atomic(new InsertionEventN(this, "insert new specialization", frame, args, nextSpecialization)).acceptAndExecute(frame, args); } protected final Object remove(String reason, Frame frame) { - return removeSame(new RewriteEvent0(findRoot(), reason)).acceptAndExecute(frame); + return atomic(new RemoveEvent0(this, reason, frame)).acceptAndExecute(frame); } protected final Object remove(String reason, Frame frame, Object o1) { - return removeSame(new RewriteEvent1(findRoot(), reason, o1)).acceptAndExecute(frame, o1); + return atomic(new RemoveEvent1(this, reason, frame, o1)).acceptAndExecute(frame, o1); } protected final Object remove(String reason, Frame frame, Object o1, Object o2) { - return removeSame(new RewriteEvent2(findRoot(), reason, o1, o2)).acceptAndExecute(frame, o1, o2); + return atomic(new RemoveEvent2(this, reason, frame, o1, o2)).acceptAndExecute(frame, o1, o2); } protected final Object remove(String reason, Frame frame, Object o1, Object o2, Object o3) { - return removeSame(new RewriteEvent3(findRoot(), reason, o1, o2, o3)).acceptAndExecute(frame, o1, o2, o3); + return atomic(new RemoveEvent3(this, reason, frame, o1, o2, o3)).acceptAndExecute(frame, o1, o2, o3); } protected final Object remove(String reason, Frame frame, Object o1, Object o2, Object o3, Object o4) { - return removeSame(new RewriteEvent4(findRoot(), reason, o1, o2, o3, o4)).acceptAndExecute(frame, o1, o2, o3, o4); + return atomic(new RemoveEvent4(this, reason, frame, o1, o2, o3, o4)).acceptAndExecute(frame, o1, o2, o3, o4); } protected final Object remove(String reason, Frame frame, Object o1, Object o2, Object o3, Object o4, Object o5) { - return removeSame(new RewriteEventN(findRoot(), reason, o1, o2, o3, o4, o5)).acceptAndExecute(frame, o1, o2, o3, o4, o5); + return atomic(new RemoveEvent5(this, reason, frame, o1, o2, o3, o4, o5)).acceptAndExecute(frame, o1, o2, o3, o4, o5); } protected final Object remove(String reason, Frame frame, Object... args) { - return removeSame(new RewriteEventN(findRoot(), reason, args)).acceptAndExecute(frame, args); + return atomic(new RemoveEventN(this, reason, frame, args)).acceptAndExecute(frame, args); } protected Object unsupported(Frame frame) { @@ -449,49 +515,24 @@ throw new UnsupportedSpecializationException(findRoot(), getSuppliedChildren(), args); } - private SpecializationNode insertSpecialization(final SpecializationNode generated, final CharSequence message) { - return atomic(new Callable<SpecializationNode>() { - public SpecializationNode call() { - return insert(generated, message); - } - }); - } - - private SpecializationNode insert(final SpecializationNode generated, CharSequence message) { - SpecializationNode start = findStart(); - if (start == this) { - // fast path for first insert - return insertBefore(this, generated, message); - } else { - return slowSortedInsert(start, generated, message); - } - } - - private static <T> SpecializationNode slowSortedInsert(SpecializationNode start, final SpecializationNode generated, final CharSequence message) { - final SpecializationNode merged = start.merge(generated); + static SpecializationNode insertSorted(SpecializationNode start, final SpecializationNode generated, final CharSequence message, final SpecializationNode merged) { if (merged == generated) { // new node if (start.count() == 2) { - insertBefore(start, start.createPolymorphic(), "insert polymorphic"); + insertAt(start, start.createPolymorphic(), "insert polymorphic"); } - SpecializationNode insertBefore = findInsertBeforeNode(generated.index, start); - return insertBefore(insertBefore, generated, message); + SpecializationNode current = start; + while (current != null && current.index < generated.index) { + current = current.next; + } + return insertAt(current, generated, message); } else { // existing node - merged.replace(merged, new RewriteEvent0(merged.findRoot(), "merged specialization")); return start; } } - private static SpecializationNode findInsertBeforeNode(int generatedIndex, SpecializationNode start) { - SpecializationNode current = start; - while (current != null && current.index < generatedIndex) { - current = current.next; - } - return current; - } - - private static <T> SpecializationNode insertBefore(SpecializationNode node, SpecializationNode insertBefore, CharSequence message) { + static <T> SpecializationNode insertAt(SpecializationNode node, SpecializationNode insertBefore, CharSequence message) { insertBefore.next = node; return node.replace(insertBefore, message); } @@ -545,7 +586,6 @@ b.append(")"); } - // utilities for generated code protected static void check(Assumption assumption) throws InvalidAssumptionException { if (assumption != null) { assumption.check(); @@ -580,4 +620,226 @@ return true; } + private static final class InsertionEvent0 extends SlowPathEvent0 implements Callable<SpecializationNode> { + + private final SpecializationNode next; + + public InsertionEvent0(SpecializationNode source, String reason, Frame frame, SpecializationNode next) { + super(source, reason, frame); + this.next = next; + } + + public SpecializationNode call() throws Exception { + SpecializationNode start = source.findStart(); + if (start.index == Integer.MAX_VALUE) { + return insertAt(start, next, this); + } else { + return insertSorted(start, next, this, start.merge(next, frame)); + } + } + + } + + private static final class InsertionEvent1 extends SlowPathEvent1 implements Callable<SpecializationNode> { + + private final SpecializationNode next; + + public InsertionEvent1(SpecializationNode source, String reason, Frame frame, Object o1, SpecializationNode next) { + super(source, reason, frame, o1); + this.next = next; + } + + public SpecializationNode call() throws Exception { + SpecializationNode start = source.findStart(); + if (start.index == Integer.MAX_VALUE) { + return insertAt(start, next, this); + } else { + return insertSorted(start, next, this, start.merge(next, frame, o1)); + } + } + + } + + private static final class InsertionEvent2 extends SlowPathEvent2 implements Callable<SpecializationNode> { + + private final SpecializationNode next; + + public InsertionEvent2(SpecializationNode source, String reason, Frame frame, Object o1, Object o2, SpecializationNode next) { + super(source, reason, frame, o1, o2); + this.next = next; + } + + public SpecializationNode call() throws Exception { + SpecializationNode start = source.findStart(); + if (start.index == Integer.MAX_VALUE) { + return insertAt(start, next, this); + } else { + return insertSorted(start, next, this, start.merge(next, frame, o1, o2)); + } + } + + } + + private static final class InsertionEvent3 extends SlowPathEvent3 implements Callable<SpecializationNode> { + + private final SpecializationNode next; + + public InsertionEvent3(SpecializationNode source, String reason, Frame frame, Object o1, Object o2, Object o3, SpecializationNode next) { + super(source, reason, frame, o1, o2, o3); + this.next = next; + } + + public SpecializationNode call() throws Exception { + SpecializationNode start = source.findStart(); + if (start.index == Integer.MAX_VALUE) { + return insertAt(start, next, this); + } else { + return insertSorted(start, next, this, start.merge(next, frame, o1, o2, o3)); + } + } + + } + + private static final class InsertionEvent4 extends SlowPathEvent4 implements Callable<SpecializationNode> { + + private final SpecializationNode next; + + public InsertionEvent4(SpecializationNode source, String reason, Frame frame, Object o1, Object o2, Object o3, Object o4, SpecializationNode next) { + super(source, reason, frame, o1, o2, o3, o4); + this.next = next; + } + + public SpecializationNode call() throws Exception { + SpecializationNode start = source.findStart(); + if (start.index == Integer.MAX_VALUE) { + return insertAt(start, next, this); + } else { + return insertSorted(start, next, this, start.merge(next, frame, o1, o2, o3, o4)); + } + } + + } + + private static final class InsertionEvent5 extends SlowPathEvent5 implements Callable<SpecializationNode> { + + private final SpecializationNode next; + + public InsertionEvent5(SpecializationNode source, String reason, Frame frame, Object o1, Object o2, Object o3, Object o4, Object o5, SpecializationNode next) { + super(source, reason, frame, o1, o2, o3, o4, o5); + this.next = next; + } + + public SpecializationNode call() throws Exception { + SpecializationNode start = source.findStart(); + if (start.index == Integer.MAX_VALUE) { + return insertAt(start, next, this); + } else { + return insertSorted(start, next, this, start.merge(next, frame, o1, o2, o3, o4, o5)); + } + } + + } + + private static final class InsertionEventN extends SlowPathEventN implements Callable<SpecializationNode> { + + private final SpecializationNode next; + + public InsertionEventN(SpecializationNode source, String reason, Frame frame, Object[] args, SpecializationNode next) { + super(source, reason, frame, args); + this.next = next; + } + + public SpecializationNode call() throws Exception { + SpecializationNode start = source.findStart(); + if (start.index == Integer.MAX_VALUE) { + return insertAt(start, next, this); + } else { + return insertSorted(start, next, this, start.merge(next, frame, args)); + } + } + } + + private static final class RemoveEvent0 extends SlowPathEvent0 implements Callable<SpecializationNode> { + + public RemoveEvent0(SpecializationNode source, String reason, Frame frame) { + super(source, reason, frame); + } + + public SpecializationNode call() throws Exception { + return removeSameImpl(source, this); + } + + } + + private static final class RemoveEvent1 extends SlowPathEvent1 implements Callable<SpecializationNode> { + + public RemoveEvent1(SpecializationNode source, String reason, Frame frame, Object o1) { + super(source, reason, frame, o1); + } + + public SpecializationNode call() throws Exception { + return removeSameImpl(source, this); + } + + } + + private static final class RemoveEvent2 extends SlowPathEvent2 implements Callable<SpecializationNode> { + + public RemoveEvent2(SpecializationNode source, String reason, Frame frame, Object o1, Object o2) { + super(source, reason, frame, o1, o2); + } + + public SpecializationNode call() throws Exception { + return removeSameImpl(source, this); + } + + } + + private static final class RemoveEvent3 extends SlowPathEvent3 implements Callable<SpecializationNode> { + + public RemoveEvent3(SpecializationNode source, String reason, Frame frame, Object o1, Object o2, Object o3) { + super(source, reason, frame, o1, o2, o3); + } + + public SpecializationNode call() throws Exception { + return removeSameImpl(source, this); + } + + } + + private static final class RemoveEvent4 extends SlowPathEvent4 implements Callable<SpecializationNode> { + + public RemoveEvent4(SpecializationNode source, String reason, Frame frame, Object o1, Object o2, Object o3, Object o4) { + super(source, reason, frame, o1, o2, o3, o4); + } + + public SpecializationNode call() throws Exception { + return removeSameImpl(source, this); + } + + } + + private static final class RemoveEvent5 extends SlowPathEvent5 implements Callable<SpecializationNode> { + + public RemoveEvent5(SpecializationNode source, String reason, Frame frame, Object o1, Object o2, Object o3, Object o4, Object o5) { + super(source, reason, frame, o1, o2, o3, o4, o5); + } + + public SpecializationNode call() throws Exception { + return removeSameImpl(source, this); + } + + } + + private static final class RemoveEventN extends SlowPathEventN implements Callable<SpecializationNode> { + + public RemoveEventN(SpecializationNode source, String reason, Frame frame, Object[] args) { + super(source, reason, frame, args); + } + + public SpecializationNode call() throws Exception { + return removeSameImpl(source, this); + } + } + }