Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support sqlalchemy select API #229

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Release type: minor

Support SQLAlchemy select API when resolving.
64 changes: 35 additions & 29 deletions src/strawberry_sqlalchemy_mapper/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@
from typing_extensions import Annotated, TypeAlias

from sqlakeyset.types import Keyset
from sqlalchemy import Select, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Query, Session
from sqlalchemy.orm import Session
from strawberry import relay
from strawberry.annotation import StrawberryAnnotation
from strawberry.extensions.field_extension import (
Expand Down Expand Up @@ -59,11 +60,11 @@
assert argument # type: ignore[truthy-function]


connection_session: contextvars.ContextVar[
Union[Session, AsyncSession, None]
] = contextvars.ContextVar(
"connection-session",
default=None,
connection_session: contextvars.ContextVar[Union[Session, AsyncSession, None]] = (
contextvars.ContextVar(
"connection-session",
default=None,
)
)


Expand Down Expand Up @@ -97,7 +98,7 @@ def __init__(
@dataclasses.dataclass
class StrawberrySQLAlchemyAsyncQuery:
session: AsyncSession
query: Callable[[Session], Query]
query: Callable[[], Select]
iterator: Iterator[Any] | None = None
limit: int | None = None
offset: int | None = None
Expand All @@ -120,16 +121,13 @@ def __aiter__(self):

async def __anext__(self):
if self.iterator is None:
q = self.query()
if self.limit is not None:
q = q.limit(self.limit)
if self.offset is not None:
q = q.offset(self.offset)

def query_runner(s: Session):
q = self.query(s)
if self.limit is not None:
q = q.limit(self.limit)
if self.offset is not None:
q = q.offset(self.offset)
return list(q)

self.iterator = iter(await self.session.run_sync(query_runner))
self.iterator = iter(await self.session.scalars(q))

try:
return next(self.iterator)
Expand Down Expand Up @@ -325,7 +323,7 @@ def default_resolver(
if session is None:
session = field_sessionmaker()

def _get_query(s: Session):
def _get_orm_query(s: Session):
if root is not None:
# root won't be None when resolving nested connections.
# TODO: Maybe we want to send this to a dataloader?
Expand All @@ -338,16 +336,29 @@ def _get_query(s: Session):

return query

def _get_select_query():
if root is not None:
# root won't be None when resolving nested connections.
# TODO: Maybe we want to send this to a dataloader?
query = getattr(root, field.python_name)
else:
query = select(model)

if field.keyset is not None:
query = query.order_by(*field.keyset)

return query

if isinstance(session, AsyncSession):
return cast(
Iterable[Any],
StrawberrySQLAlchemyAsyncQuery(
session=session,
query=lambda s: _get_query(s),
query=_get_select_query,
),
)

return _get_query(session)
return _get_orm_query(session)

field.base_resolver = StrawberryResolver(default_resolver)

Expand Down Expand Up @@ -415,8 +426,7 @@ def field(
graphql_type: Any | None = None,
extensions: Sequence[FieldExtension] = (),
sessionmaker: _SessionMaker | None = None,
) -> _T:
...
) -> _T: ...


@overload
Expand All @@ -437,8 +447,7 @@ def field(
graphql_type: Any | None = None,
extensions: Sequence[FieldExtension] = (),
sessionmaker: _SessionMaker | None = None,
) -> Any:
...
) -> Any: ...


@overload
Expand All @@ -459,8 +468,7 @@ def field(
graphql_type: Any | None = None,
extensions: Sequence[FieldExtension] = (),
sessionmaker: _SessionMaker | None = None,
) -> StrawberrySQLAlchemyField:
...
) -> StrawberrySQLAlchemyField: ...


def field(
Expand Down Expand Up @@ -599,8 +607,7 @@ def connection(
extensions: Sequence[FieldExtension] = (),
sessionmaker: _SessionMaker | None = None,
keyset: Keyset | None = None,
) -> Any:
...
) -> Any: ...


@overload
Expand All @@ -622,8 +629,7 @@ def connection(
extensions: Sequence[FieldExtension] = (),
sessionmaker: _SessionMaker | None = None,
keyset: Keyset | None = None,
) -> Any:
...
) -> Any: ...


def connection(
Expand Down
68 changes: 53 additions & 15 deletions src/strawberry_sqlalchemy_mapper/relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
)

import sqlakeyset
import sqlakeyset.asyncio
import strawberry
from sqlalchemy import and_, or_
from sqlalchemy import Row, Select, and_, or_
from sqlalchemy.exc import NoResultFound
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.inspection import inspect as sqlalchemy_inspect
from sqlalchemy.orm import Query
from strawberry import relay
from strawberry.relay.exceptions import NodeIDAnnotationError
from strawberry.relay.types import NodeType
Expand All @@ -27,7 +29,7 @@
if TYPE_CHECKING:
from typing_extensions import Literal, Self

from sqlalchemy.orm import Query, Session
from sqlalchemy.orm import Session
from strawberry.types.info import Info
from strawberry.utils.await_maybe import AwaitableOrValue

Expand Down Expand Up @@ -64,7 +66,7 @@ class KeysetConnection(relay.Connection[NodeType]):
@classmethod
def resolve_connection(
cls,
nodes: Union[Query, StrawberrySQLAlchemyAsyncQuery], # type: ignore[override]
nodes: Union[Query, Select, StrawberrySQLAlchemyAsyncQuery], # type: ignore[override]
*,
info: Info,
before: Optional[str] = None,
Expand Down Expand Up @@ -110,40 +112,76 @@ def resolve_connection(page: sqlakeyset.Page):
end_cursor=page.paging.get_bookmark_at(-1) if page else None,
),
edges=[
edge_class.resolve_edge(n, cursor=page.paging.get_bookmark_at(i))
edge_class.resolve_edge(
n[0] if isinstance(n, Row) else n,
cursor=page.paging.get_bookmark_at(i),
)
for i, n in enumerate(page)
],
)

def resolve_nodes(s: Session, nodes=nodes):
if isinstance(nodes, StrawberrySQLAlchemyAsyncQuery):
nodes = nodes.query(s)
def resolve_nodes(s: Session, nodes: Union[Query, Select]):
if isinstance(nodes, Select):
return resolve_connection(
sqlakeyset.select_page(
s,
nodes,
per_page=per_page,
after=(
sqlakeyset.unserialize_bookmark(after).place
if after
else None
),
before=(
sqlakeyset.unserialize_bookmark(before).place
if before
else None
),
)
)

return resolve_connection(
sqlakeyset.get_page(
nodes,
per_page=per_page,
after=(
sqlakeyset.unserialize_bookmark(after).place if after else None
),
before=(
sqlakeyset.unserialize_bookmark(before).place
if before
else None
),
)
)

async def resolve_nodes_async(s: AsyncSession, nodes: Select):
# the asynchronous SQLAlchemy API only supports select
return resolve_connection(
await sqlakeyset.asyncio.select_page(
s,
nodes,
per_page=per_page,
after=(
sqlakeyset.unserialize_bookmark(after).place if after else None
),
per_page=per_page,
before=(
sqlakeyset.unserialize_bookmark(before).place
if before
else None
),
)
)

# TODO: It would be better to aboid session.run_sync in here but
# sqlakeyset doesn't have a `get_page` async counterpart.
if isinstance(session, AsyncSession):
if isinstance(nodes, StrawberrySQLAlchemyAsyncQuery):
nodes = nodes.query()

async def resolve_async(nodes=nodes):
return await session.run_sync(lambda s: resolve_nodes(s))

return resolve_async()
assert isinstance(nodes, Select)
return resolve_nodes_async(session, nodes)

return resolve_nodes(session)
assert isinstance(nodes, (Query, Select))
return resolve_nodes(session, nodes)


@overload
Expand Down
4 changes: 2 additions & 2 deletions tests/relay/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class Query:
await session.commit()

session.add_all([f1, f2, f3])
session.commit()
await session.commit()

for f in [f1, f2, f3]:
result = await schema.execute(query, {"id": relay.to_base64("Fruit", f.id)})
Expand Down Expand Up @@ -266,7 +266,7 @@ class Query:
await session.commit()

session.add_all([f1, f2, f3])
session.commit()
await session.commit()

result = await schema.execute(
query,
Expand Down
Loading