From bde96a782b4420d8578dceb5221e7765015de446 Mon Sep 17 00:00:00 2001 From: Craig de Gouveia Date: Fri, 18 Aug 2023 19:09:13 +0100 Subject: [PATCH] Allow custom awsvpcConfiguration for ECS Worker (#304) --- CHANGELOG.md | 1 + prefect_aws/workers/ecs_worker.py | 92 +++++++++++++++++++++++-- tests/workers/test_ecs_worker.py | 107 ++++++++++++++++++++++++++++++ 3 files changed, 194 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 271c1711..e0a447c8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added retries to ECS task run creation for ECS worker - [#303](https://github.com/PrefectHQ/prefect-aws/pull/303) +- Added support to `ECSWorker` for `awsvpcConfiguration` [#304](https://github.com/PrefectHQ/prefect-aws/pull/304) ### Changed diff --git a/prefect_aws/workers/ecs_worker.py b/prefect_aws/workers/ecs_worker.py index 699873a7..f4d97bef 100644 --- a/prefect_aws/workers/ecs_worker.py +++ b/prefect_aws/workers/ecs_worker.py @@ -229,6 +229,7 @@ class ECSJobConfiguration(BaseJobConfiguration): ) configure_cloudwatch_logs: Optional[bool] = Field(default=None) cloudwatch_logs_options: Dict[str, str] = Field(default_factory=dict) + network_configuration: Dict[str, Any] = Field(default_factory=dict) stream_output: Optional[bool] = Field(default=None) task_start_timeout_seconds: int = Field(default=300) task_watch_poll_interval: float = Field(default=5.0) @@ -321,6 +322,18 @@ def cloudwatch_logs_options_requires_configure_cloudwatch_logs( ) return values + @root_validator + def network_configuration_requires_vpc_id(cls, values: dict) -> dict: + """ + Enforces a `vpc_id` is provided when custom network configuration mode is + enabled for network settings. + """ + if values.get("network_configuration") and not values.get("vpc_id"): + raise ValueError( + "You must provide a `vpc_id` to enable custom `network_configuration`." + ) + return values + class ECSVariables(BaseVariables): """ @@ -459,10 +472,21 @@ class ECSVariables(BaseVariables): "When `configure_cloudwatch_logs` is enabled, this setting may be used to" " pass additional options to the CloudWatch logs configuration or override" " the default options. See the [AWS" - " documentation](https://docs.aws.amazon.com/AmazonECS/latest/developerguide/using_awslogs.html#create_awslogs_logdriver_options.)" # noqa + " documentation](https://docs.aws.amazon.com/AmazonECS/latest/developerguide/using_awslogs.html#create_awslogs_logdriver_options)" # noqa " for available options. " ), ) + + network_configuration: Dict[str, Any] = Field( + default_factory=dict, + description=( + "When `network_configuration` is supplied it will override ECS Worker's" + "awsvpcConfiguration that defined in the ECS task executing your workload. " + "See the [AWS documentation](https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-properties-ecs-service-awsvpcconfiguration.html)" # noqa + " for available options." + ), + ) + stream_output: bool = Field( default=None, description=( @@ -1242,7 +1266,7 @@ def _prepare_task_definition( return task_definition - def _load_vpc_network_config( + def _load_network_configuration( self, vpc_id: Optional[str], boto_session: boto3.Session ) -> dict: """ @@ -1289,6 +1313,47 @@ def _load_vpc_network_config( } } + def _custom_network_configuration( + self, vpc_id: str, network_configuration: dict, boto_session: boto3.Session + ) -> dict: + """ + Load settings from a specific VPC or the default VPC and generate a task + run request's network configuration. + """ + ec2_client = boto_session.client("ec2") + vpc_message = f"VPC with ID {vpc_id}" + + vpcs = ec2_client.describe_vpcs(VpcIds=[vpc_id]).get("Vpcs") + + if not vpcs: + raise ValueError( + f"Failed to find {vpc_message}. " + + "Network configuration cannot be inferred. " + + "Pass an explicit `vpc_id`." + ) + + vpc_id = vpcs[0]["VpcId"] + subnets = ec2_client.describe_subnets( + Filters=[{"Name": "vpc-id", "Values": [vpc_id]}] + )["Subnets"] + + if not subnets: + raise ValueError( + f"Failed to find subnets for {vpc_message}. " + + "Network configuration cannot be inferred." + ) + + config_subnets = network_configuration.get("subnets", []) + if not all( + [conf_sn in sn.values() for conf_sn in config_subnets for sn in subnets] + ): + raise ValueError( + f"Subnets {config_subnets} not found within {vpc_message}." + + "Please check that VPC is associated with supplied subnets." + ) + + return {"awsvpcConfiguration": network_configuration} + def _prepare_task_run_request( self, boto_session: boto3.Session, @@ -1318,14 +1383,29 @@ def _prepare_task_run_request( container_overrides = overrides.get("containerOverrides", []) # Ensure the network configuration is present if using awsvpc for network mode - - if task_definition.get("networkMode") == "awsvpc" and not task_run_request.get( - "networkConfiguration" + if ( + task_definition.get("networkMode") == "awsvpc" + and not task_run_request.get("networkConfiguration") + and not configuration.network_configuration ): - task_run_request["networkConfiguration"] = self._load_vpc_network_config( + task_run_request["networkConfiguration"] = self._load_network_configuration( configuration.vpc_id, boto_session ) + # Use networkConfiguration if supplied by user + if ( + task_definition.get("networkMode") == "awsvpc" + and configuration.network_configuration + and configuration.vpc_id + ): + task_run_request["networkConfiguration"] = ( + self._custom_network_configuration( + configuration.vpc_id, + configuration.network_configuration, + boto_session, + ) + ) + # Ensure the container name is set if not provided at template time container_name = ( diff --git a/tests/workers/test_ecs_worker.py b/tests/workers/test_ecs_worker.py index 4ebfb158..ef57372b 100644 --- a/tests/workers/test_ecs_worker.py +++ b/tests/workers/test_ecs_worker.py @@ -12,6 +12,7 @@ from moto.ec2.utils import generate_instance_identity_document from prefect.server.schemas.core import FlowRun from prefect.utilities.asyncutils import run_sync_in_worker_thread +from pydantic import ValidationError from tenacity import RetryError from prefect_aws.workers.ecs_worker import ( @@ -884,6 +885,112 @@ async def test_network_config_from_vpc_id( } +@pytest.mark.usefixtures("ecs_mocks") +async def test_network_config_from_custom_settings( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + session = aws_credentials.get_boto3_session() + ec2_resource = session.resource("ec2") + vpc = ec2_resource.create_vpc(CidrBlock="10.0.0.0/16") + subnet = ec2_resource.create_subnet(CidrBlock="10.0.2.0/24", VpcId=vpc.id) + security_group = ec2_resource.create_security_group( + GroupName="ECSWorkerTestSG", Description="ECS Worker test SG", VpcId=vpc.id + ) + + configuration = await construct_configuration( + aws_credentials=aws_credentials, + vpc_id=vpc.id, + override_network_configuration=True, + network_configuration={ + "subnets": [subnet.id], + "assignPublicIp": "DISABLED", + "securityGroups": [security_group.id], + }, + ) + + session = aws_credentials.get_boto3_session() + + async with ECSWorker(work_pool_name="test") as worker: + # Capture the task run call because moto does not track 'networkConfiguration' + original_run_task = worker._create_task_run + mock_run_task = MagicMock(side_effect=original_run_task) + worker._create_task_run = mock_run_task + + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + network_configuration = mock_run_task.call_args[0][1].get("networkConfiguration") + + # Subnet ids are copied from the vpc + assert network_configuration == { + "awsvpcConfiguration": { + "subnets": [subnet.id], + "assignPublicIp": "DISABLED", + "securityGroups": [security_group.id], + } + } + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_network_config_from_custom_settings_invalid_subnet( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + session = aws_credentials.get_boto3_session() + ec2_resource = session.resource("ec2") + vpc = ec2_resource.create_vpc(CidrBlock="10.0.0.0/16") + security_group = ec2_resource.create_security_group( + GroupName="ECSWorkerTestSG", Description="ECS Worker test SG", VpcId=vpc.id + ) + ec2_resource.create_subnet(CidrBlock="10.0.2.0/24", VpcId=vpc.id) + + configuration = await construct_configuration( + aws_credentials=aws_credentials, + vpc_id=vpc.id, + override_network_configuration=True, + network_configuration={ + "subnets": ["sn-8asdas"], + "assignPublicIp": "DISABLED", + "securityGroups": [security_group.id], + }, + ) + + session = aws_credentials.get_boto3_session() + + with pytest.raises( + ValueError, + match=( + r"Subnets \['sn-8asdas'\] not found within VPC with ID " + + vpc.id + + r"\.Please check that VPC is associated with supplied subnets\." + ), + ): + async with ECSWorker(work_pool_name="test") as worker: + original_run_task = worker._create_task_run + mock_run_task = MagicMock(side_effect=original_run_task) + worker._create_task_run = mock_run_task + + await run_then_stop_task(worker, configuration, flow_run) + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_network_config_configure_network_requires_vpc_id( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + with pytest.raises( + ValidationError, + match="You must provide a `vpc_id` to enable custom `network_configuration`.", + ): + await construct_configuration( + aws_credentials=aws_credentials, + override_network_configuration=True, + network_configuration={ + "subnets": [], + "assignPublicIp": "ENABLED", + "securityGroups": [], + }, + ) + + @pytest.mark.usefixtures("ecs_mocks") async def test_network_config_from_default_vpc( aws_credentials: AwsCredentials, flow_run: FlowRun