Skip to content

Commit

Permalink
Added c4 badwords filter, added batch tokenization to tokenscounter (h…
Browse files Browse the repository at this point in the history
…uggingface#160)

* added c4 badwords filter, added batch tokenization to tokenscounter

* add exclusion_writer

* handle tokenizers with vocab size > 65k

* add simple dataloader
  • Loading branch information
guipenedo authored Apr 24, 2024
1 parent 6d06210 commit 4e9235f
Show file tree
Hide file tree
Showing 10 changed files with 366 additions and 69 deletions.
2 changes: 1 addition & 1 deletion src/datatrove/pipeline/filters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .c4_quality_filter import C4ParagraphFilter, C4QualityFilter
from .c4_filters import C4BadWordsFilter, C4ParagraphFilter, C4QualityFilter
from .fasttext_filter import FastTextClassifierFilter
from .fineweb_quality_filter import FineWebQualityFilter
from .gopher_quality_filter import GopherQualityFilter
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import heapq
import re

from numpy.random import default_rng

from datatrove.data import Document
from datatrove.io import cached_asset_path_or_download
from datatrove.pipeline.filters.base_filter import BaseFilter
from datatrove.pipeline.writers.disk_base import DiskWriter

Expand Down Expand Up @@ -168,3 +171,116 @@ def filter(self, doc: Document) -> bool | tuple[bool, str]:
if not self.paragraph_filter(doc.text):
return False, f"< {self.min_paragraphs} paragraphs"
return True


_EN_BADWORDS_URL = "https://raw.githubusercontent.com/LDNOOBW/List-of-Dirty-Naughty-Obscene-and-Otherwise-Bad-Words/25e679f03d96baa721cde20db9944649e8d0a844/en"
_BADWORDS_URL = "https://raw.githubusercontent.com/LDNOOBW/List-of-Dirty-Naughty-Obscene-and-Otherwise-Bad-Words/5faf2ba42d7b1c0977169ec3611df25a3c08eb13/"
_BADWORDS_LANGS = [
"ar",
"cs",
"da",
"de",
"en",
"eo",
"es",
"fa",
"fi",
"fil",
"fr",
"fr-CA-u-sd-caqc",
"hi",
"hu",
"it",
"ja",
"kab",
"ko",
"nl",
"no",
"pl",
"pt",
"ru",
"sv",
"th",
"tlh",
"tr",
"zh",
]
# Words that are allowed since they are common subwords in languages without
# spaces. These each filter >10% of documents of their language when disallowed.
_BADWORDS_ALLOWLIST = {"ja": {"sm", "グロ", "女の子"}, "zh": {"性"}}


class C4BadWordsFilter(BaseFilter):
"""
Badwords filter from C4.
Args:
keep_fraction (float): what percentage of pages containing bad words should be kept
fail_on_missing_language (bool) whether to fail when a document has an unknown language
seed (int): used for the uniform distribution generator for use with keep_fraction
default_language (str): what language for samples without language in their metadata
"""

name = "⛰ C4 Badwords"

def __init__(
self,
keep_fraction: float = 0.0,
fail_on_missing_language: bool = True,
seed: int = None,
default_language: str = "en",
exclusion_writer: DiskWriter = None,
):
super().__init__(exclusion_writer)
self.keep_fraction = keep_fraction
self.fail_on_missing_language = fail_on_missing_language
self._badwords_regex: dict[str, re.Pattern] = {}
self.uniform = default_rng(seed).uniform
self.default_language = default_language

def _get_badwords(self, lang: str):
if lang not in self._badwords_regex:
if lang not in _BADWORDS_LANGS:
if self.fail_on_missing_language:
raise ValueError(
f'There is not badwords list available for "{lang}". '
f"Set fail_on_missing_language=False to continue anyway."
)
else:
return None
local_path = cached_asset_path_or_download(
_BADWORDS_URL + lang if lang != "en" else _EN_BADWORDS_URL,
namespace="filters",
subfolder="c4_badwords",
)
badwords: set[str] = set()
# load from file
with open(local_path, "rt") as f:
badwords.update(line.strip() for line in f)
for lang, allowlist in _BADWORDS_ALLOWLIST.items():
badwords -= allowlist

words = [re.escape(w) for w in badwords]
self._badwords_regex[lang] = (
# For Japanese, Thai, and Chinese, do not require word separations.
re.compile("|".join(words))
if lang in ("ja", "th", "zh")
# For other languages, match only when flanked by non-word chars.
else re.compile(r"(?:\W|^)({})(?:\W|$)".format("|".join(words)))
)
return self._badwords_regex[lang]

def filter(self, doc: Document) -> bool | tuple[bool, str]:
lang: str = doc.metadata.get("language", self.default_language)
badwords_regex = self._get_badwords(lang)
if badwords_regex is None:
self.stat_update("missing_badwords_lang", f"missing_badwords_lang_{lang}")
return True
badwords_found = badwords_regex.search(doc.text.lower())
if badwords_found is not None:
self.stat_update("documents_with_badwords", f"documents_with_badwords_{lang}")
if self.keep_fraction > 0.0 and self.uniform() < self.keep_fraction:
self.stat_update("document_kept_with_badwords", f"document_kept_with_badwords_{lang}")
return True
self.stat_update(f"document_removed_with_badwords_{lang}")
return False, "document_removed_with_badwords"
return True
3 changes: 2 additions & 1 deletion src/datatrove/pipeline/readers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ def get_document_from_dict(self, data: dict, source_file: str, id_in_file: int |
if not self._empty_warning:
self._empty_warning = True
logger.warning(
f"Found document without text, skipping. " f'Is your `text_key` ("{self.text_key}") correct?'
f"Found document without text, skipping. "
f'Is your `text_key` ("{self.text_key}") correct? Available keys: {list(data.keys())}'
)
return None
document = Document(**parsed_data)
Expand Down
8 changes: 7 additions & 1 deletion src/datatrove/pipeline/tokens/context_shuffler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class DocumentTokenizerContextShuffler(PipelineStep):
output_folder: the output folder to write the shuffled documents to
window_size: the size of the window to shuffle (default: 2048 + 1)
seed: the seed for the random number generator (default: None)
token_size (int): size of each token, in bytes
"""

name = "🗃 Context Shuffler"
Expand All @@ -29,11 +30,13 @@ def __init__(
output_folder: DataFolderLike,
window_size: int = 2048 + 1,
seed: int = None,
token_size: int = 2,
):
super().__init__()
self.input_folder = get_datafolder(input_folder)
self.output_folder = get_datafolder(output_folder)
self.window_size = window_size
self.token_size = token_size
self.rand = default_rng(seed)

def get_ordering(self, all_doc_ends):
Expand Down Expand Up @@ -73,5 +76,8 @@ def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1
with mmap.mmap(f.fileno(), 0, prot=mmap.PROT_READ) as unshuf:
with self.track_time():
for windowi in ordering:
start, end = windowi * self.window_size * 2, (windowi + 1) * self.window_size * 2
start, end = (
windowi * self.window_size * self.token_size,
(windowi + 1) * self.window_size * self.token_size,
)
fout.write(unshuf[start:end])
39 changes: 17 additions & 22 deletions src/datatrove/pipeline/tokens/counter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datatrove.data import DocumentsPipeline
from datatrove.pipeline.base import PipelineStep
from datatrove.utils.tokenization import PipelineStepWithTokenizer
from datatrove.utils.tokenization import PipelineStepWithTokenizer, batched


class TokensCounter(PipelineStepWithTokenizer):
Expand All @@ -10,7 +10,8 @@ class TokensCounter(PipelineStepWithTokenizer):
Args:
tokenizer_name_or_path (str): the name or path of the tokenizer to use, from the HuggingFace tokenizers library or a local file.
count_eos_token (bool): whether to count the EOS token on each document.
count_eos_token (bool): whether to count the EOS token on each document. (basically +1 per document)
batch_size: batch size for tokenization
"""

name = "📊 Counter"
Expand All @@ -20,20 +21,12 @@ def __init__(
self,
tokenizer_name_or_path: str = "gpt2", # tokenizer to use, from HF or a local file path
count_eos_token: bool = False, # whether to count the EOS token on each document
overwrite: bool = True, # re-tokenize and recompute nb of tokens even if they are already in metadata["tokens_count"]
batch_size: int = 10000, # batch size for tokenization
):
"""
Initializes the token counting pipeline step.
Args:
tokenizer_name_or_path: Name or path of tokenizer to use (from HF or local).
count_eos_token: Whether to include the EOS token in the token count per document. (basically +1 per document)
overwrite: Whether to re-tokenize and recompute the number of tokens even if they are already stored in metadata["tokens_count"]
"""
super().__init__()
self.tokenizer_name_or_path = tokenizer_name_or_path
self.count_eos_token = count_eos_token
self.overwrite = overwrite
self.batch_size = batch_size

def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1) -> DocumentsPipeline:
"""
Expand All @@ -47,17 +40,19 @@ def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1) -> Do
DocumentsPipeline: The pipeline with updated documents, each having a new or updated `token_count` in its metadata.
"""
for document in data:
if "token_count" in document.metadata and not self.overwrite:
count = document.metadata["token_count"]
else:
with self.track_time():
count = len(self.tokenizer.encode(document.text).ids)
if self.count_eos_token:
count += 1
from tokenizers import Encoding

# tokenize document's text in batches to go faster
for batch in batched(data, self.batch_size):
with self.track_time(unit="batch"):
encoded_batch: list[Encoding] = self.tokenizer.encode_batch([document.text for document in batch])
for document, encoded in zip(batch, encoded_batch):
count = len(encoded.ids)
if self.count_eos_token:
count += 1
document.metadata["token_count"] = count
self.stat_update("tokens", value=count)
yield document
self.stat_update("tokens", value=count)
yield document


class LengthCounter(PipelineStep):
Expand Down
21 changes: 13 additions & 8 deletions src/datatrove/pipeline/tokens/merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,22 +105,25 @@ def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1
f"({len(datafiles)} vs {len(datafiles_index)} vs {len(datafiles_loss)})"
)

tokenizer_name_or_path, token_size = None, 2
if self.save_final_metadata:
if self.input_folder.isfile(f"{datafiles[0]}.metadata"):
with self.input_folder.open(f"{datafiles[0]}.metadata", "rt") as f:
tokenizer_name_or_path = f.read().splitlines()[0]
if "|" in tokenizer_name_or_path:
tokenizer_name_or_path, token_size = tokenizer_name_or_path.split("|")
token_size = int(token_size)

doc_ends = [load_doc_ends(self.input_folder.open(file, "rb")) for file in datafiles_index]
token_inputs = list(
map(partial(get_data_reader, nb_bytes=2), self.input_folder.open_files(datafiles), doc_ends)
map(partial(get_data_reader, nb_bytes=token_size), self.input_folder.open_files(datafiles), doc_ends)
)
loss_inputs = (
list(map(partial(get_data_reader, nb_bytes=1), self.input_folder.open_files(datafiles_loss), doc_ends))
if self.save_loss_metadata
else None
)

tokenizer_name_or_path = None
if self.save_final_metadata:
if self.input_folder.isfile(f"{datafiles[0]}.metadata"):
with self.input_folder.open(f"{datafiles[0]}.metadata", "rt") as f:
tokenizer_name_or_path = f.read().splitlines()[0]

ordering = self.get_ordering(doc_ends)

file_ct = 0
Expand All @@ -131,6 +134,7 @@ def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1
upload_block_size=self.upload_block_size,
tokenizer_name_or_path=tokenizer_name_or_path,
save_final_metadata=self.save_final_metadata,
token_size=token_size,
)
for input_file_id in tqdm(
ordering, desc="Merging documents", unit="documents", total=len(ordering), disable=not self.progress
Expand All @@ -147,13 +151,14 @@ def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1
upload_block_size=self.upload_block_size,
tokenizer_name_or_path=tokenizer_name_or_path,
save_final_metadata=self.save_final_metadata,
token_size=token_size,
)
# copy tokens and loss
tokens = next(token_inputs[input_file_id])
output_file.write_bytes(tokens)
if loss_inputs:
output_file.write_loss_bytes(next(loss_inputs[input_file_id]))
self.stat_update("tokens", value=len(tokens) // 2)
self.stat_update("tokens", value=len(tokens) // token_size)
# cleanup
output_file.close()
if self.save_final_metadata:
Expand Down
Loading

0 comments on commit 4e9235f

Please sign in to comment.