diff --git a/ldai/testing/test_tracker.py b/ldai/testing/test_tracker.py index b4bac6d..934197a 100644 --- a/ldai/testing/test_tracker.py +++ b/ldai/testing/test_tracker.py @@ -1,10 +1,10 @@ -from unittest.mock import MagicMock +from unittest.mock import MagicMock, call import pytest from ldclient import Config, Context, LDClient from ldclient.integrations.test_data import TestData -from ldai.tracker import FeedbackKind, LDAIConfigTracker +from ldai.tracker import FeedbackKind, LDAIConfigTracker, TokenUsage @pytest.fixture @@ -60,6 +60,85 @@ def test_tracks_duration(client: LDClient): assert tracker.get_summary().duration == 100 +def test_tracks_token_usage(client: LDClient): + context = Context.create('user-key') + tracker = LDAIConfigTracker(client, "variation-key", "config-key", context) + + tokens = TokenUsage(300, 200, 100) + tracker.track_tokens(tokens) + + calls = [ + call('$ld:ai:tokens:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 300), + call('$ld:ai:tokens:input', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 200), + call('$ld:ai:tokens:output', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 100), + ] + + client.track.assert_has_calls(calls) # type: ignore + + assert tracker.get_summary().usage == tokens + + +def test_tracks_bedrock_metrics(client: LDClient): + context = Context.create('user-key') + tracker = LDAIConfigTracker(client, "variation-key", "config-key", context) + + bedrock_result = { + '$metadata': {'httpStatusCode': 200}, + 'usage': { + 'totalTokens': 330, + 'inputTokens': 220, + 'outputTokens': 110, + }, + 'metrics': { + 'latencyMs': 50, + } + } + tracker.track_bedrock_converse_metrics(bedrock_result) + + calls = [ + call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1), + call('$ld:ai:duration:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 50), + call('$ld:ai:tokens:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 330), + call('$ld:ai:tokens:input', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 220), + call('$ld:ai:tokens:output', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 110), + ] + + client.track.assert_has_calls(calls) # type: ignore + + assert tracker.get_summary().success is True + assert tracker.get_summary().duration == 50 + assert tracker.get_summary().usage == TokenUsage(330, 220, 110) + + +def test_tracks_openai_metrics(client: LDClient): + context = Context.create('user-key') + tracker = LDAIConfigTracker(client, "variation-key", "config-key", context) + + class Result: + def __init__(self): + self.usage = Usage() + + class Usage: + def to_dict(self): + return { + 'total_tokens': 330, + 'prompt_tokens': 220, + 'completion_tokens': 110, + } + + tracker.track_openai_metrics(lambda: Result()) + + calls = [ + call('$ld:ai:tokens:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 330), + call('$ld:ai:tokens:input', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 220), + call('$ld:ai:tokens:output', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 110), + ] + + client.track.assert_has_calls(calls, any_order=False) # type: ignore + + assert tracker.get_summary().usage == TokenUsage(330, 220, 110) + + @pytest.mark.parametrize( "kind,label", [ diff --git a/ldai/tracker.py b/ldai/tracker.py index 7b12c50..2016b02 100644 --- a/ldai/tracker.py +++ b/ldai/tracker.py @@ -1,26 +1,11 @@ import time from dataclasses import dataclass from enum import Enum -from typing import Dict, Optional, Union +from typing import Dict, Optional from ldclient import Context, LDClient -@dataclass -class TokenMetrics: - """ - Metrics for token usage in AI operations. - - :param total: Total number of tokens used. - :param input: Number of input tokens. - :param output: Number of output tokens. - """ - - total: int - input: int - output: int # type: ignore - - class FeedbackKind(Enum): """ Types of feedback that can be provided for AI operations. @@ -35,99 +20,14 @@ class TokenUsage: """ Tracks token usage for AI operations. - :param total_tokens: Total number of tokens used. - :param prompt_tokens: Number of tokens in the prompt. - :param completion_tokens: Number of tokens in the completion. - """ - - total_tokens: int - prompt_tokens: int - completion_tokens: int - - def to_metrics(self): - """ - Convert token usage to metrics format. - - :return: Dictionary containing token metrics. - """ - return { - 'total': self['total_tokens'], - 'input': self['prompt_tokens'], - 'output': self['completion_tokens'], - } - - -@dataclass -class LDOpenAIUsage: - """ - LaunchDarkly-specific OpenAI usage tracking. - - :param total_tokens: Total number of tokens used. - :param prompt_tokens: Number of tokens in the prompt. - :param completion_tokens: Number of tokens in the completion. - """ - - total_tokens: int - prompt_tokens: int - completion_tokens: int - - -@dataclass -class OpenAITokenUsage: - """ - Tracks OpenAI-specific token usage. - """ - - def __init__(self, data: LDOpenAIUsage): - """ - Initialize OpenAI token usage tracking. - - :param data: OpenAI usage data. - """ - self.total_tokens = data.total_tokens - self.prompt_tokens = data.prompt_tokens - self.completion_tokens = data.completion_tokens - - def to_metrics(self) -> TokenMetrics: - """ - Convert OpenAI token usage to metrics format. - - :return: TokenMetrics object containing usage data. - """ - return TokenMetrics( - total=self.total_tokens, - input=self.prompt_tokens, - output=self.completion_tokens, - ) - - -@dataclass -class BedrockTokenUsage: - """ - Tracks AWS Bedrock-specific token usage. + :param total: Total number of tokens used. + :param input: Number of tokens in the prompt. + :param output: Number of tokens in the completion. """ - def __init__(self, data: dict): - """ - Initialize Bedrock token usage tracking. - - :param data: Dictionary containing Bedrock usage data. - """ - self.totalTokens = data.get('totalTokens', 0) - self.inputTokens = data.get('inputTokens', 0) - self.outputTokens = data.get('outputTokens', 0) - - def to_metrics(self) -> TokenMetrics: - """ - Convert Bedrock token usage to metrics format. - - :return: TokenMetrics object containing usage data. - """ - return TokenMetrics( - total=self.totalTokens, - input=self.inputTokens, - output=self.outputTokens, - ) + total: int + input: int + output: int class LDAIMetricSummary: @@ -154,7 +54,7 @@ def feedback(self) -> Optional[Dict[str, FeedbackKind]]: return self._feedback @property - def usage(self) -> Optional[Union[TokenUsage, BedrockTokenUsage]]: + def usage(self) -> Optional[TokenUsage]: return self._usage @@ -255,8 +155,8 @@ def track_openai_metrics(self, func): :return: Result of the tracked function. """ result = self.track_duration_of(func) - if result.usage: - self.track_tokens(OpenAITokenUsage(result.usage)) + if hasattr(result, 'usage') and hasattr(result.usage, 'to_dict'): + self.track_tokens(_openai_to_token_usage(result.usage.to_dict())) return result def track_bedrock_converse_metrics(self, res: dict) -> dict: @@ -275,37 +175,36 @@ def track_bedrock_converse_metrics(self, res: dict) -> dict: if res.get('metrics', {}).get('latencyMs'): self.track_duration(res['metrics']['latencyMs']) if res.get('usage'): - self.track_tokens(BedrockTokenUsage(res['usage'])) + self.track_tokens(_bedrock_to_token_usage(res['usage'])) return res - def track_tokens(self, tokens: Union[TokenUsage, BedrockTokenUsage]) -> None: + def track_tokens(self, tokens: TokenUsage) -> None: """ Track token usage metrics. :param tokens: Token usage data from either custom, OpenAI, or Bedrock sources. """ self._summary._usage = tokens - token_metrics = tokens.to_metrics() - if token_metrics.total > 0: + if tokens.total > 0: self._ld_client.track( '$ld:ai:tokens:total', self._context, self.__get_track_data(), - token_metrics.total, + tokens.total, ) - if token_metrics.input > 0: + if tokens.input > 0: self._ld_client.track( '$ld:ai:tokens:input', self._context, self.__get_track_data(), - token_metrics.input, + tokens.input, ) - if token_metrics.output > 0: + if tokens.output > 0: self._ld_client.track( '$ld:ai:tokens:output', self._context, self.__get_track_data(), - token_metrics.output, + tokens.output, ) def get_summary(self) -> LDAIMetricSummary: @@ -315,3 +214,31 @@ def get_summary(self) -> LDAIMetricSummary: :return: Summary of AI metrics. """ return self._summary + + +def _bedrock_to_token_usage(data: dict) -> TokenUsage: + """ + Convert a Bedrock usage dictionary to a TokenUsage object. + + :param data: Dictionary containing Bedrock usage data. + :return: TokenUsage object containing usage data. + """ + return TokenUsage( + total=data.get('totalTokens', 0), + input=data.get('inputTokens', 0), + output=data.get('outputTokens', 0), + ) + + +def _openai_to_token_usage(data: dict) -> TokenUsage: + """ + Convert an OpenAI usage dictionary to a TokenUsage object. + + :param data: Dictionary containing OpenAI usage data. + :return: TokenUsage object containing usage data. + """ + return TokenUsage( + total=data.get('total_tokens', 0), + input=data.get('prompt_tokens', 0), + output=data.get('completion_tokens', 0), + )