Skip to content

Commit

Permalink
Use sentinel to mark dag as removed on reserialization (apache#39825)
Browse files Browse the repository at this point in the history
We don't serialize the dag on the task.dag attr when making RPC calls.  By marking it with a sentinel value, we can add understand when we're dealing with a deserialized object, and then re-set the dag attr while skipping some of the extra code applied in the setter.
  • Loading branch information
dstandish authored Jul 10, 2024
1 parent 17b792d commit 7d7a4cd
Show file tree
Hide file tree
Showing 10 changed files with 131 additions and 47 deletions.
16 changes: 13 additions & 3 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.setup_teardown import SetupTeardownContext
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.types import NOTSET
from airflow.utils.types import ATTRIBUTE_REMOVED, NOTSET
from airflow.utils.xcom import XCOM_RETURN_KEY

if TYPE_CHECKING:
Expand Down Expand Up @@ -1245,11 +1245,21 @@ def dag(self) -> DAG: # type: ignore[override]
@dag.setter
def dag(self, dag: DAG | None):
"""Operators can be assigned to one DAG, one time. Repeat assignments to that same DAG are ok."""
from airflow.models.dag import DAG

if dag is None:
self._dag = None
return

# if set to removed, then just set and exit
if self._dag is ATTRIBUTE_REMOVED:
self._dag = dag
return
# if setting to removed, then just set and exit
if dag is ATTRIBUTE_REMOVED:
self._dag = ATTRIBUTE_REMOVED # type: ignore[assignment]
return

from airflow.models.dag import DAG

if not isinstance(dag, DAG):
raise TypeError(f"Expected DAG; received {dag.__class__.__name__}")
elif self.has_dag() and self.dag is not dag:
Expand Down
47 changes: 20 additions & 27 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@
from airflow.utils.task_group import MappedTaskGroup
from airflow.utils.task_instance_session import set_current_task_instance_session
from airflow.utils.timeout import timeout
from airflow.utils.types import ATTRIBUTE_REMOVED
from airflow.utils.xcom import XCOM_RETURN_KEY

TR = TaskReschedule
Expand Down Expand Up @@ -902,13 +903,15 @@ def _clear_next_method_args(*, task_instance: TaskInstance | TaskInstancePydanti
def _get_template_context(
*,
task_instance: TaskInstance | TaskInstancePydantic,
dag: DAG,
session: Session | None = None,
ignore_param_exceptions: bool = True,
) -> Context:
"""
Return TI Context.
:param task_instance: the task instance
:param task_instance: the task instance for the task
:param dag for the task
:param session: SQLAlchemy ORM Session
:param ignore_param_exceptions: flag to suppress value exceptions while initializing the ParamsDict
Expand All @@ -928,27 +931,10 @@ def _get_template_context(
assert task_instance.task
assert task
assert task.dag
try:
dag: DAG = task.dag
except AirflowException:
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic

if isinstance(task_instance, TaskInstancePydantic):
ti = session.scalar(
select(TaskInstance).where(
TaskInstance.task_id == task_instance.task_id,
TaskInstance.dag_id == task_instance.dag_id,
TaskInstance.run_id == task_instance.run_id,
TaskInstance.map_index == task_instance.map_index,
)
)
dag = ti.dag_model.serialized_dag.dag
if hasattr(task_instance.task, "_dag"): # BaseOperator
task_instance.task._dag = dag
else: # MappedOperator
task_instance.task.dag = dag
else:
raise
if task.dag is ATTRIBUTE_REMOVED:
task.dag = dag # required after deserialization

dag_run = task_instance.get_dagrun(session)
data_interval = dag.get_run_data_interval(dag_run)

Expand Down Expand Up @@ -1278,12 +1264,8 @@ def _record_task_map_for_downstreams(
:meta private:
"""
# when taking task over RPC, we need to add the dag back
if isinstance(task, MappedOperator):
if not task.dag:
task.dag = dag
elif not task._dag:
task._dag = dag
if task.dag is ATTRIBUTE_REMOVED:
task.dag = dag # required after deserialization

if next(task.iter_mapped_dependants(), None) is None: # No mapped dependants, no need to validate.
return
Expand Down Expand Up @@ -3313,8 +3295,12 @@ def get_template_context(
:param session: SQLAlchemy ORM Session
:param ignore_param_exceptions: flag to suppress value exceptions while initializing the ParamsDict
"""
if TYPE_CHECKING:
assert self.task
assert self.task.dag
return _get_template_context(
task_instance=self,
dag=self.task.dag,
session=session,
ignore_param_exceptions=ignore_param_exceptions,
)
Expand Down Expand Up @@ -3374,8 +3360,15 @@ def render_templates(
context = self.get_template_context()
original_task = self.task

ti = context["ti"]

if TYPE_CHECKING:
assert original_task
assert self.task
assert ti.task

if ti.task.dag is ATTRIBUTE_REMOVED:
ti.task.dag = self.task.dag

# If self.task is mapped, this call replaces self.task to point to the
# unmapped BaseOperator created by this function! This is because the
Expand Down
18 changes: 18 additions & 0 deletions airflow/serialization/pydantic/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,12 @@ def get_template_context(
"""
from airflow.models.taskinstance import _get_template_context

if TYPE_CHECKING:
assert self.task
assert self.task.dag
return _get_template_context(
task_instance=self,
dag=self.task.dag,
session=session,
ignore_param_exceptions=ignore_param_exceptions,
)
Expand Down Expand Up @@ -518,6 +522,20 @@ def _handle_reschedule(
)
_set_ti_attrs(self, updated_ti) # _handle_reschedule is a remote call that mutates the TI

def get_relevant_upstream_map_indexes(
self,
upstream: Operator,
ti_count: int | None,
*,
session: Session | None = None,
) -> int | range | None:
return TaskInstance.get_relevant_upstream_map_indexes(
self=self, # type: ignore[arg-type]
upstream=upstream,
ti_count=ti_count,
session=session,
)


if is_pydantic_2_installed():
TaskInstancePydantic.model_rebuild()
4 changes: 2 additions & 2 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
from airflow.utils.operator_resources import Resources
from airflow.utils.task_group import MappedTaskGroup, TaskGroup
from airflow.utils.timezone import from_timestamp, parse_timezone
from airflow.utils.types import NOTSET, ArgNotSet
from airflow.utils.types import ATTRIBUTE_REMOVED, NOTSET, ArgNotSet

if TYPE_CHECKING:
from inspect import Parameter
Expand Down Expand Up @@ -1297,7 +1297,7 @@ def deserialize_operator(cls, encoded_op: dict[str, Any]) -> Operator:
)
else:
op = SerializedBaseOperator(task_id=encoded_op["task_id"])

op.dag = ATTRIBUTE_REMOVED # type: ignore[assignment]
cls.populate_operator(op, encoded_op)
return op

Expand Down
19 changes: 19 additions & 0 deletions airflow/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,25 @@ def is_arg_passed(arg: Union[ArgNotSet, None] = NOTSET) -> bool:
"""Sentinel value for argument default. See ``ArgNotSet``."""


class AttributeRemoved:
"""
Sentinel type to signal when attribute removed on serialization.
:meta private:
"""

def __getattr__(self, item):
raise RuntimeError("Attribute was removed on serialization and must be set again.")


ATTRIBUTE_REMOVED = AttributeRemoved()
"""
Sentinel value for attributes removed on serialization.
:meta private:
"""


class DagRunType(str, enum.Enum):
"""Class with DagRun types."""

Expand Down
1 change: 1 addition & 0 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3568,6 +3568,7 @@ def test_operator_field_with_serialization(self, create_task_instance):
deserialized_op = SerializedBaseOperator.deserialize_operator(serialized_op)
assert deserialized_op.task_type == "EmptyOperator"
# Verify that ti.operator field renders correctly "with" Serialization
deserialized_op.dag = ti.task.dag
ser_ti = TI(task=deserialized_op, run_id=None)
assert ser_ti.operator == "EmptyOperator"
assert ser_ti.task.operator_name == "EmptyOperator"
Expand Down
5 changes: 4 additions & 1 deletion tests/providers/postgres/operators/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,10 @@ def test_parameters_are_templatized(create_task_instance_of_operator):
task_id="test-task",
)
task: SQLExecuteQueryOperator = ti.render_templates(
{"param": {"conn_id": "pg", "table": "foo", "bar": "egg"}}
{
"param": {"conn_id": "pg", "table": "foo", "bar": "egg"},
"ti": ti,
}
)
assert task.conn_id == "pg"
assert task.sql == "SELECT * FROM foo WHERE spam = %(spam)s;"
Expand Down
32 changes: 25 additions & 7 deletions tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2509,21 +2509,31 @@ def test_operator_expand_deserialized_unmap():

ser_mapped = BaseSerialization.serialize(mapped)
deser_mapped = BaseSerialization.deserialize(ser_mapped)
deser_mapped.dag = None

ser_normal = BaseSerialization.serialize(normal)
deser_normal = BaseSerialization.deserialize(ser_normal)
deser_normal.dag = None
assert deser_mapped.unmap(None) == deser_normal


@pytest.mark.db_test
def test_sensor_expand_deserialized_unmap():
"""Unmap a deserialized mapped sensor should be similar to deserializing a non-mapped sensor"""
normal = BashSensor(task_id="a", bash_command=[1, 2], mode="reschedule")
mapped = BashSensor.partial(task_id="a", mode="reschedule").expand(bash_command=[1, 2])

serialize = SerializedBaseOperator.serialize

deserialize = SerializedBaseOperator.deserialize
assert deserialize(serialize(mapped)).unmap(None) == deserialize(serialize(normal))
dag = DAG(dag_id="hello", start_date=None)
with dag:
normal = BashSensor(task_id="a", bash_command=[1, 2], mode="reschedule")
mapped = BashSensor.partial(task_id="b", mode="reschedule").expand(bash_command=[1, 2])
ser_mapped = SerializedBaseOperator.serialize(mapped)
deser_mapped = SerializedBaseOperator.deserialize(ser_mapped)
deser_mapped.dag = dag
deser_unmapped = deser_mapped.unmap(None)
ser_normal = SerializedBaseOperator.serialize(normal)
deser_normal = SerializedBaseOperator.deserialize(ser_normal)
deser_normal.dag = dag
comps = set(BashSensor._comps)
comps.remove("task_id")
assert all(getattr(deser_unmapped, c, None) == getattr(deser_normal, c, None) for c in comps)


def test_task_resources_serde():
Expand Down Expand Up @@ -2625,6 +2635,10 @@ def x(arg1, arg2, arg3):
"retry_delay": timedelta(seconds=30),
}

# this dag is not pickleable in this context, so we have to simply
# set it to None
deserialized.dag = None

# Ensure the serialized operator can also be correctly pickled, to ensure
# correct interaction between DAG pickling and serialization. This is done
# here so we don't need to duplicate tests between pickled and non-pickled
Expand Down Expand Up @@ -2721,6 +2735,10 @@ def x(arg1, arg2, arg3):
"retry_delay": timedelta(seconds=30),
}

# this dag is not pickleable in this context, so we have to simply
# set it to None
deserialized.dag = None

# Ensure the serialized operator can also be correctly pickled, to ensure
# correct interaction between DAG pickling and serialization. This is done
# here so we don't need to duplicate tests between pickled and non-pickled
Expand Down
27 changes: 20 additions & 7 deletions tests/serialization/test_pydantic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from airflow.jobs.job import Job
from airflow.jobs.local_task_job_runner import LocalTaskJobRunner
from airflow.models import MappedOperator
from airflow.models.dag import DagModel
from airflow.models.dag import DAG, DagModel
from airflow.models.dataset import (
DagScheduleDatasetReference,
DatasetEvent,
Expand All @@ -43,7 +43,7 @@
from airflow.settings import _ENABLE_AIP_44
from airflow.utils import timezone
from airflow.utils.state import State
from airflow.utils.types import DagRunType
from airflow.utils.types import ATTRIBUTE_REMOVED, DagRunType
from tests.models import DEFAULT_DATE

pytestmark = pytest.mark.db_test
Expand Down Expand Up @@ -89,7 +89,7 @@ def test_deserialize_ti_mapped_op_reserialized_with_refresh_from_task(session, d
"task_id": "target",
}

with dag_maker():
with dag_maker() as dag:

@task
def source():
Expand Down Expand Up @@ -117,7 +117,7 @@ def target(val=None):
# roundtrip ti
sered = BaseSerialization.serialize(ti, use_pydantic_models=True)
desered = BaseSerialization.deserialize(sered, use_pydantic_models=True)

assert desered.task.dag is ATTRIBUTE_REMOVED
assert "operator_class" not in sered["__var"]["task"]

assert desered.task.__class__ == MappedOperator
Expand All @@ -130,9 +130,22 @@ def target(val=None):

assert isinstance(desered.task.operator_class, dict)

resered = BaseSerialization.serialize(desered, use_pydantic_models=True)
deresered = BaseSerialization.deserialize(resered, use_pydantic_models=True)
assert deresered.task.operator_class == desered.task.operator_class == op_class_dict_expected
# let's check that we can safely add back dag...
assert isinstance(dag, DAG)
# dag already has this task
assert dag.has_task(desered.task.task_id) is True
# but the task has no dag
assert desered.task.dag is ATTRIBUTE_REMOVED
# and there are no upstream / downstreams on the task cus those are wiped out on serialization
# and this is wrong / not great but that's how it is
assert desered.task.upstream_task_ids == set()
assert desered.task.downstream_task_ids == set()
# add the dag back
desered.task.dag = dag
# great, no error
# but still, there are no upstream downstreams
assert desered.task.upstream_task_ids == set()
assert desered.task.downstream_task_ids == set()


@pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled")
Expand Down
9 changes: 9 additions & 0 deletions tests/serialization/test_serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,11 @@ def test_serialize_deserialize_pydantic(input, pydantic_class, encoded_type, cmp
reserialized = BaseSerialization.serialize(deserialized, use_pydantic_models=True)
dereserialized = BaseSerialization.deserialize(reserialized, use_pydantic_models=True)
assert isinstance(dereserialized, pydantic_class)

if encoded_type == "task_instance":
deserialized.task.dag = None
dereserialized.task.dag = None

assert dereserialized == deserialized

# Verify recursive behavior
Expand Down Expand Up @@ -394,6 +399,10 @@ def test_all_pydantic_models_round_trip():
serialized = BaseSerialization.serialize(pydantic_instance, use_pydantic_models=True)
deserialized = BaseSerialization.deserialize(serialized, use_pydantic_models=True)
assert isinstance(deserialized, c)
if isinstance(pydantic_instance, TaskInstancePydantic):
# we can't access the dag on deserialization; but there is no dag here.
deserialized.task.dag = None
pydantic_instance.task.dag = None
assert pydantic_instance == deserialized


Expand Down

0 comments on commit 7d7a4cd

Please sign in to comment.