Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Cannot resolve model_card GQL query #2161

Merged
merged 17 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/2161.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix buggy resolver of `model_card` GQL Query.
3 changes: 3 additions & 0 deletions src/ai/backend/manager/api/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,9 @@
id: ID!
name: String
vfolder: VirtualFolder

"""Added in 24.09.0."""
vfolder_node: VirtualFolderNode

Check notice on line 1010 in src/ai/backend/manager/api/schema.graphql

View workflow job for this annotation

GitHub Actions / GraphQL Inspector

Field 'vfolder_node' was added to object type 'ModelCard'

Field 'vfolder_node' was added to object type 'ModelCard'
author: String

"""Human readable name of the model."""
Expand Down
47 changes: 42 additions & 5 deletions src/ai/backend/manager/models/vfolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from sqlalchemy.engine.row import Row
from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection
from sqlalchemy.ext.asyncio import AsyncSession as SASession
from sqlalchemy.orm import load_only, relationship, selectinload
from sqlalchemy.orm import joinedload, load_only, relationship, selectinload

from ai.backend.common.bgtask import ProgressReporter
from ai.backend.common.config import model_definition_iv
Expand Down Expand Up @@ -1347,6 +1347,31 @@ def _get_field(name: str) -> Any:
cur_size=row["cur_size"],
)

@classmethod
def from_orm_row(cls, row: VFolderRow) -> VirtualFolder:
return cls(
id=row.id,
host=row.host,
quota_scope_id=row.quota_scope_id,
name=row.name,
user=row.user,
user_email=row.user_row.email if row.user_row is not None else None,
group=row.group,
group_name=row.group_row.name if row.group_row is not None else None,
creator=row.creator,
unmanaged_path=row.unmanaged_path,
usage_mode=row.usage_mode,
permission=row.permission,
ownership_type=row.ownership_type,
max_files=row.max_files,
max_size=row.max_size,
created_at=row.created_at,
last_used=row.last_used,
cloneable=row.cloneable,
status=row.status,
cur_size=row.cur_size,
)

async def resolve_num_files(self, info: graphene.ResolveInfo) -> int:
# TODO: measure on-the-fly
return 0
Expand Down Expand Up @@ -1912,6 +1937,10 @@ async def get_connection(
last=last,
)

query = query.options(
joinedload(VFolderRow.user_row),
joinedload(VFolderRow.group_row),
)
async with graph_ctx.db.begin_readonly_session() as db_session:
vfolder_rows = (await db_session.scalars(query)).all()
result = [(cls.from_row(info, vf)) for vf in vfolder_rows]
Expand Down Expand Up @@ -2222,6 +2251,7 @@ class Meta:

name = graphene.String()
vfolder = graphene.Field(VirtualFolder)
vfolder_node = graphene.Field(VirtualFolderNode, description="Added in 24.09.0.")
fregataa marked this conversation as resolved.
Show resolved Hide resolved
author = graphene.String()
title = graphene.String(description="Human readable name of the model.")
version = graphene.String()
Expand Down Expand Up @@ -2304,7 +2334,7 @@ def resolve_created_at(
) -> datetime:
try:
return dtparse(self.created_at)
except ParserError:
except (TypeError, ParserError):
return self.created_at

def resolve_modified_at(
Expand All @@ -2313,7 +2343,7 @@ def resolve_modified_at(
) -> datetime:
try:
return dtparse(self.modified_at)
except ParserError:
except (TypeError, ParserError):
return self.modified_at

@classmethod
Expand All @@ -2338,6 +2368,8 @@ def parse_model(
name = vfolder_row.name
return cls(
id=vfolder_row.id,
vfolder=VirtualFolder.from_orm_row(vfolder_row),
vfolder_node=VirtualFolderNode.from_row(resolve_info, vfolder_row),
name=name,
author=metadata.get("author") or vfolder_row.creator or "",
title=metadata.get("title") or vfolder_row.name,
Expand Down Expand Up @@ -2451,7 +2483,9 @@ async def get_node(cls, info: graphene.ResolveInfo, id: str) -> ModelCard:

_, vfolder_row_id = AsyncNode.resolve_global_id(info, id)
async with graph_ctx.db.begin_readonly_session() as db_session:
vfolder_row = await VFolderRow.get(db_session, uuid.UUID(vfolder_row_id))
vfolder_row = await VFolderRow.get(
db_session, uuid.UUID(vfolder_row_id), load_user=True, load_group=True
)
if vfolder_row.usage_mode != VFolderUsageMode.MODEL:
raise ValueError(
f"The vfolder is not model. expect: {VFolderUsageMode.MODEL.value}, got:"
Expand Down Expand Up @@ -2522,7 +2556,10 @@ async def get_connection(
VFolderRow.group.in_(model_store_project_gids)
)
query = query.where(additional_cond)
cnt_query = cnt_query.where(additional_cond)
query = query.options(
joinedload(VFolderRow.user_row),
joinedload(VFolderRow.group_row),
)
async with graph_ctx.db.begin_readonly_session() as db_session:
vfolder_rows = (await db_session.scalars(query)).all()
result = [(await cls.from_row(info, vf)) for vf in vfolder_rows]
Expand Down
Loading