Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add sqlmesh.lock and fix plans with only req changes #3299

Merged
merged 2 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/guides/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -1055,3 +1055,18 @@ Example enabling debug mode for the CLI command `sqlmesh plan`:
C:\> set SQLMESH_DEBUG=1
C:\> sqlmesh plan
```


### Python library dependencies
SQLMesh enables you to write Python models and macros which depend on third-party libraries. To ensure each run / evaluation uses the same version, you can specify versions in a sqlmesh.lock file in the root of your project.

The sqlmesh.lock must be of the format `dep==version`. Only `==` is supported.

For example:

```
numpy==2.1.2
pandas==2.2.3
```

This feature is only available in [Tobiko Cloud](https://tobikodata.com/product.html).
1 change: 1 addition & 0 deletions examples/sushi/sqlmesh-requirements.lock
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pandas==2.2.2
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
"clickhouse-connect",
"cryptography~=42.0.4",
"custom-materializations",
"databricks-sql-connector",
"dbt-bigquery",
"dbt-core",
"dbt-duckdb>=1.7.1",
Expand All @@ -78,6 +79,7 @@
"pre-commit",
"psycopg2-binary",
"pydantic<2.6.0",
"PyAthena[Pandas]",
"PyGithub",
"pyspark~=3.5.0",
"pytest",
Expand Down
11 changes: 11 additions & 0 deletions sqlmesh/core/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,10 @@ def show_model_difference_summary(
return

self._print(Tree(f"[bold]Summary of differences against `{context_diff.environment}`:"))

if context_diff.has_requirement_changes:
self._print(f"Requirements:\n{context_diff.requirements_diff()}")

self._show_summary_tree_for(
context_diff,
"Models",
Expand Down Expand Up @@ -1563,6 +1567,9 @@ def show_model_difference_summary(

self._print(f"**Summary of differences against `{context_diff.environment}`:**\n")

if context_diff.has_requirement_changes:
self._print(f"Requirements:\n{context_diff.requirements_diff()}")

added_snapshots = {context_diff.snapshots[s_id] for s_id in context_diff.added}
added_snapshot_models = {s for s in added_snapshots if s.is_model}
if added_snapshot_models:
Expand Down Expand Up @@ -1999,6 +2006,10 @@ def show_model_difference_summary(
no_diff: bool = True,
) -> None:
self._write("Model Difference Summary:")

if context_diff.has_requirement_changes:
self._write(f"Requirements:\n{context_diff.requirements_diff()}")

for added in context_diff.new_snapshots:
self._write(f" Added: {added}")
for removed in context_diff.removed_snapshots:
Expand Down
1 change: 1 addition & 0 deletions sqlmesh/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@

EXTERNAL_MODELS_YAML = "external_models.yaml"
EXTERNAL_MODELS_DEPRECATED_YAML = "schema.yaml"
REQUIREMENTS = "sqlmesh-requirements.lock"

DEFAULT_SCHEMA = "default"

Expand Down
23 changes: 23 additions & 0 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ def __init__(
self._macros: UniqueKeyDict[str, ExecutableOrMacro] = UniqueKeyDict("macros")
self._metrics: UniqueKeyDict[str, Metric] = UniqueKeyDict("metrics")
self._jinja_macros = JinjaMacroRegistry()
self._requirements: t.Dict[str, str] = {}
self._default_catalog: t.Optional[str] = None
self._loaded: bool = False

Expand Down Expand Up @@ -534,7 +535,10 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]:
load_start_ts = time.perf_counter()

projects = []
self._requirements.clear()
for context_loader in self._loaders.values():
for path in context_loader.configs:
self._load_requirements(path)
with sys_path(*context_loader.configs):
projects.append(context_loader.loader.load(self, update_schemas))

Expand Down Expand Up @@ -2061,6 +2065,7 @@ def _context_diff(
snapshots=snapshots or self.snapshots,
create_from=create_from or c.PROD,
state_reader=self.state_reader,
requirements=self._requirements,
ensure_finalized_snapshots=ensure_finalized_snapshots,
)

Expand Down Expand Up @@ -2124,6 +2129,24 @@ def _load_materializations_and_signals(self) -> None:
context_loader.loader.load_signals(self)
context_loader.loader.load_materializations(self)

def _load_requirements(self, path: Path) -> None:
path = path / c.REQUIREMENTS
if path.is_file():
with open(path, "r", encoding="utf-8") as file:
for line in file:
args = [k.strip() for k in line.split("==")]
if len(args) != 2:
raise SQLMeshError(
f"Invalid lock file entry '{line.strip()}'. Only 'dep==ver' is supported"
)
dep, ver = args
other_ver = self._requirements.get(dep, ver)
if ver != other_ver:
raise SQLMeshError(
f"Conflicting requirement {dep}: {ver} != {other_ver}. Fix your {c.REQUIREMENTS} file."
)
self._requirements[dep] = ver


class Context(GenericContext[Config]):
CONFIG_TYPE = Config
61 changes: 60 additions & 1 deletion sqlmesh/core/context_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,27 @@
from __future__ import annotations

import logging
import sys
import typing as t
from difflib import ndiff
from functools import cached_property

from sqlmesh.core.snapshot import Snapshot, SnapshotId, SnapshotTableInfo
from sqlmesh.utils.errors import SQLMeshError
from sqlmesh.utils.pydantic import PydanticModel


if sys.version_info >= (3, 12):
from importlib import metadata
else:
import importlib_metadata as metadata # type: ignore


if t.TYPE_CHECKING:
from sqlmesh.core.state_sync import StateReader

IGNORED_PACKAGES = {"sqlmesh", "sqlglot"}

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -59,6 +70,10 @@ class ContextDiff(PydanticModel):
"""Snapshot IDs that were promoted by the previous plan."""
previous_finalized_snapshots: t.Optional[t.List[SnapshotTableInfo]]
"""Snapshots from the previous finalized state."""
previous_requirements: t.Dict[str, str] = {}
"""Previous requirements."""
provided_requirements: t.Dict[str, str] = {}
"""Requirements from lock file."""

@classmethod
def create(
Expand All @@ -68,6 +83,7 @@ def create(
create_from: str,
state_reader: StateReader,
ensure_finalized_snapshots: bool = False,
requirements: t.Optional[t.Dict[str, str]] = None,
) -> ContextDiff:
"""Create a ContextDiff object.
Expand All @@ -80,6 +96,7 @@ def create(
ensure_finalized_snapshots: Whether to compare against snapshots from the latest finalized
environment state, or to use whatever snapshots are in the current environment state even if
the environment is not finalized.
requirements: Fixed requirements to build the context diff with.
Returns:
The ContextDiff object.
Expand Down Expand Up @@ -174,6 +191,8 @@ def create(
previous_plan_id=env.plan_id if env and not is_new_environment else None,
previously_promoted_snapshot_ids=previously_promoted_snapshot_ids,
previous_finalized_snapshots=env.previous_finalized_snapshots if env else None,
previous_requirements=env.requirements if env else {},
provided_requirements=requirements,
)

@classmethod
Expand Down Expand Up @@ -207,14 +226,23 @@ def create_no_diff(cls, environment: str, state_reader: StateReader) -> ContextD
previous_plan_id=env.plan_id,
previously_promoted_snapshot_ids={s.snapshot_id for s in env.promoted_snapshots},
previous_finalized_snapshots=env.previous_finalized_snapshots,
previous_requirements=env.requirements,
provided_requirements=env.requirements,
)

@property
def has_changes(self) -> bool:
return (
self.has_snapshot_changes or self.is_new_environment or self.is_unfinalized_environment
self.has_snapshot_changes
or self.is_new_environment
or self.is_unfinalized_environment
or self.has_requirement_changes
)

@property
def has_requirement_changes(self) -> bool:
return self.previous_requirements != self.requirements

@property
def has_snapshot_changes(self) -> bool:
return bool(self.added or self.removed_snapshots or self.modified_snapshots)
Expand Down Expand Up @@ -251,6 +279,37 @@ def current_modified_snapshot_ids(self) -> t.Set[SnapshotId]:
def snapshots_by_name(self) -> t.Dict[str, Snapshot]:
return {x.name: x for x in self.snapshots.values()}

@cached_property
def requirements(self) -> t.Dict[str, str]:
requirements = self.provided_requirements.copy()
distributions = metadata.packages_distributions()

for snapshot in self.snapshots.values():
if snapshot.is_model:
for executable in snapshot.model.python_env.values():
if executable.kind == "import":
try:
start = "from " if executable.payload.startswith("from ") else "import "
lib = executable.payload.split(start)[1].split()[0].split(".")[0]
if lib in distributions:
for dist in distributions[lib]:
if dist not in requirements and dist not in IGNORED_PACKAGES:
requirements[dist] = metadata.version(dist)
except metadata.PackageNotFoundError:
logger.warning("Failed to find package for %s", lib)
return requirements

def requirements_diff(self) -> str:
return "\n".join(
ndiff(
[
f"{k}=={self.previous_requirements[k]}"
for k in sorted(self.previous_requirements)
],
[f"{k}=={self.requirements[k]}" for k in sorted(self.requirements)],
)
)

@property
def environment_snapshots(self) -> t.List[SnapshotTableInfo]:
"""Returns current snapshots in the environment."""
Expand Down
43 changes: 2 additions & 41 deletions sqlmesh/core/plan/definition.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

import logging
import sys
import typing as t
from dataclasses import dataclass
from datetime import datetime
Expand All @@ -28,20 +26,9 @@
from sqlmesh.utils.date import TimeLike, now, to_datetime, to_timestamp
from sqlmesh.utils.pydantic import PydanticModel

logger = logging.getLogger(__name__)

SnapshotMapping = t.Dict[SnapshotId, t.Set[SnapshotId]]


if sys.version_info >= (3, 12):
from importlib import metadata
else:
import importlib_metadata as metadata # type: ignore


IGNORED_PACKAGES = {"sqlmesh", "sqlglot"}


class Plan(PydanticModel, frozen=True):
context_diff: ContextDiff
plan_id: str
Expand Down Expand Up @@ -99,16 +86,7 @@ def requires_backfill(self) -> bool:

@property
def has_changes(self) -> bool:
modified_snapshot_ids = {
*self.context_diff.added,
*self.context_diff.removed_snapshots,
*self.context_diff.current_modified_snapshot_ids,
}
return (
self.context_diff.is_new_environment
or self.context_diff.is_unfinalized_environment
or bool(modified_snapshot_ids)
)
return self.context_diff.has_changes

@property
def has_unmodified_unpromoted(self) -> bool:
Expand Down Expand Up @@ -217,23 +195,6 @@ def environment(self) -> Environment:
else self.context_diff.previous_finalized_snapshots
)

requirements = {}
distributions = metadata.packages_distributions()

for snapshot in self.context_diff.snapshots.values():
if snapshot.is_model:
for executable in snapshot.model.python_env.values():
if executable.kind == "import":
try:
start = "from " if executable.payload.startswith("from ") else "import "
lib = executable.payload.split(start)[1].split()[0].split(".")[0]
if lib in distributions:
for dist in distributions[lib]:
if dist not in requirements and dist not in IGNORED_PACKAGES:
requirements[dist] = metadata.version(dist)
except metadata.PackageNotFoundError:
logger.warning("Failed to find package for %s", lib)

return Environment(
snapshots=snapshots,
start_at=self.provided_start or self._earliest_interval_start,
Expand All @@ -243,7 +204,7 @@ def environment(self) -> Environment:
expiration_ts=expiration_ts,
promoted_snapshot_ids=promoted_snapshot_ids,
previous_finalized_snapshots=previous_finalized_snapshots,
requirements=requirements,
requirements=self.context_diff.requirements,
**self.environment_naming_info.dict(),
)

Expand Down
19 changes: 14 additions & 5 deletions tests/core/test_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2569,18 +2569,27 @@ def test_interval_end_per_model(make_snapshot):
assert plan_builder.build().interval_end_per_model is None


def test_plan_requirements():
def test_plan_requirements(mocker):
context = Context(paths="examples/sushi")
model = context.get_model("sushi.items")
model.python_env["ruamel"] = Executable(payload="import ruamel", kind="import")
model.python_env["Image"] = Executable(
payload="from ipywidgets.widgets.widget_media import Image", kind="import"
)

plan = context.plan(
"dev", no_prompts=True, skip_tests=True, skip_backfill=True
).environment.requirements
assert set(plan) == {"ipywidgets", "numpy", "pandas", "ruamel.yaml", "ruamel.yaml.clib"}
environment = context.plan(
"dev", no_prompts=True, skip_tests=True, skip_backfill=True, auto_apply=True
).environment
requirements = {"ipywidgets", "numpy", "pandas", "ruamel.yaml", "ruamel.yaml.clib"}
assert environment.requirements["pandas"] == "2.2.2"
assert set(environment.requirements) == requirements

mocker.patch(
"sqlmesh.core.context_diff.ContextDiff.requirements", {"numpy": "2.1.2", "pandas": "2.2.1"}
)
diff = context.plan("dev", no_prompts=True, skip_tests=True, skip_backfill=True).context_diff
assert set(diff.previous_requirements) == requirements
assert set(diff.requirements) == {"numpy", "pandas"}


def test_unaligned_start_model_with_forward_only_preview(make_snapshot):
Expand Down