diff src/gpu/ptx/vm/gpu_ptx.cpp @ 11527:c99e65785936

Improvements to PTX codegen; allows more PTX tests that run on the device to pass.
author bharadwaj
date Wed, 04 Sep 2013 10:47:37 -0400
parents 49bb1bc983c6
children 91e5f927af63
line wrap: on
line diff
--- a/src/gpu/ptx/vm/gpu_ptx.cpp	Wed Sep 04 14:56:30 2013 +0200
+++ b/src/gpu/ptx/vm/gpu_ptx.cpp	Wed Sep 04 10:47:37 2013 -0400
@@ -38,6 +38,7 @@
 gpu::Ptx::cuda_cu_ctx_create_func_t gpu::Ptx::_cuda_cu_ctx_create;
 gpu::Ptx::cuda_cu_ctx_destroy_func_t gpu::Ptx::_cuda_cu_ctx_destroy;
 gpu::Ptx::cuda_cu_ctx_synchronize_func_t gpu::Ptx::_cuda_cu_ctx_synchronize;
+gpu::Ptx::cuda_cu_ctx_set_current_func_t gpu::Ptx::_cuda_cu_ctx_set_current;
 gpu::Ptx::cuda_cu_device_get_count_func_t gpu::Ptx::_cuda_cu_device_get_count;
 gpu::Ptx::cuda_cu_device_get_name_func_t gpu::Ptx::_cuda_cu_device_get_name;
 gpu::Ptx::cuda_cu_device_get_func_t gpu::Ptx::_cuda_cu_device_get;
@@ -87,7 +88,7 @@
     tty->print_cr("Failed to initialize CUDA device");
     return false;
   }
- 
+
   if (TraceGPUInteraction) {
     tty->print_cr("CUDA driver initialization: Success");
   }
@@ -108,7 +109,7 @@
   if (TraceGPUInteraction) {
     tty->print_cr("[CUDA] Number of compute-capable devices found: %d", device_count);
   }
-  
+
   /* Get the handle to the first compute device */
   int device_id = 0;
   /* Compute-capable device handle */
@@ -195,12 +196,6 @@
   jit_options[2] = GRAAL_CU_JIT_MAX_REGISTERS;
   jit_option_values[2] = (void *)(size_t)jit_register_count;
 
-  if (TraceGPUInteraction) {
-    tty->print_cr("[CUDA] PTX Kernel\n%s", code);
-    tty->print_cr("[CUDA] Function name : %s", name);
-
-  }
-
   /* Create CUDA context to compile and execute the kernel */
   int status = _cuda_cu_ctx_create(&_device_context, 0, _cu_device);
 
@@ -213,6 +208,23 @@
     tty->print_cr("[CUDA] Success: Created context for device: %d", _cu_device);
   }
 
+  status = _cuda_cu_ctx_set_current(_device_context);
+
+  if (status != GRAAL_CUDA_SUCCESS) {
+    tty->print_cr("[CUDA] Failed to set current context for device: %d", _cu_device);
+    return NULL;
+  }
+
+  if (TraceGPUInteraction) {
+    tty->print_cr("[CUDA] Success: Set current context for device: %d", _cu_device);
+  }
+
+  if (TraceGPUInteraction) {
+    tty->print_cr("[CUDA] PTX Kernel\n%s", code);
+    tty->print_cr("[CUDA] Function name : %s", name);
+
+  }
+
   /* Load module's data with compiler options */
   status = _cuda_cu_module_load_data_ex(&cu_module, (void*) code, jit_num_options,
                                             jit_options, (void **)jit_option_values);
@@ -220,7 +232,7 @@
     if (status == GRAAL_CUDA_ERROR_NO_BINARY_FOR_GPU) {
       tty->print_cr("[CUDA] Check for malformed PTX kernel or incorrect PTX compilation options");
     }
-    tty->print_cr("[CUDA] *** Error (%d) Failed to load module data with online compiler options for method %s", 
+    tty->print_cr("[CUDA] *** Error (%d) Failed to load module data with online compiler options for method %s",
                   status, name);
     return NULL;
   }
@@ -255,7 +267,7 @@
   unsigned int blockX = 1;
   unsigned int blockY = 1;
   unsigned int blockZ = 1;
-  
+
   struct CUfunc_st* cu_function = (struct CUfunc_st*) kernel;
 
   void * config[5] = {
@@ -366,7 +378,7 @@
   if (cuda_library_name != NULL) {
     char *buffer = (char*)malloc(STD_BUFFER_SIZE);
     void *handle = os::dll_load(cuda_library_name, buffer, STD_BUFFER_SIZE);
-	free(buffer);
+        free(buffer);
     if (handle != NULL) {
       _cuda_cu_init =
         CAST_TO_FN_PTR(cuda_cu_init_func_t, os::dll_lookup(handle, "cuInit"));
@@ -376,6 +388,8 @@
         CAST_TO_FN_PTR(cuda_cu_ctx_destroy_func_t, os::dll_lookup(handle, "cuCtxDestroy"));
       _cuda_cu_ctx_synchronize =
         CAST_TO_FN_PTR(cuda_cu_ctx_synchronize_func_t, os::dll_lookup(handle, "cuCtxSynchronize"));
+      _cuda_cu_ctx_set_current =
+        CAST_TO_FN_PTR(cuda_cu_ctx_set_current_func_t, os::dll_lookup(handle, "cuCtxSetCurrent"));
       _cuda_cu_device_get_count =
         CAST_TO_FN_PTR(cuda_cu_device_get_count_func_t, os::dll_lookup(handle, "cuDeviceGetCount"));
       _cuda_cu_device_get_name =
@@ -416,4 +430,3 @@
   tty->print_cr("Failed to find CUDA linkage");
   return false;
 }
-