import numpy import tqdm from imblearn.over_sampling import ADASYN, SMOTE, SVMSMOTE, BorderlineSMOTE, KMeansSMOTE from sklearn.metrics import ( accuracy_score, f1_score, fbeta_score, precision_score, recall_score, ) from sklearn.model_selection import StratifiedKFold def train_model_with_kfold( model, X, y, n_splits=5, random_state=42, smote=True, smote_method="kmeans" ): skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state) accuracy_list = [] f1_macro_list = [] f2_macro_list = [] recall_macro_list = [] precision_macro_list = [] # per-class lists f1_class0 = [] f1_class1 = [] f2_class0 = [] f2_class1 = [] recall_class0 = [] recall_class1 = [] precision_class0 = [] precision_class1 = [] fold_num = 1 for fold_num, (train_idx, val_idx) in enumerate( tqdm.tqdm(skf.split(X, y), total=skf.n_splits, desc="Training Folds"), start=1 ): X_train, X_val = X.iloc[train_idx], X.iloc[val_idx] y_train, y_val = y.iloc[train_idx], y.iloc[val_idx] if smote: if smote_method.lower() == "kmeans": sampler = KMeansSMOTE( k_neighbors=15, cluster_balance_threshold=0.1, random_state=random_state, ) elif smote_method.lower() == "smote": sampler = SMOTE(k_neighbors=15, random_state=random_state) elif smote_method.lower() == "svmsmote": sampler = SVMSMOTE(k_neighbors=15, random_state=random_state) elif smote_method.lower() == "borderline": sampler = BorderlineSMOTE(k_neighbors=15, random_state=random_state) elif smote_method.lower() == "adasyn": sampler = ADASYN(n_neighbors=15, random_state=random_state) else: raise ValueError(f"Unknown smote_method: {smote_method}") X_train, y_train = sampler.fit_resample(X_train, y_train) model.fit(X_train, y_train) y_pred = model.predict(X_val) accuracy_list.append(accuracy_score(y_val, y_pred)) f1_macro_list.append(f1_score(y_val, y_pred, average="macro")) f2_macro_list.append(fbeta_score(y_val, y_pred, beta=2, average="macro")) recall_macro_list.append(recall_score(y_val, y_pred, average="macro")) precision_macro_list.append(precision_score(y_val, y_pred, average="macro")) f1_class0.append(f1_score(y_val, y_pred, pos_label=0)) f1_class1.append(f1_score(y_val, y_pred, pos_label=1)) f2_class0.append(fbeta_score(y_val, y_pred, beta=2, pos_label=0)) f2_class1.append(fbeta_score(y_val, y_pred, beta=2, pos_label=1)) recall_class0.append(recall_score(y_val, y_pred, pos_label=0)) recall_class1.append(recall_score(y_val, y_pred, pos_label=1)) precision_class0.append(precision_score(y_val, y_pred, pos_label=0)) precision_class1.append(precision_score(y_val, y_pred, pos_label=1)) fold_num += 1 return { "accuracy": numpy.mean(accuracy_list), "f1_macro": numpy.mean(f1_macro_list), "f2_macro": numpy.mean(f2_macro_list), "recall_macro": numpy.mean(recall_macro_list), "precision_macro": numpy.mean(precision_macro_list), "f1_class0": numpy.mean(f1_class0), "f1_class1": numpy.mean(f1_class1), "f2_class0": numpy.mean(f2_class0), "f2_class1": numpy.mean(f2_class1), "recall_class0": numpy.mean(recall_class0), "recall_class1": numpy.mean(recall_class1), "precision_class0": numpy.mean(precision_class0), "precision_class1": numpy.mean(precision_class1), } def test_model(model, X_test, y_test): y_pred = model.predict(X_test) accuracy = accuracy_score(y_test, y_pred) f1_macro = f1_score(y_test, y_pred, average="macro") f2_macro = fbeta_score(y_test, y_pred, beta=2, average="macro") recall_macro = recall_score(y_test, y_pred, average="macro") precision_macro = precision_score(y_test, y_pred, average="macro") f1_class0 = f1_score(y_test, y_pred, pos_label=0) f1_class1 = f1_score(y_test, y_pred, pos_label=1) f2_class0 = fbeta_score(y_test, y_pred, beta=2, pos_label=0) f2_class1 = fbeta_score(y_test, y_pred, beta=2, pos_label=1) recall_class0 = recall_score(y_test, y_pred, pos_label=0) recall_class1 = recall_score(y_test, y_pred, pos_label=1) precision_class0 = precision_score(y_test, y_pred, pos_label=0) precision_class1 = precision_score(y_test, y_pred, pos_label=1) return { "accuracy": accuracy, "f1_macro": f1_macro, "f2_macro": f2_macro, "recall_macro": recall_macro, "precision_macro": precision_macro, "f1_class0": f1_class0, "f1_class1": f1_class1, "f2_class0": f2_class0, "f2_class1": f2_class1, "recall_class0": recall_class0, "recall_class1": recall_class1, "precision_class0": precision_class0, "precision_class1": precision_class1, }