initial commit FESurrogateModelTutorial
This commit is contained in:
@@ -0,0 +1,115 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "741ede83",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 06 Compare Surrogate Models\n",
|
||||
"\n",
|
||||
"? notebook? ?? model notebook?? ??? metrics JSON? ?? ??? ????. ??? ?? ???? ???."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "10f43a71",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"from femsurrogate.plotting.comparison import plot_metric_comparison\n",
|
||||
"\n",
|
||||
"ROOT = Path.cwd().resolve()\n",
|
||||
"if not (ROOT / \"pyproject.toml\").exists():\n",
|
||||
" ROOT = ROOT.parent\n",
|
||||
"assert (ROOT / \"pyproject.toml\").exists(), ROOT\n",
|
||||
"RESULTS_DIR = ROOT / \"reports\" / \"results\"\n",
|
||||
"FIGURES_DIR = ROOT / \"reports\" / \"figures\"\n",
|
||||
"FIGURES_DIR.mkdir(parents=True, exist_ok=True)\n",
|
||||
"MODEL_NAMES = [\"rsm\", \"gpr\", \"random_forest\", \"gradient_boosting\", \"mlp\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9798e3bb",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Metrics ??"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "55d2447d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"records = []\n",
|
||||
"for model_name in MODEL_NAMES:\n",
|
||||
" metrics_path = RESULTS_DIR / f\"{model_name}_metrics.json\"\n",
|
||||
" assert metrics_path.exists(), metrics_path\n",
|
||||
" records.append(json.loads(metrics_path.read_text(encoding=\"utf-8\")))\n",
|
||||
"\n",
|
||||
"comparison = pd.DataFrame(records).sort_values(\"rmse\").reset_index(drop=True)\n",
|
||||
"comparison_path = RESULTS_DIR / \"model_comparison.csv\"\n",
|
||||
"comparison.to_csv(comparison_path, index=False)\n",
|
||||
"comparison"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "731bb2f7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Metric ?? plot"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "64657b84",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"figures = {}\n",
|
||||
"for metric in [\"rmse\", \"mae\", \"r2\", \"fit_time_s\", \"predict_time_s\"]:\n",
|
||||
" figure = plot_metric_comparison(comparison, metric=metric, title=f\"Surrogate {metric}\")\n",
|
||||
" figure.savefig(FIGURES_DIR / f\"comparison_{metric}.png\", dpi=150, bbox_inches=\"tight\")\n",
|
||||
" figures[metric] = figure\n",
|
||||
"\n",
|
||||
"figures[\"rmse\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "63cb5b8d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## ?? ???\n",
|
||||
"\n",
|
||||
"- `rmse`? `mae`? ?? ??? ?? ????.\n",
|
||||
"- `r2`? ??? residual plot?? ?? ?? ??? bias? ??? ????? ????.\n",
|
||||
"- `fit_time_s`, `predict_time_s`? ?? ??? ??? loop?? ????.\n",
|
||||
"- GPR? ?? dataset?? ???? sample ?? ??? ?? ??? ??? ?? ? ??."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"pygments_lexer": "ipython3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Reference in New Issue
Block a user