Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

list support for weights #92

Merged
merged 1 commit into from
Mar 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 38 additions & 11 deletions R/trainDI.R
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,11 @@ trainDI <- function(model = NA,
weight <- t(data.frame(rep(1,length(variables))))
names(weight) <- variables
}
}else{ #check if manually given weights are correct. otherwise ignore (set to 1):
if(nrow(weight)!=1||ncol(weight)!=length(variables)){
message("variable weights are not correctly specified and will be ignored. See ?aoa")
weight <- t(data.frame(rep(1,length(variables))))
names(weight) <- variables
}
weight <- weight[,na.omit(match(variables, names(weight)))]
if (any(weight<0)){
weight[weight<0]<-0
message("negative weights were set to 0")
}
}else{


weight <- user_weights(weight, variables)

}

# get CV folds from model or from parameters
Expand Down Expand Up @@ -414,6 +408,39 @@ aoa_get_weights = function(model, variables){
}



# check user weight input
# make sure this function outputs a data.frame with
# one row and columns named after the variables

user_weights = function(weight, variables){

# list input support
if(inherits(weight, "list")){
# check if all list entries are in variables
weight = as.data.frame(weight)
}


#check if manually given weights are correct. otherwise ignore (set to 1):
if(nrow(weight)!=1 || !all(variables %in% names(weight))){
message("variable weights are not correctly specified and will be ignored. See ?aoa")
weight <- t(data.frame(rep(1,length(variables))))
names(weight) <- variables
}
weight <- weight[,na.omit(match(variables, names(weight)))]
if (any(weight<0)){
weight[weight<0]<-0
message("negative weights were set to 0")
}

return(weight)

}




# Get trainingdata from train object

aoa_get_train <- function(model){
Expand Down