Skip to content

Commit

Permalink
Adding support for C++ re-indexer
Browse files Browse the repository at this point in the history
  • Loading branch information
beroy committed Oct 4, 2023
1 parent 767d4a1 commit 439f82b
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 14 deletions.
20 changes: 16 additions & 4 deletions apis/python/src/tiledbsoma/_fast_csr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from concurrent import futures
from tiledbsoma.IntIndexer import IntIndexer
from typing import List, NamedTuple, Tuple, Type, cast

import numba
Expand All @@ -10,6 +11,7 @@
import pyarrow as pa
from scipy import sparse
import somacore
from typing import Union
from . import _eager_iter


Expand Down Expand Up @@ -61,14 +63,19 @@ def __init__(
obs_joinids: npt.NDArray[np.int64],
var_joinids: npt.NDArray[np.int64],
pool: futures.Executor,
use_int_indexer: bool = True,
):
self.obs_joinids = obs_joinids
self.var_joinids = var_joinids
self.pool = pool

self.shape: Tuple[int, int] = (len(self.obs_joinids), len(self.var_joinids))
self.obs_indexer: pd.Index = pd.Index(self.obs_joinids)
self.var_indexer: pd.Index = pd.Index(self.var_joinids)
self.obs_indexer: Union[IntIndexer, pd.Index] = \
IntIndexer.map_locations(self.obs_joinids) if use_int_indexer else pd.Index(self.obs_joinids)

self.var_indexer: Union[IntIndexer, pd.Index] = \
IntIndexer.map_locations(self.var_joinids) if use_int_indexer else pd.Index(self.var_joinids)

self.row_length: npt.NDArray[np.int64] = np.zeros(
(self.shape[0],), dtype=_select_dtype(self.shape[1])
)
Expand Down Expand Up @@ -242,10 +249,15 @@ def _select_dtype(


def _reindex_and_cast(
index: pd.Index, ids: npt.NDArray[np.int64], target_dtype: npt.DTypeLike
index: Union[IntIndexer, pd.Index], ids: npt.NDArray[np.int64], target_dtype: npt.DTypeLike
) -> npt.NDArray[np.int64]:
lookup_func: Funtion
if str(type(index)).__contains__("IntIndexer"):
lookup_func = index.lookup
else:
lookup_func = index.get_indexer
return cast(
npt.NDArray[np.int64], index.get_indexer(ids).astype(target_dtype, copy=False)
npt.NDArray[np.int64], lookup_func(ids).astype(target_dtype, copy=False)
)


Expand Down
45 changes: 35 additions & 10 deletions apis/python/src/tiledbsoma/_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
import pyarrow as pa
from somacore import data, query
from scipy import sparse


from tiledbsoma.IntIndexer import IntIndexer

from . import _fast_csr
Expand Down Expand Up @@ -87,32 +85,43 @@ class AxisIndexer:
query: ExperimentAxisQuery
_cached_obs: Optional[pd.Index] = None
_cached_var: Optional[pd.Index] = None
_use_int_indexer = True

@property
def _obs_index(self) -> IntIndexer:
"""Private. Return an index for the ``obs`` axis."""
if self._cached_obs is None:
self._cached_obs = IntIndexer.map_locations(
self.query.obs_joinids().to_numpy()
)
if _use_int_indexer:
self._cached_obs = IntIndexer.map_locations(
self.query.obs_joinids().to_numpy()
)
else:
self._cached_obs = pd.Index(data=self.query.var_joinids().to_numpy())
return self._cached_obs

@property
def _var_index(self) -> IntIndexer:
"""Private. Return an index for the ``var`` axis."""
if self._cached_var is None:
self._cached_var = IntIndexer.map_locations(
self.query.var_joinids().to_numpy()
)
if _use_int_indexer:
self._cached_var = IntIndexer.map_locations(
self.query.var_joinids().to_numpy()
)
else:
self._cached_var = pd.Index(data=self.query.var_joinids().to_numpy())
return self._cached_var

def by_obs(self, coords: _Numpyable) -> npt.NDArray[np.intp]:
"""Reindex the coords (soma_joinids) over the ``obs`` axis."""
return self._obs_index.lookup(_to_numpy(coords))
if _use_int_indexer:
return self._obs_index.lookup(_to_numpy(coords))
return self._obs_index.get_indexer(_to_numpy(coords))

def by_var(self, coords: _Numpyable) -> npt.NDArray[np.intp]:
"""Reindex for the coords (soma_joinids) over the ``var`` axis."""
return self._var_index.lookup(_to_numpy(coords))
if _use_int_indexer:
return self._var_index.lookup(_to_numpy(coords))
return self._var_index.get_indexer(_to_numpy(coords))

@attrs.define(frozen=True)
class AxisQueryResult:
Expand All @@ -138,3 +147,19 @@ def to_anndata(self) -> anndata.AnnData:
X=self.X, obs=obs, var=var, layers=(self.X_layers or None)
)

def _to_numpy(it: _Numpyable) -> np.ndarray:
if isinstance(it, np.ndarray):
return it
return it.to_numpy()


class _Experimentish(Protocol):
"""The API we need from an Experiment."""

@property
def ms(self) -> Mapping[str, measurement.Measurement]:
...

@property
def obs(self) -> data.DataFrame:
...

0 comments on commit 439f82b

Please sign in to comment.