Skip to content

Commit

Permalink
Redact extra fields for Asset Endpoints in fastAPI (apache#44069)
Browse files Browse the repository at this point in the history
* redact extra fields in asset endpoints for fast api

* redact extra fields in asset endpoints for fast api

* updating test name correctly

* removing duplicated time_freezer
  • Loading branch information
vatsrahul1001 authored Nov 18, 2024
1 parent 8027055 commit 313be64
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 0 deletions.
12 changes: 12 additions & 0 deletions airflow/api_fastapi/core_api/datamodels/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

from pydantic import BaseModel, Field, field_validator

from airflow.utils.log.secrets_masker import redact


class DagScheduleAssetReference(BaseModel):
"""DAG schedule reference serializer for assets."""
Expand Down Expand Up @@ -58,6 +60,11 @@ class AssetResponse(BaseModel):
producing_tasks: list[TaskOutletAssetReference]
aliases: list[AssetAliasSchema]

@field_validator("extra", mode="after")
@classmethod
def redact_extra(cls, v: dict):
return redact(v)


class AssetCollectionResponse(BaseModel):
"""Asset collection response."""
Expand Down Expand Up @@ -93,6 +100,11 @@ class AssetEventResponse(BaseModel):
created_dagruns: list[DagRunAssetReference]
timestamp: datetime

@field_validator("extra", mode="after")
@classmethod
def redact_extra(cls, v: dict):
return redact(v)


class AssetEventCollectionResponse(BaseModel):
"""Asset event collection response."""
Expand Down
145 changes: 145 additions & 0 deletions tests/api_fastapi/core_api/routes/public/test_assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,22 @@ def _create_assets(session, num: int = 2) -> None:
session.commit()


def _create_assets_with_sensitive_extra(session, num: int = 2) -> None:
default_time = "2020-06-11T18:00:00+00:00"
assets = [
AssetModel(
id=i,
uri=f"s3://bucket/key/{i}",
extra={"password": "bar"},
created_at=timezone.parse(default_time),
updated_at=timezone.parse(default_time),
)
for i in range(1, 1 + num)
]
session.add_all(assets)
session.commit()


def _create_provided_asset(session, asset: AssetModel) -> None:
session.add(asset)
session.commit()
Expand All @@ -82,6 +98,24 @@ def _create_assets_events(session, num: int = 2) -> None:
session.commit()


def _create_assets_events_with_sensitive_extra(session, num: int = 2) -> None:
default_time = "2020-06-11T18:00:00+00:00"
assets_events = [
AssetEvent(
id=i,
asset_id=i,
extra={"password": "bar"},
source_task_id="source_task_id",
source_dag_id="source_dag_id",
source_run_id=f"source_run_id_{i}",
timestamp=timezone.parse(default_time),
)
for i in range(1, 1 + num)
]
session.add_all(assets_events)
session.commit()


def _create_provided_asset_event(session, asset_event: AssetEvent) -> None:
session.add(asset_event)
session.commit()
Expand Down Expand Up @@ -142,6 +176,10 @@ def teardown_method(self) -> None:
def create_assets(self, session, num: int = 2):
_create_assets(session=session, num=num)

@provide_session
def create_assets_with_sensitive_extra(self, session, num: int = 2):
_create_assets_with_sensitive_extra(session=session, num=num)

@provide_session
def create_provided_asset(self, session, asset: AssetModel):
_create_provided_asset(session=session, asset=asset)
Expand All @@ -150,6 +188,10 @@ def create_provided_asset(self, session, asset: AssetModel):
def create_assets_events(self, session, num: int = 2):
_create_assets_events(session=session, num=num)

@provide_session
def create_assets_events_with_sensitive_extra(self, session, num: int = 2):
_create_assets_events_with_sensitive_extra(session=session, num=num)

@provide_session
def create_provided_asset_event(self, session, asset_event: AssetEvent):
_create_provided_asset_event(session=session, asset_event=asset_event)
Expand Down Expand Up @@ -439,6 +481,68 @@ def test_limit_and_offset(self, test_client, params, expected_asset_uris):
asset_uris = [asset["uri"] for asset in response.json()["asset_events"]]
assert asset_uris == expected_asset_uris

@pytest.mark.usefixtures("time_freezer")
@pytest.mark.enable_redact
def test_should_mask_sensitive_extra(self, test_client, session):
self.create_assets_with_sensitive_extra()
self.create_assets_events_with_sensitive_extra()
self.create_dag_run()
self.create_asset_dag_run()
response = test_client.get("/public/assets/events")
assert response.status_code == 200
response_data = response.json()
assert response_data == {
"asset_events": [
{
"id": 1,
"asset_id": 1,
"uri": "s3://bucket/key/1",
"extra": {"password": "***"},
"source_task_id": "source_task_id",
"source_dag_id": "source_dag_id",
"source_run_id": "source_run_id_1",
"source_map_index": -1,
"created_dagruns": [
{
"run_id": "source_run_id_1",
"dag_id": "source_dag_id",
"logical_date": "2020-06-11T18:00:00Z",
"start_date": "2020-06-11T18:00:00Z",
"end_date": "2020-06-11T18:00:00Z",
"state": "success",
"data_interval_start": "2020-06-11T18:00:00Z",
"data_interval_end": "2020-06-11T18:00:00Z",
}
],
"timestamp": "2020-06-11T18:00:00Z",
},
{
"id": 2,
"asset_id": 2,
"uri": "s3://bucket/key/2",
"extra": {"password": "***"},
"source_task_id": "source_task_id",
"source_dag_id": "source_dag_id",
"source_run_id": "source_run_id_2",
"source_map_index": -1,
"created_dagruns": [
{
"run_id": "source_run_id_2",
"dag_id": "source_dag_id",
"logical_date": "2020-06-11T18:00:00Z",
"start_date": "2020-06-11T18:00:00Z",
"end_date": "2020-06-11T18:00:00Z",
"state": "success",
"data_interval_start": "2020-06-11T18:00:00Z",
"data_interval_end": "2020-06-11T18:00:00Z",
}
],
"timestamp": "2020-06-11T18:00:00Z",
},
],
"total_entries": 2,
}


class TestGetAssetEndpoint(TestAssets):
@pytest.mark.parametrize(
Expand Down Expand Up @@ -478,6 +582,27 @@ def test_should_respond_404(self, test_client):
assert response.status_code == 404
assert response.json()["detail"] == "The Asset with uri: `s3://bucket/key` was not found"

@pytest.mark.usefixtures("time_freezer")
@pytest.mark.enable_redact
def test_should_mask_sensitive_extra(self, test_client, session):
self.create_assets_with_sensitive_extra()
tz_datetime_format = self.default_time.replace("+00:00", "Z")
uri = "s3://bucket/key/1"
response = test_client.get(
f"/public/assets/{uri}",
)
assert response.status_code == 200
assert response.json() == {
"id": 1,
"uri": "s3://bucket/key/1",
"extra": {"password": "***"},
"created_at": tz_datetime_format,
"updated_at": tz_datetime_format,
"consuming_dags": [],
"producing_tasks": [],
"aliases": [],
}


class TestQueuedEventEndpoint(TestAssets):
def _create_asset_dag_run_queues(self, dag_id, asset_id, session):
Expand Down Expand Up @@ -593,3 +718,23 @@ def test_invalid_attr_not_allowed(self, test_client, session):
response = test_client.post("/public/assets/events", json=event_invalid_payload)

assert response.status_code == 422

@pytest.mark.usefixtures("time_freezer")
@pytest.mark.enable_redact
def test_should_mask_sensitive_extra(self, test_client, session):
self.create_assets()
event_payload = {"uri": "s3://bucket/key/1", "extra": {"password": "bar"}}
response = test_client.post("/public/assets/events", json=event_payload)
assert response.status_code == 200
assert response.json() == {
"id": mock.ANY,
"asset_id": 1,
"uri": "s3://bucket/key/1",
"extra": {"password": "***", "from_rest_api": True},
"source_task_id": None,
"source_dag_id": None,
"source_run_id": None,
"source_map_index": -1,
"created_dagruns": [],
"timestamp": self.default_time.replace("+00:00", "Z"),
}

0 comments on commit 313be64

Please sign in to comment.