Skip to content

Commit

Permalink
fix crash in data loader caused by using stale array
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Oct 14, 2024
1 parent fc26c74 commit fef836a
Showing 1 changed file with 106 additions and 95 deletions.
201 changes: 106 additions & 95 deletions src/levanter/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from levanter.data.utils import batched
from levanter.shapes import NamedShapeSpec, ShapeSpec, to_raw_shape
from levanter.utils.background_iterable import BackgroundIterable
from levanter.utils.thread_utils import blocking_wait
from levanter.utils.jax_utils import local_cpu_mesh
from levanter.utils.thread_utils import AsyncIteratorWrapper, blocking_wait


Ex = TypeVar("Ex")
Expand Down Expand Up @@ -98,6 +99,8 @@ def __iter__(self):
return self.iter_from_step(None)

def iter_from_step(self, start_from_batch: Optional[int] = None):
# sometimes we pass in an array for the start_from_batch, so we need to check for that
start_from_batch = int(start_from_batch) if start_from_batch is not None else None
return DataLoaderIterator(self, start_from_batch=start_from_batch)


Expand All @@ -109,115 +112,131 @@ def __init__(self, data_loader: DataLoader, start_from_batch: Optional[int] = No
if self.mapping is None:
self.mapping = hax.partitioning.current_thread_local_mapping()

# TODO: bring back non-prefetching version
buffered_batches = self.dl.max_buffered_batches
self._batches = iter(BackgroundIterable(self._produce_batches, max_capacity=buffered_batches))
if buffered_batches == 0:
self._batches = AsyncIteratorWrapper(self._produce_batches())
else:
self._batches = iter(BackgroundIterable(self._produce_batches, max_capacity=buffered_batches))

def __next__(self):
time_start = time.time()
out = next(self._batches)
individual_data_batch = next(self._batches)
data_for_this_batch = {index: datum for index, datum in zip(self.dl._local_indices, individual_data_batch)}
batch = self._batchify_local_data(data_for_this_batch)

time_end = time.time()
if (time_end - time_start) > 0.5:
logger.info(f"Prefetch wasn't fast enough: {time_end - time_start:.3f}")
return out
return batch

async def _produce_batches(self):
batch_number = self._start_from_batch or 0
total_ex_loaded = 0
done = False
while not done:
next_batch_numbers = []
for i in range(self.dl.prefetch_size):
if self.dl.data_store.is_finite():
next_end = (batch_number + 1) * self.dl.batch_size
available_len = await self.dl.data_store.wait_until_len_at_least(next_end)
if available_len < next_end:
done = True
break

next_batch_numbers.append(batch_number)
batch_number += 1
target_next_batch_number = batch_number + self.dl.prefetch_size
max_achievable_batch_number = await self._dataset_get_available_batch_number(target_next_batch_number)
if max_achievable_batch_number < target_next_batch_number:
done = True

next_batch_numbers = list(range(batch_number, min(target_next_batch_number, max_achievable_batch_number)))

if len(next_batch_numbers) == 0:
break

batch_number = next_batch_numbers[-1] + 1

async for batch in self._retrieve_batches(next_batch_numbers):
yield batch

total_ex_loaded += self.dl.batch_size * len(next_batch_numbers)
async def _dataset_get_available_batch_number(self, target_max_batch_number: int) -> int:
if self.dl.data_store.is_finite():
next_end = (target_max_batch_number + 1) * self.dl.batch_size
available_len = await self.dl.data_store.wait_until_len_at_least(next_end)
max_achievable_batch_number = available_len // self.dl.batch_size

async def _retrieve_batches(self, batch_numbers: list[int]):
with hax.axis_mapping(self.mapping), self.dl.mesh:
indices_for_this_batch_of_batches: list[int] = []
for bn in batch_numbers:
indices_this_batch = range(bn * self.dl.batch_size, (bn + 1) * self.dl.batch_size, 1)
indices_this_batch_this_process = [indices_this_batch[i] for i in self.dl._local_indices]
indices_for_this_batch_of_batches.extend(indices_this_batch_this_process)
return max_achievable_batch_number

return target_max_batch_number

async def _retrieve_batches(self, batch_numbers: list[int]):
with local_cpu_mesh():
time_start = time.time()
individual_datums = await self.dl.data_store.get_batch(indices_for_this_batch_of_batches)
individual_datums_for_each_batch = await self._do_retrieve_batch_of_batches(batch_numbers)
# reshape to be per batch
time_end = time.time()
logger.debug(f"Time to get {len(batch_numbers)} batches: {time_end - time_start:.3f}")
time_start = time.time()
# reshape to be per batch
individual_datums = list(batched(individual_datums, len(self.dl._local_indices)))

# below we're gonna get the indices relative to this batch (i.e. 0 to batch_size)
index_to_datum = [
{index: datum for index, datum in zip(self.dl._local_indices, individual_data_batch)}
for individual_data_batch in individual_datums
]

def get_local_batch(bn: int, begin: int, end: int) -> list:
# TODO: if we ever do "big data" (i.e. huge examples) we might want to be able to load part of an example
# which will require support from the datastore (i.e. tensorstore)
device_batch = _stack_tree(self.dl.Batch.name, [index_to_datum[bn][i] for i in range(begin, end)])
batch_leaves = hax.tree_util.tree_leaves(device_batch)
return batch_leaves

def get_local_data_for_leaf(bn, indices: _TensorSliceIndex, leaf_index: int) -> Array:
batch_slice = indices[0]
begin, end, stride = batch_slice.indices(self.dl.batch_size)
if stride != 1:
raise ValueError("Stride must be 1")

leaf_data = (get_local_batch(bn, begin, end))[leaf_index]

if isinstance(leaf_data, hax.NamedArray):
# select out the batch axis
batch_index = index_where(lambda ax: ax.name == self.dl.Batch.name, leaf_data.axes)
new_indices = list(indices)
new_indices[batch_index] = slice(None)
return leaf_data.array[tuple(new_indices)]

for data in individual_datums_for_each_batch:
yield data

def _batchify_local_data(self, data_for_this_batch: dict[int, Array]):
cache: dict[tuple[int, int], list[Array | hax.NamedArray]] = {}

def get_local_batch(begin: int, end: int) -> list:
if (begin, end) in cache:
return cache[(begin, end)]

# TODO: if we ever do "big data" (i.e. huge examples) we might want to be able to load part of an example
# which will require support from the datastore (i.e. tensorstore)
device_batch = _stack_tree(self.dl.Batch.name, [data_for_this_batch[i] for i in range(begin, end)])
batch_leaves = hax.tree_util.tree_leaves(device_batch)

cache[(begin, end)] = batch_leaves

return batch_leaves

def get_local_data_for_leaf(indices: _TensorSliceIndex, leaf_index: int) -> Array:
batch_slice = indices[0]
begin, end, stride = batch_slice.indices(self.dl.batch_size)
if stride != 1:
raise ValueError("Stride must be 1")

leaf_data = get_local_batch(begin, end)[leaf_index]

if isinstance(leaf_data, hax.NamedArray):
# select out the batch axis
batch_index = index_where(lambda ax: ax.name == self.dl.Batch.name, leaf_data.axes)
new_indices = list(indices)
new_indices[batch_index] = slice(None)
return leaf_data.array[tuple(new_indices)]
else:
other_indices = indices[1:]
if all(idx == slice(None) for idx in other_indices):
return leaf_data
else:
other_indices = indices[1:]
if all(idx == slice(None) for idx in other_indices):
return leaf_data
else:
# TODO: this doesn't work with named axes
return leaf_data[(..., *other_indices)]

for batch_offset, bn in enumerate(batch_numbers):

def make_global_array_for_leaf(leaf_index, item_leaf_shape: ShapeSpec | NamedShapeSpec):
def get_data(indices):
return get_local_data_for_leaf(batch_offset, indices, leaf_index)

raw_array = jax.make_array_from_callback(
to_raw_shape(item_leaf_shape),
jax.sharding.NamedSharding(self.dl.mesh, self._pspec_for(item_leaf_shape)),
get_data,
)
if isinstance(item_leaf_shape, NamedShapeSpec):
return hax.NamedArray(raw_array, item_leaf_shape.shape)
else:
return raw_array

gda_leaves = [
make_global_array_for_leaf(leaf_index, _batchified_shape(self.dl.Batch, item_leaf))
for leaf_index, item_leaf in enumerate(self.dl._ex_leaves)
]

gda_tree = jax.tree.unflatten(self.dl._ex_structure, gda_leaves)
yield gda_tree
# TODO: this doesn't work with named axes
return leaf_data[(..., *other_indices)]

def make_global_array_for_leaf(leaf_index, item_leaf_shape: ShapeSpec | NamedShapeSpec):
def get_data(indices):
return get_local_data_for_leaf(indices, leaf_index)

raw_array = jax.make_array_from_callback(
to_raw_shape(item_leaf_shape),
jax.sharding.NamedSharding(self.dl.mesh, self._pspec_for(item_leaf_shape)),
get_data,
)
if isinstance(item_leaf_shape, NamedShapeSpec):
return hax.NamedArray(raw_array, item_leaf_shape.shape)
else:
return raw_array

gda_leaves = [
make_global_array_for_leaf(leaf_index, _batchified_shape(self.dl.Batch, item_leaf))
for leaf_index, item_leaf in enumerate(self.dl._ex_leaves)
]
gda_tree = jax.tree.unflatten(self.dl._ex_structure, gda_leaves)
return gda_tree

async def _do_retrieve_batch_of_batches(self, batch_numbers):
indices_for_this_batch_of_batches: list[int] = []
for bn in batch_numbers:
indices_this_batch = range(bn * self.dl.batch_size, (bn + 1) * self.dl.batch_size, 1)
indices_this_batch_this_process = [indices_this_batch[i] for i in self.dl._local_indices]
indices_for_this_batch_of_batches.extend(indices_this_batch_this_process)
individual_datums = await self.dl.data_store.get_batch(indices_for_this_batch_of_batches)
individual_datums_for_each_batch = list(batched(individual_datums, len(self.dl._local_indices)))
return individual_datums_for_each_batch

def _pspec_for(self, shape_spec: ShapeSpec | NamedShapeSpec) -> PartitionSpec:
if isinstance(shape_spec, ShapeSpec): # type: ignore
Expand Down Expand Up @@ -246,14 +265,6 @@ def _batchified_shape(Batch, leaf: hax.NamedArray | Array) -> ShapeSpec | NamedS
return ShapeSpec((Batch.size,) + leaf.shape, leaf.dtype)


def _pspec_for(self, shape_spec: ShapeSpec | NamedShapeSpec) -> PartitionSpec:
if isinstance(shape_spec, ShapeSpec): # type: ignore
batch_name = hax.partitioning.physical_axis_name(self.Batch, self.axis_resources)
return PartitionSpec(batch_name, *((None,) * (len(shape_spec.shape) - 1)))
else:
return hax.partitioning.pspec_for_axis(shape_spec.shape, self.axis_resources) # type: ignore


@functools.partial(jax.jit, static_argnums=(0,))
def _stack_tree(batch_name, individual_datums):
def _stack_leaves_unchecked(*leaves):
Expand Down

0 comments on commit fef836a

Please sign in to comment.