initial commit FESurrogateModelTutorial
This commit is contained in:
@@ -0,0 +1,41 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
from femsurrogate.plotting.comparison import metrics_table, plot_metric_comparison
|
||||
from femsurrogate.plotting.diagnostics import plot_parity, plot_residuals
|
||||
from femsurrogate.surrogates.common import MetricsReport
|
||||
|
||||
|
||||
def _predictions() -> pd.DataFrame:
|
||||
return pd.DataFrame(
|
||||
{
|
||||
"y_true": [-1.0, -2.0, -3.0],
|
||||
"y_pred": [-1.1, -1.9, -3.2],
|
||||
"residual": [0.1, -0.1, 0.2],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_diagnostic_plots_return_figures():
|
||||
parity = plot_parity(_predictions(), title="Parity")
|
||||
residuals = plot_residuals(_predictions(), title="Residuals")
|
||||
|
||||
assert isinstance(parity, Figure)
|
||||
assert isinstance(residuals, Figure)
|
||||
plt.close(parity)
|
||||
plt.close(residuals)
|
||||
|
||||
|
||||
def test_metrics_table_and_comparison_plot():
|
||||
reports = [
|
||||
MetricsReport("rsm", "tip_uy_m", 0.2, 0.1, 0.9, 0.01, 0.001),
|
||||
MetricsReport("gpr", "tip_uy_m", 0.1, 0.05, 0.95, 0.2, 0.01),
|
||||
]
|
||||
|
||||
table = metrics_table(reports)
|
||||
figure = plot_metric_comparison(table, metric="rmse", title="RMSE")
|
||||
|
||||
assert list(table["model_name"]) == ["gpr", "rsm"]
|
||||
assert isinstance(figure, Figure)
|
||||
plt.close(figure)
|
||||
Reference in New Issue
Block a user