From 42091d6f24b4ea88fed110497d1491a7af400920 Mon Sep 17 00:00:00 2001 From: saeedkhosravi94 Date: Sat, 6 Dec 2025 00:14:59 +0100 Subject: [PATCH] LightGBM tuning --- LGBM_Tuning.py | 280 +++++++ __pycache__/train.cpython-312.pyc | Bin 5788 -> 5596 bytes __pycache__/utils.cpython-312.pyc | Bin 0 -> 4017 bytes catboost_info/catboost_training.json | 504 ------------ catboost_info/learn/events.out.tfevents | Bin 27370 -> 0 bytes catboost_info/learn_error.tsv | 501 ------------ catboost_info/time_left.tsv | 501 ------------ .../lgbm_vs_cat_kmeans_smote_k10_results.csv | 0 results.csv => results/results.csv | 0 results_lgbm_kmeans_smote_tuning_grid.csv | 769 ++++++++++++++++++ ...eans_smote_tuning_random_100_iteration.csv | 201 +++++ runner.py | 498 ++++++------ train.py | 7 +- utils.py | 62 +- 14 files changed, 1539 insertions(+), 1784 deletions(-) create mode 100644 LGBM_Tuning.py create mode 100644 __pycache__/utils.cpython-312.pyc delete mode 100644 catboost_info/catboost_training.json delete mode 100644 catboost_info/learn/events.out.tfevents delete mode 100644 catboost_info/learn_error.tsv delete mode 100644 catboost_info/time_left.tsv rename lgbm_vs_cat_kmeans_smote_k10_results.csv => results/lgbm_vs_cat_kmeans_smote_k10_results.csv (100%) rename results.csv => results/results.csv (100%) create mode 100644 results_lgbm_kmeans_smote_tuning_grid.csv create mode 100644 results_lgbm_kmeans_smote_tuning_random_100_iteration.csv diff --git a/LGBM_Tuning.py b/LGBM_Tuning.py new file mode 100644 index 0000000..ff7fec0 --- /dev/null +++ b/LGBM_Tuning.py @@ -0,0 +1,280 @@ +import itertools +import os +import random + +import pandas as pd +import tqdm +from imblearn.over_sampling import KMeansSMOTE +from lightgbm import LGBMClassifier +from sklearn.metrics import ( + accuracy_score, + f1_score, + fbeta_score, + precision_score, + recall_score, +) +from sklearn.model_selection import StratifiedKFold, train_test_split + +from utils import scaling_handler + +RESULTS_FILENAME = "results_lgbm_kmeans_smote_tuning_grid.csv" + + +def get_metrics(y_true, y_pred, prefix=""): + metrics = {} + + metrics[f"{prefix}accuracy"] = accuracy_score(y_true, y_pred) + metrics[f"{prefix}f1_macro"] = f1_score(y_true, y_pred, average="macro") + metrics[f"{prefix}f2_macro"] = fbeta_score(y_true, y_pred, beta=2, average="macro") + metrics[f"{prefix}recall_macro"] = recall_score(y_true, y_pred, average="macro") + metrics[f"{prefix}precision_macro"] = precision_score( + y_true, y_pred, average="macro" + ) + + f1_scores = f1_score(y_true, y_pred, average=None, zero_division=0) + f2_scores = fbeta_score(y_true, y_pred, beta=2, average=None, zero_division=0) + recall_scores = recall_score(y_true, y_pred, average=None, zero_division=0) + precision_scores = precision_score(y_true, y_pred, average=None, zero_division=0) + + metrics[f"{prefix}f1_class0"] = f1_scores[0] + metrics[f"{prefix}f1_class1"] = f1_scores[1] + metrics[f"{prefix}f2_class0"] = f2_scores[0] + metrics[f"{prefix}f2_class1"] = f2_scores[1] + metrics[f"{prefix}recall_class0"] = recall_scores[0] + metrics[f"{prefix}recall_class1"] = recall_scores[1] + metrics[f"{prefix}precision_class0"] = precision_scores[0] + metrics[f"{prefix}precision_class1"] = precision_scores[1] + + TP = sum((y_true == 1) & (y_pred == 1)) + TN = sum((y_true == 0) & (y_pred == 0)) + FP = sum((y_true == 0) & (y_pred == 1)) + FN = sum((y_true == 1) & (y_pred == 0)) + + metrics[f"{prefix}TP"] = TP + metrics[f"{prefix}TN"] = TN + metrics[f"{prefix}FP"] = FP + metrics[f"{prefix}FN"] = FN + + return metrics + + +try: + data_frame = pd.read_csv("./data/Ketamine_icp_no_missing.csv") +except FileNotFoundError: + print("Please ensure the data file exists at './data/Ketamine_icp_no_missing.csv'") + exit() + +random_state = 42 +n_split_kfold = 5 + +scaling_methods_list = [ + "standard_scaling", + "robust_scaling", + "minmax_scaling", + "yeo_johnson", +] + +boosting_type_list = ["gbdt", "dart"] +learning_rate_list = [0.03, 0.05, 0.1] +number_of_leaves_list = [100] +l2_regularization_lambda_list = [0.1, 0.5] +l1_regularization_alpha_list = [0.1, 0.5] +tree_subsample_tree_list = [0.8, 1.0] +subsample_list = [0.8, 1.0] +is_balanced_list = [True, False] + +kmeans_smote_k_neighbors_list = [10] +kmeans_smote_n_clusters_list = [5] + +param_combinations = list( + itertools.product( + scaling_methods_list, + boosting_type_list, + learning_rate_list, + number_of_leaves_list, + l2_regularization_lambda_list, + l1_regularization_alpha_list, + tree_subsample_tree_list, + subsample_list, + is_balanced_list, + kmeans_smote_k_neighbors_list, + kmeans_smote_n_clusters_list, + ) +) + +template_metrics = get_metrics( + pd.Series([0, 1, 0, 1]), + pd.Series([0, 1, 0, 1]), +) + +template_cols = ["iteration", "model", "params"] +for k in template_metrics.keys(): + template_cols.append(f"avg_val_{k}") + template_cols.append(f"test_{k}") +empty_df = pd.DataFrame(columns=template_cols) +if not os.path.exists(RESULTS_FILENAME): + empty_df.to_csv(RESULTS_FILENAME, index=False) + print(f"Initialized {RESULTS_FILENAME} with headers.") +else: + print(f"File {RESULTS_FILENAME} already exists. Appending to it.") + +iteration = 0 +for ( + scaling_method, + boosting_type, + learning_rate, + num_leaves, + reg_lambda, + reg_alpha, + colsample_bytree, + subsample, + is_balanced, + k_neighbors, + kmeans_estimator, +) in tqdm.tqdm(param_combinations): + skf = StratifiedKFold( + n_splits=n_split_kfold, shuffle=True, random_state=random_state + ) + + data_frame_scaled = scaling_handler(data_frame, scaling_method) + y = data_frame_scaled["label"] + X = data_frame_scaled.drop(columns=["label"]) + + x_train_val, x_test, y_train_val, y_test = train_test_split( + X, y, test_size=0.15, stratify=y, random_state=random_state + ) + + fold_results = [] + lgbm_classifier_params = None + + sampling_method = "none" + + for fold_idx, (train_index, val_index) in enumerate( + skf.split(x_train_val, y_train_val) + ): + x_train_fold, x_val = x_train_val.iloc[train_index], x_train_val.iloc[val_index] + y_train_fold, y_val = y_train_val.iloc[train_index], y_train_val.iloc[val_index] + + x_train = x_train_fold + y_train = y_train_fold + lgbm_classifier_params = None + + lgbm_base_params = { + "boosting_type": boosting_type, + "objective": "binary", + "learning_rate": learning_rate, + "n_jobs": -1, + "num_leaves": num_leaves, + "reg_lambda": reg_lambda, + "reg_alpha": reg_alpha, + "colsample_bytree": colsample_bytree, + "subsample": subsample, + } + + if is_balanced: + sampling_method = "KMeansSMOTE" + + smote_params = { + "sampling_strategy": "minority", + "k_neighbors": k_neighbors, + "kmeans_estimator": kmeans_estimator, + "cluster_balance_threshold": 0.001, + "random_state": random_state, + "n_jobs": -1, + } + + try: + smote = KMeansSMOTE(**smote_params) + x_train, y_train = smote.fit_resample(x_train_fold, y_train_fold) + lgbm_classifier_params = lgbm_base_params.copy() + + except RuntimeError as e: + print( + f"KMeansSMOTE failed with RuntimeError in fold {fold_idx} of iteration {iteration}: {e}. Skipping fold." + ) + continue + except ValueError as e: + print( + f"KMeansSMOTE failed with ValueError in fold {fold_idx} of iteration {iteration}: {e}. Skipping fold." + ) + continue + + else: + sampling_method = "class_weight" + + class_1_weight = int( + (y_train_fold.shape[0] - y_train_fold.sum()) / y_train_fold.sum() + ) + + lgbm_classifier_params = lgbm_base_params.copy() + lgbm_classifier_params["class_weight"] = {0: 1, 1: class_1_weight} + + if lgbm_classifier_params: + model = LGBMClassifier( + **lgbm_classifier_params, random_state=random_state, verbose=-1 + ) + model.fit(x_train, y_train) + y_pred_val = model.predict(x_val) + + val_metrics = get_metrics(y_val, y_pred_val) + fold_results.append(val_metrics) + + avg_val_metrics = {} + if fold_results: + val_df = pd.DataFrame(fold_results) + avg_val_metrics = val_df.mean().to_dict() + + test_metrics = {} + if lgbm_classifier_params: + x_train_final = x_train_val + y_train_final = y_train_val + + if is_balanced and sampling_method == "KMeansSMOTE": + try: + smote = KMeansSMOTE(**smote_params) + x_train_final, y_train_final = smote.fit_resample( + x_train_val, y_train_val + ) + except (RuntimeError, ValueError) as e: + print( + f"Final KMeansSMOTE failed for iteration {iteration}: {e}. Skipping test evaluation." + ) + lgbm_classifier_params = None + + if lgbm_classifier_params: + final_lgbm_params = lgbm_base_params.copy() + + test_model = LGBMClassifier( + **final_lgbm_params, random_state=random_state, verbose=-1 + ) + + test_model.fit(x_train_final, y_train_final) + y_pred_test = test_model.predict(x_test) + + test_metrics = get_metrics(y_test, y_pred_test, prefix="test_") + + if lgbm_classifier_params: + params_str = str(lgbm_base_params).replace("}", "") + if is_balanced: + params_str += f", 'smote_k_neighbors': {k_neighbors}, 'smote_n_clusters': {kmeans_estimator}" + + final_result_dict = { + "iteration": iteration, + "model": "LGBMClassifier", + "params": params_str + + f", 'sampling_method': '{sampling_method}', 'scaling_method': '{scaling_method}'}}", + } + + for k in template_metrics.keys(): + final_result_dict[f"avg_val_{k}"] = avg_val_metrics.get(k, float("nan")) + final_result_dict[f"test_{k}"] = test_metrics.get(f"test_{k}", float("nan")) + + result_row_df = pd.DataFrame([final_result_dict]) + + result_row_df = result_row_df.reindex(columns=template_cols, fill_value=None) + + result_row_df.to_csv(RESULTS_FILENAME, mode="a", header=False, index=False) + + iteration += 1 + +print(f"Finished: check {RESULTS_FILENAME}") diff --git a/__pycache__/train.cpython-312.pyc b/__pycache__/train.cpython-312.pyc index 4013c7803d386b3bc25b988f2c253923d85c0e41..b809b9b832347c49701ab14e8f7310e6edd1fa3a 100644 GIT binary patch delta 1689 zcma)+&rcIU6vuaVx7+S^w@XV4ZRxg@TA^a82^tcAGz6mu{soqY-n6HO25%)7^;U=^4i}MhwNt?ofa- zcvmWWl2SwM!4|V_Pn|#oS9H7w4P!XcV*g?Nu57M{wbFJ-zvbZ-i`6tbU{f~~dtDuY z6PH{+-BS^o@(gDK{dFpXu6QErva9lRp+oJcOQ8e*#WJpJ&}I{MiH_BKcF8IwM zXP3eZ+M9=u72X0SY`-yz9YuRyapp5i>3Jttb;|L&qv2~EAJdaMB>I%;^Ettld(oE@ zUAcB&F5t@j>B|kca`*Xu^@UuyCa)CE-TeQhmG`$STRfpCLF!}mQqf{$bwAvO8FOa}RH0j?4QS{wkKoep;o-{p6E2>sA z!GZ}^A?cKmv_Tk0JK_>g{$yCd1|_8=x+@ijOmCySRIh2! z#}kXYDC`4Z9yKFTZWd3#8T2KxmL|1Wf-MTp6W^kJH!=CPrWIiYEeu_g=K2wBg Gmgp~A5B7Ti delta 1851 zcma)6PfXiZ7=O>sY$uLGAY_pLF)6fxb+oHWM8%L8LhF4OM4?f@mP+dU*^8M*Wsj<- zN!tPK?J#L{;?P4PIJDxzanhvru*)i>X|<@f1FFPr6FpQt^}P^7VobFs`Fr1=-|v0j z`)q&vBO8~c-zCWh$f`a6X!+MCQVzmj@C^JD?&CP~9X!HSHnqmzy#xRnR|LhY@g>2G zL%;wdZbY{rr2?R^5hBfm9%`D@|>RE5~Sp=2E%{DtK~2s{AFbQSNTm7|>(q;c%Ii<(?KQaW3jHD^G`7 zs8k24w=HKtmrr70gZmjgWb9r%fJUmA!IPe(f1X^|>^GwZC=L=wCUemyo$f}&CYotO zC}!~_pUS1}RCWvk!&}UrvTRMiY->i$NaM#`U*6}68*RraqVnN(|CZ<2*Oo(c^fzx*#!ZhI~yQuM_n2aJSD$}%er#A z&v0pey`y{Q+_vWo^(@orIgS^s_i8Y47`}Ev9?Yu)^LW|D_pTXTJa21N7^5 zuZm-~*4*(!J}9;Y@h+d}u>Wz`g3uYi;;?5O_I-z4ci4Y8>^BZO?d{BW(_vS=2|l(V zeF3s@{G)fS57GIRg%&?uSt;u(I;5lXkiq|YKl2`T(HMiGc+96{8jR&xRP_S7K)elT zd8?ILUMnoBT7yAD#D^~7XTB#ej#tIexFs%CG)-O7mn-EOq9f5_OUq><3SZ;8SdGfo z2c?RlY8ITe;3}f2P&XBE{E5GW8~y~3Xzx)D|K^_>K0ZPqM^wDnw6GKt=uIj{=Szz(BrwXUU~# zLh(n@5i~n@?qg=|oqO)^?;eknfY$BVpY*v2;@_CC8dV)UZ->DGArc=F*9eh>o)jtQ zDbXgxg6s#B#Jj8T1Wmn21qBPq1FhhA2ftOh(X05ho^9Sx8IBu4Ph5 zU7FC5eMHbCJQrh-<3=`>iepKY0afGxaVeoI>Ljw?NTy6N4Rbo6eg9hjdzz$b{hA<2 zV*j9YyZ>|wq*Wy$sA5vNsS4@-te#A1eVIw*G#9Mc_J9TK9qj^i)5O<|XT#UDz|C_D z{5+p`lzh9V&zE`s-SOG+bv|0;qj~BJ?sIOvV{frz?@FlD@n(r1czC(UpO`*Zb~Vq9 z7F_KGru{isBNFuiVqwAz<1b*aKzzDjbV+3X3+Nuk@}Rf zoJ|sUf(+q+D~10oOC&K#bP@XqjU1;);(a(Ho|&NEC&tMbJ%q^dDXPO6qY$HHBpKPp zRR}?hQSto&UkkbrPbD=SG1%9L!4oxat8lpq^1{nX7`_n=4FXgR zm-WrDB5km@2{iy3e1kZUA}un@HY*ta+X}khvTKH%{gg6T%OXj@x&e#7`F&6wj&*fz_`=GU;*Bpl#lobnCLa2V{e)3Tg%8qWXQ zo?S2R87zkz76(#vi=Jh#wLbXF30TArsW_`*BywIH8akiOWOYeJR8mA-LP=%Qvi2p3 z7-1r*O}YE4GkyAm4pcg+X-WAerk>zL0UItkE=$Rqqa%t6dg^wVo#OjtC0-{xwWF^e zCs6-INf**dS&AnUnRuN{Lc4tiY}t~sC{4t?h|LJHC}_y5-AYM>4xhlqK4Bes<#CiBkCW_3++ecyB4ZzrX}v zk}MoNx5k|R)=u!D{N*+8uIaPo!1nb(Z!yrjw7nEKIQ@1x6kQL+ilNxj>tAt&P^=U> zS72HIF1^JuH*a$I8Kw}f^}HY zCAybAR=($4^2fmU*75lU{s+MKnS3Yc+eH6zQw`HxvIg9M$^DthZLa6O8q0ES!~IGO z_5dGO%!US~zXz(8*YFr#>~E301(2eVcfej{Mz7(lC1dS)3d^_P9?K5E_s%W`#yO|# z)obV|wi@JpV)Oa%0cp6#A|Tw*45wu`#~ajGw1Ewz!GjOBIK%R`S_FKCW2|EfO`k~P z{TV*M<6sq!S>%34m^eA24JQOhaG;VhOjIe~C1qSvuUIGs{J`a`Aj6YRs-JYftBhnd zy{hCL)-jN(VwJX~gb^wA1lMvEI|O-EQ#`K0k^t<;r&SedGhV>!v@lUe3lqYKR-^f! zS9bw;uBL?)Vk_S#yTLv@uL4R|hYAxH;W4EqN8k;r$jA-H5S4TsR?{WqtC`5kNfV+u zYv7ucl=zr3Dr<^7g^4{?0;&a@mh@2ts@8}y0a!aRg$H|JaPy=8{67EUFDu8Vg73+< zWMy3LsGhH*qH?rD#o>(kRgCqh#i}wHRm#ANSu%^9`Xk_}hj4cocSmq{6nAgpZYOkz zNht7~roly(0ziXjkn<-P{ZFf^qN*`0iami?vnFV;W#SM`Xm$uySL!)T;)0+ z`6HjS&$49)UkLQ9I(i=Yqa}Z5b)sw4(N!gagn?_7$k`43fQj;uP`{^jp12im@)sV3(P3%l|+EwJ z(7n+8@QimO5L`I+>ti!-mHEhGqR8*a3nhM6IkF>v;NG?6t|j^5aIx=XY4@pOBzEAT$$DGG~T=#$XAK%aKn7QVfIcMh5D51aq%^mBwsIQPv_GxY7mvhvI ziVYU)ZDM1!%f{ALM*5VBp3eXOU;m9y+Eh4WyQD_wzeY-!RZc1Lpk^i*pQKd}R;`GT z1vN%nB!E^Ttom5fAC!Jg;wD;&vg&YQA5e-@N>XXnlU3!#!l2}(MT2P7i&bR?owNSC z50|8_wxCsSRu$j>3aWektu$KoVbzxX4WM4Qy%|fZzN|_we-EnQc2N&niLt6=)Jssc zL!G>7)sIziijP6L`L?W}l{l-UMwEa$XC@g+D+yM~o-Y6u8GR^_R{dEO_xe64$x7Kr zw31|1%)~5ECb3@x9KBMk@=#6(wQ{%TSGpz5s@@AzL8b(AXa%Tya6g}dZa$BWLQ-?@EWK+UrRsHN|sfUk77Ytd^%G{tHG>F zu8ILAQuOmTt%k5lHZKZP#qDq>S`B5@IhjaM?TJ^W(`p#2l1wgvx_5toFRg~NYFj}# zC|A?p9kh~TRgLupP}Mai&uKM+RSBZ!LB0K1@e?T_um8KENXfHm?xYY<+Q%w7X{Erb zg7$Nu&S)Gnp_L-5yft}hreY92M{yG>v8qfc2wLXnH=Lte%B%_t3jnp&uhX1XBUzOe z;SXw+UC>2ZjbfE^^;u9o9508{YBa0zHF@fKi~B%Ysj$lEy$`f%ZFLiArOK+EwP!%x zf1eglD>YV~YVrc*Z8X|}R%2KdfAln{#Hxm!v{GkPkc=m&PZMB*D=qGz&WrkA zqSZK7&9vaDd^=?iT4}KAvz;5X?9Q8t(`r1cdR%q|^`xJ9cetm?Jb30kq0yCrBfl~oeA96{O0x#KkecV?T$s{5uq z6=?q&uf8}nomItqjzeo>{2gn0LLF8;K5-0G=Z>e!w3@-H_w`3W**`GRrPWMU8NWIV zs&vHnb+poD)#lke_266P0$R;tmG=yLXen;apGqq|R?Uz*1nNh{%Nw+s&8nz;J5css z-vwNW&|O>o;{%`yz87}Ut+{MVe!4BF!2Egvt$D0k)w&NmtRt41AfQnGxZA7d2tjfK)2UN}WBe!X_fK}f=?FO|(!l4_f*?-qI3t5$=%~QwZ zrqOrO39Do^T1PW*BS(wYDz(7_CfMRr%Etl$GrE z*|b{9s@xA-Kncf5;8hB@UCUUtMuewo5`Bv3)^b+)ecuGFj2#E?^9`=Gf>qIjc&a9; zB#myZWR>~~3uqP8|Lg9CYpr5ckF*V-`i34Cus2q-s-tZ^s2qg>#q@+`tdhIG7L-uv zFnq1#CNyVNP5c^AK9bG(bZZT(S`^Je1=NG9M z+hS<7iB&m6d8+hqwi~TBv+B;V70~ieDq2IUEv#x?yd2a)g(UnWklTlrteWJq4AkJR zay7cOl~pTWEd?ca-~I%xwz2AoC{HDchT%J3ZbBl^Lb9~XVr0ao^su6u?0=&_;*EQ%c{+V3!wF6iIEwt4zTL!4xSQ` zS%aS*a4kDl^~#zLt@j5N@p_(9|FCLwlmVzkVX63Dk5dO(wfvhtsBd$>U7@$@5UWak zdFr}mFg{0dEqhiC{yh&`W|PO_`!7y8uuAkaPnEeG4x}eM%&Ntcd1`Ou&jYkN!m6?E$dF8)p1sBEu9T2W{1vORqjnz#=FG% z|NAD3+lP+mRaUb&kEJDL&!w3YVo8>IP;;}}y^dxl5Hou>3z)2F*-M(8MC@ZQj|IHW z5T}_lVm(Lln9+LKi8OOT?4G|a4C&xy_=IMzhy~o92~6p#h!D-(5Yug+0j!`nWC+dN z5!1BhvAP?2Gic_4SVF1})GTGL>(I;-F@*^{)(~<+z@9vXSn~GiQ1jJ)TSM1QBeuF^ z8nBM8&(~?@g;?V<9xGWFSxB=ph)KFnh1xpd&|Wn2Ml3gH3NZVDOIv8>gIMO5$-s=> zT1e2$7co5p9_#k@QKZ>f#MGO#p*H93DLzzK~{ti0Sp72rQ{=nKaFU5Nq3^1?+>5od?a%AvWvP1Yqsm{RDj6!H8X! z=dtx3Uk&M62x1oAno!HCaujfS4MnV4l*hEc{!>QR&Leh7c|6q8SN2b*Sr}q1cQt@% zT}-@#SoYub(gnntXYtscLFJ=q7LJ&$(>Pwup)G@E5r}zJj|HaMD3(vNi-^tL!(%eh z7e3SM5@KSX)S(t$>@=Tdml2y8%VWtmtNPI_60!Nid2H`6Yjv7kK}jMPB*YMYCkYj>_=Zu_<}kG)qCO_elk)i4{zkMYCIo&5o7_Cew4$ zADX2iHsZ?&VDbqMooJSZ*onbBcF;vQQ!jLz&@fRQeAQ zux-x~`{!;iV3C328tK{##PkRASm3fcd?(E9$(M+QMD>JP-Sd$*>DnvA9J57%&HPn0 zoo25QbN?&?EHgW6J_W@gq0FEF@WQjOTn5j@srIY5K1 zy+urIkr32k<(9b7>>Xm_G2OcV{Z0BYgDbu?dym+MDP6$Ut6GkwSq);AW}U#UJ$W{b zX0?bd7XA&)9Hh|M|L2F%~9Ou#N_L@eE=6wW?)V6pR{Q9 z39*j7Uw{SlzS&K)&xq;zdrMPnhx-7&Aa?K2Ct#H?wSLgGuZTG-H35?imgt~a zGh%Z!8i7S?lmyZ28)9=DJ^)*_wN=3BwFR-X^jctE(pC6AfZHXlhy~lc2iCtx<|rC+ z_TR-*8)CwD-vS$6^h}>-?TBrDQ3XuryX!ETeMc-R{0*>R<38_b_5-m&dtL);QJfb| zvkt`MgkAwtNUh1E*-yk&_q_lXC!H~lX1@?CJ^mcn{#{oW)9g25q5jW+DXFcsp;;$l zvieVf{TdO~z+SvhA4u1n z5!0x608G#5MhMMZ5PMWn1Wc=ae7Si znq5Kcivo{nUaRt^SrlRoiWi}F&nrZfX3>b1O^pDiJ6i>R1GtYn2C+*k!+`}D{+>$L zt|GR={{k?-xJG=}!PR0B6SfHh_Uo3xaJm+USknFTz}((d;iqa`EgrGB`cPn9v+k79 zwQGpEPv)`p`DX=e+jYc}$A>^IGHu~mx^@FG@!i3|KDVxuq*(%D-Ua7?W#!HKO|wMA z&g${l*y|oCG`oq|nAjkwP4pU%SFhX$kc8O0j6h)bN`vt)2RN3Dn0<2qu*j>o@g<*Q zDTw(h@L0W-sSUktw-Ae)>kqYk{`GxmmWo)_20vgm@l)|qjT2uk&BpXqdTx3FIV83 zOpfItc6B+AiOOHYzxv}?K4STU-JsUz)qVqd$O6QE8@d8h+F;p>W`&5Iv33ELJFNr1 zS>c8(LTquCGq6d;Sr6&j1H_u+P6FEyw9%Mm#fT|AI{|E(tR=n=;D#*0OoPWBEKR`+ z433o|rhU!{YRA;(v_}5$7zXw+tVqq4JP%E!IdV$_0j}S9VJr2x6 z+aKR|b3>LRmb~B?Fz<;wRq0vxXXsvF!rlk);)xsbHDVio?FKea zOwgWugV?afUBLRf>08r7Rv~stb0;wC`P1-^;khBJ5vyyn24?p|7r$BI*jvOVL~aLW zuss^D@HqAkv8Q*ffbD;;5<_p>d&C;vZ3E`>xfj0c;A%C9RbSrfG>!djpea3!mj^%m+$?bU+R2COtyA4)H+=xwxewm z`pcRTyIHgfSa6X@BF(-bCX=}mSia`aSemsUHgPDAEtg5a2NySFD`JzAmP5@^p#iT; zIM#-k=l*5Dl+M%@(L=T)=IXx`m_|bTEt-8tOf1Y4*cLG>e4*lo{DD|#!xCUt6<7As zwGPBGel7-Pt`XZvv!96Rj_0w{ZWZ{}h#T@3VzCz&L2a*j^gg=w8?l$KOn?Q}PsGpg zxLPM-<4TNysohS%7b=eZK`iLH5ir@BG5AFq$GQ-cs51m+t}^Wxy=~oy87lMGEVT#t zF$-67ME|^^Y4ZZ870SLcp=(ZvxeM{wiegiI`@q#sAQtk}0BVtI7Aw%TlZYJ-(Ff)n z9^ywcXT)}?&I2~GGRBu?E{IL|JqK94d@X)n!fl%?VtZ!t*f;UQOuFWVm}1Cms2yn< zgRcQx%^fkPG(BLE2an_T0UYx{%-4Muuwi}Hw$np;A|`cO7g*-rqxfA1S38ARQT$9` zuFJ>a_W>L`jhNN18Nke5I91X^dLgD2rUNYD?G}9B&DG8zw&=lhVEw8RedwAuVvp*k z0h0}t7I5kCLF}>ZRA5u&KHwD|H>59OFGZ&S8zJejksk6aVnRnJ12bH$zJz9ehz%XX zW1UBC?4p@JV!xxbp*GcIvw&kV0I@bH9#fGn6>!A~MC^#~B&eDG^fsr53_|Ql$3$Qf z10wWkb`G%+ZypO2396%6Fk-)bw4k=B<|O{Rj@f@#K_Q53SLLzi1Jl(|jbovRRbHL| zwav0>dNeza*uHT*W*peMjAmho?YO52wZLA_`_b$IVjg37Y_ZKb37Ul?rtnY$Y8w{( z8bPxN#Ojuf116rn2R|?2_T)vx=0%SM_H4O^KV7?oSloh{?sM0$Vhq=@HGM5IdKp0!;teIQ+Jc+mq3V4V%ehKE66D z=~@h818>6TQ3zdNVW7zWt&8{P+EyZK^*7U_Mcey=z12OrricnLqE5pwx zIhKH!u%-gAcamXw^pJ^&4LczZ?BN8%<21X8Sghv=V3Fd6_~AD310VAYN{uhQ%e zVlUS4SkP@De6zxBTP9*NGGw70c{Zdz=qv;g~zcH z#Ox3Dhg$dBRQxQ6W2J~?wMhV*uD|+Uzl!+#xjjUz=RO`=zpurR-X&#-m0S{sTE@+F z0W^Dr*x2@dz%JW-3!_;%Vp5Jg7IonXUOaKTqyjOY{ygStxfCzxIrbPaAzd*Ta*_2P zd=ci@6U1Ed`vUX-*rPAKZIy^Es_p~KZJ{_`U~sjkh`l|+W9m=ccGI*t9_l@QXCAR)g5fCSj;~4cuBt z*J=?nUd&^zHOcr+n5)$x7V%67YTK8p;zvjvt4GZ10FSj5pC3vO`2jJv-0qqG{Y%2T zuQ$G-Sp#CdbGv|&??w0 0 else 1.0 models = [ - # { - # "name": "LGBM_FOCAL_LOSS", - # "model": LGBMFocalWrapper( - # n_estimators=500, - # learning_rate=0.05, - # max_depth=-1, - # subsample=0.8, - # colsample_bytree=0.8, - # random_state=42, - # ), - # "smote": True, - # "smote_method": "kmeans", - # }, - # { - # "name": "LGBM_SMOTE", - # "model": LGBMClassifier( - # n_estimators=500, - # learning_rate=0.05, - # max_depth=-1, - # subsample=0.8, - # colsample_bytree=0.8, - # random_state=42, - # verbose=-1, - # n_jobs=-1, - # ), - # "smote": True, - # "smote_method": "smote", - # }, - # { - # "name": "LGBM_KMEANS_SMOTE", - # "model": LGBMClassifier( - # n_estimators=500, - # learning_rate=0.05, - # max_depth=-1, - # subsample=0.8, - # colsample_bytree=0.8, - # random_state=42, - # verbose=-1, - # n_jobs=-1, - # ), - # "smote": True, - # "smote_method": "kmeans", - # }, - # { - # "name": "LGBM_SVM_SMOTE", - # "model": LGBMClassifier( - # n_estimators=500, - # learning_rate=0.05, - # max_depth=-1, - # subsample=0.8, - # colsample_bytree=0.8, - # random_state=42, - # verbose=-1, - # n_jobs=-1, - # ), - # "smote": True, - # "smote_method": "svm", - # }, - # { - # "name": "LGBM_BORDERLINE_SMOTE", - # "model": LGBMClassifier( - # n_estimators=500, - # learning_rate=0.05, - # max_depth=-1, - # subsample=0.8, - # colsample_bytree=0.8, - # random_state=42, - # verbose=-1, - # n_jobs=-1, - # ), - # "smote": True, - # "smote_method": "borderline", - # }, - # { - # "name": "LGBM_ADASYN_SMOTE", - # "model": LGBMClassifier( - # n_estimators=500, - # learning_rate=0.05, - # max_depth=-1, - # subsample=0.8, - # colsample_bytree=0.8, - # random_state=42, - # verbose=-1, - # n_jobs=-1, - # ), - # "smote": True, - # "smote_method": "adasyn", - # }, - # { - # "name": "LGBM_Balanced", - # "model": LGBMClassifier( - # n_estimators=500, - # learning_rate=0.05, - # max_depth=-1, - # subsample=0.8, - # colsample_bytree=0.8, - # class_weight="balanced", - # random_state=42, - # verbose=-1, - # n_jobs=-1, - # ), - # "smote": False, - # }, - # { - # "name": "LGBM_DART", - # "model": LGBMClassifier( - # n_estimators=500, - # learning_rate=0.05, - # max_depth=-1, - # subsample=0.8, - # colsample_bytree=0.8, - # boosting_type="dart", - # random_state=42, - # verbose=-1, - # n_jobs=-1, - # ), - # "smote": True, - # "smote_method": "kmeans", - # }, - # { - # "name": "LGBM_GOSS", - # "model": LGBMClassifier( - # n_estimators=500, - # learning_rate=0.05, - # max_depth=-1, - # boosting_type="goss", - # random_state=42, - # verbose=-1, - # n_jobs=-1, - # ), - # "smote": True, - # "smote_method": "kmeans", - # }, - # { - # "name": "LGBM_RF", - # "model": LGBMClassifier( - # n_estimators=500, - # learning_rate=0.05, - # max_depth=-1, - # boosting_type="rf", - # subsample=0.8, - # colsample_bytree=0.8, - # random_state=42, - # verbose=-1, - # n_jobs=-1, - # ), - # "smote": True, - # "smote_method": "kmeans", - # }, - # { - # "name": "LGBM_scale_pos_weight", - # "model": LGBMClassifier( - # n_estimators=500, - # learning_rate=0.05, - # max_depth=-1, - # scale_pos_weight=scale_pos, - # random_state=42, - # verbose=-1, - # n_jobs=-1, - # ), - # "smote": False, - # }, - # { - # "name": "LGBM_is_unbalance", - # "model": LGBMClassifier( - # n_estimators=500, - # learning_rate=0.05, - # max_depth=-1, - # is_unbalance=True, - # random_state=42, - # verbose=-1, - # n_jobs=-1, - # ), - # "smote": False, - # }, - # { - # "name": "LGBM_DART", - # "model": LGBMClassifier( - # n_estimators=500, - # learning_rate=0.05, - # max_depth=-1, - # subsample=0.8, - # colsample_bytree=0.8, - # boosting_type="dart", - # random_state=42, - # verbose=-1, - # n_jobs=-1, - # ), - # "smote": True, - # "smote_method": "kmeans", - # }, - # { - # "name": "XGB_scale_pos_weight", - # "model": XGBClassifier( - # n_estimators=500, - # learning_rate=0.05, - # max_depth=6, - # scale_pos_weight=scale_pos, - # random_state=42, - # n_jobs=-1, - # use_label_encoder=False, - # eval_metric="logloss", - # ), - # "smote": False, - # }, - # { - # "name": "CatBoost_balanced", - # "model": CatBoostClassifier( - # iterations=500, - # learning_rate=0.05, - # depth=6, - # class_weights=[1, scale_pos], - # random_state=42, - # verbose=0, - # ), - # "smote": False, - # }, - # { - # "name": "RandomForest_balanced", - # "model": RandomForestClassifier( - # n_estimators=500, - # max_depth=None, - # class_weight="balanced", - # random_state=42, - # n_jobs=-1, - # ), - # "smote": False, - # }, - # { - # "name": "BalancedRandomForest", - # "model": BalancedRandomForestClassifier( - # n_estimators=500, - # max_depth=None, - # random_state=42, - # n_jobs=-1, - # ), - # "smote": False, - # }, - # { - # "name": "LogisticRegression_balanced", - # "model": LogisticRegression( - # max_iter=1000, - # class_weight="balanced", - # solver="liblinear", - # random_state=42, - # ), - # "smote": False, - # }, { - "name": "CatBoost_balanced", - "model": CatBoostClassifier( - iterations=500, + "name": "LGBM_FOCAL_LOSS", + "model": LGBMFocalWrapper( + n_estimators=500, learning_rate=0.05, - depth=6, - class_weights=[1, scale_pos], + max_depth=-1, + subsample=0.8, + colsample_bytree=0.8, random_state=42, - verbose=0, ), - "smote": False, + "smote": True, + "smote_method": "kmeans", + }, + { + "name": "LGBM_SMOTE", + "model": LGBMClassifier( + n_estimators=500, + learning_rate=0.05, + max_depth=-1, + subsample=0.8, + colsample_bytree=0.8, + random_state=42, + verbose=-1, + n_jobs=-1, + ), + "smote": True, + "smote_method": "smote", }, { "name": "LGBM_KMEANS_SMOTE", @@ -302,8 +70,212 @@ models = [ ), "smote": True, "smote_method": "kmeans", - } - + }, + { + "name": "LGBM_SVM_SMOTE", + "model": LGBMClassifier( + n_estimators=500, + learning_rate=0.05, + max_depth=-1, + subsample=0.8, + colsample_bytree=0.8, + random_state=42, + verbose=-1, + n_jobs=-1, + ), + "smote": True, + "smote_method": "svm", + }, + { + "name": "LGBM_BORDERLINE_SMOTE", + "model": LGBMClassifier( + n_estimators=500, + learning_rate=0.05, + max_depth=-1, + subsample=0.8, + colsample_bytree=0.8, + random_state=42, + verbose=-1, + n_jobs=-1, + ), + "smote": True, + "smote_method": "borderline", + }, + { + "name": "LGBM_ADASYN_SMOTE", + "model": LGBMClassifier( + n_estimators=500, + learning_rate=0.05, + max_depth=-1, + subsample=0.8, + colsample_bytree=0.8, + random_state=42, + verbose=-1, + n_jobs=-1, + ), + "smote": True, + "smote_method": "adasyn", + }, + { + "name": "LGBM_Balanced", + "model": LGBMClassifier( + n_estimators=500, + learning_rate=0.05, + max_depth=-1, + subsample=0.8, + colsample_bytree=0.8, + class_weight="balanced", + random_state=42, + verbose=-1, + n_jobs=-1, + ), + "smote": False, + }, + { + "name": "LGBM_DART", + "model": LGBMClassifier( + n_estimators=500, + learning_rate=0.05, + max_depth=-1, + subsample=0.8, + colsample_bytree=0.8, + boosting_type="dart", + random_state=42, + verbose=-1, + n_jobs=-1, + ), + "smote": True, + "smote_method": "kmeans", + }, + { + "name": "LGBM_GOSS", + "model": LGBMClassifier( + n_estimators=500, + learning_rate=0.05, + max_depth=-1, + boosting_type="goss", + random_state=42, + verbose=-1, + n_jobs=-1, + ), + "smote": True, + "smote_method": "kmeans", + }, + { + "name": "LGBM_RF", + "model": LGBMClassifier( + n_estimators=500, + learning_rate=0.05, + max_depth=-1, + boosting_type="rf", + subsample=0.8, + colsample_bytree=0.8, + random_state=42, + verbose=-1, + n_jobs=-1, + ), + "smote": True, + "smote_method": "kmeans", + }, + { + "name": "LGBM_scale_pos_weight", + "model": LGBMClassifier( + n_estimators=500, + learning_rate=0.05, + max_depth=-1, + scale_pos_weight=scale_pos, + random_state=42, + verbose=-1, + n_jobs=-1, + ), + "smote": False, + }, + { + "name": "LGBM_is_unbalance", + "model": LGBMClassifier( + n_estimators=500, + learning_rate=0.05, + max_depth=-1, + is_unbalance=True, + random_state=42, + verbose=-1, + n_jobs=-1, + ), + "smote": False, + }, + { + "name": "LGBM_DART", + "model": LGBMClassifier( + n_estimators=500, + learning_rate=0.05, + max_depth=-1, + subsample=0.8, + colsample_bytree=0.8, + boosting_type="dart", + random_state=42, + verbose=-1, + n_jobs=-1, + ), + "smote": True, + "smote_method": "kmeans", + }, + { + "name": "XGB_scale_pos_weight", + "model": XGBClassifier( + n_estimators=500, + learning_rate=0.05, + max_depth=6, + scale_pos_weight=scale_pos, + random_state=42, + n_jobs=-1, + use_label_encoder=False, + eval_metric="logloss", + ), + "smote": False, + }, + { + "name": "CatBoost_balanced", + "model": CatBoostClassifier( + iterations=500, + learning_rate=0.05, + depth=6, + class_weights=[1, scale_pos], + random_state=42, + verbose=0, + ), + "smote": False, + }, + { + "name": "RandomForest_balanced", + "model": RandomForestClassifier( + n_estimators=500, + max_depth=None, + class_weight="balanced", + random_state=42, + n_jobs=-1, + ), + "smote": False, + }, + { + "name": "BalancedRandomForest", + "model": BalancedRandomForestClassifier( + n_estimators=500, + max_depth=None, + random_state=42, + n_jobs=-1, + ), + "smote": False, + }, + { + "name": "LogisticRegression_balanced", + "model": LogisticRegression( + max_iter=1000, + class_weight="balanced", + solver="liblinear", + random_state=42, + ), + "smote": False, + }, ] @@ -336,7 +308,7 @@ for m in models: ) results_df = pandas.DataFrame(results_to_save) -csv_file = "lgbm_vs_cat_kmeans_smote_k10_results.csv" +csv_file = "lightgbm_results.csv" try: results_df.to_csv(csv_file, mode="a", index=False, header=False) diff --git a/train.py b/train.py index ecb25bb..33bf10e 100644 --- a/train.py +++ b/train.py @@ -42,13 +42,8 @@ def train_model_with_kfold( if smote: if smote_method.lower() == "kmeans": - from collections import Counter - minority = Counter(y_train)[1] - - k_neighbors = min(10, max(2, minority // 10)) - sampler = KMeansSMOTE( - k_neighbors=k_neighbors, + k_neighbors=15, cluster_balance_threshold=0.1, random_state=random_state, ) diff --git a/utils.py b/utils.py index 6f9dea6..da1d0c0 100644 --- a/utils.py +++ b/utils.py @@ -40,18 +40,62 @@ def missing_value_handler(data_path): missing_value_counts = data_frame_imputed.isna().sum() write_textfile(f"{data_directory}/no_missing.txt", missing_value_counts) + + data_frame_imputed.to_csv("./data/Ketamine_icp_no_missing.csv", index=False) + return data_frame_imputed def scaling_handler(data_frame, method="robust_scaling"): - if method == "robust_scaling": - import pandas - from sklearn.preprocessing import RobustScaler + import pandas + from sklearn.preprocessing import ( + MaxAbsScaler, + MinMaxScaler, + PowerTransformer, + QuantileTransformer, + RobustScaler, + StandardScaler, + ) - labels = data_frame["label"] + # Separate features and label + labels = data_frame["label"] + X = data_frame.drop("label", axis=1) + + # Choose scaler/transformer + if method == "robust_scaling": scaler = RobustScaler() - x = data_frame.drop("label", axis=1) - x_scale = scaler.fit_transform(x) - data_frame_scaled = pandas.DataFrame(x_scale, columns=x.columns) - data_frame_scaled["label"] = labels.values - return data_frame_scaled + elif method == "standard_scaling": + scaler = StandardScaler() + elif method == "minmax_scaling": + scaler = MinMaxScaler() + elif method == "maxabs_scaling": + scaler = MaxAbsScaler() + elif method == "quantile_normal": + scaler = QuantileTransformer(output_distribution="normal", random_state=42) + elif method == "quantile_uniform": + scaler = QuantileTransformer(output_distribution="uniform", random_state=42) + elif method == "yeo_johnson": + scaler = PowerTransformer(method="yeo-johnson") + elif method == "box_cox": + # Box-Cox requires all positive values + scaler = PowerTransformer( + method="box-cox", + ) + X_pos = X.copy() + + min_per_column = X_pos.min() + + for col in X_pos.columns: + if min_per_column[col] <= 0: + X_pos[col] = X_pos[col] + abs(min_per_column[col]) + 1e-6 # tiny offset + + X = X_pos + else: + raise ValueError(f"Unknown scaling method: {method}") + + # Fit and transform + X_scaled = scaler.fit_transform(X) + data_frame_scaled = pandas.DataFrame(X_scaled, columns=X.columns) + data_frame_scaled["label"] = labels.values + + return data_frame_scaled