Files
Electrocardiogram/train.py
2025-12-06 00:14:59 +01:00

140 lines
5.0 KiB
Python

import numpy
import tqdm
from imblearn.over_sampling import ADASYN, SMOTE, SVMSMOTE, BorderlineSMOTE, KMeansSMOTE
from sklearn.metrics import (
accuracy_score,
f1_score,
fbeta_score,
precision_score,
recall_score,
)
from sklearn.model_selection import StratifiedKFold
def train_model_with_kfold(
model, X, y, n_splits=5, random_state=42, smote=True, smote_method="kmeans"
):
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)
accuracy_list = []
f1_macro_list = []
f2_macro_list = []
recall_macro_list = []
precision_macro_list = []
# per-class lists
f1_class0 = []
f1_class1 = []
f2_class0 = []
f2_class1 = []
recall_class0 = []
recall_class1 = []
precision_class0 = []
precision_class1 = []
fold_num = 1
for fold_num, (train_idx, val_idx) in enumerate(
tqdm.tqdm(skf.split(X, y), total=skf.n_splits, desc="Training Folds"), start=1
):
X_train, X_val = X.iloc[train_idx], X.iloc[val_idx]
y_train, y_val = y.iloc[train_idx], y.iloc[val_idx]
if smote:
if smote_method.lower() == "kmeans":
sampler = KMeansSMOTE(
k_neighbors=15,
cluster_balance_threshold=0.1,
random_state=random_state,
)
elif smote_method.lower() == "smote":
sampler = SMOTE(k_neighbors=15, random_state=random_state)
elif smote_method.lower() == "svmsmote":
sampler = SVMSMOTE(k_neighbors=15, random_state=random_state)
elif smote_method.lower() == "borderline":
sampler = BorderlineSMOTE(k_neighbors=15, random_state=random_state)
elif smote_method.lower() == "adasyn":
sampler = ADASYN(n_neighbors=15, random_state=random_state)
else:
raise ValueError(f"Unknown smote_method: {smote_method}")
X_train, y_train = sampler.fit_resample(X_train, y_train)
model.fit(X_train, y_train)
y_pred = model.predict(X_val)
accuracy_list.append(accuracy_score(y_val, y_pred))
f1_macro_list.append(f1_score(y_val, y_pred, average="macro"))
f2_macro_list.append(fbeta_score(y_val, y_pred, beta=2, average="macro"))
recall_macro_list.append(recall_score(y_val, y_pred, average="macro"))
precision_macro_list.append(precision_score(y_val, y_pred, average="macro"))
f1_class0.append(f1_score(y_val, y_pred, pos_label=0))
f1_class1.append(f1_score(y_val, y_pred, pos_label=1))
f2_class0.append(fbeta_score(y_val, y_pred, beta=2, pos_label=0))
f2_class1.append(fbeta_score(y_val, y_pred, beta=2, pos_label=1))
recall_class0.append(recall_score(y_val, y_pred, pos_label=0))
recall_class1.append(recall_score(y_val, y_pred, pos_label=1))
precision_class0.append(precision_score(y_val, y_pred, pos_label=0))
precision_class1.append(precision_score(y_val, y_pred, pos_label=1))
fold_num += 1
return {
"accuracy": numpy.mean(accuracy_list),
"f1_macro": numpy.mean(f1_macro_list),
"f2_macro": numpy.mean(f2_macro_list),
"recall_macro": numpy.mean(recall_macro_list),
"precision_macro": numpy.mean(precision_macro_list),
"f1_class0": numpy.mean(f1_class0),
"f1_class1": numpy.mean(f1_class1),
"f2_class0": numpy.mean(f2_class0),
"f2_class1": numpy.mean(f2_class1),
"recall_class0": numpy.mean(recall_class0),
"recall_class1": numpy.mean(recall_class1),
"precision_class0": numpy.mean(precision_class0),
"precision_class1": numpy.mean(precision_class1),
}
def test_model(model, X_test, y_test):
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
f1_macro = f1_score(y_test, y_pred, average="macro")
f2_macro = fbeta_score(y_test, y_pred, beta=2, average="macro")
recall_macro = recall_score(y_test, y_pred, average="macro")
precision_macro = precision_score(y_test, y_pred, average="macro")
f1_class0 = f1_score(y_test, y_pred, pos_label=0)
f1_class1 = f1_score(y_test, y_pred, pos_label=1)
f2_class0 = fbeta_score(y_test, y_pred, beta=2, pos_label=0)
f2_class1 = fbeta_score(y_test, y_pred, beta=2, pos_label=1)
recall_class0 = recall_score(y_test, y_pred, pos_label=0)
recall_class1 = recall_score(y_test, y_pred, pos_label=1)
precision_class0 = precision_score(y_test, y_pred, pos_label=0)
precision_class1 = precision_score(y_test, y_pred, pos_label=1)
return {
"accuracy": accuracy,
"f1_macro": f1_macro,
"f2_macro": f2_macro,
"recall_macro": recall_macro,
"precision_macro": precision_macro,
"f1_class0": f1_class0,
"f1_class1": f1_class1,
"f2_class0": f2_class0,
"f2_class1": f2_class1,
"recall_class0": recall_class0,
"recall_class1": recall_class1,
"precision_class0": precision_class0,
"precision_class1": precision_class1,
}