diff --git a/ldai/testing/test_tracker.py b/ldai/testing/test_tracker.py index 934197a..3196bfb 100644 --- a/ldai/testing/test_tracker.py +++ b/ldai/testing/test_tracker.py @@ -1,3 +1,4 @@ +from time import sleep from unittest.mock import MagicMock, call import pytest @@ -60,6 +61,43 @@ def test_tracks_duration(client: LDClient): assert tracker.get_summary().duration == 100 +def test_tracks_duration_of(client: LDClient): + context = Context.create('user-key') + tracker = LDAIConfigTracker(client, "variation-key", "config-key", context) + tracker.track_duration_of(lambda: sleep(0.01)) + + calls = client.track.mock_calls # type: ignore + + assert len(calls) == 1 + assert calls[0].args[0] == '$ld:ai:duration:total' + assert calls[0].args[1] == context + assert calls[0].args[2] == {'variationKey': 'variation-key', 'configKey': 'config-key'} + assert calls[0].args[3] == pytest.approx(10, rel=10) + + +def test_tracks_duration_of_with_exception(client: LDClient): + context = Context.create('user-key') + tracker = LDAIConfigTracker(client, "variation-key", "config-key", context) + + def sleep_and_throw(): + sleep(0.01) + raise ValueError("Something went wrong") + + try: + tracker.track_duration_of(sleep_and_throw) + assert False, "Should have thrown an exception" + except ValueError: + pass + + calls = client.track.mock_calls # type: ignore + + assert len(calls) == 1 + assert calls[0].args[0] == '$ld:ai:duration:total' + assert calls[0].args[1] == context + assert calls[0].args[2] == {'variationKey': 'variation-key', 'configKey': 'config-key'} + assert calls[0].args[3] == pytest.approx(10, rel=10) + + def test_tracks_token_usage(client: LDClient): context = Context.create('user-key') tracker = LDAIConfigTracker(client, "variation-key", "config-key", context) @@ -97,6 +135,7 @@ def test_tracks_bedrock_metrics(client: LDClient): calls = [ call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1), + call('$ld:ai:generation:success', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1), call('$ld:ai:duration:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 50), call('$ld:ai:tokens:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 330), call('$ld:ai:tokens:input', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 220), @@ -110,6 +149,39 @@ def test_tracks_bedrock_metrics(client: LDClient): assert tracker.get_summary().usage == TokenUsage(330, 220, 110) +def test_tracks_bedrock_metrics_with_error(client: LDClient): + context = Context.create('user-key') + tracker = LDAIConfigTracker(client, "variation-key", "config-key", context) + + bedrock_result = { + '$metadata': {'httpStatusCode': 500}, + 'usage': { + 'totalTokens': 330, + 'inputTokens': 220, + 'outputTokens': 110, + }, + 'metrics': { + 'latencyMs': 50, + } + } + tracker.track_bedrock_converse_metrics(bedrock_result) + + calls = [ + call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1), + call('$ld:ai:generation:error', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1), + call('$ld:ai:duration:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 50), + call('$ld:ai:tokens:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 330), + call('$ld:ai:tokens:input', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 220), + call('$ld:ai:tokens:output', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 110), + ] + + client.track.assert_has_calls(calls) # type: ignore + + assert tracker.get_summary().success is False + assert tracker.get_summary().duration == 50 + assert tracker.get_summary().usage == TokenUsage(330, 220, 110) + + def test_tracks_openai_metrics(client: LDClient): context = Context.create('user-key') tracker = LDAIConfigTracker(client, "variation-key", "config-key", context) @@ -129,6 +201,8 @@ def to_dict(self): tracker.track_openai_metrics(lambda: Result()) calls = [ + call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1), + call('$ld:ai:generation:success', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1), call('$ld:ai:tokens:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 330), call('$ld:ai:tokens:input', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 220), call('$ld:ai:tokens:output', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 110), @@ -139,6 +213,29 @@ def to_dict(self): assert tracker.get_summary().usage == TokenUsage(330, 220, 110) +def test_tracks_openai_metrics_with_exception(client: LDClient): + context = Context.create('user-key') + tracker = LDAIConfigTracker(client, "variation-key", "config-key", context) + + def raise_exception(): + raise ValueError("Something went wrong") + + try: + tracker.track_openai_metrics(raise_exception) + assert False, "Should have thrown an exception" + except ValueError: + pass + + calls = [ + call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1), + call('$ld:ai:generation:error', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1), + ] + + client.track.assert_has_calls(calls, any_order=False) # type: ignore + + assert tracker.get_summary().usage is None + + @pytest.mark.parametrize( "kind,label", [ @@ -166,11 +263,44 @@ def test_tracks_success(client: LDClient): tracker = LDAIConfigTracker(client, "variation-key", "config-key", context) tracker.track_success() - client.track.assert_called_with( # type: ignore - '$ld:ai:generation', - context, - {'variationKey': 'variation-key', 'configKey': 'config-key'}, - 1 - ) + calls = [ + call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1), + call('$ld:ai:generation:success', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1), + ] + + client.track.assert_has_calls(calls) # type: ignore assert tracker.get_summary().success is True + + +def test_tracks_error(client: LDClient): + context = Context.create('user-key') + tracker = LDAIConfigTracker(client, "variation-key", "config-key", context) + tracker.track_error() + + calls = [ + call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1), + call('$ld:ai:generation:error', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1), + ] + + client.track.assert_has_calls(calls) # type: ignore + + assert tracker.get_summary().success is False + + +def test_error_overwrites_success(client: LDClient): + context = Context.create('user-key') + tracker = LDAIConfigTracker(client, "variation-key", "config-key", context) + tracker.track_success() + tracker.track_error() + + calls = [ + call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1), + call('$ld:ai:generation:success', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1), + call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1), + call('$ld:ai:generation:error', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1), + ] + + client.track.assert_has_calls(calls) # type: ignore + + assert tracker.get_summary().success is False diff --git a/ldai/tracker.py b/ldai/tracker.py index 2016b02..8f3c15c 100644 --- a/ldai/tracker.py +++ b/ldai/tracker.py @@ -106,14 +106,20 @@ def track_duration_of(self, func): """ Automatically track the duration of an AI operation. + An exception occurring during the execution of the function will still + track the duration. The exception will be re-thrown. + :param func: Function to track. :return: Result of the tracked function. """ start_time = time.time() - result = func() - end_time = time.time() - duration = int((end_time - start_time) * 1000) # duration in milliseconds - self.track_duration(duration) + try: + result = func() + finally: + end_time = time.time() + duration = int((end_time - start_time) * 1000) # duration in milliseconds + self.track_duration(duration) + return result def track_feedback(self, feedback: Dict[str, FeedbackKind]) -> None: @@ -146,23 +152,58 @@ def track_success(self) -> None: self._ld_client.track( '$ld:ai:generation', self._context, self.__get_track_data(), 1 ) + self._ld_client.track( + '$ld:ai:generation:success', self._context, self.__get_track_data(), 1 + ) + + def track_error(self) -> None: + """ + Track an unsuccessful AI generation attempt. + """ + self._summary._success = False + self._ld_client.track( + '$ld:ai:generation', self._context, self.__get_track_data(), 1 + ) + self._ld_client.track( + '$ld:ai:generation:error', self._context, self.__get_track_data(), 1 + ) def track_openai_metrics(self, func): """ Track OpenAI-specific operations. + This function will track the duration of the operation, the token + usage, and the success or error status. + + If the provided function throws, then this method will also throw. + + In the case the provided function throws, this function will record the + duration and an error. + + A failed operation will not have any token usage data. + :param func: Function to track. :return: Result of the tracked function. """ - result = self.track_duration_of(func) - if hasattr(result, 'usage') and hasattr(result.usage, 'to_dict'): - self.track_tokens(_openai_to_token_usage(result.usage.to_dict())) + try: + result = self.track_duration_of(func) + self.track_success() + if hasattr(result, 'usage') and hasattr(result.usage, 'to_dict'): + self.track_tokens(_openai_to_token_usage(result.usage.to_dict())) + except Exception: + self.track_error() + raise + return result def track_bedrock_converse_metrics(self, res: dict) -> dict: """ Track AWS Bedrock conversation operations. + + This function will track the duration of the operation, the token + usage, and the success or error status. + :param res: Response dictionary from Bedrock. :return: The original response dictionary. """ @@ -170,8 +211,7 @@ def track_bedrock_converse_metrics(self, res: dict) -> dict: if status_code == 200: self.track_success() elif status_code >= 400: - # Potentially add error tracking in the future. - pass + self.track_error() if res.get('metrics', {}).get('latencyMs'): self.track_duration(res['metrics']['latencyMs']) if res.get('usage'):