initial commit FESurrogateModelTutorial
This commit is contained in:
@@ -0,0 +1,67 @@
|
||||
import pandas as pd
|
||||
from sklearn.linear_model import LinearRegression
|
||||
|
||||
from femsurrogate.surrogates.common import evaluate_model, split_dataset
|
||||
|
||||
|
||||
def _toy_dataset() -> pd.DataFrame:
|
||||
return pd.DataFrame(
|
||||
{
|
||||
"x1": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
|
||||
"x2": [1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5],
|
||||
"target": [1.0, 2.5, 4.0, 5.5, 7.0, 8.5, 10.0, 11.5],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_split_dataset_is_reproducible():
|
||||
dataset = _toy_dataset()
|
||||
first = split_dataset(
|
||||
dataset,
|
||||
feature_columns=["x1", "x2"],
|
||||
target_column="target",
|
||||
test_size=0.25,
|
||||
random_state=20260521,
|
||||
)
|
||||
second = split_dataset(
|
||||
dataset,
|
||||
feature_columns=["x1", "x2"],
|
||||
target_column="target",
|
||||
test_size=0.25,
|
||||
random_state=20260521,
|
||||
)
|
||||
|
||||
pd.testing.assert_frame_equal(first.X_train, second.X_train)
|
||||
pd.testing.assert_frame_equal(first.X_test, second.X_test)
|
||||
pd.testing.assert_series_equal(first.y_train, second.y_train)
|
||||
pd.testing.assert_series_equal(first.y_test, second.y_test)
|
||||
|
||||
|
||||
def test_evaluate_model_returns_metrics_and_predictions():
|
||||
dataset = _toy_dataset()
|
||||
split = split_dataset(
|
||||
dataset,
|
||||
feature_columns=["x1", "x2"],
|
||||
target_column="target",
|
||||
test_size=0.25,
|
||||
random_state=20260521,
|
||||
)
|
||||
result = evaluate_model(
|
||||
LinearRegression(),
|
||||
split.X_train,
|
||||
split.X_test,
|
||||
split.y_train,
|
||||
split.y_test,
|
||||
model_name="linear",
|
||||
target_column="target",
|
||||
)
|
||||
|
||||
assert result.metrics.model_name == "linear"
|
||||
assert result.metrics.target_column == "target"
|
||||
assert result.metrics.rmse >= 0.0
|
||||
assert result.metrics.mae >= 0.0
|
||||
assert result.metrics.r2 <= 1.0
|
||||
assert result.metrics.fit_time_s >= 0.0
|
||||
assert result.metrics.predict_time_s >= 0.0
|
||||
assert list(result.predictions.columns) == ["y_true", "y_pred", "residual"]
|
||||
assert len(result.predictions) == len(split.y_test)
|
||||
Reference in New Issue
Block a user