Skip to content

Commit

Permalink
Merge pull request #1450 from CEED/sjg/jit-speedup-dev
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastiangrimberg authored Jan 25, 2024
2 parents 6069a98 + 1b7492f commit b698cfb
Show file tree
Hide file tree
Showing 24 changed files with 1,497 additions and 1,200 deletions.
132 changes: 104 additions & 28 deletions backends/cuda-ref/ceed-cuda-ref-operator.c
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@ static int CeedOperatorDestroy_Cuda(CeedOperator op) {
Ceed ceed;

CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
CeedCallCuda(ceed, cuModuleUnload(impl->diag->module));
if (impl->diag->module) {
CeedCallCuda(ceed, cuModuleUnload(impl->diag->module));
}
if (impl->diag->module_point_block) {
CeedCallCuda(ceed, cuModuleUnload(impl->diag->module_point_block));
}
CeedCallCuda(ceed, cudaFree(impl->diag->d_eval_modes_in));
CeedCallCuda(ceed, cudaFree(impl->diag->d_eval_modes_out));
CeedCallCuda(ceed, cudaFree(impl->diag->d_identity));
Expand Down Expand Up @@ -580,11 +585,10 @@ static int CeedOperatorLinearAssembleQFunctionUpdate_Cuda(CeedOperator op, CeedV
//------------------------------------------------------------------------------
// Assemble Diagonal Setup
//------------------------------------------------------------------------------
static inline int CeedOperatorAssembleDiagonalSetup_Cuda(CeedOperator op, CeedInt use_ceedsize_idx) {
static inline int CeedOperatorAssembleDiagonalSetup_Cuda(CeedOperator op) {
Ceed ceed;
char *diagonal_kernel_path, *diagonal_kernel_source;
CeedInt num_input_fields, num_output_fields, num_eval_modes_in = 0, num_eval_modes_out = 0;
CeedInt num_comp, q_comp, num_nodes, num_qpts;
CeedInt q_comp, num_nodes, num_qpts;
CeedEvalMode *eval_modes_in = NULL, *eval_modes_out = NULL;
CeedBasis basis_in = NULL, basis_out = NULL;
CeedQFunctionField *qf_fields;
Expand Down Expand Up @@ -653,24 +657,10 @@ static inline int CeedOperatorAssembleDiagonalSetup_Cuda(CeedOperator op, CeedIn
CeedCallBackend(CeedCalloc(1, &impl->diag));
CeedOperatorDiag_Cuda *diag = impl->diag;

// Assemble kernel
// Basis matrices
CeedCallBackend(CeedBasisGetNumNodes(basis_in, &num_nodes));
CeedCallBackend(CeedBasisGetNumComponents(basis_in, &num_comp));
if (basis_in == CEED_BASIS_NONE) num_qpts = num_nodes;
else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts));
CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/cuda/cuda-ref-operator-assemble-diagonal.h", &diagonal_kernel_path));
CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Diagonal Assembly Kernel Source -----\n");
CeedCallBackend(CeedLoadSourceToBuffer(ceed, diagonal_kernel_path, &diagonal_kernel_source));
CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Diagonal Assembly Source Complete! -----\n");
CeedCallCuda(
ceed, CeedCompile_Cuda(ceed, diagonal_kernel_source, &diag->module, 6, "NUM_EVAL_MODES_IN", num_eval_modes_in, "NUM_EVAL_MODES_OUT",
num_eval_modes_out, "NUM_COMP", num_comp, "NUM_NODES", num_nodes, "NUM_QPTS", num_qpts, "CEED_SIZE", use_ceedsize_idx));
CeedCallCuda(ceed, CeedGetKernel_Cuda(ceed, diag->module, "LinearDiagonal", &diag->LinearDiagonal));
CeedCallCuda(ceed, CeedGetKernel_Cuda(ceed, diag->module, "LinearPointBlockDiagonal", &diag->LinearPointBlock));
CeedCallBackend(CeedFree(&diagonal_kernel_path));
CeedCallBackend(CeedFree(&diagonal_kernel_source));

// Basis matrices
const CeedInt interp_bytes = num_nodes * num_qpts * sizeof(CeedScalar);
const CeedInt eval_modes_bytes = sizeof(CeedEvalMode);
bool has_eval_none = false;
Expand Down Expand Up @@ -774,13 +764,92 @@ static inline int CeedOperatorAssembleDiagonalSetup_Cuda(CeedOperator op, CeedIn
return CEED_ERROR_SUCCESS;
}

//------------------------------------------------------------------------------
// Assemble Diagonal Setup (Compilation)
//------------------------------------------------------------------------------
static inline int CeedOperatorAssembleDiagonalSetupCompile_Cuda(CeedOperator op, CeedInt use_ceedsize_idx, const bool is_point_block) {
Ceed ceed;
char *diagonal_kernel_path, *diagonal_kernel_source;
CeedInt num_input_fields, num_output_fields, num_eval_modes_in = 0, num_eval_modes_out = 0;
CeedInt num_comp, q_comp, num_nodes, num_qpts;
CeedBasis basis_in = NULL, basis_out = NULL;
CeedQFunctionField *qf_fields;
CeedQFunction qf;
CeedOperatorField *op_fields;
CeedOperator_Cuda *impl;

CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
CeedCallBackend(CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields));

// Determine active input basis
CeedCallBackend(CeedOperatorGetFields(op, NULL, &op_fields, NULL, NULL));
CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_fields, NULL, NULL));
for (CeedInt i = 0; i < num_input_fields; i++) {
CeedVector vec;

CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec));
if (vec == CEED_VECTOR_ACTIVE) {
CeedEvalMode eval_mode;

CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis_in));
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode));
CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_in, eval_mode, &q_comp));
if (eval_mode != CEED_EVAL_WEIGHT) {
num_eval_modes_in += q_comp;
}
}
}

// Determine active output basis
CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields));
CeedCallBackend(CeedQFunctionGetFields(qf, NULL, NULL, NULL, &qf_fields));
for (CeedInt i = 0; i < num_output_fields; i++) {
CeedVector vec;

CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec));
if (vec == CEED_VECTOR_ACTIVE) {
CeedEvalMode eval_mode;

CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis_out));
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode));
CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis_out, eval_mode, &q_comp));
if (eval_mode != CEED_EVAL_WEIGHT) {
num_eval_modes_out += q_comp;
}
}
}

// Operator data struct
CeedCallBackend(CeedOperatorGetData(op, &impl));
CeedOperatorDiag_Cuda *diag = impl->diag;

// Assemble kernel
CUmodule *module = is_point_block ? &diag->module_point_block : &diag->module;
CeedInt elems_per_block = 1;
CeedCallBackend(CeedBasisGetNumNodes(basis_in, &num_nodes));
CeedCallBackend(CeedBasisGetNumComponents(basis_in, &num_comp));
if (basis_in == CEED_BASIS_NONE) num_qpts = num_nodes;
else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts));
CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/cuda/cuda-ref-operator-assemble-diagonal.h", &diagonal_kernel_path));
CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Diagonal Assembly Kernel Source -----\n");
CeedCallBackend(CeedLoadSourceToBuffer(ceed, diagonal_kernel_path, &diagonal_kernel_source));
CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Diagonal Assembly Source Complete! -----\n");
CeedCallCuda(ceed, CeedCompile_Cuda(ceed, diagonal_kernel_source, module, 8, "NUM_EVAL_MODES_IN", num_eval_modes_in, "NUM_EVAL_MODES_OUT",
num_eval_modes_out, "NUM_COMP", num_comp, "NUM_NODES", num_nodes, "NUM_QPTS", num_qpts, "USE_CEEDSIZE",
use_ceedsize_idx, "USE_POINT_BLOCK", is_point_block ? 1 : 0, "BLOCK_SIZE", num_nodes * elems_per_block));
CeedCallCuda(ceed, CeedGetKernel_Cuda(ceed, *module, "LinearDiagonal", is_point_block ? &diag->LinearPointBlock : &diag->LinearDiagonal));
CeedCallBackend(CeedFree(&diagonal_kernel_path));
CeedCallBackend(CeedFree(&diagonal_kernel_source));
return CEED_ERROR_SUCCESS;
}

//------------------------------------------------------------------------------
// Assemble Diagonal Core
//------------------------------------------------------------------------------
static inline int CeedOperatorAssembleDiagonalCore_Cuda(CeedOperator op, CeedVector assembled, CeedRequest *request, const bool is_point_block) {
Ceed ceed;
CeedSize assembled_length, assembled_qf_length;
CeedInt use_ceedsize_idx = 0, num_elem, num_nodes;
CeedInt num_elem, num_nodes;
CeedScalar *elem_diag_array;
const CeedScalar *assembled_qf_array;
CeedVector assembled_qf = NULL, elem_diag;
Expand All @@ -795,16 +864,23 @@ static inline int CeedOperatorAssembleDiagonalCore_Cuda(CeedOperator op, CeedVec
CeedCallBackend(CeedElemRestrictionDestroy(&assembled_rstr));
CeedCallBackend(CeedVectorGetArrayRead(assembled_qf, CEED_MEM_DEVICE, &assembled_qf_array));

CeedCallBackend(CeedVectorGetLength(assembled, &assembled_length));
CeedCallBackend(CeedVectorGetLength(assembled_qf, &assembled_qf_length));
if ((assembled_length > INT_MAX) || (assembled_qf_length > INT_MAX)) use_ceedsize_idx = 1;

// Setup
if (!impl->diag) CeedCallBackend(CeedOperatorAssembleDiagonalSetup_Cuda(op, use_ceedsize_idx));
if (!impl->diag) CeedCallBackend(CeedOperatorAssembleDiagonalSetup_Cuda(op));
CeedOperatorDiag_Cuda *diag = impl->diag;

assert(diag != NULL);

// Assemble kernel if needed
if ((!is_point_block && !diag->LinearDiagonal) || (is_point_block && !diag->LinearPointBlock)) {
CeedSize assembled_length, assembled_qf_length;
CeedInt use_ceedsize_idx = 0;
CeedCallBackend(CeedVectorGetLength(assembled, &assembled_length));
CeedCallBackend(CeedVectorGetLength(assembled_qf, &assembled_qf_length));
if ((assembled_length > INT_MAX) || (assembled_qf_length > INT_MAX)) use_ceedsize_idx = 1;

CeedCallBackend(CeedOperatorAssembleDiagonalSetupCompile_Cuda(op, use_ceedsize_idx, is_point_block));
}

// Restriction and diagonal vector
CeedCallBackend(CeedOperatorGetActiveElemRestrictions(op, &rstr_in, &rstr_out));
CeedCheck(rstr_in == rstr_out, ceed, CEED_ERROR_BACKEND,
Expand Down Expand Up @@ -981,8 +1057,8 @@ static int CeedSingleOperatorAssembleSetup_Cuda(CeedOperator op, CeedInt use_cee
CeedCallBackend(CeedCompile_Cuda(ceed, assembly_kernel_source, &asmb->module, 10, "NUM_EVAL_MODES_IN", num_eval_modes_in, "NUM_EVAL_MODES_OUT",
num_eval_modes_out, "NUM_COMP_IN", num_comp_in, "NUM_COMP_OUT", num_comp_out, "NUM_NODES_IN", elem_size_in,
"NUM_NODES_OUT", elem_size_out, "NUM_QPTS", num_qpts_in, "BLOCK_SIZE",
asmb->block_size_x * asmb->block_size_y * asmb->elems_per_block, "BLOCK_SIZE_Y", asmb->block_size_y, "CEED_SIZE",
use_ceedsize_idx));
asmb->block_size_x * asmb->block_size_y * asmb->elems_per_block, "BLOCK_SIZE_Y", asmb->block_size_y,
"USE_CEEDSIZE", use_ceedsize_idx));
CeedCallBackend(CeedGetKernel_Cuda(ceed, asmb->module, "LinearAssemble", &asmb->LinearAssemble));
CeedCallBackend(CeedFree(&assembly_kernel_path));
CeedCallBackend(CeedFree(&assembly_kernel_source));
Expand Down
Loading

0 comments on commit b698cfb

Please sign in to comment.