diff --git a/docs/user-guide/api/deduplication.rst b/docs/user-guide/api/deduplication.rst
index 81b17667..1eeb4bd7 100644
--- a/docs/user-guide/api/deduplication.rst
+++ b/docs/user-guide/api/deduplication.rst
@@ -13,12 +13,21 @@ Exact
Fuzzy
------------------------
+.. autoclass:: nemo_curator.BucketsToEdges
+ :members:
+
+.. autoclass:: nemo_curator.ConnectedComponents
+ :members:
+
.. autoclass:: nemo_curator.FuzzyDuplicatesConfig
:members:
.. autoclass:: nemo_curator.FuzzyDuplicates
:members:
+.. autoclass:: nemo_curator.JaccardSimilarity
+ :members:
+
.. autoclass:: nemo_curator.LSH
:members:
diff --git a/nemo_curator/modules/__init__.py b/nemo_curator/modules/__init__.py
index b792d807..897e5402 100644
--- a/nemo_curator/modules/__init__.py
+++ b/nemo_curator/modules/__init__.py
@@ -30,50 +30,62 @@
from .task import TaskDecontamination
# GPU packages
-LSH = gpu_only_import_from("nemo_curator.modules.fuzzy_dedup", "LSH")
-MinHash = gpu_only_import_from("nemo_curator.modules.fuzzy_dedup", "MinHash")
-FuzzyDuplicates = gpu_only_import_from(
- "nemo_curator.modules.fuzzy_dedup", "FuzzyDuplicates"
+MinHash = gpu_only_import_from("nemo_curator.modules.fuzzy_dedup.minhash", "MinHash")
+LSH = gpu_only_import_from("nemo_curator.modules.fuzzy_dedup.lsh", "LSH")
+JaccardSimilarity = gpu_only_import_from(
+ "nemo_curator.modules.fuzzy_dedup.jaccardsimilarity", "JaccardSimilarity"
)
BucketsToEdges = gpu_only_import_from(
- "nemo_curator.modules.fuzzy_dedup", "BucketsToEdges"
+ "nemo_curator.modules.fuzzy_dedup.bucketstoedges", "BucketsToEdges"
+)
+ConnectedComponents = gpu_only_import_from(
+ "nemo_curator.modules.fuzzy_dedup.connectedcomponents", "ConnectedComponents"
+)
+FuzzyDuplicates = gpu_only_import_from(
+ "nemo_curator.modules.fuzzy_dedup.fuzzyduplicates", "FuzzyDuplicates"
)
-SemDedup = gpu_only_import_from("nemo_curator.modules.semantic_dedup", "SemDedup")
EmbeddingCreator = gpu_only_import_from(
- "nemo_curator.modules.semantic_dedup", "EmbeddingCreator"
+ "nemo_curator.modules.semantic_dedup.embeddings", "EmbeddingCreator"
)
ClusteringModel = gpu_only_import_from(
- "nemo_curator.modules.semantic_dedup", "ClusteringModel"
+ "nemo_curator.modules.semantic_dedup.clusteringmodel", "ClusteringModel"
)
SemanticClusterLevelDedup = gpu_only_import_from(
- "nemo_curator.modules.semantic_dedup", "SemanticClusterLevelDedup"
+ "nemo_curator.modules.semantic_dedup.semanticclusterleveldedup",
+ "SemanticClusterLevelDedup",
+)
+SemDedup = gpu_only_import_from(
+ "nemo_curator.modules.semantic_dedup.semdedup", "SemDedup"
)
-# Pytorch related imports must come after all imports that require cugraph,
-# because of context cleanup issues b/w pytorch and cugraph
+
+# PyTorch-related imports must come after all imports that require cuGraph
+# because of context cleanup issues between PyTorch and cuGraph
# See this issue: https://github.com/rapidsai/cugraph/issues/2718
from .filter import Filter, Score, ScoreFilter, ParallelScoreFilter
__all__ = [
+ "AddId",
+ "FuzzyDuplicatesConfig",
+ "SemDedupConfig",
+ "blend_datasets",
+ "Shuffle",
"ExactDuplicates",
"Filter",
- "FuzzyDuplicatesConfig",
- "FuzzyDuplicates",
- "BucketsToEdges",
- "LSH",
- "MinHash",
- "Modify",
"Score",
"ScoreFilter",
"ParallelScoreFilter",
"Sequential",
+ "Modify",
"TaskDecontamination",
- "AddId",
- "blend_datasets",
- "Shuffle",
- "SemDedup",
- "SemDedupConfig",
+ "MinHash",
+ "LSH",
+ "JaccardSimilarity",
+ "BucketsToEdges",
+ "ConnectedComponents",
+ "FuzzyDuplicates",
"EmbeddingCreator",
"ClusteringModel",
"SemanticClusterLevelDedup",
+ "SemDedup",
]
diff --git a/nemo_curator/modules/fuzzy_dedup.py b/nemo_curator/modules/fuzzy_dedup.py
deleted file mode 100644
index 3afc4123..00000000
--- a/nemo_curator/modules/fuzzy_dedup.py
+++ /dev/null
@@ -1,1798 +0,0 @@
-# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from __future__ import annotations
-
-import logging
-import math
-import os
-import time
-import warnings
-from itertools import pairwise
-from typing import List, Optional, Tuple, Union
-
-import cudf
-import cugraph.dask as dcg
-import cugraph.dask.comms.comms as Comms
-import cupy as cp
-import dask_cudf
-import numpy as np
-import pandas as pd
-import pyarrow as pa
-from cugraph import MultiGraph
-from dask import dataframe as dd
-from dask.utils import M
-from tqdm import tqdm
-
-from nemo_curator._compat import MINHASH_DEPRECATED_API, MINHASH_PERMUTED_AVAILABLE
-from nemo_curator.datasets import DocumentDataset
-from nemo_curator.log import create_logger
-from nemo_curator.modules.config import FuzzyDuplicatesConfig
-from nemo_curator.modules.meta import Sequential
-from nemo_curator.utils.distributed_utils import (
- get_current_client,
- get_num_workers,
- performance_report_if_with_ts_suffix,
-)
-from nemo_curator.utils.fuzzy_dedup_utils.id_mapping import int_ids_to_str
-from nemo_curator.utils.fuzzy_dedup_utils.io_utils import (
- aggregated_anchor_docs_with_bk_read,
- check_empty_buckets,
- get_restart_offsets,
- update_restart_offsets,
-)
-from nemo_curator.utils.fuzzy_dedup_utils.merge_utils import (
- extract_partitioning_index,
- filter_text_rows_by_bucket_batch,
- merge_left_to_shuffled_right,
-)
-from nemo_curator.utils.fuzzy_dedup_utils.output_map_utils import (
- build_partition,
- get_agg_text_bytes_df,
-)
-from nemo_curator.utils.fuzzy_dedup_utils.shuffle_utils import write_partitioned_file
-
-
-class MinHash:
- """
- Computes minhash signatures of a document corpus
- """
-
- def __init__(
- self,
- seed: int = 42,
- num_hashes: int = 260,
- char_ngrams: int = 5,
- use_64bit_hash: bool = False,
- logger: Union[logging.LoggerAdapter, str] = "./",
- id_field: str = "id",
- text_field: str = "text",
- profile_dir: str = None,
- cache_dir: str = None,
- ):
- """
- Parameters
- ----------
- seed: Seed for minhash permutations
- num_hashes: Length of minhash signature (No. of minhash permutations)
- char_ngrams: Width of text window (in characters) while computing minhashes.
- use_64bit_hash: Whether to use a 64 bit hash function.
- logger: Existing logger to log to, or a path to a log directory.
- id_field: Column in the Dataset denoting document ID.
- text_field: Column in the Dataset denoting document content.
- profile_dir: str, Default None
- If specified directory to write dask profile
- cache_dir: str, Default None
- If specified, will compute & write id, minhash pairs to directory
- """
- self.num_hashes = num_hashes
- self.char_ngram = char_ngrams
- if MINHASH_DEPRECATED_API:
- self.seeds = self.generate_seeds(n_seeds=self.num_hashes, seed=seed)
- else:
- self.seeds = self.generate_hash_permutation_seeds(
- bit_width=64 if use_64bit_hash else 32,
- n_permutations=self.num_hashes,
- seed=seed,
- )
-
- self.minhash_method = self.minhash64 if use_64bit_hash else self.minhash32
-
- self.id_field = id_field
- self.text_field = text_field
-
- if cache_dir is None and profile_dir is not None:
- warnings.warn(
- "cache_dir for intermediate outputs is required to generate profiles"
- )
- self.cache_dir = cache_dir
- self.profile_dir = profile_dir
-
- if isinstance(logger, str):
- self._logger = create_logger(
- rank=0,
- log_file=os.path.join(logger, "Minhash.log"),
- name="Minhash",
- )
- else:
- self._logger = logger
-
- def generate_seeds(self, n_seeds: int = 260, seed: int = 0) -> np.ndarray:
- """
- Generate seeds for all minhash permutations based on the given seed.
- """
- gen = np.random.RandomState(seed)
- return gen.randint(0, 1e6, size=n_seeds)
-
- def generate_hash_permutation_seeds(
- self, bit_width: int, n_permutations: int = 260, seed: int = 0
- ) -> np.ndarray:
- """
- Generate seeds for all minhash permutations based on the given seed.
- """
- gen = np.random.RandomState(seed)
-
- if bit_width == 32:
- MERSENNE_PRIME = np.uint32((1 << 31) - 1)
- dtype = np.uint32
- elif bit_width == 64:
- # For 64-bit, use a larger prime number suitable for 64-bit operations
- MERSENNE_PRIME = np.uint64((1 << 61) - 1)
- dtype = np.uint64
- else:
- raise ValueError("Unsupported bit width. Use either 32 or 64.")
-
- return np.array(
- [
- (
- gen.randint(1, MERSENNE_PRIME, dtype=dtype),
- gen.randint(0, MERSENNE_PRIME, dtype=dtype),
- )
- for _ in range(n_permutations)
- ],
- dtype=dtype,
- )
-
- def minhash32(
- self, ser: cudf.Series, seeds: np.ndarray, char_ngram: int
- ) -> cudf.Series:
- """
- Compute 32bit minhashes based on the MurmurHash3 algorithm
- """
- if not isinstance(ser, cudf.Series):
- raise TypeError("Expected data of type cudf.Series")
-
- if MINHASH_DEPRECATED_API:
- warnings.warn(
- "Using an outdated minhash implementation, please update to cuDF version 24.12 "
- "or later for improved performance. "
- "Install the latest version of cuDF using `pip install curator[cuda12x_nightly]`",
- category=FutureWarning,
- )
- seeds = cudf.Series(seeds, dtype="uint32")
- return ser.str.minhash(seeds=seeds, width=char_ngram)
- else:
- seeds_a = cudf.Series(seeds[:, 0], dtype="uint32")
- seeds_b = cudf.Series(seeds[:, 1], dtype="uint32")
-
- if MINHASH_PERMUTED_AVAILABLE:
- return ser.str.minhash_permuted(
- a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
- )
- else:
- return ser.str.minhash(
- a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
- )
-
- def minhash64(
- self, ser: cudf.Series, seeds: np.ndarray, char_ngram: int
- ) -> cudf.Series:
- """
- Compute 64bit minhashes based on the MurmurHash3 algorithm
- """
- if not isinstance(ser, cudf.Series):
- raise TypeError("Expected data of type cudf.Series")
- if MINHASH_DEPRECATED_API:
- warnings.warn(
- "Using an outdated minhash implementation, please update to cuDF version 24.12 "
- "or later for improved performance. "
- "Install the latest version of cuDF using `pip install curator[cuda12x_nightly]`",
- category=FutureWarning,
- )
- seeds = cudf.Series(seeds, dtype="uint64")
- return ser.str.minhash64(seeds=seeds, width=char_ngram)
- else:
- seeds_a = cudf.Series(seeds[:, 0], dtype="uint64")
- seeds_b = cudf.Series(seeds[:, 1], dtype="uint64")
-
- if MINHASH_PERMUTED_AVAILABLE:
- return ser.str.minhash64_permuted(
- a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
- )
- else:
- return ser.str.minhash64(
- a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
- )
-
- def __call__(self, dataset: DocumentDataset) -> Union[str, DocumentDataset]:
- """
- Computes the MinHash Signatures for a given dataset.
- Parameters
- ----------
- dataset: DocumentDataset
- The input datset to compute MinHashes.
- Returns
- -------
- DocumentDataset containing IDs of all documents and the corresponding MinHash Signature
- """
- result = dataset.df[[self.id_field]]
- result["_minhash_signature"] = dataset.df[self.text_field].map_partitions(
- self.minhash_method,
- seeds=self.seeds,
- char_ngram=self.char_ngram,
- )
-
- if self.cache_dir is None:
- return DocumentDataset(result)
-
- t0 = time.time()
- self._logger.info("Starting execution for Minhashes")
- write_path = os.path.join(self.cache_dir, "_minhashes.parquet")
- if os.path.exists(write_path):
- warnings.warn(
- f"Output path {write_path} already exists and will be overwritten"
- )
- with performance_report_if_with_ts_suffix(self.profile_dir, "minhash-profile"):
- result.to_parquet(write_path, write_index=False, overwrite=True)
- self._logger.info(
- f"Time taken for Minhash signature computation = {time.time() - t0}s and output written at {write_path}"
- )
- return DocumentDataset(
- dask_cudf.read_parquet(write_path, blocksize="2GB", aggregate_files=True)
- )
-
-
-class LSH:
- """
- Performs LSH on a MinhashSignatures
- """
-
- def __init__(
- self,
- cache_dir: str,
- num_hashes: int,
- num_buckets: int,
- buckets_per_shuffle: int = 1,
- false_positive_check: bool = False,
- logger: Union[logging.LoggerAdapter, str] = "./",
- id_fields: Union[str, list] = "id",
- minhash_field: str = "_minhash_signature",
- profile_dir: Optional[str] = None,
- ):
- """
- Parameters
- ----------
- cache_dir: str
- Needs to be specified, will compute & write duplicate id, bucket pairs to cache directory.
- num_hashes: Length of minhash signature
- num_buckets: Number of bands/buckets to create from the minhash signature.
- Hashes_per_signature = num_hashes / num_buckets
- buckets_per_shuffle: Number of bands/buckets to shuffle concurrently.
- but might lead to memory pressures and related errors.
- false_positive_check: bool
- If True, writes out buckets in a format compatible with downstream false positive check.
- logger: Existing logger to log to, or a path to a log directory.
- id_field: Columns in the Dataset denoting document ID.
- minhash_field: Column in the Dataset denoting minhash signature.
- profile_dir: str, Default None
- If specified directory to write dask profile
- """
- self.num_hashes = num_hashes
- self.num_buckets = num_buckets
- self.id_fields = [id_fields] if isinstance(id_fields, str) else id_fields
- self.minhash_field = minhash_field
- self.buckets_per_shuffle = buckets_per_shuffle
- self.bucket_ranges = self._generate_bucket_ranges(
- self.num_buckets, self.num_hashes
- )
- self.buckets_as_int = false_positive_check
-
- if cache_dir is None:
- raise ValueError(
- "cache_dir for intermediate outputs is required for this stage"
- )
- self.cache_dir = cache_dir
- self.profile_dir = profile_dir
-
- if isinstance(logger, str):
- self._logger = create_logger(
- rank=0,
- log_file=os.path.join(logger, "LSH.log"),
- name="LSH",
- )
- else:
- self._logger = logger
-
- def _generate_bucket_ranges(
- self, num_buckets: int, num_hashes: int
- ) -> List[List[int]]:
- """
- Generates a list of indices for the minhash ranges given num_bands &
- num_hashes.
- eg: num_bands=3, num_hashes=6
- [[0, 1], [2, 3], [4, 5]]
- """
- minhashes_per_bucket = num_hashes // num_buckets
-
- bucket_ranges = [
- list(
- range(
- bucket * minhashes_per_bucket, (bucket + 1) * minhashes_per_bucket
- )
- )
- for bucket in range(num_buckets)
- ]
- return bucket_ranges
-
- def minhash_to_buckets(
- self,
- df: cudf.DataFrame,
- bucket_ranges: List[List[int]],
- ) -> cudf.DataFrame:
- df2 = df[self.id_fields]
- for i, h in enumerate(bucket_ranges):
- indices = cudf.Series([h]).repeat(len(df2))
- df2[f"_bucket_{i}"] = f"b{i}_" + df[self.minhash_field].list.take(
- indices
- ).hash_values(method="md5")
- return df2
-
- def bucket_id_to_int(
- self,
- bucket_ddf: dask_cudf.DataFrame,
- bucket_col_name: str = "bucket_id",
- start_id: int = 0,
- ) -> Tuple[dask_cudf.DataFrame, int]:
- """
- Maps bucket ids to a contigious integer range from starting from start_id.
- """
- unique_bucket_df = (
- bucket_ddf[[bucket_col_name]]
- .map_partitions(lambda x: x.drop_duplicates(ignore_index=True))
- .persist()
- )
- end_bucket_id = len(unique_bucket_df) - 1 + start_id
- unique_bucket_df["bucket_int_id"] = np.uint64(1)
- unique_bucket_df["bucket_int_id"] = unique_bucket_df["bucket_int_id"].cumsum()
- unique_bucket_df["bucket_int_id"] = (
- unique_bucket_df["bucket_int_id"] - 1 + start_id
- )
- bucket_ddf = bucket_ddf.merge(unique_bucket_df, on=[bucket_col_name])
- bucket_ddf = bucket_ddf.drop(columns=[bucket_col_name])
- bucket_ddf = bucket_ddf.rename(columns={"bucket_int_id": "_bucket_id"})
- bucket_ddf["_bucket_id"] = bucket_ddf["_bucket_id"].astype(np.uint64)
- return (bucket_ddf, end_bucket_id)
-
- def _minhash_to_bucket_meta(
- self, df: dask_cudf.DataFrame
- ) -> Tuple[cudf.DataFrame, int]:
- meta = df._meta_nonempty[self.id_fields]
- meta[self.minhash_field] = [np.ones(self.num_hashes)] * len(meta)
- return self.minhash_to_buckets(meta, self.bucket_ranges)
-
- def lsh(
- self,
- write_path: str,
- df: dask_cudf.DataFrame,
- ) -> bool:
- """
- Computes hash buckets for the DataFrame and writes them as parquet files to the specified path.
-
- Parameters:
- - write_path (str): The directory path to write parquet files.
- - df (dask_cudf.DataFrame): The input DataFrame with minhashes to be bucketed.
- Returns:
- are_buckets_empty: True if buckets were empty (no duplicates found), False otherwise.
- """
- wrote_buckets = False
- are_buckets_empty = True
-
- meta = self._minhash_to_bucket_meta(df)
- df = df.map_partitions(
- self.minhash_to_buckets,
- bucket_ranges=self.bucket_ranges,
- meta=meta,
- )
- bucket_start_id = 0
- for i in range(0, self.num_buckets, self.buckets_per_shuffle):
- bucket_columns = [
- f"_bucket_{i}"
- for i in range(i, min(self.num_buckets, i + self.buckets_per_shuffle))
- ]
- df2 = df.melt(
- id_vars=self.id_fields,
- value_name="_bucket_id",
- value_vars=bucket_columns,
- )[self.id_fields + ["_bucket_id"]]
-
- df2 = df2.shuffle(
- on=["_bucket_id"],
- ignore_index=True,
- npartitions=max(1, 2 ** math.floor(math.log2(df2.npartitions))),
- ).map_partitions(lambda x: x[x["_bucket_id"].duplicated(keep=False)])
-
- df2 = df2.reset_index(drop=True)
- # Buckets to Int
- if self.buckets_as_int:
- df2, end_id = self.bucket_id_to_int(
- df2, bucket_col_name="_bucket_id", start_id=bucket_start_id
- )
- # If bucketing return empty dataframe
- if end_id < bucket_start_id:
- self._logger.info(
- f"No duplicate documents found for buckets: {bucket_columns}"
- )
- continue
- bucket_start_id = end_id + 1
- are_buckets_empty = False
-
- wrote_buckets, are_buckets_empty = self._write_bucket_parquet(
- df2,
- write_path,
- wrote_buckets,
- are_buckets_empty,
- bucket_columns,
- )
-
- if are_buckets_empty:
- self._logger.info("No duplicate documents found during LSH")
- if os.path.exists(write_path):
- import shutil
-
- shutil.rmtree(write_path)
-
- return are_buckets_empty
-
- def _write_bucket_parquet(
- self,
- df: dask_cudf.DataFrame,
- write_path: str,
- wrote_buckets: bool,
- are_buckets_empty: bool,
- buckets_to_write: List[str],
- ) -> tuple[bool, bool]:
- """
- Utility function to write the bucketed data to parquet
- handling cases of overwriting and appending as needed.
- """
- if not wrote_buckets:
- if os.path.exists(write_path):
- warnings.warn(
- f"Output path {write_path} already exists and will be overwritten"
- )
- df.to_parquet(write_path, write_index=False, overwrite=True)
- else:
- df.to_parquet(
- write_path,
- write_index=False,
- overwrite=are_buckets_empty,
- append=not are_buckets_empty,
- ignore_divisions=True,
- )
- # Only check if buckets written so far are empty
- if are_buckets_empty:
- are_buckets_empty = check_empty_buckets(write_path)
- wrote_buckets = True
-
- if are_buckets_empty:
- self._logger.info(
- f"No duplicate documents found for buckets: {buckets_to_write}"
- )
- else:
- self._logger.info(f"Wrote data for buckets: {buckets_to_write}")
- return wrote_buckets, are_buckets_empty
-
- def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
- df = dataset.df
-
- write_path = os.path.join(self.cache_dir, "_buckets.parquet")
- t0 = time.time()
- with performance_report_if_with_ts_suffix(self.profile_dir, "lsh-profile"):
- empty_result = self.lsh(write_path=write_path, df=df)
- self._logger.info(
- f"Time taken for LSH = {time.time() - t0}s and output written at {write_path}"
- )
- if empty_result:
- return None
- buckets_df = dask_cudf.read_parquet(write_path, split_row_groups=False)
- return DocumentDataset(buckets_df)
-
-
-class FuzzyDuplicates:
- def __init__(
- self,
- config: FuzzyDuplicatesConfig,
- logger: Union[logging.LoggerAdapter, str] = "./",
- ):
- """
- Parameters
- ----------
- config: FuzzyDuplicatesConfig,
- Config options for finding FuzzyDuplicates
- logger: Existing logger to log to, or a path to a log directory.
-
- Returns
- -------
- DocumentDataset containing IDs of all documents and the corresponding duplicate group
- they belong to. Documents in the same group are near duplicates.
- """
- if isinstance(logger, str):
- self._logger = create_logger(
- rank=0,
- log_file=os.path.join(logger, "FuzzyDuplicates.log"),
- name="FuzzyDuplicates",
- )
- else:
- self._logger = logger
-
- self.config = config
- self.minhash = MinHash(
- seed=self.config.seed,
- num_hashes=self.config.num_hashes,
- char_ngrams=self.config.char_ngrams,
- use_64bit_hash=self.config.use_64_bit_hash,
- logger=self._logger,
- id_field=self.config.id_field,
- text_field=self.config.text_field,
- profile_dir=self.config.profile_dir,
- cache_dir=self.config.cache_dir,
- )
- self.lsh = LSH(
- cache_dir=self.config.cache_dir,
- num_hashes=self.config.num_hashes,
- num_buckets=self.config.num_buckets,
- buckets_per_shuffle=self.config.buckets_per_shuffle,
- false_positive_check=self.config.false_positive_check,
- logger=self._logger,
- id_fields=[self.config.id_field],
- profile_dir=self.config.profile_dir,
- )
-
- if self.config.false_positive_check:
- self.map_buckets = _MapBuckets(
- id_fields=[self.config.id_field],
- text_field=self.config.text_field,
- logger=self._logger,
- num_anchors=self.config.num_anchors,
- )
- self.jaccard_shuffle = _Shuffle(
- id_fields=[self.config.id_field],
- text_field=self.config.text_field,
- logger=self._logger,
- profile_dir=self.config.profile_dir,
- )
- self.jaccard_compute = JaccardSimilarity(
- id_field=self.config.id_field,
- text_field=self.config.text_field,
- ngram_width=self.config.char_ngrams,
- anchor_id_fields=[
- f"anchor_{i}_{self.config.id_field}"
- for i in range(self.config.num_anchors)
- ],
- )
- else:
- self.buckets_to_edges = BucketsToEdges(
- cache_dir=self.config.cache_dir,
- id_fields=self.config.id_field,
- logger=self._logger,
- profile_dir=self.config.profile_dir,
- )
-
- jaccard_pairs_fname = (
- "jaccard_similarity_results.parquet"
- if self.config.false_positive_check
- else "_edges.parquet"
- )
- self.connected_components = ConnectedComponents(
- cache_dir=self.config.cache_dir,
- jaccard_pairs_path=os.path.join(self.config.cache_dir, jaccard_pairs_fname),
- id_column=self.config.id_field,
- jaccard_threshold=self.config.jaccard_threshold,
- logger=self._logger,
- profile_dir=self.config.profile_dir,
- )
-
- def __call__(self, dataset: DocumentDataset):
- """
- Parameters
- ----------
- dataset: DocumentDataset
- The input datset to compute FuzzyDuplicates. Must contain a text and unique id field.
-
- Returns
- -------
- DocumentDataset containing IDs of all documents and the corresponding duplicate group
- they belong to. Documents in the same group are near duplicates.
- """
-
- # Minhash + LSH
- stage_num = 1
- print(f"Stage{stage_num}: Starting Minhash + LSH computation")
- minhashLSH = Sequential([self.minhash, self.lsh])
- buckets_df = minhashLSH(dataset)
- print(f"Stage{stage_num}: Minhash + LSH complete!")
- if buckets_df is None:
- print(
- f"Stage{stage_num}: No potential duplicate documents found during LSH"
- )
- return None
- stage_num += 1
-
- if self.config.false_positive_check:
- # Map buckets to lower cardinality distribution
- print(f"Stage{stage_num} (False Positive Check): Starting Map_Buckets")
- t0 = time.time()
- mapped_buckets_w_anchors_path = os.path.join(
- self.config.cache_dir, "anchor_docs_with_bk.parquet"
- )
- with performance_report_if_with_ts_suffix(
- self.config.profile_dir,
- "map_buckets",
- ):
- ddf_mapped_buckets_w_anchors = (
- self.map_buckets.map_buckets_with_anchors(
- documents_df=dataset.df, buckets_df=buckets_df.df
- )
- )
- ddf_mapped_buckets_w_anchors.to_parquet(
- mapped_buckets_w_anchors_path, write_index=False, overwrite=True
- )
- self._logger.info(
- f"Time taken for Map_buckets : {time.time() - t0}s and output written at {mapped_buckets_w_anchors_path}"
- )
-
- print(f"Stage{stage_num} (False Postive Check): Map_Buckets Complete!")
- stage_num += 1
-
- # Shuffle documents based on mapped buckets
- print(f"Stage{stage_num} (False Postive Check): Shuffle docs")
- shuffled_docs_path = os.path.join(
- self.config.cache_dir, "shuffled_docs.parquet"
- )
- self.jaccard_shuffle.shuffle_docs_on_buckets(
- documents_df=dataset.df,
- bucket_w_anchors_path=mapped_buckets_w_anchors_path,
- output_shuffled_docs_path=shuffled_docs_path,
- bucket_mapping_df_blocksize=self.config.bucket_mapping_blocksize,
- parts_per_worker=self.config.parts_per_worker,
- bucket_parts_per_worker=self.config.bucket_parts_per_worker,
- )
- print(f"Stage{stage_num} (False Postive Check): Shuffle docs complete!")
- stage_num += 1
-
- # jaccard comparision within buckets
- print(
- f"Stage{stage_num} (False Postive Check): Jaccard Similarity in Buckets"
- )
- jaccard_pairs_path = os.path.join(
- self.config.cache_dir, "jaccard_similarity_results.parquet"
- )
- t0 = time.time()
- with performance_report_if_with_ts_suffix(
- self.config.profile_dir,
- "jaccard-similarity",
- ):
- jaccard_pairs_df = self.jaccard_compute.jaccard_compute(
- shuffled_docs_path=shuffled_docs_path
- )
- jaccard_pairs_df.to_parquet(
- jaccard_pairs_path,
- write_index=False,
- write_metadata_file=False,
- overwrite=True,
- )
- self._logger.info(
- f"Time taken for Jaccard Similarity = {time.time()-t0}s and output written at {jaccard_pairs_path}"
- )
-
- print(
- f"Stage{stage_num} (False Postive Check): Jaccard Similarity in Buckets Complete!"
- )
- stage_num += 1
-
- else:
- # Map buckets to lower cardinality distribution
- print(f"Stage{stage_num}: Starting LSH Buckets to Graph edgelist")
- self.buckets_to_edges(buckets_df)
- print(f"Stage{stage_num}: Starting LSH Buckets to Graph edgelist Complete!")
- stage_num += 1
-
- # Connected components across buckets
- print(f"Stage{stage_num}: Connected Components across buckets")
- cc_path = os.path.join(self.config.cache_dir, "connected_components.parquet")
- self.connected_components.cc_workflow(cc_path)
- print(f"Stage{stage_num}: Connected Components across buckets complete!")
- stage_num += 1
-
- return DocumentDataset(dask_cudf.read_parquet(cc_path, split_row_groups=False))
-
-
-class BucketsToEdges:
- """
- Maps buckets generated from LSH into an edgelist that
- can be processed further by Connected Components to find duplicate
- documents
- """
-
- def __init__(
- self,
- cache_dir: str = None,
- id_fields: Union[list, str] = "id",
- str_id_name: str = "id",
- bucket_field: str = "_bucket_id",
- logger: Union[logging.LoggerAdapter, str] = "./",
- profile_dir: Optional[str] = None,
- ):
- """
- Parameters
- ----------
- cache_dir: str or None
- If specified, will compute & write the edgelist to a file
- id_fields: list or str
- id fields of documents in buckets_df
- str_id_name: str
- Ignored if there is a single id field. Multiple id fields
- will be combined into a single id field with the given name.
- bucket_field: str
- Column denoting bucket ID
- num_buckets: Number of bands/buckets to create from the minhash signature.
- Hashes_per_signature = num_hashes / num_buckets
- """
- self.cache_dir = cache_dir
- self.id_fields = [id_fields] if isinstance(id_fields, str) else id_fields
- self.str_id_name = str_id_name if len(self.id_fields) > 1 else self.id_fields[0]
- self.output_ids = [f"{self.str_id_name}_x", f"{self.str_id_name}_y"]
- self.bucket_field = bucket_field
- self.profile_dir = profile_dir
- if isinstance(logger, str):
- self._logger = create_logger(
- rank=0,
- log_file=os.path.join(logger, "Buckets_to_Edges.log"),
- name="Buckets_to_Edges",
- )
- else:
- self._logger = logger
-
- @staticmethod
- def _combine_multiple_ids(
- input_df: cudf.DataFrame, input_id_fields: list, output_id_field: str
- ) -> cudf.DataFrame:
- if output_id_field in input_df.columns:
- raise ValueError(
- f"Input df already contains column named: {output_id_field}"
- )
-
- output_df = input_df.copy()[input_df.columns.difference(input_id_fields)]
-
- output_df[output_id_field] = input_df[input_id_fields[0]].astype(str)
- for input_field in input_id_fields[1:]:
- output_df[output_id_field] = output_df[output_id_field] = (
- input_df[input_id_fields[0]].astype(str)
- + "-"
- + input_df[input_field].astype(str)
- )
-
- return output_df
-
- def buckets_to_edges(
- self,
- buckets_df: cudf.DataFrame,
- ) -> cudf.DataFrame:
-
- grouped_buckets = (
- buckets_df.groupby(self.bucket_field)[self.str_id_name]
- .agg(list)
- .list.sort_values()
- )
- bucket_docs = grouped_buckets.to_arrow().to_pylist()
- edges = []
- # Create pairs of all documents within a bucket since they are near duplicates
- # Effectively create a edge list of all near duplicate documents
- for bucket_doc in bucket_docs:
- edges.extend(pairwise(bucket_doc))
- edges = pd.DataFrame(edges, columns=self.output_ids)
- edges = pa.Table.from_pandas(edges)
- result_df = cudf.DataFrame.from_arrow(edges)
- del edges
- result_df = result_df.drop_duplicates(self.output_ids).reset_index(drop=True)
- result_df["jaccard"] = np.float32(1.0)
- return result_df
-
- def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
- buckets_df = dataset.df
- self._logger.info(f"Starting conversion of LSH Buckets to Graph Edgelist")
- if len(self.id_fields) > 1:
- buckets_df = buckets_df.map_partitions(
- BucketsToEdges._combine_multiple_ids,
- input_id_fields=self.id_fields,
- output_id_field=self.str_id_name,
- )
-
- meta = [(output_id, str) for output_id in self.output_ids]
- meta.append(("jaccard", np.float32))
- edges_df = buckets_df.map_partitions(self.buckets_to_edges, meta=meta)
-
- if self.cache_dir is None:
- return DocumentDataset(edges_df)
-
- write_path = os.path.join(self.cache_dir, "_edges.parquet")
- if os.path.exists(write_path):
- warnings.warn(
- f"Output path {write_path} already exists and will be overwritten"
- )
- t0 = time.time()
- with performance_report_if_with_ts_suffix(
- self.profile_dir,
- "bucket-to-edges",
- ):
- edges_df.to_parquet(write_path, write_index=False, overwrite=True)
- self._logger.info(
- f"Time taken for Converted Buckets To Edgelist = {time.time() - t0}s and output written at {write_path}"
- )
-
- return DocumentDataset(
- dask_cudf.read_parquet(write_path, split_row_groups=False)
- )
-
-
-class _MapBuckets:
- """
- buckets to a logical partition by using a modified bin packing algorithm.
- Combines buckets generated from LSH (typically high cardinality)
- to more coarse lower cardinality bucket groups by mapping multiple buckets
- to a logical partition using document length information and a modified bin
- packing algorithm.
- Only needed if running False Postive check to remove false positives.
- """
-
- def __init__(
- self,
- id_fields: Union[list, str] = "id",
- text_field: str = "text",
- bucket_field: str = "_bucket_id",
- num_anchors: int = 2,
- logger: Union[logging.LoggerAdapter, str] = "./",
- ):
- """
- id_fields: list or str
- id fields of df
- text_field: str = "text",
- bucket_column: str = "bucket_column",
- num_anchors: int = 2,
- logger: Union[logging.LoggerAdapter, str] = "./",
- """
- self.id_fields = [id_fields] if isinstance(id_fields, str) else id_fields
- self.text_field = text_field
- self.num_anchors = num_anchors
- self.bucket_field = bucket_field
- if isinstance(logger, str):
- self._logger = create_logger(
- rank=0,
- log_file=os.path.join(logger, "Map_Buckets.log"),
- name="Map_Buckets",
- )
- else:
- self._logger = logger
-
- @staticmethod
- def _get_output_part_ids_with_approx_equal_sum(
- bucket_text_bytes_df: cudf.DataFrame,
- max_text_bytes_per_part: int,
- buckets_column: str,
- bytes_column: str,
- output_partition_column: str,
- ) -> cudf.DataFrame:
- """
- Create a output_series that maps the ser.index into `nparts`
- so that the total sum of bucket_val_counts_df
- for each output id are all most equal and
- less than max_text_bytes_per_part
- This is used downstream for creating equal output_ids
- """
- sizes = bucket_text_bytes_df[bytes_column].values
- bucket_output_ar = build_partition(
- sizes=sizes.get(), max_size=max_text_bytes_per_part
- )
- df = cudf.DataFrame()
- df[buckets_column] = bucket_text_bytes_df[buckets_column]
- df[output_partition_column] = bucket_output_ar
- return df
-
- def _get_output_map_from_text_bytes_per_bucket(
- self,
- ddf_bk_text_bytes,
- bytes_column,
- output_partition_column="_output_partition_id",
- ):
- # String bytes limit for cuDF
- # https://github.com/rapidsai/cudf/issues/13733
- max_text_bytes_per_part = int(np.iinfo(np.int32).max * 3)
-
- self._logger.info(f"max_text_bytes_per_part = {max_text_bytes_per_part}")
- # Increasing in an attempt to prevent hitting
- # ulimits
- output_map_df_meta = cudf.DataFrame(
- {self.bucket_field: [0], output_partition_column: [1]}
- )
- output_map_df_meta = output_map_df_meta.astype(
- {self.bucket_field: np.uint64, output_partition_column: np.int32}
- )
-
- output_map_df = ddf_bk_text_bytes.map_partitions(
- _MapBuckets._get_output_part_ids_with_approx_equal_sum,
- max_text_bytes_per_part=max_text_bytes_per_part,
- buckets_column=self.bucket_field,
- bytes_column=bytes_column,
- output_partition_column=output_partition_column,
- meta=output_map_df_meta,
- )
- output_map_df = output_map_df.persist()
- self._logger.info(
- f"Step 1 of output_map_df of len: {len(output_map_df)} computed"
- )
- lower_bounds = (
- output_map_df[output_partition_column]
- .map_partitions(lambda s: (s.max() + 1))
- .compute()
- )
- lower_bounds = np.cumsum(lower_bounds)
-
- def update_id(df, lower_bound):
- df[output_partition_column] += lower_bound
- return df
-
- updated_parts = [
- output_map_df.get_partition(i).map_partitions(
- update_id, lower_bounds[i - 1]
- )
- for i in range(1, len(lower_bounds))
- ]
- updated_parts.append(output_map_df.get_partition(0))
- output_map_df = dask_cudf.concat(updated_parts)
- output_map_df = output_map_df.persist()
- self._logger.info(
- f"All steps of output_map_df of len: {len(output_map_df)} computed"
- )
- return output_map_df
-
- def _get_output_map_based_on_str_bytes(
- self, buckets_df, documents_df, bytes_column="_text_bytes"
- ):
- """
- Add output_partition_id to buckets_ddf
- """
- documents_df = documents_df.copy()
- documents_df[bytes_column] = documents_df[self.text_field].map_partitions(
- lambda s: s.str.byte_count()
- )
- n_partitions = buckets_df.npartitions
- documents_df = documents_df.drop(columns=[self.text_field]).repartition(
- npartitions=n_partitions
- )
- buckets_df = buckets_df.merge(documents_df).repartition(
- npartitions=n_partitions
- )
- del documents_df
- ddf_bk_text_bytes, agg_df_len = get_agg_text_bytes_df(
- df=buckets_df,
- agg_column=self.bucket_field,
- bytes_column=bytes_column,
- n_partitions=n_partitions,
- shuffle=True,
- )
- self._logger.info(f"Agg_df computed of length = {agg_df_len}")
- del buckets_df
- output_map_df = self._get_output_map_from_text_bytes_per_bucket(
- ddf_bk_text_bytes=ddf_bk_text_bytes,
- bytes_column=bytes_column,
- )
- return output_map_df
-
- def _random_select_anchor(self, buckets_df, n=2):
- """
- Randomly select `n` anchors from each bucket.
- """
- buckets_df = buckets_df.copy()
- buckets_df["_id_hash"] = buckets_df[self.id_fields].hash_values()
- buckets_df = buckets_df.sort_values([self.bucket_field, "_id_hash"])
- buckets_df["_order_in_bucket"] = buckets_df.groupby(
- self.bucket_field
- ).cumcount()
- buckets_df["is_anchor"] = buckets_df["_order_in_bucket"] < n
- for i in range(0, n):
- buckets_df[f"is_anchor_id_{i}"] = buckets_df["_order_in_bucket"] == i
- buckets_df = buckets_df.drop(columns=["_id_hash", "_order_in_bucket"], axis=1)
- buckets_df = buckets_df.reset_index(drop=True)
- buckets_df = buckets_df[buckets_df.is_anchor]
- return buckets_df
-
- def _add_anchor_docs(self, buckets_df, num_anchors):
- """
- Get anchor documents for each bucket.
- """
- df_anchor_bk = self._random_select_anchor(buckets_df=buckets_df, n=num_anchors)
- df_anchor_docs = None
- for i in range(num_anchors):
- df_anchor_bk_i = df_anchor_bk[df_anchor_bk[f"is_anchor_id_{i}"]][
- [self.bucket_field] + self.id_fields
- ].reset_index(drop=True)
- column_mapping = {id: f"anchor_{i}_{id}" for id in self.id_fields}
- df_anchor_bk_i = df_anchor_bk_i.rename(columns=column_mapping)
- if i == 0:
- df_anchor_docs = df_anchor_bk_i
- else:
- df_anchor_docs = df_anchor_bk_i.merge(
- df_anchor_docs, on=[self.bucket_field], how="inner"
- )
-
- df_anchor_docs_with_bk = buckets_df.merge(
- df_anchor_docs, on=[self.bucket_field], how="inner"
- )
- return df_anchor_docs_with_bk
-
- def map_buckets_with_anchors(
- self,
- documents_df: dask_cudf.DataFrame,
- buckets_df: dask_cudf.DataFrame,
- shuffle_type: Union[str, bool, None] = "tasks",
- ) -> dask_cudf.DataFrame:
- """
- Get anchor docs with bucket info
- Args:
- input_data_paths: list of paths to input data
- input_bucket_path: path to input buckets
- text_ddf_blocksize: blocksize for text ddf
- num_files: number of files to read
- num_workers: number of workers
- shuffle_type: type of shuffle to use
- Returns:
- ddf_anchor_docs_with_bk
- """
- output_map_df = self._get_output_map_based_on_str_bytes(
- buckets_df=buckets_df, documents_df=documents_df
- )
- ddf_anchor_docs_with_bk = buckets_df.map_partitions(
- self._add_anchor_docs, num_anchors=self.num_anchors
- )
- self._logger.info("output_map_df is based on string bytes")
- ddf_anchor_docs_with_bk = ddf_anchor_docs_with_bk.merge(
- output_map_df, on=self.bucket_field
- )
- # Bucket is no longer needed
- ddf_anchor_docs_with_bk = ddf_anchor_docs_with_bk.drop(
- columns=[self.bucket_field]
- )
- # Below removes any duplicates lying around after dropping buckets
- ddf_anchor_docs_with_bk = ddf_anchor_docs_with_bk.map_partitions(
- M.drop_duplicates,
- meta=ddf_anchor_docs_with_bk._meta,
- enforce_metadata=False,
- transform_divisions=False,
- align_dataframes=False,
- )
- ddf_anchor_docs_with_bk = ddf_anchor_docs_with_bk.shuffle(
- self.id_fields,
- ignore_index=True,
- shuffle_method=shuffle_type,
- ).map_partitions(
- M.drop_duplicates,
- meta=ddf_anchor_docs_with_bk._meta,
- enforce_metadata=False,
- transform_divisions=False,
- align_dataframes=False,
- )
- del output_map_df
- return ddf_anchor_docs_with_bk
-
-
-class _Shuffle:
- def __init__(
- self,
- id_fields: Union[str, list] = "id",
- text_field: str = "text",
- logger: Union[logging.LoggerAdapter, str] = "./",
- profile_dir: str = None,
- int_to_str_id: str = None,
- ):
- if isinstance(logger, str):
- self._logger = create_logger(
- rank=0,
- log_file=os.path.join(logger, "LSH.log"),
- name="LSH",
- )
- else:
- self._logger = logger
-
- self.id_fields = id_fields
- self.text_field = text_field
- self.profile_dir = profile_dir
- self.int_to_str_id = int_to_str_id
-
- def shuffle_docs_on_buckets(
- self,
- documents_df: dask_cudf.DataFrame,
- bucket_w_anchors_path: str,
- output_shuffled_docs_path: str,
- bucket_mapping_df_blocksize,
- parts_per_worker: int = 1,
- bucket_parts_per_worker: int = 8,
- partition_on: str = "_output_partition_id",
- ):
-
- ddf_anchor_docs_with_bk, bk_mapping = aggregated_anchor_docs_with_bk_read(
- path=bucket_w_anchors_path,
- blocksize=bucket_mapping_df_blocksize,
- )
- self._logger.info("Getting ddf_anchor_docs_with_bk completed")
- self._logger.debug(
- f"ddf_anchor_docs_with_bk.npartitions = {ddf_anchor_docs_with_bk.npartitions}"
- )
- st = time.time()
- num_workers = get_num_workers(get_current_client())
- parts_per_batch = num_workers * parts_per_worker
- self._logger.debug(f"parts_per_batch = {parts_per_batch}")
- parts_per_bucket_batch = num_workers * bucket_parts_per_worker
- self._logger.debug(f"parts_per_bucket_batch = {parts_per_bucket_batch}")
-
- dask_profile_name = (
- "suffle_docs"
- + f"-parts_per_batch-{parts_per_batch}"
- + f"-parts_per_bucket_batch-{parts_per_bucket_batch}"
- )
- documents_df = documents_df[self.id_fields + [self.text_field]]
-
- with performance_report_if_with_ts_suffix(self.profile_dir, dask_profile_name):
- self._batched_merge_and_write(
- left_df=documents_df,
- right_df=ddf_anchor_docs_with_bk,
- output_path=output_shuffled_docs_path,
- merge_on=self.id_fields,
- partition_on=partition_on,
- parts_per_text_batch=parts_per_batch,
- parts_per_bucket_batch=parts_per_bucket_batch,
- bk_mapping=bk_mapping,
- num_workers=num_workers,
- )
- self._logger.info(
- f"Time taken for Shuffle = {time.time()-st}s and output written at {output_shuffled_docs_path}"
- )
-
- def _batched_merge_and_write(
- self,
- left_df: dask_cudf.DataFrame,
- right_df: dask_cudf.DataFrame,
- output_path: str,
- merge_on: List[str],
- partition_on: str,
- parts_per_text_batch: int,
- parts_per_bucket_batch: int,
- bk_mapping,
- num_workers: int = None,
- ):
- total_text_partitions = left_df.npartitions
- total_bucket_partitions = right_df.npartitions
-
- # Extract global partitioning index
- left_df, global_partitioning_index = extract_partitioning_index(
- left_df,
- merge_on,
- bk_mapping,
- parts_per_bucket_batch,
- total_bucket_partitions,
- )
-
- # Set start offsets
- bucket_part_start_offset, text_part_start_offset = get_restart_offsets(
- output_path
- )
-
- # Set end offsets
- # NOTE: These end offsets are always set to the end
- # of the data. However, we may want to be able to set
- # both the start and end offsets from the command line
- # in the future.
- bucket_part_end_offset = total_bucket_partitions
- text_part_end_offset = total_text_partitions
-
- # Check that offsets are valid
- assert bucket_part_start_offset % parts_per_bucket_batch == 0
- assert bucket_part_end_offset > bucket_part_start_offset
- assert text_part_end_offset > text_part_start_offset
-
- # Initialize "retry" variables
- #
- # - retry_count: The number of successive batches that
- # we have already performed at a reduced batch size.
- # - retry_threshold: The number of successive batches
- # for which we should keep the batch size low
- # before attempting the default batch size again.
- # Every time we return to the default batch size
- # and immediately fail, retry_threshold will double.
- parts_per_text_batch_retry = None
- retry_count, retry_threshold = 0, 1
-
- self._logger.info(
- f"Starting at bucket-map partition {bucket_part_start_offset}"
- f" and text-df partition {text_part_start_offset}",
- )
-
- for bucket_part_offset in tqdm(
- range(
- bucket_part_start_offset, bucket_part_end_offset, parts_per_bucket_batch
- )
- ):
-
- # Outer loop over batches of "bucket-map" partitions
- end_bucket_offset = min(
- bucket_part_offset + parts_per_bucket_batch, bucket_part_end_offset
- )
- print(
- f"\nStarted processing bucket-map partitions {bucket_part_offset} "
- f"through {end_bucket_offset} of {bucket_part_end_offset}",
- flush=True,
- )
- st_bucket = time.time()
-
- # Select our bucket-mapping batch
- subset_bucket_df = right_df.partitions[bucket_part_offset:end_bucket_offset]
- subset_bucket_df = subset_bucket_df.persist()
-
- # Filter out rows of left_df that we know cannot
- # align with any rows of subset_bucket_df
- left_df_use = filter_text_rows_by_bucket_batch(
- left_df,
- global_partitioning_index,
- bucket_part_offset,
- bucket_part_end_offset,
- total_bucket_partitions,
- )
-
- text_part_offset = text_part_start_offset
- while text_part_offset < text_part_end_offset:
-
- # Check if we are "retrying" with a smaller "parts_per_text_batch"
- if parts_per_text_batch_retry:
- parts_per_text_batch_use = parts_per_text_batch_retry
- else:
- st_text = time.time()
- parts_per_text_batch_use = parts_per_text_batch
- print(f"Using {parts_per_text_batch_use} text partitions.", flush=True)
-
- # Select partitions for our text batch
- end_text_offset = min(
- text_part_offset + parts_per_text_batch_use, text_part_end_offset
- )
- subset_text_df = left_df_use.partitions[
- text_part_offset:end_text_offset
- ]
- subset_merged_df = merge_left_to_shuffled_right(
- subset_text_df,
- subset_bucket_df,
- merge_on,
- )
- output_df = subset_merged_df.shuffle(on=partition_on)
-
- if self.int_to_str_id is not None and output_df is not None:
- output_df = output_df.map_partitions(
- int_ids_to_str, id_column=self.int_to_str_id
- )
- batch_label = f"{end_bucket_offset}_{end_text_offset}"
- if output_df is not None:
- written_files = output_df.map_partitions(
- write_partitioned_file,
- output_path,
- partition_on,
- batch_label,
- meta=cudf.Series([True]),
- )
- written_files = written_files.compute()
- update_restart_offsets(output_path, bucket_part_offset, end_text_offset)
- del output_df
-
- print(
- "Text-df partition ",
- f"{end_text_offset}/{text_part_end_offset} "
- f"completed in {time.time()-st_text}",
- flush=True,
- )
-
- # Update loop control-flow variables
- if parts_per_text_batch_use == parts_per_text_batch:
- # We succeeded at the default batch size.
- # Reset the retry count
- retry_count, retry_threshold = 0, 1
- else:
- # We succeeded at a lower batch size
- retry_count += 1
- if retry_count >= retry_threshold:
- # Go back to the default text-batch size,
- # but increase the retry_threshold in
- # case we fail again
- parts_per_text_batch_retry = None
- retry_count, retry_threshold = 0, min(retry_threshold * 2, 16)
- text_part_offset += parts_per_text_batch_use
-
- update_restart_offsets(output_path, end_bucket_offset, end_text_offset)
- print(
- "Bucket partition ",
- f"{end_bucket_offset}/{bucket_part_end_offset} "
- f"completed in {time.time()-st_bucket}",
- flush=True,
- )
-
- # Need to reset text_part_start_offset to 0 after
- # a single bucket-batch pass (only matters if we are
- # breaking the bucket-mapping df into multiple batches)
- text_part_start_offset = 0
-
-
-class JaccardSimilarity:
- def __init__(
- self,
- id_field="id",
- anchor_id_fields=["anchor_0_id", "anchor_1_id"],
- text_field="text",
- ngram_width=5,
- ):
- self.id_field = id_field
- self.anchor_id_fields = anchor_id_fields
- self.text_field = text_field
- self.anchor_id = f"anchor_{id_field}"
- self.left_id = f"{self.id_field}_x"
- self.right_id = f"{self.id_field}_y"
- self.ngram_width = ngram_width
-
- def __call__(DocumentDataset):
- raise NotImplementedError
-
- def jaccard_compute(self, shuffled_docs_path):
- paths = [
- entry.path
- for entry in os.scandir(shuffled_docs_path)
- if not entry.path.endswith(".txt")
- ]
- meta_df = cudf.DataFrame(
- {
- self.left_id: ["x"],
- self.right_id: ["y"],
- "jaccard": np.float32([0.0]),
- }
- )
- result_df = dd.from_map(
- self._compute_jaccard_on_1_partition, paths, meta=meta_df
- ).reset_index(drop=True)
- return result_df
-
- def _compute_jaccard_on_1_partition(self, path):
- try:
- df = cudf.read_parquet(path)
- pair_df = self._compute_jaccard_and_create_pair_df(df)
- except OverflowError:
- paths = [entry.path for entry in os.scandir(os.path.join(path))]
- anchor_df_str_size_ls = [
- self._get_anchor_docs_and_string_size(path) for path in paths
- ]
- anchor_df = cudf.concat(
- [anchor_doc for anchor_doc, _ in anchor_df_str_size_ls],
- ignore_index=True,
- ).drop_duplicates()
- df_str_size = [str_size for _, str_size in anchor_df_str_size_ls]
- paths = JaccardSimilarity._create_bins(
- df_str_size, np.iinfo(np.int32).max // 10
- )
- pair_dfs = []
- for path in paths:
- print(path)
- df = cudf.read_parquet(path).reset_index(drop=True)
- df = cudf.concat([df, anchor_df], ignore_index=True)
- pair_df = self._compute_jaccard_and_create_pair_df(df)
- pair_dfs.append(pair_df)
- pair_df = cudf.concat(pair_dfs, ignore_index=True)
- return pair_df
-
- def _get_anchor_docs_and_string_size(self, path):
- df = cudf.read_parquet(path)
- str_bytes = df[self.text_field].str.byte_count().sum()
- is_anchor_flag = df[self.id_field] == df[self.anchor_id_fields[0]]
- for anchor_id in self.anchor_id_fields[1:]:
- is_anchor_flag = is_anchor_flag | (df[self.id_field] == df[anchor_id])
- anchor_df = df[is_anchor_flag].reset_index(drop=True)
- return anchor_df, {"path": path, "str_bytes": str_bytes}
-
- @staticmethod
- def _create_bins(path_dicts, max_size):
- path_dicts.sort(key=lambda x: x["str_bytes"], reverse=True)
- bins, bin_sizes = [], []
- for path_d in path_dicts:
- new_path, new_size = path_d["path"], path_d["str_bytes"]
- for i, bin_size in enumerate(bin_sizes):
- if bin_size + new_size <= max_size:
- bins[i].append(new_path)
- bin_sizes[i] += new_size
- new_size = 0
- break
- if new_size:
- bins.append([new_path])
- bin_sizes.append(new_size)
- return bins
-
- def _compute_jaccard_and_create_pair_df(self, df):
- df = df.drop_duplicates(
- subset=[self.id_field] + self.anchor_id_fields, ignore_index=True
- )
- anchor_columns = self.anchor_id_fields
- id_field = self.id_field
- result_ls = []
- try:
- for anchor_col in anchor_columns:
- doc_df = df[[id_field, self.text_field, anchor_col]]
- doc_df = doc_df.rename(columns={anchor_col: self.anchor_id})
- doc_df = doc_df[doc_df[id_field] != doc_df[self.anchor_id]]
- anchor_df = self._get_anchor_df(df, anchor_col)
- result_df = self._compute_jaccard_pair(doc_df, anchor_df)
- result_ls.append(result_df)
-
- return cudf.concat(result_ls)
- except OverflowError as e:
- print(
- "Failed with OverflowError in compute_jaccard_and_create_pair_df",
- flush=True,
- )
- print(df, flush=True)
- print("--" * 30)
- print("Error")
- print("---" * 30)
- raise e
-
- def _get_anchor_df(self, df, anchor_col):
- anchor_df = df[df[self.id_field] == df[anchor_col]]
- anchor_df = anchor_df.reset_index(drop=True)
- anchor_df = anchor_df[[anchor_col, self.text_field]]
- anchor_df = anchor_df.rename(columns={anchor_col: self.anchor_id})
- return anchor_df
-
- def _compute_jaccard_pair(self, docs_df, anchor_df):
- nrows_at_once = JaccardSimilarity._get_max_num_rows_to_process_once(
- df=docs_df, text_field=self.text_field
- )
- result_ls = []
- for i in range(0, docs_df.shape[0], nrows_at_once):
- pair_df = docs_df[i : i + nrows_at_once]
- pair_df = pair_df.merge(anchor_df, on=self.anchor_id)
- pair_df = pair_df.rename(
- columns={self.id_field: self.left_id, self.anchor_id: self.right_id}
- )
- mask = pair_df[self.left_id] != pair_df[self.right_id]
- pair_df = pair_df[mask].reset_index(drop=True)
- if len(pair_df) == 0:
- result_df = self._create_empty_jaccard_result()
- else:
- result_df = self._compute_jaccard_partition(pair_df)
- result_ls.append(result_df)
- if len(result_ls) == 0:
- return self._create_empty_jaccard_result()
- df_pair = cudf.concat(result_ls)
- return df_pair
-
- def _create_empty_jaccard_result(self):
- df = cudf.DataFrame()
- df[self.left_id] = "x"
- df[self.right_id] = "y"
- df["jaccard"] = np.empty(shape=0, dtype=np.float32)
- return df
-
- def _compute_jaccard_partition(self, df):
- text_x = f"{self.text_field}_x"
- text_y = f"{self.text_field}_y"
- df["jaccard"] = df[text_x].str.jaccard_index(df[text_y], width=self.ngram_width)
- df.drop(columns=[text_x, text_y], inplace=True)
- return df
-
- @staticmethod
- def _get_max_num_rows_to_process_once(df, text_field):
- nbytes = df[text_field].str.byte_count().sum()
- # Number of exmploded bytes
- exploded_bytes = nbytes * 5 * 2
- max_chars_allowed = 2_147_483_647
- byte_ratio = int(exploded_bytes) // max_chars_allowed
- if byte_ratio > 1:
- nrows_at_once = len(df) // byte_ratio
- else:
- nrows_at_once = len(df)
-
- nrows_at_once = max(1, nrows_at_once)
- return nrows_at_once
-
-
-class ConnectedComponents:
- def __init__(
- self,
- cache_dir: str,
- jaccard_pairs_path: str,
- id_column="id",
- jaccard_threshold: float = 0.8,
- logger: Union[logging.LoggerAdapter, str] = "./",
- profile_dir: Optional[str] = None,
- ):
- self.cache_dir = cache_dir
- self.jaccard_pairs_path = jaccard_pairs_path
- self.id_column = id_column
- self.left_id = f"{id_column}_x"
- self.right_id = f"{id_column}_y"
- self.jaccard_threshold = jaccard_threshold
- self.profile_dir = profile_dir
- if isinstance(logger, str):
- self._logger = create_logger(
- rank=0,
- log_file=os.path.join(logger, "ConnectedComponents.log"),
- name="ConnectedComponents",
- )
- else:
- self._logger = logger
-
- def cc_workflow(self, output_path):
- deduped_parsed_id_path = self._write_dedup_parsed_id()
- encoded_jaccard_pair_path = self._write_encoded_jaccard_pair(
- deduped_parsed_id_path
- )
- deduped_encoded_jaccard_path = self._write_dedup_encoded_jaccard_pair(
- encoded_jaccard_pair_path
- )
- cc_path = self._run_connected_components(
- deduped_encoded_jaccard_path, deduped_parsed_id_path, output_path
- )
- return cc_path
-
- def _run_connected_components(
- self,
- deduped_encoded_jaccard_path,
- deduped_parsed_id_path,
- output_path,
- ):
- t0 = time.time()
- with performance_report_if_with_ts_suffix(
- self.profile_dir, "connected-components-run"
- ):
-
- Comms.initialize(p2p=False)
- df = dask_cudf.read_parquet(
- deduped_encoded_jaccard_path, blocksize="1GB", aggregate_files=True
- )
- df = df[df["jaccard"] == 1].reset_index(drop=True)
-
- labels_df = dask_cudf.read_parquet(deduped_parsed_id_path)
- num_nodes = len(labels_df)
- self_edge_df = labels_df[["uid"]].rename(columns={"uid": self.left_id})
- self_edge_df[self.right_id] = self_edge_df[self.left_id]
-
- df = df[[self.left_id, self.right_id]].astype(np.int64)
- df = dask_cudf.concat([df, self_edge_df])
-
- G = MultiGraph(directed=False)
- G.from_dask_cudf_edgelist(
- df, source=self.left_id, destination=self.right_id, renumber=False
- )
- result = dcg.weakly_connected_components(G)
- del G
- max_partitions = min(32, result.npartitions)
- n_components = len(
- result[["labels"]].drop_duplicates(split_out=max_partitions)
- )
- num_labels = len(result)
- labels_df = labels_df.merge(
- result, left_on=["uid"], right_on=["vertex"], how="inner"
- )
- id_columns = [self.id_column]
- labels_df = labels_df[id_columns + ["labels"]]
- labels_df = labels_df.rename(columns={"labels": "group"})
- labels_df = labels_df.persist()
- # Doing an inner merge above
- # should not change any rows
-
- self._logger.info(
- "Result of connected compoinents are "
- f"# of groups : {n_components}, "
- f"# of docs removed : {num_labels - n_components}, "
- f"# nodes = {num_nodes}, "
- f"# rows in labels_df = {len(labels_df)}"
- )
- assert num_nodes == len(labels_df)
- # Ensure all docs in the same group are in the same partition
- labels_df = labels_df.shuffle(on=["group"], ignore_index=True)
- labels_df.to_parquet(output_path, write_index=False, overwrite=True)
- Comms.destroy()
- self._logger.info(
- f"Time taken for Connected Components Run = {time.time() - t0}s and output written at {output_path}"
- )
-
- @staticmethod
- def _sort_ids(df, id_columns):
- x = df[id_columns].values
- x = cp.sort(x, axis=1)
- for i, id_column in enumerate(id_columns):
- df[id_column] = x[:, i]
- df[id_column] = df[id_column].astype("uint64")
- return df
-
- @staticmethod
- def thresholding(df, threshold, column_to_threshold):
- mask = df[column_to_threshold] > threshold
- df.loc[mask, column_to_threshold] = np.int8(1)
- df.loc[~mask, column_to_threshold] = np.int8(0)
- return df
-
- def _write_dedup_encoded_jaccard_pair(self, encoded_jaccard_pair_path):
- output_path = f"{self.cache_dir}/final_dedup_encoded_jaccard_pair.parquet"
- t0 = time.time()
- with performance_report_if_with_ts_suffix(
- self.profile_dir, "connected-components-dedup-encoded-jaccard-pair"
- ):
-
- ddf = dask_cudf.read_parquet(
- encoded_jaccard_pair_path, blocksize="512MB", aggregate_files=True
- )
- meta = {
- self.left_id: "uint64",
- self.right_id: "uint64",
- "jaccard": "float32",
- }
- ddf = ddf.map_partitions(
- ConnectedComponents._sort_ids,
- id_columns=[self.left_id, self.right_id],
- meta=meta,
- )
- ddf = ddf.map_partitions(
- ConnectedComponents.thresholding,
- threshold=self.jaccard_threshold,
- column_to_threshold="jaccard",
- meta=meta,
- )
- ddf = ddf.map_partitions(
- M.drop_duplicates,
- meta=ddf._meta,
- enforce_metadata=False,
- transform_divisions=False,
- align_dataframes=False,
- )
-
- ddf = ddf.shuffle(
- [self.left_id, self.right_id],
- ignore_index=True,
- shuffle_method="tasks",
- )
- ddf = ddf.map_partitions(
- M.drop_duplicates,
- meta=ddf._meta,
- enforce_metadata=False,
- transform_divisions=False,
- align_dataframes=False,
- )
- ddf.to_parquet(output_path, write_index=False, overwrite=True)
- self._logger.info(
- f"Time taken for Dedup Encoding Jaccard Pairs = {time.time() - t0}s and output written at {output_path}"
- )
- return output_path
-
- def _write_dedup_parsed_id(self):
- dedup_parsed_id_path = f"{self.cache_dir}/dedup_parsed_id.parquet"
- t0 = time.time()
- with performance_report_if_with_ts_suffix(
- self.profile_dir, "connected-components-dedup-parsed-id"
- ):
- ddf = dask_cudf.read_parquet(
- self.jaccard_pairs_path,
- columns=[self.left_id, self.right_id],
- blocksize="512MB",
- aggregate_files=True,
- )
- id_columns = [self.id_column]
- unique_docs = ddf.map_partitions(
- ConnectedComponents._get_unique_ids_per_partition, id_columns=id_columns
- )
- unique_docs = unique_docs.drop_duplicates(
- # Dask does not guard against split_out=0
- split_out=max(ddf.npartitions // 4, 1)
- )
- unique_docs["uid"] = np.uint64(1)
- unique_docs["uid"] = unique_docs["uid"].cumsum()
- unique_docs["uid"] = unique_docs["uid"] - 1
- unique_docs.to_parquet(
- dedup_parsed_id_path, write_index=False, overwrite=True
- )
- self._logger.info(
- f"Time taken for Dedup Parsed Id = {time.time() - t0}s and output written at {dedup_parsed_id_path}"
- )
- return dedup_parsed_id_path
-
- def _write_encoded_jaccard_pair(self, dedup_parsed_id_path):
- output_path = f"{self.cache_dir}/encoded_jaccard_pair/"
- t0 = time.time()
- with performance_report_if_with_ts_suffix(
- self.profile_dir, "connected-components-encoded-jaccard-pair"
- ):
- ddf_id = dask_cudf.read_parquet(
- dedup_parsed_id_path, blocksize="2GB", aggregate_files=True
- )
- ddf = dask_cudf.read_parquet(
- self.jaccard_pairs_path,
- blocksize="1GB",
- aggregate_files=True,
- )
- self._merge_and_write(
- ddf=ddf,
- ddf_id=ddf_id,
- output_path=output_path,
- id_column=self.id_column,
- )
- self._logger.info(
- f"Time taken for Encoding Jaccard Pairs = {time.time() - t0}s and output written at {output_path}"
- )
- return output_path
-
- def _merge_and_write(
- self,
- ddf: dask_cudf.DataFrame,
- ddf_id: dask_cudf.DataFrame,
- output_path: str,
- id_column: str,
- ) -> None:
- st = time.time()
- # Ensure 'id_columns' is a list
- ddf_id = ddf_id.set_index(id_column)
- for tag in ["x", "y"]:
- pair_id = f"{id_column}_{tag}"
- # Merge 'ddf' with 'ddf_id' to map ids to uids
- ddf = ddf.merge(
- ddf_id,
- left_on=pair_id,
- right_index=True,
- how="inner",
- broadcast=True,
- )
- ddf = ddf.drop(columns=pair_id)
- ddf = ddf.rename(columns={"uid": f"{self.id_column}_{tag}"})
- ddf = ddf[[self.left_id, self.right_id, "jaccard"]]
- ddf.to_parquet(output_path, write_index=False, overwrite=True)
-
- et = time.time()
- self._logger.info(
- f"Time taken for merge and write = {et - st}s and output written at {output_path}"
- )
-
- @staticmethod
- def _get_unique_ids_per_partition(df, id_columns):
- unique_df_ls = []
- for tag in ["x", "y"]:
- cols_to_drop = []
- for id_col in id_columns:
- cols_to_drop.append(f"{id_col}_{tag}")
-
- subset_df = df[cols_to_drop].drop_duplicates(ignore_index=True)
- subset_df = subset_df.rename(
- columns={f"{id_col}_{tag}": f"{id_col}" for id_col in id_columns}
- )
- unique_df_ls.append(subset_df)
- unique_df = cudf.concat(unique_df_ls, ignore_index=True)
- unique_df = unique_df.drop_duplicates(ignore_index=True)
- return unique_df
diff --git a/nemo_curator/modules/fuzzy_dedup/_mapbuckets.py b/nemo_curator/modules/fuzzy_dedup/_mapbuckets.py
new file mode 100644
index 00000000..20a09ed7
--- /dev/null
+++ b/nemo_curator/modules/fuzzy_dedup/_mapbuckets.py
@@ -0,0 +1,280 @@
+# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import logging
+import os
+from typing import Union
+
+import cudf
+import dask_cudf
+import numpy as np
+from dask.utils import M
+
+from nemo_curator.log import create_logger
+from nemo_curator.utils.fuzzy_dedup_utils.output_map_utils import (
+ build_partition,
+ get_agg_text_bytes_df,
+)
+
+
+class _MapBuckets:
+ """
+ buckets to a logical partition by using a modified bin packing algorithm.
+ Combines buckets generated from LSH (typically high cardinality)
+ to more coarse lower cardinality bucket groups by mapping multiple buckets
+ to a logical partition using document length information and a modified bin
+ packing algorithm.
+ Only needed if running False Postive check to remove false positives.
+ """
+
+ def __init__(
+ self,
+ id_fields: Union[list, str] = "id",
+ text_field: str = "text",
+ bucket_field: str = "_bucket_id",
+ num_anchors: int = 2,
+ logger: Union[logging.LoggerAdapter, str] = "./",
+ ):
+ """
+ id_fields: list or str
+ id fields of df
+ text_field: str = "text",
+ bucket_column: str = "bucket_column",
+ num_anchors: int = 2,
+ logger: Union[logging.LoggerAdapter, str] = "./",
+ """
+ self.id_fields = [id_fields] if isinstance(id_fields, str) else id_fields
+ self.text_field = text_field
+ self.num_anchors = num_anchors
+ self.bucket_field = bucket_field
+ if isinstance(logger, str):
+ self._logger = create_logger(
+ rank=0,
+ log_file=os.path.join(logger, "Map_Buckets.log"),
+ name="Map_Buckets",
+ )
+ else:
+ self._logger = logger
+
+ @staticmethod
+ def _get_output_part_ids_with_approx_equal_sum(
+ bucket_text_bytes_df: cudf.DataFrame,
+ max_text_bytes_per_part: int,
+ buckets_column: str,
+ bytes_column: str,
+ output_partition_column: str,
+ ) -> cudf.DataFrame:
+ """
+ Create a output_series that maps the ser.index into `nparts`
+ so that the total sum of bucket_val_counts_df
+ for each output id are all most equal and
+ less than max_text_bytes_per_part
+ This is used downstream for creating equal output_ids
+ """
+ sizes = bucket_text_bytes_df[bytes_column].values
+ bucket_output_ar = build_partition(
+ sizes=sizes.get(), max_size=max_text_bytes_per_part
+ )
+ df = cudf.DataFrame()
+ df[buckets_column] = bucket_text_bytes_df[buckets_column]
+ df[output_partition_column] = bucket_output_ar
+ return df
+
+ def _get_output_map_from_text_bytes_per_bucket(
+ self,
+ ddf_bk_text_bytes,
+ bytes_column,
+ output_partition_column="_output_partition_id",
+ ):
+ # String bytes limit for cuDF
+ # https://github.com/rapidsai/cudf/issues/13733
+ max_text_bytes_per_part = int(np.iinfo(np.int32).max * 3)
+
+ self._logger.info(f"max_text_bytes_per_part = {max_text_bytes_per_part}")
+ # Increasing in an attempt to prevent hitting
+ # ulimits
+ output_map_df_meta = cudf.DataFrame(
+ {self.bucket_field: [0], output_partition_column: [1]}
+ )
+ output_map_df_meta = output_map_df_meta.astype(
+ {self.bucket_field: np.uint64, output_partition_column: np.int32}
+ )
+
+ output_map_df = ddf_bk_text_bytes.map_partitions(
+ _MapBuckets._get_output_part_ids_with_approx_equal_sum,
+ max_text_bytes_per_part=max_text_bytes_per_part,
+ buckets_column=self.bucket_field,
+ bytes_column=bytes_column,
+ output_partition_column=output_partition_column,
+ meta=output_map_df_meta,
+ )
+ output_map_df = output_map_df.persist()
+ self._logger.info(
+ f"Step 1 of output_map_df of len: {len(output_map_df)} computed"
+ )
+ lower_bounds = (
+ output_map_df[output_partition_column]
+ .map_partitions(lambda s: (s.max() + 1))
+ .compute()
+ )
+ lower_bounds = np.cumsum(lower_bounds)
+
+ def update_id(df, lower_bound):
+ df[output_partition_column] += lower_bound
+ return df
+
+ updated_parts = [
+ output_map_df.get_partition(i).map_partitions(
+ update_id, lower_bounds[i - 1]
+ )
+ for i in range(1, len(lower_bounds))
+ ]
+ updated_parts.append(output_map_df.get_partition(0))
+ output_map_df = dask_cudf.concat(updated_parts)
+ output_map_df = output_map_df.persist()
+ self._logger.info(
+ f"All steps of output_map_df of len: {len(output_map_df)} computed"
+ )
+ return output_map_df
+
+ def _get_output_map_based_on_str_bytes(
+ self, buckets_df, documents_df, bytes_column="_text_bytes"
+ ):
+ """
+ Add output_partition_id to buckets_ddf
+ """
+ documents_df = documents_df.copy()
+ documents_df[bytes_column] = documents_df[self.text_field].map_partitions(
+ lambda s: s.str.byte_count()
+ )
+ n_partitions = buckets_df.npartitions
+ documents_df = documents_df.drop(columns=[self.text_field]).repartition(
+ npartitions=n_partitions
+ )
+ buckets_df = buckets_df.merge(documents_df).repartition(
+ npartitions=n_partitions
+ )
+ del documents_df
+ ddf_bk_text_bytes, agg_df_len = get_agg_text_bytes_df(
+ df=buckets_df,
+ agg_column=self.bucket_field,
+ bytes_column=bytes_column,
+ n_partitions=n_partitions,
+ shuffle=True,
+ )
+ self._logger.info(f"Agg_df computed of length = {agg_df_len}")
+ del buckets_df
+ output_map_df = self._get_output_map_from_text_bytes_per_bucket(
+ ddf_bk_text_bytes=ddf_bk_text_bytes,
+ bytes_column=bytes_column,
+ )
+ return output_map_df
+
+ def _random_select_anchor(self, buckets_df, n=2):
+ """
+ Randomly select `n` anchors from each bucket.
+ """
+ buckets_df = buckets_df.copy()
+ buckets_df["_id_hash"] = buckets_df[self.id_fields].hash_values()
+ buckets_df = buckets_df.sort_values([self.bucket_field, "_id_hash"])
+ buckets_df["_order_in_bucket"] = buckets_df.groupby(
+ self.bucket_field
+ ).cumcount()
+ buckets_df["is_anchor"] = buckets_df["_order_in_bucket"] < n
+ for i in range(0, n):
+ buckets_df[f"is_anchor_id_{i}"] = buckets_df["_order_in_bucket"] == i
+ buckets_df = buckets_df.drop(columns=["_id_hash", "_order_in_bucket"], axis=1)
+ buckets_df = buckets_df.reset_index(drop=True)
+ buckets_df = buckets_df[buckets_df.is_anchor]
+ return buckets_df
+
+ def _add_anchor_docs(self, buckets_df, num_anchors):
+ """
+ Get anchor documents for each bucket.
+ """
+ df_anchor_bk = self._random_select_anchor(buckets_df=buckets_df, n=num_anchors)
+ df_anchor_docs = None
+ for i in range(num_anchors):
+ df_anchor_bk_i = df_anchor_bk[df_anchor_bk[f"is_anchor_id_{i}"]][
+ [self.bucket_field] + self.id_fields
+ ].reset_index(drop=True)
+ column_mapping = {id: f"anchor_{i}_{id}" for id in self.id_fields}
+ df_anchor_bk_i = df_anchor_bk_i.rename(columns=column_mapping)
+ if i == 0:
+ df_anchor_docs = df_anchor_bk_i
+ else:
+ df_anchor_docs = df_anchor_bk_i.merge(
+ df_anchor_docs, on=[self.bucket_field], how="inner"
+ )
+
+ df_anchor_docs_with_bk = buckets_df.merge(
+ df_anchor_docs, on=[self.bucket_field], how="inner"
+ )
+ return df_anchor_docs_with_bk
+
+ def map_buckets_with_anchors(
+ self,
+ documents_df: dask_cudf.DataFrame,
+ buckets_df: dask_cudf.DataFrame,
+ shuffle_type: Union[str, bool, None] = "tasks",
+ ) -> dask_cudf.DataFrame:
+ """
+ Get anchor docs with bucket info
+ Args:
+ input_data_paths: list of paths to input data
+ input_bucket_path: path to input buckets
+ text_ddf_blocksize: blocksize for text ddf
+ num_files: number of files to read
+ num_workers: number of workers
+ shuffle_type: type of shuffle to use
+ Returns:
+ ddf_anchor_docs_with_bk
+ """
+ output_map_df = self._get_output_map_based_on_str_bytes(
+ buckets_df=buckets_df, documents_df=documents_df
+ )
+ ddf_anchor_docs_with_bk = buckets_df.map_partitions(
+ self._add_anchor_docs, num_anchors=self.num_anchors
+ )
+ self._logger.info("output_map_df is based on string bytes")
+ ddf_anchor_docs_with_bk = ddf_anchor_docs_with_bk.merge(
+ output_map_df, on=self.bucket_field
+ )
+ # Bucket is no longer needed
+ ddf_anchor_docs_with_bk = ddf_anchor_docs_with_bk.drop(
+ columns=[self.bucket_field]
+ )
+ # Below removes any duplicates lying around after dropping buckets
+ ddf_anchor_docs_with_bk = ddf_anchor_docs_with_bk.map_partitions(
+ M.drop_duplicates,
+ meta=ddf_anchor_docs_with_bk._meta,
+ enforce_metadata=False,
+ transform_divisions=False,
+ align_dataframes=False,
+ )
+ ddf_anchor_docs_with_bk = ddf_anchor_docs_with_bk.shuffle(
+ self.id_fields,
+ ignore_index=True,
+ shuffle_method=shuffle_type,
+ ).map_partitions(
+ M.drop_duplicates,
+ meta=ddf_anchor_docs_with_bk._meta,
+ enforce_metadata=False,
+ transform_divisions=False,
+ align_dataframes=False,
+ )
+ del output_map_df
+ return ddf_anchor_docs_with_bk
diff --git a/nemo_curator/modules/fuzzy_dedup/_shuffle.py b/nemo_curator/modules/fuzzy_dedup/_shuffle.py
new file mode 100644
index 00000000..218bf4a6
--- /dev/null
+++ b/nemo_curator/modules/fuzzy_dedup/_shuffle.py
@@ -0,0 +1,284 @@
+# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import logging
+import os
+import time
+from typing import List, Union
+
+import cudf
+import dask_cudf
+from tqdm import tqdm
+
+from nemo_curator.log import create_logger
+from nemo_curator.utils.distributed_utils import (
+ get_current_client,
+ get_num_workers,
+ performance_report_if_with_ts_suffix,
+)
+from nemo_curator.utils.fuzzy_dedup_utils.id_mapping import int_ids_to_str
+from nemo_curator.utils.fuzzy_dedup_utils.io_utils import (
+ aggregated_anchor_docs_with_bk_read,
+ get_restart_offsets,
+ update_restart_offsets,
+)
+from nemo_curator.utils.fuzzy_dedup_utils.merge_utils import (
+ extract_partitioning_index,
+ filter_text_rows_by_bucket_batch,
+ merge_left_to_shuffled_right,
+)
+from nemo_curator.utils.fuzzy_dedup_utils.shuffle_utils import write_partitioned_file
+
+
+class _Shuffle:
+ def __init__(
+ self,
+ id_fields: Union[str, list] = "id",
+ text_field: str = "text",
+ logger: Union[logging.LoggerAdapter, str] = "./",
+ profile_dir: str = None,
+ int_to_str_id: str = None,
+ ):
+ if isinstance(logger, str):
+ self._logger = create_logger(
+ rank=0,
+ log_file=os.path.join(logger, "LSH.log"),
+ name="LSH",
+ )
+ else:
+ self._logger = logger
+
+ self.id_fields = id_fields
+ self.text_field = text_field
+ self.profile_dir = profile_dir
+ self.int_to_str_id = int_to_str_id
+
+ def shuffle_docs_on_buckets(
+ self,
+ documents_df: dask_cudf.DataFrame,
+ bucket_w_anchors_path: str,
+ output_shuffled_docs_path: str,
+ bucket_mapping_df_blocksize,
+ parts_per_worker: int = 1,
+ bucket_parts_per_worker: int = 8,
+ partition_on: str = "_output_partition_id",
+ ):
+
+ ddf_anchor_docs_with_bk, bk_mapping = aggregated_anchor_docs_with_bk_read(
+ path=bucket_w_anchors_path,
+ blocksize=bucket_mapping_df_blocksize,
+ )
+ self._logger.info("Getting ddf_anchor_docs_with_bk completed")
+ self._logger.debug(
+ f"ddf_anchor_docs_with_bk.npartitions = {ddf_anchor_docs_with_bk.npartitions}"
+ )
+ st = time.time()
+ num_workers = get_num_workers(get_current_client())
+ parts_per_batch = num_workers * parts_per_worker
+ self._logger.debug(f"parts_per_batch = {parts_per_batch}")
+ parts_per_bucket_batch = num_workers * bucket_parts_per_worker
+ self._logger.debug(f"parts_per_bucket_batch = {parts_per_bucket_batch}")
+
+ dask_profile_name = (
+ "suffle_docs"
+ + f"-parts_per_batch-{parts_per_batch}"
+ + f"-parts_per_bucket_batch-{parts_per_bucket_batch}"
+ )
+ documents_df = documents_df[self.id_fields + [self.text_field]]
+
+ with performance_report_if_with_ts_suffix(self.profile_dir, dask_profile_name):
+ self._batched_merge_and_write(
+ left_df=documents_df,
+ right_df=ddf_anchor_docs_with_bk,
+ output_path=output_shuffled_docs_path,
+ merge_on=self.id_fields,
+ partition_on=partition_on,
+ parts_per_text_batch=parts_per_batch,
+ parts_per_bucket_batch=parts_per_bucket_batch,
+ bk_mapping=bk_mapping,
+ num_workers=num_workers,
+ )
+ self._logger.info(
+ f"Time taken for Shuffle = {time.time()-st}s and output written at {output_shuffled_docs_path}"
+ )
+
+ def _batched_merge_and_write(
+ self,
+ left_df: dask_cudf.DataFrame,
+ right_df: dask_cudf.DataFrame,
+ output_path: str,
+ merge_on: List[str],
+ partition_on: str,
+ parts_per_text_batch: int,
+ parts_per_bucket_batch: int,
+ bk_mapping,
+ num_workers: int = None,
+ ):
+ total_text_partitions = left_df.npartitions
+ total_bucket_partitions = right_df.npartitions
+
+ # Extract global partitioning index
+ left_df, global_partitioning_index = extract_partitioning_index(
+ left_df,
+ merge_on,
+ bk_mapping,
+ parts_per_bucket_batch,
+ total_bucket_partitions,
+ )
+
+ # Set start offsets
+ bucket_part_start_offset, text_part_start_offset = get_restart_offsets(
+ output_path
+ )
+
+ # Set end offsets
+ # NOTE: These end offsets are always set to the end
+ # of the data. However, we may want to be able to set
+ # both the start and end offsets from the command line
+ # in the future.
+ bucket_part_end_offset = total_bucket_partitions
+ text_part_end_offset = total_text_partitions
+
+ # Check that offsets are valid
+ assert bucket_part_start_offset % parts_per_bucket_batch == 0
+ assert bucket_part_end_offset > bucket_part_start_offset
+ assert text_part_end_offset > text_part_start_offset
+
+ # Initialize "retry" variables
+ #
+ # - retry_count: The number of successive batches that
+ # we have already performed at a reduced batch size.
+ # - retry_threshold: The number of successive batches
+ # for which we should keep the batch size low
+ # before attempting the default batch size again.
+ # Every time we return to the default batch size
+ # and immediately fail, retry_threshold will double.
+ parts_per_text_batch_retry = None
+ retry_count, retry_threshold = 0, 1
+
+ self._logger.info(
+ f"Starting at bucket-map partition {bucket_part_start_offset}"
+ f" and text-df partition {text_part_start_offset}",
+ )
+
+ for bucket_part_offset in tqdm(
+ range(
+ bucket_part_start_offset, bucket_part_end_offset, parts_per_bucket_batch
+ )
+ ):
+
+ # Outer loop over batches of "bucket-map" partitions
+ end_bucket_offset = min(
+ bucket_part_offset + parts_per_bucket_batch, bucket_part_end_offset
+ )
+ print(
+ f"\nStarted processing bucket-map partitions {bucket_part_offset} "
+ f"through {end_bucket_offset} of {bucket_part_end_offset}",
+ flush=True,
+ )
+ st_bucket = time.time()
+
+ # Select our bucket-mapping batch
+ subset_bucket_df = right_df.partitions[bucket_part_offset:end_bucket_offset]
+ subset_bucket_df = subset_bucket_df.persist()
+
+ # Filter out rows of left_df that we know cannot
+ # align with any rows of subset_bucket_df
+ left_df_use = filter_text_rows_by_bucket_batch(
+ left_df,
+ global_partitioning_index,
+ bucket_part_offset,
+ bucket_part_end_offset,
+ total_bucket_partitions,
+ )
+
+ text_part_offset = text_part_start_offset
+ while text_part_offset < text_part_end_offset:
+
+ # Check if we are "retrying" with a smaller "parts_per_text_batch"
+ if parts_per_text_batch_retry:
+ parts_per_text_batch_use = parts_per_text_batch_retry
+ else:
+ st_text = time.time()
+ parts_per_text_batch_use = parts_per_text_batch
+ print(f"Using {parts_per_text_batch_use} text partitions.", flush=True)
+
+ # Select partitions for our text batch
+ end_text_offset = min(
+ text_part_offset + parts_per_text_batch_use, text_part_end_offset
+ )
+ subset_text_df = left_df_use.partitions[
+ text_part_offset:end_text_offset
+ ]
+ subset_merged_df = merge_left_to_shuffled_right(
+ subset_text_df,
+ subset_bucket_df,
+ merge_on,
+ )
+ output_df = subset_merged_df.shuffle(on=partition_on)
+
+ if self.int_to_str_id is not None and output_df is not None:
+ output_df = output_df.map_partitions(
+ int_ids_to_str, id_column=self.int_to_str_id
+ )
+ batch_label = f"{end_bucket_offset}_{end_text_offset}"
+ if output_df is not None:
+ written_files = output_df.map_partitions(
+ write_partitioned_file,
+ output_path,
+ partition_on,
+ batch_label,
+ meta=cudf.Series([True]),
+ )
+ written_files = written_files.compute()
+ update_restart_offsets(output_path, bucket_part_offset, end_text_offset)
+ del output_df
+
+ print(
+ "Text-df partition ",
+ f"{end_text_offset}/{text_part_end_offset} "
+ f"completed in {time.time()-st_text}",
+ flush=True,
+ )
+
+ # Update loop control-flow variables
+ if parts_per_text_batch_use == parts_per_text_batch:
+ # We succeeded at the default batch size.
+ # Reset the retry count
+ retry_count, retry_threshold = 0, 1
+ else:
+ # We succeeded at a lower batch size
+ retry_count += 1
+ if retry_count >= retry_threshold:
+ # Go back to the default text-batch size,
+ # but increase the retry_threshold in
+ # case we fail again
+ parts_per_text_batch_retry = None
+ retry_count, retry_threshold = 0, min(retry_threshold * 2, 16)
+ text_part_offset += parts_per_text_batch_use
+
+ update_restart_offsets(output_path, end_bucket_offset, end_text_offset)
+ print(
+ "Bucket partition ",
+ f"{end_bucket_offset}/{bucket_part_end_offset} "
+ f"completed in {time.time()-st_bucket}",
+ flush=True,
+ )
+
+ # Need to reset text_part_start_offset to 0 after
+ # a single bucket-batch pass (only matters if we are
+ # breaking the bucket-mapping df into multiple batches)
+ text_part_start_offset = 0
diff --git a/nemo_curator/modules/fuzzy_dedup/bucketstoedges.py b/nemo_curator/modules/fuzzy_dedup/bucketstoedges.py
new file mode 100644
index 00000000..5ff08b4c
--- /dev/null
+++ b/nemo_curator/modules/fuzzy_dedup/bucketstoedges.py
@@ -0,0 +1,160 @@
+# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import logging
+import os
+import time
+import warnings
+from itertools import pairwise
+from typing import Optional, Union
+
+import cudf
+import dask_cudf
+import numpy as np
+import pandas as pd
+import pyarrow as pa
+
+from nemo_curator.datasets import DocumentDataset
+from nemo_curator.log import create_logger
+from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix
+
+
+class BucketsToEdges:
+ """
+ Maps buckets generated from LSH into an edgelist that
+ can be processed further by Connected Components to find duplicate
+ documents
+ """
+
+ def __init__(
+ self,
+ cache_dir: str = None,
+ id_fields: Union[list, str] = "id",
+ str_id_name: str = "id",
+ bucket_field: str = "_bucket_id",
+ logger: Union[logging.LoggerAdapter, str] = "./",
+ profile_dir: Optional[str] = None,
+ ):
+ """
+ Parameters
+ ----------
+ cache_dir: str or None
+ If specified, will compute & write the edgelist to a file
+ id_fields: list or str
+ id fields of documents in buckets_df
+ str_id_name: str
+ Ignored if there is a single id field. Multiple id fields
+ will be combined into a single id field with the given name.
+ bucket_field: str
+ Column denoting bucket ID
+ num_buckets: Number of bands/buckets to create from the minhash signature.
+ Hashes_per_signature = num_hashes / num_buckets
+ """
+ self.cache_dir = cache_dir
+ self.id_fields = [id_fields] if isinstance(id_fields, str) else id_fields
+ self.str_id_name = str_id_name if len(self.id_fields) > 1 else self.id_fields[0]
+ self.output_ids = [f"{self.str_id_name}_x", f"{self.str_id_name}_y"]
+ self.bucket_field = bucket_field
+ self.profile_dir = profile_dir
+ if isinstance(logger, str):
+ self._logger = create_logger(
+ rank=0,
+ log_file=os.path.join(logger, "Buckets_to_Edges.log"),
+ name="Buckets_to_Edges",
+ )
+ else:
+ self._logger = logger
+
+ @staticmethod
+ def _combine_multiple_ids(
+ input_df: cudf.DataFrame, input_id_fields: list, output_id_field: str
+ ) -> cudf.DataFrame:
+ if output_id_field in input_df.columns:
+ raise ValueError(
+ f"Input df already contains column named: {output_id_field}"
+ )
+
+ output_df = input_df.copy()[input_df.columns.difference(input_id_fields)]
+
+ output_df[output_id_field] = input_df[input_id_fields[0]].astype(str)
+ for input_field in input_id_fields[1:]:
+ output_df[output_id_field] = output_df[output_id_field] = (
+ input_df[input_id_fields[0]].astype(str)
+ + "-"
+ + input_df[input_field].astype(str)
+ )
+
+ return output_df
+
+ def buckets_to_edges(
+ self,
+ buckets_df: cudf.DataFrame,
+ ) -> cudf.DataFrame:
+
+ grouped_buckets = (
+ buckets_df.groupby(self.bucket_field)[self.str_id_name]
+ .agg(list)
+ .list.sort_values()
+ )
+ bucket_docs = grouped_buckets.to_arrow().to_pylist()
+ edges = []
+ # Create pairs of all documents within a bucket since they are near duplicates
+ # Effectively create a edge list of all near duplicate documents
+ for bucket_doc in bucket_docs:
+ edges.extend(pairwise(bucket_doc))
+ edges = pd.DataFrame(edges, columns=self.output_ids)
+ edges = pa.Table.from_pandas(edges)
+ result_df = cudf.DataFrame.from_arrow(edges)
+ del edges
+ result_df = result_df.drop_duplicates(self.output_ids).reset_index(drop=True)
+ result_df["jaccard"] = np.float32(1.0)
+ return result_df
+
+ def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
+ buckets_df = dataset.df
+ self._logger.info(f"Starting conversion of LSH Buckets to Graph Edgelist")
+ if len(self.id_fields) > 1:
+ buckets_df = buckets_df.map_partitions(
+ BucketsToEdges._combine_multiple_ids,
+ input_id_fields=self.id_fields,
+ output_id_field=self.str_id_name,
+ )
+
+ meta = [(output_id, str) for output_id in self.output_ids]
+ meta.append(("jaccard", np.float32))
+ edges_df = buckets_df.map_partitions(self.buckets_to_edges, meta=meta)
+
+ if self.cache_dir is None:
+ return DocumentDataset(edges_df)
+
+ write_path = os.path.join(self.cache_dir, "_edges.parquet")
+ if os.path.exists(write_path):
+ warnings.warn(
+ f"Output path {write_path} already exists and will be overwritten"
+ )
+ t0 = time.time()
+ with performance_report_if_with_ts_suffix(
+ self.profile_dir,
+ "bucket-to-edges",
+ ):
+ edges_df.to_parquet(write_path, write_index=False, overwrite=True)
+ self._logger.info(
+ f"Time taken for Converted Buckets To Edgelist = {time.time() - t0}s and output written at {write_path}"
+ )
+
+ return DocumentDataset(
+ dask_cudf.read_parquet(write_path, split_row_groups=False)
+ )
diff --git a/nemo_curator/modules/fuzzy_dedup/connectedcomponents.py b/nemo_curator/modules/fuzzy_dedup/connectedcomponents.py
new file mode 100644
index 00000000..1394ae9a
--- /dev/null
+++ b/nemo_curator/modules/fuzzy_dedup/connectedcomponents.py
@@ -0,0 +1,305 @@
+# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import logging
+import os
+import time
+from typing import Optional, Union
+
+import cudf
+import cugraph.dask as dcg
+import cugraph.dask.comms.comms as Comms
+import cupy as cp
+import dask_cudf
+import numpy as np
+from cugraph import MultiGraph
+from dask.utils import M
+
+from nemo_curator.log import create_logger
+from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix
+
+
+class ConnectedComponents:
+ def __init__(
+ self,
+ cache_dir: str,
+ jaccard_pairs_path: str,
+ id_column="id",
+ jaccard_threshold: float = 0.8,
+ logger: Union[logging.LoggerAdapter, str] = "./",
+ profile_dir: Optional[str] = None,
+ ):
+ self.cache_dir = cache_dir
+ self.jaccard_pairs_path = jaccard_pairs_path
+ self.id_column = id_column
+ self.left_id = f"{id_column}_x"
+ self.right_id = f"{id_column}_y"
+ self.jaccard_threshold = jaccard_threshold
+ self.profile_dir = profile_dir
+ if isinstance(logger, str):
+ self._logger = create_logger(
+ rank=0,
+ log_file=os.path.join(logger, "ConnectedComponents.log"),
+ name="ConnectedComponents",
+ )
+ else:
+ self._logger = logger
+
+ def cc_workflow(self, output_path):
+ deduped_parsed_id_path = self._write_dedup_parsed_id()
+ encoded_jaccard_pair_path = self._write_encoded_jaccard_pair(
+ deduped_parsed_id_path
+ )
+ deduped_encoded_jaccard_path = self._write_dedup_encoded_jaccard_pair(
+ encoded_jaccard_pair_path
+ )
+ cc_path = self._run_connected_components(
+ deduped_encoded_jaccard_path, deduped_parsed_id_path, output_path
+ )
+ return cc_path
+
+ def _run_connected_components(
+ self,
+ deduped_encoded_jaccard_path,
+ deduped_parsed_id_path,
+ output_path,
+ ):
+ t0 = time.time()
+ with performance_report_if_with_ts_suffix(
+ self.profile_dir, "connected-components-run"
+ ):
+
+ Comms.initialize(p2p=False)
+ df = dask_cudf.read_parquet(
+ deduped_encoded_jaccard_path, blocksize="1GB", aggregate_files=True
+ )
+ df = df[df["jaccard"] == 1].reset_index(drop=True)
+
+ labels_df = dask_cudf.read_parquet(deduped_parsed_id_path)
+ num_nodes = len(labels_df)
+ self_edge_df = labels_df[["uid"]].rename(columns={"uid": self.left_id})
+ self_edge_df[self.right_id] = self_edge_df[self.left_id]
+
+ df = df[[self.left_id, self.right_id]].astype(np.int64)
+ df = dask_cudf.concat([df, self_edge_df])
+
+ G = MultiGraph(directed=False)
+ G.from_dask_cudf_edgelist(
+ df, source=self.left_id, destination=self.right_id, renumber=False
+ )
+ result = dcg.weakly_connected_components(G)
+ del G
+ max_partitions = min(32, result.npartitions)
+ n_components = len(
+ result[["labels"]].drop_duplicates(split_out=max_partitions)
+ )
+ num_labels = len(result)
+ labels_df = labels_df.merge(
+ result, left_on=["uid"], right_on=["vertex"], how="inner"
+ )
+ id_columns = [self.id_column]
+ labels_df = labels_df[id_columns + ["labels"]]
+ labels_df = labels_df.rename(columns={"labels": "group"})
+ labels_df = labels_df.persist()
+ # Doing an inner merge above
+ # should not change any rows
+
+ self._logger.info(
+ "Result of connected compoinents are "
+ f"# of groups : {n_components}, "
+ f"# of docs removed : {num_labels - n_components}, "
+ f"# nodes = {num_nodes}, "
+ f"# rows in labels_df = {len(labels_df)}"
+ )
+ assert num_nodes == len(labels_df)
+ # Ensure all docs in the same group are in the same partition
+ labels_df = labels_df.shuffle(on=["group"], ignore_index=True)
+ labels_df.to_parquet(output_path, write_index=False, overwrite=True)
+ Comms.destroy()
+ self._logger.info(
+ f"Time taken for Connected Components Run = {time.time() - t0}s and output written at {output_path}"
+ )
+
+ @staticmethod
+ def _sort_ids(df, id_columns):
+ x = df[id_columns].values
+ x = cp.sort(x, axis=1)
+ for i, id_column in enumerate(id_columns):
+ df[id_column] = x[:, i]
+ df[id_column] = df[id_column].astype("uint64")
+ return df
+
+ @staticmethod
+ def thresholding(df, threshold, column_to_threshold):
+ mask = df[column_to_threshold] > threshold
+ df.loc[mask, column_to_threshold] = np.int8(1)
+ df.loc[~mask, column_to_threshold] = np.int8(0)
+ return df
+
+ def _write_dedup_encoded_jaccard_pair(self, encoded_jaccard_pair_path):
+ output_path = f"{self.cache_dir}/final_dedup_encoded_jaccard_pair.parquet"
+ t0 = time.time()
+ with performance_report_if_with_ts_suffix(
+ self.profile_dir, "connected-components-dedup-encoded-jaccard-pair"
+ ):
+
+ ddf = dask_cudf.read_parquet(
+ encoded_jaccard_pair_path, blocksize="512MB", aggregate_files=True
+ )
+ meta = {
+ self.left_id: "uint64",
+ self.right_id: "uint64",
+ "jaccard": "float32",
+ }
+ ddf = ddf.map_partitions(
+ ConnectedComponents._sort_ids,
+ id_columns=[self.left_id, self.right_id],
+ meta=meta,
+ )
+ ddf = ddf.map_partitions(
+ ConnectedComponents.thresholding,
+ threshold=self.jaccard_threshold,
+ column_to_threshold="jaccard",
+ meta=meta,
+ )
+ ddf = ddf.map_partitions(
+ M.drop_duplicates,
+ meta=ddf._meta,
+ enforce_metadata=False,
+ transform_divisions=False,
+ align_dataframes=False,
+ )
+
+ ddf = ddf.shuffle(
+ [self.left_id, self.right_id],
+ ignore_index=True,
+ shuffle_method="tasks",
+ )
+ ddf = ddf.map_partitions(
+ M.drop_duplicates,
+ meta=ddf._meta,
+ enforce_metadata=False,
+ transform_divisions=False,
+ align_dataframes=False,
+ )
+ ddf.to_parquet(output_path, write_index=False, overwrite=True)
+ self._logger.info(
+ f"Time taken for Dedup Encoding Jaccard Pairs = {time.time() - t0}s and output written at {output_path}"
+ )
+ return output_path
+
+ def _write_dedup_parsed_id(self):
+ dedup_parsed_id_path = f"{self.cache_dir}/dedup_parsed_id.parquet"
+ t0 = time.time()
+ with performance_report_if_with_ts_suffix(
+ self.profile_dir, "connected-components-dedup-parsed-id"
+ ):
+ ddf = dask_cudf.read_parquet(
+ self.jaccard_pairs_path,
+ columns=[self.left_id, self.right_id],
+ blocksize="512MB",
+ aggregate_files=True,
+ )
+ id_columns = [self.id_column]
+ unique_docs = ddf.map_partitions(
+ ConnectedComponents._get_unique_ids_per_partition, id_columns=id_columns
+ )
+ unique_docs = unique_docs.drop_duplicates(
+ # Dask does not guard against split_out=0
+ split_out=max(ddf.npartitions // 4, 1)
+ )
+ unique_docs["uid"] = np.uint64(1)
+ unique_docs["uid"] = unique_docs["uid"].cumsum()
+ unique_docs["uid"] = unique_docs["uid"] - 1
+ unique_docs.to_parquet(
+ dedup_parsed_id_path, write_index=False, overwrite=True
+ )
+ self._logger.info(
+ f"Time taken for Dedup Parsed Id = {time.time() - t0}s and output written at {dedup_parsed_id_path}"
+ )
+ return dedup_parsed_id_path
+
+ def _write_encoded_jaccard_pair(self, dedup_parsed_id_path):
+ output_path = f"{self.cache_dir}/encoded_jaccard_pair/"
+ t0 = time.time()
+ with performance_report_if_with_ts_suffix(
+ self.profile_dir, "connected-components-encoded-jaccard-pair"
+ ):
+ ddf_id = dask_cudf.read_parquet(
+ dedup_parsed_id_path, blocksize="2GB", aggregate_files=True
+ )
+ ddf = dask_cudf.read_parquet(
+ self.jaccard_pairs_path,
+ blocksize="1GB",
+ aggregate_files=True,
+ )
+ self._merge_and_write(
+ ddf=ddf,
+ ddf_id=ddf_id,
+ output_path=output_path,
+ id_column=self.id_column,
+ )
+ self._logger.info(
+ f"Time taken for Encoding Jaccard Pairs = {time.time() - t0}s and output written at {output_path}"
+ )
+ return output_path
+
+ def _merge_and_write(
+ self,
+ ddf: dask_cudf.DataFrame,
+ ddf_id: dask_cudf.DataFrame,
+ output_path: str,
+ id_column: str,
+ ) -> None:
+ st = time.time()
+ # Ensure 'id_columns' is a list
+ ddf_id = ddf_id.set_index(id_column)
+ for tag in ["x", "y"]:
+ pair_id = f"{id_column}_{tag}"
+ # Merge 'ddf' with 'ddf_id' to map ids to uids
+ ddf = ddf.merge(
+ ddf_id,
+ left_on=pair_id,
+ right_index=True,
+ how="inner",
+ broadcast=True,
+ )
+ ddf = ddf.drop(columns=pair_id)
+ ddf = ddf.rename(columns={"uid": f"{self.id_column}_{tag}"})
+ ddf = ddf[[self.left_id, self.right_id, "jaccard"]]
+ ddf.to_parquet(output_path, write_index=False, overwrite=True)
+
+ et = time.time()
+ self._logger.info(
+ f"Time taken for merge and write = {et - st}s and output written at {output_path}"
+ )
+
+ @staticmethod
+ def _get_unique_ids_per_partition(df, id_columns):
+ unique_df_ls = []
+ for tag in ["x", "y"]:
+ cols_to_drop = []
+ for id_col in id_columns:
+ cols_to_drop.append(f"{id_col}_{tag}")
+
+ subset_df = df[cols_to_drop].drop_duplicates(ignore_index=True)
+ subset_df = subset_df.rename(
+ columns={f"{id_col}_{tag}": f"{id_col}" for id_col in id_columns}
+ )
+ unique_df_ls.append(subset_df)
+ unique_df = cudf.concat(unique_df_ls, ignore_index=True)
+ unique_df = unique_df.drop_duplicates(ignore_index=True)
+ return unique_df
diff --git a/nemo_curator/modules/fuzzy_dedup/fuzzyduplicates.py b/nemo_curator/modules/fuzzy_dedup/fuzzyduplicates.py
new file mode 100644
index 00000000..41ef9c9f
--- /dev/null
+++ b/nemo_curator/modules/fuzzy_dedup/fuzzyduplicates.py
@@ -0,0 +1,246 @@
+# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import logging
+import os
+import time
+from typing import Union
+
+import dask_cudf
+
+from nemo_curator.datasets import DocumentDataset
+from nemo_curator.log import create_logger
+from nemo_curator.modules.config import FuzzyDuplicatesConfig
+from nemo_curator.modules.fuzzy_dedup._mapbuckets import _MapBuckets
+from nemo_curator.modules.fuzzy_dedup._shuffle import _Shuffle
+from nemo_curator.modules.fuzzy_dedup.bucketstoedges import BucketsToEdges
+from nemo_curator.modules.fuzzy_dedup.connectedcomponents import ConnectedComponents
+from nemo_curator.modules.fuzzy_dedup.jaccardsimilarity import JaccardSimilarity
+from nemo_curator.modules.fuzzy_dedup.lsh import LSH
+from nemo_curator.modules.fuzzy_dedup.minhash import MinHash
+from nemo_curator.modules.meta import Sequential
+from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix
+
+
+class FuzzyDuplicates:
+ def __init__(
+ self,
+ config: FuzzyDuplicatesConfig,
+ logger: Union[logging.LoggerAdapter, str] = "./",
+ ):
+ """
+ Parameters
+ ----------
+ config: FuzzyDuplicatesConfig,
+ Config options for finding FuzzyDuplicates
+ logger: Existing logger to log to, or a path to a log directory.
+
+ Returns
+ -------
+ DocumentDataset containing IDs of all documents and the corresponding duplicate group
+ they belong to. Documents in the same group are near duplicates.
+ """
+ if isinstance(logger, str):
+ self._logger = create_logger(
+ rank=0,
+ log_file=os.path.join(logger, "FuzzyDuplicates.log"),
+ name="FuzzyDuplicates",
+ )
+ else:
+ self._logger = logger
+
+ self.config = config
+ self.minhash = MinHash(
+ seed=self.config.seed,
+ num_hashes=self.config.num_hashes,
+ char_ngrams=self.config.char_ngrams,
+ use_64bit_hash=self.config.use_64_bit_hash,
+ logger=self._logger,
+ id_field=self.config.id_field,
+ text_field=self.config.text_field,
+ profile_dir=self.config.profile_dir,
+ cache_dir=self.config.cache_dir,
+ )
+ self.lsh = LSH(
+ cache_dir=self.config.cache_dir,
+ num_hashes=self.config.num_hashes,
+ num_buckets=self.config.num_buckets,
+ buckets_per_shuffle=self.config.buckets_per_shuffle,
+ false_positive_check=self.config.false_positive_check,
+ logger=self._logger,
+ id_fields=[self.config.id_field],
+ profile_dir=self.config.profile_dir,
+ )
+
+ if self.config.false_positive_check:
+ self.map_buckets = _MapBuckets(
+ id_fields=[self.config.id_field],
+ text_field=self.config.text_field,
+ logger=self._logger,
+ num_anchors=self.config.num_anchors,
+ )
+ self.jaccard_shuffle = _Shuffle(
+ id_fields=[self.config.id_field],
+ text_field=self.config.text_field,
+ logger=self._logger,
+ profile_dir=self.config.profile_dir,
+ )
+ self.jaccard_compute = JaccardSimilarity(
+ id_field=self.config.id_field,
+ text_field=self.config.text_field,
+ ngram_width=self.config.char_ngrams,
+ anchor_id_fields=[
+ f"anchor_{i}_{self.config.id_field}"
+ for i in range(self.config.num_anchors)
+ ],
+ )
+ else:
+ self.buckets_to_edges = BucketsToEdges(
+ cache_dir=self.config.cache_dir,
+ id_fields=self.config.id_field,
+ logger=self._logger,
+ profile_dir=self.config.profile_dir,
+ )
+
+ jaccard_pairs_fname = (
+ "jaccard_similarity_results.parquet"
+ if self.config.false_positive_check
+ else "_edges.parquet"
+ )
+ self.connected_components = ConnectedComponents(
+ cache_dir=self.config.cache_dir,
+ jaccard_pairs_path=os.path.join(self.config.cache_dir, jaccard_pairs_fname),
+ id_column=self.config.id_field,
+ jaccard_threshold=self.config.jaccard_threshold,
+ logger=self._logger,
+ profile_dir=self.config.profile_dir,
+ )
+
+ def __call__(self, dataset: DocumentDataset):
+ """
+ Parameters
+ ----------
+ dataset: DocumentDataset
+ The input datset to compute FuzzyDuplicates. Must contain a text and unique id field.
+
+ Returns
+ -------
+ DocumentDataset containing IDs of all documents and the corresponding duplicate group
+ they belong to. Documents in the same group are near duplicates.
+ """
+
+ # Minhash + LSH
+ stage_num = 1
+ print(f"Stage {stage_num}: Starting Minhash + LSH computation")
+ minhashLSH = Sequential([self.minhash, self.lsh])
+ buckets_df = minhashLSH(dataset)
+ print(f"Stage {stage_num}: Minhash + LSH complete!")
+ if buckets_df is None:
+ print(
+ f"Stage {stage_num}: No potential duplicate documents found during LSH"
+ )
+ return None
+ stage_num += 1
+
+ if self.config.false_positive_check:
+ # Map buckets to lower cardinality distribution
+ print(f"Stage {stage_num} (False Positive Check): Starting Map_Buckets")
+ t0 = time.time()
+ mapped_buckets_w_anchors_path = os.path.join(
+ self.config.cache_dir, "anchor_docs_with_bk.parquet"
+ )
+ with performance_report_if_with_ts_suffix(
+ self.config.profile_dir,
+ "map_buckets",
+ ):
+ ddf_mapped_buckets_w_anchors = (
+ self.map_buckets.map_buckets_with_anchors(
+ documents_df=dataset.df, buckets_df=buckets_df.df
+ )
+ )
+ ddf_mapped_buckets_w_anchors.to_parquet(
+ mapped_buckets_w_anchors_path, write_index=False, overwrite=True
+ )
+ self._logger.info(
+ f"Time taken for Map_buckets : {time.time() - t0}s and output written at {mapped_buckets_w_anchors_path}"
+ )
+
+ print(f"Stage {stage_num} (False Postive Check): Map_Buckets Complete!")
+ stage_num += 1
+
+ # Shuffle documents based on mapped buckets
+ print(f"Stage {stage_num} (False Postive Check): Shuffle docs")
+ shuffled_docs_path = os.path.join(
+ self.config.cache_dir, "shuffled_docs.parquet"
+ )
+ self.jaccard_shuffle.shuffle_docs_on_buckets(
+ documents_df=dataset.df,
+ bucket_w_anchors_path=mapped_buckets_w_anchors_path,
+ output_shuffled_docs_path=shuffled_docs_path,
+ bucket_mapping_df_blocksize=self.config.bucket_mapping_blocksize,
+ parts_per_worker=self.config.parts_per_worker,
+ bucket_parts_per_worker=self.config.bucket_parts_per_worker,
+ )
+ print(f"Stage {stage_num} (False Postive Check): Shuffle docs complete!")
+ stage_num += 1
+
+ # jaccard comparision within buckets
+ print(
+ f"Stage {stage_num} (False Postive Check): Jaccard Similarity in Buckets"
+ )
+ jaccard_pairs_path = os.path.join(
+ self.config.cache_dir, "jaccard_similarity_results.parquet"
+ )
+ t0 = time.time()
+ with performance_report_if_with_ts_suffix(
+ self.config.profile_dir,
+ "jaccard-similarity",
+ ):
+ jaccard_pairs_df = self.jaccard_compute.jaccard_compute(
+ shuffled_docs_path=shuffled_docs_path
+ )
+ jaccard_pairs_df.to_parquet(
+ jaccard_pairs_path,
+ write_index=False,
+ write_metadata_file=False,
+ overwrite=True,
+ )
+ self._logger.info(
+ f"Time taken for Jaccard Similarity = {time.time()-t0}s and output written at {jaccard_pairs_path}"
+ )
+
+ print(
+ f"Stage {stage_num} (False Postive Check): Jaccard Similarity in Buckets Complete!"
+ )
+ stage_num += 1
+
+ else:
+ # Map buckets to lower cardinality distribution
+ print(f"Stage {stage_num}: Starting LSH Buckets to Graph Edgelist")
+ self.buckets_to_edges(buckets_df)
+ print(
+ f"Stage {stage_num}: Starting LSH Buckets to Graph Edgelist Complete!"
+ )
+ stage_num += 1
+
+ # Connected components across buckets
+ print(f"Stage {stage_num}: Connected Components across buckets")
+ cc_path = os.path.join(self.config.cache_dir, "connected_components.parquet")
+ self.connected_components.cc_workflow(cc_path)
+ print(f"Stage {stage_num}: Connected Components across buckets complete!")
+ stage_num += 1
+
+ return DocumentDataset(dask_cudf.read_parquet(cc_path, split_row_groups=False))
diff --git a/nemo_curator/modules/fuzzy_dedup/jaccardsimilarity.py b/nemo_curator/modules/fuzzy_dedup/jaccardsimilarity.py
new file mode 100644
index 00000000..04ac73a4
--- /dev/null
+++ b/nemo_curator/modules/fuzzy_dedup/jaccardsimilarity.py
@@ -0,0 +1,199 @@
+# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import os
+
+import cudf
+import numpy as np
+from dask import dataframe as dd
+
+
+class JaccardSimilarity:
+ def __init__(
+ self,
+ id_field="id",
+ anchor_id_fields=["anchor_0_id", "anchor_1_id"],
+ text_field="text",
+ ngram_width=5,
+ ):
+ self.id_field = id_field
+ self.anchor_id_fields = anchor_id_fields
+ self.text_field = text_field
+ self.anchor_id = f"anchor_{id_field}"
+ self.left_id = f"{self.id_field}_x"
+ self.right_id = f"{self.id_field}_y"
+ self.ngram_width = ngram_width
+
+ def __call__(DocumentDataset):
+ raise NotImplementedError
+
+ def jaccard_compute(self, shuffled_docs_path):
+ paths = [
+ entry.path
+ for entry in os.scandir(shuffled_docs_path)
+ if not entry.path.endswith(".txt")
+ ]
+ meta_df = cudf.DataFrame(
+ {
+ self.left_id: ["x"],
+ self.right_id: ["y"],
+ "jaccard": np.float32([0.0]),
+ }
+ )
+ result_df = dd.from_map(
+ self._compute_jaccard_on_1_partition, paths, meta=meta_df
+ ).reset_index(drop=True)
+ return result_df
+
+ def _compute_jaccard_on_1_partition(self, path):
+ try:
+ df = cudf.read_parquet(path)
+ pair_df = self._compute_jaccard_and_create_pair_df(df)
+ except OverflowError:
+ paths = [entry.path for entry in os.scandir(os.path.join(path))]
+ anchor_df_str_size_ls = [
+ self._get_anchor_docs_and_string_size(path) for path in paths
+ ]
+ anchor_df = cudf.concat(
+ [anchor_doc for anchor_doc, _ in anchor_df_str_size_ls],
+ ignore_index=True,
+ ).drop_duplicates()
+ df_str_size = [str_size for _, str_size in anchor_df_str_size_ls]
+ paths = JaccardSimilarity._create_bins(
+ df_str_size, np.iinfo(np.int32).max // 10
+ )
+ pair_dfs = []
+ for path in paths:
+ print(path)
+ df = cudf.read_parquet(path).reset_index(drop=True)
+ df = cudf.concat([df, anchor_df], ignore_index=True)
+ pair_df = self._compute_jaccard_and_create_pair_df(df)
+ pair_dfs.append(pair_df)
+ pair_df = cudf.concat(pair_dfs, ignore_index=True)
+ return pair_df
+
+ def _get_anchor_docs_and_string_size(self, path):
+ df = cudf.read_parquet(path)
+ str_bytes = df[self.text_field].str.byte_count().sum()
+ is_anchor_flag = df[self.id_field] == df[self.anchor_id_fields[0]]
+ for anchor_id in self.anchor_id_fields[1:]:
+ is_anchor_flag = is_anchor_flag | (df[self.id_field] == df[anchor_id])
+ anchor_df = df[is_anchor_flag].reset_index(drop=True)
+ return anchor_df, {"path": path, "str_bytes": str_bytes}
+
+ @staticmethod
+ def _create_bins(path_dicts, max_size):
+ path_dicts.sort(key=lambda x: x["str_bytes"], reverse=True)
+ bins, bin_sizes = [], []
+ for path_d in path_dicts:
+ new_path, new_size = path_d["path"], path_d["str_bytes"]
+ for i, bin_size in enumerate(bin_sizes):
+ if bin_size + new_size <= max_size:
+ bins[i].append(new_path)
+ bin_sizes[i] += new_size
+ new_size = 0
+ break
+ if new_size:
+ bins.append([new_path])
+ bin_sizes.append(new_size)
+ return bins
+
+ def _compute_jaccard_and_create_pair_df(self, df):
+ df = df.drop_duplicates(
+ subset=[self.id_field] + self.anchor_id_fields, ignore_index=True
+ )
+ anchor_columns = self.anchor_id_fields
+ id_field = self.id_field
+ result_ls = []
+ try:
+ for anchor_col in anchor_columns:
+ doc_df = df[[id_field, self.text_field, anchor_col]]
+ doc_df = doc_df.rename(columns={anchor_col: self.anchor_id})
+ doc_df = doc_df[doc_df[id_field] != doc_df[self.anchor_id]]
+ anchor_df = self._get_anchor_df(df, anchor_col)
+ result_df = self._compute_jaccard_pair(doc_df, anchor_df)
+ result_ls.append(result_df)
+
+ return cudf.concat(result_ls)
+ except OverflowError as e:
+ print(
+ "Failed with OverflowError in compute_jaccard_and_create_pair_df",
+ flush=True,
+ )
+ print(df, flush=True)
+ print("--" * 30)
+ print("Error")
+ print("---" * 30)
+ raise e
+
+ def _get_anchor_df(self, df, anchor_col):
+ anchor_df = df[df[self.id_field] == df[anchor_col]]
+ anchor_df = anchor_df.reset_index(drop=True)
+ anchor_df = anchor_df[[anchor_col, self.text_field]]
+ anchor_df = anchor_df.rename(columns={anchor_col: self.anchor_id})
+ return anchor_df
+
+ def _compute_jaccard_pair(self, docs_df, anchor_df):
+ nrows_at_once = JaccardSimilarity._get_max_num_rows_to_process_once(
+ df=docs_df, text_field=self.text_field
+ )
+ result_ls = []
+ for i in range(0, docs_df.shape[0], nrows_at_once):
+ pair_df = docs_df[i : i + nrows_at_once]
+ pair_df = pair_df.merge(anchor_df, on=self.anchor_id)
+ pair_df = pair_df.rename(
+ columns={self.id_field: self.left_id, self.anchor_id: self.right_id}
+ )
+ mask = pair_df[self.left_id] != pair_df[self.right_id]
+ pair_df = pair_df[mask].reset_index(drop=True)
+ if len(pair_df) == 0:
+ result_df = self._create_empty_jaccard_result()
+ else:
+ result_df = self._compute_jaccard_partition(pair_df)
+ result_ls.append(result_df)
+ if len(result_ls) == 0:
+ return self._create_empty_jaccard_result()
+ df_pair = cudf.concat(result_ls)
+ return df_pair
+
+ def _create_empty_jaccard_result(self):
+ df = cudf.DataFrame()
+ df[self.left_id] = "x"
+ df[self.right_id] = "y"
+ df["jaccard"] = np.empty(shape=0, dtype=np.float32)
+ return df
+
+ def _compute_jaccard_partition(self, df):
+ text_x = f"{self.text_field}_x"
+ text_y = f"{self.text_field}_y"
+ df["jaccard"] = df[text_x].str.jaccard_index(df[text_y], width=self.ngram_width)
+ df.drop(columns=[text_x, text_y], inplace=True)
+ return df
+
+ @staticmethod
+ def _get_max_num_rows_to_process_once(df, text_field):
+ nbytes = df[text_field].str.byte_count().sum()
+ # Number of exmploded bytes
+ exploded_bytes = nbytes * 5 * 2
+ max_chars_allowed = 2_147_483_647
+ byte_ratio = int(exploded_bytes) // max_chars_allowed
+ if byte_ratio > 1:
+ nrows_at_once = len(df) // byte_ratio
+ else:
+ nrows_at_once = len(df)
+
+ nrows_at_once = max(1, nrows_at_once)
+ return nrows_at_once
diff --git a/nemo_curator/modules/fuzzy_dedup/lsh.py b/nemo_curator/modules/fuzzy_dedup/lsh.py
new file mode 100644
index 00000000..4a38b7c6
--- /dev/null
+++ b/nemo_curator/modules/fuzzy_dedup/lsh.py
@@ -0,0 +1,289 @@
+# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import logging
+import math
+import os
+import time
+import warnings
+from typing import List, Optional, Tuple, Union
+
+import cudf
+import dask_cudf
+import numpy as np
+
+from nemo_curator.datasets import DocumentDataset
+from nemo_curator.log import create_logger
+from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix
+from nemo_curator.utils.fuzzy_dedup_utils.io_utils import check_empty_buckets
+
+
+class LSH:
+ """
+ Performs LSH on a MinhashSignatures
+ """
+
+ def __init__(
+ self,
+ cache_dir: str,
+ num_hashes: int,
+ num_buckets: int,
+ buckets_per_shuffle: int = 1,
+ false_positive_check: bool = False,
+ logger: Union[logging.LoggerAdapter, str] = "./",
+ id_fields: Union[str, list] = "id",
+ minhash_field: str = "_minhash_signature",
+ profile_dir: Optional[str] = None,
+ ):
+ """
+ Parameters
+ ----------
+ cache_dir: str
+ Needs to be specified, will compute & write duplicate id, bucket pairs to cache directory.
+ num_hashes: Length of minhash signature
+ num_buckets: Number of bands/buckets to create from the minhash signature.
+ Hashes_per_signature = num_hashes / num_buckets
+ buckets_per_shuffle: Number of bands/buckets to shuffle concurrently.
+ but might lead to memory pressures and related errors.
+ false_positive_check: bool
+ If True, writes out buckets in a format compatible with downstream false positive check.
+ logger: Existing logger to log to, or a path to a log directory.
+ id_field: Columns in the Dataset denoting document ID.
+ minhash_field: Column in the Dataset denoting minhash signature.
+ profile_dir: str, Default None
+ If specified directory to write dask profile
+ """
+ self.num_hashes = num_hashes
+ self.num_buckets = num_buckets
+ self.id_fields = [id_fields] if isinstance(id_fields, str) else id_fields
+ self.minhash_field = minhash_field
+ self.buckets_per_shuffle = buckets_per_shuffle
+ self.bucket_ranges = self._generate_bucket_ranges(
+ self.num_buckets, self.num_hashes
+ )
+ self.buckets_as_int = false_positive_check
+
+ if cache_dir is None:
+ raise ValueError(
+ "cache_dir for intermediate outputs is required for this stage"
+ )
+ self.cache_dir = cache_dir
+ self.profile_dir = profile_dir
+
+ if isinstance(logger, str):
+ self._logger = create_logger(
+ rank=0,
+ log_file=os.path.join(logger, "LSH.log"),
+ name="LSH",
+ )
+ else:
+ self._logger = logger
+
+ def _generate_bucket_ranges(
+ self, num_buckets: int, num_hashes: int
+ ) -> List[List[int]]:
+ """
+ Generates a list of indices for the minhash ranges given num_bands &
+ num_hashes.
+ eg: num_bands=3, num_hashes=6
+ [[0, 1], [2, 3], [4, 5]]
+ """
+ minhashes_per_bucket = num_hashes // num_buckets
+
+ bucket_ranges = [
+ list(
+ range(
+ bucket * minhashes_per_bucket, (bucket + 1) * minhashes_per_bucket
+ )
+ )
+ for bucket in range(num_buckets)
+ ]
+ return bucket_ranges
+
+ def minhash_to_buckets(
+ self,
+ df: cudf.DataFrame,
+ bucket_ranges: List[List[int]],
+ ) -> cudf.DataFrame:
+ df2 = df[self.id_fields]
+ for i, h in enumerate(bucket_ranges):
+ indices = cudf.Series([h]).repeat(len(df2))
+ df2[f"_bucket_{i}"] = f"b{i}_" + df[self.minhash_field].list.take(
+ indices
+ ).hash_values(method="md5")
+ return df2
+
+ def bucket_id_to_int(
+ self,
+ bucket_ddf: dask_cudf.DataFrame,
+ bucket_col_name: str = "bucket_id",
+ start_id: int = 0,
+ ) -> Tuple[dask_cudf.DataFrame, int]:
+ """
+ Maps bucket ids to a contigious integer range from starting from start_id.
+ """
+ unique_bucket_df = (
+ bucket_ddf[[bucket_col_name]]
+ .map_partitions(lambda x: x.drop_duplicates(ignore_index=True))
+ .persist()
+ )
+ end_bucket_id = len(unique_bucket_df) - 1 + start_id
+ unique_bucket_df["bucket_int_id"] = np.uint64(1)
+ unique_bucket_df["bucket_int_id"] = unique_bucket_df["bucket_int_id"].cumsum()
+ unique_bucket_df["bucket_int_id"] = (
+ unique_bucket_df["bucket_int_id"] - 1 + start_id
+ )
+ bucket_ddf = bucket_ddf.merge(unique_bucket_df, on=[bucket_col_name])
+ bucket_ddf = bucket_ddf.drop(columns=[bucket_col_name])
+ bucket_ddf = bucket_ddf.rename(columns={"bucket_int_id": "_bucket_id"})
+ bucket_ddf["_bucket_id"] = bucket_ddf["_bucket_id"].astype(np.uint64)
+ return (bucket_ddf, end_bucket_id)
+
+ def _minhash_to_bucket_meta(
+ self, df: dask_cudf.DataFrame
+ ) -> Tuple[cudf.DataFrame, int]:
+ meta = df._meta_nonempty[self.id_fields]
+ meta[self.minhash_field] = [np.ones(self.num_hashes)] * len(meta)
+ return self.minhash_to_buckets(meta, self.bucket_ranges)
+
+ def lsh(
+ self,
+ write_path: str,
+ df: dask_cudf.DataFrame,
+ ) -> bool:
+ """
+ Computes hash buckets for the DataFrame and writes them as parquet files to the specified path.
+
+ Parameters:
+ - write_path (str): The directory path to write parquet files.
+ - df (dask_cudf.DataFrame): The input DataFrame with minhashes to be bucketed.
+ Returns:
+ are_buckets_empty: True if buckets were empty (no duplicates found), False otherwise.
+ """
+ wrote_buckets = False
+ are_buckets_empty = True
+
+ meta = self._minhash_to_bucket_meta(df)
+ df = df.map_partitions(
+ self.minhash_to_buckets,
+ bucket_ranges=self.bucket_ranges,
+ meta=meta,
+ )
+ bucket_start_id = 0
+ for i in range(0, self.num_buckets, self.buckets_per_shuffle):
+ bucket_columns = [
+ f"_bucket_{i}"
+ for i in range(i, min(self.num_buckets, i + self.buckets_per_shuffle))
+ ]
+ df2 = df.melt(
+ id_vars=self.id_fields,
+ value_name="_bucket_id",
+ value_vars=bucket_columns,
+ )[self.id_fields + ["_bucket_id"]]
+
+ df2 = df2.shuffle(
+ on=["_bucket_id"],
+ ignore_index=True,
+ npartitions=max(1, 2 ** math.floor(math.log2(df2.npartitions))),
+ ).map_partitions(lambda x: x[x["_bucket_id"].duplicated(keep=False)])
+
+ df2 = df2.reset_index(drop=True)
+ # Buckets to Int
+ if self.buckets_as_int:
+ df2, end_id = self.bucket_id_to_int(
+ df2, bucket_col_name="_bucket_id", start_id=bucket_start_id
+ )
+ # If bucketing return empty dataframe
+ if end_id < bucket_start_id:
+ self._logger.info(
+ f"No duplicate documents found for buckets: {bucket_columns}"
+ )
+ continue
+ bucket_start_id = end_id + 1
+ are_buckets_empty = False
+
+ wrote_buckets, are_buckets_empty = self._write_bucket_parquet(
+ df2,
+ write_path,
+ wrote_buckets,
+ are_buckets_empty,
+ bucket_columns,
+ )
+
+ if are_buckets_empty:
+ self._logger.info("No duplicate documents found during LSH")
+ if os.path.exists(write_path):
+ import shutil
+
+ shutil.rmtree(write_path)
+
+ return are_buckets_empty
+
+ def _write_bucket_parquet(
+ self,
+ df: dask_cudf.DataFrame,
+ write_path: str,
+ wrote_buckets: bool,
+ are_buckets_empty: bool,
+ buckets_to_write: List[str],
+ ) -> tuple[bool, bool]:
+ """
+ Utility function to write the bucketed data to parquet
+ handling cases of overwriting and appending as needed.
+ """
+ if not wrote_buckets:
+ if os.path.exists(write_path):
+ warnings.warn(
+ f"Output path {write_path} already exists and will be overwritten"
+ )
+ df.to_parquet(write_path, write_index=False, overwrite=True)
+ else:
+ df.to_parquet(
+ write_path,
+ write_index=False,
+ overwrite=are_buckets_empty,
+ append=not are_buckets_empty,
+ ignore_divisions=True,
+ )
+ # Only check if buckets written so far are empty
+ if are_buckets_empty:
+ are_buckets_empty = check_empty_buckets(write_path)
+ wrote_buckets = True
+
+ if are_buckets_empty:
+ self._logger.info(
+ f"No duplicate documents found for buckets: {buckets_to_write}"
+ )
+ else:
+ self._logger.info(f"Wrote data for buckets: {buckets_to_write}")
+ return wrote_buckets, are_buckets_empty
+
+ def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
+ df = dataset.df
+
+ write_path = os.path.join(self.cache_dir, "_buckets.parquet")
+ t0 = time.time()
+ with performance_report_if_with_ts_suffix(self.profile_dir, "lsh-profile"):
+ empty_result = self.lsh(write_path=write_path, df=df)
+ self._logger.info(
+ f"Time taken for LSH = {time.time() - t0}s and output written at {write_path}"
+ )
+
+ if empty_result:
+ return None
+
+ buckets_df = dask_cudf.read_parquet(write_path, split_row_groups=False)
+ return DocumentDataset(buckets_df)
diff --git a/nemo_curator/modules/fuzzy_dedup/minhash.py b/nemo_curator/modules/fuzzy_dedup/minhash.py
new file mode 100644
index 00000000..b38b2268
--- /dev/null
+++ b/nemo_curator/modules/fuzzy_dedup/minhash.py
@@ -0,0 +1,229 @@
+# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import logging
+import os
+import time
+import warnings
+from typing import Union
+
+import cudf
+import dask_cudf
+import numpy as np
+
+from nemo_curator._compat import MINHASH_DEPRECATED_API, MINHASH_PERMUTED_AVAILABLE
+from nemo_curator.datasets import DocumentDataset
+from nemo_curator.log import create_logger
+from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix
+
+
+class MinHash:
+ """
+ Computes minhash signatures of a document corpus
+ """
+
+ def __init__(
+ self,
+ seed: int = 42,
+ num_hashes: int = 260,
+ char_ngrams: int = 5,
+ use_64bit_hash: bool = False,
+ logger: Union[logging.LoggerAdapter, str] = "./",
+ id_field: str = "id",
+ text_field: str = "text",
+ profile_dir: str = None,
+ cache_dir: str = None,
+ ):
+ """
+ Parameters
+ ----------
+ seed: Seed for minhash permutations
+ num_hashes: Length of minhash signature (No. of minhash permutations)
+ char_ngrams: Width of text window (in characters) while computing minhashes.
+ use_64bit_hash: Whether to use a 64 bit hash function.
+ logger: Existing logger to log to, or a path to a log directory.
+ id_field: Column in the Dataset denoting document ID.
+ text_field: Column in the Dataset denoting document content.
+ profile_dir: str, Default None
+ If specified directory to write dask profile
+ cache_dir: str, Default None
+ If specified, will compute & write id, minhash pairs to directory
+ """
+ self.num_hashes = num_hashes
+ self.char_ngram = char_ngrams
+
+ if MINHASH_DEPRECATED_API:
+ self.seeds = self.generate_seeds(n_seeds=self.num_hashes, seed=seed)
+ else:
+ self.seeds = self.generate_hash_permutation_seeds(
+ bit_width=64 if use_64bit_hash else 32,
+ n_permutations=self.num_hashes,
+ seed=seed,
+ )
+
+ self.minhash_method = self.minhash64 if use_64bit_hash else self.minhash32
+ self.id_field = id_field
+ self.text_field = text_field
+
+ if cache_dir is None and profile_dir is not None:
+ warnings.warn(
+ "cache_dir for intermediate outputs is required to generate profiles"
+ )
+ self.cache_dir = cache_dir
+ self.profile_dir = profile_dir
+
+ if isinstance(logger, str):
+ self._logger = create_logger(
+ rank=0,
+ log_file=os.path.join(logger, "Minhash.log"),
+ name="Minhash",
+ )
+ else:
+ self._logger = logger
+
+ def generate_seeds(self, n_seeds: int = 260, seed: int = 0) -> np.ndarray:
+ """
+ Generate seeds for all minhash permutations based on the given seed.
+ """
+ gen = np.random.RandomState(seed)
+ return gen.randint(0, 1e6, size=n_seeds)
+
+ def generate_hash_permutation_seeds(
+ self, bit_width: int, n_permutations: int = 260, seed: int = 0
+ ) -> np.ndarray:
+ """
+ Generate seeds for all minhash permutations based on the given seed.
+ """
+ gen = np.random.RandomState(seed)
+
+ if bit_width == 32:
+ MERSENNE_PRIME = np.uint32((1 << 31) - 1)
+ dtype = np.uint32
+ elif bit_width == 64:
+ # For 64-bit, use a larger prime number suitable for 64-bit operations
+ MERSENNE_PRIME = np.uint64((1 << 61) - 1)
+ dtype = np.uint64
+ else:
+ raise ValueError("Unsupported bit width. Use either 32 or 64.")
+
+ return np.array(
+ [
+ (
+ gen.randint(1, MERSENNE_PRIME, dtype=dtype),
+ gen.randint(0, MERSENNE_PRIME, dtype=dtype),
+ )
+ for _ in range(n_permutations)
+ ],
+ dtype=dtype,
+ )
+
+ def minhash32(
+ self, ser: cudf.Series, seeds: np.ndarray, char_ngram: int
+ ) -> cudf.Series:
+ """
+ Compute 32bit minhashes based on the MurmurHash3 algorithm
+ """
+ if not isinstance(ser, cudf.Series):
+ raise TypeError("Expected data of type cudf.Series")
+
+ if MINHASH_DEPRECATED_API:
+ warnings.warn(
+ "Using an outdated minhash implementation, please update to cuDF version 24.12 "
+ "or later for improved performance. "
+ "Install the latest version of cuDF using `pip install curator[cuda12x_nightly]`",
+ category=FutureWarning,
+ )
+ seeds = cudf.Series(seeds, dtype="uint32")
+ return ser.str.minhash(seeds=seeds, width=char_ngram)
+ else:
+ seeds_a = cudf.Series(seeds[:, 0], dtype="uint32")
+ seeds_b = cudf.Series(seeds[:, 1], dtype="uint32")
+
+ if MINHASH_PERMUTED_AVAILABLE:
+ return ser.str.minhash_permuted(
+ a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
+ )
+ else:
+ return ser.str.minhash(
+ a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
+ )
+
+ def minhash64(
+ self, ser: cudf.Series, seeds: np.ndarray, char_ngram: int
+ ) -> cudf.Series:
+ """
+ Compute 64bit minhashes based on the MurmurHash3 algorithm
+ """
+ if not isinstance(ser, cudf.Series):
+ raise TypeError("Expected data of type cudf.Series")
+ if MINHASH_DEPRECATED_API:
+ warnings.warn(
+ "Using an outdated minhash implementation, please update to cuDF version 24.12 "
+ "or later for improved performance. "
+ "Install the latest version of cuDF using `pip install curator[cuda12x_nightly]`",
+ category=FutureWarning,
+ )
+ seeds = cudf.Series(seeds, dtype="uint64")
+ return ser.str.minhash64(seeds=seeds, width=char_ngram)
+ else:
+ seeds_a = cudf.Series(seeds[:, 0], dtype="uint64")
+ seeds_b = cudf.Series(seeds[:, 1], dtype="uint64")
+
+ if MINHASH_PERMUTED_AVAILABLE:
+ return ser.str.minhash64_permuted(
+ a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
+ )
+ else:
+ return ser.str.minhash64(
+ a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
+ )
+
+ def __call__(self, dataset: DocumentDataset) -> Union[str, DocumentDataset]:
+ """
+ Computes the MinHash Signatures for a given dataset.
+ Parameters
+ ----------
+ dataset: DocumentDataset
+ The input datset to compute MinHashes.
+ Returns
+ -------
+ DocumentDataset containing IDs of all documents and the corresponding MinHash Signature
+ """
+ result = dataset.df[[self.id_field]]
+ result["_minhash_signature"] = dataset.df[self.text_field].map_partitions(
+ self.minhash_method,
+ seeds=self.seeds,
+ char_ngram=self.char_ngram,
+ )
+
+ if self.cache_dir is None:
+ return DocumentDataset(result)
+
+ t0 = time.time()
+ self._logger.info("Starting execution for Minhashes")
+ write_path = os.path.join(self.cache_dir, "_minhashes.parquet")
+ if os.path.exists(write_path):
+ warnings.warn(
+ f"Output path {write_path} already exists and will be overwritten"
+ )
+ with performance_report_if_with_ts_suffix(self.profile_dir, "minhash-profile"):
+ result.to_parquet(write_path, write_index=False, overwrite=True)
+ self._logger.info(
+ f"Time taken for Minhash signature computation = {time.time() - t0}s and output written at {write_path}"
+ )
+ return DocumentDataset(
+ dask_cudf.read_parquet(write_path, blocksize="2GB", aggregate_files=True)
+ )
diff --git a/nemo_curator/modules/semantic_dedup.py b/nemo_curator/modules/semantic_dedup.py
deleted file mode 100644
index c8a96774..00000000
--- a/nemo_curator/modules/semantic_dedup.py
+++ /dev/null
@@ -1,649 +0,0 @@
-# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import logging
-import os
-import shutil
-import time
-from dataclasses import dataclass
-from typing import List, Optional, Union
-
-import cudf
-import cupy as cp
-import dask.bag as db
-import dask.dataframe as dd
-import dask_cudf
-import numpy as np
-import torch
-import torch.nn as nn
-from crossfit import op
-from crossfit.backend.torch.hf.model import HFModel
-from cuml.dask.cluster import KMeans
-from torch.nn import functional as F
-from transformers import AutoConfig, AutoModel, AutoTokenizer
-
-from nemo_curator.classifiers.base import _get_suggest_memory_for_classifier
-from nemo_curator.datasets import DocumentDataset
-from nemo_curator.log import create_logger
-from nemo_curator.modules.config import SemDedupConfig
-from nemo_curator.utils.distributed_utils import (
- performance_report_if_with_ts_suffix,
- write_to_disk,
-)
-from nemo_curator.utils.file_utils import expand_outdir_and_mkdir
-from nemo_curator.utils.semdedup_utils import (
- assign_and_sort_clusters,
- extract_dedup_data,
- get_semantic_matches_per_cluster,
-)
-
-
-# Embedding Creation Module
-@dataclass
-class EmbeddingConfig:
- model_name_or_path: str
- max_seq_length: int = None
-
- def __post_init__(self):
- self.max_seq_length = AutoTokenizer.from_pretrained(
- self.model_name_or_path
- ).model_max_length
- # Gaurd against the HF bug
- # which sets max_seq_length to max(int) for some models
- if self.max_seq_length > 1e5:
- self.max_seq_length = AutoConfig.from_pretrained(
- self.model_name_or_path
- ).max_position_embeddings
-
-
-class EmbeddingPytorchModel(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.model = AutoModel.from_pretrained(
- config.model_name_or_path, config=self.config, force_download=False
- )
-
- def feature(self, input_ids, attention_mask):
- with torch.autocast(device_type=input_ids.device.type):
- embeddings = self.model(input_ids=input_ids, attention_mask=attention_mask)
- return embeddings
-
- @torch.no_grad()
- def forward(self, batch):
- feature = self.feature(batch["input_ids"], batch["attention_mask"])
- return self._mean_pooling(feature, batch["attention_mask"])
-
- def _mean_pooling(self, model_output, attention_mask):
- token_embeddings = model_output[0]
- input_mask_expanded = (
- attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
- )
- sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, dim=1)
- sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)
- return F.normalize(sum_embeddings / sum_mask, dim=1)
-
-
-class EmbeddingCrossFitModel(HFModel):
- def __init__(
- self,
- config: EmbeddingConfig,
- max_mem_gb: Optional[int] = None,
- ):
- self.config = config
- if max_mem_gb is None:
- max_mem_gb = _get_suggest_memory_for_classifier()
- super().__init__(self.config.model_name_or_path, max_mem_gb=max_mem_gb)
-
- def load_model(self, device="cuda"):
- model = EmbeddingPytorchModel(self.config)
- model = model.to(device)
- model.eval()
- return model
-
- def max_seq_length(self):
- return self.config.max_seq_length
-
- def load_config(self):
- return AutoConfig.from_pretrained(self.config.model_name_or_path)
-
- def load_tokenizer(self):
- return AutoTokenizer.from_pretrained(self.config.model_name_or_path)
-
-
-class EmbeddingCreator:
- def __init__(
- self,
- embedding_model_name_or_path: str,
- embedding_batch_size: int,
- embedding_output_dir: str,
- embedding_max_mem_gb: Optional[int] = None,
- input_column: str = "text",
- embedding_column: str = "embeddings",
- write_embeddings_to_disk: bool = True,
- write_to_filename: bool = False,
- logger: Union[logging.Logger, str] = "./",
- profile_dir: Optional[str] = None,
- ):
- """
- Initializes an EmbeddingCreator for generating embeddings using the specified model configurations.
-
- Args:
- embedding_model_name_or_path (str): The path or identifier for the model used to generate embeddings.
- embedding_batch_size (int): Number of samples to process in each batch.
- embedding_output_dir (str): Directory path where embeddings will be saved.
- embedding_max_mem_gb (int): Maximum memory usage in GB for the embedding process.
- If None, it defaults to the available GPU memory minus 4 GB.
- input_column (str): Column name from the data to be used for embedding generation, defaults to "text".
- write_embeddings_to_disk (bool, optional): If True, saves the embeddings to disk, defaults to True.
- We recommend setting this to False when you have a delayed pipeline.
- Setting it to False can lead to more memory overhead.
- write_to_filename (bool): If True, saves the embeddings to the same filename as input files, defaults to False.
- logger (Union[logging.Logger, str]): Logger object or path to store logs, defaults to "./".
- profile_dir (str): If specified directory to write dask profile. Default is None.
-
- Attributes:
- embeddings_config (EmbeddingConfig): Configuration for embeddings.
- batch_size (int): Batch size for embedding generation.
- logger (logging.Logger): Logger instance for the class.
- embedding_output_dir (str): Output directory for embeddings.
- input_column (str): Input column for data processing.
- model (EmbeddingCrossFitModel): Model instance for embedding generation.
- write_to_filename (bool): If True, saves the embeddings to the same filename as input files, defaults to False.
- """
-
- self.embeddings_config = EmbeddingConfig(
- model_name_or_path=embedding_model_name_or_path,
- )
- self.batch_size = embedding_batch_size
- self.logger = self._setup_logger(logger)
- self.embedding_output_dir = embedding_output_dir
- self.input_column = input_column
- self.embedding_column = embedding_column
- self.model = EmbeddingCrossFitModel(
- self.embeddings_config, max_mem_gb=embedding_max_mem_gb
- )
- self.write_embeddings_to_disk = write_embeddings_to_disk
- self.write_to_filename = write_to_filename
- self.profile_dir = profile_dir
-
- def _setup_logger(self, logger):
- if isinstance(logger, str):
- return create_logger(
- rank=0,
- name="compute-embeddings",
- log_file=os.path.join(logger, "compute_embeddings.log"),
- log_level=logging.INFO,
- stdout=True,
- )
- else:
- return logger
-
- def create_embeddings(
- self, ddf: dask_cudf.DataFrame, input_column="text"
- ) -> dask_cudf.DataFrame:
- pipe = op.Sequential(
- op.Tokenizer(
- self.model,
- cols=[input_column],
- tokenizer_type="sentencepiece",
- max_length=self.embeddings_config.max_seq_length,
- ),
- op.Predictor(
- self.model,
- sorted_data_loader=True,
- batch_size=self.batch_size,
- pred_output_col=self.embedding_column,
- ),
- keep_cols=ddf.columns.tolist(),
- )
- return pipe(ddf)
-
- def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
- t0 = time.time()
- if self.write_embeddings_to_disk:
- with performance_report_if_with_ts_suffix(
- self.profile_dir, "embedding-creator"
- ):
- embedding_ddf = self.create_embeddings(dataset.df, self.input_column)
- write_to_disk(
- embedding_ddf,
- self.embedding_output_dir,
- write_to_filename=self.write_to_filename,
- output_type="parquet",
- )
-
- ddf = DocumentDataset(
- dask_cudf.read_parquet(
- self.embedding_output_dir, blocksize="2GB", aggregate_files=True
- )
- )
- else:
- ddf = DocumentDataset(embedding_ddf)
-
- self.logger.info(
- f"Time taken for Creating Embeddings : {time.time() - t0}"
- + (
- f" and output written at {self.embedding_output_dir}"
- if self.write_embeddings_to_disk
- else ""
- )
- )
-
- return ddf
-
-
-### Clustering Module
-def get_embedding_ar(df: "cudf.DataFrame", embedding_col: str) -> cp.ndarray:
- return df[embedding_col].list.leaves.values.reshape(len(df), -1)
-
-
-def add_dist_to_cents(
- df: "cudf.DataFrame", embedding_col: str, centroids: cp.ndarray
-) -> "cudf.DataFrame":
- embed_array = get_embedding_ar(df, embedding_col)
- centroids_ar = centroids[df["nearest_cent"].values]
- dist_to_cents = cp.sqrt(np.sum((embed_array - centroids_ar) ** 2, axis=1))
- df["dist_to_cent"] = dist_to_cents
- return df
-
-
-class ClusteringModel:
- def __init__(
- self,
- id_column: str,
- max_iter: int,
- n_clusters: int,
- clustering_output_dir: str,
- embedding_col: str = "embeddings",
- sim_metric: str = "cosine",
- which_to_keep: str = "hard",
- sort_clusters: bool = True,
- kmeans_with_cos_dist: bool = False,
- partition_size: str = "2gb",
- logger: Union[logging.Logger, str] = "./",
- profile_dir: Optional[str] = None,
- ):
- """
- Initializes the ClusteringModel with the provided settings for semantic clustering to help semantic deduplication.
-
- Args:
- id_column (str): Column name used as the identifier in the dataset.
- max_iter (int): Maximum number of iterations for the clustering algorithm.
- n_clusters (int): The number of clusters to form.
- clustering_output_dir (str): Directory path where clustering results will be saved.
- embedding_col (str): Column name where the embeddings are stored.
- sim_metric (str): Similarity metric to use for clustering, default is "cosine".
- which_to_keep (str): Strategy to decide which duplicates to keep; default is "hard".
- sort_clusters (bool): Whether to sort clusters, default is True.
- kmeans_with_cos_dist (bool): Whether to use KMeans with cosine distance, default is False.
- partition_size (str): The size of data partition to run kmeans with, default is "2gb".
- logger (Union[logging.Logger, str]): Logger object or directory path to save logs; default is "./".
- profile_dir (str): If specified directory to write dask profile. Default is None.
-
- This constructor sets up the parameters required for clustering operations.
- """
- self.id_col = id_column
- self.max_iter = max_iter
- self.n_clusters = n_clusters
- self.clustering_output_dir = clustering_output_dir
- self.embedding_col = embedding_col
- self.sim_metric = sim_metric
- self.keep_hard = which_to_keep == "hard"
- self.kmeans_with_cos_dist = kmeans_with_cos_dist
- self.partition_size = partition_size
- self.sort_clusters = sort_clusters
- self.logger = self._setup_logger(logger)
- self.profile_dir = profile_dir
-
- if not os.path.exists(self.clustering_output_dir):
- expand_outdir_and_mkdir(self.clustering_output_dir)
- else:
- self.logger.warning(
- f"Clustering output directory {self.clustering_output_dir} already exists and will be overwritten"
- )
-
- def _setup_logger(self, logger):
- if isinstance(logger, str):
- return create_logger(
- rank=0,
- name="SemanticClusterLevelDedup",
- log_file=os.path.join(logger, "SemanticClusterLevelDedup.log"),
- log_level=logging.INFO,
- stdout=True,
- )
- else:
- return logger
-
- def __call__(self, embeddings_dataset: DocumentDataset):
- embeddings_df = embeddings_dataset.df
-
- if self.embedding_col not in embeddings_df.columns:
- raise ValueError(
- f"Expected embedding column '{self.embedding_col}'"
- f" to be in dataset. Only found columns {embeddings_df.columns}"
- )
-
- with performance_report_if_with_ts_suffix(self.profile_dir, "clustering-model"):
- embeddings_df = embeddings_df[[self.id_col, self.embedding_col]]
- embeddings_df = embeddings_df.repartition(
- partition_size=self.partition_size
- )
- embeddings_df = embeddings_df.to_backend("pandas").persist()
- embeddings_df = embeddings_df.to_backend("cudf")
-
- cupy_darr = embeddings_df.map_partitions(
- get_embedding_ar, self.embedding_col, meta=cp.ndarray([1, 1])
- )
- cupy_darr.compute_chunk_sizes()
- t0 = time.time()
- kmeans = KMeans(n_clusters=self.n_clusters, max_iter=self.max_iter)
- self.logger.info("KMeans starting fit")
- kmeans.fit(cupy_darr)
- self.logger.info("KMeans fit complete")
- self.logger.info(f"Time taken for KMeans Fit: {time.time() - t0}")
-
- self.logger.info(
- "Computing nearest centroids + distance to centers using kmeans.predict"
- )
- t0 = time.time()
- nearest_cents = kmeans.predict(cupy_darr)
- self.logger.info(f"Time taken for KMeans Predict: {time.time() - t0}")
- t0 = time.time()
- embeddings_df["nearest_cent"] = nearest_cents.astype(np.int32)
- del nearest_cents
- meta_df = embeddings_df._meta.copy()
- meta_df["dist_to_cent"] = cp.zeros(1)
- embeddings_df = embeddings_df.map_partitions(
- add_dist_to_cents,
- embedding_col=self.embedding_col,
- centroids=kmeans.cluster_centers_,
- meta=meta_df,
- )
- embeddings_df = embeddings_df.reset_index(drop=True)
- centroids = kmeans.cluster_centers_
- kmeans_centroids_file = os.path.join(
- self.clustering_output_dir, "kmeans_centroids.npy"
- )
- np.save(kmeans_centroids_file, centroids)
- self.logger.info("Saving centroids complete")
- del kmeans, cupy_darr, centroids
-
- clustering_output_dir = os.path.join(
- self.clustering_output_dir, "embs_by_nearest_center"
- )
- if os.path.exists(clustering_output_dir):
- self.logger.warning(
- f"Output directory {clustering_output_dir} already exists and will be overwritten"
- )
- shutil.rmtree(clustering_output_dir)
-
- embeddings_df.to_parquet(
- clustering_output_dir,
- index=False,
- partition_on="nearest_cent",
- )
- self.logger.info(
- f"Time taken for Assigning distance to each embedding : {time.time() - t0} "
- f"and output written at {clustering_output_dir}"
- )
-
- del embeddings_df
-
- if self.sort_clusters:
- assign_and_sort_clusters(
- id_col=self.id_col,
- kmeans_centroids_file=kmeans_centroids_file,
- nearest_cent_dir=clustering_output_dir,
- output_sorted_clusters_dir=os.path.join(
- self.clustering_output_dir, "sorted"
- ),
- embedding_col=self.embedding_col,
- sim_metric=self.sim_metric,
- keep_hard=self.keep_hard,
- kmeans_with_cos_dist=self.kmeans_with_cos_dist,
- cluster_ids=range(self.n_clusters),
- logger=self.logger,
- profile_dir=self.profile_dir,
- )
-
- fps = [
- os.path.join(clustering_output_dir, file_name)
- for file_name in os.listdir(clustering_output_dir)
- ]
- embeddings_df = dd.from_map(cudf.read_parquet, fps)
- return DocumentDataset(embeddings_df)
-
-
-class SemanticClusterLevelDedup:
- def __init__(
- self,
- n_clusters: int,
- emb_by_clust_dir: str,
- sorted_clusters_dir: str,
- id_column: str,
- id_column_type: str,
- which_to_keep: str,
- output_dir: str,
- embedding_col: str = "embeddings",
- logger: Union[logging.Logger, str] = "./",
- profile_dir: Optional[str] = None,
- ) -> None:
- """
- Initialize the SemanticClusterLevelDedup class.
-
- Args:
- n_clusters (int): Number of clusters.
- emb_by_clust_dir (str): Directory containing embeddings by cluster.
- sorted_clusters_dir (str): Directory containing sorted clusters.
- id_column (str): Column name for IDs.
- id_column_type (str): Data type of the ID column.
- which_to_keep (str): Strategy for which duplicate to keep.
- output_dir (str): Directory to save output files.
- embedding_col (str): Column where the embeddings are stored.
- logger (Union[logging.Logger, str]): Logger instance or path to the log file directory.
- profile_dir (str): If specified directory to write dask profile. Default is None.
- """
- self.n_clusters = n_clusters
- self.emb_by_clust_dir = emb_by_clust_dir
- self.sorted_clusters_dir = sorted_clusters_dir
- self.id_col = id_column
- self.id_col_type = id_column_type
- self.which_to_keep = which_to_keep
- self.output_dir = output_dir
- self.semdedup_pruning_tables_dir = os.path.join(
- output_dir, "semdedup_pruning_tables"
- )
- self.computed_semantic_match_dfs = False
- self.embedding_col = embedding_col
- self.logger = self._setup_logger(logger)
- self.profile_dir = profile_dir
-
- def _setup_logger(self, logger: Union[logging.Logger, str]) -> logging.Logger:
- """
- Set up the logger.
-
- Args:
- logger (Union[logging.Logger, str]): Logger instance or path to the log file directory.
-
- Returns:
- logging.Logger: Configured logger.
- """
- if isinstance(logger, str):
- return create_logger(
- rank=0,
- name="SemanticClusterLevelDedup",
- log_file=os.path.join(logger, "SemanticClusterLevelDedup.log"),
- log_level=logging.INFO,
- stdout=True,
- )
- else:
- return logger
-
- def compute_semantic_match_dfs(
- self, eps_list: Optional[List[float]] = None
- ) -> None:
- """
- Compute semantic match dataframes for clusters.
-
- Args:
- eps_list (Optional[List[float]]): List of epsilon values for clustering.
- """
- if eps_list is None:
- eps_list1 = [1.0e-2, 1.0e-3, 1.0e-4, 1.0e-5, 1.0e-6]
- eps_list2 = [0.1 + x * 0.005 for x in range(34)]
- eps_list = eps_list1 + eps_list2
-
- if os.path.exists(self.semdedup_pruning_tables_dir):
- self.logger.info(
- f"Removing existing directory {self.semdedup_pruning_tables_dir}"
- )
- shutil.rmtree(self.semdedup_pruning_tables_dir)
- expand_outdir_and_mkdir(self.semdedup_pruning_tables_dir)
- t0 = time.time()
- with performance_report_if_with_ts_suffix(
- self.profile_dir, "semantic-match-compute"
- ):
- tasks = db.from_sequence(
- list(range(self.n_clusters)), npartitions=self.n_clusters
- ).map(
- lambda cluster_id: get_semantic_matches_per_cluster(
- cluster_id=cluster_id,
- emb_by_clust_dir=self.emb_by_clust_dir,
- sorted_clusters_dir=self.sorted_clusters_dir,
- id_col=self.id_col,
- id_col_type=self.id_col_type,
- eps_list=eps_list,
- output_dir=self.semdedup_pruning_tables_dir,
- embedding_col=self.embedding_col,
- which_to_keep=self.which_to_keep,
- )
- )
- tasks.compute()
- self.logger.info(
- f"Time taken for Computing Semantic Matches : {time.time() - t0}"
- )
- self.computed_semantic_match_dfs = True
-
- def extract_dedup_data(self, eps_to_extract: float) -> DocumentDataset:
- """
- Extract deduplicated data based on epsilon value.
-
- Args:
- eps_to_extract (float): Epsilon threshold for extracting deduplicated data.
-
- Returns:
- DocumentDataset: Dataset containing deduplicated documents.
- """
- if not self.computed_semantic_match_dfs:
- raise ValueError(
- "Run compute_semantic_match_dfs before calling extract_dedup_data"
- )
-
- output_summary_file = os.path.join(
- self.output_dir, f"dedup_summary_{eps_to_extract}.csv"
- )
- output_parquet_path = os.path.join(
- self.output_dir, f"unique_ids_{eps_to_extract}.parquet"
- )
- extract_dedup_data(
- eps=eps_to_extract,
- n_clusters=self.n_clusters,
- id_col=self.id_col,
- id_col_type=self.id_col_type,
- sorted_clusters_dir=self.sorted_clusters_dir,
- semdedup_pruning_tables_dir=self.semdedup_pruning_tables_dir,
- output_summary_file=output_summary_file,
- output_parquet_path=output_parquet_path,
- logger=self.logger,
- profile_dir=self.profile_dir,
- )
-
- fps = [
- os.path.join(output_parquet_path, file_name)
- for file_name in os.listdir(output_parquet_path)
- ]
- return DocumentDataset.read_parquet(fps, backend="cudf")
-
-
-class SemDedup:
- def __init__(
- self,
- config: SemDedupConfig,
- input_column: str = "text",
- id_column: str = "id",
- id_column_type: str = "int",
- logger: Union[logging.Logger, str] = "./",
- ) -> None:
- """
- Initialize the SemDedup class.
-
- Args:
- config (SemDedupConfig): Configuration for SemDedup.
- logger (Union[logging.Logger, str]): Logger instance or path to the log file directory.
- """
- self.config = config
- self.logger = logger
- cache_dir = config.cache_dir
- self.embedding_creator = EmbeddingCreator(
- embedding_model_name_or_path=config.embedding_model_name_or_path,
- embedding_batch_size=config.embedding_batch_size,
- input_column=input_column,
- embedding_output_dir=os.path.join(cache_dir, config.embeddings_save_loc),
- logger=logger,
- profile_dir=self.config.profile_dir,
- )
- self.clustering_model = ClusteringModel(
- id_column=id_column,
- max_iter=config.max_iter,
- n_clusters=config.n_clusters,
- clustering_output_dir=os.path.join(cache_dir, config.clustering_save_loc),
- logger=logger,
- profile_dir=self.config.profile_dir,
- )
- self.semantic_cluster_dedup = SemanticClusterLevelDedup(
- n_clusters=config.n_clusters,
- emb_by_clust_dir=os.path.join(
- cache_dir, config.clustering_save_loc, "embs_by_nearest_center"
- ),
- sorted_clusters_dir=os.path.join(
- cache_dir, config.clustering_save_loc, "sorted"
- ),
- id_column=id_column,
- id_column_type=id_column_type,
- which_to_keep=config.which_to_keep,
- output_dir=os.path.join(cache_dir, config.clustering_save_loc),
- logger=logger,
- profile_dir=self.config.profile_dir,
- )
- self.eps_thresholds = config.eps_thresholds
- self.eps_to_extract = config.eps_to_extract
-
- def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
- """
- Execute the SemDedup process.
-
- Args:
- dataset (DocumentDataset): Input dataset for deduplication.
-
- Returns:
- DocumentDataset: Deduplicated dataset.
- """
- embeddings_dataset = self.embedding_creator(dataset)
- self.clustering_model(embeddings_dataset)
- self.semantic_cluster_dedup.compute_semantic_match_dfs(self.eps_thresholds)
- return self.semantic_cluster_dedup.extract_dedup_data(
- eps_to_extract=self.eps_to_extract
- )
diff --git a/nemo_curator/modules/semantic_dedup/clusteringmodel.py b/nemo_curator/modules/semantic_dedup/clusteringmodel.py
new file mode 100644
index 00000000..18714ff7
--- /dev/null
+++ b/nemo_curator/modules/semantic_dedup/clusteringmodel.py
@@ -0,0 +1,215 @@
+# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+import os
+import shutil
+import time
+from typing import Optional, Union
+
+import cudf
+import cupy as cp
+import dask.dataframe as dd
+import numpy as np
+from cuml.dask.cluster import KMeans
+
+from nemo_curator.datasets import DocumentDataset
+from nemo_curator.log import create_logger
+from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix
+from nemo_curator.utils.file_utils import expand_outdir_and_mkdir
+from nemo_curator.utils.semdedup_utils import assign_and_sort_clusters
+
+
+### Clustering Module
+def get_embedding_ar(df: "cudf.DataFrame", embedding_col: str) -> cp.ndarray:
+ return df[embedding_col].list.leaves.values.reshape(len(df), -1)
+
+
+def add_dist_to_cents(
+ df: "cudf.DataFrame", embedding_col: str, centroids: cp.ndarray
+) -> "cudf.DataFrame":
+ embed_array = get_embedding_ar(df, embedding_col)
+ centroids_ar = centroids[df["nearest_cent"].values]
+ dist_to_cents = cp.sqrt(np.sum((embed_array - centroids_ar) ** 2, axis=1))
+ df["dist_to_cent"] = dist_to_cents
+ return df
+
+
+class ClusteringModel:
+ def __init__(
+ self,
+ id_column: str,
+ max_iter: int,
+ n_clusters: int,
+ clustering_output_dir: str,
+ embedding_col: str = "embeddings",
+ sim_metric: str = "cosine",
+ which_to_keep: str = "hard",
+ sort_clusters: bool = True,
+ kmeans_with_cos_dist: bool = False,
+ partition_size: str = "2gb",
+ logger: Union[logging.Logger, str] = "./",
+ profile_dir: Optional[str] = None,
+ ):
+ """
+ Initializes the ClusteringModel with the provided settings for semantic clustering to help semantic deduplication.
+
+ Args:
+ id_column (str): Column name used as the identifier in the dataset.
+ max_iter (int): Maximum number of iterations for the clustering algorithm.
+ n_clusters (int): The number of clusters to form.
+ clustering_output_dir (str): Directory path where clustering results will be saved.
+ embedding_col (str): Column name where the embeddings are stored.
+ sim_metric (str): Similarity metric to use for clustering, default is "cosine".
+ which_to_keep (str): Strategy to decide which duplicates to keep; default is "hard".
+ sort_clusters (bool): Whether to sort clusters, default is True.
+ kmeans_with_cos_dist (bool): Whether to use KMeans with cosine distance, default is False.
+ partition_size (str): The size of data partition to run kmeans with, default is "2gb".
+ logger (Union[logging.Logger, str]): Logger object or directory path to save logs; default is "./".
+ profile_dir (str): If specified directory to write dask profile. Default is None.
+
+ This constructor sets up the parameters required for clustering operations.
+ """
+ self.id_col = id_column
+ self.max_iter = max_iter
+ self.n_clusters = n_clusters
+ self.clustering_output_dir = clustering_output_dir
+ self.embedding_col = embedding_col
+ self.sim_metric = sim_metric
+ self.keep_hard = which_to_keep == "hard"
+ self.kmeans_with_cos_dist = kmeans_with_cos_dist
+ self.partition_size = partition_size
+ self.sort_clusters = sort_clusters
+ self.logger = self._setup_logger(logger)
+ self.profile_dir = profile_dir
+
+ if not os.path.exists(self.clustering_output_dir):
+ expand_outdir_and_mkdir(self.clustering_output_dir)
+ else:
+ self.logger.warning(
+ f"Clustering output directory {self.clustering_output_dir} already exists and will be overwritten"
+ )
+
+ def _setup_logger(self, logger):
+ if isinstance(logger, str):
+ return create_logger(
+ rank=0,
+ name="SemanticClusterLevelDedup",
+ log_file=os.path.join(logger, "SemanticClusterLevelDedup.log"),
+ log_level=logging.INFO,
+ stdout=True,
+ )
+ else:
+ return logger
+
+ def __call__(self, embeddings_dataset: DocumentDataset):
+ embeddings_df = embeddings_dataset.df
+
+ if self.embedding_col not in embeddings_df.columns:
+ raise ValueError(
+ f"Expected embedding column '{self.embedding_col}'"
+ f" to be in dataset. Only found columns {embeddings_df.columns}"
+ )
+
+ with performance_report_if_with_ts_suffix(self.profile_dir, "clustering-model"):
+ embeddings_df = embeddings_df[[self.id_col, self.embedding_col]]
+ embeddings_df = embeddings_df.repartition(
+ partition_size=self.partition_size
+ )
+ embeddings_df = embeddings_df.to_backend("pandas").persist()
+ embeddings_df = embeddings_df.to_backend("cudf")
+
+ cupy_darr = embeddings_df.map_partitions(
+ get_embedding_ar, self.embedding_col, meta=cp.ndarray([1, 1])
+ )
+ cupy_darr.compute_chunk_sizes()
+ t0 = time.time()
+ kmeans = KMeans(n_clusters=self.n_clusters, max_iter=self.max_iter)
+ self.logger.info("KMeans starting fit")
+ kmeans.fit(cupy_darr)
+ self.logger.info("KMeans fit complete")
+ self.logger.info(f"Time taken for KMeans Fit: {time.time() - t0}")
+
+ self.logger.info(
+ "Computing nearest centroids + distance to centers using kmeans.predict"
+ )
+ t0 = time.time()
+ nearest_cents = kmeans.predict(cupy_darr)
+ self.logger.info(f"Time taken for KMeans Predict: {time.time() - t0}")
+
+ t0 = time.time()
+ embeddings_df["nearest_cent"] = nearest_cents.astype(np.int32)
+ del nearest_cents
+ meta_df = embeddings_df._meta.copy()
+ meta_df["dist_to_cent"] = cp.zeros(1)
+ embeddings_df = embeddings_df.map_partitions(
+ add_dist_to_cents,
+ embedding_col=self.embedding_col,
+ centroids=kmeans.cluster_centers_,
+ meta=meta_df,
+ )
+ embeddings_df = embeddings_df.reset_index(drop=True)
+ centroids = kmeans.cluster_centers_
+ kmeans_centroids_file = os.path.join(
+ self.clustering_output_dir, "kmeans_centroids.npy"
+ )
+ np.save(kmeans_centroids_file, centroids)
+ self.logger.info("Saving centroids complete")
+ del kmeans, cupy_darr, centroids
+
+ clustering_output_dir = os.path.join(
+ self.clustering_output_dir, "embs_by_nearest_center"
+ )
+ if os.path.exists(clustering_output_dir):
+ self.logger.warning(
+ f"Output directory {clustering_output_dir} already exists and will be overwritten"
+ )
+ shutil.rmtree(clustering_output_dir)
+
+ embeddings_df.to_parquet(
+ clustering_output_dir,
+ index=False,
+ partition_on="nearest_cent",
+ )
+ self.logger.info(
+ f"Time taken for Assigning distance to each embedding : {time.time() - t0} "
+ f"and output written at {clustering_output_dir}"
+ )
+
+ del embeddings_df
+
+ if self.sort_clusters:
+ assign_and_sort_clusters(
+ id_col=self.id_col,
+ kmeans_centroids_file=kmeans_centroids_file,
+ nearest_cent_dir=clustering_output_dir,
+ output_sorted_clusters_dir=os.path.join(
+ self.clustering_output_dir, "sorted"
+ ),
+ embedding_col=self.embedding_col,
+ sim_metric=self.sim_metric,
+ keep_hard=self.keep_hard,
+ kmeans_with_cos_dist=self.kmeans_with_cos_dist,
+ cluster_ids=range(self.n_clusters),
+ logger=self.logger,
+ profile_dir=self.profile_dir,
+ )
+
+ fps = [
+ os.path.join(clustering_output_dir, file_name)
+ for file_name in os.listdir(clustering_output_dir)
+ ]
+ embeddings_df = dd.from_map(cudf.read_parquet, fps)
+ return DocumentDataset(embeddings_df)
diff --git a/nemo_curator/modules/semantic_dedup/embeddings.py b/nemo_curator/modules/semantic_dedup/embeddings.py
new file mode 100644
index 00000000..4a0b638b
--- /dev/null
+++ b/nemo_curator/modules/semantic_dedup/embeddings.py
@@ -0,0 +1,231 @@
+# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+import os
+import time
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import dask_cudf
+import torch
+import torch.nn as nn
+from crossfit import op
+from crossfit.backend.torch.hf.model import HFModel
+from torch.nn import functional as F
+from transformers import AutoConfig, AutoModel, AutoTokenizer
+
+from nemo_curator.classifiers.base import _get_suggest_memory_for_classifier
+from nemo_curator.datasets import DocumentDataset
+from nemo_curator.log import create_logger
+from nemo_curator.utils.distributed_utils import (
+ performance_report_if_with_ts_suffix,
+ write_to_disk,
+)
+
+
+# Embedding Creation Module
+@dataclass
+class EmbeddingConfig:
+ model_name_or_path: str
+ max_seq_length: int = None
+
+ def __post_init__(self):
+ self.max_seq_length = AutoTokenizer.from_pretrained(
+ self.model_name_or_path
+ ).model_max_length
+ # Gaurd against the HF bug
+ # which sets max_seq_length to max(int) for some models
+ if self.max_seq_length > 1e5:
+ self.max_seq_length = AutoConfig.from_pretrained(
+ self.model_name_or_path
+ ).max_position_embeddings
+
+
+class EmbeddingPytorchModel(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.model = AutoModel.from_pretrained(
+ config.model_name_or_path, config=self.config, force_download=False
+ )
+
+ def feature(self, input_ids, attention_mask):
+ with torch.autocast(device_type=input_ids.device.type):
+ embeddings = self.model(input_ids=input_ids, attention_mask=attention_mask)
+ return embeddings
+
+ @torch.no_grad()
+ def forward(self, batch):
+ feature = self.feature(batch["input_ids"], batch["attention_mask"])
+ return self._mean_pooling(feature, batch["attention_mask"])
+
+ def _mean_pooling(self, model_output, attention_mask):
+ token_embeddings = model_output[0]
+ input_mask_expanded = (
+ attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
+ )
+ sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, dim=1)
+ sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)
+ return F.normalize(sum_embeddings / sum_mask, dim=1)
+
+
+class EmbeddingCrossFitModel(HFModel):
+ def __init__(
+ self,
+ config: EmbeddingConfig,
+ max_mem_gb: Optional[int] = None,
+ ):
+ self.config = config
+ if max_mem_gb is None:
+ max_mem_gb = _get_suggest_memory_for_classifier()
+ super().__init__(self.config.model_name_or_path, max_mem_gb=max_mem_gb)
+
+ def load_model(self, device="cuda"):
+ model = EmbeddingPytorchModel(self.config)
+ model = model.to(device)
+ model.eval()
+ return model
+
+ def max_seq_length(self):
+ return self.config.max_seq_length
+
+ def load_config(self):
+ return AutoConfig.from_pretrained(self.config.model_name_or_path)
+
+ def load_tokenizer(self):
+ return AutoTokenizer.from_pretrained(self.config.model_name_or_path)
+
+
+class EmbeddingCreator:
+ def __init__(
+ self,
+ embedding_model_name_or_path: str,
+ embedding_batch_size: int,
+ embedding_output_dir: str,
+ embedding_max_mem_gb: Optional[int] = None,
+ input_column: str = "text",
+ embedding_column: str = "embeddings",
+ write_embeddings_to_disk: bool = True,
+ write_to_filename: bool = False,
+ logger: Union[logging.Logger, str] = "./",
+ profile_dir: Optional[str] = None,
+ ):
+ """
+ Initializes an EmbeddingCreator for generating embeddings using the specified model configurations.
+
+ Args:
+ embedding_model_name_or_path (str): The path or identifier for the model used to generate embeddings.
+ embedding_batch_size (int): Number of samples to process in each batch.
+ embedding_output_dir (str): Directory path where embeddings will be saved.
+ embedding_max_mem_gb (int): Maximum memory usage in GB for the embedding process.
+ If None, it defaults to the available GPU memory minus 4 GB.
+ input_column (str): Column name from the data to be used for embedding generation, defaults to "text".
+ write_embeddings_to_disk (bool, optional): If True, saves the embeddings to disk, defaults to True.
+ We recommend setting this to False when you have a delayed pipeline.
+ Setting it to False can lead to more memory overhead.
+ write_to_filename (bool): If True, saves the embeddings to the same filename as input files, defaults to False.
+ logger (Union[logging.Logger, str]): Logger object or path to store logs, defaults to "./".
+ profile_dir (str): If specified directory to write dask profile. Default is None.
+
+ Attributes:
+ embeddings_config (EmbeddingConfig): Configuration for embeddings.
+ batch_size (int): Batch size for embedding generation.
+ logger (logging.Logger): Logger instance for the class.
+ embedding_output_dir (str): Output directory for embeddings.
+ input_column (str): Input column for data processing.
+ model (EmbeddingCrossFitModel): Model instance for embedding generation.
+ write_to_filename (bool): If True, saves the embeddings to the same filename as input files, defaults to False.
+ """
+
+ self.embeddings_config = EmbeddingConfig(
+ model_name_or_path=embedding_model_name_or_path,
+ )
+ self.batch_size = embedding_batch_size
+ self.logger = self._setup_logger(logger)
+ self.embedding_output_dir = embedding_output_dir
+ self.input_column = input_column
+ self.embedding_column = embedding_column
+ self.model = EmbeddingCrossFitModel(
+ self.embeddings_config, max_mem_gb=embedding_max_mem_gb
+ )
+ self.write_embeddings_to_disk = write_embeddings_to_disk
+ self.write_to_filename = write_to_filename
+ self.profile_dir = profile_dir
+
+ def _setup_logger(self, logger):
+ if isinstance(logger, str):
+ return create_logger(
+ rank=0,
+ name="compute-embeddings",
+ log_file=os.path.join(logger, "compute_embeddings.log"),
+ log_level=logging.INFO,
+ stdout=True,
+ )
+ else:
+ return logger
+
+ def create_embeddings(
+ self, ddf: dask_cudf.DataFrame, input_column="text"
+ ) -> dask_cudf.DataFrame:
+ pipe = op.Sequential(
+ op.Tokenizer(
+ self.model,
+ cols=[input_column],
+ tokenizer_type="sentencepiece",
+ max_length=self.embeddings_config.max_seq_length,
+ ),
+ op.Predictor(
+ self.model,
+ sorted_data_loader=True,
+ batch_size=self.batch_size,
+ pred_output_col=self.embedding_column,
+ ),
+ keep_cols=ddf.columns.tolist(),
+ )
+ return pipe(ddf)
+
+ def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
+ t0 = time.time()
+ if self.write_embeddings_to_disk:
+ with performance_report_if_with_ts_suffix(
+ self.profile_dir, "embedding-creator"
+ ):
+ embedding_ddf = self.create_embeddings(dataset.df, self.input_column)
+ write_to_disk(
+ embedding_ddf,
+ self.embedding_output_dir,
+ write_to_filename=self.write_to_filename,
+ output_type="parquet",
+ )
+
+ ddf = DocumentDataset(
+ dask_cudf.read_parquet(
+ self.embedding_output_dir, blocksize="2GB", aggregate_files=True
+ )
+ )
+ else:
+ ddf = DocumentDataset(embedding_ddf)
+
+ self.logger.info(
+ f"Time taken for Creating Embeddings : {time.time() - t0}"
+ + (
+ f" and output written at {self.embedding_output_dir}"
+ if self.write_embeddings_to_disk
+ else ""
+ )
+ )
+
+ return ddf
diff --git a/nemo_curator/modules/semantic_dedup/semanticclusterleveldedup.py b/nemo_curator/modules/semantic_dedup/semanticclusterleveldedup.py
new file mode 100644
index 00000000..4329c2b0
--- /dev/null
+++ b/nemo_curator/modules/semantic_dedup/semanticclusterleveldedup.py
@@ -0,0 +1,183 @@
+# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+import os
+import shutil
+import time
+from typing import List, Optional, Union
+
+import dask.bag as db
+
+from nemo_curator.datasets import DocumentDataset
+from nemo_curator.log import create_logger
+from nemo_curator.modules.config import SemDedupConfig
+from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix
+from nemo_curator.utils.file_utils import expand_outdir_and_mkdir
+from nemo_curator.utils.semdedup_utils import (
+ extract_dedup_data,
+ get_semantic_matches_per_cluster,
+)
+
+
+class SemanticClusterLevelDedup:
+ def __init__(
+ self,
+ n_clusters: int,
+ emb_by_clust_dir: str,
+ sorted_clusters_dir: str,
+ id_column: str,
+ id_column_type: str,
+ which_to_keep: str,
+ output_dir: str,
+ embedding_col: str = "embeddings",
+ logger: Union[logging.Logger, str] = "./",
+ profile_dir: Optional[str] = None,
+ ) -> None:
+ """
+ Initialize the SemanticClusterLevelDedup class.
+
+ Args:
+ n_clusters (int): Number of clusters.
+ emb_by_clust_dir (str): Directory containing embeddings by cluster.
+ sorted_clusters_dir (str): Directory containing sorted clusters.
+ id_column (str): Column name for IDs.
+ id_column_type (str): Data type of the ID column.
+ which_to_keep (str): Strategy for which duplicate to keep.
+ output_dir (str): Directory to save output files.
+ embedding_col (str): Column where the embeddings are stored.
+ logger (Union[logging.Logger, str]): Logger instance or path to the log file directory.
+ profile_dir (str): If specified directory to write dask profile. Default is None.
+ """
+ self.n_clusters = n_clusters
+ self.emb_by_clust_dir = emb_by_clust_dir
+ self.sorted_clusters_dir = sorted_clusters_dir
+ self.id_col = id_column
+ self.id_col_type = id_column_type
+ self.which_to_keep = which_to_keep
+ self.output_dir = output_dir
+ self.semdedup_pruning_tables_dir = os.path.join(
+ output_dir, "semdedup_pruning_tables"
+ )
+ self.computed_semantic_match_dfs = False
+ self.embedding_col = embedding_col
+ self.logger = self._setup_logger(logger)
+ self.profile_dir = profile_dir
+
+ def _setup_logger(self, logger: Union[logging.Logger, str]) -> logging.Logger:
+ """
+ Set up the logger.
+
+ Args:
+ logger (Union[logging.Logger, str]): Logger instance or path to the log file directory.
+
+ Returns:
+ logging.Logger: Configured logger.
+ """
+ if isinstance(logger, str):
+ return create_logger(
+ rank=0,
+ name="SemanticClusterLevelDedup",
+ log_file=os.path.join(logger, "SemanticClusterLevelDedup.log"),
+ log_level=logging.INFO,
+ stdout=True,
+ )
+ else:
+ return logger
+
+ def compute_semantic_match_dfs(
+ self, eps_list: Optional[List[float]] = None
+ ) -> None:
+ """
+ Compute semantic match dataframes for clusters.
+
+ Args:
+ eps_list (Optional[List[float]]): List of epsilon values for clustering.
+ """
+ if eps_list is None:
+ eps_list1 = [1.0e-2, 1.0e-3, 1.0e-4, 1.0e-5, 1.0e-6]
+ eps_list2 = [0.1 + x * 0.005 for x in range(34)]
+ eps_list = eps_list1 + eps_list2
+
+ if os.path.exists(self.semdedup_pruning_tables_dir):
+ self.logger.info(
+ f"Removing existing directory {self.semdedup_pruning_tables_dir}"
+ )
+ shutil.rmtree(self.semdedup_pruning_tables_dir)
+ expand_outdir_and_mkdir(self.semdedup_pruning_tables_dir)
+ t0 = time.time()
+ with performance_report_if_with_ts_suffix(
+ self.profile_dir, "semantic-match-compute"
+ ):
+ tasks = db.from_sequence(
+ list(range(self.n_clusters)), npartitions=self.n_clusters
+ ).map(
+ lambda cluster_id: get_semantic_matches_per_cluster(
+ cluster_id=cluster_id,
+ emb_by_clust_dir=self.emb_by_clust_dir,
+ sorted_clusters_dir=self.sorted_clusters_dir,
+ id_col=self.id_col,
+ id_col_type=self.id_col_type,
+ eps_list=eps_list,
+ output_dir=self.semdedup_pruning_tables_dir,
+ embedding_col=self.embedding_col,
+ which_to_keep=self.which_to_keep,
+ )
+ )
+ tasks.compute()
+ self.logger.info(
+ f"Time taken for Computing Semantic Matches : {time.time() - t0}"
+ )
+ self.computed_semantic_match_dfs = True
+
+ def extract_dedup_data(self, eps_to_extract: float) -> DocumentDataset:
+ """
+ Extract deduplicated data based on epsilon value.
+
+ Args:
+ eps_to_extract (float): Epsilon threshold for extracting deduplicated data.
+
+ Returns:
+ DocumentDataset: Dataset containing deduplicated documents.
+ """
+ if not self.computed_semantic_match_dfs:
+ raise ValueError(
+ "Run compute_semantic_match_dfs before calling extract_dedup_data"
+ )
+
+ output_summary_file = os.path.join(
+ self.output_dir, f"dedup_summary_{eps_to_extract}.csv"
+ )
+ output_parquet_path = os.path.join(
+ self.output_dir, f"unique_ids_{eps_to_extract}.parquet"
+ )
+ extract_dedup_data(
+ eps=eps_to_extract,
+ n_clusters=self.n_clusters,
+ id_col=self.id_col,
+ id_col_type=self.id_col_type,
+ sorted_clusters_dir=self.sorted_clusters_dir,
+ semdedup_pruning_tables_dir=self.semdedup_pruning_tables_dir,
+ output_summary_file=output_summary_file,
+ output_parquet_path=output_parquet_path,
+ logger=self.logger,
+ profile_dir=self.profile_dir,
+ )
+
+ fps = [
+ os.path.join(output_parquet_path, file_name)
+ for file_name in os.listdir(output_parquet_path)
+ ]
+ return DocumentDataset.read_parquet(fps, backend="cudf")
diff --git a/nemo_curator/modules/semantic_dedup/semdedup.py b/nemo_curator/modules/semantic_dedup/semdedup.py
new file mode 100644
index 00000000..a03d152b
--- /dev/null
+++ b/nemo_curator/modules/semantic_dedup/semdedup.py
@@ -0,0 +1,97 @@
+# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+import os
+from typing import Union
+
+from nemo_curator.datasets import DocumentDataset
+from nemo_curator.modules.config import SemDedupConfig
+from nemo_curator.modules.semantic_dedup.clusteringmodel import ClusteringModel
+from nemo_curator.modules.semantic_dedup.embeddings import EmbeddingCreator
+from nemo_curator.modules.semantic_dedup.semanticclusterleveldedup import (
+ SemanticClusterLevelDedup,
+)
+
+
+class SemDedup:
+ def __init__(
+ self,
+ config: SemDedupConfig,
+ input_column: str = "text",
+ id_column: str = "id",
+ id_column_type: str = "int",
+ logger: Union[logging.Logger, str] = "./",
+ ) -> None:
+ """
+ Initialize the SemDedup class.
+
+ Args:
+ config (SemDedupConfig): Configuration for SemDedup.
+ logger (Union[logging.Logger, str]): Logger instance or path to the log file directory.
+ """
+ self.config = config
+ self.logger = logger
+ cache_dir = config.cache_dir
+ self.embedding_creator = EmbeddingCreator(
+ embedding_model_name_or_path=config.embedding_model_name_or_path,
+ embedding_batch_size=config.embedding_batch_size,
+ input_column=input_column,
+ embedding_output_dir=os.path.join(cache_dir, config.embeddings_save_loc),
+ logger=logger,
+ profile_dir=self.config.profile_dir,
+ )
+ self.clustering_model = ClusteringModel(
+ id_column=id_column,
+ max_iter=config.max_iter,
+ n_clusters=config.n_clusters,
+ clustering_output_dir=os.path.join(cache_dir, config.clustering_save_loc),
+ logger=logger,
+ profile_dir=self.config.profile_dir,
+ )
+ self.semantic_cluster_dedup = SemanticClusterLevelDedup(
+ n_clusters=config.n_clusters,
+ emb_by_clust_dir=os.path.join(
+ cache_dir, config.clustering_save_loc, "embs_by_nearest_center"
+ ),
+ sorted_clusters_dir=os.path.join(
+ cache_dir, config.clustering_save_loc, "sorted"
+ ),
+ id_column=id_column,
+ id_column_type=id_column_type,
+ which_to_keep=config.which_to_keep,
+ output_dir=os.path.join(cache_dir, config.clustering_save_loc),
+ logger=logger,
+ profile_dir=self.config.profile_dir,
+ )
+ self.eps_thresholds = config.eps_thresholds
+ self.eps_to_extract = config.eps_to_extract
+
+ def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
+ """
+ Execute the SemDedup process.
+
+ Args:
+ dataset (DocumentDataset): Input dataset for deduplication.
+
+ Returns:
+ DocumentDataset: Deduplicated dataset.
+ """
+ embeddings_dataset = self.embedding_creator(dataset)
+ self.clustering_model(embeddings_dataset)
+ self.semantic_cluster_dedup.compute_semantic_match_dfs(self.eps_thresholds)
+ return self.semantic_cluster_dedup.extract_dedup_data(
+ eps_to_extract=self.eps_to_extract
+ )
diff --git a/nemo_curator/scripts/fuzzy_deduplication/connected_components.py b/nemo_curator/scripts/fuzzy_deduplication/connected_components.py
index 725a34ea..f43eec74 100644
--- a/nemo_curator/scripts/fuzzy_deduplication/connected_components.py
+++ b/nemo_curator/scripts/fuzzy_deduplication/connected_components.py
@@ -16,7 +16,7 @@
import os
import time
-from nemo_curator.modules.fuzzy_dedup import ConnectedComponents
+from nemo_curator import ConnectedComponents
from nemo_curator.utils.distributed_utils import get_client
from nemo_curator.utils.script_utils import ArgumentHelper
diff --git a/nemo_curator/scripts/fuzzy_deduplication/jaccard_compute.py b/nemo_curator/scripts/fuzzy_deduplication/jaccard_compute.py
index 81321424..d87be3cb 100644
--- a/nemo_curator/scripts/fuzzy_deduplication/jaccard_compute.py
+++ b/nemo_curator/scripts/fuzzy_deduplication/jaccard_compute.py
@@ -16,7 +16,7 @@
import os
import time
-from nemo_curator.modules.fuzzy_dedup import JaccardSimilarity
+from nemo_curator import JaccardSimilarity
from nemo_curator.utils.distributed_utils import get_client, get_num_workers
from nemo_curator.utils.script_utils import ArgumentHelper
diff --git a/nemo_curator/scripts/fuzzy_deduplication/jaccard_shuffle.py b/nemo_curator/scripts/fuzzy_deduplication/jaccard_shuffle.py
index e0c4e67a..24c2243a 100644
--- a/nemo_curator/scripts/fuzzy_deduplication/jaccard_shuffle.py
+++ b/nemo_curator/scripts/fuzzy_deduplication/jaccard_shuffle.py
@@ -16,7 +16,7 @@
import os
import time
-from nemo_curator.modules.fuzzy_dedup import _Shuffle
+from nemo_curator.modules.fuzzy_dedup._shuffle import _Shuffle
from nemo_curator.utils.distributed_utils import get_client, get_num_workers
from nemo_curator.utils.fuzzy_dedup_utils.io_utils import (
get_text_ddf_from_json_path_with_blocksize,
@@ -27,7 +27,7 @@
def func():
import cudf
- from nemo_curator.modules.fuzzy_dedup import _Shuffle
+ from nemo_curator.modules.fuzzy_dedup._shuffle import _Shuffle
def main(args):
diff --git a/nemo_curator/scripts/fuzzy_deduplication/map_buckets.py b/nemo_curator/scripts/fuzzy_deduplication/map_buckets.py
index 7af70cb5..fb825b1b 100644
--- a/nemo_curator/scripts/fuzzy_deduplication/map_buckets.py
+++ b/nemo_curator/scripts/fuzzy_deduplication/map_buckets.py
@@ -16,7 +16,7 @@
import os
import time
-from nemo_curator.modules.fuzzy_dedup import _MapBuckets
+from nemo_curator.modules.fuzzy_dedup._mapbuckets import _MapBuckets
from nemo_curator.utils.distributed_utils import get_client, get_num_workers
from nemo_curator.utils.fuzzy_dedup_utils.io_utils import (
get_bucket_ddf_from_parquet_path,
diff --git a/nemo_curator/scripts/semdedup/clustering.py b/nemo_curator/scripts/semdedup/clustering.py
index 27ddfb3a..db4885c3 100644
--- a/nemo_curator/scripts/semdedup/clustering.py
+++ b/nemo_curator/scripts/semdedup/clustering.py
@@ -18,10 +18,10 @@
import dask_cudf
+from nemo_curator import ClusteringModel
from nemo_curator.datasets import DocumentDataset
from nemo_curator.log import create_logger
from nemo_curator.modules.config import SemDedupConfig
-from nemo_curator.modules.semantic_dedup import ClusteringModel
from nemo_curator.utils.distributed_utils import get_client
from nemo_curator.utils.file_utils import expand_outdir_and_mkdir
from nemo_curator.utils.script_utils import ArgumentHelper
diff --git a/nemo_curator/scripts/semdedup/compute_embeddings.py b/nemo_curator/scripts/semdedup/compute_embeddings.py
index 014390f8..e46c9d01 100644
--- a/nemo_curator/scripts/semdedup/compute_embeddings.py
+++ b/nemo_curator/scripts/semdedup/compute_embeddings.py
@@ -16,10 +16,10 @@
import os
import time
+from nemo_curator import EmbeddingCreator
from nemo_curator.datasets import DocumentDataset
from nemo_curator.log import create_logger
from nemo_curator.modules.config import SemDedupConfig
-from nemo_curator.modules.semantic_dedup import EmbeddingCreator
from nemo_curator.utils.distributed_utils import get_client, read_data
from nemo_curator.utils.file_utils import expand_outdir_and_mkdir, get_remaining_files
from nemo_curator.utils.script_utils import ArgumentHelper
diff --git a/nemo_curator/scripts/semdedup/extract_dedup_data.py b/nemo_curator/scripts/semdedup/extract_dedup_data.py
index 5b489fa1..b6ffaebc 100755
--- a/nemo_curator/scripts/semdedup/extract_dedup_data.py
+++ b/nemo_curator/scripts/semdedup/extract_dedup_data.py
@@ -2,9 +2,9 @@
import os
from datetime import datetime
+from nemo_curator import SemanticClusterLevelDedup
from nemo_curator.log import create_logger
from nemo_curator.modules.config import SemDedupConfig
-from nemo_curator.modules.semantic_dedup import SemanticClusterLevelDedup
from nemo_curator.utils.distributed_utils import get_client
from nemo_curator.utils.script_utils import ArgumentHelper
diff --git a/tutorials/dapt-curation/README.md b/tutorials/dapt-curation/README.md
index 0e43e48a..4f67c616 100755
--- a/tutorials/dapt-curation/README.md
+++ b/tutorials/dapt-curation/README.md
@@ -37,7 +37,7 @@ The tutorial follows the steps below:
- Heuristic-based quality filtering (Number of lines, worc count, top N-grams, etc.)
- Fix unicode errors via ftfy
- PII redaction
- - GPU accelerated fuzzy and semanctic deduplication
+ - GPU accelerated fuzzy and semantic deduplication
- Step 6: Save the filtered and curated data
- Step 7: Blend datasets and shuffle
diff --git a/tutorials/pretraining-data-curation/red-pajama-v2-curation-tutorial.ipynb b/tutorials/pretraining-data-curation/red-pajama-v2-curation-tutorial.ipynb
index d0f690ea..ae4adffd 100644
--- a/tutorials/pretraining-data-curation/red-pajama-v2-curation-tutorial.ipynb
+++ b/tutorials/pretraining-data-curation/red-pajama-v2-curation-tutorial.ipynb
@@ -31,7 +31,7 @@
"# 1. Introduction\n",
"\n",
"\n",
- "In this tutorial, we will show how to curate large-scale data for LLM pretraining in a distributed environment using NeMo-Curator. Specifically, we will focus on the following modules in NeMo-Curator:\n",
+ "In this tutorial, we will show how to curate large-scale data for LLM pretraining in a distributed environment using NeMo Curator. Specifically, we will focus on the following modules in NeMo Curator:\n",
"\n",
"- Language identification and separation\n",
"- Text reformatting and cleaning\n",
@@ -58,13 +58,13 @@
"\n",
"- **OS**: Ubuntu 22.04.4 LTS\n",
"\n",
- "## 1.2 Running NeMo-Curator\n",
+ "## 1.2 Running NeMo Curator\n",
"\n",
- "NeMo-curator came pre-installed in Nemo Framework container. This notebook use 24.07 release of the NeMo Framework container. User can pull the container following the steps below:\n",
+ "NeMo Curator comes pre-installed in the NeMo Framework container. This notebook uses the 24.07 release of the NeMo Framework container. The user can pull the container by following the steps below:\n",
"\n",
"- Get access to the NeMo Framework container on [NGC](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo)\n",
"\n",
- "- Set your docker credentials\n",
+ "- Set your Docker credentials:\n",
"\n",
"\n",
" `docker login nvcr.io`\n",
@@ -77,7 +77,7 @@
" \n",
" `docker pull docker pull nvcr.io/nvidia/nemo:24.07`\n",
"\n",
- "Alternatively, NeMo-Curator is also available on [PyPi](https://pypi.org/project/nemo-curator/) and [GitHub](https://github.com/NVIDIA/NeMo-Curator)."
+ "Alternatively, NeMo Curator is available on [PyPi](https://pypi.org/project/nemo-curator/) and [GitHub](https://github.com/NVIDIA/NeMo-Curator)."
]
},
{
@@ -88,22 +88,22 @@
"# 2. Getting started\n",
"\n",
"\n",
- "NeMo-Curator uses dask for parallelization. Before we start using curator, we need to start a dask cluster. To start a multi-node dask cluster in Slurm, we can use the `start-distributed-notebook.sh` script in this directory to start the cluster. The user will need to change the following variables:\n",
+ "NeMo Curator uses Dask for parallelization. Before we start using NeMo Curator, we need to start a Dask cluster. To start a multi-node Dask cluster in Slurm, we can use the `start-distributed-notebook.sh` script in this directory. The user will need to change the following variables:\n",
"\n",
"- Slurm job directives\n",
- "- Device type (`cpu` or `gpu`). Curator has both cpu and gpu modules. Check [here](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/cpuvsgpu.html) to see which modules are cpu/gpu\n",
- "- CPU related parameters if using cpu modules. Configure the number of workers and memory limit to efficiently use available computational resources while preventing out of memory\n",
+ "- Device type (`cpu` or `gpu`). NeMo Curator has both CPU-based and GPU-based modules. Check [here](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/cpuvsgpu.html) to see which modules are CPU-based and/or GPU-based\n",
+ "- CPU-related parameters which are used for CPU-based modules: configure the number of workers and the memory limit to efficiently use available computational resources and prevent out of memory errors\n",
"- Path to the NeMo Framework container image\n",
- "- Path to `container-entrypoint.sh` script which is responsible for launching the dask schduler and workers\n",
+ "- Path to `container-entrypoint.sh` script, which is responsible for launching the Dask schduler and workers\n",
"\n",
- "Running the script will also launch a jupyter lab session on the rank 0 node and pass the dask schduler address as an environment variable that will be used later to connect to the dask client.\n",
+ "Running the script will also launch a JupyterLab session on the rank 0 node and pass the Dask scheduler address as an environment variable to be used later for connecting to the Dask client.\n",
"\n",
- "The preprocessing modules such as Add ID and Text cleaning are cpu-based so we will start a cpu dask cluster first."
+ "The preprocessing modules such as AddId and text cleaning are CPU-based, so we will start a CPU-based Dask cluster first."
]
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
"id": "5de0fe93",
"metadata": {
"tags": []
@@ -121,31 +121,19 @@
"source": [
"import os\n",
"import time\n",
- "from dask.distributed import Client\n",
"import warnings\n",
"import dask.dataframe as dd\n",
"import dask_cudf\n",
"import cudf\n",
- "import gzip\n",
- "import json\n",
- "import dask.bag as db\n",
- "import glob\n",
"from dask.distributed import wait\n",
"import numpy as np\n",
"\n",
"from nemo_curator import get_client\n",
- "from nemo_curator.datasets import DocumentDataset\n",
"from nemo_curator.utils.distributed_utils import (\n",
" get_num_workers,\n",
" read_data,\n",
" write_to_disk,\n",
")\n",
- "from nemo_curator.utils.file_utils import (\n",
- " expand_outdir_and_mkdir, \n",
- " get_all_files_paths_under, \n",
- " separate_by_metadata,\n",
- " get_batched_files,\n",
- ")\n",
"\n",
"warnings.filterwarnings('ignore')\n",
"base_dir = \"/path/to/data\""
@@ -495,21 +483,20 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": null,
"id": "7419a216-0dad-4d13-89ee-c3c1d009efa8",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
- "from nemo_curator import ScoreFilter, Modify\n",
+ "from nemo_curator import ScoreFilter\n",
"from nemo_curator.filters import FastTextLangId\n",
- "from nemo_curator.modifiers import UnicodeReformatter\n",
"from nemo_curator.utils.file_utils import get_all_files_paths_under, separate_by_metadata\n",
"\n",
"# Language ID path\n",
- "language_output_path = expand_outdir_and_mkdir(os.path.join(base_dir,\"rpv2-2023-06-language\"))\n",
- "language_data_output_path = expand_outdir_and_mkdir(os.path.join(language_output_path,\"data\"))\n",
+ "language_output_path = expand_outdir_and_mkdir(os.path.join(base_dir, \"rpv2-2023-06-language\"))\n",
+ "language_data_output_path = expand_outdir_and_mkdir(os.path.join(language_output_path, \"data\"))\n",
"\n",
"# Fasttext model path\n",
"model_path = language_output_path\n",
@@ -808,14 +795,13 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": null,
"id": "f6dc1754",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
- "from nemo_curator.log import create_logger\n",
"from nemo_curator.modules import ExactDuplicates\n",
"\n",
"def pre_imports():\n",
@@ -1796,14 +1782,14 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": null,
"id": "7985cf1a-9d88-4844-8ce4-e68d9792118c",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
- "from nemo_curator.modules.fuzzy_dedup import _MapBuckets\n",
+ "from nemo_curator.modules.fuzzy_dedup._mapbuckets import _MapBuckets\n",
"from nemo_curator.utils.fuzzy_dedup_utils.io_utils import (\n",
" get_bucket_ddf_from_parquet_path,\n",
" get_text_ddf_from_json_path_with_blocksize,\n",
@@ -2031,14 +2017,14 @@
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": null,
"id": "11d7184d-4ca5-4b49-85b4-1264056f5c33",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
- "from nemo_curator.modules.fuzzy_dedup import _Shuffle\n",
+ "from nemo_curator.modules.fuzzy_dedup._shuffle import _Shuffle\n",
"\n",
"log_dir = os.path.join(base_dir, \"logs\")\n",
"input_anchor_docs_with_bk_path = os.path.join(base_dir,\"fuzzy-dedup-output-2023-06/anchor_docs_with_bk.parquet\")\n",
@@ -2512,14 +2498,14 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": null,
"id": "573dccf7-2e23-4aae-a3ec-2b9e1a42d97d",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
- "from nemo_curator.modules.fuzzy_dedup import JaccardSimilarity\n",
+ "from nemo_curator import JaccardSimilarity\n",
"\n",
"id_field = 'id'\n",
"text_field = 'raw_content'\n",
@@ -2670,14 +2656,14 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": null,
"id": "f9aeb619-3fab-4a18-b582-bccae3eefd17",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
- "from nemo_curator.modules.fuzzy_dedup import ConnectedComponents\n",
+ "from nemo_curator import ConnectedComponents\n",
"\n",
"cache_dir = expand_outdir_and_mkdir(\n",
" os.path.join(base_dir, \"fuzzy-dedup-output-2023-06/cc-cache\")\n",
@@ -3255,7 +3241,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": null,
"id": "f1461b61-887c-4099-bd9f-32e79dc5fdbb",
"metadata": {
"tags": []
@@ -3264,10 +3250,10 @@
"source": [
"from nemo_curator import MinHash\n",
"from nemo_curator import LSH\n",
- "from nemo_curator.modules.fuzzy_dedup import _MapBuckets\n",
- "from nemo_curator.modules.fuzzy_dedup import _Shuffle\n",
- "from nemo_curator.modules.fuzzy_dedup import ConnectedComponents\n",
- "from nemo_curator.modules.fuzzy_dedup import JaccardSimilarity\n",
+ "from nemo_curator.modules.fuzzy_dedup._mapbuckets import _MapBuckets\n",
+ "from nemo_curator.modules.fuzzy_dedup._shuffle import _Shuffle\n",
+ "from nemo_curator import ConnectedComponents\n",
+ "from nemo_curator import JaccardSimilarity\n",
"\n",
"from nemo_curator.utils.file_utils import reshard_jsonl\n",
"from nemo_curator.utils.fuzzy_dedup_utils.id_mapping import convert_str_id_to_int\n",
@@ -4718,14 +4704,13 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": null,
"id": "49273a8b-848f-4f24-a0ba-3c0b478d17cc",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
- "import nemo_curator\n",
"from nemo_curator.utils.config_utils import build_filter_pipeline\n",
"\n",
"filter_config_file = os.path.join(base_dir, \"config/heuristic_filter_en.yaml\")\n",
diff --git a/tutorials/single_node_tutorial/single_gpu_tutorial.ipynb b/tutorials/single_node_tutorial/single_gpu_tutorial.ipynb
index 3170b350..2512cf73 100644
--- a/tutorials/single_node_tutorial/single_gpu_tutorial.ipynb
+++ b/tutorials/single_node_tutorial/single_gpu_tutorial.ipynb
@@ -122,15 +122,13 @@
},
"outputs": [],
"source": [
- "import argparse\n",
"import os\n",
"\n",
- "from nemo_curator.utils.distributed_utils import get_client,get_num_workers\n",
+ "from nemo_curator.utils.distributed_utils import get_client, get_num_workers\n",
"from nemo_curator.utils.file_utils import get_all_files_paths_under, separate_by_metadata\n",
- "from nemo_curator.utils.distributed_utils import read_data,write_to_disk\n",
+ "from nemo_curator.utils.distributed_utils import read_data, write_to_disk\n",
"from nemo_curator.datasets import DocumentDataset\n",
"\n",
- "import sys\n",
"import pandas as pd\n",
"import time\n",
"import cudf\n",
@@ -138,7 +136,7 @@
"import dask\n",
"import numpy as np\n",
"from dask.distributed import Client, LocalCluster\n",
- "import jsonlines\n"
+ "import jsonlines"
]
},
{
@@ -406,7 +404,7 @@
},
"outputs": [],
"source": [
- "from nemo_curator import ScoreFilter,Modify\n",
+ "from nemo_curator import ScoreFilter, Modify\n",
"from nemo_curator.filters import FastTextLangId\n",
"from nemo_curator.modifiers import UnicodeReformatter"
]
@@ -1360,7 +1358,8 @@
" get_bucket_ddf_from_parquet_path,\n",
" get_text_ddf_from_json_path_with_blocksize,\n",
")\n",
- "from nemo_curator.modules.fuzzy_dedup import _MapBuckets,_Shuffle"
+ "from nemo_curator.modules.fuzzy_dedup._mapbuckets import _MapBuckets\n",
+ "from nemo_curator.modules.fuzzy_dedup._shuffle import _Shuffle"
]
},
{
@@ -1572,7 +1571,7 @@
},
"outputs": [],
"source": [
- "from nemo_curator.modules.fuzzy_dedup import JaccardSimilarity"
+ "from nemo_curator import JaccardSimilarity"
]
},
{
@@ -1691,7 +1690,7 @@
},
"outputs": [],
"source": [
- "from nemo_curator.modules.fuzzy_dedup import ConnectedComponents"
+ "from nemo_curator import ConnectedComponents"
]
},
{
@@ -2258,8 +2257,8 @@
"outputs": [],
"source": [
"from nemo_curator.utils.config_utils import build_filter_pipeline\n",
- "from nemo_curator import Score, Filter, ScoreFilter\n",
- "from nemo_curator.utils.file_utils import get_batched_files,expand_outdir_and_mkdir"
+ "from nemo_curator import Score, ScoreFilter\n",
+ "from nemo_curator.utils.file_utils import expand_outdir_and_mkdir"
]
},
{
@@ -2282,7 +2281,7 @@
"import warnings\n",
"\n",
"# Disable the metadata warning\n",
- "warnings.filterwarnings(\"ignore\",module=\"dask.dataframe.core\")"
+ "warnings.filterwarnings(\"ignore\", module=\"dask.dataframe.core\")"
]
},
{
diff --git a/tutorials/zyda2-tutorial/1_fuzzy_dedup/2_buckets_to_edges.py b/tutorials/zyda2-tutorial/1_fuzzy_dedup/2_buckets_to_edges.py
index 853fe6fd..45755673 100644
--- a/tutorials/zyda2-tutorial/1_fuzzy_dedup/2_buckets_to_edges.py
+++ b/tutorials/zyda2-tutorial/1_fuzzy_dedup/2_buckets_to_edges.py
@@ -4,8 +4,8 @@
import dask_cudf
+from nemo_curator import BucketsToEdges
from nemo_curator.datasets import DocumentDataset
-from nemo_curator.modules.fuzzy_dedup import BucketsToEdges
from nemo_curator.utils.distributed_utils import get_client, get_num_workers
logging.basicConfig(format="%(asctime)s: %(message)s", level=logging.INFO)
diff --git a/tutorials/zyda2-tutorial/1_fuzzy_dedup/3_connected_components.py b/tutorials/zyda2-tutorial/1_fuzzy_dedup/3_connected_components.py
index db76ce5b..e6ad2165 100644
--- a/tutorials/zyda2-tutorial/1_fuzzy_dedup/3_connected_components.py
+++ b/tutorials/zyda2-tutorial/1_fuzzy_dedup/3_connected_components.py
@@ -2,7 +2,7 @@
import os
import time
-from nemo_curator.modules.fuzzy_dedup import ConnectedComponents
+from nemo_curator import ConnectedComponents
from nemo_curator.utils.distributed_utils import get_client, get_num_workers
logging.basicConfig(format="%(asctime)s: %(message)s", level=logging.INFO)