Skip to content

Commit

Permalink
Improve serverless handling and AzureAISearchSource as input (#3071)
Browse files Browse the repository at this point in the history
# Description

Please add an informative description that covers that changes made by
the pull request and link all relevant issues.

Sample notebook updated accordingly:
https://github.com/microsoft/promptflow/blob/9de72467d945112ec1e91ab323e61a6658333ead/src/promptflow-rag/build_index_sample.ipynb

# All Promptflow Contribution checklist:
- [x] **The pull request does not introduce [breaking changes].**
- [x] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [x] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [x] Title of the pull request is clear and informative.
- [x] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.
  • Loading branch information
jingyizhu99 authored May 7, 2024
1 parent 34d363b commit 0d983e2
Show file tree
Hide file tree
Showing 13 changed files with 205 additions and 45 deletions.
4 changes: 3 additions & 1 deletion src/promptflow-rag/promptflow/rag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore

from ._build_mlindex import build_index
from ._get_langchain_retriever import get_langchain_retriever_from_index

__all__ = [
"build_index"
"build_index",
"get_langchain_retriever_from_index"
]
90 changes: 50 additions & 40 deletions src/promptflow-rag/promptflow/rag/_build_mlindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,18 @@
import yaml # type: ignore[import]
from packaging import version


from promptflow.rag.constants._common import AZURE_AI_SEARCH_API_VERSION
from promptflow.rag.resources import EmbeddingsModelConfig, AzureAISearchConfig, AzureAISearchSource, LocalSource
from promptflow.rag.config import EmbeddingsModelConfig, AzureAISearchConfig, AzureAISearchSource, LocalSource
from promptflow.rag._utils._open_ai_utils import build_open_ai_protocol


def build_index(
*,
name: str,
vector_store: str,
vector_store: str = "azure_ai_search",
input_source: Union[AzureAISearchSource, LocalSource],
index_config: AzureAISearchConfig, # todo better name?
index_config: Optional[AzureAISearchConfig] = None, # todo better name?
embeddings_model_config: EmbeddingsModelConfig,
data_source_url: Optional[str] = None,
tokens_per_chunk: int = 1024,
Expand All @@ -40,8 +41,8 @@ def build_index(
:paramtype input_source: Union[AzureAISearchSource, LocalSource]
:keyword index_config: The configuration for Azure Cognitive Search output.
:paramtype index_config: AzureAISearchConfig
:keyword index_config: The configuration for AOAI embedding model.
:paramtype index_config: EmbeddingsModelConfig
:keyword embeddings_model_config: The configuration for embedding model.
:paramtype embeddings_model_config: EmbeddingsModelConfig
:keyword data_source_url: The URL of the data source.
:paramtype data_source_url: Optional[str]
:keyword tokens_per_chunk: The size of each chunk.
Expand Down Expand Up @@ -72,29 +73,39 @@ def build_index(
)
raise e

is_serverless_connection = False
if not embeddings_model_config.model_name:
raise ValueError("Please specify embeddings_model_config.model_name")

if "cohere" in embeddings_model_config.model_name:
# If model uri is None, it is *considered* as a serverless endpoint for now.
# TODO: depends on azureml.rag.Embeddings.from_uri to finalize a scheme for different embeddings
if not embeddings_model_config.connection_config:
raise ValueError("Please specify embeddings_model_config.connection_config to use cohere embedding models")
if not embeddings_model_config.connection_config and not embeddings_model_config.connection_id:
raise ValueError(
"Please specify connection_config or connection_id to use serverless connection"
)
embeddings_model_uri = None
is_serverless_connection = True
print("Using serverless connection.")
else:
embeddings_model_uri = build_open_ai_protocol(
embeddings_model_config.deployment_name,
embeddings_model_config.model_name
)
connection_id = embeddings_model_config.get_connection_id()

if vector_store == "azure_ai_search" and isinstance(input_source, AzureAISearchSource):
if isinstance(input_source, AzureAISearchSource):
return _create_mlindex_from_existing_ai_search(
# TODO: Fix Bug 2818331
embedding_model=embeddings_model_config.embeddings_model,
name=name,
embedding_model_uri=embeddings_model_uri,
connection_id=embeddings_model_config.connection_config.build_connection_id(),
is_serverless_connection=is_serverless_connection,
connection_id=connection_id,
ai_search_config=input_source,
)

if not index_config:
raise ValueError("Please provide index_config details")
embeddings_cache_path = str(Path(embeddings_cache_path) if embeddings_cache_path else Path.cwd())
save_path = str(Path(embeddings_cache_path) / f"{name}-mlindex")
splitter_args = {"chunk_size": tokens_per_chunk, "chunk_overlap": token_overlap_across_chunks, "use_rcts": True}
Expand All @@ -103,6 +114,7 @@ def build_index(
if chunk_prepend_summary is not None:
splitter_args["chunk_preprend_summary"] = chunk_prepend_summary

print(f"Crack and chunk files from local path: {input_source.input_data_path}")
chunked_docs = DocumentChunksIterator(
files_source=input_source.input_data_path,
glob=input_glob,
Expand All @@ -118,8 +130,7 @@ def build_index(

connection_args = {}
if embeddings_model_uri and "open_ai" in embeddings_model_uri:
if embeddings_model_config.connection_config:
connection_id = embeddings_model_config.connection_config.build_connection_id()
if connection_id:
aoai_connection = get_connection_by_id_v2(connection_id)
if isinstance(aoai_connection, dict):
if "properties" in aoai_connection and "target" in aoai_connection["properties"]:
Expand All @@ -133,6 +144,7 @@ def build_index(
"connection": {"id": connection_id},
"endpoint": endpoint,
}
print(f"Start embedding using connection with id = {connection_id}")
else:
import openai
import os
Expand All @@ -147,23 +159,16 @@ def build_index(
"connection": {"key": api_key},
"endpoint": os.getenv(api_base),
}
print("Start embedding using api_key and api_base from environment variables.")
embedder = EmbeddingsContainer.from_uri(
embeddings_model_uri,
**connection_args,
)
elif not embeddings_model_uri:
# cohere connection doesn't support environment variables yet
# import os
# api_key = "SERVERLESS_CONNECTION_KEY"
# api_base = "SERVERLESS_CONNECTION_ENDPOINT"
# connection_args = {
# "connection_type": "environment",
# "connection": {"key": api_key},
# "endpoint": os.getenv(api_base),
# }
elif is_serverless_connection:
print(f"Start embedding using serverless connection with id = {connection_id}.")
connection_args = {
"connection_type": "workspace_connection",
"connection": {"id": embeddings_model_config.connection_config.build_connection_id()},
"connection": {"id": connection_id},
}
embedder = EmbeddingsContainer.from_uri(None, credential=None, **connection_args)
else:
Expand All @@ -177,7 +182,8 @@ def build_index(
ai_search_args = {
"index_name": index_config.ai_search_index_name,
}
if not index_config.ai_search_connection_config:
ai_search_connection_id = index_config.get_connection_id()
if not ai_search_connection_id:
import os

ai_search_args = {
Expand All @@ -191,9 +197,9 @@ def build_index(
}
connection_args = {"connection_type": "environment", "connection": {"key": "AZURE_AI_SEARCH_KEY"}}
else:
connection_id = index_config.ai_search_connection_config.build_connection_id()
ai_search_connection = get_connection_by_id_v2(connection_id)
ai_search_connection = get_connection_by_id_v2(ai_search_connection_id)
if isinstance(ai_search_connection, dict):
endpoint = ai_search_connection["properties"]["target"]
ai_search_args = {
**ai_search_args,
**{
Expand All @@ -205,6 +211,7 @@ def build_index(
}
elif ai_search_connection.target:
api_version = AZURE_AI_SEARCH_API_VERSION
endpoint = ai_search_connection.target
if ai_search_connection.tags and "ApiVersion" in ai_search_connection.tags:
api_version = ai_search_connection.tags["ApiVersion"]
ai_search_args = {
Expand All @@ -218,31 +225,34 @@ def build_index(
raise ValueError("Cannot get target from ai search connection")
connection_args = {
"connection_type": "workspace_connection",
"connection": {"id": connection_id},
"connection": {"id": ai_search_connection_id},
"endpoint": endpoint,
}

print("Start creating index from embeddings.")
create_index_from_raw_embeddings(
emb=embedder,
acs_config=ai_search_args,
connection=connection_args,
output_path=save_path,
)

print(f"Successfully created index at {save_path}")
return save_path


def _create_mlindex_from_existing_ai_search(
embedding_model: str,
name: str,
embedding_model_uri: Optional[str],
connection_id: Optional[str],
is_serverless_connection: bool,
ai_search_config: AzureAISearchSource,
) -> str:
try:
from azureml.rag.embeddings import EmbeddingsContainer
from azureml.rag.utils.connections import get_connection_by_id_v2
except ImportError as e:
print(
"In order to use build_index to build an Index locally, you must have azure-ai-generative[index] installed"
"In order to use build_index to build an Index locally, you must have azureml.rag installed"
)
raise e
mlindex_config = {}
Expand All @@ -259,8 +269,14 @@ def _create_mlindex_from_existing_ai_search(
}
else:
ai_search_connection = get_connection_by_id_v2(ai_search_config.ai_search_connection_id)
if isinstance(ai_search_connection, dict):
endpoint = ai_search_connection["properties"]["target"]
elif ai_search_connection.target:
endpoint = ai_search_connection.target
else:
raise ValueError("Cannot get target from ai search connection")
connection_info = {
"endpoint": ai_search_connection["properties"]["target"],
"endpoint": endpoint,
"connection_type": "workspace_connection",
"connection": {
"id": ai_search_config.ai_search_connection_id,
Expand All @@ -284,14 +300,7 @@ def _create_mlindex_from_existing_ai_search(
mlindex_config["index"]["field_mapping"]["metadata"] = ai_search_config.ai_search_metadata_key

model_connection_args: Dict[str, Optional[Union[str, Dict]]]
if "cohere" in embedding_model:
# api_key = "SERVERLESS_CONNECTION_KEY"
# api_base = "SERVERLESS_CONNECTION_ENDPOINT"
# connection_args = {
# "connection_type": "environment",
# "connection": {"key": api_key},
# "endpoint": os.getenv(api_base),
# }
if is_serverless_connection:
connection_args = {
"connection_type": "workspace_connection",
"connection": {"id": connection_id},
Expand All @@ -309,10 +318,11 @@ def _create_mlindex_from_existing_ai_search(
embedding = EmbeddingsContainer.from_uri(embedding_model_uri, credential=None, **model_connection_args)
mlindex_config["embeddings"] = embedding.get_metadata()

path = Path.cwd() / f"import-ai_search-{ai_search_config.ai_search_index_name}-mlindex"
path = Path.cwd() / f"{name}-mlindex"

path.mkdir(exist_ok=True)
with open(path / "MLIndex", "w", encoding="utf-8") as f:
yaml.dump(mlindex_config, f)

print(f"Successfully created index at {path}")
return path
14 changes: 14 additions & 0 deletions src/promptflow-rag/promptflow/rag/_get_langchain_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
from azureml.rag.mlindex import MLIndex
from promptflow.rag.constants._common import STORAGE_URI_TO_MLINDEX_PATH_FORMAT
import re


def get_langchain_retriever_from_index(path: str):
if not re.match(STORAGE_URI_TO_MLINDEX_PATH_FORMAT, path):
raise ValueError(
"Path to MLIndex file doesn't have the correct format."
)
return MLIndex(path).as_langchain_retriever()
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# Defines stuff related to the resulting created index, like the index type.

from typing import Optional
from promptflow.rag.constants._common import CONNECTION_ID_FORMAT
from ._connection_config import ConnectionConfig


Expand All @@ -18,13 +19,31 @@ class AzureAISearchConfig:
:type ai_search_index_name: Optional[str]
:param ai_search_connection_id: The Azure AI Search connection Config.
:type ai_search_connection_config: Optional[ConnectionConfig]
:param ai_search_connection_id: The name of the Azure AI Search index.
:type connection_id: Optional[str]
"""

def __init__(
self,
*,
ai_search_index_name: Optional[str] = None,
ai_search_connection_config: Optional[ConnectionConfig] = None,
connection_id: Optional[str] = None,
) -> None:
self.ai_search_index_name = ai_search_index_name
self.ai_search_connection_config = ai_search_connection_config
self.connection_id = connection_id

def get_connection_id(self) -> Optional[str]:
"""Get connection id from connection config or connection id"""
import re

if self.connection_id:
if not re.match(CONNECTION_ID_FORMAT, self.connection_id):
raise ValueError(
"Your connection id doesn't have the correct format"
)
return self.connection_id
if self.ai_search_connection_config:
return self.ai_search_connection_config.build_connection_id()
return None
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from typing import Optional
from ._connection_config import ConnectionConfig
from promptflow.rag.constants._common import CONNECTION_ID_FORMAT


class EmbeddingsModelConfig:
Expand All @@ -17,7 +18,9 @@ class EmbeddingsModelConfig:
:param model_name: The name of the embedding model.
:type model_name: Optional[str]
:param deployment_name: The deployment_name for the embedding model.
:type deployment_name: Optional[ConnectionConfig]
:type deployment_name: Optional[str]
:param connection_id: The connection id for the embedding model.
:type connection_id: Optional[str]
:param connection_config: The connection configuration for the embedding model.
:type connection_config: Optional[ConnectionConfig]
"""
Expand All @@ -27,8 +30,24 @@ def __init__(
*,
model_name: Optional[str] = None,
deployment_name: Optional[str] = None,
connection_id: Optional[str] = None,
connection_config: Optional[ConnectionConfig] = None,
) -> None:
self.model_name = model_name
self.deployment_name = deployment_name
self.connection_id = connection_id
self.connection_config = connection_config

def get_connection_id(self) -> Optional[str]:
"""Get connection id from connection config or connection id"""
import re

if self.connection_id:
if not re.match(CONNECTION_ID_FORMAT, self.connection_id):
raise ValueError(
"Your connection id doesn't have the correct format"
)
return self.connection_id
if self.connection_config:
return self.connection_config.build_connection_id()
return None
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
from typing import Union
from typing import Optional, Union

from promptflow.rag.constants import IndexInputType

Expand Down Expand Up @@ -46,7 +46,7 @@ def __init__(
ai_search_embedding_key: str,
ai_search_title_key: str,
ai_search_metadata_key: str,
ai_search_connection_id: str,
ai_search_connection_id: Optional[str] = None,
num_docs_to_import: int = 50,
):
self.ai_search_index_name = ai_search_index_name
Expand Down
3 changes: 3 additions & 0 deletions src/promptflow-rag/promptflow/rag/constants/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
AZURE_AI_SEARCH_API_VERSION = "2023-07-01-preview"
OPEN_AI_PROTOCOL_TEMPLATE = "azure_open_ai://deployment/{}/model/{}"
CONNECTION_ID_TEMPLATE = "/subscriptions/{}/resourceGroups/{}/providers/Microsoft.MachineLearningServices/workspaces/{}/connections/{}" # noqa: E501
CONNECTION_ID_FORMAT = CONNECTION_ID_TEMPLATE.format(".*", ".*", ".*", ".*")
STORAGE_URI_TO_MLINDEX_PATH_TEMPLATE = "azureml://subscriptions/{}/resourcegroups/{}/workspaces/{}/datastores/{}/paths/{}" # noqa: E501
STORAGE_URI_TO_MLINDEX_PATH_FORMAT = STORAGE_URI_TO_MLINDEX_PATH_TEMPLATE.format(".*", ".*", ".*", ".*", ".*")


class IndexInputType(object):
Expand Down
Loading

0 comments on commit 0d983e2

Please sign in to comment.