Files
Electrocardiogram/test_no_x.py

43 lines
1.1 KiB
Python

import pandas
from models.lightgbm_model import LIGHT_GBM
data_frame = pandas.read_csv("./data/Ketamine_icp_no_x.csv")
lgbm_results = pandas.read_csv("./lightgbm_tuning_results_no_x.csv")
def get_best_params(results, metrics=["f2_class1", "f1_class1"]):
max_f2 = results[metrics[0]].max()
best_f2_rows = results[results[metrics[0]] == max_f2]
best_row = best_f2_rows.loc[best_f2_rows[metrics[1]].idxmax()]
return best_row.to_dict()
lgbm_best_params = get_best_params(lgbm_results)
lgbm_model = LIGHT_GBM(data_frame, params=lgbm_best_params)
lgbm_model.fit()
lgbm_test_metrics = lgbm_model.eval()
print(lgbm_test_metrics)
def clean_metrics(metrics):
return {k: float(v) if hasattr(v, "item") else v for k, v in metrics.items()}
lgbm_test_metrics_clean = clean_metrics(lgbm_test_metrics)
comparison_df = pandas.DataFrame(
[
{"model": "lightgbm", **lgbm_test_metrics_clean},
]
)
comparison_filename = "comparison_catboost_lightgbm_no_x.csv"
comparison_df.to_csv(comparison_filename, index=False)
print(f"Comparison saved to: {comparison_filename}")