Skip to content

Commit

Permalink
first version of spatial error profile
Browse files Browse the repository at this point in the history
  • Loading branch information
HannaMeyer committed Mar 11, 2024
1 parent 2d2d9fa commit 03b38a1
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 25 deletions.
63 changes: 45 additions & 18 deletions R/errorProfiles.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#' @description Performance metrics are calculated for moving windows of dissimilarity values based on cross-validated training data
#' @param model the model used to get the AOA
#' @param trainDI the result of \code{\link{trainDI}} or aoa object \code{\link{aoa}}
#' @param locations Optional. sf object for the training data used in model. Only used if variable=="geodist". Note that they must be in the same order as model$trainingData.
#' @param variable Character. Which dissimilarity or distance measure to use for the error metric. Current options are "DI" or "LPD"
#' @param multiCV Logical. Re-run model fitting and validation with different CV strategies. See details.
#' @param window.size Numeric. Size of the moving window. See \code{\link{rollapply}}.
Expand Down Expand Up @@ -31,7 +32,8 @@


errorProfiles <- function(model,
trainDI,
trainDI=NULL,
locations=NULL,
variable = "DI",
multiCV=FALSE,
length.out = 10,
Expand All @@ -47,13 +49,17 @@ errorProfiles <- function(model,
trainDI = trainDI$parameters
}

if(!is.null(locations)&variable=="geodist"){
message("warning: Please ensure that the order of the locations matches to model$trainingData")
}


# get DIs and Errormetrics OR calculate new ones from multiCV
if(!multiCV){
preds_all <- get_preds_all(model, trainDI, variable)
preds_all <- get_preds_all(model, trainDI, locations, variable)
}
if(multiCV){
preds_all <- multiCV(model, length.out, method, useWeight, variable)
preds_all <- multiCV(model, locations, length.out, method, useWeight, variable)
}

# train model between DI and Errormetric
Expand Down Expand Up @@ -126,36 +132,45 @@ errorModel <- function(preds_all, model, window.size, calib, k, m, variable){
errormodel <- lm(metric ~ DI, data = performance)
} else if (variable == "LPD") {
errormodel <- lm(metric ~ LPD, data = performance)
} else if (variable=="geodist"){
errormodel <- lm(metric ~ geodist, data = performance)
}
}
if(calib=="scam"){
if (!requireNamespace("scam", quietly = TRUE)) {
stop("Package \"scam\" needed for this function to work. Please install it.",
call. = FALSE)
}
if (variable == "DI") {
if (variable %in% c("DI","geodist")) {
if (model$maximize){ # e.g. accuracy, kappa, r2
bs="mpd"
}else{
bs="mpi" #e.g. RMSE
}
errormodel <- scam::scam(metric~s(DI, k=k, bs=bs, m=m),
data=performance,
family=stats::gaussian(link="identity"))
if(variable=="DI"){
errormodel <- scam::scam(metric~s(DI, k=k, bs=bs, m=m),
data=performance,
family=stats::gaussian(link="identity"))
}else if (variable=="geodist"){
errormodel <- scam::scam(metric~s(geodist, k=k, bs=bs, m=m),
data=performance,
family=stats::gaussian(link="identity"))
}

} else if (variable == "LPD") {
if (model$maximize){ # e.g. accuracy, kappa, r2
bs="mpi"
}else{
bs="mpd" #e.g. RMSE
}
errormodel <- scam::scam(metric~s(LPD, k=k, bs=bs, m=m),
data=performance,
family=stats::gaussian(link="identity"))
data=performance,
family=stats::gaussian(link="identity"))
}
}
if(calib=="exp"){
if (variable == "DI") {
stop("Eexponential model currently only implemented for LPD")
if (variable %in% c("DI","geodist")) {
stop("Exponential model currently only implemented for LPD")
} else if (variable == "LPD") {
errormodel <- lm(metric ~ log(LPD), data = performance)
}
Expand All @@ -168,7 +183,7 @@ errorModel <- function(preds_all, model, window.size, calib, k, m, variable){


# MultiCV
multiCV <- function(model, length.out, method, useWeight, variable,...){
multiCV <- function(model, locations, length.out, method, useWeight, variable,...){

preds_all <- data.frame()
train_predictors <- model$trainingData[,-which(names(model$trainingData)==".outcome")]
Expand Down Expand Up @@ -199,6 +214,10 @@ multiCV <- function(model, length.out, method, useWeight, variable,...){
trainDI_new <- trainDI(model_new, method=method, useWeight=useWeight)
} else if (variable == "LPD") {
trainDI_new <- trainDI(model_new, method=method, useWeight=useWeight, LPD = TRUE)
} else if (variable=="geodist"){
tmp_gd_new <- CAST::geodist(locations,modeldomain=locations,cvfolds = model$control$indexOut)
geodist_new <- tmp_gd_new[tmp_gd_new$what=="CV-distances","dist"]

}


Expand All @@ -213,21 +232,26 @@ multiCV <- function(model, length.out, method, useWeight, variable,...){
preds_dat_tmp <- data.frame(preds,"LPD"=trainDI_new$trainLPD)
preds_dat_tmp <- preds_dat_tmp[preds_dat_tmp$LPD > 0,]
preds_all <- rbind(preds_all,preds_dat_tmp)
} else if (variable == "geodist"){
preds_dat_tmp <- data.frame(preds,"geodist"=geodist_new)
preds_all <- rbind(preds_all,preds_dat_tmp)
# NO AOA used here
}
}

attr(preds_all, "AOA_threshold") <- trainDI_new$threshold
if(variable%in%c("DI","LPD")){
attr(preds_all, "AOA_threshold") <- trainDI_new$threshold
message(paste0("Note: multiCV=TRUE calculated new AOA threshold of ", round(trainDI_new$threshold, 5),
"\nThreshold is stored in the attributes, access with attr(error_model, 'AOA_threshold').",
"\nPlease refere to examples and details for further information."))
}
attr(preds_all, "variable") <- variable
attr(preds_all, "metric") <- model$metric
message(paste0("Note: multiCV=TRUE calculated new AOA threshold of ", round(trainDI_new$threshold, 5),
"\nThreshold is stored in the attributes, access with attr(error_model, 'AOA_threshold').",
"\nPlease refere to examples and details for further information."))
return(preds_all)
}


# Get Preds all
get_preds_all <- function(model, trainDI, variable){
get_preds_all <- function(model, trainDI, locations, variable){

if(is.null(model$pred)){
stop("no cross-predictions can be retrieved from the model. Train with savePredictions=TRUE or provide calibration data")
Expand All @@ -252,6 +276,9 @@ get_preds_all <- function(model, trainDI, variable){
preds_all$LPD <- trainDI$trainLPD[!is.na(trainDI$trainLPD)]
## only take predictions from inside the AOA:
preds_all <- preds_all[preds_all$LPD>0,]
} else if(variable=="geodist"){
tmp_gd <- CAST::geodist(locations,modeldomain=locations,cvfolds = model$control$indexOut)
preds_all$geodist <- tmp_gd[tmp_gd$what=="CV-distances","dist"]
}

attr(preds_all, "AOA_threshold") <- trainDI$threshold
Expand Down
23 changes: 16 additions & 7 deletions inst/examples/ex_errorProfiles.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,41 @@


data(splotdata)
splotdata <- st_drop_geometry(splotdata)
predictors <- terra::rast(system.file("extdata","predictors_chile.tif", package="CAST"))

model <- caret::train(splotdata[,6:16], splotdata$Species_richness, ntree = 10,
model <- caret::train(st_drop_geometry(splotdata)[,6:16], splotdata$Species_richness, ntree = 10,
trControl = trainControl(method = "cv", savePredictions = TRUE))

AOA <- aoa(predictors, model, LPD = TRUE, maxLPD = 1)

# DI ~ error
### DI ~ error
errormodel_DI <- errorProfiles(model, AOA, variable = "DI")
plot(errormodel_DI)

expected_error_DI = terra::predict(AOA$DI, errormodel_DI)
plot(expected_error_DI)

# LPD ~ error
### LPD ~ error
errormodel_LPD <- errorProfiles(model, AOA, variable = "LPD")
plot(errormodel_LPD)

expected_error_LPD = terra::predict(AOA$LPD, errormodel_LPD)
plot(expected_error_LPD)


# with multiCV = TRUE (for DI ~ error)

### geodist ~ error
errormodel_geodist = errorProfiles(model, locations=splotdata,
variable = "geodist")
plot(errormodel_geodist)

dist <- terra::distance(predictors[[1]],vect(splotdata))
names(dist) <- "geodist"
expected_error_DI <- terra::predict(dist, errormodel_geodist)
plot(expected_error_DI)


### with multiCV = TRUE (for DI ~ error)
errormodel_DI = errorProfiles(model, AOA, multiCV = TRUE, length.out = 3, variable = "DI")
plot(errormodel_DI)

Expand All @@ -43,7 +54,5 @@
plot(mask_aoa)




}

0 comments on commit 03b38a1

Please sign in to comment.