diff graal/com.oracle.truffle.dsl.processor/src/com/oracle/truffle/dsl/processor/parser/NodeParser.java @ 19289:62c43fcf5be2

Truffle-DSL: implement @Cached and fixes for the new guard expression syntax.
author Christian Humer <christian.humer@gmail.com>
date Tue, 03 Feb 2015 15:07:07 +0100
parents 259a416388d7
children f4792a544170
line wrap: on
line diff
--- a/graal/com.oracle.truffle.dsl.processor/src/com/oracle/truffle/dsl/processor/parser/NodeParser.java	Mon Dec 29 18:32:03 2014 +0100
+++ b/graal/com.oracle.truffle.dsl.processor/src/com/oracle/truffle/dsl/processor/parser/NodeParser.java	Tue Feb 03 15:07:07 2015 +0100
@@ -36,6 +36,7 @@
 import com.oracle.truffle.dsl.processor.expression.*;
 import com.oracle.truffle.dsl.processor.java.*;
 import com.oracle.truffle.dsl.processor.java.compiler.*;
+import com.oracle.truffle.dsl.processor.java.model.*;
 import com.oracle.truffle.dsl.processor.model.*;
 import com.oracle.truffle.dsl.processor.model.NodeChildData.Cardinality;
 import com.oracle.truffle.dsl.processor.model.SpecializationData.SpecializationKind;
@@ -99,7 +100,8 @@
         } catch (CompileErrorException e) {
             throw e;
         } catch (Throwable e) {
-            throw new RuntimeException(String.format("Parsing of Node %s failed.", ElementUtils.getQualifiedName(rootType)), e);
+            e.addSuppressed(new RuntimeException(String.format("Parsing of Node %s failed.", ElementUtils.getQualifiedName(rootType)), e));
+            throw e;
         }
         if (node == null && !enclosedNodes.isEmpty()) {
             node = new NodeData(context, rootType);
@@ -209,29 +211,44 @@
                     node.addError(importAnnotation, importClassesValue, "The specified import guard class '%s' is not a declared type.", ElementUtils.getQualifiedName(importGuardClass));
                     continue;
                 }
+
                 TypeElement typeElement = ElementUtils.fromTypeMirror(importGuardClass);
-
-                // hack to reload type is necessary for incremental compiling in eclipse.
-                // otherwise methods inside of import guard types are just not found.
-                typeElement = ElementUtils.fromTypeMirror(context.reloadType(typeElement.asType()));
-
                 if (typeElement.getEnclosingElement().getKind().isClass() && !typeElement.getModifiers().contains(Modifier.PUBLIC)) {
                     node.addError(importAnnotation, importClassesValue, "The specified import guard class '%s' must be public.", ElementUtils.getQualifiedName(importGuardClass));
                     continue;
                 }
-
-                List<? extends ExecutableElement> importMethods = ElementFilter.methodsIn(processingEnv.getElementUtils().getAllMembers(typeElement));
-
-                for (ExecutableElement importMethod : importMethods) {
-                    if (!importMethod.getModifiers().contains(Modifier.PUBLIC) || !importMethod.getModifiers().contains(Modifier.STATIC)) {
-                        continue;
-                    }
-                    elements.add(importMethod);
-                }
+                elements.addAll(importPublicStaticMembers(typeElement, false));
             }
         }
     }
 
+    private List<? extends Element> importPublicStaticMembers(TypeElement importGuardClass, boolean includeConstructors) {
+        // hack to reload type is necessary for incremental compiling in eclipse.
+        // otherwise methods inside of import guard types are just not found.
+        TypeElement typeElement = ElementUtils.fromTypeMirror(context.reloadType(importGuardClass.asType()));
+
+        List<Element> members = new ArrayList<>();
+        for (Element importElement : processingEnv.getElementUtils().getAllMembers(typeElement)) {
+            if (!importElement.getModifiers().contains(Modifier.PUBLIC)) {
+                continue;
+            }
+
+            if (includeConstructors && importElement.getKind() == ElementKind.CONSTRUCTOR) {
+                members.add(importElement);
+            }
+
+            if (!importElement.getModifiers().contains(Modifier.STATIC)) {
+                continue;
+            }
+
+            ElementKind kind = importElement.getKind();
+            if (kind.isField() || kind == ElementKind.METHOD) {
+                members.add(importElement);
+            }
+        }
+        return members;
+    }
+
     private NodeData parseNodeData(TypeElement templateType, List<TypeElement> typeHierarchy) {
         AnnotationMirror typeSystemMirror = findFirstAnnotation(typeHierarchy, TypeSystemReference.class);
         if (typeSystemMirror == null) {
@@ -286,7 +303,7 @@
             }
             if (field.getModifiers().contains(Modifier.PUBLIC) || field.getModifiers().contains(Modifier.PROTECTED)) {
                 String name = field.getSimpleName().toString();
-                fields.add(new NodeFieldData(field, null, field.asType(), name, false));
+                fields.add(new NodeFieldData(field, null, field, false));
                 names.add(name);
             }
         }
@@ -301,7 +318,7 @@
                 String name = ElementUtils.firstLetterLowerCase(ElementUtils.getAnnotationValue(String.class, mirror, "name"));
                 TypeMirror type = ElementUtils.getAnnotationValue(TypeMirror.class, mirror, "type");
 
-                NodeFieldData field = new NodeFieldData(typeElement, mirror, type, name, true);
+                NodeFieldData field = new NodeFieldData(typeElement, mirror, new CodeVariableElement(type, name), true);
                 if (name.isEmpty()) {
                     field.addError(ElementUtils.getAnnotationValue(mirror, "name"), "Field name cannot be empty.");
                 } else if (names.contains(name)) {
@@ -469,6 +486,7 @@
             }
         }
 
+        TypeMirror cacheAnnotation = context.getType(Cached.class);
         List<TypeMirror> frameTypes = context.getFrameTypes();
         // pre-parse specializations to find signature size
         for (ExecutableElement method : methods) {
@@ -476,10 +494,10 @@
             if (mirror == null) {
                 continue;
             }
-
             int currentArgumentIndex = 0;
             boolean skipShortCircuit = false;
             outer: for (VariableElement var : method.getParameters()) {
+
                 TypeMirror type = var.asType();
                 if (currentArgumentIndex == 0) {
                     // skip optionals
@@ -489,6 +507,11 @@
                         }
                     }
                 }
+
+                if (ElementUtils.findAnnotationMirror(var.getAnnotationMirrors(), cacheAnnotation) != null) {
+                    continue outer;
+                }
+
                 int childIndex = currentArgumentIndex < children.size() ? currentArgumentIndex : children.size() - 1;
                 if (childIndex == -1) {
                     continue;
@@ -669,7 +692,9 @@
         // Declaration order is not required for child nodes.
         List<? extends Element> members = processingEnv.getElementUtils().getAllMembers(templateType);
         NodeData node = parseNodeData(templateType, lookupTypes);
-
+        if (node.hasErrors()) {
+            return node;
+        }
         node.setExecutableTypes(groupExecutableTypes(new ExecutableTypeMethodParser(context, node, createAllowedChildFrameTypes(parentNode)).parse(members)));
         node.setFrameType(parentNode.getFrameType());
         return node;
@@ -691,6 +716,7 @@
         }
 
         initializeExpressions(elements, node);
+
         if (node.hasErrors()) {
             return;
         }
@@ -701,19 +727,17 @@
         initializePolymorphism(node); // requires specializations
         initializeReachability(node);
         initializeContains(node);
-
         resolveContains(node);
 
         List<SpecializationData> specializations = node.getSpecializations();
         for (SpecializationData cur : specializations) {
-            for (SpecializationData child : specializations) {
-                if (child != null && child != cur && child.getContains().contains(cur)) {
-                    cur.getExcludedBy().add(child);
+            for (SpecializationData contained : cur.getContains()) {
+                if (contained != cur) {
+                    contained.getExcludedBy().add(cur);
                 }
             }
         }
 
-        // verify specialization parameter length
         initializeSpecializationIdsWithMethodNames(node.getSpecializations());
     }
 
@@ -778,15 +802,11 @@
                     AnnotationValue value = ElementUtils.getAnnotationValue(specialization.getMarkerAnnotation(), "contains");
                     specialization.addError(value, "The referenced specialization '%s' could not be found.", includeName);
                 } else {
-                    if (!foundSpecialization.isContainedBy(specialization)) {
+                    if (foundSpecialization.compareTo(specialization) > 0) {
                         AnnotationValue value = ElementUtils.getAnnotationValue(specialization.getMarkerAnnotation(), "contains");
                         if (foundSpecialization.compareTo(specialization) > 0) {
                             specialization.addError(value, "The contained specialization '%s' must be declared before the containing specialization.", includeName);
-                        }/*
-                          * else { specialization.addError(value,
-                          * "The contained specialization '%s' is not fully compatible. The contained specialization must be strictly more generic than the containing one."
-                          * , includeName); }
-                          */
+                        }
 
                     }
                     resolvedSpecializations.add(foundSpecialization);
@@ -935,48 +955,143 @@
     private void initializeExpressions(List<? extends Element> elements, NodeData node) {
         List<Element> members = filterNotAccessibleElements(node.getTemplateType(), elements);
 
-        final TypeMirror booleanType = context.getType(boolean.class);
+        List<VariableElement> fields = new ArrayList<>();
+        for (NodeFieldData field : node.getFields()) {
+            fields.add(field.getVariable());
+        }
 
         for (SpecializationData specialization : node.getSpecializations()) {
             if (specialization.getMethod() == null) {
                 continue;
             }
 
-            List<Element> specializationMembers = new ArrayList<>();
+            List<Element> specializationMembers = new ArrayList<>(members.size() + specialization.getParameters().size() + fields.size());
             for (Parameter p : specialization.getParameters()) {
-                VariableElement variableElement = p.getVariableElement();
-                if (variableElement == null) {
-                    throw new AssertionError("All parameters must be resolved for " + specialization);
+                specializationMembers.add(p.getVariableElement());
+            }
+            specializationMembers.addAll(fields);
+            specializationMembers.addAll(members);
+            DSLExpressionResolver resolver = new DSLExpressionResolver(context, specializationMembers);
+
+            initializeCaches(specialization, resolver);
+            initializeGuards(specialization, resolver);
+            if (specialization.hasErrors()) {
+                continue;
+            }
+            initializeLimit(specialization, resolver);
+        }
+    }
+
+    private void initializeLimit(SpecializationData specialization, DSLExpressionResolver resolver) {
+        AnnotationValue annotationValue = ElementUtils.getAnnotationValue(specialization.getMessageAnnotation(), "limit");
+
+        String limitValue;
+        if (annotationValue == null) {
+            limitValue = "";
+        } else {
+            limitValue = (String) annotationValue.getValue();
+        }
+        if (limitValue.isEmpty()) {
+            limitValue = "3";
+        } else if (!specialization.hasMultipleInstances()) {
+            specialization.addWarning(annotationValue, "The limit expression has no effect. Multiple specialization instantiations are impossible for this specialization.");
+            return;
+        }
+
+        TypeMirror expectedType = context.getType(int.class);
+        try {
+            DSLExpression expression = DSLExpression.parse(limitValue);
+            expression.accept(resolver);
+            if (!ElementUtils.typeEquals(expression.getResolvedType(), expectedType)) {
+                specialization.addError(annotationValue, "Incompatible return type %s. Limit expressions must return %s.", ElementUtils.getSimpleName(expression.getResolvedType()),
+                                ElementUtils.getSimpleName(expectedType));
+            }
+            if (specialization.isDynamicParameterBound(expression)) {
+                specialization.addError(annotationValue, "Limit expressions must not bind dynamic parameter values.");
+            }
+
+            specialization.setLimitExpression(expression);
+        } catch (InvalidExpressionException e) {
+            specialization.addError(annotationValue, "Error parsing expression '%s': %s", limitValue, e.getMessage());
+        }
+    }
+
+    private void initializeCaches(SpecializationData specialization, DSLExpressionResolver resolver) {
+        TypeMirror cacheMirror = context.getType(Cached.class);
+        List<CacheExpression> expressions = new ArrayList<>();
+        for (Parameter parameter : specialization.getParameters()) {
+            AnnotationMirror annotationMirror = ElementUtils.findAnnotationMirror(parameter.getVariableElement().getAnnotationMirrors(), cacheMirror);
+            if (annotationMirror != null) {
+                String initializer = ElementUtils.getAnnotationValue(String.class, annotationMirror, "value");
+
+                TypeMirror parameterType = parameter.getType();
+
+                DSLExpressionResolver localResolver = resolver;
+                if (parameterType.getKind() == TypeKind.DECLARED) {
+                    localResolver = localResolver.copy(importPublicStaticMembers(ElementUtils.fromTypeMirror(parameterType), true));
                 }
 
-                specializationMembers.add(variableElement);
-            }
-            specializationMembers.addAll(members);
-
-            DSLExpressionResolver resolver = new DSLExpressionResolver(context, specializationMembers);
-            List<String> guardDefinitions = ElementUtils.getAnnotationValueList(String.class, specialization.getMarkerAnnotation(), "guards");
-            List<GuardExpression> guardExpressions = new ArrayList<>();
-            for (String guard : guardDefinitions) {
-                GuardExpression guardExpression;
+                CacheExpression cacheExpression;
                 DSLExpression expression = null;
                 try {
-                    expression = DSLExpression.parse(guard);
-                    expression.accept(resolver);
-                    guardExpression = new GuardExpression(specialization, expression);
-                    if (!ElementUtils.typeEquals(expression.getResolvedType(), booleanType)) {
-                        guardExpression.addError("Incompatible return type %s. Guard expressions must return boolean.", ElementUtils.getSimpleName(expression.getResolvedType()));
+                    expression = DSLExpression.parse(initializer);
+                    expression.accept(localResolver);
+                    cacheExpression = new CacheExpression(parameter, annotationMirror, expression);
+                    if (!ElementUtils.typeEquals(expression.getResolvedType(), parameter.getType())) {
+                        cacheExpression.addError("Incompatible return type %s. The expression type must be equal to the parameter type %s.", ElementUtils.getSimpleName(expression.getResolvedType()),
+                                        ElementUtils.getSimpleName(parameter.getType()));
                     }
-// guardExpression.addWarning("Expression:%s", expression.toString());
                 } catch (InvalidExpressionException e) {
-                    guardExpression = new GuardExpression(specialization, null);
-                    guardExpression.addError("Error parsing expression '%s': %s", guard, e.getMessage());
+                    cacheExpression = new CacheExpression(parameter, annotationMirror, null);
+                    cacheExpression.addError("Error parsing expression '%s': %s", initializer, e.getMessage());
+                }
+                expressions.add(cacheExpression);
+            }
+        }
+        specialization.setCaches(expressions);
+
+        if (specialization.hasErrors()) {
+            return;
+        }
+
+        // verify that cache expressions are bound in the correct order.
+        for (int i = 0; i < expressions.size(); i++) {
+            CacheExpression currentExpression = expressions.get(i);
+            Set<VariableElement> boundVariables = currentExpression.getExpression().findBoundVariableElements();
+            for (int j = i + 1; j < expressions.size(); j++) {
+                CacheExpression boundExpression = expressions.get(j);
+                if (boundVariables.contains(boundExpression.getParameter().getVariableElement())) {
+                    currentExpression.addError("The initializer expression of parameter '%s' binds unitialized parameter '%s. Reorder the parameters to resolve the problem.",
+                                    currentExpression.getParameter().getLocalName(), boundExpression.getParameter().getLocalName());
+                    break;
                 }
-                guardExpressions.add(guardExpression);
             }
-            specialization.setGuards(guardExpressions);
         }
     }
 
+    private void initializeGuards(SpecializationData specialization, DSLExpressionResolver resolver) {
+        final TypeMirror booleanType = context.getType(boolean.class);
+        List<String> guardDefinitions = ElementUtils.getAnnotationValueList(String.class, specialization.getMarkerAnnotation(), "guards");
+        List<GuardExpression> guardExpressions = new ArrayList<>();
+        for (String guard : guardDefinitions) {
+            GuardExpression guardExpression;
+            DSLExpression expression = null;
+            try {
+                expression = DSLExpression.parse(guard);
+                expression.accept(resolver);
+                guardExpression = new GuardExpression(specialization, expression);
+                if (!ElementUtils.typeEquals(expression.getResolvedType(), booleanType)) {
+                    guardExpression.addError("Incompatible return type %s. Guards must return %s.", ElementUtils.getSimpleName(expression.getResolvedType()), ElementUtils.getSimpleName(booleanType));
+                }
+            } catch (InvalidExpressionException e) {
+                guardExpression = new GuardExpression(specialization, null);
+                guardExpression.addError("Error parsing expression '%s': %s", guard, e.getMessage());
+            }
+            guardExpressions.add(guardExpression);
+        }
+        specialization.setGuards(guardExpressions);
+    }
+
     private static List<Element> filterNotAccessibleElements(TypeElement templateType, List<? extends Element> elements) {
         String packageName = ElementUtils.getPackageName(templateType);
         List<Element> filteredElements = new ArrayList<>(elements);
@@ -1030,10 +1145,10 @@
         GenericParser parser = new GenericParser(context, node);
         MethodSpec specification = parser.createDefaultMethodSpec(node.getSpecializations().iterator().next().getMethod(), null, true, null);
 
-        List<TypeMirror> parameterTypes = new ArrayList<>();
+        List<VariableElement> parameterTypes = new ArrayList<>();
         int signatureIndex = 1;
         for (ParameterSpec spec : specification.getRequired()) {
-            parameterTypes.add(createGenericType(spec, node.getSpecializations(), signatureIndex));
+            parameterTypes.add(new CodeVariableElement(createGenericType(spec, node.getSpecializations(), signatureIndex), "arg" + signatureIndex));
             if (spec.isSignature()) {
                 signatureIndex++;
             }