From 78694b4820e3432d54e3234a191d23314e0bff7d Mon Sep 17 00:00:00 2001 From: Lina Tang Date: Tue, 23 Apr 2024 11:19:57 +0800 Subject: [PATCH] Refine code: not using flow contract in script executor --- .../promptflow/executor/_script_executor.py | 22 ++++++++++++------- .../promptflow/executor/flow_executor.py | 8 ++++--- .../promptflow/executor/flow_validator.py | 16 +++++++++----- .../promptflow/batch/_batch_engine.py | 3 ++- .../flow_with_signature/flow.flex.yaml | 11 ++++++++++ .../flow_with_signature/inputs.jsonl | 2 ++ .../flow_with_signature/my_flow.py | 6 +++++ 7 files changed, 50 insertions(+), 18 deletions(-) create mode 100644 src/promptflow/tests/test_configs/eager_flows/flow_with_signature/flow.flex.yaml create mode 100644 src/promptflow/tests/test_configs/eager_flows/flow_with_signature/inputs.jsonl create mode 100644 src/promptflow/tests/test_configs/eager_flows/flow_with_signature/my_flow.py diff --git a/src/promptflow-core/promptflow/executor/_script_executor.py b/src/promptflow-core/promptflow/executor/_script_executor.py index 1feca7e5c21..1c3bd26ab8b 100644 --- a/src/promptflow-core/promptflow/executor/_script_executor.py +++ b/src/promptflow-core/promptflow/executor/_script_executor.py @@ -66,7 +66,7 @@ def __init__( self._working_dir = Flow._resolve_working_dir(entry, working_dir) else: self._working_dir = working_dir or Path.cwd() - self._initialize_flow() + self._init_sign_from_yaml() self._initialize_function() self._connections = connections self._storage = storage or DefaultRunStorage() @@ -94,7 +94,8 @@ def exec_line( **kwargs, ) -> LineResult: run_id = run_id or str(uuid.uuid4()) - inputs = apply_default_value_for_input(self._flow.inputs, inputs) + if not self.is_function_entry: + inputs = apply_default_value_for_input(self._inputs_sign, inputs) with self._exec_line_context(run_id, index): return self._exec_line(inputs, index, run_id, allow_generator_output=allow_generator_output) @@ -120,7 +121,8 @@ def _exec_line_preprocess( # Executor will add line_number to batch inputs if there is no line_number in the original inputs, # which should be removed, so, we only preserve the inputs that are contained in self._inputs. inputs = {k: inputs[k] for k in self._inputs if k in inputs} - FlowValidator.ensure_flow_inputs_type(self._flow, inputs) + if not self.is_function_entry: + FlowValidator.ensure_flow_inputs_type(self._inputs_sign, inputs) return run_info, inputs, run_tracker, None, [] def _exec_line( @@ -226,7 +228,8 @@ async def exec_line_async( **kwargs, ) -> LineResult: run_id = run_id or str(uuid.uuid4()) - inputs = apply_default_value_for_input(self._flow.inputs, inputs) + if not self.is_function_entry: + inputs = apply_default_value_for_input(self._inputs_sign, inputs) with self._exec_line_context(run_id, index): return await self._exec_line_async(inputs, index, run_id, allow_generator_output=allow_generator_output) @@ -289,7 +292,8 @@ def get_inputs_definition(self): def _resolve_init_kwargs(self, c: type, init_kwargs: dict): """Resolve init kwargs, the connection names will be resolved to connection objects.""" logger.debug(f"Resolving init kwargs: {init_kwargs.keys()}.") - init_kwargs = apply_default_value_for_input(self._flow.init, init_kwargs) + if not self.is_function_entry: + init_kwargs = apply_default_value_for_input(self._init_sign, init_kwargs) sig = inspect.signature(c.__init__) connection_params = [] model_config_param_name_2_cls = {} @@ -453,8 +457,10 @@ def _parse_flow_file(self): ) from e return module_name, func_name - def _initialize_flow(self): - if not inspect.isfunction(self._flow_file): + def _init_sign_from_yaml(self): + if not self.is_function_entry: with open(self._working_dir / self._flow_file, "r", encoding="utf-8") as fin: flow_dag = load_yaml(fin) - self._flow = FlexFlow.deserialize(flow_dag) + flow = FlexFlow.deserialize(flow_dag) + self._inputs_sign = flow.inputs + self._init_sign = flow.init diff --git a/src/promptflow-core/promptflow/executor/flow_executor.py b/src/promptflow-core/promptflow/executor/flow_executor.py index 09c60be1b69..9e09343a209 100644 --- a/src/promptflow-core/promptflow/executor/flow_executor.py +++ b/src/promptflow-core/promptflow/executor/flow_executor.py @@ -477,7 +477,7 @@ def convert_flow_input_types(self, inputs: dict) -> Mapping[str, Any]: :return: A dictionary containing the converted inputs. :rtype: Mapping[str, Any] """ - return FlowValidator.resolve_flow_inputs_type(self._flow, inputs) + return FlowValidator.resolve_flow_inputs_type(self._flow.inputs, inputs) @property def _default_inputs_mapping(self): @@ -956,7 +956,9 @@ def _exec( aggregation_inputs = {} try: if validate_inputs: - inputs = FlowValidator.ensure_flow_inputs_type(flow=self._flow, inputs=inputs, idx=run_info.index) + inputs = FlowValidator.ensure_flow_inputs_type( + flow=self._flow.inputs, inputs=inputs, idx=run_info.index + ) inputs = self._multimedia_processor.load_multimedia_data(self._flow.inputs, inputs) # Inputs are assigned after validation and multimedia data loading, instead of at the start of the flow run. # This way, if validation or multimedia data loading fails, we avoid persisting invalid inputs. @@ -1041,7 +1043,7 @@ async def _exec_async( aggregation_inputs = {} try: if validate_inputs: - inputs = FlowValidator.ensure_flow_inputs_type(flow=self._flow, inputs=inputs, idx=line_number) + inputs = FlowValidator.ensure_flow_inputs_type(flow=self._flow.inputs, inputs=inputs, idx=line_number) # TODO: Consider async implementation for load_multimedia_data inputs = self._multimedia_processor.load_multimedia_data(self._flow.inputs, inputs) # Make sure the run_info with converted inputs results rather than original inputs diff --git a/src/promptflow-core/promptflow/executor/flow_validator.py b/src/promptflow-core/promptflow/executor/flow_validator.py index fd6478aee35..5bbe7bf1145 100644 --- a/src/promptflow-core/promptflow/executor/flow_validator.py +++ b/src/promptflow-core/promptflow/executor/flow_validator.py @@ -7,7 +7,7 @@ from typing import Any, List, Mapping, Optional from promptflow._utils.logger_utils import logger -from promptflow.contracts.flow import Flow, InputValueType, Node +from promptflow.contracts.flow import Flow, FlowInputDefinition, InputValueType, Node from promptflow.contracts.tool import ValueType from promptflow.executor._errors import ( DuplicateNodeName, @@ -182,7 +182,9 @@ def resolve_aggregated_flow_inputs_type(flow: Flow, inputs: Mapping[str, List[An return updated_inputs @staticmethod - def resolve_flow_inputs_type(flow: Flow, inputs: Mapping[str, Any], idx: Optional[int] = None) -> Mapping[str, Any]: + def resolve_flow_inputs_type( + flow_inputs: FlowInputDefinition, inputs: Mapping[str, Any], idx: Optional[int] = None + ) -> Mapping[str, Any]: """Resolve inputs by type if existing. Ignore missing inputs. :param flow: The `flow` parameter is of type `Flow` and represents a flow object @@ -198,13 +200,15 @@ def resolve_flow_inputs_type(flow: Flow, inputs: Mapping[str, Any], idx: Optiona :rtype: Mapping[str, Any] """ updated_inputs = {k: v for k, v in inputs.items()} - for k, v in flow.inputs.items(): + for k, v in flow_inputs.items(): if k in inputs: updated_inputs[k] = FlowValidator._parse_input_value(k, inputs[k], v.type, idx) return updated_inputs @staticmethod - def ensure_flow_inputs_type(flow: Flow, inputs: Mapping[str, Any], idx: Optional[int] = None) -> Mapping[str, Any]: + def ensure_flow_inputs_type( + flow_inputs: FlowInputDefinition, inputs: Mapping[str, Any], idx: Optional[int] = None + ) -> Mapping[str, Any]: """Make sure the inputs are completed and in the correct type. Raise Exception if not valid. :param flow: The `flow` parameter is of type `Flow` and represents a flow object @@ -219,7 +223,7 @@ def ensure_flow_inputs_type(flow: Flow, inputs: Mapping[str, Any], idx: Optional type specified in the `flow` object. :rtype: Mapping[str, Any] """ - for k, v in flow.inputs.items(): + for k, _ in flow_inputs.items(): if k not in inputs: line_info = "in input data" if idx is None else f"in line {idx} of input data" msg_format = ( @@ -228,7 +232,7 @@ def ensure_flow_inputs_type(flow: Flow, inputs: Mapping[str, Any], idx: Optional "if it's no longer needed." ) raise InputNotFound(message_format=msg_format, input_name=k, line_info=line_info) - return FlowValidator.resolve_flow_inputs_type(flow, inputs, idx) + return FlowValidator.resolve_flow_inputs_type(flow_inputs, inputs, idx) @staticmethod def convert_flow_inputs_for_node(flow: Flow, node: Node, inputs: Mapping[str, Any]) -> Mapping[str, Any]: diff --git a/src/promptflow-devkit/promptflow/batch/_batch_engine.py b/src/promptflow-devkit/promptflow/batch/_batch_engine.py index 83ddddad0de..230347be6cd 100644 --- a/src/promptflow-devkit/promptflow/batch/_batch_engine.py +++ b/src/promptflow-devkit/promptflow/batch/_batch_engine.py @@ -567,7 +567,8 @@ def _get_aggregation_inputs(self, batch_inputs, line_results: List[LineResult]): succeeded_batch_inputs = [batch_inputs[i] for i in succeeded] resolved_succeeded_batch_inputs = [ - FlowValidator.ensure_flow_inputs_type(flow=self._flow, inputs=input) for input in succeeded_batch_inputs + FlowValidator.ensure_flow_inputs_type(flow=self._flow.inputs, inputs=input) + for input in succeeded_batch_inputs ] succeeded_inputs = transpose(resolved_succeeded_batch_inputs, keys=list(self._flow.inputs.keys())) aggregation_inputs = transpose( diff --git a/src/promptflow/tests/test_configs/eager_flows/flow_with_signature/flow.flex.yaml b/src/promptflow/tests/test_configs/eager_flows/flow_with_signature/flow.flex.yaml new file mode 100644 index 00000000000..c04400b68b3 --- /dev/null +++ b/src/promptflow/tests/test_configs/eager_flows/flow_with_signature/flow.flex.yaml @@ -0,0 +1,11 @@ +entry: my_flow:MyClass +init: + input_init: + type: string + default: input_init +inputs: + input_1: + type: string + input_2: + type: string + default: input_2 \ No newline at end of file diff --git a/src/promptflow/tests/test_configs/eager_flows/flow_with_signature/inputs.jsonl b/src/promptflow/tests/test_configs/eager_flows/flow_with_signature/inputs.jsonl new file mode 100644 index 00000000000..ac71ed5fe13 --- /dev/null +++ b/src/promptflow/tests/test_configs/eager_flows/flow_with_signature/inputs.jsonl @@ -0,0 +1,2 @@ +{"input_1": "input_1"} +{"input_1": "input_1"} \ No newline at end of file diff --git a/src/promptflow/tests/test_configs/eager_flows/flow_with_signature/my_flow.py b/src/promptflow/tests/test_configs/eager_flows/flow_with_signature/my_flow.py new file mode 100644 index 00000000000..1fd41551131 --- /dev/null +++ b/src/promptflow/tests/test_configs/eager_flows/flow_with_signature/my_flow.py @@ -0,0 +1,6 @@ +class MyClass: + def __init__(self, input_init: str = "default_input_init"): + pass + + def __call__(self, input_1, input_2: str = "default_input_2"): + return {"output": input_2}