Skip to content

Commit

Permalink
Add resume to batch engine (#2003)
Browse files Browse the repository at this point in the history
# Description

This pull request adds the resume function for batch run in batch
engine.

The most important changes are:

1. Addition of new feature:
*
[`src/promptflow/promptflow/batch/_batch_engine.py`](diffhunk://#diff-ecf0905f2116abe08b5cf2931b856bf39aa8b38a30fadd9042538d076dbfde80R20-R29):
Added logic to handle resuming from a previous run. In resume run, the
previous completed and failed lines are dropped and the other lines are
processed. Added `_copy_previous_run_result` to load previous flow run
output from output.jsonl, copy image files to output_dir, and extract
aggregation inputs for aggregation node.

2. Addition of new utility functions:
*
[`src/promptflow/promptflow/_utils/utils.py`](diffhunk://#diff-d12fdd7b90cc1748f1d3e1237b4f357ba7f66740445d117beeb68ed174d1e86eR74-R81):
Added `load_list_from_jsonl` function to load a list from a jsonl file
and `copy_file_except` function to copy all files from one directory to
another excluding a specific file.
[[1]](diffhunk://#diff-d12fdd7b90cc1748f1d3e1237b4f357ba7f66740445d117beeb68ed174d1e86eR74-R81)
[[2]](diffhunk://#diff-d12fdd7b90cc1748f1d3e1237b4f357ba7f66740445d117beeb68ed174d1e86eR332-R359)

3. Rearrange functions:
*
[`src/promptflow/promptflow/executor/flow_executor.py`](diffhunk://#diff-faa6c81d614b7e41b18a42a93139d961d92afa9aa9dd0b72cb6b7176d7541e69L663-L671):
Moved `_extract_aggregation_inputs` and `_extract_aggregation_input`
functions to utils to be used in more general cases.

4. Changes to test files:
*
[`src/promptflow/tests/executor/e2etests/test_batch_engine.py`](diffhunk://#diff-7e540aa137597e4a83cd0e2592915e9fe80c2a976e79df2f9b461254c15faca0R34-R36):
Added e2etest to resume batch run and resume batch run with aggregation.

# All Promptflow Contribution checklist:
- [ ] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [ ] Title of the pull request is clear and informative.
- [ ] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.

---------

Co-authored-by: Min Shi <[email protected]>
  • Loading branch information
Jasmin3q and Min Shi authored Feb 23, 2024
1 parent 5a15859 commit 61356cc
Show file tree
Hide file tree
Showing 60 changed files with 654 additions and 20 deletions.
6 changes: 6 additions & 0 deletions src/promptflow/promptflow/_core/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,9 @@ class DuplicateToolMappingError(ValidationException):
"""Exception raised when multiple tools are linked to the same deprecated tool id."""

pass


class ResumeCopyError(SystemErrorException):
"""Exception raised when failed to copy the results when resuming the run."""

pass
15 changes: 14 additions & 1 deletion src/promptflow/promptflow/_utils/execution_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from typing import AbstractSet, Any, Dict, List, Mapping

from promptflow._utils.logger_utils import logger
from promptflow.contracts.flow import Flow, FlowInputDefinition, InputValueType
from promptflow.contracts.flow import Flow, FlowInputDefinition, InputAssignment, InputValueType
from promptflow.contracts.run_info import FlowRunInfo, Status
from promptflow.executor import _input_assignment_parser


def apply_default_value_for_input(inputs: Dict[str, FlowInputDefinition], line_inputs: Mapping) -> Dict[str, Any]:
Expand Down Expand Up @@ -56,3 +57,15 @@ def get_aggregation_inputs_properties(flow: Flow) -> AbstractSet[str]:
def collect_lines(indexes: List[int], kvs: Mapping[str, List]) -> Mapping[str, List]:
"""Collect the values from the kvs according to the indexes."""
return {k: [v[i] for i in indexes] for k, v in kvs.items()}


def extract_aggregation_inputs(flow: Flow, nodes_outputs: dict) -> Dict[str, Any]:
"""Extract the aggregation inputs of a flow from the nodes outputs."""
_aggregation_inputs_references = get_aggregation_inputs_properties(flow)
return {prop: _parse_aggregation_input(nodes_outputs, prop) for prop in _aggregation_inputs_references}


def _parse_aggregation_input(nodes_outputs: dict, aggregation_input_property: str):
"""Parse the value of the aggregation input from the nodes outputs."""
assign = InputAssignment.deserialize(aggregation_input_property)
return _input_assignment_parser.parse_value(assign, nodes_outputs, {})
37 changes: 37 additions & 0 deletions src/promptflow/promptflow/_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import logging
import os
import re
import shutil
import time
import traceback
from datetime import datetime
Expand Down Expand Up @@ -70,6 +71,14 @@ def dump_list_to_jsonl(file_path: Union[str, Path], list_data: List[Dict]):
jsonl_file.write("\n")


def load_list_from_jsonl(file: Union[str, Path]):
content = []
with open(file, "r", encoding=DEFAULT_ENCODING) as fin:
for line in fin:
content.append(json.loads(line))
return content


def transpose(values: List[Dict[str, Any]], keys: Optional[List] = None) -> Dict[str, List]:
keys = keys or list(values[0].keys())
return {key: [v.get(key) for v in values] for key in keys}
Expand Down Expand Up @@ -329,3 +338,31 @@ def _match_reference(env_val: str):
return None, None
name, key = m.groups()
return name, key


def copy_file_except(src_dir, dst_dir, exclude_file):
"""
Copy all files from src_dir to dst_dir recursively, excluding a specific file
directly under the root of src_dir.
:param src_dir: Source directory path
:type src_dir: str
:param dst_dir: Destination directory path
:type dst_dir: str
:param exclude_file: Name of the file to exclude from copying
:type exclude_file: str
"""
os.makedirs(dst_dir, exist_ok=True)

for root, dirs, files in os.walk(src_dir):
rel_path = os.path.relpath(root, src_dir)
current_dst_dir = os.path.join(dst_dir, rel_path)

os.makedirs(current_dst_dir, exist_ok=True)

for file in files:
if rel_path == "." and file == exclude_file:
continue # Skip the excluded file
src_file_path = os.path.join(root, file)
dst_file_path = os.path.join(current_dst_dir, file)
shutil.copy2(src_file_path, dst_file_path)
92 changes: 86 additions & 6 deletions src/promptflow/promptflow/batch/_batch_engine.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

import asyncio
import signal
import threading
Expand All @@ -11,20 +10,23 @@
from typing import Any, Dict, List, Mapping, Optional

from promptflow._constants import LANGUAGE_KEY, LINE_NUMBER_KEY, LINE_TIMEOUT_SEC, FlowLanguage
from promptflow._core._errors import UnexpectedError
from promptflow._core._errors import ResumeCopyError, UnexpectedError
from promptflow._core.operation_context import OperationContext
from promptflow._utils.async_utils import async_run_allowing_running_loop
from promptflow._utils.context_utils import _change_working_dir
from promptflow._utils.execution_utils import (
apply_default_value_for_input,
collect_lines,
extract_aggregation_inputs,
get_aggregation_inputs_properties,
handle_line_failures,
)
from promptflow._utils.logger_utils import bulk_logger
from promptflow._utils.utils import (
copy_file_except,
dump_list_to_jsonl,
get_int_env_var,
load_list_from_jsonl,
log_progress,
resolve_dir_to_absolute,
transpose,
Expand All @@ -42,7 +44,7 @@
from promptflow.executor._line_execution_process_pool import signal_handler
from promptflow.executor._result import AggregationResult, LineResult
from promptflow.executor.flow_validator import FlowValidator
from promptflow.storage._run_storage import AbstractBatchRunStorage, AbstractRunStorage
from promptflow.storage import AbstractBatchRunStorage, AbstractRunStorage

OUTPUT_FILE_NAME = "output.jsonl"
DEFAULT_CONCURRENCY = 10
Expand Down Expand Up @@ -192,9 +194,21 @@ def run(
batch_inputs = batch_input_processor.process_batch_inputs(input_dirs, inputs_mapping)
# resolve output dir
output_dir = resolve_dir_to_absolute(self._working_dir, output_dir)

previous_run_results = None
if resume_from_run_storage and resume_from_run_output_dir:
previous_run_results = self._copy_previous_run_result(
resume_from_run_storage, resume_from_run_output_dir, batch_inputs, output_dir
)

# run flow in batch mode
return async_run_allowing_running_loop(
self._exec_in_task, batch_inputs, run_id, output_dir, raise_on_line_failure
self._exec_in_task,
batch_inputs,
run_id,
output_dir,
raise_on_line_failure,
previous_run_results,
)
finally:
async_run_allowing_running_loop(self._executor_proxy.destroy)
Expand All @@ -213,6 +227,66 @@ def run(
)
raise unexpected_error from e

def _copy_previous_run_result(
self,
resume_from_run_storage: AbstractBatchRunStorage,
resume_from_run_output_dir: Path,
batch_inputs: List,
output_dir: Path,
) -> List[LineResult]:
"""Duplicate the previous debug_info from resume_from_run_storage and output from resume_from_run_output_dir
to the storage of new run,
return the list of previous line results for the usage of aggregation and summarization.
"""
# Load the previous flow run output from output.jsonl
previous_run_output = load_list_from_jsonl(resume_from_run_output_dir / "output.jsonl")
previous_run_output_dict = {
each_line_output[LINE_NUMBER_KEY]: each_line_output for each_line_output in previous_run_output
}

# Copy other files from resume_from_run_output_dir to output_dir in case there are images
copy_file_except(resume_from_run_output_dir, output_dir, "output.jsonl")

try:
previous_run_results = []
for i in range(len(batch_inputs)):
previous_run_info = resume_from_run_storage.load_flow_run_info(i)

if previous_run_info and previous_run_info.status == Status.Completed:
# Load previous node run info
previous_node_run_infos = resume_from_run_storage.load_node_run_info_for_line(i)
previous_node_run_infos_dict = {node_run.node: node_run for node_run in previous_node_run_infos}
previous_node_run_outputs = {
node_info.node: node_info.output for node_info in previous_node_run_infos
}

# Extract aggregation inputs for flow with aggregation node
aggregation_inputs = extract_aggregation_inputs(self._flow, previous_node_run_outputs)

# Persist previous run info and node run info
self._storage.persist_flow_run(previous_run_info)
for node_run_info in previous_node_run_infos:
self._storage.persist_node_run(node_run_info)

# Create LineResult object for previous line result
previous_line_result = LineResult(
output=previous_run_output_dict[i],
aggregation_inputs=aggregation_inputs,
run_info=previous_run_info,
node_run_infos=previous_node_run_infos_dict,
)
previous_run_results.append(previous_line_result)

return previous_run_results
except Exception as e:
bulk_logger.error(f"Error occurred while copying previous run result. Exception: {str(e)}")
resume_copy_error = ResumeCopyError(
target=ErrorTarget.BATCH,
message_format="Failed to copy results when resuming the run. Error: {error_type_and_message}.",
error_type_and_message=f"({e.__class__.__name__}) {e}",
)
raise resume_copy_error from e

def cancel(self):
"""Cancel the batch run"""
self._is_canceled = True
Expand All @@ -223,11 +297,13 @@ async def _exec_in_task(
run_id: str = None,
output_dir: Path = None,
raise_on_line_failure: bool = False,
previous_line_results: List[LineResult] = None,
) -> BatchResult:
# if the batch run is canceled, asyncio.CancelledError will be raised and no results will be returned,
# so we pass empty line results list and aggr results and update them in _exec so that when the batch
# run is canceled we can get the current completed line results and aggr results.
line_results: List[LineResult] = []
line_results.extend(previous_line_results or [])
aggr_result = AggregationResult({}, {}, {})
task = asyncio.create_task(
self._exec(line_results, aggr_result, batch_inputs, run_id, output_dir, raise_on_line_failure)
Expand Down Expand Up @@ -260,13 +336,18 @@ async def _exec(
batch_inputs = [
apply_default_value_for_input(self._flow.inputs, each_line_input) for each_line_input in batch_inputs
]

existing_results_line_numbers = set([r.run_info.index for r in line_results])
bulk_logger.info(f"Skipped the execution of {len(existing_results_line_numbers)} existing results.")
inputs_to_run = [input for input in batch_inputs if input[LINE_NUMBER_KEY] not in existing_results_line_numbers]

run_id = run_id or str(uuid.uuid4())

# execute lines
is_timeout = False
if isinstance(self._executor_proxy, PythonExecutorProxy):
results, is_timeout = self._executor_proxy._exec_batch(
batch_inputs,
inputs_to_run,
output_dir,
run_id,
batch_timeout_sec=self._batch_timeout_sec,
Expand All @@ -278,7 +359,6 @@ async def _exec(
# TODO: Enable batch timeout for other api based executor proxy
await self._exec_batch(line_results, batch_inputs, run_id)
handle_line_failures([r.run_info for r in line_results], raise_on_line_failure)

# persist outputs to output dir
outputs = [
{LINE_NUMBER_KEY: r.run_info.index, **r.output}
Expand Down
12 changes: 2 additions & 10 deletions src/promptflow/promptflow/executor/flow_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from promptflow._utils.execution_utils import (
apply_default_value_for_input,
collect_lines,
extract_aggregation_inputs,
get_aggregation_inputs_properties,
)
from promptflow._utils.logger_utils import flow_logger, logger
Expand Down Expand Up @@ -660,15 +661,6 @@ def _exec_in_thread(self, args) -> LineResult:
self._completed_idx[line_number] = thread_name
return results

def _extract_aggregation_inputs(self, nodes_outputs: dict):
return {
prop: self._extract_aggregation_input(nodes_outputs, prop) for prop in self._aggregation_inputs_references
}

def _extract_aggregation_input(self, nodes_outputs: dict, aggregation_input_property: str):
assign = InputAssignment.deserialize(aggregation_input_property)
return _input_assignment_parser.parse_value(assign, nodes_outputs, {})

def exec_line(
self,
inputs: Mapping[str, Any],
Expand Down Expand Up @@ -833,7 +825,7 @@ def _exec_inner(
run_tracker.persist_selected_node_runs(run_info, generator_output_nodes)
run_tracker.allow_generator_types = allow_generator_output
run_tracker.end_run(run_info.run_id, result=output)
aggregation_inputs = self._extract_aggregation_inputs(nodes_outputs)
aggregation_inputs = extract_aggregation_inputs(self._flow, nodes_outputs)
return output, aggregation_inputs

def _exec(
Expand Down
Loading

0 comments on commit 61356cc

Please sign in to comment.