Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Riesz Representer implementation, no conditional effects #147

Open
wants to merge 3 commits into
base: mlr3superlearner
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: lmtp
Title: Non-Parametric Causal Effects of Feasible Interventions Based on Modified Treatment Policies
Version: 1.4.1.9001
Version: 1.5.1.9001
Authors@R:
c(person(given = "Nicholas",
family = "Williams",
Expand All @@ -11,7 +11,11 @@ Authors@R:
family = "Díaz",
email = "[email protected]",
role = c("aut", "cph"),
comment = c(ORCID = "0000-0001-9056-2047")))
comment = c(ORCID = "0000-0001-9056-2047")),
person(given = "Herb",
family = "Susmann",
email = "[email protected]",
role = "ctb"))
Description: Non-parametric estimators for casual effects based on longitudinal modified treatment
policies as described in Diaz, Williams, Hoffman, and Schenck <doi:10.1080/01621459.2021.1955691>, traditional point treatment,
and traditional longitudinal effects. Continuous, binary, categorical treatments, and multivariate treatments are allowed as well are
Expand All @@ -35,8 +39,10 @@ Imports:
progressr,
data.table (>= 1.13.0),
checkmate (>= 2.1.0),
isotone,
mlr3superlearner
isotone,
mlr3superlearner,
torch,
coro
URL: https://beyondtheate.com/, https://github.com/nt-williams/lmtp
BugReports: https://github.com/nt-williams/lmtp/issues
Suggests:
Expand Down
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ export(lmtp_sdr)
export(lmtp_sub)
export(lmtp_survival)
export(lmtp_tmle)
export(sequential_module)
export(static_binary_off)
export(static_binary_on)
export(tidy)
importFrom(R6,R6Class)
importFrom(checkmate,`%??%`)
importFrom(data.table,.SD)
importFrom(data.table,`:=`)
importFrom(data.table,as.data.table)
Expand All @@ -28,6 +30,7 @@ importFrom(stats,binomial)
importFrom(stats,coef)
importFrom(stats,gaussian)
importFrom(stats,glm)
importFrom(stats,model.matrix)
importFrom(stats,na.omit)
importFrom(stats,plogis)
importFrom(stats,pnorm)
Expand All @@ -37,5 +40,6 @@ importFrom(stats,qnorm)
importFrom(stats,quantile)
importFrom(stats,runif)
importFrom(stats,sd)
importFrom(stats,setNames)
importFrom(stats,var)
importFrom(stats,weighted.mean)
4 changes: 2 additions & 2 deletions R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ check_ref_class <- function(x) {

assertRefClass <- checkmate::makeAssertionFunction(check_ref_class)

check_trt_type <- function(data, trt, mtp) {
check_trt_type <- function(data, trt, riesz, mtp) {
if (riesz) return(invisible())
is_decimal <- vector("logical", length(trt))
for (i in seq_along(trt)) {
a <- data[[trt[i]]]
Expand Down Expand Up @@ -200,4 +201,3 @@ check_same_weights <- function(weights) {
}

assertSameWeights <- checkmate::makeAssertionFunction(check_same_weights)

26 changes: 21 additions & 5 deletions R/estimators.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#' @param mtp \[\code{logical(1)}\]\cr
#' Is the intervention of interest a modified treatment policy?
#' Default is \code{FALSE}. If treatment variables are continuous this should be \code{TRUE}.
#' Ignored if \code{riesz = TRUE}.
#' @param outcome_type \[\code{character(1)}\]\cr
#' Outcome variable type (i.e., continuous, binomial, survival).
#' @param id \[\code{character(1)}\]\cr
Expand All @@ -54,6 +55,11 @@
#' @param learners_trt \[\code{character}\]\cr A vector of \code{mlr3superlearner} algorithms for estimation
#' of the outcome regression. Default is \code{c("mean", "glm")}.
#' \bold{Only include candidate learners capable of binary classification}.
#' Ignored if \code{riesz = TRUE}.
#' @param riesz \[\code{logical(1)}\]\cr
#' Use Riesz representers to learn the density ratios? Default is \code{FALSE}.
#' @param module \[\code{function}\]\cr
#' A function that returns a neural network module. Only used if \code{riesz = TRUE}.
#' @param folds \[\code{integer(1)}\]\cr
#' The number of folds to be used for cross-fitting.
#' @param weights \[\code{numeric(nrow(data))}\]\cr
Expand Down Expand Up @@ -81,7 +87,7 @@
#' \item{shift}{The shift function specifying the treatment policy of interest.}
#' \item{outcome_reg}{An n x Tau + 1 matrix of outcome regression predictions.
#' The mean of the first column is used for calculating theta.}
#' \item{density_ratios}{An n x Tau matrix of the estimated, non-cumulative, density ratios.}
#' \item{density_ratios}{If \code{riesz = FALSE}, an n x Tau matrix of the estimated, non-cumulative, density ratios. If \code{riesz = TRUE}, an n x Tau matrix of the estimated, cumulative, density ratios.}
#' \item{fits_m}{A list the same length as \code{folds}, containing the fits at each time-point
#' for each fold for the outcome regression.}
#' \item{fits_r}{A list the same length as \code{folds}, containing the fits at each time-point
Expand All @@ -96,6 +102,8 @@ lmtp_tmle <- function(data, trt, outcome, baseline = NULL, time_vary = NULL,
id = NULL, bounds = NULL,
learners_outcome = c("mean", "glm"),
learners_trt = c("mean", "glm"),
riesz = FALSE,
module = sequential_module(),
folds = 10, weights = NULL,
control = lmtp_control()) {
assertNotDataTable(data)
Expand Down Expand Up @@ -124,7 +132,7 @@ lmtp_tmle <- function(data, trt, outcome, baseline = NULL, time_vary = NULL,
checkmate::assertNumber(control$.bound)
checkmate::assertNumber(control$.trim, upper = 1)
checkmate::assertLogical(control$.return_full_fits, len = 1)
check_trt_type(data, unlist(trt), mtp)
check_trt_type(data, unlist(trt), riesz, mtp)

task <- lmtp_task$new(
data = data,
Expand All @@ -146,10 +154,16 @@ lmtp_tmle <- function(data, trt, outcome, baseline = NULL, time_vary = NULL,

pb <- progressr::progressor(task$tau*folds*2)

ratios <- cf_r(task, learners_trt, mtp, control, pb)
if (isFALSE(riesz)) {
ratios <- cf_r(task, learners_trt, mtp, control, pb)
} else {
ratios <- cf_riesz(task, module, control, pb)
}

estims <- cf_tmle(task,
"tmp_lmtp_scaled_outcome",
ratios$ratios,
riesz,
learners_outcome,
control,
pb)
Expand All @@ -160,6 +174,7 @@ lmtp_tmle <- function(data, trt, outcome, baseline = NULL, time_vary = NULL,
m = list(natural = estims$natural, shifted = estims$shifted),
r = ratios$ratios,
tau = task$tau,
riesz = riesz,
folds = task$folds,
id = task$id,
outcome_type = task$outcome_type,
Expand Down Expand Up @@ -302,7 +317,7 @@ lmtp_sdr <- function(data, trt, outcome, baseline = NULL, time_vary = NULL,
checkmate::assertNumber(control$.bound)
checkmate::assertNumber(control$.trim, upper = 1)
checkmate::assertLogical(control$.return_full_fits, len = 1)
check_trt_type(data, unlist(trt), mtp)
check_trt_type(data, unlist(trt), FALSE, mtp)

task <- lmtp_task$new(
data = data,
Expand Down Expand Up @@ -338,6 +353,7 @@ lmtp_sdr <- function(data, trt, outcome, baseline = NULL, time_vary = NULL,
m = list(natural = estims$natural, shifted = estims$shifted),
r = ratios$ratios,
tau = task$tau,
riesz = FALSE,
folds = task$folds,
id = task$id,
outcome_type = task$outcome_type,
Expand Down Expand Up @@ -609,7 +625,7 @@ lmtp_ipw <- function(data, trt, outcome, baseline = NULL, time_vary = NULL, cens
checkmate::assertNumber(control$.bound)
checkmate::assertNumber(control$.trim, upper = 1)
checkmate::assertLogical(control$.return_full_fits, len = 1)
check_trt_type(data, unlist(trt), mtp)
check_trt_type(data, unlist(trt), FALSE, mtp)

task <- lmtp_task$new(
data = data,
Expand Down
2 changes: 1 addition & 1 deletion R/lmtp-package.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#' @importFrom stats runif as.formula coef glm plogis predict qlogis qnorm pnorm sd quantile var binomial gaussian na.omit weighted.mean
#' @importFrom stats runif as.formula coef glm plogis predict qlogis qnorm pnorm sd quantile var binomial gaussian na.omit weighted.mean model.matrix setNames
#' @keywords internal
"_PACKAGE"

Expand Down
24 changes: 20 additions & 4 deletions R/lmtp_control.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@
#' Use discrete or ensemble super learner?
#' @param .info \[\code{logical(1)}\]\cr
#' Print super learner fitting info to the console?
#' @param .epochs \[\code{integer(1)}\]\cr
#' The number of epochs to train the neural network.
#' @param .learning_rate \[\code{numeric(1)}\]\cr
#' The learning rate for the neural network.
#' @param .batch_size \[\code{integer(1)}\]\cr
#' The batch size for the neural network.
#' @param .device \[\code{character(1)}\]\cr
#' The device to train the neural network on. Default is \code{"cpu"}.
#'
#' @return A list of parameters controlling the estimation procedure.
#' @export
Expand All @@ -25,16 +33,24 @@
#' lmtp_control(.trim = 0.975)
lmtp_control <- function(.bound = 1e5,
.trim = 0.999,
.learners_outcome_folds = 10,
.learners_trt_folds = 10,
.learners_outcome_folds = NULL,
.learners_trt_folds = NULL,
.return_full_fits = FALSE,
.discrete = TRUE,
.info = FALSE) {
.info = FALSE,
.epochs = 100L,
.learning_rate = 0.01,
.batch_size = 64,
.device = c("cpu", "cuda", "mps")) {
list(.bound = .bound,
.trim = .trim,
.learners_outcome_folds = .learners_outcome_folds,
.learners_trt_folds = .learners_trt_folds,
.return_full_fits = .return_full_fits,
.discrete = .discrete,
.info = .info)
.info = .info,
.epochs = .epochs,
.learning_rate = .learning_rate,
.batch_size = .batch_size,
.device = match.arg(.device))
}
23 changes: 23 additions & 0 deletions R/make_dataset.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
make_dataset <- function(data, x, device) {
self <- NULL
dataset <- torch::dataset(
name = "tmp_lmtp_dataset",
initialize = function(data, x, device) {
for (df in names(data)) {
if (ncol(data[[df]]) > 0) {
df_x <- data[[df]][, x, drop = FALSE]
self[[df]] <- one_hot_encode(df_x) |>
as_torch(device = device)
}
}
},
.getitem = function(i) {
fields <- grep("data", names(self), value = TRUE)
setNames(lapply(fields, function(x) self[[x]][i, ]), fields)
},
.length = function() {
self$data$size()[1]
}
)
dataset(data, x, device)
}
41 changes: 41 additions & 0 deletions R/nn_riesz.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#' @importFrom checkmate `%??%`
nn_riesz <- function(train,
vars,
module,
.f,
weights = NULL,
batch_size,
learning_rate,
epochs,
device) {
dataset <- make_dataset(train, vars, device = device)
train_dl <- torch::dataloader(dataset, batch_size = batch_size)
model <- module(ncol(dataset$data))
model$to(device = device)

weights <- weights %??% 1

optimizer <- torch::optim_adam(
params = c(model$parameters),
lr = learning_rate,
weight_decay = 0.01
)

scheduler <- torch::lr_one_cycle(optimizer, max_lr = learning_rate, total_steps = epochs)

for (epoch in 1:epochs) {
coro::loop(for (b in train_dl) {
# Regression loss
loss <- (model(b$data)$pow(2) - (2 * weights * .f(model, b)))$mean(dtype = torch::torch_float())

optimizer$zero_grad()
loss$backward()

optimizer$step()
})
scheduler$step()
}

model$eval()
model
}
107 changes: 107 additions & 0 deletions R/riesz.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
cf_riesz <- function(task, module, control, pb) {
out <- vector("list", length = length(task$folds))
for (fold in seq_along(task$folds)) {
out[[fold]] <- future::future({
estimate_riesz(get_folded_data(task$natural, task$folds, fold),
get_folded_data(task$shifted, task$folds, fold),
task$trt,
task$cens,
task$risk,
task$tau,
task$node_list$trt,
module,
control,
pb)
},
seed = TRUE)
}

out <- future::value(out)
recombine_ratios(out, task$folds)
}

estimate_riesz <- function(natural,
shifted,
trt,
cens,
risk,
tau,
node_list,
module,
control,
pb) {
weights <- matrix(0, nrow(natural$train), ncol = tau)
riesz_valid <- matrix(data = 0, nrow = nrow(natural$valid), ncol = tau)
fits <- vector("list", length = tau)

for (t in 1:tau) {
jrt <- censored(natural$train, cens, t)$j
drt <- at_risk(natural$train, risk, t)
irv <- censored(natural$valid, cens, t)$i
jrv <- censored(natural$valid, cens, t)$j
drv <- at_risk(natural$valid, risk, t)

if (length(trt) > 1) {
trt_t <- trt[[t]]
} else {
trt_t <- trt[[1]]
}

vars <- c(node_list[[t]], cens[[t]])

shifted_train <- natural$train
shifted_train[, trt_t] <- shifted$train[, trt_t]

if (!is.null(cens)) {
shifted_train[[cens[t]]] <- shifted$train[[cens[t]]]
}

if ((t - 1) == 0) {
wts <- rep(1, nrow(natural$train))
} else {
wts <- weights[jrt & drt, t - 1]
}

model <- nn_riesz(
train = list(data = natural$train[jrt & drt, vars, drop = FALSE],
data_1 = shifted_train[jrt & drt, vars, drop = FALSE]),
vars = vars,
module = module,
.f = \(alpha, dl) alpha(dl[["data_1"]]),
weights = wts,
batch_size = control$.batch_size,
learning_rate = control$.learning_rate,
epochs = control$.epochs,
device = control$.device
)

# Return the full model object or return nothing
if (control$.return_full_fits) {
fits[[t]] <- model
} else {
fits[[t]] <- NULL
}

weights[jrt & drt, t] <- as.numeric(
model(
as_torch(
one_hot_encode(natural$train[jrt & drt, vars, drop = FALSE]),
device = control$.device
)
)
)

riesz_valid[jrv & drv, t] <- as.numeric(
model(
as_torch(
one_hot_encode(natural$valid[jrv & drv, vars, drop = FALSE]),
device = control$.device
)
)
)

pb()
}

list(ratios = riesz_valid, fits = fits)
}
Loading