{ "cells": [ { "cell_type": "markdown", "id": "1638813f", "metadata": {}, "source": [ "# Random Forest Surrogate\n", "\n", "Random Forest? feature scaling ??? ????? interaction? ??? ???? tree ensemble??.\n", "\n", "?? ?? notebook? ?? dataset, target, train/test split seed? ????." ] }, { "cell_type": "code", "execution_count": null, "id": "b8bdcff3", "metadata": {}, "outputs": [], "source": [ "import json\n", "import warnings\n", "from pathlib import Path\n", "\n", "import pandas as pd\n", "from sklearn.exceptions import ConvergenceWarning\n", "\n", "from femsurrogate.data.schema import DEFAULT_RANDOM_SEED, PARAMETER_COLUMNS\n", "from femsurrogate.plotting.diagnostics import plot_parity, plot_residuals\n", "from femsurrogate.surrogates.common import evaluate_model, metrics_to_dict, split_dataset\n", "from femsurrogate.surrogates.registry import make_model\n", "\n", "ROOT = Path.cwd().resolve()\n", "if not (ROOT / \"pyproject.toml\").exists():\n", " ROOT = ROOT.parent\n", "assert (ROOT / \"pyproject.toml\").exists(), ROOT\n", "DATASET_PATH = ROOT / \"data\" / \"reference\" / \"beam2d_lhs_300.csv\"\n", "RESULTS_DIR = ROOT / \"reports\" / \"results\"\n", "PREDICTIONS_DIR = ROOT / \"reports\" / \"predictions\"\n", "FIGURES_DIR = ROOT / \"reports\" / \"figures\"\n", "for directory in [RESULTS_DIR, PREDICTIONS_DIR, FIGURES_DIR]:\n", " directory.mkdir(parents=True, exist_ok=True)" ] }, { "cell_type": "markdown", "id": "6957e06b", "metadata": {}, "source": [ "## Dataset? split" ] }, { "cell_type": "code", "execution_count": null, "id": "209f40be", "metadata": {}, "outputs": [], "source": [ "dataset = pd.read_csv(DATASET_PATH)\n", "target_column = \"tip_uy_m\"\n", "split = split_dataset(\n", " dataset,\n", " feature_columns=list(PARAMETER_COLUMNS),\n", " target_column=target_column,\n", " test_size=0.2,\n", " random_state=DEFAULT_RANDOM_SEED,\n", ")\n", "\n", "len(split.X_train), len(split.X_test)" ] }, { "cell_type": "markdown", "id": "fdc2fbed", "metadata": {}, "source": [ "## ?? ??? ??" ] }, { "cell_type": "code", "execution_count": null, "id": "63ed6060", "metadata": {}, "outputs": [], "source": [ "MODEL_NAME = \"random_forest\"\n", "model = make_model(MODEL_NAME, random_state=DEFAULT_RANDOM_SEED, **{})\n", "\n", "with warnings.catch_warnings():\n", " warnings.filterwarnings(\"ignore\", category=ConvergenceWarning)\n", " result = evaluate_model(\n", " model,\n", " split.X_train,\n", " split.X_test,\n", " split.y_train,\n", " split.y_test,\n", " model_name=MODEL_NAME,\n", " target_column=target_column,\n", " )\n", "\n", "metrics = metrics_to_dict(result.metrics)\n", "metrics" ] }, { "cell_type": "markdown", "id": "25dafcbf", "metadata": {}, "source": [ "## ?? ??" ] }, { "cell_type": "code", "execution_count": null, "id": "8357de5d", "metadata": {}, "outputs": [], "source": [ "metrics_path = RESULTS_DIR / f\"{MODEL_NAME}_metrics.json\"\n", "predictions_path = PREDICTIONS_DIR / f\"{MODEL_NAME}_predictions.csv\"\n", "\n", "metrics_path.write_text(json.dumps(metrics, indent=2), encoding=\"utf-8\")\n", "result.predictions.to_csv(predictions_path, index=False)\n", "\n", "{\"metrics_path\": str(metrics_path), \"predictions_path\": str(predictions_path)}" ] }, { "cell_type": "markdown", "id": "069f56dd", "metadata": {}, "source": [ "## ?? plot" ] }, { "cell_type": "code", "execution_count": null, "id": "00159cc2", "metadata": {}, "outputs": [], "source": [ "parity_fig = plot_parity(result.predictions, title=f\"{MODEL_NAME} parity\")\n", "residual_fig = plot_residuals(result.predictions, title=f\"{MODEL_NAME} residuals\")\n", "parity_fig.savefig(FIGURES_DIR / f\"{MODEL_NAME}_parity.png\", dpi=150, bbox_inches=\"tight\")\n", "residual_fig.savefig(FIGURES_DIR / f\"{MODEL_NAME}_residuals.png\", dpi=150, bbox_inches=\"tight\")\n", "parity_fig" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "pygments_lexer": "ipython3" } }, "nbformat": 4, "nbformat_minor": 5 }