Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastiangrimberg committed Feb 15, 2023
1 parent 5a9d33d commit 063bea1
Show file tree
Hide file tree
Showing 8 changed files with 191 additions and 143 deletions.
16 changes: 8 additions & 8 deletions backends/ref/ceed-ref-basis.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ static int CeedBasisApply_Ref(CeedBasis basis, CeedInt num_elem, CeedTransposeMo
CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &num_qpts));
CeedFESpace fe_space;
CeedCall(CeedBasisGetFESpace(basis, &fe_space));
CeedTensorContract contract;
CeedCallBackend(CeedBasisGetTensorContract(basis, &contract));
const CeedInt add = (t_mode == CEED_TRANSPOSE);
Expand Down Expand Up @@ -195,12 +193,13 @@ static int CeedBasisApply_Ref(CeedBasis basis, CeedInt num_elem, CeedTransposeMo
switch (eval_mode) {
// Interpolate to/from quadrature points
case CEED_EVAL_INTERP: {
CeedInt qdim = (fe_space == CEED_FE_SPACE_H1) ? 1 : dim;
CeedInt P = num_nodes, Q = qdim * num_qpts;
CeedInt q_comp;
CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, &q_comp));
CeedInt P = num_nodes, Q = q_comp * num_qpts;
const CeedScalar *interp;
CeedCallBackend(CeedBasisGetInterp(basis, &interp));
if (t_mode == CEED_TRANSPOSE) {
P = qdim * num_qpts;
P = q_comp * num_qpts;
Q = num_nodes;
}
CeedCallBackend(CeedTensorContractApply(contract, num_comp, P, num_elem, Q, interp, t_mode, add, u, v));
Expand Down Expand Up @@ -250,12 +249,13 @@ static int CeedBasisApply_Ref(CeedBasis basis, CeedInt num_elem, CeedTransposeMo
} break;
// Evaluate the curl to/from the quadrature points
case CEED_EVAL_CURL: {
CeedInt cdim = (dim < 3) ? 1 : dim;
CeedInt P = num_nodes, Q = cdim * num_qpts;
CeedInt curl_comp;
CeedCallBackend(CeedBasisGetNumCurlComponents(basis, &curl_comp));
CeedInt P = num_nodes, Q = curl_comp * num_qpts;
const CeedScalar *curl;
CeedCallBackend(CeedBasisGetCurl(basis, &curl));
if (t_mode == CEED_TRANSPOSE) {
P = cdim * num_qpts;
P = curl_comp * num_qpts;
Q = num_nodes;
}
CeedCallBackend(CeedTensorContractApply(contract, num_comp, P, num_elem, Q, curl, t_mode, add, u, v));
Expand Down
4 changes: 2 additions & 2 deletions include/ceed-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,8 @@ struct CeedBasis_private {
CeedScalar *grad; /* row-major matrix of shape [dim * Q, P] matrix expressing derivatives of nodal basis functions at quadrature points */
CeedScalar *grad_1d; /* row-major matrix of shape [Q1d, P1d] matrix expressing derivatives of nodal basis functions at quadrature points */
CeedScalar *div; /* row-major matrix of shape [Q, P] expressing the divergence of basis functions at quadrature points for H(div) discretizations */
CeedScalar *curl; /* row-major matrix of shape [cdim * Q, P], cdim = 1 if dim < 3 else dim, expressing the curl of basis functions at quadrature
points for H(curl) discretizations */
CeedScalar *curl; /* row-major matrix of shape [curl_dim * Q, P], curl_dim = 1 if dim < 3 else dim, expressing the curl of basis functions at
quadrature points for H(curl) discretizations */
void *data; /* place for the backend to store any data */
};

Expand Down
5 changes: 5 additions & 0 deletions include/ceed/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,4 +280,9 @@ CEED_EXTERN int CeedOperatorSetData(CeedOperator op, void *data);
CEED_EXTERN int CeedOperatorReference(CeedOperator op);
CEED_EXTERN int CeedOperatorSetSetupDone(CeedOperator op);

CEED_EXTERN int CeedMatrixMatrixMultiply(Ceed ceed, const CeedScalar *mat_A, const CeedScalar *mat_B, CeedScalar *mat_C, CeedInt m, CeedInt n,
CeedInt kk);
CEED_EXTERN int CeedHouseholderApplyQ(CeedScalar *A, const CeedScalar *Q, const CeedScalar *tau, CeedTransposeMode t_mode, CeedInt m, CeedInt n,
CeedInt k, CeedInt row, CeedInt col);

#endif
6 changes: 2 additions & 4 deletions include/ceed/ceed.h
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,8 @@ CEED_EXTERN int CeedBasisGetCeed(CeedBasis basis, Ceed *ceed);
CEED_EXTERN int CeedBasisGetDimension(CeedBasis basis, CeedInt *dim);
CEED_EXTERN int CeedBasisGetTopology(CeedBasis basis, CeedElemTopology *topo);
CEED_EXTERN int CeedBasisGetNumComponents(CeedBasis basis, CeedInt *num_comp);
CEED_EXTERN int CeedBasisGetNumQuadratureComponents(CeedBasis basis, CeedInt *q_comp);
CEED_EXTERN int CeedBasisGetNumCurlComponents(CeedBasis basis, CeedInt *curl_comp);
CEED_EXTERN int CeedBasisGetNumNodes(CeedBasis basis, CeedInt *P);
CEED_EXTERN int CeedBasisGetNumNodes1D(CeedBasis basis, CeedInt *P_1d);
CEED_EXTERN int CeedBasisGetNumQuadraturePoints(CeedBasis basis, CeedInt *Q);
Expand All @@ -391,10 +393,6 @@ CEED_EXTERN int CeedLobattoQuadrature(CeedInt Q, CeedScalar *q_ref_1d, CeedScala
CEED_EXTERN int CeedQRFactorization(Ceed ceed, CeedScalar *mat, CeedScalar *tau, CeedInt m, CeedInt n);
CEED_EXTERN int CeedSymmetricSchurDecomposition(Ceed ceed, CeedScalar *mat, CeedScalar *lambda, CeedInt n);
CEED_EXTERN int CeedSimultaneousDiagonalization(Ceed ceed, CeedScalar *mat_A, CeedScalar *mat_B, CeedScalar *x, CeedScalar *lambda, CeedInt n);
CEED_EXTERN int CeedHouseholderApplyQ(CeedScalar *A, const CeedScalar *Q, const CeedScalar *tau, CeedTransposeMode t_mode, CeedInt m, CeedInt n,
CeedInt k, CeedInt row, CeedInt col);
CEED_EXTERN int CeedMatrixMatrixMultiply(Ceed ceed, const CeedScalar *mat_A, const CeedScalar *mat_B, CeedScalar *mat_C, CeedInt m, CeedInt n,
CeedInt kk);

/** Handle for the user provided CeedQFunction callback function

Expand Down
Loading

0 comments on commit 063bea1

Please sign in to comment.