From 8a92f182fa6d36ea88bf81b689eddb2911614853 Mon Sep 17 00:00:00 2001 From: Sean Lee Date: Mon, 25 Dec 2023 10:29:25 +0800 Subject: [PATCH 1/2] support UAE sentence embeddings --- .../embeddings/__init__.py | 2 + .../embeddings/huggingface.py | 86 +++++++++++++++++++ .../unit_tests/embeddings/test_imports.py | 1 + 3 files changed, 89 insertions(+) diff --git a/libs/community/langchain_community/embeddings/__init__.py b/libs/community/langchain_community/embeddings/__init__.py index ce9cfc7aa0b76..4b30d3378ec77 100644 --- a/libs/community/langchain_community/embeddings/__init__.py +++ b/libs/community/langchain_community/embeddings/__init__.py @@ -47,6 +47,7 @@ HuggingFaceEmbeddings, HuggingFaceInferenceAPIEmbeddings, HuggingFaceInstructEmbeddings, + HuggingFaceUaeEmbeddings ) from langchain_community.embeddings.huggingface_hub import HuggingFaceHubEmbeddings from langchain_community.embeddings.infinity import InfinityEmbeddings @@ -136,6 +137,7 @@ "JohnSnowLabsEmbeddings", "VoyageEmbeddings", "BookendEmbeddings", + "HuggingFaceUaeEmbeddings", ] diff --git a/libs/community/langchain_community/embeddings/huggingface.py b/libs/community/langchain_community/embeddings/huggingface.py index 84a568866f178..10865b9996599 100644 --- a/libs/community/langchain_community/embeddings/huggingface.py +++ b/libs/community/langchain_community/embeddings/huggingface.py @@ -7,6 +7,7 @@ DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2" DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large" DEFAULT_BGE_MODEL = "BAAI/bge-large-en" +DEFAULT_UAE_MODEL = "WhereIsAI/UAE-Large-V1" DEFAULT_EMBED_INSTRUCTION = "Represent the document for retrieval: " DEFAULT_QUERY_INSTRUCTION = ( "Represent the question for retrieving supporting documents: " @@ -341,3 +342,88 @@ def embed_query(self, text: str) -> List[float]: Embeddings for the text. """ return self.embed_documents([text])[0] + + +class HuggingFaceUaeEmbeddings(BaseModel, Embeddings): + """HuggingFace UAE sentence embedding models. + Arxiv: https://arxiv.org/abs/2309.12871 + + To use, you should have the ``angle_emb`` python package installed. + + Example: + .. code-block:: python + + from langchain_community.embeddings import HuggingFaceUaeEmbeddings + + model_name = "WhereIsAI/UAE-Large-V1" + model_kwargs = { + 'device': 'cpu', + 'pooling_strategy': 'cls', + } + encode_kwargs = {'to_numpy': True} + prompt = None + hf = HuggingFaceUaeEmbeddings( + model_name=model_name, + model_kwargs=model_kwargs, + encode_kwargs=encode_kwargs, + prompt=prompt + ) + """ + + client: Any #: :meta private: + model_name: str = DEFAULT_UAE_MODEL + """Model name to use.""" + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Keyword arguments to pass to the model.""" + encode_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Keyword arguments to pass when calling the `encode` method of the model.""" + prompt: Optional[str] = None + """prompt argument""" + + def __init__(self, **kwargs: Any): + """Initialize the angle_emb.""" + super().__init__(**kwargs) + try: + import angle_emb + + except ImportError as exc: + raise ImportError( + "Could not import angle_emb python package. " + "Please install it with `pip install angle_emb`." + ) from exc + + self.client = angle_emb.AnglE( + self.model_name, **self.model_kwargs + ) + self.client.set_prompt(prompt=self.prompt) + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Compute doc embeddings using a HuggingFace transformer model. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text. + """ + texts = [t.replace("\n", " ") for t in texts] + if isinstance(self.prompt, str): + texts = [{'text': text} for text in texts] + embeddings = self.client.encode(texts, **self.encode_kwargs) + return embeddings.tolist() + + def embed_query(self, text: str) -> List[float]: + """Compute query embeddings using a HuggingFace transformer model. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + return self.embed_documents([text])[0] diff --git a/libs/community/tests/unit_tests/embeddings/test_imports.py b/libs/community/tests/unit_tests/embeddings/test_imports.py index d33d98e493b26..75e7b0c90757d 100644 --- a/libs/community/tests/unit_tests/embeddings/test_imports.py +++ b/libs/community/tests/unit_tests/embeddings/test_imports.py @@ -53,6 +53,7 @@ "JohnSnowLabsEmbeddings", "VoyageEmbeddings", "BookendEmbeddings", + "HuggingFaceUaeEmbeddings", ] From 490b371683ccbe7550879e792a3299f498ec7e2c Mon Sep 17 00:00:00 2001 From: Sean Lee Date: Wed, 31 Jan 2024 09:09:41 +0800 Subject: [PATCH 2/2] lint --- libs/community/langchain_community/embeddings/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/community/langchain_community/embeddings/__init__.py b/libs/community/langchain_community/embeddings/__init__.py index 85669045211a8..31447d2db7d97 100644 --- a/libs/community/langchain_community/embeddings/__init__.py +++ b/libs/community/langchain_community/embeddings/__init__.py @@ -11,7 +11,7 @@ """ -import logging +import logging # NOQA from typing import Any from langchain_community.embeddings.aleph_alpha import (