Skip to content

Commit

Permalink
Added train test bleed example
Browse files Browse the repository at this point in the history
  • Loading branch information
Pringled committed Oct 11, 2024
1 parent a0375ba commit 91e958b
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 2 deletions.
2 changes: 1 addition & 1 deletion tutorials/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ This is a list of all our tutorials. They are all self-contained ipython noteboo
| | what? | Link |
|--------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------|
| **Recipe search** | Learn how to do lightning-fast semantic search by distilling a small model. Compare a really tiny model to a larger with one with a better vocabulary. Learn what Fattoush is (delicious). | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/minishlab/model2vec/blob/master/tutorials/recipe_search.ipynb) |
| **Semantic deduplication** | Learn how Model2Vec can be used to detect duplicate texts. Clean your dataset efficiently by finding both exact and semantic duplicates. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/minishlab/model2vec/blob/master/tutorials/semantic_deduplication.ipynb) |
| **Semantic deduplication** | Learn how Model2Vec can be used to detect duplicate texts. Clean your dataset efficiently by finding both exact and semantic duplicates. Detect train-test bleed. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/minishlab/model2vec/blob/master/tutorials/semantic_deduplication.ipynb) |
151 changes: 150 additions & 1 deletion tutorials/semantic_deduplication.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,156 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Model2Vec is again much faster, with 27 seconds vs 56 seconds for MinHash. The number of found duplicates is roughly the same using the default settings for MinHash."
"Model2Vec is again much faster, with 27 seconds vs 56 seconds for MinHash. The number of found duplicates is roughly the same using the default settings for MinHash.\n",
"\n",
"Now, as a last experiment, let's also embed the test set, and see if there are any duplicates between the training and test set. This is a common issue in NLP, where the test set may contain instances that are also in the training set."
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 118/118 [00:02<00:00, 45.36it/s]\n",
"100%|██████████| 8/8 [00:00<00:00, 51.05it/s]\n",
"100%|██████████| 8/8 [00:01<00:00, 5.40it/s]\n",
"100%|██████████| 7600/7600 [00:00<00:00, 901108.42it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of duplicates found between train and test: 138\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# Load the datasets\n",
"ds_train = load_dataset(\"ag_news\")[\"train\"]\n",
"ds_test = load_dataset(\"ag_news\")[\"test\"]\n",
"\n",
"texts_train = ds_train['text']\n",
"texts_test = ds_test['text']\n",
"\n",
"# Encode texts into embeddings\n",
"embedding_matrix_train = model.encode(texts_train, show_progressbar=True)\n",
"embedding_matrix_test = model.encode(texts_test, show_progressbar=True)\n",
"\n",
"def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix_2: np.ndarray, threshold: float, batch_size: int = 1024) -> tuple[list[int], dict[int, int]]:\n",
" \"\"\"\n",
" Deduplicate embeddings across two datasets and return the indices of duplicates between them.\n",
" \n",
" :param embedding_matrix_1: The embeddings of the first dataset (e.g., train).\n",
" :param embedding_matrix_2: The embeddings of the second dataset (e.g., test).\n",
" :param threshold: The similarity threshold to use for deduplication.\n",
" :param batch_size: The batch size to use for similarity computation.\n",
" :return: A tuple containing the duplicate indices and a dictionary mapping removed indices in the second dataset to their corresponding indices in the first dataset.\n",
" \"\"\"\n",
" reach = Reach(vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))])\n",
"\n",
" # Keep track of duplicates in the second dataset\n",
" duplicate_indices_in_test = []\n",
" duplicate_to_original_mapping = {}\n",
"\n",
" # Find nearest neighbors from the test set in the train set\n",
" results = reach.nearest_neighbor_threshold(\n",
" embedding_matrix_2, \n",
" threshold=threshold, \n",
" batch_size=batch_size, \n",
" show_progressbar=True\n",
" )\n",
" \n",
" # Process duplicates\n",
" for i, similar_items in enumerate(tqdm(results)):\n",
" # Similar items are returned as (index, score), we are only interested in the index\n",
" similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold] # Keep those above the threshold\n",
" \n",
" # If we find a similar item in the train set, mark it as a duplicate\n",
" if similar_indices:\n",
" duplicate_indices_in_test.append(i)\n",
" duplicate_to_original_mapping[i] = similar_indices[0] # Map duplicate in test to original in train\n",
"\n",
" return duplicate_indices_in_test, duplicate_to_original_mapping\n",
"\n",
"# Check for train/test bleed\n",
"duplicate_indices_in_test, duplicate_to_original_mapping = deduplicate_across_datasets(\n",
" embedding_matrix_train, \n",
" embedding_matrix_test, \n",
" threshold=0.99 # High threshold for deduplication\n",
")\n",
"\n",
"print(f\"Number of duplicates found between train and test: {len(duplicate_indices_in_test)}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 91,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Example 1:\n",
"Train text:\n",
"Jackson Squares Off With Attorney SANTA MARIA, Calif. - Fans of Michael Jackson erupted in cheers Monday as the pop star emerged from a double-decker tour bus and went into court for a showdown with the prosecutor who has pursued him for years on child molestation charges...\n",
"Test text:\n",
"Jackson Squares Off With Prosecutor SANTA MARIA, Calif. - Fans of Michael Jackson erupted in cheers Monday as the pop star emerged from a double-decker tour bus and went into court for a showdown with the prosecutor who has pursued him for years on child molestation charges...\n",
"--------------------------------------------------\n",
"Example 2:\n",
"Train text:\n",
"Cassini Spies Two Moons Around Saturn (AP) AP - NASA's Cassini spacecraft has spied two new little moons around satellite-rich Saturn, the space agency said.\n",
"Test text:\n",
"Cassini Spies Two Little Saturn Moons (AP) AP - NASA's Cassini spacecraft has spied two new little moons around satellite-rich Saturn, the space agency said Monday.\n",
"--------------------------------------------------\n",
"Example 3:\n",
"Train text:\n",
"Intel to Delay Product for High-Definition TVs SAN FRANCISCO (Reuters) - In the latest of a series of product delays, Intel Corp. has postponed the launch of a video display chip it had previously planned to introduce by year end, putting off a showdown with Texas Instruments Inc. in the fast-growing market for high-definition television displays.\n",
"Test text:\n",
"Intel to delay product aimed for high-definition TVs SAN FRANCISCO -- In the latest of a series of product delays, Intel Corp. has postponed the launch of a video display chip it had previously planned to introduce by year end, putting off a showdown with Texas Instruments Inc. in the fast-growing market for high-definition television displays.\n",
"--------------------------------------------------\n",
"Example 4:\n",
"Train text:\n",
"Staples Profit Up Sharply, to Enter China NEW YORK (Reuters) - Staples Inc. &lt;A HREF=\"http://www.investor.reuters.com/FullQuote.aspx?ticker=SPLS.O target=/stocks/quickinfo/fullquote\"&gt;SPLS.O&lt;/A&gt;, the top U.S. office products retailer, on Tuesday reported a 39 percent jump in quarterly profit, raised its full-year forecast and said it plans to enter the fast-growing Chinese market.\n",
"Test text:\n",
"Staples Profit Up, to Enter China Market NEW YORK (Reuters) - Staples Inc. &lt;A HREF=\"http://www.investor.reuters.com/FullQuote.aspx?ticker=SPLS.O target=/stocks/quickinfo/fullquote\"&gt;SPLS.O&lt;/A&gt;, the top U.S. office products retailer, on Tuesday reported a 39 percent jump in quarterly profit, raised its full-year forecast and said it plans to enter the fast-growing Chinese market, sending its shares higher.\n",
"--------------------------------------------------\n",
"Example 5:\n",
"Train text:\n",
"Stocks Climb on Drop in Consumer Prices NEW YORK - Stocks rose for a second straight session Tuesday as a drop in consumer prices Tuesday allowed investors to put aside worries about inflation, at least for the short term. With gasoline prices falling to eight-month lows, the Consumer Price Index registered a small drop in July, giving consumers a respite from soaring energy prices...\n",
"Test text:\n",
"Stocks Climb on Drop in Consumer Prices NEW YORK - Stocks rose for a second straight session Tuesday as a drop in consumer prices allowed investors to put aside worries about inflation, at least for the short term. With gasoline prices falling to eight-month lows, the Consumer Price Index registered a small drop in July, giving consumers a respite from soaring energy prices...\n",
"--------------------------------------------------\n"
]
}
],
"source": [
"num_examples = 5\n",
"for i, test_idx in enumerate(duplicate_indices_in_test[:num_examples]):\n",
" train_idx = duplicate_to_original_mapping[test_idx]\n",
" print(f\"Example {i + 1}:\")\n",
" print(f\"Train text:\\n{texts_train[train_idx]}\")\n",
" print(f\"Test text:\\n{texts_test[test_idx]}\")\n",
" print(\"-\" * 50)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"These again look like duplicates. We can very efficiently find train/test bleed examples using Model2Vec, ensuring that our test set is clean and does not contain any duplicates from the training set."
]
}
],
Expand Down

0 comments on commit 91e958b

Please sign in to comment.