Skip to content

Commit

Permalink
feat: support sqlalchemy select API
Browse files Browse the repository at this point in the history
  • Loading branch information
novag committed Jan 10, 2025
1 parent 2133fd7 commit 918ea6c
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 46 deletions.
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Release type: minor

Support SQLAlchemy select API when resolving.
64 changes: 35 additions & 29 deletions src/strawberry_sqlalchemy_mapper/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@
from typing_extensions import Annotated, TypeAlias

from sqlakeyset.types import Keyset
from sqlalchemy import Select, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Query, Session
from sqlalchemy.orm import Session
from strawberry import relay
from strawberry.annotation import StrawberryAnnotation
from strawberry.extensions.field_extension import (
Expand Down Expand Up @@ -59,11 +60,11 @@
assert argument # type: ignore[truthy-function]


connection_session: contextvars.ContextVar[
Union[Session, AsyncSession, None]
] = contextvars.ContextVar(
"connection-session",
default=None,
connection_session: contextvars.ContextVar[Union[Session, AsyncSession, None]] = (
contextvars.ContextVar(
"connection-session",
default=None,
)
)


Expand Down Expand Up @@ -97,7 +98,7 @@ def __init__(
@dataclasses.dataclass
class StrawberrySQLAlchemyAsyncQuery:
session: AsyncSession
query: Callable[[Session], Query]
query: Callable[[], Select]
iterator: Iterator[Any] | None = None
limit: int | None = None
offset: int | None = None
Expand All @@ -120,16 +121,13 @@ def __aiter__(self):

async def __anext__(self):
if self.iterator is None:
q = self.query()
if self.limit is not None:
q = q.limit(self.limit)
if self.offset is not None:
q = q.offset(self.offset)

def query_runner(s: Session):
q = self.query(s)
if self.limit is not None:
q = q.limit(self.limit)
if self.offset is not None:
q = q.offset(self.offset)
return list(q)

self.iterator = iter(await self.session.run_sync(query_runner))
self.iterator = iter(await self.session.scalars(q))

try:
return next(self.iterator)
Expand Down Expand Up @@ -325,7 +323,7 @@ def default_resolver(
if session is None:
session = field_sessionmaker()

def _get_query(s: Session):
def _get_orm_query(s: Session):
if root is not None:
# root won't be None when resolving nested connections.
# TODO: Maybe we want to send this to a dataloader?
Expand All @@ -338,16 +336,29 @@ def _get_query(s: Session):

return query

def _get_select_query():
if root is not None:
# root won't be None when resolving nested connections.
# TODO: Maybe we want to send this to a dataloader?
query = getattr(root, field.python_name)
else:
query = select(model)

if field.keyset is not None:
query = query.order_by(*field.keyset)

return query

if isinstance(session, AsyncSession):
return cast(
Iterable[Any],
StrawberrySQLAlchemyAsyncQuery(
session=session,
query=lambda s: _get_query(s),
query=_get_select_query,
),
)

return _get_query(session)
return _get_orm_query(session)

field.base_resolver = StrawberryResolver(default_resolver)

Expand Down Expand Up @@ -415,8 +426,7 @@ def field(
graphql_type: Any | None = None,
extensions: Sequence[FieldExtension] = (),
sessionmaker: _SessionMaker | None = None,
) -> _T:
...
) -> _T: ...


@overload
Expand All @@ -437,8 +447,7 @@ def field(
graphql_type: Any | None = None,
extensions: Sequence[FieldExtension] = (),
sessionmaker: _SessionMaker | None = None,
) -> Any:
...
) -> Any: ...


@overload
Expand All @@ -459,8 +468,7 @@ def field(
graphql_type: Any | None = None,
extensions: Sequence[FieldExtension] = (),
sessionmaker: _SessionMaker | None = None,
) -> StrawberrySQLAlchemyField:
...
) -> StrawberrySQLAlchemyField: ...


def field(
Expand Down Expand Up @@ -599,8 +607,7 @@ def connection(
extensions: Sequence[FieldExtension] = (),
sessionmaker: _SessionMaker | None = None,
keyset: Keyset | None = None,
) -> Any:
...
) -> Any: ...


@overload
Expand All @@ -622,8 +629,7 @@ def connection(
extensions: Sequence[FieldExtension] = (),
sessionmaker: _SessionMaker | None = None,
keyset: Keyset | None = None,
) -> Any:
...
) -> Any: ...


def connection(
Expand Down
68 changes: 53 additions & 15 deletions src/strawberry_sqlalchemy_mapper/relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
)

import sqlakeyset
import sqlakeyset.asyncio
import strawberry
from sqlalchemy import and_, or_
from sqlalchemy import Row, Select, and_, or_
from sqlalchemy.exc import NoResultFound
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.inspection import inspect as sqlalchemy_inspect
from sqlalchemy.orm import Query
from strawberry import relay
from strawberry.relay.exceptions import NodeIDAnnotationError
from strawberry.relay.types import NodeType
Expand All @@ -27,7 +29,7 @@
if TYPE_CHECKING:
from typing_extensions import Literal, Self

from sqlalchemy.orm import Query, Session
from sqlalchemy.orm import Session
from strawberry.types.info import Info
from strawberry.utils.await_maybe import AwaitableOrValue

Expand Down Expand Up @@ -64,7 +66,7 @@ class KeysetConnection(relay.Connection[NodeType]):
@classmethod
def resolve_connection(
cls,
nodes: Union[Query, StrawberrySQLAlchemyAsyncQuery], # type: ignore[override]
nodes: Union[Query, Select, StrawberrySQLAlchemyAsyncQuery], # type: ignore[override]
*,
info: Info,
before: Optional[str] = None,
Expand Down Expand Up @@ -110,40 +112,76 @@ def resolve_connection(page: sqlakeyset.Page):
end_cursor=page.paging.get_bookmark_at(-1) if page else None,
),
edges=[
edge_class.resolve_edge(n, cursor=page.paging.get_bookmark_at(i))
edge_class.resolve_edge(
n[0] if isinstance(n, Row) else n,
cursor=page.paging.get_bookmark_at(i),
)
for i, n in enumerate(page)
],
)

def resolve_nodes(s: Session, nodes=nodes):
if isinstance(nodes, StrawberrySQLAlchemyAsyncQuery):
nodes = nodes.query(s)
def resolve_nodes(s: Session, nodes: Union[Query, Select]):
if isinstance(nodes, Select):
return resolve_connection(
sqlakeyset.select_page(
s,
nodes,
per_page=per_page,
after=(
sqlakeyset.unserialize_bookmark(after).place
if after
else None
),
before=(
sqlakeyset.unserialize_bookmark(before).place
if before
else None
),
)
)

return resolve_connection(
sqlakeyset.get_page(
nodes,
per_page=per_page,
after=(
sqlakeyset.unserialize_bookmark(after).place if after else None
),
before=(
sqlakeyset.unserialize_bookmark(before).place
if before
else None
),
)
)

async def resolve_nodes_async(s: AsyncSession, nodes: Select):
# the asynchronous SQLAlchemy API only supports select
return resolve_connection(
await sqlakeyset.asyncio.select_page(
s,
nodes,
per_page=per_page,
after=(
sqlakeyset.unserialize_bookmark(after).place if after else None
),
per_page=per_page,
before=(
sqlakeyset.unserialize_bookmark(before).place
if before
else None
),
)
)

# TODO: It would be better to aboid session.run_sync in here but
# sqlakeyset doesn't have a `get_page` async counterpart.
if isinstance(session, AsyncSession):
if isinstance(nodes, StrawberrySQLAlchemyAsyncQuery):
nodes = nodes.query()

async def resolve_async(nodes=nodes):
return await session.run_sync(lambda s: resolve_nodes(s))

return resolve_async()
assert isinstance(nodes, Select)
return resolve_nodes_async(session, nodes)

return resolve_nodes(session)
assert isinstance(nodes, (Query, Select))
return resolve_nodes(session, nodes)


@overload
Expand Down
4 changes: 2 additions & 2 deletions tests/relay/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class Query:
await session.commit()

session.add_all([f1, f2, f3])
session.commit()
await session.commit()

for f in [f1, f2, f3]:
result = await schema.execute(query, {"id": relay.to_base64("Fruit", f.id)})
Expand Down Expand Up @@ -266,7 +266,7 @@ class Query:
await session.commit()

session.add_all([f1, f2, f3])
session.commit()
await session.commit()

result = await schema.execute(
query,
Expand Down

0 comments on commit 918ea6c

Please sign in to comment.