comparing models
This commit is contained in:
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
data/
|
||||
25
data.py
Normal file
25
data.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""
|
||||
Saeed Khosravi - 26 Nov 2025
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pandas
|
||||
|
||||
from train import test_model, train_model_with_kfold
|
||||
from utils import missing_value_handler, scaling_handler
|
||||
|
||||
# STEP 1: handle missing values + remove id column + robust scaling
|
||||
|
||||
if not os.path.exists("./data/Ketamin_icp_cleaned.csv"):
|
||||
data_path = "./data/Ketamine_icp.csv"
|
||||
data_frame_imputed = missing_value_handler(data_path)
|
||||
|
||||
data_frame = scaling_handler(data_frame_imputed, method="robust_scaling")
|
||||
|
||||
data_frame.to_csv("./data/Ketamin_icp_cleaned.csv", index=False)
|
||||
else:
|
||||
data_frame = pandas.read_csv("./data/Ketamin_icp_cleaned.csv")
|
||||
|
||||
|
||||
# STEP 2: Train and Test models
|
||||
35
results.csv
Normal file
35
results.csv
Normal file
@@ -0,0 +1,35 @@
|
||||
model,stage,accuracy,f1_macro,f2_macro,recall_macro,precision_macro,f1_class0,f1_class1,f2_class0,f2_class1,recall_class0,recall_class1,precision_class0,precision_class1,TP,TN,FP,FN
|
||||
LGBM_SMOTE,train,0.98818499917456,0.8760126546483045,0.8399085858018045,0.8198112125154727,0.9598771423117448,0.9939444918466291,0.75808081744998,0.9967070984144822,0.6831100731891271,0.9985576102161305,0.6410648148148148,0.9893749043238158,0.9303793802996736,774,27036,2,34
|
||||
LGBM_SMOTE,test,0.9867724867724867,0.8559208843672739,0.8127209992015513,0.7896857910481889,0.965374580868779,0.9932270501198291,0.7186147186147186,0.9966541196152238,0.6287878787878788,0.9989520016767973,0.5804195804195804,0.9875673435557397,0.9431818181818182,83,4766,5,60
|
||||
LGBM_KMEANS_SMOTE,train,0.98818499917456,0.8760126546483045,0.8399085858018045,0.8198112125154727,0.9598771423117448,0.9939444918466291,0.75808081744998,0.9967070984144822,0.6831100731891271,0.9985576102161305,0.6410648148148148,0.9893749043238158,0.9303793802996736,774,27036,2,34
|
||||
LGBM_KMEANS_SMOTE,test,0.9867724867724867,0.8559208843672739,0.8127209992015513,0.7896857910481889,0.965374580868779,0.9932270501198291,0.7186147186147186,0.9966541196152238,0.6287878787878788,0.9989520016767973,0.5804195804195804,0.9875673435557397,0.9431818181818182,83,4766,5,60
|
||||
LGBM_SVM_SMOTE,train,0.98818499917456,0.8760126546483045,0.8399085858018045,0.8198112125154727,0.9598771423117448,0.9939444918466291,0.75808081744998,0.9967070984144822,0.6831100731891271,0.9985576102161305,0.6410648148148148,0.9893749043238158,0.9303793802996736,774,27036,2,34
|
||||
LGBM_SVM_SMOTE,test,0.9867724867724867,0.8559208843672739,0.8127209992015513,0.7896857910481889,0.965374580868779,0.9932270501198291,0.7186147186147186,0.9966541196152238,0.6287878787878788,0.9989520016767973,0.5804195804195804,0.9875673435557397,0.9431818181818182,83,4766,5,60
|
||||
LGBM_BORDERLINE_SMOTE,train,0.98818499917456,0.8760126546483045,0.8399085858018045,0.8198112125154727,0.9598771423117448,0.9939444918466291,0.75808081744998,0.9967070984144822,0.6831100731891271,0.9985576102161305,0.6410648148148148,0.9893749043238158,0.9303793802996736,774,27036,2,34
|
||||
LGBM_BORDERLINE_SMOTE,test,0.9867724867724867,0.8559208843672739,0.8127209992015513,0.7896857910481889,0.965374580868779,0.9932270501198291,0.7186147186147186,0.9966541196152238,0.6287878787878788,0.9989520016767973,0.5804195804195804,0.9875673435557397,0.9431818181818182,83,4766,5,60
|
||||
LGBM_ADASYN_SMOTE,train,0.98818499917456,0.8760126546483045,0.8399085858018045,0.8198112125154727,0.9598771423117448,0.9939444918466291,0.75808081744998,0.9967070984144822,0.6831100731891271,0.9985576102161305,0.6410648148148148,0.9893749043238158,0.9303793802996736,774,27036,2,34
|
||||
LGBM_ADASYN_SMOTE,test,0.9867724867724867,0.8559208843672739,0.8127209992015513,0.7896857910481889,0.965374580868779,0.9932270501198291,0.7186147186147186,0.9966541196152238,0.6287878787878788,0.9989520016767973,0.5804195804195804,0.9875673435557397,0.9431818181818182,83,4766,5,60
|
||||
LGBM_Balanced,train,0.9880772921438743,0.8819085772805074,0.8575020546207861,0.843178069610239,0.9330261074557831,0.9938798529106704,0.7699373016503442,0.9957519348782675,0.7192521743633049,0.9970042873686262,0.6893518518518518,0.990776780581843,0.8752754343297229,779,27033,5,29
|
||||
LGBM_Balanced,test,0.9861619861619861,0.8571081584060867,0.8246557274872268,0.8063299098721441,0.9299298722724962,0.9929048414023373,0.7213114754098361,0.9955224505168013,0.6537890044576523,0.997275204359673,0.6153846153846154,0.9885726158321213,0.8712871287128713,88,4758,13,55
|
||||
LGBM_DART,train,0.9880413597061434,0.8747819377304309,0.8394311701094301,0.8197295388100121,0.956851692191424,0.9938704735789482,0.7556934018819133,0.9965888563780074,0.6822734838408528,0.9984096949039746,0.6410493827160494,0.989373835326035,0.9243295490568126,685,27022,16,123
|
||||
LGBM_DART,test,0.9871794871794872,0.8660089064245453,0.8302067070505215,0.8102456126979287,0.9484753203607024,0.9934286012308334,0.7385892116182573,0.996234309623431,0.664179104477612,0.9981136030182352,0.6223776223776224,0.9887873754152824,0.9081632653061225,89,4762,9,54
|
||||
LGBM_GOSS,train,0.988185050764564,0.8767619724826792,0.8413979145609835,0.8216153002029628,0.9579916919528333,0.9939435911838446,0.7595803537815136,0.9966404297714536,0.6861553993505135,0.9984466497886416,0.644783950617284,0.9894818584294637,0.9265015254762032,775,27035,3,33
|
||||
LGBM_GOSS,test,0.9871794871794872,0.8615239070778549,0.8193657257149385,0.7966787980411958,0.9662106135986732,0.9934340802501302,0.7296137339055794,0.9967374937259494,0.6419939577039275,0.9989520016767973,0.5944055944055944,0.9879767827529021,0.9444444444444444,85,4766,5,58
|
||||
LGBM_RF,train,0.9858507062671537,0.8481335782755919,0.8099301710699093,0.7891912188845618,0.9416024225560872,0.99275226985963,0.7035148866915535,0.9958737030582172,0.6239866390816016,0.9979657711024567,0.5804166666666666,0.987593854471584,0.8956109906405902,507,26993,45,301
|
||||
LGBM_RF,test,0.9855514855514855,0.8465095837474215,0.809386193422338,0.7890569920542673,0.9352652953120861,0.9925972265665728,0.70042194092827,0.9956492637215528,0.6231231231231231,0.9976944036889541,0.5804195804195804,0.9875518672199171,0.8829787234042553,83,4760,11,60
|
||||
LGBM_scale_pos_weight,train,0.9883646355682123,0.8847082787258997,0.8597348514523183,0.8451146678749734,0.9373215094285751,0.9940273787240415,0.7753891787277578,0.9959218264040318,0.7235478765006047,0.9971892122931566,0.6930401234567901,0.9908875151444739,0.8837555037126762,783,27033,5,25
|
||||
LGBM_scale_pos_weight,test,0.9865689865689866,0.8624097011164226,0.8311362209667295,0.8133229168651512,0.9313849935322168,0.9931120851596744,0.7317073170731707,0.9956057752667922,0.6666666666666666,0.997275204359673,0.6293706293706294,0.9889835792974434,0.8737864077669902,90,4758,13,53
|
||||
LGBM_is_unbalance,train,0.9882927448977485,0.8841054144736109,0.8598757937903134,0.8456764784528097,0.9349710933550665,0.9939903291217178,0.774220499825504,0.9958405185143574,0.7239110690662693,0.9970782655475945,0.6942746913580247,0.9909237024663042,0.8790184842438284,779,27033,5,29
|
||||
LGBM_is_unbalance,test,0.9867724867724867,0.863955516465169,0.8317145687693159,0.8134277166974715,0.9356694049190748,0.9932171553793175,0.7346938775510204,0.9957733511884834,0.6676557863501483,0.9974848040243136,0.6293706293706294,0.9889858686616791,0.8823529411764706,90,4759,12,53
|
||||
XGB_scale_pos_weight,train,0.9881131600941002,0.8868996060146659,0.8705466767428707,0.8606000987409465,0.9188644360910505,0.9938914211334648,0.7799077908958669,0.9951365452837685,0.7459568082019727,0.9959687160004116,0.7252314814814815,0.9918245196483149,0.8459043525337859,781,27028,10,27
|
||||
XGB_scale_pos_weight,test,0.9839234839234839,0.8421764514805791,0.8211201077315871,0.8085688153808045,0.8847258771929825,0.9917458990701076,0.6926070038910506,0.9935521688159438,0.6486880466472303,0.9947600083839866,0.6223776223776224,0.98875,0.7807017543859649,89,4746,25,54
|
||||
CatBoost_balanced,train,0.9843784049402589,0.8696686267343388,0.8824472728294012,0.8916952848998795,0.8508242781484853,0.9919396338322237,0.7473976196364541,0.9908276010500254,0.7740669446087769,0.9900881006639566,0.7933024691358025,0.9938004847319636,0.7078480715650071,789,26898,140,19
|
||||
CatBoost_balanced,test,0.9802604802604803,0.8348421298822796,0.8461546793313885,0.8541662696976049,0.8176680164072361,0.9898162729658793,0.6798679867986799,0.988757446094471,0.703551912568306,0.9880528191154894,0.7202797202797203,0.991586032814472,0.64375,103,4714,57,40
|
||||
RandomForest_balanced,train,0.9780219876596711,0.69482483029087,0.6476767204797605,0.6284046291958619,0.9637732327782711,0.988805060981185,0.400844599600555,0.9952275898084167,0.3001258511511045,0.9995561719719707,0.2572530864197531,0.9782843830335899,0.9492620825229521,749,27037,1,59
|
||||
RandomForest_balanced,test,0.9774114774114774,0.6876170818855889,0.6409384493281152,0.6220632228806615,0.9494516644359052,0.9884938322794651,0.3867403314917127,0.994991652754591,0.28688524590163933,0.9993712010060783,0.24475524475524477,0.977850697292863,0.9210526315789473,35,4768,3,108
|
||||
BalancedRandomForest,train,0.939704118430013,0.7028311098818047,0.7781778852496469,0.8752773234871295,0.6471027486545569,0.9681428255443214,0.43751939421928815,0.9533096243066284,0.6030461461926654,0.9436719309248763,0.8068827160493829,0.9939232551592664,0.30028224214984744,790,25662,1376,18
|
||||
BalancedRandomForest,test,0.9318274318274318,0.6688915684702886,0.7336858411761832,0.8190487986128313,0.622796487015859,0.9639513612396428,0.37383177570093457,0.9486994831822418,0.5186721991701245,0.9387968979249633,0.6993006993006993,0.9904909332153914,0.25510204081632654,100,4479,292,43
|
||||
LogisticRegression_balanced,train,0.9390578117583936,0.7053593691260407,0.7842568337517993,0.888164180680811,0.6482233216447624,0.9677636915743127,0.44295504667776875,0.9522564007253292,0.6162572667782692,0.9421925588924864,0.8341358024691358,0.9947683858648215,0.3016782574247032,745,25582,1456,63
|
||||
LogisticRegression_balanced,test,0.9356939356939357,0.68489389075354,0.7544562768745361,0.844781921076199,0.6342667509156945,0.9660141966014196,0.4037735849056604,0.9510397695989158,0.5578727841501564,0.9413120939006497,0.7482517482517482,0.9920477137176938,0.27648578811369506,107,4491,280,36
|
||||
LGBM_FOCAL_LOSS,train,0.9874308822922471,0.8637586787288788,0.8216344627004155,0.7989945052063886,0.969041715035573,0.9935632802546319,0.7339540772031252,0.9968410405554005,0.6464278848454306,0.9990383931288267,0.5989506172839506,0.9881485197433785,0.9499349103277674,774,27036,2,34
|
||||
LGBM_FOCAL_LOSS,test,0.9855514855514855,0.838525460793502,0.7925445481202558,0.7687067700691679,0.9626827249232757,0.9926064771425596,0.6844444444444444,0.9964040809499917,0.5886850152905199,0.9989520016767973,0.5384615384615384,0.9863410596026491,0.9390243902439024,77,4766,5,66
|
||||
|
318
runner.py
Normal file
318
runner.py
Normal file
@@ -0,0 +1,318 @@
|
||||
import pandas
|
||||
from catboost import CatBoostClassifier
|
||||
from imblearn.ensemble import BalancedRandomForestClassifier
|
||||
from lightgbm import LGBMClassifier
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.metrics import confusion_matrix
|
||||
from sklearn.model_selection import train_test_split
|
||||
from xgboost import XGBClassifier
|
||||
|
||||
from custom_models.LGBMFocalWrapper import LGBMFocalWrapper
|
||||
from train import test_model, train_model_with_kfold
|
||||
|
||||
data_frame = pandas.read_csv("./data/Ketamin_icp_cleaned.csv")
|
||||
y = data_frame["label"]
|
||||
X = data_frame.drop(columns=["label"])
|
||||
|
||||
x_train, x_test, y_train, y_test = train_test_split(
|
||||
X,
|
||||
y,
|
||||
test_size=0.15,
|
||||
stratify=y,
|
||||
random_state=42,
|
||||
)
|
||||
|
||||
neg = sum(y_train == 0)
|
||||
pos = sum(y_train == 1)
|
||||
scale_pos = neg / pos if pos > 0 else 1.0
|
||||
|
||||
models = [
|
||||
{
|
||||
"name": "LGBM_FOCAL_LOSS",
|
||||
"model": LGBMFocalWrapper(
|
||||
n_estimators=500,
|
||||
learning_rate=0.05,
|
||||
max_depth=-1,
|
||||
subsample=0.8,
|
||||
colsample_bytree=0.8,
|
||||
random_state=42,
|
||||
),
|
||||
"smote": True,
|
||||
"smote_method": "kmeans",
|
||||
},
|
||||
{
|
||||
"name": "LGBM_SMOTE",
|
||||
"model": LGBMClassifier(
|
||||
n_estimators=500,
|
||||
learning_rate=0.05,
|
||||
max_depth=-1,
|
||||
subsample=0.8,
|
||||
colsample_bytree=0.8,
|
||||
random_state=42,
|
||||
verbose=-1,
|
||||
n_jobs=-1,
|
||||
),
|
||||
"smote": True,
|
||||
"smote_method": "smote",
|
||||
},
|
||||
{
|
||||
"name": "LGBM_KMEANS_SMOTE",
|
||||
"model": LGBMClassifier(
|
||||
n_estimators=500,
|
||||
learning_rate=0.05,
|
||||
max_depth=-1,
|
||||
subsample=0.8,
|
||||
colsample_bytree=0.8,
|
||||
random_state=42,
|
||||
verbose=-1,
|
||||
n_jobs=-1,
|
||||
),
|
||||
"smote": True,
|
||||
"smote_method": "kmeans",
|
||||
},
|
||||
{
|
||||
"name": "LGBM_SVM_SMOTE",
|
||||
"model": LGBMClassifier(
|
||||
n_estimators=500,
|
||||
learning_rate=0.05,
|
||||
max_depth=-1,
|
||||
subsample=0.8,
|
||||
colsample_bytree=0.8,
|
||||
random_state=42,
|
||||
verbose=-1,
|
||||
n_jobs=-1,
|
||||
),
|
||||
"smote": True,
|
||||
"smote_method": "svm",
|
||||
},
|
||||
{
|
||||
"name": "LGBM_BORDERLINE_SMOTE",
|
||||
"model": LGBMClassifier(
|
||||
n_estimators=500,
|
||||
learning_rate=0.05,
|
||||
max_depth=-1,
|
||||
subsample=0.8,
|
||||
colsample_bytree=0.8,
|
||||
random_state=42,
|
||||
verbose=-1,
|
||||
n_jobs=-1,
|
||||
),
|
||||
"smote": True,
|
||||
"smote_method": "borderline",
|
||||
},
|
||||
{
|
||||
"name": "LGBM_ADASYN_SMOTE",
|
||||
"model": LGBMClassifier(
|
||||
n_estimators=500,
|
||||
learning_rate=0.05,
|
||||
max_depth=-1,
|
||||
subsample=0.8,
|
||||
colsample_bytree=0.8,
|
||||
random_state=42,
|
||||
verbose=-1,
|
||||
n_jobs=-1,
|
||||
),
|
||||
"smote": True,
|
||||
"smote_method": "adasyn",
|
||||
},
|
||||
{
|
||||
"name": "LGBM_Balanced",
|
||||
"model": LGBMClassifier(
|
||||
n_estimators=500,
|
||||
learning_rate=0.05,
|
||||
max_depth=-1,
|
||||
subsample=0.8,
|
||||
colsample_bytree=0.8,
|
||||
class_weight="balanced",
|
||||
random_state=42,
|
||||
verbose=-1,
|
||||
n_jobs=-1,
|
||||
),
|
||||
"smote": False,
|
||||
},
|
||||
{
|
||||
"name": "LGBM_DART",
|
||||
"model": LGBMClassifier(
|
||||
n_estimators=500,
|
||||
learning_rate=0.05,
|
||||
max_depth=-1,
|
||||
subsample=0.8,
|
||||
colsample_bytree=0.8,
|
||||
boosting_type="dart",
|
||||
random_state=42,
|
||||
verbose=-1,
|
||||
n_jobs=-1,
|
||||
),
|
||||
"smote": True,
|
||||
"smote_method": "kmeans",
|
||||
},
|
||||
{
|
||||
"name": "LGBM_GOSS",
|
||||
"model": LGBMClassifier(
|
||||
n_estimators=500,
|
||||
learning_rate=0.05,
|
||||
max_depth=-1,
|
||||
boosting_type="goss",
|
||||
random_state=42,
|
||||
verbose=-1,
|
||||
n_jobs=-1,
|
||||
),
|
||||
"smote": True,
|
||||
"smote_method": "kmeans",
|
||||
},
|
||||
{
|
||||
"name": "LGBM_RF",
|
||||
"model": LGBMClassifier(
|
||||
n_estimators=500,
|
||||
learning_rate=0.05,
|
||||
max_depth=-1,
|
||||
boosting_type="rf",
|
||||
subsample=0.8,
|
||||
colsample_bytree=0.8,
|
||||
random_state=42,
|
||||
verbose=-1,
|
||||
n_jobs=-1,
|
||||
),
|
||||
"smote": True,
|
||||
"smote_method": "kmeans",
|
||||
},
|
||||
{
|
||||
"name": "LGBM_scale_pos_weight",
|
||||
"model": LGBMClassifier(
|
||||
n_estimators=500,
|
||||
learning_rate=0.05,
|
||||
max_depth=-1,
|
||||
scale_pos_weight=scale_pos,
|
||||
random_state=42,
|
||||
verbose=-1,
|
||||
n_jobs=-1,
|
||||
),
|
||||
"smote": False,
|
||||
},
|
||||
{
|
||||
"name": "LGBM_is_unbalance",
|
||||
"model": LGBMClassifier(
|
||||
n_estimators=500,
|
||||
learning_rate=0.05,
|
||||
max_depth=-1,
|
||||
is_unbalance=True,
|
||||
random_state=42,
|
||||
verbose=-1,
|
||||
n_jobs=-1,
|
||||
),
|
||||
"smote": False,
|
||||
},
|
||||
{
|
||||
"name": "LGBM_DART",
|
||||
"model": LGBMClassifier(
|
||||
n_estimators=500,
|
||||
learning_rate=0.05,
|
||||
max_depth=-1,
|
||||
subsample=0.8,
|
||||
colsample_bytree=0.8,
|
||||
boosting_type="dart",
|
||||
random_state=42,
|
||||
verbose=-1,
|
||||
n_jobs=-1,
|
||||
),
|
||||
"smote": True,
|
||||
"smote_method": "kmeans",
|
||||
},
|
||||
{
|
||||
"name": "XGB_scale_pos_weight",
|
||||
"model": XGBClassifier(
|
||||
n_estimators=500,
|
||||
learning_rate=0.05,
|
||||
max_depth=6,
|
||||
scale_pos_weight=scale_pos,
|
||||
random_state=42,
|
||||
n_jobs=-1,
|
||||
use_label_encoder=False,
|
||||
eval_metric="logloss",
|
||||
),
|
||||
"smote": False,
|
||||
},
|
||||
{
|
||||
"name": "CatBoost_balanced",
|
||||
"model": CatBoostClassifier(
|
||||
iterations=500,
|
||||
learning_rate=0.05,
|
||||
depth=6,
|
||||
class_weights=[1, scale_pos],
|
||||
random_state=42,
|
||||
verbose=0,
|
||||
),
|
||||
"smote": False,
|
||||
},
|
||||
{
|
||||
"name": "RandomForest_balanced",
|
||||
"model": RandomForestClassifier(
|
||||
n_estimators=500,
|
||||
max_depth=None,
|
||||
class_weight="balanced",
|
||||
random_state=42,
|
||||
n_jobs=-1,
|
||||
),
|
||||
"smote": False,
|
||||
},
|
||||
{
|
||||
"name": "BalancedRandomForest",
|
||||
"model": BalancedRandomForestClassifier(
|
||||
n_estimators=500,
|
||||
max_depth=None,
|
||||
random_state=42,
|
||||
n_jobs=-1,
|
||||
),
|
||||
"smote": False,
|
||||
},
|
||||
{
|
||||
"name": "LogisticRegression_balanced",
|
||||
"model": LogisticRegression(
|
||||
max_iter=1000,
|
||||
class_weight="balanced",
|
||||
solver="liblinear",
|
||||
random_state=42,
|
||||
),
|
||||
"smote": False,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def compute_confusion(y_true, y_pred):
|
||||
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
|
||||
return {"TP": tp, "TN": tn, "FP": fp, "FN": fn}
|
||||
|
||||
|
||||
results_to_save = []
|
||||
|
||||
for m in models:
|
||||
print(f"\n===== Training model: {m['name']} =====")
|
||||
|
||||
train_results = train_model_with_kfold(
|
||||
m["model"], x_train, y_train, n_splits=10, smote=m["smote"]
|
||||
)
|
||||
|
||||
y_train_pred = m["model"].predict(x_train)
|
||||
train_confusion = compute_confusion(y_train, y_train_pred)
|
||||
|
||||
test_results = test_model(m["model"], x_test, y_test)
|
||||
y_test_pred = m["model"].predict(x_test)
|
||||
test_confusion = compute_confusion(y_test, y_test_pred)
|
||||
|
||||
results_to_save.append(
|
||||
{"model": m["name"], "stage": "train", **train_results, **train_confusion}
|
||||
)
|
||||
results_to_save.append(
|
||||
{"model": m["name"], "stage": "test", **test_results, **test_confusion}
|
||||
)
|
||||
|
||||
results_df = pandas.DataFrame(results_to_save)
|
||||
csv_file = "lightgbm_results.csv"
|
||||
|
||||
try:
|
||||
results_df.to_csv(csv_file, mode="a", index=False, header=False)
|
||||
except FileNotFoundError:
|
||||
results_df.to_csv(csv_file, mode="w", index=False)
|
||||
|
||||
print(f"\nAll results saved to {csv_file}")
|
||||
139
train.py
Normal file
139
train.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import numpy
|
||||
import tqdm
|
||||
from imblearn.over_sampling import ADASYN, SMOTE, SVMSMOTE, BorderlineSMOTE, KMeansSMOTE
|
||||
from sklearn.metrics import (
|
||||
accuracy_score,
|
||||
f1_score,
|
||||
fbeta_score,
|
||||
precision_score,
|
||||
recall_score,
|
||||
)
|
||||
from sklearn.model_selection import StratifiedKFold
|
||||
|
||||
|
||||
def train_model_with_kfold(
|
||||
model, X, y, n_splits=5, random_state=42, smote=True, smote_method="kmeans"
|
||||
):
|
||||
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)
|
||||
|
||||
accuracy_list = []
|
||||
f1_macro_list = []
|
||||
f2_macro_list = []
|
||||
recall_macro_list = []
|
||||
precision_macro_list = []
|
||||
|
||||
# per-class lists
|
||||
f1_class0 = []
|
||||
f1_class1 = []
|
||||
f2_class0 = []
|
||||
f2_class1 = []
|
||||
recall_class0 = []
|
||||
recall_class1 = []
|
||||
precision_class0 = []
|
||||
precision_class1 = []
|
||||
|
||||
fold_num = 1
|
||||
|
||||
for fold_num, (train_idx, val_idx) in enumerate(
|
||||
tqdm.tqdm(skf.split(X, y), total=skf.n_splits, desc="Training Folds"), start=1
|
||||
):
|
||||
X_train, X_val = X.iloc[train_idx], X.iloc[val_idx]
|
||||
y_train, y_val = y.iloc[train_idx], y.iloc[val_idx]
|
||||
|
||||
if smote:
|
||||
if smote_method.lower() == "kmeans":
|
||||
sampler = KMeansSMOTE(
|
||||
k_neighbors=5,
|
||||
cluster_balance_threshold=0.1,
|
||||
random_state=random_state,
|
||||
)
|
||||
elif smote_method.lower() == "smote":
|
||||
sampler = SMOTE(k_neighbors=5, random_state=random_state)
|
||||
elif smote_method.lower() == "svmsmote":
|
||||
sampler = SVMSMOTE(k_neighbors=5, random_state=random_state)
|
||||
elif smote_method.lower() == "borderline":
|
||||
sampler = BorderlineSMOTE(k_neighbors=5, random_state=random_state)
|
||||
elif smote_method.lower() == "adasyn":
|
||||
sampler = ADASYN(n_neighbors=5, random_state=random_state)
|
||||
else:
|
||||
raise ValueError(f"Unknown smote_method: {smote_method}")
|
||||
|
||||
X_train, y_train = sampler.fit_resample(X_train, y_train)
|
||||
|
||||
model.fit(X_train, y_train)
|
||||
|
||||
y_pred = model.predict(X_val)
|
||||
|
||||
accuracy_list.append(accuracy_score(y_val, y_pred))
|
||||
f1_macro_list.append(f1_score(y_val, y_pred, average="macro"))
|
||||
f2_macro_list.append(fbeta_score(y_val, y_pred, beta=2, average="macro"))
|
||||
recall_macro_list.append(recall_score(y_val, y_pred, average="macro"))
|
||||
precision_macro_list.append(precision_score(y_val, y_pred, average="macro"))
|
||||
|
||||
f1_class0.append(f1_score(y_val, y_pred, pos_label=0))
|
||||
f1_class1.append(f1_score(y_val, y_pred, pos_label=1))
|
||||
|
||||
f2_class0.append(fbeta_score(y_val, y_pred, beta=2, pos_label=0))
|
||||
f2_class1.append(fbeta_score(y_val, y_pred, beta=2, pos_label=1))
|
||||
|
||||
recall_class0.append(recall_score(y_val, y_pred, pos_label=0))
|
||||
recall_class1.append(recall_score(y_val, y_pred, pos_label=1))
|
||||
|
||||
precision_class0.append(precision_score(y_val, y_pred, pos_label=0))
|
||||
precision_class1.append(precision_score(y_val, y_pred, pos_label=1))
|
||||
|
||||
fold_num += 1
|
||||
|
||||
return {
|
||||
"accuracy": numpy.mean(accuracy_list),
|
||||
"f1_macro": numpy.mean(f1_macro_list),
|
||||
"f2_macro": numpy.mean(f2_macro_list),
|
||||
"recall_macro": numpy.mean(recall_macro_list),
|
||||
"precision_macro": numpy.mean(precision_macro_list),
|
||||
"f1_class0": numpy.mean(f1_class0),
|
||||
"f1_class1": numpy.mean(f1_class1),
|
||||
"f2_class0": numpy.mean(f2_class0),
|
||||
"f2_class1": numpy.mean(f2_class1),
|
||||
"recall_class0": numpy.mean(recall_class0),
|
||||
"recall_class1": numpy.mean(recall_class1),
|
||||
"precision_class0": numpy.mean(precision_class0),
|
||||
"precision_class1": numpy.mean(precision_class1),
|
||||
}
|
||||
|
||||
|
||||
def test_model(model, X_test, y_test):
|
||||
y_pred = model.predict(X_test)
|
||||
|
||||
accuracy = accuracy_score(y_test, y_pred)
|
||||
f1_macro = f1_score(y_test, y_pred, average="macro")
|
||||
f2_macro = fbeta_score(y_test, y_pred, beta=2, average="macro")
|
||||
recall_macro = recall_score(y_test, y_pred, average="macro")
|
||||
precision_macro = precision_score(y_test, y_pred, average="macro")
|
||||
|
||||
f1_class0 = f1_score(y_test, y_pred, pos_label=0)
|
||||
f1_class1 = f1_score(y_test, y_pred, pos_label=1)
|
||||
|
||||
f2_class0 = fbeta_score(y_test, y_pred, beta=2, pos_label=0)
|
||||
f2_class1 = fbeta_score(y_test, y_pred, beta=2, pos_label=1)
|
||||
|
||||
recall_class0 = recall_score(y_test, y_pred, pos_label=0)
|
||||
recall_class1 = recall_score(y_test, y_pred, pos_label=1)
|
||||
|
||||
precision_class0 = precision_score(y_test, y_pred, pos_label=0)
|
||||
precision_class1 = precision_score(y_test, y_pred, pos_label=1)
|
||||
|
||||
return {
|
||||
"accuracy": accuracy,
|
||||
"f1_macro": f1_macro,
|
||||
"f2_macro": f2_macro,
|
||||
"recall_macro": recall_macro,
|
||||
"precision_macro": precision_macro,
|
||||
"f1_class0": f1_class0,
|
||||
"f1_class1": f1_class1,
|
||||
"f2_class0": f2_class0,
|
||||
"f2_class1": f2_class1,
|
||||
"recall_class0": recall_class0,
|
||||
"recall_class1": recall_class1,
|
||||
"precision_class0": precision_class0,
|
||||
"precision_class1": precision_class1,
|
||||
}
|
||||
62
utils.py
Normal file
62
utils.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
Saeed Khosravi - 27 Nov 2025
|
||||
"""
|
||||
|
||||
|
||||
def split_path(full_path):
|
||||
import os
|
||||
|
||||
directory = os.path.dirname(full_path)
|
||||
filename = os.path.splitext(os.path.basename(full_path))[0]
|
||||
return directory, filename
|
||||
|
||||
|
||||
def write_textfile(path, data_list):
|
||||
with open(path, "w") as file:
|
||||
for data in data_list:
|
||||
file.write(f"{data} \n")
|
||||
|
||||
|
||||
def missing_value_handler(data_path):
|
||||
import pandas
|
||||
from sklearn.impute import KNNImputer
|
||||
|
||||
data_directory, data_filename = split_path(data_path)
|
||||
|
||||
data_frame = pandas.read_csv(data_path)
|
||||
|
||||
columns = list(data_frame.head(0))
|
||||
# remove column id
|
||||
if "id" in columns:
|
||||
data_frame = data_frame.drop("id", axis="columns")
|
||||
|
||||
columns = list(data_frame.head(0))
|
||||
write_textfile(f"{data_directory}/columns.txt", columns)
|
||||
|
||||
# find missing values
|
||||
missing_value_counts = data_frame.isna().sum()
|
||||
write_textfile(f"{data_directory}/missing.txt", missing_value_counts)
|
||||
|
||||
# fill missing values - KNNImputer
|
||||
|
||||
imputer = KNNImputer(n_neighbors=5)
|
||||
data_imputed = imputer.fit_transform(data_frame)
|
||||
data_frame_imputed = pandas.DataFrame(data_imputed, columns=columns)
|
||||
|
||||
missing_value_counts = data_frame_imputed.isna().sum()
|
||||
write_textfile(f"{data_directory}/no_missing.txt", missing_value_counts)
|
||||
return data_frame_imputed
|
||||
|
||||
|
||||
def scaling_handler(data_frame, method="robust_scaling"):
|
||||
if method == "robust_scaling":
|
||||
import pandas
|
||||
from sklearn.preprocessing import RobustScaler
|
||||
|
||||
labels = data_frame["label"]
|
||||
scaler = RobustScaler()
|
||||
x = data_frame.drop("label", axis=1)
|
||||
x_scale = scaler.fit_transform(x)
|
||||
data_frame_scaled = pandas.DataFrame(x_scale, columns=x.columns)
|
||||
data_frame_scaled["label"] = labels.values
|
||||
return data_frame_scaled
|
||||
Reference in New Issue
Block a user