Skip to content

Commit

Permalink
fix R conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
judithabk6 committed Dec 21, 2023
1 parent 2cd241c commit f4a87be
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/med_bench/mediation.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,7 @@ def r_mediation_DML(y, t, m, x, trim=0.05, order=1):
Polynomials/interactions are created using the Generate.
Powers command of the LARF package.
"""
x_r, t_r, m_r, y_r = [base.as_matrix(_convert_array_to_R(uu)) for uu in
x_r, t_r, m_r, y_r = [_convert_array_to_R(uu) for uu in
(x, t, m, y)]
res = causalweight.medDML(y_r, t_r, m_r, x_r, trim=trim, order=order)
raw_res_R = np.array(res.rx2('results'))
Expand Down
10 changes: 7 additions & 3 deletions src/med_bench/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

import rpy2.robjects as robjects
import rpy2.robjects.packages as rpackages
from rpy2.robjects import pandas2ri, numpy2ri
import rpy2.rinterface as ri

ri.initr()
matrix = ri.baseenv["matrix"]


def get_interactions(interaction, *args):
"""
Expand Down Expand Up @@ -59,5 +63,5 @@ def _convert_array_to_R(x):
if len(x.shape) == 1:
return robjects.FloatVector(x)
elif len(x.shape) == 2:
return robjects.r.matrix(robjects.FloatVector(x.ravel()),
nrow=x.shape[0], byrow='TRUE')
return matrix(robjects.FloatVector(x.ravel()),
nrow=x.shape[0], byrow='TRUE')

0 comments on commit f4a87be

Please sign in to comment.