Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: support UAE sentence embeddings #15134

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion libs/community/langchain_community/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"""


import logging
import logging # NOQA
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete comment

from typing import Any

from langchain_community.embeddings.aleph_alpha import (
Expand Down Expand Up @@ -47,6 +47,7 @@
HuggingFaceEmbeddings,
HuggingFaceInferenceAPIEmbeddings,
HuggingFaceInstructEmbeddings,
HuggingFaceUaeEmbeddings
)
from langchain_community.embeddings.huggingface_hub import HuggingFaceHubEmbeddings
from langchain_community.embeddings.infinity import InfinityEmbeddings
Expand Down Expand Up @@ -139,6 +140,7 @@
"JohnSnowLabsEmbeddings",
"VoyageEmbeddings",
"BookendEmbeddings",
"HuggingFaceUaeEmbeddings",
"VolcanoEmbeddings",
]

Expand Down
86 changes: 86 additions & 0 deletions libs/community/langchain_community/embeddings/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large"
DEFAULT_BGE_MODEL = "BAAI/bge-large-en"
DEFAULT_UAE_MODEL = "WhereIsAI/UAE-Large-V1"
DEFAULT_EMBED_INSTRUCTION = "Represent the document for retrieval: "
DEFAULT_QUERY_INSTRUCTION = (
"Represent the question for retrieving supporting documents: "
Expand Down Expand Up @@ -341,3 +342,88 @@ def embed_query(self, text: str) -> List[float]:
Embeddings for the text.
"""
return self.embed_documents([text])[0]


class HuggingFaceUaeEmbeddings(BaseModel, Embeddings):
"""HuggingFace UAE sentence embedding models.
Arxiv: https://arxiv.org/abs/2309.12871

To use, you should have the ``angle_emb`` python package installed.

Example:
.. code-block:: python

from langchain_community.embeddings import HuggingFaceUaeEmbeddings

model_name = "WhereIsAI/UAE-Large-V1"
model_kwargs = {
'device': 'cpu',
'pooling_strategy': 'cls',
}
encode_kwargs = {'to_numpy': True}
prompt = None
hf = HuggingFaceUaeEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
prompt=prompt
)
"""

client: Any #: :meta private:
model_name: str = DEFAULT_UAE_MODEL
"""Model name to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass to the model."""
encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass when calling the `encode` method of the model."""
prompt: Optional[str] = None
"""prompt argument"""

def __init__(self, **kwargs: Any):
"""Initialize the angle_emb."""
super().__init__(**kwargs)
try:
import angle_emb

except ImportError as exc:
raise ImportError(
"Could not import angle_emb python package. "
"Please install it with `pip install angle_emb`."
) from exc

self.client = angle_emb.AnglE(
self.model_name, **self.model_kwargs
)
self.client.set_prompt(prompt=self.prompt)

class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a HuggingFace transformer model.

Args:
texts: The list of texts to embed.

Returns:
List of embeddings, one for each text.
"""
texts = [t.replace("\n", " ") for t in texts]
if isinstance(self.prompt, str):
texts = [{'text': text} for text in texts]
embeddings = self.client.encode(texts, **self.encode_kwargs)
return embeddings.tolist()

def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a HuggingFace transformer model.

Args:
text: The text to embed.

Returns:
Embeddings for the text.
"""
return self.embed_documents([text])[0]
1 change: 1 addition & 0 deletions libs/community/tests/unit_tests/embeddings/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"JohnSnowLabsEmbeddings",
"VoyageEmbeddings",
"BookendEmbeddings",
"HuggingFaceUaeEmbeddings",
"VolcanoEmbeddings",
]

Expand Down
Loading