Mercurial > hg > truffle
diff graal/com.oracle.truffle.dsl.processor/src/com/oracle/truffle/dsl/processor/generator/NodeGenFactory.java @ 19289:62c43fcf5be2
Truffle-DSL: implement @Cached and fixes for the new guard expression syntax.
author | Christian Humer <christian.humer@gmail.com> |
---|---|
date | Tue, 03 Feb 2015 15:07:07 +0100 |
parents | 08aa0372dad4 |
children | f4792a544170 |
line wrap: on
line diff
--- a/graal/com.oracle.truffle.dsl.processor/src/com/oracle/truffle/dsl/processor/generator/NodeGenFactory.java Mon Dec 29 18:32:03 2014 +0100 +++ b/graal/com.oracle.truffle.dsl.processor/src/com/oracle/truffle/dsl/processor/generator/NodeGenFactory.java Tue Feb 03 15:07:07 2015 +0100 @@ -42,11 +42,13 @@ import com.oracle.truffle.api.frame.*; import com.oracle.truffle.api.nodes.*; import com.oracle.truffle.api.nodes.Node.Child; +import com.oracle.truffle.api.nodes.Node.Children; import com.oracle.truffle.dsl.processor.*; import com.oracle.truffle.dsl.processor.expression.*; import com.oracle.truffle.dsl.processor.expression.DSLExpression.Variable; import com.oracle.truffle.dsl.processor.java.*; import com.oracle.truffle.dsl.processor.java.model.*; +import com.oracle.truffle.dsl.processor.java.model.CodeTypeMirror.ArrayCodeTypeMirror; import com.oracle.truffle.dsl.processor.model.*; import com.oracle.truffle.dsl.processor.parser.*; import com.oracle.truffle.dsl.processor.parser.SpecializationGroup.TypeGuard; @@ -54,8 +56,8 @@ public class NodeGenFactory { private static final String FRAME_VALUE = TemplateMethod.FRAME_NAME; - private static final String NAME_SUFFIX = "_"; + private static final String NODE_SUFFIX = "NodeGen"; private final ProcessorContext context; private final NodeData node; @@ -90,7 +92,7 @@ } public static String nodeTypeName(NodeData node) { - return resolveNodeId(node) + "NodeGen"; + return resolveNodeId(node) + NODE_SUFFIX; } private static String resolveNodeId(NodeData node) { @@ -242,7 +244,7 @@ private Element createUnsupportedMethod() { LocalContext locals = LocalContext.load(this); - CodeExecutableElement method = locals.createMethod(modifiers(PRIVATE), getType(UnsupportedSpecializationException.class), "unsupported"); + CodeExecutableElement method = locals.createMethod(modifiers(PROTECTED), getType(UnsupportedSpecializationException.class), "unsupported"); CodeTreeBuilder builder = method.createBuilder(); builder.startReturn(); @@ -320,7 +322,7 @@ return constructor; } - public static boolean mayBeExcluded(SpecializationData specialization) { + private static boolean mayBeExcluded(SpecializationData specialization) { return !specialization.getExceptions().isEmpty() || !specialization.getExcludedBy().isEmpty(); } @@ -334,7 +336,7 @@ List<SpecializationData> generateSpecializations = new ArrayList<>(); generateSpecializations.add(node.getUninitializedSpecialization()); - if (needsPolymorphic(reachableSpecializations)) { + if (needsPolymorphic()) { generateSpecializations.add(node.getPolymorphicSpecialization()); } generateSpecializations.addAll(reachableSpecializations); @@ -350,6 +352,26 @@ return node.getUninitializedSpecialization(); } + private boolean needsPolymorphic() { + List<SpecializationData> reachableSpecializations = getReachableSpecializations(); + if (reachableSpecializations.size() != 1) { + return true; + } + + SpecializationData specialization = reachableSpecializations.get(0); + for (Parameter parameter : specialization.getSignatureParameters()) { + TypeData type = parameter.getTypeSystemType(); + if (type != null && type.hasImplicitSourceTypes()) { + return true; + } + } + if (specialization.hasMultipleInstances()) { + return true; + } + return false; + + } + // create specialization private CodeTypeElement createBaseSpecialization() { @@ -492,7 +514,8 @@ } private Element createMergeMethod(SpecializationData specialization) { - if (specialization.getExcludedBy().isEmpty() && !specialization.isPolymorphic()) { + boolean cacheBoundGuard = specialization.hasMultipleInstances(); + if (specialization.getExcludedBy().isEmpty() && !specialization.isPolymorphic() && !cacheBoundGuard) { return null; } TypeMirror specializationNodeType = getType(SpecializationNode.class); @@ -513,7 +536,11 @@ builder.statement("removeSame(\"Contained by " + containedSpecialization.createReferenceName() + "\")"); builder.end(); } - builder.statement("return super.merge(newNode)"); + if (cacheBoundGuard) { + builder.statement("return super.mergeNoSame(newNode)"); + } else { + builder.statement("return super.merge(newNode)"); + } } return executable; @@ -555,21 +582,6 @@ return executable; } - private boolean needsPolymorphic(List<SpecializationData> reachableSpecializations) { - if (reachableSpecializations.size() > 1) { - return true; - } - if (options.implicitCastOptimization().isDuplicateTail()) { - SpecializationData specialization = reachableSpecializations.get(0); - for (Parameter parameter : specialization.getSignatureParameters()) { - if (parameter.getTypeSystemType().hasImplicitSourceTypes()) { - return true; - } - } - } - return false; - } - private Element createCreateFallback(Map<SpecializationData, CodeTypeElement> generatedSpecializationClasses) { SpecializationData fallback = node.getGenericSpecialization(); if (fallback == null) { @@ -614,7 +626,7 @@ if (generatedType == null) { throw new AssertionError("No generated type for " + specialization); } - return createSlowPathExecute(specialization, locals); + return createSlowPathExecute(specialization, values); } public boolean isFastPath() { @@ -624,7 +636,7 @@ builder.tree(execution); - if (hasFallthrough(group, genericType, locals, false)) { + if (hasFallthrough(group, genericType, locals, false, null)) { builder.returnNull(); } return method; @@ -713,8 +725,6 @@ return evaluatedCount; } - // create specialization - private Element createUnsupported() { SpecializationData fallback = node.getGenericSpecialization(); if (fallback == null || optimizeFallback(fallback) || fallback.getMethod() == null) { @@ -738,12 +748,20 @@ if (reachableSpecializations.size() != 1) { return false; } - for (Parameter parameter : reachableSpecializations.get(0).getSignatureParameters()) { + + SpecializationData specialization = reachableSpecializations.get(0); + + for (Parameter parameter : specialization.getSignatureParameters()) { TypeData type = parameter.getTypeSystemType(); if (type != null && type.hasImplicitSourceTypes()) { return false; } } + if (specialization.getCaches().size() > 0) { + // TODO chumer: caches do not yet support single specialization. + // it could be worthwhile to explore if this is possible + return false; + } return true; } @@ -1014,6 +1032,7 @@ if (specialization.isFallback()) { return builder.returnNull().build(); } + if (node.isFrameUsedByAnyGuard()) { builder.tree(createTransferToInterpreterAndInvalidate()); } @@ -1029,7 +1048,27 @@ } } - builder.startReturn().tree(createCallCreateMethod(specialization, null, currentValues)).end(); + CodeTree create = createCallCreateMethod(specialization, null, currentValues); + + if (specialization.hasMultipleInstances()) { + builder.declaration(getType(SpecializationNode.class), "s", create); + + DSLExpression limitExpression = specialization.getLimitExpression(); + CodeTree limitExpressionTree; + if (limitExpression == null) { + limitExpressionTree = CodeTreeBuilder.singleString("3"); + } else { + limitExpressionTree = DSLExpressionGenerator.write(limitExpression, accessParent(null), // + castBoundTypes(bindExpressionValues(limitExpression, specialization, currentValues))); + } + + builder.startIf().string("countSame(s) < ").tree(limitExpressionTree).end().startBlock(); + builder.statement("return s"); + builder.end(); + + } else { + builder.startReturn().tree(create).end(); + } if (mayBeExcluded(specialization)) { CodeTreeBuilder checkHasSeenBuilder = builder.create(); @@ -1041,7 +1080,7 @@ return builder.build(); } - private static boolean hasFallthrough(SpecializationGroup group, TypeData forType, LocalContext currentValues, boolean fastPath) { + private boolean hasFallthrough(SpecializationGroup group, TypeData forType, LocalContext currentValues, boolean fastPath, List<GuardExpression> ignoreGuards) { for (TypeGuard guard : group.getTypeGuards()) { if (currentValues.getValue(guard.getSignatureIndex()) == null) { // not evaluated @@ -1053,9 +1092,23 @@ } } - List<GuardExpression> expressions = new ArrayList<>(group.getGuards()); - expressions.removeAll(group.findElseConnectableGuards()); - if (!expressions.isEmpty()) { + List<GuardExpression> guards = new ArrayList<>(group.getGuards()); + List<GuardExpression> elseConnectable = group.findElseConnectableGuards(); + guards.removeAll(elseConnectable); + if (ignoreGuards != null) { + guards.removeAll(ignoreGuards); + } + SpecializationData specialization = group.getSpecialization(); + if (specialization != null && fastPath) { + for (ListIterator<GuardExpression> iterator = guards.listIterator(); iterator.hasNext();) { + GuardExpression guard = iterator.next(); + if (!specialization.isDynamicParameterBound(guard.getExpression())) { + iterator.remove(); + } + } + } + + if (!guards.isEmpty()) { return true; } @@ -1063,12 +1116,20 @@ return true; } - if (!fastPath && group.getSpecialization() != null && mayBeExcluded(group.getSpecialization())) { + if (!fastPath && specialization != null && mayBeExcluded(specialization)) { return true; } - if (!group.getChildren().isEmpty()) { - return hasFallthrough(group.getChildren().get(group.getChildren().size() - 1), forType, currentValues, fastPath); + if (!elseConnectable.isEmpty()) { + SpecializationGroup previous = group.getPrevious(); + if (previous != null && hasFallthrough(previous, forType, currentValues, fastPath, previous.getGuards())) { + return true; + } + } + + List<SpecializationGroup> groupChildren = group.getChildren(); + if (!groupChildren.isEmpty()) { + return hasFallthrough(groupChildren.get(groupChildren.size() - 1), forType, currentValues, fastPath, ignoreGuards); } return false; @@ -1119,12 +1180,20 @@ } if (currentValues != null) { for (Parameter p : specialization.getSignatureParameters()) { - LocalVariable local = currentValues.get(p.getLocalName()); CodeVariableElement var = createImplicitProfileParameter(p.getSpecification().getExecution(), p.getTypeSystemType()); if (var != null) { - builder.tree(local.createReference()); + // we need the original name here + builder.tree(LocalVariable.fromParameter(p).createReference()); } } + for (CacheExpression cache : specialization.getCaches()) { + LocalVariable variable = currentValues.get(cache.getParameter().getLocalName()); + if (variable == null) { + throw new AssertionError("Could not bind cached value " + cache.getParameter().getLocalName() + ": " + currentValues); + } + builder.tree(variable.createReference()); + } + } builder.end(); @@ -1201,6 +1270,23 @@ } } } + for (CacheExpression cache : specialization.getCaches()) { + String name = cache.getParameter().getLocalName(); + TypeMirror type = cache.getParameter().getType(); + + if (ElementUtils.isAssignable(type, new ArrayCodeTypeMirror(getType(Node.class)))) { + CodeVariableElement var = clazz.add(new CodeVariableElement(modifiers(PRIVATE, FINAL), type, name)); + var.addAnnotationMirror(new CodeAnnotationMirror(context.getDeclaredType(Children.class))); + } else if (ElementUtils.isAssignable(type, getType(Node.class))) { + CodeVariableElement var = clazz.add(new CodeVariableElement(modifiers(PRIVATE), type, name)); + var.addAnnotationMirror(new CodeAnnotationMirror(context.getDeclaredType(Child.class))); + } else { + clazz.add(new CodeVariableElement(modifiers(PRIVATE, FINAL), type, name)); + } + constructor.addParameter(new CodeVariableElement(type, name)); + builder.startStatement().string("this.").string(name).string(" = ").string(name).end(); + } + } if (constructor.getParameters().isEmpty()) { @@ -1302,6 +1388,10 @@ TypeData type = forType == null ? genericType : forType; LocalContext currentLocals = LocalContext.load(this, evaluatedArguments, varArgsThreshold); + if (specialization != null) { + currentLocals.loadFastPathCachedValues(specialization); + } + CodeExecutableElement executable = currentLocals.createMethod(modifiers(PUBLIC), type.getPrimitiveType(), TypeSystemNodeFactory.executeName(forType), FRAME_VALUE); executable.getAnnotationMirrors().add(new CodeAnnotationMirror(context.getDeclaredType(Override.class))); @@ -1354,7 +1444,7 @@ } }; builder.tree(createGuardAndCast(group, type, currentLocals, executionFactory)); - if (hasFallthrough(group, type, originalValues, true) || group.getSpecialization().isFallback()) { + if (hasFallthrough(group, type, originalValues, true, null) || group.getSpecialization().isFallback()) { builder.tree(createCallNext(type, originalValues)); } } @@ -1428,19 +1518,24 @@ } else { castGuards = new HashSet<>(); for (TypeGuard castGuard : group.getTypeGuards()) { - if (isTypeGuardUsedInAnyGuardBelow(group, currentValues, castGuard)) { + if (isTypeGuardUsedInAnyGuardOrCacheBelow(group, currentValues, castGuard)) { castGuards.add(castGuard); } } } - CodeTree[] checkAndCast = createTypeCheckAndCast(group.getTypeGuards(), castGuards, currentValues, execution); + + SpecializationData specialization = group.getSpecialization(); + CodeTree[] checkAndCast = createTypeCheckCastAndCaches(specialization, group.getTypeGuards(), castGuards, currentValues, execution); + CodeTree check = checkAndCast[0]; CodeTree cast = checkAndCast[1]; List<GuardExpression> elseGuardExpressions = group.findElseConnectableGuards(); List<GuardExpression> guardExpressions = new ArrayList<>(group.getGuards()); guardExpressions.removeAll(elseGuardExpressions); - CodeTree methodGuards = createMethodGuardCheck(guardExpressions, group.getSpecialization(), currentValues); + CodeTree[] methodGuardAndAssertions = createMethodGuardCheck(guardExpressions, specialization, currentValues, execution.isFastPath()); + CodeTree methodGuards = methodGuardAndAssertions[0]; + CodeTree guardAssertions = methodGuardAndAssertions[1]; if (!group.getAssumptions().isEmpty()) { if (execution.isFastPath() && !forType.isGeneric()) { @@ -1460,6 +1555,9 @@ if (!cast.isEmpty()) { builder.tree(cast); } + if (!guardAssertions.isEmpty()) { + builder.tree(guardAssertions); + } boolean elseIf = !elseGuardExpressions.isEmpty(); if (!methodGuards.isEmpty()) { builder.startIf(elseIf); @@ -1476,7 +1574,6 @@ for (SpecializationGroup child : group.getChildren()) { builder.tree(createGuardAndCast(child, forType, currentValues.copy(), execution)); } - SpecializationData specialization = group.getSpecialization(); if (specialization != null) { builder.tree(execution.createExecute(specialization, currentValues)); } @@ -1534,21 +1631,25 @@ return true; } - private boolean isTypeGuardUsedInAnyGuardBelow(SpecializationGroup group, LocalContext currentValues, TypeGuard typeGuard) { - LocalVariable localVariable = currentValues.getValue(typeGuard.getSignatureIndex()); + private boolean isTypeGuardUsedInAnyGuardOrCacheBelow(SpecializationGroup group, LocalContext currentValues, TypeGuard typeGuard) { + String localName = currentValues.getValue(typeGuard.getSignatureIndex()).getName(); + SpecializationData specialization = group.getSpecialization(); for (GuardExpression guard : group.getGuards()) { - Map<Variable, LocalVariable> boundValues = bindLocalValues(guard.getExpression(), group.getSpecialization(), currentValues); - for (Variable var : guard.getExpression().findBoundVariables()) { - LocalVariable target = boundValues.get(var); - if (localVariable.getName().equals(target.getName())) { + if (isVariableBoundIn(specialization, guard.getExpression(), localName, currentValues)) { + return true; + } + } + if (specialization != null) { + for (CacheExpression cache : specialization.getCaches()) { + if (isVariableBoundIn(specialization, cache.getExpression(), localName, currentValues)) { return true; } } } for (SpecializationGroup child : group.getChildren()) { - if (isTypeGuardUsedInAnyGuardBelow(child, currentValues, typeGuard)) { + if (isTypeGuardUsedInAnyGuardOrCacheBelow(child, currentValues, typeGuard)) { return true; } } @@ -1556,6 +1657,17 @@ return false; } + private static boolean isVariableBoundIn(SpecializationData specialization, DSLExpression expression, String localName, LocalContext currentValues) throws AssertionError { + Map<Variable, LocalVariable> boundValues = bindExpressionValues(expression, specialization, currentValues); + for (Variable var : expression.findBoundVariables()) { + LocalVariable target = boundValues.get(var); + if (target != null && localName.equals(target.getName())) { + return true; + } + } + return false; + } + private CodeExecutableElement createExecuteChildMethod(NodeExecutionData execution, TypeData targetType) { LocalContext locals = LocalContext.load(this, 0, varArgsThreshold); @@ -1947,33 +2059,48 @@ return builder.build(); } - private CodeTree createMethodGuardCheck(List<GuardExpression> guardExpressions, SpecializationData specialization, LocalContext currentValues) { - CodeTreeBuilder builder = CodeTreeBuilder.createBuilder(); + private CodeTree[] createMethodGuardCheck(List<GuardExpression> guardExpressions, SpecializationData specialization, LocalContext currentValues, boolean fastPath) { + CodeTreeBuilder expressionBuilder = CodeTreeBuilder.createBuilder(); + CodeTreeBuilder assertionBuilder = CodeTreeBuilder.createBuilder(); String and = ""; for (GuardExpression guard : guardExpressions) { DSLExpression expression = guard.getExpression(); - Map<Variable, LocalVariable> bindings = bindLocalValues(expression, specialization, currentValues); - Map<Variable, CodeTree> resolvedBindings = new HashMap<>(); - for (Variable variable : bindings.keySet()) { - LocalVariable localVariable = bindings.get(variable); - CodeTree resolved = CodeTreeBuilder.singleString(localVariable.getName()); - if (!ElementUtils.typeEquals(variable.getResolvedType(), localVariable.getTypeMirror())) { - resolved = CodeTreeBuilder.createBuilder().cast(variable.getResolvedType(), resolved).build(); - } - resolvedBindings.put(variable, resolved); + + Map<Variable, CodeTree> resolvedBindings = castBoundTypes(bindExpressionValues(expression, specialization, currentValues)); + CodeTree expressionCode = DSLExpressionGenerator.write(expression, accessParent(null), resolvedBindings); + + if (!specialization.isDynamicParameterBound(expression) && fastPath) { + /* + * Guards where no dynamic parameters are bound can just be executed on the fast + * path. + */ + assertionBuilder.startAssert().tree(expressionCode).end(); + } else { + expressionBuilder.string(and); + expressionBuilder.tree(expressionCode); + and = " && "; } - - builder.string(and); - builder.tree(DSLExpressionGenerator.write(expression, accessParent(null), resolvedBindings)); - and = " && "; } - return builder.build(); + return new CodeTree[]{expressionBuilder.build(), assertionBuilder.build()}; } - private static Map<Variable, LocalVariable> bindLocalValues(DSLExpression expression, SpecializationData specialization, LocalContext currentValues) throws AssertionError { + private static Map<Variable, CodeTree> castBoundTypes(Map<Variable, LocalVariable> bindings) { + Map<Variable, CodeTree> resolvedBindings = new HashMap<>(); + for (Variable variable : bindings.keySet()) { + LocalVariable localVariable = bindings.get(variable); + CodeTree resolved = localVariable.createReference(); + if (!ElementUtils.typeEquals(variable.getResolvedType(), localVariable.getTypeMirror())) { + resolved = CodeTreeBuilder.createBuilder().cast(variable.getResolvedType(), resolved).build(); + } + resolvedBindings.put(variable, resolved); + } + return resolvedBindings; + } + + private static Map<Variable, LocalVariable> bindExpressionValues(DSLExpression expression, SpecializationData specialization, LocalContext currentValues) throws AssertionError { Map<Variable, LocalVariable> bindings = new HashMap<>(); - List<Variable> boundVariables = expression.findBoundVariables(); + Set<Variable> boundVariables = expression.findBoundVariables(); if (specialization == null && !boundVariables.isEmpty()) { throw new AssertionError("Cannot bind guard variable in non-specialization group. yet."); } @@ -1989,18 +2116,18 @@ } else { localVariable = currentValues.get(resolvedParameter.getLocalName()); } - if (localVariable == null) { - throw new AssertionError("Could not resolve local for execution."); + if (localVariable != null) { + bindings.put(variable, localVariable); } - bindings.put(variable, localVariable); } } return bindings; } - private CodeTree[] createTypeCheckAndCast(List<TypeGuard> typeGuards, Set<TypeGuard> castGuards, LocalContext currentValues, SpecializationExecution specializationExecution) { + private CodeTree[] createTypeCheckCastAndCaches(SpecializationData specialization, List<TypeGuard> typeGuards, Set<TypeGuard> castGuards, LocalContext currentValues, + SpecializationExecution specializationExecution) { CodeTreeBuilder checksBuilder = CodeTreeBuilder.createBuilder(); - CodeTreeBuilder castsBuilder = CodeTreeBuilder.createBuilder(); + CodeTreeBuilder castsAndCaches = CodeTreeBuilder.createBuilder(); for (TypeGuard typeGuard : typeGuards) { int signatureIndex = typeGuard.getSignatureIndex(); LocalVariable value = currentValues.getValue(signatureIndex); @@ -2061,12 +2188,26 @@ if (castGuards == null || castGuards.contains(typeGuard)) { LocalVariable castVariable = currentValues.getValue(execution).nextName().newType(typeGuard.getType()).accessWith(null); currentValues.setValue(execution, castVariable); - castsBuilder.tree(castVariable.createDeclaration(castBuilder.build())); + castsAndCaches.tree(castVariable.createDeclaration(castBuilder.build())); } checksBuilder.tree(checkBuilder.build()); } - return new CodeTree[]{checksBuilder.build(), castsBuilder.build()}; + + if (specialization != null && !specializationExecution.isFastPath()) { + for (CacheExpression cache : specialization.getCaches()) { + CodeTree initializer = DSLExpressionGenerator.write(cache.getExpression(), accessParent(null), + castBoundTypes(bindExpressionValues(cache.getExpression(), specialization, currentValues))); + String name = cache.getParameter().getLocalName(); + // multiple specializations might use the same name + String varName = name + specialization.getIndex(); + TypeMirror type = cache.getParameter().getType(); + castsAndCaches.declaration(type, varName, initializer); + currentValues.set(name, new LocalVariable(null, type, varName, null)); + } + } + + return new CodeTree[]{checksBuilder.build(), castsAndCaches.build()}; } public static final class LocalContext { @@ -2078,6 +2219,14 @@ this.factory = factory; } + public void loadFastPathCachedValues(SpecializationData specialization) { + for (CacheExpression cache : specialization.getCaches()) { + Parameter cacheParameter = cache.getParameter(); + String name = cacheParameter.getVariableElement().getSimpleName().toString(); + set(cacheParameter.getLocalName(), new LocalVariable(cacheParameter.getTypeSystemType(), cacheParameter.getType(), name, CodeTreeBuilder.singleString("this." + name))); + } + } + public CodeExecutableElement createMethod(Set<Modifier> modifiers, TypeMirror returnType, String name, String... optionalArguments) { CodeExecutableElement method = new CodeExecutableElement(modifiers, returnType, name); addParametersTo(method, optionalArguments);