Skip to content

Commit

Permalink
fix: Cannot resolve model_card GQL query (#2161)
Browse files Browse the repository at this point in the history
Co-authored-by: Joongi Kim <[email protected]>
  • Loading branch information
fregataa and achimnol authored Jul 1, 2024
1 parent 1065ed0 commit 266f128
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 5 deletions.
1 change: 1 addition & 0 deletions changes/2161.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix buggy resolver of `model_card` GQL Query.
3 changes: 3 additions & 0 deletions src/ai/backend/manager/api/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,9 @@ type ModelCard implements Node {
id: ID!
name: String
vfolder: VirtualFolder

"""Added in 24.09.0."""
vfolder_node: VirtualFolderNode
author: String

"""Human readable name of the model."""
Expand Down
47 changes: 42 additions & 5 deletions src/ai/backend/manager/models/vfolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from sqlalchemy.engine.row import Row
from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection
from sqlalchemy.ext.asyncio import AsyncSession as SASession
from sqlalchemy.orm import load_only, relationship, selectinload
from sqlalchemy.orm import joinedload, load_only, relationship, selectinload

from ai.backend.common.bgtask import ProgressReporter
from ai.backend.common.config import model_definition_iv
Expand Down Expand Up @@ -1347,6 +1347,31 @@ def _get_field(name: str) -> Any:
cur_size=row["cur_size"],
)

@classmethod
def from_orm_row(cls, row: VFolderRow) -> VirtualFolder:
return cls(
id=row.id,
host=row.host,
quota_scope_id=row.quota_scope_id,
name=row.name,
user=row.user,
user_email=row.user_row.email if row.user_row is not None else None,
group=row.group,
group_name=row.group_row.name if row.group_row is not None else None,
creator=row.creator,
unmanaged_path=row.unmanaged_path,
usage_mode=row.usage_mode,
permission=row.permission,
ownership_type=row.ownership_type,
max_files=row.max_files,
max_size=row.max_size,
created_at=row.created_at,
last_used=row.last_used,
cloneable=row.cloneable,
status=row.status,
cur_size=row.cur_size,
)

async def resolve_num_files(self, info: graphene.ResolveInfo) -> int:
# TODO: measure on-the-fly
return 0
Expand Down Expand Up @@ -1912,6 +1937,10 @@ async def get_connection(
last=last,
)

query = query.options(
joinedload(VFolderRow.user_row),
joinedload(VFolderRow.group_row),
)
async with graph_ctx.db.begin_readonly_session() as db_session:
vfolder_rows = (await db_session.scalars(query)).all()
result = [(cls.from_row(info, vf)) for vf in vfolder_rows]
Expand Down Expand Up @@ -2222,6 +2251,7 @@ class Meta:

name = graphene.String()
vfolder = graphene.Field(VirtualFolder)
vfolder_node = graphene.Field(VirtualFolderNode, description="Added in 24.09.0.")
author = graphene.String()
title = graphene.String(description="Human readable name of the model.")
version = graphene.String()
Expand Down Expand Up @@ -2304,7 +2334,7 @@ def resolve_created_at(
) -> datetime:
try:
return dtparse(self.created_at)
except ParserError:
except (TypeError, ParserError):
return self.created_at

def resolve_modified_at(
Expand All @@ -2313,7 +2343,7 @@ def resolve_modified_at(
) -> datetime:
try:
return dtparse(self.modified_at)
except ParserError:
except (TypeError, ParserError):
return self.modified_at

@classmethod
Expand All @@ -2338,6 +2368,8 @@ def parse_model(
name = vfolder_row.name
return cls(
id=vfolder_row.id,
vfolder=VirtualFolder.from_orm_row(vfolder_row),
vfolder_node=VirtualFolderNode.from_row(resolve_info, vfolder_row),
name=name,
author=metadata.get("author") or vfolder_row.creator or "",
title=metadata.get("title") or vfolder_row.name,
Expand Down Expand Up @@ -2451,7 +2483,9 @@ async def get_node(cls, info: graphene.ResolveInfo, id: str) -> ModelCard:

_, vfolder_row_id = AsyncNode.resolve_global_id(info, id)
async with graph_ctx.db.begin_readonly_session() as db_session:
vfolder_row = await VFolderRow.get(db_session, uuid.UUID(vfolder_row_id))
vfolder_row = await VFolderRow.get(
db_session, uuid.UUID(vfolder_row_id), load_user=True, load_group=True
)
if vfolder_row.usage_mode != VFolderUsageMode.MODEL:
raise ValueError(
f"The vfolder is not model. expect: {VFolderUsageMode.MODEL.value}, got:"
Expand Down Expand Up @@ -2522,7 +2556,10 @@ async def get_connection(
VFolderRow.group.in_(model_store_project_gids)
)
query = query.where(additional_cond)
cnt_query = cnt_query.where(additional_cond)
query = query.options(
joinedload(VFolderRow.user_row),
joinedload(VFolderRow.group_row),
)
async with graph_ctx.db.begin_readonly_session() as db_session:
vfolder_rows = (await db_session.scalars(query)).all()
result = [(await cls.from_row(info, vf)) for vf in vfolder_rows]
Expand Down

0 comments on commit 266f128

Please sign in to comment.