Skip to content

Commit

Permalink
Reduce memory footprint of s3 key trigger (apache#40473)
Browse files Browse the repository at this point in the history
* Use generator to reduce memory footprint

We can return True on the first positive.  And we don't need to keep track of the files.

* add tests

* fixup

* Update airflow/providers/amazon/aws/hooks/s3.py

Co-authored-by: Vincent <[email protected]>

* Revert "Update airflow/providers/amazon/aws/hooks/s3.py"

This reverts commit cb2f31a.

* reapply vincent's suggestion

* add check for param

* add changelog

---------

Co-authored-by: Vincent <[email protected]>
  • Loading branch information
dstandish and vincbeck authored Jun 28, 2024
1 parent 6c12744 commit bbfeee4
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 96 deletions.
16 changes: 16 additions & 0 deletions airflow/providers/amazon/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,22 @@
Changelog
---------

main
....

Bug Fixes
~~~~~~~~~

* Reduce memory footprint of s3 key trigger (#40473)
- Decorator ``provide_bucket_name_async`` removed
* We do not need a separate decorator for async. The old one is removed and users can use ``provide_bucket_name``
for coroutine functions, async iterators, and normal synchronous functions.
- Hook method ``get_file_metadata_async`` is now an async iterator
* Previously, the metadata objects were accumulated in a list. Now the objects are yielded as we page
through the results. To get a list you may use ``async for`` in a list comprehension.
- S3KeyTrigger avoids loading all positive matches into memory in some circumstances


8.25.0
......

Expand Down
114 changes: 62 additions & 52 deletions airflow/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import asyncio
import fnmatch
import gzip as gz
import inspect
import logging
import os
import re
Expand All @@ -36,7 +37,7 @@
from io import BytesIO
from pathlib import Path
from tempfile import NamedTemporaryFile, gettempdir
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable
from urllib.parse import urlsplit
from uuid import uuid4

Expand Down Expand Up @@ -65,38 +66,15 @@ def provide_bucket_name(func: Callable) -> Callable:
"""Provide a bucket name taken from the connection if no bucket name has been passed to the function."""
if hasattr(func, "_unify_bucket_name_and_key_wrapped"):
logger.warning("`unify_bucket_name_and_key` should wrap `provide_bucket_name`.")
function_signature = signature(func)

@wraps(func)
def wrapper(*args, **kwargs) -> Callable:
bound_args = function_signature.bind(*args, **kwargs)

if "bucket_name" not in bound_args.arguments:
self = args[0]

if "bucket_name" in self.service_config:
bound_args.arguments["bucket_name"] = self.service_config["bucket_name"]
elif self.conn_config and self.conn_config.schema:
warnings.warn(
"s3 conn_type, and the associated schema field, is deprecated. "
"Please use aws conn_type instead, and specify `bucket_name` "
"in `service_config.s3` within `extras`.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
bound_args.arguments["bucket_name"] = self.conn_config.schema

return func(*bound_args.args, **bound_args.kwargs)

return wrapper


def provide_bucket_name_async(func: Callable) -> Callable:
"""Provide a bucket name taken from the connection if no bucket name has been passed to the function."""
function_signature = signature(func)
if "bucket_name" not in function_signature.parameters:
raise RuntimeError(
"Decorator provide_bucket_name should only wrap a function with param 'bucket_name'."
)

@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
# todo: raise immediately if func has no bucket_name arg
async def maybe_add_bucket_name(*args, **kwargs):
bound_args = function_signature.bind(*args, **kwargs)

if "bucket_name" not in bound_args.arguments:
Expand All @@ -105,8 +83,46 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:
connection = await sync_to_async(self.get_connection)(self.aws_conn_id)
if connection.schema:
bound_args.arguments["bucket_name"] = connection.schema
return bound_args

if inspect.iscoroutinefunction(func):

@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
bound_args = await maybe_add_bucket_name(*args, **kwargs)
print(f"invoking async function {func=}")
return await func(*bound_args.args, **bound_args.kwargs)

elif inspect.isasyncgenfunction(func):

@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
bound_args = await maybe_add_bucket_name(*args, **kwargs)
async for thing in func(*bound_args.args, **bound_args.kwargs):
yield thing

else:

@wraps(func)
def wrapper(*args, **kwargs) -> Callable:
bound_args = function_signature.bind(*args, **kwargs)

if "bucket_name" not in bound_args.arguments:
self = args[0]

if "bucket_name" in self.service_config:
bound_args.arguments["bucket_name"] = self.service_config["bucket_name"]
elif self.conn_config and self.conn_config.schema:
warnings.warn(
"s3 conn_type, and the associated schema field, is deprecated. "
"Please use aws conn_type instead, and specify `bucket_name` "
"in `service_config.s3` within `extras`.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
bound_args.arguments["bucket_name"] = self.conn_config.schema

return await func(*bound_args.args, **bound_args.kwargs)
return func(*bound_args.args, **bound_args.kwargs)

return wrapper

Expand Down Expand Up @@ -400,8 +416,8 @@ def list_prefixes(

return prefixes

@provide_bucket_name_async
@unify_bucket_name_and_key
@provide_bucket_name
async def get_head_object_async(
self, client: AioBaseClient, key: str, bucket_name: str | None = None
) -> dict[str, Any] | None:
Expand Down Expand Up @@ -462,10 +478,10 @@ async def list_prefixes_async(

return prefixes

@provide_bucket_name_async
@provide_bucket_name
async def get_file_metadata_async(
self, client: AioBaseClient, bucket_name: str, key: str | None = None
) -> list[Any]:
) -> AsyncIterator[Any]:
"""
Get a list of files that a key matching a wildcard expression exists in a bucket asynchronously.
Expand All @@ -477,11 +493,10 @@ async def get_file_metadata_async(
delimiter = ""
paginator = client.get_paginator("list_objects_v2")
response = paginator.paginate(Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter)
files = []
async for page in response:
if "Contents" in page:
files += page["Contents"]
return files
for row in page["Contents"]:
yield row

async def _check_key_async(
self,
Expand All @@ -506,21 +521,16 @@ async def _check_key_async(
"""
bucket_name, key = self.get_s3_bucket_key(bucket_val, key, "bucket_name", "bucket_key")
if wildcard_match:
keys = await self.get_file_metadata_async(client, bucket_name, key)
key_matches = [k for k in keys if fnmatch.fnmatch(k["Key"], key)]
if not key_matches:
return False
elif use_regex:
keys = await self.get_file_metadata_async(client, bucket_name)
key_matches = [k for k in keys if re.match(pattern=key, string=k["Key"])]
if not key_matches:
return False
else:
obj = await self.get_head_object_async(client, key, bucket_name)
if obj is None:
return False

return True
async for k in self.get_file_metadata_async(client, bucket_name, key):
if fnmatch.fnmatch(k["Key"], key):
return True
return False
if use_regex:
async for k in self.get_file_metadata_async(client, bucket_name):
if re.match(pattern=key, string=k["Key"]):
return True
return False
return bool(await self.get_head_object_async(client, key, bucket_name))

async def check_key_async(
self,
Expand Down
122 changes: 78 additions & 44 deletions tests/providers/amazon/aws/hooks/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import re
from datetime import datetime as std_datetime, timezone
from unittest import mock, mock as async_mock
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import AsyncMock, MagicMock, Mock, patch
from urllib.parse import parse_qs

import boto3
Expand Down Expand Up @@ -428,8 +428,9 @@ async def test_s3_key_hook_get_file_metadata_async(self, mock_client):

s3_hook_async = S3Hook(client_type="S3")
mock_client.get_paginator = mock.Mock(return_value=mock_paginator)
task = await s3_hook_async.get_file_metadata_async(mock_client, "test_bucket", "test*")
assert task == [
keys = [x async for x in s3_hook_async.get_file_metadata_async(mock_client, "test_bucket", "test*")]

assert keys == [
{"Key": "test_key", "ETag": "etag1", "LastModified": datetime(2020, 8, 14, 17, 19, 34)},
{"Key": "test_key2", "ETag": "etag2", "LastModified": datetime(2020, 8, 14, 17, 19, 34)},
]
Expand Down Expand Up @@ -632,64 +633,90 @@ async def test_s3_prefix_sensor_hook_check_for_prefix_async(

@pytest.mark.asyncio
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_s3_bucket_key")
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_head_object_async")
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.async_conn")
async def test__check_key_async_without_wildcard_match(
self, mock_client, mock_head_object, mock_get_bucket_key
):
async def test__check_key_async_without_wildcard_match(self, mock_get_conn, mock_get_bucket_key):
"""Test _check_key_async function without using wildcard_match"""
mock_get_bucket_key.return_value = "test_bucket", "test.txt"
mock_head_object.return_value = {"ContentLength": 0}
mock_client = mock_get_conn.return_value
mock_client.head_object = AsyncMock(return_value={"ContentLength": 0})
s3_hook_async = S3Hook(client_type="S3", resource_type="S3")
response = await s3_hook_async._check_key_async(
mock_client.return_value, "test_bucket", False, "s3://test_bucket/file/test.txt"
mock_client, "test_bucket", False, "s3://test_bucket/file/test.txt"
)
assert response is True

@pytest.mark.asyncio
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_s3_bucket_key")
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_head_object_async")
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.async_conn")
async def test_s3__check_key_async_without_wildcard_match_and_get_none(
self, mock_client, mock_head_object, mock_get_bucket_key
self, mock_get_conn, mock_get_bucket_key
):
"""Test _check_key_async function when get head object returns none"""
mock_get_bucket_key.return_value = "test_bucket", "test.txt"
mock_head_object.return_value = None
s3_hook_async = S3Hook(client_type="S3", resource_type="S3")
mock_client = mock_get_conn.return_value
mock_client.head_object = AsyncMock(return_value=None)
response = await s3_hook_async._check_key_async(
mock_client.return_value, "test_bucket", False, "s3://test_bucket/file/test.txt"
mock_client, "test_bucket", False, "s3://test_bucket/file/test.txt"
)
assert response is False

# @async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_s3_bucket_key")
@pytest.mark.asyncio
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_s3_bucket_key")
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_file_metadata_async")
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.async_conn")
async def test_s3__check_key_async_with_wildcard_match(
self, mock_client, mock_get_file_metadata, mock_get_bucket_key
):
@pytest.mark.parametrize(
"contents, result",
[
(
[
{
"Key": "test/example_s3_test_file.txt",
"ETag": "etag1",
"LastModified": datetime(2020, 8, 14, 17, 19, 34),
"Size": 0,
},
{
"Key": "test_key2",
"ETag": "etag2",
"LastModified": datetime(2020, 8, 14, 17, 19, 34),
"Size": 0,
},
],
True,
),
(
[
{
"Key": "test/example_aeoua.txt",
"ETag": "etag1",
"LastModified": datetime(2020, 8, 14, 17, 19, 34),
"Size": 0,
},
{
"Key": "test_key2",
"ETag": "etag2",
"LastModified": datetime(2020, 8, 14, 17, 19, 34),
"Size": 0,
},
],
False,
),
],
)
async def test_s3__check_key_async_with_wildcard_match(self, mock_get_conn, contents, result):
"""Test _check_key_async function"""
mock_get_bucket_key.return_value = "test_bucket", "test"
mock_get_file_metadata.return_value = [
{
"Key": "test_key",
"ETag": "etag1",
"LastModified": datetime(2020, 8, 14, 17, 19, 34),
"Size": 0,
},
{
"Key": "test_key2",
"ETag": "etag2",
"LastModified": datetime(2020, 8, 14, 17, 19, 34),
"Size": 0,
},
]
client = mock_get_conn.return_value
paginator = client.get_paginator.return_value
r = paginator.paginate.return_value
r.__aiter__.return_value = [{"Contents": contents}]
s3_hook_async = S3Hook(client_type="S3", resource_type="S3")
response = await s3_hook_async._check_key_async(
mock_client.return_value, "test_bucket", True, "test/example_s3_test_file.txt"
client=client,
bucket_val="test_bucket",
wildcard_match=True,
key="test/example_s3_test_file.txt",
)
assert response is False
assert response is result

@pytest.mark.parametrize(
"key, pattern, expected",
Expand All @@ -701,24 +728,31 @@ async def test_s3__check_key_async_with_wildcard_match(
)
@pytest.mark.asyncio
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_s3_bucket_key")
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_file_metadata_async")
@async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.async_conn")
async def test__check_key_async_with_use_regex(
self, mock_client, mock_get_file_metadata, mock_get_bucket_key, key, pattern, expected
self, mock_get_conn, mock_get_bucket_key, key, pattern, expected
):
"""Match AWS S3 key with regex expression"""
mock_get_bucket_key.return_value = "test_bucket", pattern
mock_get_file_metadata.return_value = [
client = mock_get_conn.return_value
paginator = client.get_paginator.return_value
r = paginator.paginate.return_value
r.__aiter__.return_value = [
{
"Key": key,
"ETag": "etag1",
"LastModified": datetime(2020, 8, 14, 17, 19, 34),
"Size": 0,
},
"Contents": [
{
"Key": key,
"ETag": "etag1",
"LastModified": datetime(2020, 8, 14, 17, 19, 34),
"Size": 0,
},
]
}
]

s3_hook_async = S3Hook(client_type="S3", resource_type="S3")
response = await s3_hook_async._check_key_async(
client=mock_client.return_value,
client=client,
bucket_val="test_bucket",
wildcard_match=False,
key=pattern,
Expand Down

0 comments on commit bbfeee4

Please sign in to comment.