Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/azure openai support #176

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ You can sign up for an OpenAI account [here](https://platform.openai.com/) and m
Important: Free-tier OpenAI accounts may be subject to rate limits, which could affect AG-A's performance. We recommend using a paid OpenAI API key for seamless functionality.


#### Azure OpenAI Setup
To use Azure OpenAI, you'll need to set the following Azure OpenAI values, as environment variables:
```bash
export AZURE_OPENAI_API_KEY=<...>
export OPENAI_API_VERSION=<...>
export AZURE_OPENAI_ENDPOINT=<...>
```

## Usage

We support two ways of using AutoGluon Assistant: WebUI and CLI.
Expand Down
35 changes: 27 additions & 8 deletions src/autogluon/assistant/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf

from autogluon.assistant.llm import AssistantChatBedrock, AssistantChatOpenAI, LLMFactory
from autogluon.assistant.llm import (
AssistantChatBedrock,
AssistantChatOpenAI,
AssistantAzureChatOpenAI,
LLMFactory,
)

from .predictor import AutogluonTabularPredictor
from .task import TabularPredictionTask
Expand All @@ -31,7 +36,9 @@
def timeout(seconds: int, error_message: Optional[str] = None):
if sys.platform == "win32":
# Windows implementation using threading
timer = threading.Timer(seconds, lambda: (_ for _ in ()).throw(TimeoutError(error_message)))
timer = threading.Timer(
seconds, lambda: (_ for _ in ()).throw(TimeoutError(error_message))
)
timer.start()
try:
yield
Expand All @@ -55,7 +62,9 @@ class TabularPredictionAssistant:

def __init__(self, config: DictConfig) -> None:
self.config = config
self.llm: Union[AssistantChatOpenAI, AssistantChatBedrock] = LLMFactory.get_chat_model(config.llm)
self.llm: Union[
AssistantChatOpenAI, AssistantAzureChatOpenAI, AssistantChatBedrock
] = LLMFactory.get_chat_model(config.llm)
self.predictor = AutogluonTabularPredictor(config.autogluon)
self.feature_transformers_config = get_feature_transformers_config(config)

Expand Down Expand Up @@ -95,13 +104,19 @@ def inference_task(self, task: TabularPredictionTask) -> TabularPredictionTask:
):
task = preprocessor.transform(task)
except Exception as e:
self.handle_exception(f"Task inference preprocessing: {preprocessor_class}", e)
self.handle_exception(
f"Task inference preprocessing: {preprocessor_class}", e
)

bold_start = "\033[1m"
bold_end = "\033[0m"

logger.info(f"{bold_start}Total number of prompt tokens:{bold_end} {self.llm.input_}")
logger.info(f"{bold_start}Total number of completion tokens:{bold_end} {self.llm.output_}")
logger.info(
f"{bold_start}Total number of prompt tokens:{bold_end} {self.llm.input_}"
)
logger.info(
f"{bold_start}Total number of completion tokens:{bold_end} {self.llm.output_}"
)
logger.info("Task understanding complete!")
return task

Expand All @@ -111,7 +126,9 @@ def preprocess_task(self, task: TabularPredictionTask) -> TabularPredictionTask:
task = self.inference_task(task)
if self.feature_transformers_config:
logger.info("Automatic feature generation starts...")
fe_transformers = [instantiate(ft_config) for ft_config in self.feature_transformers_config]
fe_transformers = [
instantiate(ft_config) for ft_config in self.feature_transformers_config
]
for fe_transformer in fe_transformers:
try:
with timeout(
Expand All @@ -120,7 +137,9 @@ def preprocess_task(self, task: TabularPredictionTask) -> TabularPredictionTask:
):
task = fe_transformer.fit_transform(task)
except Exception as e:
self.handle_exception(f"Task preprocessing: {fe_transformer.name}", e)
self.handle_exception(
f"Task preprocessing: {fe_transformer.name}", e
)
logger.info("Automatic feature generation complete!")
else:
logger.info("Automatic feature generation is disabled. ")
Expand Down
4 changes: 3 additions & 1 deletion src/autogluon/assistant/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from .llm import AssistantChatBedrock, AssistantChatOpenAI, LLMFactory
from .llm import AssistantChatBedrock, AssistantChatOpenAI, LLMFactory, AssistantAzureChatOpenAI

__all__ = [
"AssistantAzureChatOpenAI",
"AssistantChatOpenAI",
"AssistantChatBedrock",
"LLMFactory",

]
128 changes: 113 additions & 15 deletions src/autogluon/assistant/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import botocore
from langchain.schema import AIMessage, BaseMessage
from langchain_aws import ChatBedrock
from langchain_openai import ChatOpenAI
from langchain_openai import ChatOpenAI, AzureChatOpenAI
from omegaconf import DictConfig
from openai import OpenAI
from openai import OpenAI, AzureOpenAI
from pydantic import BaseModel, Field
from tenacity import retry, stop_after_attempt, wait_exponential

Expand All @@ -36,7 +36,50 @@ def describe(self) -> Dict[str, Any]:
"completion_tokens": self.output_,
}

@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10))
@retry(
stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10)
)
def invoke(self, *args, **kwargs):
input_: List[BaseMessage] = args[0]
response = super().invoke(*args, **kwargs)

# Update token usage
if isinstance(response, AIMessage) and response.usage_metadata:
self.input_ += response.usage_metadata.get("input_tokens", 0)
self.output_ += response.usage_metadata.get("output_tokens", 0)

self.history_.append(
{
"input": [{"type": msg.type, "content": msg.content} for msg in input_],
"output": pprint.pformat(dict(response)),
"prompt_tokens": self.input_,
"completion_tokens": self.output_,
}
)
return response


class AssistantAzureChatOpenAI(AzureChatOpenAI, BaseModel):
"""
AssistantAzureChatOpenAI is a subclass of AzureChatOpenAI that traces the input and output of the model.
"""

history_: List[Dict[str, Any]] = Field(default_factory=list)
input_: int = Field(default=0)
output_: int = Field(default=0)

def describe(self) -> Dict[str, Any]:
return {
"model": self.model_name,
"proxy": self.openai_proxy,
"history": self.history_,
"prompt_tokens": self.input_,
"completion_tokens": self.output_,
}

@retry(
stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10)
)
def invoke(self, *args, **kwargs):
input_: List[BaseMessage] = args[0]
response = super().invoke(*args, **kwargs)
Expand Down Expand Up @@ -74,7 +117,9 @@ def describe(self) -> Dict[str, Any]:
"completion_tokens": self.output_,
}

@retry(stop=stop_after_attempt(50), wait=wait_exponential(multiplier=1, min=4, max=10))
@retry(
stop=stop_after_attempt(50), wait=wait_exponential(multiplier=1, min=4, max=10)
)
def invoke(self, *args, **kwargs):
input_: List[BaseMessage] = args[0]
try:
Expand Down Expand Up @@ -107,7 +152,11 @@ def get_openai_models() -> List[str]:
try:
client = OpenAI()
models = client.models.list()
return [model.id for model in models if model.id.startswith(("gpt-3.5", "gpt-4"))]
return [
model.id
for model in models
if model.id.startswith(("gpt-3.5", "gpt-4"))
]
except Exception as e:
print(f"Error fetching OpenAI models: {e}")
return []
Expand All @@ -123,9 +172,25 @@ def get_bedrock_models() -> List[str]:
print(f"Error fetching Bedrock models: {e}")
return []

@staticmethod
def get_azure_models() -> List[str]:
try:
client = AzureOpenAI()
models = client.models.list()
return [
model.id
for model in models
if model.id.startswith(("gpt-3.5", "gpt-4"))
]
except Exception as e:
print(f"Error fetching Azure models: {e}")
return []

@classmethod
def get_valid_models(cls, provider):
if provider == "openai":
if provider == "azure":
return cls.get_azure_models()
elif provider == "openai":
return cls.get_openai_models()
elif provider == "bedrock":
model_names = cls.get_bedrock_models()
Expand All @@ -136,17 +201,40 @@ def get_valid_models(cls, provider):

@classmethod
def get_valid_providers(cls):
return ["openai", "bedrock"]
return ["azure", "openai", "bedrock"]

@staticmethod
def _get_openai_chat_model(config: DictConfig) -> AssistantChatOpenAI:
def _get_azure_chat_model(
config: DictConfig,
) -> AssistantAzureChatOpenAI:
if "AZURE_OPENAI_API_KEY" in os.environ:
api_key = os.environ["AZURE_OPENAI_API_KEY"]
else:
raise Exception("Azure API env variable AZURE_API_KEY not set")
Alex-Wenner-FHR marked this conversation as resolved.
Show resolved Hide resolved

logger.info(
f"AGA is using model {config.model} from Azure to assist you with the task."
)
return AssistantAzureChatOpenAI(
api_key = api_key,
model_name=config.model,
temperature=config.temperature,
max_tokens=config.max_tokens,
verbose=config.verbose,
)

@staticmethod
def _get_openai_chat_model(
config: DictConfig,
) -> AssistantChatOpenAI:
if "OPENAI_API_KEY" in os.environ:
api_key = os.environ["OPENAI_API_KEY"]
else:
raise Exception("OpenAI API env variable OPENAI_API_KEY not set")

logger.info(f"AGA is using model {config.model} from OpenAI to assist you with the task.")

logger.info(
f"AGA is using model {config.model} from OpenAI to assist you with the task."
)
return AssistantChatOpenAI(
model_name=config.model,
temperature=config.temperature,
Expand All @@ -158,7 +246,9 @@ def _get_openai_chat_model(config: DictConfig) -> AssistantChatOpenAI:

@staticmethod
def _get_bedrock_chat_model(config: DictConfig) -> AssistantChatBedrock:
logger.info(f"AGA is using model {config.model} from Bedrock to assist you with the task.")
logger.info(
f"AGA is using model {config.model} from Bedrock to assist you with the task."
)

return AssistantChatBedrock(
model_id=config.model,
Expand All @@ -172,19 +262,27 @@ def _get_bedrock_chat_model(config: DictConfig) -> AssistantChatBedrock:
)

@classmethod
def get_chat_model(cls, config: DictConfig) -> Union[AssistantChatOpenAI, AssistantChatBedrock]:
def get_chat_model(
cls, config: DictConfig
) -> Union[AssistantChatOpenAI, AssistantAzureChatOpenAI, AssistantChatBedrock]:
valid_providers = cls.get_valid_providers()
assert config.provider in valid_providers, f"{config.provider} is not a valid provider in: {valid_providers}"
assert (
config.provider in valid_providers
), f"{config.provider} is not a valid provider in: {valid_providers}"

valid_models = cls.get_valid_models(config.provider)
assert (
config.model in valid_models
), f"{config.model} is not a valid model in: {valid_models} for provider {config.provider}"

if config.model not in WHITE_LIST_LLM:
logger.warning(f"{config.model} is not on the white list. Our white list models include {WHITE_LIST_LLM}")
logger.warning(
f"{config.model} is not on the white list. Our white list models include {WHITE_LIST_LLM}"
)

if config.provider == "openai":
if config.provider == "azure":
return LLMFactory._get_azure_chat_model(config)
elif config.provider == "openai":
return LLMFactory._get_openai_chat_model(config)
elif config.provider == "bedrock":
return LLMFactory._get_bedrock_chat_model(config)
Expand Down
8 changes: 5 additions & 3 deletions src/autogluon/assistant/ui/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,15 @@
# LLM configurations
LLM_MAPPING = {
"Claude 3.5 with Amazon Bedrock": "anthropic.claude-3-5-sonnet-20241022-v2:0",
"GPT 4o": "gpt-4o-2024-08-06",
"GPT 4o with OpenAI": "gpt-4o-2024-08-06",
"GPT 4o with Azure": "gpt-4o-2024-08-06",

}

LLM_OPTIONS = ["Claude 3.5 with Amazon Bedrock", "GPT 4o"]
LLM_OPTIONS = ["Claude 3.5 with Amazon Bedrock", "GPT 4o with OpenAI", "GPT 4o with Azure"]

# Provider configuration
PROVIDER_MAPPING = {"Claude 3.5 with Amazon Bedrock": "bedrock", "GPT 4o": "openai"}
PROVIDER_MAPPING = {"Claude 3.5 with Amazon Bedrock": "bedrock", "GPT 4o with OpenAI": "openai", "GPT 4o with Azure": "azure"}

INITIAL_STAGE = {
"Task Understanding": [],
Expand Down
Loading
Loading