Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Zeller Group
SIAMCAT
Commits
44768174
Commit
44768174
authored
Apr 03, 2022
by
Jakob Wirbel
Browse files
changes because of the internal changes towards mlr3.
parent
4adab759
Changes
2
Expand all
Hide whitespace changes
Inline
Side-by-side
R/model_evaluation_plot.r
View file @
44768174
...
...
@@ -4,74 +4,90 @@
### Heidelberg 2012-2018 GNU GPL 3.0
#' @title Model Evaluation Plot
#'
@description Produces two plots for model evaluation. The first plot shows
#'
the Receiver Operating Characteristic (ROC)-curves, the other the
#'
Precision-recall (PR)-curves for the different cross-validation
#'
repetitions.
#'
@usage model.evaluation.plot(..., fn.plot = NULL,
#'
colours=NULL, show.all=FALSE, verbose = 1)
#'
#'
@description Produces plots for model evaluation.
#'
#'
@usage model.evaluation.plot(..., fn.plot = NULL,
#'
colours=NULL, show.all=FALSE, verbose = 1)
#'
#' @param ... one or more object of class \link{siamcat-class}, can be named
#'
#' @param fn.plot string, filename for the pdf-plot
#' @param show.all boolean, Should all repeated cross-validation models be
#' plotted?
#'
#' @param show.all boolean, Should the results from repeated cross-validation
#' models be plotted? Defaults to \code{FALSE}, leading to a single line for
#' the mean across cross-valdiation repeats
#'
#' @param colours colour specification for the different \link{siamcat-class}-
#' objects, defaults to \code{NULL} which will cause the colours to be
#' picked from the \code{'Set1'} palette
#' objects, defaults to \code{NULL} which will cause the colours to be
#' picked from the \code{'Set1'} palette
#'
#' @param verbose control output: \code{0} for no output at all, \code{1}
#' for only information about progress and success, \code{2} for normal
#' level of information and \code{3} for full debug information,
#' defaults to \code{1}
#' for only information about progress and success, \code{2} for normal
#' level of information and \code{3} for full debug information,
#' defaults to \code{1}
#'
#' @keywords SIAMCAT model.evaluation.plot
#'
#' @section Binary classification problems:
#' The first plot shows the Receiver Operating Characteristic (ROC)-curve,
#' the other plot the Precision-recall (PR)-curve for the model. If
#' \code{show.all == FALSE} (which is the default), a single line representing
#' the mean across cross-validation repeats will be plotted, otherwise the
#' individual cross-validation repeats will be included as
#' lightly shaded lines.
#'
#' @section Regression problems:
#' For regression problems, this function will produce a scatter plot between
#' the real and predicted values. If several \link{siamcat-class}-objects are
#' supplied, a single plot for each object will be produced.
#'
#' @export
#'
#' @encoding UTF-8
#'
#' @return Does not return anything, but produces the model evaluation plot.
#'
#' @examples
#' data(siamcat_example)
#'
#' # simple working example
#' model.evaluation.plot(siamcat_example, fn.plot='./eval.pdf')
#'
#'
#' # plot several named SIAMCAT object
#' #
(
although we use only one example object here
)
#' # although we use only one example object here
#' model.evaluation.plot('Example_1'=siamcat_example,
#' 'Example_2'=siamcat_example, colours=c('red', 'blue'),
#' fn.plot='./eval.pdf')
#'
#' # show indiviudal cross-validation repeats
#' model.evaluation.plot(siamcat_example, fn.plot='./eval.pdf', show.all=TRUE)
model.evaluation.plot
<-
function
(
...
,
fn.plot
=
NULL
,
colours
=
NULL
,
show.all
=
FALSE
,
verbose
=
1
)
{
if
(
verbose
>
1
)
message
(
"+ starting model.evaluation.plot"
)
s.time
<-
proc.time
()[
3
]
if
(
!
is.null
(
fn.plot
))
pdf
(
fn.plot
,
onefile
=
TRUE
)
if
(
verbose
>
2
)
message
(
"+ plotting ROC"
)
if
(
is.null
(
fn.plot
))
par
(
ask
=
TRUE
)
par
(
mar
=
c
(
5.1
,
4.1
,
4.1
,
2.1
))
plot
(
NULL
,
xlim
=
c
(
0
,
1
),
ylim
=
c
(
0
,
1
),
xlab
=
"False positive rate"
,
ylab
=
"True positive rate"
,
type
=
"n"
)
title
(
paste
(
"ROC curve for the model"
,
sep
=
" "
))
abline
(
a
=
0
,
b
=
1
,
lty
=
3
)
args
<-
list
(
...
)
if
(
length
(
args
)
>
1
)
{
# checks
if
(
!
all
(
vapply
(
args
,
class
,
FUN.VALUE
=
character
(
1
))
==
'siamcat'
)){
stop
(
"Please supply only SIAMCAT objects. Exiting..."
)
FUN.VALUE
=
character
(
1
))
==
'siamcat'
)){
stop
(
"Please supply only SIAMCAT objects. Exiting..."
)
}
if
(
any
(
vapply
(
args
,
FUN
=
function
(
x
){
is.null
(
eval_data
(
x
,
verbose
=
0
))},
FUN.VALUE
=
logical
(
1
)))){
stop
(
"Not all SIAMCAT objects have evaluation data. Exiting..."
)
if
(
any
(
vapply
(
args
,
FUN
=
function
(
x
){
is.null
(
eval_data
(
x
,
verbose
=
0
))},
FUN.VALUE
=
logical
(
1
)))){
stop
(
"Not all SIAMCAT objects have evaluation data. Exiting..."
)
}
all.types
<-
vapply
(
args
,
FUN
=
function
(
x
){
label
(
x
)
$
type
},
FUN.VALUE
=
character
(
1
))
if
(
all
(
all.types
==
'BINARY'
)){
type
<-
'binary'
}
else
if
(
all
(
all.types
==
'CONTINUOUS'
)){
type
<-
'regression'
}
else
{
stop
(
paste0
(
"All SIAMCAT objects need to be either binary"
,
" classification or regression tasks"
))
}
n
<-
length
(
args
)
if
(
is.null
(
colours
))
{
if
(
n
>
9
)
{
...
...
@@ -85,118 +101,113 @@ model.evaluation.plot <- function(..., fn.plot=NULL, colours = NULL,
}
}
stopifnot
(
length
(
colours
)
==
n
)
# ROC
legend.val
<-
c
()
# plot each roc curve for each eval data object
for
(
i
in
seq_along
(
args
))
{
legend.val
<-
c
(
legend.val
,
as.numeric
(
single.roc.plot
(
args
[[
i
]],
colours
[
i
],
show.all
=
show.all
,
verbose
=
verbose
)))
}
if
(
!
is.null
(
names
(
args
)))
{
legend
(
'bottomright'
,
legend
=
paste0
(
names
(
args
),
' AUC: '
,
format
(
legend.val
,
digits
=
3
)),
col
=
colours
,
lty
=
1
,
lwd
=
2
,
cex
=
0.8
,
y.intersp
=
1.5
)
}
else
{
legend
(
'bottomright'
,
legend
=
paste0
(
'AUC: '
,
format
(
legend.val
,
digits
=
3
)),
col
=
colours
,
lty
=
1
,
lwd
=
2
,
cex
=
0.8
,
y.intersp
=
1.5
)
}
# PR
# precision recall curve
if
(
verbose
>
2
)
message
(
"+ plotting PRC"
)
plot
(
NULL
,
xlim
=
c
(
0
,
1
),
ylim
=
c
(
0
,
1
),
xlab
=
"Recall"
,
ylab
=
"Precision"
,
type
=
"n"
)
title
(
paste
(
"Precision-recall curve for the model"
,
sep
=
" "
))
legend.val
<-
c
()
# plot each roc curve for each eval data object
for
(
i
in
seq_along
(
args
))
{
legend.val
<-
c
(
legend.val
,
as.numeric
(
single.pr.plot
(
args
[[
i
]],
colours
[
i
],
show.all
=
show.all
,
verbose
=
verbose
)))
}
if
(
!
is.null
(
names
(
args
)))
{
legend
(
'bottomright'
,
legend
=
paste0
(
names
(
args
),
' AUC: '
,
format
(
legend.val
,
digits
=
3
)),
col
=
colours
,
lty
=
1
,
lwd
=
2
,
cex
=
0.8
,
y.intersp
=
1.5
)
}
else
{
legend
(
'bottomright'
,
legend
=
paste0
(
'AUC: '
,
format
(
legend.val
,
digits
=
3
)),
col
=
colours
,
lty
=
1
,
lwd
=
2
,
cex
=
0.8
,
y.intersp
=
1.5
)
}
}
else
if
(
length
(
args
)
==
1
)
{
}
else
if
(
length
(
args
)
==
1
)
{
# checks
if
(
!
all
(
is
(
args
[[
1
]],
'siamcat'
)))
stop
(
'Please supply a SIAMCAT object. Exiting...'
)
if
(
is.null
(
eval_data
(
args
[[
1
]],
verbose
=
0
))){
stop
(
'SIAMCAT object has no evaluation data. Exiting...'
)
}
if
(
is.null
(
data_split
(
args
[[
1
]],
verbose
=
0
))){
auc.text
<-
'Mean-prediction AUC:'
pr.text
<-
'Mean AUC:'
if
(
is.null
(
colours
))
colours
<-
'black'
if
(
label
(
args
[[
1
]])
$
type
==
'BINARY'
){
type
<-
'binary'
}
else
if
(
label
(
args
[[
1
]])
$
type
==
'CONTINUOUS'
){
type
<-
'regression'
}
else
{
if
(
data_split
(
args
[[
1
]])
$
num.resample
==
1
){
auc.text
<-
'AUC:'
pr.text
<-
'AUC:'
stop
(
paste0
(
"SIAMCAT object needs to be either binary"
,
" classification or regression task"
))
}
}
else
{
stop
(
'No SIAMCAT object supplied. Exiting...'
)
}
if
(
!
is.null
(
fn.plot
))
pdf
(
fn.plot
,
onefile
=
TRUE
)
if
(
is.null
(
fn.plot
))
par
(
ask
=
TRUE
)
par
(
mar
=
c
(
5.1
,
4.1
,
4.1
,
2.1
))
if
(
type
==
'binary'
){
# plot ROC(s)
plot
(
NULL
,
xlim
=
c
(
0
,
1
),
ylim
=
c
(
0
,
1
),
xlab
=
"False positive rate"
,
ylab
=
"True positive rate"
,
type
=
"n"
)
title
(
paste
(
"ROC curve for the model"
,
sep
=
" "
))
abline
(
a
=
0
,
b
=
1
,
lty
=
3
)
if
(
length
(
args
)
>
1
){
legend.val
<-
c
()
# plot each roc curve for each eval data object
for
(
i
in
seq_along
(
args
))
{
legend.val
<-
c
(
legend.val
,
as.numeric
(
single.roc.plot
(
args
[[
i
]],
colours
[
i
],
show.all
=
show.all
,
verbose
=
verbose
)))
}
if
(
!
is.null
(
names
(
args
)))
{
legend
(
'bottomright'
,
legend
=
paste0
(
names
(
args
),
' AUC: '
,
format
(
legend.val
,
digits
=
3
)),
col
=
colours
,
lty
=
1
,
lwd
=
2
,
cex
=
0.8
,
y.intersp
=
1.5
)
}
else
{
legend
(
'bottomright'
,
legend
=
paste0
(
'AUC: '
,
format
(
legend.val
,
digits
=
3
)),
col
=
colours
,
lty
=
1
,
lwd
=
2
,
cex
=
0.8
,
y.intersp
=
1.5
)
}
}
else
{
if
(
is.null
(
data_split
(
args
[[
1
]],
verbose
=
0
))){
auc.text
<-
'Mean-prediction AUC:'
pr.text
<-
'Mean AUC:'
}
else
{
if
(
data_split
(
args
[[
1
]])
$
num.resample
==
1
){
auc.text
<-
'AUC:'
pr.text
<-
'AUC:'
}
else
{
auc.text
<-
'Mean-prediction AUC:'
pr.text
<-
'Mean AUC:'
}
}
auroc
<-
single.roc.plot
(
args
[[
1
]],
colours
,
show.all
=
show.all
,
verbose
=
verbose
)
text
(
0.7
,
0.1
,
paste
(
auc.text
,
format
(
auroc
,
digits
=
3
)))
}
# ROC
if
(
is.null
(
colours
))
colours
<-
'black'
auroc
<-
single.roc.plot
(
args
[[
1
]],
colours
,
show.all
=
show.all
,
verbose
=
verbose
)
text
(
0.7
,
0.1
,
paste
(
auc.text
,
format
(
auroc
,
digits
=
3
)))
# PR
if
(
verbose
>
2
)
message
(
"+ plotting PRC"
)
plot
(
NULL
,
xlim
=
c
(
0
,
1
),
ylim
=
c
(
0
,
1
),
xlab
=
"Recall"
,
ylab
=
"Precision"
,
type
=
"n"
)
# plot PRC(s)
plot
(
NULL
,
xlim
=
c
(
0
,
1
),
ylim
=
c
(
0
,
1
),
xlab
=
"Recall"
,
ylab
=
"Precision"
,
type
=
"n"
)
title
(
paste
(
"Precision-recall curve for the model"
,
sep
=
" "
))
label
<-
label
(
args
[[
1
]])
abline
(
h
=
mean
(
label
$
label
==
max
(
label
$
info
)),
lty
=
3
)
auprc
<-
single.pr.plot
(
args
[[
1
]],
colours
,
show.all
=
show.all
,
verbose
=
verbose
)
text
(
0.7
,
0.1
,
paste
(
pr.text
,
format
(
auprc
,
digits
=
3
)))
}
else
{
stop
(
'No SIAMCAT object supplied. Exiting...'
)
if
(
length
(
args
)
>
1
){
legend.val
<-
c
()
# plot each roc curve for each eval data object
for
(
i
in
seq_along
(
args
))
{
legend.val
<-
c
(
legend.val
,
as.numeric
(
single.pr.plot
(
args
[[
i
]],
colours
[
i
],
show.all
=
show.all
,
verbose
=
verbose
)))
}
if
(
!
is.null
(
names
(
args
)))
{
legend
(
'bottomright'
,
legend
=
paste0
(
names
(
args
),
' AUC: '
,
format
(
legend.val
,
digits
=
3
)),
col
=
colours
,
lty
=
1
,
lwd
=
2
,
cex
=
0.8
,
y.intersp
=
1.5
)
}
else
{
legend
(
'bottomright'
,
legend
=
paste0
(
'AUC: '
,
format
(
legend.val
,
digits
=
3
)),
col
=
colours
,
lty
=
1
,
lwd
=
2
,
cex
=
0.8
,
y.intersp
=
1.5
)
}
}
else
{
auprc
<-
single.pr.plot
(
args
[[
1
]],
colours
,
show.all
=
show.all
,
verbose
=
verbose
)
text
(
0.7
,
0.1
,
paste
(
pr.text
,
format
(
auprc
,
digits
=
3
)))
}
}
else
if
(
type
==
'regression'
){
# plot Scatter
for
(
i
in
seq_along
(
args
)){
single.regression.plot
(
args
[[
i
]],
colours
[
i
],
names
(
args
)[
i
],
verbose
)
}
}
if
(
!
is.null
(
fn.plot
))
tmp
<-
dev.off
()
if
(
is.null
(
fn.plot
))
par
(
ask
=
FALSE
)
e.time
<-
proc.time
()[
3
]
if
(
verbose
>
1
)
message
(
paste
(
"+ finished model.evaluation.plot in"
,
message
(
paste
0
(
"+ finished model.evaluation.plot in
"
,
formatC
(
e.time
-
s.time
,
digits
=
3
),
"s"
))
...
...
@@ -211,15 +222,15 @@ model.evaluation.plot <- function(..., fn.plot=NULL, colours = NULL,
single.pr.plot
<-
function
(
siamcat
,
colour
,
show.all
,
verbose
)
{
eval.data
<-
eval_data
(
siamcat
)
aucspr.all
<-
eval.data
$
auprc.all
# pr curves for resampling
if
(
!
is.null
(
eval.data
$
prc.all
)
&
show.all
)
{
aucspr.all
=
eval.data
$
auprc.all
for
(
c
in
seq_len
(
length
(
eval.data
$
prc.all
)))
{
pr
=
eval.data
$
prc.all
[[
c
]]
lines
(
pr
$
recall
,
pr
$
precision
,
col
=
alpha
(
colour
,
alpha
=
0.5
))
if
(
verbose
>
2
)
message
(
paste
(
message
(
paste
0
(
"+++ AU-PRC (resampled run "
,
c
,
"): "
,
...
...
@@ -257,15 +268,15 @@ single.pr.plot <- function(siamcat, colour, show.all, verbose) {
single.roc.plot
<-
function
(
siamcat
,
colour
,
show.all
,
verbose
)
{
eval.data
<-
eval_data
(
siamcat
)
aucs
<-
eval.data
$
auroc.all
if
(
!
is.null
(
eval.data
$
roc.all
)
&
show.all
){
aucs
=
eval.data
$
auroc.all
for
(
c
in
seq_along
(
eval.data
$
roc.all
))
{
roc.c
=
eval.data
$
roc.all
[[
c
]]
lines
(
1
-
roc.c
$
specificities
,
roc.c
$
sensitivities
,
col
=
alpha
(
colour
,
alpha
=
0.5
))
if
(
verbose
>
2
)
{
message
(
paste
(
'+++ AU-ROC (resampled run '
,
message
(
paste
0
(
'+++ AU-ROC (resampled run '
,
c
,
"): "
,
format
(
aucs
[
c
],
digits
=
3
)))
}
}
...
...
@@ -303,3 +314,35 @@ single.roc.plot <- function(siamcat, colour, show.all, verbose) {
return
(
as.numeric
(
auroc
))
}
single.regression.plot
<-
function
(
siamcat
,
colour
,
title
=
NULL
,
verbose
=
0
){
eval.data
<-
eval_data
(
siamcat
)
mean.pred
<-
rowMeans
(
pred_matrix
(
siamcat
))
mod
<-
'M'
range
<-
label
(
siamcat
)
$
info
if
(
is.null
(
data_split
(
siamcat
,
verbose
=
0
))){
mod
<-
'Mean m'
}
else
{
if
(
data_split
(
siamcat
)
$
num.resample
==
1
){
mod
<-
'M'
}
else
{
mod
<-
'Mean m'
}
}
plot
(
NULL
,
xlim
=
c
(
range
[
1
],
range
[
2
]),
ylim
=
c
(
range
[
1
],
range
[
2
]),
xlab
=
"True value"
,
ylab
=
paste0
(
mod
,
"odel predictions"
),
type
=
"n"
)
abline
(
a
=
0
,
b
=
1
,
lty
=
3
)
if
(
is.null
(
title
)){
title
(
"Model predictions"
)
}
else
{
title
(
title
)
}
points
(
label
(
siamcat
)
$
label
,
mean.pred
,
col
=
colour
)
legend
(
'bottomright'
,
legend
=
c
(
paste0
(
'R-squared: '
,
format
(
eval.data
$
r2
,
digits
=
3
)),
paste0
(
'Mean absolute error: '
,
format
(
eval.data
$
mae
,
digits
=
3
)),
paste0
(
'Mean squared error: '
,
format
(
eval.data
$
mse
,
digits
=
3
))))
}
R/model_interpretation_plot.R
View file @
44768174
This diff is collapsed.
Click to expand it.
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment