Skip to content

Commit

Permalink
fix(BA-420): Regression of outdated vfolder GQL resolver (#3047)
Browse files Browse the repository at this point in the history
Co-authored-by: Sanghun Lee <[email protected]>
Backported-from: main (24.12)
Backported-to: 24.09
Backport-of: 3047
  • Loading branch information
jopemachine and fregataa committed Jan 14, 2025
1 parent 76dc7b7 commit e5972cb
Show file tree
Hide file tree
Showing 7 changed files with 377 additions and 51 deletions.
1 change: 1 addition & 0 deletions changes/3047.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix regression of outdated `vfolder` GQL resolver.
21 changes: 21 additions & 0 deletions src/ai/backend/manager/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,6 +922,27 @@ async def batch_result_in_session(
return [*objs_per_key.values()]


async def batch_result_in_scalar_stream(
graph_ctx: GraphQueryContext,
db_sess: SASession,
query: sa.sql.Select,
obj_type: type[T_SQLBasedGQLObject],
key_list: Iterable[T_Key],
key_getter: Callable[[Row], T_Key],
) -> Sequence[Optional[T_SQLBasedGQLObject]]:
"""
A batched query adaptor for (key -> item) resolving patterns.
stream the result scalar in async session.
"""
objs_per_key: dict[T_Key, Optional[T_SQLBasedGQLObject]]
objs_per_key = {}
for key in key_list:
objs_per_key[key] = None
async for row in await db_sess.stream_scalars(query):
objs_per_key[key_getter(row)] = obj_type.from_row(graph_ctx, row)
return [*objs_per_key.values()]


async def batch_multiresult_in_session(
graph_ctx: GraphQueryContext,
db_sess: SASession,
Expand Down
8 changes: 4 additions & 4 deletions src/ai/backend/manager/models/gql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1756,16 +1756,16 @@ async def resolve_vfolder(
user_id: Optional[uuid.UUID] = None,
) -> Optional[VirtualFolder]:
graph_ctx: GraphQueryContext = info.context
user_role = graph_ctx.user["role"]
vfolder_id = uuid.UUID(id)
loader = graph_ctx.dataloader_manager.get_loader(
graph_ctx,
"VirtualFolder.by_id",
user_uuid=user_id,
user_role=user_role,
domain_name=domain_name,
group_id=group_id,
user_id=user_id,
filter=None,
)
return await loader.load(id)
return await loader.load(vfolder_id)

@staticmethod
@scoped_query(autofill_user=False, user_key="user_id")
Expand Down
118 changes: 72 additions & 46 deletions src/ai/backend/manager/models/vfolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
List,
NamedTuple,
Optional,
Self,
Sequence,
TypeAlias,
cast,
Expand All @@ -35,7 +36,7 @@
from sqlalchemy.engine.row import Row
from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection
from sqlalchemy.ext.asyncio import AsyncSession as SASession
from sqlalchemy.orm import load_only, relationship, selectinload
from sqlalchemy.orm import joinedload, load_only, relationship, selectinload

from ai.backend.common.bgtask import ProgressReporter
from ai.backend.common.defs import MODEL_VFOLDER_LENGTH_LIMIT
Expand Down Expand Up @@ -76,6 +77,7 @@
QuotaScopeIDType,
StrEnumType,
batch_multiresult,
batch_result_in_scalar_stream,
metadata,
)
from .group import GroupRow
Expand Down Expand Up @@ -1406,40 +1408,67 @@ class Meta:
status = graphene.String()

@classmethod
def from_row(cls, ctx: GraphQueryContext, row: Row | VFolderRow) -> Optional[VirtualFolder]:
if row is None:
return None

def _get_field(name: str) -> Any:
try:
return row[name]
except sa.exc.NoSuchColumnError:
def from_row(cls, ctx: GraphQueryContext, row: Row | VFolderRow | None) -> Optional[Self]:
match row:
case None:
return None

return cls(
id=row["id"],
host=row["host"],
quota_scope_id=row["quota_scope_id"],
name=row["name"],
user=row["user"],
user_email=_get_field("users_email"),
group=row["group"],
group_name=_get_field("groups_name"),
creator=row["creator"],
domain_name=row["domain_name"],
unmanaged_path=row["unmanaged_path"],
usage_mode=row["usage_mode"],
permission=row["permission"],
ownership_type=row["ownership_type"],
max_files=row["max_files"],
max_size=row["max_size"], # in MiB
created_at=row["created_at"],
last_used=row["last_used"],
# num_attached=row['num_attached'],
cloneable=row["cloneable"],
status=row["status"],
cur_size=row["cur_size"],
)
case VFolderRow():
return cls(
id=row.id,
host=row.host,
quota_scope_id=row.quota_scope_id,
name=row.name,
user=row.user,
user_email=row.user_row.email if row.user_row is not None else None,
group=row.group,
group_name=row.group_row.name if row.group_row is not None else None,
creator=row.creator,
domain_name=row.domain_name,
unmanaged_path=row.unmanaged_path,
usage_mode=row.usage_mode,
permission=row.permission,
ownership_type=row.ownership_type,
max_files=row.max_files,
max_size=row.max_size, # in MiB
created_at=row.created_at,
last_used=row.last_used,
cloneable=row.cloneable,
status=row.status,
cur_size=row.cur_size,
)
case Row():

def _get_field(name: str) -> Any:
try:
return row[name]
except (KeyError, sa.exc.NoSuchColumnError):
return None

return cls(
id=row["id"],
host=row["host"],
quota_scope_id=row["quota_scope_id"],
name=row["name"],
user=row["user"],
user_email=_get_field("users_email"),
group=row["group"],
group_name=_get_field("groups_name"),
creator=row["creator"],
domain_name=row["domain_name"],
unmanaged_path=row["unmanaged_path"],
usage_mode=row["usage_mode"],
permission=row["permission"],
ownership_type=row["ownership_type"],
max_files=row["max_files"],
max_size=row["max_size"], # in MiB
created_at=row["created_at"],
last_used=row["last_used"],
# num_attached=row['num_attached'],
cloneable=row["cloneable"],
status=row["status"],
cur_size=row["cur_size"],
)
raise ValueError(f"Type not allowed to parse (t:{type(row)})")

@classmethod
def from_orm_row(cls, row: VFolderRow) -> VirtualFolder:
Expand Down Expand Up @@ -1608,20 +1637,17 @@ async def load_slice(
async def batch_load_by_id(
cls,
graph_ctx: GraphQueryContext,
ids: list[str],
ids: list[uuid.UUID],
*,
domain_name: str | None = None,
group_id: uuid.UUID | None = None,
user_id: uuid.UUID | None = None,
filter: str | None = None,
) -> Sequence[Sequence[VirtualFolder]]:
from .user import UserRow

j = sa.join(VFolderRow, UserRow, VFolderRow.user == UserRow.uuid)
domain_name: Optional[str] = None,
group_id: Optional[uuid.UUID] = None,
user_id: Optional[uuid.UUID] = None,
filter: Optional[str] = None,
) -> Sequence[Optional[VirtualFolder]]:
query = (
sa.select(VFolderRow)
.select_from(j)
.where(VFolderRow.id.in_(ids))
.options(joinedload(VFolderRow.user_row), joinedload(VFolderRow.group_row))
.order_by(sa.desc(VFolderRow.created_at))
)
if user_id is not None:
Expand All @@ -1634,13 +1660,13 @@ async def batch_load_by_id(
qfparser = QueryFilterParser(cls._queryfilter_fieldspec)
query = qfparser.append_filter(query, filter)
async with graph_ctx.db.begin_readonly_session() as db_sess:
return await batch_multiresult(
return await batch_result_in_scalar_stream(
graph_ctx,
db_sess,
query,
cls,
ids,
lambda row: row["user"],
lambda row: row.id,
)

@classmethod
Expand Down
19 changes: 18 additions & 1 deletion tests/manager/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from unittest.mock import AsyncMock, MagicMock
from urllib.parse import quote_plus as urlquote

import aiofiles.os
import aiohttp
import asyncpg
import pytest
Expand Down Expand Up @@ -419,7 +420,12 @@ async def database_engine(local_config, database):


@pytest.fixture()
def database_fixture(local_config, test_db, database) -> Iterator[None]:
def extra_fixtures():
return {}


@pytest.fixture()
def database_fixture(local_config, test_db, database, extra_fixtures) -> Iterator[None]:
"""
Populate the example data as fixtures to the database
and delete them after use.
Expand All @@ -430,12 +436,20 @@ def database_fixture(local_config, test_db, database) -> Iterator[None]:
db_url = f"postgresql+asyncpg://{db_user}:{urlquote(db_pass)}@{db_addr}/{test_db}"

build_root = Path(os.environ["BACKEND_BUILD_ROOT"])

extra_fixture_file = tempfile.NamedTemporaryFile(delete=False)
extra_fixture_file_path = Path(extra_fixture_file.name)

with open(extra_fixture_file_path, "w") as f:
json.dump(extra_fixtures, f)

fixture_paths = [
build_root / "fixtures" / "manager" / "example-users.json",
build_root / "fixtures" / "manager" / "example-keypairs.json",
build_root / "fixtures" / "manager" / "example-set-user-main-access-keys.json",
build_root / "fixtures" / "manager" / "example-resource-presets.json",
build_root / "fixtures" / "manager" / "example-container-registries-harbor.json",
extra_fixture_file_path,
]

async def init_fixture() -> None:
Expand All @@ -460,6 +474,9 @@ async def init_fixture() -> None:
yield

async def clean_fixture() -> None:
if extra_fixture_file_path.exists():
await aiofiles.os.remove(extra_fixture_file_path)

engine: SAEngine = sa.ext.asyncio.create_async_engine(
db_url,
connect_args=pgsql_connect_opts,
Expand Down
40 changes: 40 additions & 0 deletions tests/manager/models/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock

import pytest

from ai.backend.manager.models.base import batch_result_in_scalar_stream


@pytest.mark.asyncio
async def test_batch_result_in_scalar_stream():
key_list = [1, 2, 3]

mock_rows = [SimpleNamespace(id=1, data="data1"), SimpleNamespace(id=3, data="data3")]

async def mock_stream_scalars(query):
for row in mock_rows:
yield row

mock_db_sess = MagicMock()
mock_db_sess.stream_scalars = AsyncMock(side_effect=mock_stream_scalars)

def mock_from_row(graph_ctx, row):
return {"id": row.id, "data": row.data}

mock_obj_type = MagicMock()
mock_obj_type.from_row.side_effect = mock_from_row

key_getter = lambda row: row.id
graph_ctx = None
result = await batch_result_in_scalar_stream(
graph_ctx,
mock_db_sess,
query=None, # We use mocking instead of using query here
obj_type=mock_obj_type,
key_list=key_list,
key_getter=key_getter,
)

expected_result = [{"id": 1, "data": "data1"}, None, {"id": 3, "data": "data3"}]
assert result == expected_result
Loading

0 comments on commit e5972cb

Please sign in to comment.