001/*
002 * Copyright (c) 2013, 2014, Oracle and/or its affiliates. All rights reserved.
003 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
004 *
005 * This code is free software; you can redistribute it and/or modify it
006 * under the terms of the GNU General Public License version 2 only, as
007 * published by the Free Software Foundation.
008 *
009 * This code is distributed in the hope that it will be useful, but WITHOUT
010 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
011 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
012 * version 2 for more details (a copy is included in the LICENSE file that
013 * accompanied this code).
014 *
015 * You should have received a copy of the GNU General Public License version
016 * 2 along with this work; if not, write to the Free Software Foundation,
017 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
018 *
019 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
020 * or visit www.oracle.com if you need additional information or have any
021 * questions.
022 */
023package com.oracle.graal.truffle;
024
025import java.util.*;
026import java.util.stream.*;
027
028import com.oracle.truffle.api.*;
029import com.oracle.truffle.api.nodes.*;
030import com.oracle.truffle.api.nodes.Node;
031
032public class TruffleInlining implements Iterable<TruffleInliningDecision> {
033
034    private final List<TruffleInliningDecision> callSites;
035
036    protected TruffleInlining(List<TruffleInliningDecision> callSites) {
037        this.callSites = callSites;
038    }
039
040    public TruffleInlining(OptimizedCallTarget sourceTarget, TruffleInliningPolicy policy) {
041        this(createDecisions(sourceTarget, policy, sourceTarget.getRootNode().getCompilerOptions()));
042
043    }
044
045    private static List<TruffleInliningDecision> createDecisions(OptimizedCallTarget sourceTarget, TruffleInliningPolicy policy, CompilerOptions options) {
046        int nodeCount = sourceTarget.getNonTrivialNodeCount();
047        List<TruffleInliningDecision> exploredCallSites = exploreCallSites(new ArrayList<>(Arrays.asList(sourceTarget)), nodeCount, policy);
048        return decideInlining(exploredCallSites, policy, nodeCount, options);
049    }
050
051    private static List<TruffleInliningDecision> exploreCallSites(List<OptimizedCallTarget> stack, int callStackNodeCount, TruffleInliningPolicy policy) {
052        List<TruffleInliningDecision> exploredCallSites = new ArrayList<>();
053        OptimizedCallTarget parentTarget = stack.get(stack.size() - 1);
054        for (OptimizedDirectCallNode callNode : parentTarget.getCallNodes()) {
055            OptimizedCallTarget currentTarget = callNode.getCurrentCallTarget();
056            stack.add(currentTarget); // push
057            exploredCallSites.add(exploreCallSite(stack, callStackNodeCount, policy, callNode));
058            stack.remove(stack.size() - 1); // pop
059        }
060        return exploredCallSites;
061    }
062
063    private static TruffleInliningDecision exploreCallSite(List<OptimizedCallTarget> callStack, int callStackNodeCount, TruffleInliningPolicy policy, OptimizedDirectCallNode callNode) {
064        OptimizedCallTarget parentTarget = callStack.get(callStack.size() - 2);
065        OptimizedCallTarget currentTarget = callStack.get(callStack.size() - 1);
066
067        List<TruffleInliningDecision> childCallSites = Collections.emptyList();
068        double frequency = calculateFrequency(parentTarget, callNode);
069        int nodeCount = callNode.getCurrentCallTarget().getNonTrivialNodeCount();
070
071        int recursions = countRecursions(callStack);
072        int deepNodeCount = nodeCount;
073        if (callStack.size() < 15 && recursions <= TruffleCompilerOptions.TruffleMaximumRecursiveInlining.getValue()) {
074            /*
075             * We make a preliminary optimistic inlining decision with best possible characteristics
076             * to avoid the exploration of unnecessary paths in the inlining tree.
077             */
078            final CompilerOptions options = callNode.getRootNode().getCompilerOptions();
079            if (policy.isAllowed(new TruffleInliningProfile(callNode, nodeCount, nodeCount, frequency, recursions), callStackNodeCount, options)) {
080                List<TruffleInliningDecision> exploredCallSites = exploreCallSites(callStack, callStackNodeCount + nodeCount, policy);
081                childCallSites = decideInlining(exploredCallSites, policy, nodeCount, options);
082                for (TruffleInliningDecision childCallSite : childCallSites) {
083                    if (childCallSite.isInline()) {
084                        deepNodeCount += childCallSite.getProfile().getDeepNodeCount();
085                    } else {
086                        /* we don't need those anymore. */
087                        childCallSite.getCallSites().clear();
088                    }
089                }
090            }
091        }
092
093        TruffleInliningProfile profile = new TruffleInliningProfile(callNode, nodeCount, deepNodeCount, frequency, recursions);
094        profile.setScore(policy.calculateScore(profile));
095        return new TruffleInliningDecision(currentTarget, profile, childCallSites);
096    }
097
098    private static double calculateFrequency(OptimizedCallTarget target, OptimizedDirectCallNode ocn) {
099        return (double) Math.max(1, ocn.getCallCount()) / (double) Math.max(1, target.getCompilationProfile().getInterpreterCallCount());
100    }
101
102    private static int countRecursions(List<OptimizedCallTarget> stack) {
103        int count = 0;
104        OptimizedCallTarget top = stack.get(stack.size() - 1);
105        for (int i = 0; i < stack.size() - 1; i++) {
106            if (stack.get(i) == top) {
107                count++;
108            }
109        }
110
111        return count;
112    }
113
114    private static List<TruffleInliningDecision> decideInlining(List<TruffleInliningDecision> callSites, TruffleInliningPolicy policy, int nodeCount, CompilerOptions options) {
115        int deepNodeCount = nodeCount;
116        int index = 0;
117        for (TruffleInliningDecision callSite : callSites.stream().sorted().collect(Collectors.toList())) {
118            TruffleInliningProfile profile = callSite.getProfile();
119            profile.setQueryIndex(index++);
120            if (policy.isAllowed(profile, deepNodeCount, options)) {
121                callSite.setInline(true);
122                deepNodeCount += profile.getDeepNodeCount();
123            }
124        }
125        return callSites;
126    }
127
128    public int getInlinedNodeCount() {
129        return getCallSites().stream().filter(callSite -> callSite.isInline()).mapToInt(callSite -> callSite.getProfile().getDeepNodeCount()).sum();
130    }
131
132    public int countCalls() {
133        return getCallSites().stream().mapToInt(callSite -> callSite.isInline() ? callSite.countCalls() + 1 : 1).sum();
134    }
135
136    public int countInlinedCalls() {
137        return getCallSites().stream().filter(TruffleInliningDecision::isInline).mapToInt(callSite -> callSite.countInlinedCalls() + 1).sum();
138    }
139
140    public final List<TruffleInliningDecision> getCallSites() {
141        return callSites;
142    }
143
144    public Iterator<TruffleInliningDecision> iterator() {
145        return callSites.iterator();
146    }
147
148    public TruffleInliningDecision findByCall(OptimizedDirectCallNode callNode) {
149        for (TruffleInliningDecision d : getCallSites()) {
150            if (d.getProfile().getCallNode() == callNode) {
151                return d;
152            }
153        }
154        return null;
155    }
156
157    /**
158     * Visits all nodes of the {@link CallTarget} and all of its inlined calls.
159     */
160    public void accept(OptimizedCallTarget target, NodeVisitor visitor) {
161        target.getRootNode().accept(new CallTreeNodeVisitorImpl(target, visitor));
162    }
163
164    /**
165     * Creates an iterator for all nodes of the {@link CallTarget} and all of its inlined calls.
166     */
167    public Iterator<Node> makeNodeIterator(OptimizedCallTarget target) {
168        return new CallTreeNodeIterator(target);
169    }
170
171    /**
172     * This visitor extends the {@link NodeVisitor} interface to be usable for traversing the full
173     * call tree.
174     */
175    public interface CallTreeNodeVisitor extends NodeVisitor {
176
177        boolean visit(List<TruffleInlining> decisionStack, Node node);
178
179        default boolean visit(Node node) {
180            return visit(null, node);
181        }
182
183        static int getNodeDepth(List<TruffleInlining> decisionStack, Node node) {
184            int depth = calculateNodeDepth(node);
185            if (decisionStack != null) {
186                for (int i = decisionStack.size() - 1; i > 0; i--) {
187                    TruffleInliningDecision decision = (TruffleInliningDecision) decisionStack.get(i);
188                    depth += calculateNodeDepth(decision.getProfile().getCallNode());
189                }
190            }
191            return depth;
192        }
193
194        static int calculateNodeDepth(Node node) {
195            int depth = 0;
196            Node traverseNode = node;
197            while (traverseNode != null) {
198                depth++;
199                traverseNode = traverseNode.getParent();
200            }
201            return depth;
202        }
203
204        static TruffleInliningDecision getCurrentInliningDecision(List<TruffleInlining> decisionStack) {
205            if (decisionStack == null || decisionStack.size() <= 1) {
206                return null;
207            }
208            return (TruffleInliningDecision) decisionStack.get(decisionStack.size() - 1);
209        }
210
211    }
212
213    /**
214     * This visitor wraps an existing {@link NodeVisitor} or {@link CallTreeNodeVisitor} and
215     * traverses the full Truffle tree including inlined call sites.
216     */
217    private static final class CallTreeNodeVisitorImpl implements NodeVisitor {
218
219        protected final List<TruffleInlining> stack = new ArrayList<>();
220        private final NodeVisitor visitor;
221        private boolean continueTraverse = true;
222
223        public CallTreeNodeVisitorImpl(OptimizedCallTarget target, NodeVisitor visitor) {
224            stack.add(target.getInlining());
225            this.visitor = visitor;
226        }
227
228        public boolean visit(Node node) {
229            if (node instanceof OptimizedDirectCallNode) {
230                OptimizedDirectCallNode callNode = (OptimizedDirectCallNode) node;
231                TruffleInlining inlining = stack.get(stack.size() - 1);
232                if (inlining != null) {
233                    TruffleInliningDecision childInlining = inlining.findByCall(callNode);
234                    if (childInlining != null) {
235                        stack.add(childInlining);
236                        continueTraverse = visitNode(node);
237                        if (continueTraverse && childInlining.isInline()) {
238                            childInlining.getTarget().getRootNode().accept(this);
239                        }
240                        stack.remove(stack.size() - 1);
241                    }
242                }
243                return continueTraverse;
244            } else {
245                continueTraverse = visitNode(node);
246                return continueTraverse;
247            }
248        }
249
250        private boolean visitNode(Node node) {
251            if (visitor instanceof CallTreeNodeVisitor) {
252                return ((CallTreeNodeVisitor) visitor).visit(stack, node);
253            } else {
254                return visitor.visit(node);
255            }
256        }
257    }
258
259    private static final class CallTreeNodeIterator implements Iterator<Node> {
260
261        private List<TruffleInlining> inliningDecisionStack = new ArrayList<>();
262        private List<Iterator<Node>> iteratorStack = new ArrayList<>();
263
264        public CallTreeNodeIterator(OptimizedCallTarget target) {
265            inliningDecisionStack.add(target.getInlining());
266            iteratorStack.add(NodeUtil.makeRecursiveIterator(target.getRootNode()));
267        }
268
269        public boolean hasNext() {
270            return peekIterator() != null;
271        }
272
273        public Node next() {
274            Iterator<Node> iterator = peekIterator();
275            if (iterator == null) {
276                throw new NoSuchElementException();
277            }
278
279            Node node = iterator.next();
280            if (node instanceof OptimizedDirectCallNode) {
281                visitInlinedCall(node);
282            }
283            return node;
284        }
285
286        private void visitInlinedCall(Node node) {
287            TruffleInlining currentDecision = inliningDecisionStack.get(inliningDecisionStack.size() - 1);
288            if (currentDecision == null) {
289                return;
290            }
291            TruffleInliningDecision decision = currentDecision.findByCall((OptimizedDirectCallNode) node);
292            if (decision.isInline()) {
293                inliningDecisionStack.add(decision);
294                iteratorStack.add(NodeUtil.makeRecursiveIterator(decision.getTarget().getRootNode()));
295            }
296        }
297
298        private Iterator<Node> peekIterator() {
299            int tos = iteratorStack.size() - 1;
300            while (tos >= 0) {
301                Iterator<Node> iterable = iteratorStack.get(tos);
302                if (iterable.hasNext()) {
303                    return iterable;
304                } else {
305                    iteratorStack.remove(tos);
306                    inliningDecisionStack.remove(tos--);
307                }
308            }
309            return null;
310        }
311
312        public void remove() {
313            throw new UnsupportedOperationException();
314        }
315
316    }
317
318}