diff graal/com.oracle.truffle.dsl.processor/src/com/oracle/truffle/dsl/processor/generator/NodeGenFactory.java @ 19291:f4792a544170

Truffle-DSL: implement new assumptions semantics.
author Christian Humer <christian.humer@gmail.com>
date Wed, 11 Feb 2015 12:13:44 +0100
parents 62c43fcf5be2
children 21b9b9941775
line wrap: on
line diff
--- a/graal/com.oracle.truffle.dsl.processor/src/com/oracle/truffle/dsl/processor/generator/NodeGenFactory.java	Wed Feb 11 12:13:44 2015 +0100
+++ b/graal/com.oracle.truffle.dsl.processor/src/com/oracle/truffle/dsl/processor/generator/NodeGenFactory.java	Wed Feb 11 12:13:44 2015 +0100
@@ -32,7 +32,6 @@
 import javax.lang.model.type.*;
 import javax.lang.model.util.*;
 
-import com.oracle.truffle.api.*;
 import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
 import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
 import com.oracle.truffle.api.dsl.*;
@@ -74,11 +73,10 @@
         this.genericType = typeSystem.getGenericTypeData();
         this.options = typeSystem.getOptions();
         this.singleSpecializable = isSingleSpecializableImpl();
-        this.varArgsThreshold = calculateVarArgsThresHold();
-
+        this.varArgsThreshold = calculateVarArgsThreshold();
     }
 
-    private int calculateVarArgsThresHold() {
+    private int calculateVarArgsThreshold() {
         TypeMirror specialization = context.getType(SpecializationNode.class);
         TypeElement specializationType = fromTypeMirror(specialization);
 
@@ -95,6 +93,10 @@
         return resolveNodeId(node) + NODE_SUFFIX;
     }
 
+    private static String assumptionName(AssumptionExpression assumption) {
+        return assumption.getId() + NAME_SUFFIX;
+    }
+
     private static String resolveNodeId(NodeData node) {
         String nodeid = node.getNodeId();
         if (nodeid.endsWith("Node") && !nodeid.equals("Node")) {
@@ -157,18 +159,10 @@
         }
     }
 
-    private static String assumptionName(String assumption) {
-        return assumption + "_";
-    }
-
     public CodeTypeElement create() {
         CodeTypeElement clazz = GeneratorUtils.createClass(node, null, modifiers(FINAL), nodeTypeName(node), node.getTemplateType().asType());
         ElementUtils.setVisibility(clazz.getModifiers(), ElementUtils.getVisibility(node.getTemplateType().getModifiers()));
 
-        for (String assumption : node.getAssumptions()) {
-            clazz.add(new CodeVariableElement(modifiers(PRIVATE, FINAL), getType(Assumption.class), assumptionName(assumption)));
-        }
-
         for (NodeChildData child : node.getChildren()) {
             clazz.addOptional(createAccessChildMethod(child));
         }
@@ -757,6 +751,11 @@
                 return false;
             }
         }
+
+        if (!specialization.getAssumptionExpressions().isEmpty()) {
+            return false;
+        }
+
         if (specialization.getCaches().size() > 0) {
             // TODO chumer: caches do not yet support single specialization.
             // it could be worthwhile to explore if this is possible
@@ -830,7 +829,7 @@
             if (wrappedExecutableType != null) {
                 builder.startReturn().tree(callTemplateMethod(null, wrappedExecutableType, locals)).end();
             } else {
-                builder.tree(createFastPathExecute(builder, specialization, execType.getType(), locals));
+                builder.tree(createFastPath(builder, specialization, execType.getType(), locals));
             }
         } else {
             // create acceptAndExecute
@@ -988,7 +987,8 @@
                     continue;
                 }
             }
-            builder.string(parameter.getLocalName());
+
+            builder.defaultValue(parameter.getType());
         }
         builder.end();
         return builder.build();
@@ -1036,6 +1036,35 @@
         if (node.isFrameUsedByAnyGuard()) {
             builder.tree(createTransferToInterpreterAndInvalidate());
         }
+
+        boolean hasAssumptions = !specialization.getAssumptionExpressions().isEmpty();
+        if (hasAssumptions) {
+            for (AssumptionExpression assumption : specialization.getAssumptionExpressions()) {
+                CodeTree assumptions = DSLExpressionGenerator.write(assumption.getExpression(), accessParent(null),
+                                castBoundTypes(bindExpressionValues(assumption.getExpression(), specialization, currentValues)));
+                String name = assumptionName(assumption);
+                // needs specialization index for assumption to make unique
+                String varName = name + specialization.getIndex();
+                TypeMirror type = assumption.getExpression().getResolvedType();
+                builder.declaration(type, varName, assumptions);
+                currentValues.set(name, new LocalVariable(null, type, varName, null));
+            }
+
+            builder.startIf();
+            String sep = "";
+            for (AssumptionExpression assumption : specialization.getAssumptionExpressions()) {
+                LocalVariable assumptionVar = currentValues.get(assumptionName(assumption));
+                if (assumptionVar == null) {
+                    throw new AssertionError("assumption var not resolved");
+                }
+                builder.string(sep);
+                builder.startCall("isValid").tree(assumptionVar.createReference()).end();
+                sep = " && ";
+            }
+            builder.end();
+            builder.startBlock();
+        }
+
         for (SpecializationData otherSpeciailzation : node.getSpecializations()) {
             if (otherSpeciailzation == specialization) {
                 continue;
@@ -1052,7 +1081,6 @@
 
         if (specialization.hasMultipleInstances()) {
             builder.declaration(getType(SpecializationNode.class), "s", create);
-
             DSLExpression limitExpression = specialization.getLimitExpression();
             CodeTree limitExpressionTree;
             if (limitExpression == null) {
@@ -1065,11 +1093,14 @@
             builder.startIf().string("countSame(s) < ").tree(limitExpressionTree).end().startBlock();
             builder.statement("return s");
             builder.end();
-
         } else {
             builder.startReturn().tree(create).end();
         }
 
+        if (hasAssumptions) {
+            builder.end();
+        }
+
         if (mayBeExcluded(specialization)) {
             CodeTreeBuilder checkHasSeenBuilder = builder.create();
             checkHasSeenBuilder.startIf().string("!").tree(accessParent(excludedFieldName(specialization))).end().startBlock();
@@ -1112,7 +1143,7 @@
             return true;
         }
 
-        if ((!fastPath || forType.isGeneric()) && !group.getAssumptions().isEmpty()) {
+        if (!fastPath && specialization != null && !specialization.getAssumptionExpressions().isEmpty()) {
             return true;
         }
 
@@ -1193,7 +1224,13 @@
                 }
                 builder.tree(variable.createReference());
             }
-
+            for (AssumptionExpression assumption : specialization.getAssumptionExpressions()) {
+                LocalVariable variable = currentValues.get(assumptionName(assumption));
+                if (variable == null) {
+                    throw new AssertionError("Could not bind assumption value " + assumption.getId() + ": " + currentValues);
+                }
+                builder.tree(variable.createReference());
+            }
         }
         builder.end();
 
@@ -1287,6 +1324,13 @@
                 builder.startStatement().string("this.").string(name).string(" = ").string(name).end();
             }
 
+            for (AssumptionExpression assumption : specialization.getAssumptionExpressions()) {
+                String name = assumptionName(assumption);
+                TypeMirror type = assumption.getExpression().getResolvedType();
+                clazz.add(new CodeVariableElement(modifiers(PRIVATE, FINAL), type, name));
+                constructor.addParameter(new CodeVariableElement(type, name));
+                builder.startStatement().string("this.").string(name).string(" = ").string(name).end();
+            }
         }
 
         if (constructor.getParameters().isEmpty()) {
@@ -1346,9 +1390,12 @@
         return builder.build();
     }
 
-    private static CodeTree createCallDelegate(String methodName, TypeData forType, LocalContext currentValues) {
+    private static CodeTree createCallDelegate(String methodName, String reason, TypeData forType, LocalContext currentValues) {
         CodeTreeBuilder builder = CodeTreeBuilder.createBuilder();
         builder.startCall(methodName);
+        if (reason != null) {
+            builder.doubleQuote(reason);
+        }
         currentValues.addReferencesTo(builder, FRAME_VALUE);
         builder.end();
 
@@ -1389,7 +1436,7 @@
         LocalContext currentLocals = LocalContext.load(this, evaluatedArguments, varArgsThreshold);
 
         if (specialization != null) {
-            currentLocals.loadFastPathCachedValues(specialization);
+            currentLocals.loadFastPathState(specialization);
         }
 
         CodeExecutableElement executable = currentLocals.createMethod(modifiers(PUBLIC), type.getPrimitiveType(), TypeSystemNodeFactory.executeName(forType), FRAME_VALUE);
@@ -1400,12 +1447,12 @@
         }
 
         CodeTreeBuilder builder = executable.createBuilder();
-        builder.tree(createFastPathExecute(builder, specialization, type, currentLocals));
+        builder.tree(createFastPath(builder, specialization, type, currentLocals));
 
         return executable;
     }
 
-    private CodeTree createFastPathExecute(CodeTreeBuilder parent, SpecializationData specialization, TypeData type, LocalContext currentLocals) {
+    private CodeTree createFastPath(CodeTreeBuilder parent, SpecializationData specialization, TypeData type, LocalContext currentLocals) {
         final CodeTreeBuilder builder = parent.create();
 
         for (NodeExecutionData execution : node.getChildExecutions()) {
@@ -1426,11 +1473,11 @@
 
         LocalContext originalValues = currentLocals.copy();
         if (specialization == null) {
-            builder.startReturn().tree(createCallDelegate("acceptAndExecute", type, currentLocals)).end();
+            builder.startReturn().tree(createCallDelegate("acceptAndExecute", null, type, currentLocals)).end();
         } else if (specialization.isPolymorphic()) {
             builder.tree(createCallNext(type, currentLocals));
         } else if (specialization.isUninitialized()) {
-            builder.startReturn().tree(createCallDelegate("uninitialized", type, currentLocals)).end();
+            builder.startReturn().tree(createCallDelegate("uninitialized", null, type, currentLocals)).end();
         } else {
             final TypeData finalType = type;
             SpecializationGroup group = SpecializationGroup.create(specialization);
@@ -1495,6 +1542,27 @@
             ifCount++;
         }
         CodeTreeBuilder execute = builder.create();
+
+        if (!specialization.getAssumptionExpressions().isEmpty()) {
+            builder.startTryBlock();
+            for (AssumptionExpression assumption : specialization.getAssumptionExpressions()) {
+                LocalVariable assumptionVar = currentValues.get(assumptionName(assumption));
+                if (assumptionVar == null) {
+                    throw new AssertionError("Could not resolve assumption var " + currentValues);
+                }
+                builder.startStatement().startCall("check").tree(assumptionVar.createReference()).end().end();
+            }
+            builder.end().startCatchBlock(getType(InvalidAssumptionException.class), "ae");
+            builder.startReturn();
+            List<String> assumptionIds = new ArrayList<>();
+            for (AssumptionExpression assumption : specialization.getAssumptionExpressions()) {
+                assumptionIds.add(assumption.getId());
+            }
+            builder.tree(createCallDelegate("removeThis", String.format("Assumption %s invalidated", assumptionIds), forType, currentValues));
+            builder.end();
+            builder.end();
+        }
+
         execute.startReturn();
         if (specialization.getMethod() == null) {
             execute.startCall("unsupported");
@@ -1525,7 +1593,7 @@
         }
 
         SpecializationData specialization = group.getSpecialization();
-        CodeTree[] checkAndCast = createTypeCheckCastAndCaches(specialization, group.getTypeGuards(), castGuards, currentValues, execution);
+        CodeTree[] checkAndCast = createTypeCheckAndLocals(specialization, group.getTypeGuards(), castGuards, currentValues, execution);
 
         CodeTree check = checkAndCast[0];
         CodeTree cast = checkAndCast[1];
@@ -1537,14 +1605,6 @@
         CodeTree methodGuards = methodGuardAndAssertions[0];
         CodeTree guardAssertions = methodGuardAndAssertions[1];
 
-        if (!group.getAssumptions().isEmpty()) {
-            if (execution.isFastPath() && !forType.isGeneric()) {
-                cast = appendAssumptionFastPath(cast, group.getAssumptions(), forType, currentValues);
-            } else {
-                methodGuards = appendAssumptionSlowPath(methodGuards, group.getAssumptions());
-            }
-        }
-
         int ifCount = 0;
         if (!check.isEmpty()) {
             builder.startIf();
@@ -1583,33 +1643,6 @@
         return builder.build();
     }
 
-    private CodeTree appendAssumptionSlowPath(CodeTree methodGuards, List<String> assumptions) {
-        CodeTreeBuilder builder = CodeTreeBuilder.createBuilder();
-
-        builder.tree(methodGuards);
-        String connect = methodGuards.isEmpty() ? "" : " && ";
-        for (String assumption : assumptions) {
-            builder.string(connect);
-            builder.startCall(accessParent(assumptionName(assumption)), "isValid").end();
-            connect = " && ";
-        }
-
-        return builder.build();
-    }
-
-    private CodeTree appendAssumptionFastPath(CodeTree casts, List<String> assumptions, TypeData forType, LocalContext currentValues) {
-        CodeTreeBuilder builder = CodeTreeBuilder.createBuilder();
-        builder.tree(casts);
-        builder.startTryBlock();
-        for (String assumption : assumptions) {
-            builder.startStatement().startCall(accessParent(assumptionName(assumption)), "check").end().end();
-        }
-        builder.end().startCatchBlock(getType(InvalidAssumptionException.class), "ae");
-        builder.tree(createCallNext(forType, currentValues));
-        builder.end();
-        return builder.build();
-    }
-
     private static boolean isReachableGroup(SpecializationGroup group, int ifCount) {
         if (ifCount != 0) {
             return true;
@@ -1623,7 +1656,7 @@
          * Hacky else case. In this case the specialization is not reachable due to previous else
          * branch. This is only true if the minimum state is not checked.
          */
-        if (previous.getGuards().size() == 1 && previous.getTypeGuards().isEmpty() && previous.getAssumptions().isEmpty() &&
+        if (previous.getGuards().size() == 1 && previous.getTypeGuards().isEmpty() &&
                         (previous.getParent() == null || previous.getMaxSpecializationIndex() != previous.getParent().getMaxSpecializationIndex())) {
             return false;
         }
@@ -2124,10 +2157,10 @@
         return bindings;
     }
 
-    private CodeTree[] createTypeCheckCastAndCaches(SpecializationData specialization, List<TypeGuard> typeGuards, Set<TypeGuard> castGuards, LocalContext currentValues,
+    private CodeTree[] createTypeCheckAndLocals(SpecializationData specialization, List<TypeGuard> typeGuards, Set<TypeGuard> castGuards, LocalContext currentValues,
                     SpecializationExecution specializationExecution) {
         CodeTreeBuilder checksBuilder = CodeTreeBuilder.createBuilder();
-        CodeTreeBuilder castsAndCaches = CodeTreeBuilder.createBuilder();
+        CodeTreeBuilder localsBuilder = CodeTreeBuilder.createBuilder();
         for (TypeGuard typeGuard : typeGuards) {
             int signatureIndex = typeGuard.getSignatureIndex();
             LocalVariable value = currentValues.getValue(signatureIndex);
@@ -2188,7 +2221,7 @@
             if (castGuards == null || castGuards.contains(typeGuard)) {
                 LocalVariable castVariable = currentValues.getValue(execution).nextName().newType(typeGuard.getType()).accessWith(null);
                 currentValues.setValue(execution, castVariable);
-                castsAndCaches.tree(castVariable.createDeclaration(castBuilder.build()));
+                localsBuilder.tree(castVariable.createDeclaration(castBuilder.build()));
             }
 
             checksBuilder.tree(checkBuilder.build());
@@ -2202,12 +2235,12 @@
                 // multiple specializations might use the same name
                 String varName = name + specialization.getIndex();
                 TypeMirror type = cache.getParameter().getType();
-                castsAndCaches.declaration(type, varName, initializer);
+                localsBuilder.declaration(type, varName, initializer);
                 currentValues.set(name, new LocalVariable(null, type, varName, null));
             }
         }
 
-        return new CodeTree[]{checksBuilder.build(), castsAndCaches.build()};
+        return new CodeTree[]{checksBuilder.build(), localsBuilder.build()};
     }
 
     public static final class LocalContext {
@@ -2219,12 +2252,18 @@
             this.factory = factory;
         }
 
-        public void loadFastPathCachedValues(SpecializationData specialization) {
+        public void loadFastPathState(SpecializationData specialization) {
             for (CacheExpression cache : specialization.getCaches()) {
                 Parameter cacheParameter = cache.getParameter();
                 String name = cacheParameter.getVariableElement().getSimpleName().toString();
                 set(cacheParameter.getLocalName(), new LocalVariable(cacheParameter.getTypeSystemType(), cacheParameter.getType(), name, CodeTreeBuilder.singleString("this." + name)));
             }
+
+            for (AssumptionExpression assumption : specialization.getAssumptionExpressions()) {
+                String name = assumptionName(assumption);
+                TypeMirror type = assumption.getExpression().getResolvedType();
+                set(name, new LocalVariable(null, type, name, CodeTreeBuilder.singleString("this." + name)));
+            }
         }
 
         public CodeExecutableElement createMethod(Set<Modifier> modifiers, TypeMirror returnType, String name, String... optionalArguments) {
@@ -2425,6 +2464,11 @@
             return values.get(shortCircuitName(execution));
         }
 
+        @Override
+        public String toString() {
+            return "LocalContext [values=" + values + "]";
+        }
+
     }
 
     public static final class LocalVariable {