Skip to content

Commit

Permalink
fix!: Unify tracking token to use only TokenUsage (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
keelerm84 authored Dec 13, 2024
1 parent e425b1f commit 80e1845
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 121 deletions.
83 changes: 81 additions & 2 deletions ldai/testing/test_tracker.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from unittest.mock import MagicMock
from unittest.mock import MagicMock, call

import pytest
from ldclient import Config, Context, LDClient
from ldclient.integrations.test_data import TestData

from ldai.tracker import FeedbackKind, LDAIConfigTracker
from ldai.tracker import FeedbackKind, LDAIConfigTracker, TokenUsage


@pytest.fixture
Expand Down Expand Up @@ -60,6 +60,85 @@ def test_tracks_duration(client: LDClient):
assert tracker.get_summary().duration == 100


def test_tracks_token_usage(client: LDClient):
context = Context.create('user-key')
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)

tokens = TokenUsage(300, 200, 100)
tracker.track_tokens(tokens)

calls = [
call('$ld:ai:tokens:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 300),
call('$ld:ai:tokens:input', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 200),
call('$ld:ai:tokens:output', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 100),
]

client.track.assert_has_calls(calls) # type: ignore

assert tracker.get_summary().usage == tokens


def test_tracks_bedrock_metrics(client: LDClient):
context = Context.create('user-key')
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)

bedrock_result = {
'$metadata': {'httpStatusCode': 200},
'usage': {
'totalTokens': 330,
'inputTokens': 220,
'outputTokens': 110,
},
'metrics': {
'latencyMs': 50,
}
}
tracker.track_bedrock_converse_metrics(bedrock_result)

calls = [
call('$ld:ai:generation', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 1),
call('$ld:ai:duration:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 50),
call('$ld:ai:tokens:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 330),
call('$ld:ai:tokens:input', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 220),
call('$ld:ai:tokens:output', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 110),
]

client.track.assert_has_calls(calls) # type: ignore

assert tracker.get_summary().success is True
assert tracker.get_summary().duration == 50
assert tracker.get_summary().usage == TokenUsage(330, 220, 110)


def test_tracks_openai_metrics(client: LDClient):
context = Context.create('user-key')
tracker = LDAIConfigTracker(client, "variation-key", "config-key", context)

class Result:
def __init__(self):
self.usage = Usage()

class Usage:
def to_dict(self):
return {
'total_tokens': 330,
'prompt_tokens': 220,
'completion_tokens': 110,
}

tracker.track_openai_metrics(lambda: Result())

calls = [
call('$ld:ai:tokens:total', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 330),
call('$ld:ai:tokens:input', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 220),
call('$ld:ai:tokens:output', context, {'variationKey': 'variation-key', 'configKey': 'config-key'}, 110),
]

client.track.assert_has_calls(calls, any_order=False) # type: ignore

assert tracker.get_summary().usage == TokenUsage(330, 220, 110)


@pytest.mark.parametrize(
"kind,label",
[
Expand Down
165 changes: 46 additions & 119 deletions ldai/tracker.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,11 @@
import time
from dataclasses import dataclass
from enum import Enum
from typing import Dict, Optional, Union
from typing import Dict, Optional

from ldclient import Context, LDClient


@dataclass
class TokenMetrics:
"""
Metrics for token usage in AI operations.
:param total: Total number of tokens used.
:param input: Number of input tokens.
:param output: Number of output tokens.
"""

total: int
input: int
output: int # type: ignore


class FeedbackKind(Enum):
"""
Types of feedback that can be provided for AI operations.
Expand All @@ -35,99 +20,14 @@ class TokenUsage:
"""
Tracks token usage for AI operations.
:param total_tokens: Total number of tokens used.
:param prompt_tokens: Number of tokens in the prompt.
:param completion_tokens: Number of tokens in the completion.
"""

total_tokens: int
prompt_tokens: int
completion_tokens: int

def to_metrics(self):
"""
Convert token usage to metrics format.
:return: Dictionary containing token metrics.
"""
return {
'total': self['total_tokens'],
'input': self['prompt_tokens'],
'output': self['completion_tokens'],
}


@dataclass
class LDOpenAIUsage:
"""
LaunchDarkly-specific OpenAI usage tracking.
:param total_tokens: Total number of tokens used.
:param prompt_tokens: Number of tokens in the prompt.
:param completion_tokens: Number of tokens in the completion.
"""

total_tokens: int
prompt_tokens: int
completion_tokens: int


@dataclass
class OpenAITokenUsage:
"""
Tracks OpenAI-specific token usage.
"""

def __init__(self, data: LDOpenAIUsage):
"""
Initialize OpenAI token usage tracking.
:param data: OpenAI usage data.
"""
self.total_tokens = data.total_tokens
self.prompt_tokens = data.prompt_tokens
self.completion_tokens = data.completion_tokens

def to_metrics(self) -> TokenMetrics:
"""
Convert OpenAI token usage to metrics format.
:return: TokenMetrics object containing usage data.
"""
return TokenMetrics(
total=self.total_tokens,
input=self.prompt_tokens,
output=self.completion_tokens,
)


@dataclass
class BedrockTokenUsage:
"""
Tracks AWS Bedrock-specific token usage.
:param total: Total number of tokens used.
:param input: Number of tokens in the prompt.
:param output: Number of tokens in the completion.
"""

def __init__(self, data: dict):
"""
Initialize Bedrock token usage tracking.
:param data: Dictionary containing Bedrock usage data.
"""
self.totalTokens = data.get('totalTokens', 0)
self.inputTokens = data.get('inputTokens', 0)
self.outputTokens = data.get('outputTokens', 0)

def to_metrics(self) -> TokenMetrics:
"""
Convert Bedrock token usage to metrics format.
:return: TokenMetrics object containing usage data.
"""
return TokenMetrics(
total=self.totalTokens,
input=self.inputTokens,
output=self.outputTokens,
)
total: int
input: int
output: int


class LDAIMetricSummary:
Expand All @@ -154,7 +54,7 @@ def feedback(self) -> Optional[Dict[str, FeedbackKind]]:
return self._feedback

@property
def usage(self) -> Optional[Union[TokenUsage, BedrockTokenUsage]]:
def usage(self) -> Optional[TokenUsage]:
return self._usage


Expand Down Expand Up @@ -255,8 +155,8 @@ def track_openai_metrics(self, func):
:return: Result of the tracked function.
"""
result = self.track_duration_of(func)
if result.usage:
self.track_tokens(OpenAITokenUsage(result.usage))
if hasattr(result, 'usage') and hasattr(result.usage, 'to_dict'):
self.track_tokens(_openai_to_token_usage(result.usage.to_dict()))
return result

def track_bedrock_converse_metrics(self, res: dict) -> dict:
Expand All @@ -275,37 +175,36 @@ def track_bedrock_converse_metrics(self, res: dict) -> dict:
if res.get('metrics', {}).get('latencyMs'):
self.track_duration(res['metrics']['latencyMs'])
if res.get('usage'):
self.track_tokens(BedrockTokenUsage(res['usage']))
self.track_tokens(_bedrock_to_token_usage(res['usage']))
return res

def track_tokens(self, tokens: Union[TokenUsage, BedrockTokenUsage]) -> None:
def track_tokens(self, tokens: TokenUsage) -> None:
"""
Track token usage metrics.
:param tokens: Token usage data from either custom, OpenAI, or Bedrock sources.
"""
self._summary._usage = tokens
token_metrics = tokens.to_metrics()
if token_metrics.total > 0:
if tokens.total > 0:
self._ld_client.track(
'$ld:ai:tokens:total',
self._context,
self.__get_track_data(),
token_metrics.total,
tokens.total,
)
if token_metrics.input > 0:
if tokens.input > 0:
self._ld_client.track(
'$ld:ai:tokens:input',
self._context,
self.__get_track_data(),
token_metrics.input,
tokens.input,
)
if token_metrics.output > 0:
if tokens.output > 0:
self._ld_client.track(
'$ld:ai:tokens:output',
self._context,
self.__get_track_data(),
token_metrics.output,
tokens.output,
)

def get_summary(self) -> LDAIMetricSummary:
Expand All @@ -315,3 +214,31 @@ def get_summary(self) -> LDAIMetricSummary:
:return: Summary of AI metrics.
"""
return self._summary


def _bedrock_to_token_usage(data: dict) -> TokenUsage:
"""
Convert a Bedrock usage dictionary to a TokenUsage object.
:param data: Dictionary containing Bedrock usage data.
:return: TokenUsage object containing usage data.
"""
return TokenUsage(
total=data.get('totalTokens', 0),
input=data.get('inputTokens', 0),
output=data.get('outputTokens', 0),
)


def _openai_to_token_usage(data: dict) -> TokenUsage:
"""
Convert an OpenAI usage dictionary to a TokenUsage object.
:param data: Dictionary containing OpenAI usage data.
:return: TokenUsage object containing usage data.
"""
return TokenUsage(
total=data.get('total_tokens', 0),
input=data.get('prompt_tokens', 0),
output=data.get('completion_tokens', 0),
)

0 comments on commit 80e1845

Please sign in to comment.