Skip to content

Commit

Permalink
[Executor] Allow run async tools in executor (#1099)
Browse files Browse the repository at this point in the history
# Description

In current implementation, pf assume all the tools are normal
functions/methods.
When user provides an async function decorated by @tool, the behavior of
pf would be strange.

In this PR, we checked the function type of the passed function, then go
to a different call stack to ensure async function works well.

TODO: Currently if one async tool call another async tool, there would
be no traces support. We need to adjust the logic to add such support.

# All Promptflow Contribution checklist:
- [ ] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [ ] Title of the pull request is clear and informative.
- [ ] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.

---------

Co-authored-by: Heyi Tang <[email protected]>
  • Loading branch information
thy09 and Heyi Tang authored Nov 13, 2023
1 parent ce8ee8e commit caf34c6
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 27 deletions.
85 changes: 68 additions & 17 deletions src/promptflow/promptflow/_core/flow_execution_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

import inspect
import logging
import threading
import time
Expand All @@ -10,7 +11,7 @@
from logging import WARNING
from typing import Callable, List

from promptflow._core._errors import ToolExecutionError
from promptflow._core._errors import ToolExecutionError, UnexpectedError
from promptflow._core.cache_manager import AbstractCacheManager, CacheInfo, CacheResult
from promptflow._core.operation_context import OperationContext
from promptflow._core.tool import parse_all_args
Expand Down Expand Up @@ -88,23 +89,9 @@ def invoke_tool_with_cache(self, f: Callable, argnames: List[str], args, kwargs)
output = Tracer.pop(output)
return output

self._current_tool = f
all_args = parse_all_args(argnames, args, kwargs)
node_run_id = self._generate_current_node_run_id()
flow_logger.info(f"Executing node {self._current_node.name}. node run id: {node_run_id}")
parent_run_id = f"{self._run_id}_{self._line_number}" if self._line_number is not None else self._run_id
run_info: RunInfo = self._run_tracker.start_node_run(
node=self._current_node.name,
flow_run_id=self._run_id,
parent_run_id=parent_run_id,
run_id=node_run_id,
index=self._line_number,
)
run_info = self.prepare_node_run(f, argnames, args, kwargs)
node_run_id = run_info.run_id

run_info.index = self._line_number
run_info.variant_id = self._variant_id

self._run_tracker.set_inputs(node_run_id, {key: value for key, value in all_args.items() if key != "self"})
traces = []
try:
hit_cache = False
Expand Down Expand Up @@ -145,6 +132,70 @@ def invoke_tool_with_cache(self, f: Callable, argnames: List[str], args, kwargs)
finally:
self._run_tracker.persist_node_run(run_info)

def prepare_node_run(self, f, argnames=[], args=[], kwargs={}):
self._current_tool = f
all_args = parse_all_args(argnames, args, kwargs)
node_run_id = self._generate_current_node_run_id()
flow_logger.info(f"Executing node {self._current_node.name}. node run id: {node_run_id}")
parent_run_id = f"{self._run_id}_{self._line_number}" if self._line_number is not None else self._run_id
run_info: RunInfo = self._run_tracker.start_node_run(
node=self._current_node.name,
flow_run_id=self._run_id,
parent_run_id=parent_run_id,
run_id=node_run_id,
index=self._line_number,
)
run_info.index = self._line_number
self._run_tracker.set_inputs(node_run_id, {key: value for key, value in all_args.items() if key != "self"})
return run_info

async def invoke_tool_async(self, f: Callable, kwargs):
if not inspect.iscoroutinefunction(f):
raise UnexpectedError(
message_format="Tool {function} is not a coroutine function.",
function=f.__name__
)
run_info = self.prepare_node_run(f, kwargs=kwargs)
node_run_id = run_info.run_id

traces = []
try:
Tracer.start_tracing(node_run_id)
trace = Tracer.push_tool(f, kwargs=kwargs)
trace.node_name = run_info.node
result = await self._invoke_tool_async_inner(f, kwargs)
result = Tracer.pop(result)
traces = Tracer.end_tracing()
self._current_tool = None
self._run_tracker.end_run(node_run_id, result=result, traces=traces)
flow_logger.info(f"Node {self._current_node.name} completes.")
return result
except Exception as e:
logger.exception(f"Node {self._current_node.name} in line {self._line_number} failed. Exception: {e}.")
Tracer.pop(error=e)
if not traces:
traces = Tracer.end_tracing()
self._run_tracker.end_run(node_run_id, ex=e, traces=traces)
raise
finally:
self._run_tracker.persist_node_run(run_info)

async def _invoke_tool_async_inner(self, f: Callable, kwargs):
try:
return await f(**kwargs)
except PromptflowException as e:
# All the exceptions from built-in tools are PromptflowException.
# For these cases, raise the exception directly.
if f.__module__ is not None:
e.module = f.__module__
raise e
except Exception as e:
node_name = self._current_node.name if self._current_node else f.__name__
# Otherwise, we assume the error comes from user's tool.
# For these cases, raise ToolExecutionError, which is classified as UserError
# and shows stack trace in the error message to make it easy for user to troubleshoot.
raise ToolExecutionError(node_name=node_name, module=f.__module__) from e

def invoke_tool(self, f: Callable, args, kwargs):
node_name = self._current_node.name if self._current_node else f.__name__
try:
Expand Down
22 changes: 14 additions & 8 deletions src/promptflow/promptflow/_core/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,20 @@ def tool(

def tool_decorator(func: Callable) -> Callable:
from promptflow.exceptions import UserErrorException

@functools.wraps(func)
def new_f(*args, **kwargs):
tool_invoker = ToolInvoker.active_instance()
# If there is no active tool invoker for tracing or other purposes, just call the function.
if tool_invoker is None:
return func(*args, **kwargs)
return tool_invoker.invoke_tool(func, *args, **kwargs)
if inspect.iscoroutinefunction(func):
@functools.wraps(func)
async def new_f_async(*args, **kwargs):
"""TODO: Add tracing support for async tools."""
return await func(*args, **kwargs)
new_f = new_f_async
else:
@functools.wraps(func)
def new_f(*args, **kwargs):
tool_invoker = ToolInvoker.active_instance()
# If there is no active tool invoker for tracing or other purposes, just call the function.
if tool_invoker is None:
return func(*args, **kwargs)
return tool_invoker.invoke_tool(func, *args, **kwargs)

if type is not None and type not in [k.value for k in ToolType]:
raise UserErrorException(f"Tool type {type} is not supported yet.")
Expand Down
2 changes: 1 addition & 1 deletion src/promptflow/promptflow/_core/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def end_tracing(cls, raise_ex=False):
return tracer.to_json()

@classmethod
def push_tool(cls, f, args, kwargs):
def push_tool(cls, f, args=[], kwargs={}):
obj = cls.active_instance()
sig = inspect.signature(f).parameters
all_kwargs = {**{k: v for k, v in zip(sig.keys(), args)}, **kwargs}
Expand Down
8 changes: 7 additions & 1 deletion src/promptflow/promptflow/executor/_flow_nodes_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

import asyncio
import contextvars
import inspect
from concurrent import futures
from concurrent.futures import Future, ThreadPoolExecutor
from typing import Dict, List, Tuple
Expand Down Expand Up @@ -113,7 +115,11 @@ def _exec_single_node_in_thread(self, args: Tuple[Node, DAGManager]):
f = self._tools_manager.get_tool(node.name)
kwargs = dag_manager.get_node_valid_inputs(node, f)
context.current_node = node
result = f(**kwargs)
if inspect.iscoroutinefunction(f):
# TODO: Run async functions in flow level event loop
result = asyncio.run(context.invoke_tool_async(f, kwargs=kwargs))
else:
result = f(**kwargs)
context.current_node = None
return result
finally:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def test_executor_exec_bulk_with_openai_metrics(self, dev_connections):
"script_with_import",
"package_tools",
"connection_as_input",
"async_tools",
],
)
def test_executor_exec_line(self, flow_folder, dev_connections):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from promptflow import tool
import asyncio


@tool
async def passthrough_str_and_wait(input1: str, wait_seconds=3) -> str:
assert isinstance(input1, str), f"input1 should be a string, got {input1}"
print("Wait for", wait_seconds, "seconds")
for i in range(wait_seconds):
print(i)
await asyncio.sleep(1)
return input1
36 changes: 36 additions & 0 deletions src/promptflow/tests/test_configs/flows/async_tools/flow.dag.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
inputs:
input_str:
type: string
default: Hello
outputs:
ouput1:
type: string
reference: ${async_passthrough1.output}
output2:
type: string
reference: ${async_passthrough2.output}
nodes:
- name: async_passthrough
type: python
source:
type: code
path: async_passthrough.py
inputs:
input1: ${inputs.input_str}
wait_seconds: 3
- name: async_passthrough1
type: python
source:
type: code
path: async_passthrough.py
inputs:
input1: ${async_passthrough.output}
wait_seconds: 3
- name: async_passthrough2
type: python
source:
type: code
path: async_passthrough.py
inputs:
input1: ${async_passthrough.output}
wait_seconds: 3

0 comments on commit caf34c6

Please sign in to comment.