Mercurial > hg > truffle
view graal/com.oracle.graal.hotspot.hsail/src/com/oracle/graal/hotspot/hsail/ForEachToGraal.java @ 18408:2c3666f44855
Truffle: initial commit of object API implementation
author | Andreas Woess <andreas.woess@jku.at> |
---|---|
date | Tue, 18 Nov 2014 23:19:43 +0100 |
parents | c88ab4f1f04a |
children |
line wrap: on
line source
/* * Copyright (c) 2009, 2011, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it * under the terms of the GNU General Public License version 2 only, as * published by the Free Software Foundation. * * This code is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License * version 2 for more details (a copy is included in the LICENSE file that * accompanied this code). * * You should have received a copy of the GNU General Public License version * 2 along with this work; if not, write to the Free Software Foundation, * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. * * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA * or visit www.oracle.com if you need additional information or have any * questions. */ package com.oracle.graal.hotspot.hsail; import static com.oracle.graal.hotspot.HotSpotGraalRuntime.*; import java.lang.reflect.*; import java.util.*; import java.util.concurrent.*; import java.util.function.*; import com.oracle.graal.api.code.*; import com.oracle.graal.api.meta.*; import com.oracle.graal.compiler.common.*; import com.oracle.graal.compiler.hsail.*; import com.oracle.graal.compiler.target.*; import com.oracle.graal.debug.*; import com.oracle.graal.debug.internal.*; import com.oracle.graal.gpu.*; import com.oracle.graal.graph.iterators.*; import com.oracle.graal.hotspot.*; import com.oracle.graal.hotspot.meta.*; import com.oracle.graal.hsail.*; import com.oracle.graal.java.*; import com.oracle.graal.nodes.*; import com.oracle.graal.nodes.java.*; import com.oracle.graal.phases.*; import com.oracle.graal.phases.util.*; import com.oracle.graal.printer.*; /** * Implements compile and dispatch of Java code containing lambda constructs. Currently only used by * JDK interception code that offloads to the GPU. */ public class ForEachToGraal implements CompileAndDispatch { private static HSAILHotSpotBackend getHSAILBackend() { Backend backend = runtime().getBackend(HSAIL.class); return (HSAILHotSpotBackend) backend; } ConcurrentHashMap<Class<?>, String> resolvedConsumerTargetMethods = new ConcurrentHashMap<>(); /** * Returns the name of the reduction method given a class implementing {@link IntConsumer}. * * @param opClass a class implementing {@link IntConsumer}. * @return the name of the reduction method */ public String getIntReduceTargetName(Class<?> opClass) { String cachedMethodName = resolvedConsumerTargetMethods.get(Objects.requireNonNull(opClass)); if (cachedMethodName != null) { return cachedMethodName; } else { Method acceptMethod = null; for (Method m : opClass.getMethods()) { if (m.getName().equals("applyAsInt")) { assert acceptMethod == null : "found more than one implementation of applyAsInt in " + opClass; acceptMethod = m; } } // Ensure a debug configuration for this thread is initialized if (DebugScope.getConfig() == null) { DebugEnvironment.initialize(System.out); } HSAILHotSpotBackend backend = getHSAILBackend(); Providers providers = backend.getProviders(); StructuredGraph graph = new StructuredGraph(((HotSpotMetaAccessProvider) providers.getMetaAccess()).lookupJavaMethod(acceptMethod)); new GraphBuilderPhase.Instance(providers.getMetaAccess(), GraphBuilderConfiguration.getDefault(), OptimisticOptimizations.ALL).apply(graph); NodeIterable<MethodCallTargetNode> calls = graph.getNodes(MethodCallTargetNode.class); assert calls.count() == 1; ResolvedJavaMethod lambdaMethod = calls.first().targetMethod(); Debug.log("target ... %s", lambdaMethod); String className = lambdaMethod.getDeclaringClass().getName(); if (!className.equals("Ljava/lang/Integer;")) { return null; } resolvedConsumerTargetMethods.put(opClass, lambdaMethod.getName()); return lambdaMethod.getName().intern(); } } /** * Gets a compiled and installed kernel for the lambda called by the * {@link IntConsumer#accept(int)} method in a class implementing {@link IntConsumer}. * * @param intConsumerClass a class implementing {@link IntConsumer} * @return a {@link HotSpotNmethod} handle to the compiled and installed kernel */ private static HotSpotNmethod getCompiledLambda(Class<?> intConsumerClass) { Method acceptMethod = null; for (Method m : intConsumerClass.getMethods()) { if (m.getName().equals("accept")) { assert acceptMethod == null : "found more than one implementation of accept(int) in " + intConsumerClass; acceptMethod = m; } } // Ensure a debug configuration for this thread is initialized if (DebugScope.getConfig() == null) { DebugEnvironment.initialize(System.out); } HSAILHotSpotBackend backend = getHSAILBackend(); Providers providers = backend.getProviders(); StructuredGraph graph = new StructuredGraph(((HotSpotMetaAccessProvider) providers.getMetaAccess()).lookupJavaMethod(acceptMethod)); new GraphBuilderPhase.Instance(providers.getMetaAccess(), GraphBuilderConfiguration.getDefault(), OptimisticOptimizations.ALL).apply(graph); NodeIterable<MethodCallTargetNode> calls = graph.getNodes(MethodCallTargetNode.class); assert calls.count() == 1; ResolvedJavaMethod lambdaMethod = calls.first().targetMethod(); Debug.log("target ... %s", lambdaMethod); if (lambdaMethod == null) { Debug.log("Did not find call in accept()"); return null; } assert lambdaMethod.getName().startsWith("lambda$"); ExternalCompilationResult hsailCode = backend.compileKernel(lambdaMethod, true); return backend.installKernel(lambdaMethod, hsailCode); } @Override public Object createKernel(Class<?> consumerClass) { try { return getCompiledLambda(consumerClass); } catch (Throwable e) { // If Graal compilation throws an exception, we want to revert to regular Java Debug.log("WARNING: Graal compilation failed"); e.printStackTrace(); return null; } } @Override public Object createKernelFromHsailString(String code, String methodName) { ExternalCompilationResult hsailCode = new ExternalCompilationResult(); try (Debug.Scope ds = Debug.scope("GeneratingKernelBinary")) { HSAILHotSpotBackend backend = getHSAILBackend(); Providers providers = backend.getProviders(); Method integerOffloadMethod = null; for (Method m : Integer.class.getMethods()) { if (m.getName().equals(methodName)) { integerOffloadMethod = m; break; } } if (integerOffloadMethod != null) { ResolvedJavaMethod rm = ((HotSpotMetaAccessProvider) providers.getMetaAccess()).lookupJavaMethod(integerOffloadMethod); long kernel = HSAILHotSpotBackend.generateKernel(code.getBytes(), "Integer::" + methodName); if (kernel == 0) { throw new GraalInternalError("Failed to compile HSAIL kernel from String"); } hsailCode.setEntryPoint(kernel); return backend.installKernel(rm, hsailCode); // is a HotSpotNmethod } else { return null; } } catch (Throwable e) { throw Debug.handle(e); } } @Override public boolean dispatchKernel(Object kernel, int jobSize, Object[] args) { HotSpotNmethod code = (HotSpotNmethod) kernel; if (code != null) { try { // No return value from HSAIL kernels getHSAILBackend().executeKernel(code, jobSize, args); return true; } catch (InvalidInstalledCodeException iice) { Debug.log("WARNING: Invalid installed code at exec time: %s", iice); iice.printStackTrace(); return false; } } else { // Should throw something sensible here return false; } } /** * Running with a larger global size seems to increase the performance for sum, but it might be * different for other reductions so it is a knob. */ private static final int GlobalSize = 1024 * Integer.getInteger("com.amd.sumatra.reduce.globalsize.multiple", 1); @Override public Integer offloadIntReduceImpl(Object okraKernel, int identity, int[] streamSource) { // NOTE - this reduce requires local size of 64 which is the SumatraUtils default // Handmade reduce does not support +UseCompressedOops HotSpotVMConfig config = runtime().getConfig(); if (config.useCompressedOops == true || config.useHSAILDeoptimization == true) { throw new GraalInternalError("Reduce offload not compatible with +UseCompressedOops or +UseHSAILDeoptimization"); } try { assert streamSource.length >= GlobalSize : "Input array length=" + streamSource.length + " smaller than requested global_size=" + GlobalSize; int[] result = {identity}; Object[] args = {streamSource, result, streamSource.length}; args[0] = streamSource; dispatchKernel(okraKernel, GlobalSize, args); // kernel result is result[0]. return result[0]; } catch (Exception e) { System.err.println(e); e.printStackTrace(); } return null; } @Override public String getIntegerReduceIntrinsic(String reducerName) { // Note all of these depend on group size of 256 String reduceOp = "/* Invalid */ "; String atomicResultProduction = "/* Invalid */ "; if (reducerName.equals("sum")) { reduceOp = "add_u32 "; atomicResultProduction = "atomicnoret_add_global_u32 "; } else if (reducerName.equals("max")) { reduceOp = "max_s32 "; atomicResultProduction = "atomicnoret_max_global_s32 "; } else if (reducerName.equals("min")) { reduceOp = "min_s32 "; atomicResultProduction = "atomicnoret_min_global_s32 "; } else { return "/* Invalid */ "; } // @formatter:off return new String( "version 0:95:$full:$large; // BRIG Object Format Version 0:4" + "\n" + "" + "\n" + "kernel &run(" + "\n" + " align 8 kernarg_u64 %arg_p3," + "\n" + " align 8 kernarg_u64 %arg_p4," + "\n" + " align 4 kernarg_u32 %arg_p5)" + "\n" + "{" + "\n" + "" + "\n" + " align 4 group_u32 %reduce_cllocal_scratch[256];" + "\n" + "" + "\n" + " workitemabsid_u32 $s2, 0;" + "\n" + "" + "\n" + " ld_kernarg_u32 $s1, [%arg_p5];" + "\n" + " ld_kernarg_u64 $d0, [%arg_p4];" + "\n" + " ld_kernarg_u64 $d1, [%arg_p3];" + "\n" + "" + "\n" + " add_u64 $d0, $d0, 24; // adjust over obj array headers" + "\n" + " add_u64 $d1, $d1, 24;" + "\n" + " cmp_ge_b1_s32 $c0, $s2, $s1; // if(gloId < length){" + "\n" + " cbr $c0, @BB0_1;" + "\n" + " gridsize_u32 $s0, 0; // s0 is globalsize" + "\n" + " add_u32 $s0, $s0, $s2; // gx += globalsize" + "\n" + " cvt_s64_s32 $d2, $s2; // s2 is global id" + "\n" + " shl_u64 $d2, $d2, 2;" + "\n" + " add_u64 $d2, $d1, $d2;" + "\n" + " ld_global_u32 $s3, [$d2]; // load this element from input" + "\n" + " brn @BB0_3;" + "\n" + "" + "\n" + "@BB0_1:" + "\n" + " mov_b32 $s0, $s2;" + "\n" + "" + "\n" + "@BB0_3:" + "\n" + " cmp_ge_b1_s32 $c1, $s0, $s1; // while (gx < length)" + "\n" + " cbr $c1, @BB0_6;" + "\n" + " gridsize_u32 $s2, 0;" + "\n" + "" + "\n" + "@BB0_5:" + "\n" + " cvt_s64_s32 $d2, $s0;" + "\n" + " shl_u64 $d2, $d2, 2;" + "\n" + " add_u64 $d2, $d1, $d2;" + "\n" + " ld_global_u32 $s4, [$d2];" + "\n" + reduceOp + " $s3, $s3, $s4;" + "\n" + " add_u32 $s0, $s0, $s2;" + "\n" + " cmp_lt_b1_s32 $c1, $s0, $s1;" + "\n" + " cbr $c1, @BB0_5;" + "\n" + "" + "\n" + "@BB0_6:" + "\n" + " workgroupid_u32 $s0, 0;" + "\n" + " workgroupsize_u32 $s2, 0;" + "\n" + " mul_u32 $s2, $s2, $s0;" + "\n" + " sub_u32 $s2, $s1, $s2;" + "\n" + " workitemid_u32 $s1, 0;" + "\n" + " add_u32 $s4, $s1, 128;" + "\n" + " cmp_lt_b1_u32 $c1, $s4, $s2;" + "\n" + " cmp_lt_b1_s32 $c2, $s1, 128;" + "\n" + " and_b1 $c1, $c2, $c1;" + "\n" + " cvt_s64_s32 $d1, $s1;" + "\n" + " shl_u64 $d1, $d1, 2;" + "\n" + " lda_group_u64 $d2, [%reduce_cllocal_scratch];" + "\n" + " add_u64 $d1, $d2, $d1;" + "\n" + " st_group_u32 $s3, [$d1];" + "\n" + " barrier_fgroup;" + "\n" + " not_b1 $c1, $c1;" + "\n" + " cbr $c1, @BB0_8;" + "\n" + " ld_group_u32 $s3, [$d1];" + "\n" + " cvt_s64_s32 $d3, $s4;" + "\n" + " shl_u64 $d3, $d3, 2;" + "\n" + " add_u64 $d3, $d2, $d3;" + "\n" + " ld_group_u32 $s4, [$d3];" + "\n" + reduceOp + " $s3, $s3, $s4;" + "\n" + " st_group_u32 $s3, [$d1];" + "\n" + "" + "\n" + "@BB0_8:" + "\n" + " add_u32 $s3, $s1, 64;" + "\n" + " cmp_lt_b1_u32 $c1, $s3, $s2;" + "\n" + " cmp_lt_b1_s32 $c2, $s1, 64;" + "\n" + " and_b1 $c1, $c2, $c1;" + "\n" + " barrier_fgroup;" + "\n" + " not_b1 $c1, $c1;" + "\n" + " cbr $c1, @BB0_10;" + "\n" + " ld_group_u32 $s4, [$d1];" + "\n" + " cvt_s64_s32 $d3, $s3;" + "\n" + " shl_u64 $d3, $d3, 2;" + "\n" + " add_u64 $d3, $d2, $d3;" + "\n" + " ld_group_u32 $s3, [$d3];" + "\n" + reduceOp + " $s3, $s3, $s4;" + "\n" + " st_group_u32 $s3, [$d1];" + "\n" + "" + "\n" + "@BB0_10:" + "\n" + " add_u32 $s3, $s1, 32;" + "\n" + " cmp_lt_b1_u32 $c1, $s3, $s2;" + "\n" + " cmp_lt_b1_s32 $c2, $s1, 32;" + "\n" + " and_b1 $c1, $c2, $c1;" + "\n" + " barrier_fgroup;" + "\n" + " not_b1 $c1, $c1;" + "\n" + " cbr $c1, @BB0_12;" + "\n" + " ld_group_u32 $s4, [$d1];" + "\n" + " cvt_s64_s32 $d3, $s3;" + "\n" + " shl_u64 $d3, $d3, 2;" + "\n" + " add_u64 $d3, $d2, $d3;" + "\n" + " ld_group_u32 $s3, [$d3];" + "\n" + reduceOp + " $s3, $s3, $s4;" + "\n" + " st_group_u32 $s3, [$d1];" + "\n" + "" + "\n" + "@BB0_12:" + "\n" + " add_u32 $s3, $s1, 16;" + "\n" + " cmp_lt_b1_u32 $c1, $s3, $s2;" + "\n" + " cmp_lt_b1_s32 $c2, $s1, 16;" + "\n" + " and_b1 $c1, $c2, $c1;" + "\n" + " barrier_fgroup;" + "\n" + " not_b1 $c1, $c1;" + "\n" + " cbr $c1, @BB0_14;" + "\n" + " ld_group_u32 $s4, [$d1];" + "\n" + " cvt_s64_s32 $d3, $s3;" + "\n" + " shl_u64 $d3, $d3, 2;" + "\n" + " add_u64 $d3, $d2, $d3;" + "\n" + " ld_group_u32 $s3, [$d3];" + "\n" + reduceOp + " $s3, $s3, $s4;" + "\n" + " st_group_u32 $s3, [$d1];" + "\n" + "" + "\n" + "@BB0_14:" + "\n" + " add_u32 $s3, $s1, 8;" + "\n" + " cmp_lt_b1_u32 $c1, $s3, $s2;" + "\n" + " cmp_lt_b1_s32 $c2, $s1, 8;" + "\n" + " and_b1 $c1, $c2, $c1;" + "\n" + " barrier_fgroup;" + "\n" + " not_b1 $c1, $c1;" + "\n" + " cbr $c1, @BB0_16;" + "\n" + " ld_group_u32 $s4, [$d1];" + "\n" + " cvt_s64_s32 $d3, $s3;" + "\n" + " shl_u64 $d3, $d3, 2;" + "\n" + " add_u64 $d3, $d2, $d3;" + "\n" + " ld_group_u32 $s3, [$d3];" + "\n" + reduceOp + " $s3, $s3, $s4;" + "\n" + " st_group_u32 $s3, [$d1];" + "\n" + "" + "\n" + "@BB0_16:" + "\n" + " add_u32 $s3, $s1, 4;" + "\n" + " cmp_lt_b1_u32 $c1, $s3, $s2;" + "\n" + " cmp_lt_b1_s32 $c2, $s1, 4;" + "\n" + " and_b1 $c1, $c2, $c1;" + "\n" + " barrier_fgroup;" + "\n" + " not_b1 $c1, $c1;" + "\n" + " cbr $c1, @BB0_18;" + "\n" + " ld_group_u32 $s4, [$d1];" + "\n" + " cvt_s64_s32 $d3, $s3;" + "\n" + " shl_u64 $d3, $d3, 2;" + "\n" + " add_u64 $d3, $d2, $d3;" + "\n" + " ld_group_u32 $s3, [$d3];" + "\n" + reduceOp + " $s3, $s3, $s4;" + "\n" + " st_group_u32 $s3, [$d1];" + "\n" + "" + "\n" + "@BB0_18:" + "\n" + " add_u32 $s3, $s1, 2;" + "\n" + " cmp_lt_b1_u32 $c1, $s3, $s2;" + "\n" + " cmp_lt_b1_s32 $c2, $s1, 2;" + "\n" + " and_b1 $c1, $c2, $c1;" + "\n" + " barrier_fgroup;" + "\n" + " not_b1 $c1, $c1;" + "\n" + " cbr $c1, @BB0_20;" + "\n" + " ld_group_u32 $s4, [$d1];" + "\n" + " cvt_s64_s32 $d3, $s3;" + "\n" + " shl_u64 $d3, $d3, 2;" + "\n" + " add_u64 $d3, $d2, $d3;" + "\n" + " ld_group_u32 $s3, [$d3];" + "\n" + reduceOp + " $s3, $s3, $s4;" + "\n" + " st_group_u32 $s3, [$d1];" + "\n" + "" + "\n" + "@BB0_20:" + "\n" + " add_u32 $s3, $s1, 1;" + "\n" + " cmp_lt_b1_u32 $c1, $s3, $s2;" + "\n" + " cmp_lt_b1_s32 $c2, $s1, 1;" + "\n" + " and_b1 $c1, $c2, $c1;" + "\n" + " barrier_fgroup;" + "\n" + " not_b1 $c1, $c1;" + "\n" + " cbr $c1, @BB0_22;" + "\n" + " ld_group_u32 $s4, [$d1];" + "\n" + " cvt_s64_s32 $d3, $s3;" + "\n" + " shl_u64 $d3, $d3, 2;" + "\n" + " add_u64 $d2, $d2, $d3;" + "\n" + " ld_group_u32 $s3, [$d2];" + "\n" + reduceOp + " $s3, $s3, $s4;" + "\n" + " st_group_u32 $s3, [$d1];" + "\n" + "" + "\n" + "@BB0_22:" + "\n" + " cmp_gt_b1_u32 $c0, $s1, 0; // s1 is local id, done if > 0" + "\n" + " cbr $c0, @BB0_24;" + "\n" + "" + "\n" + " ld_group_u32 $s2, [%reduce_cllocal_scratch]; // s2 is result[get_group_id(0)];" + "\n" + atomicResultProduction + " [$d0], $s2; // build global result from local results" + "\n" + "" + "\n" + "@BB0_24:" + "\n" + " ret;" + "\n" + "};" + "\n"); //@formatter:on } }