From cc01490f54a4a12b4802c98e485a75449011dbb1 Mon Sep 17 00:00:00 2001 From: Nicholas Williams Date: Thu, 12 Sep 2024 10:24:32 -0700 Subject: [PATCH 1/2] Riesz implementation, with batches, no conditional effects --- DESCRIPTION | 18 ++++--- NAMESPACE | 4 ++ R/checks.R | 4 +- R/estimators.R | 24 ++++++++-- R/lmtp-package.R | 2 +- R/lmtp_control.R | 24 ++++++++-- R/make_dataset.R | 23 +++++++++ R/nn_riesz.R | 41 ++++++++++++++++ R/riesz.R | 101 +++++++++++++++++++++++++++++++++++++++ R/sequential_module.R | 27 +++++++++++ R/sl.R | 19 +++++--- R/theta.R | 4 +- R/tmle.R | 11 +++-- R/utils.R | 9 ++++ man/lmtp-package.Rd | 5 ++ man/lmtp_control.Rd | 18 ++++++- man/lmtp_tmle.Rd | 16 +++++-- man/sequential_module.Rd | 24 ++++++++++ 18 files changed, 339 insertions(+), 35 deletions(-) create mode 100644 R/make_dataset.R create mode 100644 R/nn_riesz.R create mode 100644 R/riesz.R create mode 100644 R/sequential_module.R create mode 100644 man/sequential_module.Rd diff --git a/DESCRIPTION b/DESCRIPTION index fbfa5ca..a2d222f 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: lmtp Title: Non-Parametric Causal Effects of Feasible Interventions Based on Modified Treatment Policies -Version: 1.4.1.9001 +Version: 1.5.1.9001 Authors@R: c(person(given = "Nicholas", family = "Williams", @@ -11,7 +11,11 @@ Authors@R: family = "Díaz", email = "Ivan.Diaz@nyulangone.org", role = c("aut", "cph"), - comment = c(ORCID = "0000-0001-9056-2047"))) + comment = c(ORCID = "0000-0001-9056-2047")), + person(given = "Herb", + family = "Susmann", + email = "herbert.susmann@nyulangone.org", + role = "ctb")) Description: Non-parametric estimators for casual effects based on longitudinal modified treatment policies as described in Diaz, Williams, Hoffman, and Schenck , traditional point treatment, and traditional longitudinal effects. Continuous, binary, categorical treatments, and multivariate treatments are allowed as well are @@ -19,16 +23,14 @@ Description: Non-parametric estimators for casual effects based on longitudinal irrespective of treatment variable type. For both continuous and binary outcomes, additive treatment effects can be calculated and relative risks and odds ratios may be calculated for binary outcomes. Depends: - mlr3superlearner, R (>= 2.10) License: AGPL-3 Encoding: UTF-8 LazyData: true Roxygen: list(markdown = TRUE) -RoxygenNote: 7.3.1 +RoxygenNote: 7.3.2 Imports: stats, - nnls, cli, R6, generics, @@ -37,8 +39,10 @@ Imports: progressr, data.table (>= 1.13.0), checkmate (>= 2.1.0), - isotone, - mlr3superlearner + isotone, + mlr3superlearner, + torch, + coro URL: https://beyondtheate.com/, https://github.com/nt-williams/lmtp BugReports: https://github.com/nt-williams/lmtp/issues Suggests: diff --git a/NAMESPACE b/NAMESPACE index f190e0c..32e1679 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -15,10 +15,12 @@ export(lmtp_sdr) export(lmtp_sub) export(lmtp_survival) export(lmtp_tmle) +export(sequential_module) export(static_binary_off) export(static_binary_on) export(tidy) importFrom(R6,R6Class) +importFrom(checkmate,`%??%`) importFrom(data.table,.SD) importFrom(data.table,`:=`) importFrom(data.table,as.data.table) @@ -28,6 +30,7 @@ importFrom(stats,binomial) importFrom(stats,coef) importFrom(stats,gaussian) importFrom(stats,glm) +importFrom(stats,model.matrix) importFrom(stats,na.omit) importFrom(stats,plogis) importFrom(stats,pnorm) @@ -37,5 +40,6 @@ importFrom(stats,qnorm) importFrom(stats,quantile) importFrom(stats,runif) importFrom(stats,sd) +importFrom(stats,setNames) importFrom(stats,var) importFrom(stats,weighted.mean) diff --git a/R/checks.R b/R/checks.R index f6ae1ce..3acd03a 100644 --- a/R/checks.R +++ b/R/checks.R @@ -172,7 +172,8 @@ check_ref_class <- function(x) { assertRefClass <- checkmate::makeAssertionFunction(check_ref_class) -check_trt_type <- function(data, trt, mtp) { +check_trt_type <- function(data, trt, riesz, mtp) { + if (riesz) return(invisible()) is_decimal <- vector("logical", length(trt)) for (i in seq_along(trt)) { a <- data[[trt[i]]] @@ -200,4 +201,3 @@ check_same_weights <- function(weights) { } assertSameWeights <- checkmate::makeAssertionFunction(check_same_weights) - diff --git a/R/estimators.R b/R/estimators.R index b82734c..5cd03fa 100644 --- a/R/estimators.R +++ b/R/estimators.R @@ -41,6 +41,7 @@ #' @param mtp \[\code{logical(1)}\]\cr #' Is the intervention of interest a modified treatment policy? #' Default is \code{FALSE}. If treatment variables are continuous this should be \code{TRUE}. +#' Ignored if \code{riesz = TRUE}. #' @param outcome_type \[\code{character(1)}\]\cr #' Outcome variable type (i.e., continuous, binomial, survival). #' @param id \[\code{character(1)}\]\cr @@ -54,6 +55,11 @@ #' @param learners_trt \[\code{character}\]\cr A vector of \code{mlr3superlearner} algorithms for estimation #' of the outcome regression. Default is \code{c("mean", "glm")}. #' \bold{Only include candidate learners capable of binary classification}. +#' Ignored if \code{riesz = TRUE}. +#' @param riesz \[\code{logical(1)}\]\cr +#' Use Riesz representers to learn the density ratios? Default is \code{FALSE}. +#' @param module \[\code{function}\]\cr +#' A function that returns a neural network module. Only used if \code{riesz = TRUE}. #' @param folds \[\code{integer(1)}\]\cr #' The number of folds to be used for cross-fitting. #' @param weights \[\code{numeric(nrow(data))}\]\cr @@ -81,7 +87,7 @@ #' \item{shift}{The shift function specifying the treatment policy of interest.} #' \item{outcome_reg}{An n x Tau + 1 matrix of outcome regression predictions. #' The mean of the first column is used for calculating theta.} -#' \item{density_ratios}{An n x Tau matrix of the estimated, non-cumulative, density ratios.} +#' \item{density_ratios}{If \code{riesz = FALSE}, an n x Tau matrix of the estimated, non-cumulative, density ratios. If \code{riesz = TRUE}, an n x Tau matrix of the estimated, cumulative, density ratios.} #' \item{fits_m}{A list the same length as \code{folds}, containing the fits at each time-point #' for each fold for the outcome regression.} #' \item{fits_r}{A list the same length as \code{folds}, containing the fits at each time-point @@ -96,6 +102,8 @@ lmtp_tmle <- function(data, trt, outcome, baseline = NULL, time_vary = NULL, id = NULL, bounds = NULL, learners_outcome = c("mean", "glm"), learners_trt = c("mean", "glm"), + riesz = FALSE, + module = sequential_module(), folds = 10, weights = NULL, control = lmtp_control()) { assertNotDataTable(data) @@ -124,7 +132,7 @@ lmtp_tmle <- function(data, trt, outcome, baseline = NULL, time_vary = NULL, checkmate::assertNumber(control$.bound) checkmate::assertNumber(control$.trim, upper = 1) checkmate::assertLogical(control$.return_full_fits, len = 1) - check_trt_type(data, unlist(trt), mtp) + check_trt_type(data, unlist(trt), riesz, mtp) task <- lmtp_task$new( data = data, @@ -146,10 +154,16 @@ lmtp_tmle <- function(data, trt, outcome, baseline = NULL, time_vary = NULL, pb <- progressr::progressor(task$tau*folds*2) - ratios <- cf_r(task, learners_trt, mtp, control, pb) + if (isFALSE(riesz)) { + ratios <- cf_r(task, learners_trt, mtp, control, pb) + } else { + ratios <- cf_riesz(task, module, mtp, control, pb) + } + estims <- cf_tmle(task, "tmp_lmtp_scaled_outcome", ratios$ratios, + riesz, learners_outcome, control, pb) @@ -302,7 +316,7 @@ lmtp_sdr <- function(data, trt, outcome, baseline = NULL, time_vary = NULL, checkmate::assertNumber(control$.bound) checkmate::assertNumber(control$.trim, upper = 1) checkmate::assertLogical(control$.return_full_fits, len = 1) - check_trt_type(data, unlist(trt), mtp) + check_trt_type(data, unlist(trt), FALSE, mtp) task <- lmtp_task$new( data = data, @@ -609,7 +623,7 @@ lmtp_ipw <- function(data, trt, outcome, baseline = NULL, time_vary = NULL, cens checkmate::assertNumber(control$.bound) checkmate::assertNumber(control$.trim, upper = 1) checkmate::assertLogical(control$.return_full_fits, len = 1) - check_trt_type(data, unlist(trt), mtp) + check_trt_type(data, unlist(trt), FALSE, mtp) task <- lmtp_task$new( data = data, diff --git a/R/lmtp-package.R b/R/lmtp-package.R index cf0c16f..3d5d6f0 100644 --- a/R/lmtp-package.R +++ b/R/lmtp-package.R @@ -1,4 +1,4 @@ -#' @importFrom stats runif as.formula coef glm plogis predict qlogis qnorm pnorm sd quantile var binomial gaussian na.omit weighted.mean +#' @importFrom stats runif as.formula coef glm plogis predict qlogis qnorm pnorm sd quantile var binomial gaussian na.omit weighted.mean model.matrix setNames #' @keywords internal "_PACKAGE" diff --git a/R/lmtp_control.R b/R/lmtp_control.R index 4c8ad6b..6a900a3 100644 --- a/R/lmtp_control.R +++ b/R/lmtp_control.R @@ -13,6 +13,14 @@ #' The number of cross-validation folds for \code{learners_trt}. #' @param .return_full_fits \[\code{logical(1)}\]\cr #' Return full SuperLearner fits? Default is \code{FALSE}, return only SuperLearner weights. +#' @param .epochs \[\code{integer(1)}\]\cr +#' The number of epochs to train the neural network. +#' @param .learning_rate \[\code{numeric(1)}\]\cr +#' The learning rate for the neural network. +#' @param .batch_size \[\code{integer(1)}\]\cr +#' The batch size for the neural network. +#' @param .device \[\code{character(1)}\]\cr +#' The device to train the neural network on. Default is \code{"cpu"}. #' #' @return A list of parameters controlling the estimation procedure. #' @export @@ -21,12 +29,20 @@ #' lmtp_control(.trim = 0.975) lmtp_control <- function(.bound = 1e5, .trim = 0.999, - .learners_outcome_folds = 10, - .learners_trt_folds = 10, - .return_full_fits = FALSE) { + .learners_outcome_folds = NULL, + .learners_trt_folds = NULL, + .return_full_fits = FALSE, + .epochs = 100L, + .learning_rate = 0.01, + .batch_size = 64, + .device = c("cpu", "cuda", "mps")) { list(.bound = .bound, .trim = .trim, .learners_outcome_folds = .learners_outcome_folds, .learners_trt_folds = .learners_trt_folds, - .return_full_fits = .return_full_fits) + .return_full_fits = .return_full_fits, + .epochs = .epochs, + .learning_rate = .learning_rate, + .batch_size = .batch_size, + .device = match.arg(.device)) } diff --git a/R/make_dataset.R b/R/make_dataset.R new file mode 100644 index 0000000..c611c4f --- /dev/null +++ b/R/make_dataset.R @@ -0,0 +1,23 @@ +make_dataset <- function(data, x, device) { + self <- NULL + dataset <- torch::dataset( + name = "tmp_lmtp_dataset", + initialize = function(data, x, device) { + for (df in names(data)) { + if (ncol(data[[df]]) > 0) { + df_x <- data[[df]][, x, drop = FALSE] + self[[df]] <- one_hot_encode(df_x) |> + as_torch(device = device) + } + } + }, + .getitem = function(i) { + fields <- grep("data", names(self), value = TRUE) + setNames(lapply(fields, function(x) self[[x]][i, ]), fields) + }, + .length = function() { + self$data$size()[1] + } + ) + dataset(data, x, device) +} diff --git a/R/nn_riesz.R b/R/nn_riesz.R new file mode 100644 index 0000000..d3f66f8 --- /dev/null +++ b/R/nn_riesz.R @@ -0,0 +1,41 @@ +#' @importFrom checkmate `%??%` +nn_riesz <- function(train, + vars, + module, + .f, + weights = NULL, + batch_size, + learning_rate, + epochs, + device) { + dataset <- make_dataset(train, vars, device = device) + train_dl <- torch::dataloader(dataset, batch_size = batch_size) + model <- module(ncol(dataset$data)) + model$to(device = device) + + weights <- weights %??% 1 + + optimizer <- torch::optim_adam( + params = c(model$parameters), + lr = learning_rate, + weight_decay = 0.01 + ) + + scheduler <- torch::lr_one_cycle(optimizer, max_lr = learning_rate, total_steps = epochs) + + for (epoch in 1:epochs) { + coro::loop(for (b in train_dl) { + # Regression loss + loss <- (model(b$data)$pow(2) - (2 * weights * .f(model, b)))$mean(dtype = torch::torch_float()) + + optimizer$zero_grad() + loss$backward() + + optimizer$step() + }) + scheduler$step() + } + + model$eval() + model +} diff --git a/R/riesz.R b/R/riesz.R new file mode 100644 index 0000000..fddbd12 --- /dev/null +++ b/R/riesz.R @@ -0,0 +1,101 @@ +cf_riesz <- function(task, module, mtp, control, pb) { + out <- vector("list", length = length(task$folds)) + for (fold in seq_along(task$folds)) { + out[[fold]] <- future::future({ + estimate_riesz(get_folded_data(task$natural, task$folds, fold), + get_folded_data(task$shifted, task$folds, fold), + task$trt, + task$cens, + task$risk, + task$tau, + task$node_list$trt, + module, + mtp, + control, + pb) + }, + seed = TRUE) + } + + out <- future::value(out) + recombine_ratios(out, task$folds) +} + +estimate_riesz <- function(natural, + shifted, + trt, + cens, + risk, + tau, + node_list, + module, + mtp, + control, + pb) { + weights <- rep(1, nrow(natural$train)) + riesz_valid <- matrix(data = 0, nrow = nrow(natural$valid), ncol = tau) + fits <- vector("list", length = tau) + + for (t in 1:tau) { + jrt <- censored(natural$train, cens, t)$j + drt <- at_risk(natural$train, risk, t) + irv <- censored(natural$valid, cens, t)$i + jrv <- censored(natural$valid, cens, t)$j + drv <- at_risk(natural$valid, risk, t) + + if (length(trt) > 1) { + trt_t <- trt[[t]] + } else { + trt_t <- trt[[1]] + } + + frv <- followed_rule(natural$valid[, trt_t], shifted$valid[, trt_t], mtp) + + vars <- c(node_list[[t]], cens[[t]]) + + new_shifted_train <- natural$train + new_shifted_train[, trt_t] <- shifted$train[, trt_t] + + model <- nn_riesz( + train = list(data = natural$train[jrt & drt, vars, drop = FALSE], + data_1 = new_shifted_train[jrt & drt, vars, drop = FALSE]), + vars = vars, + module = module, + .f = \(alpha, dl) alpha(dl[["data_1"]]), + weights = weights, + batch_size = control$.batch_size, + learning_rate = control$.learning_rate, + epochs = control$.epochs, + device = control$.device + ) + + # Return the full model object or return nothing + if (control$.return_full_fits) { + fits[[t]] <- model + } else { + fits[[t]] <- NULL + } + + weights <- as.numeric( + model( + as_torch( + one_hot_encode(natural$train[jrv & drv, vars, drop = FALSE]), + device = control$.device + ) + ) + ) + + riesz_valid[jrv & drv, t] <- as.numeric( + model( + as_torch( + one_hot_encode(natural$valid[jrv & drv, vars, drop = FALSE]), + device = control$.device + ) + ) + ) + + pb() + } + + list(ratios = riesz_valid, fits = fits) +} diff --git a/R/sequential_module.R b/R/sequential_module.R new file mode 100644 index 0000000..7ac767b --- /dev/null +++ b/R/sequential_module.R @@ -0,0 +1,27 @@ +#' Sequential neural network module function factory +#' +#' @param layers \[numeric(1)\]\cr Number of hidden layers. +#' @param hidden \[numeric(1)\]\cr Number of hidden units. +#' @param dropout \[numeric(1)\]\cr Dropout rate. +#' +#' @return A function that returns a sequential neural network module. +#' @export +#' +#' @examples +#' if (torch::torch_is_installed()) sequential_module() +sequential_module <- function(layers = 1, hidden = 20, dropout = 0.1) { + function(d_in) { + d_out <- 1 + + middle_layers <- lapply(1:layers, \(x) torch::nn_sequential(torch::nn_linear(hidden, hidden), torch::nn_elu())) + + torch::nn_sequential( + torch::nn_linear(d_in, hidden), + torch::nn_elu(), + do.call(torch::nn_sequential, middle_layers), + torch::nn_linear(hidden, d_out), + torch::nn_dropout(dropout), + torch::nn_softplus() + ) + } +} diff --git a/R/sl.R b/R/sl.R index 7e57ed3..853c187 100644 --- a/R/sl.R +++ b/R/sl.R @@ -1,10 +1,17 @@ run_ensemble <- function(data, y, learners, outcome_type, id, folds) { - fit <- mlr3superlearner(data = data, - target = y, - library = learners, - outcome_type = outcome_type, - folds = folds, - group = id) + fit <- mlr3superlearner::mlr3superlearner( + data = data, + target = y, + library = learners, + outcome_type = outcome_type, + folds = folds, + group = { + if (length(unique(data[[id]])) == nrow(data)) + NULL + else + id + } + ) fit } diff --git a/R/theta.R b/R/theta.R index 7df2f8f..3cd0bad 100644 --- a/R/theta.R +++ b/R/theta.R @@ -59,8 +59,8 @@ theta_ipw <- function(eta) { } eif <- function(r, tau, shifted, natural) { - natural[is.na(natural)] <- -999 - shifted[is.na(shifted)] <- -999 + natural[is.na(natural)] <- 0 + shifted[is.na(shifted)] <- 0 m <- shifted[, 2:(tau + 1), drop = FALSE] - natural[, 1:tau, drop = FALSE] rowSums(compute_weights(r, 1, tau) * m, na.rm = TRUE) + shifted[, 1] } diff --git a/R/tmle.R b/R/tmle.R index a65cc45..7ca4ed5 100644 --- a/R/tmle.R +++ b/R/tmle.R @@ -1,8 +1,11 @@ -cf_tmle <- function(task, outcome, ratios, learners, control, pb) { +cf_tmle <- function(task, outcome, ratios, riesz, learners, control, pb) { out <- vector("list", length = length(task$folds)) - ratios <- matrix(t(apply(ratios, 1, cumprod)), - nrow = nrow(ratios), - ncol = ncol(ratios)) + + if (isFALSE(riesz)) { + ratios <- matrix(t(apply(ratios, 1, cumprod)), + nrow = nrow(ratios), + ncol = ncol(ratios)) + } for (fold in seq_along(task$folds)) { out[[fold]] <- future::future({ diff --git a/R/utils.R b/R/utils.R index fe6ee88..3438b2d 100644 --- a/R/utils.R +++ b/R/utils.R @@ -245,3 +245,12 @@ is_decimal <- function(x) { test <- floor(x) !(x == test) } + +as_torch <- function(data, device) { + torch::torch_tensor(as.matrix(data), dtype = torch::torch_float(), device = device) +} + +one_hot_encode <- function(data, vars) { + tmp <- data[, vars, drop = FALSE] + as.data.frame(model.matrix(~ ., data = tmp))[, -1, drop = FALSE] +} diff --git a/man/lmtp-package.Rd b/man/lmtp-package.Rd index 21ffd1d..6ddfc24 100644 --- a/man/lmtp-package.Rd +++ b/man/lmtp-package.Rd @@ -25,5 +25,10 @@ Authors: \item Iván Díaz \email{Ivan.Diaz@nyulangone.org} (\href{https://orcid.org/0000-0001-9056-2047}{ORCID}) [copyright holder] } +Other contributors: +\itemize{ + \item Herb Susmann \email{herbert.susmann@nyulangone.org} [contributor] +} + } \keyword{internal} diff --git a/man/lmtp_control.Rd b/man/lmtp_control.Rd index 2e1d10a..8e6d25f 100644 --- a/man/lmtp_control.Rd +++ b/man/lmtp_control.Rd @@ -9,7 +9,11 @@ lmtp_control( .trim = 0.999, .learners_outcome_folds = 10, .learners_trt_folds = 10, - .return_full_fits = FALSE + .return_full_fits = FALSE, + .epochs = 100L, + .learning_rate = 0.01, + .batch_size = 64, + .device = c("cpu", "cuda", "mps") ) } \arguments{ @@ -30,6 +34,18 @@ The number of cross-validation folds for \code{learners_trt}.} \item{.return_full_fits}{[\code{logical(1)}]\cr Return full SuperLearner fits? Default is \code{FALSE}, return only SuperLearner weights.} + +\item{.epochs}{[\code{integer(1)}]\cr +The number of epochs to train the neural network.} + +\item{.learning_rate}{[\code{numeric(1)}]\cr +The learning rate for the neural network.} + +\item{.batch_size}{[\code{integer(1)}]\cr +The batch size for the neural network.} + +\item{.device}{[\code{character(1)}]\cr +The device to train the neural network on. Default is \code{"cpu"}.} } \value{ A list of parameters controlling the estimation procedure. diff --git a/man/lmtp_tmle.Rd b/man/lmtp_tmle.Rd index 37bbf17..01b7901 100644 --- a/man/lmtp_tmle.Rd +++ b/man/lmtp_tmle.Rd @@ -20,6 +20,8 @@ lmtp_tmle( bounds = NULL, learners_outcome = c("mean", "glm"), learners_trt = c("mean", "glm"), + riesz = FALSE, + module = sequential_module(), folds = 10, weights = NULL, control = lmtp_control() @@ -71,7 +73,8 @@ all time points.} \item{mtp}{[\code{logical(1)}]\cr Is the intervention of interest a modified treatment policy? -Default is \code{FALSE}. If treatment variables are continuous this should be \code{TRUE}.} +Default is \code{FALSE}. If treatment variables are continuous this should be \code{TRUE}. +Ignored if \code{riesz = TRUE}.} \item{outcome_type}{[\code{character(1)}]\cr Outcome variable type (i.e., continuous, binomial, survival).} @@ -89,7 +92,14 @@ of the outcome regression. Default is \code{c("mean", "glm")}.} \item{learners_trt}{[\code{character}]\cr A vector of \code{mlr3superlearner} algorithms for estimation of the outcome regression. Default is \code{c("mean", "glm")}. -\bold{Only include candidate learners capable of binary classification}.} +\bold{Only include candidate learners capable of binary classification}. +Ignored if \code{riesz = TRUE}.} + +\item{riesz}{[\code{logical(1)}]\cr +Use Riesz representers to learn the density ratios? Default is \code{FALSE}.} + +\item{module}{[\code{function}]\cr +A function that returns a neural network module. Only used if \code{riesz = TRUE}.} \item{folds}{[\code{integer(1)}]\cr The number of folds to be used for cross-fitting.} @@ -112,7 +122,7 @@ A list of class \code{lmtp} containing the following components: \item{shift}{The shift function specifying the treatment policy of interest.} \item{outcome_reg}{An n x Tau + 1 matrix of outcome regression predictions. The mean of the first column is used for calculating theta.} -\item{density_ratios}{An n x Tau matrix of the estimated, non-cumulative, density ratios.} +\item{density_ratios}{If \code{riesz = FALSE}, an n x Tau matrix of the estimated, non-cumulative, density ratios. If \code{riesz = TRUE}, an n x Tau matrix of the estimated, cumulative, density ratios.} \item{fits_m}{A list the same length as \code{folds}, containing the fits at each time-point for each fold for the outcome regression.} \item{fits_r}{A list the same length as \code{folds}, containing the fits at each time-point diff --git a/man/sequential_module.Rd b/man/sequential_module.Rd new file mode 100644 index 0000000..68801b9 --- /dev/null +++ b/man/sequential_module.Rd @@ -0,0 +1,24 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/sequential_module.R +\name{sequential_module} +\alias{sequential_module} +\title{Sequential neural network module function factory} +\usage{ +sequential_module(layers = 1, hidden = 20, dropout = 0.1) +} +\arguments{ +\item{layers}{[numeric(1)]\cr Number of hidden layers.} + +\item{hidden}{[numeric(1)]\cr Number of hidden units.} + +\item{dropout}{[numeric(1)]\cr Dropout rate.} +} +\value{ +A function that returns a sequential neural network module. +} +\description{ +Sequential neural network module function factory +} +\examples{ +if (torch::torch_is_installed()) sequential_module() +} From fee435f20e5532a4447671e6107cce66f8b2b8bb Mon Sep 17 00:00:00 2001 From: Nicholas Williams Date: Wed, 18 Sep 2024 12:32:27 -0700 Subject: [PATCH 2/2] Bug fixes --- R/estimators.R | 4 +++- R/riesz.R | 30 ++++++++++++++++++------------ R/theta.R | 10 ++++++++-- 3 files changed, 29 insertions(+), 15 deletions(-) diff --git a/R/estimators.R b/R/estimators.R index 5cd03fa..d7050d7 100644 --- a/R/estimators.R +++ b/R/estimators.R @@ -157,7 +157,7 @@ lmtp_tmle <- function(data, trt, outcome, baseline = NULL, time_vary = NULL, if (isFALSE(riesz)) { ratios <- cf_r(task, learners_trt, mtp, control, pb) } else { - ratios <- cf_riesz(task, module, mtp, control, pb) + ratios <- cf_riesz(task, module, control, pb) } estims <- cf_tmle(task, @@ -174,6 +174,7 @@ lmtp_tmle <- function(data, trt, outcome, baseline = NULL, time_vary = NULL, m = list(natural = estims$natural, shifted = estims$shifted), r = ratios$ratios, tau = task$tau, + riesz = riesz, folds = task$folds, id = task$id, outcome_type = task$outcome_type, @@ -352,6 +353,7 @@ lmtp_sdr <- function(data, trt, outcome, baseline = NULL, time_vary = NULL, m = list(natural = estims$natural, shifted = estims$shifted), r = ratios$ratios, tau = task$tau, + riesz = FALSE, folds = task$folds, id = task$id, outcome_type = task$outcome_type, diff --git a/R/riesz.R b/R/riesz.R index fddbd12..b2cfdaf 100644 --- a/R/riesz.R +++ b/R/riesz.R @@ -1,4 +1,4 @@ -cf_riesz <- function(task, module, mtp, control, pb) { +cf_riesz <- function(task, module, control, pb) { out <- vector("list", length = length(task$folds)) for (fold in seq_along(task$folds)) { out[[fold]] <- future::future({ @@ -10,7 +10,6 @@ cf_riesz <- function(task, module, mtp, control, pb) { task$tau, task$node_list$trt, module, - mtp, control, pb) }, @@ -29,10 +28,9 @@ estimate_riesz <- function(natural, tau, node_list, module, - mtp, control, pb) { - weights <- rep(1, nrow(natural$train)) + weights <- matrix(0, nrow(natural$train), ncol = tau) riesz_valid <- matrix(data = 0, nrow = nrow(natural$valid), ncol = tau) fits <- vector("list", length = tau) @@ -49,20 +47,28 @@ estimate_riesz <- function(natural, trt_t <- trt[[1]] } - frv <- followed_rule(natural$valid[, trt_t], shifted$valid[, trt_t], mtp) - vars <- c(node_list[[t]], cens[[t]]) - new_shifted_train <- natural$train - new_shifted_train[, trt_t] <- shifted$train[, trt_t] + shifted_train <- natural$train + shifted_train[, trt_t] <- shifted$train[, trt_t] + + if (!is.null(cens)) { + shifted_train[[cens[t]]] <- shifted$train[[cens[t]]] + } + + if ((t - 1) == 0) { + wts <- rep(1, nrow(natural$train)) + } else { + wts <- weights[jrt & drt, t - 1] + } model <- nn_riesz( train = list(data = natural$train[jrt & drt, vars, drop = FALSE], - data_1 = new_shifted_train[jrt & drt, vars, drop = FALSE]), + data_1 = shifted_train[jrt & drt, vars, drop = FALSE]), vars = vars, module = module, .f = \(alpha, dl) alpha(dl[["data_1"]]), - weights = weights, + weights = wts, batch_size = control$.batch_size, learning_rate = control$.learning_rate, epochs = control$.epochs, @@ -76,10 +82,10 @@ estimate_riesz <- function(natural, fits[[t]] <- NULL } - weights <- as.numeric( + weights[jrt & drt, t] <- as.numeric( model( as_torch( - one_hot_encode(natural$train[jrv & drv, vars, drop = FALSE]), + one_hot_encode(natural$train[jrt & drt, vars, drop = FALSE]), device = control$.device ) ) diff --git a/R/theta.R b/R/theta.R index 3cd0bad..ce1c1fa 100644 --- a/R/theta.R +++ b/R/theta.R @@ -58,16 +58,22 @@ theta_ipw <- function(eta) { out } -eif <- function(r, tau, shifted, natural) { +eif <- function(r, tau, cumprod, shifted, natural) { natural[is.na(natural)] <- 0 shifted[is.na(shifted)] <- 0 m <- shifted[, 2:(tau + 1), drop = FALSE] - natural[, 1:tau, drop = FALSE] - rowSums(compute_weights(r, 1, tau) * m, na.rm = TRUE) + shifted[, 1] + if (cumprod) { + out <- rowSums(compute_weights(r, 1, tau) * m, na.rm = TRUE) + shifted[, 1] + } else { + out <- rowSums(r * m, na.rm = TRUE) + shifted[, 1] + } + out } theta_dr <- function(eta, augmented = FALSE) { inflnce <- eif(r = eta$r, tau = eta$tau, + cumprod = !eta$riesz, shifted = eta$m$shifted, natural = eta$m$natural) theta <- {