001/*
002 * Copyright (c) 2015, 2015, Oracle and/or its affiliates. All rights reserved.
003 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
004 *
005 * This code is free software; you can redistribute it and/or modify it
006 * under the terms of the GNU General Public License version 2 only, as
007 * published by the Free Software Foundation.
008 *
009 * This code is distributed in the hope that it will be useful, but WITHOUT
010 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
011 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
012 * version 2 for more details (a copy is included in the LICENSE file that
013 * accompanied this code).
014 *
015 * You should have received a copy of the GNU General Public License version
016 * 2 along with this work; if not, write to the Free Software Foundation,
017 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
018 *
019 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
020 * or visit www.oracle.com if you need additional information or have any
021 * questions.
022 */
023package com.oracle.graal.phases.common;
024
025import static jdk.internal.jvmci.meta.DeoptimizationAction.*;
026import static jdk.internal.jvmci.meta.DeoptimizationReason.*;
027
028import java.util.*;
029import java.util.function.*;
030
031import com.oracle.graal.debug.*;
032import jdk.internal.jvmci.meta.*;
033
034import com.oracle.graal.compiler.common.cfg.*;
035import com.oracle.graal.compiler.common.type.*;
036import com.oracle.graal.graph.*;
037import com.oracle.graal.nodeinfo.*;
038import com.oracle.graal.nodes.*;
039import com.oracle.graal.nodes.cfg.*;
040import com.oracle.graal.nodes.extended.*;
041import com.oracle.graal.nodes.java.*;
042import com.oracle.graal.nodes.util.*;
043import com.oracle.graal.phases.*;
044import com.oracle.graal.phases.common.LoweringPhase.Frame;
045import com.oracle.graal.phases.schedule.*;
046
047public class DominatorConditionalEliminationPhase extends Phase {
048
049    private static final DebugMetric metricStampsRegistered = Debug.metric("StampsRegistered");
050    private static final DebugMetric metricStampsFound = Debug.metric("StampsFound");
051    private final boolean fullSchedule;
052
053    public DominatorConditionalEliminationPhase(boolean fullSchedule) {
054        this.fullSchedule = fullSchedule;
055    }
056
057    private static final class InfoElement {
058        private final Stamp stamp;
059        private final ValueNode guard;
060
061        public InfoElement(Stamp stamp, ValueNode guard) {
062            this.stamp = stamp;
063            this.guard = guard;
064        }
065
066        public Stamp getStamp() {
067            return stamp;
068        }
069
070        public ValueNode getGuard() {
071            return guard;
072        }
073
074        @Override
075        public String toString() {
076            return stamp + " -> " + guard;
077        }
078    }
079
080    private static final class Info {
081        private final ArrayList<InfoElement> infos;
082
083        public Info() {
084            infos = new ArrayList<>();
085        }
086
087        public Iterable<InfoElement> getElements() {
088            return infos;
089        }
090
091        public void pushElement(InfoElement element) {
092            infos.add(element);
093        }
094
095        public void popElement() {
096            infos.remove(infos.size() - 1);
097        }
098    }
099
100    @Override
101    protected void run(StructuredGraph graph) {
102
103        Function<Block, Iterable<? extends Node>> blockToNodes;
104        Function<Node, Block> nodeToBlock;
105        Block startBlock;
106
107        if (fullSchedule) {
108            SchedulePhase schedule = new SchedulePhase(SchedulePhase.SchedulingStrategy.EARLIEST);
109            schedule.apply(graph);
110            ControlFlowGraph cfg = schedule.getCFG();
111            cfg.computePostdominators();
112            blockToNodes = b -> schedule.getBlockToNodesMap().get(b);
113            nodeToBlock = n -> schedule.getNodeToBlockMap().get(n);
114            startBlock = cfg.getStartBlock();
115        } else {
116            ControlFlowGraph cfg = ControlFlowGraph.compute(graph, true, true, true, true);
117            cfg.computePostdominators();
118            BlockMap<List<FixedNode>> nodes = new BlockMap<>(cfg);
119            for (Block b : cfg.getBlocks()) {
120                ArrayList<FixedNode> curNodes = new ArrayList<>();
121                for (FixedNode node : b.getNodes()) {
122                    if (node instanceof AbstractBeginNode || node instanceof FixedGuardNode || node instanceof CheckCastNode || node instanceof ConditionAnchorNode || node instanceof IfNode) {
123                        curNodes.add(node);
124                    }
125                }
126                nodes.put(b, curNodes);
127            }
128            blockToNodes = b -> nodes.get(b);
129            nodeToBlock = n -> cfg.blockFor(n);
130            startBlock = cfg.getStartBlock();
131        }
132
133        Instance instance = new Instance(graph, blockToNodes, nodeToBlock);
134        instance.processBlock(startBlock);
135    }
136
137    private static class Instance {
138
139        private NodeMap<Info> map;
140        private Deque<LoopExitNode> loopExits;
141        private final Function<Block, Iterable<? extends Node>> blockToNodes;
142        private final Function<Node, Block> nodeToBlock;
143
144        public Instance(StructuredGraph graph, Function<Block, Iterable<? extends Node>> blockToNodes, Function<Node, Block> nodeToBlock) {
145            map = graph.createNodeMap();
146            loopExits = new ArrayDeque<>();
147            this.blockToNodes = blockToNodes;
148            this.nodeToBlock = nodeToBlock;
149        }
150
151        public void processBlock(Block startBlock) {
152            LoweringPhase.processBlock(new InstanceFrame(startBlock, null));
153        }
154
155        public class InstanceFrame extends LoweringPhase.Frame<InstanceFrame> {
156            List<Runnable> undoOperations = new ArrayList<>();
157
158            public InstanceFrame(Block block, InstanceFrame parent) {
159                super(block, parent);
160            }
161
162            @Override
163            public Frame<?> enter(Block b) {
164                return new InstanceFrame(b, this);
165            }
166
167            @Override
168            public void preprocess() {
169                Instance.this.preprocess(block, undoOperations);
170            }
171
172            @Override
173            public void postprocess() {
174                Instance.postprocess(undoOperations);
175            }
176        }
177
178        private static void postprocess(List<Runnable> undoOperations) {
179            for (Runnable r : undoOperations) {
180                r.run();
181            }
182        }
183
184        private void preprocess(Block block, List<Runnable> undoOperations) {
185            AbstractBeginNode beginNode = block.getBeginNode();
186            if (beginNode instanceof LoopExitNode && beginNode.isAlive()) {
187                LoopExitNode loopExitNode = (LoopExitNode) beginNode;
188                this.loopExits.push(loopExitNode);
189                undoOperations.add(() -> loopExits.pop());
190            } else if (block.getDominator() != null &&
191                            (block.getDominator().getLoopDepth() > block.getLoopDepth() || (block.getDominator().getLoopDepth() == block.getLoopDepth() && block.getDominator().getLoop() != block.getLoop()))) {
192                // We are exiting the loop, but there is not a single loop exit block along our
193                // dominator tree (e.g., we are a merge of two loop exits).
194                final NodeMap<Info> oldMap = map;
195                final Deque<LoopExitNode> oldLoopExits = loopExits;
196                map = map.graph().createNodeMap();
197                loopExits = new ArrayDeque<>();
198                undoOperations.add(() -> {
199                    map = oldMap;
200                    loopExits = oldLoopExits;
201                });
202            }
203            for (Node n : blockToNodes.apply(block)) {
204                if (n.isAlive()) {
205                    processNode(n, undoOperations);
206                }
207            }
208        }
209
210        private void processNode(Node node, List<Runnable> undoOperations) {
211            if (node instanceof AbstractBeginNode) {
212                processAbstractBegin((AbstractBeginNode) node, undoOperations);
213            } else if (node instanceof FixedGuardNode) {
214                processFixedGuard((FixedGuardNode) node, undoOperations);
215            } else if (node instanceof GuardNode) {
216                processGuard((GuardNode) node, undoOperations);
217            } else if (node instanceof CheckCastNode) {
218                processCheckCast((CheckCastNode) node);
219            } else if (node instanceof ConditionAnchorNode) {
220                processConditionAnchor((ConditionAnchorNode) node);
221            } else if (node instanceof IfNode) {
222                processIf((IfNode) node, undoOperations);
223            } else {
224                return;
225            }
226        }
227
228        private void processCheckCast(CheckCastNode node) {
229            for (InfoElement infoElement : getInfoElements(node.object())) {
230                TriState result = node.tryFold(infoElement.getStamp());
231                if (result.isKnown()) {
232                    if (rewireGuards(infoElement.getGuard(), result.toBoolean(), (guard, checkCastResult) -> {
233                        if (checkCastResult) {
234                            PiNode piNode = node.graph().unique(new PiNode(node.object(), node.stamp(), guard));
235                            node.replaceAtUsages(piNode);
236                            GraphUtil.unlinkFixedNode(node);
237                            node.safeDelete();
238                        } else {
239                            DeoptimizeNode deopt = node.graph().add(new DeoptimizeNode(InvalidateReprofile, UnreachedCode));
240                            node.replaceAtPredecessor(deopt);
241                            GraphUtil.killCFG(node);
242                        }
243                        return true;
244                    })) {
245                        return;
246                    }
247                }
248            }
249        }
250
251        private void processIf(IfNode node, List<Runnable> undoOperations) {
252            tryProofCondition(node.condition(), (guard, result) -> {
253                AbstractBeginNode survivingSuccessor = node.getSuccessor(result);
254                survivingSuccessor.replaceAtUsages(InputType.Guard, guard);
255                survivingSuccessor.replaceAtPredecessor(null);
256                node.replaceAtPredecessor(survivingSuccessor);
257                GraphUtil.killCFG(node);
258                if (survivingSuccessor instanceof BeginNode) {
259                    undoOperations.add(() -> {
260                        if (survivingSuccessor.isAlive()) {
261                            ((BeginNode) survivingSuccessor).trySimplify();
262                        }
263                    });
264                }
265                return true;
266            });
267        }
268
269        private void registerNewCondition(LogicNode condition, boolean negated, ValueNode guard, List<Runnable> undoOperations) {
270            if (condition instanceof UnaryOpLogicNode) {
271                UnaryOpLogicNode unaryLogicNode = (UnaryOpLogicNode) condition;
272                Stamp newStamp = unaryLogicNode.getSucceedingStampForValue(negated);
273                registerNewStamp(unaryLogicNode.getValue(), newStamp, guard, undoOperations);
274            } else if (condition instanceof BinaryOpLogicNode) {
275                BinaryOpLogicNode binaryOpLogicNode = (BinaryOpLogicNode) condition;
276                ValueNode x = binaryOpLogicNode.getX();
277                if (!x.isConstant()) {
278                    Stamp newStampX = binaryOpLogicNode.getSucceedingStampForX(negated);
279                    registerNewStamp(x, newStampX, guard, undoOperations);
280                }
281
282                ValueNode y = binaryOpLogicNode.getY();
283                if (!y.isConstant()) {
284                    Stamp newStampY = binaryOpLogicNode.getSucceedingStampForY(negated);
285                    registerNewStamp(y, newStampY, guard, undoOperations);
286                }
287            }
288            registerCondition(condition, negated, guard, undoOperations);
289        }
290
291        private void registerCondition(LogicNode condition, boolean negated, ValueNode guard, List<Runnable> undoOperations) {
292            registerNewStamp(condition, negated ? StampFactory.contradiction() : StampFactory.tautology(), guard, undoOperations);
293        }
294
295        private Iterable<InfoElement> getInfoElements(ValueNode proxiedValue) {
296            ValueNode value = GraphUtil.unproxify(proxiedValue);
297            Info info = map.get(value);
298            if (info == null) {
299                return Collections.emptyList();
300            } else {
301                return info.getElements();
302            }
303        }
304
305        private boolean rewireGuards(ValueNode guard, boolean result, GuardRewirer rewireGuardFunction) {
306            assert guard instanceof GuardingNode;
307            metricStampsFound.increment();
308            ValueNode proxiedGuard = proxyGuard(guard);
309            return rewireGuardFunction.rewire(proxiedGuard, result);
310        }
311
312        private ValueNode proxyGuard(ValueNode guard) {
313            ValueNode proxiedGuard = guard;
314            if (!this.loopExits.isEmpty()) {
315                while (proxiedGuard instanceof GuardProxyNode) {
316                    proxiedGuard = ((GuardProxyNode) proxiedGuard).value();
317                }
318                Block guardBlock = nodeToBlock.apply(proxiedGuard);
319                assert guardBlock != null;
320                for (Iterator<LoopExitNode> iter = loopExits.descendingIterator(); iter.hasNext();) {
321                    LoopExitNode loopExitNode = iter.next();
322                    Block loopExitBlock = nodeToBlock.apply(loopExitNode);
323                    if (guardBlock != loopExitBlock && AbstractControlFlowGraph.dominates(guardBlock, loopExitBlock)) {
324                        Block loopBeginBlock = nodeToBlock.apply(loopExitNode.loopBegin());
325                        if (!AbstractControlFlowGraph.dominates(guardBlock, loopBeginBlock) || guardBlock == loopBeginBlock) {
326                            proxiedGuard = proxiedGuard.graph().unique(new GuardProxyNode((GuardingNode) proxiedGuard, loopExitNode));
327                        }
328                    }
329                }
330            }
331            return proxiedGuard;
332        }
333
334        @FunctionalInterface
335        private interface GuardRewirer {
336            /**
337             * Called if the condition could be proven to have a constant value ({@code result})
338             * under {@code guard}.
339             *
340             * Return whether a transformation could be applied.
341             */
342            boolean rewire(ValueNode guard, boolean result);
343        }
344
345        private boolean tryProofCondition(LogicNode node, GuardRewirer rewireGuardFunction) {
346            for (InfoElement infoElement : getInfoElements(node)) {
347                Stamp stamp = infoElement.getStamp();
348                JavaConstant constant = (JavaConstant) stamp.asConstant();
349                if (constant != null) {
350                    return rewireGuards(infoElement.getGuard(), constant.asBoolean(), rewireGuardFunction);
351                }
352            }
353            if (node instanceof UnaryOpLogicNode) {
354                UnaryOpLogicNode unaryLogicNode = (UnaryOpLogicNode) node;
355                ValueNode value = unaryLogicNode.getValue();
356                for (InfoElement infoElement : getInfoElements(value)) {
357                    Stamp stamp = infoElement.getStamp();
358                    TriState result = unaryLogicNode.tryFold(stamp);
359                    if (result.isKnown()) {
360                        return rewireGuards(infoElement.getGuard(), result.toBoolean(), rewireGuardFunction);
361                    }
362                }
363            } else if (node instanceof BinaryOpLogicNode) {
364                BinaryOpLogicNode binaryOpLogicNode = (BinaryOpLogicNode) node;
365                for (InfoElement infoElement : getInfoElements(binaryOpLogicNode)) {
366                    if (infoElement.getStamp().equals(StampFactory.contradiction())) {
367                        return rewireGuards(infoElement.getGuard(), false, rewireGuardFunction);
368                    } else if (infoElement.getStamp().equals(StampFactory.tautology())) {
369                        return rewireGuards(infoElement.getGuard(), true, rewireGuardFunction);
370                    }
371                }
372
373                ValueNode x = binaryOpLogicNode.getX();
374                ValueNode y = binaryOpLogicNode.getY();
375                for (InfoElement infoElement : getInfoElements(x)) {
376                    TriState result = binaryOpLogicNode.tryFold(infoElement.getStamp(), y.stamp());
377                    if (result.isKnown()) {
378                        return rewireGuards(infoElement.getGuard(), result.toBoolean(), rewireGuardFunction);
379                    }
380                }
381
382                for (InfoElement infoElement : getInfoElements(y)) {
383                    TriState result = binaryOpLogicNode.tryFold(x.stamp(), infoElement.getStamp());
384                    if (result.isKnown()) {
385                        return rewireGuards(infoElement.getGuard(), result.toBoolean(), rewireGuardFunction);
386                    }
387                }
388            } else if (node instanceof ShortCircuitOrNode) {
389                final ShortCircuitOrNode shortCircuitOrNode = (ShortCircuitOrNode) node;
390                if (this.loopExits.isEmpty()) {
391                    return tryProofCondition(shortCircuitOrNode.getX(), (guard, result) -> {
392                        if (result == !shortCircuitOrNode.isXNegated()) {
393                            return rewireGuards(guard, true, rewireGuardFunction);
394                        } else {
395                            return tryProofCondition(shortCircuitOrNode.getY(), (innerGuard, innerResult) -> {
396                                if (innerGuard == guard) {
397                                    return rewireGuards(guard, innerResult ^ shortCircuitOrNode.isYNegated(), rewireGuardFunction);
398                                }
399                                return false;
400                            });
401                        }
402                    });
403                }
404            }
405
406            return false;
407        }
408
409        private void registerNewStamp(ValueNode proxiedValue, Stamp newStamp, ValueNode guard, List<Runnable> undoOperations) {
410            if (newStamp != null) {
411                ValueNode value = GraphUtil.unproxify(proxiedValue);
412                Info info = map.get(value);
413                if (info == null) {
414                    info = new Info();
415                    map.set(value, info);
416                }
417                metricStampsRegistered.increment();
418                final Info finalInfo = info;
419                finalInfo.pushElement(new InfoElement(newStamp, guard));
420                undoOperations.add(() -> finalInfo.popElement());
421            }
422        }
423
424        private void processConditionAnchor(ConditionAnchorNode node) {
425            tryProofCondition(node.condition(), (guard, result) -> {
426                if (result != node.isNegated()) {
427                    node.replaceAtUsages(guard);
428                    GraphUtil.unlinkFixedNode(node);
429                    GraphUtil.killWithUnusedFloatingInputs(node);
430                } else {
431                    ValueAnchorNode valueAnchor = node.graph().add(new ValueAnchorNode(null));
432                    node.replaceAtUsages(valueAnchor);
433                    node.graph().replaceFixedWithFixed(node, valueAnchor);
434                }
435                return true;
436            });
437        }
438
439        private void processGuard(GuardNode node, List<Runnable> undoOperations) {
440            if (!tryProofCondition(node.condition(), (guard, result) -> {
441                if (result != node.isNegated()) {
442                    node.replaceAndDelete(guard);
443                } else {
444                    DeoptimizeNode deopt = node.graph().add(new DeoptimizeNode(node.action(), node.reason()));
445                    AbstractBeginNode beginNode = (AbstractBeginNode) node.getAnchor();
446                    FixedNode next = beginNode.next();
447                    beginNode.setNext(deopt);
448                    GraphUtil.killCFG(next);
449                }
450                return true;
451            })) {
452                registerNewCondition(node.condition(), node.isNegated(), node, undoOperations);
453            }
454        }
455
456        private void processFixedGuard(FixedGuardNode node, List<Runnable> undoOperations) {
457            if (!tryProofCondition(node.condition(), (guard, result) -> {
458                if (result != node.isNegated()) {
459                    node.replaceAtUsages(guard);
460                    GraphUtil.unlinkFixedNode(node);
461                    GraphUtil.killWithUnusedFloatingInputs(node);
462                } else {
463                    DeoptimizeNode deopt = node.graph().add(new DeoptimizeNode(node.getAction(), node.getReason(), node.getSpeculation()));
464                    deopt.setStateBefore(node.stateBefore());
465                    node.replaceAtPredecessor(deopt);
466                    GraphUtil.killCFG(node);
467                }
468                return true;
469            })) {
470                registerNewCondition(node.condition(), node.isNegated(), node, undoOperations);
471            }
472        }
473
474        private void processAbstractBegin(AbstractBeginNode beginNode, List<Runnable> undoOperations) {
475            Node predecessor = beginNode.predecessor();
476            if (predecessor instanceof IfNode) {
477                IfNode ifNode = (IfNode) predecessor;
478                boolean negated = (ifNode.falseSuccessor() == beginNode);
479                LogicNode condition = ifNode.condition();
480                registerNewCondition(condition, negated, beginNode, undoOperations);
481            } else if (predecessor instanceof TypeSwitchNode) {
482                TypeSwitchNode typeSwitch = (TypeSwitchNode) predecessor;
483                processTypeSwitch(beginNode, undoOperations, predecessor, typeSwitch);
484            } else if (predecessor instanceof IntegerSwitchNode) {
485                IntegerSwitchNode integerSwitchNode = (IntegerSwitchNode) predecessor;
486                processIntegerSwitch(beginNode, undoOperations, predecessor, integerSwitchNode);
487            }
488        }
489
490        private void processIntegerSwitch(AbstractBeginNode beginNode, List<Runnable> undoOperations, Node predecessor, IntegerSwitchNode integerSwitchNode) {
491            Stamp stamp = null;
492            for (int i = 0; i < integerSwitchNode.keyCount(); i++) {
493                if (integerSwitchNode.keySuccessor(i) == predecessor) {
494                    if (stamp == null) {
495                        stamp = StampFactory.forConstant(integerSwitchNode.keyAt(i));
496                    } else {
497                        stamp = stamp.meet(StampFactory.forConstant(integerSwitchNode.keyAt(i)));
498                    }
499                }
500            }
501
502            if (stamp != null) {
503                registerNewStamp(integerSwitchNode.value(), stamp, beginNode, undoOperations);
504            }
505        }
506
507        private void processTypeSwitch(AbstractBeginNode beginNode, List<Runnable> undoOperations, Node predecessor, TypeSwitchNode typeSwitch) {
508            ValueNode hub = typeSwitch.value();
509            if (hub instanceof LoadHubNode) {
510                LoadHubNode loadHub = (LoadHubNode) hub;
511                Stamp stamp = null;
512                for (int i = 0; i < typeSwitch.keyCount(); i++) {
513                    if (typeSwitch.keySuccessor(i) == predecessor) {
514                        if (stamp == null) {
515                            stamp = StampFactory.exactNonNull(typeSwitch.typeAt(i));
516                        } else {
517                            stamp = stamp.meet(StampFactory.exactNonNull(typeSwitch.typeAt(i)));
518                        }
519                    }
520                }
521                if (stamp != null) {
522                    registerNewStamp(loadHub.getValue(), stamp, beginNode, undoOperations);
523                }
524            }
525        }
526    }
527}