main.R
This commit is contained in:
65
main.R
65
main.R
@@ -4,14 +4,11 @@
|
|||||||
library(lightgbm)
|
library(lightgbm)
|
||||||
library(MLmetrics)
|
library(MLmetrics)
|
||||||
|
|
||||||
# 1. Load your data
|
|
||||||
df <- read.csv("./data/Ketamine_icp.csv")
|
df <- read.csv("./data/Ketamine_icp.csv")
|
||||||
|
|
||||||
# --- 2. Data Preparation ---
|
target_name <- "label"
|
||||||
target_name <- "label"
|
|
||||||
target_index <- which(names(df) == target_name)
|
target_index <- which(names(df) == target_name)
|
||||||
|
|
||||||
# Prepare target variable
|
|
||||||
if (is.factor(df[, target_index])) {
|
if (is.factor(df[, target_index])) {
|
||||||
y <- as.numeric(df[, target_index]) - 1
|
y <- as.numeric(df[, target_index]) - 1
|
||||||
} else {
|
} else {
|
||||||
@@ -19,18 +16,16 @@ if (is.factor(df[, target_index])) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Create the data matrix for features
|
# Create the data matrix for features
|
||||||
X <- as.matrix(df[, -target_index])
|
X <- as.matrix(df[, -target_index])
|
||||||
|
|
||||||
# --- 3. Split Data into Training and Testing Sets ---
|
set.seed(42)
|
||||||
set.seed(42)
|
train_index <- sample(nrow(X), size = 0.8 * nrow(X))
|
||||||
train_index <- sample(nrow(X), size = 0.8 * nrow(X))
|
|
||||||
|
|
||||||
X_train <- X[train_index, ]
|
X_train <- X[train_index, ]
|
||||||
X_test <- X[-train_index, ]
|
X_test <- X[-train_index, ]
|
||||||
y_train <- y[train_index]
|
y_train <- y[train_index]
|
||||||
y_test <- y[-train_index]
|
y_test <- y[-train_index]
|
||||||
|
|
||||||
# --- 4. Get Feature Importance from Full Model ---
|
|
||||||
lgb_train_full <- lgb.Dataset(data = X_train, label = y_train)
|
lgb_train_full <- lgb.Dataset(data = X_train, label = y_train)
|
||||||
|
|
||||||
params <- list(
|
params <- list(
|
||||||
@@ -64,23 +59,21 @@ results_df <- data.frame(
|
|||||||
Recall_class1 = numeric()
|
Recall_class1 = numeric()
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- 5. Loop through different numbers of top features ---
|
|
||||||
cat("Training models with different numbers of top features...\n")
|
cat("Training models with different numbers of top features...\n")
|
||||||
cat("=====================================================\n")
|
|
||||||
|
|
||||||
for (i in 1:num_features) {
|
for (i in 1:num_features) {
|
||||||
cat(paste("Training model with top", i, "features...\n"))
|
cat(paste("Training model with top", i, "features...\n"))
|
||||||
|
|
||||||
# Select top i features
|
# Select top i features
|
||||||
top_features <- importance$Feature[1:i]
|
top_features <- importance$Feature[1:i]
|
||||||
|
|
||||||
# Subset training and test data
|
# Subset training and test data
|
||||||
X_train_sub <- X_train[, top_features, drop = FALSE]
|
X_train_sub <- X_train[, top_features, drop = FALSE]
|
||||||
X_test_sub <- X_test[, top_features, drop = FALSE]
|
X_test_sub <- X_test[, top_features, drop = FALSE]
|
||||||
|
|
||||||
# Create LightGBM dataset
|
# Create LightGBM dataset
|
||||||
lgb_train_sub <- lgb.Dataset(data = X_train_sub, label = y_train)
|
lgb_train_sub <- lgb.Dataset(data = X_train_sub, label = y_train)
|
||||||
|
|
||||||
# Train model with subset of features
|
# Train model with subset of features
|
||||||
bst_sub <- lgb.train(
|
bst_sub <- lgb.train(
|
||||||
params = params,
|
params = params,
|
||||||
@@ -88,27 +81,27 @@ for (i in 1:num_features) {
|
|||||||
nrounds = 100,
|
nrounds = 100,
|
||||||
verbose = -1
|
verbose = -1
|
||||||
)
|
)
|
||||||
|
|
||||||
# Make predictions
|
# Make predictions
|
||||||
pred_prob_sub <- predict(bst_sub, X_test_sub)
|
pred_prob_sub <- predict(bst_sub, X_test_sub)
|
||||||
pred_class_sub <- as.numeric(pred_prob_sub > 0.5)
|
pred_class_sub <- as.numeric(pred_prob_sub > 0.5)
|
||||||
|
|
||||||
# Calculate metrics
|
# Calculate metrics
|
||||||
accuracy <- mean(pred_class_sub == y_test)
|
accuracy <- mean(pred_class_sub == y_test)
|
||||||
|
|
||||||
# For binary classification
|
# For binary classification
|
||||||
if (length(unique(y_test)) == 2) {
|
if (length(unique(y_test)) == 2) {
|
||||||
# F1 score for class 1
|
# F1 score for class 1
|
||||||
f1 <- F1_Score(y_true = y_test, y_pred = pred_class_sub, positive = 1)
|
f1 <- F1_Score(y_true = y_test, y_pred = pred_class_sub, positive = 1)
|
||||||
|
|
||||||
# Precision and Recall for class 1
|
# Precision and Recall for class 1
|
||||||
precision <- Precision(y_true = y_test, y_pred = pred_class_sub, positive = 1)
|
precision <- Precision(y_true = y_test, y_pred = pred_class_sub, positive = 1)
|
||||||
recall <- Recall(y_true = y_test, y_pred = pred_class_sub, positive = 1)
|
recall <- Recall(y_true = y_test, y_pred = pred_class_sub, positive = 1)
|
||||||
|
|
||||||
# F2-score (beta = 2)
|
# F2-score (beta = 2)
|
||||||
beta <- 2
|
beta <- 2
|
||||||
f2 <- (1 + beta^2) * (precision * recall) / (beta^2 * precision + recall)
|
f2 <- (1 + beta^2) * (precision * recall) / (beta^2 * precision + recall)
|
||||||
|
|
||||||
# Handle cases where precision or recall might be NaN
|
# Handle cases where precision or recall might be NaN
|
||||||
if (is.na(f2)) {
|
if (is.na(f2)) {
|
||||||
f2 <- 0
|
f2 <- 0
|
||||||
@@ -120,7 +113,7 @@ for (i in 1:num_features) {
|
|||||||
recall <- NA
|
recall <- NA
|
||||||
f2 <- NA
|
f2 <- NA
|
||||||
}
|
}
|
||||||
|
|
||||||
# Store results
|
# Store results
|
||||||
results_df <- rbind(results_df, data.frame(
|
results_df <- rbind(results_df, data.frame(
|
||||||
Num_Features = i,
|
Num_Features = i,
|
||||||
@@ -131,49 +124,44 @@ for (i in 1:num_features) {
|
|||||||
Precision_class1 = round(precision, 4),
|
Precision_class1 = round(precision, 4),
|
||||||
Recall_class1 = round(recall, 4)
|
Recall_class1 = round(recall, 4)
|
||||||
))
|
))
|
||||||
|
|
||||||
# Print progress
|
# Print progress
|
||||||
cat(paste(" Accuracy:", round(accuracy, 4),
|
cat(paste(" Accuracy:", round(accuracy, 4),
|
||||||
"| F1:", round(f1, 4),
|
"| F1:", round(f1, 4),
|
||||||
"| F2:", round(f2, 4),
|
"| F2:", round(f2, 4),
|
||||||
"| Precision:", round(precision, 4),
|
"| Precision:", round(precision, 4),
|
||||||
"| Recall:", round(recall, 4), "\n"))
|
"| Recall:", round(recall, 4), "\n"))
|
||||||
}
|
}
|
||||||
|
|
||||||
cat("=====================================================\n")
|
|
||||||
|
|
||||||
# --- 6. Display Results ---
|
cat("Summary of Results:\n")
|
||||||
cat("\nSummary of Results:\n")
|
|
||||||
cat("===================\n")
|
|
||||||
print(results_df)
|
print(results_df)
|
||||||
|
|
||||||
# Find best performing models based on different metrics
|
# Find best performing models based on different metrics
|
||||||
cat("\nBest Performing Models:\n")
|
cat("\nBest Performing Models:\n")
|
||||||
cat("=======================\n")
|
|
||||||
|
|
||||||
# Best by F1 score
|
# Best by F1 score
|
||||||
if (!all(is.na(results_df$F1_class1))) {
|
if (!all(is.na(results_df$F1_class1))) {
|
||||||
best_f1_idx <- which.max(results_df$F1_class1)
|
best_f1_idx <- which.max(results_df$F1_class1)
|
||||||
cat(paste("Best F1-score (", results_df$F1_class1[best_f1_idx],
|
cat(paste("Best F1-score (", results_df$F1_class1[best_f1_idx],
|
||||||
") with", results_df$Num_Features[best_f1_idx], "features\n"))
|
") with", results_df$Num_Features[best_f1_idx], "features\n"))
|
||||||
}
|
}
|
||||||
|
|
||||||
# Best by F2 score
|
# Best by F2 score
|
||||||
if (!all(is.na(results_df$F2_class1))) {
|
if (!all(is.na(results_df$F2_class1))) {
|
||||||
best_f2_idx <- which.max(results_df$F2_class1)
|
best_f2_idx <- which.max(results_df$F2_class1)
|
||||||
cat(paste("Best F2-score (", results_df$F2_class1[best_f2_idx],
|
cat(paste("Best F2-score (", results_df$F2_class1[best_f2_idx],
|
||||||
") with", results_df$Num_Features[best_f2_idx], "features\n"))
|
") with", results_df$Num_Features[best_f2_idx], "features\n"))
|
||||||
}
|
}
|
||||||
|
|
||||||
# Best by Accuracy
|
# Best by Accuracy
|
||||||
best_acc_idx <- which.max(results_df$Accuracy)
|
best_acc_idx <- which.max(results_df$Accuracy)
|
||||||
cat(paste("Best Accuracy (", results_df$Accuracy[best_acc_idx],
|
cat(paste("Best Accuracy (", results_df$Accuracy[best_acc_idx],
|
||||||
") with", results_df$Num_Features[best_acc_idx], "features\n"))
|
") with", results_df$Num_Features[best_acc_idx], "features\n"))
|
||||||
|
|
||||||
# --- 7. Optional: Plot metrics vs number of features ---
|
|
||||||
if (require(ggplot2)) {
|
if (require(ggplot2)) {
|
||||||
library(ggplot2)
|
library(ggplot2)
|
||||||
|
|
||||||
# Plot F1 and F2 scores
|
# Plot F1 and F2 scores
|
||||||
p1 <- ggplot(results_df, aes(x = Num_Features)) +
|
p1 <- ggplot(results_df, aes(x = Num_Features)) +
|
||||||
geom_line(aes(y = F1_class1, color = "F1 Score"), size = 1) +
|
geom_line(aes(y = F1_class1, color = "F1 Score"), size = 1) +
|
||||||
@@ -185,7 +173,7 @@ if (require(ggplot2)) {
|
|||||||
y = "Score Value") +
|
y = "Score Value") +
|
||||||
theme_minimal() +
|
theme_minimal() +
|
||||||
scale_color_manual(values = c("F1 Score" = "blue", "F2 Score" = "red"))
|
scale_color_manual(values = c("F1 Score" = "blue", "F2 Score" = "red"))
|
||||||
|
|
||||||
# Plot Accuracy
|
# Plot Accuracy
|
||||||
p2 <- ggplot(results_df, aes(x = Num_Features, y = Accuracy)) +
|
p2 <- ggplot(results_df, aes(x = Num_Features, y = Accuracy)) +
|
||||||
geom_line(color = "darkgreen", size = 1) +
|
geom_line(color = "darkgreen", size = 1) +
|
||||||
@@ -194,7 +182,7 @@ if (require(ggplot2)) {
|
|||||||
x = "Number of Top Features",
|
x = "Number of Top Features",
|
||||||
y = "Accuracy") +
|
y = "Accuracy") +
|
||||||
theme_minimal()
|
theme_minimal()
|
||||||
|
|
||||||
# Plot Precision and Recall
|
# Plot Precision and Recall
|
||||||
p3 <- ggplot(results_df, aes(x = Num_Features)) +
|
p3 <- ggplot(results_df, aes(x = Num_Features)) +
|
||||||
geom_line(aes(y = Precision_class1, color = "Precision"), size = 1) +
|
geom_line(aes(y = Precision_class1, color = "Precision"), size = 1) +
|
||||||
@@ -206,17 +194,14 @@ if (require(ggplot2)) {
|
|||||||
y = "Score Value") +
|
y = "Score Value") +
|
||||||
theme_minimal() +
|
theme_minimal() +
|
||||||
scale_color_manual(values = c("Precision" = "purple", "Recall" = "orange"))
|
scale_color_manual(values = c("Precision" = "purple", "Recall" = "orange"))
|
||||||
|
|
||||||
# Display plots
|
# Display plots
|
||||||
print(p1)
|
print(p1)
|
||||||
print(p2)
|
print(p2)
|
||||||
print(p3)
|
print(p3)
|
||||||
}
|
}
|
||||||
|
|
||||||
# --- 8. Save results to CSV ---
|
|
||||||
write.csv(results_df, "feature_selection_results.csv", row.names = FALSE)
|
write.csv(results_df, "feature_selection_results.csv", row.names = FALSE)
|
||||||
cat("\nResults saved to 'feature_selection_results.csv'\n")
|
cat("\nResults saved to 'feature_selection_results.csv'\n")
|
||||||
|
|
||||||
# --- 9. Display top 20 feature importance plot ---
|
|
||||||
lgb.plot.importance(importance, top_n = min(20, num_features))
|
lgb.plot.importance(importance, top_n = min(20, num_features))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user