Skip to content

Commit

Permalink
Instrument acreate's for open-ai (#935)
Browse files Browse the repository at this point in the history
* Instrument acreate's for open ai async

* Remove duplicated vendor

* Re-use expected & input payloads in tests
  • Loading branch information
hmstepanek authored Oct 18, 2023
1 parent 1d86430 commit 51e7362
Show file tree
Hide file tree
Showing 4 changed files with 300 additions and 63 deletions.
211 changes: 196 additions & 15 deletions newrelic/hooks/mlmodel_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@
def wrap_embedding_create(wrapped, instance, args, kwargs):
transaction = current_transaction()
if not transaction:
return
return wrapped(*args, **kwargs)

ft_name = callable_name(wrapped)
with FunctionTrace(ft_name) as ft:
response = wrapped(*args, **kwargs)

if not response:
return
return response

available_metadata = get_trace_linking_metadata()
span_id = available_metadata.get("span.id", "")
Expand Down Expand Up @@ -94,14 +94,14 @@ def wrap_chat_completion_create(wrapped, instance, args, kwargs):
transaction = current_transaction()

if not transaction:
return
return wrapped(*args, **kwargs)

ft_name = callable_name(wrapped)
with FunctionTrace(ft_name) as ft:
response = wrapped(*args, **kwargs)

if not response:
return
return response

custom_attrs_dict = transaction._custom_params
conversation_id = custom_attrs_dict.get("conversation_id", "")
Expand All @@ -113,10 +113,15 @@ def wrap_chat_completion_create(wrapped, instance, args, kwargs):

response_headers = getattr(response, "_nr_response_headers", None)
response_model = response.get("model", "")
response_id = response.get("id", "")
settings = transaction.settings if transaction.settings is not None else global_settings()
response_id = response.get("id")
request_id = response_headers.get("x-request-id", "")

api_key = getattr(response, "api_key", None)
response_usage = response.get("usage", {})
settings = transaction.settings if transaction.settings is not None else global_settings()

messages = kwargs.get("messages", [])
choices = response.get("choices", [])

chat_completion_summary_dict = {
"id": chat_completion_id,
Expand All @@ -126,18 +131,18 @@ def wrap_chat_completion_create(wrapped, instance, args, kwargs):
"trace_id": trace_id,
"transaction_id": transaction._transaction_id,
"request_id": request_id,
"api_key_last_four_digits": f"sk-{response.api_key[-4:]}",
"api_key_last_four_digits": f"sk-{api_key[-4:]}" if api_key else "",
"duration": ft.duration,
"request.model": kwargs.get("model") or kwargs.get("engine") or "",
"response.model": response_model,
"response.organization": response.organization,
"response.organization": getattr(response, "organization", ""),
"response.usage.completion_tokens": response_usage.get("completion_tokens", "") if any(response_usage) else "",
"response.usage.total_tokens": response_usage.get("total_tokens", "") if any(response_usage) else "",
"response.usage.prompt_tokens": response_usage.get("prompt_tokens", "") if any(response_usage) else "",
"request.temperature": kwargs.get("temperature", ""),
"request.max_tokens": kwargs.get("max_tokens", ""),
"response.choices.finish_reason": response.choices[0].finish_reason,
"response.api_type": response.api_type,
"response.choices.finish_reason": choices[0].finish_reason if choices else "",
"response.api_type": getattr(response, "api_type", ""),
"response.headers.llmVersion": response_headers.get("openai-version", ""),
"response.headers.ratelimitLimitRequests": check_rate_limit_header(
response_headers, "x-ratelimit-limit-requests", True
Expand All @@ -158,11 +163,13 @@ def wrap_chat_completion_create(wrapped, instance, args, kwargs):
response_headers, "x-ratelimit-remaining-requests", True
),
"vendor": "openAI",
"response.number_of_messages": len(kwargs.get("messages", [])) + len(response.choices),
"response.number_of_messages": len(messages) + len(choices),
}

transaction.record_ml_event("LlmChatCompletionSummary", chat_completion_summary_dict)
message_list = list(kwargs.get("messages", [])) + [response.choices[0].message]
message_list = list(messages)
if choices:
message_list.extend([choices[0].message])

create_chat_completion_message_event(
transaction,
Expand Down Expand Up @@ -208,9 +215,6 @@ def create_chat_completion_message_event(
request_id,
conversation_id,
):
if not transaction:
return

for index, message in enumerate(message_list):
chat_completion_message_dict = {
"id": "%s-%s" % (response_id, index),
Expand All @@ -230,6 +234,179 @@ def create_chat_completion_message_event(
transaction.record_ml_event("LlmChatCompletionMessage", chat_completion_message_dict)


async def wrap_embedding_acreate(wrapped, instance, args, kwargs):
transaction = current_transaction()
if not transaction:
return await wrapped(*args, **kwargs)

ft_name = callable_name(wrapped)
with FunctionTrace(ft_name) as ft:
response = await wrapped(*args, **kwargs)

if not response:
return response

embedding_id = str(uuid.uuid4())
response_headers = getattr(response, "_nr_response_headers", None)

settings = transaction.settings if transaction.settings is not None else global_settings()
available_metadata = get_trace_linking_metadata()
span_id = available_metadata.get("span.id", "")
trace_id = available_metadata.get("trace.id", "")

api_key = getattr(response, "api_key", None)
usage = response.get("usage")
total_tokens = ""
prompt_tokens = ""
if usage:
total_tokens = usage.get("total_tokens", "")
prompt_tokens = usage.get("prompt_tokens", "")

embedding_dict = {
"id": embedding_id,
"duration": ft.duration,
"api_key_last_four_digits": f"sk-{api_key[-4:]}" if api_key else "",
"request_id": response_headers.get("x-request-id", ""),
"input": kwargs.get("input", ""),
"response.api_type": getattr(response, "api_type", ""),
"response.organization": getattr(response, "organization", ""),
"request.model": kwargs.get("model") or kwargs.get("engine") or "",
"response.model": response.get("model", ""),
"appName": settings.app_name,
"trace_id": trace_id,
"transaction_id": transaction._transaction_id,
"span_id": span_id,
"response.usage.total_tokens": total_tokens,
"response.usage.prompt_tokens": prompt_tokens,
"response.headers.llmVersion": response_headers.get("openai-version", ""),
"response.headers.ratelimitLimitRequests": check_rate_limit_header(
response_headers, "x-ratelimit-limit-requests", True
),
"response.headers.ratelimitLimitTokens": check_rate_limit_header(
response_headers, "x-ratelimit-limit-tokens", True
),
"response.headers.ratelimitResetTokens": check_rate_limit_header(
response_headers, "x-ratelimit-reset-tokens", False
),
"response.headers.ratelimitResetRequests": check_rate_limit_header(
response_headers, "x-ratelimit-reset-requests", False
),
"response.headers.ratelimitRemainingTokens": check_rate_limit_header(
response_headers, "x-ratelimit-remaining-tokens", True
),
"response.headers.ratelimitRemainingRequests": check_rate_limit_header(
response_headers, "x-ratelimit-remaining-requests", True
),
"vendor": "openAI",
}

transaction.record_ml_event("LlmEmbedding", embedding_dict)
return response


async def wrap_chat_completion_acreate(wrapped, instance, args, kwargs):
transaction = current_transaction()

if not transaction:
return await wrapped(*args, **kwargs)

ft_name = callable_name(wrapped)
with FunctionTrace(ft_name) as ft:
response = await wrapped(*args, **kwargs)

if not response:
return response

conversation_id = transaction._custom_params.get("conversation_id", "")

chat_completion_id = str(uuid.uuid4())
available_metadata = get_trace_linking_metadata()
span_id = available_metadata.get("span.id", "")
trace_id = available_metadata.get("trace.id", "")

response_headers = getattr(response, "_nr_response_headers", None)
response_model = response.get("model", "")
settings = transaction.settings if transaction.settings is not None else global_settings()
response_id = response.get("id")
request_id = response_headers.get("x-request-id", "")

api_key = getattr(response, "api_key", None)
usage = response.get("usage")
total_tokens = ""
prompt_tokens = ""
completion_tokens = ""
if usage:
total_tokens = usage.get("total_tokens", "")
prompt_tokens = usage.get("prompt_tokens", "")
completion_tokens = usage.get("completion_tokens", "")

messages = kwargs.get("messages", [])
choices = response.get("choices", [])

chat_completion_summary_dict = {
"id": chat_completion_id,
"appName": settings.app_name,
"conversation_id": conversation_id,
"request_id": request_id,
"span_id": span_id,
"trace_id": trace_id,
"transaction_id": transaction._transaction_id,
"api_key_last_four_digits": f"sk-{api_key[-4:]}" if api_key else "",
"duration": ft.duration,
"request.model": kwargs.get("model") or kwargs.get("engine") or "",
"response.model": response_model,
"response.organization": getattr(response, "organization", ""),
"response.usage.completion_tokens": completion_tokens,
"response.usage.total_tokens": total_tokens,
"response.usage.prompt_tokens": prompt_tokens,
"response.number_of_messages": len(messages) + len(choices),
"request.temperature": kwargs.get("temperature", ""),
"request.max_tokens": kwargs.get("max_tokens", ""),
"response.choices.finish_reason": choices[0].get("finish_reason", "") if choices else "",
"response.api_type": getattr(response, "api_type", ""),
"response.headers.llmVersion": response_headers.get("openai-version", ""),
"response.headers.ratelimitLimitRequests": check_rate_limit_header(
response_headers, "x-ratelimit-limit-requests", True
),
"response.headers.ratelimitLimitTokens": check_rate_limit_header(
response_headers, "x-ratelimit-limit-tokens", True
),
"response.headers.ratelimitResetTokens": check_rate_limit_header(
response_headers, "x-ratelimit-reset-tokens", False
),
"response.headers.ratelimitResetRequests": check_rate_limit_header(
response_headers, "x-ratelimit-reset-requests", False
),
"response.headers.ratelimitRemainingTokens": check_rate_limit_header(
response_headers, "x-ratelimit-remaining-tokens", True
),
"response.headers.ratelimitRemainingRequests": check_rate_limit_header(
response_headers, "x-ratelimit-remaining-requests", True
),
"vendor": "openAI",
}

transaction.record_ml_event("LlmChatCompletionSummary", chat_completion_summary_dict)
message_list = list(messages)
if choices:
message_list.extend([choices[0].message])

create_chat_completion_message_event(
transaction,
settings.app_name,
message_list,
chat_completion_id,
span_id,
trace_id,
response_model,
response_id,
request_id,
conversation_id,
)

return response


def wrap_convert_to_openai_object(wrapped, instance, args, kwargs):
resp = args[0]
returned_response = wrapped(*args, **kwargs)
Expand All @@ -247,8 +424,12 @@ def instrument_openai_util(module):
def instrument_openai_api_resources_embedding(module):
if hasattr(module.Embedding, "create"):
wrap_function_wrapper(module, "Embedding.create", wrap_embedding_create)
if hasattr(module.Embedding, "acreate"):
wrap_function_wrapper(module, "Embedding.acreate", wrap_embedding_acreate)


def instrument_openai_api_resources_chat_completion(module):
if hasattr(module.ChatCompletion, "create"):
wrap_function_wrapper(module, "ChatCompletion.create", wrap_chat_completion_create)
if hasattr(module.ChatCompletion, "acreate"):
wrap_function_wrapper(module, "ChatCompletion.acreate", wrap_chat_completion_acreate)
15 changes: 15 additions & 0 deletions tests/mlmodel_openai/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
collector_available_fixture,
)

from newrelic.api.time_trace import current_trace
from newrelic.api.transaction import current_transaction
from newrelic.common.object_wrapper import wrap_function_wrapper

_default_settings = {
Expand All @@ -49,6 +51,19 @@
OPENAI_AUDIT_LOG_CONTENTS = {}


@pytest.fixture
def set_trace_info():
def set_info():
txn = current_transaction()
if txn:
txn._trace_id = "trace-id"
trace = current_trace()
if trace:
trace.guid = "span-id"

return set_info


@pytest.fixture(autouse=True, scope="session")
def openai_server():
"""
Expand Down
Loading

0 comments on commit 51e7362

Please sign in to comment.