# HG changeset patch # User Tom Rodriguez # Date 1446605114 28800 # Node ID 86bee10c31b062b94db7c64db8007e13dd240c69 # Parent 2a5b62614a96c4ec775521fce138a494482f421b Fold complex expressions during CE when possible diff -r 2a5b62614a96 -r 86bee10c31b0 graal/com.oracle.graal.compiler.test/src/com/oracle/graal/compiler/test/ConditionalEliminationTest11.java --- a/graal/com.oracle.graal.compiler.test/src/com/oracle/graal/compiler/test/ConditionalEliminationTest11.java Tue Nov 03 15:04:20 2015 -0800 +++ b/graal/com.oracle.graal.compiler.test/src/com/oracle/graal/compiler/test/ConditionalEliminationTest11.java Tue Nov 03 18:45:14 2015 -0800 @@ -227,4 +227,76 @@ testConditionalElimination("test9Snippet", "reference9Snippet"); } + static class ByteHolder { + public byte b; + + byte byteValue() { + return b; + } + } + + public static int test10Snippet(ByteHolder b) { + int v = b.byteValue(); + long a = v & 0xffffffff; + if (v != 44) { + GraalDirectives.deoptimize(); + } + if ((a & 16) == 16) { + GraalDirectives.deoptimize(); + } + if ((a & 8) != 8) { + GraalDirectives.deoptimize(); + } + if ((a & 44) != 44) { + GraalDirectives.deoptimize(); + } + + return v; + } + + public static int reference10Snippet(ByteHolder b) { + byte v = b.byteValue(); + if (v != 44) { + GraalDirectives.deoptimize(); + } + return v; + } + + @Test + public void test10() { + testConditionalElimination("test10Snippet", "reference10Snippet"); + } + + public static int test11Snippet(ByteHolder b) { + int v = b.byteValue(); + long a = v & 0xffffffff; + + if ((a & 16) == 16) { + GraalDirectives.deoptimize(); + } + if ((a & 8) != 8) { + GraalDirectives.deoptimize(); + } + if ((a & 44) != 44) { + GraalDirectives.deoptimize(); + } + if (v != 44) { + GraalDirectives.deoptimize(); + } + return v; + } + + public static int reference11Snippet(ByteHolder b) { + byte v = b.byteValue(); + if (v != 44) { + GraalDirectives.deoptimize(); + } + return v; + } + + @Test + public void test11() { + testConditionalElimination("test11Snippet", "reference11Snippet"); + } + } diff -r 2a5b62614a96 -r 86bee10c31b0 graal/com.oracle.graal.phases.common/src/com/oracle/graal/phases/common/DominatorConditionalEliminationPhase.java --- a/graal/com.oracle.graal.phases.common/src/com/oracle/graal/phases/common/DominatorConditionalEliminationPhase.java Tue Nov 03 15:04:20 2015 -0800 +++ b/graal/com.oracle.graal.phases.common/src/com/oracle/graal/phases/common/DominatorConditionalEliminationPhase.java Tue Nov 03 18:45:14 2015 -0800 @@ -64,6 +64,7 @@ import com.oracle.graal.nodes.IfNode; import com.oracle.graal.nodes.LogicNode; import com.oracle.graal.nodes.LoopExitNode; +import com.oracle.graal.nodes.ParameterNode; import com.oracle.graal.nodes.PiNode; import com.oracle.graal.nodes.ShortCircuitOrNode; import com.oracle.graal.nodes.StructuredGraph; @@ -375,6 +376,94 @@ registerCondition(condition, negated, guard, undoOperations); } + static class Pair { + public final F first; + public final S second; + + public Pair(F first, S second) { + this.first = first; + this.second = second; + } + } + + @FunctionalInterface + interface InfoElementProvider { + Iterable getInfoElements(ValueNode value); + } + + static Pair recursiveFoldStamp(Node node, InfoElementProvider info) { + if (node instanceof UnaryNode) { + UnaryNode unary = (UnaryNode) node; + ValueNode value = unary.getValue(); + for (InfoElement infoElement : info.getInfoElements(value)) { + Stamp result = unary.foldStamp(infoElement.getStamp()); + if (result != null) { + return new Pair<>(infoElement, result); + } + } + Pair foldResult = recursiveFoldStamp(value, info); + if (foldResult != null) { + Stamp result = unary.foldStamp(foldResult.second); + if (result != null) { + return new Pair<>(foldResult.first, result); + } + } + } else if (node instanceof BinaryNode) { + BinaryNode binary = (BinaryNode) node; + ValueNode y = binary.getY(); + ValueNode x = binary.getX(); + if (y.isConstant()) { + for (InfoElement infoElement : info.getInfoElements(x)) { + Stamp result = binary.foldStamp(infoElement.stamp, y.stamp()); + if (result != null) { + return new Pair<>(infoElement, result); + } + } + Pair foldResult = recursiveFoldStamp(x, info); + if (foldResult != null) { + Stamp result = binary.foldStamp(foldResult.second, y.stamp()); + if (result != null) { + return new Pair<>(foldResult.first, result); + } + } + } + } + return null; + } + + /** + * Recursively try to fold stamps within this expression using information from + * {@link #getInfoElements(ValueNode)}. It's only safe to use constants and one + * {@link InfoElement} otherwise more than one guard would be required. + * + * @param node + * @return the pair of the @{link InfoElement} used and the stamp produced for the whole + * expression + */ + Pair recursiveFoldStampFromInfo(Node node) { + return recursiveFoldStamp(node, (value) -> getInfoElements(value)); + } + + /** + * Recursively try to fold stamps within this expression using {@code newStamp} if the node + * {@code original} is encountered in the expression. It's only safe to use constants and + * the passed in stamp otherwise more than one guard would be required. + * + * @param node + * @param original + * @param newStamp + * @return the improved stamp or null is nothing could be done + */ + @SuppressWarnings("unchecked") + static Stamp recursiveFoldStamp(Node node, ValueNode original, Stamp newStamp) { + InfoElement element = new InfoElement(newStamp, original); + Pair result = recursiveFoldStamp(node, (value) -> value == original ? Collections.singleton(element) : Collections.EMPTY_LIST); + if (result != null) { + return result.second; + } + return null; + } + /** * Checks for safe nodes when moving pending tests up. */ @@ -393,7 +482,7 @@ return; } ValueNode curValue = (ValueNode) curNode; - if (curValue.isConstant() || curValue == value) { + if (curValue.isConstant() || curValue == value || curValue instanceof ParameterNode) { return; } if (curValue instanceof BinaryNode || curValue instanceof UnaryNode) { @@ -412,6 +501,12 @@ if (unaryLogicNode.getValue() == original) { result = unaryLogicNode.tryFold(newStamp); } + if (!result.isKnown()) { + Stamp foldResult = recursiveFoldStamp(unaryLogicNode.getValue(), original, newStamp); + if (foldResult != null) { + result = unaryLogicNode.tryFold(foldResult); + } + } } else if (pending.condition instanceof BinaryOpLogicNode) { BinaryOpLogicNode binaryOpLogicNode = (BinaryOpLogicNode) pending.condition; ValueNode x = binaryOpLogicNode.getX(); @@ -425,6 +520,12 @@ result = binaryOpLogicNode.tryFold(andOp.foldStamp(newStamp, y.stamp()), y.stamp()); } } + if (!result.isKnown() && y.isConstant()) { + Stamp foldResult = recursiveFoldStamp(x, original, newStamp); + if (foldResult != null) { + result = binaryOpLogicNode.tryFold(foldResult, y.stamp()); + } + } } if (result.isKnown()) { /* @@ -537,9 +638,16 @@ return rewireGuards(infoElement.getGuard(), result.toBoolean(), rewireGuardFunction); } } + Pair foldResult = recursiveFoldStampFromInfo(value); + if (foldResult != null) { + TriState result = unaryLogicNode.tryFold(foldResult.second); + if (result.isKnown()) { + return rewireGuards(foldResult.first.getGuard(), result.toBoolean(), rewireGuardFunction); + } + } if (thisGuard != null && unaryLogicNode.stamp() instanceof IntegerStamp) { Stamp newStamp = unaryLogicNode.getSucceedingStampForValue(thisGuard.isNegated()); - if (newStamp != null && foldPendingTest(thisGuard, unaryLogicNode.getValue(), newStamp, rewireGuardFunction)) { + if (newStamp != null && foldPendingTest(thisGuard, value, newStamp, rewireGuardFunction)) { return true; } @@ -563,12 +671,23 @@ } } - for (InfoElement infoElement : getInfoElements(y)) { - TriState result = binaryOpLogicNode.tryFold(x.stamp(), infoElement.getStamp()); - if (result.isKnown()) { - return rewireGuards(infoElement.getGuard(), result.toBoolean(), rewireGuardFunction); + if (y.isConstant()) { + Pair foldResult = recursiveFoldStampFromInfo(x); + if (foldResult != null) { + TriState result = binaryOpLogicNode.tryFold(foldResult.second, y.stamp()); + if (result.isKnown()) { + return rewireGuards(foldResult.first.getGuard(), result.toBoolean(), rewireGuardFunction); + } + } + } else { + for (InfoElement infoElement : getInfoElements(y)) { + TriState result = binaryOpLogicNode.tryFold(x.stamp(), infoElement.getStamp()); + if (result.isKnown()) { + return rewireGuards(infoElement.getGuard(), result.toBoolean(), rewireGuardFunction); + } } } + /* * For complex expressions involving constants, see if it's possible to fold the * tests by using stamps one level up in the expression. For instance, (x + n < y) @@ -580,7 +699,7 @@ BinaryArithmeticNode binary = (BinaryArithmeticNode) x; if (binary.getY().isConstant()) { for (InfoElement infoElement : getInfoElements(binary.getX())) { - Stamp newStampX = binary.tryFoldStamp(infoElement.getStamp(), binary.getY().stamp()); + Stamp newStampX = binary.foldStamp(infoElement.getStamp(), binary.getY().stamp()); TriState result = binaryOpLogicNode.tryFold(newStampX, y.stamp()); if (result.isKnown()) { return rewireGuards(infoElement.getGuard(), result.toBoolean(), rewireGuardFunction);