Skip to content

Commit

Permalink
Merge pull request #310 from awslabs/simlapointe/complex-pc
Browse files Browse the repository at this point in the history
Complex-valued coarse preconditioner
  • Loading branch information
simlapointe authored Dec 30, 2024
2 parents 7383601 + 0dbaaec commit ceb2154
Show file tree
Hide file tree
Showing 18 changed files with 142 additions and 42 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ The format of this changelog is based on
- Added adaptive time-stepping capability for transient simulations. The new ODE integrators
rely on the SUNDIALS library and can be specified by setting the
`config["Solver"]["Transient"]["Type"]` option to `"CVODE"` or `"ARKODE"`.
- Added an option to use the complex-valued system matrix for the coarse level solve (sparse
direct solve) instead of the real-valued approximation. This can be specified with
`config["Solver"]["Linear"]["ComplexCoarseSolve"]`.

## [0.13.0] - 2024-05-20

Expand Down
4 changes: 4 additions & 0 deletions docs/src/config/solver.md
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ under the directory specified by
"MGSmoothOrder": <int>,
"PCMatReal": <bool>,
"PCMatShifted": <bool>,
"ComplexCoarseSolve": <bool>,
"PCSide": <string>,
"DivFreeTol": <float>,
"DivFreeMaxIts": <float>,
Expand Down Expand Up @@ -420,6 +421,9 @@ domain problems using a positive definite approximation of the system matrix by
the sign for the mass matrix contribution, which can help performance at high frequencies
(relative to the lowest nonzero eigenfrequencies of the model).

`"ComplexCoarseSolve" [false]` : When set to `true`, the coarse-level solver uses the true
complex-valued system matrix. When set to `false`, the real-valued approximation is used.

`"PCSide" ["Default"]` : Side for preconditioning. Not all options are available for all
iterative solver choices, and the default choice depends on the iterative solver used.

Expand Down
2 changes: 1 addition & 1 deletion palace/drivers/basesolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ void BaseSolver::SolveEstimateMarkRefine(std::vector<std::unique_ptr<Mesh>> &mes

// Optionally rebalance and write the adapted mesh to file.
{
const auto ratio_pre = mesh::RebalanceMesh(*mesh.back(), iodata);
const auto ratio_pre = mesh::RebalanceMesh(iodata, *mesh.back());
if (ratio_pre > refinement.maximum_imbalance)
{
int min_elem, max_elem;
Expand Down
2 changes: 1 addition & 1 deletion palace/linalg/gmg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class GeometricMultigridSolver : public Solver<OperType>
const std::vector<const Operator *> *G, int cycle_it,
int smooth_it, int cheby_order, double cheby_sf_max,
double cheby_sf_min, bool cheby_4th_kind);
GeometricMultigridSolver(MPI_Comm comm, const IoData &iodata,
GeometricMultigridSolver(const IoData &iodata, MPI_Comm comm,
std::unique_ptr<Solver<OperType>> &&coarse_solver,
const std::vector<const Operator *> &P,
const std::vector<const Operator *> *G = nullptr)
Expand Down
29 changes: 15 additions & 14 deletions palace/linalg/ksp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ namespace
{

template <typename OperType>
std::unique_ptr<IterativeSolver<OperType>> ConfigureKrylovSolver(MPI_Comm comm,
const IoData &iodata)
std::unique_ptr<IterativeSolver<OperType>> ConfigureKrylovSolver(const IoData &iodata,
MPI_Comm comm)
{
// Create the solver.
std::unique_ptr<IterativeSolver<OperType>> ksp;
Expand Down Expand Up @@ -114,7 +114,7 @@ std::unique_ptr<IterativeSolver<OperType>> ConfigureKrylovSolver(MPI_Comm comm,
}

template <typename OperType, typename T, typename... U>
auto MakeWrapperSolver(U &&...args)
auto MakeWrapperSolver(const IoData &iodata, U &&...args)
{
// Sparse direct solver types copy the input matrix, so there is no need to save the
// parallel assembled operator.
Expand All @@ -131,12 +131,13 @@ auto MakeWrapperSolver(U &&...args)
#endif
false);
return std::make_unique<MfemWrapperSolver<OperType>>(
std::make_unique<T>(std::forward<U>(args)...), save_assembled);
std::make_unique<T>(iodata, std::forward<U>(args)...), save_assembled,
iodata.solver.linear.complex_coarse_solve);
}

template <typename OperType>
std::unique_ptr<Solver<OperType>>
ConfigurePreconditionerSolver(MPI_Comm comm, const IoData &iodata,
ConfigurePreconditionerSolver(const IoData &iodata, MPI_Comm comm,
FiniteElementSpaceHierarchy &fespaces,
FiniteElementSpaceHierarchy *aux_fespaces)
{
Expand All @@ -161,31 +162,31 @@ ConfigurePreconditionerSolver(MPI_Comm comm, const IoData &iodata,
break;
case config::LinearSolverData::Type::SUPERLU:
#if defined(MFEM_USE_SUPERLU)
pc = MakeWrapperSolver<OperType, SuperLUSolver>(comm, iodata, print);
pc = MakeWrapperSolver<OperType, SuperLUSolver>(iodata, comm, print);
#else
MFEM_ABORT("Solver was not built with SuperLU_DIST support, please choose a "
"different solver!");
#endif
break;
case config::LinearSolverData::Type::STRUMPACK:
#if defined(MFEM_USE_STRUMPACK)
pc = MakeWrapperSolver<OperType, StrumpackSolver>(comm, iodata, print);
pc = MakeWrapperSolver<OperType, StrumpackSolver>(iodata, comm, print);
#else
MFEM_ABORT("Solver was not built with STRUMPACK support, please choose a "
"different solver!");
#endif
break;
case config::LinearSolverData::Type::STRUMPACK_MP:
#if defined(MFEM_USE_STRUMPACK)
pc = MakeWrapperSolver<OperType, StrumpackMixedPrecisionSolver>(comm, iodata, print);
pc = MakeWrapperSolver<OperType, StrumpackMixedPrecisionSolver>(iodata, comm, print);
#else
MFEM_ABORT("Solver was not built with STRUMPACK support, please choose a "
"different solver!");
#endif
break;
case config::LinearSolverData::Type::MUMPS:
#if defined(MFEM_USE_MUMPS)
pc = MakeWrapperSolver<OperType, MumpsSolver>(comm, iodata, print);
pc = MakeWrapperSolver<OperType, MumpsSolver>(iodata, comm, print);
#else
MFEM_ABORT(
"Solver was not built with MUMPS support, please choose a different solver!");
Expand Down Expand Up @@ -214,12 +215,12 @@ ConfigurePreconditionerSolver(MPI_Comm comm, const IoData &iodata,
"primary space and auxiliary spaces for construction!");
const auto G = fespaces.GetDiscreteInterpolators(*aux_fespaces);
return std::make_unique<GeometricMultigridSolver<OperType>>(
comm, iodata, std::move(pc), fespaces.GetProlongationOperators(), &G);
iodata, comm, std::move(pc), fespaces.GetProlongationOperators(), &G);
}
else
{
return std::make_unique<GeometricMultigridSolver<OperType>>(
comm, iodata, std::move(pc), fespaces.GetProlongationOperators());
iodata, comm, std::move(pc), fespaces.GetProlongationOperators());
}
}();
gmg->EnableTimer(); // Enable timing for primary geometric multigrid solver
Expand All @@ -238,9 +239,9 @@ BaseKspSolver<OperType>::BaseKspSolver(const IoData &iodata,
FiniteElementSpaceHierarchy &fespaces,
FiniteElementSpaceHierarchy *aux_fespaces)
: BaseKspSolver(
ConfigureKrylovSolver<OperType>(fespaces.GetFinestFESpace().GetComm(), iodata),
ConfigurePreconditionerSolver<OperType>(fespaces.GetFinestFESpace().GetComm(),
iodata, fespaces, aux_fespaces))
ConfigureKrylovSolver<OperType>(iodata, fespaces.GetFinestFESpace().GetComm()),
ConfigurePreconditionerSolver<OperType>(
iodata, fespaces.GetFinestFESpace().GetComm(), fespaces, aux_fespaces))
{
use_timer = true;
}
Expand Down
2 changes: 1 addition & 1 deletion palace/linalg/mumps.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class MumpsSolver : public mfem::MUMPSSolver
public:
MumpsSolver(MPI_Comm comm, mfem::MUMPSSolver::MatType sym,
config::LinearSolverData::SymFactType reorder, double blr_tol, int print);
MumpsSolver(MPI_Comm comm, const IoData &iodata, int print)
MumpsSolver(const IoData &iodata, MPI_Comm comm, int print)
: MumpsSolver(comm,
(iodata.solver.linear.pc_mat_shifted ||
iodata.problem.type == config::ProblemData::Type::TRANSIENT ||
Expand Down
56 changes: 48 additions & 8 deletions palace/linalg/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,27 @@ void MfemWrapperSolver<ComplexOperator>::SetOperator(const ComplexOperator &op)
}
if (hAr && hAi)
{
A.reset(mfem::Add(1.0, *hAr, 1.0, *hAi));
if (complex_matrix)
{
// A = [Ar, -Ai]
// [Ai, Ar]
mfem::Array2D<const mfem::HypreParMatrix *> blocks(2, 2);
mfem::Array2D<double> block_coeffs(2, 2);
blocks(0, 0) = hAr;
blocks(0, 1) = hAi;
blocks(1, 0) = hAi;
blocks(1, 1) = hAr;
block_coeffs(0, 0) = 1.0;
block_coeffs(0, 1) = -1.0;
block_coeffs(1, 0) = 1.0;
block_coeffs(1, 1) = 1.0;
A.reset(mfem::HypreParMatrixFromBlocks(blocks, &block_coeffs));
}
else
{
// A = Ar + Ai.
A.reset(mfem::Add(1.0, *hAr, 1.0, *hAi));
}
if (PtAPr)
{
PtAPr->StealParallelAssemble();
Expand Down Expand Up @@ -101,13 +121,33 @@ template <>
void MfemWrapperSolver<ComplexOperator>::Mult(const ComplexVector &x,
ComplexVector &y) const
{
mfem::Array<const Vector *> X(2);
mfem::Array<Vector *> Y(2);
X[0] = &x.Real();
X[1] = &x.Imag();
Y[0] = &y.Real();
Y[1] = &y.Imag();
pc->ArrayMult(X, Y);
if (pc->Height() == x.Size())
{
mfem::Array<const Vector *> X(2);
mfem::Array<Vector *> Y(2);
X[0] = &x.Real();
X[1] = &x.Imag();
Y[0] = &y.Real();
Y[1] = &y.Imag();
pc->ArrayMult(X, Y);
}
else
{
const int Nx = x.Size(), Ny = y.Size();
Vector X(2 * Nx), Y(2 * Ny), yr, yi;
X.UseDevice(true);
Y.UseDevice(true);
yr.UseDevice(true);
yi.UseDevice(true);
linalg::SetSubVector(X, 0, x.Real());
linalg::SetSubVector(X, Nx, x.Imag());
pc->Mult(X, Y);
Y.ReadWrite();
yr.MakeRef(Y, 0, Ny);
yi.MakeRef(Y, Ny, Ny);
y.Real() = yr;
y.Imag() = yi;
}
}

} // namespace palace
11 changes: 8 additions & 3 deletions palace/linalg/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,22 @@ class MfemWrapperSolver : public Solver<OperType>
// The actual mfem::Solver.
std::unique_ptr<mfem::Solver> pc;

// Real-valued system matrix A = Ar + Ai in parallel assembled form.
// System matrix A in parallel assembled form.
std::unique_ptr<mfem::HypreParMatrix> A;

// Whether or not to save the parallel assembled matrix after calling
// mfem::Solver::SetOperator (some solvers copy their input).
bool save_assembled;

// Whether to use the exact complex-valued system matrix or the real-valued
// approximation A = Ar + Ai.
bool complex_matrix = true;

public:
MfemWrapperSolver(std::unique_ptr<mfem::Solver> &&pc, bool save_assembled = true)
MfemWrapperSolver(std::unique_ptr<mfem::Solver> &&pc, bool save_assembled = true,
bool complex_matrix = true)
: Solver<OperType>(pc->iterative_mode), pc(std::move(pc)),
save_assembled(save_assembled)
save_assembled(save_assembled), complex_matrix(complex_matrix)
{
}

Expand Down
2 changes: 1 addition & 1 deletion palace/linalg/strumpack.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class StrumpackSolverBase : public StrumpackSolverType
config::LinearSolverData::CompressionType compression, double lr_tol,
int butterfly_l, int lossy_prec, int print);

StrumpackSolverBase(MPI_Comm comm, const IoData &iodata, int print)
StrumpackSolverBase(const IoData &iodata, MPI_Comm comm, int print)
: StrumpackSolverBase(comm, iodata.solver.linear.sym_fact_type,
iodata.solver.linear.strumpack_compression_type,
iodata.solver.linear.strumpack_lr_tol,
Expand Down
2 changes: 1 addition & 1 deletion palace/linalg/superlu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class SuperLUSolver : public mfem::Solver
public:
SuperLUSolver(MPI_Comm comm, config::LinearSolverData::SymFactType reorder, bool use_3d,
int print);
SuperLUSolver(MPI_Comm comm, const IoData &iodata, int print)
SuperLUSolver(const IoData &iodata, MPI_Comm comm, int print)
: SuperLUSolver(comm, iodata.solver.linear.sym_fact_type,
iodata.solver.linear.superlu_3d, print)
{
Expand Down
37 changes: 36 additions & 1 deletion palace/linalg/vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -520,12 +520,47 @@ void SetSubVector(ComplexVector &x, const mfem::Array<int> &rows, const ComplexV
mfem::forall_switch(use_dev, N,
[=] MFEM_HOST_DEVICE(int i)
{
const auto id = idx[i];
const int id = idx[i];
XR[id] = YR[id];
XI[id] = YI[id];
});
}

template <>
void SetSubVector(Vector &x, int start, const Vector &y)
{
const bool use_dev = x.UseDevice();
const int N = y.Size();
MFEM_ASSERT(start >= 0 && start + N <= x.Size(), "Invalid range for SetSubVector!");
const auto *Y = y.Read(use_dev);
auto *X = x.ReadWrite(use_dev);
mfem::forall_switch(use_dev, N,
[=] MFEM_HOST_DEVICE(int i)
{
const int id = start + i;
X[id] = Y[i];
});
}

template <>
void SetSubVector(ComplexVector &x, int start, const ComplexVector &y)
{
const bool use_dev = x.UseDevice();
const int N = y.Size();
MFEM_ASSERT(start >= 0 && start + N <= x.Size(), "Invalid range for SetSubVector!");
const auto *YR = y.Real().Read(use_dev);
const auto *YI = y.Imag().Read(use_dev);
auto *XR = x.Real().ReadWrite(use_dev);
auto *XI = x.Imag().ReadWrite(use_dev);
mfem::forall_switch(use_dev, N,
[=] MFEM_HOST_DEVICE(int i)
{
const int id = start + i;
XR[id] = YR[i];
XI[id] = YI[i];
});
}

template <>
void SetSubVector(Vector &x, int start, int end, double s)
{
Expand Down
4 changes: 4 additions & 0 deletions palace/linalg/vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ void SetSubVector(VecType &x, const mfem::Array<int> &rows, double s);
template <typename VecType>
void SetSubVector(VecType &x, const mfem::Array<int> &rows, const VecType &y);

// Sets contiguous entries from start to the given vector.
template <typename VecType>
void SetSubVector(VecType &x, int start, const VecType &y);

// Sets all entries in the range [start, end) to the given value.
template <typename VecType>
void SetSubVector(VecType &x, int start, int end, double s);
Expand Down
2 changes: 1 addition & 1 deletion palace/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ int main(int argc, char *argv[])
std::vector<std::unique_ptr<Mesh>> mesh;
{
std::vector<std::unique_ptr<mfem::ParMesh>> mfem_mesh;
mfem_mesh.push_back(mesh::ReadMesh(world_comm, iodata));
mfem_mesh.push_back(mesh::ReadMesh(iodata, world_comm));
iodata.NondimensionalizeInputs(*mfem_mesh[0]);
mesh::RefineMesh(iodata, mfem_mesh);
for (auto &m : mfem_mesh)
Expand Down
3 changes: 3 additions & 0 deletions palace/utils/configfile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1779,6 +1779,7 @@ void LinearSolverData::SetUp(json &solver)
// Preconditioner-specific options.
pc_mat_real = linear->value("PCMatReal", pc_mat_real);
pc_mat_shifted = linear->value("PCMatShifted", pc_mat_shifted);
complex_coarse_solve = linear->value("ComplexCoarseSolve", complex_coarse_solve);
pc_side_type = linear->value("PCSide", pc_side_type);
sym_fact_type = linear->value("ColumnOrdering", sym_fact_type);
strumpack_compression_type =
Expand Down Expand Up @@ -1821,6 +1822,7 @@ void LinearSolverData::SetUp(json &solver)

linear->erase("PCMatReal");
linear->erase("PCMatShifted");
linear->erase("ComplexCoarseSolve");
linear->erase("PCSide");
linear->erase("ColumnOrdering");
linear->erase("STRUMPACKCompressionType");
Expand Down Expand Up @@ -1865,6 +1867,7 @@ void LinearSolverData::SetUp(json &solver)

std::cout << "PCMatReal: " << pc_mat_real << '\n';
std::cout << "PCMatShifted: " << pc_mat_shifted << '\n';
std::cout << "ComplexCoarseSolve: " << complex_coarse_solve << '\n';
std::cout << "PCSide: " << pc_side_type << '\n';
std::cout << "ColumnOrdering: " << sym_fact_type << '\n';
std::cout << "STRUMPACKCompressionType: " << strumpack_compression_type << '\n';
Expand Down
4 changes: 4 additions & 0 deletions palace/utils/configfile.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,10 @@ struct LinearSolverData
// (makes the preconditoner matrix SPD).
int pc_mat_shifted = -1;

// For frequency domain applications, use the complex-valued system matrix in the sparse
// direct solver.
bool complex_coarse_solve = false;

// Choose left or right preconditioning.
enum class SideType
{
Expand Down
Loading

0 comments on commit ceb2154

Please sign in to comment.