Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enable MERGE in postgres and oracle #255

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,5 +168,6 @@ cython_debug/

# generated changelog
/docs/changelog.md
.zed
.cursorrules
.cursorignore
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ repos:
- id: unasyncd
additional_dependencies: ["ruff"]
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: "v0.6.1"
rev: "v0.6.7"
hooks:
- id: ruff
args: ["--fix"]
Expand All @@ -36,7 +36,7 @@ repos:
additional_dependencies:
- tomli
- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.11.1"
rev: "v1.11.2"
hooks:
- id: mypy
exclude: "docs"
Expand Down Expand Up @@ -71,6 +71,6 @@ repos:
# hooks:
# - id: pyright
- repo: https://github.com/sphinx-contrib/sphinx-lint
rev: "v0.9.1"
rev: "v1.0.0"
hooks:
- id: sphinx-lint
2 changes: 1 addition & 1 deletion advanced_alchemy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from fastnanoid import generate as nanoid # pyright: ignore[reportMissingImports]

else:
nanoid = uuid4
nanoid = uuid4 # type: ignore[assignment]

if TYPE_CHECKING:
from sqlalchemy.sql import FromClause
Expand Down
15 changes: 8 additions & 7 deletions advanced_alchemy/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import TYPE_CHECKING, Any

from sqlalchemy import ClauseElement, ColumnElement, UpdateBase
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.ext.compiler import compiles # pyright: ignore[reportUnknownVariableType]

if TYPE_CHECKING:
from typing import Literal
Expand Down Expand Up @@ -32,7 +32,7 @@ def where(self, expr: ColumnElement[bool]) -> MergeClause:
def visit_merge_clause(element: MergeClause, compiler: StrSQLCompiler, **kw: Any) -> str:
case_predicate = ""
if element.predicate is not None:
case_predicate = f" AND {element.predicate._compiler_dispatch(compiler, **kw)!s}" # noqa: SLF001
case_predicate = f" AND {element.predicate._compiler_dispatch(compiler, **kw)!s}" # noqa: SLF001 # pyright: ignore[reportPrivateUsage]

if element.command == "INSERT":
sets, sets_tos = list(element.on_sets), list(element.on_sets.values())
Expand All @@ -41,7 +41,7 @@ def visit_merge_clause(element: MergeClause, compiler: StrSQLCompiler, **kw: Any
sets, sets_tos = list(sorted_on_sets), list(sorted_on_sets.values())

merge_insert = ", ".join(sets)
values = ", ".join(e._compiler_dispatch(compiler, **kw) for e in sets_tos) # noqa: SLF001
values = ", ".join(e._compiler_dispatch(compiler, **kw) for e in sets_tos) # noqa: SLF001 # pyright: ignore[reportPrivateUsage]
return f"WHEN NOT MATCHED{case_predicate} THEN {element.command} ({merge_insert}) VALUES ({values})"

set_list = list(element.on_sets.items())
Expand All @@ -54,7 +54,8 @@ def visit_merge_clause(element: MergeClause, compiler: StrSQLCompiler, **kw: Any

if element.on_sets:
values = ", ".join(
f"{name} = {column._compiler_dispatch(compiler, **kw)}" for name, column in set_list # noqa: SLF001
f"{name} = {column._compiler_dispatch(compiler, **kw)}" # noqa: SLF001 # pyright: ignore[reportPrivateUsage]
for name, column in set_list
)
merge_action = f" SET {values}"

Expand All @@ -75,13 +76,13 @@ def __init__(self, into: Any, using: Any, on: Any) -> None:
def when_matched(self, operations: set[Literal["UPDATE", "DELETE", "INSERT"]]) -> MergeClause:
for op in operations:
self.clauses.append(clause := MergeClause(op))
return clause
return clause # pyright: ignore[reportPossiblyUnboundVariable]


@compiles(Merge) # type: ignore[no-untyped-call, misc]
def visit_merge(element: Merge, compiler: StrSQLCompiler, **kw: Any) -> str:
clauses = " ".join(clause._compiler_dispatch(compiler, **kw) for clause in element.clauses) # noqa: SLF001
sql_text = f"MERGE INTO {element.into} USING {element.using} ON {element.on}"
clauses = " ".join(clause._compiler_dispatch(compiler, **kw) for clause in element.clauses) # noqa: SLF001 # pyright: ignore[reportPrivateUsage]
sql_text = f"MERGE INTO {element.into} USING {element.using} ON ({element.on})"

if clauses:
sql_text += f" {clauses}"
Expand Down
182 changes: 164 additions & 18 deletions advanced_alchemy/repository/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
StatementLambdaElement,
TextClause,
any_,
bindparam,
delete,
lambda_stmt,
over,
Expand Down Expand Up @@ -1718,24 +1719,6 @@ async def upsert(
self._expunge(instance, auto_expunge=auto_expunge)
return instance

def _supports_merge_operations(self, force_disable_merge: bool = False) -> bool:
return (
(
self._dialect.server_version_info is not None
and self._dialect.server_version_info[0] >= POSTGRES_VERSION_SUPPORTING_MERGE
and self._dialect.name == "postgresql"
)
or self._dialect.name == "oracle"
) and not force_disable_merge

def _get_merge_stmt(
self,
into: Any,
using: Any,
on: Any,
) -> Merge:
return Merge(into=into, using=using, on=on)

async def upsert_many(
self,
data: list[ModelT],
Expand Down Expand Up @@ -1805,12 +1788,93 @@ async def upsert_many(
execution_options=execution_options,
auto_expunge=False,
)
if self._supports_merge_operations(force_disable_merge=no_merge):
result = await self.session.execute(self._get_merge_stmt(data=data, match_fields=match_fields))
instances = cast("list[ModelT]", result.fetchall())
await self._flush_or_commit(auto_commit=auto_commit)
for instance in instances:
self._expunge(instance, auto_expunge=auto_expunge)
return instances
# fallback to the insert/update method
for field_name in match_fields:
field = get_instrumented_attr(self.model_type, field_name)
matched_values = list(
{getattr(datum, field_name) for datum in existing_objs if datum}, # ensure the list is unique
)
match_filter.append(
any_(matched_values) == field if self._prefer_any else field.in_(matched_values), # type: ignore[arg-type]
)
existing_ids = self._get_object_ids(existing_objs=existing_objs)
data = self._merge_on_match_fields(data, existing_objs, match_fields)
for datum in data:
if getattr(datum, self.id_attribute, None) in existing_ids:
data_to_update.append(datum)
else:
data_to_insert.append(datum)
if data_to_insert:
instances.extend(
await self.add_many(data_to_insert, auto_commit=False, auto_expunge=False),
)
if data_to_update:
instances.extend(
await self.update_many(
data_to_update,
auto_commit=False,
auto_expunge=False,
load=load,
execution_options=execution_options,
),
)
await self._flush_or_commit(auto_commit=auto_commit)
for instance in instances:
self._expunge(instance, auto_expunge=auto_expunge)
return instances

async def _upsert_many_default(
self,
data: list[ModelT],
*,
auto_expunge: bool | None = None,
auto_commit: bool | None = None,
match_fields: list[str] | str | None = None,
error_messages: ErrorMessages | None | EmptyType = Empty,
load: LoadSpec | None = None,
execution_options: dict[str, Any] | None = None,
) -> list[ModelT]:
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
instances: list[ModelT] = []
data_to_update: list[ModelT] = []
data_to_insert: list[ModelT] = []
match_fields = self._get_match_fields(match_fields=match_fields)
if match_fields is None:
match_fields = [self.id_attribute]
match_filter: list[StatementFilter | ColumnElement[bool]] = []
if match_fields:
for field_name in match_fields:
field = get_instrumented_attr(self.model_type, field_name)
matched_values = [
field_data for datum in data if (field_data := getattr(datum, field_name)) is not None
]
match_filter.append(any_(matched_values) == field if self._prefer_any else field.in_(matched_values)) # type: ignore[arg-type]

with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
existing_objs = await self.list(
*match_filter,
load=load,
execution_options=execution_options,
auto_expunge=False,
)
for field_name in match_fields:
field = get_instrumented_attr(self.model_type, field_name)
matched_values = list(
{getattr(datum, field_name) for datum in existing_objs if datum}, # ensure the list is unique
)
match_filter.append(
any_(matched_values) == field if self._prefer_any else field.in_(matched_values), # type: ignore[arg-type]
)
existing_ids = self._get_object_ids(existing_objs=existing_objs)
data = self._merge_on_match_fields(data, existing_objs, match_fields)
for datum in data:
Expand All @@ -1837,6 +1901,88 @@ async def upsert_many(
self._expunge(instance, auto_expunge=auto_expunge)
return instances

async def _upsert_many_merge(
self,
data: list[ModelT],
*,
auto_expunge: bool | None = None,
auto_commit: bool | None = None,
match_fields: list[str] | str | None = None,
error_messages: ErrorMessages | None | EmptyType = Empty,
load: LoadSpec | None = None,
execution_options: dict[str, Any] | None = None,
) -> list[ModelT]:
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
match_fields = self._get_match_fields(match_fields=match_fields)
if match_fields is None:
match_fields = [self.id_attribute]

with wrap_sqlalchemy_exception(error_messages=error_messages, dialect_name=self._dialect.name):
target = self.model_type.__table__
columns = [column for column in target.columns if column.name not in match_fields]
source = select(bindparam("src_data", type_=list, expanding=True)) # pyright: ignore[reportUnknownArgumentType]
if self._dialect.name == "oracle":
source = source.select_from(text("DUAL"))

on_clause = sql_func.and_(*[target.c[field] == source.c[field] for field in match_fields])
merge_stmt = Merge(into=target, using=source, on=on_clause)
merge_stmt = merge_stmt.when_matched({"UPDATE"}).values(
**{column.name: source.c[column.name] for column in columns},
)

values = [{column.name: getattr(item, column.name) for column in target.columns} for item in data]
result = await self.session.execute(merge_stmt, bind_arguments={"src_data": values})
instances = result.fetchall()
await self._flush_or_commit(auto_commit=auto_commit)
for instance in instances:
self._expunge(instance, auto_expunge=auto_expunge)
return instances

def _supports_merge_operations(self, force_disable_merge: bool = False) -> bool:
return (
(
self._dialect.server_version_info is not None
and self._dialect.server_version_info[0] >= POSTGRES_VERSION_SUPPORTING_MERGE
and self._dialect.name == "postgresql"
)
or self._dialect.name == "oracle"
) and not force_disable_merge

def _get_merge_stmt(
self,
data: list[ModelT],
match_fields: list[str],
) -> Merge:
target = self.model_type.__table__
bind_values = [
{column.name: getattr(item, column.name) for column in target.columns if hasattr(item, column.name)}
for item in data
]

source = (
select(
*[text(f":_{column.name}").label(column.name) for column in target.columns], # pyright: ignore[reportUnknownArgumentType]
)
.select_from(text("(values :data)"))
.params(data=bind_values)
.subquery(name="src")
)

on_clause = sql_func.and_(*[target.c[field] == source.c[field] for field in match_fields])

merge = Merge(into=target, using=source, on=on_clause)

update_columns = {c.name: c for c in target.c if c.name not in match_fields}
insert_columns = {c.name: c for c in target.c}

merge.when_matched({"UPDATE"}).values(**update_columns)
merge.when_matched({"INSERT"}).values(**insert_columns)

return merge

def _get_object_ids(self, existing_objs: list[ModelT]) -> list[Any]:
return [obj_id for datum in existing_objs if (obj_id := getattr(datum, self.id_attribute)) is not None]

Expand Down
Loading
Loading