Skip to content

Commit

Permalink
Merge pull request #1699 from CEED/jeremy/set-jit-defines
Browse files Browse the repository at this point in the history
Add CeedAddJitDefine
  • Loading branch information
jeremylt authored Oct 22, 2024
2 parents 1dc8b1e + 830fc37 commit e036be4
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 22 deletions.
28 changes: 23 additions & 5 deletions backends/cuda/ceed-cuda-compile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ int CeedCompile_Cuda(Ceed ceed, const char *source, CUmodule *module, const Ceed
size_t ptx_size;
char *ptx;
const int num_opts = 4;
CeedInt num_jit_source_dirs = 0;
CeedInt num_jit_source_dirs = 0, num_jit_defines = 0;
const char **opts;
nvrtcProgram prog;
struct cudaDeviceProp prop;
Expand Down Expand Up @@ -85,19 +85,34 @@ int CeedCompile_Cuda(Ceed ceed, const char *source, CUmodule *module, const Ceed
opts[1] = arch_arg.c_str();
opts[2] = "-Dint32_t=int";
opts[3] = "-DCEED_RUNNING_JIT_PASS=1";
// Additional include dirs
{
const char **jit_source_dirs;

CeedCallBackend(CeedGetJitSourceRoots(ceed, &num_jit_source_dirs, &jit_source_dirs));
CeedCallBackend(CeedRealloc(num_opts + num_jit_source_dirs, &opts));
for (CeedInt i = 0; i < num_jit_source_dirs; i++) {
std::ostringstream include_dirs_arg;
std::ostringstream include_dir_arg;

include_dirs_arg << "-I" << jit_source_dirs[i];
CeedCallBackend(CeedStringAllocCopy(include_dirs_arg.str().c_str(), (char **)&opts[num_opts + i]));
include_dir_arg << "-I" << jit_source_dirs[i];
CeedCallBackend(CeedStringAllocCopy(include_dir_arg.str().c_str(), (char **)&opts[num_opts + i]));
}
CeedCallBackend(CeedRestoreJitSourceRoots(ceed, &jit_source_dirs));
}
// User defines
{
const char **jit_defines;

CeedCallBackend(CeedGetJitDefines(ceed, &num_jit_defines, &jit_defines));
CeedCallBackend(CeedRealloc(num_opts + num_jit_source_dirs + num_jit_defines, &opts));
for (CeedInt i = 0; i < num_jit_defines; i++) {
std::ostringstream define_arg;

define_arg << "-D" << jit_defines[i];
CeedCallBackend(CeedStringAllocCopy(define_arg.str().c_str(), (char **)&opts[num_opts + num_jit_source_dirs + i]));
}
CeedCallBackend(CeedRestoreJitDefines(ceed, &jit_defines));
}

// Add string source argument provided in call
code << source;
Expand All @@ -106,11 +121,14 @@ int CeedCompile_Cuda(Ceed ceed, const char *source, CUmodule *module, const Ceed
CeedCallNvrtc(ceed, nvrtcCreateProgram(&prog, code.str().c_str(), NULL, 0, NULL, NULL));

// Compile kernel
nvrtcResult result = nvrtcCompileProgram(prog, num_opts + num_jit_source_dirs, opts);
nvrtcResult result = nvrtcCompileProgram(prog, num_opts + num_jit_source_dirs + num_jit_defines, opts);

for (CeedInt i = 0; i < num_jit_source_dirs; i++) {
CeedCallBackend(CeedFree(&opts[num_opts + i]));
}
for (CeedInt i = 0; i < num_jit_defines; i++) {
CeedCallBackend(CeedFree(&opts[num_opts + num_jit_source_dirs + i]));
}
CeedCallBackend(CeedFree(&opts));
if (result != NVRTC_SUCCESS) {
char *log;
Expand Down
28 changes: 23 additions & 5 deletions backends/hip/ceed-hip-compile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ int CeedCompile_Hip(Ceed ceed, const char *source, hipModule_t *module, const Ce
size_t ptx_size;
char *ptx;
const int num_opts = 4;
CeedInt num_jit_source_dirs = 0;
CeedInt num_jit_source_dirs = 0, num_jit_defines = 0;
const char **opts;
int runtime_version;
hiprtcProgram prog;
Expand Down Expand Up @@ -87,19 +87,34 @@ int CeedCompile_Hip(Ceed ceed, const char *source, hipModule_t *module, const Ce
opts[1] = arch_arg.c_str();
opts[2] = "-munsafe-fp-atomics";
opts[3] = "-DCEED_RUNNING_JIT_PASS=1";
// Additional include dirs
{
const char **jit_source_dirs;

CeedCallBackend(CeedGetJitSourceRoots(ceed, &num_jit_source_dirs, &jit_source_dirs));
CeedCallBackend(CeedRealloc(num_opts + num_jit_source_dirs, &opts));
for (CeedInt i = 0; i < num_jit_source_dirs; i++) {
std::ostringstream include_dirs_arg;
std::ostringstream include_dir_arg;

include_dirs_arg << "-I" << jit_source_dirs[i];
CeedCallBackend(CeedStringAllocCopy(include_dirs_arg.str().c_str(), (char **)&opts[num_opts + i]));
include_dir_arg << "-I" << jit_source_dirs[i];
CeedCallBackend(CeedStringAllocCopy(include_dir_arg.str().c_str(), (char **)&opts[num_opts + i]));
}
CeedCallBackend(CeedRestoreJitSourceRoots(ceed, &jit_source_dirs));
}
// User defines
{
const char **jit_defines;

CeedCallBackend(CeedGetJitDefines(ceed, &num_jit_defines, &jit_defines));
CeedCallBackend(CeedRealloc(num_opts + num_jit_source_dirs + num_jit_defines, &opts));
for (CeedInt i = 0; i < num_jit_defines; i++) {
std::ostringstream define_arg;

define_arg << "-D" << jit_defines[i];
CeedCallBackend(CeedStringAllocCopy(define_arg.str().c_str(), (char **)&opts[num_opts + num_jit_source_dirs + i]));
}
CeedCallBackend(CeedRestoreJitDefines(ceed, &jit_defines));
}

// Add string source argument provided in call
code << source;
Expand All @@ -108,11 +123,14 @@ int CeedCompile_Hip(Ceed ceed, const char *source, hipModule_t *module, const Ce
CeedCallHiprtc(ceed, hiprtcCreateProgram(&prog, code.str().c_str(), NULL, 0, NULL, NULL));

// Compile kernel
hiprtcResult result = hiprtcCompileProgram(prog, num_opts + num_jit_source_dirs, opts);
hiprtcResult result = hiprtcCompileProgram(prog, num_opts + num_jit_source_dirs + num_jit_defines, opts);

for (CeedInt i = 0; i < num_jit_source_dirs; i++) {
CeedCallBackend(CeedFree(&opts[num_opts + i]));
}
for (CeedInt i = 0; i < num_jit_defines; i++) {
CeedCallBackend(CeedFree(&opts[num_opts + num_jit_source_dirs + i]));
}
CeedCallBackend(CeedFree(&opts));
if (result != HIPRTC_SUCCESS) {
size_t log_size;
Expand Down
2 changes: 2 additions & 0 deletions doc/sphinx/source/releasenotes.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ On this page we provide a summary of the main API changes, new features and exam
- Add `CeedElemRestrictionGetLLayout` to provide L-vector layout for strided `CeedElemRestriction` created with `CEED_BACKEND_STRIDES`.
- Add `CeedVectorReturnCeed` and similar when parent `Ceed` context for a libCEED object is only needed once in a calling scope.
- Enable `#pragma once` for all JiT source; remove duplicate includes in JiT source string before compilation.
- Allow user to set additional compiler options for CUDA and HIP JiT.
Specifically, directories set with `CeedAddJitSourceRoot(ceed, "foo/bar")` will be used to set `-Ifoo/bar` and defines set with `CeedAddJitDefine(ceed, "foo=bar")` will be used to set `-Dfoo=bar`.

### Examples

Expand Down
4 changes: 3 additions & 1 deletion include/ceed-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ struct Ceed_private {
Ceed op_fallback_ceed, op_fallback_parent;
const char *op_fallback_resource;
char **jit_source_roots;
CeedInt num_jit_source_roots;
CeedInt num_jit_source_roots, max_jit_source_roots, num_jit_source_roots_readers;
char **jit_defines;
CeedInt num_jit_defines, max_jit_defines, num_jit_defines_readers;
int (*Error)(Ceed, const char *, int, const char *, int, const char *, va_list *);
int (*SetStream)(Ceed, void *);
int (*GetPreferredMemType)(CeedMemType *);
Expand Down
2 changes: 2 additions & 0 deletions include/ceed/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ CEED_EXTERN int CeedGetWorkVector(Ceed ceed, CeedSize len, CeedVector *vec);
CEED_EXTERN int CeedRestoreWorkVector(Ceed ceed, CeedVector *vec);
CEED_EXTERN int CeedGetJitSourceRoots(Ceed ceed, CeedInt *num_source_roots, const char ***jit_source_roots);
CEED_EXTERN int CeedRestoreJitSourceRoots(Ceed ceed, const char ***jit_source_roots);
CEED_EXTERN int CeedGetJitDefines(Ceed ceed, CeedInt *num_defines, const char ***jit_defines);
CEED_EXTERN int CeedRestoreJitDefines(Ceed ceed, const char ***jit_defines);

CEED_EXTERN int CeedVectorHasValidArray(CeedVector vec, bool *has_valid_array);
CEED_EXTERN int CeedVectorHasBorrowedArrayOfType(CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type);
Expand Down
1 change: 1 addition & 0 deletions include/ceed/ceed.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ CEED_EXTERN int CeedReferenceCopy(Ceed ceed, Ceed *ceed_copy);
CEED_EXTERN int CeedGetResource(Ceed ceed, const char **resource);
CEED_EXTERN int CeedIsDeterministic(Ceed ceed, bool *is_deterministic);
CEED_EXTERN int CeedAddJitSourceRoot(Ceed ceed, const char *jit_source_root);
CEED_EXTERN int CeedAddJitDefine(Ceed ceed, const char *jit_define);
CEED_EXTERN int CeedView(Ceed ceed, FILE *stream);
CEED_EXTERN int CeedDestroy(Ceed *ceed);
CEED_EXTERN int CeedErrorImpl(Ceed ceed, const char *filename, int lineno, const char *func, int ecode, const char *format, ...);
Expand Down
118 changes: 110 additions & 8 deletions interface/ceed.c
Original file line number Diff line number Diff line change
Expand Up @@ -659,14 +659,24 @@ int CeedGetOperatorFallbackCeed(Ceed ceed, Ceed *fallback_ceed) {
fallback_ceed->Error = ceed->Error;
ceed->op_fallback_ceed = fallback_ceed;
{
const char **jit_source_dirs;
CeedInt num_jit_source_dirs = 0;
const char **jit_source_roots;
CeedInt num_jit_source_roots = 0;

CeedCall(CeedGetJitSourceRoots(ceed, &num_jit_source_dirs, &jit_source_dirs));
for (CeedInt i = 0; i < num_jit_source_dirs; i++) {
CeedCall(CeedAddJitSourceRoot(fallback_ceed, jit_source_dirs[i]));
CeedCall(CeedGetJitSourceRoots(ceed, &num_jit_source_roots, &jit_source_roots));
for (CeedInt i = 0; i < num_jit_source_roots; i++) {
CeedCall(CeedAddJitSourceRoot(fallback_ceed, jit_source_roots[i]));
}
CeedCall(CeedRestoreJitSourceRoots(ceed, &jit_source_dirs));
CeedCall(CeedRestoreJitSourceRoots(ceed, &jit_source_roots));
}
{
const char **jit_defines;
CeedInt num_jit_defines = 0;

CeedCall(CeedGetJitDefines(ceed, &num_jit_defines, &jit_defines));
for (CeedInt i = 0; i < num_jit_defines; i++) {
CeedCall(CeedAddJitSourceRoot(fallback_ceed, jit_defines[i]));
}
CeedCall(CeedRestoreJitDefines(ceed, &jit_defines));
}
}
*fallback_ceed = ceed->op_fallback_ceed;
Expand Down Expand Up @@ -874,7 +884,7 @@ int CeedRestoreWorkVector(Ceed ceed, CeedVector *vec) {
}

/**
@brief Retrieve list ofadditional JiT source roots from `Ceed` context.
@brief Retrieve list of additional JiT source roots from `Ceed` context.
Note: The caller is responsible for restoring `jit_source_roots` with @ref CeedRestoreJitSourceRoots().
Expand All @@ -892,6 +902,7 @@ int CeedGetJitSourceRoots(Ceed ceed, CeedInt *num_source_roots, const char ***ji
CeedCall(CeedGetParent(ceed, &ceed_parent));
*num_source_roots = ceed_parent->num_jit_source_roots;
*jit_source_roots = (const char **)ceed_parent->jit_source_roots;
ceed_parent->num_jit_source_roots_readers++;
return CEED_ERROR_SUCCESS;
}

Expand All @@ -906,7 +917,53 @@ int CeedGetJitSourceRoots(Ceed ceed, CeedInt *num_source_roots, const char ***ji
@ref Backend
**/
int CeedRestoreJitSourceRoots(Ceed ceed, const char ***jit_source_roots) {
Ceed ceed_parent;

CeedCall(CeedGetParent(ceed, &ceed_parent));
*jit_source_roots = NULL;
ceed_parent->num_jit_source_roots_readers--;
return CEED_ERROR_SUCCESS;
}

/**
@brief Retrieve list of additional JiT defines from `Ceed` context.
Note: The caller is responsible for restoring `jit_defines` with @ref CeedRestoreJitDefines().
@param[in] ceed `Ceed` context
@param[out] num_jit_defines Number of JiT defines
@param[out] jit_defines Strings such as `foo=bar`, used as `-Dfoo=bar` in JiT
@return An error code: 0 - success, otherwise - failure
@ref Backend
**/
int CeedGetJitDefines(Ceed ceed, CeedInt *num_defines, const char ***jit_defines) {
Ceed ceed_parent;

CeedCall(CeedGetParent(ceed, &ceed_parent));
*num_defines = ceed_parent->num_jit_defines;
*jit_defines = (const char **)ceed_parent->jit_defines;
ceed_parent->num_jit_defines_readers++;
return CEED_ERROR_SUCCESS;
}

/**
@brief Restore list of additional JiT defines from with @ref CeedGetJitDefines()
@param[in] ceed `Ceed` context
@param[out] jit_defines String such as `foo=bar`, used as `-Dfoo=bar` in JiT
@return An error code: 0 - success, otherwise - failure
@ref Backend
**/
int CeedRestoreJitDefines(Ceed ceed, const char ***jit_defines) {
Ceed ceed_parent;

CeedCall(CeedGetParent(ceed, &ceed_parent));
*jit_defines = NULL;
ceed_parent->num_jit_defines_readers--;
return CEED_ERROR_SUCCESS;
}

Expand Down Expand Up @@ -1290,17 +1347,52 @@ int CeedAddJitSourceRoot(Ceed ceed, const char *jit_source_root) {
Ceed ceed_parent;

CeedCall(CeedGetParent(ceed, &ceed_parent));
CeedCheck(!ceed_parent->num_jit_source_roots_readers, ceed, CEED_ERROR_ACCESS, "Cannot add JiT source root, read access has not been restored");

CeedInt index = ceed_parent->num_jit_source_roots;
size_t path_length = strlen(jit_source_root);

CeedCall(CeedRealloc(index + 1, &ceed_parent->jit_source_roots));
if (ceed_parent->num_jit_source_roots == ceed_parent->max_jit_source_roots) {
if (ceed_parent->max_jit_source_roots == 0) ceed_parent->max_jit_source_roots = 1;
ceed_parent->max_jit_source_roots *= 2;
CeedCall(CeedRealloc(ceed_parent->max_jit_source_roots, &ceed_parent->jit_source_roots));
}
CeedCall(CeedCalloc(path_length + 1, &ceed_parent->jit_source_roots[index]));
memcpy(ceed_parent->jit_source_roots[index], jit_source_root, path_length);
ceed_parent->num_jit_source_roots++;
return CEED_ERROR_SUCCESS;
}

/**
@brief Set additional JiT compiler define for `Ceed` context
@param[in,out] ceed `Ceed` context
@param[in] jit_define String such as `foo=bar`, used as `-Dfoo=bar` in JiT
@return An error code: 0 - success, otherwise - failure
@ref User
**/
int CeedAddJitDefine(Ceed ceed, const char *jit_define) {
Ceed ceed_parent;

CeedCall(CeedGetParent(ceed, &ceed_parent));
CeedCheck(!ceed_parent->num_jit_defines_readers, ceed, CEED_ERROR_ACCESS, "Cannot add JiT define, read access has not been restored");

CeedInt index = ceed_parent->num_jit_defines;
size_t define_length = strlen(jit_define);

if (ceed_parent->num_jit_defines == ceed_parent->max_jit_defines) {
if (ceed_parent->max_jit_defines == 0) ceed_parent->max_jit_defines = 1;
ceed_parent->max_jit_defines *= 2;
CeedCall(CeedRealloc(ceed_parent->max_jit_defines, &ceed_parent->jit_defines));
}
CeedCall(CeedCalloc(define_length + 1, &ceed_parent->jit_defines[index]));
memcpy(ceed_parent->jit_defines[index], jit_define, define_length);
ceed_parent->num_jit_defines++;
return CEED_ERROR_SUCCESS;
}

/**
@brief View a `Ceed`
Expand Down Expand Up @@ -1338,6 +1430,11 @@ int CeedDestroy(Ceed *ceed) {
*ceed = NULL;
return CEED_ERROR_SUCCESS;
}

CeedCheck(!(*ceed)->num_jit_source_roots_readers, *ceed, CEED_ERROR_ACCESS,
"Cannot destroy ceed context, read access for JiT source roots has been granted");
CeedCheck(!(*ceed)->num_jit_defines_readers, *ceed, CEED_ERROR_ACCESS, "Cannot add JiT source root, read access for JiT defines has been granted");

if ((*ceed)->delegate) CeedCall(CeedDestroy(&(*ceed)->delegate));

if ((*ceed)->obj_delegate_count > 0) {
Expand All @@ -1355,6 +1452,11 @@ int CeedDestroy(Ceed *ceed) {
}
CeedCall(CeedFree(&(*ceed)->jit_source_roots));

for (CeedInt i = 0; i < (*ceed)->num_jit_defines; i++) {
CeedCall(CeedFree(&(*ceed)->jit_defines[i]));
}
CeedCall(CeedFree(&(*ceed)->jit_defines));

CeedCall(CeedFree(&(*ceed)->f_offsets));
CeedCall(CeedFree(&(*ceed)->resource));
CeedCall(CeedDestroy(&(*ceed)->op_fallback_ceed));
Expand Down
5 changes: 3 additions & 2 deletions tests/t406-qfunction.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ int main(int argc, char **argv) {

memcpy(&file_path[last_slash - file_path], "/test-include/", 15);
CeedAddJitSourceRoot(ceed, file_path);
CeedAddJitDefine(ceed, "COMPILER_DEFINED_SCALE=42");
}

CeedVectorCreate(ceed, q, &w);
Expand Down Expand Up @@ -71,9 +72,9 @@ int main(int argc, char **argv) {

CeedVectorGetArrayRead(v, CEED_MEM_HOST, &v_array);
for (CeedInt i = 0; i < q; i++) {
if (fabs(5 * v_true[i] * sqrt(2.) - v_array[i]) > 1E3 * CEED_EPSILON) {
if (fabs(5 * COMPILER_DEFINED_SCALE * v_true[i] * sqrt(2.) - v_array[i]) > 5E3 * CEED_EPSILON) {
// LCOV_EXCL_START
printf("[%" CeedInt_FMT "] v_true %f != v %f\n", i, 5 * v_true[i] * sqrt(2.), v_array[i]);
printf("[%" CeedInt_FMT "] v_true %f != v %f\n", i, 5 * COMPILER_DEFINED_SCALE * v_true[i] * sqrt(2.), v_array[i]);
// LCOV_EXCL_STOP
}
}
Expand Down
7 changes: 6 additions & 1 deletion tests/t406-qfunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
# include "t406-qfunction-scales.h"
// clang-format on

// Extra define set via CeedAddJitDefine() during JiT
#ifndef CEED_RUNNING_JIT_PASS
#define COMPILER_DEFINED_SCALE 42
#endif

CEED_QFUNCTION(setup)(void *ctx, const CeedInt Q, const CeedScalar *const *in, CeedScalar *const *out) {
const CeedScalar *w = in[0];
CeedScalar *q_data = out[0];
Expand All @@ -36,7 +41,7 @@ CEED_QFUNCTION(mass)(void *ctx, const CeedInt Q, const CeedScalar *const *in, Ce
const CeedScalar *q_data = in[0], *u = in[1];
CeedScalar *v = out[0];
for (CeedInt i = 0; i < Q; i++) {
v[i] = q_data[i] * (times_two(u[i]) + times_three(u[i])) * sqrt(1.0 * SCALE_TWO);
v[i] = q_data[i] * COMPILER_DEFINED_SCALE * (times_two(u[i]) + times_three(u[i])) * sqrt(1.0 * SCALE_TWO);
}
return 0;
}

0 comments on commit e036be4

Please sign in to comment.