Skip to content

Commit

Permalink
Refactor SerializedDagModel and DagCode for dag versioning (apache#43821
Browse files Browse the repository at this point in the history
)

* Refactor SerializedDagModel and DagCode for dag versioning

Now that we have dag versioning, the SerializedDagModel and
DagCode objects should no longer be deleted. Deletion should
start with the DagModel, which will cascade to the DagVersion,
then to the DagCode and SerializedDagModel.

Also, these models are no longer updated. Instead, a new
object is added; hence, the last_updated is changed to created_at.

* fixup! Refactor SerializedDagModel and DagCode for dag versioning

* fixup! fixup! Refactor SerializedDagModel and DagCode for dag versioning

* update rpc test since remove_deleted_dag has been removed

* Apply suggestions from code review
  • Loading branch information
ephraimbuddy authored Nov 16, 2024
1 parent fa2e4e9 commit 2db0d11
Show file tree
Hide file tree
Showing 13 changed files with 120 additions and 222 deletions.
6 changes: 0 additions & 6 deletions airflow/api/common/delete_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from airflow.exceptions import AirflowException, DagNotFound
from airflow.models import DagModel
from airflow.models.errors import ParseImportError
from airflow.models.serialized_dag import SerializedDagModel
from airflow.utils.db import get_sqla_model_classes
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import TaskInstanceState
Expand Down Expand Up @@ -64,11 +63,6 @@ def delete_dag(dag_id: str, keep_records_in_log: bool = True, session: Session =
if dag is None:
raise DagNotFound(f"Dag id {dag_id} not found")

# Scheduler removes DAGs without files from serialized_dag table every dag_dir_list_interval.
# There may be a lag, so explicitly removes serialized DAG here.
if SerializedDagModel.has_dag(dag_id=dag_id, session=session):
SerializedDagModel.remove_dag(dag_id=dag_id, session=session)

count = 0

for model in get_sqla_model_classes():
Expand Down
3 changes: 0 additions & 3 deletions airflow/api_internal/endpoints/rpc_api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.jobs.job import Job, most_recent_job
from airflow.models.dagcode import DagCode
from airflow.models.taskinstance import _record_task_map_for_downstreams
from airflow.models.xcom_arg import _get_task_map_length
from airflow.sensors.base import _orig_start_date
Expand Down Expand Up @@ -94,7 +93,6 @@ def initialize_method_map() -> dict[str, Callable]:
_xcom_pull,
_record_task_map_for_downstreams,
trigger_dag,
DagCode.remove_deleted_code,
DagModel.deactivate_deleted_dags,
DagModel.get_paused_dag_ids,
DagModel.get_current,
Expand Down Expand Up @@ -138,7 +136,6 @@ def initialize_method_map() -> dict[str, Callable]:
DagRun._get_log_template,
RenderedTaskInstanceFields._update_runtime_evaluated_template_fields,
SerializedDagModel.get_serialized_dag,
SerializedDagModel.remove_deleted_dags,
SkipMixin._skip,
SkipMixin._skip_all_except,
TaskInstance._check_and_change_state_before_execution,
Expand Down
6 changes: 1 addition & 5 deletions airflow/dag_processing/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,11 +507,7 @@ def deactivate_stale_dags(
stale_dag_threshold: int,
session: Session = NEW_SESSION,
):
"""
Detect DAGs which are no longer present in files.
Deactivate them and remove them in the serialized_dag table.
"""
"""Detect and deactivate DAGs which are no longer present in files."""
to_deactivate = set()
query = select(DagModel.dag_id, DagModel.fileloc, DagModel.last_parsed_time).where(DagModel.is_active)
standalone_dag_processor = conf.getboolean("scheduler", "standalone_dag_processor")
Expand Down
8 changes: 8 additions & 0 deletions airflow/migrations/versions/0047_3_0_0_add_dag_versioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def upgrade():
ondelete="CASCADE",
)
batch_op.create_unique_constraint("dag_code_dag_version_id_uq", ["dag_version_id"])
batch_op.drop_column("last_updated")
batch_op.add_column(sa.Column("created_at", UtcDateTime(), nullable=False, default=timezone.utcnow))

with op.batch_alter_table(
"serialized_dag", recreate="always", naming_convention=naming_convention
Expand All @@ -100,6 +102,8 @@ def upgrade():
ondelete="CASCADE",
)
batch_op.create_unique_constraint("serialized_dag_dag_version_id_uq", ["dag_version_id"])
batch_op.drop_column("last_updated")
batch_op.add_column(sa.Column("created_at", UtcDateTime(), nullable=False, default=timezone.utcnow))

with op.batch_alter_table("task_instance", schema=None) as batch_op:
batch_op.add_column(sa.Column("dag_version_id", UUIDType(binary=False)))
Expand Down Expand Up @@ -140,6 +144,8 @@ def downgrade():
batch_op.drop_constraint(batch_op.f("dag_code_dag_version_id_fkey"), type_="foreignkey")
batch_op.drop_column("dag_version_id")
batch_op.create_primary_key("dag_code_pkey", ["fileloc_hash"])
batch_op.drop_column("created_at")
batch_op.add_column(sa.Column("last_updated", UtcDateTime(), nullable=False))

with op.batch_alter_table("serialized_dag", schema=None, naming_convention=naming_convention) as batch_op:
batch_op.drop_column("id")
Expand All @@ -149,6 +155,8 @@ def downgrade():
batch_op.create_primary_key("serialized_dag_pkey", ["dag_id"])
batch_op.drop_constraint(batch_op.f("serialized_dag_dag_version_id_fkey"), type_="foreignkey")
batch_op.drop_column("dag_version_id")
batch_op.drop_column("created_at")
batch_op.add_column(sa.Column("last_updated", UtcDateTime(), nullable=False))

with op.batch_alter_table("dag_run", schema=None) as batch_op:
batch_op.add_column(sa.Column("dag_hash", sa.String(length=32), autoincrement=False, nullable=True))
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/dag_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class DagVersion(Base):
)
dag_runs = relationship("DagRun", back_populates="dag_version", cascade="all, delete, delete-orphan")
task_instances = relationship("TaskInstance", back_populates="dag_version")
created_at = Column(UtcDateTime, default=timezone.utcnow)
created_at = Column(UtcDateTime, nullable=False, default=timezone.utcnow)

__table_args__ = (
UniqueConstraint("dag_id", "version_number", name="dag_id_v_name_v_number_unique_constraint"),
Expand Down
41 changes: 5 additions & 36 deletions airflow/models/dagcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,15 @@

import logging
import struct
from typing import TYPE_CHECKING, Collection
from typing import TYPE_CHECKING

import uuid6
from sqlalchemy import BigInteger, Column, ForeignKey, String, Text, delete, select
from sqlalchemy import BigInteger, Column, ForeignKey, String, Text, select
from sqlalchemy.dialects.mysql import MEDIUMTEXT
from sqlalchemy.orm import relationship
from sqlalchemy.sql.expression import literal
from sqlalchemy_utils import UUIDType

from airflow.api_internal.internal_api_call import internal_api_call
from airflow.configuration import conf
from airflow.exceptions import DagCodeNotFound
from airflow.models.base import Base
Expand Down Expand Up @@ -58,7 +57,7 @@ class DagCode(Base):
fileloc_hash = Column(BigInteger, nullable=False)
fileloc = Column(String(2000), nullable=False)
# The max length of fileloc exceeds the limit of indexing.
last_updated = Column(UtcDateTime, nullable=False)
created_at = Column(UtcDateTime, nullable=False, default=timezone.utcnow)
source_code = Column(Text().with_variant(MEDIUMTEXT(), "mysql"), nullable=False)
dag_version_id = Column(
UUIDType(binary=False), ForeignKey("dag_version.id", ondelete="CASCADE"), nullable=False, unique=True
Expand All @@ -74,7 +73,7 @@ def __init__(self, dag_version, full_filepath: str, source_code: str | None = No

@classmethod
@provide_session
def write_dag(cls, dag_version: DagVersion, fileloc: str, session: Session = NEW_SESSION) -> DagCode:
def write_code(cls, dag_version: DagVersion, fileloc: str, session: Session = NEW_SESSION) -> DagCode:
"""
Write code into database.
Expand All @@ -87,36 +86,6 @@ def write_dag(cls, dag_version: DagVersion, fileloc: str, session: Session = NEW
log.debug("DAG file %s written into DagCode table", fileloc)
return dag_code

@classmethod
@internal_api_call
@provide_session
def remove_deleted_code(
cls,
alive_dag_filelocs: Collection[str],
processor_subdir: str,
session: Session = NEW_SESSION,
) -> None:
"""
Delete code not included in alive_dag_filelocs.
:param alive_dag_filelocs: file paths of alive DAGs
:param processor_subdir: dag processor subdir
:param session: ORM Session
"""
alive_fileloc_hashes = [cls.dag_fileloc_hash(fileloc) for fileloc in alive_dag_filelocs]

log.debug("Deleting code from %s table ", cls.__tablename__)

session.execute(
delete(cls)
.where(
cls.fileloc_hash.notin_(alive_fileloc_hashes),
cls.fileloc.notin_(alive_dag_filelocs),
cls.fileloc.contains(processor_subdir),
)
.execution_options(synchronize_session="fetch")
)

@classmethod
@provide_session
def has_dag(cls, fileloc: str, session: Session = NEW_SESSION) -> bool:
Expand Down Expand Up @@ -172,7 +141,7 @@ def _get_code_from_db(cls, fileloc, session: Session = NEW_SESSION) -> str:
dag_code = session.scalar(
select(cls)
.where(cls.fileloc_hash == cls.dag_fileloc_hash(fileloc))
.order_by(cls.last_updated.desc())
.order_by(cls.created_at.desc())
.limit(1)
)
if not dag_code:
Expand Down
81 changes: 19 additions & 62 deletions airflow/models/serialized_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
import logging
import zlib
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Collection
from typing import TYPE_CHECKING, Any

import sqlalchemy_jsonfield
import uuid6
from sqlalchemy import Column, ForeignKey, LargeBinary, String, exc, or_, select
from sqlalchemy import Column, ForeignKey, LargeBinary, String, exc, select
from sqlalchemy.orm import backref, foreign, relationship
from sqlalchemy.sql.expression import func, literal
from sqlalchemy_utils import UUIDType
Expand Down Expand Up @@ -83,7 +83,7 @@ class SerializedDagModel(Base):
dag_id = Column(String(ID_LEN), nullable=False)
_data = Column("data", sqlalchemy_jsonfield.JSONField(json=json), nullable=True)
_data_compressed = Column("data_compressed", LargeBinary, nullable=True)
last_updated = Column(UtcDateTime, nullable=False)
created_at = Column(UtcDateTime, nullable=False, default=timezone.utcnow)
dag_hash = Column(String(32), nullable=False)
processor_subdir = Column(String(2000), nullable=True)

Expand All @@ -110,7 +110,6 @@ class SerializedDagModel(Base):

def __init__(self, dag: DAG, processor_subdir: str | None = None) -> None:
self.dag_id = dag.dag_id
self.last_updated = timezone.utcnow()
self.processor_subdir = processor_subdir

dag_data = SerializedDAG.to_dict(dag)
Expand Down Expand Up @@ -186,7 +185,7 @@ def write_dag(
if session.scalar(
select(literal(True)).where(
cls.dag_id == dag.dag_id,
(timezone.utcnow() - timedelta(seconds=min_update_interval)) < cls.last_updated,
(timezone.utcnow() - timedelta(seconds=min_update_interval)) < cls.created_at,
)
):
return False
Expand All @@ -196,7 +195,7 @@ def write_dag(
serialized_dag_db = session.execute(
select(cls.dag_hash, cls.processor_subdir)
.where(cls.dag_id == dag.dag_id)
.order_by(cls.last_updated.desc())
.order_by(cls.created_at.desc())
).first()

if (
Expand All @@ -215,13 +214,12 @@ def write_dag(
new_serialized_dag.dag_version = dagv
session.add(new_serialized_dag)
log.debug("DAG: %s written to the DB", dag.dag_id)

DagCode.write_dag(dagv, dag.fileloc, session=session)
DagCode.write_code(dagv, dag.fileloc, session=session)
return True

@classmethod
def latest_item_select_object(cls, dag_id):
return select(cls).where(cls.dag_id == dag_id).order_by(cls.last_updated.desc()).limit(1)
return select(cls).where(cls.dag_id == dag_id).order_by(cls.created_at.desc()).limit(1)

@classmethod
@provide_session
Expand All @@ -237,7 +235,7 @@ def get_latest_serialized_dags(
"""
# Subquery to get the latest serdag per dag_id
latest_serdag_subquery = (
session.query(cls.dag_id, func.max(cls.last_updated).label("last_updated"))
session.query(cls.dag_id, func.max(cls.created_at).label("created_at"))
.filter(cls.dag_id.in_(dag_ids))
.group_by(cls.dag_id)
.subquery()
Expand All @@ -246,7 +244,7 @@ def get_latest_serialized_dags(
select(cls)
.join(
latest_serdag_subquery,
cls.last_updated == latest_serdag_subquery.c.last_updated,
cls.created_at == latest_serdag_subquery.c.created_at,
)
.where(cls.dag_id.in_(dag_ids))
).all()
Expand All @@ -262,15 +260,15 @@ def read_all_dags(cls, session: Session = NEW_SESSION) -> dict[str, SerializedDA
:returns: a dict of DAGs read from database
"""
latest_serialized_dag_subquery = (
session.query(cls.dag_id, func.max(cls.last_updated).label("max_updated"))
session.query(cls.dag_id, func.max(cls.created_at).label("max_created"))
.group_by(cls.dag_id)
.subquery()
)
serialized_dags = session.scalars(
select(cls).join(
latest_serialized_dag_subquery,
(cls.dag_id == latest_serialized_dag_subquery.c.dag_id)
and (cls.last_updated == latest_serialized_dag_subquery.c.max_updated),
and (cls.created_at == latest_serialized_dag_subquery.c.max_created),
)
)

Expand Down Expand Up @@ -313,47 +311,6 @@ def dag(self) -> SerializedDAG:
raise ValueError("invalid or missing serialized DAG data")
return SerializedDAG.from_dict(data)

@classmethod
@provide_session
def remove_dag(cls, dag_id: str, session: Session = NEW_SESSION) -> None:
"""
Delete a DAG with given dag_id.
:param dag_id: dag_id to be deleted
:param session: ORM Session.
"""
session.execute(cls.__table__.delete().where(cls.dag_id == dag_id))

@classmethod
@internal_api_call
@provide_session
def remove_deleted_dags(
cls,
alive_dag_filelocs: Collection[str],
processor_subdir: str | None = None,
session: Session = NEW_SESSION,
) -> None:
"""
Delete DAGs not included in alive_dag_filelocs.
:param alive_dag_filelocs: file paths of alive DAGs
:param processor_subdir: dag processor subdir
:param session: ORM Session
"""
log.debug(
"Deleting Serialized DAGs (for which DAG files are deleted) from %s table ", cls.__tablename__
)
# Deleting the DagModel cascade deletes the serialized Dag through the dag version relationship
session.execute(
DagModel.__table__.delete().where(
DagModel.fileloc.notin_(alive_dag_filelocs),
or_(
DagModel.processor_subdir.is_(None),
DagModel.processor_subdir == processor_subdir,
),
)
)

@classmethod
@provide_session
def has_dag(cls, dag_id: str, session: Session = NEW_SESSION) -> bool:
Expand Down Expand Up @@ -418,7 +375,7 @@ def get_last_updated_datetime(cls, dag_id: str, session: Session = NEW_SESSION)
:param session: ORM Session
"""
return session.scalar(
select(cls.last_updated).where(cls.dag_id == dag_id).order_by(cls.last_updated.desc()).limit(1)
select(cls.created_at).where(cls.dag_id == dag_id).order_by(cls.created_at.desc()).limit(1)
)

@classmethod
Expand All @@ -429,7 +386,7 @@ def get_max_last_updated_datetime(cls, session: Session = NEW_SESSION) -> dateti
:param session: ORM Session
"""
return session.scalar(select(func.max(cls.last_updated)))
return session.scalar(select(func.max(cls.created_at)))

@classmethod
@provide_session
Expand All @@ -442,7 +399,7 @@ def get_latest_version_hash(cls, dag_id: str, session: Session = NEW_SESSION) ->
:return: DAG Hash, or None if the DAG is not found
"""
return session.scalar(
select(cls.dag_hash).where(cls.dag_id == dag_id).order_by(cls.last_updated.desc()).limit(1)
select(cls.dag_hash).where(cls.dag_id == dag_id).order_by(cls.created_at.desc()).limit(1)
)

@classmethod
Expand All @@ -461,9 +418,9 @@ def get_latest_version_hash_and_updated_datetime(
:return: A tuple of DAG Hash and last updated datetime, or None if the DAG is not found
"""
return session.execute(
select(cls.dag_hash, cls.last_updated)
select(cls.dag_hash, cls.created_at)
.where(cls.dag_id == dag_id)
.order_by(cls.last_updated.desc())
.order_by(cls.created_at.desc())
.limit(1)
).one_or_none()

Expand All @@ -476,7 +433,7 @@ def get_dag_dependencies(cls, session: Session = NEW_SESSION) -> dict[str, list[
:param session: ORM Session
"""
latest_sdag_subquery = (
session.query(cls.dag_id, func.max(cls.last_updated).label("max_updated"))
session.query(cls.dag_id, func.max(cls.created_at).label("max_created"))
.group_by(cls.dag_id)
.subquery()
)
Expand All @@ -485,7 +442,7 @@ def get_dag_dependencies(cls, session: Session = NEW_SESSION) -> dict[str, list[
select(cls.dag_id, func.json_extract(cls._data, "$.dag.dag_dependencies")).join(
latest_sdag_subquery,
(cls.dag_id == latest_sdag_subquery.c.dag_id)
and (cls.last_updated == latest_sdag_subquery.c.max_updated),
and (cls.created_at == latest_sdag_subquery.c.max_created),
)
)
iterator = ((dag_id, json.loads(deps_data) if deps_data else []) for dag_id, deps_data in query)
Expand All @@ -494,7 +451,7 @@ def get_dag_dependencies(cls, session: Session = NEW_SESSION) -> dict[str, list[
select(cls.dag_id, func.json_extract_path(cls._data, "dag", "dag_dependencies")).join(
latest_sdag_subquery,
(cls.dag_id == latest_sdag_subquery.c.dag_id)
and (cls.last_updated == latest_sdag_subquery.c.max_updated),
and (cls.created_at == latest_sdag_subquery.c.max_created),
)
)
return {dag_id: [DagDependency(**d) for d in (deps_data or [])] for dag_id, deps_data in iterator}
Expand Down
Loading

0 comments on commit 2db0d11

Please sign in to comment.