diff --git a/scripts/speech_recognition/convert_to_tarred_audio_dataset.py b/scripts/speech_recognition/convert_to_tarred_audio_dataset.py index c3b5cef57cbc..b722a79327b9 100644 --- a/scripts/speech_recognition/convert_to_tarred_audio_dataset.py +++ b/scripts/speech_recognition/convert_to_tarred_audio_dataset.py @@ -91,6 +91,7 @@ import soundfile from joblib import Parallel, delayed from omegaconf import DictConfig, OmegaConf, open_dict +from tqdm import tqdm try: import create_dali_tarred_dataset_index as dali_index @@ -99,117 +100,6 @@ except (ImportError, ModuleNotFoundError, FileNotFoundError): DALI_INDEX_SCRIPT_AVAILABLE = False -parser = argparse.ArgumentParser( - description="Convert an existing ASR dataset to tarballs compatible with TarredAudioToTextDataLayer." -) -parser.add_argument( - "--manifest_path", default=None, type=str, required=False, help="Path to the existing dataset's manifest." -) - -parser.add_argument( - '--concat_manifest_paths', - nargs='+', - default=None, - type=str, - required=False, - help="Path to the additional dataset's manifests that will be concatenated with base dataset.", -) - -# Optional arguments -parser.add_argument( - "--target_dir", - default='./tarred', - type=str, - help="Target directory for resulting tarballs and manifest. Defaults to `./tarred`. Creates the path if necessary.", -) - -parser.add_argument( - "--metadata_path", - required=False, - default=None, - type=str, - help="Path to metadata file for the dataset.", -) - -parser.add_argument( - "--num_shards", - default=-1, - type=int, - help="Number of shards (tarballs) to create. Used for partitioning data among workers.", -) -parser.add_argument( - '--max_duration', - default=None, - required=True, - type=float, - help='Maximum duration of audio clip in the dataset. By default, it is None and is required to be set.', -) -parser.add_argument( - '--min_duration', - default=None, - type=float, - help='Minimum duration of audio clip in the dataset. By default, it is None and will not filter files.', -) -parser.add_argument( - "--shuffle", - action='store_true', - help="Whether or not to randomly shuffle the samples in the manifest before tarring/sharding.", -) - -parser.add_argument( - "--keep_files_together", - action='store_true', - help="Whether or not to keep entries from the same file (but different offsets) together when sorting before tarring/sharding.", -) - -parser.add_argument( - "--sort_in_shards", - action='store_true', - help="Whether or not to sort samples inside the shards based on their duration.", -) - -parser.add_argument( - "--buckets_num", - type=int, - default=1, - help="Number of buckets to create based on duration.", -) - -parser.add_argument( - "--dynamic_buckets_num", - type=int, - default=30, - help="Intended for dynamic (on-the-fly) bucketing; this option will not bucket your dataset during tar conversion. " - "Estimates optimal bucket duration bins for a given number of buckets.", -) - -parser.add_argument("--shuffle_seed", type=int, default=None, help="Random seed for use if shuffling is enabled.") -parser.add_argument( - '--write_metadata', - action='store_true', - help=( - "Flag to write a blank metadata with the current call config. " - "Note that the metadata will not contain the number of shards, " - "and it must be filled out by the user." - ), -) -parser.add_argument( - "--no_shard_manifests", - action='store_true', - help="Do not write sharded manifests along with the aggregated manifest.", -) -parser.add_argument( - "--force_codec", - type=str, - default=None, - help=( - "If specified, transcode the audio to the given format. " - "Supports libnsndfile formats (example values: 'opus', 'flac')." - ), -) -parser.add_argument('--workers', type=int, default=1, help='Number of worker processes') -args = parser.parse_args() - @dataclass class ASRTarredDatasetConfig: @@ -219,6 +109,7 @@ class ASRTarredDatasetConfig: min_duration: Optional[float] = None shuffle_seed: Optional[int] = None sort_in_shards: bool = True + slice_with_offset: bool = True shard_manifests: bool = True keep_files_together: bool = False force_codec: Optional[str] = None @@ -277,19 +168,39 @@ def configure(self, config: ASRTarredDatasetConfig): if self.config.num_shards < 0: raise ValueError("`num_shards` must be > 0. Please fill in the metadata information correctly.") - def create_new_dataset(self, manifest_path: str, target_dir: str = "./tarred/", num_workers: int = 0): + def create_new_dataset( + self, + manifest_path: str, + target_dir: str = "./tarred/", + num_workers: int = 0, + buckets_num: int = 1, + dynamic_buckets_num: int = 30, + dry_run: bool = False, + ): """ Creates a new tarred dataset from a given manifest file. Args: - manifest_path: Path to the original ASR manifest. - target_dir: Output directory. - num_workers: Integer denoting number of parallel worker processes which will write tarfiles. - Defaults to 1 - which denotes sequential worker process. + manifest_path (str): Path to the original ASR manifest file. + target_dir (str, optional): Output directory where tarred files and manifests will be saved. Defaults to "./tarred/". + num_workers (int, optional): Number of parallel worker processes for writing tar files. Defaults to 0 (sequential processing). + buckets_num (int, optional): Number of buckets for static bucketing. Defaults to 1 (no bucketing). + dynamic_buckets_num (int, optional): Number of buckets to estimate for dynamic bucketing. Defaults to 30. + dry_run (bool, optional): If True, performs a dry run without creating actual tar files. Defaults to False. + + Raises: + ValueError: If the configuration has not been set. + FileNotFoundError: If the manifest file does not exist. Output: - Writes tarfiles, along with the tarred dataset compatible manifest file. - Also preserves a record of the metadata used to construct this tarred dataset. + - Creates tar files and a tarred dataset compatible manifest file in the specified `target_dir`. + - Preserves a record of the metadata used to construct the tarred dataset in `metadata.yaml`. + - Optionally creates shard manifests if `config.shard_manifests` is enabled. + + Notes: + - The function reads the manifest, applies filtering and shuffling if specified, and creates shards of tar files. + - It generates shard manifests and the main tarred dataset manifest. + - Metadata is updated and saved based on the tarred dataset configuration. """ if self.config is None: raise ValueError("Config has not been set. Please call `configure(config: ASRTarredDatasetConfig)`") @@ -357,7 +268,7 @@ def create_new_dataset(self, manifest_path: str, target_dir: str = "./tarred/", with Parallel(n_jobs=num_workers, verbose=config.num_shards) as parallel: # Call parallel tarfile construction new_entries_list = parallel( - delayed(self._create_shard)(entries[start_idx:end_idx], target_dir, i, manifest_folder) + delayed(self._create_shard)(entries[start_idx:end_idx], target_dir, i, manifest_folder, dry_run) for i, (start_idx, end_idx) in enumerate(zip(start_indices, end_indices)) ) @@ -395,10 +306,10 @@ def create_new_dataset(self, manifest_path: str, target_dir: str = "./tarred/", metadata.dataset_config = config metadata.num_samples_per_shard = len(new_entries) // config.num_shards - if args.buckets_num <= 1: + if buckets_num <= 1: # Estimate and update dynamic bucketing args bucketing_kwargs = self.estimate_dynamic_bucketing_duration_bins( - new_manifest_path, num_buckets=args.dynamic_buckets_num + new_manifest_path, num_buckets=dynamic_buckets_num ) for k, v in bucketing_kwargs.items(): setattr(metadata.dataset_config, k, v) @@ -410,6 +321,7 @@ def create_new_dataset(self, manifest_path: str, target_dir: str = "./tarred/", def estimate_dynamic_bucketing_duration_bins(self, manifest_path: str, num_buckets: int = 30) -> dict: from lhotse import CutSet from lhotse.dataset.sampling.dynamic_bucketing import estimate_duration_buckets + from nemo.collections.common.data.lhotse.nemo_adapters import LazyNeMoIterator cuts = CutSet(LazyNeMoIterator(manifest_path, metadata_only=True)) @@ -439,25 +351,33 @@ def create_concatenated_dataset( metadata: ASRTarredDatasetMetadata, target_dir: str = "./tarred_concatenated/", num_workers: int = 1, + dry_run: bool = False, ): """ - Creates new tarfiles in order to create a concatenated dataset, whose manifest contains the data for - both the original dataset as well as the new data submitted in manifest paths. + Creates a concatenated tarred dataset from the base manifest and additional manifest files. Args: - base_manifest_path: Path to the manifest file which contains the information for the original + base_manifest_path (str): Path to the base manifest file that contains information for the original tarred dataset (with flattened paths). - manifest_paths: List of one or more paths to manifest files that will be concatenated with above - base tarred dataset. - metadata: ASRTarredDatasetMetadata dataclass instance with overrides from command line. - target_dir: Output directory + manifest_paths (List[str]): List of paths to additional manifest files that will be concatenated with + the base tarred dataset. + metadata (ASRTarredDatasetMetadata): Metadata instance containing configuration and overrides. + target_dir (str, optional): Output directory where tarred files and manifests will be saved. Defaults to "./tarred_concatenated/". + num_workers (int, optional): Number of parallel worker processes for creating tar files. Defaults to 1. + dry_run (bool, optional): If True, performs a dry run without creating actual tar files. Defaults to False. - Output: - Writes tarfiles which with indices mapping to a "concatenated" tarred dataset, - along with the tarred dataset compatible manifest file which includes information - about all the datasets that comprise the concatenated dataset. + Raises: + FileNotFoundError: If the base manifest file or any of the additional manifest files does not exist. - Also preserves a record of the metadata used to construct this tarred dataset. + Output: + - Creates tar files and a concatenated tarred dataset compatible manifest file in the specified `target_dir`. + - Updates metadata to reflect the concatenated dataset, including the version and historical data. + + Notes: + - The function reads the base manifest and additional manifests, filters and shuffles entries as needed, + and creates new shards of tar files. + - It generates a new concatenated dataset manifest and updates metadata with versioning and historical context. + - If `metadata` is provided, the function updates its version and includes historical data in the new metadata. """ if not os.path.exists(target_dir): os.makedirs(target_dir) @@ -548,7 +468,9 @@ def create_concatenated_dataset( with Parallel(n_jobs=num_workers, verbose=num_added_shards) as parallel: # Call parallel tarfile construction new_entries_list = parallel( - delayed(self._create_shard)(entries[start_idx:end_idx], target_dir, shard_idx, manifest_folder) + delayed(self._create_shard)( + entries[start_idx:end_idx], target_dir, shard_idx, manifest_folder, dry_run + ) for i, (start_idx, end_idx, shard_idx) in enumerate(zip(start_indices, end_indices, shard_indices)) ) @@ -625,6 +547,10 @@ def _read_manifest(self, manifest_path: str, config: ASRTarredDatasetConfig): for line in m: entry = json.loads(line) audio_key = "audio_filepath" if "audio_filepath" in entry else "audio_file" + if config.slice_with_offset and "offset" not in entry: + raise KeyError( + f"Manifest entry does not contain 'offset' field, but '--slice_with_offset' is enabled: {entry}" + ) if audio_key not in entry: raise KeyError(f"Manifest entry does not contain 'audio_filepath' or 'audio_file' key: {entry}") audio_filepath = entry[audio_key] @@ -644,17 +570,28 @@ def _read_manifest(self, manifest_path: str, config: ASRTarredDatasetConfig): return entries, total_duration, filtered_entries, filtered_duration - def _write_to_tar(self, tar, audio_filepath: str, squashed_filename: str) -> None: - if (codec := self.config.force_codec) is None or audio_filepath.endswith(f".{codec}"): + def _write_to_tar( + self, tar, audio_filepath: str, squashed_filename: str, duration: float = None, offset: float = 0 + ) -> None: + if ((codec := self.config.force_codec) is None or audio_filepath.endswith(f".{codec}")) and not duration: # Add existing file without transcoding. tar.add(audio_filepath, arcname=squashed_filename) else: - # Transcode to the desired format in-memory and add the result to the tar file. + # Read audio file audio, sampling_rate = soundfile.read(audio_filepath, dtype=np.float32) + # Calculate start and end points for slicing + start_sample = int(offset * sampling_rate) + end_sample = int((offset + duration) * sampling_rate) if duration else None + audio = audio[start_sample:end_sample] + # Transcode and write to tar encoded_audio = BytesIO() - if codec == "opus": - kwargs = {"format": "ogg", "subtype": "opus"} + if codec is not None: + if codec == "opus": + kwargs = {"format": "ogg", "subtype": "opus"} + else: + kwargs = {"format": codec} else: + codec = soundfile.info(audio_filepath).format.lower() kwargs = {"format": codec} soundfile.write(encoded_audio, audio, sampling_rate, closefd=False, **kwargs) encoded_squashed_filename = f"{squashed_filename.split('.')[0]}.{codec}" @@ -663,20 +600,26 @@ def _write_to_tar(self, tar, audio_filepath: str, squashed_filename: str) -> Non ti.size = len(encoded_audio.getvalue()) tar.addfile(ti, encoded_audio) - def _create_shard(self, entries, target_dir, shard_id, manifest_folder): + def _create_shard(self, entries, target_dir, shard_id, manifest_folder: str = None, dry_run: bool = False): """Creates a tarball containing the audio files from `entries`.""" if self.config.sort_in_shards: entries.sort(key=lambda x: x["duration"], reverse=False) new_entries = [] - tar = tarfile.open(os.path.join(target_dir, f'audio_{shard_id}.tar'), mode='w', dereference=True) + + tar_filepath = os.path.join(target_dir, f'audio_{shard_id}.tar') + if not dry_run: + tar = tarfile.open(tar_filepath, mode='w', dereference=True) count = dict() - for entry in entries: + for entry in tqdm(entries, desc="Creating shard.."): # We squash the filename since we do not preserve directory structure of audio files in the tarball. if os.path.exists(entry["audio_filepath"]): audio_filepath = entry["audio_filepath"] else: + if not manifest_folder: + raise FileNotFoundError(f"Could not find {entry['audio_filepath']}!") + audio_filepath = os.path.join(manifest_folder, entry["audio_filepath"]) if not os.path.exists(audio_filepath): raise FileNotFoundError(f"Could not find {entry['audio_filepath']}!") @@ -686,14 +629,33 @@ def _create_shard(self, entries, target_dir, shard_id, manifest_folder): # Need the following replacement as long as WebDataset splits on first period base = base.replace('.', '_') squashed_filename = f'{base}{ext}' - if squashed_filename not in count: - self._write_to_tar(tar, audio_filepath, squashed_filename) - to_write = squashed_filename - count[squashed_filename] = 1 - else: - to_write = base + "-sub" + str(count[squashed_filename]) + ext + + if self.config.slice_with_offset: + if squashed_filename not in count: + count[squashed_filename] = 1 + + to_write = base + "_" + str(count[squashed_filename]) + ext + if not dry_run: + self._write_to_tar( + tar, audio_filepath, to_write, duration=entry['duration'], offset=entry['offset'] + ) count[squashed_filename] += 1 + entry['source_audio_offset'] = entry['offset'] + del entry['offset'] + else: + if squashed_filename not in count: + if not dry_run: + self._write_to_tar(tar, audio_filepath, squashed_filename) + to_write = squashed_filename + count[squashed_filename] = 1 + else: + to_write = base + "-sub" + str(count[squashed_filename]) + ext + count[squashed_filename] += 1 + + if dry_run: + entry['abs_audio_filepath'] = audio_filepath + # Carry over every key in the entry, override audio_filepath and shard_id new_entry = { **entry, @@ -702,7 +664,8 @@ def _create_shard(self, entries, target_dir, shard_id, manifest_folder): } new_entries.append(new_entry) - tar.close() + if not dry_run: + tar.close() return new_entries @classmethod @@ -718,41 +681,65 @@ def setup_history(cls, base_metadata: ASRTarredDatasetMetadata, history: List[An history.append(metadata_copy) -def main(): +def main(args): if args.buckets_num > 1: bucket_length = (args.max_duration - args.min_duration) / float(args.buckets_num) - for i in range(args.buckets_num): - min_duration = args.min_duration + i * bucket_length - max_duration = min_duration + bucket_length - if i == args.buckets_num - 1: + for i_bucket in range(args.buckets_num): + bucket_config = copy.deepcopy(args) + bucket_config.min_duration = args.min_duration + i_bucket * bucket_length + bucket_config.max_duration = bucket_config.min_duration + bucket_length + if i_bucket == args.buckets_num - 1: # add a small number to cover the samples with exactly duration of max_duration in the last bucket. - max_duration += 1e-5 - target_dir = os.path.join(args.target_dir, f"bucket{i+1}") - print(f"Creating bucket {i+1} with min_duration={min_duration} and max_duration={max_duration} ...") - print(f"Results are being saved at: {target_dir}.") - create_tar_datasets(min_duration=min_duration, max_duration=max_duration, target_dir=target_dir) - print(f"Bucket {i+1} is created.") + bucket_config.max_duration += 1e-5 + bucket_config.target_dir = os.path.join(args.target_dir, f"bucket{i_bucket+1}") + print( + f"Creating bucket {i_bucket+1} with min_duration={bucket_config.min_duration} and max_duration={bucket_config.max_duration} ..." + ) + print(f"Results are being saved at: {bucket_config.target_dir}.") + create_tar_datasets(**vars(bucket_config)) + print(f"Bucket {i_bucket+1} is created.") else: - create_tar_datasets(min_duration=args.min_duration, max_duration=args.max_duration, target_dir=args.target_dir) - - -def create_tar_datasets(min_duration: float, max_duration: float, target_dir: str): + create_tar_datasets(**vars(args)) + + +def create_tar_datasets( + manifest_path: str = None, + concat_manifest_paths: str = None, + target_dir: str = None, + metadata_path: str = None, + num_shards: int = -1, + max_duration: float = None, + min_duration: float = None, + shuffle: bool = False, + keep_files_together: bool = False, + sort_in_shards: bool = False, + buckets_num: int = 1, + dynamic_buckets_num: int = 30, + shuffle_seed: int = None, + write_metadata: bool = False, + no_shard_manifests: bool = False, + force_codec: str = None, + workers: int = 1, + slice_with_offset: bool = False, + dry_run: bool = False, +): builder = ASRTarredDatasetBuilder() - shard_manifests = False if args.no_shard_manifests else True + shard_manifests = False if no_shard_manifests else True - if args.write_metadata: + if write_metadata: metadata = ASRTarredDatasetMetadata() dataset_cfg = ASRTarredDatasetConfig( - num_shards=args.num_shards, - shuffle=args.shuffle, + num_shards=num_shards, + shuffle=shuffle, max_duration=max_duration, min_duration=min_duration, - shuffle_seed=args.shuffle_seed, - sort_in_shards=args.sort_in_shards, + shuffle_seed=shuffle_seed, + sort_in_shards=sort_in_shards, shard_manifests=shard_manifests, - keep_files_together=args.keep_files_together, - force_codec=args.force_codec, + keep_files_together=keep_files_together, + force_codec=force_codec, + slice_with_offset=slice_with_offset, ) metadata.dataset_config = dataset_cfg @@ -761,32 +748,40 @@ def create_tar_datasets(min_duration: float, max_duration: float, target_dir: st print(f"Default metadata written to {output_path}") exit(0) - if args.concat_manifest_paths is None or len(args.concat_manifest_paths) == 0: + if concat_manifest_paths is None or len(concat_manifest_paths) == 0: print("Creating new tarred dataset ...") # Create a tarred dataset from scratch config = ASRTarredDatasetConfig( - num_shards=args.num_shards, - shuffle=args.shuffle, + num_shards=num_shards, + shuffle=shuffle, max_duration=max_duration, min_duration=min_duration, - shuffle_seed=args.shuffle_seed, - sort_in_shards=args.sort_in_shards, + shuffle_seed=shuffle_seed, + sort_in_shards=sort_in_shards, shard_manifests=shard_manifests, - keep_files_together=args.keep_files_together, - force_codec=args.force_codec, + keep_files_together=keep_files_together, + force_codec=force_codec, + slice_with_offset=slice_with_offset, ) builder.configure(config) - builder.create_new_dataset(manifest_path=args.manifest_path, target_dir=target_dir, num_workers=args.workers) + builder.create_new_dataset( + manifest_path=manifest_path, + target_dir=target_dir, + num_workers=workers, + buckets_num=buckets_num, + dynamic_buckets_num=dynamic_buckets_num, + dry_run=dry_run, + ) else: - if args.buckets_num > 1: + if buckets_num > 1: raise ValueError("Concatenation feature does not support buckets_num > 1.") print("Concatenating multiple tarred datasets ...") # Implicitly update config from base details - if args.metadata_path is not None: - metadata = ASRTarredDatasetMetadata.from_file(args.metadata_path) + if metadata_path is not None: + metadata = ASRTarredDatasetMetadata.from_file(metadata_path) else: raise ValueError("`metadata` yaml file path must be provided!") @@ -798,27 +793,151 @@ def create_tar_datasets(min_duration: float, max_duration: float, target_dir: st # Add command line overrides (everything other than num_shards) metadata.dataset_config.max_duration = max_duration metadata.dataset_config.min_duration = min_duration - metadata.dataset_config.shuffle = args.shuffle - metadata.dataset_config.shuffle_seed = args.shuffle_seed - metadata.dataset_config.sort_in_shards = args.sort_in_shards + metadata.dataset_config.shuffle = shuffle + metadata.dataset_config.shuffle_seed = shuffle_seed + metadata.dataset_config.sort_in_shards = sort_in_shards metadata.dataset_config.shard_manifests = shard_manifests builder.configure(metadata.dataset_config) # Concatenate a tarred dataset onto a previous one builder.create_concatenated_dataset( - base_manifest_path=args.manifest_path, - manifest_paths=args.concat_manifest_paths, + base_manifest_path=manifest_path, + manifest_paths=concat_manifest_paths, metadata=metadata, target_dir=target_dir, - num_workers=args.workers, + num_workers=workers, + slice_with_offset=slice_with_offset, + dry_run=dry_run, ) if DALI_INDEX_SCRIPT_AVAILABLE and dali_index.INDEX_CREATOR_AVAILABLE: print("Constructing DALI Tarfile Index - ", target_dir) - index_config = dali_index.DALITarredIndexConfig(tar_dir=target_dir, workers=args.workers) + index_config = dali_index.DALITarredIndexConfig(tar_dir=target_dir, workers=workers) dali_index.main(index_config) if __name__ == "__main__": - main() + parser = argparse.ArgumentParser( + description="Convert an existing ASR dataset to tarballs compatible with TarredAudioToTextDataLayer." + ) + parser.add_argument( + "--manifest_path", default=None, type=str, required=False, help="Path to the existing dataset's manifest." + ) + + parser.add_argument( + '--concat_manifest_paths', + nargs='+', + default=None, + type=str, + required=False, + help="Path to the additional dataset's manifests that will be concatenated with base dataset.", + ) + + # Optional arguments + parser.add_argument( + "--target_dir", + default='./tarred', + type=str, + help="Target directory for resulting tarballs and manifest. Defaults to `./tarred`. Creates the path if necessary.", + ) + + parser.add_argument( + "--metadata_path", + required=False, + default=None, + type=str, + help="Path to metadata file for the dataset.", + ) + + parser.add_argument( + "--num_shards", + default=-1, + type=int, + help="Number of shards (tarballs) to create. Used for partitioning data among workers.", + ) + parser.add_argument( + '--max_duration', + default=None, + required=True, + type=float, + help='Maximum duration of audio clip in the dataset. By default, it is None and is required to be set.', + ) + parser.add_argument( + '--min_duration', + default=None, + type=float, + help='Minimum duration of audio clip in the dataset. By default, it is None and will not filter files.', + ) + parser.add_argument( + "--shuffle", + action='store_true', + help="Whether or not to randomly shuffle the samples in the manifest before tarring/sharding.", + ) + + parser.add_argument( + "--keep_files_together", + action='store_true', + help="Whether or not to keep entries from the same file (but different offsets) together when sorting before tarring/sharding.", + ) + parser.add_argument( + "--slice_with_offset", + action='store_true', + help="If set, only slices the audio based on `duration` and `offset` entry parameters.", + ) + parser.add_argument( + "--sort_in_shards", + action='store_true', + help="Whether or not to sort samples inside the shards based on their duration.", + ) + + parser.add_argument( + "--buckets_num", + type=int, + default=1, + help="Number of buckets to create based on duration.", + ) + + parser.add_argument( + "--dynamic_buckets_num", + type=int, + default=30, + help="Intended for dynamic (on-the-fly) bucketing; this option will not bucket your dataset during tar conversion. " + "Estimates optimal bucket duration bins for a given number of buckets.", + ) + + parser.add_argument("--shuffle_seed", type=int, default=None, help="Random seed for use if shuffling is enabled.") + parser.add_argument( + '--write_metadata', + action='store_true', + help=( + "Flag to write a blank metadata with the current call config. " + "Note that the metadata will not contain the number of shards, " + "and it must be filled out by the user." + ), + ) + parser.add_argument( + "--no_shard_manifests", + action='store_true', + help="Do not write sharded manifests along with the aggregated manifest.", + ) + parser.add_argument( + "--force_codec", + type=str, + default=None, + help=( + "If specified, transcode the audio to the given format. " + "Supports libnsndfile formats (example values: 'opus', 'flac')." + ), + ) + parser.add_argument( + "--dry_run", + action='store_true', + help=( + "If set, only creates manifests for each shard without creating the actual tar files. " + "This allows you to verify the output structure and content before committing to the full tarball creation process." + ), + ) + parser.add_argument('--workers', type=int, default=1, help='Number of worker processes') + args = parser.parse_args() + main(args) diff --git a/scripts/speech_recognition/partial_convertion_to_tarred_audio_dataset.py b/scripts/speech_recognition/partial_convertion_to_tarred_audio_dataset.py new file mode 100644 index 000000000000..d1f774f87ae0 --- /dev/null +++ b/scripts/speech_recognition/partial_convertion_to_tarred_audio_dataset.py @@ -0,0 +1,195 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from dataclasses import dataclass, field +from typing import Optional + +import hydra +from convert_to_tarred_audio_dataset import ASRTarredDatasetBuilder, ASRTarredDatasetMetadata +from hydra.core.config_store import ConfigStore +from joblib import Parallel, delayed +from omegaconf import MISSING +from tqdm import tqdm + +""" +# Partial Tarred Audio Dataset Creator + +## Overview + +This script facilitates the creation of tarred and sharded audio datasets from existing tarred manifests. It allows you to select specific shards from a manifest file and then tar them separately. This is useful when you need to process a subset of shards or manage tarred datasets in chunks. + +## Prerequisites + +- Ensure that the `convert_to_tarred_audio_dataset` script is correctly configured and run with the `--dry_run` flag to generate the necessary manifest files. +- Make sure the paths to the manifest and metadata files are correct and accessible. + +## Usage + +### Script Execution + +To run the script, use the following command: + +python partial_tar_creator.py \ + # the path to the tarred manifest file that contains the entries for the shards you want to process. This option is mandatory. + --tarred_manifest_filepath= \ + # any other optional argument + --output_dir= \ + --shards_to_tar= \ + --num_workers=-1 \ + --dataset_metadata_filepath= + +Example: +python partial_tar_creator.py \ + tarred_manifest_filepath="path/to/manifest.json" \ + shards_to_tar="0:3" +""" + + +def select_shards(manifest_filepath: str, shards_to_tar: str, slice_with_offset: bool = False): + """ + Selects and returns a subset of shards from the tarred manifest file. + + Args: + manifest_filepath (str): The path to the tarred manifest file. + shards_to_tar (str): A range or list of shard IDs to select, e.g., "0:5" or "0,1,2". + slice_with_offset (bool, optional): If True, slices entries based on audio offsets. Defaults to False. + + Raises: + FileNotFoundError: If the manifest file does not exist. + KeyError: If `slice_with_offset` is enabled but required fields are missing in the manifest entries. + + Returns: + Dict[int, List[Dict[str, any]]]: A dictionary where the keys are shard IDs and the values are lists of entries for those shards. + """ + shard_ids = [] + if shards_to_tar != "all": + if ":" not in shards_to_tar: + shard_ids = [int(shards_to_tar)] + else: + start_shard_idx, end_shard_idx = map( + lambda x: int(x.strip()) if x.strip() else None, shards_to_tar.split(":") + ) + shard_ids = list(range(start_shard_idx, end_shard_idx)) + + entries_to_shard = {} + with open(manifest_filepath, 'r') as manifest: + for line in tqdm(manifest, desc="Selecting shards"): + entry = json.loads(line) + if shards_to_tar == "all" or entry['shard_id'] in shard_ids: + if entry['shard_id'] not in entries_to_shard: + entries_to_shard[entry['shard_id']] = [] + + if slice_with_offset: + if 'abs_audio_filepath' not in entry or 'source_audio_offset' not in entry: + raise KeyError( + f"`slice_with_offset` is enabled, but `abs_audio_filepath` and/or `source_audio_offset` are not found in the entry:\n{entry}." + ) + entry['audio_filepath'] = entry['abs_audio_filepath'] + entry['offset'] = entry['source_audio_offset'] + + entries_to_shard[entry['shard_id']].append(entry) + + return entries_to_shard + + +@dataclass +class PartialASRTarredDatasetConfig: + """ + Configuration class for creating partial tarred audio dataset shards. + + Attributes: + tarred_manifest_filepath (str): The path to the tarred manifest file. + output_dir (Optional[str]): Directory where the output tarred shards will be saved. + shards_to_tar (Optional[str]): A range or list of shard IDs to tar. + num_workers (int): Number of parallel workers to use for tar file creation. + dataset_metadata_filepath (Optional[str]): Path to the dataset metadata YAML file. + dataset_metadata (ASRTarredDatasetMetadata): Dataset metadata configuration. + """ + + tarred_manifest_filepath: str = MISSING + output_dir: Optional[str] = None + shards_to_tar: Optional[str] = "all" + num_workers: int = 1 + dataset_metadata_filepath: Optional[str] = None + dataset_metadata: ASRTarredDatasetMetadata = field(default=ASRTarredDatasetMetadata) + slice_with_offset: bool = False + + +def create_shards(cfg: PartialASRTarredDatasetConfig): + """ + Creates tarred shards based on the provided configuration. + + Args: + cfg (PartialASRTarredDatasetConfig): The configuration object containing paths, shard IDs, and metadata. + + Raises: + ValueError: If the `tarred_manifest_filepath` is None. + FileNotFoundError: If the tarred manifest file or dataset metadata file does not exist. + + Notes: + - Reads the tarred manifest file and selects the specified shards. + - Creates tarred shards in parallel using the `ASRTarredDatasetBuilder`. + - The `dataset_metadata_filepath` is inferred if not provided. + """ + if cfg.tarred_manifest_filepath is None: + raise ValueError("The `tarred_manifest_filepath` cannot be `None`. Please check your configuration.") + + if not os.path.exists(cfg.tarred_manifest_filepath): + raise FileNotFoundError( + f"The `tarred_manifest_filepath` was not found: {cfg.tarred_manifest_filepath}. Please verify that the filepath is correct." + ) + + if cfg.dataset_metadata_filepath is None: + cfg.dataset_metadata_filepath = os.path.join(os.path.dirname(cfg.tarred_manifest_filepath), "metadata.yaml") + + if cfg.output_dir is None: + cfg.output_dir = os.path.dirname(cfg.tarred_manifest_filepath) + + if not os.path.exists(cfg.dataset_metadata_filepath): + raise FileNotFoundError( + f"The `dataset_metadata_filepath` was not found: {cfg.dataset_metadata_filepath}. Please verify that the filepath is correct." + ) + else: + cfg.dataset_metadata = ASRTarredDatasetMetadata.from_file(cfg.dataset_metadata_filepath) + + entries_to_shard = select_shards( + cfg.tarred_manifest_filepath, cfg.shards_to_tar, cfg.dataset_metadata.dataset_config.slice_with_offset + ) + + builder = ASRTarredDatasetBuilder() + builder.configure(cfg.dataset_metadata.dataset_config) + + with Parallel(n_jobs=cfg.num_workers, verbose=len(entries_to_shard)) as parallel: + # Call parallel tarfile construction + _ = parallel( + delayed(builder._create_shard)( + entries=entries_to_shard[shard_id], + target_dir=cfg.output_dir, + shard_id=shard_id, + ) + for shard_id in entries_to_shard + ) + + +@hydra.main(config_path=None, config_name='partial_tar_config') +def main(cfg: PartialASRTarredDatasetConfig): + create_shards(cfg) + + +ConfigStore.instance().store(name='partial_tar_config', node=PartialASRTarredDatasetConfig) + +if __name__ == '__main__': + main()