changeset 22943:86bee10c31b0

Fold complex expressions during CE when possible
author Tom Rodriguez <tom.rodriguez@oracle.com>
date Tue, 03 Nov 2015 18:45:14 -0800
parents 2a5b62614a96
children d6f0245476e2
files graal/com.oracle.graal.compiler.test/src/com/oracle/graal/compiler/test/ConditionalEliminationTest11.java graal/com.oracle.graal.phases.common/src/com/oracle/graal/phases/common/DominatorConditionalEliminationPhase.java
diffstat 2 files changed, 198 insertions(+), 7 deletions(-) [+]
line wrap: on
line diff
--- a/graal/com.oracle.graal.compiler.test/src/com/oracle/graal/compiler/test/ConditionalEliminationTest11.java	Tue Nov 03 15:04:20 2015 -0800
+++ b/graal/com.oracle.graal.compiler.test/src/com/oracle/graal/compiler/test/ConditionalEliminationTest11.java	Tue Nov 03 18:45:14 2015 -0800
@@ -227,4 +227,76 @@
         testConditionalElimination("test9Snippet", "reference9Snippet");
     }
 
+    static class ByteHolder {
+        public byte b;
+
+        byte byteValue() {
+            return b;
+        }
+    }
+
+    public static int test10Snippet(ByteHolder b) {
+        int v = b.byteValue();
+        long a = v & 0xffffffff;
+        if (v != 44) {
+            GraalDirectives.deoptimize();
+        }
+        if ((a & 16) == 16) {
+            GraalDirectives.deoptimize();
+        }
+        if ((a & 8) != 8) {
+            GraalDirectives.deoptimize();
+        }
+        if ((a & 44) != 44) {
+            GraalDirectives.deoptimize();
+        }
+
+        return v;
+    }
+
+    public static int reference10Snippet(ByteHolder b) {
+        byte v = b.byteValue();
+        if (v != 44) {
+            GraalDirectives.deoptimize();
+        }
+        return v;
+    }
+
+    @Test
+    public void test10() {
+        testConditionalElimination("test10Snippet", "reference10Snippet");
+    }
+
+    public static int test11Snippet(ByteHolder b) {
+        int v = b.byteValue();
+        long a = v & 0xffffffff;
+
+        if ((a & 16) == 16) {
+            GraalDirectives.deoptimize();
+        }
+        if ((a & 8) != 8) {
+            GraalDirectives.deoptimize();
+        }
+        if ((a & 44) != 44) {
+            GraalDirectives.deoptimize();
+        }
+        if (v != 44) {
+            GraalDirectives.deoptimize();
+        }
+        return v;
+    }
+
+    public static int reference11Snippet(ByteHolder b) {
+        byte v = b.byteValue();
+        if (v != 44) {
+            GraalDirectives.deoptimize();
+        }
+        return v;
+    }
+
+    @Test
+    public void test11() {
+        testConditionalElimination("test11Snippet", "reference11Snippet");
+    }
+
 }
--- a/graal/com.oracle.graal.phases.common/src/com/oracle/graal/phases/common/DominatorConditionalEliminationPhase.java	Tue Nov 03 15:04:20 2015 -0800
+++ b/graal/com.oracle.graal.phases.common/src/com/oracle/graal/phases/common/DominatorConditionalEliminationPhase.java	Tue Nov 03 18:45:14 2015 -0800
@@ -64,6 +64,7 @@
 import com.oracle.graal.nodes.IfNode;
 import com.oracle.graal.nodes.LogicNode;
 import com.oracle.graal.nodes.LoopExitNode;
+import com.oracle.graal.nodes.ParameterNode;
 import com.oracle.graal.nodes.PiNode;
 import com.oracle.graal.nodes.ShortCircuitOrNode;
 import com.oracle.graal.nodes.StructuredGraph;
@@ -375,6 +376,94 @@
             registerCondition(condition, negated, guard, undoOperations);
         }
 
+        static class Pair<F, S> {
+            public final F first;
+            public final S second;
+
+            public Pair(F first, S second) {
+                this.first = first;
+                this.second = second;
+            }
+        }
+
+        @FunctionalInterface
+        interface InfoElementProvider {
+            Iterable<InfoElement> getInfoElements(ValueNode value);
+        }
+
+        static Pair<InfoElement, Stamp> recursiveFoldStamp(Node node, InfoElementProvider info) {
+            if (node instanceof UnaryNode) {
+                UnaryNode unary = (UnaryNode) node;
+                ValueNode value = unary.getValue();
+                for (InfoElement infoElement : info.getInfoElements(value)) {
+                    Stamp result = unary.foldStamp(infoElement.getStamp());
+                    if (result != null) {
+                        return new Pair<>(infoElement, result);
+                    }
+                }
+                Pair<InfoElement, Stamp> foldResult = recursiveFoldStamp(value, info);
+                if (foldResult != null) {
+                    Stamp result = unary.foldStamp(foldResult.second);
+                    if (result != null) {
+                        return new Pair<>(foldResult.first, result);
+                    }
+                }
+            } else if (node instanceof BinaryNode) {
+                BinaryNode binary = (BinaryNode) node;
+                ValueNode y = binary.getY();
+                ValueNode x = binary.getX();
+                if (y.isConstant()) {
+                    for (InfoElement infoElement : info.getInfoElements(x)) {
+                        Stamp result = binary.foldStamp(infoElement.stamp, y.stamp());
+                        if (result != null) {
+                            return new Pair<>(infoElement, result);
+                        }
+                    }
+                    Pair<InfoElement, Stamp> foldResult = recursiveFoldStamp(x, info);
+                    if (foldResult != null) {
+                        Stamp result = binary.foldStamp(foldResult.second, y.stamp());
+                        if (result != null) {
+                            return new Pair<>(foldResult.first, result);
+                        }
+                    }
+                }
+            }
+            return null;
+        }
+
+        /**
+         * Recursively try to fold stamps within this expression using information from
+         * {@link #getInfoElements(ValueNode)}. It's only safe to use constants and one
+         * {@link InfoElement} otherwise more than one guard would be required.
+         *
+         * @param node
+         * @return the pair of the @{link InfoElement} used and the stamp produced for the whole
+         *         expression
+         */
+        Pair<InfoElement, Stamp> recursiveFoldStampFromInfo(Node node) {
+            return recursiveFoldStamp(node, (value) -> getInfoElements(value));
+        }
+
+        /**
+         * Recursively try to fold stamps within this expression using {@code newStamp} if the node
+         * {@code original} is encountered in the expression. It's only safe to use constants and
+         * the passed in stamp otherwise more than one guard would be required.
+         *
+         * @param node
+         * @param original
+         * @param newStamp
+         * @return the improved stamp or null is nothing could be done
+         */
+        @SuppressWarnings("unchecked")
+        static Stamp recursiveFoldStamp(Node node, ValueNode original, Stamp newStamp) {
+            InfoElement element = new InfoElement(newStamp, original);
+            Pair<InfoElement, Stamp> result = recursiveFoldStamp(node, (value) -> value == original ? Collections.singleton(element) : Collections.EMPTY_LIST);
+            if (result != null) {
+                return result.second;
+            }
+            return null;
+        }
+
         /**
          * Checks for safe nodes when moving pending tests up.
          */
@@ -393,7 +482,7 @@
                     return;
                 }
                 ValueNode curValue = (ValueNode) curNode;
-                if (curValue.isConstant() || curValue == value) {
+                if (curValue.isConstant() || curValue == value || curValue instanceof ParameterNode) {
                     return;
                 }
                 if (curValue instanceof BinaryNode || curValue instanceof UnaryNode) {
@@ -412,6 +501,12 @@
                     if (unaryLogicNode.getValue() == original) {
                         result = unaryLogicNode.tryFold(newStamp);
                     }
+                    if (!result.isKnown()) {
+                        Stamp foldResult = recursiveFoldStamp(unaryLogicNode.getValue(), original, newStamp);
+                        if (foldResult != null) {
+                            result = unaryLogicNode.tryFold(foldResult);
+                        }
+                    }
                 } else if (pending.condition instanceof BinaryOpLogicNode) {
                     BinaryOpLogicNode binaryOpLogicNode = (BinaryOpLogicNode) pending.condition;
                     ValueNode x = binaryOpLogicNode.getX();
@@ -425,6 +520,12 @@
                             result = binaryOpLogicNode.tryFold(andOp.foldStamp(newStamp, y.stamp()), y.stamp());
                         }
                     }
+                    if (!result.isKnown() && y.isConstant()) {
+                        Stamp foldResult = recursiveFoldStamp(x, original, newStamp);
+                        if (foldResult != null) {
+                            result = binaryOpLogicNode.tryFold(foldResult, y.stamp());
+                        }
+                    }
                 }
                 if (result.isKnown()) {
                     /*
@@ -537,9 +638,16 @@
                         return rewireGuards(infoElement.getGuard(), result.toBoolean(), rewireGuardFunction);
                     }
                 }
+                Pair<InfoElement, Stamp> foldResult = recursiveFoldStampFromInfo(value);
+                if (foldResult != null) {
+                    TriState result = unaryLogicNode.tryFold(foldResult.second);
+                    if (result.isKnown()) {
+                        return rewireGuards(foldResult.first.getGuard(), result.toBoolean(), rewireGuardFunction);
+                    }
+                }
                 if (thisGuard != null && unaryLogicNode.stamp() instanceof IntegerStamp) {
                     Stamp newStamp = unaryLogicNode.getSucceedingStampForValue(thisGuard.isNegated());
-                    if (newStamp != null && foldPendingTest(thisGuard, unaryLogicNode.getValue(), newStamp, rewireGuardFunction)) {
+                    if (newStamp != null && foldPendingTest(thisGuard, value, newStamp, rewireGuardFunction)) {
                         return true;
                     }
 
@@ -563,12 +671,23 @@
                     }
                 }
 
-                for (InfoElement infoElement : getInfoElements(y)) {
-                    TriState result = binaryOpLogicNode.tryFold(x.stamp(), infoElement.getStamp());
-                    if (result.isKnown()) {
-                        return rewireGuards(infoElement.getGuard(), result.toBoolean(), rewireGuardFunction);
+                if (y.isConstant()) {
+                    Pair<InfoElement, Stamp> foldResult = recursiveFoldStampFromInfo(x);
+                    if (foldResult != null) {
+                        TriState result = binaryOpLogicNode.tryFold(foldResult.second, y.stamp());
+                        if (result.isKnown()) {
+                            return rewireGuards(foldResult.first.getGuard(), result.toBoolean(), rewireGuardFunction);
+                        }
+                    }
+                } else {
+                    for (InfoElement infoElement : getInfoElements(y)) {
+                        TriState result = binaryOpLogicNode.tryFold(x.stamp(), infoElement.getStamp());
+                        if (result.isKnown()) {
+                            return rewireGuards(infoElement.getGuard(), result.toBoolean(), rewireGuardFunction);
+                        }
                     }
                 }
+
                 /*
                  * For complex expressions involving constants, see if it's possible to fold the
                  * tests by using stamps one level up in the expression. For instance, (x + n < y)
@@ -580,7 +699,7 @@
                     BinaryArithmeticNode<?> binary = (BinaryArithmeticNode<?>) x;
                     if (binary.getY().isConstant()) {
                         for (InfoElement infoElement : getInfoElements(binary.getX())) {
-                            Stamp newStampX = binary.tryFoldStamp(infoElement.getStamp(), binary.getY().stamp());
+                            Stamp newStampX = binary.foldStamp(infoElement.getStamp(), binary.getY().stamp());
                             TriState result = binaryOpLogicNode.tryFold(newStampX, y.stamp());
                             if (result.isKnown()) {
                                 return rewireGuards(infoElement.getGuard(), result.toBoolean(), rewireGuardFunction);