diff --git a/libs/community/langchain_community/vectorstores/opensearch_vector_search.py b/libs/community/langchain_community/vectorstores/opensearch_vector_search.py index 99ece00e7a69c..685bf33d328f3 100644 --- a/libs/community/langchain_community/vectorstores/opensearch_vector_search.py +++ b/libs/community/langchain_community/vectorstores/opensearch_vector_search.py @@ -81,14 +81,16 @@ def _validate_aoss_with_engines(is_aoss: bool, engine: str) -> None: def _is_aoss_enabled(http_auth: Any) -> bool: - """Check if the service is http_auth is set as `aoss`.""" - if ( - http_auth is not None - and hasattr(http_auth, "service") - and http_auth.service == "aoss" - ): - return True - return False + """Check if service attribute of http_auth is set to `aoss`.""" + if http_auth is not None: + if hasattr(http_auth, "service") and http_auth.service == "aoss": + return True + elif hasattr(http_auth, "signer") and http_auth.signer.service == "aoss": + return True + else: + return False + else: + return False def _bulk_ingest_embeddings( diff --git a/libs/community/pyproject.toml b/libs/community/pyproject.toml index f05d6580af262..7532ea747f589 100644 --- a/libs/community/pyproject.toml +++ b/libs/community/pyproject.toml @@ -83,6 +83,8 @@ msal = {version = "^1.25.0", optional = true} databricks-vectorsearch = {version = "^0.21", optional = true} dgml-utils = {version = "^0.3.0", optional = true} datasets = {version = "^2.15.0", optional = true} +requests-aws4auth = {version = "^1.2.3", optional = true} +opensearch-py = {version = "^2.4.2", optional = true} [tool.poetry.group.test] optional = true @@ -241,6 +243,8 @@ extended_testing = [ "databricks-vectorsearch", "dgml-utils", "cohere", + "opensearch-py", + "requests-aws4auth", ] [tool.ruff] diff --git a/libs/community/tests/unit_tests/vectorstores/test_opensearch_vector_search.py b/libs/community/tests/unit_tests/vectorstores/test_opensearch_vector_search.py new file mode 100644 index 0000000000000..6dd6bbc5e9ce5 --- /dev/null +++ b/libs/community/tests/unit_tests/vectorstores/test_opensearch_vector_search.py @@ -0,0 +1,56 @@ +import pytest +from pytest_mock import MockerFixture + +from langchain_community.embeddings import FakeEmbeddings +from langchain_community.vectorstores.opensearch_vector_search import ( + OpenSearchVectorSearch, +) + + +@pytest.mark.requires("opensearchpy") +@pytest.mark.parametrize(["service", "expected"], [("aoss", True), ("es", False)]) +def test_detect_aoss_using_signer_auth( + mocker: MockerFixture, service: str, expected: bool +) -> None: + from opensearchpy import RequestsAWSV4SignerAuth + + mocker.patch.object(RequestsAWSV4SignerAuth, "_sign_request") + http_auth = RequestsAWSV4SignerAuth( + credentials="credentials", region="eu-central-1", service=service + ) + database = OpenSearchVectorSearch( + opensearch_url="http://localhost:9200", + index_name="test", + embedding_function=FakeEmbeddings(size=42), + http_auth=http_auth, + ) + + assert database.is_aoss == expected + + +@pytest.mark.requires("opensearchpy") +@pytest.mark.requires("requests_aws4auth") +@pytest.mark.parametrize(["service", "expected"], [("aoss", True), ("es", False)]) +def test_detect_aoss_using_aws4auth(service: str, expected: bool) -> None: + from requests_aws4auth import AWS4Auth + + http_auth = AWS4Auth("access_key_id", "secret_access_key", "eu-central-1", service) + database = OpenSearchVectorSearch( + opensearch_url="http://localhost:9200", + index_name="test", + embedding_function=FakeEmbeddings(size=42), + http_auth=http_auth, + ) + + assert database.is_aoss == expected + + +@pytest.mark.requires("opensearchpy") +def test_detect_aoss_using_no_auth() -> None: + database = OpenSearchVectorSearch( + opensearch_url="http://localhost:9200", + index_name="test", + embedding_function=FakeEmbeddings(size=42), + ) + + assert database.is_aoss is False