Commit 83b2edb3 authored by Jean-Karim Heriche's avatar Jean-Karim Heriche

Improved UI, added option to save model predictions in the table. Use xgboost via caret.

parent 3d0ef60f
......@@ -8,36 +8,44 @@
ui_feature_selection <- function(id) {
ns <- NS(id)
column(12,
fluidRow(
box( width = 3,
title = "Features to consider", solidHeader = TRUE, status = "primary",
uiOutput(ns("featuresToProcess"))
),
box( width = 3,
title = "Target annotation column", solidHeader = TRUE, status = "primary",
uiOutput(ns("targetCol"))
),
# box(width = 3,
# title = "Parameters", solidHeader = TRUE, status = "primary",
# uiOutput(ns("featSelectionParams"))),
actionButton(ns("featSelectionButton"), "Start", icon("check"),
style="color: #fff; background-color: #3C8DBC; border-color: #3C8DBC")
),
fluidRow(
box( width = 6, align = "left",
title = "Plot of feature importance", solidHeader = TRUE, status = "primary",
withSpinner(
plotlyOutput(ns("plot"), height = '500px'))
), # End box
box( width = 6,
tags$style(paste0("#", ns("hover"), " { height: 10em;}")),
h4(HTML("&nbsp; Accuracy of model predictions")),
textOutput(ns("accuracy"))
)
) # End fluidRow
) # End column
fluidPage(
fluidRow(
column(12,
fluidRow(
box( width = 3,
title = "Features to consider", solidHeader = TRUE, status = "primary",
uiOutput(ns("featuresToProcess"))
),
box( width = 3,
title = "Target annotation column", solidHeader = TRUE, status = "primary",
uiOutput(ns("targetCol"))
),
# box(width = 3,
# title = "Parameters", solidHeader = TRUE, status = "primary",
# uiOutput(ns("featSelectionParams"))),
actionButton(ns("featSelectionButton"), "Start", icon("check"),
style="color: #fff; background-color: #3C8DBC; border-color: #3C8DBC")
),
fluidRow(
box( width = 6, align = "left",
title = "Plot of feature importance", solidHeader = TRUE, status = "primary",
plotlyOutput(ns("plot"), height = '500px')
), # End box
box( width = 6,
title = "Model information", solidHeader = TRUE, status = "primary",
verbatimTextOutput(ns("accuracy"))
)
)
) # End column
), # End first row
fluidRow(
column(12,
column( width = 3,
uiOutput(ns("predictButton"))
)
)
) # End second row
) # End fluidPage
} # End ui_feature_selection()
feature_selection_server <- function(input, output, session, rv, session_parent) {
......@@ -51,7 +59,7 @@ feature_selection_server <- function(input, output, session, rv, session_parent
output$featuresToProcess <- renderUI({
tagList(
selectizeInput(inputId = ns("featuresToProcess"),
label = "Feature selection will be applied to these features",
label = "Classification will use these features",
multiple = TRUE,
choices = names(rv$data)),
checkboxInput(ns("check"), "Select all numeric variables", FALSE)
......@@ -63,7 +71,8 @@ feature_selection_server <- function(input, output, session, rv, session_parent
selectizeInput(inputId = ns("targetCol"),
label = "Select column containing target annotations",
multiple = FALSE,
choices = names(rv$data))
choices = c("", names(rv$data))),
checkboxInput(ns("nocv"), "Don't use cross-validation (faster but may be less accurate)", FALSE)
)
})
......@@ -81,16 +90,23 @@ feature_selection_server <- function(input, output, session, rv, session_parent
}
})
####################################
## Feature selection with XGBoost ##
####################################
#######################################################
## Classification and feature selection with XGBoost ##
#######################################################
classifier.data <- reactiveValues(model = NULL, classes = NULL)
action_button_feature_selection <- function(){
shinyjs::disable("featSelectionButton")
req(rv$data)
tmp <- rv$data[,c(input$targetCol, input$featuresToProcess)]
# Remove rows with NAs and infinite values
tmp <- tmp[is.finite(rowSums(tmp[,input$featuresToProcess])),]
# Target vector
target <- rv$data[, input$targetCol]
target <- tmp[, input$targetCol]
tmp <- tmp[, input$featuresToProcess]
classes <- levels(as.factor(target))
classifier.data$classes <- classes
# Extract data with annotations
idx.to.keep <- which(!is.na(target) & tolower(target) != 'none' & target != "")
......@@ -99,58 +115,79 @@ feature_selection_server <- function(input, output, session, rv, session_parent
# If this becomes a problem, we can use functions from the caret package
train.idx <- sample(idx.to.keep, floor(0.67*length(idx.to.keep)))
test.idx <- idx.to.keep[-train.idx]
# Convert annotations to integer values starting from 0
# This is required by xgboost
target <- as.integer(as.factor(target)) - 1
# Form training and test sets
train.data <- xgb.DMatrix(data = as.matrix(rv$data[train.idx, input$featuresToProcess]), label = target[train.idx])
test.data <- xgb.DMatrix(data = as.matrix(rv$data[test.idx, input$featuresToProcess]), label = target[test.idx])
# Determine classification type
# Use binary:logistic for binary classification
# and multi:softprob for multiple classes
# Both output probabilities of belonging to each class
nb.class <- length(unique(target))
if(nb.class == 2) {
classification_type <- "binary:logistic"
train.data <- as.matrix(tmp[train.idx, input$featuresToProcess])
train.labels <- as.factor(target[train.idx])
test.data <- as.matrix(tmp[test.idx, input$featuresToProcess])
test.labels <- as.factor(target[test.idx])
# Tune xgboost hyperparameters using caret
nrounds <- seq(from = 100, to = 500, by = 50)
eta <- c(0.025, 0.05, 0.1, 0.3, 0.4)
depth <- c(3, 4, 5, 6)
if(input$nocv) { # No cross validation: use default parameters
paramGrid <- expand.grid(
nrounds = 500,
eta = 0.3,
max_depth = 6,
gamma = 0,
colsample_bytree = 1,
min_child_weight = 1,
subsample = 1
)
nfolds <- 1
} else {
classification_type <- "multi:softprob"
paramGrid <- expand.grid(
nrounds = nrounds,
eta = eta,
max_depth = depth,
gamma = 0,
colsample_bytree = 1,
min_child_weight = 1,
subsample = 1
)
nfolds <- 5
}
# xgboost parameters with default values
params <- list(booster = "gbtree",
objective = classification_type,
eta = 0.3,
gamma = 0,
max_depth = 6,
min_child_weight = 1,
subsample = 1,
colsample_bytree = 1)
# Learn a model
xgb <- xgb.train (params = params,
data = train.data,
nrounds = 1000,
early_stopping_rounds = 10,
watchlist = list(train=train.data, test=test.data),
verbose = 0,
num_class = nb.class)
# Make predictions
# Outputs a matrix of probabilities per class (classes are in columns)
xgbpred <- predict(xgb, test.data, reshape = TRUE)
# Assign labels based on highest probability
xgbpred.classes <- max.col(xgbpred) - 1
accuracy <- sprintf("%1.2f%%", sum(xgbpred.classes==target[test.idx])/length(xgbpred.classes) * 100)
# Get feature importance
# This also clusters them. The number of clusters is automiatically determined (using BIC)
feature.importance <- xgb.importance(model = xgb, feature_names = input$featuresToProcess)
trainCtrl <- caret::trainControl(
method = "cv", # cross-validation
number = nfolds, # with n folds
verboseIter = FALSE, # no training log
allowParallel = TRUE # FALSE for reproducible results
)
xgbModel <- caret::train(
x = train.data,
y = train.labels,
trControl = trainCtrl,
tuneGrid = paramGrid,
method = "xgbTree",
verbose = FALSE
)
classifier.data$model <- xgbModel
# Evaluate on held out data
xgbpred <- predict(xgbModel, newdata = test.data)
confusion.matrix <- confusionMatrix(xgbpred, test.labels, mode = "everything")
# Get feature importance using the xgboost library
# This also clusters the features. The number of clusters is automatically determined (using BIC)
feature.importance <- xgb.importance(model = xgbModel$finalModel, feature_names = xgbModel$finalModel$feature_names)
p <- xgb.ggplot.importance(feature.importance)
p <- p + ggtitle("") + theme(legend.title = element_blank())
remove_modal_spinner()
output$plot <- renderPlotly({
ggplotly(p, tooltip = "none", source = "featureImportance") %>%
layout(legend = list(orientation = "v", x = 1, y = 0.8, title=list(text='Clusters'))) %>%
config(p = ., staticPlot = FALSE, doubleClick = "reset+autosize", autosizable = TRUE, displayModeBar = TRUE,
sendData = FALSE, displaylogo = FALSE,
modeBarButtonsToRemove = c("sendDataToCloud", "hoverCompareCartesian", "select2d", "lasso2d")) # Control Plotly's tool bar
})
output$accuracy <- renderPrint( confusion.matrix )
output$plot <- renderPlotly(p)
output$accuracy <- renderText( accuracy )
output$predictButton <- renderUI(actionButton(ns("predict"),
"Make predictions and add them to the table"))
remove_modal_spinner()
shinyjs::enable("featSelectionButton")
}
......@@ -160,10 +197,12 @@ feature_selection_server <- function(input, output, session, rv, session_parent
# Check that we have at least two columns selected and they they are numeric
# then start processing
Check_variable <- function(buttonaction){
if(length(input$featuresToProcess)<2) {
showNotification("You need to select at least 2 variables/columns.", type = "error")
if(is.null(input$targetCol) || input$targetCol == "") {
showNotification("You need to select an annotation column with class assignments.", type = "error")
} else if(length(input$featuresToProcess)<2) {
showNotification("You need to select at least 2 features/columns.", type = "error")
} else if(!any(unlist(lapply(rv$data[,input$featuresToProcess], is.numeric)))) {
showNotification("Non-numeric variables/columns selected.", type = "error")
showNotification("Non-numeric features/columns selected.", type = "error")
}
else{
show_modal_spinner(text = "It looks like this is going to take a while. Please wait...",
......@@ -172,5 +211,19 @@ feature_selection_server <- function(input, output, session, rv, session_parent
}
}
}
observeEvent(input$predict, {
req(rv$data)
if(!is.null(classifier.data$model)) {
if(!("xgboost.predictions" %in% colnames(rv$data))) {
rv$data$xgboost.predictions <- NA
}
tmp <- rv$data[,input$featuresToProcess]
# Remove rows with NAs and infinite values
tmp <- as.matrix(tmp[is.finite(rowSums(tmp)),])
preds <- predict(classifier.data$model, newdata = tmp)
rv$data[rownames(tmp),]$xgboost.predictions <- classifier.data$classes[preds]
showNotification("Model predictions have been added to the data.", type = "warning")
}
})
}
\ No newline at end of file
......@@ -15,7 +15,8 @@ cran.mirrors <- c('https://ftp.gwdg.de/pub/misc/cran/', 'https://cloud.r-project
## Install required packages if missing
## CRAN packages
pkg <- c("devtools", "BiocManager", "DT", "shiny", "shinyFiles", "shinycssloaders", "shinydashboard", "shinyjs",
"shinyWidgets", "shinybusy", "assertthat", "ggplot2", "plotly", "RANN", "MASS", "uwot", "xgboost", "Ckmeans.1d.dp")
"shinyWidgets", "shinybusy", "assertthat", "ggplot2", "plotly", "RANN", "MASS", "uwot", "xgboost", "Ckmeans.1d.dp",
"caret")
new.pkg <- pkg[!(pkg %in% installed.packages())]
if (length(new.pkg)) {
message(paste0("Installing ", new.pkg, "\n"))
......@@ -126,7 +127,7 @@ server <- function(input,output,session){
menuItem(HTML("&nbsp;&nbsp;Annotate"), tabName = "annotate", icon = icon("edit")),
menuItem(HTML("&nbsp;&nbsp;Dimensionality<br>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;reduction"),
tabName = "dimensionality_reduction", icon = icon("cube")),
menuItem(HTML("&nbsp;&nbsp;Feature<br>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;selection"),
menuItem(HTML("&nbsp;&nbsp;Classification and <br>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;feature selection"),
tabName = "feature_selection", icon = icon("filter")),
menuItem(HTML("&nbsp;&nbsp;Clustering"), tabName = "cluster", icon = icon("object-group"))
)})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment