Mercurial > hg > graal-compiler
changeset 22943:86bee10c31b0
Fold complex expressions during CE when possible
author | Tom Rodriguez <tom.rodriguez@oracle.com> |
---|---|
date | Tue, 03 Nov 2015 18:45:14 -0800 |
parents | 2a5b62614a96 |
children | d6f0245476e2 |
files | graal/com.oracle.graal.compiler.test/src/com/oracle/graal/compiler/test/ConditionalEliminationTest11.java graal/com.oracle.graal.phases.common/src/com/oracle/graal/phases/common/DominatorConditionalEliminationPhase.java |
diffstat | 2 files changed, 198 insertions(+), 7 deletions(-) [+] |
line wrap: on
line diff
--- a/graal/com.oracle.graal.compiler.test/src/com/oracle/graal/compiler/test/ConditionalEliminationTest11.java Tue Nov 03 15:04:20 2015 -0800 +++ b/graal/com.oracle.graal.compiler.test/src/com/oracle/graal/compiler/test/ConditionalEliminationTest11.java Tue Nov 03 18:45:14 2015 -0800 @@ -227,4 +227,76 @@ testConditionalElimination("test9Snippet", "reference9Snippet"); } + static class ByteHolder { + public byte b; + + byte byteValue() { + return b; + } + } + + public static int test10Snippet(ByteHolder b) { + int v = b.byteValue(); + long a = v & 0xffffffff; + if (v != 44) { + GraalDirectives.deoptimize(); + } + if ((a & 16) == 16) { + GraalDirectives.deoptimize(); + } + if ((a & 8) != 8) { + GraalDirectives.deoptimize(); + } + if ((a & 44) != 44) { + GraalDirectives.deoptimize(); + } + + return v; + } + + public static int reference10Snippet(ByteHolder b) { + byte v = b.byteValue(); + if (v != 44) { + GraalDirectives.deoptimize(); + } + return v; + } + + @Test + public void test10() { + testConditionalElimination("test10Snippet", "reference10Snippet"); + } + + public static int test11Snippet(ByteHolder b) { + int v = b.byteValue(); + long a = v & 0xffffffff; + + if ((a & 16) == 16) { + GraalDirectives.deoptimize(); + } + if ((a & 8) != 8) { + GraalDirectives.deoptimize(); + } + if ((a & 44) != 44) { + GraalDirectives.deoptimize(); + } + if (v != 44) { + GraalDirectives.deoptimize(); + } + return v; + } + + public static int reference11Snippet(ByteHolder b) { + byte v = b.byteValue(); + if (v != 44) { + GraalDirectives.deoptimize(); + } + return v; + } + + @Test + public void test11() { + testConditionalElimination("test11Snippet", "reference11Snippet"); + } + }
--- a/graal/com.oracle.graal.phases.common/src/com/oracle/graal/phases/common/DominatorConditionalEliminationPhase.java Tue Nov 03 15:04:20 2015 -0800 +++ b/graal/com.oracle.graal.phases.common/src/com/oracle/graal/phases/common/DominatorConditionalEliminationPhase.java Tue Nov 03 18:45:14 2015 -0800 @@ -64,6 +64,7 @@ import com.oracle.graal.nodes.IfNode; import com.oracle.graal.nodes.LogicNode; import com.oracle.graal.nodes.LoopExitNode; +import com.oracle.graal.nodes.ParameterNode; import com.oracle.graal.nodes.PiNode; import com.oracle.graal.nodes.ShortCircuitOrNode; import com.oracle.graal.nodes.StructuredGraph; @@ -375,6 +376,94 @@ registerCondition(condition, negated, guard, undoOperations); } + static class Pair<F, S> { + public final F first; + public final S second; + + public Pair(F first, S second) { + this.first = first; + this.second = second; + } + } + + @FunctionalInterface + interface InfoElementProvider { + Iterable<InfoElement> getInfoElements(ValueNode value); + } + + static Pair<InfoElement, Stamp> recursiveFoldStamp(Node node, InfoElementProvider info) { + if (node instanceof UnaryNode) { + UnaryNode unary = (UnaryNode) node; + ValueNode value = unary.getValue(); + for (InfoElement infoElement : info.getInfoElements(value)) { + Stamp result = unary.foldStamp(infoElement.getStamp()); + if (result != null) { + return new Pair<>(infoElement, result); + } + } + Pair<InfoElement, Stamp> foldResult = recursiveFoldStamp(value, info); + if (foldResult != null) { + Stamp result = unary.foldStamp(foldResult.second); + if (result != null) { + return new Pair<>(foldResult.first, result); + } + } + } else if (node instanceof BinaryNode) { + BinaryNode binary = (BinaryNode) node; + ValueNode y = binary.getY(); + ValueNode x = binary.getX(); + if (y.isConstant()) { + for (InfoElement infoElement : info.getInfoElements(x)) { + Stamp result = binary.foldStamp(infoElement.stamp, y.stamp()); + if (result != null) { + return new Pair<>(infoElement, result); + } + } + Pair<InfoElement, Stamp> foldResult = recursiveFoldStamp(x, info); + if (foldResult != null) { + Stamp result = binary.foldStamp(foldResult.second, y.stamp()); + if (result != null) { + return new Pair<>(foldResult.first, result); + } + } + } + } + return null; + } + + /** + * Recursively try to fold stamps within this expression using information from + * {@link #getInfoElements(ValueNode)}. It's only safe to use constants and one + * {@link InfoElement} otherwise more than one guard would be required. + * + * @param node + * @return the pair of the @{link InfoElement} used and the stamp produced for the whole + * expression + */ + Pair<InfoElement, Stamp> recursiveFoldStampFromInfo(Node node) { + return recursiveFoldStamp(node, (value) -> getInfoElements(value)); + } + + /** + * Recursively try to fold stamps within this expression using {@code newStamp} if the node + * {@code original} is encountered in the expression. It's only safe to use constants and + * the passed in stamp otherwise more than one guard would be required. + * + * @param node + * @param original + * @param newStamp + * @return the improved stamp or null is nothing could be done + */ + @SuppressWarnings("unchecked") + static Stamp recursiveFoldStamp(Node node, ValueNode original, Stamp newStamp) { + InfoElement element = new InfoElement(newStamp, original); + Pair<InfoElement, Stamp> result = recursiveFoldStamp(node, (value) -> value == original ? Collections.singleton(element) : Collections.EMPTY_LIST); + if (result != null) { + return result.second; + } + return null; + } + /** * Checks for safe nodes when moving pending tests up. */ @@ -393,7 +482,7 @@ return; } ValueNode curValue = (ValueNode) curNode; - if (curValue.isConstant() || curValue == value) { + if (curValue.isConstant() || curValue == value || curValue instanceof ParameterNode) { return; } if (curValue instanceof BinaryNode || curValue instanceof UnaryNode) { @@ -412,6 +501,12 @@ if (unaryLogicNode.getValue() == original) { result = unaryLogicNode.tryFold(newStamp); } + if (!result.isKnown()) { + Stamp foldResult = recursiveFoldStamp(unaryLogicNode.getValue(), original, newStamp); + if (foldResult != null) { + result = unaryLogicNode.tryFold(foldResult); + } + } } else if (pending.condition instanceof BinaryOpLogicNode) { BinaryOpLogicNode binaryOpLogicNode = (BinaryOpLogicNode) pending.condition; ValueNode x = binaryOpLogicNode.getX(); @@ -425,6 +520,12 @@ result = binaryOpLogicNode.tryFold(andOp.foldStamp(newStamp, y.stamp()), y.stamp()); } } + if (!result.isKnown() && y.isConstant()) { + Stamp foldResult = recursiveFoldStamp(x, original, newStamp); + if (foldResult != null) { + result = binaryOpLogicNode.tryFold(foldResult, y.stamp()); + } + } } if (result.isKnown()) { /* @@ -537,9 +638,16 @@ return rewireGuards(infoElement.getGuard(), result.toBoolean(), rewireGuardFunction); } } + Pair<InfoElement, Stamp> foldResult = recursiveFoldStampFromInfo(value); + if (foldResult != null) { + TriState result = unaryLogicNode.tryFold(foldResult.second); + if (result.isKnown()) { + return rewireGuards(foldResult.first.getGuard(), result.toBoolean(), rewireGuardFunction); + } + } if (thisGuard != null && unaryLogicNode.stamp() instanceof IntegerStamp) { Stamp newStamp = unaryLogicNode.getSucceedingStampForValue(thisGuard.isNegated()); - if (newStamp != null && foldPendingTest(thisGuard, unaryLogicNode.getValue(), newStamp, rewireGuardFunction)) { + if (newStamp != null && foldPendingTest(thisGuard, value, newStamp, rewireGuardFunction)) { return true; } @@ -563,12 +671,23 @@ } } - for (InfoElement infoElement : getInfoElements(y)) { - TriState result = binaryOpLogicNode.tryFold(x.stamp(), infoElement.getStamp()); - if (result.isKnown()) { - return rewireGuards(infoElement.getGuard(), result.toBoolean(), rewireGuardFunction); + if (y.isConstant()) { + Pair<InfoElement, Stamp> foldResult = recursiveFoldStampFromInfo(x); + if (foldResult != null) { + TriState result = binaryOpLogicNode.tryFold(foldResult.second, y.stamp()); + if (result.isKnown()) { + return rewireGuards(foldResult.first.getGuard(), result.toBoolean(), rewireGuardFunction); + } + } + } else { + for (InfoElement infoElement : getInfoElements(y)) { + TriState result = binaryOpLogicNode.tryFold(x.stamp(), infoElement.getStamp()); + if (result.isKnown()) { + return rewireGuards(infoElement.getGuard(), result.toBoolean(), rewireGuardFunction); + } } } + /* * For complex expressions involving constants, see if it's possible to fold the * tests by using stamps one level up in the expression. For instance, (x + n < y) @@ -580,7 +699,7 @@ BinaryArithmeticNode<?> binary = (BinaryArithmeticNode<?>) x; if (binary.getY().isConstant()) { for (InfoElement infoElement : getInfoElements(binary.getX())) { - Stamp newStampX = binary.tryFoldStamp(infoElement.getStamp(), binary.getY().stamp()); + Stamp newStampX = binary.foldStamp(infoElement.getStamp(), binary.getY().stamp()); TriState result = binaryOpLogicNode.tryFold(newStampX, y.stamp()); if (result.isKnown()) { return rewireGuards(infoElement.getGuard(), result.toBoolean(), rewireGuardFunction);