Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated preprocess_text.py to fix NeMo-text-processing :: WARNING :: Your input is too long and could take a long time to normalize. Use split_text_into_sentences() to make the input shorter and then call normalize_list(). #11780

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 56 additions & 12 deletions scripts/dataset_processing/tts/preprocess_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
--num_workers=4 \
--joblib_batch_size=16
"""

import argparse
from pathlib import Path

Expand All @@ -49,27 +48,42 @@

def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Process and normalize text data.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description="Process and normalize text data.",
)
parser.add_argument(
"--input_manifest", required=True, type=Path, help="Path to input training manifest.",
"--input_manifest",
required=True,
type=Path,
help="Path to input training manifest.",
)
parser.add_argument(
"--output_manifest", required=True, type=Path, help="Path to output training manifest with processed text.",
"--output_manifest",
required=True,
type=Path,
help="Path to output training manifest with processed text.",
)
parser.add_argument(
"--overwrite",
action=argparse.BooleanOptionalAction,
help="Whether to overwrite the output manifest file if it exists.",
)
parser.add_argument(
"--text_key", default="text", type=str, help="Input text field to normalize.",
"--text_key",
default="text",
type=str,
help="Input text field to normalize.",
)
parser.add_argument(
"--normalized_text_key", default="normalized_text", type=str, help="Output field to save normalized text to.",
"--normalized_text_key",
default="normalized_text",
type=str,
help="Output field to save normalized text to.",
)
parser.add_argument(
"--lower_case", action=argparse.BooleanOptionalAction, help="Whether to convert the final text to lower case.",
"--lower_case",
action=argparse.BooleanOptionalAction,
help="Whether to convert the final text to lower case.",
)
parser.add_argument(
"--normalizer_config_path",
Expand Down Expand Up @@ -102,14 +116,44 @@ def _process_entry(
text = entry[text_key]

if normalizer is not None:
if lower_case_norm:
text = text.lower()
text = normalizer.normalize(text, punct_pre_process=True, punct_post_process=True)
# Define additional split symbols to enhance splitting
additional_split_symbols = ";|:" # Adjust based on your dataset's characteristics

# Split text into sentences using additional split symbols
sentences = normalizer.split_text_into_sentences(text, additional_split_symbols=additional_split_symbols)

# Further split sentences longer than 500 words
split_sentences = []
for sentence in sentences:
words = sentence.split()
if len(words) > 500:
# Split into chunks of 500 words
for i in range(0, len(words), 500):
chunk = ' '.join(words[i : i + 500])
split_sentences.append(chunk)
else:
split_sentences.append(sentence)

# Log sentences exceeding 500 words (for debugging)
for idx, sentence in enumerate(split_sentences):
word_count = len(sentence.split())
if word_count > 500:
print(f"Warning: Sentence {idx} with {word_count} words is still too long.")

# Normalize each sentence individually
normalized_sentences = [
normalizer.normalize(sentence, punct_pre_process=True, punct_post_process=True)
for sentence in split_sentences
]
# Concatenate normalized sentences
normalized_text = ' '.join(normalized_sentences)
else:
normalized_text = text

if lower_case:
text = text.lower()
normalized_text = normalized_text.lower()

entry[normalized_text_key] = text
entry[normalized_text_key] = normalized_text

return entry

Expand Down
Loading