changeset 15640:2208a130d636

HSAIL Deopt support for VirtualObjects. Only create the host graph is there are deopts. Add a test provided by Tom Deneau.
author Gilles Duboscq <duboscq@ssw.jku.at>
date Sun, 04 May 2014 18:58:16 +0200
parents 19ec9885ce6e
children e500d6900328
files graal/com.oracle.graal.compiler.hsail.test.infra/src/com/oracle/graal/compiler/hsail/test/infra/GraalKernelTester.java graal/com.oracle.graal.compiler.hsail.test/src/com/oracle/graal/compiler/hsail/test/lambda/VecmathNBodyDeoptTest.java graal/com.oracle.graal.hotspot.hsail/src/com/oracle/graal/hotspot/hsail/HSAILHotSpotBackend.java graal/com.oracle.graal.virtual/src/com/oracle/graal/virtual/nodes/VirtualObjectState.java
diffstat 4 files changed, 223 insertions(+), 32 deletions(-) [+]
line wrap: on
line diff
--- a/graal/com.oracle.graal.compiler.hsail.test.infra/src/com/oracle/graal/compiler/hsail/test/infra/GraalKernelTester.java	Wed May 14 12:37:39 2014 +0200
+++ b/graal/com.oracle.graal.compiler.hsail.test.infra/src/com/oracle/graal/compiler/hsail/test/infra/GraalKernelTester.java	Sun May 04 18:58:16 2014 +0200
@@ -135,7 +135,7 @@
      * with HSAIL code.
      */
     public boolean canHandleDeoptVirtualObjects() {
-        return false;
+        return true;
     }
 
     /**
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/graal/com.oracle.graal.compiler.hsail.test/src/com/oracle/graal/compiler/hsail/test/lambda/VecmathNBodyDeoptTest.java	Sun May 04 18:58:16 2014 +0200
@@ -0,0 +1,138 @@
+/*
+ * Copyright (c) 2009, 2012, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+package com.oracle.graal.compiler.hsail.test.lambda;
+
+import java.util.*;
+import org.junit.*;
+import com.oracle.graal.compiler.hsail.test.infra.GraalKernelTester;
+import javax.vecmath.*;
+
+/**
+ * Tests NBody algorithm using the javax.vecmath package (all objects non-escaping).
+ */
+public class VecmathNBodyDeoptTest extends GraalKernelTester {
+    static final int bodies = 1024;
+    static final float delT = .005f;
+    static final float espSqr = 1.0f;
+    static final float mass = 5f;
+    static final int width = 768;
+    static final int height = 768;
+
+    static class Body extends Vector3f {
+
+        /**
+         *
+         */
+        private static final long serialVersionUID = 1L;
+
+        public Body(float _x, float _y, float _z, float _m) {
+            super(_x, _y, _z);
+            m = _m;
+            v = new Vector3f(0, 0, 0);
+        }
+
+        float m;
+        Vector3f v;
+
+        public float getM() {
+            return m;
+        }
+
+        public Vector3f computeAcc(Body[] in_bodies, float espSqr1, float delT1) {
+            Vector3f acc = new Vector3f();
+
+            for (Body b : in_bodies) {
+                Vector3f d = new Vector3f();
+                d.sub(b, this);
+                float invDist = 1.0f / (float) Math.sqrt(d.lengthSquared() + espSqr1);
+                float s = b.getM() * invDist * invDist * invDist;
+                acc.scaleAdd(s, d, acc);
+            }
+
+            // now return acc scaled by delT
+            acc.scale(delT1);
+            return acc;
+        }
+    }
+
+    @Result Body[] in_bodies = new Body[bodies];
+    @Result Body[] out_bodies = new Body[bodies];
+
+    static Body[] seed_bodies = new Body[bodies];
+
+    static {
+        java.util.Random randgen = new Random(0);
+        final float maxDist = width / 4;
+        for (int body = 0; body < bodies; body++) {
+            final float theta = (float) (randgen.nextFloat() * Math.PI * 2);
+            final float phi = (float) (randgen.nextFloat() * Math.PI * 2);
+            final float radius = randgen.nextFloat() * maxDist;
+            float x = (float) (radius * Math.cos(theta) * Math.sin(phi)) + width / 2;
+            float y = (float) (radius * Math.sin(theta) * Math.sin(phi)) + height / 2;
+            float z = (float) (radius * Math.cos(phi));
+            seed_bodies[body] = new Body(x, y, z, mass);
+        }
+    }
+
+    @Override
+    public void runTest() {
+        System.arraycopy(seed_bodies, 0, in_bodies, 0, seed_bodies.length);
+        for (int b = 0; b < bodies; b++) {
+            out_bodies[b] = new Body(0, 0, 0, mass);
+        }
+        // no local copies of arrays so we make it an instance lambda
+
+        dispatchLambdaKernel(bodies, (gid) -> {
+            Body inb = in_bodies[gid];
+            Body outb = out_bodies[gid];
+            Vector3f acc = inb.computeAcc(in_bodies, espSqr, delT);
+
+            Vector3f tmpPos = new Vector3f();
+            tmpPos.scaleAdd(delT, inb.v, inb);
+            if (gid == bodies / 2) {
+                tmpPos.x += forceDeopt(gid);
+            }
+            tmpPos.scaleAdd(0.5f * delT, acc, tmpPos);
+            outb.set(tmpPos);
+
+            outb.v.add(inb.v, acc);
+        });
+    }
+
+    @Override
+    protected boolean supportsRequiredCapabilities() {
+        return (canHandleDeoptVirtualObjects());
+    }
+
+    @Test
+    public void test() {
+        testGeneratedHsail();
+    }
+
+    @Test
+    public void testUsingLambdaMethod() {
+        testGeneratedHsailUsingLambdaMethod();
+    }
+
+}
--- a/graal/com.oracle.graal.hotspot.hsail/src/com/oracle/graal/hotspot/hsail/HSAILHotSpotBackend.java	Wed May 14 12:37:39 2014 +0200
+++ b/graal/com.oracle.graal.hotspot.hsail/src/com/oracle/graal/hotspot/hsail/HSAILHotSpotBackend.java	Sun May 04 18:58:16 2014 +0200
@@ -32,6 +32,9 @@
 
 import java.lang.reflect.*;
 import java.util.*;
+import java.util.Map.Entry;
+import java.util.function.*;
+import java.util.stream.*;
 
 import com.amd.okra.*;
 import com.oracle.graal.api.code.*;
@@ -69,9 +72,11 @@
 import com.oracle.graal.nodes.extended.*;
 import com.oracle.graal.nodes.java.*;
 import com.oracle.graal.nodes.spi.*;
+import com.oracle.graal.nodes.virtual.*;
 import com.oracle.graal.options.*;
 import com.oracle.graal.phases.*;
 import com.oracle.graal.phases.tiers.*;
+import com.oracle.graal.virtual.nodes.*;
 
 /**
  * HSAIL specific backend.
@@ -266,7 +271,7 @@
         StructuredGraph hostGraph = hsailCode.getHostGraph();
         if (hostGraph != null) {
             // TODO get rid of the unverified entry point in the host code
-            try (Scope ds = Debug.scope("GeneratingHostGraph")) {
+            try (Scope ds = Debug.scope("GeneratingHostGraph", new DebugDumpScope("HostGraph"))) {
                 HotSpotBackend hostBackend = getRuntime().getHostBackend();
                 JavaType[] parameterTypes = new JavaType[hostGraph.getNodes(ParameterNode.class).count()];
                 Debug.log("Param count: %d", parameterTypes.length);
@@ -726,6 +731,8 @@
         asm.emitString(spillsegStringFinal, spillsegDeclarationPosition);
         // Emit the epilogue.
 
+        HSAILHotSpotLIRGenerationResult lirGenRes = ((HSAILCompilationResultBuilder) crb).lirGenRes;
+
         int numSRegs = 0;
         int numDRegs = 0;
         int numStackSlotBytes = 0;
@@ -736,31 +743,39 @@
             Set<Register> infoUsedRegs = new TreeSet<>();
             Set<StackSlot> infoUsedStackSlots = new HashSet<>();
             List<Infopoint> infoList = crb.compilationResult.getInfopoints();
+            Queue<Value[]> workList = new LinkedList<>();
             for (Infopoint info : infoList) {
                 BytecodeFrame frame = info.debugInfo.frame();
                 while (frame != null) {
-                    for (int i = 0; i < frame.numLocals + frame.numStack; i++) {
-                        Value val = frame.values[i];
-                        if (isLegal(val)) {
-                            if (isRegister(val)) {
-                                Register reg = asRegister(val);
-                                infoUsedRegs.add(reg);
-                                if (hsailRegConfig.isAllocatableSReg(reg)) {
-                                    numSRegs = Math.max(numSRegs, reg.encoding + 1);
-                                } else if (hsailRegConfig.isAllocatableDReg(reg)) {
-                                    numDRegs = Math.max(numDRegs, reg.encoding + 1);
-                                }
-                            } else if (isStackSlot(val)) {
-                                StackSlot slot = asStackSlot(val);
-                                Kind slotKind = slot.getKind();
-                                int slotSizeBytes = (slotKind.isObject() ? 8 : slotKind.getByteCount());
-                                int slotOffsetMax = HSAIL.getStackOffsetStart(slot, slotSizeBytes * 8) + slotSizeBytes;
-                                numStackSlotBytes = Math.max(numStackSlotBytes, slotOffsetMax);
-                                infoUsedStackSlots.add(slot);
+                    workList.add(frame.values);
+                    frame = frame.caller();
+                }
+            }
+            while (!workList.isEmpty()) {
+                Value[] values = workList.poll();
+                for (Value val : values) {
+                    if (isLegal(val)) {
+                        if (isRegister(val)) {
+                            Register reg = asRegister(val);
+                            infoUsedRegs.add(reg);
+                            if (hsailRegConfig.isAllocatableSReg(reg)) {
+                                numSRegs = Math.max(numSRegs, reg.encoding + 1);
+                            } else if (hsailRegConfig.isAllocatableDReg(reg)) {
+                                numDRegs = Math.max(numDRegs, reg.encoding + 1);
                             }
+                        } else if (isStackSlot(val)) {
+                            StackSlot slot = asStackSlot(val);
+                            Kind slotKind = slot.getKind();
+                            int slotSizeBytes = (slotKind.isObject() ? 8 : slotKind.getByteCount());
+                            int slotOffsetMax = HSAIL.getStackOffsetStart(slot, slotSizeBytes * 8) + slotSizeBytes;
+                            numStackSlotBytes = Math.max(numStackSlotBytes, slotOffsetMax);
+                            infoUsedStackSlots.add(slot);
+                        } else if (isVirtualObject(val)) {
+                            workList.add(((VirtualObject) val).getValues());
+                        } else {
+                            assert isConstant(val) : "Unsupported value: " + val;
                         }
                     }
-                    frame = frame.caller();
                 }
             }
 
@@ -923,8 +938,9 @@
         asm.emitString0("}; \n");
 
         ExternalCompilationResult compilationResult = (ExternalCompilationResult) crb.compilationResult;
-        HSAILHotSpotLIRGenerationResult lirGenRes = ((HSAILCompilationResultBuilder) crb).lirGenRes;
-        compilationResult.setHostGraph(prepareHostGraph(method, lirGenRes.getDeopts(), getProviders(), config, numSRegs, numDRegs));
+        if (useHSAILDeoptimization) {
+            compilationResult.setHostGraph(prepareHostGraph(method, lirGenRes.getDeopts(), getProviders(), config, numSRegs, numDRegs));
+        }
     }
 
     private static class OopMapArrayBuilder {
@@ -1090,26 +1106,53 @@
     }
 
     private static FrameState createFrameState(BytecodeFrame lowLevelFrame, ParameterNode hsailFrame, HotSpotProviders providers, HotSpotVMConfig config, int numSRegs, int numDRegs) {
+        return createFrameState(lowLevelFrame, hsailFrame, providers, config, numSRegs, numDRegs, new HashMap<VirtualObject, VirtualObjectNode>());
+    }
+
+    private static FrameState createFrameState(BytecodeFrame lowLevelFrame, ParameterNode hsailFrame, HotSpotProviders providers, HotSpotVMConfig config, int numSRegs, int numDRegs,
+                    Map<VirtualObject, VirtualObjectNode> virtualObjects) {
+        FrameState outterFrameState = null;
+        if (lowLevelFrame.caller() != null) {
+            outterFrameState = createFrameState(lowLevelFrame.caller(), hsailFrame, providers, config, numSRegs, numDRegs, virtualObjects);
+        }
         StructuredGraph hostGraph = hsailFrame.graph();
+        Function<? super Value, ? extends ValueNode> lirValueToHirNode = v -> getNodeForValueFromFrame(v, hsailFrame, hostGraph, providers, config, numSRegs, numDRegs, virtualObjects);
         ValueNode[] locals = new ValueNode[lowLevelFrame.numLocals];
         for (int i = 0; i < lowLevelFrame.numLocals; i++) {
-            locals[i] = getNodeForValueFromFrame(lowLevelFrame.getLocalValue(i), hsailFrame, hostGraph, providers, config, numSRegs, numDRegs);
+            locals[i] = lirValueToHirNode.apply(lowLevelFrame.getLocalValue(i));
         }
         List<ValueNode> stack = new ArrayList<>(lowLevelFrame.numStack);
         for (int i = 0; i < lowLevelFrame.numStack; i++) {
-            stack.add(getNodeForValueFromFrame(lowLevelFrame.getStackValue(i), hsailFrame, hostGraph, providers, config, numSRegs, numDRegs));
+            stack.add(lirValueToHirNode.apply(lowLevelFrame.getStackValue(i)));
         }
         ValueNode[] locks = new ValueNode[lowLevelFrame.numLocks];
         MonitorIdNode[] monitorIds = new MonitorIdNode[lowLevelFrame.numLocks];
         for (int i = 0; i < lowLevelFrame.numLocks; i++) {
             HotSpotMonitorValue lockValue = (HotSpotMonitorValue) lowLevelFrame.getLockValue(i);
-            locks[i] = getNodeForValueFromFrame(lockValue, hsailFrame, hostGraph, providers, config, numSRegs, numDRegs);
+            locks[i] = lirValueToHirNode.apply(lockValue);
             monitorIds[i] = getMonitorIdForHotSpotMonitorValueFromFrame(lockValue, hsailFrame, hostGraph);
         }
         FrameState frameState = hostGraph.add(new FrameState(lowLevelFrame.getMethod(), lowLevelFrame.getBCI(), locals, stack, locks, monitorIds, lowLevelFrame.rethrowException, false));
-        if (lowLevelFrame.caller() != null) {
-            frameState.setOuterFrameState(createFrameState(lowLevelFrame.caller(), hsailFrame, providers, config, numSRegs, numDRegs));
+        if (outterFrameState != null) {
+            frameState.setOuterFrameState(outterFrameState);
         }
+        Map<VirtualObject, VirtualObjectNode> virtualObjectsCopy;
+        // TODO this could be implemented more efficiently with a mark into the map
+        // unfortunately LinkedHashMap doesn't seem to provide that.
+        List<VirtualObjectState> virtualStates = new ArrayList<>(virtualObjects.size());
+        do {
+            virtualObjectsCopy = new HashMap<>(virtualObjects);
+            virtualStates.clear();
+            for (Entry<VirtualObject, VirtualObjectNode> entry : virtualObjectsCopy.entrySet()) {
+                VirtualObject virtualObject = entry.getKey();
+                VirtualObjectNode virtualObjectNode = entry.getValue();
+                List<ValueNode> fieldValues = Arrays.stream(virtualObject.getValues()).map(lirValueToHirNode).collect(Collectors.toList());
+                virtualStates.add(new VirtualObjectState(virtualObjectNode, fieldValues));
+            }
+            // New virtual objects may have been discovered while processing the previous set.
+            // Wait until a fixed point is reached
+        } while (virtualObjectsCopy.size() < virtualObjects.size());
+        virtualStates.forEach(vos -> frameState.addVirtualObjectMapping(hostGraph.unique(vos)));
         return frameState;
     }
 
@@ -1122,18 +1165,18 @@
     }
 
     private static ValueNode getNodeForValueFromFrame(Value localValue, ParameterNode hsailFrame, StructuredGraph hostGraph, HotSpotProviders providers, HotSpotVMConfig config, int numSRegs,
-                    int numDRegs) {
+                    int numDRegs, Map<VirtualObject, VirtualObjectNode> virtualObjects) {
         ValueNode valueNode;
         if (localValue instanceof Constant) {
             valueNode = ConstantNode.forConstant((Constant) localValue, providers.getMetaAccess(), hostGraph);
         } else if (localValue instanceof VirtualObject) {
-            throw GraalInternalError.unimplemented();
+            valueNode = getNodeForVirtualObjectFromFrame((VirtualObject) localValue, virtualObjects, hostGraph);
         } else if (localValue instanceof StackSlot) {
             StackSlot slot = (StackSlot) localValue;
             valueNode = getNodeForStackSlotFromFrame(slot, localValue.getKind(), hsailFrame, hostGraph, providers, config, numSRegs, numDRegs);
         } else if (localValue instanceof HotSpotMonitorValue) {
             HotSpotMonitorValue hotSpotMonitorValue = (HotSpotMonitorValue) localValue;
-            return getNodeForValueFromFrame(hotSpotMonitorValue.getOwner(), hsailFrame, hostGraph, providers, config, numSRegs, numDRegs);
+            return getNodeForValueFromFrame(hotSpotMonitorValue.getOwner(), hsailFrame, hostGraph, providers, config, numSRegs, numDRegs, virtualObjects);
         } else if (localValue instanceof RegisterValue) {
             RegisterValue registerValue = (RegisterValue) localValue;
             int regNumber = registerValue.getRegister().number;
@@ -1146,6 +1189,16 @@
         return valueNode;
     }
 
+    private static ValueNode getNodeForVirtualObjectFromFrame(VirtualObject virtualObject, Map<VirtualObject, VirtualObjectNode> virtualObjects, StructuredGraph hostGraph) {
+        return virtualObjects.computeIfAbsent(virtualObject, vo -> {
+            if (vo.getType().isArray()) {
+                return hostGraph.add(new VirtualArrayNode(vo.getType().getComponentType(), vo.getValues().length));
+            } else {
+                return hostGraph.add(new VirtualInstanceNode(vo.getType(), true));
+            }
+        });
+    }
+
     private static ValueNode getNodeForRegisterFromFrame(int regNumber, Kind valueKind, ParameterNode hsailFrame, StructuredGraph hostGraph, HotSpotProviders providers, HotSpotVMConfig config,
                     int numSRegs) {
         ValueNode valueNode;
--- a/graal/com.oracle.graal.virtual/src/com/oracle/graal/virtual/nodes/VirtualObjectState.java	Wed May 14 12:37:39 2014 +0200
+++ b/graal/com.oracle.graal.virtual/src/com/oracle/graal/virtual/nodes/VirtualObjectState.java	Sun May 04 18:58:16 2014 +0200
@@ -45,7 +45,7 @@
         this.fieldValues = new NodeInputList<>(this, fieldValues);
     }
 
-    private VirtualObjectState(VirtualObjectNode object, List<ValueNode> fieldValues) {
+    public VirtualObjectState(VirtualObjectNode object, List<ValueNode> fieldValues) {
         super(object);
         assert object.entryCount() == fieldValues.size();
         this.fieldValues = new NodeInputList<>(this, fieldValues);