From 1c9251c42df8ab739c18dfac9ebee921ef26a3b4 Mon Sep 17 00:00:00 2001 From: Ludwigm6 Date: Tue, 12 Mar 2024 16:13:11 +0100 Subject: [PATCH] list support for weights --- R/trainDI.R | 49 ++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 38 insertions(+), 11 deletions(-) diff --git a/R/trainDI.R b/R/trainDI.R index 7b3aa6ec..5e65a356 100644 --- a/R/trainDI.R +++ b/R/trainDI.R @@ -130,17 +130,11 @@ trainDI <- function(model = NA, weight <- t(data.frame(rep(1,length(variables)))) names(weight) <- variables } - }else{ #check if manually given weights are correct. otherwise ignore (set to 1): - if(nrow(weight)!=1||ncol(weight)!=length(variables)){ - message("variable weights are not correctly specified and will be ignored. See ?aoa") - weight <- t(data.frame(rep(1,length(variables)))) - names(weight) <- variables - } - weight <- weight[,na.omit(match(variables, names(weight)))] - if (any(weight<0)){ - weight[weight<0]<-0 - message("negative weights were set to 0") - } + }else{ + + + weight <- user_weights(weight, variables) + } # get CV folds from model or from parameters @@ -414,6 +408,39 @@ aoa_get_weights = function(model, variables){ } + +# check user weight input +# make sure this function outputs a data.frame with +# one row and columns named after the variables + +user_weights = function(weight, variables){ + + # list input support + if(inherits(weight, "list")){ + # check if all list entries are in variables + weight = as.data.frame(weight) + } + + + #check if manually given weights are correct. otherwise ignore (set to 1): + if(nrow(weight)!=1 || !all(variables %in% names(weight))){ + message("variable weights are not correctly specified and will be ignored. See ?aoa") + weight <- t(data.frame(rep(1,length(variables)))) + names(weight) <- variables + } + weight <- weight[,na.omit(match(variables, names(weight)))] + if (any(weight<0)){ + weight[weight<0]<-0 + message("negative weights were set to 0") + } + + return(weight) + +} + + + + # Get trainingdata from train object aoa_get_train <- function(model){