56 lines
1.6 KiB
Python
56 lines
1.6 KiB
Python
import pandas
|
|
from catboost_model import CAT_BOOST
|
|
from lightgbm_model import LIGHT_GBM
|
|
|
|
data_frame = pandas.read_csv("./data/Ketamine_icp_no_missing.csv")
|
|
|
|
cat_boost_results = pandas.read_csv("./cat_boost_tuning_results.csv")
|
|
lgbm_results = pandas.read_csv("./lightgbm_tuning_results.csv")
|
|
|
|
|
|
def get_best_params(data_frame, metrics=["f2_class1", "f1_class1"]):
|
|
max_f2 = cat_boost_results[metrics[0]].max()
|
|
|
|
best_f2_rows = cat_boost_results[cat_boost_results[metrics[0]] == max_f2]
|
|
|
|
best_row = best_f2_rows.loc[best_f2_rows[metrics[1]].idxmax()]
|
|
|
|
return best_row.to_dict()
|
|
|
|
|
|
cat_boost_best_params = get_best_params(cat_boost_results)
|
|
cat_boost_model = CAT_BOOST(data_frame, params=cat_boost_best_params)
|
|
cat_boost_model.fit()
|
|
cat_test_metrics = cat_boost_model.eval()
|
|
print(cat_test_metrics)
|
|
|
|
x_test, y_test = cat_boost_model.x_test, cat_boost_model.y_test
|
|
|
|
lgbm_best_params = get_best_params(lgbm_results)
|
|
lgbm_model = LIGHT_GBM(data_frame, params=lgbm_best_params)
|
|
lgbm_model.fit()
|
|
lgbm_test_metrics = lgbm_model.eval(x_test, y_test)
|
|
print(lgbm_test_metrics)
|
|
|
|
import pandas as pd
|
|
|
|
|
|
def clean_metrics(metrics):
|
|
return {k: float(v) if hasattr(v, "item") else v for k, v in metrics.items()}
|
|
|
|
|
|
cat_test_metrics_clean = clean_metrics(cat_test_metrics)
|
|
lgbm_test_metrics_clean = clean_metrics(lgbm_test_metrics)
|
|
|
|
comparison_df = pd.DataFrame(
|
|
[
|
|
{"model": "catboost", **cat_test_metrics_clean},
|
|
{"model": "lightgbm", **lgbm_test_metrics_clean},
|
|
]
|
|
)
|
|
|
|
comparison_filename = "comparison_catboost_lightgbm.csv"
|
|
comparison_df.to_csv(comparison_filename, index=False)
|
|
|
|
print(f"Comparison saved to: {comparison_filename}")
|