Skip to content

Commit

Permalink
[python] Link lifetimes of SOMAArray and ManagedQuery (#3516)
Browse files Browse the repository at this point in the history
  • Loading branch information
nguyenv authored and github-actions[bot] committed Jan 7, 2025
1 parent c757020 commit 159ed72
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 55 deletions.
12 changes: 6 additions & 6 deletions apis/python/src/tiledbsoma/_dense_nd_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ._arrow_types import pyarrow_to_carrow_type
from ._common_nd_array import NDArray
from ._exception import SOMAError, map_exception_for_create
from ._read_iters import TableReadIter
from ._read_iters import ManagedQuery, TableReadIter
from ._tdb_handles import DenseNDArrayWrapper
from ._types import OpenTimestamp, Slice
from ._util import dense_indices_to_shape
Expand Down Expand Up @@ -313,11 +313,11 @@ def write(
input = np.ascontiguousarray(input)
order = clib.ResultOrder.rowmajor

mq = clib.ManagedQuery(clib_handle, clib_handle.context())
mq.set_layout(order)
_util._set_coords(mq, clib_handle, new_coords)
mq.set_soma_data(input)
mq.submit_write()
mq = ManagedQuery(self, platform_config)
mq._handle.set_layout(order)
_util._set_coords(mq, new_coords)
mq._handle.set_soma_data(input)
mq._handle.submit_write()

tiledb_write_options = TileDBWriteOptions.from_platform_config(platform_config)
if tiledb_write_options.consolidate_and_vacuum:
Expand Down
60 changes: 34 additions & 26 deletions apis/python/src/tiledbsoma/_read_iters.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
cast,
)

import attrs
import numpy as np
import numpy.typing as npt
import pyarrow as pa
Expand Down Expand Up @@ -470,6 +471,27 @@ def _cs_reader(
yield sp, indices


@attrs.define(frozen=True)
class ManagedQuery:
"""Keep the lifetime of the SOMAArray tethered to ManagedQuery."""

_array: SOMAArray
_platform_config: options.PlatformConfig | None
_handle: clib.ManagedQuery = attrs.field(init=False)

def __attrs_post_init__(self) -> None:
clib_handle = self._array._handle._handle

if self._platform_config is not None:
cfg = clib_handle.context().config()
cfg.update(self._platform_config)
ctx = clib.SOMAContext(cfg)
else:
ctx = clib_handle.context()

object.__setattr__(self, "_handle", clib.ManagedQuery(clib_handle, ctx))


class SparseTensorReadIterBase(somacore.ReadIter[_RT], metaclass=abc.ABCMeta):
"""Private implementation class"""

Expand All @@ -487,27 +509,18 @@ def __init__(
self.result_order = result_order
self.platform_config = platform_config

clib_handle = array._handle._handle
self.mq = ManagedQuery(array, platform_config)

if platform_config is not None:
cfg = clib_handle.context().config()
cfg.update(platform_config)
ctx = clib.SOMAContext(cfg)
else:
ctx = clib_handle.context()
self.mq._handle.set_layout(result_order)

self.mq = clib.ManagedQuery(clib_handle, ctx)

self.mq.set_layout(result_order)

_util._set_coords(self.mq, clib_handle, coords)
_util._set_coords(self.mq, coords)

@abc.abstractmethod
def _from_table(self, arrow_table: pa.Table) -> _RT:
raise NotImplementedError()

def __next__(self) -> _RT:
return self._from_table(self.mq.next())
return self._from_table(self.mq._handle.next())

def concat(self) -> _RT:
"""Returns all the requested data in a single operation.
Expand Down Expand Up @@ -556,27 +569,22 @@ def __init__(
):
clib_handle = array._handle._handle

if platform_config is not None:
cfg = clib_handle.context().config()
cfg.update(platform_config)
ctx = clib.SOMAContext(cfg)
else:
ctx = clib_handle.context()
self.mq = ManagedQuery(array, platform_config)

self.mq = clib.ManagedQuery(clib_handle, ctx)

self.mq.set_layout(result_order)
self.mq._handle.set_layout(result_order)

if column_names is not None:
self.mq.select_columns(list(column_names))
self.mq._handle.select_columns(list(column_names))

if value_filter is not None:
self.mq.set_condition(QueryCondition(value_filter), clib_handle.schema)
self.mq._handle.set_condition(
QueryCondition(value_filter), clib_handle.schema
)

_util._set_coords(self.mq, clib_handle, coords)
_util._set_coords(self.mq, coords)

def __next__(self) -> pa.Table:
return self.mq.next()
return self.mq._handle.next()


def _coords_strider(
Expand Down
46 changes: 23 additions & 23 deletions apis/python/src/tiledbsoma/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from concurrent.futures import Future
from itertools import zip_longest
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Expand All @@ -38,6 +39,9 @@
_DictFilterSpec,
)

if TYPE_CHECKING:
from ._read_iters import ManagedQuery

Check warning on line 43 in apis/python/src/tiledbsoma/_util.py

View check run for this annotation

Codecov / codecov/patch

apis/python/src/tiledbsoma/_util.py#L43

Added line #L43 was not covered by tests

_JSONFilter = Union[str, Dict[str, Union[str, Union[int, float]]]]
_JSONFilterList = Union[str, List[_JSONFilter]]

Expand Down Expand Up @@ -465,48 +469,44 @@ def _cast_domainish(domainish: List[Any]) -> Tuple[Tuple[object, object], ...]:
return tuple(result)


def _set_coords(
mq: clib.ManagedQuery, sarr: clib.SOMAArray, coords: options.SparseNDCoords
) -> None:
def _set_coords(mq: ManagedQuery, coords: options.SparseNDCoords) -> None:
if not is_nonstringy_sequence(coords):
raise TypeError(
f"coords type {type(coords)} must be a regular sequence,"
" not str or bytes"
)

if len(coords) > len(sarr.dimension_names):
if len(coords) > len(mq._array._handle._handle.dimension_names):
raise ValueError(
f"coords ({len(coords)} elements) must be shorter than ndim"
f" ({len(sarr.dimension_names)})"
f" ({len(mq._array._handle._handle.dimension_names)})"
)

for i, coord in enumerate(coords):
_set_coord(i, mq, sarr, coord)
_set_coord(i, mq, coord)


def _set_coord(
dim_idx: int, mq: clib.ManagedQuery, sarr: clib.SOMAArray, coord: object
) -> None:
def _set_coord(dim_idx: int, mq: ManagedQuery, coord: object) -> None:
if coord is None:
return

dim = sarr.schema.field(dim_idx)
dom = _cast_domainish(sarr.domain())[dim_idx]
dim = mq._array._handle._handle.schema.field(dim_idx)
dom = _cast_domainish(mq._array._handle._handle.domain())[dim_idx]

if isinstance(coord, (str, bytes)):
mq.set_dim_points_string_or_bytes(dim.name, [coord])
mq._handle.set_dim_points_string_or_bytes(dim.name, [coord])
return

if isinstance(coord, (pa.Array, pa.ChunkedArray)):
mq.set_dim_points_arrow(dim.name, coord)
mq._handle.set_dim_points_arrow(dim.name, coord)
return

if isinstance(coord, (Sequence, np.ndarray)):
_set_coord_by_py_seq_or_np_array(mq, dim, coord)
return

if isinstance(coord, int):
mq.set_dim_points_int64(dim.name, [coord])
mq._handle.set_dim_points_int64(dim.name, [coord])
return

# Note: slice(None, None) matches the is_slice_of part, unless we also check
Expand All @@ -521,11 +521,11 @@ def _set_coord(
if coord.stop is None:
# There's no way to specify "to infinity" for strings.
# We have to get the nonempty domain and use that as the end.\
ned = _cast_domainish(sarr.non_empty_domain())
ned = _cast_domainish(mq._array._handle._handle.non_empty_domain())
_, stop = ned[dim_idx]
else:
stop = coord.stop
mq.set_dim_ranges_string_or_bytes(dim.name, [(start, stop)])
mq._handle.set_dim_ranges_string_or_bytes(dim.name, [(start, stop)])
return

# Note: slice(None, None) matches the is_slice_of part, unless we also check
Expand All @@ -548,7 +548,7 @@ def _set_coord(
else:
istop = ts_dom[1].as_py()

mq.set_dim_ranges_int64(dim.name, [(istart, istop)])
mq._handle.set_dim_ranges_int64(dim.name, [(istart, istop)])
return

if isinstance(coord, slice):
Expand All @@ -562,7 +562,7 @@ def _set_coord(


def _set_coord_by_py_seq_or_np_array(
mq: clib.ManagedQuery, dim: pa.Field, coord: object
mq: ManagedQuery, dim: pa.Field, coord: object
) -> None:
if isinstance(coord, np.ndarray):
if coord.ndim != 1:
Expand All @@ -571,7 +571,7 @@ def _set_coord_by_py_seq_or_np_array(
)

try:
set_dim_points = getattr(mq, f"set_dim_points_{dim.type}")
set_dim_points = getattr(mq._handle, f"set_dim_points_{dim.type}")
except AttributeError:
# We have to handle this type specially below
pass
Expand All @@ -580,7 +580,7 @@ def _set_coord_by_py_seq_or_np_array(
return

if pa_types_is_string_or_bytes(dim.type):
mq.set_dim_points_string_or_bytes(dim.name, coord)
mq._handle.set_dim_points_string_or_bytes(dim.name, coord)
return

if pa.types.is_timestamp(dim.type):
Expand All @@ -591,14 +591,14 @@ def _set_coord_by_py_seq_or_np_array(
icoord = [
int(e.astype("int64")) if isinstance(e, np.datetime64) else e for e in coord
]
mq.set_dim_points_int64(dim.name, icoord)
mq._handle.set_dim_points_int64(dim.name, icoord)
return

raise ValueError(f"unhandled type {dim.type} for index column named {dim.name}")


def _set_coord_by_numeric_slice(
mq: clib.ManagedQuery, dim: pa.Field, dom: Tuple[object, object], coord: Slice[Any]
mq: ManagedQuery, dim: pa.Field, dom: Tuple[object, object], coord: Slice[Any]
) -> None:
try:
lo_hi = slice_to_numeric_range(coord, dom)
Expand All @@ -609,7 +609,7 @@ def _set_coord_by_numeric_slice(
return

try:
set_dim_range = getattr(mq, f"set_dim_ranges_{dim.type}")
set_dim_range = getattr(mq._handle, f"set_dim_ranges_{dim.type}")
set_dim_range(dim.name, [lo_hi])
return
except AttributeError:
Expand Down
27 changes: 27 additions & 0 deletions apis/python/tests/test_sparse_nd_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1952,3 +1952,30 @@ def test_pass_configs(tmp_path):
}
).tables()
)


def test_iter(tmp_path: pathlib.Path):
arrow_tensor = create_random_tensor("table", (1,), np.float32(), density=1)

with soma.SparseNDArray.create(
tmp_path.as_uri(), type=pa.float64(), shape=(1,)
) as write_arr:
write_arr.write(arrow_tensor)

# Verify that the SOMAArray stays open as long as the ManagedQuery
# (i.e., `next`) is still active
a = soma.open(tmp_path.as_uri(), mode="r").read().tables()
assert next(a)
with pytest.raises(StopIteration):
next(a)

# Open two instances of the same array. Iterating through one should not
# affect the other
a = soma.open(tmp_path.as_uri(), mode="r").read().tables()
b = soma.open(tmp_path.as_uri(), mode="r").read().tables()
assert next(a)
assert next(b)
with pytest.raises(StopIteration):
next(a)
with pytest.raises(StopIteration):
next(b)

0 comments on commit 159ed72

Please sign in to comment.