diff --git a/prefect_aws/workers/ecs_worker.py b/prefect_aws/workers/ecs_worker.py index 450cf185..b3b54495 100644 --- a/prefect_aws/workers/ecs_worker.py +++ b/prefect_aws/workers/ecs_worker.py @@ -70,9 +70,9 @@ from pydantic import VERSION as PYDANTIC_VERSION if PYDANTIC_VERSION.startswith("2."): - from pydantic.v1 import Field, root_validator + from pydantic.v1 import BaseModel, Field, root_validator else: - from pydantic import Field, root_validator + from pydantic import Field, root_validator, BaseModel from slugify import slugify from tenacity import retry, stop_after_attempt, wait_fixed, wait_random @@ -367,6 +367,16 @@ def network_configuration_requires_vpc_id(cls, values: dict) -> dict: return values +class CapacityProvider(BaseModel): + """ + The capacity provider strategy to use when running the task. + """ + + capacity_provider: str + weight: int + base: int + + class ECSVariables(BaseVariables): """ Variables for templating an ECS job. @@ -425,6 +435,13 @@ class ECSVariables(BaseVariables): ), ) ) + capacity_provider_strategy: Optional[List[CapacityProvider]] = Field( + default=None, + description=( + "The capacity provider strategy to use when running the task. This is only" + "If a capacityProviderStrategy is specified, we will omit the launchType" + ), + ) image: Optional[str] = Field( default=None, description=( diff --git a/tests/workers/test_ecs_worker.py b/tests/workers/test_ecs_worker.py index 39329d8d..9db3882e 100644 --- a/tests/workers/test_ecs_worker.py +++ b/tests/workers/test_ecs_worker.py @@ -2020,15 +2020,11 @@ async def test_user_defined_environment_variables_in_task_definition_template( async def test_user_defined_capacity_provider_strategy( aws_credentials: AwsCredentials, flow_run: FlowRun ): - configuration = await construct_configuration_with_job_template( - template_overrides=dict( - task_run_request={ - "capacityProviderStrategy": [ - {"base": 0, "weight": 1, "capacityProvider": "r6i.large"}, - ] - }, - ), + configuration = await construct_configuration( aws_credentials=aws_credentials, + capacity_provider_strategy=[ + {"base": 0, "weight": 1, "capacityProvider": "r6i.large"} + ], ) assert "launchType" not in configuration.task_run_request