main.R
This commit is contained in:
17
main.R
17
main.R
@@ -4,14 +4,11 @@
|
|||||||
library(lightgbm)
|
library(lightgbm)
|
||||||
library(MLmetrics)
|
library(MLmetrics)
|
||||||
|
|
||||||
# 1. Load your data
|
|
||||||
df <- read.csv("./data/Ketamine_icp.csv")
|
df <- read.csv("./data/Ketamine_icp.csv")
|
||||||
|
|
||||||
# --- 2. Data Preparation ---
|
|
||||||
target_name <- "label"
|
target_name <- "label"
|
||||||
target_index <- which(names(df) == target_name)
|
target_index <- which(names(df) == target_name)
|
||||||
|
|
||||||
# Prepare target variable
|
|
||||||
if (is.factor(df[, target_index])) {
|
if (is.factor(df[, target_index])) {
|
||||||
y <- as.numeric(df[, target_index]) - 1
|
y <- as.numeric(df[, target_index]) - 1
|
||||||
} else {
|
} else {
|
||||||
@@ -21,7 +18,6 @@ if (is.factor(df[, target_index])) {
|
|||||||
# Create the data matrix for features
|
# Create the data matrix for features
|
||||||
X <- as.matrix(df[, -target_index])
|
X <- as.matrix(df[, -target_index])
|
||||||
|
|
||||||
# --- 3. Split Data into Training and Testing Sets ---
|
|
||||||
set.seed(42)
|
set.seed(42)
|
||||||
train_index <- sample(nrow(X), size = 0.8 * nrow(X))
|
train_index <- sample(nrow(X), size = 0.8 * nrow(X))
|
||||||
|
|
||||||
@@ -30,7 +26,6 @@ X_test <- X[-train_index, ]
|
|||||||
y_train <- y[train_index]
|
y_train <- y[train_index]
|
||||||
y_test <- y[-train_index]
|
y_test <- y[-train_index]
|
||||||
|
|
||||||
# --- 4. Get Feature Importance from Full Model ---
|
|
||||||
lgb_train_full <- lgb.Dataset(data = X_train, label = y_train)
|
lgb_train_full <- lgb.Dataset(data = X_train, label = y_train)
|
||||||
|
|
||||||
params <- list(
|
params <- list(
|
||||||
@@ -64,9 +59,7 @@ results_df <- data.frame(
|
|||||||
Recall_class1 = numeric()
|
Recall_class1 = numeric()
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- 5. Loop through different numbers of top features ---
|
|
||||||
cat("Training models with different numbers of top features...\n")
|
cat("Training models with different numbers of top features...\n")
|
||||||
cat("=====================================================\n")
|
|
||||||
|
|
||||||
for (i in 1:num_features) {
|
for (i in 1:num_features) {
|
||||||
cat(paste("Training model with top", i, "features...\n"))
|
cat(paste("Training model with top", i, "features...\n"))
|
||||||
@@ -140,16 +133,12 @@ for (i in 1:num_features) {
|
|||||||
"| Recall:", round(recall, 4), "\n"))
|
"| Recall:", round(recall, 4), "\n"))
|
||||||
}
|
}
|
||||||
|
|
||||||
cat("=====================================================\n")
|
|
||||||
|
|
||||||
# --- 6. Display Results ---
|
cat("Summary of Results:\n")
|
||||||
cat("\nSummary of Results:\n")
|
|
||||||
cat("===================\n")
|
|
||||||
print(results_df)
|
print(results_df)
|
||||||
|
|
||||||
# Find best performing models based on different metrics
|
# Find best performing models based on different metrics
|
||||||
cat("\nBest Performing Models:\n")
|
cat("\nBest Performing Models:\n")
|
||||||
cat("=======================\n")
|
|
||||||
|
|
||||||
# Best by F1 score
|
# Best by F1 score
|
||||||
if (!all(is.na(results_df$F1_class1))) {
|
if (!all(is.na(results_df$F1_class1))) {
|
||||||
@@ -170,7 +159,6 @@ best_acc_idx <- which.max(results_df$Accuracy)
|
|||||||
cat(paste("Best Accuracy (", results_df$Accuracy[best_acc_idx],
|
cat(paste("Best Accuracy (", results_df$Accuracy[best_acc_idx],
|
||||||
") with", results_df$Num_Features[best_acc_idx], "features\n"))
|
") with", results_df$Num_Features[best_acc_idx], "features\n"))
|
||||||
|
|
||||||
# --- 7. Optional: Plot metrics vs number of features ---
|
|
||||||
if (require(ggplot2)) {
|
if (require(ggplot2)) {
|
||||||
library(ggplot2)
|
library(ggplot2)
|
||||||
|
|
||||||
@@ -213,10 +201,7 @@ if (require(ggplot2)) {
|
|||||||
print(p3)
|
print(p3)
|
||||||
}
|
}
|
||||||
|
|
||||||
# --- 8. Save results to CSV ---
|
|
||||||
write.csv(results_df, "feature_selection_results.csv", row.names = FALSE)
|
write.csv(results_df, "feature_selection_results.csv", row.names = FALSE)
|
||||||
cat("\nResults saved to 'feature_selection_results.csv'\n")
|
cat("\nResults saved to 'feature_selection_results.csv'\n")
|
||||||
|
|
||||||
# --- 9. Display top 20 feature importance plot ---
|
|
||||||
lgb.plot.importance(importance, top_n = min(20, num_features))
|
lgb.plot.importance(importance, top_n = min(20, num_features))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user