Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for nvidia/llama-3.2-nv-embedqa-1b-v2's dimensions param #126

Merged
merged 2 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""Embeddings Components Derived from NVEModel/Embeddings"""

from typing import Any, List, Literal, Optional
from typing import Any, Dict, List, Literal, Optional

from langchain_core.embeddings import Embeddings
from langchain_core.outputs.llm_result import LLMResult
Expand Down Expand Up @@ -28,6 +26,8 @@ class NVIDIAEmbeddings(BaseModel, Embeddings):
- truncate: "NONE", "START", "END", truncate input text if it exceeds the model's
maximum token length. Default is "NONE", which raises an error if an input is
too long.
- dimensions: int, the number of dimensions for the embeddings. This parameter is
not supported by all models.
"""

model_config = ConfigDict(
Expand All @@ -47,6 +47,13 @@ class NVIDIAEmbeddings(BaseModel, Embeddings):
"Default is 'NONE', which raises an error if an input is too long."
),
)
dimensions: Optional[int] = Field(
default=None,
description=(
"The number of dimensions for the embeddings. This parameter is not "
"supported by all models."
),
)
max_batch_size: int = Field(default=_DEFAULT_BATCH_SIZE)

def __init__(self, **kwargs: Any):
Expand All @@ -67,6 +74,8 @@ def __init__(self, **kwargs: Any):
trucate (str): "NONE", "START", "END", truncate input text if it exceeds
the model's context length. Default is "NONE", which raises
an error if an input is too long.
dimensions (int): The number of dimensions for the embeddings. This
parameter is not supported by all models.

API Key:
- The recommended way to provide the API key is through the `NVIDIA_API_KEY`
Expand Down Expand Up @@ -125,14 +134,17 @@ def _embed(
# user: str -- ignored
# truncate: "NONE" | "START" | "END" -- default "NONE", error raised if
# an input is too long
payload = {
# dimensions: int -- not supported by all models
payload: Dict[str, Any] = {
"input": texts,
"model": self.model,
"encoding_format": "float",
"input_type": model_type,
}
if self.truncate:
payload["truncate"] = self.truncate
if self.dimensions:
payload["dimensions"] = self.dimensions

response = self._client.get_req(
payload=payload,
Expand Down
76 changes: 76 additions & 0 deletions libs/ai-endpoints/tests/integration_tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,5 +97,81 @@ def test_embed_documents_truncate(
assert len(output) == count


@pytest.mark.parametrize("dimensions", [32, 64, 128, 2048])
def test_embed_query_with_dimensions(
embedding_model: str, mode: dict, dimensions: int
) -> None:
if embedding_model != "nvidia/llama-3.2-nv-embedqa-1b-v2":
pytest.skip("Model does not support custom dimensions.")
query = "foo bar"
embedding = NVIDIAEmbeddings(model=embedding_model, dimensions=dimensions, **mode)
assert len(embedding.embed_query(query)) == dimensions


@pytest.mark.parametrize("dimensions", [32, 64, 128, 2048])
def test_embed_documents_with_dimensions(
embedding_model: str, mode: dict, dimensions: int
) -> None:
if embedding_model != "nvidia/llama-3.2-nv-embedqa-1b-v2":
pytest.skip("Model does not support custom dimensions.")
documents = ["foo bar", "bar foo"]
embedding = NVIDIAEmbeddings(model=embedding_model, dimensions=dimensions, **mode)
output = embedding.embed_documents(documents)
assert len(output) == len(documents)
assert all(len(doc) == dimensions for doc in output)


@pytest.mark.parametrize("dimensions", [102400])
def test_embed_query_with_large_dimensions(
embedding_model: str, mode: dict, dimensions: int
) -> None:
if embedding_model != "nvidia/llama-3.2-nv-embedqa-1b-v2":
pytest.skip("Model does not support custom dimensions.")
query = "foo bar"
embedding = NVIDIAEmbeddings(model=embedding_model, dimensions=dimensions, **mode)
assert 2048 <= len(embedding.embed_query(query)) < dimensions


@pytest.mark.parametrize("dimensions", [102400])
def test_embed_documents_with_large_dimensions(
embedding_model: str, mode: dict, dimensions: int
) -> None:
if embedding_model != "nvidia/llama-3.2-nv-embedqa-1b-v2":
pytest.skip("Model does not support custom dimensions.")
documents = ["foo bar", "bar foo"]
embedding = NVIDIAEmbeddings(model=embedding_model, dimensions=dimensions, **mode)
output = embedding.embed_documents(documents)
assert len(output) == len(documents)
assert all(2048 <= len(doc) < dimensions for doc in output)


@pytest.mark.parametrize("dimensions", [-1])
def test_embed_query_invalid_dimensions(
embedding_model: str, mode: dict, dimensions: int
) -> None:
if embedding_model != "nvidia/llama-3.2-nv-embedqa-1b-v2":
pytest.skip("Model does not support custom dimensions.")
query = "foo bar"
with pytest.raises(Exception) as exc:
NVIDIAEmbeddings(
model=embedding_model, dimensions=dimensions, **mode
).embed_query(query)
assert "400" in str(exc.value)


@pytest.mark.parametrize("dimensions", [-1])
def test_embed_documents_invalid_dimensions(
embedding_model: str, mode: dict, dimensions: int
) -> None:
if embedding_model != "nvidia/llama-3.2-nv-embedqa-1b-v2":
pytest.skip("Model does not support custom dimensions.")
documents = ["foo bar", "bar foo"]
with pytest.raises(Exception) as exc:
NVIDIAEmbeddings(
model=embedding_model, dimensions=dimensions, **mode
).embed_documents(documents)
assert "400" in str(exc.value)


# todo: test max_length > max length accepted by the model
# todo: test max_batch_size > max batch size accepted by the model
Loading