169 lines
4.5 KiB
Plaintext
169 lines
4.5 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "69c9cc92",
|
|
"metadata": {},
|
|
"source": [
|
|
"# MLP Neural Network Surrogate\n",
|
|
"\n",
|
|
"MLP? scaled input? target? ??? ???? ??? ??? ????.\n",
|
|
"\n",
|
|
"?? ?? notebook? ?? dataset, target, train/test split seed? ????."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "eed089f5",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import json\n",
|
|
"import warnings\n",
|
|
"from pathlib import Path\n",
|
|
"\n",
|
|
"import pandas as pd\n",
|
|
"from sklearn.exceptions import ConvergenceWarning\n",
|
|
"\n",
|
|
"from femsurrogate.data.schema import DEFAULT_RANDOM_SEED, PARAMETER_COLUMNS\n",
|
|
"from femsurrogate.plotting.diagnostics import plot_parity, plot_residuals\n",
|
|
"from femsurrogate.surrogates.common import evaluate_model, metrics_to_dict, split_dataset\n",
|
|
"from femsurrogate.surrogates.registry import make_model\n",
|
|
"\n",
|
|
"ROOT = Path.cwd().resolve()\n",
|
|
"if not (ROOT / \"pyproject.toml\").exists():\n",
|
|
" ROOT = ROOT.parent\n",
|
|
"assert (ROOT / \"pyproject.toml\").exists(), ROOT\n",
|
|
"DATASET_PATH = ROOT / \"data\" / \"reference\" / \"beam2d_lhs_300.csv\"\n",
|
|
"RESULTS_DIR = ROOT / \"reports\" / \"results\"\n",
|
|
"PREDICTIONS_DIR = ROOT / \"reports\" / \"predictions\"\n",
|
|
"FIGURES_DIR = ROOT / \"reports\" / \"figures\"\n",
|
|
"for directory in [RESULTS_DIR, PREDICTIONS_DIR, FIGURES_DIR]:\n",
|
|
" directory.mkdir(parents=True, exist_ok=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "b2547495",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Dataset? split"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "039279d4",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"dataset = pd.read_csv(DATASET_PATH)\n",
|
|
"target_column = \"tip_uy_m\"\n",
|
|
"split = split_dataset(\n",
|
|
" dataset,\n",
|
|
" feature_columns=list(PARAMETER_COLUMNS),\n",
|
|
" target_column=target_column,\n",
|
|
" test_size=0.2,\n",
|
|
" random_state=DEFAULT_RANDOM_SEED,\n",
|
|
")\n",
|
|
"\n",
|
|
"len(split.X_train), len(split.X_test)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "71b3323f",
|
|
"metadata": {},
|
|
"source": [
|
|
"## ?? ??? ??"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "605253ef",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"MODEL_NAME = \"mlp\"\n",
|
|
"model = make_model(MODEL_NAME, random_state=DEFAULT_RANDOM_SEED, **{'max_iter': 500})\n",
|
|
"\n",
|
|
"with warnings.catch_warnings():\n",
|
|
" warnings.filterwarnings(\"ignore\", category=ConvergenceWarning)\n",
|
|
" result = evaluate_model(\n",
|
|
" model,\n",
|
|
" split.X_train,\n",
|
|
" split.X_test,\n",
|
|
" split.y_train,\n",
|
|
" split.y_test,\n",
|
|
" model_name=MODEL_NAME,\n",
|
|
" target_column=target_column,\n",
|
|
" )\n",
|
|
"\n",
|
|
"metrics = metrics_to_dict(result.metrics)\n",
|
|
"metrics"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "f70ad9fd",
|
|
"metadata": {},
|
|
"source": [
|
|
"## ?? ??"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "fb08797b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"metrics_path = RESULTS_DIR / f\"{MODEL_NAME}_metrics.json\"\n",
|
|
"predictions_path = PREDICTIONS_DIR / f\"{MODEL_NAME}_predictions.csv\"\n",
|
|
"\n",
|
|
"metrics_path.write_text(json.dumps(metrics, indent=2), encoding=\"utf-8\")\n",
|
|
"result.predictions.to_csv(predictions_path, index=False)\n",
|
|
"\n",
|
|
"{\"metrics_path\": str(metrics_path), \"predictions_path\": str(predictions_path)}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "39a34166",
|
|
"metadata": {},
|
|
"source": [
|
|
"## ?? plot"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "694a1081",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"parity_fig = plot_parity(result.predictions, title=f\"{MODEL_NAME} parity\")\n",
|
|
"residual_fig = plot_residuals(result.predictions, title=f\"{MODEL_NAME} residuals\")\n",
|
|
"parity_fig.savefig(FIGURES_DIR / f\"{MODEL_NAME}_parity.png\", dpi=150, bbox_inches=\"tight\")\n",
|
|
"residual_fig.savefig(FIGURES_DIR / f\"{MODEL_NAME}_residuals.png\", dpi=150, bbox_inches=\"tight\")\n",
|
|
"parity_fig"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"name": "python",
|
|
"pygments_lexer": "ipython3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|