Skip to content

Commit

Permalink
Fix progress bar
Browse files Browse the repository at this point in the history
  • Loading branch information
stephantul committed Oct 19, 2024
1 parent ac1afd5 commit b90e168
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
12 changes: 6 additions & 6 deletions model2vec/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def encode_as_sequence(
sentences: list[str] | str,
max_length: int | None = None,
batch_size: int = 1024,
show_progressbar: bool = False,
show_progress_bar: bool = False,
) -> list[np.ndarray] | np.ndarray:
"""
Encode a list of sentences as a list of numpy arrays of tokens.
Expand All @@ -177,7 +177,7 @@ def encode_as_sequence(
:param max_length: The maximum length of the sentences. Any tokens beyond this length will be truncated.
If this is None, no truncation is done.
:param batch_size: The batch size to use.
:param show_progressbar: Whether to show the progress bar.
:param show_progress_bar: Whether to show the progress bar.
:return: The encoded sentences with an embedding per token.
"""
was_single = False
Expand All @@ -189,7 +189,7 @@ def encode_as_sequence(
for batch in tqdm(
self._batch(sentences, batch_size),
total=math.ceil(len(sentences) / batch_size),
disable=not show_progressbar,
disable=not show_progress_bar,
):
out_array.extend(self._encode_batch_as_sequence(batch, max_length))

Expand All @@ -213,7 +213,7 @@ def _encode_batch_as_sequence(self, sentences: list[str], max_length: int | None
def encode(
self,
sentences: list[str] | str,
show_progressbar: bool = False,
show_progress_bar: bool = False,
max_length: int | None = 512,
batch_size: int = 1024,
**kwargs: Any,
Expand All @@ -225,7 +225,7 @@ def encode(
For ease of use, we don't batch sentences together.
:param sentences: The list of sentences to encode. You can also pass a single sentence.
:param show_progressbar: Whether to show the progress bar.
:param show_progress_bar: Whether to show the progress bar.
:param max_length: The maximum length of the sentences. Any tokens beyond this length will be truncated.
If this is None, no truncation is done.
:param batch_size: The batch size to use.
Expand All @@ -241,7 +241,7 @@ def encode(
for batch in tqdm(
self._batch(sentences, batch_size),
total=math.ceil(len(sentences) / batch_size),
disable=not show_progressbar,
disable=not show_progress_bar,
):
out_arrays.append(self._encode_batch(batch, max_length))

Expand Down
6 changes: 3 additions & 3 deletions tutorials/semantic_deduplication.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@
],
"source": [
"# Encode texts into embeddings\n",
"embedding_matrix = model.encode(texts, show_progressbar=True)"
"embedding_matrix = model.encode(texts)"
]
},
{
Expand Down Expand Up @@ -407,8 +407,8 @@
"texts_test = ds_test['text']\n",
"\n",
"# Encode texts into embeddings\n",
"embedding_matrix_train = model.encode(texts_train, show_progressbar=True)\n",
"embedding_matrix_test = model.encode(texts_test, show_progressbar=True)\n",
"embedding_matrix_train = model.encode(texts_train)\n",
"embedding_matrix_test = model.encode(texts_test)\n",
"\n",
"def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix_2: np.ndarray, threshold: float, batch_size: int = 1024) -> tuple[list[int], dict[int, int]]:\n",
" \"\"\"\n",
Expand Down

0 comments on commit b90e168

Please sign in to comment.