From 6fc5dafea9593dce31f34002efc9ec46d15212aa Mon Sep 17 00:00:00 2001 From: tobymao Date: Fri, 25 Oct 2024 22:02:34 -0700 Subject: [PATCH] feat: add sqlmesh.lock and fix plans with only req changes --- docs/guides/configuration.md | 15 ++++++ examples/sushi/sqlmesh-requirements.lock | 1 + setup.py | 2 + sqlmesh/core/console.py | 11 +++++ sqlmesh/core/constants.py | 1 + sqlmesh/core/context.py | 21 ++++++++ sqlmesh/core/context_diff.py | 61 +++++++++++++++++++++++- sqlmesh/core/plan/definition.py | 43 +---------------- tests/core/test_plan.py | 19 ++++++-- 9 files changed, 127 insertions(+), 47 deletions(-) create mode 100644 examples/sushi/sqlmesh-requirements.lock diff --git a/docs/guides/configuration.md b/docs/guides/configuration.md index 8079d9e2f..76ba1a8f8 100644 --- a/docs/guides/configuration.md +++ b/docs/guides/configuration.md @@ -1055,3 +1055,18 @@ Example enabling debug mode for the CLI command `sqlmesh plan`: C:\> set SQLMESH_DEBUG=1 C:\> sqlmesh plan ``` + + +### Python library dependencies +SQLMesh enables you to write Python models and macros which depend on third-party libraries. To ensure each run / evaluation uses the same version, you can specify versions in a sqlmesh.lock file in the root of your project. + +The sqlmesh.lock must be of the format `dep==version`. Only `==` is supported. + +For example: + +``` +numpy==2.1.2 +pandas==2.2.3 +``` + +This feature is only available in [Tobiko Cloud](https://tobikodata.com/product.html). diff --git a/examples/sushi/sqlmesh-requirements.lock b/examples/sushi/sqlmesh-requirements.lock new file mode 100644 index 000000000..4b2e332e2 --- /dev/null +++ b/examples/sushi/sqlmesh-requirements.lock @@ -0,0 +1 @@ +pandas==2.2.2 diff --git a/setup.py b/setup.py index a142ccb2a..4ba8ed6ea 100644 --- a/setup.py +++ b/setup.py @@ -65,6 +65,7 @@ "clickhouse-connect", "cryptography~=42.0.4", "custom-materializations", + "databricks-sql-connector", "dbt-bigquery", "dbt-core", "dbt-duckdb>=1.7.1", @@ -78,6 +79,7 @@ "pre-commit", "psycopg2-binary", "pydantic<2.6.0", + "PyAthena[Pandas]", "PyGithub", "pyspark~=3.5.0", "pytest", diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index 33d68c1d2..a1c16b145 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -626,6 +626,10 @@ def show_model_difference_summary( return self._print(Tree(f"[bold]Summary of differences against `{context_diff.environment}`:")) + + if context_diff.has_requirement_changes: + self._print(f"Requirements:\n{context_diff.requirements_diff()}") + self._show_summary_tree_for( context_diff, "Models", @@ -1563,6 +1567,9 @@ def show_model_difference_summary( self._print(f"**Summary of differences against `{context_diff.environment}`:**\n") + if context_diff.has_requirement_changes: + self._print(f"Requirements:\n{context_diff.requirements_diff()}") + added_snapshots = {context_diff.snapshots[s_id] for s_id in context_diff.added} added_snapshot_models = {s for s in added_snapshots if s.is_model} if added_snapshot_models: @@ -1999,6 +2006,10 @@ def show_model_difference_summary( no_diff: bool = True, ) -> None: self._write("Model Difference Summary:") + + if context_diff.has_requirement_changes: + self._write(f"Requirements:\n{context_diff.requirements_diff()}") + for added in context_diff.new_snapshots: self._write(f" Added: {added}") for removed in context_diff.removed_snapshots: diff --git a/sqlmesh/core/constants.py b/sqlmesh/core/constants.py index cbb8b9a3b..3fcb135f4 100644 --- a/sqlmesh/core/constants.py +++ b/sqlmesh/core/constants.py @@ -70,6 +70,7 @@ EXTERNAL_MODELS_YAML = "external_models.yaml" EXTERNAL_MODELS_DEPRECATED_YAML = "schema.yaml" +REQUIREMENTS = "sqlmesh-requirements.lock" DEFAULT_SCHEMA = "default" diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 9acd20fd0..0fc5423bf 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -327,6 +327,7 @@ def __init__( self._macros: UniqueKeyDict[str, ExecutableOrMacro] = UniqueKeyDict("macros") self._metrics: UniqueKeyDict[str, Metric] = UniqueKeyDict("metrics") self._jinja_macros = JinjaMacroRegistry() + self._requirements: t.Dict[str, str] = {} self._default_catalog: t.Optional[str] = None self._loaded: bool = False @@ -338,6 +339,7 @@ def __init__( loader=(loader or config.loader)(**config.loader_kwargs), configs={} ) self._loaders[project_type].configs[path] = config + self._load_requirements(path) self.project_type = c.HYBRID if len(self._loaders) > 1 else project_type self._all_dialects: t.Set[str] = {self.config.dialect or ""} @@ -2043,6 +2045,7 @@ def _context_diff( snapshots=snapshots or self.snapshots, create_from=create_from or c.PROD, state_reader=self.state_reader, + requirements=self._requirements, ensure_finalized_snapshots=ensure_finalized_snapshots, ) @@ -2106,6 +2109,24 @@ def _load_materializations_and_signals(self) -> None: context_loader.loader.load_signals(self) context_loader.loader.load_materializations(self) + def _load_requirements(self, path: Path) -> None: + path = path / c.REQUIREMENTS + if path.is_file(): + with open(path, "r", encoding="utf-8") as file: + for line in file: + args = [k.strip() for k in line.split("==")] + if len(args) != 2: + raise SQLMeshError( + f"Invalid lock file entry '{line.strip()}'. Only 'dep==ver' is supported" + ) + dep, ver = args + other_ver = self._requirements.get(dep, ver) + if ver != other_ver: + raise SQLMeshError( + "Conflicting requirement {dep}: {ver} != {other_ver}. Fix your sqlmesh.lock file." + ) + self._requirements[dep] = ver + class Context(GenericContext[Config]): CONFIG_TYPE = Config diff --git a/sqlmesh/core/context_diff.py b/sqlmesh/core/context_diff.py index 03b4cc5e2..7322a17ce 100644 --- a/sqlmesh/core/context_diff.py +++ b/sqlmesh/core/context_diff.py @@ -13,16 +13,27 @@ from __future__ import annotations import logging +import sys import typing as t +from difflib import ndiff from functools import cached_property from sqlmesh.core.snapshot import Snapshot, SnapshotId, SnapshotTableInfo from sqlmesh.utils.errors import SQLMeshError from sqlmesh.utils.pydantic import PydanticModel + +if sys.version_info >= (3, 12): + from importlib import metadata +else: + import importlib_metadata as metadata # type: ignore + + if t.TYPE_CHECKING: from sqlmesh.core.state_sync import StateReader +IGNORED_PACKAGES = {"sqlmesh", "sqlglot"} + logger = logging.getLogger(__name__) @@ -59,6 +70,10 @@ class ContextDiff(PydanticModel): """Snapshot IDs that were promoted by the previous plan.""" previous_finalized_snapshots: t.Optional[t.List[SnapshotTableInfo]] """Snapshots from the previous finalized state.""" + previous_requirements: t.Dict[str, str] = {} + """Previous requirements.""" + provided_requirements: t.Dict[str, str] = {} + """Requirements from lock file.""" @classmethod def create( @@ -68,6 +83,7 @@ def create( create_from: str, state_reader: StateReader, ensure_finalized_snapshots: bool = False, + requirements: t.Optional[t.Dict[str, str]] = None, ) -> ContextDiff: """Create a ContextDiff object. @@ -80,6 +96,7 @@ def create( ensure_finalized_snapshots: Whether to compare against snapshots from the latest finalized environment state, or to use whatever snapshots are in the current environment state even if the environment is not finalized. + requirements: Fixed requirements to build the context diff with. Returns: The ContextDiff object. @@ -174,6 +191,8 @@ def create( previous_plan_id=env.plan_id if env and not is_new_environment else None, previously_promoted_snapshot_ids=previously_promoted_snapshot_ids, previous_finalized_snapshots=env.previous_finalized_snapshots if env else None, + previous_requirements=env.requirements if env else {}, + provided_requirements=requirements, ) @classmethod @@ -207,14 +226,23 @@ def create_no_diff(cls, environment: str, state_reader: StateReader) -> ContextD previous_plan_id=env.plan_id, previously_promoted_snapshot_ids={s.snapshot_id for s in env.promoted_snapshots}, previous_finalized_snapshots=env.previous_finalized_snapshots, + previous_requirements=env.requirements, + provided_requirements=env.requirements, ) @property def has_changes(self) -> bool: return ( - self.has_snapshot_changes or self.is_new_environment or self.is_unfinalized_environment + self.has_snapshot_changes + or self.is_new_environment + or self.is_unfinalized_environment + or self.has_requirement_changes ) + @property + def has_requirement_changes(self) -> bool: + return self.previous_requirements != self.requirements + @property def has_snapshot_changes(self) -> bool: return bool(self.added or self.removed_snapshots or self.modified_snapshots) @@ -251,6 +279,37 @@ def current_modified_snapshot_ids(self) -> t.Set[SnapshotId]: def snapshots_by_name(self) -> t.Dict[str, Snapshot]: return {x.name: x for x in self.snapshots.values()} + @cached_property + def requirements(self) -> t.Dict[str, str]: + requirements = self.provided_requirements.copy() + distributions = metadata.packages_distributions() + + for snapshot in self.snapshots.values(): + if snapshot.is_model: + for executable in snapshot.model.python_env.values(): + if executable.kind == "import": + try: + start = "from " if executable.payload.startswith("from ") else "import " + lib = executable.payload.split(start)[1].split()[0].split(".")[0] + if lib in distributions: + for dist in distributions[lib]: + if dist not in requirements and dist not in IGNORED_PACKAGES: + requirements[dist] = metadata.version(dist) + except metadata.PackageNotFoundError: + logger.warning("Failed to find package for %s", lib) + return requirements + + def requirements_diff(self) -> str: + return "\n".join( + ndiff( + [ + f"{k}=={self.previous_requirements[k]}" + for k in sorted(self.previous_requirements) + ], + [f"{k}=={self.requirements[k]}" for k in sorted(self.requirements)], + ) + ) + @property def environment_snapshots(self) -> t.List[SnapshotTableInfo]: """Returns current snapshots in the environment.""" diff --git a/sqlmesh/core/plan/definition.py b/sqlmesh/core/plan/definition.py index 81a15cf96..59c9814d7 100644 --- a/sqlmesh/core/plan/definition.py +++ b/sqlmesh/core/plan/definition.py @@ -1,7 +1,5 @@ from __future__ import annotations -import logging -import sys import typing as t from dataclasses import dataclass from datetime import datetime @@ -28,20 +26,9 @@ from sqlmesh.utils.date import TimeLike, now, to_datetime, to_timestamp from sqlmesh.utils.pydantic import PydanticModel -logger = logging.getLogger(__name__) - SnapshotMapping = t.Dict[SnapshotId, t.Set[SnapshotId]] -if sys.version_info >= (3, 12): - from importlib import metadata -else: - import importlib_metadata as metadata # type: ignore - - -IGNORED_PACKAGES = {"sqlmesh", "sqlglot"} - - class Plan(PydanticModel, frozen=True): context_diff: ContextDiff plan_id: str @@ -99,16 +86,7 @@ def requires_backfill(self) -> bool: @property def has_changes(self) -> bool: - modified_snapshot_ids = { - *self.context_diff.added, - *self.context_diff.removed_snapshots, - *self.context_diff.current_modified_snapshot_ids, - } - return ( - self.context_diff.is_new_environment - or self.context_diff.is_unfinalized_environment - or bool(modified_snapshot_ids) - ) + return self.context_diff.has_changes @property def has_unmodified_unpromoted(self) -> bool: @@ -217,23 +195,6 @@ def environment(self) -> Environment: else self.context_diff.previous_finalized_snapshots ) - requirements = {} - distributions = metadata.packages_distributions() - - for snapshot in self.context_diff.snapshots.values(): - if snapshot.is_model: - for executable in snapshot.model.python_env.values(): - if executable.kind == "import": - try: - start = "from " if executable.payload.startswith("from ") else "import " - lib = executable.payload.split(start)[1].split()[0].split(".")[0] - if lib in distributions: - for dist in distributions[lib]: - if dist not in requirements and dist not in IGNORED_PACKAGES: - requirements[dist] = metadata.version(dist) - except metadata.PackageNotFoundError: - logger.warning("Failed to find package for %s", lib) - return Environment( snapshots=snapshots, start_at=self.provided_start or self._earliest_interval_start, @@ -243,7 +204,7 @@ def environment(self) -> Environment: expiration_ts=expiration_ts, promoted_snapshot_ids=promoted_snapshot_ids, previous_finalized_snapshots=previous_finalized_snapshots, - requirements=requirements, + requirements=self.context_diff.requirements, **self.environment_naming_info.dict(), ) diff --git a/tests/core/test_plan.py b/tests/core/test_plan.py index 09869e162..d7bd87aa9 100644 --- a/tests/core/test_plan.py +++ b/tests/core/test_plan.py @@ -2551,7 +2551,7 @@ def test_interval_end_per_model(make_snapshot): assert plan_builder.build().interval_end_per_model is None -def test_plan_requirements(): +def test_plan_requirements(mocker): context = Context(paths="examples/sushi") model = context.get_model("sushi.items") model.python_env["ruamel"] = Executable(payload="import ruamel", kind="import") @@ -2559,10 +2559,19 @@ def test_plan_requirements(): payload="from ipywidgets.widgets.widget_media import Image", kind="import" ) - plan = context.plan( - "dev", no_prompts=True, skip_tests=True, skip_backfill=True - ).environment.requirements - assert set(plan) == {"ipywidgets", "numpy", "pandas", "ruamel.yaml", "ruamel.yaml.clib"} + environment = context.plan( + "dev", no_prompts=True, skip_tests=True, skip_backfill=True, auto_apply=True + ).environment + requirements = {"ipywidgets", "numpy", "pandas", "ruamel.yaml", "ruamel.yaml.clib"} + assert environment.requirements["pandas"] == "2.2.2" + assert set(environment.requirements) == requirements + + mocker.patch( + "sqlmesh.core.context_diff.ContextDiff.requirements", {"numpy": "2.1.2", "pandas": "2.2.1"} + ) + diff = context.plan("dev", no_prompts=True, skip_tests=True, skip_backfill=True).context_diff + assert set(diff.previous_requirements) == requirements + assert set(diff.requirements) == {"numpy", "pandas"} def test_unaligned_start_model_with_forward_only_preview(make_snapshot):