diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000..d7c4620 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,3 @@ +Release type: minor + +Support SQLAlchemy select API when resolving. diff --git a/src/strawberry_sqlalchemy_mapper/field.py b/src/strawberry_sqlalchemy_mapper/field.py index 2b66d98..b49d1a5 100644 --- a/src/strawberry_sqlalchemy_mapper/field.py +++ b/src/strawberry_sqlalchemy_mapper/field.py @@ -29,8 +29,9 @@ from typing_extensions import Annotated, TypeAlias from sqlakeyset.types import Keyset +from sqlalchemy import Select, select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import Query, Session +from sqlalchemy.orm import Session from strawberry import relay from strawberry.annotation import StrawberryAnnotation from strawberry.extensions.field_extension import ( @@ -59,11 +60,11 @@ assert argument # type: ignore[truthy-function] -connection_session: contextvars.ContextVar[ - Union[Session, AsyncSession, None] -] = contextvars.ContextVar( - "connection-session", - default=None, +connection_session: contextvars.ContextVar[Union[Session, AsyncSession, None]] = ( + contextvars.ContextVar( + "connection-session", + default=None, + ) ) @@ -97,7 +98,7 @@ def __init__( @dataclasses.dataclass class StrawberrySQLAlchemyAsyncQuery: session: AsyncSession - query: Callable[[Session], Query] + query: Callable[[], Select] iterator: Iterator[Any] | None = None limit: int | None = None offset: int | None = None @@ -120,16 +121,13 @@ def __aiter__(self): async def __anext__(self): if self.iterator is None: + q = self.query() + if self.limit is not None: + q = q.limit(self.limit) + if self.offset is not None: + q = q.offset(self.offset) - def query_runner(s: Session): - q = self.query(s) - if self.limit is not None: - q = q.limit(self.limit) - if self.offset is not None: - q = q.offset(self.offset) - return list(q) - - self.iterator = iter(await self.session.run_sync(query_runner)) + self.iterator = iter(await self.session.scalars(q)) try: return next(self.iterator) @@ -325,7 +323,7 @@ def default_resolver( if session is None: session = field_sessionmaker() - def _get_query(s: Session): + def _get_orm_query(s: Session): if root is not None: # root won't be None when resolving nested connections. # TODO: Maybe we want to send this to a dataloader? @@ -338,16 +336,29 @@ def _get_query(s: Session): return query + def _get_select_query(): + if root is not None: + # root won't be None when resolving nested connections. + # TODO: Maybe we want to send this to a dataloader? + query = getattr(root, field.python_name) + else: + query = select(model) + + if field.keyset is not None: + query = query.order_by(*field.keyset) + + return query + if isinstance(session, AsyncSession): return cast( Iterable[Any], StrawberrySQLAlchemyAsyncQuery( session=session, - query=lambda s: _get_query(s), + query=_get_select_query, ), ) - return _get_query(session) + return _get_orm_query(session) field.base_resolver = StrawberryResolver(default_resolver) @@ -415,8 +426,7 @@ def field( graphql_type: Any | None = None, extensions: Sequence[FieldExtension] = (), sessionmaker: _SessionMaker | None = None, -) -> _T: - ... +) -> _T: ... @overload @@ -437,8 +447,7 @@ def field( graphql_type: Any | None = None, extensions: Sequence[FieldExtension] = (), sessionmaker: _SessionMaker | None = None, -) -> Any: - ... +) -> Any: ... @overload @@ -459,8 +468,7 @@ def field( graphql_type: Any | None = None, extensions: Sequence[FieldExtension] = (), sessionmaker: _SessionMaker | None = None, -) -> StrawberrySQLAlchemyField: - ... +) -> StrawberrySQLAlchemyField: ... def field( @@ -599,8 +607,7 @@ def connection( extensions: Sequence[FieldExtension] = (), sessionmaker: _SessionMaker | None = None, keyset: Keyset | None = None, -) -> Any: - ... +) -> Any: ... @overload @@ -622,8 +629,7 @@ def connection( extensions: Sequence[FieldExtension] = (), sessionmaker: _SessionMaker | None = None, keyset: Keyset | None = None, -) -> Any: - ... +) -> Any: ... def connection( diff --git a/src/strawberry_sqlalchemy_mapper/relay.py b/src/strawberry_sqlalchemy_mapper/relay.py index 4defe0a..9b7dd12 100644 --- a/src/strawberry_sqlalchemy_mapper/relay.py +++ b/src/strawberry_sqlalchemy_mapper/relay.py @@ -14,11 +14,13 @@ ) import sqlakeyset +import sqlakeyset.asyncio import strawberry -from sqlalchemy import and_, or_ +from sqlalchemy import Row, Select, and_, or_ from sqlalchemy.exc import NoResultFound from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.inspection import inspect as sqlalchemy_inspect +from sqlalchemy.orm import Query from strawberry import relay from strawberry.relay.exceptions import NodeIDAnnotationError from strawberry.relay.types import NodeType @@ -27,7 +29,7 @@ if TYPE_CHECKING: from typing_extensions import Literal, Self - from sqlalchemy.orm import Query, Session + from sqlalchemy.orm import Session from strawberry.types.info import Info from strawberry.utils.await_maybe import AwaitableOrValue @@ -64,7 +66,7 @@ class KeysetConnection(relay.Connection[NodeType]): @classmethod def resolve_connection( cls, - nodes: Union[Query, StrawberrySQLAlchemyAsyncQuery], # type: ignore[override] + nodes: Union[Query, Select, StrawberrySQLAlchemyAsyncQuery], # type: ignore[override] *, info: Info, before: Optional[str] = None, @@ -110,40 +112,76 @@ def resolve_connection(page: sqlakeyset.Page): end_cursor=page.paging.get_bookmark_at(-1) if page else None, ), edges=[ - edge_class.resolve_edge(n, cursor=page.paging.get_bookmark_at(i)) + edge_class.resolve_edge( + n[0] if isinstance(n, Row) else n, + cursor=page.paging.get_bookmark_at(i), + ) for i, n in enumerate(page) ], ) - def resolve_nodes(s: Session, nodes=nodes): - if isinstance(nodes, StrawberrySQLAlchemyAsyncQuery): - nodes = nodes.query(s) + def resolve_nodes(s: Session, nodes: Union[Query, Select]): + if isinstance(nodes, Select): + return resolve_connection( + sqlakeyset.select_page( + s, + nodes, + per_page=per_page, + after=( + sqlakeyset.unserialize_bookmark(after).place + if after + else None + ), + before=( + sqlakeyset.unserialize_bookmark(before).place + if before + else None + ), + ) + ) return resolve_connection( sqlakeyset.get_page( nodes, + per_page=per_page, + after=( + sqlakeyset.unserialize_bookmark(after).place if after else None + ), before=( sqlakeyset.unserialize_bookmark(before).place if before else None ), + ) + ) + + async def resolve_nodes_async(s: AsyncSession, nodes: Select): + # the asynchronous SQLAlchemy API only supports select + return resolve_connection( + await sqlakeyset.asyncio.select_page( + s, + nodes, + per_page=per_page, after=( sqlakeyset.unserialize_bookmark(after).place if after else None ), - per_page=per_page, + before=( + sqlakeyset.unserialize_bookmark(before).place + if before + else None + ), ) ) - # TODO: It would be better to aboid session.run_sync in here but - # sqlakeyset doesn't have a `get_page` async counterpart. if isinstance(session, AsyncSession): + if isinstance(nodes, StrawberrySQLAlchemyAsyncQuery): + nodes = nodes.query() - async def resolve_async(nodes=nodes): - return await session.run_sync(lambda s: resolve_nodes(s)) - - return resolve_async() + assert isinstance(nodes, Select) + return resolve_nodes_async(session, nodes) - return resolve_nodes(session) + assert isinstance(nodes, (Query, Select)) + return resolve_nodes(session, nodes) @overload diff --git a/tests/relay/test_node.py b/tests/relay/test_node.py index e520957..df20a91 100644 --- a/tests/relay/test_node.py +++ b/tests/relay/test_node.py @@ -109,7 +109,7 @@ class Query: await session.commit() session.add_all([f1, f2, f3]) - session.commit() + await session.commit() for f in [f1, f2, f3]: result = await schema.execute(query, {"id": relay.to_base64("Fruit", f.id)}) @@ -266,7 +266,7 @@ class Query: await session.commit() session.add_all([f1, f2, f3]) - session.commit() + await session.commit() result = await schema.execute( query,