Files
ResearchProject/FESurrogateModelTutorial/tests/test_surrogate_models.py
T
2026-05-21 17:03:51 +09:00

48 lines
1.4 KiB
Python

import warnings
import numpy as np
import pandas as pd
from sklearn.exceptions import ConvergenceWarning
from femsurrogate.surrogates.registry import MODEL_NAMES, make_model
def _toy_regression_data():
X = pd.DataFrame(
{
"L_m": np.linspace(1.0, 3.0, 12),
"b_m": np.linspace(0.02, 0.08, 12),
"h_m": np.linspace(0.04, 0.16, 12),
"E_pa": np.linspace(100e9, 220e9, 12),
"P_n": np.linspace(100.0, 2000.0, 12),
}
)
y = -X["P_n"] * X["L_m"] ** 3 / (3.0 * X["E_pa"] * X["b_m"] * X["h_m"] ** 3)
return X, y
def test_registry_exposes_expected_surrogate_names():
assert MODEL_NAMES == ("rsm", "gpr", "random_forest", "gradient_boosting", "mlp")
def test_make_model_builds_estimators_that_fit_and_predict():
X, y = _toy_regression_data()
fast_overrides = {
"random_forest": {"n_estimators": 5, "n_jobs": 1},
"gradient_boosting": {"n_estimators": 5},
"mlp": {"hidden_layer_sizes": (4,), "max_iter": 25, "early_stopping": False},
}
for model_name in MODEL_NAMES:
model = make_model(
model_name,
random_state=20260521,
**fast_overrides.get(model_name, {}),
)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=ConvergenceWarning)
model.fit(X, y)
predictions = model.predict(X.iloc[:3])
assert predictions.shape == (3,)