This commit is contained in:
2025-12-10 22:03:58 +01:00
parent 05dfcb36de
commit ccd94812bc

17
main.R
View File

@@ -4,14 +4,11 @@
library(lightgbm)
library(MLmetrics)
# 1. Load your data
df <- read.csv("./data/Ketamine_icp.csv")
# --- 2. Data Preparation ---
target_name <- "label"
target_index <- which(names(df) == target_name)
# Prepare target variable
if (is.factor(df[, target_index])) {
y <- as.numeric(df[, target_index]) - 1
} else {
@@ -21,7 +18,6 @@ if (is.factor(df[, target_index])) {
# Create the data matrix for features
X <- as.matrix(df[, -target_index])
# --- 3. Split Data into Training and Testing Sets ---
set.seed(42)
train_index <- sample(nrow(X), size = 0.8 * nrow(X))
@@ -30,7 +26,6 @@ X_test <- X[-train_index, ]
y_train <- y[train_index]
y_test <- y[-train_index]
# --- 4. Get Feature Importance from Full Model ---
lgb_train_full <- lgb.Dataset(data = X_train, label = y_train)
params <- list(
@@ -64,9 +59,7 @@ results_df <- data.frame(
Recall_class1 = numeric()
)
# --- 5. Loop through different numbers of top features ---
cat("Training models with different numbers of top features...\n")
cat("=====================================================\n")
for (i in 1:num_features) {
cat(paste("Training model with top", i, "features...\n"))
@@ -140,16 +133,12 @@ for (i in 1:num_features) {
"| Recall:", round(recall, 4), "\n"))
}
cat("=====================================================\n")
# --- 6. Display Results ---
cat("\nSummary of Results:\n")
cat("===================\n")
cat("Summary of Results:\n")
print(results_df)
# Find best performing models based on different metrics
cat("\nBest Performing Models:\n")
cat("=======================\n")
# Best by F1 score
if (!all(is.na(results_df$F1_class1))) {
@@ -170,7 +159,6 @@ best_acc_idx <- which.max(results_df$Accuracy)
cat(paste("Best Accuracy (", results_df$Accuracy[best_acc_idx],
") with", results_df$Num_Features[best_acc_idx], "features\n"))
# --- 7. Optional: Plot metrics vs number of features ---
if (require(ggplot2)) {
library(ggplot2)
@@ -213,10 +201,7 @@ if (require(ggplot2)) {
print(p3)
}
# --- 8. Save results to CSV ---
write.csv(results_df, "feature_selection_results.csv", row.names = FALSE)
cat("\nResults saved to 'feature_selection_results.csv'\n")
# --- 9. Display top 20 feature importance plot ---
lgb.plot.importance(importance, top_n = min(20, num_features))