refactor: extract math solver boundary
This commit is contained in:
@@ -0,0 +1,81 @@
|
||||
#pragma once
|
||||
|
||||
#include "fesa/Math/DenseMatrix.hpp"
|
||||
#include "fesa/Util/Diagnostics.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace fesa {
|
||||
|
||||
struct SolveResult {
|
||||
std::vector<Real> x;
|
||||
std::vector<Diagnostic> diagnostics;
|
||||
|
||||
bool ok() const {
|
||||
return !hasError(diagnostics);
|
||||
}
|
||||
};
|
||||
|
||||
class LinearSolver {
|
||||
public:
|
||||
virtual ~LinearSolver() = default;
|
||||
virtual SolveResult solve(DenseMatrix a, std::vector<Real> b) const = 0;
|
||||
};
|
||||
|
||||
class GaussianEliminationSolver final : public LinearSolver {
|
||||
public:
|
||||
SolveResult solve(DenseMatrix a, std::vector<Real> b) const override {
|
||||
const LocalIndex n = a.rows();
|
||||
SolveResult result;
|
||||
if (a.rows() != a.cols() || static_cast<LocalIndex>(b.size()) != n) {
|
||||
result.diagnostics.push_back(makeDiagnostic(Severity::Error, "FESA-SOLVER-SIZE", "Linear system size mismatch", "solver"));
|
||||
return result;
|
||||
}
|
||||
for (LocalIndex col = 0; col < n; ++col) {
|
||||
LocalIndex pivot = col;
|
||||
Real pivot_abs = std::fabs(a(col, col));
|
||||
for (LocalIndex row = col + 1; row < n; ++row) {
|
||||
const Real candidate = std::fabs(a(row, col));
|
||||
if (candidate > pivot_abs) {
|
||||
pivot_abs = candidate;
|
||||
pivot = row;
|
||||
}
|
||||
}
|
||||
if (pivot_abs < 1.0e-12) {
|
||||
result.diagnostics.push_back(makeDiagnostic(Severity::Error, "FESA-SINGULAR-SOLVER",
|
||||
"Reduced system is singular or ill-conditioned", "solver"));
|
||||
return result;
|
||||
}
|
||||
if (pivot != col) {
|
||||
for (LocalIndex j = col; j < n; ++j) {
|
||||
std::swap(a(col, j), a(pivot, j));
|
||||
}
|
||||
std::swap(b[static_cast<std::size_t>(col)], b[static_cast<std::size_t>(pivot)]);
|
||||
}
|
||||
const Real diag = a(col, col);
|
||||
for (LocalIndex row = col + 1; row < n; ++row) {
|
||||
const Real factor = a(row, col) / diag;
|
||||
a(row, col) = 0.0;
|
||||
for (LocalIndex j = col + 1; j < n; ++j) {
|
||||
a(row, j) -= factor * a(col, j);
|
||||
}
|
||||
b[static_cast<std::size_t>(row)] -= factor * b[static_cast<std::size_t>(col)];
|
||||
}
|
||||
}
|
||||
result.x.assign(static_cast<std::size_t>(n), 0.0);
|
||||
for (LocalIndex i = n; i-- > 0;) {
|
||||
Real sum = b[static_cast<std::size_t>(i)];
|
||||
for (LocalIndex j = i + 1; j < n; ++j) {
|
||||
sum -= a(i, j) * result.x[static_cast<std::size_t>(j)];
|
||||
}
|
||||
result.x[static_cast<std::size_t>(i)] = sum / a(i, i);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace fesa
|
||||
Reference in New Issue
Block a user