Skip to content

Commit

Permalink
Merge pull request #6433 from BerriAI/litellm_fix_audit_logs
Browse files Browse the repository at this point in the history
(proxy audit logs) fix serialization error on audit logs
  • Loading branch information
ishaan-jaff authored Oct 26, 2024
2 parents c03e5da + 6f1c06f commit b3141e1
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 39 deletions.
2 changes: 1 addition & 1 deletion litellm/proxy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def ui_label(self):
return ui_labels.get(self.value, "")


class LitellmTableNames(enum.Enum):
class LitellmTableNames(str, enum.Enum):
"""
Enum for Table Names used by LiteLLM
"""
Expand Down
43 changes: 43 additions & 0 deletions litellm/proxy/management_helpers/audit_logs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
Functions to create audit logs for LiteLLM Proxy
"""

import json

import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import LiteLLM_AuditLogs


async def create_audit_log_for_update(request_data: LiteLLM_AuditLogs):
from litellm.proxy.proxy_server import premium_user, prisma_client

if premium_user is not True:
return

if litellm.store_audit_logs is not True:
return
if prisma_client is None:
raise Exception("prisma_client is None, no DB connected")

verbose_proxy_logger.debug("creating audit log for %s", request_data)

if isinstance(request_data.updated_values, dict):
request_data.updated_values = json.dumps(request_data.updated_values)

if isinstance(request_data.before_value, dict):
request_data.before_value = json.dumps(request_data.before_value)

_request_data = request_data.model_dump(exclude_none=True)

try:
await prisma_client.db.litellm_auditlog.create(
data={
**_request_data, # type: ignore
}
)
except Exception as e:
# [Non-Blocking Exception. Do not allow blocking LLM API call]
verbose_proxy_logger.error(f"Failed Creating audit log {e}")

return
2 changes: 1 addition & 1 deletion litellm/proxy/proxy_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ general_settings:
proxy_batch_write_at: 60 # Batch write spend updates every 60s

litellm_settings:
success_callback: ["langfuse"]
store_audit_logs: true

# https://docs.litellm.ai/docs/proxy/reliability#default-fallbacks
default_fallbacks: ["gpt-4o-2024-08-06", "claude-3-5-sonnet-20240620"]
Expand Down
39 changes: 2 additions & 37 deletions litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def generate_feedback_box():
)
from litellm.proxy.management_endpoints.team_endpoints import router as team_router
from litellm.proxy.management_endpoints.ui_sso import router as ui_sso_router
from litellm.proxy.management_helpers.audit_logs import create_audit_log_for_update
from litellm.proxy.openai_files_endpoints.files_endpoints import is_known_model
from litellm.proxy.openai_files_endpoints.files_endpoints import (
router as openai_files_router,
Expand Down Expand Up @@ -6398,11 +6399,7 @@ async def list_end_user(
--header 'Authorization: Bearer sk-1234'
```
"""
from litellm.proxy.proxy_server import (
create_audit_log_for_update,
litellm_proxy_admin_name,
prisma_client,
)
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client

if (
user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN
Expand Down Expand Up @@ -6433,38 +6430,6 @@ async def list_end_user(
return returned_response


async def create_audit_log_for_update(request_data: LiteLLM_AuditLogs):
if premium_user is not True:
return

if litellm.store_audit_logs is not True:
return
if prisma_client is None:
raise Exception("prisma_client is None, no DB connected")

verbose_proxy_logger.debug("creating audit log for %s", request_data)

if isinstance(request_data.updated_values, dict):
request_data.updated_values = json.dumps(request_data.updated_values)

if isinstance(request_data.before_value, dict):
request_data.before_value = json.dumps(request_data.before_value)

_request_data = request_data.dict(exclude_none=True)

try:
await prisma_client.db.litellm_auditlog.create(
data={
**_request_data, # type: ignore
}
)
except Exception as e:
# [Non-Blocking Exception. Do not allow blocking LLM API call]
verbose_proxy_logger.error(f"Failed Creating audit log {e}")

return


#### BUDGET TABLE MANAGEMENT ####


Expand Down
152 changes: 152 additions & 0 deletions tests/local_testing/test_audit_logs_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import os
import sys
import traceback
import uuid
from datetime import datetime

from dotenv import load_dotenv
from fastapi import Request
from fastapi.routing import APIRoute


import io
import os
import time

# this file is to test litellm/proxy

sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import asyncio
import logging

load_dotenv()

import pytest
import uuid
import litellm
from litellm._logging import verbose_proxy_logger

from litellm.proxy.proxy_server import (
LitellmUserRoles,
audio_transcriptions,
chat_completion,
completion,
embeddings,
image_generation,
model_list,
moderations,
new_end_user,
user_api_key_auth,
)

from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend

verbose_proxy_logger.setLevel(level=logging.DEBUG)

from starlette.datastructures import URL

from litellm.proxy.management_helpers.audit_logs import create_audit_log_for_update
from litellm.proxy._types import LiteLLM_AuditLogs, LitellmTableNames
from litellm.caching.caching import DualCache
from unittest.mock import patch, AsyncMock

proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache())
import json


@pytest.mark.asyncio
async def test_create_audit_log_for_update_premium_user():
"""
Basic unit test for create_audit_log_for_update
Test that the audit log is created when a premium user updates a team
"""
with patch("litellm.proxy.proxy_server.premium_user", True), patch(
"litellm.store_audit_logs", True
), patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:

mock_prisma.db.litellm_auditlog.create = AsyncMock()

request_data = LiteLLM_AuditLogs(
id="test_id",
updated_at=datetime.now(),
changed_by="test_changed_by",
action="updated",
table_name=LitellmTableNames.TEAM_TABLE_NAME,
object_id="test_object_id",
updated_values=json.dumps({"key": "value"}),
before_value=json.dumps({"old_key": "old_value"}),
)

await create_audit_log_for_update(request_data)

mock_prisma.db.litellm_auditlog.create.assert_called_once_with(
data={
"id": "test_id",
"updated_at": request_data.updated_at,
"changed_by": request_data.changed_by,
"action": request_data.action,
"table_name": request_data.table_name,
"object_id": request_data.object_id,
"updated_values": request_data.updated_values,
"before_value": request_data.before_value,
}
)


@pytest.fixture
def prisma_client():
from litellm.proxy.proxy_cli import append_query_params

### add connection pool + pool timeout args
params = {"connection_limit": 100, "pool_timeout": 60}
database_url = os.getenv("DATABASE_URL")
modified_url = append_query_params(database_url, params)
os.environ["DATABASE_URL"] = modified_url

# Assuming PrismaClient is a class that needs to be instantiated
prisma_client = PrismaClient(
database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj
)

return prisma_client


@pytest.mark.asyncio()
async def test_create_audit_log_in_db(prisma_client):
print("prisma client=", prisma_client)

setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm.proxy.proxy_server, "premium_user", True)
setattr(litellm, "store_audit_logs", True)

await litellm.proxy.proxy_server.prisma_client.connect()
audit_log_id = f"audit_log_id_{uuid.uuid4()}"

# create a audit log for /key/generate
request_data = LiteLLM_AuditLogs(
id=audit_log_id,
updated_at=datetime.now(),
changed_by="test_changed_by",
action="updated",
table_name=LitellmTableNames.TEAM_TABLE_NAME,
object_id="test_object_id",
updated_values=json.dumps({"key": "value"}),
before_value=json.dumps({"old_key": "old_value"}),
)

await create_audit_log_for_update(request_data)

await asyncio.sleep(1)

# now read the last log from the db
last_log = await prisma_client.db.litellm_auditlog.find_first(
where={"id": audit_log_id}
)

assert last_log.id == audit_log_id

setattr(litellm, "store_audit_logs", False)

0 comments on commit b3141e1

Please sign in to comment.