Skip to content

Commit

Permalink
[Executor] Update openai tokens to trace (#1918)
Browse files Browse the repository at this point in the history
# Description

This pull request primarily introduces changes to the
`src/promptflow/promptflow/_core/tracer.py` file. The main focus of the
changes is to enhance the tracing mechanism by adding a `TokenCollector`
class for collecting OpenAI tokens associated with spans, and
integrating it into the existing tracing flow. The changes also include
a minor import modification for the OpenTelemetry library.

Main Changes:

*
[`src/promptflow/promptflow/_core/tracer.py`](diffhunk://#diff-8f8c2ae53e5ffd37a14e8a899119fbb2742486db8faab6df3fcf506e1b720ad8R159-R198):
Introduced a new class `TokenCollector` to collect OpenAI tokens
associated with spans. The class provides methods to collect tokens,
merge tokens for parent spans, and retrieve tokens for a given span id.
*
[`src/promptflow/promptflow/_core/tracer.py`](diffhunk://#diff-8f8c2ae53e5ffd37a14e8a899119fbb2742486db8faab6df3fcf506e1b720ad8R272-R274):
The `enrich_span_with_output` function now also sets token attributes to
the span if any tokens are associated with the span id.
*
[`src/promptflow/promptflow/_core/tracer.py`](diffhunk://#diff-8f8c2ae53e5ffd37a14e8a899119fbb2742486db8faab6df3fcf506e1b720ad8R341-R350):
The `wrapped` function (both synchronous and asynchronous versions) now
collects OpenAI tokens if the trace type is `LLM`, and collects tokens
for the parent span before returning the output.

# All Promptflow Contribution checklist:
- [x] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [x] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [x] Title of the pull request is clear and informative.
- [x] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.

---------

Co-authored-by: Lina Tang <[email protected]>
  • Loading branch information
lumoslnt and Lina Tang authored Feb 2, 2024
1 parent 669f080 commit d0a2829
Showing 1 changed file with 57 additions and 4 deletions.
61 changes: 57 additions & 4 deletions src/promptflow/promptflow/_core/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from collections.abc import Iterator
from contextvars import ContextVar
from datetime import datetime
from threading import Lock
from typing import Callable, Dict, List, Optional

from opentelemetry import trace
import opentelemetry.trace as otel_trace
from opentelemetry.trace.status import StatusCode

from promptflow._core.generator_proxy import GeneratorProxy, generate_from_proxy
Expand All @@ -24,7 +25,8 @@

from .thread_local_singleton import ThreadLocalSingleton

open_telemetry_tracer = trace.get_tracer("promptflow")

open_telemetry_tracer = otel_trace.get_tracer("promptflow")


class Tracer(ThreadLocalSingleton):
Expand Down Expand Up @@ -153,6 +155,46 @@ def _format_error(error: Exception) -> dict:
}


class TokenCollector():
_lock = Lock()

def __init__(self):
self._span_id_to_tokens = {}

def collect_openai_tokens(self, span, output):
span_id = span.get_span_context().span_id
if not inspect.isgenerator(output) and hasattr(output, "usage") and output.usage is not None:
tokens = {
f"__computed__.cumulative_token_count.{k.split('_')[0]}": v for k, v in output.usage.dict().items()
}
if tokens:
with self._lock:
self._span_id_to_tokens[span_id] = tokens

def collect_openai_tokens_for_parent_span(self, span):
tokens = self.try_get_openai_tokens(span.get_span_context().span_id)
if tokens:
if not hasattr(span, "parent") or span.parent is None:
return
parent_span_id = span.parent.span_id
with self._lock:
if parent_span_id in self._span_id_to_tokens:
merged_tokens = {
key: self._span_id_to_tokens[parent_span_id].get(key, 0) + tokens.get(key, 0)
for key in set(self._span_id_to_tokens[parent_span_id]) | set(tokens)
}
self._span_id_to_tokens[parent_span_id] = merged_tokens
else:
self._span_id_to_tokens[parent_span_id] = tokens

def try_get_openai_tokens(self, span_id):
with self._lock:
return self._span_id_to_tokens.get(span_id, None)


token_collector = TokenCollector()


def _create_trace_from_function_call(
f, *, args=None, kwargs=None, args_to_ignore: Optional[List[str]] = None, trace_type=TraceType.FUNCTION
):
Expand Down Expand Up @@ -242,6 +284,9 @@ def enrich_span_with_output(span, output):
try:
serialized_output = serialize_attribute(output)
span.set_attribute("output", serialized_output)
tokens = token_collector.try_get_openai_tokens(span.get_span_context().span_id)
if tokens:
span.set_attributes(tokens)
except Exception as e:
logging.warning(f"Failed to enrich span with output: {e}")

Expand Down Expand Up @@ -313,12 +358,16 @@ async def wrapped(*args, **kwargs):
Tracer.push(trace)
enrich_span_with_input(span, trace.inputs)
output = await func(*args, **kwargs)
if trace_type == TraceType.LLM:
token_collector.collect_openai_tokens(span, output)
enrich_span_with_output(span, output)
span.set_status(StatusCode.OK)
return Tracer.pop(output)
output = Tracer.pop(output)
except Exception as e:
Tracer.pop(None, e)
raise
token_collector.collect_openai_tokens_for_parent_span(span)
return output

wrapped.__original_function = func

Expand Down Expand Up @@ -358,12 +407,16 @@ def wrapped(*args, **kwargs):
Tracer.push(trace)
enrich_span_with_input(span, trace.inputs)
output = func(*args, **kwargs)
if trace_type == TraceType.LLM:
token_collector.collect_openai_tokens(span, output)
enrich_span_with_output(span, output)
span.set_status(StatusCode.OK)
return Tracer.pop(output)
output = Tracer.pop(output)
except Exception as e:
Tracer.pop(None, e)
raise
token_collector.collect_openai_tokens_for_parent_span(span)
return output

wrapped.__original_function = func

Expand Down

0 comments on commit d0a2829

Please sign in to comment.