initial commit FESurrogateModelTutorial
This commit is contained in:
@@ -0,0 +1,168 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ad3ce4ad",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Gradient Boosting Surrogate\n",
|
||||
"\n",
|
||||
"Gradient Boosting? shallow tree? ????? ?? residual pattern? ??? ensemble??.\n",
|
||||
"\n",
|
||||
"?? ?? notebook? ?? dataset, target, train/test split seed? ????."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0fec75b1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"import warnings\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"from sklearn.exceptions import ConvergenceWarning\n",
|
||||
"\n",
|
||||
"from femsurrogate.data.schema import DEFAULT_RANDOM_SEED, PARAMETER_COLUMNS\n",
|
||||
"from femsurrogate.plotting.diagnostics import plot_parity, plot_residuals\n",
|
||||
"from femsurrogate.surrogates.common import evaluate_model, metrics_to_dict, split_dataset\n",
|
||||
"from femsurrogate.surrogates.registry import make_model\n",
|
||||
"\n",
|
||||
"ROOT = Path.cwd().resolve()\n",
|
||||
"if not (ROOT / \"pyproject.toml\").exists():\n",
|
||||
" ROOT = ROOT.parent\n",
|
||||
"assert (ROOT / \"pyproject.toml\").exists(), ROOT\n",
|
||||
"DATASET_PATH = ROOT / \"data\" / \"reference\" / \"beam2d_lhs_300.csv\"\n",
|
||||
"RESULTS_DIR = ROOT / \"reports\" / \"results\"\n",
|
||||
"PREDICTIONS_DIR = ROOT / \"reports\" / \"predictions\"\n",
|
||||
"FIGURES_DIR = ROOT / \"reports\" / \"figures\"\n",
|
||||
"for directory in [RESULTS_DIR, PREDICTIONS_DIR, FIGURES_DIR]:\n",
|
||||
" directory.mkdir(parents=True, exist_ok=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1d93d9fb",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Dataset? split"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c905b0aa",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dataset = pd.read_csv(DATASET_PATH)\n",
|
||||
"target_column = \"tip_uy_m\"\n",
|
||||
"split = split_dataset(\n",
|
||||
" dataset,\n",
|
||||
" feature_columns=list(PARAMETER_COLUMNS),\n",
|
||||
" target_column=target_column,\n",
|
||||
" test_size=0.2,\n",
|
||||
" random_state=DEFAULT_RANDOM_SEED,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"len(split.X_train), len(split.X_test)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c5e7af1a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## ?? ??? ??"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f5168f3e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"MODEL_NAME = \"gradient_boosting\"\n",
|
||||
"model = make_model(MODEL_NAME, random_state=DEFAULT_RANDOM_SEED, **{})\n",
|
||||
"\n",
|
||||
"with warnings.catch_warnings():\n",
|
||||
" warnings.filterwarnings(\"ignore\", category=ConvergenceWarning)\n",
|
||||
" result = evaluate_model(\n",
|
||||
" model,\n",
|
||||
" split.X_train,\n",
|
||||
" split.X_test,\n",
|
||||
" split.y_train,\n",
|
||||
" split.y_test,\n",
|
||||
" model_name=MODEL_NAME,\n",
|
||||
" target_column=target_column,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"metrics = metrics_to_dict(result.metrics)\n",
|
||||
"metrics"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f02aead6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## ?? ??"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1daaf0f6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"metrics_path = RESULTS_DIR / f\"{MODEL_NAME}_metrics.json\"\n",
|
||||
"predictions_path = PREDICTIONS_DIR / f\"{MODEL_NAME}_predictions.csv\"\n",
|
||||
"\n",
|
||||
"metrics_path.write_text(json.dumps(metrics, indent=2), encoding=\"utf-8\")\n",
|
||||
"result.predictions.to_csv(predictions_path, index=False)\n",
|
||||
"\n",
|
||||
"{\"metrics_path\": str(metrics_path), \"predictions_path\": str(predictions_path)}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "14f71882",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## ?? plot"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "807c1025",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"parity_fig = plot_parity(result.predictions, title=f\"{MODEL_NAME} parity\")\n",
|
||||
"residual_fig = plot_residuals(result.predictions, title=f\"{MODEL_NAME} residuals\")\n",
|
||||
"parity_fig.savefig(FIGURES_DIR / f\"{MODEL_NAME}_parity.png\", dpi=150, bbox_inches=\"tight\")\n",
|
||||
"residual_fig.savefig(FIGURES_DIR / f\"{MODEL_NAME}_residuals.png\", dpi=150, bbox_inches=\"tight\")\n",
|
||||
"parity_fig"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"pygments_lexer": "ipython3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Reference in New Issue
Block a user