diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index b3bf56b4da42..e66731f504bc 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -101,7 +101,7 @@ from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.setup_teardown import SetupTeardownContext from airflow.utils.trigger_rule import TriggerRule -from airflow.utils.types import NOTSET +from airflow.utils.types import ATTRIBUTE_REMOVED, NOTSET from airflow.utils.xcom import XCOM_RETURN_KEY if TYPE_CHECKING: @@ -1245,11 +1245,21 @@ def dag(self) -> DAG: # type: ignore[override] @dag.setter def dag(self, dag: DAG | None): """Operators can be assigned to one DAG, one time. Repeat assignments to that same DAG are ok.""" - from airflow.models.dag import DAG - if dag is None: self._dag = None return + + # if set to removed, then just set and exit + if self._dag is ATTRIBUTE_REMOVED: + self._dag = dag + return + # if setting to removed, then just set and exit + if dag is ATTRIBUTE_REMOVED: + self._dag = ATTRIBUTE_REMOVED # type: ignore[assignment] + return + + from airflow.models.dag import DAG + if not isinstance(dag, DAG): raise TypeError(f"Expected DAG; received {dag.__class__.__name__}") elif self.has_dag() and self.dag is not dag: diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 6b9afc01dd09..9d248f4db113 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -136,6 +136,7 @@ from airflow.utils.task_group import MappedTaskGroup from airflow.utils.task_instance_session import set_current_task_instance_session from airflow.utils.timeout import timeout +from airflow.utils.types import ATTRIBUTE_REMOVED from airflow.utils.xcom import XCOM_RETURN_KEY TR = TaskReschedule @@ -902,13 +903,15 @@ def _clear_next_method_args(*, task_instance: TaskInstance | TaskInstancePydanti def _get_template_context( *, task_instance: TaskInstance | TaskInstancePydantic, + dag: DAG, session: Session | None = None, ignore_param_exceptions: bool = True, ) -> Context: """ Return TI Context. - :param task_instance: the task instance + :param task_instance: the task instance for the task + :param dag for the task :param session: SQLAlchemy ORM Session :param ignore_param_exceptions: flag to suppress value exceptions while initializing the ParamsDict @@ -928,27 +931,10 @@ def _get_template_context( assert task_instance.task assert task assert task.dag - try: - dag: DAG = task.dag - except AirflowException: - from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic - if isinstance(task_instance, TaskInstancePydantic): - ti = session.scalar( - select(TaskInstance).where( - TaskInstance.task_id == task_instance.task_id, - TaskInstance.dag_id == task_instance.dag_id, - TaskInstance.run_id == task_instance.run_id, - TaskInstance.map_index == task_instance.map_index, - ) - ) - dag = ti.dag_model.serialized_dag.dag - if hasattr(task_instance.task, "_dag"): # BaseOperator - task_instance.task._dag = dag - else: # MappedOperator - task_instance.task.dag = dag - else: - raise + if task.dag is ATTRIBUTE_REMOVED: + task.dag = dag # required after deserialization + dag_run = task_instance.get_dagrun(session) data_interval = dag.get_run_data_interval(dag_run) @@ -1278,12 +1264,8 @@ def _record_task_map_for_downstreams( :meta private: """ - # when taking task over RPC, we need to add the dag back - if isinstance(task, MappedOperator): - if not task.dag: - task.dag = dag - elif not task._dag: - task._dag = dag + if task.dag is ATTRIBUTE_REMOVED: + task.dag = dag # required after deserialization if next(task.iter_mapped_dependants(), None) is None: # No mapped dependants, no need to validate. return @@ -3313,8 +3295,12 @@ def get_template_context( :param session: SQLAlchemy ORM Session :param ignore_param_exceptions: flag to suppress value exceptions while initializing the ParamsDict """ + if TYPE_CHECKING: + assert self.task + assert self.task.dag return _get_template_context( task_instance=self, + dag=self.task.dag, session=session, ignore_param_exceptions=ignore_param_exceptions, ) @@ -3374,8 +3360,15 @@ def render_templates( context = self.get_template_context() original_task = self.task + ti = context["ti"] + if TYPE_CHECKING: assert original_task + assert self.task + assert ti.task + + if ti.task.dag is ATTRIBUTE_REMOVED: + ti.task.dag = self.task.dag # If self.task is mapped, this call replaces self.task to point to the # unmapped BaseOperator created by this function! This is because the diff --git a/airflow/serialization/pydantic/taskinstance.py b/airflow/serialization/pydantic/taskinstance.py index e499a9869194..829594dc7084 100644 --- a/airflow/serialization/pydantic/taskinstance.py +++ b/airflow/serialization/pydantic/taskinstance.py @@ -288,8 +288,12 @@ def get_template_context( """ from airflow.models.taskinstance import _get_template_context + if TYPE_CHECKING: + assert self.task + assert self.task.dag return _get_template_context( task_instance=self, + dag=self.task.dag, session=session, ignore_param_exceptions=ignore_param_exceptions, ) @@ -518,6 +522,20 @@ def _handle_reschedule( ) _set_ti_attrs(self, updated_ti) # _handle_reschedule is a remote call that mutates the TI + def get_relevant_upstream_map_indexes( + self, + upstream: Operator, + ti_count: int | None, + *, + session: Session | None = None, + ) -> int | range | None: + return TaskInstance.get_relevant_upstream_map_indexes( + self=self, # type: ignore[arg-type] + upstream=upstream, + ti_count=ti_count, + session=session, + ) + if is_pydantic_2_installed(): TaskInstancePydantic.model_rebuild() diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index bb7aca84d05b..80dfd98c8254 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -74,7 +74,7 @@ from airflow.utils.operator_resources import Resources from airflow.utils.task_group import MappedTaskGroup, TaskGroup from airflow.utils.timezone import from_timestamp, parse_timezone -from airflow.utils.types import NOTSET, ArgNotSet +from airflow.utils.types import ATTRIBUTE_REMOVED, NOTSET, ArgNotSet if TYPE_CHECKING: from inspect import Parameter @@ -1297,7 +1297,7 @@ def deserialize_operator(cls, encoded_op: dict[str, Any]) -> Operator: ) else: op = SerializedBaseOperator(task_id=encoded_op["task_id"]) - + op.dag = ATTRIBUTE_REMOVED # type: ignore[assignment] cls.populate_operator(op, encoded_op) return op diff --git a/airflow/utils/types.py b/airflow/utils/types.py index 7467c9c3dd03..e4c5c511acb9 100644 --- a/airflow/utils/types.py +++ b/airflow/utils/types.py @@ -46,6 +46,25 @@ def is_arg_passed(arg: Union[ArgNotSet, None] = NOTSET) -> bool: """Sentinel value for argument default. See ``ArgNotSet``.""" +class AttributeRemoved: + """ + Sentinel type to signal when attribute removed on serialization. + + :meta private: + """ + + def __getattr__(self, item): + raise RuntimeError("Attribute was removed on serialization and must be set again.") + + +ATTRIBUTE_REMOVED = AttributeRemoved() +""" +Sentinel value for attributes removed on serialization. + +:meta private: +""" + + class DagRunType(str, enum.Enum): """Class with DagRun types.""" diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 2e39914cd553..5a8d6139a2fc 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -3568,6 +3568,7 @@ def test_operator_field_with_serialization(self, create_task_instance): deserialized_op = SerializedBaseOperator.deserialize_operator(serialized_op) assert deserialized_op.task_type == "EmptyOperator" # Verify that ti.operator field renders correctly "with" Serialization + deserialized_op.dag = ti.task.dag ser_ti = TI(task=deserialized_op, run_id=None) assert ser_ti.operator == "EmptyOperator" assert ser_ti.task.operator_name == "EmptyOperator" diff --git a/tests/providers/postgres/operators/test_postgres.py b/tests/providers/postgres/operators/test_postgres.py index 54831882feab..0bc34a519549 100644 --- a/tests/providers/postgres/operators/test_postgres.py +++ b/tests/providers/postgres/operators/test_postgres.py @@ -221,7 +221,10 @@ def test_parameters_are_templatized(create_task_instance_of_operator): task_id="test-task", ) task: SQLExecuteQueryOperator = ti.render_templates( - {"param": {"conn_id": "pg", "table": "foo", "bar": "egg"}} + { + "param": {"conn_id": "pg", "table": "foo", "bar": "egg"}, + "ti": ti, + } ) assert task.conn_id == "pg" assert task.sql == "SELECT * FROM foo WHERE spam = %(spam)s;" diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index c0650add5944..b2929e8c27b5 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -2509,21 +2509,31 @@ def test_operator_expand_deserialized_unmap(): ser_mapped = BaseSerialization.serialize(mapped) deser_mapped = BaseSerialization.deserialize(ser_mapped) + deser_mapped.dag = None + ser_normal = BaseSerialization.serialize(normal) deser_normal = BaseSerialization.deserialize(ser_normal) + deser_normal.dag = None assert deser_mapped.unmap(None) == deser_normal @pytest.mark.db_test def test_sensor_expand_deserialized_unmap(): """Unmap a deserialized mapped sensor should be similar to deserializing a non-mapped sensor""" - normal = BashSensor(task_id="a", bash_command=[1, 2], mode="reschedule") - mapped = BashSensor.partial(task_id="a", mode="reschedule").expand(bash_command=[1, 2]) - - serialize = SerializedBaseOperator.serialize - - deserialize = SerializedBaseOperator.deserialize - assert deserialize(serialize(mapped)).unmap(None) == deserialize(serialize(normal)) + dag = DAG(dag_id="hello", start_date=None) + with dag: + normal = BashSensor(task_id="a", bash_command=[1, 2], mode="reschedule") + mapped = BashSensor.partial(task_id="b", mode="reschedule").expand(bash_command=[1, 2]) + ser_mapped = SerializedBaseOperator.serialize(mapped) + deser_mapped = SerializedBaseOperator.deserialize(ser_mapped) + deser_mapped.dag = dag + deser_unmapped = deser_mapped.unmap(None) + ser_normal = SerializedBaseOperator.serialize(normal) + deser_normal = SerializedBaseOperator.deserialize(ser_normal) + deser_normal.dag = dag + comps = set(BashSensor._comps) + comps.remove("task_id") + assert all(getattr(deser_unmapped, c, None) == getattr(deser_normal, c, None) for c in comps) def test_task_resources_serde(): @@ -2625,6 +2635,10 @@ def x(arg1, arg2, arg3): "retry_delay": timedelta(seconds=30), } + # this dag is not pickleable in this context, so we have to simply + # set it to None + deserialized.dag = None + # Ensure the serialized operator can also be correctly pickled, to ensure # correct interaction between DAG pickling and serialization. This is done # here so we don't need to duplicate tests between pickled and non-pickled @@ -2721,6 +2735,10 @@ def x(arg1, arg2, arg3): "retry_delay": timedelta(seconds=30), } + # this dag is not pickleable in this context, so we have to simply + # set it to None + deserialized.dag = None + # Ensure the serialized operator can also be correctly pickled, to ensure # correct interaction between DAG pickling and serialization. This is done # here so we don't need to duplicate tests between pickled and non-pickled diff --git a/tests/serialization/test_pydantic_models.py b/tests/serialization/test_pydantic_models.py index dae611e68bd2..e3aacc829873 100644 --- a/tests/serialization/test_pydantic_models.py +++ b/tests/serialization/test_pydantic_models.py @@ -27,7 +27,7 @@ from airflow.jobs.job import Job from airflow.jobs.local_task_job_runner import LocalTaskJobRunner from airflow.models import MappedOperator -from airflow.models.dag import DagModel +from airflow.models.dag import DAG, DagModel from airflow.models.dataset import ( DagScheduleDatasetReference, DatasetEvent, @@ -43,7 +43,7 @@ from airflow.settings import _ENABLE_AIP_44 from airflow.utils import timezone from airflow.utils.state import State -from airflow.utils.types import DagRunType +from airflow.utils.types import ATTRIBUTE_REMOVED, DagRunType from tests.models import DEFAULT_DATE pytestmark = pytest.mark.db_test @@ -89,7 +89,7 @@ def test_deserialize_ti_mapped_op_reserialized_with_refresh_from_task(session, d "task_id": "target", } - with dag_maker(): + with dag_maker() as dag: @task def source(): @@ -117,7 +117,7 @@ def target(val=None): # roundtrip ti sered = BaseSerialization.serialize(ti, use_pydantic_models=True) desered = BaseSerialization.deserialize(sered, use_pydantic_models=True) - + assert desered.task.dag is ATTRIBUTE_REMOVED assert "operator_class" not in sered["__var"]["task"] assert desered.task.__class__ == MappedOperator @@ -130,9 +130,22 @@ def target(val=None): assert isinstance(desered.task.operator_class, dict) - resered = BaseSerialization.serialize(desered, use_pydantic_models=True) - deresered = BaseSerialization.deserialize(resered, use_pydantic_models=True) - assert deresered.task.operator_class == desered.task.operator_class == op_class_dict_expected + # let's check that we can safely add back dag... + assert isinstance(dag, DAG) + # dag already has this task + assert dag.has_task(desered.task.task_id) is True + # but the task has no dag + assert desered.task.dag is ATTRIBUTE_REMOVED + # and there are no upstream / downstreams on the task cus those are wiped out on serialization + # and this is wrong / not great but that's how it is + assert desered.task.upstream_task_ids == set() + assert desered.task.downstream_task_ids == set() + # add the dag back + desered.task.dag = dag + # great, no error + # but still, there are no upstream downstreams + assert desered.task.upstream_task_ids == set() + assert desered.task.downstream_task_ids == set() @pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled") diff --git a/tests/serialization/test_serialized_objects.py b/tests/serialization/test_serialized_objects.py index 9e14cb35fe63..3272c9fa5081 100644 --- a/tests/serialization/test_serialized_objects.py +++ b/tests/serialization/test_serialized_objects.py @@ -339,6 +339,11 @@ def test_serialize_deserialize_pydantic(input, pydantic_class, encoded_type, cmp reserialized = BaseSerialization.serialize(deserialized, use_pydantic_models=True) dereserialized = BaseSerialization.deserialize(reserialized, use_pydantic_models=True) assert isinstance(dereserialized, pydantic_class) + + if encoded_type == "task_instance": + deserialized.task.dag = None + dereserialized.task.dag = None + assert dereserialized == deserialized # Verify recursive behavior @@ -394,6 +399,10 @@ def test_all_pydantic_models_round_trip(): serialized = BaseSerialization.serialize(pydantic_instance, use_pydantic_models=True) deserialized = BaseSerialization.deserialize(serialized, use_pydantic_models=True) assert isinstance(deserialized, c) + if isinstance(pydantic_instance, TaskInstancePydantic): + # we can't access the dag on deserialization; but there is no dag here. + deserialized.task.dag = None + pydantic_instance.task.dag = None assert pydantic_instance == deserialized