From bd0d379b5d801961173edd757bdfaf711dbe751d Mon Sep 17 00:00:00 2001 From: Stefan Langenbach Date: Tue, 12 Dec 2023 10:25:53 +0000 Subject: [PATCH 1/2] Fix AOSS detection --- .../vectorstores/opensearch_vector_search.py | 18 ++++--- .../test_opensearch_vector_search.py | 52 +++++++++++++++++++ libs/langchain/poetry.lock | 46 +++++++++++++++- libs/langchain/pyproject.toml | 4 ++ 4 files changed, 110 insertions(+), 10 deletions(-) create mode 100644 libs/community/tests/unit_tests/vectorstores/test_opensearch_vector_search.py diff --git a/libs/community/langchain_community/vectorstores/opensearch_vector_search.py b/libs/community/langchain_community/vectorstores/opensearch_vector_search.py index 99ece00e7a69c..685bf33d328f3 100644 --- a/libs/community/langchain_community/vectorstores/opensearch_vector_search.py +++ b/libs/community/langchain_community/vectorstores/opensearch_vector_search.py @@ -81,14 +81,16 @@ def _validate_aoss_with_engines(is_aoss: bool, engine: str) -> None: def _is_aoss_enabled(http_auth: Any) -> bool: - """Check if the service is http_auth is set as `aoss`.""" - if ( - http_auth is not None - and hasattr(http_auth, "service") - and http_auth.service == "aoss" - ): - return True - return False + """Check if service attribute of http_auth is set to `aoss`.""" + if http_auth is not None: + if hasattr(http_auth, "service") and http_auth.service == "aoss": + return True + elif hasattr(http_auth, "signer") and http_auth.signer.service == "aoss": + return True + else: + return False + else: + return False def _bulk_ingest_embeddings( diff --git a/libs/community/tests/unit_tests/vectorstores/test_opensearch_vector_search.py b/libs/community/tests/unit_tests/vectorstores/test_opensearch_vector_search.py new file mode 100644 index 0000000000000..a7fd55a7e0bb1 --- /dev/null +++ b/libs/community/tests/unit_tests/vectorstores/test_opensearch_vector_search.py @@ -0,0 +1,52 @@ +import pytest +from opensearchpy import RequestsAWSV4SignerAuth +from pytest_mock import MockerFixture +from requests_aws4auth import AWS4Auth + +from langchain_community.vectorstores.opensearch_vector_search import OpenSearchVectorSearch +from langchain_community.embeddings import FakeEmbeddings + + +@pytest.mark.requires("opensearchpy") +@pytest.mark.parametrize(["service", "expected"], [("aoss", True), ("es", False)]) +def test_detect_aoss_using_signer_auth( + mocker: MockerFixture, service: str, expected: bool +) -> None: + mocker.patch.object(RequestsAWSV4SignerAuth, "_sign_request") + http_auth = RequestsAWSV4SignerAuth( + credentials="credentials", region="eu-central-1", service=service + ) + database = OpenSearchVectorSearch( + opensearch_url="http://localhost:9200", + index_name="test", + embedding_function=FakeEmbeddings(size=42), + http_auth=http_auth, + ) + + assert database.is_aoss == expected + + +@pytest.mark.requires("opensearchpy") +@pytest.mark.requires("requests_aws4auth") +@pytest.mark.parametrize(["service", "expected"], [("aoss", True), ("es", False)]) +def test_detect_aoss_using_aws4auth(service: str, expected: bool) -> None: + http_auth = AWS4Auth("access_key_id", "secret_access_key", "eu-central-1", service) + database = OpenSearchVectorSearch( + opensearch_url="http://localhost:9200", + index_name="test", + embedding_function=FakeEmbeddings(size=42), + http_auth=http_auth, + ) + + assert database.is_aoss == expected + + +@pytest.mark.requires("opensearchpy") +def test_detect_aoss_using_no_auth() -> None: + database = OpenSearchVectorSearch( + opensearch_url="http://localhost:9200", + index_name="test", + embedding_function=FakeEmbeddings(size=42), + ) + + assert database.is_aoss is False diff --git a/libs/langchain/poetry.lock b/libs/langchain/poetry.lock index 6387498bcd9f2..98f4e6d588124 100644 --- a/libs/langchain/poetry.lock +++ b/libs/langchain/poetry.lock @@ -4696,6 +4696,30 @@ files = [ [package.dependencies] requests = ">=2,<3" +[[package]] +name = "opensearch-py" +version = "2.4.2" +description = "Python client for OpenSearch" +optional = true +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, <4" +files = [ + {file = "opensearch-py-2.4.2.tar.gz", hash = "sha256:564f175af134aa885f4ced6846eb4532e08b414fff0a7976f76b276fe0e69158"}, + {file = "opensearch_py-2.4.2-py2.py3-none-any.whl", hash = "sha256:7867319132133e2974c09f76a54eb1d502b989229be52da583d93ddc743ea111"}, +] + +[package.dependencies] +certifi = ">=2022.12.07" +python-dateutil = "*" +requests = ">=2.4.0,<3.0.0" +six = "*" +urllib3 = ">=1.26.18" + +[package.extras] +async = ["aiohttp (>=3,<4)"] +develop = ["black", "botocore", "coverage (<8.0.0)", "jinja2", "mock", "myst-parser", "pytest (>=3.0.0)", "pytest-cov", "pytest-mock (<4.0.0)", "pytz", "pyyaml", "requests (>=2.0.0,<3.0.0)", "sphinx", "sphinx-copybutton", "sphinx-rtd-theme"] +docs = ["aiohttp (>=3,<4)", "myst-parser", "sphinx", "sphinx-copybutton", "sphinx-rtd-theme"] +kerberos = ["requests-kerberos"] + [[package]] name = "orjson" version = "3.9.10" @@ -6831,6 +6855,24 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "requests-aws4auth" +version = "1.2.3" +description = "AWS4 authentication for Requests" +optional = true +python-versions = ">=3.3" +files = [ + {file = "requests-aws4auth-1.2.3.tar.gz", hash = "sha256:d4c73c19f37f80d4aa9c5bd4fa376cfd0c69299c48b00a8eb2ae6b0416164fb8"}, + {file = "requests_aws4auth-1.2.3-py2.py3-none-any.whl", hash = "sha256:8070a5207e95fa5fe88e87d9a75f34e768cbab35bb3557ef20cbbf9426dee4d5"}, +] + +[package.dependencies] +requests = "*" +six = "*" + +[package.extras] +httpx = ["httpx"] + [[package]] name = "requests-file" version = "1.5.1" @@ -9093,7 +9135,7 @@ cli = ["typer"] cohere = ["cohere"] docarray = ["docarray"] embeddings = ["sentence-transformers"] -extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cohere", "couchbase", "dashvector", "databricks-vectorsearch", "datasets", "dgml-utils", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "hologres-vector", "html2text", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openai", "openapi-pydantic", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict"] +extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cohere", "couchbase", "dashvector", "databricks-vectorsearch", "datasets", "dgml-utils", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "hologres-vector", "html2text", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openai", "openapi-pydantic", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-aws4auth", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict"] javascript = ["esprima"] llms = ["clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openlm", "torch", "transformers"] openai = ["openai", "tiktoken"] @@ -9103,4 +9145,4 @@ text-helpers = ["chardet"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "ef04ae95bce50580df3bc8df24d6ab273c01f968622b4c85a89e6b4e64ac8cd1" +content-hash = "b4271d2d998d68d9a4440a04ea8b8752449a9cab1490b5189187b521a0741fe1" diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index 7bba73b4a18c3..f42824d2ad27d 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -110,6 +110,8 @@ databricks-vectorsearch = {version = "^0.21", optional = true} couchbase = {version = "^4.1.9", optional = true} dgml-utils = {version = "^0.3.0", optional = true} datasets = {version = "^2.15.0", optional = true} +requests-aws4auth = {version = "^1.2.3", optional = true} +opensearch-py = {version = "^2.4.2", optional = true} [tool.poetry.group.test] optional = true @@ -294,6 +296,8 @@ extended_testing = [ "couchbase", "dgml-utils", "cohere", + "opensearch-py", + "requests-aws4auth", ] [tool.ruff] From e5d2ab6105e01248dd9fbe65d08528a9d1e1f652 Mon Sep 17 00:00:00 2001 From: Stefan Langenbach Date: Tue, 12 Dec 2023 10:38:14 +0000 Subject: [PATCH 2/2] Fix AOSS tests --- .../vectorstores/test_opensearch_vector_search.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/libs/community/tests/unit_tests/vectorstores/test_opensearch_vector_search.py b/libs/community/tests/unit_tests/vectorstores/test_opensearch_vector_search.py index a7fd55a7e0bb1..6dd6bbc5e9ce5 100644 --- a/libs/community/tests/unit_tests/vectorstores/test_opensearch_vector_search.py +++ b/libs/community/tests/unit_tests/vectorstores/test_opensearch_vector_search.py @@ -1,10 +1,10 @@ import pytest -from opensearchpy import RequestsAWSV4SignerAuth from pytest_mock import MockerFixture -from requests_aws4auth import AWS4Auth -from langchain_community.vectorstores.opensearch_vector_search import OpenSearchVectorSearch from langchain_community.embeddings import FakeEmbeddings +from langchain_community.vectorstores.opensearch_vector_search import ( + OpenSearchVectorSearch, +) @pytest.mark.requires("opensearchpy") @@ -12,6 +12,8 @@ def test_detect_aoss_using_signer_auth( mocker: MockerFixture, service: str, expected: bool ) -> None: + from opensearchpy import RequestsAWSV4SignerAuth + mocker.patch.object(RequestsAWSV4SignerAuth, "_sign_request") http_auth = RequestsAWSV4SignerAuth( credentials="credentials", region="eu-central-1", service=service @@ -30,6 +32,8 @@ def test_detect_aoss_using_signer_auth( @pytest.mark.requires("requests_aws4auth") @pytest.mark.parametrize(["service", "expected"], [("aoss", True), ("es", False)]) def test_detect_aoss_using_aws4auth(service: str, expected: bool) -> None: + from requests_aws4auth import AWS4Auth + http_auth = AWS4Auth("access_key_id", "secret_access_key", "eu-central-1", service) database = OpenSearchVectorSearch( opensearch_url="http://localhost:9200",