Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/azure openai support #176

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ You can sign up for an OpenAI account [here](https://platform.openai.com/) and m
Important: Free-tier OpenAI accounts may be subject to rate limits, which could affect AG-A's performance. We recommend using a paid OpenAI API key for seamless functionality.


#### Azure OpenAI Setup
To use Azure OpenAI, you'll need to set the following Azure OpenAI values, as environment variables:
```bash
export AZURE_OPENAI_API_KEY=<...>
export OPENAI_API_VERSION=<...>
export AZURE_OPENAI_ENDPOINT=<...>
```

## Usage

We support two ways of using AutoGluon Assistant: WebUI and CLI.
Expand Down
6 changes: 4 additions & 2 deletions src/autogluon/assistant/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf

from autogluon.assistant.llm import AssistantChatBedrock, AssistantChatOpenAI, LLMFactory
from autogluon.assistant.llm import AssistantAzureChatOpenAI, AssistantChatBedrock, AssistantChatOpenAI, LLMFactory

from .predictor import AutogluonTabularPredictor
from .task import TabularPredictionTask
Expand Down Expand Up @@ -55,7 +55,9 @@ class TabularPredictionAssistant:

def __init__(self, config: DictConfig) -> None:
self.config = config
self.llm: Union[AssistantChatOpenAI, AssistantChatBedrock] = LLMFactory.get_chat_model(config.llm)
self.llm: Union[AssistantChatOpenAI, AssistantAzureChatOpenAI, AssistantChatBedrock] = (
LLMFactory.get_chat_model(config.llm)
)
self.predictor = AutogluonTabularPredictor(config.autogluon)
self.feature_transformers_config = get_feature_transformers_config(config)

Expand Down
3 changes: 2 additions & 1 deletion src/autogluon/assistant/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .llm import AssistantChatBedrock, AssistantChatOpenAI, LLMFactory
from .llm import AssistantAzureChatOpenAI, AssistantChatBedrock, AssistantChatOpenAI, LLMFactory

__all__ = [
"AssistantAzureChatOpenAI",
"AssistantChatOpenAI",
"AssistantChatBedrock",
"LLMFactory",
Expand Down
95 changes: 87 additions & 8 deletions src/autogluon/assistant/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import botocore
from langchain.schema import AIMessage, BaseMessage
from langchain_aws import ChatBedrock
from langchain_openai import ChatOpenAI
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from omegaconf import DictConfig
from openai import OpenAI
from openai import AzureOpenAI, OpenAI
from pydantic import BaseModel, Field
from tenacity import retry, stop_after_attempt, wait_exponential

Expand Down Expand Up @@ -57,6 +57,45 @@ def invoke(self, *args, **kwargs):
return response


class AssistantAzureChatOpenAI(AzureChatOpenAI, BaseModel):
"""
AssistantAzureChatOpenAI is a subclass of AzureChatOpenAI that traces the input and output of the model.
"""

history_: List[Dict[str, Any]] = Field(default_factory=list)
input_: int = Field(default=0)
output_: int = Field(default=0)

def describe(self) -> Dict[str, Any]:
return {
"model": self.model_name,
"proxy": self.openai_proxy,
"history": self.history_,
"prompt_tokens": self.input_,
"completion_tokens": self.output_,
}

@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10))
def invoke(self, *args, **kwargs):
input_: List[BaseMessage] = args[0]
response = super().invoke(*args, **kwargs)

# Update token usage
if isinstance(response, AIMessage) and response.usage_metadata:
self.input_ += response.usage_metadata.get("input_tokens", 0)
self.output_ += response.usage_metadata.get("output_tokens", 0)

self.history_.append(
{
"input": [{"type": msg.type, "content": msg.content} for msg in input_],
"output": pprint.pformat(dict(response)),
"prompt_tokens": self.input_,
"completion_tokens": self.output_,
}
)
return response


class AssistantChatBedrock(ChatBedrock, BaseModel):
"""
AssistantChatBedrock is a subclass of ChatBedrock that traces the input and output of the model.
Expand Down Expand Up @@ -123,9 +162,21 @@ def get_bedrock_models() -> List[str]:
print(f"Error fetching Bedrock models: {e}")
return []

@staticmethod
def get_azure_models() -> List[str]:
try:
client = AzureOpenAI()
models = client.models.list()
return [model.id for model in models if model.id.startswith(("gpt-3.5", "gpt-4"))]
except Exception as e:
print(f"Error fetching Azure models: {e}")
return []

@classmethod
def get_valid_models(cls, provider):
if provider == "openai":
if provider == "azure":
return cls.get_azure_models()
elif provider == "openai":
return cls.get_openai_models()
elif provider == "bedrock":
model_names = cls.get_bedrock_models()
Expand All @@ -136,17 +187,41 @@ def get_valid_models(cls, provider):

@classmethod
def get_valid_providers(cls):
return ["openai", "bedrock"]
return ["azure", "openai", "bedrock"]

@staticmethod
def _get_openai_chat_model(config: DictConfig) -> AssistantChatOpenAI:
def _get_azure_chat_model(
config: DictConfig,
) -> AssistantAzureChatOpenAI:
if "AZURE_OPENAI_API_KEY" in os.environ:
api_key = os.environ["AZURE_OPENAI_API_KEY"]
else:
raise Exception("Azure API env variable AZURE_API_KEY not set")
Alex-Wenner-FHR marked this conversation as resolved.
Show resolved Hide resolved

if "OPENAI_API_VERSION" not in os.environ:
raise Exception("Azure API env variable OPENAI_API_VERSION not set")
if "AZURE_OPENAI_ENDPOINT" not in os.environ:
raise Exception("Azure API env variable AZURE_OPENAI_ENDPOINT not set")

logger.info(f"AGA is using model {config.model} from Azure to assist you with the task.")
return AssistantAzureChatOpenAI(
api_key=api_key,
model_name=config.model,
temperature=config.temperature,
max_tokens=config.max_tokens,
verbose=config.verbose,
)

@staticmethod
def _get_openai_chat_model(
config: DictConfig,
) -> AssistantChatOpenAI:
if "OPENAI_API_KEY" in os.environ:
api_key = os.environ["OPENAI_API_KEY"]
else:
raise Exception("OpenAI API env variable OPENAI_API_KEY not set")

logger.info(f"AGA is using model {config.model} from OpenAI to assist you with the task.")

return AssistantChatOpenAI(
model_name=config.model,
temperature=config.temperature,
Expand All @@ -172,7 +247,9 @@ def _get_bedrock_chat_model(config: DictConfig) -> AssistantChatBedrock:
)

@classmethod
def get_chat_model(cls, config: DictConfig) -> Union[AssistantChatOpenAI, AssistantChatBedrock]:
def get_chat_model(
cls, config: DictConfig
) -> Union[AssistantChatOpenAI, AssistantAzureChatOpenAI, AssistantChatBedrock]:
valid_providers = cls.get_valid_providers()
assert config.provider in valid_providers, f"{config.provider} is not a valid provider in: {valid_providers}"

Expand All @@ -184,7 +261,9 @@ def get_chat_model(cls, config: DictConfig) -> Union[AssistantChatOpenAI, Assist
if config.model not in WHITE_LIST_LLM:
logger.warning(f"{config.model} is not on the white list. Our white list models include {WHITE_LIST_LLM}")

if config.provider == "openai":
if config.provider == "azure":
return LLMFactory._get_azure_chat_model(config)
elif config.provider == "openai":
return LLMFactory._get_openai_chat_model(config)
elif config.provider == "bedrock":
return LLMFactory._get_bedrock_chat_model(config)
Expand Down
11 changes: 8 additions & 3 deletions src/autogluon/assistant/ui/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,18 @@
# LLM configurations
LLM_MAPPING = {
"Claude 3.5 with Amazon Bedrock": "anthropic.claude-3-5-sonnet-20241022-v2:0",
"GPT 4o": "gpt-4o-2024-08-06",
"GPT 4o with OpenAI": "gpt-4o-2024-08-06",
"GPT 4o with Azure": "gpt-4o-2024-08-06",
}

LLM_OPTIONS = ["Claude 3.5 with Amazon Bedrock", "GPT 4o"]
LLM_OPTIONS = ["Claude 3.5 with Amazon Bedrock", "GPT 4o with OpenAI", "GPT 4o with Azure"]

# Provider configuration
PROVIDER_MAPPING = {"Claude 3.5 with Amazon Bedrock": "bedrock", "GPT 4o": "openai"}
PROVIDER_MAPPING = {
"Claude 3.5 with Amazon Bedrock": "bedrock",
"GPT 4o with OpenAI": "openai",
"GPT 4o with Azure": "azure",
}

INITIAL_STAGE = {
"Task Understanding": [],
Expand Down
Loading
Loading