changeset 11442:2868b55001d4

Truffle-DSL: fixed specializationg grouping failed with guards using base types.
author Christian Humer <christian.humer@gmail.com>
date Wed, 28 Aug 2013 01:45:13 +0200
parents fc509b6fbfdf
children b33783cbd8ce f49ee75d2a8b
files graal/com.oracle.truffle.dsl.processor/src/com/oracle/truffle/dsl/processor/node/NodeCodeGenerator.java graal/com.oracle.truffle.dsl.processor/src/com/oracle/truffle/dsl/processor/node/SpecializationGroup.java graal/com.oracle.truffle.dsl.processor/src/com/oracle/truffle/dsl/processor/template/TemplateMethod.java
diffstat 3 files changed, 93 insertions(+), 18 deletions(-) [+]
line wrap: on
line diff
--- a/graal/com.oracle.truffle.dsl.processor/src/com/oracle/truffle/dsl/processor/node/NodeCodeGenerator.java	Tue Aug 27 23:06:24 2013 +0200
+++ b/graal/com.oracle.truffle.dsl.processor/src/com/oracle/truffle/dsl/processor/node/NodeCodeGenerator.java	Wed Aug 28 01:45:13 2013 +0200
@@ -1545,7 +1545,7 @@
                 return true;
             }
             SpecializationGroup previous = group.getPreviousGroup();
-            if (previous == null || previous.getElseConnectableGuard() == null) {
+            if (previous == null || previous.getElseConnectableGuards().isEmpty()) {
                 return true;
             }
 
@@ -1571,7 +1571,7 @@
             String guardsAnd = "";
             String guardsCastAnd = "";
 
-            GuardData elseGuard = group.getElseConnectableGuard();
+            List<GuardData> elseGuards = group.getElseConnectableGuards();
 
             boolean minimumState = checkMinimumState;
             if (minimumState) {
@@ -1643,7 +1643,7 @@
             }
 
             for (GuardData guard : group.getGuards()) {
-                if (elseGuard == guard) {
+                if (elseGuards.contains(guard)) {
                     continue;
                 }
 
@@ -1656,17 +1656,17 @@
                 }
             }
 
-            int ifCount = startGuardIf(builder, guardsBuilder, 0, elseGuard);
+            int ifCount = startGuardIf(builder, guardsBuilder, 0, elseGuards);
             builder.tree(castBuilder.getRoot());
-            ifCount = startGuardIf(builder, guardsCastBuilder, ifCount, elseGuard);
+            ifCount = startGuardIf(builder, guardsCastBuilder, ifCount, elseGuards);
             return ifCount;
         }
 
-        private int startGuardIf(CodeTreeBuilder builder, CodeTreeBuilder conditionBuilder, int ifCount, GuardData elseGuard) {
+        private int startGuardIf(CodeTreeBuilder builder, CodeTreeBuilder conditionBuilder, int ifCount, List<GuardData> elseGuard) {
             int newIfCount = ifCount;
 
             if (!conditionBuilder.isEmpty()) {
-                if (ifCount == 0 && elseGuard != null) {
+                if (ifCount == 0 && !elseGuard.isEmpty()) {
                     builder.startElseIf();
                 } else {
                     builder.startIf();
@@ -1674,7 +1674,7 @@
                 builder.tree(conditionBuilder.getRoot());
                 builder.end().startBlock();
                 newIfCount++;
-            } else if (ifCount == 0 && elseGuard != null) {
+            } else if (ifCount == 0 && !elseGuard.isEmpty()) {
                 builder.startElseBlock();
                 newIfCount++;
             }
--- a/graal/com.oracle.truffle.dsl.processor/src/com/oracle/truffle/dsl/processor/node/SpecializationGroup.java	Tue Aug 27 23:06:24 2013 +0200
+++ b/graal/com.oracle.truffle.dsl.processor/src/com/oracle/truffle/dsl/processor/node/SpecializationGroup.java	Wed Aug 28 01:45:13 2013 +0200
@@ -24,6 +24,10 @@
 
 import java.util.*;
 
+import javax.lang.model.type.*;
+
+import com.oracle.truffle.dsl.processor.*;
+import com.oracle.truffle.dsl.processor.template.*;
 import com.oracle.truffle.dsl.processor.template.TemplateMethod.Signature;
 import com.oracle.truffle.dsl.processor.typesystem.*;
 
@@ -64,6 +68,15 @@
         updateChildren(children);
     }
 
+    public List<TypeGuard> getAllGuards() {
+        List<TypeGuard> collectedGuards = new ArrayList<>();
+        collectedGuards.addAll(typeGuards);
+        if (parent != null) {
+            collectedGuards.addAll(parent.getAllGuards());
+        }
+        return collectedGuards;
+    }
+
     public TypeGuard findTypeGuard(int signatureIndex) {
         for (TypeGuard guard : typeGuards) {
             if (guard.getSignatureIndex() == signatureIndex) {
@@ -73,15 +86,34 @@
         return null;
     }
 
-    public GuardData getElseConnectableGuard() {
+    public List<GuardData> getElseConnectableGuards() {
         if (!getTypeGuards().isEmpty() || !getAssumptions().isEmpty()) {
-            return null;
+            return Collections.emptyList();
+        }
+
+        if (getGuards().isEmpty()) {
+            return Collections.emptyList();
         }
-        SpecializationGroup previousGroup = getPreviousGroup();
-        if (previousGroup != null && getGuards().size() >= 1 && previousGroup.getGuards().size() == 1) {
-            GuardData guard = getGuards().get(0);
-            GuardData previousGuard = previousGroup.getGuards().get(0);
+
+        List<GuardData> elseConnectableGuards = new ArrayList<>();
+        int guardIndex = 0;
+        while (guardIndex < getGuards().size() && findNegatedGuardInPrevious(getGuards().get(guardIndex)) != null) {
+            elseConnectableGuards.add(getGuards().get(guardIndex));
+            guardIndex++;
+        }
 
+        return elseConnectableGuards;
+    }
+
+    private GuardData findNegatedGuardInPrevious(GuardData guard) {
+        SpecializationGroup previous = this;
+        while ((previous = previous.getPreviousGroup()) != null) {
+            List<GuardData> elseConnectedGuards = previous.getElseConnectableGuards();
+
+            if (previous == null || previous.getGuards().size() != elseConnectedGuards.size() + 1) {
+                return null;
+            }
+            GuardData previousGuard = previous.getGuards().get(elseConnectedGuards.size());
             if (guard.getMethod().equals(previousGuard.getMethod())) {
                 assert guard.isNegated() != previousGuard.isNegated();
                 return guard;
@@ -169,6 +201,30 @@
             guardMatches.add(guard);
         }
 
+        // check for guards for required type casts
+        for (Iterator<GuardData> iterator = guardMatches.iterator(); iterator.hasNext();) {
+            GuardData guardMatch = iterator.next();
+
+            List<TypeMirror> guardTypes = TemplateMethod.getSignatureTypes(guardMatch.getParameters());
+            for (int i = 0; i < guardTypes.size(); i++) {
+                TypeMirror guardType = guardTypes.get(i);
+                int signatureIndex = i + 1;
+
+                // object guards can be safely moved up
+                if (Utils.isObject(guardType)) {
+                    continue;
+                }
+
+                // signature index required for moving up guards
+                if (containsIndex(typeGuardsMatches, signatureIndex) || (first.getParent() != null && first.getParent().containsTypeGuardIndex(signatureIndex))) {
+                    continue;
+                }
+
+                iterator.remove();
+                break;
+            }
+        }
+
         if (assumptionMatches.isEmpty() && typeGuardsMatches.isEmpty() && guardMatches.isEmpty()) {
             return null;
         }
@@ -183,6 +239,25 @@
         return new SpecializationGroup(newChildren, assumptionMatches, typeGuardsMatches, guardMatches);
     }
 
+    private boolean containsTypeGuardIndex(int index) {
+        if (containsIndex(typeGuards, index)) {
+            return true;
+        }
+        if (parent != null) {
+            return parent.containsTypeGuardIndex(index);
+        }
+        return false;
+    }
+
+    private static boolean containsIndex(List<TypeGuard> typeGuards, int signatureIndex) {
+        for (TypeGuard guard : typeGuards) {
+            if (guard.signatureIndex == signatureIndex) {
+                return true;
+            }
+        }
+        return false;
+    }
+
     public static SpecializationGroup create(List<SpecializationData> specializations) {
         List<SpecializationGroup> groups = new ArrayList<>();
         for (SpecializationData specialization : specializations) {
--- a/graal/com.oracle.truffle.dsl.processor/src/com/oracle/truffle/dsl/processor/template/TemplateMethod.java	Tue Aug 27 23:06:24 2013 +0200
+++ b/graal/com.oracle.truffle.dsl.processor/src/com/oracle/truffle/dsl/processor/template/TemplateMethod.java	Wed Aug 28 01:45:13 2013 +0200
@@ -283,8 +283,8 @@
             throw new IllegalStateException("Cannot compare two methods with different type systems.");
         }
 
-        List<TypeMirror> signature1 = getSignatureTypes();
-        List<TypeMirror> signature2 = compareMethod.getSignatureTypes();
+        List<TypeMirror> signature1 = getSignatureTypes(getReturnTypeAndParameters());
+        List<TypeMirror> signature2 = getSignatureTypes(compareMethod.getReturnTypeAndParameters());
         if (signature1.size() != signature2.size()) {
             return signature2.size() - signature1.size();
         }
@@ -333,9 +333,9 @@
         return Utils.getSimpleName(signature1).compareTo(Utils.getSimpleName(signature2));
     }
 
-    public List<TypeMirror> getSignatureTypes() {
+    public static List<TypeMirror> getSignatureTypes(List<ActualParameter> params) {
         List<TypeMirror> types = new ArrayList<>();
-        for (ActualParameter param : getReturnTypeAndParameters()) {
+        for (ActualParameter param : params) {
             if (param.getSpecification().isSignature()) {
                 types.add(param.getType());
             }