diff --git a/README.md b/README.md index 77f5d432..a6c9c165 100644 --- a/README.md +++ b/README.md @@ -76,6 +76,14 @@ You can sign up for an OpenAI account [here](https://platform.openai.com/) and m Important: Free-tier OpenAI accounts may be subject to rate limits, which could affect AG-A's performance. We recommend using a paid OpenAI API key for seamless functionality. +#### Azure OpenAI Setup +To use Azure OpenAI, you'll need to set the following Azure OpenAI values, as environment variables: +```bash +export AZURE_OPENAI_API_KEY=<...> +export OPENAI_API_VERSION=<...> +export AZURE_OPENAI_ENDPOINT=<...> +``` + ## Usage We support two ways of using AutoGluon Assistant: WebUI and CLI. diff --git a/src/autogluon/assistant/assistant.py b/src/autogluon/assistant/assistant.py index 708b1dd6..7ecfff38 100644 --- a/src/autogluon/assistant/assistant.py +++ b/src/autogluon/assistant/assistant.py @@ -8,7 +8,7 @@ from hydra.utils import instantiate from omegaconf import DictConfig, OmegaConf -from autogluon.assistant.llm import AssistantChatBedrock, AssistantChatOpenAI, LLMFactory +from autogluon.assistant.llm import AssistantAzureChatOpenAI, AssistantChatBedrock, AssistantChatOpenAI, LLMFactory from .predictor import AutogluonTabularPredictor from .task import TabularPredictionTask @@ -55,7 +55,9 @@ class TabularPredictionAssistant: def __init__(self, config: DictConfig) -> None: self.config = config - self.llm: Union[AssistantChatOpenAI, AssistantChatBedrock] = LLMFactory.get_chat_model(config.llm) + self.llm: Union[AssistantChatOpenAI, AssistantAzureChatOpenAI, AssistantChatBedrock] = ( + LLMFactory.get_chat_model(config.llm) + ) self.predictor = AutogluonTabularPredictor(config.autogluon) self.feature_transformers_config = get_feature_transformers_config(config) diff --git a/src/autogluon/assistant/llm/__init__.py b/src/autogluon/assistant/llm/__init__.py index e4c3210d..d029dc51 100644 --- a/src/autogluon/assistant/llm/__init__.py +++ b/src/autogluon/assistant/llm/__init__.py @@ -1,6 +1,7 @@ -from .llm import AssistantChatBedrock, AssistantChatOpenAI, LLMFactory +from .llm import AssistantAzureChatOpenAI, AssistantChatBedrock, AssistantChatOpenAI, LLMFactory __all__ = [ + "AssistantAzureChatOpenAI", "AssistantChatOpenAI", "AssistantChatBedrock", "LLMFactory", diff --git a/src/autogluon/assistant/llm/llm.py b/src/autogluon/assistant/llm/llm.py index f35b80fe..f054f939 100644 --- a/src/autogluon/assistant/llm/llm.py +++ b/src/autogluon/assistant/llm/llm.py @@ -7,9 +7,9 @@ import botocore from langchain.schema import AIMessage, BaseMessage from langchain_aws import ChatBedrock -from langchain_openai import ChatOpenAI +from langchain_openai import AzureChatOpenAI, ChatOpenAI from omegaconf import DictConfig -from openai import OpenAI +from openai import AzureOpenAI, OpenAI from pydantic import BaseModel, Field from tenacity import retry, stop_after_attempt, wait_exponential @@ -57,6 +57,45 @@ def invoke(self, *args, **kwargs): return response +class AssistantAzureChatOpenAI(AzureChatOpenAI, BaseModel): + """ + AssistantAzureChatOpenAI is a subclass of AzureChatOpenAI that traces the input and output of the model. + """ + + history_: List[Dict[str, Any]] = Field(default_factory=list) + input_: int = Field(default=0) + output_: int = Field(default=0) + + def describe(self) -> Dict[str, Any]: + return { + "model": self.model_name, + "proxy": self.openai_proxy, + "history": self.history_, + "prompt_tokens": self.input_, + "completion_tokens": self.output_, + } + + @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10)) + def invoke(self, *args, **kwargs): + input_: List[BaseMessage] = args[0] + response = super().invoke(*args, **kwargs) + + # Update token usage + if isinstance(response, AIMessage) and response.usage_metadata: + self.input_ += response.usage_metadata.get("input_tokens", 0) + self.output_ += response.usage_metadata.get("output_tokens", 0) + + self.history_.append( + { + "input": [{"type": msg.type, "content": msg.content} for msg in input_], + "output": pprint.pformat(dict(response)), + "prompt_tokens": self.input_, + "completion_tokens": self.output_, + } + ) + return response + + class AssistantChatBedrock(ChatBedrock, BaseModel): """ AssistantChatBedrock is a subclass of ChatBedrock that traces the input and output of the model. @@ -123,9 +162,21 @@ def get_bedrock_models() -> List[str]: print(f"Error fetching Bedrock models: {e}") return [] + @staticmethod + def get_azure_models() -> List[str]: + try: + client = AzureOpenAI() + models = client.models.list() + return [model.id for model in models if model.id.startswith(("gpt-3.5", "gpt-4"))] + except Exception as e: + print(f"Error fetching Azure models: {e}") + return [] + @classmethod def get_valid_models(cls, provider): - if provider == "openai": + if provider == "azure": + return cls.get_azure_models() + elif provider == "openai": return cls.get_openai_models() elif provider == "bedrock": model_names = cls.get_bedrock_models() @@ -136,17 +187,41 @@ def get_valid_models(cls, provider): @classmethod def get_valid_providers(cls): - return ["openai", "bedrock"] + return ["azure", "openai", "bedrock"] @staticmethod - def _get_openai_chat_model(config: DictConfig) -> AssistantChatOpenAI: + def _get_azure_chat_model( + config: DictConfig, + ) -> AssistantAzureChatOpenAI: + if "AZURE_OPENAI_API_KEY" in os.environ: + api_key = os.environ["AZURE_OPENAI_API_KEY"] + else: + raise Exception("Azure API env variable AZURE_API_KEY not set") + + if "OPENAI_API_VERSION" not in os.environ: + raise Exception("Azure API env variable OPENAI_API_VERSION not set") + if "AZURE_OPENAI_ENDPOINT" not in os.environ: + raise Exception("Azure API env variable AZURE_OPENAI_ENDPOINT not set") + + logger.info(f"AGA is using model {config.model} from Azure to assist you with the task.") + return AssistantAzureChatOpenAI( + api_key=api_key, + model_name=config.model, + temperature=config.temperature, + max_tokens=config.max_tokens, + verbose=config.verbose, + ) + + @staticmethod + def _get_openai_chat_model( + config: DictConfig, + ) -> AssistantChatOpenAI: if "OPENAI_API_KEY" in os.environ: api_key = os.environ["OPENAI_API_KEY"] else: raise Exception("OpenAI API env variable OPENAI_API_KEY not set") logger.info(f"AGA is using model {config.model} from OpenAI to assist you with the task.") - return AssistantChatOpenAI( model_name=config.model, temperature=config.temperature, @@ -172,7 +247,9 @@ def _get_bedrock_chat_model(config: DictConfig) -> AssistantChatBedrock: ) @classmethod - def get_chat_model(cls, config: DictConfig) -> Union[AssistantChatOpenAI, AssistantChatBedrock]: + def get_chat_model( + cls, config: DictConfig + ) -> Union[AssistantChatOpenAI, AssistantAzureChatOpenAI, AssistantChatBedrock]: valid_providers = cls.get_valid_providers() assert config.provider in valid_providers, f"{config.provider} is not a valid provider in: {valid_providers}" @@ -184,7 +261,9 @@ def get_chat_model(cls, config: DictConfig) -> Union[AssistantChatOpenAI, Assist if config.model not in WHITE_LIST_LLM: logger.warning(f"{config.model} is not on the white list. Our white list models include {WHITE_LIST_LLM}") - if config.provider == "openai": + if config.provider == "azure": + return LLMFactory._get_azure_chat_model(config) + elif config.provider == "openai": return LLMFactory._get_openai_chat_model(config) elif config.provider == "bedrock": return LLMFactory._get_bedrock_chat_model(config) diff --git a/src/autogluon/assistant/ui/constants.py b/src/autogluon/assistant/ui/constants.py index 3008ede0..973c2afc 100644 --- a/src/autogluon/assistant/ui/constants.py +++ b/src/autogluon/assistant/ui/constants.py @@ -35,13 +35,18 @@ # LLM configurations LLM_MAPPING = { "Claude 3.5 with Amazon Bedrock": "anthropic.claude-3-5-sonnet-20241022-v2:0", - "GPT 4o": "gpt-4o-2024-08-06", + "GPT 4o with OpenAI": "gpt-4o-2024-08-06", + "GPT 4o with Azure": "gpt-4o-2024-08-06", } -LLM_OPTIONS = ["Claude 3.5 with Amazon Bedrock", "GPT 4o"] +LLM_OPTIONS = ["Claude 3.5 with Amazon Bedrock", "GPT 4o with OpenAI", "GPT 4o with Azure"] # Provider configuration -PROVIDER_MAPPING = {"Claude 3.5 with Amazon Bedrock": "bedrock", "GPT 4o": "openai"} +PROVIDER_MAPPING = { + "Claude 3.5 with Amazon Bedrock": "bedrock", + "GPT 4o with OpenAI": "openai", + "GPT 4o with Azure": "azure", +} INITIAL_STAGE = { "Task Understanding": [], diff --git a/tools/configure_llms.sh b/tools/configure_llms.sh index 1ec46b5c..83dde0fb 100644 --- a/tools/configure_llms.sh +++ b/tools/configure_llms.sh @@ -22,6 +22,9 @@ tmp_AWS_DEFAULT_REGION="" tmp_AWS_ACCESS_KEY_ID="" tmp_AWS_SECRET_ACCESS_KEY="" tmp_OPENAI_API_KEY="" +tmp_AZURE_OPENAI_API_KEY="" +tmp_OPENAI_API_VERSION="" +tmp_AZURE_OPENAI_ENDPOINT="" # Function to print colored messages print_color() { @@ -60,6 +63,20 @@ validate_openai_api_key() { return 1 } +# Function to validate Azure OpenAI endpoint +validate_azure_endpoint() { + local endpoint=$1 + [[ $endpoint =~ ^https://[a-zA-Z0-9-]+\.openai\.azure\.com/?$ ]] && return 0 + return 1 +} + +# Function to validate API version +validate_api_version() { + local version=$1 + [[ $version =~ ^[0-9]{4}-[0-9]{2}-[0-9]{2}$ ]] && return 0 + return 1 +} + # Function to read existing configuration into temporary variables read_existing_config() { if [ -f "$CONFIG_FILE" ]; then @@ -70,6 +87,9 @@ read_existing_config() { "AWS_ACCESS_KEY_ID") tmp_AWS_ACCESS_KEY_ID="$value" ;; "AWS_SECRET_ACCESS_KEY") tmp_AWS_SECRET_ACCESS_KEY="$value" ;; "OPENAI_API_KEY") tmp_OPENAI_API_KEY="$value" ;; + "AZURE_OPENAI_API_KEY") tmp_AZURE_OPENAI_API_KEY="$value" ;; + "OPENAI_API_VERSION") tmp_OPENAI_API_VERSION="$value" ;; + "AZURE_OPENAI_ENDPOINT") tmp_AZURE_OPENAI_ENDPOINT="$value" ;; esac fi done < "$CONFIG_FILE" @@ -83,15 +103,21 @@ save_configuration() { # Create or truncate the config file echo "" > "$CONFIG_FILE" || { print_color "$RED" "Error: Cannot write to '$CONFIG_FILE'"; return 1; } - if [ "$provider" = "bedrock" ]; then - # Update AWS variables - tmp_AWS_DEFAULT_REGION="$AWS_DEFAULT_REGION" - tmp_AWS_ACCESS_KEY_ID="$AWS_ACCESS_KEY_ID" - tmp_AWS_SECRET_ACCESS_KEY="$AWS_SECRET_ACCESS_KEY" - else - # Update OpenAI variable - tmp_OPENAI_API_KEY="$OPENAI_API_KEY" - fi + case "$provider" in + "bedrock") + tmp_AWS_DEFAULT_REGION="$AWS_DEFAULT_REGION" + tmp_AWS_ACCESS_KEY_ID="$AWS_ACCESS_KEY_ID" + tmp_AWS_SECRET_ACCESS_KEY="$AWS_SECRET_ACCESS_KEY" + ;; + "openai") + tmp_OPENAI_API_KEY="$OPENAI_API_KEY" + ;; + "azure") + tmp_AZURE_OPENAI_API_KEY="$AZURE_OPENAI_API_KEY" + tmp_OPENAI_API_VERSION="$OPENAI_API_VERSION" + tmp_AZURE_OPENAI_ENDPOINT="$AZURE_OPENAI_ENDPOINT" + ;; + esac # Save all configurations if [ -n "$tmp_AWS_ACCESS_KEY_ID" ]; then @@ -104,6 +130,12 @@ save_configuration() { echo "OPENAI_API_KEY=$tmp_OPENAI_API_KEY" >> "$CONFIG_FILE" fi + if [ -n "$tmp_AZURE_OPENAI_API_KEY" ]; then + echo "AZURE_OPENAI_API_KEY=$tmp_AZURE_OPENAI_API_KEY" >> "$CONFIG_FILE" + echo "OPENAI_API_VERSION=$tmp_OPENAI_API_VERSION" >> "$CONFIG_FILE" + echo "AZURE_OPENAI_ENDPOINT=$tmp_AZURE_OPENAI_ENDPOINT" >> "$CONFIG_FILE" + fi + # Export all variables if [ -n "$tmp_AWS_ACCESS_KEY_ID" ]; then export AWS_DEFAULT_REGION="$tmp_AWS_DEFAULT_REGION" @@ -115,6 +147,12 @@ save_configuration() { export OPENAI_API_KEY="$tmp_OPENAI_API_KEY" fi + if [ -n "$tmp_AZURE_OPENAI_API_KEY" ]; then + export AZURE_OPENAI_API_KEY="$tmp_AZURE_OPENAI_API_KEY" + export OPENAI_API_VERSION="$tmp_OPENAI_API_VERSION" + export AZURE_OPENAI_ENDPOINT="$tmp_AZURE_OPENAI_ENDPOINT" + fi + # Set proper permissions chmod 600 "$CONFIG_FILE" @@ -158,6 +196,16 @@ display_config() { else print_color "$YELLOW" "OpenAI is not configured" fi + + echo + print_color "$GREEN" "Azure OpenAI Configuration:" + if [ -n "$tmp_AZURE_OPENAI_API_KEY" ]; then + echo "AZURE_OPENAI_API_KEY=********" + echo "OPENAI_API_VERSION=${tmp_OPENAI_API_VERSION}" + echo "AZURE_OPENAI_ENDPOINT=${tmp_AZURE_OPENAI_ENDPOINT}" + else + print_color "$YELLOW" "Azure OpenAI is not configured" + fi echo } @@ -177,6 +225,16 @@ display_env_vars() { echo print_color "$GREEN" "OpenAI Environment Variables:" echo "OPENAI_API_KEY=${OPENAI_API_KEY:-(not set)}" + + echo + print_color "$GREEN" "Azure OpenAI Environment Variables:" + if [ -n "$AZURE_OPENAI_API_KEY" ]; then + echo "AZURE_OPENAI_API_KEY=********" + else + echo "AZURE_OPENAI_API_KEY=(not set)" + fi + echo "OPENAI_API_VERSION=${OPENAI_API_VERSION:-(not set)}" + echo "AZURE_OPENAI_ENDPOINT=${AZURE_OPENAI_ENDPOINT:-(not set)}" echo # Compare configuration with environment variables @@ -196,6 +254,15 @@ display_env_vars() { if [ -n "$tmp_OPENAI_API_KEY" ] && [ "$tmp_OPENAI_API_KEY" != "$OPENAI_API_KEY" ]; then has_mismatch=true fi + if [ -n "$tmp_AZURE_OPENAI_API_KEY" ] && [ "$tmp_AZURE_OPENAI_API_KEY" != "$AZURE_OPENAI_API_KEY" ]; then + has_mismatch=true + fi + if [ -n "$tmp_OPENAI_API_VERSION" ] && [ "$tmp_OPENAI_API_VERSION" != "$OPENAI_API_VERSION" ]; then + has_mismatch=true + fi + if [ -n "$tmp_AZURE_OPENAI_ENDPOINT" ] && [ "$tmp_AZURE_OPENAI_ENDPOINT" != "$AZURE_OPENAI_ENDPOINT" ]; then + has_mismatch=true + fi if [ "$has_mismatch" = true ]; then print_color "$YELLOW" "Warning: Some environment variables don't match the configuration file." @@ -209,7 +276,8 @@ configure_provider() { print_color "$GREEN" "Select your LLM provider to configure:" echo "1) AWS Bedrock" echo "2) OpenAI" - echo -n "Enter your choice (1/2): " + echo "3) Azure OpenAI" + echo -n "Enter your choice (1/2/3): " read provider_choice case $provider_choice in @@ -251,6 +319,34 @@ configure_provider() { save_configuration "openai" ;; + + 3) + print_color "$BLUE" "\nConfiguring Azure OpenAI..." + + echo -n "Enter your Azure OpenAI API Key: " + read -s AZURE_OPENAI_API_KEY + echo + + while true; do + echo -n "Enter the API version (YYYY-MM-DD format): " + read OPENAI_API_VERSION + if validate_api_version "$OPENAI_API_VERSION"; then + break + fi + print_color "$RED" "Invalid API version format. Please use YYYY-MM-DD format." + done + + while true; do + echo -n "Enter your Azure OpenAI endpoint (https://.openai.azure.com): " + read AZURE_OPENAI_ENDPOINT + if validate_azure_endpoint "$AZURE_OPENAI_ENDPOINT"; then + break + fi + print_color "$RED" "Invalid endpoint format. Please enter a valid Azure OpenAI endpoint." + done + + save_configuration "azure" + ;; *) print_color "$RED" "Invalid choice. Exiting."