diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..78a7a8e Binary files /dev/null and b/.DS_Store differ diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..8fce603 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +data/ diff --git a/data.py b/data.py new file mode 100644 index 0000000..a9cb4ac --- /dev/null +++ b/data.py @@ -0,0 +1,25 @@ +""" +Saeed Khosravi - 26 Nov 2025 +""" + +import os + +import pandas + +from train import test_model, train_model_with_kfold +from utils import missing_value_handler, scaling_handler + +# STEP 1: handle missing values + remove id column + robust scaling + +if not os.path.exists("./data/Ketamin_icp_cleaned.csv"): + data_path = "./data/Ketamine_icp.csv" + data_frame_imputed = missing_value_handler(data_path) + + data_frame = scaling_handler(data_frame_imputed, method="robust_scaling") + + data_frame.to_csv("./data/Ketamin_icp_cleaned.csv", index=False) +else: + data_frame = pandas.read_csv("./data/Ketamin_icp_cleaned.csv") + + +# STEP 2: Train and Test models diff --git a/results.csv b/results.csv new file mode 100644 index 0000000..38f0463 --- /dev/null +++ b/results.csv @@ -0,0 +1,35 @@ +model,stage,accuracy,f1_macro,f2_macro,recall_macro,precision_macro,f1_class0,f1_class1,f2_class0,f2_class1,recall_class0,recall_class1,precision_class0,precision_class1,TP,TN,FP,FN +LGBM_SMOTE,train,0.98818499917456,0.8760126546483045,0.8399085858018045,0.8198112125154727,0.9598771423117448,0.9939444918466291,0.75808081744998,0.9967070984144822,0.6831100731891271,0.9985576102161305,0.6410648148148148,0.9893749043238158,0.9303793802996736,774,27036,2,34 +LGBM_SMOTE,test,0.9867724867724867,0.8559208843672739,0.8127209992015513,0.7896857910481889,0.965374580868779,0.9932270501198291,0.7186147186147186,0.9966541196152238,0.6287878787878788,0.9989520016767973,0.5804195804195804,0.9875673435557397,0.9431818181818182,83,4766,5,60 +LGBM_KMEANS_SMOTE,train,0.98818499917456,0.8760126546483045,0.8399085858018045,0.8198112125154727,0.9598771423117448,0.9939444918466291,0.75808081744998,0.9967070984144822,0.6831100731891271,0.9985576102161305,0.6410648148148148,0.9893749043238158,0.9303793802996736,774,27036,2,34 +LGBM_KMEANS_SMOTE,test,0.9867724867724867,0.8559208843672739,0.8127209992015513,0.7896857910481889,0.965374580868779,0.9932270501198291,0.7186147186147186,0.9966541196152238,0.6287878787878788,0.9989520016767973,0.5804195804195804,0.9875673435557397,0.9431818181818182,83,4766,5,60 +LGBM_SVM_SMOTE,train,0.98818499917456,0.8760126546483045,0.8399085858018045,0.8198112125154727,0.9598771423117448,0.9939444918466291,0.75808081744998,0.9967070984144822,0.6831100731891271,0.9985576102161305,0.6410648148148148,0.9893749043238158,0.9303793802996736,774,27036,2,34 +LGBM_SVM_SMOTE,test,0.9867724867724867,0.8559208843672739,0.8127209992015513,0.7896857910481889,0.965374580868779,0.9932270501198291,0.7186147186147186,0.9966541196152238,0.6287878787878788,0.9989520016767973,0.5804195804195804,0.9875673435557397,0.9431818181818182,83,4766,5,60 +LGBM_BORDERLINE_SMOTE,train,0.98818499917456,0.8760126546483045,0.8399085858018045,0.8198112125154727,0.9598771423117448,0.9939444918466291,0.75808081744998,0.9967070984144822,0.6831100731891271,0.9985576102161305,0.6410648148148148,0.9893749043238158,0.9303793802996736,774,27036,2,34 +LGBM_BORDERLINE_SMOTE,test,0.9867724867724867,0.8559208843672739,0.8127209992015513,0.7896857910481889,0.965374580868779,0.9932270501198291,0.7186147186147186,0.9966541196152238,0.6287878787878788,0.9989520016767973,0.5804195804195804,0.9875673435557397,0.9431818181818182,83,4766,5,60 +LGBM_ADASYN_SMOTE,train,0.98818499917456,0.8760126546483045,0.8399085858018045,0.8198112125154727,0.9598771423117448,0.9939444918466291,0.75808081744998,0.9967070984144822,0.6831100731891271,0.9985576102161305,0.6410648148148148,0.9893749043238158,0.9303793802996736,774,27036,2,34 +LGBM_ADASYN_SMOTE,test,0.9867724867724867,0.8559208843672739,0.8127209992015513,0.7896857910481889,0.965374580868779,0.9932270501198291,0.7186147186147186,0.9966541196152238,0.6287878787878788,0.9989520016767973,0.5804195804195804,0.9875673435557397,0.9431818181818182,83,4766,5,60 +LGBM_Balanced,train,0.9880772921438743,0.8819085772805074,0.8575020546207861,0.843178069610239,0.9330261074557831,0.9938798529106704,0.7699373016503442,0.9957519348782675,0.7192521743633049,0.9970042873686262,0.6893518518518518,0.990776780581843,0.8752754343297229,779,27033,5,29 +LGBM_Balanced,test,0.9861619861619861,0.8571081584060867,0.8246557274872268,0.8063299098721441,0.9299298722724962,0.9929048414023373,0.7213114754098361,0.9955224505168013,0.6537890044576523,0.997275204359673,0.6153846153846154,0.9885726158321213,0.8712871287128713,88,4758,13,55 +LGBM_DART,train,0.9880413597061434,0.8747819377304309,0.8394311701094301,0.8197295388100121,0.956851692191424,0.9938704735789482,0.7556934018819133,0.9965888563780074,0.6822734838408528,0.9984096949039746,0.6410493827160494,0.989373835326035,0.9243295490568126,685,27022,16,123 +LGBM_DART,test,0.9871794871794872,0.8660089064245453,0.8302067070505215,0.8102456126979287,0.9484753203607024,0.9934286012308334,0.7385892116182573,0.996234309623431,0.664179104477612,0.9981136030182352,0.6223776223776224,0.9887873754152824,0.9081632653061225,89,4762,9,54 +LGBM_GOSS,train,0.988185050764564,0.8767619724826792,0.8413979145609835,0.8216153002029628,0.9579916919528333,0.9939435911838446,0.7595803537815136,0.9966404297714536,0.6861553993505135,0.9984466497886416,0.644783950617284,0.9894818584294637,0.9265015254762032,775,27035,3,33 +LGBM_GOSS,test,0.9871794871794872,0.8615239070778549,0.8193657257149385,0.7966787980411958,0.9662106135986732,0.9934340802501302,0.7296137339055794,0.9967374937259494,0.6419939577039275,0.9989520016767973,0.5944055944055944,0.9879767827529021,0.9444444444444444,85,4766,5,58 +LGBM_RF,train,0.9858507062671537,0.8481335782755919,0.8099301710699093,0.7891912188845618,0.9416024225560872,0.99275226985963,0.7035148866915535,0.9958737030582172,0.6239866390816016,0.9979657711024567,0.5804166666666666,0.987593854471584,0.8956109906405902,507,26993,45,301 +LGBM_RF,test,0.9855514855514855,0.8465095837474215,0.809386193422338,0.7890569920542673,0.9352652953120861,0.9925972265665728,0.70042194092827,0.9956492637215528,0.6231231231231231,0.9976944036889541,0.5804195804195804,0.9875518672199171,0.8829787234042553,83,4760,11,60 +LGBM_scale_pos_weight,train,0.9883646355682123,0.8847082787258997,0.8597348514523183,0.8451146678749734,0.9373215094285751,0.9940273787240415,0.7753891787277578,0.9959218264040318,0.7235478765006047,0.9971892122931566,0.6930401234567901,0.9908875151444739,0.8837555037126762,783,27033,5,25 +LGBM_scale_pos_weight,test,0.9865689865689866,0.8624097011164226,0.8311362209667295,0.8133229168651512,0.9313849935322168,0.9931120851596744,0.7317073170731707,0.9956057752667922,0.6666666666666666,0.997275204359673,0.6293706293706294,0.9889835792974434,0.8737864077669902,90,4758,13,53 +LGBM_is_unbalance,train,0.9882927448977485,0.8841054144736109,0.8598757937903134,0.8456764784528097,0.9349710933550665,0.9939903291217178,0.774220499825504,0.9958405185143574,0.7239110690662693,0.9970782655475945,0.6942746913580247,0.9909237024663042,0.8790184842438284,779,27033,5,29 +LGBM_is_unbalance,test,0.9867724867724867,0.863955516465169,0.8317145687693159,0.8134277166974715,0.9356694049190748,0.9932171553793175,0.7346938775510204,0.9957733511884834,0.6676557863501483,0.9974848040243136,0.6293706293706294,0.9889858686616791,0.8823529411764706,90,4759,12,53 +XGB_scale_pos_weight,train,0.9881131600941002,0.8868996060146659,0.8705466767428707,0.8606000987409465,0.9188644360910505,0.9938914211334648,0.7799077908958669,0.9951365452837685,0.7459568082019727,0.9959687160004116,0.7252314814814815,0.9918245196483149,0.8459043525337859,781,27028,10,27 +XGB_scale_pos_weight,test,0.9839234839234839,0.8421764514805791,0.8211201077315871,0.8085688153808045,0.8847258771929825,0.9917458990701076,0.6926070038910506,0.9935521688159438,0.6486880466472303,0.9947600083839866,0.6223776223776224,0.98875,0.7807017543859649,89,4746,25,54 +CatBoost_balanced,train,0.9843784049402589,0.8696686267343388,0.8824472728294012,0.8916952848998795,0.8508242781484853,0.9919396338322237,0.7473976196364541,0.9908276010500254,0.7740669446087769,0.9900881006639566,0.7933024691358025,0.9938004847319636,0.7078480715650071,789,26898,140,19 +CatBoost_balanced,test,0.9802604802604803,0.8348421298822796,0.8461546793313885,0.8541662696976049,0.8176680164072361,0.9898162729658793,0.6798679867986799,0.988757446094471,0.703551912568306,0.9880528191154894,0.7202797202797203,0.991586032814472,0.64375,103,4714,57,40 +RandomForest_balanced,train,0.9780219876596711,0.69482483029087,0.6476767204797605,0.6284046291958619,0.9637732327782711,0.988805060981185,0.400844599600555,0.9952275898084167,0.3001258511511045,0.9995561719719707,0.2572530864197531,0.9782843830335899,0.9492620825229521,749,27037,1,59 +RandomForest_balanced,test,0.9774114774114774,0.6876170818855889,0.6409384493281152,0.6220632228806615,0.9494516644359052,0.9884938322794651,0.3867403314917127,0.994991652754591,0.28688524590163933,0.9993712010060783,0.24475524475524477,0.977850697292863,0.9210526315789473,35,4768,3,108 +BalancedRandomForest,train,0.939704118430013,0.7028311098818047,0.7781778852496469,0.8752773234871295,0.6471027486545569,0.9681428255443214,0.43751939421928815,0.9533096243066284,0.6030461461926654,0.9436719309248763,0.8068827160493829,0.9939232551592664,0.30028224214984744,790,25662,1376,18 +BalancedRandomForest,test,0.9318274318274318,0.6688915684702886,0.7336858411761832,0.8190487986128313,0.622796487015859,0.9639513612396428,0.37383177570093457,0.9486994831822418,0.5186721991701245,0.9387968979249633,0.6993006993006993,0.9904909332153914,0.25510204081632654,100,4479,292,43 +LogisticRegression_balanced,train,0.9390578117583936,0.7053593691260407,0.7842568337517993,0.888164180680811,0.6482233216447624,0.9677636915743127,0.44295504667776875,0.9522564007253292,0.6162572667782692,0.9421925588924864,0.8341358024691358,0.9947683858648215,0.3016782574247032,745,25582,1456,63 +LogisticRegression_balanced,test,0.9356939356939357,0.68489389075354,0.7544562768745361,0.844781921076199,0.6342667509156945,0.9660141966014196,0.4037735849056604,0.9510397695989158,0.5578727841501564,0.9413120939006497,0.7482517482517482,0.9920477137176938,0.27648578811369506,107,4491,280,36 +LGBM_FOCAL_LOSS,train,0.9874308822922471,0.8637586787288788,0.8216344627004155,0.7989945052063886,0.969041715035573,0.9935632802546319,0.7339540772031252,0.9968410405554005,0.6464278848454306,0.9990383931288267,0.5989506172839506,0.9881485197433785,0.9499349103277674,774,27036,2,34 +LGBM_FOCAL_LOSS,test,0.9855514855514855,0.838525460793502,0.7925445481202558,0.7687067700691679,0.9626827249232757,0.9926064771425596,0.6844444444444444,0.9964040809499917,0.5886850152905199,0.9989520016767973,0.5384615384615384,0.9863410596026491,0.9390243902439024,77,4766,5,66 diff --git a/runner.py b/runner.py new file mode 100644 index 0000000..f07c7ba --- /dev/null +++ b/runner.py @@ -0,0 +1,318 @@ +import pandas +from catboost import CatBoostClassifier +from imblearn.ensemble import BalancedRandomForestClassifier +from lightgbm import LGBMClassifier +from sklearn.ensemble import RandomForestClassifier +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import confusion_matrix +from sklearn.model_selection import train_test_split +from xgboost import XGBClassifier + +from custom_models.LGBMFocalWrapper import LGBMFocalWrapper +from train import test_model, train_model_with_kfold + +data_frame = pandas.read_csv("./data/Ketamin_icp_cleaned.csv") +y = data_frame["label"] +X = data_frame.drop(columns=["label"]) + +x_train, x_test, y_train, y_test = train_test_split( + X, + y, + test_size=0.15, + stratify=y, + random_state=42, +) + +neg = sum(y_train == 0) +pos = sum(y_train == 1) +scale_pos = neg / pos if pos > 0 else 1.0 + +models = [ + { + "name": "LGBM_FOCAL_LOSS", + "model": LGBMFocalWrapper( + n_estimators=500, + learning_rate=0.05, + max_depth=-1, + subsample=0.8, + colsample_bytree=0.8, + random_state=42, + ), + "smote": True, + "smote_method": "kmeans", + }, + { + "name": "LGBM_SMOTE", + "model": LGBMClassifier( + n_estimators=500, + learning_rate=0.05, + max_depth=-1, + subsample=0.8, + colsample_bytree=0.8, + random_state=42, + verbose=-1, + n_jobs=-1, + ), + "smote": True, + "smote_method": "smote", + }, + { + "name": "LGBM_KMEANS_SMOTE", + "model": LGBMClassifier( + n_estimators=500, + learning_rate=0.05, + max_depth=-1, + subsample=0.8, + colsample_bytree=0.8, + random_state=42, + verbose=-1, + n_jobs=-1, + ), + "smote": True, + "smote_method": "kmeans", + }, + { + "name": "LGBM_SVM_SMOTE", + "model": LGBMClassifier( + n_estimators=500, + learning_rate=0.05, + max_depth=-1, + subsample=0.8, + colsample_bytree=0.8, + random_state=42, + verbose=-1, + n_jobs=-1, + ), + "smote": True, + "smote_method": "svm", + }, + { + "name": "LGBM_BORDERLINE_SMOTE", + "model": LGBMClassifier( + n_estimators=500, + learning_rate=0.05, + max_depth=-1, + subsample=0.8, + colsample_bytree=0.8, + random_state=42, + verbose=-1, + n_jobs=-1, + ), + "smote": True, + "smote_method": "borderline", + }, + { + "name": "LGBM_ADASYN_SMOTE", + "model": LGBMClassifier( + n_estimators=500, + learning_rate=0.05, + max_depth=-1, + subsample=0.8, + colsample_bytree=0.8, + random_state=42, + verbose=-1, + n_jobs=-1, + ), + "smote": True, + "smote_method": "adasyn", + }, + { + "name": "LGBM_Balanced", + "model": LGBMClassifier( + n_estimators=500, + learning_rate=0.05, + max_depth=-1, + subsample=0.8, + colsample_bytree=0.8, + class_weight="balanced", + random_state=42, + verbose=-1, + n_jobs=-1, + ), + "smote": False, + }, + { + "name": "LGBM_DART", + "model": LGBMClassifier( + n_estimators=500, + learning_rate=0.05, + max_depth=-1, + subsample=0.8, + colsample_bytree=0.8, + boosting_type="dart", + random_state=42, + verbose=-1, + n_jobs=-1, + ), + "smote": True, + "smote_method": "kmeans", + }, + { + "name": "LGBM_GOSS", + "model": LGBMClassifier( + n_estimators=500, + learning_rate=0.05, + max_depth=-1, + boosting_type="goss", + random_state=42, + verbose=-1, + n_jobs=-1, + ), + "smote": True, + "smote_method": "kmeans", + }, + { + "name": "LGBM_RF", + "model": LGBMClassifier( + n_estimators=500, + learning_rate=0.05, + max_depth=-1, + boosting_type="rf", + subsample=0.8, + colsample_bytree=0.8, + random_state=42, + verbose=-1, + n_jobs=-1, + ), + "smote": True, + "smote_method": "kmeans", + }, + { + "name": "LGBM_scale_pos_weight", + "model": LGBMClassifier( + n_estimators=500, + learning_rate=0.05, + max_depth=-1, + scale_pos_weight=scale_pos, + random_state=42, + verbose=-1, + n_jobs=-1, + ), + "smote": False, + }, + { + "name": "LGBM_is_unbalance", + "model": LGBMClassifier( + n_estimators=500, + learning_rate=0.05, + max_depth=-1, + is_unbalance=True, + random_state=42, + verbose=-1, + n_jobs=-1, + ), + "smote": False, + }, + { + "name": "LGBM_DART", + "model": LGBMClassifier( + n_estimators=500, + learning_rate=0.05, + max_depth=-1, + subsample=0.8, + colsample_bytree=0.8, + boosting_type="dart", + random_state=42, + verbose=-1, + n_jobs=-1, + ), + "smote": True, + "smote_method": "kmeans", + }, + { + "name": "XGB_scale_pos_weight", + "model": XGBClassifier( + n_estimators=500, + learning_rate=0.05, + max_depth=6, + scale_pos_weight=scale_pos, + random_state=42, + n_jobs=-1, + use_label_encoder=False, + eval_metric="logloss", + ), + "smote": False, + }, + { + "name": "CatBoost_balanced", + "model": CatBoostClassifier( + iterations=500, + learning_rate=0.05, + depth=6, + class_weights=[1, scale_pos], + random_state=42, + verbose=0, + ), + "smote": False, + }, + { + "name": "RandomForest_balanced", + "model": RandomForestClassifier( + n_estimators=500, + max_depth=None, + class_weight="balanced", + random_state=42, + n_jobs=-1, + ), + "smote": False, + }, + { + "name": "BalancedRandomForest", + "model": BalancedRandomForestClassifier( + n_estimators=500, + max_depth=None, + random_state=42, + n_jobs=-1, + ), + "smote": False, + }, + { + "name": "LogisticRegression_balanced", + "model": LogisticRegression( + max_iter=1000, + class_weight="balanced", + solver="liblinear", + random_state=42, + ), + "smote": False, + }, +] + + +def compute_confusion(y_true, y_pred): + tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel() + return {"TP": tp, "TN": tn, "FP": fp, "FN": fn} + + +results_to_save = [] + +for m in models: + print(f"\n===== Training model: {m['name']} =====") + + train_results = train_model_with_kfold( + m["model"], x_train, y_train, n_splits=10, smote=m["smote"] + ) + + y_train_pred = m["model"].predict(x_train) + train_confusion = compute_confusion(y_train, y_train_pred) + + test_results = test_model(m["model"], x_test, y_test) + y_test_pred = m["model"].predict(x_test) + test_confusion = compute_confusion(y_test, y_test_pred) + + results_to_save.append( + {"model": m["name"], "stage": "train", **train_results, **train_confusion} + ) + results_to_save.append( + {"model": m["name"], "stage": "test", **test_results, **test_confusion} + ) + +results_df = pandas.DataFrame(results_to_save) +csv_file = "lightgbm_results.csv" + +try: + results_df.to_csv(csv_file, mode="a", index=False, header=False) +except FileNotFoundError: + results_df.to_csv(csv_file, mode="w", index=False) + +print(f"\nAll results saved to {csv_file}") diff --git a/train.py b/train.py new file mode 100644 index 0000000..f2df775 --- /dev/null +++ b/train.py @@ -0,0 +1,139 @@ +import numpy +import tqdm +from imblearn.over_sampling import ADASYN, SMOTE, SVMSMOTE, BorderlineSMOTE, KMeansSMOTE +from sklearn.metrics import ( + accuracy_score, + f1_score, + fbeta_score, + precision_score, + recall_score, +) +from sklearn.model_selection import StratifiedKFold + + +def train_model_with_kfold( + model, X, y, n_splits=5, random_state=42, smote=True, smote_method="kmeans" +): + skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state) + + accuracy_list = [] + f1_macro_list = [] + f2_macro_list = [] + recall_macro_list = [] + precision_macro_list = [] + + # per-class lists + f1_class0 = [] + f1_class1 = [] + f2_class0 = [] + f2_class1 = [] + recall_class0 = [] + recall_class1 = [] + precision_class0 = [] + precision_class1 = [] + + fold_num = 1 + + for fold_num, (train_idx, val_idx) in enumerate( + tqdm.tqdm(skf.split(X, y), total=skf.n_splits, desc="Training Folds"), start=1 + ): + X_train, X_val = X.iloc[train_idx], X.iloc[val_idx] + y_train, y_val = y.iloc[train_idx], y.iloc[val_idx] + + if smote: + if smote_method.lower() == "kmeans": + sampler = KMeansSMOTE( + k_neighbors=5, + cluster_balance_threshold=0.1, + random_state=random_state, + ) + elif smote_method.lower() == "smote": + sampler = SMOTE(k_neighbors=5, random_state=random_state) + elif smote_method.lower() == "svmsmote": + sampler = SVMSMOTE(k_neighbors=5, random_state=random_state) + elif smote_method.lower() == "borderline": + sampler = BorderlineSMOTE(k_neighbors=5, random_state=random_state) + elif smote_method.lower() == "adasyn": + sampler = ADASYN(n_neighbors=5, random_state=random_state) + else: + raise ValueError(f"Unknown smote_method: {smote_method}") + + X_train, y_train = sampler.fit_resample(X_train, y_train) + + model.fit(X_train, y_train) + + y_pred = model.predict(X_val) + + accuracy_list.append(accuracy_score(y_val, y_pred)) + f1_macro_list.append(f1_score(y_val, y_pred, average="macro")) + f2_macro_list.append(fbeta_score(y_val, y_pred, beta=2, average="macro")) + recall_macro_list.append(recall_score(y_val, y_pred, average="macro")) + precision_macro_list.append(precision_score(y_val, y_pred, average="macro")) + + f1_class0.append(f1_score(y_val, y_pred, pos_label=0)) + f1_class1.append(f1_score(y_val, y_pred, pos_label=1)) + + f2_class0.append(fbeta_score(y_val, y_pred, beta=2, pos_label=0)) + f2_class1.append(fbeta_score(y_val, y_pred, beta=2, pos_label=1)) + + recall_class0.append(recall_score(y_val, y_pred, pos_label=0)) + recall_class1.append(recall_score(y_val, y_pred, pos_label=1)) + + precision_class0.append(precision_score(y_val, y_pred, pos_label=0)) + precision_class1.append(precision_score(y_val, y_pred, pos_label=1)) + + fold_num += 1 + + return { + "accuracy": numpy.mean(accuracy_list), + "f1_macro": numpy.mean(f1_macro_list), + "f2_macro": numpy.mean(f2_macro_list), + "recall_macro": numpy.mean(recall_macro_list), + "precision_macro": numpy.mean(precision_macro_list), + "f1_class0": numpy.mean(f1_class0), + "f1_class1": numpy.mean(f1_class1), + "f2_class0": numpy.mean(f2_class0), + "f2_class1": numpy.mean(f2_class1), + "recall_class0": numpy.mean(recall_class0), + "recall_class1": numpy.mean(recall_class1), + "precision_class0": numpy.mean(precision_class0), + "precision_class1": numpy.mean(precision_class1), + } + + +def test_model(model, X_test, y_test): + y_pred = model.predict(X_test) + + accuracy = accuracy_score(y_test, y_pred) + f1_macro = f1_score(y_test, y_pred, average="macro") + f2_macro = fbeta_score(y_test, y_pred, beta=2, average="macro") + recall_macro = recall_score(y_test, y_pred, average="macro") + precision_macro = precision_score(y_test, y_pred, average="macro") + + f1_class0 = f1_score(y_test, y_pred, pos_label=0) + f1_class1 = f1_score(y_test, y_pred, pos_label=1) + + f2_class0 = fbeta_score(y_test, y_pred, beta=2, pos_label=0) + f2_class1 = fbeta_score(y_test, y_pred, beta=2, pos_label=1) + + recall_class0 = recall_score(y_test, y_pred, pos_label=0) + recall_class1 = recall_score(y_test, y_pred, pos_label=1) + + precision_class0 = precision_score(y_test, y_pred, pos_label=0) + precision_class1 = precision_score(y_test, y_pred, pos_label=1) + + return { + "accuracy": accuracy, + "f1_macro": f1_macro, + "f2_macro": f2_macro, + "recall_macro": recall_macro, + "precision_macro": precision_macro, + "f1_class0": f1_class0, + "f1_class1": f1_class1, + "f2_class0": f2_class0, + "f2_class1": f2_class1, + "recall_class0": recall_class0, + "recall_class1": recall_class1, + "precision_class0": precision_class0, + "precision_class1": precision_class1, + } diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..b94545d --- /dev/null +++ b/utils.py @@ -0,0 +1,62 @@ +""" +Saeed Khosravi - 27 Nov 2025 +""" + + +def split_path(full_path): + import os + + directory = os.path.dirname(full_path) + filename = os.path.splitext(os.path.basename(full_path))[0] + return directory, filename + + +def write_textfile(path, data_list): + with open(path, "w") as file: + for data in data_list: + file.write(f"{data} \n") + + +def missing_value_handler(data_path): + import pandas + from sklearn.impute import KNNImputer + + data_directory, data_filename = split_path(data_path) + + data_frame = pandas.read_csv(data_path) + + columns = list(data_frame.head(0)) + # remove column id + if "id" in columns: + data_frame = data_frame.drop("id", axis="columns") + + columns = list(data_frame.head(0)) + write_textfile(f"{data_directory}/columns.txt", columns) + + # find missing values + missing_value_counts = data_frame.isna().sum() + write_textfile(f"{data_directory}/missing.txt", missing_value_counts) + + # fill missing values - KNNImputer + + imputer = KNNImputer(n_neighbors=5) + data_imputed = imputer.fit_transform(data_frame) + data_frame_imputed = pandas.DataFrame(data_imputed, columns=columns) + + missing_value_counts = data_frame_imputed.isna().sum() + write_textfile(f"{data_directory}/no_missing.txt", missing_value_counts) + return data_frame_imputed + + +def scaling_handler(data_frame, method="robust_scaling"): + if method == "robust_scaling": + import pandas + from sklearn.preprocessing import RobustScaler + + labels = data_frame["label"] + scaler = RobustScaler() + x = data_frame.drop("label", axis=1) + x_scale = scaler.fit_transform(x) + data_frame_scaled = pandas.DataFrame(x_scale, columns=x.columns) + data_frame_scaled["label"] = labels.values + return data_frame_scaled