Files
ResearchProject/FESurrogateModelTutorial/notebooks/02_gaussian_process_kriging_surrogate.ipynb
2026-05-21 17:03:51 +09:00

169 lines
4.5 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "144f56e0",
"metadata": {},
"source": [
"# Gaussian Process / Kriging Surrogate\n",
"\n",
"GPR? smooth response? ?? ???? ??? ???? ????.\n",
"\n",
"?? ?? notebook? ?? dataset, target, train/test split seed? ????."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7a5a2974",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import warnings\n",
"from pathlib import Path\n",
"\n",
"import pandas as pd\n",
"from sklearn.exceptions import ConvergenceWarning\n",
"\n",
"from femsurrogate.data.schema import DEFAULT_RANDOM_SEED, PARAMETER_COLUMNS\n",
"from femsurrogate.plotting.diagnostics import plot_parity, plot_residuals\n",
"from femsurrogate.surrogates.common import evaluate_model, metrics_to_dict, split_dataset\n",
"from femsurrogate.surrogates.registry import make_model\n",
"\n",
"ROOT = Path.cwd().resolve()\n",
"if not (ROOT / \"pyproject.toml\").exists():\n",
" ROOT = ROOT.parent\n",
"assert (ROOT / \"pyproject.toml\").exists(), ROOT\n",
"DATASET_PATH = ROOT / \"data\" / \"reference\" / \"beam2d_lhs_300.csv\"\n",
"RESULTS_DIR = ROOT / \"reports\" / \"results\"\n",
"PREDICTIONS_DIR = ROOT / \"reports\" / \"predictions\"\n",
"FIGURES_DIR = ROOT / \"reports\" / \"figures\"\n",
"for directory in [RESULTS_DIR, PREDICTIONS_DIR, FIGURES_DIR]:\n",
" directory.mkdir(parents=True, exist_ok=True)"
]
},
{
"cell_type": "markdown",
"id": "a2c2b3fe",
"metadata": {},
"source": [
"## Dataset? split"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ef13216f",
"metadata": {},
"outputs": [],
"source": [
"dataset = pd.read_csv(DATASET_PATH)\n",
"target_column = \"tip_uy_m\"\n",
"split = split_dataset(\n",
" dataset,\n",
" feature_columns=list(PARAMETER_COLUMNS),\n",
" target_column=target_column,\n",
" test_size=0.2,\n",
" random_state=DEFAULT_RANDOM_SEED,\n",
")\n",
"\n",
"len(split.X_train), len(split.X_test)"
]
},
{
"cell_type": "markdown",
"id": "6c3be3f4",
"metadata": {},
"source": [
"## ?? ??? ??"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5e09c8a",
"metadata": {},
"outputs": [],
"source": [
"MODEL_NAME = \"gpr\"\n",
"model = make_model(MODEL_NAME, random_state=DEFAULT_RANDOM_SEED, **{})\n",
"\n",
"with warnings.catch_warnings():\n",
" warnings.filterwarnings(\"ignore\", category=ConvergenceWarning)\n",
" result = evaluate_model(\n",
" model,\n",
" split.X_train,\n",
" split.X_test,\n",
" split.y_train,\n",
" split.y_test,\n",
" model_name=MODEL_NAME,\n",
" target_column=target_column,\n",
" )\n",
"\n",
"metrics = metrics_to_dict(result.metrics)\n",
"metrics"
]
},
{
"cell_type": "markdown",
"id": "4b9e74af",
"metadata": {},
"source": [
"## ?? ??"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5af40658",
"metadata": {},
"outputs": [],
"source": [
"metrics_path = RESULTS_DIR / f\"{MODEL_NAME}_metrics.json\"\n",
"predictions_path = PREDICTIONS_DIR / f\"{MODEL_NAME}_predictions.csv\"\n",
"\n",
"metrics_path.write_text(json.dumps(metrics, indent=2), encoding=\"utf-8\")\n",
"result.predictions.to_csv(predictions_path, index=False)\n",
"\n",
"{\"metrics_path\": str(metrics_path), \"predictions_path\": str(predictions_path)}"
]
},
{
"cell_type": "markdown",
"id": "37a881b4",
"metadata": {},
"source": [
"## ?? plot"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d156a54d",
"metadata": {},
"outputs": [],
"source": [
"parity_fig = plot_parity(result.predictions, title=f\"{MODEL_NAME} parity\")\n",
"residual_fig = plot_residuals(result.predictions, title=f\"{MODEL_NAME} residuals\")\n",
"parity_fig.savefig(FIGURES_DIR / f\"{MODEL_NAME}_parity.png\", dpi=150, bbox_inches=\"tight\")\n",
"residual_fig.savefig(FIGURES_DIR / f\"{MODEL_NAME}_residuals.png\", dpi=150, bbox_inches=\"tight\")\n",
"parity_fig"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}