Skip to content

Commit

Permalink
Extending CSR Accumulator to reuse query context instead of creating …
Browse files Browse the repository at this point in the history
…its own
  • Loading branch information
beroy committed Dec 4, 2023
1 parent 1d71cc1 commit e808e51
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
9 changes: 6 additions & 3 deletions python-spec/src/somacore/query/_fast_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy.typing as npt
import pyarrow as pa
from scipy import sparse

from typing import Optional, Any
from .. import data as scd
from . import _eager_iter
from . import types
Expand All @@ -19,6 +19,7 @@ def read_csr(
obs_joinids: pa.Array,
var_joinids: pa.Array,
index_factory: types.IndexFactory,
context: Optional[Any]
) -> "AccumulatedCSR":
if not isinstance(matrix, scd.SparseNDArray) or matrix.ndim != 2:
raise TypeError("Can only read from a 2D SparseNDArray")
Expand All @@ -30,6 +31,7 @@ def read_csr(
var_joinids=var_joinids,
pool=pool,
index_factory=index_factory,
context = context,
)
for tbl in _eager_iter.EagerIterator(
matrix.read((obs_joinids, var_joinids)).tables(),
Expand Down Expand Up @@ -86,14 +88,15 @@ def __init__(
var_joinids: npt.NDArray[np.int64],
pool: futures.Executor,
index_factory: types.IndexFactory,
context: Optional[Any],
):
self.obs_joinids = obs_joinids
self.var_joinids = var_joinids
self.pool = pool

self.shape: Tuple[int, int] = (len(self.obs_joinids), len(self.var_joinids))
self.obs_indexer = index_factory(self.obs_joinids)
self.var_indexer = index_factory(self.var_joinids)
self.obs_indexer = index_factory(self.obs_joinids, context)
self.var_indexer = index_factory(self.var_joinids, context)
self.row_length: npt.NDArray[np.int64] = np.zeros(
(self.shape[0],), dtype=_select_dtype(self.shape[1])
)
Expand Down
1 change: 1 addition & 0 deletions python-spec/src/somacore/query/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ def _read(
self.obs_joinids(),
self.var_joinids(),
index_factory=self._index_factory,
context=self.experiment.context,
).to_scipy()
for _xname in all_x_arrays
}
Expand Down
4 changes: 2 additions & 2 deletions python-spec/src/somacore/query/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import numpy.typing as npt
from typing_extensions import Protocol

from typing import Optional, Any

class IndexLike(Protocol):
"""The basics of what we expect an Index to be.
Expand All @@ -26,7 +26,7 @@ def get_indexer(
"""Something compatible with Pandas' Index.get_indexer method."""


IndexFactory = Callable[[npt.NDArray[np.int64]], "IndexLike"]
IndexFactory = Callable[[npt.NDArray[np.int64], Optional[Any]], "IndexLike"]
"""Function that builds an index over the given NDArray.
This interface is implemented by the callable ``pandas.Index``.
Expand Down

0 comments on commit e808e51

Please sign in to comment.