42 lines
1.2 KiB
Python
42 lines
1.2 KiB
Python
import matplotlib.pyplot as plt
|
|
import pandas as pd
|
|
from matplotlib.figure import Figure
|
|
|
|
from femsurrogate.plotting.comparison import metrics_table, plot_metric_comparison
|
|
from femsurrogate.plotting.diagnostics import plot_parity, plot_residuals
|
|
from femsurrogate.surrogates.common import MetricsReport
|
|
|
|
|
|
def _predictions() -> pd.DataFrame:
|
|
return pd.DataFrame(
|
|
{
|
|
"y_true": [-1.0, -2.0, -3.0],
|
|
"y_pred": [-1.1, -1.9, -3.2],
|
|
"residual": [0.1, -0.1, 0.2],
|
|
}
|
|
)
|
|
|
|
|
|
def test_diagnostic_plots_return_figures():
|
|
parity = plot_parity(_predictions(), title="Parity")
|
|
residuals = plot_residuals(_predictions(), title="Residuals")
|
|
|
|
assert isinstance(parity, Figure)
|
|
assert isinstance(residuals, Figure)
|
|
plt.close(parity)
|
|
plt.close(residuals)
|
|
|
|
|
|
def test_metrics_table_and_comparison_plot():
|
|
reports = [
|
|
MetricsReport("rsm", "tip_uy_m", 0.2, 0.1, 0.9, 0.01, 0.001),
|
|
MetricsReport("gpr", "tip_uy_m", 0.1, 0.05, 0.95, 0.2, 0.01),
|
|
]
|
|
|
|
table = metrics_table(reports)
|
|
figure = plot_metric_comparison(table, metric="rmse", title="RMSE")
|
|
|
|
assert list(table["model_name"]) == ["gpr", "rsm"]
|
|
assert isinstance(figure, Figure)
|
|
plt.close(figure)
|