Skip to content

Commit

Permalink
Added langchain embedding.
Browse files Browse the repository at this point in the history
  • Loading branch information
lu-ohai committed Jan 14, 2025
1 parent b1d3e25 commit c3dd4f9
Show file tree
Hide file tree
Showing 5 changed files with 365 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
{
"cells": [
{
"cell_type": "raw",
"metadata": {
"vscode": {
"languageId": "raw"
}
},
"source": [
"---\n",
"keywords: [OCIModelDeploymentEndpointEmbeddings]\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# OCI Data Science Model Deployment Endpoint\n",
"\n",
"[OCI Data Science](https://docs.oracle.com/en-us/iaas/data-science/using/home.htm) is a fully managed and serverless platform for data science teams to build, train, and manage machine learning models in the Oracle Cloud Infrastructure.\n",
"\n",
"This notebooks goes over how to use an embedding model hosted on a [OCI Data Science Model Deployment](https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-about.htm).\n",
"\n",
"To authenticate, [oracle-ads](https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html) has been used to automatically load credentials for invoking endpoint."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Prerequisite\n",
"We will need to install the `oracle-ads` sdk"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip3 install -U oracle-ads"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prerequisite\n",
"\n",
"### Deploy model\n",
"Check [Oracle GitHub samples repository](https://github.com/oracle-samples/oci-data-science-ai-samples/tree/main/model-deployment/containers/llama2) on how to deploy your embedding model on OCI Data Science Model deployment.\n",
"\n",
"### Policies\n",
"Make sure to have the required [policies](https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm#model_dep_policies_auth__predict-endpoint) to access the OCI Data Science Model Deployment endpoint.\n",
"\n",
"## Set up\n",
"After having deployed model, you have to set up **`endpoint`**: The model HTTP endpoint from the deployed model, e.g. `\"https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<MD_OCID>/predict\"` of the `OCIModelDeploymentEndpointEmbeddings` call.\n",
"\n",
"\n",
"### Authentication\n",
"\n",
"You can set authentication through either ads or environment variables. When you are working in OCI Data Science Notebook Session, you can leverage resource principal to access other OCI resources. Check out [here](https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html) to see more options. \n",
"\n",
"## Example"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import ads\n",
"from langchain_community.embeddings import OCIModelDeploymentEndpointEmbeddings\n",
"\n",
"# Set authentication through ads\n",
"# Use resource principal are operating within a\n",
"# OCI service that has resource principal based\n",
"# authentication configured\n",
"ads.set_auth(\"resource_principal\")\n",
"\n",
"# Create an instance of OCI Model Deployment Endpoint\n",
"# Replace the endpoint uri with your own\n",
"embeddings = OCIModelDeploymentEndpointEmbeddings(\n",
" endpoint=\"https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<MD_OCID>/predict\",\n",
")\n",
"\n",
"query = \"Hello World!\"\n",
"embeddings.embed_query(query)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"documents = [\"This is a sample document\", \"and here is another one\"]\n",
"embeddings.embed_documents(documents)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "oci_langchain",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
5 changes: 5 additions & 0 deletions libs/community/langchain_community/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@
from langchain_community.embeddings.nlpcloud import (
NLPCloudEmbeddings,
)
from langchain_community.embeddings.oci_data_science_model_deployment_endpoint import (
OCIModelDeploymentEndpointEmbeddings,
)
from langchain_community.embeddings.oci_generative_ai import (
OCIGenAIEmbeddings,
)
Expand Down Expand Up @@ -300,6 +303,7 @@
"MosaicMLInstructorEmbeddings",
"NLPCloudEmbeddings",
"NeMoEmbeddings",
"OCIModelDeploymentEndpointEmbeddings",
"OCIGenAIEmbeddings",
"OctoAIEmbeddings",
"OllamaEmbeddings",
Expand Down Expand Up @@ -385,6 +389,7 @@
"MosaicMLInstructorEmbeddings": "langchain_community.embeddings.mosaicml",
"NLPCloudEmbeddings": "langchain_community.embeddings.nlpcloud",
"NeMoEmbeddings": "langchain_community.embeddings.nemo",
"OCIModelDeploymentEndpointEmbeddings": "langchain_community.embeddings.oci_data_science_model_deployment_endpoint",
"OCIGenAIEmbeddings": "langchain_community.embeddings.oci_generative_ai",
"OctoAIEmbeddings": "langchain_community.embeddings.octoai_embeddings",
"OllamaEmbeddings": "langchain_community.embeddings.ollama",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
from typing import Any, Dict, List, Optional, Mapping, Callable
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, root_validator, Field
from langchain_core.utils import get_from_dict_or_env
from langchain_core.language_models.llms import create_base_retry_decorator
import requests


DEFAULT_HEADER = {
"Content-Type": "application/json",
}


class TokenExpiredError(Exception):
pass


def _create_retry_decorator(llm) -> Callable[[Any], Any]:
"""Creates a retry decorator."""
errors = [requests.exceptions.ConnectTimeout, TokenExpiredError]
decorator = create_base_retry_decorator(
error_types=errors, max_retries=llm.max_retries
)
return decorator


class OCIModelDeploymentEndpointEmbeddings(BaseModel, Embeddings):
"""Embedding model deployed on OCI Data Science Model Deployment.
Example:
.. code-block:: python
from langchain_community.embeddings import OCIModelDeploymentEndpointEmbeddings
embeddings = OCIModelDeploymentEndpointEmbeddings(
endpoint="https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<md_ocid>/predict",
)
"""

auth: dict = Field(default_factory=dict, exclude=True)
"""ADS auth dictionary for OCI authentication:
https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html.
This can be generated by calling `ads.common.auth.api_keys()`
or `ads.common.auth.resource_principal()`. If this is not
provided then the `ads.common.default_signer()` will be used."""

endpoint: str = ""
"""The uri of the endpoint from the deployed Model Deployment model."""

model_kwargs: Optional[Dict] = None
"""Keyword arguments to pass to the model."""

endpoint_kwargs: Optional[Dict] = None
"""Optional attributes (except for headers) passed to the request.post
function.
"""

max_retries: int = 1
"""The maximum number of retries to make when generating."""

@root_validator()
def validate_environment( # pylint: disable=no-self-argument
cls, values: Dict
) -> Dict:
"""Validate that python package exists in environment."""
try:
import ads

except ImportError as ex:
raise ImportError(
"Could not import ads python package. "
"Please install it with `pip install oracle_ads`."
) from ex
if not values.get("auth", None):
values["auth"] = ads.common.auth.default_signer()
values["endpoint"] = get_from_dict_or_env(
values,
"endpoint",
"OCI_LLM_ENDPOINT",
)
return values

@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
_model_kwargs = self.model_kwargs or {}
return {
**{"endpoint": self.endpoint},
**{"model_kwargs": _model_kwargs},
}

def _embed_with_retry(self, **kwargs) -> Any:
"""Use tenacity to retry the call."""
retry_decorator = _create_retry_decorator(self)

@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
try:
response = requests.post(self.endpoint, **kwargs)
response.raise_for_status()
return response
except requests.exceptions.HTTPError as http_err:
if response.status_code == 401 and self._refresh_signer():
raise TokenExpiredError() from http_err
else:
raise ValueError(
f"Server error: {str(http_err)}. Message: {response.text}"
) from http_err
except Exception as e:
raise ValueError(f"Error occurs by inference endpoint: {str(e)}") from e

return _completion_with_retry(**kwargs)

def _embedding(self, texts: List[str]) -> List[List[float]]:
"""Call out to OCI Data Science Model Deployment Endpoint.
Args:
texts: A list of texts to embed.
Returns:
A list of list of floats representing the embeddings, or None if an
error occurs.
"""
_model_kwargs = self.model_kwargs or {}
body = self._construct_request_body(texts, _model_kwargs)
request_kwargs = self._construct_request_kwargs(body)
response = self._embed_with_retry(**request_kwargs)
return self._proceses_response(response)

def _construct_request_kwargs(self, body: Any) -> dict:
"""Constructs the request kwargs as a dictionary."""
from ads.model.common.utils import _is_json_serializable

_endpoint_kwargs = self.endpoint_kwargs or {}
headers = _endpoint_kwargs.pop("headers", DEFAULT_HEADER)
return (
dict(
headers=headers,
json=body,
auth=self.auth.get("signer"),
**_endpoint_kwargs,
)
if _is_json_serializable(body)
else dict(
headers=headers,
data=body,
auth=self.auth.get("signer"),
**_endpoint_kwargs,
)
)

def _construct_request_body(self, texts: List[str], params: dict) -> Any:
"""Constructs the request body."""
return {"input": texts}

def _proceses_response(self, response: requests.Response) -> List[List[float]]:
"""Extracts results from requests.Response."""
try:
res_json = response.json()
embeddings = res_json["data"][0]["embedding"]
except Exception as e:
raise ValueError(
f"Error raised by inference API: {e}.\nResponse: {response.text}"
)
return embeddings

def embed_documents(
self,
texts: List[str],
chunk_size: Optional[int] = None,
) -> List[List[float]]:
"""Compute doc embeddings using OCI Data Science Model Deployment Endpoint.
Args:
texts: The list of texts to embed.
chunk_size: The chunk size defines how many input texts will
be grouped together as request. If None, will use the
chunk size specified by the class.
Returns:
List of embeddings, one for each text.
"""
results = []
_chunk_size = (
len(texts) if (not chunk_size or chunk_size > len(texts)) else chunk_size
)
for i in range(0, len(texts), _chunk_size):
response = self._embedding(texts[i : i + _chunk_size])
results.extend(response)
return results

def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using OCI Data Science Model Deployment Endpoint.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return self._embedding([text])[0]
1 change: 1 addition & 0 deletions libs/community/tests/unit_tests/embeddings/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
"VoyageEmbeddings",
"BookendEmbeddings",
"VolcanoEmbeddings",
"OCIModelDeploymentEndpointEmbeddings",
"OCIGenAIEmbeddings",
"QuantizedBiEncoderEmbeddings",
"NeMoEmbeddings",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Test OCI Data Science Model Deployment Endpoint."""

import responses
import pytest
from pytest_mock import MockerFixture
from langchain_community.embeddings import OCIModelDeploymentEndpointEmbeddings


@pytest.mark.requires("ads")
@responses.activate
def test_embedding_call(mocker: MockerFixture) -> None:
"""Test valid call to oci model deployment endpoint."""
endpoint = "https://MD_OCID/predict"
documents = ["Hello", "World"]
expected_output = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]
responses.add(
responses.POST,
endpoint,
json={
"embeddings": expected_output,
},
status=200,
)
mocker.patch("ads.common.auth.default_signer", return_value=dict(signer=None))

embeddings = OCIModelDeploymentEndpointEmbeddings( # type: ignore[call-arg]
endpoint=endpoint,
)

output = embeddings.embed_documents(documents)
assert output == expected_output

0 comments on commit c3dd4f9

Please sign in to comment.