comparing models

This commit is contained in:
2025-11-30 23:02:25 +01:00
parent 602cf15dc6
commit db2c405bac
7 changed files with 580 additions and 0 deletions

BIN
.DS_Store vendored Normal file

Binary file not shown.

1
.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
data/

25
data.py Normal file
View File

@@ -0,0 +1,25 @@
"""
Saeed Khosravi - 26 Nov 2025
"""
import os
import pandas
from train import test_model, train_model_with_kfold
from utils import missing_value_handler, scaling_handler
# STEP 1: handle missing values + remove id column + robust scaling
if not os.path.exists("./data/Ketamin_icp_cleaned.csv"):
data_path = "./data/Ketamine_icp.csv"
data_frame_imputed = missing_value_handler(data_path)
data_frame = scaling_handler(data_frame_imputed, method="robust_scaling")
data_frame.to_csv("./data/Ketamin_icp_cleaned.csv", index=False)
else:
data_frame = pandas.read_csv("./data/Ketamin_icp_cleaned.csv")
# STEP 2: Train and Test models

35
results.csv Normal file
View File

@@ -0,0 +1,35 @@
model,stage,accuracy,f1_macro,f2_macro,recall_macro,precision_macro,f1_class0,f1_class1,f2_class0,f2_class1,recall_class0,recall_class1,precision_class0,precision_class1,TP,TN,FP,FN
LGBM_SMOTE,train,0.98818499917456,0.8760126546483045,0.8399085858018045,0.8198112125154727,0.9598771423117448,0.9939444918466291,0.75808081744998,0.9967070984144822,0.6831100731891271,0.9985576102161305,0.6410648148148148,0.9893749043238158,0.9303793802996736,774,27036,2,34
LGBM_SMOTE,test,0.9867724867724867,0.8559208843672739,0.8127209992015513,0.7896857910481889,0.965374580868779,0.9932270501198291,0.7186147186147186,0.9966541196152238,0.6287878787878788,0.9989520016767973,0.5804195804195804,0.9875673435557397,0.9431818181818182,83,4766,5,60
LGBM_KMEANS_SMOTE,train,0.98818499917456,0.8760126546483045,0.8399085858018045,0.8198112125154727,0.9598771423117448,0.9939444918466291,0.75808081744998,0.9967070984144822,0.6831100731891271,0.9985576102161305,0.6410648148148148,0.9893749043238158,0.9303793802996736,774,27036,2,34
LGBM_KMEANS_SMOTE,test,0.9867724867724867,0.8559208843672739,0.8127209992015513,0.7896857910481889,0.965374580868779,0.9932270501198291,0.7186147186147186,0.9966541196152238,0.6287878787878788,0.9989520016767973,0.5804195804195804,0.9875673435557397,0.9431818181818182,83,4766,5,60
LGBM_SVM_SMOTE,train,0.98818499917456,0.8760126546483045,0.8399085858018045,0.8198112125154727,0.9598771423117448,0.9939444918466291,0.75808081744998,0.9967070984144822,0.6831100731891271,0.9985576102161305,0.6410648148148148,0.9893749043238158,0.9303793802996736,774,27036,2,34
LGBM_SVM_SMOTE,test,0.9867724867724867,0.8559208843672739,0.8127209992015513,0.7896857910481889,0.965374580868779,0.9932270501198291,0.7186147186147186,0.9966541196152238,0.6287878787878788,0.9989520016767973,0.5804195804195804,0.9875673435557397,0.9431818181818182,83,4766,5,60
LGBM_BORDERLINE_SMOTE,train,0.98818499917456,0.8760126546483045,0.8399085858018045,0.8198112125154727,0.9598771423117448,0.9939444918466291,0.75808081744998,0.9967070984144822,0.6831100731891271,0.9985576102161305,0.6410648148148148,0.9893749043238158,0.9303793802996736,774,27036,2,34
LGBM_BORDERLINE_SMOTE,test,0.9867724867724867,0.8559208843672739,0.8127209992015513,0.7896857910481889,0.965374580868779,0.9932270501198291,0.7186147186147186,0.9966541196152238,0.6287878787878788,0.9989520016767973,0.5804195804195804,0.9875673435557397,0.9431818181818182,83,4766,5,60
LGBM_ADASYN_SMOTE,train,0.98818499917456,0.8760126546483045,0.8399085858018045,0.8198112125154727,0.9598771423117448,0.9939444918466291,0.75808081744998,0.9967070984144822,0.6831100731891271,0.9985576102161305,0.6410648148148148,0.9893749043238158,0.9303793802996736,774,27036,2,34
LGBM_ADASYN_SMOTE,test,0.9867724867724867,0.8559208843672739,0.8127209992015513,0.7896857910481889,0.965374580868779,0.9932270501198291,0.7186147186147186,0.9966541196152238,0.6287878787878788,0.9989520016767973,0.5804195804195804,0.9875673435557397,0.9431818181818182,83,4766,5,60
LGBM_Balanced,train,0.9880772921438743,0.8819085772805074,0.8575020546207861,0.843178069610239,0.9330261074557831,0.9938798529106704,0.7699373016503442,0.9957519348782675,0.7192521743633049,0.9970042873686262,0.6893518518518518,0.990776780581843,0.8752754343297229,779,27033,5,29
LGBM_Balanced,test,0.9861619861619861,0.8571081584060867,0.8246557274872268,0.8063299098721441,0.9299298722724962,0.9929048414023373,0.7213114754098361,0.9955224505168013,0.6537890044576523,0.997275204359673,0.6153846153846154,0.9885726158321213,0.8712871287128713,88,4758,13,55
LGBM_DART,train,0.9880413597061434,0.8747819377304309,0.8394311701094301,0.8197295388100121,0.956851692191424,0.9938704735789482,0.7556934018819133,0.9965888563780074,0.6822734838408528,0.9984096949039746,0.6410493827160494,0.989373835326035,0.9243295490568126,685,27022,16,123
LGBM_DART,test,0.9871794871794872,0.8660089064245453,0.8302067070505215,0.8102456126979287,0.9484753203607024,0.9934286012308334,0.7385892116182573,0.996234309623431,0.664179104477612,0.9981136030182352,0.6223776223776224,0.9887873754152824,0.9081632653061225,89,4762,9,54
LGBM_GOSS,train,0.988185050764564,0.8767619724826792,0.8413979145609835,0.8216153002029628,0.9579916919528333,0.9939435911838446,0.7595803537815136,0.9966404297714536,0.6861553993505135,0.9984466497886416,0.644783950617284,0.9894818584294637,0.9265015254762032,775,27035,3,33
LGBM_GOSS,test,0.9871794871794872,0.8615239070778549,0.8193657257149385,0.7966787980411958,0.9662106135986732,0.9934340802501302,0.7296137339055794,0.9967374937259494,0.6419939577039275,0.9989520016767973,0.5944055944055944,0.9879767827529021,0.9444444444444444,85,4766,5,58
LGBM_RF,train,0.9858507062671537,0.8481335782755919,0.8099301710699093,0.7891912188845618,0.9416024225560872,0.99275226985963,0.7035148866915535,0.9958737030582172,0.6239866390816016,0.9979657711024567,0.5804166666666666,0.987593854471584,0.8956109906405902,507,26993,45,301
LGBM_RF,test,0.9855514855514855,0.8465095837474215,0.809386193422338,0.7890569920542673,0.9352652953120861,0.9925972265665728,0.70042194092827,0.9956492637215528,0.6231231231231231,0.9976944036889541,0.5804195804195804,0.9875518672199171,0.8829787234042553,83,4760,11,60
LGBM_scale_pos_weight,train,0.9883646355682123,0.8847082787258997,0.8597348514523183,0.8451146678749734,0.9373215094285751,0.9940273787240415,0.7753891787277578,0.9959218264040318,0.7235478765006047,0.9971892122931566,0.6930401234567901,0.9908875151444739,0.8837555037126762,783,27033,5,25
LGBM_scale_pos_weight,test,0.9865689865689866,0.8624097011164226,0.8311362209667295,0.8133229168651512,0.9313849935322168,0.9931120851596744,0.7317073170731707,0.9956057752667922,0.6666666666666666,0.997275204359673,0.6293706293706294,0.9889835792974434,0.8737864077669902,90,4758,13,53
LGBM_is_unbalance,train,0.9882927448977485,0.8841054144736109,0.8598757937903134,0.8456764784528097,0.9349710933550665,0.9939903291217178,0.774220499825504,0.9958405185143574,0.7239110690662693,0.9970782655475945,0.6942746913580247,0.9909237024663042,0.8790184842438284,779,27033,5,29
LGBM_is_unbalance,test,0.9867724867724867,0.863955516465169,0.8317145687693159,0.8134277166974715,0.9356694049190748,0.9932171553793175,0.7346938775510204,0.9957733511884834,0.6676557863501483,0.9974848040243136,0.6293706293706294,0.9889858686616791,0.8823529411764706,90,4759,12,53
XGB_scale_pos_weight,train,0.9881131600941002,0.8868996060146659,0.8705466767428707,0.8606000987409465,0.9188644360910505,0.9938914211334648,0.7799077908958669,0.9951365452837685,0.7459568082019727,0.9959687160004116,0.7252314814814815,0.9918245196483149,0.8459043525337859,781,27028,10,27
XGB_scale_pos_weight,test,0.9839234839234839,0.8421764514805791,0.8211201077315871,0.8085688153808045,0.8847258771929825,0.9917458990701076,0.6926070038910506,0.9935521688159438,0.6486880466472303,0.9947600083839866,0.6223776223776224,0.98875,0.7807017543859649,89,4746,25,54
CatBoost_balanced,train,0.9843784049402589,0.8696686267343388,0.8824472728294012,0.8916952848998795,0.8508242781484853,0.9919396338322237,0.7473976196364541,0.9908276010500254,0.7740669446087769,0.9900881006639566,0.7933024691358025,0.9938004847319636,0.7078480715650071,789,26898,140,19
CatBoost_balanced,test,0.9802604802604803,0.8348421298822796,0.8461546793313885,0.8541662696976049,0.8176680164072361,0.9898162729658793,0.6798679867986799,0.988757446094471,0.703551912568306,0.9880528191154894,0.7202797202797203,0.991586032814472,0.64375,103,4714,57,40
RandomForest_balanced,train,0.9780219876596711,0.69482483029087,0.6476767204797605,0.6284046291958619,0.9637732327782711,0.988805060981185,0.400844599600555,0.9952275898084167,0.3001258511511045,0.9995561719719707,0.2572530864197531,0.9782843830335899,0.9492620825229521,749,27037,1,59
RandomForest_balanced,test,0.9774114774114774,0.6876170818855889,0.6409384493281152,0.6220632228806615,0.9494516644359052,0.9884938322794651,0.3867403314917127,0.994991652754591,0.28688524590163933,0.9993712010060783,0.24475524475524477,0.977850697292863,0.9210526315789473,35,4768,3,108
BalancedRandomForest,train,0.939704118430013,0.7028311098818047,0.7781778852496469,0.8752773234871295,0.6471027486545569,0.9681428255443214,0.43751939421928815,0.9533096243066284,0.6030461461926654,0.9436719309248763,0.8068827160493829,0.9939232551592664,0.30028224214984744,790,25662,1376,18
BalancedRandomForest,test,0.9318274318274318,0.6688915684702886,0.7336858411761832,0.8190487986128313,0.622796487015859,0.9639513612396428,0.37383177570093457,0.9486994831822418,0.5186721991701245,0.9387968979249633,0.6993006993006993,0.9904909332153914,0.25510204081632654,100,4479,292,43
LogisticRegression_balanced,train,0.9390578117583936,0.7053593691260407,0.7842568337517993,0.888164180680811,0.6482233216447624,0.9677636915743127,0.44295504667776875,0.9522564007253292,0.6162572667782692,0.9421925588924864,0.8341358024691358,0.9947683858648215,0.3016782574247032,745,25582,1456,63
LogisticRegression_balanced,test,0.9356939356939357,0.68489389075354,0.7544562768745361,0.844781921076199,0.6342667509156945,0.9660141966014196,0.4037735849056604,0.9510397695989158,0.5578727841501564,0.9413120939006497,0.7482517482517482,0.9920477137176938,0.27648578811369506,107,4491,280,36
LGBM_FOCAL_LOSS,train,0.9874308822922471,0.8637586787288788,0.8216344627004155,0.7989945052063886,0.969041715035573,0.9935632802546319,0.7339540772031252,0.9968410405554005,0.6464278848454306,0.9990383931288267,0.5989506172839506,0.9881485197433785,0.9499349103277674,774,27036,2,34
LGBM_FOCAL_LOSS,test,0.9855514855514855,0.838525460793502,0.7925445481202558,0.7687067700691679,0.9626827249232757,0.9926064771425596,0.6844444444444444,0.9964040809499917,0.5886850152905199,0.9989520016767973,0.5384615384615384,0.9863410596026491,0.9390243902439024,77,4766,5,66
1 model stage accuracy f1_macro f2_macro recall_macro precision_macro f1_class0 f1_class1 f2_class0 f2_class1 recall_class0 recall_class1 precision_class0 precision_class1 TP TN FP FN
2 LGBM_SMOTE train 0.98818499917456 0.8760126546483045 0.8399085858018045 0.8198112125154727 0.9598771423117448 0.9939444918466291 0.75808081744998 0.9967070984144822 0.6831100731891271 0.9985576102161305 0.6410648148148148 0.9893749043238158 0.9303793802996736 774 27036 2 34
3 LGBM_SMOTE test 0.9867724867724867 0.8559208843672739 0.8127209992015513 0.7896857910481889 0.965374580868779 0.9932270501198291 0.7186147186147186 0.9966541196152238 0.6287878787878788 0.9989520016767973 0.5804195804195804 0.9875673435557397 0.9431818181818182 83 4766 5 60
4 LGBM_KMEANS_SMOTE train 0.98818499917456 0.8760126546483045 0.8399085858018045 0.8198112125154727 0.9598771423117448 0.9939444918466291 0.75808081744998 0.9967070984144822 0.6831100731891271 0.9985576102161305 0.6410648148148148 0.9893749043238158 0.9303793802996736 774 27036 2 34
5 LGBM_KMEANS_SMOTE test 0.9867724867724867 0.8559208843672739 0.8127209992015513 0.7896857910481889 0.965374580868779 0.9932270501198291 0.7186147186147186 0.9966541196152238 0.6287878787878788 0.9989520016767973 0.5804195804195804 0.9875673435557397 0.9431818181818182 83 4766 5 60
6 LGBM_SVM_SMOTE train 0.98818499917456 0.8760126546483045 0.8399085858018045 0.8198112125154727 0.9598771423117448 0.9939444918466291 0.75808081744998 0.9967070984144822 0.6831100731891271 0.9985576102161305 0.6410648148148148 0.9893749043238158 0.9303793802996736 774 27036 2 34
7 LGBM_SVM_SMOTE test 0.9867724867724867 0.8559208843672739 0.8127209992015513 0.7896857910481889 0.965374580868779 0.9932270501198291 0.7186147186147186 0.9966541196152238 0.6287878787878788 0.9989520016767973 0.5804195804195804 0.9875673435557397 0.9431818181818182 83 4766 5 60
8 LGBM_BORDERLINE_SMOTE train 0.98818499917456 0.8760126546483045 0.8399085858018045 0.8198112125154727 0.9598771423117448 0.9939444918466291 0.75808081744998 0.9967070984144822 0.6831100731891271 0.9985576102161305 0.6410648148148148 0.9893749043238158 0.9303793802996736 774 27036 2 34
9 LGBM_BORDERLINE_SMOTE test 0.9867724867724867 0.8559208843672739 0.8127209992015513 0.7896857910481889 0.965374580868779 0.9932270501198291 0.7186147186147186 0.9966541196152238 0.6287878787878788 0.9989520016767973 0.5804195804195804 0.9875673435557397 0.9431818181818182 83 4766 5 60
10 LGBM_ADASYN_SMOTE train 0.98818499917456 0.8760126546483045 0.8399085858018045 0.8198112125154727 0.9598771423117448 0.9939444918466291 0.75808081744998 0.9967070984144822 0.6831100731891271 0.9985576102161305 0.6410648148148148 0.9893749043238158 0.9303793802996736 774 27036 2 34
11 LGBM_ADASYN_SMOTE test 0.9867724867724867 0.8559208843672739 0.8127209992015513 0.7896857910481889 0.965374580868779 0.9932270501198291 0.7186147186147186 0.9966541196152238 0.6287878787878788 0.9989520016767973 0.5804195804195804 0.9875673435557397 0.9431818181818182 83 4766 5 60
12 LGBM_Balanced train 0.9880772921438743 0.8819085772805074 0.8575020546207861 0.843178069610239 0.9330261074557831 0.9938798529106704 0.7699373016503442 0.9957519348782675 0.7192521743633049 0.9970042873686262 0.6893518518518518 0.990776780581843 0.8752754343297229 779 27033 5 29
13 LGBM_Balanced test 0.9861619861619861 0.8571081584060867 0.8246557274872268 0.8063299098721441 0.9299298722724962 0.9929048414023373 0.7213114754098361 0.9955224505168013 0.6537890044576523 0.997275204359673 0.6153846153846154 0.9885726158321213 0.8712871287128713 88 4758 13 55
14 LGBM_DART train 0.9880413597061434 0.8747819377304309 0.8394311701094301 0.8197295388100121 0.956851692191424 0.9938704735789482 0.7556934018819133 0.9965888563780074 0.6822734838408528 0.9984096949039746 0.6410493827160494 0.989373835326035 0.9243295490568126 685 27022 16 123
15 LGBM_DART test 0.9871794871794872 0.8660089064245453 0.8302067070505215 0.8102456126979287 0.9484753203607024 0.9934286012308334 0.7385892116182573 0.996234309623431 0.664179104477612 0.9981136030182352 0.6223776223776224 0.9887873754152824 0.9081632653061225 89 4762 9 54
16 LGBM_GOSS train 0.988185050764564 0.8767619724826792 0.8413979145609835 0.8216153002029628 0.9579916919528333 0.9939435911838446 0.7595803537815136 0.9966404297714536 0.6861553993505135 0.9984466497886416 0.644783950617284 0.9894818584294637 0.9265015254762032 775 27035 3 33
17 LGBM_GOSS test 0.9871794871794872 0.8615239070778549 0.8193657257149385 0.7966787980411958 0.9662106135986732 0.9934340802501302 0.7296137339055794 0.9967374937259494 0.6419939577039275 0.9989520016767973 0.5944055944055944 0.9879767827529021 0.9444444444444444 85 4766 5 58
18 LGBM_RF train 0.9858507062671537 0.8481335782755919 0.8099301710699093 0.7891912188845618 0.9416024225560872 0.99275226985963 0.7035148866915535 0.9958737030582172 0.6239866390816016 0.9979657711024567 0.5804166666666666 0.987593854471584 0.8956109906405902 507 26993 45 301
19 LGBM_RF test 0.9855514855514855 0.8465095837474215 0.809386193422338 0.7890569920542673 0.9352652953120861 0.9925972265665728 0.70042194092827 0.9956492637215528 0.6231231231231231 0.9976944036889541 0.5804195804195804 0.9875518672199171 0.8829787234042553 83 4760 11 60
20 LGBM_scale_pos_weight train 0.9883646355682123 0.8847082787258997 0.8597348514523183 0.8451146678749734 0.9373215094285751 0.9940273787240415 0.7753891787277578 0.9959218264040318 0.7235478765006047 0.9971892122931566 0.6930401234567901 0.9908875151444739 0.8837555037126762 783 27033 5 25
21 LGBM_scale_pos_weight test 0.9865689865689866 0.8624097011164226 0.8311362209667295 0.8133229168651512 0.9313849935322168 0.9931120851596744 0.7317073170731707 0.9956057752667922 0.6666666666666666 0.997275204359673 0.6293706293706294 0.9889835792974434 0.8737864077669902 90 4758 13 53
22 LGBM_is_unbalance train 0.9882927448977485 0.8841054144736109 0.8598757937903134 0.8456764784528097 0.9349710933550665 0.9939903291217178 0.774220499825504 0.9958405185143574 0.7239110690662693 0.9970782655475945 0.6942746913580247 0.9909237024663042 0.8790184842438284 779 27033 5 29
23 LGBM_is_unbalance test 0.9867724867724867 0.863955516465169 0.8317145687693159 0.8134277166974715 0.9356694049190748 0.9932171553793175 0.7346938775510204 0.9957733511884834 0.6676557863501483 0.9974848040243136 0.6293706293706294 0.9889858686616791 0.8823529411764706 90 4759 12 53
24 XGB_scale_pos_weight train 0.9881131600941002 0.8868996060146659 0.8705466767428707 0.8606000987409465 0.9188644360910505 0.9938914211334648 0.7799077908958669 0.9951365452837685 0.7459568082019727 0.9959687160004116 0.7252314814814815 0.9918245196483149 0.8459043525337859 781 27028 10 27
25 XGB_scale_pos_weight test 0.9839234839234839 0.8421764514805791 0.8211201077315871 0.8085688153808045 0.8847258771929825 0.9917458990701076 0.6926070038910506 0.9935521688159438 0.6486880466472303 0.9947600083839866 0.6223776223776224 0.98875 0.7807017543859649 89 4746 25 54
26 CatBoost_balanced train 0.9843784049402589 0.8696686267343388 0.8824472728294012 0.8916952848998795 0.8508242781484853 0.9919396338322237 0.7473976196364541 0.9908276010500254 0.7740669446087769 0.9900881006639566 0.7933024691358025 0.9938004847319636 0.7078480715650071 789 26898 140 19
27 CatBoost_balanced test 0.9802604802604803 0.8348421298822796 0.8461546793313885 0.8541662696976049 0.8176680164072361 0.9898162729658793 0.6798679867986799 0.988757446094471 0.703551912568306 0.9880528191154894 0.7202797202797203 0.991586032814472 0.64375 103 4714 57 40
28 RandomForest_balanced train 0.9780219876596711 0.69482483029087 0.6476767204797605 0.6284046291958619 0.9637732327782711 0.988805060981185 0.400844599600555 0.9952275898084167 0.3001258511511045 0.9995561719719707 0.2572530864197531 0.9782843830335899 0.9492620825229521 749 27037 1 59
29 RandomForest_balanced test 0.9774114774114774 0.6876170818855889 0.6409384493281152 0.6220632228806615 0.9494516644359052 0.9884938322794651 0.3867403314917127 0.994991652754591 0.28688524590163933 0.9993712010060783 0.24475524475524477 0.977850697292863 0.9210526315789473 35 4768 3 108
30 BalancedRandomForest train 0.939704118430013 0.7028311098818047 0.7781778852496469 0.8752773234871295 0.6471027486545569 0.9681428255443214 0.43751939421928815 0.9533096243066284 0.6030461461926654 0.9436719309248763 0.8068827160493829 0.9939232551592664 0.30028224214984744 790 25662 1376 18
31 BalancedRandomForest test 0.9318274318274318 0.6688915684702886 0.7336858411761832 0.8190487986128313 0.622796487015859 0.9639513612396428 0.37383177570093457 0.9486994831822418 0.5186721991701245 0.9387968979249633 0.6993006993006993 0.9904909332153914 0.25510204081632654 100 4479 292 43
32 LogisticRegression_balanced train 0.9390578117583936 0.7053593691260407 0.7842568337517993 0.888164180680811 0.6482233216447624 0.9677636915743127 0.44295504667776875 0.9522564007253292 0.6162572667782692 0.9421925588924864 0.8341358024691358 0.9947683858648215 0.3016782574247032 745 25582 1456 63
33 LogisticRegression_balanced test 0.9356939356939357 0.68489389075354 0.7544562768745361 0.844781921076199 0.6342667509156945 0.9660141966014196 0.4037735849056604 0.9510397695989158 0.5578727841501564 0.9413120939006497 0.7482517482517482 0.9920477137176938 0.27648578811369506 107 4491 280 36
34 LGBM_FOCAL_LOSS train 0.9874308822922471 0.8637586787288788 0.8216344627004155 0.7989945052063886 0.969041715035573 0.9935632802546319 0.7339540772031252 0.9968410405554005 0.6464278848454306 0.9990383931288267 0.5989506172839506 0.9881485197433785 0.9499349103277674 774 27036 2 34
35 LGBM_FOCAL_LOSS test 0.9855514855514855 0.838525460793502 0.7925445481202558 0.7687067700691679 0.9626827249232757 0.9926064771425596 0.6844444444444444 0.9964040809499917 0.5886850152905199 0.9989520016767973 0.5384615384615384 0.9863410596026491 0.9390243902439024 77 4766 5 66

318
runner.py Normal file
View File

@@ -0,0 +1,318 @@
import pandas
from catboost import CatBoostClassifier
from imblearn.ensemble import BalancedRandomForestClassifier
from lightgbm import LGBMClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from xgboost import XGBClassifier
from custom_models.LGBMFocalWrapper import LGBMFocalWrapper
from train import test_model, train_model_with_kfold
data_frame = pandas.read_csv("./data/Ketamin_icp_cleaned.csv")
y = data_frame["label"]
X = data_frame.drop(columns=["label"])
x_train, x_test, y_train, y_test = train_test_split(
X,
y,
test_size=0.15,
stratify=y,
random_state=42,
)
neg = sum(y_train == 0)
pos = sum(y_train == 1)
scale_pos = neg / pos if pos > 0 else 1.0
models = [
{
"name": "LGBM_FOCAL_LOSS",
"model": LGBMFocalWrapper(
n_estimators=500,
learning_rate=0.05,
max_depth=-1,
subsample=0.8,
colsample_bytree=0.8,
random_state=42,
),
"smote": True,
"smote_method": "kmeans",
},
{
"name": "LGBM_SMOTE",
"model": LGBMClassifier(
n_estimators=500,
learning_rate=0.05,
max_depth=-1,
subsample=0.8,
colsample_bytree=0.8,
random_state=42,
verbose=-1,
n_jobs=-1,
),
"smote": True,
"smote_method": "smote",
},
{
"name": "LGBM_KMEANS_SMOTE",
"model": LGBMClassifier(
n_estimators=500,
learning_rate=0.05,
max_depth=-1,
subsample=0.8,
colsample_bytree=0.8,
random_state=42,
verbose=-1,
n_jobs=-1,
),
"smote": True,
"smote_method": "kmeans",
},
{
"name": "LGBM_SVM_SMOTE",
"model": LGBMClassifier(
n_estimators=500,
learning_rate=0.05,
max_depth=-1,
subsample=0.8,
colsample_bytree=0.8,
random_state=42,
verbose=-1,
n_jobs=-1,
),
"smote": True,
"smote_method": "svm",
},
{
"name": "LGBM_BORDERLINE_SMOTE",
"model": LGBMClassifier(
n_estimators=500,
learning_rate=0.05,
max_depth=-1,
subsample=0.8,
colsample_bytree=0.8,
random_state=42,
verbose=-1,
n_jobs=-1,
),
"smote": True,
"smote_method": "borderline",
},
{
"name": "LGBM_ADASYN_SMOTE",
"model": LGBMClassifier(
n_estimators=500,
learning_rate=0.05,
max_depth=-1,
subsample=0.8,
colsample_bytree=0.8,
random_state=42,
verbose=-1,
n_jobs=-1,
),
"smote": True,
"smote_method": "adasyn",
},
{
"name": "LGBM_Balanced",
"model": LGBMClassifier(
n_estimators=500,
learning_rate=0.05,
max_depth=-1,
subsample=0.8,
colsample_bytree=0.8,
class_weight="balanced",
random_state=42,
verbose=-1,
n_jobs=-1,
),
"smote": False,
},
{
"name": "LGBM_DART",
"model": LGBMClassifier(
n_estimators=500,
learning_rate=0.05,
max_depth=-1,
subsample=0.8,
colsample_bytree=0.8,
boosting_type="dart",
random_state=42,
verbose=-1,
n_jobs=-1,
),
"smote": True,
"smote_method": "kmeans",
},
{
"name": "LGBM_GOSS",
"model": LGBMClassifier(
n_estimators=500,
learning_rate=0.05,
max_depth=-1,
boosting_type="goss",
random_state=42,
verbose=-1,
n_jobs=-1,
),
"smote": True,
"smote_method": "kmeans",
},
{
"name": "LGBM_RF",
"model": LGBMClassifier(
n_estimators=500,
learning_rate=0.05,
max_depth=-1,
boosting_type="rf",
subsample=0.8,
colsample_bytree=0.8,
random_state=42,
verbose=-1,
n_jobs=-1,
),
"smote": True,
"smote_method": "kmeans",
},
{
"name": "LGBM_scale_pos_weight",
"model": LGBMClassifier(
n_estimators=500,
learning_rate=0.05,
max_depth=-1,
scale_pos_weight=scale_pos,
random_state=42,
verbose=-1,
n_jobs=-1,
),
"smote": False,
},
{
"name": "LGBM_is_unbalance",
"model": LGBMClassifier(
n_estimators=500,
learning_rate=0.05,
max_depth=-1,
is_unbalance=True,
random_state=42,
verbose=-1,
n_jobs=-1,
),
"smote": False,
},
{
"name": "LGBM_DART",
"model": LGBMClassifier(
n_estimators=500,
learning_rate=0.05,
max_depth=-1,
subsample=0.8,
colsample_bytree=0.8,
boosting_type="dart",
random_state=42,
verbose=-1,
n_jobs=-1,
),
"smote": True,
"smote_method": "kmeans",
},
{
"name": "XGB_scale_pos_weight",
"model": XGBClassifier(
n_estimators=500,
learning_rate=0.05,
max_depth=6,
scale_pos_weight=scale_pos,
random_state=42,
n_jobs=-1,
use_label_encoder=False,
eval_metric="logloss",
),
"smote": False,
},
{
"name": "CatBoost_balanced",
"model": CatBoostClassifier(
iterations=500,
learning_rate=0.05,
depth=6,
class_weights=[1, scale_pos],
random_state=42,
verbose=0,
),
"smote": False,
},
{
"name": "RandomForest_balanced",
"model": RandomForestClassifier(
n_estimators=500,
max_depth=None,
class_weight="balanced",
random_state=42,
n_jobs=-1,
),
"smote": False,
},
{
"name": "BalancedRandomForest",
"model": BalancedRandomForestClassifier(
n_estimators=500,
max_depth=None,
random_state=42,
n_jobs=-1,
),
"smote": False,
},
{
"name": "LogisticRegression_balanced",
"model": LogisticRegression(
max_iter=1000,
class_weight="balanced",
solver="liblinear",
random_state=42,
),
"smote": False,
},
]
def compute_confusion(y_true, y_pred):
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
return {"TP": tp, "TN": tn, "FP": fp, "FN": fn}
results_to_save = []
for m in models:
print(f"\n===== Training model: {m['name']} =====")
train_results = train_model_with_kfold(
m["model"], x_train, y_train, n_splits=10, smote=m["smote"]
)
y_train_pred = m["model"].predict(x_train)
train_confusion = compute_confusion(y_train, y_train_pred)
test_results = test_model(m["model"], x_test, y_test)
y_test_pred = m["model"].predict(x_test)
test_confusion = compute_confusion(y_test, y_test_pred)
results_to_save.append(
{"model": m["name"], "stage": "train", **train_results, **train_confusion}
)
results_to_save.append(
{"model": m["name"], "stage": "test", **test_results, **test_confusion}
)
results_df = pandas.DataFrame(results_to_save)
csv_file = "lightgbm_results.csv"
try:
results_df.to_csv(csv_file, mode="a", index=False, header=False)
except FileNotFoundError:
results_df.to_csv(csv_file, mode="w", index=False)
print(f"\nAll results saved to {csv_file}")

139
train.py Normal file
View File

@@ -0,0 +1,139 @@
import numpy
import tqdm
from imblearn.over_sampling import ADASYN, SMOTE, SVMSMOTE, BorderlineSMOTE, KMeansSMOTE
from sklearn.metrics import (
accuracy_score,
f1_score,
fbeta_score,
precision_score,
recall_score,
)
from sklearn.model_selection import StratifiedKFold
def train_model_with_kfold(
model, X, y, n_splits=5, random_state=42, smote=True, smote_method="kmeans"
):
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)
accuracy_list = []
f1_macro_list = []
f2_macro_list = []
recall_macro_list = []
precision_macro_list = []
# per-class lists
f1_class0 = []
f1_class1 = []
f2_class0 = []
f2_class1 = []
recall_class0 = []
recall_class1 = []
precision_class0 = []
precision_class1 = []
fold_num = 1
for fold_num, (train_idx, val_idx) in enumerate(
tqdm.tqdm(skf.split(X, y), total=skf.n_splits, desc="Training Folds"), start=1
):
X_train, X_val = X.iloc[train_idx], X.iloc[val_idx]
y_train, y_val = y.iloc[train_idx], y.iloc[val_idx]
if smote:
if smote_method.lower() == "kmeans":
sampler = KMeansSMOTE(
k_neighbors=5,
cluster_balance_threshold=0.1,
random_state=random_state,
)
elif smote_method.lower() == "smote":
sampler = SMOTE(k_neighbors=5, random_state=random_state)
elif smote_method.lower() == "svmsmote":
sampler = SVMSMOTE(k_neighbors=5, random_state=random_state)
elif smote_method.lower() == "borderline":
sampler = BorderlineSMOTE(k_neighbors=5, random_state=random_state)
elif smote_method.lower() == "adasyn":
sampler = ADASYN(n_neighbors=5, random_state=random_state)
else:
raise ValueError(f"Unknown smote_method: {smote_method}")
X_train, y_train = sampler.fit_resample(X_train, y_train)
model.fit(X_train, y_train)
y_pred = model.predict(X_val)
accuracy_list.append(accuracy_score(y_val, y_pred))
f1_macro_list.append(f1_score(y_val, y_pred, average="macro"))
f2_macro_list.append(fbeta_score(y_val, y_pred, beta=2, average="macro"))
recall_macro_list.append(recall_score(y_val, y_pred, average="macro"))
precision_macro_list.append(precision_score(y_val, y_pred, average="macro"))
f1_class0.append(f1_score(y_val, y_pred, pos_label=0))
f1_class1.append(f1_score(y_val, y_pred, pos_label=1))
f2_class0.append(fbeta_score(y_val, y_pred, beta=2, pos_label=0))
f2_class1.append(fbeta_score(y_val, y_pred, beta=2, pos_label=1))
recall_class0.append(recall_score(y_val, y_pred, pos_label=0))
recall_class1.append(recall_score(y_val, y_pred, pos_label=1))
precision_class0.append(precision_score(y_val, y_pred, pos_label=0))
precision_class1.append(precision_score(y_val, y_pred, pos_label=1))
fold_num += 1
return {
"accuracy": numpy.mean(accuracy_list),
"f1_macro": numpy.mean(f1_macro_list),
"f2_macro": numpy.mean(f2_macro_list),
"recall_macro": numpy.mean(recall_macro_list),
"precision_macro": numpy.mean(precision_macro_list),
"f1_class0": numpy.mean(f1_class0),
"f1_class1": numpy.mean(f1_class1),
"f2_class0": numpy.mean(f2_class0),
"f2_class1": numpy.mean(f2_class1),
"recall_class0": numpy.mean(recall_class0),
"recall_class1": numpy.mean(recall_class1),
"precision_class0": numpy.mean(precision_class0),
"precision_class1": numpy.mean(precision_class1),
}
def test_model(model, X_test, y_test):
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
f1_macro = f1_score(y_test, y_pred, average="macro")
f2_macro = fbeta_score(y_test, y_pred, beta=2, average="macro")
recall_macro = recall_score(y_test, y_pred, average="macro")
precision_macro = precision_score(y_test, y_pred, average="macro")
f1_class0 = f1_score(y_test, y_pred, pos_label=0)
f1_class1 = f1_score(y_test, y_pred, pos_label=1)
f2_class0 = fbeta_score(y_test, y_pred, beta=2, pos_label=0)
f2_class1 = fbeta_score(y_test, y_pred, beta=2, pos_label=1)
recall_class0 = recall_score(y_test, y_pred, pos_label=0)
recall_class1 = recall_score(y_test, y_pred, pos_label=1)
precision_class0 = precision_score(y_test, y_pred, pos_label=0)
precision_class1 = precision_score(y_test, y_pred, pos_label=1)
return {
"accuracy": accuracy,
"f1_macro": f1_macro,
"f2_macro": f2_macro,
"recall_macro": recall_macro,
"precision_macro": precision_macro,
"f1_class0": f1_class0,
"f1_class1": f1_class1,
"f2_class0": f2_class0,
"f2_class1": f2_class1,
"recall_class0": recall_class0,
"recall_class1": recall_class1,
"precision_class0": precision_class0,
"precision_class1": precision_class1,
}

62
utils.py Normal file
View File

@@ -0,0 +1,62 @@
"""
Saeed Khosravi - 27 Nov 2025
"""
def split_path(full_path):
import os
directory = os.path.dirname(full_path)
filename = os.path.splitext(os.path.basename(full_path))[0]
return directory, filename
def write_textfile(path, data_list):
with open(path, "w") as file:
for data in data_list:
file.write(f"{data} \n")
def missing_value_handler(data_path):
import pandas
from sklearn.impute import KNNImputer
data_directory, data_filename = split_path(data_path)
data_frame = pandas.read_csv(data_path)
columns = list(data_frame.head(0))
# remove column id
if "id" in columns:
data_frame = data_frame.drop("id", axis="columns")
columns = list(data_frame.head(0))
write_textfile(f"{data_directory}/columns.txt", columns)
# find missing values
missing_value_counts = data_frame.isna().sum()
write_textfile(f"{data_directory}/missing.txt", missing_value_counts)
# fill missing values - KNNImputer
imputer = KNNImputer(n_neighbors=5)
data_imputed = imputer.fit_transform(data_frame)
data_frame_imputed = pandas.DataFrame(data_imputed, columns=columns)
missing_value_counts = data_frame_imputed.isna().sum()
write_textfile(f"{data_directory}/no_missing.txt", missing_value_counts)
return data_frame_imputed
def scaling_handler(data_frame, method="robust_scaling"):
if method == "robust_scaling":
import pandas
from sklearn.preprocessing import RobustScaler
labels = data_frame["label"]
scaler = RobustScaler()
x = data_frame.drop("label", axis=1)
x_scale = scaler.fit_transform(x)
data_frame_scaled = pandas.DataFrame(x_scale, columns=x.columns)
data_frame_scaled["label"] = labels.values
return data_frame_scaled