Files
2026-05-21 17:03:51 +09:00

116 lines
3.0 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "741ede83",
"metadata": {},
"source": [
"# 06 Compare Surrogate Models\n",
"\n",
"? notebook? ?? model notebook?? ??? metrics JSON? ?? ??? ????. ??? ?? ???? ???."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "10f43a71",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"from pathlib import Path\n",
"\n",
"import pandas as pd\n",
"\n",
"from femsurrogate.plotting.comparison import plot_metric_comparison\n",
"\n",
"ROOT = Path.cwd().resolve()\n",
"if not (ROOT / \"pyproject.toml\").exists():\n",
" ROOT = ROOT.parent\n",
"assert (ROOT / \"pyproject.toml\").exists(), ROOT\n",
"RESULTS_DIR = ROOT / \"reports\" / \"results\"\n",
"FIGURES_DIR = ROOT / \"reports\" / \"figures\"\n",
"FIGURES_DIR.mkdir(parents=True, exist_ok=True)\n",
"MODEL_NAMES = [\"rsm\", \"gpr\", \"random_forest\", \"gradient_boosting\", \"mlp\"]"
]
},
{
"cell_type": "markdown",
"id": "9798e3bb",
"metadata": {},
"source": [
"## Metrics ??"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "55d2447d",
"metadata": {},
"outputs": [],
"source": [
"records = []\n",
"for model_name in MODEL_NAMES:\n",
" metrics_path = RESULTS_DIR / f\"{model_name}_metrics.json\"\n",
" assert metrics_path.exists(), metrics_path\n",
" records.append(json.loads(metrics_path.read_text(encoding=\"utf-8\")))\n",
"\n",
"comparison = pd.DataFrame(records).sort_values(\"rmse\").reset_index(drop=True)\n",
"comparison_path = RESULTS_DIR / \"model_comparison.csv\"\n",
"comparison.to_csv(comparison_path, index=False)\n",
"comparison"
]
},
{
"cell_type": "markdown",
"id": "731bb2f7",
"metadata": {},
"source": [
"## Metric ?? plot"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "64657b84",
"metadata": {},
"outputs": [],
"source": [
"figures = {}\n",
"for metric in [\"rmse\", \"mae\", \"r2\", \"fit_time_s\", \"predict_time_s\"]:\n",
" figure = plot_metric_comparison(comparison, metric=metric, title=f\"Surrogate {metric}\")\n",
" figure.savefig(FIGURES_DIR / f\"comparison_{metric}.png\", dpi=150, bbox_inches=\"tight\")\n",
" figures[metric] = figure\n",
"\n",
"figures[\"rmse\"]"
]
},
{
"cell_type": "markdown",
"id": "63cb5b8d",
"metadata": {},
"source": [
"## ?? ???\n",
"\n",
"- `rmse`? `mae`? ?? ??? ?? ????.\n",
"- `r2`? ??? residual plot?? ?? ?? ??? bias? ??? ????? ????.\n",
"- `fit_time_s`, `predict_time_s`? ?? ??? ??? loop?? ????.\n",
"- GPR? ?? dataset?? ???? sample ?? ??? ?? ??? ??? ?? ? ??."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}