changeset 19623:490f2c54c28a

Support for creating merges during partial evaluation of bytecode interpreters.
author Thomas Wuerthinger <thomas.wuerthinger@oracle.com>
date Fri, 27 Feb 2015 16:01:12 +0100
parents 0e90dbf0b9fd
children cdbb43aae6fd
files graal/com.oracle.graal.java/src/com/oracle/graal/java/GraphBuilderPhase.java graal/com.oracle.graal.java/src/com/oracle/graal/java/HIRFrameStateBuilder.java graal/com.oracle.graal.truffle.test/src/com/oracle/graal/truffle/test/BytecodeInterpreterPartialEvaluationTest.java
diffstat 3 files changed, 123 insertions(+), 23 deletions(-) [+]
line wrap: on
line diff
--- a/graal/com.oracle.graal.java/src/com/oracle/graal/java/GraphBuilderPhase.java	Fri Feb 27 14:06:36 2015 +0100
+++ b/graal/com.oracle.graal.java/src/com/oracle/graal/java/GraphBuilderPhase.java	Fri Feb 27 16:01:12 2015 +0100
@@ -204,6 +204,7 @@
             private FixedWithNextNode lastInstr;                 // the last instruction added
             private final boolean explodeLoops;
             private final boolean mergeExplosions;
+            private final Map<HIRFrameStateBuilder, Integer> mergeExplosionsMap;
             private Stack<ExplodedLoopContext> explodeLoopsContext;
             private int nextPeelIteration = 1;
             private boolean controlFlowSplit;
@@ -233,12 +234,15 @@
                     explodeLoops = loopExplosionPlugin.shouldExplodeLoops(method);
                     if (explodeLoops) {
                         mergeExplosions = loopExplosionPlugin.shouldMergeExplosions(method);
+                        mergeExplosionsMap = new HashMap<>();
                     } else {
                         mergeExplosions = false;
+                        mergeExplosionsMap = null;
                     }
                 } else {
                     explodeLoops = false;
                     mergeExplosions = false;
+                    mergeExplosionsMap = null;
                 }
             }
 
@@ -349,6 +353,9 @@
                 ExplodedLoopContext context = new ExplodedLoopContext();
                 context.header = header;
                 context.peelIteration = this.getCurrentDimension();
+                if (this.mergeExplosions) {
+                    this.addToMergeCache(getEntryState(context.header, context.peelIteration), context.peelIteration);
+                }
                 explodeLoopsContext.push(context);
                 if (Debug.isDumpEnabled() && DumpDuringGraphBuilding.getValue()) {
                     Debug.dump(currentGraph, "before loop explosion dimension " + context.peelIteration);
@@ -359,9 +366,12 @@
                 return header.loopEnd + 1;
             }
 
+            private void addToMergeCache(HIRFrameStateBuilder key, int dimension) {
+                mergeExplosionsMap.put(key, dimension);
+            }
+
             private void peelIteration(BciBlock[] blocks, BciBlock header, ExplodedLoopContext context) {
                 while (true) {
-
                     processBlock(this, header);
                     for (int j = header.getId() + 1; j <= header.loopEnd; ++j) {
                         BciBlock block = blocks[j];
@@ -1372,7 +1382,7 @@
                 assert block != null && state != null;
                 assert !block.isExceptionEntry || state.stackSize() == 1;
 
-                int operatingDimension = findOperatingDimension(block);
+                int operatingDimension = findOperatingDimension(block, state);
 
                 if (getFirstInstruction(block, operatingDimension) == null) {
                     /*
@@ -1417,8 +1427,8 @@
                     Debug.log("createTarget %s: merging backward branch to loop header %s, result: %s", block, loopBegin, result);
                     return result;
                 }
-                assert currentBlock == null || currentBlock.getId() < block.getId() : "must not be backward branch";
-                assert getFirstInstruction(block, operatingDimension).next() == null : "bytecodes already parsed for block";
+                assert currentBlock == null || currentBlock.getId() < block.getId() || this.mergeExplosions : "must not be backward branch";
+                assert getFirstInstruction(block, operatingDimension).next() == null || this.mergeExplosions : "bytecodes already parsed for block";
 
                 if (getFirstInstruction(block, operatingDimension) instanceof AbstractBeginNode && !(getFirstInstruction(block, operatingDimension) instanceof AbstractMergeNode)) {
                     /*
@@ -1459,19 +1469,27 @@
                 return result;
             }
 
-            private int findOperatingDimension(BciBlock block) {
+            private int findOperatingDimension(BciBlock block, HIRFrameStateBuilder state) {
                 if (this.explodeLoops && this.explodeLoopsContext != null && !this.explodeLoopsContext.isEmpty()) {
-                    return findOperatingDimensionWithLoopExplosion(block);
+                    return findOperatingDimensionWithLoopExplosion(block, state);
                 }
                 return this.getCurrentDimension();
             }
 
-            private int findOperatingDimensionWithLoopExplosion(BciBlock block) {
+            private int findOperatingDimensionWithLoopExplosion(BciBlock block, HIRFrameStateBuilder state) {
                 int i;
                 for (i = explodeLoopsContext.size() - 1; i >= 0; --i) {
                     ExplodedLoopContext context = explodeLoopsContext.elementAt(i);
                     if (context.header == block) {
 
+                        if (this.mergeExplosions) {
+                            state.clearNonLiveLocals(block, liveness, true);
+                            Integer cachedDimension = mergeExplosionsMap.get(state);
+                            if (cachedDimension != null) {
+                                return cachedDimension;
+                            }
+                        }
+
                         // We have a hit on our current explosion context loop begin.
                         if (context.targetPeelIteration == null) {
                             context.targetPeelIteration = new int[1];
@@ -1482,6 +1500,9 @@
                         // This is the first hit => allocate a new dimension and at the same
                         // time mark the context loop begin as hit during the current
                         // iteration.
+                        if (this.mergeExplosions) {
+                            this.addToMergeCache(state.copy(), nextPeelIteration);
+                        }
                         context.targetPeelIteration[context.targetPeelIteration.length - 1] = nextPeelIteration++;
                         if (nextPeelIteration > MaximumLoopExplosionCount.getValue()) {
                             String message = "too many loop explosion interations - does the explosion not terminate for method " + method + "?";
--- a/graal/com.oracle.graal.java/src/com/oracle/graal/java/HIRFrameStateBuilder.java	Fri Feb 27 14:06:36 2015 +0100
+++ b/graal/com.oracle.graal.java/src/com/oracle/graal/java/HIRFrameStateBuilder.java	Fri Feb 27 16:01:12 2015 +0100
@@ -873,4 +873,61 @@
         assert xpeek() == null;
         return true;
     }
+
+    @Override
+    public int hashCode() {
+        int result = hashCode(locals, locals.length);
+        result *= 13;
+        result += hashCode(stack, this.stackSize);
+        return result;
+    }
+
+    private static int hashCode(Object[] a, int length) {
+        int result = 1;
+        for (int i = 0; i < length; ++i) {
+            Object element = a[i];
+            result = 31 * result + (element == null ? 0 : System.identityHashCode(element));
+        }
+        return result;
+    }
+
+    private static boolean equals(ValueNode[] a, ValueNode[] b, int length) {
+        for (int i = 0; i < length; ++i) {
+            if (a[i] != b[i]) {
+                return false;
+            }
+        }
+        return true;
+    }
+
+    @Override
+    public boolean equals(Object otherObject) {
+        if (otherObject instanceof HIRFrameStateBuilder) {
+            HIRFrameStateBuilder other = (HIRFrameStateBuilder) otherObject;
+            if (other.method != method) {
+                return false;
+            }
+            if (other.stackSize != stackSize) {
+                return false;
+            }
+            if (other.checkTypes != checkTypes) {
+                return false;
+            }
+            if (other.rethrowException != rethrowException) {
+                return false;
+            }
+            if (other.graph != graph) {
+                return false;
+            }
+            if (other.outerFrameStateSupplier != outerFrameStateSupplier) {
+                return false;
+            }
+            if (other.locals.length != locals.length) {
+                return false;
+            }
+            return equals(other.locals, locals, locals.length) && equals(other.stack, stack, stackSize) && equals(other.lockedObjects, lockedObjects, lockedObjects.length) &&
+                            equals(other.monitorIds, monitorIds, monitorIds.length);
+        }
+        return false;
+    }
 }
--- a/graal/com.oracle.graal.truffle.test/src/com/oracle/graal/truffle/test/BytecodeInterpreterPartialEvaluationTest.java	Fri Feb 27 14:06:36 2015 +0100
+++ b/graal/com.oracle.graal.truffle.test/src/com/oracle/graal/truffle/test/BytecodeInterpreterPartialEvaluationTest.java	Fri Feb 27 16:01:12 2015 +0100
@@ -26,7 +26,6 @@
 
 import com.oracle.truffle.api.*;
 import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
-import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
 import com.oracle.truffle.api.frame.*;
 import com.oracle.truffle.api.nodes.*;
 
@@ -37,14 +36,17 @@
         public static final byte RETURN = 1;
         public static final byte ADD = 2;
         public static final byte IFZERO = 3;
+        public static final byte POP = 4;
     }
 
     public static class Program extends RootNode {
-        @CompilationFinal final byte[] bytecodes;
+        private final String name;
+        @CompilationFinal private final byte[] bytecodes;
         @CompilationFinal private final FrameSlot[] locals;
         @CompilationFinal private final FrameSlot[] stack;
 
-        public Program(byte[] bytecodes, int maxLocals, int maxStack) {
+        public Program(String name, byte[] bytecodes, int maxLocals, int maxStack) {
+            this.name = name;
             this.bytecodes = bytecodes;
             locals = new FrameSlot[maxLocals];
             stack = new FrameSlot[maxStack];
@@ -70,9 +72,9 @@
             }
         }
 
-        @TruffleBoundary
-        public void print(String name, int value) {
-            System.out.println(name + "=" + value);
+        @Override
+        public String toString() {
+            return name;
         }
 
         @Override
@@ -89,14 +91,14 @@
                         value = bytecodes[bci + 1];
                         setInt(frame, ++topOfStack, value);
                         bci = bci + 2;
-                        break;
+                        continue;
                     case Bytecode.RETURN:
                         return getInt(frame, topOfStack);
                     case Bytecode.ADD:
                         setInt(frame, topOfStack - 1, getInt(frame, topOfStack) + getInt(frame, topOfStack - 1));
                         topOfStack--;
                         bci = bci + 1;
-                        break;
+                        continue;
                     case Bytecode.IFZERO:
                         if (getInt(frame, topOfStack--) == 0) {
                             bci = bytecodes[bci + 1];
@@ -105,6 +107,10 @@
                             bci = bci + 2;
                             continue;
                         }
+                    case Bytecode.POP:
+                        topOfStack--;
+                        bci++;
+                        continue;
                 }
             }
         }
@@ -115,16 +121,16 @@
     }
 
     @Test
-    public void simpleProgram() {
+    public void constReturnProgram() {
         byte[] bytecodes = new byte[]{
         /* 0: */Bytecode.CONST,
         /* 1: */42,
         /* 2: */Bytecode.RETURN};
-        assertPartialEvalEquals("constant42", new Program(bytecodes, 0, 2));
+        assertPartialEvalEquals("constant42", new Program("constReturnProgram", bytecodes, 0, 2));
     }
 
     @Test
-    public void simpleProgramWithAdd() {
+    public void constAddProgram() {
         byte[] bytecodes = new byte[]{
         /* 0: */Bytecode.CONST,
         /* 1: */40,
@@ -132,11 +138,11 @@
         /* 3: */2,
         /* 4: */Bytecode.ADD,
         /* 5: */Bytecode.RETURN};
-        assertPartialEvalEquals("constant42", new Program(bytecodes, 0, 2));
+        assertPartialEvalEquals("constant42", new Program("constAddProgram", bytecodes, 0, 2));
     }
 
     @Test
-    public void simpleProgramWithIf() {
+    public void simpleIfProgram() {
         byte[] bytecodes = new byte[]{
         /* 0: */Bytecode.CONST,
         /* 1: */40,
@@ -147,11 +153,27 @@
         /* 6: */Bytecode.CONST,
         /* 7: */42,
         /* 8: */Bytecode.RETURN};
-        assertPartialEvalEquals("constant42", new Program(bytecodes, 0, 3));
+        assertPartialEvalEquals("constant42", new Program("simpleIfProgram", bytecodes, 0, 3));
+    }
+
+    @Test
+    public void ifAndPopProgram() {
+        byte[] bytecodes = new byte[]{
+        /* 0: */Bytecode.CONST,
+        /* 1: */40,
+        /* 2: */Bytecode.CONST,
+        /* 3: */1,
+        /* 4: */Bytecode.IFZERO,
+        /* 5: */9,
+        /* 6: */Bytecode.POP,
+        /* 7: */Bytecode.CONST,
+        /* 8: */42,
+        /* 9: */Bytecode.RETURN};
+        assertPartialEvalEquals("constant42", new Program("ifAndPopProgram", bytecodes, 0, 3));
     }
 
     @Test(timeout = 1000)
-    public void simpleProgramWithManyIfs() {
+    public void manyIfsProgram() {
         byte[] bytecodes = new byte[]{
         /* 0: */Bytecode.CONST,
         /* 1: */40,
@@ -198,6 +220,6 @@
         /* 42: */Bytecode.CONST,
         /* 43: */42,
         /* 44: */Bytecode.RETURN};
-        assertPartialEvalEquals("constant42", new Program(bytecodes, 0, 3));
+        assertPartialEvalEquals("constant42", new Program("manyIfsProgram", bytecodes, 0, 3));
     }
 }