diff --git a/src/levanter/data/loader.py b/src/levanter/data/loader.py index ab97e0827..fdecfa245 100644 --- a/src/levanter/data/loader.py +++ b/src/levanter/data/loader.py @@ -21,7 +21,8 @@ from levanter.data.utils import batched from levanter.shapes import NamedShapeSpec, ShapeSpec, to_raw_shape from levanter.utils.background_iterable import BackgroundIterable -from levanter.utils.thread_utils import blocking_wait +from levanter.utils.jax_utils import local_cpu_mesh +from levanter.utils.thread_utils import AsyncIteratorWrapper, blocking_wait Ex = TypeVar("Ex") @@ -98,6 +99,8 @@ def __iter__(self): return self.iter_from_step(None) def iter_from_step(self, start_from_batch: Optional[int] = None): + # sometimes we pass in an array for the start_from_batch, so we need to check for that + start_from_batch = int(start_from_batch) if start_from_batch is not None else None return DataLoaderIterator(self, start_from_batch=start_from_batch) @@ -109,115 +112,131 @@ def __init__(self, data_loader: DataLoader, start_from_batch: Optional[int] = No if self.mapping is None: self.mapping = hax.partitioning.current_thread_local_mapping() - # TODO: bring back non-prefetching version buffered_batches = self.dl.max_buffered_batches - self._batches = iter(BackgroundIterable(self._produce_batches, max_capacity=buffered_batches)) + if buffered_batches == 0: + self._batches = AsyncIteratorWrapper(self._produce_batches()) + else: + self._batches = iter(BackgroundIterable(self._produce_batches, max_capacity=buffered_batches)) def __next__(self): time_start = time.time() - out = next(self._batches) + individual_data_batch = next(self._batches) + data_for_this_batch = {index: datum for index, datum in zip(self.dl._local_indices, individual_data_batch)} + batch = self._batchify_local_data(data_for_this_batch) + time_end = time.time() if (time_end - time_start) > 0.5: logger.info(f"Prefetch wasn't fast enough: {time_end - time_start:.3f}") - return out + return batch async def _produce_batches(self): batch_number = self._start_from_batch or 0 - total_ex_loaded = 0 done = False while not done: - next_batch_numbers = [] - for i in range(self.dl.prefetch_size): - if self.dl.data_store.is_finite(): - next_end = (batch_number + 1) * self.dl.batch_size - available_len = await self.dl.data_store.wait_until_len_at_least(next_end) - if available_len < next_end: - done = True - break - - next_batch_numbers.append(batch_number) - batch_number += 1 + target_next_batch_number = batch_number + self.dl.prefetch_size + max_achievable_batch_number = await self._dataset_get_available_batch_number(target_next_batch_number) + if max_achievable_batch_number < target_next_batch_number: + done = True + + next_batch_numbers = list(range(batch_number, min(target_next_batch_number, max_achievable_batch_number))) + + if len(next_batch_numbers) == 0: + break + + batch_number = next_batch_numbers[-1] + 1 async for batch in self._retrieve_batches(next_batch_numbers): yield batch - total_ex_loaded += self.dl.batch_size * len(next_batch_numbers) + async def _dataset_get_available_batch_number(self, target_max_batch_number: int) -> int: + if self.dl.data_store.is_finite(): + next_end = (target_max_batch_number + 1) * self.dl.batch_size + available_len = await self.dl.data_store.wait_until_len_at_least(next_end) + max_achievable_batch_number = available_len // self.dl.batch_size - async def _retrieve_batches(self, batch_numbers: list[int]): - with hax.axis_mapping(self.mapping), self.dl.mesh: - indices_for_this_batch_of_batches: list[int] = [] - for bn in batch_numbers: - indices_this_batch = range(bn * self.dl.batch_size, (bn + 1) * self.dl.batch_size, 1) - indices_this_batch_this_process = [indices_this_batch[i] for i in self.dl._local_indices] - indices_for_this_batch_of_batches.extend(indices_this_batch_this_process) + return max_achievable_batch_number + + return target_max_batch_number + async def _retrieve_batches(self, batch_numbers: list[int]): + with local_cpu_mesh(): time_start = time.time() - individual_datums = await self.dl.data_store.get_batch(indices_for_this_batch_of_batches) + individual_datums_for_each_batch = await self._do_retrieve_batch_of_batches(batch_numbers) + # reshape to be per batch time_end = time.time() logger.debug(f"Time to get {len(batch_numbers)} batches: {time_end - time_start:.3f}") - time_start = time.time() - # reshape to be per batch - individual_datums = list(batched(individual_datums, len(self.dl._local_indices))) - - # below we're gonna get the indices relative to this batch (i.e. 0 to batch_size) - index_to_datum = [ - {index: datum for index, datum in zip(self.dl._local_indices, individual_data_batch)} - for individual_data_batch in individual_datums - ] - - def get_local_batch(bn: int, begin: int, end: int) -> list: - # TODO: if we ever do "big data" (i.e. huge examples) we might want to be able to load part of an example - # which will require support from the datastore (i.e. tensorstore) - device_batch = _stack_tree(self.dl.Batch.name, [index_to_datum[bn][i] for i in range(begin, end)]) - batch_leaves = hax.tree_util.tree_leaves(device_batch) - return batch_leaves - - def get_local_data_for_leaf(bn, indices: _TensorSliceIndex, leaf_index: int) -> Array: - batch_slice = indices[0] - begin, end, stride = batch_slice.indices(self.dl.batch_size) - if stride != 1: - raise ValueError("Stride must be 1") - - leaf_data = (get_local_batch(bn, begin, end))[leaf_index] - - if isinstance(leaf_data, hax.NamedArray): - # select out the batch axis - batch_index = index_where(lambda ax: ax.name == self.dl.Batch.name, leaf_data.axes) - new_indices = list(indices) - new_indices[batch_index] = slice(None) - return leaf_data.array[tuple(new_indices)] + for data in individual_datums_for_each_batch: + yield data + + def _batchify_local_data(self, data_for_this_batch: dict[int, Array]): + cache: dict[tuple[int, int], list[Array | hax.NamedArray]] = {} + + def get_local_batch(begin: int, end: int) -> list: + if (begin, end) in cache: + return cache[(begin, end)] + + # TODO: if we ever do "big data" (i.e. huge examples) we might want to be able to load part of an example + # which will require support from the datastore (i.e. tensorstore) + device_batch = _stack_tree(self.dl.Batch.name, [data_for_this_batch[i] for i in range(begin, end)]) + batch_leaves = hax.tree_util.tree_leaves(device_batch) + + cache[(begin, end)] = batch_leaves + + return batch_leaves + + def get_local_data_for_leaf(indices: _TensorSliceIndex, leaf_index: int) -> Array: + batch_slice = indices[0] + begin, end, stride = batch_slice.indices(self.dl.batch_size) + if stride != 1: + raise ValueError("Stride must be 1") + + leaf_data = get_local_batch(begin, end)[leaf_index] + + if isinstance(leaf_data, hax.NamedArray): + # select out the batch axis + batch_index = index_where(lambda ax: ax.name == self.dl.Batch.name, leaf_data.axes) + new_indices = list(indices) + new_indices[batch_index] = slice(None) + return leaf_data.array[tuple(new_indices)] + else: + other_indices = indices[1:] + if all(idx == slice(None) for idx in other_indices): + return leaf_data else: - other_indices = indices[1:] - if all(idx == slice(None) for idx in other_indices): - return leaf_data - else: - # TODO: this doesn't work with named axes - return leaf_data[(..., *other_indices)] - - for batch_offset, bn in enumerate(batch_numbers): - - def make_global_array_for_leaf(leaf_index, item_leaf_shape: ShapeSpec | NamedShapeSpec): - def get_data(indices): - return get_local_data_for_leaf(batch_offset, indices, leaf_index) - - raw_array = jax.make_array_from_callback( - to_raw_shape(item_leaf_shape), - jax.sharding.NamedSharding(self.dl.mesh, self._pspec_for(item_leaf_shape)), - get_data, - ) - if isinstance(item_leaf_shape, NamedShapeSpec): - return hax.NamedArray(raw_array, item_leaf_shape.shape) - else: - return raw_array - - gda_leaves = [ - make_global_array_for_leaf(leaf_index, _batchified_shape(self.dl.Batch, item_leaf)) - for leaf_index, item_leaf in enumerate(self.dl._ex_leaves) - ] - - gda_tree = jax.tree.unflatten(self.dl._ex_structure, gda_leaves) - yield gda_tree + # TODO: this doesn't work with named axes + return leaf_data[(..., *other_indices)] + + def make_global_array_for_leaf(leaf_index, item_leaf_shape: ShapeSpec | NamedShapeSpec): + def get_data(indices): + return get_local_data_for_leaf(indices, leaf_index) + + raw_array = jax.make_array_from_callback( + to_raw_shape(item_leaf_shape), + jax.sharding.NamedSharding(self.dl.mesh, self._pspec_for(item_leaf_shape)), + get_data, + ) + if isinstance(item_leaf_shape, NamedShapeSpec): + return hax.NamedArray(raw_array, item_leaf_shape.shape) + else: + return raw_array + + gda_leaves = [ + make_global_array_for_leaf(leaf_index, _batchified_shape(self.dl.Batch, item_leaf)) + for leaf_index, item_leaf in enumerate(self.dl._ex_leaves) + ] + gda_tree = jax.tree.unflatten(self.dl._ex_structure, gda_leaves) + return gda_tree + + async def _do_retrieve_batch_of_batches(self, batch_numbers): + indices_for_this_batch_of_batches: list[int] = [] + for bn in batch_numbers: + indices_this_batch = range(bn * self.dl.batch_size, (bn + 1) * self.dl.batch_size, 1) + indices_this_batch_this_process = [indices_this_batch[i] for i in self.dl._local_indices] + indices_for_this_batch_of_batches.extend(indices_this_batch_this_process) + individual_datums = await self.dl.data_store.get_batch(indices_for_this_batch_of_batches) + individual_datums_for_each_batch = list(batched(individual_datums, len(self.dl._local_indices))) + return individual_datums_for_each_batch def _pspec_for(self, shape_spec: ShapeSpec | NamedShapeSpec) -> PartitionSpec: if isinstance(shape_spec, ShapeSpec): # type: ignore @@ -246,14 +265,6 @@ def _batchified_shape(Batch, leaf: hax.NamedArray | Array) -> ShapeSpec | NamedS return ShapeSpec((Batch.size,) + leaf.shape, leaf.dtype) -def _pspec_for(self, shape_spec: ShapeSpec | NamedShapeSpec) -> PartitionSpec: - if isinstance(shape_spec, ShapeSpec): # type: ignore - batch_name = hax.partitioning.physical_axis_name(self.Batch, self.axis_resources) - return PartitionSpec(batch_name, *((None,) * (len(shape_spec.shape) - 1))) - else: - return hax.partitioning.pspec_for_axis(shape_spec.shape, self.axis_resources) # type: ignore - - @functools.partial(jax.jit, static_argnums=(0,)) def _stack_tree(batch_name, individual_datums): def _stack_leaves_unchecked(*leaves):