diff --git a/src/med_bench/mediation.py b/src/med_bench/mediation.py index e65baf0..5848f2e 100644 --- a/src/med_bench/mediation.py +++ b/src/med_bench/mediation.py @@ -957,7 +957,7 @@ def r_mediation_DML(y, t, m, x, trim=0.05, order=1): Polynomials/interactions are created using the Generate. Powers command of the LARF package. """ - x_r, t_r, m_r, y_r = [base.as_matrix(_convert_array_to_R(uu)) for uu in + x_r, t_r, m_r, y_r = [_convert_array_to_R(uu) for uu in (x, t, m, y)] res = causalweight.medDML(y_r, t_r, m_r, x_r, trim=trim, order=order) raw_res_R = np.array(res.rx2('results')) diff --git a/src/med_bench/utils/utils.py b/src/med_bench/utils/utils.py index 05f1681..0e3a19f 100644 --- a/src/med_bench/utils/utils.py +++ b/src/med_bench/utils/utils.py @@ -2,7 +2,11 @@ import rpy2.robjects as robjects import rpy2.robjects.packages as rpackages -from rpy2.robjects import pandas2ri, numpy2ri +import rpy2.rinterface as ri + +ri.initr() +matrix = ri.baseenv["matrix"] + def get_interactions(interaction, *args): """ @@ -59,5 +63,5 @@ def _convert_array_to_R(x): if len(x.shape) == 1: return robjects.FloatVector(x) elif len(x.shape) == 2: - return robjects.r.matrix(robjects.FloatVector(x.ravel()), - nrow=x.shape[0], byrow='TRUE') + return matrix(robjects.FloatVector(x.ravel()), + nrow=x.shape[0], byrow='TRUE')