diff --git a/src/levanter/data/loader.py b/src/levanter/data/loader.py index 320e8266d..fdecfa245 100644 --- a/src/levanter/data/loader.py +++ b/src/levanter/data/loader.py @@ -99,6 +99,8 @@ def __iter__(self): return self.iter_from_step(None) def iter_from_step(self, start_from_batch: Optional[int] = None): + # sometimes we pass in an array for the start_from_batch, so we need to check for that + start_from_batch = int(start_from_batch) if start_from_batch is not None else None return DataLoaderIterator(self, start_from_batch=start_from_batch)