From 0d983e23c983ed3c002e2df4fc68f247a65491f7 Mon Sep 17 00:00:00 2001 From: jingyizhu99 <83610845+jingyizhu99@users.noreply.github.com> Date: Tue, 7 May 2024 12:06:28 -0700 Subject: [PATCH] Improve serverless handling and AzureAISearchSource as input (#3071) # Description Please add an informative description that covers that changes made by the pull request and link all relevant issues. Sample notebook updated accordingly: https://github.com/microsoft/promptflow/blob/9de72467d945112ec1e91ab323e61a6658333ead/src/promptflow-rag/build_index_sample.ipynb # All Promptflow Contribution checklist: - [x] **The pull request does not introduce [breaking changes].** - [x] **CHANGELOG is updated for new features, bug fixes or other significant changes.** - [x] **I have read the [contribution guidelines](../CONTRIBUTING.md).** - [ ] **Create an issue and link to the pull request to get dedicated review from promptflow team. Learn more: [suggested workflow](../CONTRIBUTING.md#suggested-workflow).** ## General Guidelines and Best Practices - [x] Title of the pull request is clear and informative. - [x] There are a small number of commits, each of which have an informative message. This means that previously merged commits do not appear in the history of the PR. For more information on cleaning up the commits in your PR, [see this page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md). ### Testing Guidelines - [ ] Pull request includes test coverage for the included changes. --- src/promptflow-rag/promptflow/rag/__init__.py | 4 +- .../promptflow/rag/_build_mlindex.py | 90 ++++++++++-------- .../rag/_get_langchain_retriever.py | 14 +++ .../rag/{resources => config}/__init__.py | 0 .../_azure_ai_search_config.py | 19 ++++ .../_connection_config.py | 0 .../_embeddings_model_config.py | 21 ++++- .../{resources => config}/_index_config.py | 0 .../_index_data_source.py | 4 +- .../promptflow/rag/constants/_common.py | 3 + src/promptflow-rag/pyproject.toml | 92 +++++++++++++++++++ src/promptflow-rag/requirements.txt | 2 +- src/promptflow-rag/version.txt | 1 + 13 files changed, 205 insertions(+), 45 deletions(-) create mode 100644 src/promptflow-rag/promptflow/rag/_get_langchain_retriever.py rename src/promptflow-rag/promptflow/rag/{resources => config}/__init__.py (100%) rename src/promptflow-rag/promptflow/rag/{resources => config}/_azure_ai_search_config.py (58%) rename src/promptflow-rag/promptflow/rag/{resources => config}/_connection_config.py (100%) rename src/promptflow-rag/promptflow/rag/{resources => config}/_embeddings_model_config.py (59%) rename src/promptflow-rag/promptflow/rag/{resources => config}/_index_config.py (100%) rename src/promptflow-rag/promptflow/rag/{resources => config}/_index_data_source.py (97%) create mode 100644 src/promptflow-rag/pyproject.toml create mode 100644 src/promptflow-rag/version.txt diff --git a/src/promptflow-rag/promptflow/rag/__init__.py b/src/promptflow-rag/promptflow/rag/__init__.py index 0e4a08aeecc..0185e9c769e 100644 --- a/src/promptflow-rag/promptflow/rag/__init__.py +++ b/src/promptflow-rag/promptflow/rag/__init__.py @@ -5,7 +5,9 @@ __path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore from ._build_mlindex import build_index +from ._get_langchain_retriever import get_langchain_retriever_from_index __all__ = [ - "build_index" + "build_index", + "get_langchain_retriever_from_index" ] diff --git a/src/promptflow-rag/promptflow/rag/_build_mlindex.py b/src/promptflow-rag/promptflow/rag/_build_mlindex.py index cc2157b1547..4f40abcc46d 100644 --- a/src/promptflow-rag/promptflow/rag/_build_mlindex.py +++ b/src/promptflow-rag/promptflow/rag/_build_mlindex.py @@ -8,17 +8,18 @@ import yaml # type: ignore[import] from packaging import version + from promptflow.rag.constants._common import AZURE_AI_SEARCH_API_VERSION -from promptflow.rag.resources import EmbeddingsModelConfig, AzureAISearchConfig, AzureAISearchSource, LocalSource +from promptflow.rag.config import EmbeddingsModelConfig, AzureAISearchConfig, AzureAISearchSource, LocalSource from promptflow.rag._utils._open_ai_utils import build_open_ai_protocol def build_index( *, name: str, - vector_store: str, + vector_store: str = "azure_ai_search", input_source: Union[AzureAISearchSource, LocalSource], - index_config: AzureAISearchConfig, # todo better name? + index_config: Optional[AzureAISearchConfig] = None, # todo better name? embeddings_model_config: EmbeddingsModelConfig, data_source_url: Optional[str] = None, tokens_per_chunk: int = 1024, @@ -40,8 +41,8 @@ def build_index( :paramtype input_source: Union[AzureAISearchSource, LocalSource] :keyword index_config: The configuration for Azure Cognitive Search output. :paramtype index_config: AzureAISearchConfig - :keyword index_config: The configuration for AOAI embedding model. - :paramtype index_config: EmbeddingsModelConfig + :keyword embeddings_model_config: The configuration for embedding model. + :paramtype embeddings_model_config: EmbeddingsModelConfig :keyword data_source_url: The URL of the data source. :paramtype data_source_url: Optional[str] :keyword tokens_per_chunk: The size of each chunk. @@ -72,29 +73,39 @@ def build_index( ) raise e + is_serverless_connection = False if not embeddings_model_config.model_name: raise ValueError("Please specify embeddings_model_config.model_name") if "cohere" in embeddings_model_config.model_name: # If model uri is None, it is *considered* as a serverless endpoint for now. # TODO: depends on azureml.rag.Embeddings.from_uri to finalize a scheme for different embeddings - if not embeddings_model_config.connection_config: - raise ValueError("Please specify embeddings_model_config.connection_config to use cohere embedding models") + if not embeddings_model_config.connection_config and not embeddings_model_config.connection_id: + raise ValueError( + "Please specify connection_config or connection_id to use serverless connection" + ) embeddings_model_uri = None + is_serverless_connection = True + print("Using serverless connection.") else: embeddings_model_uri = build_open_ai_protocol( embeddings_model_config.deployment_name, embeddings_model_config.model_name ) + connection_id = embeddings_model_config.get_connection_id() - if vector_store == "azure_ai_search" and isinstance(input_source, AzureAISearchSource): + if isinstance(input_source, AzureAISearchSource): return _create_mlindex_from_existing_ai_search( # TODO: Fix Bug 2818331 - embedding_model=embeddings_model_config.embeddings_model, + name=name, embedding_model_uri=embeddings_model_uri, - connection_id=embeddings_model_config.connection_config.build_connection_id(), + is_serverless_connection=is_serverless_connection, + connection_id=connection_id, ai_search_config=input_source, ) + + if not index_config: + raise ValueError("Please provide index_config details") embeddings_cache_path = str(Path(embeddings_cache_path) if embeddings_cache_path else Path.cwd()) save_path = str(Path(embeddings_cache_path) / f"{name}-mlindex") splitter_args = {"chunk_size": tokens_per_chunk, "chunk_overlap": token_overlap_across_chunks, "use_rcts": True} @@ -103,6 +114,7 @@ def build_index( if chunk_prepend_summary is not None: splitter_args["chunk_preprend_summary"] = chunk_prepend_summary + print(f"Crack and chunk files from local path: {input_source.input_data_path}") chunked_docs = DocumentChunksIterator( files_source=input_source.input_data_path, glob=input_glob, @@ -118,8 +130,7 @@ def build_index( connection_args = {} if embeddings_model_uri and "open_ai" in embeddings_model_uri: - if embeddings_model_config.connection_config: - connection_id = embeddings_model_config.connection_config.build_connection_id() + if connection_id: aoai_connection = get_connection_by_id_v2(connection_id) if isinstance(aoai_connection, dict): if "properties" in aoai_connection and "target" in aoai_connection["properties"]: @@ -133,6 +144,7 @@ def build_index( "connection": {"id": connection_id}, "endpoint": endpoint, } + print(f"Start embedding using connection with id = {connection_id}") else: import openai import os @@ -147,23 +159,16 @@ def build_index( "connection": {"key": api_key}, "endpoint": os.getenv(api_base), } + print("Start embedding using api_key and api_base from environment variables.") embedder = EmbeddingsContainer.from_uri( embeddings_model_uri, **connection_args, ) - elif not embeddings_model_uri: - # cohere connection doesn't support environment variables yet - # import os - # api_key = "SERVERLESS_CONNECTION_KEY" - # api_base = "SERVERLESS_CONNECTION_ENDPOINT" - # connection_args = { - # "connection_type": "environment", - # "connection": {"key": api_key}, - # "endpoint": os.getenv(api_base), - # } + elif is_serverless_connection: + print(f"Start embedding using serverless connection with id = {connection_id}.") connection_args = { "connection_type": "workspace_connection", - "connection": {"id": embeddings_model_config.connection_config.build_connection_id()}, + "connection": {"id": connection_id}, } embedder = EmbeddingsContainer.from_uri(None, credential=None, **connection_args) else: @@ -177,7 +182,8 @@ def build_index( ai_search_args = { "index_name": index_config.ai_search_index_name, } - if not index_config.ai_search_connection_config: + ai_search_connection_id = index_config.get_connection_id() + if not ai_search_connection_id: import os ai_search_args = { @@ -191,9 +197,9 @@ def build_index( } connection_args = {"connection_type": "environment", "connection": {"key": "AZURE_AI_SEARCH_KEY"}} else: - connection_id = index_config.ai_search_connection_config.build_connection_id() - ai_search_connection = get_connection_by_id_v2(connection_id) + ai_search_connection = get_connection_by_id_v2(ai_search_connection_id) if isinstance(ai_search_connection, dict): + endpoint = ai_search_connection["properties"]["target"] ai_search_args = { **ai_search_args, **{ @@ -205,6 +211,7 @@ def build_index( } elif ai_search_connection.target: api_version = AZURE_AI_SEARCH_API_VERSION + endpoint = ai_search_connection.target if ai_search_connection.tags and "ApiVersion" in ai_search_connection.tags: api_version = ai_search_connection.tags["ApiVersion"] ai_search_args = { @@ -218,23 +225,26 @@ def build_index( raise ValueError("Cannot get target from ai search connection") connection_args = { "connection_type": "workspace_connection", - "connection": {"id": connection_id}, + "connection": {"id": ai_search_connection_id}, + "endpoint": endpoint, } + print("Start creating index from embeddings.") create_index_from_raw_embeddings( emb=embedder, acs_config=ai_search_args, connection=connection_args, output_path=save_path, ) - + print(f"Successfully created index at {save_path}") return save_path def _create_mlindex_from_existing_ai_search( - embedding_model: str, + name: str, embedding_model_uri: Optional[str], connection_id: Optional[str], + is_serverless_connection: bool, ai_search_config: AzureAISearchSource, ) -> str: try: @@ -242,7 +252,7 @@ def _create_mlindex_from_existing_ai_search( from azureml.rag.utils.connections import get_connection_by_id_v2 except ImportError as e: print( - "In order to use build_index to build an Index locally, you must have azure-ai-generative[index] installed" + "In order to use build_index to build an Index locally, you must have azureml.rag installed" ) raise e mlindex_config = {} @@ -259,8 +269,14 @@ def _create_mlindex_from_existing_ai_search( } else: ai_search_connection = get_connection_by_id_v2(ai_search_config.ai_search_connection_id) + if isinstance(ai_search_connection, dict): + endpoint = ai_search_connection["properties"]["target"] + elif ai_search_connection.target: + endpoint = ai_search_connection.target + else: + raise ValueError("Cannot get target from ai search connection") connection_info = { - "endpoint": ai_search_connection["properties"]["target"], + "endpoint": endpoint, "connection_type": "workspace_connection", "connection": { "id": ai_search_config.ai_search_connection_id, @@ -284,14 +300,7 @@ def _create_mlindex_from_existing_ai_search( mlindex_config["index"]["field_mapping"]["metadata"] = ai_search_config.ai_search_metadata_key model_connection_args: Dict[str, Optional[Union[str, Dict]]] - if "cohere" in embedding_model: - # api_key = "SERVERLESS_CONNECTION_KEY" - # api_base = "SERVERLESS_CONNECTION_ENDPOINT" - # connection_args = { - # "connection_type": "environment", - # "connection": {"key": api_key}, - # "endpoint": os.getenv(api_base), - # } + if is_serverless_connection: connection_args = { "connection_type": "workspace_connection", "connection": {"id": connection_id}, @@ -309,10 +318,11 @@ def _create_mlindex_from_existing_ai_search( embedding = EmbeddingsContainer.from_uri(embedding_model_uri, credential=None, **model_connection_args) mlindex_config["embeddings"] = embedding.get_metadata() - path = Path.cwd() / f"import-ai_search-{ai_search_config.ai_search_index_name}-mlindex" + path = Path.cwd() / f"{name}-mlindex" path.mkdir(exist_ok=True) with open(path / "MLIndex", "w", encoding="utf-8") as f: yaml.dump(mlindex_config, f) + print(f"Successfully created index at {path}") return path diff --git a/src/promptflow-rag/promptflow/rag/_get_langchain_retriever.py b/src/promptflow-rag/promptflow/rag/_get_langchain_retriever.py new file mode 100644 index 00000000000..4785f268edd --- /dev/null +++ b/src/promptflow-rag/promptflow/rag/_get_langchain_retriever.py @@ -0,0 +1,14 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from azureml.rag.mlindex import MLIndex +from promptflow.rag.constants._common import STORAGE_URI_TO_MLINDEX_PATH_FORMAT +import re + + +def get_langchain_retriever_from_index(path: str): + if not re.match(STORAGE_URI_TO_MLINDEX_PATH_FORMAT, path): + raise ValueError( + "Path to MLIndex file doesn't have the correct format." + ) + return MLIndex(path).as_langchain_retriever() diff --git a/src/promptflow-rag/promptflow/rag/resources/__init__.py b/src/promptflow-rag/promptflow/rag/config/__init__.py similarity index 100% rename from src/promptflow-rag/promptflow/rag/resources/__init__.py rename to src/promptflow-rag/promptflow/rag/config/__init__.py diff --git a/src/promptflow-rag/promptflow/rag/resources/_azure_ai_search_config.py b/src/promptflow-rag/promptflow/rag/config/_azure_ai_search_config.py similarity index 58% rename from src/promptflow-rag/promptflow/rag/resources/_azure_ai_search_config.py rename to src/promptflow-rag/promptflow/rag/config/_azure_ai_search_config.py index f94b4530df4..09dcf3ebb18 100644 --- a/src/promptflow-rag/promptflow/rag/resources/_azure_ai_search_config.py +++ b/src/promptflow-rag/promptflow/rag/config/_azure_ai_search_config.py @@ -8,6 +8,7 @@ # Defines stuff related to the resulting created index, like the index type. from typing import Optional +from promptflow.rag.constants._common import CONNECTION_ID_FORMAT from ._connection_config import ConnectionConfig @@ -18,6 +19,8 @@ class AzureAISearchConfig: :type ai_search_index_name: Optional[str] :param ai_search_connection_id: The Azure AI Search connection Config. :type ai_search_connection_config: Optional[ConnectionConfig] + :param ai_search_connection_id: The name of the Azure AI Search index. + :type connection_id: Optional[str] """ def __init__( @@ -25,6 +28,22 @@ def __init__( *, ai_search_index_name: Optional[str] = None, ai_search_connection_config: Optional[ConnectionConfig] = None, + connection_id: Optional[str] = None, ) -> None: self.ai_search_index_name = ai_search_index_name self.ai_search_connection_config = ai_search_connection_config + self.connection_id = connection_id + + def get_connection_id(self) -> Optional[str]: + """Get connection id from connection config or connection id""" + import re + + if self.connection_id: + if not re.match(CONNECTION_ID_FORMAT, self.connection_id): + raise ValueError( + "Your connection id doesn't have the correct format" + ) + return self.connection_id + if self.ai_search_connection_config: + return self.ai_search_connection_config.build_connection_id() + return None diff --git a/src/promptflow-rag/promptflow/rag/resources/_connection_config.py b/src/promptflow-rag/promptflow/rag/config/_connection_config.py similarity index 100% rename from src/promptflow-rag/promptflow/rag/resources/_connection_config.py rename to src/promptflow-rag/promptflow/rag/config/_connection_config.py diff --git a/src/promptflow-rag/promptflow/rag/resources/_embeddings_model_config.py b/src/promptflow-rag/promptflow/rag/config/_embeddings_model_config.py similarity index 59% rename from src/promptflow-rag/promptflow/rag/resources/_embeddings_model_config.py rename to src/promptflow-rag/promptflow/rag/config/_embeddings_model_config.py index fb15b9ef606..91bb871b6a3 100644 --- a/src/promptflow-rag/promptflow/rag/resources/_embeddings_model_config.py +++ b/src/promptflow-rag/promptflow/rag/config/_embeddings_model_config.py @@ -9,6 +9,7 @@ from typing import Optional from ._connection_config import ConnectionConfig +from promptflow.rag.constants._common import CONNECTION_ID_FORMAT class EmbeddingsModelConfig: @@ -17,7 +18,9 @@ class EmbeddingsModelConfig: :param model_name: The name of the embedding model. :type model_name: Optional[str] :param deployment_name: The deployment_name for the embedding model. - :type deployment_name: Optional[ConnectionConfig] + :type deployment_name: Optional[str] + :param connection_id: The connection id for the embedding model. + :type connection_id: Optional[str] :param connection_config: The connection configuration for the embedding model. :type connection_config: Optional[ConnectionConfig] """ @@ -27,8 +30,24 @@ def __init__( *, model_name: Optional[str] = None, deployment_name: Optional[str] = None, + connection_id: Optional[str] = None, connection_config: Optional[ConnectionConfig] = None, ) -> None: self.model_name = model_name self.deployment_name = deployment_name + self.connection_id = connection_id self.connection_config = connection_config + + def get_connection_id(self) -> Optional[str]: + """Get connection id from connection config or connection id""" + import re + + if self.connection_id: + if not re.match(CONNECTION_ID_FORMAT, self.connection_id): + raise ValueError( + "Your connection id doesn't have the correct format" + ) + return self.connection_id + if self.connection_config: + return self.connection_config.build_connection_id() + return None diff --git a/src/promptflow-rag/promptflow/rag/resources/_index_config.py b/src/promptflow-rag/promptflow/rag/config/_index_config.py similarity index 100% rename from src/promptflow-rag/promptflow/rag/resources/_index_config.py rename to src/promptflow-rag/promptflow/rag/config/_index_config.py diff --git a/src/promptflow-rag/promptflow/rag/resources/_index_data_source.py b/src/promptflow-rag/promptflow/rag/config/_index_data_source.py similarity index 97% rename from src/promptflow-rag/promptflow/rag/resources/_index_data_source.py rename to src/promptflow-rag/promptflow/rag/config/_index_data_source.py index d6062272ce1..2da8be9e7ef 100644 --- a/src/promptflow-rag/promptflow/rag/resources/_index_data_source.py +++ b/src/promptflow-rag/promptflow/rag/config/_index_data_source.py @@ -1,7 +1,7 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -from typing import Union +from typing import Optional, Union from promptflow.rag.constants import IndexInputType @@ -46,7 +46,7 @@ def __init__( ai_search_embedding_key: str, ai_search_title_key: str, ai_search_metadata_key: str, - ai_search_connection_id: str, + ai_search_connection_id: Optional[str] = None, num_docs_to_import: int = 50, ): self.ai_search_index_name = ai_search_index_name diff --git a/src/promptflow-rag/promptflow/rag/constants/_common.py b/src/promptflow-rag/promptflow/rag/constants/_common.py index 4a78382c7ed..d1e6ad53a11 100644 --- a/src/promptflow-rag/promptflow/rag/constants/_common.py +++ b/src/promptflow-rag/promptflow/rag/constants/_common.py @@ -5,6 +5,9 @@ AZURE_AI_SEARCH_API_VERSION = "2023-07-01-preview" OPEN_AI_PROTOCOL_TEMPLATE = "azure_open_ai://deployment/{}/model/{}" CONNECTION_ID_TEMPLATE = "/subscriptions/{}/resourceGroups/{}/providers/Microsoft.MachineLearningServices/workspaces/{}/connections/{}" # noqa: E501 +CONNECTION_ID_FORMAT = CONNECTION_ID_TEMPLATE.format(".*", ".*", ".*", ".*") +STORAGE_URI_TO_MLINDEX_PATH_TEMPLATE = "azureml://subscriptions/{}/resourcegroups/{}/workspaces/{}/datastores/{}/paths/{}" # noqa: E501 +STORAGE_URI_TO_MLINDEX_PATH_FORMAT = STORAGE_URI_TO_MLINDEX_PATH_TEMPLATE.format(".*", ".*", ".*", ".*", ".*") class IndexInputType(object): diff --git a/src/promptflow-rag/pyproject.toml b/src/promptflow-rag/pyproject.toml new file mode 100644 index 00000000000..b40fa2763ed --- /dev/null +++ b/src/promptflow-rag/pyproject.toml @@ -0,0 +1,92 @@ +# dummpy toml file, will be replaced by setup.py during release +# poetry +[tool.poetry] +name = "promptflow-rag" +version = "0.2.0.dev0" +description = "Prompt flow RAG" +license = "MIT" +authors = [ + "Microsoft Corporation " +] +repository = "https://github.com/microsoft/promptflow" +homepage = "https://microsoft.github.io/promptflow/" +readme = ["README.md"] +keywords = ["telemetry"] +classifiers = [ + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", +] +packages = [ + { include = "promptflow" } +] + +[tool.poetry.urls] +"Bug Reports" = "https://github.com/microsoft/promptflow/issues" + +# dependencies +[tool.poetry.dependencies] +python = "<4.0,>=3.8" +azureml-rag[cognitive_search,document_parsing] +openai + +[tool.poetry.group.dev.dependencies] +pre-commit = "*" +import-linter = "*" + +[tool.poetry.group.test.dependencies] +pytest = "*" +pytest-asyncio = "*" +pytest-cov = "*" +pytest-mock = "*" +pytest-xdist = "*" + +# test: pytest and coverage +[tool.pytest.ini_options] +markers = [ + "unittest", + "e2etest", +] +# junit - analyse and publish test results (https://github.com/EnricoMi/publish-unit-test-result-action) +# durations - list the slowest test durations +addopts = """ +--junit-xml=test-results.xml \ +--dist loadfile \ +--log-level=info \ +--log-format="%(asctime)s %(levelname)s %(message)s" \ +--log-date-format="[%Y-%m-%d %H:%M:%S]" \ +--durations=5 \ +-ra \ +-vv +""" +env = [ +] +testpaths = ["tests"] + +[tool.coverage.run] +concurrency = ["multiprocessing"] +source = ["promptflow"] +omit = [ + "__init__.py", +] + +[tool.black] +line-length = 120 + +# import linter +# reference: https://pypi.org/project/import-linter/ +[tool.importlinter] +root_package = "promptflow" +include_external_packages = "True" + +[[tool.importlinter.contracts]] +name = "Contract forbidden modules" +type = "forbidden" +source_modules = ["promptflow.rag"] +forbidden_modules = [] diff --git a/src/promptflow-rag/requirements.txt b/src/promptflow-rag/requirements.txt index 19cf310c4f1..29a77e58f9c 100644 --- a/src/promptflow-rag/requirements.txt +++ b/src/promptflow-rag/requirements.txt @@ -1,2 +1,2 @@ -azureml-rag[cognitive_search,document_parsing] +azureml-rag[cognitive_search,document_parsing,langchain] openai diff --git a/src/promptflow-rag/version.txt b/src/promptflow-rag/version.txt new file mode 100644 index 00000000000..8452eb1e1e5 --- /dev/null +++ b/src/promptflow-rag/version.txt @@ -0,0 +1 @@ +VERSION = "0.1.0b1" \ No newline at end of file