import matplotlib.pyplot as plt import pandas as pd from matplotlib.figure import Figure from femsurrogate.plotting.comparison import metrics_table, plot_metric_comparison from femsurrogate.plotting.diagnostics import plot_parity, plot_residuals from femsurrogate.surrogates.common import MetricsReport def _predictions() -> pd.DataFrame: return pd.DataFrame( { "y_true": [-1.0, -2.0, -3.0], "y_pred": [-1.1, -1.9, -3.2], "residual": [0.1, -0.1, 0.2], } ) def test_diagnostic_plots_return_figures(): parity = plot_parity(_predictions(), title="Parity") residuals = plot_residuals(_predictions(), title="Residuals") assert isinstance(parity, Figure) assert isinstance(residuals, Figure) plt.close(parity) plt.close(residuals) def test_metrics_table_and_comparison_plot(): reports = [ MetricsReport("rsm", "tip_uy_m", 0.2, 0.1, 0.9, 0.01, 0.001), MetricsReport("gpr", "tip_uy_m", 0.1, 0.05, 0.95, 0.2, 0.01), ] table = metrics_table(reports) figure = plot_metric_comparison(table, metric="rmse", title="RMSE") assert list(table["model_name"]) == ["gpr", "rsm"] assert isinstance(figure, Figure) plt.close(figure)