comparing models
This commit is contained in:
139
train.py
Normal file
139
train.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import numpy
|
||||
import tqdm
|
||||
from imblearn.over_sampling import ADASYN, SMOTE, SVMSMOTE, BorderlineSMOTE, KMeansSMOTE
|
||||
from sklearn.metrics import (
|
||||
accuracy_score,
|
||||
f1_score,
|
||||
fbeta_score,
|
||||
precision_score,
|
||||
recall_score,
|
||||
)
|
||||
from sklearn.model_selection import StratifiedKFold
|
||||
|
||||
|
||||
def train_model_with_kfold(
|
||||
model, X, y, n_splits=5, random_state=42, smote=True, smote_method="kmeans"
|
||||
):
|
||||
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)
|
||||
|
||||
accuracy_list = []
|
||||
f1_macro_list = []
|
||||
f2_macro_list = []
|
||||
recall_macro_list = []
|
||||
precision_macro_list = []
|
||||
|
||||
# per-class lists
|
||||
f1_class0 = []
|
||||
f1_class1 = []
|
||||
f2_class0 = []
|
||||
f2_class1 = []
|
||||
recall_class0 = []
|
||||
recall_class1 = []
|
||||
precision_class0 = []
|
||||
precision_class1 = []
|
||||
|
||||
fold_num = 1
|
||||
|
||||
for fold_num, (train_idx, val_idx) in enumerate(
|
||||
tqdm.tqdm(skf.split(X, y), total=skf.n_splits, desc="Training Folds"), start=1
|
||||
):
|
||||
X_train, X_val = X.iloc[train_idx], X.iloc[val_idx]
|
||||
y_train, y_val = y.iloc[train_idx], y.iloc[val_idx]
|
||||
|
||||
if smote:
|
||||
if smote_method.lower() == "kmeans":
|
||||
sampler = KMeansSMOTE(
|
||||
k_neighbors=5,
|
||||
cluster_balance_threshold=0.1,
|
||||
random_state=random_state,
|
||||
)
|
||||
elif smote_method.lower() == "smote":
|
||||
sampler = SMOTE(k_neighbors=5, random_state=random_state)
|
||||
elif smote_method.lower() == "svmsmote":
|
||||
sampler = SVMSMOTE(k_neighbors=5, random_state=random_state)
|
||||
elif smote_method.lower() == "borderline":
|
||||
sampler = BorderlineSMOTE(k_neighbors=5, random_state=random_state)
|
||||
elif smote_method.lower() == "adasyn":
|
||||
sampler = ADASYN(n_neighbors=5, random_state=random_state)
|
||||
else:
|
||||
raise ValueError(f"Unknown smote_method: {smote_method}")
|
||||
|
||||
X_train, y_train = sampler.fit_resample(X_train, y_train)
|
||||
|
||||
model.fit(X_train, y_train)
|
||||
|
||||
y_pred = model.predict(X_val)
|
||||
|
||||
accuracy_list.append(accuracy_score(y_val, y_pred))
|
||||
f1_macro_list.append(f1_score(y_val, y_pred, average="macro"))
|
||||
f2_macro_list.append(fbeta_score(y_val, y_pred, beta=2, average="macro"))
|
||||
recall_macro_list.append(recall_score(y_val, y_pred, average="macro"))
|
||||
precision_macro_list.append(precision_score(y_val, y_pred, average="macro"))
|
||||
|
||||
f1_class0.append(f1_score(y_val, y_pred, pos_label=0))
|
||||
f1_class1.append(f1_score(y_val, y_pred, pos_label=1))
|
||||
|
||||
f2_class0.append(fbeta_score(y_val, y_pred, beta=2, pos_label=0))
|
||||
f2_class1.append(fbeta_score(y_val, y_pred, beta=2, pos_label=1))
|
||||
|
||||
recall_class0.append(recall_score(y_val, y_pred, pos_label=0))
|
||||
recall_class1.append(recall_score(y_val, y_pred, pos_label=1))
|
||||
|
||||
precision_class0.append(precision_score(y_val, y_pred, pos_label=0))
|
||||
precision_class1.append(precision_score(y_val, y_pred, pos_label=1))
|
||||
|
||||
fold_num += 1
|
||||
|
||||
return {
|
||||
"accuracy": numpy.mean(accuracy_list),
|
||||
"f1_macro": numpy.mean(f1_macro_list),
|
||||
"f2_macro": numpy.mean(f2_macro_list),
|
||||
"recall_macro": numpy.mean(recall_macro_list),
|
||||
"precision_macro": numpy.mean(precision_macro_list),
|
||||
"f1_class0": numpy.mean(f1_class0),
|
||||
"f1_class1": numpy.mean(f1_class1),
|
||||
"f2_class0": numpy.mean(f2_class0),
|
||||
"f2_class1": numpy.mean(f2_class1),
|
||||
"recall_class0": numpy.mean(recall_class0),
|
||||
"recall_class1": numpy.mean(recall_class1),
|
||||
"precision_class0": numpy.mean(precision_class0),
|
||||
"precision_class1": numpy.mean(precision_class1),
|
||||
}
|
||||
|
||||
|
||||
def test_model(model, X_test, y_test):
|
||||
y_pred = model.predict(X_test)
|
||||
|
||||
accuracy = accuracy_score(y_test, y_pred)
|
||||
f1_macro = f1_score(y_test, y_pred, average="macro")
|
||||
f2_macro = fbeta_score(y_test, y_pred, beta=2, average="macro")
|
||||
recall_macro = recall_score(y_test, y_pred, average="macro")
|
||||
precision_macro = precision_score(y_test, y_pred, average="macro")
|
||||
|
||||
f1_class0 = f1_score(y_test, y_pred, pos_label=0)
|
||||
f1_class1 = f1_score(y_test, y_pred, pos_label=1)
|
||||
|
||||
f2_class0 = fbeta_score(y_test, y_pred, beta=2, pos_label=0)
|
||||
f2_class1 = fbeta_score(y_test, y_pred, beta=2, pos_label=1)
|
||||
|
||||
recall_class0 = recall_score(y_test, y_pred, pos_label=0)
|
||||
recall_class1 = recall_score(y_test, y_pred, pos_label=1)
|
||||
|
||||
precision_class0 = precision_score(y_test, y_pred, pos_label=0)
|
||||
precision_class1 = precision_score(y_test, y_pred, pos_label=1)
|
||||
|
||||
return {
|
||||
"accuracy": accuracy,
|
||||
"f1_macro": f1_macro,
|
||||
"f2_macro": f2_macro,
|
||||
"recall_macro": recall_macro,
|
||||
"precision_macro": precision_macro,
|
||||
"f1_class0": f1_class0,
|
||||
"f1_class1": f1_class1,
|
||||
"f2_class0": f2_class0,
|
||||
"f2_class1": f2_class1,
|
||||
"recall_class0": recall_class0,
|
||||
"recall_class1": recall_class1,
|
||||
"precision_class0": precision_class0,
|
||||
"precision_class1": precision_class1,
|
||||
}
|
||||
Reference in New Issue
Block a user