Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: impl graphql relay resolver for user and group #1719

Merged
merged 35 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
c8f4089
impl graphql relay resolver for user and group
fregataa Nov 15, 2023
74d297b
optimize group-user association query
fregataa Nov 16, 2023
7f20245
enable query node
fregataa Nov 16, 2023
796b24c
impl customized Connection object
fregataa Nov 16, 2023
8419f1f
connection root query works
fregataa Nov 16, 2023
18fca0b
complete fields in graphene nodes
fregataa Nov 17, 2023
4c89582
add news fragment
fregataa Nov 17, 2023
8f4b770
add deprecation message
fregataa Nov 17, 2023
e1784f3
Merge branch 'main' into feature/graphene-async-relay
fregataa Nov 19, 2023
469088d
impl connection pagination adaptor for sqlalchemy
fregataa Nov 20, 2023
4e04840
now it works
fregataa Nov 21, 2023
7c9f02a
Merge branch 'main' into feature/graphene-async-relay
fregataa Nov 22, 2023
f6a5ac5
it now works with offset
fregataa Nov 22, 2023
e16247b
some minor fixes
fregataa Nov 22, 2023
507623b
merge divided resolvers
fregataa Nov 22, 2023
85b2a73
fix connection cursor bug
fregataa Nov 22, 2023
024aea8
set default ordering in sql paginated query and etc
fregataa Nov 22, 2023
8b55da1
declare OrderingItem in minilang ordering
fregataa Nov 23, 2023
835348b
Merge branch 'main' into feature/graphene-async-relay
fregataa Nov 23, 2023
c74241d
minor changes
fregataa Nov 23, 2023
2a4be94
add user node resolver
fregataa Nov 23, 2023
38b6524
refactor more
fregataa Nov 23, 2023
e9a9803
refactor
fregataa Nov 23, 2023
74fc389
apply filter to count_func
fregataa Nov 23, 2023
b0cc252
reverse the result in node resolver not in connection resolver
fregataa Nov 23, 2023
46cc493
change names for better readability
fregataa Nov 24, 2023
d5887b6
let resolver reverse ordering when backward pagination
fregataa Nov 24, 2023
0088003
Merge branch 'main' into feature/graphene-async-relay
fregataa Nov 24, 2023
73625f5
apply join on total count resovler of user node
fregataa Nov 24, 2023
c8eaa2f
Merge branch 'main' into feature/graphene-async-relay
fregataa Nov 28, 2023
7201f2b
better type name
fregataa Nov 28, 2023
2e537b9
add description message to fields
fregataa Nov 28, 2023
b171569
a bit more comments
fregataa Nov 28, 2023
fc99cbc
fix description fields
fregataa Nov 29, 2023
fc74d2d
minor name change
fregataa Nov 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/1719.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implement async compatible graphql relay node object and implement group/user graphql relay nodes.
229 changes: 229 additions & 0 deletions src/ai/backend/manager/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
List,
Mapping,
MutableMapping,
NamedTuple,
Optional,
Protocol,
Sequence,
Expand Down Expand Up @@ -68,6 +69,13 @@

from .. import models
from ..api.exceptions import GenericForbidden, InvalidAPIParameters
from .gql_relay import (
AsyncListConnectionField,
AsyncNode,
ConnectionPaginationOrder,
)
from .minilang.ordering import OrderDirection, OrderingItem, QueryOrderParser
from .minilang.queryfilter import QueryFilterParser, WhereClauseType

if TYPE_CHECKING:
from .gql import GraphQueryContext
Expand Down Expand Up @@ -1098,3 +1106,224 @@ class InferenceSessionErrorInfo(graphene.ObjectType):
session_id = graphene.UUID()

errors = graphene.List(graphene.NonNull(InferenceSessionErrorInfo), required=True)


class AsyncPaginatedConnectionField(AsyncListConnectionField):
def __init__(self, type, *args, **kwargs):
kwargs.setdefault("filter", graphene.String())
kwargs.setdefault("order", graphene.String())
kwargs.setdefault("offset", graphene.Int())
super().__init__(type, *args, **kwargs)


PaginatedConnectionField = AsyncPaginatedConnectionField


class ConnectionArgs(NamedTuple):
cursor: str | None
pagination_order: ConnectionPaginationOrder | None
requested_page_size: int | None


def validate_connection_args(
*,
after: str | None = None,
first: int | None = None,
before: str | None = None,
last: int | None = None,
) -> ConnectionArgs:
"""
Validate arguments used for GraphQL relay connection, and determine pagination ordering, cursor and page size.
It is not allowed to use arguments for forward pagination and arguments for backward pagination at the same time.
"""
order: ConnectionPaginationOrder | None = None
cursor: str | None = None
requested_page_size: int | None = None

if after is not None:
order = ConnectionPaginationOrder.FORWARD
cursor = after
if first is not None:
if first < 0:
raise ValueError("Argument 'first' must be a non-negative integer.")
order = ConnectionPaginationOrder.FORWARD
requested_page_size = first

if before is not None:
if order is ConnectionPaginationOrder.FORWARD:
raise ValueError(
"Can only paginate with single direction, forwards or backwards. Please set only"
" one of (after, first) and (before, last)."
)
order = ConnectionPaginationOrder.BACKWARD
cursor = before
if last is not None:
if last < 0:
raise ValueError("Argument 'last' must be a non-negative integer.")
if order is ConnectionPaginationOrder.FORWARD:
raise ValueError(
"Can only paginate with single direction, forwards or backwards. Please set only"
" one of (after, first) and (before, last)."
)
order = ConnectionPaginationOrder.BACKWARD
requested_page_size = last

return ConnectionArgs(cursor, order, requested_page_size)


def _build_sql_stmt_from_connection_arg(
info: graphene.ResolveInfo,
orm_class,
id_column: sa.Column,
filter_expr: str | None = None,
order_expr: str | None = None,
*,
connection_arg: ConnectionArgs,
) -> tuple[sa.sql.Select, list[WhereClauseType]]:
rapsealk marked this conversation as resolved.
Show resolved Hide resolved
stmt = sa.select(orm_class)
conditions: list[WhereClauseType] = []

cursor_id, pagination_order, requested_page_size = connection_arg
rapsealk marked this conversation as resolved.
Show resolved Hide resolved

# Default ordering by id column
id_ordering_item: OrderingItem = OrderingItem(id_column, OrderDirection.ASC)
ordering_item_list: list[OrderingItem] = []
if order_expr is not None:
parser = QueryOrderParser()
ordering_item_list = parser.parse_order(orm_class, order_expr)

# Apply SQL order_by
match pagination_order:
case ConnectionPaginationOrder.FORWARD | None:
set_ordering = lambda col, direction: (
col.asc() if direction == OrderDirection.ASC else col.desc()
)
case ConnectionPaginationOrder.BACKWARD:
set_ordering = lambda col, direction: (
col.desc() if direction == OrderDirection.ASC else col.asc()
)
# id column should be applied last
for col, direction in [*ordering_item_list, id_ordering_item]:
stmt = stmt.order_by(set_ordering(col, direction))

# Set cursor by comparing scalar values of subquery that queried by cursor id
if cursor_id is not None:
_, _id = AsyncNode.resolve_global_id(info, cursor_id)
match pagination_order:
case ConnectionPaginationOrder.FORWARD | None:
conditions.append(id_column > _id)
set_subquery = lambda col, subquery, direction: (
col >= subquery if direction == OrderDirection.ASC else col <= subquery
)
case ConnectionPaginationOrder.BACKWARD:
conditions.append(id_column < _id)
set_subquery = lambda col, subquery, direction: (
col <= subquery if direction == OrderDirection.ASC else col >= subquery
)
for col, direction in ordering_item_list:
subq = sa.select(col).where(id_column == _id).scalar_subquery()
stmt = stmt.where(set_subquery(col, subq, direction))

if requested_page_size is not None:
# Add 1 to determine has_next_page or has_previous_page
stmt = stmt.limit(requested_page_size + 1)

if filter_expr is not None:
condition_parser = QueryFilterParser()
conditions.append(condition_parser.parse_filter(orm_class, filter_expr))

for cond in conditions:
stmt = stmt.where(cond)
return stmt, conditions


def _build_sql_stmt_from_sql_arg(
info: graphene.ResolveInfo,
orm_class,
id_column: sa.Column,
filter_expr: str | None = None,
order_expr: str | None = None,
*,
limit: int | None = None,
offset: int | None = None,
) -> tuple[sa.sql.Select, list[WhereClauseType]]:
stmt = sa.select(orm_class)
conditions: list[WhereClauseType] = []

if order_expr is not None:
parser = QueryOrderParser()
stmt = parser.append_ordering(stmt, order_expr)

# default order_by id column
stmt = stmt.order_by(id_column.asc())

if filter_expr is not None:
condition_parser = QueryFilterParser()
# stmt = condition_parser.append_filter(stmt, filter_expr)
conditions.append(condition_parser.parse_filter(orm_class, filter_expr))

if limit is not None:
stmt = stmt.limit(limit)

if offset is not None:
stmt = stmt.offset(offset)
return stmt, conditions


class GraphQLConnectionSQLInfo(NamedTuple):
sql_stmt: sa.sql.Select
sql_conditions: list[WhereClauseType]
cursor: str | None
pagination_order: ConnectionPaginationOrder | None
requested_page_size: int | None


def generate_sql_info_for_gql_connection(
info: graphene.ResolveInfo,
orm_class,
id_column: sa.Column,
filter_expr: str | None = None,
order_expr: str | None = None,
offset: int | None = None,
after: str | None = None,
first: int | None = None,
before: str | None = None,
last: int | None = None,
) -> GraphQLConnectionSQLInfo:
"""
Get GraphQL arguments and generate SQL query statement, cursor that points an id of a node, pagination order, and page size.
If `offset` is None, return SQL query parsed from GraphQL Connection spec arguments.
Else, return normally paginated SQL query and `first` is used as SQL limit.
"""

if offset is None:
connection_arg = validate_connection_args(
after=after, first=first, before=before, last=last
)
stmt, conditions = _build_sql_stmt_from_connection_arg(
info,
orm_class,
id_column,
filter_expr,
order_expr,
connection_arg=connection_arg,
)
return GraphQLConnectionSQLInfo(
stmt,
conditions,
connection_arg.cursor,
connection_arg.pagination_order,
connection_arg.requested_page_size,
)
rapsealk marked this conversation as resolved.
Show resolved Hide resolved
else:
page_size = first
stmt, conditions = _build_sql_stmt_from_sql_arg(
info,
orm_class,
id_column,
filter_expr,
order_expr,
limit=page_size,
offset=offset,
)
return GraphQLConnectionSQLInfo(stmt, conditions, None, None, page_size)
87 changes: 84 additions & 3 deletions src/ai/backend/manager/models/gql.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from ai.backend.common.types import QuotaScopeID
from ai.backend.manager.defs import DEFAULT_IMAGE_ARCH
from ai.backend.manager.models.gql_relay import AsyncNode, ConnectionResolverResult

from .etcd import (
ContainerRegistry,
Expand Down Expand Up @@ -49,10 +50,18 @@
)
from .acl import PredefinedAtomicPermission
from .agent import Agent, AgentList, AgentSummary, AgentSummaryList, ModifyAgent
from .base import DataLoaderManager, privileged_query, scoped_query
from .base import DataLoaderManager, PaginatedConnectionField, privileged_query, scoped_query
from .domain import CreateDomain, DeleteDomain, Domain, ModifyDomain, PurgeDomain
from .endpoint import Endpoint, EndpointList, EndpointToken, EndpointTokenList
from .group import CreateGroup, DeleteGroup, Group, ModifyGroup, PurgeGroup
from .group import (
CreateGroup,
DeleteGroup,
Group,
GroupConnection,
GroupNode,
ModifyGroup,
PurgeGroup,
)
from .image import (
AliasImage,
ClearImages,
Expand Down Expand Up @@ -114,7 +123,9 @@
ModifyUser,
PurgeUser,
User,
UserConnection,
UserList,
UserNode,
UserRole,
UserStatus,
)
Expand Down Expand Up @@ -238,7 +249,7 @@ class Queries(graphene.ObjectType):
All available GraphQL queries.
"""

node = graphene.relay.Node.Field()
node = AsyncNode.Field()

# super-admin only
agent = graphene.Field(
Expand Down Expand Up @@ -292,6 +303,11 @@ class Queries(graphene.ObjectType):
is_active=graphene.Boolean(),
)

group_node = graphene.Field(
GroupNode, id=graphene.String(required=True), description="Added in 24.03.0."
)
Comment on lines +306 to +308
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If AsyncNode is to handle Relay Global ID, how about setting id field to AsyncNode itself?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have tried it, but it needs so much work. so I deprioritized this task

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try this in next PR!

group_nodes = PaginatedConnectionField(GroupConnection, description="Added in 24.03.0.")

group = graphene.Field(
Group,
id=graphene.UUID(required=True),
Expand Down Expand Up @@ -358,6 +374,11 @@ class Queries(graphene.ObjectType):
status=graphene.String(),
)

user_node = graphene.Field(
UserNode, id=graphene.String(required=True), description="Added in 24.03.0."
)
user_nodes = PaginatedConnectionField(UserConnection, description="Added in 24.03.0.")

keypair = graphene.Field(
KeyPair,
domain_name=graphene.String(),
Expand Down Expand Up @@ -794,6 +815,36 @@ async def resolve_domains(
) -> Sequence[Domain]:
return await Domain.load_all(info.context, is_active=is_active)

async def resolve_group_node(
root: Any,
info: graphene.ResolveInfo,
id: str,
):
return await GroupNode.get_node(info, id)

async def resolve_group_nodes(
root: Any,
info: graphene.ResolveInfo,
*,
filter: str | None = None,
order: str | None = None,
offset: int | None = None,
after: str | None = None,
first: int | None = None,
before: str | None = None,
last: int | None = None,
) -> ConnectionResolverResult:
return await GroupNode.get_connection(
info,
filter,
order,
offset,
after,
first,
before,
last,
)

@staticmethod
async def resolve_group(
root: Any,
Expand Down Expand Up @@ -1089,6 +1140,36 @@ async def resolve_user_list(
)
return UserList(user_list, total_count)

async def resolve_user_node(
root: Any,
info: graphene.ResolveInfo,
id: str,
):
return await UserNode.get_node(info, id)

async def resolve_user_nodes(
root: Any,
info: graphene.ResolveInfo,
*,
filter: str | None = None,
order: str | None = None,
offset: int | None = None,
after: str | None = None,
first: int | None = None,
before: str | None = None,
last: int | None = None,
) -> ConnectionResolverResult:
return await UserNode.get_connection(
info,
filter,
order,
offset,
after,
first,
before,
last,
)

@staticmethod
@scoped_query(autofill_user=True, user_key="access_key")
async def resolve_keypair(
Expand Down
Loading
Loading