Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Catch OSError for arrow #7348

Merged
merged 3 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1990,7 +1990,7 @@ def flatten(self, new_fingerprint: Optional[str] = None, max_depth=16) -> "Datas
dataset.info.features = self._info.features.flatten(max_depth=max_depth)
dataset.info.features = Features({col: dataset.info.features[col] for col in dataset.data.column_names})
dataset._data = update_metadata_with_features(dataset._data, dataset.features)
logger.info(f'Flattened dataset from depth {depth} to depth {1 if depth + 1 < max_depth else "unknown"}.')
logger.info(f"Flattened dataset from depth {depth} to depth {1 if depth + 1 < max_depth else 'unknown'}.")
dataset._fingerprint = new_fingerprint
return dataset

Expand Down Expand Up @@ -3176,9 +3176,9 @@ def format_new_fingerprint(new_fingerprint: str, rank: int) -> str:
del kwargs["shard"]
else:
logger.info(f"Loading cached processed dataset at {format_cache_file_name(cache_file_name, '*')}")
assert (
None not in transformed_shards
), f"Failed to retrieve results from map: result list {transformed_shards} still contains None - at least one worker failed to return its results"
assert None not in transformed_shards, (
f"Failed to retrieve results from map: result list {transformed_shards} still contains None - at least one worker failed to return its results"
)
logger.info(f"Concatenating {num_proc} shards")
result = _concatenate_map_style_datasets(transformed_shards)
# update fingerprint if the dataset changed
Expand Down Expand Up @@ -5651,7 +5651,7 @@ def push_to_hub(
create_pr=create_pr,
)
logger.info(
f"Commit #{i+1} completed"
f"Commit #{i + 1} completed"
+ (f" (still {num_commits - i - 1} to go)" if num_commits - i - 1 else "")
+ "."
)
Expand Down
14 changes: 7 additions & 7 deletions src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,7 +1114,7 @@ def as_dataset(
"datasets.load_dataset() before trying to access the Dataset object."
)

logger.debug(f'Constructing Dataset for split {split or ", ".join(self.info.splits)}, from {self._output_dir}')
logger.debug(f"Constructing Dataset for split {split or ', '.join(self.info.splits)}, from {self._output_dir}")

# By default, return all splits
if split is None:
Expand Down Expand Up @@ -1528,9 +1528,9 @@ def _prepare_split(
# the content is the number of examples progress update
pbar.update(content)

assert (
None not in examples_per_job
), f"Failed to retrieve results from prepare_split: result list {examples_per_job} still contains None - at least one worker failed to return its results"
assert None not in examples_per_job, (
f"Failed to retrieve results from prepare_split: result list {examples_per_job} still contains None - at least one worker failed to return its results"
)

total_shards = sum(shards_per_job)
total_num_examples = sum(examples_per_job)
Expand Down Expand Up @@ -1783,9 +1783,9 @@ def _prepare_split(
# the content is the number of examples progress update
pbar.update(content)

assert (
None not in examples_per_job
), f"Failed to retrieve results from prepare_split: result list {examples_per_job} still contains None - at least one worker failed to return its results"
assert None not in examples_per_job, (
f"Failed to retrieve results from prepare_split: result list {examples_per_job} still contains None - at least one worker failed to return its results"
)

total_shards = sum(shards_per_job)
total_num_examples = sum(examples_per_job)
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1802,7 +1802,7 @@ def push_to_hub(
create_pr=create_pr,
)
logger.info(
f"Commit #{i+1} completed"
f"Commit #{i + 1} completed"
+ (f" (still {num_commits - i - 1} to go)" if num_commits - i - 1 else "")
+ "."
)
Expand Down
4 changes: 2 additions & 2 deletions src/datasets/features/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -2168,8 +2168,8 @@ def recursive_reorder(source, target, stack=""):
if sorted(source) != sorted(target):
message = (
f"Keys mismatch: between {source} (source) and {target} (target).\n"
f"{source.keys()-target.keys()} are missing from target "
f"and {target.keys()-source.keys()} are missing from source" + stack_position
f"{source.keys() - target.keys()} are missing from target "
f"and {target.keys() - source.keys()} are missing from source" + stack_position
)
raise ValueError(message)
return {key: recursive_reorder(source[key], target[key], stack + f".{key}") for key in target}
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/features/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def encode_example(self, translation_dict):
return translation_dict
elif self.languages and set(translation_dict) - lang_set:
raise ValueError(
f'Some languages in example ({", ".join(sorted(set(translation_dict) - lang_set))}) are not in valid set ({", ".join(lang_set)}).'
f"Some languages in example ({', '.join(sorted(set(translation_dict) - lang_set))}) are not in valid set ({', '.join(lang_set)})."
)

# Convert dictionary into tuples, splitting out cases where there are
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/fingerprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def generate_fingerprint(dataset: "Dataset") -> str:


def generate_random_fingerprint(nbits: int = 64) -> str:
return f"{fingerprint_rng.getrandbits(nbits):0{nbits//4}x}"
return f"{fingerprint_rng.getrandbits(nbits):0{nbits // 4}x}"


def update_fingerprint(fingerprint, transform, transform_args):
Expand Down
4 changes: 2 additions & 2 deletions src/datasets/packaged_modules/arrow/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _split_generators(self, dl_manager):
with open(file, "rb") as f:
try:
reader = pa.ipc.open_stream(f)
except pa.lib.ArrowInvalid:
except (OSError, pa.lib.ArrowInvalid):
reader = pa.ipc.open_file(f)
self.info.features = datasets.Features.from_arrow_schema(reader.schema)
break
Expand All @@ -65,7 +65,7 @@ def _generate_tables(self, files):
try:
try:
batches = pa.ipc.open_stream(f)
except pa.lib.ArrowInvalid:
except (OSError, pa.lib.ArrowInvalid):
reader = pa.ipc.open_file(f)
batches = (reader.get_batch(i) for i in range(reader.num_record_batches))
for batch_idx, record_batch in enumerate(batches):
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def passage_generator():
successes += ok
if successes != len(documents):
logger.warning(
f"Some documents failed to be added to ElasticSearch. Failures: {len(documents)-successes}/{len(documents)}"
f"Some documents failed to be added to ElasticSearch. Failures: {len(documents) - successes}/{len(documents)}"
)
logger.info(f"Indexed {successes:d} documents")

Expand Down
2 changes: 1 addition & 1 deletion src/datasets/splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ class SplitReadInstruction:
"""

def __init__(self, split_info=None):
self._splits = NonMutableDict(error_msg="Overlap between splits. Split {key} has been added with " "itself.")
self._splits = NonMutableDict(error_msg="Overlap between splits. Split {key} has been added with itself.")

if split_info:
self.add(SlicedSplitInfo(split_info=split_info, slice_value=None))
Expand Down
3 changes: 1 addition & 2 deletions src/datasets/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def _get_default_logging_level():
return log_levels[env_level_str]
else:
logging.getLogger().warning(
f"Unknown option DATASETS_VERBOSITY={env_level_str}, "
f"has to be one of: { ', '.join(log_levels.keys()) }"
f"Unknown option DATASETS_VERBOSITY={env_level_str}, has to be one of: {', '.join(log_levels.keys())}"
)
return _default_log_level

Expand Down
4 changes: 2 additions & 2 deletions src/datasets/utils/stratify.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ def stratified_shuffle_split_generate_indices(y, n_train, n_test, rng, n_splits=
raise ValueError("Minimum class count error")
if n_train < n_classes:
raise ValueError(
"The train_size = %d should be greater or " "equal to the number of classes = %d" % (n_train, n_classes)
"The train_size = %d should be greater or equal to the number of classes = %d" % (n_train, n_classes)
)
if n_test < n_classes:
raise ValueError(
"The test_size = %d should be greater or " "equal to the number of classes = %d" % (n_test, n_classes)
"The test_size = %d should be greater or equal to the number of classes = %d" % (n_test, n_classes)
)
class_indices = np.split(np.argsort(y_indices, kind="mergesort"), np.cumsum(class_counts)[:-1])
for _ in range(n_splits):
Expand Down
24 changes: 12 additions & 12 deletions tests/test_iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,9 +1093,9 @@ def test_skip_examples_iterable():
skip_ex_iterable = SkipExamplesIterable(base_ex_iterable, n=count)
expected = list(generate_examples_fn(n=total))[count:]
assert list(skip_ex_iterable) == expected
assert (
skip_ex_iterable.shuffle_data_sources(np.random.default_rng(42)) is skip_ex_iterable
), "skip examples makes the shards order fixed"
assert skip_ex_iterable.shuffle_data_sources(np.random.default_rng(42)) is skip_ex_iterable, (
"skip examples makes the shards order fixed"
)
assert_load_state_dict_resumes_iteration(skip_ex_iterable)


Expand All @@ -1105,9 +1105,9 @@ def test_take_examples_iterable():
take_ex_iterable = TakeExamplesIterable(base_ex_iterable, n=count)
expected = list(generate_examples_fn(n=total))[:count]
assert list(take_ex_iterable) == expected
assert (
take_ex_iterable.shuffle_data_sources(np.random.default_rng(42)) is take_ex_iterable
), "skip examples makes the shards order fixed"
assert take_ex_iterable.shuffle_data_sources(np.random.default_rng(42)) is take_ex_iterable, (
"skip examples makes the shards order fixed"
)
assert_load_state_dict_resumes_iteration(take_ex_iterable)


Expand Down Expand Up @@ -1155,9 +1155,9 @@ def test_horizontally_concatenated_examples_iterable():
concatenated_ex_iterable = HorizontallyConcatenatedMultiSourcesExamplesIterable([ex_iterable1, ex_iterable2])
expected = [{**x, **y} for (_, x), (_, y) in zip(ex_iterable1, ex_iterable2)]
assert [x for _, x in concatenated_ex_iterable] == expected
assert (
concatenated_ex_iterable.shuffle_data_sources(np.random.default_rng(42)) is concatenated_ex_iterable
), "horizontally concatenated examples makes the shards order fixed"
assert concatenated_ex_iterable.shuffle_data_sources(np.random.default_rng(42)) is concatenated_ex_iterable, (
"horizontally concatenated examples makes the shards order fixed"
)
assert_load_state_dict_resumes_iteration(concatenated_ex_iterable)


Expand Down Expand Up @@ -2217,7 +2217,7 @@ def test_iterable_dataset_batch():
assert len(batch["id"]) == 3
assert len(batch["text"]) == 3
assert batch["id"] == [3 * i, 3 * i + 1, 3 * i + 2]
assert batch["text"] == [f"Text {3*i}", f"Text {3*i+1}", f"Text {3*i+2}"]
assert batch["text"] == [f"Text {3 * i}", f"Text {3 * i + 1}", f"Text {3 * i + 2}"]

# Check last partial batch
assert len(batches[3]["id"]) == 1
Expand All @@ -2234,7 +2234,7 @@ def test_iterable_dataset_batch():
assert len(batch["id"]) == 3
assert len(batch["text"]) == 3
assert batch["id"] == [3 * i, 3 * i + 1, 3 * i + 2]
assert batch["text"] == [f"Text {3*i}", f"Text {3*i+1}", f"Text {3*i+2}"]
assert batch["text"] == [f"Text {3 * i}", f"Text {3 * i + 1}", f"Text {3 * i + 2}"]

# Test with batch_size=4 (doesn't evenly divide dataset size)
batched_ds = ds.batch(batch_size=4, drop_last_batch=False)
Expand All @@ -2245,7 +2245,7 @@ def test_iterable_dataset_batch():
assert len(batch["id"]) == 4
assert len(batch["text"]) == 4
assert batch["id"] == [4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3]
assert batch["text"] == [f"Text {4*i}", f"Text {4*i+1}", f"Text {4*i+2}", f"Text {4*i+3}"]
assert batch["text"] == [f"Text {4 * i}", f"Text {4 * i + 1}", f"Text {4 * i + 2}", f"Text {4 * i + 3}"]

# Check last partial batch
assert len(batches[2]["id"]) == 2
Expand Down
Loading