Skip to content

Commit

Permalink
feat: Add track_error to mirror track_success (#33)
Browse files Browse the repository at this point in the history
Additionally, emit new `$ld:ai:generation:(success|error)` events on
success or failure.
  • Loading branch information
keelerm84 authored Dec 17, 2024
1 parent 80e1845 commit 404f704
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 15 deletions.
142 changes: 136 additions & 6 deletions ldai/testing/test_tracker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from time import sleep
from unittest.mock import MagicMock, call

import pytest
Expand Down Expand Up @@ -60,6 +61,43 @@ def test_tracks_duration(client: LDClient):
assert tracker.get_summary().duration == 100


def test_tracks_duration_of(client: LDClient):
context = Context.create('user-key')
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
tracker.track_duration_of(lambda: sleep(0.01))

calls = client.track.mock_calls # type: ignore

assert len(calls) == 1
assert calls[0].args[0] == '$ld:ai:duration:total'
assert calls[0].args[1] == context
assert calls[0].args[2] == {'variationKey': 'variation-key', 'configKey': 'config-key'}
assert calls[0].args[3] == pytest.approx(10, rel=10)


def test_tracks_duration_of_with_exception(client: LDClient):
context = Context.create('user-key')
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)

def sleep_and_throw():
sleep(0.01)
raise ValueError("Something went wrong")

try:
tracker.track_duration_of(sleep_and_throw)
assert False, "Should have thrown an exception"
except ValueError:
pass

calls = client.track.mock_calls # type: ignore

assert len(calls) == 1
assert calls[0].args[0] == '$ld:ai:duration:total'
assert calls[0].args[1] == context
assert calls[0].args[2] == {'variationKey': 'variation-key', 'configKey': 'config-key'}
assert calls[0].args[3] == pytest.approx(10, rel=10)


def test_tracks_token_usage(client: LDClient):
context = Context.create('user-key')
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
Expand Down Expand Up @@ -97,6 +135,7 @@ def test_tracks_bedrock_metrics(client: LDClient):

calls = [
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
call('$ld:ai:generation:success', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
call('$ld:ai:duration:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 50),
call('$ld:ai:tokens:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 330),
call('$ld:ai:tokens:input', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 220),
Expand All @@ -110,6 +149,39 @@ def test_tracks_bedrock_metrics(client: LDClient):
assert tracker.get_summary().usage == TokenUsage(330, 220, 110)


def test_tracks_bedrock_metrics_with_error(client: LDClient):
context = Context.create('user-key')
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)

bedrock_result = {
'$metadata': {'httpStatusCode': 500},
'usage': {
'totalTokens': 330,
'inputTokens': 220,
'outputTokens': 110,
},
'metrics': {
'latencyMs': 50,
}
}
tracker.track_bedrock_converse_metrics(bedrock_result)

calls = [
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
call('$ld:ai:generation:error', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
call('$ld:ai:duration:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 50),
call('$ld:ai:tokens:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 330),
call('$ld:ai:tokens:input', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 220),
call('$ld:ai:tokens:output', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 110),
]

client.track.assert_has_calls(calls) # type: ignore

assert tracker.get_summary().success is False
assert tracker.get_summary().duration == 50
assert tracker.get_summary().usage == TokenUsage(330, 220, 110)


def test_tracks_openai_metrics(client: LDClient):
context = Context.create('user-key')
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
Expand All @@ -129,6 +201,8 @@ def to_dict(self):
tracker.track_openai_metrics(lambda: Result())

calls = [
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
call('$ld:ai:generation:success', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
call('$ld:ai:tokens:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 330),
call('$ld:ai:tokens:input', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 220),
call('$ld:ai:tokens:output', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 110),
Expand All @@ -139,6 +213,29 @@ def to_dict(self):
assert tracker.get_summary().usage == TokenUsage(330, 220, 110)


def test_tracks_openai_metrics_with_exception(client: LDClient):
context = Context.create('user-key')
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)

def raise_exception():
raise ValueError("Something went wrong")

try:
tracker.track_openai_metrics(raise_exception)
assert False, "Should have thrown an exception"
except ValueError:
pass

calls = [
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
call('$ld:ai:generation:error', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
]

client.track.assert_has_calls(calls, any_order=False) # type: ignore

assert tracker.get_summary().usage is None


@pytest.mark.parametrize(
"kind,label",
[
Expand Down Expand Up @@ -166,11 +263,44 @@ def test_tracks_success(client: LDClient):
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
tracker.track_success()

client.track.assert_called_with( # type: ignore
'$ld:ai:generation',
context,
{'variationKey': 'variation-key', 'configKey': 'config-key'},
1
)
calls = [
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
call('$ld:ai:generation:success', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
]

client.track.assert_has_calls(calls) # type: ignore

assert tracker.get_summary().success is True


def test_tracks_error(client: LDClient):
context = Context.create('user-key')
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
tracker.track_error()

calls = [
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
call('$ld:ai:generation:error', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
]

client.track.assert_has_calls(calls) # type: ignore

assert tracker.get_summary().success is False


def test_error_overwrites_success(client: LDClient):
context = Context.create('user-key')
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)
tracker.track_success()
tracker.track_error()

calls = [
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
call('$ld:ai:generation:success', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
call('$ld:ai:generation:error', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
]

client.track.assert_has_calls(calls) # type: ignore

assert tracker.get_summary().success is False
58 changes: 49 additions & 9 deletions ldai/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,20 @@ def track_duration_of(self, func):
"""
Automatically track the duration of an AI operation.
An exception occurring during the execution of the function will still
track the duration. The exception will be re-thrown.
:param func: Function to track.
:return: Result of the tracked function.
"""
start_time = time.time()
result = func()
end_time = time.time()
duration = int((end_time - start_time) * 1000) # duration in milliseconds
self.track_duration(duration)
try:
result = func()
finally:
end_time = time.time()
duration = int((end_time - start_time) * 1000) # duration in milliseconds
self.track_duration(duration)

return result

def track_feedback(self, feedback: Dict[str, FeedbackKind]) -> None:
Expand Down Expand Up @@ -146,32 +152,66 @@ def track_success(self) -> None:
self._ld_client.track(
'$ld:ai:generation', self._context, self.__get_track_data(), 1
)
self._ld_client.track(
'$ld:ai:generation:success', self._context, self.__get_track_data(), 1
)

def track_error(self) -> None:
"""
Track an unsuccessful AI generation attempt.
"""
self._summary._success = False
self._ld_client.track(
'$ld:ai:generation', self._context, self.__get_track_data(), 1
)
self._ld_client.track(
'$ld:ai:generation:error', self._context, self.__get_track_data(), 1
)

def track_openai_metrics(self, func):
"""
Track OpenAI-specific operations.
This function will track the duration of the operation, the token
usage, and the success or error status.
If the provided function throws, then this method will also throw.
In the case the provided function throws, this function will record the
duration and an error.
A failed operation will not have any token usage data.
:param func: Function to track.
:return: Result of the tracked function.
"""
result = self.track_duration_of(func)
if hasattr(result, 'usage') and hasattr(result.usage, 'to_dict'):
self.track_tokens(_openai_to_token_usage(result.usage.to_dict()))
try:
result = self.track_duration_of(func)
self.track_success()
if hasattr(result, 'usage') and hasattr(result.usage, 'to_dict'):
self.track_tokens(_openai_to_token_usage(result.usage.to_dict()))
except Exception:
self.track_error()
raise

return result

def track_bedrock_converse_metrics(self, res: dict) -> dict:
"""
Track AWS Bedrock conversation operations.
This function will track the duration of the operation, the token
usage, and the success or error status.
:param res: Response dictionary from Bedrock.
:return: The original response dictionary.
"""
status_code = res.get('$metadata', {}).get('httpStatusCode', 0)
if status_code == 200:
self.track_success()
elif status_code >= 400:
# Potentially add error tracking in the future.
pass
self.track_error()
if res.get('metrics', {}).get('latencyMs'):
self.track_duration(res['metrics']['latencyMs'])
if res.get('usage'):
Expand Down

0 comments on commit 404f704

Please sign in to comment.