From 2d84ea9ffc373a9f071eeefa54c43624209b1e41 Mon Sep 17 00:00:00 2001 From: Peiwen Gao <111329184+PeiwenGaoMS@users.noreply.github.com> Date: Mon, 30 Oct 2023 18:03:56 +0800 Subject: [PATCH] [Internal][Executor] Add docstring for storage and flow (#947) # Description Add docstring for storage and flow # All Promptflow Contribution checklist: - [x] **The pull request does not introduce [breaking changes].** - [ ] **CHANGELOG is updated for new features, bug fixes or other significant changes.** - [x] **I have read the [contribution guidelines](../CONTRIBUTING.md).** - [ ] **Create an issue and link to the pull request to get dedicated review from promptflow team. Learn more: [suggested workflow](../CONTRIBUTING.md#suggested-workflow).** ## General Guidelines and Best Practices - [x] Title of the pull request is clear and informative. - [x] There are a small number of commits, each of which have an informative message. This means that previously merged commits do not appear in the history of the PR. For more information on cleaning up the commits in your PR, [see this page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md). ### Testing Guidelines - [ ] Pull request includes test coverage for the included changes. --- src/promptflow/promptflow/contracts/flow.py | 189 +++++++++++++++--- src/promptflow/promptflow/exceptions.py | 4 +- .../promptflow/storage/_run_storage.py | 49 ++++- 3 files changed, 209 insertions(+), 33 deletions(-) diff --git a/src/promptflow/promptflow/contracts/flow.py b/src/promptflow/promptflow/contracts/flow.py index 4396cfe8dd1..6ff711a133b 100644 --- a/src/promptflow/promptflow/contracts/flow.py +++ b/src/promptflow/promptflow/contracts/flow.py @@ -12,8 +12,8 @@ import yaml from promptflow.exceptions import ErrorTarget -from .._sdk._constants import DEFAULT_ENCODING +from .._sdk._constants import DEFAULT_ENCODING from .._utils.dataclass_serializer import serialize from .._utils.utils import try_import from ._errors import FailedToImportModule, NodeConditionConflict @@ -36,7 +36,17 @@ class InputValueType(Enum): @dataclass class InputAssignment: - """This class represents the assignment of an input value.""" + """This class represents the assignment of an input value. + + :param value: The value of the input assignment. + :type value: Any + :param value_type: The type of the input assignment. + :type value_type: ~promptflow.contracts.flow.InputValueType + :param section: The section of the input assignment, usually the output. + :type section: str + :param property: The property of the input assignment that exists in the section. + :type property: str + """ value: Any value_type: InputValueType = InputValueType.LITERAL @@ -62,7 +72,7 @@ def deserialize(value: str) -> "InputAssignment": :param value: The string to be deserialized. :type value: str :return: The input assignment constructed from the string. - :rtype: InputAssignment + :rtype: ~promptflow.contracts.flow.InputAssignment """ literal_value = InputAssignment(value, InputValueType.LITERAL) if isinstance(value, str) and value.startswith("$") and len(value) > 2: @@ -80,7 +90,7 @@ def deserialize_reference(value: str) -> "InputAssignment": :param value: The string to be deserialized. :type value: str :return: The input assignment of reference types. - :rtype: InputAssignment + :rtype: ~promptflow.contracts.flow.InputAssignment """ if FlowInputAssignment.is_flow_input(value): return FlowInputAssignment.deserialize(value) @@ -93,7 +103,7 @@ def deserialize_node_reference(data: str) -> "InputAssignment": :param data: The string to be deserialized. :type data: str :return: Input assignment of node reference type. - :rtype: InputAssignment + :rtype: ~promptflow.contracts.flow.InputAssignment """ value_type = InputValueType.NODE_REFERENCE if "." not in data: @@ -107,7 +117,11 @@ def deserialize_node_reference(data: str) -> "InputAssignment": @dataclass class FlowInputAssignment(InputAssignment): - """This class represents the assignment of a flow input value.""" + """This class represents the assignment of a flow input value. + + :param prefix: The prefix of the flow input. + :type prefix: str + """ prefix: str = FLOW_INPUT_PREFIX @@ -132,7 +146,7 @@ def deserialize(value: str) -> "FlowInputAssignment": :param value: The string to be deserialized. :type value: str :return: The flow input assignment constructed from the string. - :rtype: FlowInputAssignment + :rtype: ~promptflow.contracts.flow.FlowInputAssignment """ for prefix in FLOW_INPUT_PREFIXES: if value.startswith(prefix): @@ -152,7 +166,15 @@ class ToolSourceType(str, Enum): @dataclass class ToolSource: - """This class represents the source of a tool.""" + """This class represents the source of a tool. + + :param type: The type of the tool source. + :type type: ~promptflow.contracts.flow.ToolSourceType + :param tool: The tool of the tool source. + :type tool: str + :param path: The path of the tool source. + :type path: str + """ type: ToolSourceType = ToolSourceType.Code tool: Optional[str] = None @@ -165,7 +187,7 @@ def deserialize(data: dict) -> "ToolSource": :param data: The dict to be deserialized. :type data: dict :return: The tool source constructed from the dict. - :rtype: ToolSource + :rtype: ~promptflow.contracts.flow.ToolSource """ result = ToolSource(data.get("type", ToolSourceType.Code.value)) if "tool" in data: @@ -177,7 +199,13 @@ def deserialize(data: dict) -> "ToolSource": @dataclass class ActivateCondition: - """This class represents the activate condition of a node.""" + """This class represents the activate condition of a node. + + :param condition: The condition of the activate condition. + :type condition: ~promptflow.contracts.flow.InputAssignment + :param condition_value: The value of the condition. + :type condition_value: Any + """ condition: InputAssignment condition_value: Any @@ -189,7 +217,7 @@ def deserialize(data: dict) -> "ActivateCondition": :param data: The dict to be deserialized. :type data: dict :return: The activate condition constructed from the dict. - :rtype: ActivateCondition + :rtype: ~promptflow.contracts.flow.ActivateCondition """ result = ActivateCondition( condition=InputAssignment.deserialize(data["when"]), @@ -200,7 +228,15 @@ def deserialize(data: dict) -> "ActivateCondition": @dataclass class SkipCondition: - """This class represents the skip condition of a node.""" + """This class represents the skip condition of a node. + + :param condition: The condition of the skip condition. + :type condition: ~promptflow.contracts.flow.InputAssignment + :param condition_value: The value of the condition. + :type condition_value: Any + :param return_value: The return value when skip condition is met. + :type return_value: ~promptflow.contracts.flow.InputAssignment + """ condition: InputAssignment condition_value: Any @@ -213,7 +249,7 @@ def deserialize(data: dict) -> "SkipCondition": :param data: The dict to be deserialized. :type data: dict :return: The skip condition constructed from the dict. - :rtype: SkipCondition + :rtype: ~promptflow.contracts.flow.SkipCondition """ result = SkipCondition( condition=InputAssignment.deserialize(data["when"]), @@ -225,7 +261,39 @@ def deserialize(data: dict) -> "SkipCondition": @dataclass class Node: - """This class represents a node in a flow.""" + """This class represents a node in a flow. + + :param name: The name of the node. + :type name: str + :param tool: The tool of the node. + :type tool: str + :param inputs: The inputs of the node. + :type inputs: Dict[str, InputAssignment] + :param comment: The comment of the node. + :type comment: str + :param api: The api of the node. + :type api: str + :param provider: The provider of the node. + :type provider: str + :param module: The module of the node. + :type module: str + :param connection: The connection of the node. + :type connection: str + :param aggregation: Whether the node is an aggregation node. + :type aggregation: bool + :param enable_cache: Whether the node enable cache. + :type enable_cache: bool + :param use_variants: Whether the node use variants. + :type use_variants: bool + :param source: The source of the node. + :type source: ~promptflow.contracts.flow.ToolSource + :param type: The tool type of the node. + :type type: ~promptflow.contracts.tool.ToolType + :param skip: The skip condition of the node. + :type skip: ~promptflow.contracts.flow.SkipCondition + :param activate: The activate condition of the node. + :type activate: ~promptflow.contracts.flow.ActivateCondition + """ name: str tool: str @@ -264,7 +332,7 @@ def deserialize(data: dict) -> "Node": :param data: The dict to be deserialized. :type data: dict :return: The node constructed from the dict. - :rtype: Node + :rtype: ~promptflow.contracts.flow.Node """ node = Node( name=data.get("name"), @@ -295,7 +363,21 @@ def deserialize(data: dict) -> "Node": @dataclass class FlowInputDefinition: - """This class represents the definition of a flow input.""" + """This class represents the definition of a flow input. + + :param type: The type of the flow input. + :type type: ~promptflow.contracts.tool.ValueType + :param default: The default value of the flow input. + :type default: str + :param description: The description of the flow input. + :type description: str + :param enum: The enum of the flow input. + :type enum: List[str] + :param is_chat_input: Whether the flow input is a chat input. + :type is_chat_input: bool + :param is_chat_history: Whether the flow input is a chat history. + :type is_chat_history: bool + """ type: ValueType default: str = None @@ -305,6 +387,11 @@ class FlowInputDefinition: is_chat_history: bool = None def serialize(self): + """Serialize the flow input definition to a dict. + + :return: The dict of the flow input definition. + :rtype: dict + """ data = {} data["type"] = self.type.value if self.default: @@ -326,7 +413,7 @@ def deserialize(data: dict) -> "FlowInputDefinition": :param data: The dict to be deserialized. :type data: dict :return: The flow input definition constructed from the dict. - :rtype: FlowInputDefinition + :rtype: ~promptflow.contracts.flow.FlowInputDefinition """ return FlowInputDefinition( ValueType(data["type"]), @@ -340,7 +427,19 @@ def deserialize(data: dict) -> "FlowInputDefinition": @dataclass class FlowOutputDefinition: - """This class represents the definition of a flow output.""" + """This class represents the definition of a flow output. + + :param type: The type of the flow output. + :type type: ~promptflow.contracts.tool.ValueType + :param reference: The reference of the flow output. + :type reference: ~promptflow.contracts.flow.InputAssignment + :param description: The description of the flow output. + :type description: str + :param evaluation_only: Whether the flow output is for evaluation only. + :type evaluation_only: bool + :param is_chat_output: Whether the flow output is a chat output. + :type is_chat_output: bool + """ type: ValueType reference: InputAssignment @@ -349,7 +448,11 @@ class FlowOutputDefinition: is_chat_output: bool = False def serialize(self): - """Serialize the flow output definition to a dict.""" + """Serialize the flow output definition to a dict. + + :return: The dict of the flow output definition. + :rtype: dict + """ data = {} data["type"] = self.type.value if self.reference: @@ -369,7 +472,7 @@ def deserialize(data: dict): :param data: The dict to be deserialized. :type data: dict :return: The flow output definition constructed from the dict. - :rtype: FlowOutputDefinition + :rtype: ~promptflow.contracts.flow.FlowOutputDefinition """ return FlowOutputDefinition( ValueType(data["type"]), @@ -382,7 +485,13 @@ def deserialize(data: dict): @dataclass class NodeVariant: - """This class represents a node variant.""" + """This class represents a node variant. + + :param node: The node of the node variant. + :type node: ~promptflow.contracts.flow.Node + :param description: The description of the node variant. + :type description: str + """ node: Node description: str = "" @@ -394,7 +503,7 @@ def deserialize(data: dict) -> "NodeVariant": :param data: The dict to be deserialized. :type data: dict :return: The node variant constructed from the dict. - :rtype: NodeVariant + :rtype: ~promptflow.contracts.flow.NodeVariant """ return NodeVariant( Node.deserialize(data["node"]), @@ -404,7 +513,13 @@ def deserialize(data: dict) -> "NodeVariant": @dataclass class NodeVariants: - """This class represents the variants of a node.""" + """This class represents the variants of a node. + + :param default_variant_id: The default variant id of the node. + :type default_variant_id: str + :param variants: The variants of the node. + :type variants: Dict[str, NodeVariant] + """ default_variant_id: str # The default variant id of the node variants: Dict[str, NodeVariant] # The variants of the node @@ -416,7 +531,7 @@ def deserialize(data: dict) -> "NodeVariants": :param data: The dict to be deserialized. :type data: dict :return: The node variants constructed from the dict. - :rtype: NodeVariants + :rtype: ~promptflow.contracts.flow.NodeVariants """ variants = {} for variant_id, node in data["variants"].items(): @@ -426,7 +541,23 @@ def deserialize(data: dict) -> "NodeVariants": @dataclass class Flow: - """This class represents a flow.""" + """This class represents a flow. + + :param id: The id of the flow. + :type id: str + :param name: The name of the flow. + :type name: str + :param nodes: The nodes of the flow. + :type nodes: List[Node] + :param inputs: The inputs of the flow. + :type inputs: Dict[str, FlowInputDefinition] + :param outputs: The outputs of the flow. + :type outputs: Dict[str, FlowOutputDefinition] + :param tools: The tools of the flow. + :type tools: List[Tool] + :param node_variants: The node variants of the flow. + :type node_variants: Dict[str, NodeVariants] + """ id: str name: str @@ -437,7 +568,11 @@ class Flow: node_variants: Dict[str, NodeVariants] = None def serialize(self): - """Serialize the flow to a dict.""" + """Serialize the flow to a dict. + + :return: The dict of the flow. + :rtype: dict + """ data = { "id": self.id, "name": self.name, @@ -473,7 +608,7 @@ def deserialize(data: dict) -> "Flow": :param data: The dict to be deserialized. :type data: dict :return: The flow constructed from the dict. - :rtype: Flow + :rtype: ~promptflow.contracts.flow.Flow """ tools = [Tool.deserialize(t) for t in data.get("tools") or []] nodes = [Node.deserialize(n) for n in data.get("nodes") or []] diff --git a/src/promptflow/promptflow/exceptions.py b/src/promptflow/promptflow/exceptions.py index 256cef52298..d440151968b 100644 --- a/src/promptflow/promptflow/exceptions.py +++ b/src/promptflow/promptflow/exceptions.py @@ -31,7 +31,7 @@ class PromptflowException(Exception): :param message: A message describing the error. This is the error message the user will see. :type message: str :param target: The name of the element that caused the exception to be thrown. - :type target: ErrorTarget + :type target: ~promptflow.exceptions.ErrorTarget :param error: The original exception if any. :type error: Exception """ @@ -97,7 +97,7 @@ def target(self): """The error target. :return: The error target. - :rtype: ErrorTarget + :rtype: ~promptflow.exceptions.ErrorTarget """ return self._target diff --git a/src/promptflow/promptflow/storage/_run_storage.py b/src/promptflow/promptflow/storage/_run_storage.py index 8d91d1fa031..f9f0cc80efa 100644 --- a/src/promptflow/promptflow/storage/_run_storage.py +++ b/src/promptflow/promptflow/storage/_run_storage.py @@ -7,33 +7,64 @@ from promptflow._utils.multimedia_utils import get_file_reference_encoder, recursive_process from promptflow.contracts.multimedia import Image -from promptflow.contracts.run_info import FlowRunInfo, RunInfo as NodeRunInfo +from promptflow.contracts.run_info import FlowRunInfo +from promptflow.contracts.run_info import RunInfo as NodeRunInfo class AbstractRunStorage: def persist_node_run(self, run_info: NodeRunInfo): - """Write the node run info to somewhere immediately after the node is executed.""" + """Write the node run info to somewhere immediately after the node is executed. + + :param run_info: The run info of the node. + :type run_info: ~promptflow.contracts.run_info.RunInfo + """ raise NotImplementedError("AbstractRunStorage is an abstract class, no implementation for persist_node_run.") def persist_flow_run(self, run_info: FlowRunInfo): - """Write the flow run info to somewhere immediately after one line data is executed for the flow.""" + """Write the flow run info to somewhere immediately after one line data is executed for the flow. + + :param run_info: The run info of the node. + :type run_info: ~promptflow.contracts.run_info.RunInfo + """ raise NotImplementedError("AbstractRunStorage is an abstract class, no implementation for persist_flow_run.") class DummyRunStorage(AbstractRunStorage): def persist_node_run(self, run_info: NodeRunInfo): + """Dummy implementation for persist_node_run + + :param run_info: The run info of the node. + :type run_info: ~promptflow.contracts.run_info.RunInfo + """ pass def persist_flow_run(self, run_info: FlowRunInfo): + """Dummy implementation for persist_flow_run + + :param run_info: The run info of the node. + :type run_info: ~promptflow.contracts.run_info.RunInfo + """ pass class DefaultRunStorage(AbstractRunStorage): def __init__(self, base_dir: Path = None, sub_dir: Path = None): + """Initialize the default run storage. + + :param base_dir: The base directory to store the multimedia data. + :type base_dir: Path + :param sub_dir: The sub directory to store the multimedia data. + :type sub_dir: Path + """ self._base_dir = base_dir self._sub_dir = sub_dir def persist_node_run(self, run_info: NodeRunInfo): + """Persist the multimedia data in node run info after the node is executed. + + :param run_info: The run info of the node. + :type run_info: ~promptflow.contracts.run_info.RunInfo + """ if run_info.inputs: run_info.inputs = self._persist_images(run_info.inputs) if run_info.output: @@ -44,6 +75,11 @@ def persist_node_run(self, run_info: NodeRunInfo): run_info.api_calls = self._persist_images(run_info.api_calls) def persist_flow_run(self, run_info: FlowRunInfo): + """Persist the multimedia data in flow run info after one line data is executed for the flow. + + :param run_info: The run info of the flow. + :type run_info: ~promptflow.contracts.run_info.FlowRunInfo + """ if run_info.inputs: run_info.inputs = self._persist_images(run_info.inputs) if run_info.output: @@ -54,10 +90,15 @@ def persist_flow_run(self, run_info: FlowRunInfo): run_info.api_calls = self._persist_images(run_info.api_calls) def _persist_images(self, value): + """Serialize the images in the value to file path and save them to the disk. + + :param value: A value that may contain images. + :type value: Any + """ if self._base_dir: pfbytes_file_reference_encoder = get_file_reference_encoder( folder_path=self._base_dir, - relative_path=self._sub_dir + relative_path=self._sub_dir, ) else: pfbytes_file_reference_encoder = None