Skip to content

Commit

Permalink
Refine code: not using flow contract in script executor
Browse files Browse the repository at this point in the history
  • Loading branch information
Lina Tang committed Apr 23, 2024
1 parent 8102634 commit 78694b4
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 18 deletions.
22 changes: 14 additions & 8 deletions src/promptflow-core/promptflow/executor/_script_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
self._working_dir = Flow._resolve_working_dir(entry, working_dir)
else:
self._working_dir = working_dir or Path.cwd()
self._initialize_flow()
self._init_sign_from_yaml()
self._initialize_function()
self._connections = connections
self._storage = storage or DefaultRunStorage()
Expand Down Expand Up @@ -94,7 +94,8 @@ def exec_line(
**kwargs,
) -> LineResult:
run_id = run_id or str(uuid.uuid4())
inputs = apply_default_value_for_input(self._flow.inputs, inputs)
if not self.is_function_entry:
inputs = apply_default_value_for_input(self._inputs_sign, inputs)
with self._exec_line_context(run_id, index):
return self._exec_line(inputs, index, run_id, allow_generator_output=allow_generator_output)

Expand All @@ -120,7 +121,8 @@ def _exec_line_preprocess(
# Executor will add line_number to batch inputs if there is no line_number in the original inputs,
# which should be removed, so, we only preserve the inputs that are contained in self._inputs.
inputs = {k: inputs[k] for k in self._inputs if k in inputs}
FlowValidator.ensure_flow_inputs_type(self._flow, inputs)
if not self.is_function_entry:
FlowValidator.ensure_flow_inputs_type(self._inputs_sign, inputs)
return run_info, inputs, run_tracker, None, []

def _exec_line(
Expand Down Expand Up @@ -226,7 +228,8 @@ async def exec_line_async(
**kwargs,
) -> LineResult:
run_id = run_id or str(uuid.uuid4())
inputs = apply_default_value_for_input(self._flow.inputs, inputs)
if not self.is_function_entry:
inputs = apply_default_value_for_input(self._inputs_sign, inputs)
with self._exec_line_context(run_id, index):
return await self._exec_line_async(inputs, index, run_id, allow_generator_output=allow_generator_output)

Expand Down Expand Up @@ -289,7 +292,8 @@ def get_inputs_definition(self):
def _resolve_init_kwargs(self, c: type, init_kwargs: dict):
"""Resolve init kwargs, the connection names will be resolved to connection objects."""
logger.debug(f"Resolving init kwargs: {init_kwargs.keys()}.")
init_kwargs = apply_default_value_for_input(self._flow.init, init_kwargs)
if not self.is_function_entry:
init_kwargs = apply_default_value_for_input(self._init_sign, init_kwargs)
sig = inspect.signature(c.__init__)
connection_params = []
model_config_param_name_2_cls = {}
Expand Down Expand Up @@ -453,8 +457,10 @@ def _parse_flow_file(self):
) from e
return module_name, func_name

def _initialize_flow(self):
if not inspect.isfunction(self._flow_file):
def _init_sign_from_yaml(self):
if not self.is_function_entry:
with open(self._working_dir / self._flow_file, "r", encoding="utf-8") as fin:

Check failure

Code scanning / CodeQL

Uncontrolled data used in path expression High

This path depends on a
user-provided value
.
flow_dag = load_yaml(fin)
self._flow = FlexFlow.deserialize(flow_dag)
flow = FlexFlow.deserialize(flow_dag)
self._inputs_sign = flow.inputs
self._init_sign = flow.init
8 changes: 5 additions & 3 deletions src/promptflow-core/promptflow/executor/flow_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def convert_flow_input_types(self, inputs: dict) -> Mapping[str, Any]:
:return: A dictionary containing the converted inputs.
:rtype: Mapping[str, Any]
"""
return FlowValidator.resolve_flow_inputs_type(self._flow, inputs)
return FlowValidator.resolve_flow_inputs_type(self._flow.inputs, inputs)

@property
def _default_inputs_mapping(self):
Expand Down Expand Up @@ -956,7 +956,9 @@ def _exec(
aggregation_inputs = {}
try:
if validate_inputs:
inputs = FlowValidator.ensure_flow_inputs_type(flow=self._flow, inputs=inputs, idx=run_info.index)
inputs = FlowValidator.ensure_flow_inputs_type(
flow=self._flow.inputs, inputs=inputs, idx=run_info.index
)
inputs = self._multimedia_processor.load_multimedia_data(self._flow.inputs, inputs)
# Inputs are assigned after validation and multimedia data loading, instead of at the start of the flow run.
# This way, if validation or multimedia data loading fails, we avoid persisting invalid inputs.
Expand Down Expand Up @@ -1041,7 +1043,7 @@ async def _exec_async(
aggregation_inputs = {}
try:
if validate_inputs:
inputs = FlowValidator.ensure_flow_inputs_type(flow=self._flow, inputs=inputs, idx=line_number)
inputs = FlowValidator.ensure_flow_inputs_type(flow=self._flow.inputs, inputs=inputs, idx=line_number)
# TODO: Consider async implementation for load_multimedia_data
inputs = self._multimedia_processor.load_multimedia_data(self._flow.inputs, inputs)
# Make sure the run_info with converted inputs results rather than original inputs
Expand Down
16 changes: 10 additions & 6 deletions src/promptflow-core/promptflow/executor/flow_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, List, Mapping, Optional

from promptflow._utils.logger_utils import logger
from promptflow.contracts.flow import Flow, InputValueType, Node
from promptflow.contracts.flow import Flow, FlowInputDefinition, InputValueType, Node
from promptflow.contracts.tool import ValueType
from promptflow.executor._errors import (
DuplicateNodeName,
Expand Down Expand Up @@ -182,7 +182,9 @@ def resolve_aggregated_flow_inputs_type(flow: Flow, inputs: Mapping[str, List[An
return updated_inputs

@staticmethod
def resolve_flow_inputs_type(flow: Flow, inputs: Mapping[str, Any], idx: Optional[int] = None) -> Mapping[str, Any]:
def resolve_flow_inputs_type(
flow_inputs: FlowInputDefinition, inputs: Mapping[str, Any], idx: Optional[int] = None
) -> Mapping[str, Any]:
"""Resolve inputs by type if existing. Ignore missing inputs.
:param flow: The `flow` parameter is of type `Flow` and represents a flow object
Expand All @@ -198,13 +200,15 @@ def resolve_flow_inputs_type(flow: Flow, inputs: Mapping[str, Any], idx: Optiona
:rtype: Mapping[str, Any]
"""
updated_inputs = {k: v for k, v in inputs.items()}
for k, v in flow.inputs.items():
for k, v in flow_inputs.items():
if k in inputs:
updated_inputs[k] = FlowValidator._parse_input_value(k, inputs[k], v.type, idx)
return updated_inputs

@staticmethod
def ensure_flow_inputs_type(flow: Flow, inputs: Mapping[str, Any], idx: Optional[int] = None) -> Mapping[str, Any]:
def ensure_flow_inputs_type(
flow_inputs: FlowInputDefinition, inputs: Mapping[str, Any], idx: Optional[int] = None
) -> Mapping[str, Any]:
"""Make sure the inputs are completed and in the correct type. Raise Exception if not valid.
:param flow: The `flow` parameter is of type `Flow` and represents a flow object
Expand All @@ -219,7 +223,7 @@ def ensure_flow_inputs_type(flow: Flow, inputs: Mapping[str, Any], idx: Optional
type specified in the `flow` object.
:rtype: Mapping[str, Any]
"""
for k, v in flow.inputs.items():
for k, _ in flow_inputs.items():
if k not in inputs:
line_info = "in input data" if idx is None else f"in line {idx} of input data"
msg_format = (
Expand All @@ -228,7 +232,7 @@ def ensure_flow_inputs_type(flow: Flow, inputs: Mapping[str, Any], idx: Optional
"if it's no longer needed."
)
raise InputNotFound(message_format=msg_format, input_name=k, line_info=line_info)
return FlowValidator.resolve_flow_inputs_type(flow, inputs, idx)
return FlowValidator.resolve_flow_inputs_type(flow_inputs, inputs, idx)

@staticmethod
def convert_flow_inputs_for_node(flow: Flow, node: Node, inputs: Mapping[str, Any]) -> Mapping[str, Any]:
Expand Down
3 changes: 2 additions & 1 deletion src/promptflow-devkit/promptflow/batch/_batch_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,8 @@ def _get_aggregation_inputs(self, batch_inputs, line_results: List[LineResult]):

succeeded_batch_inputs = [batch_inputs[i] for i in succeeded]
resolved_succeeded_batch_inputs = [
FlowValidator.ensure_flow_inputs_type(flow=self._flow, inputs=input) for input in succeeded_batch_inputs
FlowValidator.ensure_flow_inputs_type(flow=self._flow.inputs, inputs=input)
for input in succeeded_batch_inputs
]
succeeded_inputs = transpose(resolved_succeeded_batch_inputs, keys=list(self._flow.inputs.keys()))
aggregation_inputs = transpose(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
entry: my_flow:MyClass
init:
input_init:
type: string
default: input_init
inputs:
input_1:
type: string
input_2:
type: string
default: input_2
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
{"input_1": "input_1"}
{"input_1": "input_1"}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class MyClass:
def __init__(self, input_init: str = "default_input_init"):
pass

def __call__(self, input_1, input_2: str = "default_input_2"):
return {"output": input_2}

0 comments on commit 78694b4

Please sign in to comment.