Skip to content

Commit

Permalink
feat: improve content ai safety logging and responses
Browse files Browse the repository at this point in the history
  • Loading branch information
mrickettsk committed Aug 4, 2024
1 parent fa7d0b6 commit 6abe546
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 42 deletions.
7 changes: 7 additions & 0 deletions docs/mkdocs/aep-docs/docs/design.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ The diagram below illustrates the flow of the AEP application from the pipeline'

![AEP Application Flow](img/aep-application-flow.png)

## Prompt Injection Detection
AEP uses [Azure's AI Content Safety Service](https://azure-ai-content-safety-api-docs.developer.azure-api.net/) to detect prompt injections in the generated text.

If the prompt contains more than 10000 characters, the prompt is split into smaller chunks of 10000 characters each. Each chunk is then sent to the AI Content Safety Service for analysis.

If the prompt contains more than 100000 characters, the prompt is automatically flagged as a possible prompt injection to protect the system from potential abuse.

## AEP Costs Summary (March 31 - April 29) 💰

| Azure Resource | SKU/Size | Cost |
Expand Down
45 changes: 43 additions & 2 deletions src/helpers/open_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from openai import OpenAI, OpenAIError, AzureOpenAI
from logger.logger import event_logger
from middleware.metrics import update_last_request_time
import requests

def call_openai(system_prompt, user_prompt, query, model):
tic = time.perf_counter()
Expand Down Expand Up @@ -48,6 +49,46 @@ def call_openai(system_prompt, user_prompt, query, model):
event_logger.error(err)
return {"success": False, "error": str(err), "time": f"0.0"}


def call_ai_content_safety(prompt):
"""
Check for prompt injection in a single chunk of the prompt.
Args:
chunk (str): The prompt to check for injection.
"""

url = config.azure_cs_endpoint + "/contentsafety/text:shieldPrompt?api-version=2024-02-15-preview"
headers = {
'Ocp-Apim-Subscription-Key': config.azure_cs_key,
'Content-Type': 'application/json'
}

# Use the prompt argument as the document value
data = {
"documents": [
f"{prompt}"
]
}

# Make a POST request to the AI Content Safety API
try:
response = requests.post(url, headers=headers, json=data)
# Log the response
response_json = response.json()
event_logger.info(f"Response from AI ContentSafety: {response_json}")
except Exception as err:
event_logger.error(err)
return {"success": False, "error": str(err), "time": f"0.0"}
event_logger.error(f"{err}")
event_logger.error("Failed to make request to AI Content Safety")
return False

# Check if attackDetected is True in documentsAnalysis
try:
if response_json['documentsAnalysis'][0]['attackDetected']:
event_logger.info(f"Prompt injection Detected in: {prompt}")
return False # Fail if attackDetected is True
except Exception as err:
event_logger.error("Failed to check for prompt injection in response from AI Content Safety")
event_logger.error(f"{err}")
return False

return True
69 changes: 34 additions & 35 deletions src/helpers/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,8 @@
import re
from logger.logger import event_logger
import helpers.config as config
from dotenv import load_dotenv
import helpers.open_ai as openai
from middleware.metrics import update_number_of_tokens_saved, update_perc_of_tokens_saved
import requests
import helpers.config as config

load_dotenv()

def read_prompt(prompt_type, request_id, prompt):
event_logger.info(f"Request ID: {request_id} Searching Prompt File")
Expand Down Expand Up @@ -89,33 +85,36 @@ def reduce_prompt_tokens(prompt):


def check_for_prompt_inj(prompt):
event_logger.debug(f"Checking for prompt injection")
url = config.azure_cs_endpoint + "/contentsafety/text:shieldPrompt?api-version=2024-02-15-preview"
event_logger.debug(f"CS Config URL: {url}")
headers = {
'Ocp-Apim-Subscription-Key': config.azure_cs_key,
'Content-Type': 'application/json'
}
data = {
# Use the prompt argument as the userPrompt value
"documents": [
f"{prompt}"
]
}
try:
response = requests.post(url, headers=headers, data=json.dumps(data))
event_logger.debug(f"Response from AI ContentSafety: {response.json()}")

# Log the response
response_json = response.json()

# Check if attackDetected is True in either userPromptAnalysis or documentsAnalysis
if response_json['documentsAnalysis'][0]['attackDetected']:
event_logger.info(f"Response from AI ContentSafety: {response.json()}")
event_logger.info(f"Prompt injection Detected in: {prompt}")
return False # Fail if attackDetected is True

except Exception as err:
event_logger.error(f"Failed to perform prompt injection detection: {err}")

return True
"""
Check for prompt injection in the given prompt.
If the prompt contains more than 10000 unicode characters, it will be split into chunks of 10000 characters.
If the prompt contains more than 100000 unicode characters, the function will return False.
Args:
prompt (str): The prompt to check for injection.
Returns:
bool: True if no prompt injection is detected, False otherwise.
"""
# Calculate number of unicode characters in the prompt
unicode_count = len(prompt.encode('utf-8'))

# If the number of unicode characters exceeds 100000 then return False
if unicode_count > 100000:
event_logger.warning("Prompt contained more than 100000 unicode characters.")
return False

# If the number of unicode characters exceeds 10000 then split the prompt into chunks of 10000 characters
if unicode_count > 10000:
prompt = [prompt[i:i+10000] for i in range(0, len(prompt), 10000)]
else:
prompt = [prompt]

# Loop through the number of prompt chunks and check for prompt injection if any return False then return False
for chunk in prompt:
result = openai.call_ai_content_safety(chunk)
if not result:
return False

return True
18 changes: 13 additions & 5 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,21 +74,22 @@ async def create_predefined_query(request: Request, query: PredefinedQuery):
try:
prompt_dict = read_prompt(query.prompt_type, query.request_id, query.prompt)
except Exception as err:
event_logger.error(f"Request ID: {query.request_id} | Error: {err}")
event_logger.info(f"Request ID: {query.request_id} | Error: {err}")
raise HTTPException(status_code=500, detail=f"Internal Server Error. Request ID: {query.request_id}")

# Check for prompt injection
result = check_for_prompt_inj(prompt_dict["prompt"])
if not result:
event_logger.warning("Possible malicious content detected, including prompt injection.")
raise HTTPException(status_code=400, detail="Bad Request: Possible malicious content detected in prompt.")

event_logger.warning(f"Possible malicious content detected, including prompt injection in request ID: {query.request_id}")
raise HTTPException(status_code=400, detail=f"Bad Request: Possible malicious content detected in prompt. If you believe this is a mistake, reference: {query.request_id}")
event_logger.info(f"Checked for injection")

response = call_openai(prompt_dict["system"], prompt_dict["user"], prompt_dict["prompt"], prompt_dict["model"])

# If response is not successful, raise a 500 status code and log error in logs/requests.log
if not response["success"]:
error = response["error"]
request_logger.error(f"Request ID: {query.request_id} | Error: {error} | Time Elapsed: {response['time']}")
request_logger.info(f"Request ID: {query.request_id} | Error: {error} | Time Elapsed: {response['time']}")
raise HTTPException(status_code=500, detail=f"Internal Server Error. Request ID: {query.request_id} | Error: {error}")
else:
request_logger.info(f"Request ID: {query.request_id} | Successful: True | Time Elapsed: {response['time']}")
Expand All @@ -104,6 +105,13 @@ async def create_info_query(request: Request, info_query: CustomQuery):
if info_query.compression_enabled:
info_query.prompt = reduce_prompt_tokens(info_query.prompt)

# Check for prompt injection
result = check_for_prompt_inj(info_query.prompt)
if not result:
event_logger.warning(f"Possible malicious content detected, including prompt injection in request ID: {info_query.request_id}")
raise HTTPException(status_code=400, detail=f"Bad Request: Possible malicious content detected in prompt. If you believe this is a mistake, reference: {info_query.request_id}")
event_logger.info(f"Checked for injection")

# Read in custom prompt defined in the request and make call to OpenAI
response = call_openai(info_query.system_prompt, info_query.user_prompt, info_query.prompt, info_query.model)

Expand Down
10 changes: 10 additions & 0 deletions src/tests/test_prompt_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ def test_prompt_compression(self):
print(compressed_prompt)
assert match is None

def test_prompt_injection_detection(self):
# Generate a prompt with more than 100000 unicode characters
prompt = "a"*100001

# Check for prompt injection
result = prompts.check_for_prompt_inj(prompt)

# Assert that the function returns False
assert not result

def test_retrieve_prompts_exc(self):
# rename prompt file to force an exception
os.rename('system_prompts/prompts.json', 'system_prompts/prompts.json.test')
Expand Down

0 comments on commit 6abe546

Please sign in to comment.