From dbd9b42aef1eaca3aa6d9411b630297cd2ccd7f0 Mon Sep 17 00:00:00 2001 From: Xiaopeng Wang Date: Sat, 11 May 2024 16:57:13 +0800 Subject: [PATCH 1/3] support serving engine switch in pf serve cmd (#3161) # Description Add --engine parameter in `pf flow serve` command, customer can use this parameter to switch the python serving engine between `flask` and `fastapi`, it's default to `flask` # All Promptflow Contribution checklist: - [x] **The pull request does not introduce [breaking changes].** - [x] **CHANGELOG is updated for new features, bug fixes or other significant changes.** - [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).** - [ ] **Create an issue and link to the pull request to get dedicated review from promptflow team. Learn more: [suggested workflow](../CONTRIBUTING.md#suggested-workflow).** ## General Guidelines and Best Practices - [ ] Title of the pull request is clear and informative. - [ ] There are a small number of commits, each of which have an informative message. This means that previously merged commits do not appear in the history of the PR. For more information on cleaning up the commits in your PR, [see this page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md). ### Testing Guidelines - [ ] Pull request includes test coverage for the included changes. --------- Co-authored-by: xiaopwan --- docs/reference/pf-command-reference.md | 11 +++++++ src/promptflow-devkit/CHANGELOG.md | 1 + .../promptflow/_cli/_pf/_flow.py | 5 +++ .../promptflow/_sdk/_utilities/serve_utils.py | 21 +++++++++++- .../unittests/test_flow_serve_cli.py | 33 +++++++++++++++++++ src/promptflow/CHANGELOG.md | 1 + 6 files changed, 71 insertions(+), 1 deletion(-) diff --git a/docs/reference/pf-command-reference.md b/docs/reference/pf-command-reference.md index 0b9c623add8..076e5f4882b 100644 --- a/docs/reference/pf-command-reference.md +++ b/docs/reference/pf-command-reference.md @@ -294,6 +294,7 @@ pf flow serve --source [--verbose] [--debug] [--skip-open-browser] + [--engine] ``` #### Examples @@ -310,6 +311,12 @@ Serve flow as an endpoint with specific port and host. pf flow serve --source --port --host --environment-variables key1="`${my_connection.api_key}`" key2="value2" ``` +Serve flow as an endpoint with specific port, host, environment-variables and fastapi serving engine. + +```bash +pf flow serve --source --port --host --environment-variables key1="`${my_connection.api_key}`" key2="value2" --engine fastapi +``` + #### Required Parameter `--source` @@ -342,6 +349,10 @@ Show debug information during serve. Skip opening browser after serve. Store true parameter. +`--engine` + +Switch python serving engine between `flask` amd `fastapi`, default to `flask`. + ## pf connection Manage prompt flow connections. diff --git a/src/promptflow-devkit/CHANGELOG.md b/src/promptflow-devkit/CHANGELOG.md index 1fa4efd1e3f..7f0401a9889 100644 --- a/src/promptflow-devkit/CHANGELOG.md +++ b/src/promptflow-devkit/CHANGELOG.md @@ -5,6 +5,7 @@ ### Improvements - Interactive browser credential is excluded by default when using Azure AI connections, user could set `PF_NO_INTERACTIVE_LOGIN=False` to enable it. - Visualize flex flow run(s) switches to trace UI page. +- Add new `--engine` parameter for `pf flow serve`. This parameter can be used to switch python serving engine between `flask` and `fastapi`, currently it defaults to `flask`. - Return the secrets in the connection object by default to improve flex flow experience. - Behaviors not changed: 'pf connection' command will scrub secrets. - New behavior: connection object by `client.connection.get` will have real secrets. `print(connection_obj)` directly will scrub those secrets. `print(connection_obj.api_key)` or `print(connection_obj.secrets)` will print the REAL secrets. diff --git a/src/promptflow-devkit/promptflow/_cli/_pf/_flow.py b/src/promptflow-devkit/promptflow/_cli/_pf/_flow.py index 29e316ed73e..d0e1b348585 100644 --- a/src/promptflow-devkit/promptflow/_cli/_pf/_flow.py +++ b/src/promptflow-devkit/promptflow/_cli/_pf/_flow.py @@ -199,6 +199,9 @@ def add_parser_serve_flow(subparsers): add_param_skip_browser = lambda parser: parser.add_argument( # noqa: E731 "--skip-open-browser", action="store_true", default=False, help="Skip open browser for flow serving." ) + add_param_engine = lambda parser: parser.add_argument( # noqa: E731 + "--engine", type=str, default="flask", help="The engine to serve the flow, can be flask or fastapi." + ) activate_action( name="serve", description="Serving a flow as an endpoint.", @@ -207,6 +210,7 @@ def add_parser_serve_flow(subparsers): add_param_source, add_param_port, add_param_host, + add_param_engine, add_param_static_folder, add_param_environment_variables, add_param_config, @@ -595,6 +599,7 @@ def serve_flow(args): host=args.host, port=args.port, skip_open_browser=args.skip_open_browser, + engine=args.engine, ) logger.info("Promptflow app ended") diff --git a/src/promptflow-devkit/promptflow/_sdk/_utilities/serve_utils.py b/src/promptflow-devkit/promptflow/_sdk/_utilities/serve_utils.py index cb6414900d4..7d41324a708 100644 --- a/src/promptflow-devkit/promptflow/_sdk/_utilities/serve_utils.py +++ b/src/promptflow-devkit/promptflow/_sdk/_utilities/serve_utils.py @@ -58,6 +58,7 @@ def start_flow_service( environment_variables: Dict[str, str] = None, init: Dict[str, Any] = None, skip_open_browser: bool = True, + engine: str = "flask", ): logger.info( "Start promptflow server with port %s", @@ -72,6 +73,11 @@ def start_flow_service( message_format="Support directory `source` for Python flow only for now, but got {source}.", source=source, ) + if engine not in ["flask", "fastapi"]: + raise UserErrorException( + message_format="Unsupported engine {engine} for Python flow, only support 'flask' and 'fastapi'.", + engine=engine, + ) serve_python_flow( flow_file_name=flow_file_name, flow_dir=flow_dir, @@ -82,6 +88,7 @@ def start_flow_service( config=config or {}, environment_variables=environment_variables or {}, skip_open_browser=skip_open_browser, + engine=engine, ) else: serve_csharp_flow( @@ -103,6 +110,7 @@ def serve_python_flow( environment_variables, init, skip_open_browser: bool, + engine, ): from promptflow._sdk._configuration import Configuration from promptflow.core._serving.app import create_app @@ -121,13 +129,24 @@ def serve_python_flow( environment_variables=environment_variables, connection_provider=connection_provider, init=init, + engine=engine, ) if not skip_open_browser: target = f"http://{host}:{port}" logger.info(f"Opening browser {target}...") webbrowser.open(target) # Debug is not supported for now as debug will rerun command, and we changed working directory. - app.run(port=port, host=host) + if engine == "flask": + app.run(port=port, host=host) + else: + try: + import uvicorn + + uvicorn.run(app, host=host, port=port, access_log=False, log_config=None) + except ImportError: + raise UserErrorException( + message_format="FastAPI engine requires uvicorn, please install uvicorn by `pip install uvicorn`." + ) @contextlib.contextmanager diff --git a/src/promptflow-devkit/tests/sdk_cli_test/unittests/test_flow_serve_cli.py b/src/promptflow-devkit/tests/sdk_cli_test/unittests/test_flow_serve_cli.py index 3a2897e06dc..fc7aa6bdd1b 100644 --- a/src/promptflow-devkit/tests/sdk_cli_test/unittests/test_flow_serve_cli.py +++ b/src/promptflow-devkit/tests/sdk_cli_test/unittests/test_flow_serve_cli.py @@ -49,6 +49,17 @@ def test_flow_serve(self, source: Path): "--skip-open-browser", ) mock_run.assert_called_once_with(port=8080, host="localhost") + with mock.patch("uvicorn.run") as mock_run: + run_pf_command( + "flow", + "serve", + "--source", + source.as_posix(), + "--skip-open-browser", + "--engine", + "fastapi", + ) + mock_run.assert_called_once() @pytest.mark.parametrize( "source", @@ -71,3 +82,25 @@ def test_flow_serve_failed(self, source: Path, capsys): "pf.flow.serve failed with UserErrorException: Support directory `source` for Python flow only for now" in out ) + + @pytest.mark.parametrize( + "source", + [ + pytest.param(EAGER_FLOWS_DIR / "simple_with_yaml", id="simple_with_yaml_file"), + pytest.param(FLOWS_DIR / "simple_hello_world", id="simple_hello_world_file"), + ], + ) + def test_flow_serve_invalid_engine(self, source: Path, capsys): + invalid_engine = "invalid_engine" + with pytest.raises(SystemExit): + run_pf_command( + "flow", + "serve", + "--source", + source.as_posix(), + "--skip-open-browser", + "--engine", + invalid_engine, + ) + out, err = capsys.readouterr() + assert f"Unsupported engine {invalid_engine} for Python flow, only support 'flask' and 'fastapi'." in out diff --git a/src/promptflow/CHANGELOG.md b/src/promptflow/CHANGELOG.md index 61eac63b1a4..b4c261eb1a7 100644 --- a/src/promptflow/CHANGELOG.md +++ b/src/promptflow/CHANGELOG.md @@ -4,6 +4,7 @@ ### Improvements - [promptflow-devkit]: Interactive browser credential is excluded by default when using Azure AI connections, user could set `PF_NO_INTERACTIVE_LOGIN=False` to enable it. +- [promptflow-devkit]: Add new `--engine` parameter for `pf flow serve`. This parameter can be used to switch python serving engine between `flask` and `fastapi`, currently it defaults to `flask`. - [promptflow-azure]: Refine trace Cosmos DB setup process to print setup status during the process, and display error message from service when setup failed. - [promptflow-devkit][promptflow-azure] - Return the secrets in the connection object by default to improve flex flow experience. - Reach the sub package docs for more details about this. [promptflow-devkit](https://microsoft.github.io/promptflow/reference/changelog/promptflow-devkit.html) [promptflow-azure](https://microsoft.github.io/promptflow/reference/changelog/promptflow-azure.html) From 5c3a5c2385d42908a38954182d53001a246a12bb Mon Sep 17 00:00:00 2001 From: Heyi Tang Date: Sat, 11 May 2024 17:53:21 +0800 Subject: [PATCH 2/3] [Internal][tracing] Refactor span enrich logic to make it more clear. (#3056) # Description Refactor span enrich logic to make it more clear. This pull request introduces a new structure for enriching spans in the tracing module of the `promptflow-tracing` package. The changes primarily involve the creation of a `SpanEnricher` class and a `SpanEnricherManager` singleton class in the `src/promptflow-tracing/promptflow/tracing/_span_enricher.py` file. These classes are used to enrich spans with inputs and outputs of traced functions. Two specific enrichers, `LLMSpanEnricher` and `EmbeddingSpanEnricher`, are created and registered with the manager. The `TraceType` enum is also updated to include a new `RETRIEVAL` type. Key changes include: * [`src/promptflow-tracing/promptflow/tracing/_span_enricher.py`](diffhunk://#diff-695282eb7072d912e1bbe30979ce4b19b7c4ffe45a76fd64d48d1c392d4f5c25R1-R49): Introduced `SpanEnricher` and `SpanEnricherManager` classes. The manager is a singleton that maintains a dictionary mapping trace types to their respective span enrichers. The base `SpanEnricher` class provides a method to enrich a span with the output of a traced function. * [`src/promptflow-tracing/promptflow/tracing/_trace.py`](diffhunk://#diff-580941184737186a94e3e9c06e467e86c06ce846d30ffa42c1c43b45812ddcdcR23): Imported the `SpanEnricher` and `SpanEnricherManager` classes. The `enrich_span_with_trace_type` function now uses the `SpanEnricherManager` to enrich spans based on their trace type. Two specific enrichers, `LLMSpanEnricher` and `EmbeddingSpanEnricher`, are implemented and registered with the `SpanEnricherManager`. [[1]](diffhunk://#diff-580941184737186a94e3e9c06e467e86c06ce846d30ffa42c1c43b45812ddcdcR23) [[2]](diffhunk://#diff-580941184737186a94e3e9c06e467e86c06ce846d30ffa42c1c43b45812ddcdcL150-R154) [[3]](diffhunk://#diff-580941184737186a94e3e9c06e467e86c06ce846d30ffa42c1c43b45812ddcdcR515-R532) * [`src/promptflow-tracing/promptflow/tracing/contracts/trace.py`](diffhunk://#diff-064c539f55a396428c9d28af1cbcf4fe39beb69daac7d9891cd6f713897fa704R18): Added `RETRIEVAL` as a new trace type to the `TraceType` enum. These changes aim to make the process of enriching spans more flexible and extensible, allowing for different enrichment logic based on the trace type. # All Promptflow Contribution checklist: - [ ] **The pull request does not introduce [breaking changes].** - [ ] **CHANGELOG is updated for new features, bug fixes or other significant changes.** - [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).** - [ ] **Create an issue and link to the pull request to get dedicated review from promptflow team. Learn more: [suggested workflow](../CONTRIBUTING.md#suggested-workflow).** ## General Guidelines and Best Practices - [ ] Title of the pull request is clear and informative. - [ ] There are a small number of commits, each of which have an informative message. This means that previously merged commits do not appear in the history of the PR. For more information on cleaning up the commits in your PR, [see this page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md). ### Testing Guidelines - [ ] Pull request includes test coverage for the included changes. Co-authored-by: Heyi --- .../promptflow/tracing/_span_enricher.py | 49 +++++++++++++++++++ .../promptflow/tracing/_trace.py | 32 ++++++++---- .../promptflow/tracing/contracts/trace.py | 1 + 3 files changed, 72 insertions(+), 10 deletions(-) create mode 100644 src/promptflow-tracing/promptflow/tracing/_span_enricher.py diff --git a/src/promptflow-tracing/promptflow/tracing/_span_enricher.py b/src/promptflow-tracing/promptflow/tracing/_span_enricher.py new file mode 100644 index 00000000000..da4b4f14275 --- /dev/null +++ b/src/promptflow-tracing/promptflow/tracing/_span_enricher.py @@ -0,0 +1,49 @@ +from typing import Dict + +from .contracts.trace import TraceType + + +class SpanEnricher: + def __init__(self): + pass + + def enrich(self, span, inputs, output): + """This method is used to enrich the span with the inputs and output of the traced function. + Note that this method is called after the function is called, so some inputs related logic is not here. + """ + # TODO: Also move input related logic here. + from ._trace import enrich_span_with_output + + enrich_span_with_output(span, output) + + +class SpanEnricherManager: + _instance = None + + def __init__(self): + self._type2enricher: Dict[str, SpanEnricher] = {} + self._base_enricher = SpanEnricher() + + @classmethod + def get_instance(cls) -> "SpanEnricherManager": + if cls._instance is None: + cls._instance = SpanEnricherManager() + return cls._instance + + @classmethod + def register(cls, trace_type, enricher: SpanEnricher): + cls.get_instance()._register(trace_type, enricher) + + @classmethod + def enrich(cls, span, inputs, output, trace_type): + cls.get_instance()._enrich(span, inputs, output, trace_type) + + def _register(self, trace_type, enricher: SpanEnricher): + self._type2enricher[trace_type] = enricher + + def _enrich(self, span, inputs, output, trace_type): + enricher = self._type2enricher.get(trace_type, self._base_enricher) + enricher.enrich(span, inputs, output) + + +SpanEnricherManager.register(TraceType.FUNCTION, SpanEnricher()) diff --git a/src/promptflow-tracing/promptflow/tracing/_trace.py b/src/promptflow-tracing/promptflow/tracing/_trace.py index ba9d56ab15d..9d93710f79c 100644 --- a/src/promptflow-tracing/promptflow/tracing/_trace.py +++ b/src/promptflow-tracing/promptflow/tracing/_trace.py @@ -21,6 +21,7 @@ from ._openai_utils import OpenAIMetricsCalculator, OpenAIResponseParser from ._operation_context import OperationContext +from ._span_enricher import SpanEnricher, SpanEnricherManager from ._tracer import Tracer, _create_trace_from_function_call, get_node_name_from_context from ._utils import get_input_names_for_prompt_template, get_prompt_param_name_from_func, serialize from .contracts.generator_proxy import AsyncGeneratorProxy, GeneratorProxy @@ -148,17 +149,10 @@ def enrich_span_with_input(span, input): def enrich_span_with_trace_type(span, inputs, output, trace_type): - if trace_type == TraceType.LLM: - # Handle the non-streaming output of LLM, the streaming output will be handled in traced_generator. - token_collector.collect_openai_tokens(span, output) - enrich_span_with_llm_output(span, output) - elif trace_type == TraceType.EMBEDDING: - token_collector.collect_openai_tokens(span, output) - enrich_span_with_embedding(span, inputs, output) + SpanEnricherManager.enrich(span, inputs, output, trace_type) + # TODO: Move the following logic to SpanEnricher enrich_span_with_openai_tokens(span, trace_type) - enrich_span_with_output(span, output) - output = trace_iterator_if_needed(span, inputs, output) - return output + return trace_iterator_if_needed(span, inputs, output) def trace_iterator_if_needed(span, inputs, output): @@ -519,3 +513,21 @@ async def greetings_async(user_id): """ return _traced(func, trace_type=TraceType.FUNCTION) + + +class LLMSpanEnricher(SpanEnricher): + def enrich(self, span, inputs, output): + token_collector.collect_openai_tokens(span, output) + enrich_span_with_llm_output(span, output) + super().enrich(span, inputs, output) + + +class EmbeddingSpanEnricher(SpanEnricher): + def enrich(self, span, inputs, output): + token_collector.collect_openai_tokens(span, output) + enrich_span_with_embedding(span, inputs, output) + super().enrich(span, inputs, output) + + +SpanEnricherManager.register(TraceType.LLM, LLMSpanEnricher()) +SpanEnricherManager.register(TraceType.EMBEDDING, EmbeddingSpanEnricher()) diff --git a/src/promptflow-tracing/promptflow/tracing/contracts/trace.py b/src/promptflow-tracing/promptflow/tracing/contracts/trace.py index 3003967314a..980e3f7d9c2 100644 --- a/src/promptflow-tracing/promptflow/tracing/contracts/trace.py +++ b/src/promptflow-tracing/promptflow/tracing/contracts/trace.py @@ -15,6 +15,7 @@ class TraceType(str, Enum): LANGCHAIN = "LangChain" FLOW = "Flow" EMBEDDING = "Embedding" + RETRIEVAL = "Retrieval" @dataclass From 2cf0fbef18ba46aaa1bc81e875cb5a672de4a0f2 Mon Sep 17 00:00:00 2001 From: Peiwen Gao <111329184+PeiwenGaoMS@users.noreply.github.com> Date: Sat, 11 May 2024 17:54:34 +0800 Subject: [PATCH 3/3] [Internal][Executor] Return inputs definition and has aggregation by initialize api in execution server (#3212) # Description Return inputs definition and has aggregation by initialize api in execution server to avoid creating flow executor in runtime container. # All Promptflow Contribution checklist: - [x] **The pull request does not introduce [breaking changes].** - [ ] **CHANGELOG is updated for new features, bug fixes or other significant changes.** - [x] **I have read the [contribution guidelines](../CONTRIBUTING.md).** - [ ] **Create an issue and link to the pull request to get dedicated review from promptflow team. Learn more: [suggested workflow](../CONTRIBUTING.md#suggested-workflow).** ## General Guidelines and Best Practices - [x] Title of the pull request is clear and informative. - [x] There are a small number of commits, each of which have an informative message. This means that previously merged commits do not appear in the history of the PR. For more information on cleaning up the commits in your PR, [see this page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md). ### Testing Guidelines - [ ] Pull request includes test coverage for the included changes. --- .../promptflow/executor/_prompty_executor.py | 3 + .../executor/_service/apis/batch.py | 10 +- .../_service/utils/batch_coordinator.py | 6 + .../promptflow/executor/flow_executor.py | 3 + .../executor/_service/apis/test_batch.py | 108 ++++++++++++++++++ src/promptflow/tests/executor/utils.py | 34 ++++-- 6 files changed, 151 insertions(+), 13 deletions(-) create mode 100644 src/promptflow/tests/executor/unittests/executor/_service/apis/test_batch.py diff --git a/src/promptflow-core/promptflow/executor/_prompty_executor.py b/src/promptflow-core/promptflow/executor/_prompty_executor.py index d8a7abed15b..a38e4ae9699 100644 --- a/src/promptflow-core/promptflow/executor/_prompty_executor.py +++ b/src/promptflow-core/promptflow/executor/_prompty_executor.py @@ -66,3 +66,6 @@ def _init_input_sign(self): self._inputs_sign = flow.inputs # The init signature only used for flex flow, so we set the _init_sign to empty dict for prompty flow. self._init_sign = {} + + def get_inputs_definition(self): + return self._inputs diff --git a/src/promptflow-core/promptflow/executor/_service/apis/batch.py b/src/promptflow-core/promptflow/executor/_service/apis/batch.py index 22d237a1347..ace583e86b5 100644 --- a/src/promptflow-core/promptflow/executor/_service/apis/batch.py +++ b/src/promptflow-core/promptflow/executor/_service/apis/batch.py @@ -24,13 +24,13 @@ @router.post("/initialize") def initialize(request: InitializationRequest): with get_log_context(request, enable_service_logger=True): - # validate request and get operation context + # Validate request and get operation context. request.validate_request() operation_context = update_and_get_operation_context(request.operation_context) service_logger.info(f"Received batch init request, executor version: {operation_context.get_user_agent()}.") - # resolve environment variables + # Resolve environment variables. set_environment_variables(request.environment_variables) - # init batch coordinator to validate flow and create process pool + # Init batch coordinator to validate flow and create process pool. batch_coordinator = BatchCoordinator( working_dir=request.working_dir, flow_file=request.flow_file, @@ -42,8 +42,8 @@ def initialize(request: InitializationRequest): init_kwargs=request.init_kwargs, ) batch_coordinator.start() - # return json response - return {"status": "initialized"} + # Return some flow infos including the flow inputs definition and whether it has aggregation nodes. + return batch_coordinator.get_flow_infos() @router.post("/execution") diff --git a/src/promptflow-core/promptflow/executor/_service/utils/batch_coordinator.py b/src/promptflow-core/promptflow/executor/_service/utils/batch_coordinator.py index fe6d0467452..a586445b451 100644 --- a/src/promptflow-core/promptflow/executor/_service/utils/batch_coordinator.py +++ b/src/promptflow-core/promptflow/executor/_service/utils/batch_coordinator.py @@ -79,6 +79,12 @@ def get_instance(cls): def get_log_context(self): return self._log_context + def get_flow_infos(self): + return { + "inputs_definition": self._flow_executor.get_inputs_definition(), + "has_aggregation": self._flow_executor.has_aggregation_node, + } + def start(self): """Start the process pool.""" self._process_pool.start() diff --git a/src/promptflow-core/promptflow/executor/flow_executor.py b/src/promptflow-core/promptflow/executor/flow_executor.py index df3099404b4..4ffede414c4 100644 --- a/src/promptflow-core/promptflow/executor/flow_executor.py +++ b/src/promptflow-core/promptflow/executor/flow_executor.py @@ -681,6 +681,9 @@ def _exec_in_thread(self, args) -> LineResult: self._completed_idx[line_number] = thread_name return results + def get_inputs_definition(self): + return self._flow.inputs + def exec_line( self, inputs: Mapping[str, Any], diff --git a/src/promptflow/tests/executor/unittests/executor/_service/apis/test_batch.py b/src/promptflow/tests/executor/unittests/executor/_service/apis/test_batch.py new file mode 100644 index 00000000000..ab0ca967992 --- /dev/null +++ b/src/promptflow/tests/executor/unittests/executor/_service/apis/test_batch.py @@ -0,0 +1,108 @@ +import pytest +from fastapi.testclient import TestClient + +from .....utils import construct_initialization_request_json + + +def construct_initialize_request_json(): + return {} + + +@pytest.mark.unittest +class TestBatchApis: + @pytest.mark.parametrize( + "flow_folder, flow_file, init_kwargs, expected_inputs_definition, expected_has_aggregation", + [ + # dag flow without aggregation nodes + ( + "print_input_flow", + "flow.dag.yaml", + None, + { + "text": { + "type": "string", + "default": None, + "description": "", + "enum": [], + "is_chat_input": False, + "is_chat_history": None, + } + }, + False, + ), + # dag flow with aggregation nodes + ( + "simple_aggregation", + "flow.dag.yaml", + None, + { + "text": { + "type": "string", + "default": "play", + "description": "", + "enum": [], + "is_chat_input": False, + "is_chat_history": None, + } + }, + True, + ), + # flex flow without aggregation + ( + "simple_with_yaml", + "flow.flex.yaml", + None, + { + "input_val": { + "type": "string", + "default": "gpt", + "description": None, + "enum": None, + "is_chat_input": False, + "is_chat_history": None, + } + }, + False, + ), + # flex flow with aggregation + ( + "basic_callable_class_async", + "flow.flex.yaml", + {"obj_input": "obj_input"}, + { + "func_input": { + "type": "string", + "default": None, + "description": None, + "enum": None, + "is_chat_input": False, + "is_chat_history": None, + } + }, + True, + ), + ], + ) + def test_initialize( + self, + executor_client: TestClient, + flow_folder, + flow_file, + init_kwargs, + expected_inputs_definition, + expected_has_aggregation, + ): + initialization_request = construct_initialization_request_json( + flow_folder=flow_folder, + flow_file=flow_file, + init_kwargs=init_kwargs, + ) + response = executor_client.post(url="/initialize", json=initialization_request) + # assert response + assert response.status_code == 200 + assert response.json() == { + "inputs_definition": expected_inputs_definition, + "has_aggregation": expected_has_aggregation, + } + executor_client.post(url="/finalize") + assert response.status_code == 200 diff --git a/src/promptflow/tests/executor/utils.py b/src/promptflow/tests/executor/utils.py index fd0edf82dbb..9ccab8bf4d0 100644 --- a/src/promptflow/tests/executor/utils.py +++ b/src/promptflow/tests/executor/utils.py @@ -132,21 +132,39 @@ def is_image_file(file_path: Path): def construct_flow_execution_request_json(flow_folder, root=FLOW_ROOT, inputs=None, connections=None): + base_execution_request = construct_base_execution_request_json(flow_folder, root=root, connections=connections) + flow_execution_request = { + "run_id": str(uuid.uuid4()), + "inputs": inputs, + "operation_context": { + "request_id": "test-request-id", + "user_agent": "test-user-agent", + }, + } + return {**base_execution_request, **flow_execution_request} + + +def construct_initialization_request_json( + flow_folder, root=FLOW_ROOT, flow_file="flow.dag.yaml", connections=None, init_kwargs=None +): + if flow_file == "flow.flex.yaml": + root = EAGER_FLOW_ROOT + base_execution_request = construct_base_execution_request_json( + flow_folder, root=root, connections=connections, flow_file=flow_file + ) + return {**base_execution_request, "init_kwargs": init_kwargs} if init_kwargs is not None else base_execution_request + + +def construct_base_execution_request_json(flow_folder, root=FLOW_ROOT, connections=None, flow_file="flow.dag.yaml"): working_dir = get_flow_folder(flow_folder, root=root) tmp_dir = Path(mkdtemp()) log_path = tmp_dir / "log.txt" return { - "run_id": str(uuid.uuid4()), "working_dir": working_dir.as_posix(), - "flow_file": "flow.dag.yaml", + "flow_file": flow_file, "output_dir": tmp_dir.as_posix(), - "connections": connections, "log_path": log_path.as_posix(), - "inputs": inputs, - "operation_context": { - "request_id": "test-request-id", - "user_agent": "test-user-agent", - }, + "connections": connections, }