# HG changeset patch # User Christian Humer # Date 1391180990 -3600 # Node ID e86d32f4803f276b905d665e378c97d3fc03bb32 # Parent 1e72cd05b77e67883bb5203e0c0f525fa16070a7 Truffle: Implement cache for truffle inlining heuristic. diff -r 1e72cd05b77e -r e86d32f4803f graal/com.oracle.graal.truffle/src/com/oracle/graal/truffle/OptimizedCallTarget.java --- a/graal/com.oracle.graal.truffle/src/com/oracle/graal/truffle/OptimizedCallTarget.java Fri Jan 31 16:04:33 2014 +0200 +++ b/graal/com.oracle.graal.truffle/src/com/oracle/graal/truffle/OptimizedCallTarget.java Fri Jan 31 16:09:50 2014 +0100 @@ -259,7 +259,7 @@ continue; } - int notInlinedCallSiteCount = TruffleInliningImpl.getInlinableCallSites(callTarget).size(); + int notInlinedCallSiteCount = TruffleInliningImpl.getInlinableCallSites(callTarget, callTarget).size(); int nodeCount = NodeUtil.countNodes(callTarget.getRootNode(), null, true); int inlinedCallSiteCount = countInlinedNodes(callTarget.getRootNode()); String comment = callTarget.installedCode == null ? " int" : ""; diff -r 1e72cd05b77e -r e86d32f4803f graal/com.oracle.graal.truffle/src/com/oracle/graal/truffle/TruffleInliningImpl.java --- a/graal/com.oracle.graal.truffle/src/com/oracle/graal/truffle/TruffleInliningImpl.java Fri Jan 31 16:04:33 2014 +0200 +++ b/graal/com.oracle.graal.truffle/src/com/oracle/graal/truffle/TruffleInliningImpl.java Fri Jan 31 16:09:50 2014 +0100 @@ -30,6 +30,7 @@ import com.oracle.graal.debug.*; import com.oracle.truffle.api.*; import com.oracle.truffle.api.nodes.*; +import com.oracle.truffle.api.nodes.CallNode.*; class TruffleInliningImpl implements TruffleInlining { @@ -45,35 +46,41 @@ return MIN_INVOKES_AFTER_INLINING; } + private static void refresh(InliningPolicy policy, List infos) { + for (InlinableCallSiteInfo info : infos) { + info.refresh(policy); + } + } + @Override public boolean performInlining(OptimizedCallTarget target) { final InliningPolicy policy = new InliningPolicy(target); if (!policy.continueInlining()) { if (TraceTruffleInliningDetails.getValue()) { - List inlinableCallSites = getInlinableCallSites(target); + List inlinableCallSites = getInlinableCallSites(policy, target); if (!inlinableCallSites.isEmpty()) { OUT.printf("[truffle] inlining hit caller size limit (%3d >= %3d).%3d remaining call sites in %s:\n", policy.callerNodeCount, TruffleInliningMaxCallerSize.getValue(), inlinableCallSites.size(), target.getRootNode()); - policy.sortByRelevance(inlinableCallSites); + InliningPolicy.sortByRelevance(inlinableCallSites); printCallSiteInfo(policy, inlinableCallSites, ""); } } return false; } - List inlinableCallSites = getInlinableCallSites(target); + List inlinableCallSites = getInlinableCallSites(policy, target); if (inlinableCallSites.isEmpty()) { return false; } - policy.sortByRelevance(inlinableCallSites); + InliningPolicy.sortByRelevance(inlinableCallSites); boolean inlined = false; for (InlinableCallSiteInfo inlinableCallSite : inlinableCallSites) { - if (!policy.isWorthInlining(inlinableCallSite)) { + if (!inlinableCallSite.isWorth()) { break; } - if (inlinableCallSite.getCallSite().inline()) { + if (inlinableCallSite.getCallNode().inline()) { if (TraceTruffleInlining.getValue()) { printCallSiteInfo(policy, inlinableCallSite, "inlined"); } @@ -84,7 +91,10 @@ if (inlined) { for (InlinableCallSiteInfo callSite : inlinableCallSites) { - CallNode.internalResetCallCount(callSite.getCallSite()); + CompilerCallView callView = callSite.getCallNode().getCompilerCallView(); + if (callView != null) { + callView.resetCallCount(); + } } } else { if (TraceTruffleInliningDetails.getValue()) { @@ -105,7 +115,7 @@ private static void printCallSiteInfo(InliningPolicy policy, InlinableCallSiteInfo callSite, String msg) { String calls = String.format("%4s/%4s", callSite.getCallCount(), policy.callerInvocationCount); String nodes = String.format("%3s/%3s", callSite.getInlineNodeCount(), policy.callerNodeCount); - OUT.printf("[truffle] %-9s %-50s |Nodes %6s |Calls %6s %7.3f |%s\n", msg, callSite.getCallSite(), nodes, calls, policy.metric(callSite), callSite.getCallSite().getCallTarget()); + OUT.printf("[truffle] %-9s %-50s |Nodes %6s |Calls %6s %7.3f |%s\n", msg, callSite.getCallNode(), nodes, calls, policy.metric(callSite), callSite.getCallNode().getCallTarget()); } private static final class InliningPolicy { @@ -138,16 +148,19 @@ return (double) callSite.getCallCount() / (double) callerInvocationCount; } - public void sortByRelevance(List inlinableCallSites) { + private static void sortByRelevance(List inlinableCallSites) { Collections.sort(inlinableCallSites, new Comparator() { @Override public int compare(InlinableCallSiteInfo cs1, InlinableCallSiteInfo cs2) { - int result = (isWorthInlining(cs2) ? 1 : 0) - (isWorthInlining(cs1) ? 1 : 0); - if (result == 0) { - return Double.compare(metric(cs2), metric(cs1)); + boolean cs1Worth = cs1.isWorth(); + boolean cs2Worth = cs2.isWorth(); + if (cs1Worth && cs2Worth) { + return Double.compare(cs2.getScore(), cs1.getScore()); + } else if (cs1Worth ^ cs2Worth) { + return cs1Worth ? -1 : 1; } - return result; + return 0; } }); } @@ -155,16 +168,20 @@ private static final class InlinableCallSiteInfo { - private final CallNode callSite; - private final int callCount; + private final CallNode callNode; private final int nodeCount; private final int recursiveDepth; - public InlinableCallSiteInfo(CallNode callSite) { - assert callSite.isInlinable(); - this.callSite = callSite; - this.callCount = CallNode.internalGetCallCount(callSite); - RootCallTarget target = (RootCallTarget) callSite.getCallTarget(); + private int callCount; + private boolean worth; + private double score; + + @SuppressWarnings("unused") + public InlinableCallSiteInfo(InliningPolicy policy, CallNode callNode) { + assert callNode.isInlinable(); + this.callNode = callNode; + RootCallTarget target = (RootCallTarget) callNode.getCallTarget(); + this.nodeCount = target.getRootNode().getInlineNodeCount(); this.recursiveDepth = calculateRecursiveDepth(); } @@ -173,14 +190,31 @@ return recursiveDepth; } + public void refresh(InliningPolicy policy) { + this.callCount = callNode.getCompilerCallView().getCallCount(); + this.worth = policy.isWorthInlining(this); + if (worth) { + this.score = policy.metric(this); + } + // TODO shall we refresh the node count as well? + } + + public boolean isWorth() { + return worth; + } + + public double getScore() { + return score; + } + private int calculateRecursiveDepth() { int depth = 0; - Node parent = callSite.getParent(); + Node parent = callNode.getParent(); while (parent != null) { if (parent instanceof RootNode) { RootNode root = ((RootNode) parent); - if (root.getCallTarget() == callSite.getCallTarget()) { + if (root.getCallTarget() == callNode.getCallTarget()) { depth++; } parent = root.getParentInlinedCall(); @@ -191,8 +225,8 @@ return depth; } - public CallNode getCallSite() { - return callSite; + public CallNode getCallNode() { + return callNode; } public int getCallCount() { @@ -204,7 +238,7 @@ } } - static List getInlinableCallSites(final RootCallTarget target) { + private static List getInlinableCallSites(final InliningPolicy policy, final RootCallTarget target) { final ArrayList inlinableCallSites = new ArrayList<>(); target.getRootNode().accept(new NodeVisitor() { @@ -212,9 +246,16 @@ public boolean visit(Node node) { if (node instanceof CallNode) { CallNode callNode = (CallNode) node; + if (!callNode.isInlined()) { if (callNode.isInlinable()) { - inlinableCallSites.add(new InlinableCallSiteInfo(callNode)); + CompilerCallView view = callNode.getCompilerCallView(); + InlinableCallSiteInfo info = (InlinableCallSiteInfo) view.load(); + if (info == null) { + info = new InlinableCallSiteInfo(policy, callNode); + view.store(info); + } + inlinableCallSites.add(info); } } else { callNode.getInlinedRoot().accept(this); @@ -223,6 +264,11 @@ return true; } }); + refresh(policy, inlinableCallSites); return inlinableCallSites; } + + static List getInlinableCallSites(final OptimizedCallTarget target, final RootCallTarget root) { + return getInlinableCallSites(new InliningPolicy(target), root); + } } diff -r 1e72cd05b77e -r e86d32f4803f graal/com.oracle.truffle.api/src/com/oracle/truffle/api/nodes/CallNode.java --- a/graal/com.oracle.truffle.api/src/com/oracle/truffle/api/nodes/CallNode.java Fri Jan 31 16:04:33 2014 +0200 +++ b/graal/com.oracle.truffle.api/src/com/oracle/truffle/api/nodes/CallNode.java Fri Jan 31 16:09:50 2014 +0100 @@ -102,7 +102,7 @@ */ public static CallNode create(CallTarget target) { if (isInlinable(target)) { - return new InlinableCallNode(target); + return new InlinableCallNode((RootCallTarget) target); } else { return new DefaultCallNode(target); } @@ -111,26 +111,27 @@ /** * Warning: this is internal API and may change without notice. */ - public static int internalGetCallCount(CallNode callNode) { - if (callNode.isInlinable() && !callNode.isInlined()) { - return ((InlinableCallNode) callNode).getCallCount(); - } - throw new UnsupportedOperationException(); + public interface CompilerCallView { + + int getCallCount(); + + void resetCallCount(); + + void store(Object value); + + Object load(); } /** * Warning: this is internal API and may change without notice. */ - public static void internalResetCallCount(CallNode callNode) { - if (callNode.isInlinable() && !callNode.isInlined()) { - ((InlinableCallNode) callNode).resetCallCount(); - return; - } + public CompilerCallView getCompilerCallView() { + return null; } private static boolean isInlinable(CallTarget callTarget) { - if (callTarget instanceof DefaultCallTarget) { - return (((DefaultCallTarget) callTarget).getRootNode()).isInlinable(); + if (callTarget instanceof RootCallTarget) { + return (((RootCallTarget) callTarget).getRootNode()).isInlinable(); } return false; } @@ -168,11 +169,11 @@ } - static final class InlinableCallNode extends CallNode { + static final class InlinableCallNode extends CallNode implements CompilerCallView { private int callCount; - public InlinableCallNode(CallTarget target) { + public InlinableCallNode(RootCallTarget target) { super(target); } @@ -208,16 +209,31 @@ return true; } + @Override + public CompilerCallView getCompilerCallView() { + return this; + } + /* Truffle internal API. */ - int getCallCount() { + public int getCallCount() { return callCount; } /* Truffle internal API. */ - void resetCallCount() { + public void resetCallCount() { callCount = 0; } + private Object storedCompilerInfo; + + public void store(Object value) { + this.storedCompilerInfo = value; + } + + public Object load() { + return storedCompilerInfo; + } + } static final class InlinedCallNode extends CallNode {