diff --git a/.github/workflows/run-test-suite.yml b/.github/workflows/run-test-suite.yml index f083755..5a2f308 100644 --- a/.github/workflows/run-test-suite.yml +++ b/.github/workflows/run-test-suite.yml @@ -25,13 +25,12 @@ jobs: OPENAI_API_KEY: ${{ secrets.TEST_AZURE_OPENAI_KEY }} AZURE_CS_ENDPOINT: ${{ secrets.TEST_AZURE_CS_ENDPOINT }} AZURE_CS_KEY: ${{ secrets.TEST_AZURE_CS_KEY }} - SYSTEM_PROMPT_FILE: "system_prompts/prompts.json" + azure_openai_api_version: "2023-12-01-preview" SYSTEM_API_KEY: "system" OPENAI_API_TYPE: "azure" steps: - - name: Checkout repository uses: actions/checkout@v2 @@ -40,23 +39,15 @@ jobs: with: python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + if [ -f src/requirements.txt ]; then pip install -r src/requirements.txt; fi + - name: Run tests + working-directory: src/ run: | - pwd - echo "AZURE_VAULT_ID=${AZURE_VAULT_ID}" >> .env - echo "AZURE_CLIENT_ID=${AZURE_CLIENT_ID}" >> .env - echo "AZURE_TENANT_ID=${AZURE_TENANT_ID}" >> .env - echo "AZURE_CLIENT_SECRET=${AZURE_CLIENT_SECRET}" >> .env - echo "AZURE_OPENAI_ENDPOINT=${AZURE_OPENAI_ENDPOINT}" >> .env - echo "AZURE_OPENAI_KEY=${AZURE_OPENAI_KEY}" >> .env - echo "OPENAI_API_KEY=${OPENAI_API_KEY}" >> .env - echo "AZURE_CS_ENDPOINT=${AZURE_CS_ENDPOINT}" >> .env - echo "AZURE_CS_KEY=${AZURE_CS_KEY}" >> .env - echo "SYSTEM_PROMPT_FILE=${SYSTEM_PROMPT_FILE}" >> .env - echo "azure_openai_api_version=${azure_openai_api_version}" >> .env - echo "SYSTEM_API_KEY=${SYSTEM_API_KEY}" >> .env - echo "OPENAI_API_TYPE=${OPENAI_API_TYPE}" >> .env - make unittest + pytest tests/ -vv -s --junit-xml=test-results.xml - name: Surface failing tests if: always() @@ -75,7 +66,7 @@ jobs: display-options: fEX # (Optional) Fail the workflow if no JUnit XML was found. - fail-on-empty: false + fail-on-empty: true # (Optional) Title of the test results section in the workflow summary title: AEP Test Results diff --git a/docs/mkdocs/aep-docs/docs/design.md b/docs/mkdocs/aep-docs/docs/design.md index fb2f49e..7435bd0 100644 --- a/docs/mkdocs/aep-docs/docs/design.md +++ b/docs/mkdocs/aep-docs/docs/design.md @@ -12,6 +12,13 @@ The diagram below illustrates the flow of the AEP application from the pipeline' ![AEP Application Flow](img/aep-application-flow.png) +## Prompt Injection Detection +AEP uses [Azure's AI Content Safety Service](https://azure-ai-content-safety-api-docs.developer.azure-api.net/) to detect prompt injections in the generated text. + +If the prompt contains more than 10000 characters, the prompt is split into smaller chunks of 10000 characters each. Each chunk is then sent to the AI Content Safety Service for analysis. + +If the prompt contains more than 100000 characters, the prompt is automatically flagged as a possible prompt injection to protect the system from potential abuse. + ## AEP Costs Summary (March 31 - April 29) 💰 | Azure Resource | SKU/Size | Cost | diff --git a/src/helpers/open_ai.py b/src/helpers/open_ai.py index f753047..7dcd8ce 100644 --- a/src/helpers/open_ai.py +++ b/src/helpers/open_ai.py @@ -4,6 +4,7 @@ from openai import OpenAI, OpenAIError, AzureOpenAI from logger.logger import event_logger from middleware.metrics import update_last_request_time +import requests def call_openai(system_prompt, user_prompt, query, model): tic = time.perf_counter() @@ -48,6 +49,46 @@ def call_openai(system_prompt, user_prompt, query, model): event_logger.error(err) return {"success": False, "error": str(err), "time": f"0.0"} + +def call_ai_content_safety(prompt): + """ + Check for prompt injection in a single chunk of the prompt. + Args: + chunk (str): The prompt to check for injection. + """ + + url = config.azure_cs_endpoint + "/contentsafety/text:shieldPrompt?api-version=2024-02-15-preview" + headers = { + 'Ocp-Apim-Subscription-Key': config.azure_cs_key, + 'Content-Type': 'application/json' + } + + # Use the prompt argument as the document value + data = { + "documents": [ + f"{prompt}" + ] + } + + # Make a POST request to the AI Content Safety API + try: + response = requests.post(url, headers=headers, json=data) + # Log the response + response_json = response.json() + event_logger.info(f"Response from AI ContentSafety: {response_json}") except Exception as err: - event_logger.error(err) - return {"success": False, "error": str(err), "time": f"0.0"} + event_logger.error(f"{err}") + event_logger.error("Failed to make request to AI Content Safety") + return False + + # Check if attackDetected is True in documentsAnalysis + try: + if response_json['documentsAnalysis'][0]['attackDetected']: + event_logger.info(f"Prompt injection Detected in: {prompt}") + return False # Fail if attackDetected is True + except Exception as err: + event_logger.error("Failed to check for prompt injection in response from AI Content Safety") + event_logger.error(f"{err}") + return False + + return True \ No newline at end of file diff --git a/src/helpers/prompts.py b/src/helpers/prompts.py index 12fe03a..e074909 100644 --- a/src/helpers/prompts.py +++ b/src/helpers/prompts.py @@ -3,12 +3,8 @@ import re from logger.logger import event_logger import helpers.config as config -from dotenv import load_dotenv +import helpers.open_ai as openai from middleware.metrics import update_number_of_tokens_saved, update_perc_of_tokens_saved -import requests -import helpers.config as config - -load_dotenv() def read_prompt(prompt_type, request_id, prompt): event_logger.info(f"Request ID: {request_id} Searching Prompt File") @@ -89,33 +85,36 @@ def reduce_prompt_tokens(prompt): def check_for_prompt_inj(prompt): - event_logger.debug(f"Checking for prompt injection") - url = config.azure_cs_endpoint + "/contentsafety/text:shieldPrompt?api-version=2024-02-15-preview" - event_logger.debug(f"CS Config URL: {url}") - headers = { - 'Ocp-Apim-Subscription-Key': config.azure_cs_key, - 'Content-Type': 'application/json' - } - data = { - # Use the prompt argument as the userPrompt value - "documents": [ - f"{prompt}" - ] - } - try: - response = requests.post(url, headers=headers, data=json.dumps(data)) - event_logger.debug(f"Response from AI ContentSafety: {response.json()}") - - # Log the response - response_json = response.json() - - # Check if attackDetected is True in either userPromptAnalysis or documentsAnalysis - if response_json['documentsAnalysis'][0]['attackDetected']: - event_logger.info(f"Response from AI ContentSafety: {response.json()}") - event_logger.info(f"Prompt injection Detected in: {prompt}") - return False # Fail if attackDetected is True - - except Exception as err: - event_logger.error(f"Failed to perform prompt injection detection: {err}") - - return True + """ + Check for prompt injection in the given prompt. + + If the prompt contains more than 10000 unicode characters, it will be split into chunks of 10000 characters. + + If the prompt contains more than 100000 unicode characters, the function will return False. + + Args: + prompt (str): The prompt to check for injection. + Returns: + bool: True if no prompt injection is detected, False otherwise. + """ + # Calculate number of unicode characters in the prompt + unicode_count = len(prompt.encode('utf-8')) + + # If the number of unicode characters exceeds 100000 then return False + if unicode_count > 100000: + event_logger.warning("Prompt contained more than 100000 unicode characters.") + return False + + # If the number of unicode characters exceeds 10000 then split the prompt into chunks of 10000 characters + if unicode_count > 10000: + prompt = [prompt[i:i+10000] for i in range(0, len(prompt), 10000)] + else: + prompt = [prompt] + + # Loop through the number of prompt chunks and check for prompt injection if any return False then return False + for chunk in prompt: + result = openai.call_ai_content_safety(chunk) + if not result: + return False + + return True \ No newline at end of file diff --git a/src/main.py b/src/main.py index 1024369..fb4f59d 100644 --- a/src/main.py +++ b/src/main.py @@ -45,10 +45,12 @@ async def app_lifespan(app: FastAPI): app = FastAPI(lifespan=app_lifespan) +# Instrument application with Prometheus metrics +metric_insgtrumentation = metrics.begin_instrumentation() +metric_insgtrumentation.instrument(app).expose(app) + # Apply the authentication middleware to the app app.middleware('http')(authenticate) -# Instrument application with Prometheus metrics -metrics.begin_instrumentation(app) # GET - API Root @app.get("/") @@ -74,21 +76,22 @@ async def create_predefined_query(request: Request, query: PredefinedQuery): try: prompt_dict = read_prompt(query.prompt_type, query.request_id, query.prompt) except Exception as err: - event_logger.error(f"Request ID: {query.request_id} | Error: {err}") + event_logger.info(f"Request ID: {query.request_id} | Error: {err}") raise HTTPException(status_code=500, detail=f"Internal Server Error. Request ID: {query.request_id}") # Check for prompt injection result = check_for_prompt_inj(prompt_dict["prompt"]) if not result: - event_logger.warning("Possible malicious content detected, including prompt injection.") - raise HTTPException(status_code=400, detail="Bad Request: Possible malicious content detected in prompt.") - + event_logger.warning(f"Possible malicious content detected, including prompt injection in request ID: {query.request_id}") + raise HTTPException(status_code=400, detail=f"Bad Request: Possible malicious content detected in prompt. If you believe this is a mistake, reference: {query.request_id}") + event_logger.info(f"Checked for injection") + response = call_openai(prompt_dict["system"], prompt_dict["user"], prompt_dict["prompt"], prompt_dict["model"]) # If response is not successful, raise a 500 status code and log error in logs/requests.log if not response["success"]: error = response["error"] - request_logger.error(f"Request ID: {query.request_id} | Error: {error} | Time Elapsed: {response['time']}") + request_logger.info(f"Request ID: {query.request_id} | Error: {error} | Time Elapsed: {response['time']}") raise HTTPException(status_code=500, detail=f"Internal Server Error. Request ID: {query.request_id} | Error: {error}") else: request_logger.info(f"Request ID: {query.request_id} | Successful: True | Time Elapsed: {response['time']}") @@ -103,6 +106,13 @@ async def create_info_query(request: Request, info_query: CustomQuery): # If compression is enabled, reduce the prompt tokens if info_query.compression_enabled: info_query.prompt = reduce_prompt_tokens(info_query.prompt) + + # Check for prompt injection + result = check_for_prompt_inj(info_query.prompt) + if not result: + event_logger.warning(f"Possible malicious content detected, including prompt injection in request ID: {info_query.request_id}") + raise HTTPException(status_code=400, detail=f"Bad Request: Possible malicious content detected in prompt. If you believe this is a mistake, reference: {info_query.request_id}") + event_logger.info(f"Checked for injection") # Read in custom prompt defined in the request and make call to OpenAI response = call_openai(info_query.system_prompt, info_query.user_prompt, info_query.prompt, info_query.model) @@ -120,3 +130,5 @@ async def create_info_query(request: Request, info_query: CustomQuery): @app.get("/healthcheck") def healthcheck(): return {"status": "healthy"} + + diff --git a/src/tests/test_prompt_ops.py b/src/tests/test_prompt_ops.py index 9946065..357f7c0 100644 --- a/src/tests/test_prompt_ops.py +++ b/src/tests/test_prompt_ops.py @@ -61,6 +61,16 @@ def test_prompt_compression(self): print(compressed_prompt) assert match is None + def test_prompt_injection_detection(self): + # Generate a prompt with more than 100000 unicode characters + prompt = "a"*100001 + + # Check for prompt injection + result = prompts.check_for_prompt_inj(prompt) + + # Assert that the function returns False + assert not result + def test_retrieve_prompts_exc(self): # rename prompt file to force an exception os.rename('system_prompts/prompts.json', 'system_prompts/prompts.json.test')