diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 46d88cacb04..a732942a7ee 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -3352,15 +3352,12 @@ def apply_function_on_filtered_inputs(pa_inputs, indices, check_same_num_example validate_function_output(processed_inputs, indices) if not update_data: return None # Nothing to update, let's move on - if shard._format_type or input_columns: - # TODO(QL, MS): ideally the behavior should be the same even if the dataset is formatted (may require major release) - inputs_to_merge = dict(zip(pa_inputs.column_names, pa_inputs.itercolumns())) - elif isinstance(inputs, LazyDict): + if isinstance(inputs, LazyDict): inputs_to_merge = { k: (v if k not in inputs.keys_to_format else pa_inputs[k]) for k, v in inputs.data.items() } else: - inputs_to_merge = inputs + inputs_to_merge = dict(zip(pa_inputs.column_names, pa_inputs.itercolumns())) if remove_columns is not None: for column in remove_columns: # `function` can modify input in-place causing column to be already removed.