-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
399 additions
and
251 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Import utils | ||
from utils import (build_logger, get_config) | ||
|
||
# Import misc | ||
from azure.core.credentials import AzureKeyCredential | ||
from fastapi import HTTPException, status | ||
from tenacity import retry, stop_after_attempt, wait_random_exponential | ||
import azure.ai.contentsafety as azure_cs | ||
import azure.core.exceptions as azure_exceptions | ||
|
||
|
||
### | ||
# Init misc | ||
### | ||
|
||
logger = build_logger(__name__) | ||
|
||
### | ||
# Init Azure Content Safety | ||
### | ||
|
||
# Score are following: 0 - Safe, 2 - Low, 4 - Medium, 6 - High | ||
# See: https://review.learn.microsoft.com/en-us/azure/cognitive-services/content-safety/concepts/harm-categories?branch=release-build-content-safety#severity-levels | ||
ACS_SEVERITY_THRESHOLD = 2 | ||
ACS_API_BASE = get_config("acs", "api_base", str, required=True) | ||
ACS_API_TOKEN = get_config("acs", "api_token", str, required=True) | ||
ACS_MAX_LENGTH = get_config("acs", "max_length", int, required=True) | ||
logger.info(f"Connected Azure Content Safety to {ACS_API_BASE}") | ||
acs_client = azure_cs.ContentSafetyClient( | ||
ACS_API_BASE, AzureKeyCredential(ACS_API_TOKEN) | ||
) | ||
|
||
|
||
class ContentSafety: | ||
@retry( | ||
reraise=True, | ||
stop=stop_after_attempt(3), | ||
wait=wait_random_exponential(multiplier=0.5, max=30), | ||
) | ||
async def is_moderated(self, prompt: str) -> bool: | ||
logger.debug(f"Checking moderation for text: {prompt}") | ||
|
||
if len(prompt) > ACS_MAX_LENGTH: | ||
logger.info(f"Message ({len(prompt)}) too long for moderation") | ||
raise HTTPException( | ||
status_code=status.HTTP_400_BAD_REQUEST, | ||
detail="Message too long", | ||
) | ||
|
||
req = azure_cs.models.AnalyzeTextOptions( | ||
text=prompt, | ||
categories=[ | ||
azure_cs.models.TextCategory.HATE, | ||
azure_cs.models.TextCategory.SELF_HARM, | ||
azure_cs.models.TextCategory.SEXUAL, | ||
azure_cs.models.TextCategory.VIOLENCE, | ||
], | ||
) | ||
|
||
try: | ||
res = acs_client.analyze_text(req) | ||
except azure_exceptions.ClientAuthenticationError as e: | ||
logger.exception(e) | ||
return False | ||
|
||
is_moderated = any( | ||
cat.severity >= ACS_SEVERITY_THRESHOLD | ||
for cat in [ | ||
res.hate_result, | ||
res.self_harm_result, | ||
res.sexual_result, | ||
res.violence_result, | ||
] | ||
) | ||
if is_moderated: | ||
logger.info(f"Message is moderated: {prompt}") | ||
logger.debug(f"Moderation result: {res}") | ||
|
||
return is_moderated |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
# Import utils | ||
from uuid import UUID | ||
from utils import (build_logger, get_config, hash_token) | ||
|
||
# Import misc | ||
from azure.identity import DefaultAzureCredential | ||
from models.user import UserModel | ||
from tenacity import retry, stop_after_attempt, wait_random_exponential | ||
from typing import Any, Dict, List, AsyncGenerator, Union | ||
import asyncio | ||
import openai | ||
|
||
|
||
### | ||
# Init misc | ||
### | ||
|
||
logger = build_logger(__name__) | ||
loop = asyncio.get_running_loop() | ||
|
||
|
||
### | ||
# Init OpenIA | ||
### | ||
|
||
async def refresh_oai_token_background(): | ||
""" | ||
Refresh OpenAI token every 15 minutes. | ||
The OpenAI SDK does not support token refresh, so we need to do it manually. We passe manually the token to the SDK. Azure AD tokens are valid for 30 mins, but we refresh every 15 minutes to be safe. | ||
See: https://github.com/openai/openai-python/pull/350#issuecomment-1489813285 | ||
""" | ||
while True: | ||
logger.info("Refreshing OpenAI token") | ||
oai_cred = DefaultAzureCredential() | ||
oai_token = oai_cred.get_token("https://cognitiveservices.azure.com/.default") | ||
openai.api_key = oai_token.token | ||
# Execute every 20 minutes | ||
await asyncio.sleep(15 * 60) | ||
|
||
|
||
openai.api_base = get_config("openai", "api_base", str, required=True) | ||
openai.api_type = "azure_ad" | ||
openai.api_version = "2023-05-15" | ||
logger.info(f"Using Aure private service ({openai.api_base})") | ||
loop.create_task(refresh_oai_token_background()) | ||
|
||
OAI_GPT_DEPLOY_ID = get_config("openai", "gpt_deploy_id", str, required=True) | ||
OAI_GPT_MAX_TOKENS = get_config("openai", "gpt_max_tokens", int, required=True) | ||
OAI_GPT_MODEL = get_config( | ||
"openai", "gpt_model", str, default="gpt-3.5-turbo", required=True | ||
) | ||
logger.info( | ||
f'Using OpenAI ADA model "{OAI_GPT_MODEL}" ({OAI_GPT_DEPLOY_ID}) with {OAI_GPT_MAX_TOKENS} tokens max' | ||
) | ||
|
||
OAI_ADA_DEPLOY_ID = get_config("openai", "ada_deploy_id", str, required=True) | ||
OAI_ADA_MAX_TOKENS = get_config("openai", "ada_max_tokens", int, required=True) | ||
OAI_ADA_MODEL = get_config( | ||
"openai", "ada_model", str, default="text-embedding-ada-002", required=True | ||
) | ||
logger.info( | ||
f'Using OpenAI ADA model "{OAI_ADA_MODEL}" ({OAI_ADA_DEPLOY_ID}) with {OAI_ADA_MAX_TOKENS} tokens max' | ||
) | ||
|
||
|
||
class OpenAI: | ||
@retry( | ||
reraise=True, | ||
stop=stop_after_attempt(3), | ||
wait=wait_random_exponential(multiplier=0.5, max=30), | ||
) | ||
async def vector_from_text(self, prompt: str, user_id: UUID) -> List[float]: | ||
logger.debug(f"Getting vector for text: {prompt}") | ||
try: | ||
res = openai.Embedding.create( | ||
deployment_id=OAI_ADA_DEPLOY_ID, | ||
input=prompt, | ||
model=OAI_ADA_MODEL, | ||
user=user_id.hex, | ||
) | ||
except openai.error.AuthenticationError as e: | ||
logger.exception(e) | ||
return [] | ||
|
||
return res.data[0].embedding | ||
|
||
@retry( | ||
reraise=True, | ||
stop=stop_after_attempt(3), | ||
wait=wait_random_exponential(multiplier=0.5, max=30), | ||
) | ||
async def completion(self, messages: List[Dict[str, str]], current_user: UserModel) -> Union[str, None]: | ||
try: | ||
# Use chat completion to get a more natural response and lower the usage cost | ||
completion = openai.ChatCompletion.create( | ||
deployment_id=OAI_GPT_DEPLOY_ID, | ||
messages=messages, | ||
model=OAI_GPT_MODEL, | ||
presence_penalty=1, # Increase the model's likelihood to talk about new topics | ||
user=hash_token(current_user.id.bytes).hex, | ||
) | ||
content = completion["choices"][0].message.content | ||
except openai.error.AuthenticationError as e: | ||
logger.exception(e) | ||
return | ||
|
||
return content | ||
|
||
@retry( | ||
reraise=True, | ||
stop=stop_after_attempt(3), | ||
wait=wait_random_exponential(multiplier=0.5, max=30), | ||
) | ||
async def completion_stream(self, messages: List[Dict[str, str]], current_user: UserModel) -> AsyncGenerator[Any, None]: | ||
try: | ||
# Use chat completion to get a more natural response and lower the usage cost | ||
chunks = openai.ChatCompletion.create( | ||
deployment_id=OAI_GPT_DEPLOY_ID, | ||
messages=messages, | ||
model=OAI_GPT_MODEL, | ||
presence_penalty=1, # Increase the model's likelihood to talk about new topics | ||
stream=True, | ||
user=hash_token(current_user.id.bytes).hex, | ||
) | ||
except openai.error.AuthenticationError as e: | ||
logger.exception(e) | ||
return | ||
|
||
for chunk in chunks: | ||
content = chunk["choices"][0].get("delta", {}).get("content") | ||
if content is not None: | ||
yield content |
Oops, something went wrong.