Skip to content

Commit

Permalink
Reformat codes
Browse files Browse the repository at this point in the history
  • Loading branch information
gaow committed Mar 17, 2024
1 parent 308b496 commit 134d889
Show file tree
Hide file tree
Showing 24 changed files with 3,157 additions and 2,561 deletions.
306 changes: 173 additions & 133 deletions R/bayesian_multivariate_regression.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,192 +2,232 @@
#
#' @importFrom R6 R6Class
#' @importFrom stats optim
#'
BayesianMultivariateRegression = R6Class("BayesianMultivariateRegression",
#'
BayesianMultivariateRegression <- R6Class("BayesianMultivariateRegression",
inherit = BayesianSimpleRegression,

public = list(
initialize = function (J, prior_variance) {
private$J = J
private$.prior_variance = prior_variance
private$.posterior_b1 = matrix(0,J,nrow(prior_variance))
private$prior_variance_scalar = 1
initialize = function(J, prior_variance) {
private$J <- J
private$.prior_variance <- prior_variance
private$.posterior_b1 <- matrix(0, J, nrow(prior_variance))
private$prior_variance_scalar <- 1
return(invisible(self))
},

# This returns the R6 object invisibly.
fit = function (d, prior_weights = NULL, use_residual = FALSE,
save_summary_stats = FALSE, save_var = FALSE,
estimate_prior_variance_method = NULL,
check_null_threshold = 0) {

fit = function(d, prior_weights = NULL, use_residual = FALSE,
save_summary_stats = FALSE, save_var = FALSE,
estimate_prior_variance_method = NULL,
check_null_threshold = 0) {
# d: data object
# use_residual: fit with residual instead of with Y.
# A special feature for when used with SuSiE algorithm.
# bhat is J by R
bhat = d$get_coef(use_residual)
if (is.numeric(d$svs))
bhat <- d$get_coef(use_residual)
if (is.numeric(d$svs)) {
# X2_sum is a length J vector
sbhat2 = lapply(1:d$n_effect,
function (j) d$residual_variance / d$X2_sum[j])
else
sbhat2 = d$svs
for (j in 1:length(sbhat2))
sbhat2 <- lapply(
1:d$n_effect,
function(j) d$residual_variance / d$X2_sum[j]
)
} else {
sbhat2 <- d$svs
}
for (j in 1:length(sbhat2)) {
sbhat2[[j]][which(is.nan(sbhat2[[j]]) |
is.infinite(sbhat2[[j]]))] = 1e6
is.infinite(sbhat2[[j]]))] <- 1e6
}
if (save_summary_stats) {
private$.bhat = bhat
private$.sbhat = sqrt(do.call(rbind,lapply(1:length(sbhat2),
function (j) diag(sbhat2[[j]]))))
private$.bhat <- bhat
private$.sbhat <- sqrt(do.call(rbind, lapply(
1:length(sbhat2),
function(j) diag(sbhat2[[j]])
)))
}

# Deal with prior variance: can be "estimated" across effects.
if (!is.null(estimate_prior_variance_method)) {
if (estimate_prior_variance_method == "EM")
private$cache = list(b = bhat,s = sbhat2)
else
private$prior_variance_scalar =
private$estimate_prior_variance(bhat,sbhat2,prior_weights,
if (estimate_prior_variance_method == "EM") {
private$cache <- list(b = bhat, s = sbhat2)
} else {
private$prior_variance_scalar <-
private$estimate_prior_variance(bhat, sbhat2, prior_weights,
method = estimate_prior_variance_method,
check_null_threshold = check_null_threshold)
check_null_threshold = check_null_threshold
)
}
}

# Posterior calculations.
post = multivariate_regression(bhat,sbhat2,
private$.prior_variance * private$prior_variance_scalar,
d$svs_inv)
private$.posterior_b1 = post$b1
private$.posterior_b2 = post$b2
if (save_var)
private$.posterior_variance = post$cov
private$.lbf = post$lbf
post <- multivariate_regression(
bhat, sbhat2,
private$.prior_variance * private$prior_variance_scalar,
d$svs_inv
)
private$.posterior_b1 <- post$b1
private$.posterior_b2 <- post$b2
if (save_var) {
private$.posterior_variance <- post$cov
}
private$.lbf <- post$lbf

return(invisible(self))
}
),

private = list(
.prior_variance = NULL,
.prior_variance = NULL,
.prior_variance_inv = NULL,

loglik = function (scalar, bhat, S, prior_weights) {
U = private$.prior_variance * scalar
lbf = multivariate_lbf(bhat,S,U)
return(compute_softmax(lbf,prior_weights)$log_sum)
loglik = function(scalar, bhat, S, prior_weights) {
U <- private$.prior_variance * scalar
lbf <- multivariate_lbf(bhat, S, U)
return(compute_softmax(lbf, prior_weights)$log_sum)
},
estimate_prior_variance_optim = function(betahat, shat2, prior_weights, ...) {
exp(optim(
par = 0, fn = private$neg_loglik_logscale, betahat = betahat,
shat2 = shat2, prior_weights = prior_weights, ...
)$par)
},

estimate_prior_variance_optim = function (betahat,shat2,prior_weights,...)
exp(optim(par = 0,fn = private$neg_loglik_logscale,betahat = betahat,
shat2 = shat2,prior_weights = prior_weights,...)$par),

estimate_prior_variance_em_direct_inv =
function (pip, inv_function = pseudo_inverse) {

# Update directly using inverse of prior matrix
# This is very similar to updating the univariate case via EM,
# \sigma_0^2 = \mathrm{tr}(S_0^{-1} E[bb^T])/r
# where S_0 is prior variance, E[bb^T] is 2nd moment of SER effect:
# that is, E[bb^T] = \sum_j alpha_j * b_jb_j^T where b_j is
# posterior of j. Recall in univariate case it is \sigma_0^2 =
# E[bb^T] directly.
if (is.null(private$.prior_variance_inv))
private$.prior_variance_inv = inv_function(private$.prior_variance)
if (length(dim(private$.posterior_b2)) == 3)

# When R > 1.
mu2 = Reduce("+",lapply(1:length(pip),
function (j) pip[j] * private$.posterior_b2[,,j]))
else {

# When R = 1 each post_b2 is a scalar. Now make it a matrix
# to be compatable with later computations.
if (ncol(private$.posterior_b2) != 1)
stop("Data dimension is incorrect for posterior_b2")
mu2 = matrix(sum(pip * private$.posterior_b2[,1]),1,1)
}
V = sum(diag(private$.prior_variance_inv$inv %*% mu2)) /
function(pip, inv_function = pseudo_inverse) {
# Update directly using inverse of prior matrix
# This is very similar to updating the univariate case via EM,
# \sigma_0^2 = \mathrm{tr}(S_0^{-1} E[bb^T])/r
# where S_0 is prior variance, E[bb^T] is 2nd moment of SER effect:
# that is, E[bb^T] = \sum_j alpha_j * b_jb_j^T where b_j is
# posterior of j. Recall in univariate case it is \sigma_0^2 =
# E[bb^T] directly.
if (is.null(private$.prior_variance_inv)) {
private$.prior_variance_inv <- inv_function(private$.prior_variance)
}
if (length(dim(private$.posterior_b2)) == 3) {

# When R > 1.
mu2 <- Reduce("+", lapply(
1:length(pip),
function(j) pip[j] * private$.posterior_b2[, , j]
))
} else {
# When R = 1 each post_b2 is a scalar. Now make it a matrix
# to be compatable with later computations.
if (ncol(private$.posterior_b2) != 1) {
stop("Data dimension is incorrect for posterior_b2")
}
mu2 <- matrix(sum(pip * private$.posterior_b2[, 1]), 1, 1)
}
V <- sum(diag(private$.prior_variance_inv$inv %*% mu2)) /
private$.prior_variance_inv$rank
return(V)
},

estimate_prior_variance_em_inv_safe = function (pip) {

return(V)
},
estimate_prior_variance_em_inv_safe = function(pip) {
# Instead of computing S_0^{-1} and E[bb^T] we compute them as
# one quantity to avoid explicit inverse. We need S_inv a J
# vector of R by R matrices (private$cache$s), bhat a J by R
# vector (private$cache$b), the original prior matrix S_0
# (private$.prior_variance) and the scalar from previous update
# (private$prior_variance_scalar).
# U = \sigma_0 S_0
U = private$prior_variance_scalar * private$.prior_variance
S_inv = lapply(1:private$J,
function (j) invert_via_chol(private$cache$s[[j]]))$inv

U <- private$prior_variance_scalar * private$.prior_variance
S_inv <- lapply(
1:private$J,
function(j) invert_via_chol(private$cache$s[[j]])
)$inv

# posterior covariance pre-multipled by U^{-1}
post_cov_U = lapply(1:private$J,
function (j) solve(diag(nrow(U)) + S_inv[[j]] %*% U))

post_cov_U <- lapply(
1:private$J,
function(j) solve(diag(nrow(U)) + S_inv[[j]] %*% U)
)

# posterior first moment pre-multipled by U^{-1}
post_b1_U = lapply(1:private$J,
function (j) post_cov_U[[j]] %*%
(S_inv[[j]] %*% private$cache$b[j,]))

post_b1_U <- lapply(
1:private$J,
function(j) {
post_cov_U[[j]] %*%
(S_inv[[j]] %*% private$cache$b[j, ])
}
)

# Posterior 2nd moment pre-multiplied by S_0^{-1}.
b2_U = lapply(1:private$J,
function (j) private$prior_variance_scalar *
(tcrossprod(post_b1_U[[j]]) %*% U + post_cov_U[[j]]))
V = sum(diag(Reduce("+",lapply(1:private$J,
function(j) pip[j] * b2_U[[j]]))))/nrow(U)
b2_U <- lapply(
1:private$J,
function(j) {
private$prior_variance_scalar *
(tcrossprod(post_b1_U[[j]]) %*% U + post_cov_U[[j]])
}
)
V <- sum(diag(Reduce("+", lapply(
1:private$J,
function(j) pip[j] * b2_U[[j]]
)))) / nrow(U)
return(V)
},

estimate_prior_variance_em = function (pip) {
tryCatch({
private$estimate_prior_variance_em_direct_inv(pip,
inv_function = pseudo_inverse)
},
error = function (e)
private$estimate_prior_variance_em_inv_safe(pip))
estimate_prior_variance_em = function(pip) {
tryCatch(
{
private$estimate_prior_variance_em_direct_inv(pip,
inv_function = pseudo_inverse
)
},
error = function(e) {
private$estimate_prior_variance_em_inv_safe(pip)
}
)
},

estimate_prior_variance_simple = function() 1
)
)

# Multiviate regression calculations.
#
#' @importFrom abind abind
multivariate_regression = function (bhat, S, U, S_inv) {
if (is.numeric(S_inv))
S_inv = lapply(1:length(S),function(j) invert_via_chol(S[[j]]))
post_cov = lapply(1:length(S),
function(j) U %*% solve(diag(nrow(U)) + S_inv[[j]] %*% U))
lbf = sapply(1:length(S),
function(j) (log(det(S[[j]])) - log(det(S[[j]] + U)))/2 +
t(bhat[j,]) %*% S_inv[[j]] %*%
post_cov[[j]] %*% S_inv[[j]] %*% bhat[j,]/2)
lbf[which(is.nan(lbf))] = 0
multivariate_regression <- function(bhat, S, U, S_inv) {
if (is.numeric(S_inv)) {
S_inv <- lapply(1:length(S), function(j) invert_via_chol(S[[j]]))
}
post_cov <- lapply(
1:length(S),
function(j) U %*% solve(diag(nrow(U)) + S_inv[[j]] %*% U)
)
lbf <- sapply(
1:length(S),
function(j) {
(log(det(S[[j]])) - log(det(S[[j]] + U))) / 2 +
t(bhat[j, ]) %*% S_inv[[j]] %*%
post_cov[[j]] %*% S_inv[[j]] %*% bhat[j, ] / 2
}
)
lbf[which(is.nan(lbf))] <- 0

# Using rbind here will end up with dimension issues for degenerated
# case on J; have to use t(...(cbind, )) instead.
post_b1 = t(do.call(cbind,lapply(1:length(S),
function(j) post_cov[[j]] %*% (S_inv[[j]] %*% bhat[j,]))))
post_b2 = lapply(1:length(post_cov),
function(j) tcrossprod(post_b1[j,]) + post_cov[[j]])

post_b1 <- t(do.call(cbind, lapply(
1:length(S),
function(j) post_cov[[j]] %*% (S_inv[[j]] %*% bhat[j, ])
)))
post_b2 <- lapply(
1:length(post_cov),
function(j) tcrossprod(post_b1[j, ]) + post_cov[[j]]
)

# Deal with degerate case with one condition.
if (ncol(post_b1) == 1)
post_b2 = matrix(unlist(post_b2),length(post_b2),1)
else
post_b2 = aperm(abind(post_b2, along = 3),c(2,1,3))
return(list(b1 = post_b1,b2 = post_b2,lbf = lbf,cov = post_cov))
if (ncol(post_b1) == 1) {
post_b2 <- matrix(unlist(post_b2), length(post_b2), 1)
} else {
post_b2 <- aperm(abind(post_b2, along = 3), c(2, 1, 3))
}
return(list(b1 = post_b1, b2 = post_b2, lbf = lbf, cov = post_cov))
}

#' @importFrom mvtnorm dmvnorm
multivariate_lbf = function (bhat, S, U) {
lbf = sapply(1:length(S),
function(j) dmvnorm(x = bhat[j,],sigma = S[[j]] + U,log = TRUE) -
dmvnorm(x = bhat[j,],sigma = S[[j]],log = TRUE))
lbf[which(is.nan(lbf))] = 0
multivariate_lbf <- function(bhat, S, U) {
lbf <- sapply(
1:length(S),
function(j) {
dmvnorm(x = bhat[j, ], sigma = S[[j]] + U, log = TRUE) -
dmvnorm(x = bhat[j, ], sigma = S[[j]], log = TRUE)
}
)
lbf[which(is.nan(lbf))] <- 0
return(lbf)
}
Loading

0 comments on commit 134d889

Please sign in to comment.