-
-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #6433 from BerriAI/litellm_fix_audit_logs
(proxy audit logs) fix serialization error on audit logs
- Loading branch information
Showing
5 changed files
with
199 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
""" | ||
Functions to create audit logs for LiteLLM Proxy | ||
""" | ||
|
||
import json | ||
|
||
import litellm | ||
from litellm._logging import verbose_proxy_logger | ||
from litellm.proxy._types import LiteLLM_AuditLogs | ||
|
||
|
||
async def create_audit_log_for_update(request_data: LiteLLM_AuditLogs): | ||
from litellm.proxy.proxy_server import premium_user, prisma_client | ||
|
||
if premium_user is not True: | ||
return | ||
|
||
if litellm.store_audit_logs is not True: | ||
return | ||
if prisma_client is None: | ||
raise Exception("prisma_client is None, no DB connected") | ||
|
||
verbose_proxy_logger.debug("creating audit log for %s", request_data) | ||
|
||
if isinstance(request_data.updated_values, dict): | ||
request_data.updated_values = json.dumps(request_data.updated_values) | ||
|
||
if isinstance(request_data.before_value, dict): | ||
request_data.before_value = json.dumps(request_data.before_value) | ||
|
||
_request_data = request_data.model_dump(exclude_none=True) | ||
|
||
try: | ||
await prisma_client.db.litellm_auditlog.create( | ||
data={ | ||
**_request_data, # type: ignore | ||
} | ||
) | ||
except Exception as e: | ||
# [Non-Blocking Exception. Do not allow blocking LLM API call] | ||
verbose_proxy_logger.error(f"Failed Creating audit log {e}") | ||
|
||
return |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
import os | ||
import sys | ||
import traceback | ||
import uuid | ||
from datetime import datetime | ||
|
||
from dotenv import load_dotenv | ||
from fastapi import Request | ||
from fastapi.routing import APIRoute | ||
|
||
|
||
import io | ||
import os | ||
import time | ||
|
||
# this file is to test litellm/proxy | ||
|
||
sys.path.insert( | ||
0, os.path.abspath("../..") | ||
) # Adds the parent directory to the system path | ||
import asyncio | ||
import logging | ||
|
||
load_dotenv() | ||
|
||
import pytest | ||
import uuid | ||
import litellm | ||
from litellm._logging import verbose_proxy_logger | ||
|
||
from litellm.proxy.proxy_server import ( | ||
LitellmUserRoles, | ||
audio_transcriptions, | ||
chat_completion, | ||
completion, | ||
embeddings, | ||
image_generation, | ||
model_list, | ||
moderations, | ||
new_end_user, | ||
user_api_key_auth, | ||
) | ||
|
||
from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend | ||
|
||
verbose_proxy_logger.setLevel(level=logging.DEBUG) | ||
|
||
from starlette.datastructures import URL | ||
|
||
from litellm.proxy.management_helpers.audit_logs import create_audit_log_for_update | ||
from litellm.proxy._types import LiteLLM_AuditLogs, LitellmTableNames | ||
from litellm.caching.caching import DualCache | ||
from unittest.mock import patch, AsyncMock | ||
|
||
proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache()) | ||
import json | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_create_audit_log_for_update_premium_user(): | ||
""" | ||
Basic unit test for create_audit_log_for_update | ||
Test that the audit log is created when a premium user updates a team | ||
""" | ||
with patch("litellm.proxy.proxy_server.premium_user", True), patch( | ||
"litellm.store_audit_logs", True | ||
), patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma: | ||
|
||
mock_prisma.db.litellm_auditlog.create = AsyncMock() | ||
|
||
request_data = LiteLLM_AuditLogs( | ||
id="test_id", | ||
updated_at=datetime.now(), | ||
changed_by="test_changed_by", | ||
action="updated", | ||
table_name=LitellmTableNames.TEAM_TABLE_NAME, | ||
object_id="test_object_id", | ||
updated_values=json.dumps({"key": "value"}), | ||
before_value=json.dumps({"old_key": "old_value"}), | ||
) | ||
|
||
await create_audit_log_for_update(request_data) | ||
|
||
mock_prisma.db.litellm_auditlog.create.assert_called_once_with( | ||
data={ | ||
"id": "test_id", | ||
"updated_at": request_data.updated_at, | ||
"changed_by": request_data.changed_by, | ||
"action": request_data.action, | ||
"table_name": request_data.table_name, | ||
"object_id": request_data.object_id, | ||
"updated_values": request_data.updated_values, | ||
"before_value": request_data.before_value, | ||
} | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def prisma_client(): | ||
from litellm.proxy.proxy_cli import append_query_params | ||
|
||
### add connection pool + pool timeout args | ||
params = {"connection_limit": 100, "pool_timeout": 60} | ||
database_url = os.getenv("DATABASE_URL") | ||
modified_url = append_query_params(database_url, params) | ||
os.environ["DATABASE_URL"] = modified_url | ||
|
||
# Assuming PrismaClient is a class that needs to be instantiated | ||
prisma_client = PrismaClient( | ||
database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj | ||
) | ||
|
||
return prisma_client | ||
|
||
|
||
@pytest.mark.asyncio() | ||
async def test_create_audit_log_in_db(prisma_client): | ||
print("prisma client=", prisma_client) | ||
|
||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | ||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | ||
setattr(litellm.proxy.proxy_server, "premium_user", True) | ||
setattr(litellm, "store_audit_logs", True) | ||
|
||
await litellm.proxy.proxy_server.prisma_client.connect() | ||
audit_log_id = f"audit_log_id_{uuid.uuid4()}" | ||
|
||
# create a audit log for /key/generate | ||
request_data = LiteLLM_AuditLogs( | ||
id=audit_log_id, | ||
updated_at=datetime.now(), | ||
changed_by="test_changed_by", | ||
action="updated", | ||
table_name=LitellmTableNames.TEAM_TABLE_NAME, | ||
object_id="test_object_id", | ||
updated_values=json.dumps({"key": "value"}), | ||
before_value=json.dumps({"old_key": "old_value"}), | ||
) | ||
|
||
await create_audit_log_for_update(request_data) | ||
|
||
await asyncio.sleep(1) | ||
|
||
# now read the last log from the db | ||
last_log = await prisma_client.db.litellm_auditlog.find_first( | ||
where={"id": audit_log_id} | ||
) | ||
|
||
assert last_log.id == audit_log_id | ||
|
||
setattr(litellm, "store_audit_logs", False) |