Skip to content

Commit

Permalink
Fully implement keys + use ndarrays
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzi committed Nov 28, 2023
1 parent 7f792dc commit dbdeb57
Showing 1 changed file with 88 additions and 17 deletions.
105 changes: 88 additions & 17 deletions python-spec/src/somacore/query/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,10 @@ def to_anndata(
*,
column_names: Optional[AxisColumnNames] = None,
X_layers: Sequence[str] = (),
obsm_keys: Sequence[str] = [],
obsm_keys: Sequence[str] = (),
obsp_keys: Sequence[str] = (),
varm_keys: Sequence[str] = (),
varp_keys: Sequence[str] = (),
) -> anndata.AnnData:
"""
Executes the query and return result as an ``AnnData`` in-memory object.
Expand All @@ -278,6 +281,9 @@ def to_anndata(
column_names=column_names or AxisColumnNames(obs=None, var=None),
X_layers=X_layers,
obsm_keys=obsm_keys,
obsp_keys=obsp_keys,
varm_keys=varm_keys,
varp_keys=varp_keys,
).to_anndata()

# Context management
Expand Down Expand Up @@ -319,9 +325,12 @@ def _read(
*,
column_names: AxisColumnNames,
X_layers: Sequence[str],
obsm_keys: Sequence[str] = [], # TODO: Add obsp_keys, varm_keys, varp_keys
obsm_keys: Sequence[str] = (),
obsp_keys: Sequence[str] = (),
varm_keys: Sequence[str] = (),
varp_keys: Sequence[str] = (),
) -> "_AxisQueryResult":
"""Reads the entire query result into in-memory Arrow tables.
"""Reads the entire query result in memory.
This is a low-level routine intended to be used by loaders for other
in-core formats, such as AnnData, which can be created from the
Expand All @@ -333,6 +342,14 @@ def _read(
to read.
X_layers: Additional X layers to read and return
in the ``layers`` slot.
obsm_keys:
Additional obsm keys to read and return in the obsm slot.
obsp_keys:
Additional obsp keys to read and return in the obsp slot.
varm_keys:
Additional varm keys to read and return in the varm slot.
varp_keys:
Additional varp keys to read and return in the varp slot.
"""
x_collection = self._ms.X
all_x_names = [X_name] + list(X_layers)
Expand All @@ -359,11 +376,31 @@ def _read(
# TODO: do it in parallel?
obsm = dict()
for key in obsm_keys:
obsm[key] = self._axism_inner_csr(_Axis.OBS, key)
obsm[key] = self._axism_inner_ndarray(_Axis.OBS, key)

obsp = dict()
for key in obsp_keys:
obsp[key] = self._axisp_inner_ndarray(_Axis.OBS, key)

varm = dict()
for key in varm_keys:
varm[key] = self._axism_inner_ndarray(_Axis.VAR, key)

varp = dict()
for key in varp_keys:
varp[key] = self._axisp_inner_ndarray(_Axis.VAR, key)

x = x_matrices.pop(X_name)

return _AxisQueryResult(
obs=obs_table, var=var_table, X=x, obsm=obsm, X_layers=x_matrices
obs=obs_table,
var=var_table,
X=x,
obsm=obsm,
obsp=obsp,
varm=varm,
varp=varp,
X_layers=x_matrices
)

def _read_both_axes(
Expand Down Expand Up @@ -475,18 +512,46 @@ def _axism_inner(

joinids = getattr(self._joinids, axis.value)
return axism[layer].read((joinids, col_joinids))

def _axisp_inner_ndarray(
self,
axis: "_Axis",
layer: str,
) -> np.ndarray[np.float32]:
key = axis.value + "p"

if key not in self._ms:
raise ValueError(f"Measurement does not contain {key} data")

is_obs = axis is _Axis.OBS

def _axism_inner_csr(
axisp = self._ms.obsp if is_obs else self._ms.varp
if not (layer and layer in axisp):
raise ValueError(f"Must specify '{key}' layer")

joinids = getattr(self._joinids, axis.value)

n_row = n_col = len(self._joinids.obs)

T = axisp[layer].read((joinids, joinids)).tables().concat()
idx = (self.indexer.by_obs if is_obs else self.indexer.by_var)(T["soma_dim_0"])
Z = np.empty(n_row * n_col)
np.put(Z, idx * n_col + T["soma_dim_1"], T["soma_data"])
return Z.reshape(n_row, n_col)

def _axism_inner_ndarray(
self,
axis: "_Axis",
layer: str,
) -> data.SparseRead:
) -> np.ndarray[np.float32]:
key = axis.value + "m"

if key not in self._ms:
raise ValueError(f"Measurement does not contain {key} data")

is_obs = axis is _Axis.OBS

axism = self._ms.obsm if axis is _Axis.OBS else self._ms.varm
axism = self._ms.obsm if is_obs else self._ms.varm
if not (layer and layer in axism):
raise ValueError(f"Must specify '{key}' layer")

Expand All @@ -495,7 +560,13 @@ def _axism_inner_csr(

joinids = getattr(self._joinids, axis.value)

return _fast_csr.read_scipy_csr(axism[layer], joinids, col_idx)
n_row = len(self._joinids.obs)

T = axism[layer].read((joinids, col_idx)).tables().concat()
idx = (self.indexer.by_obs if is_obs else self.indexer.by_var)(T["soma_dim_0"])
Z = np.empty(n_row * n_col)
np.put(Z, idx * n_col + T["soma_dim_1"], T["soma_data"])
return Z.reshape(n_row, n_col)

@property
def _obs_df(self) -> data.DataFrame:
Expand Down Expand Up @@ -535,14 +606,14 @@ class _AxisQueryResult:
"""Experiment.ms[...].X[...] query slice, as an SciPy sparse.csr_matrix """
X_layers: Dict[str, sparse.csr_matrix] = attrs.field(factory=dict)
"""Any additional X layers requested, as SciPy sparse.csr_matrix(s)"""
obsm: Dict[str, sparse.csr_matrix] = attrs.field(factory=dict)
"""Experiment.obsm query slice, as SciPy sparse.csr_matrix(s)"""
obsp: Dict[str, sparse.csr_matrix] = attrs.field(factory=dict)
"""Experiment.obsp query slice, as SciPy sparse.csr_matrix(s)"""
varm: Dict[str, sparse.csr_matrix] = attrs.field(factory=dict)
"""Experiment.varm query slice, as SciPy sparse.csr_matrix(s)"""
varp: Dict[str, sparse.csr_matrix] = attrs.field(factory=dict)
"""Experiment.varp query slice, as SciPy sparse.csr_matrix(s)"""
obsm: Dict[str, np.ndarray[np.float32]] = attrs.field(factory=dict)
"""Experiment.obsm query slice, as a numpy ndarray"""
obsp: Dict[str, np.ndarray[np.float32]] = attrs.field(factory=dict)
"""Experiment.obsp query slice, as a numpy ndarray"""
varm: Dict[str, np.ndarray[np.float32]] = attrs.field(factory=dict)
"""Experiment.varm query slice, as a numpy ndarray"""
varp: Dict[str, np.ndarray[np.float32]] = attrs.field(factory=dict)
"""Experiment.varp query slice, as a numpy ndarray"""

def to_anndata(self) -> anndata.AnnData:
obs = self.obs.to_pandas()
Expand Down

0 comments on commit dbdeb57

Please sign in to comment.