{ "cells": [ { "cell_type": "markdown", "id": "741ede83", "metadata": {}, "source": [ "# 06 Compare Surrogate Models\n", "\n", "? notebook? ?? model notebook?? ??? metrics JSON? ?? ??? ????. ??? ?? ???? ???." ] }, { "cell_type": "code", "execution_count": null, "id": "10f43a71", "metadata": {}, "outputs": [], "source": [ "import json\n", "from pathlib import Path\n", "\n", "import pandas as pd\n", "\n", "from femsurrogate.plotting.comparison import plot_metric_comparison\n", "\n", "ROOT = Path.cwd().resolve()\n", "if not (ROOT / \"pyproject.toml\").exists():\n", " ROOT = ROOT.parent\n", "assert (ROOT / \"pyproject.toml\").exists(), ROOT\n", "RESULTS_DIR = ROOT / \"reports\" / \"results\"\n", "FIGURES_DIR = ROOT / \"reports\" / \"figures\"\n", "FIGURES_DIR.mkdir(parents=True, exist_ok=True)\n", "MODEL_NAMES = [\"rsm\", \"gpr\", \"random_forest\", \"gradient_boosting\", \"mlp\"]" ] }, { "cell_type": "markdown", "id": "9798e3bb", "metadata": {}, "source": [ "## Metrics ??" ] }, { "cell_type": "code", "execution_count": null, "id": "55d2447d", "metadata": {}, "outputs": [], "source": [ "records = []\n", "for model_name in MODEL_NAMES:\n", " metrics_path = RESULTS_DIR / f\"{model_name}_metrics.json\"\n", " assert metrics_path.exists(), metrics_path\n", " records.append(json.loads(metrics_path.read_text(encoding=\"utf-8\")))\n", "\n", "comparison = pd.DataFrame(records).sort_values(\"rmse\").reset_index(drop=True)\n", "comparison_path = RESULTS_DIR / \"model_comparison.csv\"\n", "comparison.to_csv(comparison_path, index=False)\n", "comparison" ] }, { "cell_type": "markdown", "id": "731bb2f7", "metadata": {}, "source": [ "## Metric ?? plot" ] }, { "cell_type": "code", "execution_count": null, "id": "64657b84", "metadata": {}, "outputs": [], "source": [ "figures = {}\n", "for metric in [\"rmse\", \"mae\", \"r2\", \"fit_time_s\", \"predict_time_s\"]:\n", " figure = plot_metric_comparison(comparison, metric=metric, title=f\"Surrogate {metric}\")\n", " figure.savefig(FIGURES_DIR / f\"comparison_{metric}.png\", dpi=150, bbox_inches=\"tight\")\n", " figures[metric] = figure\n", "\n", "figures[\"rmse\"]" ] }, { "cell_type": "markdown", "id": "63cb5b8d", "metadata": {}, "source": [ "## ?? ???\n", "\n", "- `rmse`? `mae`? ?? ??? ?? ????.\n", "- `r2`? ??? residual plot?? ?? ?? ??? bias? ??? ????? ????.\n", "- `fit_time_s`, `predict_time_s`? ?? ??? ??? loop?? ????.\n", "- GPR? ?? dataset?? ???? sample ?? ??? ?? ??? ??? ?? ? ??." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "pygments_lexer": "ipython3" } }, "nbformat": 4, "nbformat_minor": 5 }