forked from langchain-ai/langchain
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
365 additions
and
0 deletions.
There are no files selected for viewing
126 changes: 126 additions & 0 deletions
126
docs/docs/integrations/text_embedding/oci_model_deployment_endpoint.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "raw", | ||
"metadata": { | ||
"vscode": { | ||
"languageId": "raw" | ||
} | ||
}, | ||
"source": [ | ||
"---\n", | ||
"keywords: [OCIModelDeploymentEndpointEmbeddings]\n", | ||
"---" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# OCI Data Science Model Deployment Endpoint\n", | ||
"\n", | ||
"[OCI Data Science](https://docs.oracle.com/en-us/iaas/data-science/using/home.htm) is a fully managed and serverless platform for data science teams to build, train, and manage machine learning models in the Oracle Cloud Infrastructure.\n", | ||
"\n", | ||
"This notebooks goes over how to use an embedding model hosted on a [OCI Data Science Model Deployment](https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-about.htm).\n", | ||
"\n", | ||
"To authenticate, [oracle-ads](https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html) has been used to automatically load credentials for invoking endpoint." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Prerequisite\n", | ||
"We will need to install the `oracle-ads` sdk" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!pip3 install -U oracle-ads" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Prerequisite\n", | ||
"\n", | ||
"### Deploy model\n", | ||
"Check [Oracle GitHub samples repository](https://github.com/oracle-samples/oci-data-science-ai-samples/tree/main/model-deployment/containers/llama2) on how to deploy your embedding model on OCI Data Science Model deployment.\n", | ||
"\n", | ||
"### Policies\n", | ||
"Make sure to have the required [policies](https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm#model_dep_policies_auth__predict-endpoint) to access the OCI Data Science Model Deployment endpoint.\n", | ||
"\n", | ||
"## Set up\n", | ||
"After having deployed model, you have to set up **`endpoint`**: The model HTTP endpoint from the deployed model, e.g. `\"https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<MD_OCID>/predict\"` of the `OCIModelDeploymentEndpointEmbeddings` call.\n", | ||
"\n", | ||
"\n", | ||
"### Authentication\n", | ||
"\n", | ||
"You can set authentication through either ads or environment variables. When you are working in OCI Data Science Notebook Session, you can leverage resource principal to access other OCI resources. Check out [here](https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html) to see more options. \n", | ||
"\n", | ||
"## Example" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import ads\n", | ||
"from langchain_community.embeddings import OCIModelDeploymentEndpointEmbeddings\n", | ||
"\n", | ||
"# Set authentication through ads\n", | ||
"# Use resource principal are operating within a\n", | ||
"# OCI service that has resource principal based\n", | ||
"# authentication configured\n", | ||
"ads.set_auth(\"resource_principal\")\n", | ||
"\n", | ||
"# Create an instance of OCI Model Deployment Endpoint\n", | ||
"# Replace the endpoint uri with your own\n", | ||
"embeddings = OCIModelDeploymentEndpointEmbeddings(\n", | ||
" endpoint=\"https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<MD_OCID>/predict\",\n", | ||
")\n", | ||
"\n", | ||
"query = \"Hello World!\"\n", | ||
"embeddings.embed_query(query)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"documents = [\"This is a sample document\", \"and here is another one\"]\n", | ||
"embeddings.embed_documents(documents)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "oci_langchain", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.8.9" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
202 changes: 202 additions & 0 deletions
202
libs/community/langchain_community/embeddings/oci_data_science_model_deployment_endpoint.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
from typing import Any, Dict, List, Optional, Mapping, Callable | ||
from langchain_core.embeddings import Embeddings | ||
from langchain_core.pydantic_v1 import BaseModel, root_validator, Field | ||
from langchain_core.utils import get_from_dict_or_env | ||
from langchain_core.language_models.llms import create_base_retry_decorator | ||
import requests | ||
|
||
|
||
DEFAULT_HEADER = { | ||
"Content-Type": "application/json", | ||
} | ||
|
||
|
||
class TokenExpiredError(Exception): | ||
pass | ||
|
||
|
||
def _create_retry_decorator(llm) -> Callable[[Any], Any]: | ||
"""Creates a retry decorator.""" | ||
errors = [requests.exceptions.ConnectTimeout, TokenExpiredError] | ||
decorator = create_base_retry_decorator( | ||
error_types=errors, max_retries=llm.max_retries | ||
) | ||
return decorator | ||
|
||
|
||
class OCIModelDeploymentEndpointEmbeddings(BaseModel, Embeddings): | ||
"""Embedding model deployed on OCI Data Science Model Deployment. | ||
Example: | ||
.. code-block:: python | ||
from langchain_community.embeddings import OCIModelDeploymentEndpointEmbeddings | ||
embeddings = OCIModelDeploymentEndpointEmbeddings( | ||
endpoint="https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<md_ocid>/predict", | ||
) | ||
""" | ||
|
||
auth: dict = Field(default_factory=dict, exclude=True) | ||
"""ADS auth dictionary for OCI authentication: | ||
https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html. | ||
This can be generated by calling `ads.common.auth.api_keys()` | ||
or `ads.common.auth.resource_principal()`. If this is not | ||
provided then the `ads.common.default_signer()` will be used.""" | ||
|
||
endpoint: str = "" | ||
"""The uri of the endpoint from the deployed Model Deployment model.""" | ||
|
||
model_kwargs: Optional[Dict] = None | ||
"""Keyword arguments to pass to the model.""" | ||
|
||
endpoint_kwargs: Optional[Dict] = None | ||
"""Optional attributes (except for headers) passed to the request.post | ||
function. | ||
""" | ||
|
||
max_retries: int = 1 | ||
"""The maximum number of retries to make when generating.""" | ||
|
||
@root_validator() | ||
def validate_environment( # pylint: disable=no-self-argument | ||
cls, values: Dict | ||
) -> Dict: | ||
"""Validate that python package exists in environment.""" | ||
try: | ||
import ads | ||
|
||
except ImportError as ex: | ||
raise ImportError( | ||
"Could not import ads python package. " | ||
"Please install it with `pip install oracle_ads`." | ||
) from ex | ||
if not values.get("auth", None): | ||
values["auth"] = ads.common.auth.default_signer() | ||
values["endpoint"] = get_from_dict_or_env( | ||
values, | ||
"endpoint", | ||
"OCI_LLM_ENDPOINT", | ||
) | ||
return values | ||
|
||
@property | ||
def _identifying_params(self) -> Mapping[str, Any]: | ||
"""Get the identifying parameters.""" | ||
_model_kwargs = self.model_kwargs or {} | ||
return { | ||
**{"endpoint": self.endpoint}, | ||
**{"model_kwargs": _model_kwargs}, | ||
} | ||
|
||
def _embed_with_retry(self, **kwargs) -> Any: | ||
"""Use tenacity to retry the call.""" | ||
retry_decorator = _create_retry_decorator(self) | ||
|
||
@retry_decorator | ||
def _completion_with_retry(**kwargs: Any) -> Any: | ||
try: | ||
response = requests.post(self.endpoint, **kwargs) | ||
response.raise_for_status() | ||
return response | ||
except requests.exceptions.HTTPError as http_err: | ||
if response.status_code == 401 and self._refresh_signer(): | ||
raise TokenExpiredError() from http_err | ||
else: | ||
raise ValueError( | ||
f"Server error: {str(http_err)}. Message: {response.text}" | ||
) from http_err | ||
except Exception as e: | ||
raise ValueError(f"Error occurs by inference endpoint: {str(e)}") from e | ||
|
||
return _completion_with_retry(**kwargs) | ||
|
||
def _embedding(self, texts: List[str]) -> List[List[float]]: | ||
"""Call out to OCI Data Science Model Deployment Endpoint. | ||
Args: | ||
texts: A list of texts to embed. | ||
Returns: | ||
A list of list of floats representing the embeddings, or None if an | ||
error occurs. | ||
""" | ||
_model_kwargs = self.model_kwargs or {} | ||
body = self._construct_request_body(texts, _model_kwargs) | ||
request_kwargs = self._construct_request_kwargs(body) | ||
response = self._embed_with_retry(**request_kwargs) | ||
return self._proceses_response(response) | ||
|
||
def _construct_request_kwargs(self, body: Any) -> dict: | ||
"""Constructs the request kwargs as a dictionary.""" | ||
from ads.model.common.utils import _is_json_serializable | ||
|
||
_endpoint_kwargs = self.endpoint_kwargs or {} | ||
headers = _endpoint_kwargs.pop("headers", DEFAULT_HEADER) | ||
return ( | ||
dict( | ||
headers=headers, | ||
json=body, | ||
auth=self.auth.get("signer"), | ||
**_endpoint_kwargs, | ||
) | ||
if _is_json_serializable(body) | ||
else dict( | ||
headers=headers, | ||
data=body, | ||
auth=self.auth.get("signer"), | ||
**_endpoint_kwargs, | ||
) | ||
) | ||
|
||
def _construct_request_body(self, texts: List[str], params: dict) -> Any: | ||
"""Constructs the request body.""" | ||
return {"input": texts} | ||
|
||
def _proceses_response(self, response: requests.Response) -> List[List[float]]: | ||
"""Extracts results from requests.Response.""" | ||
try: | ||
res_json = response.json() | ||
embeddings = res_json["data"][0]["embedding"] | ||
except Exception as e: | ||
raise ValueError( | ||
f"Error raised by inference API: {e}.\nResponse: {response.text}" | ||
) | ||
return embeddings | ||
|
||
def embed_documents( | ||
self, | ||
texts: List[str], | ||
chunk_size: Optional[int] = None, | ||
) -> List[List[float]]: | ||
"""Compute doc embeddings using OCI Data Science Model Deployment Endpoint. | ||
Args: | ||
texts: The list of texts to embed. | ||
chunk_size: The chunk size defines how many input texts will | ||
be grouped together as request. If None, will use the | ||
chunk size specified by the class. | ||
Returns: | ||
List of embeddings, one for each text. | ||
""" | ||
results = [] | ||
_chunk_size = ( | ||
len(texts) if (not chunk_size or chunk_size > len(texts)) else chunk_size | ||
) | ||
for i in range(0, len(texts), _chunk_size): | ||
response = self._embedding(texts[i : i + _chunk_size]) | ||
results.extend(response) | ||
return results | ||
|
||
def embed_query(self, text: str) -> List[float]: | ||
"""Compute query embeddings using OCI Data Science Model Deployment Endpoint. | ||
Args: | ||
text: The text to embed. | ||
Returns: | ||
Embeddings for the text. | ||
""" | ||
return self._embedding([text])[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
31 changes: 31 additions & 0 deletions
31
libs/community/tests/unit_tests/embeddings/test_oci_model_deployment_endpoint.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
"""Test OCI Data Science Model Deployment Endpoint.""" | ||
|
||
import responses | ||
import pytest | ||
from pytest_mock import MockerFixture | ||
from langchain_community.embeddings import OCIModelDeploymentEndpointEmbeddings | ||
|
||
|
||
@pytest.mark.requires("ads") | ||
@responses.activate | ||
def test_embedding_call(mocker: MockerFixture) -> None: | ||
"""Test valid call to oci model deployment endpoint.""" | ||
endpoint = "https://MD_OCID/predict" | ||
documents = ["Hello", "World"] | ||
expected_output = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]] | ||
responses.add( | ||
responses.POST, | ||
endpoint, | ||
json={ | ||
"embeddings": expected_output, | ||
}, | ||
status=200, | ||
) | ||
mocker.patch("ads.common.auth.default_signer", return_value=dict(signer=None)) | ||
|
||
embeddings = OCIModelDeploymentEndpointEmbeddings( # type: ignore[call-arg] | ||
endpoint=endpoint, | ||
) | ||
|
||
output = embeddings.embed_documents(documents) | ||
assert output == expected_output |