Skip to content

Commit

Permalink
Merge pull request #1588 from CEED/jed/nvrtc-cubin
Browse files Browse the repository at this point in the history
backends/cuda: NVRTC compile to CUBIN when supported (resolve #1587)
  • Loading branch information
jedbrown authored May 23, 2024
2 parents 38f3b71 + 29ec485 commit 9e9230d
Showing 1 changed file with 19 additions and 3 deletions.
22 changes: 19 additions & 3 deletions backends/cuda/ceed-cuda-compile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,19 @@ int CeedCompile_Cuda(Ceed ceed, const char *source, CUmodule *module, const Ceed
opts[0] = "-default-device";
CeedCallBackend(CeedGetData(ceed, &ceed_data));
CeedCallCuda(ceed, cudaGetDeviceProperties(&prop, ceed_data->device_id));
std::string arch_arg = "-arch=compute_" + std::to_string(prop.major) + std::to_string(prop.minor);
opts[1] = arch_arg.c_str();
opts[2] = "-Dint32_t=int";
std::string arch_arg =
#if CUDA_VERSION >= 11010
// NVRTC used to support only virtual architectures through the option
// -arch, since it was only emitting PTX. It will now support actual
// architectures as well to emit SASS.
// https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#dynamic-code-generation
"-arch=sm_"
#else
"-arch=compute_"
#endif
+ std::to_string(prop.major) + std::to_string(prop.minor);
opts[1] = arch_arg.c_str();
opts[2] = "-Dint32_t=int";

// Add string source argument provided in call
code << source;
Expand All @@ -106,9 +116,15 @@ int CeedCompile_Cuda(Ceed ceed, const char *source, CUmodule *module, const Ceed
return CeedError(ceed, CEED_ERROR_BACKEND, "%s\n%s", nvrtcGetErrorString(result), log);
}

#if CUDA_VERSION >= 11010
CeedCallNvrtc(ceed, nvrtcGetCUBINSize(prog, &ptx_size));
CeedCallBackend(CeedMalloc(ptx_size, &ptx));
CeedCallNvrtc(ceed, nvrtcGetCUBIN(prog, ptx));
#else
CeedCallNvrtc(ceed, nvrtcGetPTXSize(prog, &ptx_size));
CeedCallBackend(CeedMalloc(ptx_size, &ptx));
CeedCallNvrtc(ceed, nvrtcGetPTX(prog, ptx));
#endif
CeedCallNvrtc(ceed, nvrtcDestroyProgram(&prog));

CeedCallCuda(ceed, cuModuleLoadData(module, ptx));
Expand Down

0 comments on commit 9e9230d

Please sign in to comment.