diff graal/com.oracle.truffle.dsl.processor/src/com/oracle/truffle/dsl/processor/node/NodeCodeGenerator.java @ 11201:7fc3e1fb3965

Truffle-DSL: specialization group fixes.
author Christian Humer <christian.humer@gmail.com>
date Mon, 05 Aug 2013 19:50:34 +0200
parents 3479ab380552
children 80de3bbfa8b9
line wrap: on
line diff
--- a/graal/com.oracle.truffle.dsl.processor/src/com/oracle/truffle/dsl/processor/node/NodeCodeGenerator.java	Mon Aug 05 19:48:15 2013 +0200
+++ b/graal/com.oracle.truffle.dsl.processor/src/com/oracle/truffle/dsl/processor/node/NodeCodeGenerator.java	Mon Aug 05 19:50:34 2013 +0200
@@ -38,6 +38,7 @@
 import com.oracle.truffle.dsl.processor.ast.*;
 import com.oracle.truffle.dsl.processor.node.NodeChildData.Cardinality;
 import com.oracle.truffle.dsl.processor.node.NodeChildData.ExecutionKind;
+import com.oracle.truffle.dsl.processor.node.SpecializationGroup.TypeGuard;
 import com.oracle.truffle.dsl.processor.template.*;
 import com.oracle.truffle.dsl.processor.template.TemplateMethod.Signature;
 import com.oracle.truffle.dsl.processor.typesystem.*;
@@ -1103,8 +1104,9 @@
             NodeData node = specialization.getNode();
             CodeTypeElement clazz = getElement();
 
+            SpecializationGroup rootGroup = createSpecializationGroups(node);
+
             if (node.needsRewrites(context)) {
-
                 if (node.isPolymorphic()) {
 
                     CodeVariableElement var = new CodeVariableElement(modifiers(PROTECTED), clazz.asType(), "next0");
@@ -1129,12 +1131,12 @@
                     }
                 }
 
-                clazz.add(createGenericExecuteAndSpecialize(node));
+                clazz.add(createGenericExecuteAndSpecialize(node, rootGroup));
                 clazz.add(createInfoMessage(node));
             }
 
             if (node.getGenericSpecialization() != null && node.getGenericSpecialization().isReachable()) {
-                clazz.add(createGenericExecute(node));
+                clazz.add(createGenericExecute(node, rootGroup));
             }
         }
 
@@ -1415,7 +1417,7 @@
             return var;
         }
 
-        private CodeExecutableElement createGenericExecuteAndSpecialize(final NodeData node) {
+        private CodeExecutableElement createGenericExecuteAndSpecialize(final NodeData node, SpecializationGroup rootGroup) {
             TypeMirror genericReturnType = node.getGenericSpecialization().getReturnType().getType();
             CodeExecutableElement method = new CodeExecutableElement(modifiers(PROTECTED), genericReturnType, EXECUTE_SPECIALIZE_NAME);
             method.addParameter(new CodeVariableElement(getContext().getType(int.class), "minimumState"));
@@ -1442,6 +1444,30 @@
             addInternalValueParameterNames(builder, node.getGenericSpecialization(), node.getGenericSpecialization(), null, false, true);
             builder.end().end();
 
+            final String currentNodeVar = currentNode;
+            builder.tree(createExecuteTree(builder, node.getGenericSpecialization(), rootGroup, true, new CodeBlock<SpecializationData>() {
+
+                public CodeTree create(CodeTreeBuilder b, SpecializationData current) {
+                    return createGenericInvokeAndSpecialize(b, node.getGenericSpecialization(), current, currentNodeVar);
+                }
+            }));
+
+            boolean firstUnreachable = true;
+            for (SpecializationData current : node.getSpecializations()) {
+                if (current.isUninitialized() || current.isReachable()) {
+                    continue;
+                }
+                if (firstUnreachable) {
+                    emitEncounteredSynthetic(builder, current);
+                    firstUnreachable = false;
+                }
+            }
+            emitUnreachableSpecializations(builder, node);
+
+            return method;
+        }
+
+        private SpecializationGroup createSpecializationGroups(final NodeData node) {
             List<SpecializationData> specializations = node.getSpecializations();
             List<SpecializationData> filteredSpecializations = new ArrayList<>();
             for (SpecializationData current : specializations) {
@@ -1451,35 +1477,10 @@
                 filteredSpecializations.add(current);
             }
 
-            List<SpecializationGroup> groups = SpecializationGroup.create(filteredSpecializations);
-
-            final String currentNodeVar = currentNode;
-            for (SpecializationGroup group : groups) {
-                builder.tree(createExecuteTree(builder, node.getGenericSpecialization(), group, true, new CodeBlock<SpecializationData>() {
-
-                    public CodeTree create(CodeTreeBuilder b, SpecializationData current) {
-                        return createGenericInvokeAndSpecialize(b, node.getGenericSpecialization(), current, currentNodeVar);
-                    }
-                }));
-            }
-
-            boolean firstUnreachable = true;
-            for (SpecializationData current : specializations) {
-                if (current.isUninitialized() || current.isReachable()) {
-                    continue;
-                }
-                if (firstUnreachable) {
-                    emitEncounteredSynthetic(builder, current);
-                    firstUnreachable = false;
-                }
-
-                builder.string("// unreachable ").string(current.getId()).newLine();
-            }
-
-            return method;
+            return SpecializationGroup.create(filteredSpecializations);
         }
 
-        private CodeExecutableElement createGenericExecute(NodeData node) {
+        private CodeExecutableElement createGenericExecute(NodeData node, SpecializationGroup group) {
             TypeMirror genericReturnType = node.getGenericSpecialization().getReturnType().getType();
             CodeExecutableElement method = new CodeExecutableElement(modifiers(PROTECTED), genericReturnType, EXECUTE_GENERIC_NAME);
 
@@ -1488,34 +1489,25 @@
             addInternalValueParameters(method, node.getGenericSpecialization(), node.needsFrame(), false);
             final CodeTreeBuilder builder = method.createBuilder();
 
-            List<SpecializationData> specializations = node.getSpecializations();
-            List<SpecializationData> filteredSpecializations = new ArrayList<>();
-            for (SpecializationData current : specializations) {
-                if (current.isUninitialized() || !current.isReachable()) {
-                    continue;
+            builder.tree(createExecuteTree(builder, node.getGenericSpecialization(), group, false, new CodeBlock<SpecializationData>() {
+
+                public CodeTree create(CodeTreeBuilder b, SpecializationData current) {
+                    return createGenericInvoke(builder, current.getNode().getGenericSpecialization(), current);
                 }
-                filteredSpecializations.add(current);
-            }
-
-            List<SpecializationGroup> groups = SpecializationGroup.create(filteredSpecializations);
-
-            for (SpecializationGroup group : groups) {
-                builder.tree(createExecuteTree(builder, node.getGenericSpecialization(), group, false, new CodeBlock<SpecializationData>() {
-
-                    public CodeTree create(CodeTreeBuilder b, SpecializationData current) {
-                        return createGenericInvoke(builder, current.getNode().getGenericSpecialization(), current);
-                    }
-                }));
-            }
-
-            for (SpecializationData current : specializations) {
+            }));
+
+            emitUnreachableSpecializations(builder, node);
+
+            return method;
+        }
+
+        private void emitUnreachableSpecializations(final CodeTreeBuilder builder, NodeData node) {
+            for (SpecializationData current : node.getSpecializations()) {
                 if (current.isUninitialized() || current.isReachable()) {
                     continue;
                 }
                 builder.string("// unreachable ").string(current.getId()).newLine();
             }
-
-            return method;
         }
 
         private CodeTree createExecuteTree(CodeTreeBuilder outerParent, final SpecializationData source, final SpecializationGroup group, final boolean checkMinimumState,
@@ -1541,15 +1533,52 @@
         }
 
         private CodeTree guard(CodeTreeBuilder parent, SpecializationData source, SpecializationGroup group, boolean checkMinimumState, CodeBlock<Void> bodyBlock) {
+            CodeTreeBuilder builder = parent.create();
+
+            int ifCount = emitGuards(builder, source, group, checkMinimumState);
+
+            if (isReachableGroup(group, ifCount, checkMinimumState)) {
+                builder.tree(bodyBlock.create(builder, null));
+            }
+
+            builder.end(ifCount);
+
+            return builder.getRoot();
+        }
+
+        private boolean isReachableGroup(SpecializationGroup group, int ifCount, boolean checkMinimumState) {
+            if (ifCount != 0) {
+                return true;
+            }
+            SpecializationGroup previous = group.getPreviousGroup();
+            if (previous == null || previous.getElseConnectableGuard() == null) {
+                return true;
+            }
+
+            /*
+             * Hacky else case. In this case the specialization is not reachable due to previous
+             * else branch. This is only true if the minimum state is not checked.
+             */
+            if (previous.getGuards().size() == 1 && previous.getTypeGuards().isEmpty() && previous.getAssumptions().isEmpty() && !checkMinimumState &&
+                            (previous.getParent() == null || previous.getMaxSpecializationIndex() != previous.getParent().getMaxSpecializationIndex())) {
+                return false;
+            }
+
+            return true;
+        }
+
+        private int emitGuards(CodeTreeBuilder builder, SpecializationData source, SpecializationGroup group, boolean checkMinimumState) {
             NodeData node = source.getNode();
 
-            CodeTreeBuilder guardsBuilder = parent.create();
-            CodeTreeBuilder castBuilder = parent.create();
-            CodeTreeBuilder guardsCastBuilder = parent.create();
+            CodeTreeBuilder guardsBuilder = builder.create();
+            CodeTreeBuilder castBuilder = builder.create();
+            CodeTreeBuilder guardsCastBuilder = builder.create();
 
             String guardsAnd = "";
             String guardsCastAnd = "";
 
+            GuardData elseGuard = group.getElseConnectableGuard();
+
             boolean minimumState = checkMinimumState;
             if (minimumState) {
                 int groupMaxIndex = group.getMaxSpecializationIndex();
@@ -1585,11 +1614,8 @@
                 guardsAnd = " && ";
             }
 
-            int argOffset = group.getTypeGuardOffset();
-            int argIndex = argOffset;
-            for (TypeData typeData : group.getTypeGuards()) {
-
-                ActualParameter valueParam = source.getSignatureParameter(argIndex);
+            for (TypeGuard typeGuard : group.getTypeGuards()) {
+                ActualParameter valueParam = source.getSignatureParameter(typeGuard.getSignatureIndex());
 
                 if (valueParam == null) {
                     /*
@@ -1598,9 +1624,9 @@
                      * specialization.
                      */
                     if (group.getSpecialization() != null) {
-                        valueParam = group.getSpecialization().getSignatureParameter(argIndex);
+                        valueParam = group.getSpecialization().getSignatureParameter(typeGuard.getSignatureIndex());
                     } else {
-                        valueParam = node.getGenericSpecialization().getSignatureParameter(argIndex);
+                        valueParam = node.getGenericSpecialization().getSignatureParameter(typeGuard.getSignatureIndex());
                     }
                 }
 
@@ -1609,93 +1635,68 @@
                     throw new IllegalStateException();
                 }
 
-                CodeTree implicitGuard = createTypeGuard(guardsBuilder, child, valueParam, typeData);
+                CodeTree implicitGuard = createTypeGuard(guardsBuilder, child, valueParam, typeGuard.getType());
                 if (implicitGuard != null) {
                     guardsBuilder.string(guardsAnd);
                     guardsBuilder.tree(implicitGuard);
                     guardsAnd = " && ";
                 }
 
-                CodeTree cast = createCast(castBuilder, child, valueParam, typeData);
+                CodeTree cast = createCast(castBuilder, child, valueParam, typeGuard.getType());
                 if (cast != null) {
                     castBuilder.tree(cast);
                 }
-
-                argIndex++;
             }
-            CodeTreeBuilder builder = parent.create();
-
-            int ifCount = 0;
-            if (isElseConnectableGroup(group)) {
-                if (minimumState) {
-                    builder.startElseIf().tree(guardsBuilder.getRoot()).end().startBlock();
-                } else {
-                    builder.startElseBlock();
+
+            for (GuardData guard : group.getGuards()) {
+                if (elseGuard == guard) {
+                    continue;
                 }
-                ifCount++;
-
-            } else {
-                for (GuardData guard : group.getGuards()) {
-                    if (needsTypeGuard(source, group, guard)) {
-                        guardsCastBuilder.tree(createMethodGuard(parent, guardsCastAnd, source, guard));
-                        guardsCastAnd = " && ";
-                    } else {
-                        guardsBuilder.tree(createMethodGuard(parent, guardsAnd, source, guard));
-                        guardsAnd = " && ";
-                    }
-                }
-
-                if (!guardsBuilder.isEmpty()) {
-                    builder.startIf().tree(guardsBuilder.getRoot()).end().startBlock();
-                    ifCount++;
-                }
-                builder.tree(castBuilder.getRoot());
-
-                if (!guardsCastBuilder.isEmpty()) {
-                    builder.startIf().tree(guardsCastBuilder.getRoot()).end().startBlock();
-                    ifCount++;
+
+                if (needsTypeGuard(source, group, guard)) {
+                    guardsCastBuilder.tree(createMethodGuard(builder, guardsCastAnd, source, guard));
+                    guardsCastAnd = " && ";
+                } else {
+                    guardsBuilder.tree(createMethodGuard(builder, guardsAnd, source, guard));
+                    guardsAnd = " && ";
                 }
             }
 
-            builder.tree(bodyBlock.create(builder, null));
-
-            builder.end(ifCount);
-            return builder.getRoot();
+            int ifCount = startGuardIf(builder, guardsBuilder, 0, elseGuard);
+            builder.tree(castBuilder.getRoot());
+            ifCount = startGuardIf(builder, guardsCastBuilder, ifCount, elseGuard);
+            return ifCount;
         }
 
-        private boolean isElseConnectableGroup(SpecializationGroup group) {
-            if (!group.getTypeGuards().isEmpty() || !group.getAssumptions().isEmpty()) {
-                return false;
+        private int startGuardIf(CodeTreeBuilder builder, CodeTreeBuilder conditionBuilder, int ifCount, GuardData elseGuard) {
+            int newIfCount = ifCount;
+
+            if (!conditionBuilder.isEmpty()) {
+                if (ifCount == 0 && elseGuard != null) {
+                    builder.startElseIf();
+                } else {
+                    builder.startIf();
+                }
+                builder.tree(conditionBuilder.getRoot());
+                builder.end().startBlock();
+                newIfCount++;
+            } else if (ifCount == 0 && elseGuard != null) {
+                builder.startElseBlock();
+                newIfCount++;
             }
-
-            SpecializationGroup previousGroup = group.getPreviousGroup();
-            if (previousGroup != null && group.getGuards().size() == 1 && previousGroup.getGuards().size() == 1) {
-                GuardData guard = group.getGuards().get(0);
-                GuardData previousGuard = previousGroup.getGuards().get(0);
-
-                if (guard.getMethod().equals(previousGuard.getMethod())) {
-                    assert guard.isNegated() != previousGuard.isNegated();
-                    return true;
-                }
-            }
-            return false;
+            return newIfCount;
         }
 
         private boolean needsTypeGuard(SpecializationData source, SpecializationGroup group, GuardData guard) {
-            int offset = group.getTypeGuardOffset();
-            int argIndex = 0;
+            int signatureIndex = 0;
             for (ActualParameter parameter : guard.getParameters()) {
                 if (!parameter.getSpecification().isSignature()) {
                     continue;
                 }
-                if (argIndex < offset) {
-                    // type casted in parent group
-                    continue;
-                }
-
-                int guardIndex = argIndex - offset;
-                if (guardIndex < group.getTypeGuards().size()) {
-                    TypeData requiredType = group.getTypeGuards().get(guardIndex);
+
+                TypeGuard typeGuard = group.findTypeGuard(signatureIndex);
+                if (typeGuard != null) {
+                    TypeData requiredType = typeGuard.getType();
 
                     ActualParameter sourceParameter = source.findParameter(parameter.getLocalName());
                     if (sourceParameter == null) {
@@ -1706,7 +1707,8 @@
                         return true;
                     }
                 }
-                argIndex++;
+
+                signatureIndex++;
             }
             return false;
         }
@@ -1749,10 +1751,10 @@
             } else {
                 // simple rewrite
                 if (current.getExceptions().isEmpty()) {
-                    builder.tree(createGenericInvoke(builder, source, current, createReplaceCall(builder, current, currentNodeVar, currentNodeVar, null)));
+                    builder.tree(createGenericInvoke(builder, source, current, createReplaceCall(builder, current, currentNodeVar, currentNodeVar, null), null));
                 } else {
                     builder.startStatement().string(currentNodeVar).string(" = ").tree(createReplaceCall(builder, current, currentNodeVar, currentNodeVar, null)).end();
-                    builder.tree(createGenericInvoke(builder, source, current, CodeTreeBuilder.singleString(currentNodeVar)));
+                    builder.tree(createGenericInvoke(builder, source, current, null, CodeTreeBuilder.singleString(currentNodeVar)));
                 }
             }
             CodeTreeBuilder root = parent.create();
@@ -1770,7 +1772,7 @@
             builder.tree(createFindRoot(builder, node, false));
             builder.end();
             builder.end();
-            builder.tree(createGenericInvoke(builder, source, current, createReplaceCall(builder, current, "root", currentNode, null)));
+            builder.tree(createGenericInvoke(builder, source, current, createReplaceCall(builder, current, "root", "(" + baseClassName(node) + ") root", null), null));
             return builder.getRoot();
         }
 
@@ -1806,22 +1808,30 @@
             return builder.getRoot();
         }
 
-        protected CodeTree createGenericInvoke(CodeTreeBuilder parent, SpecializationData source, SpecializationData current, CodeTree replaceCall) {
+        protected CodeTree createGenericInvoke(CodeTreeBuilder parent, SpecializationData source, SpecializationData current, CodeTree replaceCall, CodeTree replaceVar) {
+            assert replaceCall == null || replaceVar == null;
             CodeTreeBuilder builder = parent.create();
+            CodeTree replace = replaceVar;
+            if (replace == null) {
+                replace = replaceCall;
+            }
             if (current.isGeneric()) {
-                builder.startReturn().tree(replaceCall).string(".").startCall(EXECUTE_GENERIC_NAME);
+                builder.startReturn().tree(replace).string(".").startCall(EXECUTE_GENERIC_NAME);
                 addInternalValueParameterNames(builder, source, current, null, current.getNode().needsFrame(), true);
                 builder.end().end();
-
             } else if (current.getMethod() == null) {
-                builder.statement(replaceCall);
+                if (replaceCall != null) {
+                    builder.statement(replaceCall);
+                }
                 emitEncounteredSynthetic(builder, current);
             } else if (!current.canBeAccessedByInstanceOf(getContext(), source.getNode().getNodeType())) {
-                builder.statement(replaceCall);
+                if (replaceCall != null) {
+                    builder.statement(replaceCall);
+                }
                 builder.startReturn().tree(createTemplateMethodCall(parent, null, source, current, null)).end();
             } else {
-                replaceCall.add(new CodeTree(CodeTreeKind.STRING, null, "."));
-                builder.startReturn().tree(createTemplateMethodCall(parent, replaceCall, source, current, null)).end();
+                replace.add(new CodeTree(CodeTreeKind.STRING, null, "."));
+                builder.startReturn().tree(createTemplateMethodCall(parent, replace, source, current, null)).end();
             }
             return builder.getRoot();
         }
@@ -2494,7 +2504,7 @@
             builder.startBlock();
             String message = ("Polymorphic limit reached (" + node.getPolymorphicDepth() + ")");
             builder.tree(createGenericInvoke(builder, node.getGenericPolymorphicSpecialization(), node.getGenericSpecialization(),
-                            createReplaceCall(builder, node.getGenericSpecialization(), "root", "this", message)));
+                            createReplaceCall(builder, node.getGenericSpecialization(), "root", "this", message), null));
             builder.end();
 
             builder.startElseBlock();