diff --git a/evals/cli/oaieval.py b/evals/cli/oaieval.py index e9cd432..06c3c5f 100644 --- a/evals/cli/oaieval.py +++ b/evals/cli/oaieval.py @@ -220,6 +220,7 @@ def to_number(x: str) -> Union[int, float, str]: **extra_eval_params, ) result = eval.run(recorder) + add_token_usage_to_result(result, recorder) recorder.record_final_report(result) if not (args.dry_run or args.local_run): @@ -258,6 +259,34 @@ def build_recorder( ) +def add_token_usage_to_result(result: dict[str, Any], recorder: RecorderBase) -> None: + """ + Add token usage from logged sampling events to the result dictionary from the recorder. + """ + usage_events = [] + sampling_events = recorder.get_events("sampling") + for event in sampling_events: + if "usage" in event.data: + usage_events.append(dict(event.data["usage"])) + logger.info(f"Found {len(usage_events)}/{len(sampling_events)} sampling events with usage data") + if usage_events: + # Sum up the usage of all samples (assumes the usage is the same for all samples) + total_usage = { + key: sum(u[key] if u[key] is not None else 0 for u in usage_events) + for key in usage_events[0] + } + total_usage_str = "\n".join(f"{key}: {value:,}" for key, value in total_usage.items()) + logger.info(f"Token usage from {len(usage_events)} sampling events:\n{total_usage_str}") + for key, value in total_usage.items(): + keyname = f"usage_{key}" + if keyname not in result: + result[keyname] = value + else: + logger.warning( + f"Usage key {keyname} already exists in result, not adding {keyname}" + ) + + def main() -> None: parser = get_parser() args = cast(OaiEvalArguments, parser.parse_args(sys.argv[1:])) diff --git a/evals/completion_fns/openai.py b/evals/completion_fns/openai.py index ed50818..f3075f6 100644 --- a/evals/completion_fns/openai.py +++ b/evals/completion_fns/openai.py @@ -88,7 +88,12 @@ def __call__( **{**kwargs, **self.extra_options}, ) result = OpenAICompletionResult(raw_data=result, prompt=openai_create_prompt) - record_sampling(prompt=result.prompt, sampled=result.get_completions()) + record_sampling( + prompt=result.prompt, + sampled=result.get_completions(), + model=result.raw_data.model, + usage=result.raw_data.usage, + ) return result @@ -133,5 +138,10 @@ def __call__( **{**kwargs, **self.extra_options}, ) result = OpenAIChatCompletionResult(raw_data=result, prompt=openai_create_prompt) - record_sampling(prompt=result.prompt, sampled=result.get_completions()) + record_sampling( + prompt=result.prompt, + sampled=result.get_completions(), + model=result.raw_data.model, + usage=result.raw_data.usage, + ) return result