diff --git a/.gitignore b/.gitignore index 6936905..7850657 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ data/ -.DS_Store \ No newline at end of file +.DS_Store +.Rproj.user diff --git a/38323473-2aab-4553-b759-81bc15506d7e.png b/38323473-2aab-4553-b759-81bc15506d7e.png new file mode 100644 index 0000000..2cb47e4 Binary files /dev/null and b/38323473-2aab-4553-b759-81bc15506d7e.png differ diff --git a/Electrocardiogram.Rproj b/Electrocardiogram.Rproj new file mode 100644 index 0000000..8e3c2eb --- /dev/null +++ b/Electrocardiogram.Rproj @@ -0,0 +1,13 @@ +Version: 1.0 + +RestoreWorkspace: Default +SaveWorkspace: Default +AlwaysSaveHistory: Default + +EnableCodeIndexing: Yes +UseSpacesForTab: Yes +NumSpacesForTab: 2 +Encoding: UTF-8 + +RnwWeave: Sweave +LaTeX: pdfLaTeX diff --git a/feature_selection_results.csv b/feature_selection_results.csv new file mode 100644 index 0000000..2d59e13 --- /dev/null +++ b/feature_selection_results.csv @@ -0,0 +1,32 @@ +"Num_Features","Features","Accuracy","F1_class1","F2_class1","Precision_class1","Recall_class1" +1,"freq_median_freq",0.9773,0.4111,0.3103,0.8966,0.2667 +2,"freq_median_freq, time_kurtosis",0.9806,0.5544,0.454,0.8778,0.4051 +3,"freq_median_freq, time_kurtosis, freq_peak_freq",0.9812,0.5859,0.4932,0.8529,0.4462 +4,"freq_median_freq, time_kurtosis, freq_peak_freq, acf_time_decay",0.9808,0.5743,0.4824,0.8416,0.4359 +5,"freq_median_freq, time_kurtosis, freq_peak_freq, acf_time_decay, time_ptp",0.9844,0.6928,0.627,0.8394,0.5897 +6,"freq_median_freq, time_kurtosis, freq_peak_freq, acf_time_decay, time_ptp, freq_centroid",0.9863,0.7305,0.6638,0.8777,0.6256 +7,"freq_median_freq, time_kurtosis, freq_peak_freq, acf_time_decay, time_ptp, freq_centroid, freq_top1_amp",0.9847,0.7024,0.6406,0.8369,0.6051 +8,"freq_median_freq, time_kurtosis, freq_peak_freq, acf_time_decay, time_ptp, freq_centroid, freq_top1_amp, freq_rolloff_95",0.9855,0.7164,0.6522,0.8571,0.6154 +9,"freq_median_freq, time_kurtosis, freq_peak_freq, acf_time_decay, time_ptp, freq_centroid, freq_top1_amp, freq_rolloff_95, acf_mean_abs",0.985,0.7118,0.6541,0.8345,0.6205 +10,"freq_median_freq, time_kurtosis, freq_peak_freq, acf_time_decay, time_ptp, freq_centroid, freq_top1_amp, freq_rolloff_95, acf_mean_abs, time_std",0.986,0.7278,0.6663,0.8601,0.6308 +11,"freq_median_freq, time_kurtosis, freq_peak_freq, acf_time_decay, time_ptp, freq_centroid, freq_top1_amp, freq_rolloff_95, acf_mean_abs, time_std, time_range_ratio",0.9863,0.7353,0.6757,0.8621,0.641 +12,"freq_median_freq, time_kurtosis, freq_peak_freq, acf_time_decay, time_ptp, freq_centroid, freq_top1_amp, freq_rolloff_95, acf_mean_abs, time_std, time_range_ratio, time_min",0.9852,0.7104,0.6467,0.85,0.6103 +13,"freq_median_freq, time_kurtosis, freq_peak_freq, acf_time_decay, time_ptp, freq_centroid, freq_top1_amp, freq_rolloff_95, acf_mean_abs, time_std, time_range_ratio, time_min, time_skew",0.9867,0.7449,0.6857,0.8699,0.6513 +14,"freq_median_freq, time_kurtosis, freq_peak_freq, acf_time_decay, time_ptp, freq_centroid, freq_top1_amp, freq_rolloff_95, acf_mean_abs, time_std, time_range_ratio, time_min, time_skew, acf_integral_time",0.986,0.7278,0.6663,0.8601,0.6308 +15,"freq_median_freq, time_kurtosis, freq_peak_freq, acf_time_decay, time_ptp, freq_centroid, freq_top1_amp, freq_rolloff_95, acf_mean_abs, time_std, time_range_ratio, time_min, time_skew, acf_integral_time, ar_r2",0.9875,0.7602,0.7012,0.8844,0.6667 +16,"freq_median_freq, time_kurtosis, freq_peak_freq, acf_time_decay, time_ptp, freq_centroid, freq_top1_amp, freq_rolloff_95, acf_mean_abs, time_std, time_range_ratio, time_min, time_skew, acf_integral_time, ar_r2, time_var",0.987,0.7493,0.6872,0.8819,0.6513 +17,"freq_median_freq, time_kurtosis, freq_peak_freq, acf_time_decay, time_ptp, freq_centroid, freq_top1_amp, freq_rolloff_95, acf_mean_abs, time_std, time_range_ratio, time_min, time_skew, acf_integral_time, ar_r2, time_var, acf_val_dom",0.986,0.7294,0.6703,0.8552,0.6359 +18,"freq_median_freq, time_kurtosis, freq_peak_freq, acf_time_decay, time_ptp, freq_centroid, freq_top1_amp, freq_rolloff_95, acf_mean_abs, time_std, time_range_ratio, time_min, time_skew, acf_integral_time, ar_r2, time_var, acf_val_dom, ar_resid_var",0.9864,0.739,0.6803,0.863,0.6462 +19,"freq_median_freq, time_kurtosis, freq_peak_freq, acf_time_decay, time_ptp, freq_centroid, freq_top1_amp, freq_rolloff_95, acf_mean_abs, time_std, time_range_ratio, time_min, time_skew, acf_integral_time, ar_r2, time_var, acf_val_dom, ar_resid_var, time_iqr",0.9855,0.7181,0.6562,0.8521,0.6205 +20,"freq_median_freq, time_kurtosis, freq_peak_freq, acf_time_decay, time_ptp, freq_centroid, freq_top1_amp, freq_rolloff_95, acf_mean_abs, time_std, time_range_ratio, time_min, time_skew, acf_integral_time, ar_r2, time_var, acf_val_dom, ar_resid_var, time_iqr, freq_bandwidth",0.9858,0.7257,0.6656,0.8542,0.6308 +21,"freq_median_freq, time_kurtosis, freq_peak_freq, acf_time_decay, time_ptp, freq_centroid, freq_top1_amp, freq_rolloff_95, acf_mean_abs, time_std, time_range_ratio, time_min, time_skew, acf_integral_time, ar_r2, time_var, acf_val_dom, ar_resid_var, time_iqr, freq_bandwidth, time_max",0.9861,0.7362,0.6828,0.8467,0.6513 +22,"freq_median_freq, time_kurtosis, freq_peak_freq, acf_time_decay, time_ptp, freq_centroid, freq_top1_amp, freq_rolloff_95, acf_mean_abs, time_std, time_range_ratio, time_min, time_skew, acf_integral_time, ar_r2, time_var, acf_val_dom, ar_resid_var, time_iqr, freq_bandwidth, time_max, freq_entropy",0.9863,0.7368,0.6796,0.8571,0.6462 +23,"freq_median_freq, time_kurtosis, freq_peak_freq, acf_time_decay, time_ptp, freq_centroid, freq_top1_amp, freq_rolloff_95, acf_mean_abs, time_std, time_range_ratio, time_min, time_skew, acf_integral_time, ar_r2, time_var, acf_val_dom, ar_resid_var, time_iqr, freq_bandwidth, time_max, freq_entropy, time_p5",0.986,0.7278,0.6663,0.8601,0.6308 +24,"freq_median_freq, time_kurtosis, freq_peak_freq, acf_time_decay, time_ptp, freq_centroid, freq_top1_amp, freq_rolloff_95, acf_mean_abs, time_std, time_range_ratio, time_min, time_skew, acf_integral_time, ar_r2, time_var, acf_val_dom, ar_resid_var, time_iqr, freq_bandwidth, time_max, freq_entropy, time_p5, ar1_coef",0.9864,0.739,0.6803,0.863,0.6462 +25,"freq_median_freq, time_kurtosis, freq_peak_freq, acf_time_decay, time_ptp, freq_centroid, freq_top1_amp, freq_rolloff_95, acf_mean_abs, time_std, time_range_ratio, time_min, time_skew, acf_integral_time, ar_r2, time_var, acf_val_dom, ar_resid_var, time_iqr, freq_bandwidth, time_max, freq_entropy, time_p5, ar1_coef, time_p95",0.9857,0.7219,0.6609,0.8531,0.6256 +26,"freq_median_freq, time_kurtosis, freq_peak_freq, acf_time_decay, time_ptp, freq_centroid, freq_top1_amp, freq_rolloff_95, acf_mean_abs, time_std, time_range_ratio, time_min, time_skew, acf_integral_time, ar_r2, time_var, acf_val_dom, ar_resid_var, time_iqr, freq_bandwidth, time_max, freq_entropy, time_p5, ar1_coef, time_p95, time_mean",0.9857,0.7251,0.6688,0.8435,0.6359 +27,"freq_median_freq, time_kurtosis, freq_peak_freq, acf_time_decay, time_ptp, freq_centroid, freq_top1_amp, freq_rolloff_95, acf_mean_abs, time_std, time_range_ratio, time_min, time_skew, acf_integral_time, ar_r2, time_var, acf_val_dom, ar_resid_var, time_iqr, freq_bandwidth, time_max, freq_entropy, time_p5, ar1_coef, time_p95, time_mean, time_p25",0.9861,0.7347,0.6789,0.8514,0.6462 +28,"freq_median_freq, time_kurtosis, freq_peak_freq, acf_time_decay, time_ptp, freq_centroid, freq_top1_amp, freq_rolloff_95, acf_mean_abs, time_std, time_range_ratio, time_min, time_skew, acf_integral_time, ar_r2, time_var, acf_val_dom, ar_resid_var, time_iqr, freq_bandwidth, time_max, freq_entropy, time_p5, ar1_coef, time_p95, time_mean, time_p25, freq_top1_freq",0.9866,0.7381,0.6732,0.8794,0.6359 +29,"freq_median_freq, time_kurtosis, freq_peak_freq, acf_time_decay, time_ptp, freq_centroid, freq_top1_amp, freq_rolloff_95, acf_mean_abs, time_std, time_range_ratio, time_min, time_skew, acf_integral_time, ar_r2, time_var, acf_val_dom, ar_resid_var, time_iqr, freq_bandwidth, time_max, freq_entropy, time_p5, ar1_coef, time_p95, time_mean, time_p25, freq_top1_freq, time_median",0.9855,0.7214,0.6641,0.8425,0.6308 +30,"freq_median_freq, time_kurtosis, freq_peak_freq, acf_time_decay, time_ptp, freq_centroid, freq_top1_amp, freq_rolloff_95, acf_mean_abs, time_std, time_range_ratio, time_min, time_skew, acf_integral_time, ar_r2, time_var, acf_val_dom, ar_resid_var, time_iqr, freq_bandwidth, time_max, freq_entropy, time_p5, ar1_coef, time_p95, time_mean, time_p25, freq_top1_freq, time_median, time_p75",0.9861,0.7284,0.663,0.8714,0.6256 +31,"freq_median_freq, time_kurtosis, freq_peak_freq, acf_time_decay, time_ptp, freq_centroid, freq_top1_amp, freq_rolloff_95, acf_mean_abs, time_std, time_range_ratio, time_min, time_skew, acf_integral_time, ar_r2, time_var, acf_val_dom, ar_resid_var, time_iqr, freq_bandwidth, time_max, freq_entropy, time_p5, ar1_coef, time_p95, time_mean, time_p25, freq_top1_freq, time_median, time_p75, time_rms",0.9869,0.7485,0.6904,0.8707,0.6564 diff --git a/main.R b/main.R new file mode 100644 index 0000000..ae120fe --- /dev/null +++ b/main.R @@ -0,0 +1,222 @@ +#install.packages("lightgbm", repos = "https://cran.r-project.org") +#install.packages("MLmetrics") + +library(lightgbm) +library(MLmetrics) + +# 1. Load your data +df <- read.csv("./data/Ketamine_icp_no_x.csv") + +# --- 2. Data Preparation --- +target_name <- "label" +target_index <- which(names(df) == target_name) + +# Prepare target variable +if (is.factor(df[, target_index])) { + y <- as.numeric(df[, target_index]) - 1 +} else { + y <- df[, target_index] +} + +# Create the data matrix for features +X <- as.matrix(df[, -target_index]) + +# --- 3. Split Data into Training and Testing Sets --- +set.seed(42) +train_index <- sample(nrow(X), size = 0.8 * nrow(X)) + +X_train <- X[train_index, ] +X_test <- X[-train_index, ] +y_train <- y[train_index] +y_test <- y[-train_index] + +# --- 4. Get Feature Importance from Full Model --- +lgb_train_full <- lgb.Dataset(data = X_train, label = y_train) + +params <- list( + objective = "binary", + metric = "binary_logloss", + boosting_type = "gbdt", + num_leaves = 20, + learning_rate = 0.05, + feature_fraction = 0.8 +) + +bst_full <- lgb.train( + params = params, + data = lgb_train_full, + nrounds = 100, + verbose = -1 +) + +# Get feature importance +importance <- lgb.importance(bst_full) +num_features <- nrow(importance) + +# Create a data frame to store results +results_df <- data.frame( + Num_Features = integer(), + Features = character(), + Accuracy = numeric(), + F1_class1 = numeric(), + F2_class1 = numeric(), + Precision_class1 = numeric(), + Recall_class1 = numeric() +) + +# --- 5. Loop through different numbers of top features --- +cat("Training models with different numbers of top features...\n") +cat("=====================================================\n") + +for (i in 1:num_features) { + cat(paste("Training model with top", i, "features...\n")) + + # Select top i features + top_features <- importance$Feature[1:i] + + # Subset training and test data + X_train_sub <- X_train[, top_features, drop = FALSE] + X_test_sub <- X_test[, top_features, drop = FALSE] + + # Create LightGBM dataset + lgb_train_sub <- lgb.Dataset(data = X_train_sub, label = y_train) + + # Train model with subset of features + bst_sub <- lgb.train( + params = params, + data = lgb_train_sub, + nrounds = 100, + verbose = -1 + ) + + # Make predictions + pred_prob_sub <- predict(bst_sub, X_test_sub) + pred_class_sub <- as.numeric(pred_prob_sub > 0.5) + + # Calculate metrics + accuracy <- mean(pred_class_sub == y_test) + + # For binary classification + if (length(unique(y_test)) == 2) { + # F1 score for class 1 + f1 <- F1_Score(y_true = y_test, y_pred = pred_class_sub, positive = 1) + + # Precision and Recall for class 1 + precision <- Precision(y_true = y_test, y_pred = pred_class_sub, positive = 1) + recall <- Recall(y_true = y_test, y_pred = pred_class_sub, positive = 1) + + # F2-score (beta = 2) + beta <- 2 + f2 <- (1 + beta^2) * (precision * recall) / (beta^2 * precision + recall) + + # Handle cases where precision or recall might be NaN + if (is.na(f2)) { + f2 <- 0 + } + } else { + # For multi-class classification + f1 <- NA + precision <- NA + recall <- NA + f2 <- NA + } + + # Store results + results_df <- rbind(results_df, data.frame( + Num_Features = i, + Features = paste(top_features, collapse = ", "), + Accuracy = round(accuracy, 4), + F1_class1 = round(f1, 4), + F2_class1 = round(f2, 4), + Precision_class1 = round(precision, 4), + Recall_class1 = round(recall, 4) + )) + + # Print progress + cat(paste(" Accuracy:", round(accuracy, 4), + "| F1:", round(f1, 4), + "| F2:", round(f2, 4), + "| Precision:", round(precision, 4), + "| Recall:", round(recall, 4), "\n")) +} + +cat("=====================================================\n") + +# --- 6. Display Results --- +cat("\nSummary of Results:\n") +cat("===================\n") +print(results_df) + +# Find best performing models based on different metrics +cat("\nBest Performing Models:\n") +cat("=======================\n") + +# Best by F1 score +if (!all(is.na(results_df$F1_class1))) { + best_f1_idx <- which.max(results_df$F1_class1) + cat(paste("Best F1-score (", results_df$F1_class1[best_f1_idx], + ") with", results_df$Num_Features[best_f1_idx], "features\n")) +} + +# Best by F2 score +if (!all(is.na(results_df$F2_class1))) { + best_f2_idx <- which.max(results_df$F2_class1) + cat(paste("Best F2-score (", results_df$F2_class1[best_f2_idx], + ") with", results_df$Num_Features[best_f2_idx], "features\n")) +} + +# Best by Accuracy +best_acc_idx <- which.max(results_df$Accuracy) +cat(paste("Best Accuracy (", results_df$Accuracy[best_acc_idx], + ") with", results_df$Num_Features[best_acc_idx], "features\n")) + +# --- 7. Optional: Plot metrics vs number of features --- +if (require(ggplot2)) { + library(ggplot2) + + # Plot F1 and F2 scores + p1 <- ggplot(results_df, aes(x = Num_Features)) + + geom_line(aes(y = F1_class1, color = "F1 Score"), size = 1) + + geom_line(aes(y = F2_class1, color = "F2 Score"), size = 1) + + geom_point(aes(y = F1_class1, color = "F1 Score"), size = 2) + + geom_point(aes(y = F2_class1, color = "F2 Score"), size = 2) + + labs(title = "F1 and F2 Scores vs Number of Features", + x = "Number of Top Features", + y = "Score Value") + + theme_minimal() + + scale_color_manual(values = c("F1 Score" = "blue", "F2 Score" = "red")) + + # Plot Accuracy + p2 <- ggplot(results_df, aes(x = Num_Features, y = Accuracy)) + + geom_line(color = "darkgreen", size = 1) + + geom_point(color = "darkgreen", size = 2) + + labs(title = "Accuracy vs Number of Features", + x = "Number of Top Features", + y = "Accuracy") + + theme_minimal() + + # Plot Precision and Recall + p3 <- ggplot(results_df, aes(x = Num_Features)) + + geom_line(aes(y = Precision_class1, color = "Precision"), size = 1) + + geom_line(aes(y = Recall_class1, color = "Recall"), size = 1) + + geom_point(aes(y = Precision_class1, color = "Precision"), size = 2) + + geom_point(aes(y = Recall_class1, color = "Recall"), size = 2) + + labs(title = "Precision and Recall (Class 1) vs Number of Features", + x = "Number of Top Features", + y = "Score Value") + + theme_minimal() + + scale_color_manual(values = c("Precision" = "purple", "Recall" = "orange")) + + # Display plots + print(p1) + print(p2) + print(p3) +} + +# --- 8. Save results to CSV --- +write.csv(results_df, "feature_selection_results.csv", row.names = FALSE) +cat("\nResults saved to 'feature_selection_results.csv'\n") + +# --- 9. Display top 20 feature importance plot --- +lgb.plot.importance(importance, top_n = min(20, num_features)) +