diff --git a/changes/3047.fix.md b/changes/3047.fix.md new file mode 100644 index 00000000000..9839cb249bc --- /dev/null +++ b/changes/3047.fix.md @@ -0,0 +1 @@ +Fix regression of outdated `vfolder` GQL resolver. diff --git a/src/ai/backend/manager/models/base.py b/src/ai/backend/manager/models/base.py index 594bef2680d..9cc41bb94d3 100644 --- a/src/ai/backend/manager/models/base.py +++ b/src/ai/backend/manager/models/base.py @@ -922,6 +922,27 @@ async def batch_result_in_session( return [*objs_per_key.values()] +async def batch_result_in_scalar_stream( + graph_ctx: GraphQueryContext, + db_sess: SASession, + query: sa.sql.Select, + obj_type: type[T_SQLBasedGQLObject], + key_list: Iterable[T_Key], + key_getter: Callable[[Row], T_Key], +) -> Sequence[Optional[T_SQLBasedGQLObject]]: + """ + A batched query adaptor for (key -> item) resolving patterns. + stream the result scalar in async session. + """ + objs_per_key: dict[T_Key, Optional[T_SQLBasedGQLObject]] + objs_per_key = {} + for key in key_list: + objs_per_key[key] = None + async for row in await db_sess.stream_scalars(query): + objs_per_key[key_getter(row)] = obj_type.from_row(graph_ctx, row) + return [*objs_per_key.values()] + + async def batch_multiresult_in_session( graph_ctx: GraphQueryContext, db_sess: SASession, diff --git a/src/ai/backend/manager/models/gql.py b/src/ai/backend/manager/models/gql.py index b2338e23b2e..6caa40676e5 100644 --- a/src/ai/backend/manager/models/gql.py +++ b/src/ai/backend/manager/models/gql.py @@ -1756,16 +1756,16 @@ async def resolve_vfolder( user_id: Optional[uuid.UUID] = None, ) -> Optional[VirtualFolder]: graph_ctx: GraphQueryContext = info.context - user_role = graph_ctx.user["role"] + vfolder_id = uuid.UUID(id) loader = graph_ctx.dataloader_manager.get_loader( graph_ctx, "VirtualFolder.by_id", - user_uuid=user_id, - user_role=user_role, domain_name=domain_name, group_id=group_id, + user_id=user_id, + filter=None, ) - return await loader.load(id) + return await loader.load(vfolder_id) @staticmethod @scoped_query(autofill_user=False, user_key="user_id") diff --git a/src/ai/backend/manager/models/vfolder.py b/src/ai/backend/manager/models/vfolder.py index 3b8022a7e70..dc1dc4cb3f4 100644 --- a/src/ai/backend/manager/models/vfolder.py +++ b/src/ai/backend/manager/models/vfolder.py @@ -17,6 +17,7 @@ List, NamedTuple, Optional, + Self, Sequence, TypeAlias, cast, @@ -35,7 +36,7 @@ from sqlalchemy.engine.row import Row from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection from sqlalchemy.ext.asyncio import AsyncSession as SASession -from sqlalchemy.orm import load_only, relationship, selectinload +from sqlalchemy.orm import joinedload, load_only, relationship, selectinload from ai.backend.common.bgtask import ProgressReporter from ai.backend.common.defs import MODEL_VFOLDER_LENGTH_LIMIT @@ -76,6 +77,7 @@ QuotaScopeIDType, StrEnumType, batch_multiresult, + batch_result_in_scalar_stream, metadata, ) from .group import GroupRow @@ -1406,40 +1408,67 @@ class Meta: status = graphene.String() @classmethod - def from_row(cls, ctx: GraphQueryContext, row: Row | VFolderRow) -> Optional[VirtualFolder]: - if row is None: - return None - - def _get_field(name: str) -> Any: - try: - return row[name] - except sa.exc.NoSuchColumnError: + def from_row(cls, ctx: GraphQueryContext, row: Row | VFolderRow | None) -> Optional[Self]: + match row: + case None: return None - - return cls( - id=row["id"], - host=row["host"], - quota_scope_id=row["quota_scope_id"], - name=row["name"], - user=row["user"], - user_email=_get_field("users_email"), - group=row["group"], - group_name=_get_field("groups_name"), - creator=row["creator"], - domain_name=row["domain_name"], - unmanaged_path=row["unmanaged_path"], - usage_mode=row["usage_mode"], - permission=row["permission"], - ownership_type=row["ownership_type"], - max_files=row["max_files"], - max_size=row["max_size"], # in MiB - created_at=row["created_at"], - last_used=row["last_used"], - # num_attached=row['num_attached'], - cloneable=row["cloneable"], - status=row["status"], - cur_size=row["cur_size"], - ) + case VFolderRow(): + return cls( + id=row.id, + host=row.host, + quota_scope_id=row.quota_scope_id, + name=row.name, + user=row.user, + user_email=row.user_row.email if row.user_row is not None else None, + group=row.group, + group_name=row.group_row.name if row.group_row is not None else None, + creator=row.creator, + domain_name=row.domain_name, + unmanaged_path=row.unmanaged_path, + usage_mode=row.usage_mode, + permission=row.permission, + ownership_type=row.ownership_type, + max_files=row.max_files, + max_size=row.max_size, # in MiB + created_at=row.created_at, + last_used=row.last_used, + cloneable=row.cloneable, + status=row.status, + cur_size=row.cur_size, + ) + case Row(): + + def _get_field(name: str) -> Any: + try: + return row[name] + except (KeyError, sa.exc.NoSuchColumnError): + return None + + return cls( + id=row["id"], + host=row["host"], + quota_scope_id=row["quota_scope_id"], + name=row["name"], + user=row["user"], + user_email=_get_field("users_email"), + group=row["group"], + group_name=_get_field("groups_name"), + creator=row["creator"], + domain_name=row["domain_name"], + unmanaged_path=row["unmanaged_path"], + usage_mode=row["usage_mode"], + permission=row["permission"], + ownership_type=row["ownership_type"], + max_files=row["max_files"], + max_size=row["max_size"], # in MiB + created_at=row["created_at"], + last_used=row["last_used"], + # num_attached=row['num_attached'], + cloneable=row["cloneable"], + status=row["status"], + cur_size=row["cur_size"], + ) + raise ValueError(f"Type not allowed to parse (t:{type(row)})") @classmethod def from_orm_row(cls, row: VFolderRow) -> VirtualFolder: @@ -1608,20 +1637,17 @@ async def load_slice( async def batch_load_by_id( cls, graph_ctx: GraphQueryContext, - ids: list[str], + ids: list[uuid.UUID], *, - domain_name: str | None = None, - group_id: uuid.UUID | None = None, - user_id: uuid.UUID | None = None, - filter: str | None = None, - ) -> Sequence[Sequence[VirtualFolder]]: - from .user import UserRow - - j = sa.join(VFolderRow, UserRow, VFolderRow.user == UserRow.uuid) + domain_name: Optional[str] = None, + group_id: Optional[uuid.UUID] = None, + user_id: Optional[uuid.UUID] = None, + filter: Optional[str] = None, + ) -> Sequence[Optional[VirtualFolder]]: query = ( sa.select(VFolderRow) - .select_from(j) .where(VFolderRow.id.in_(ids)) + .options(joinedload(VFolderRow.user_row), joinedload(VFolderRow.group_row)) .order_by(sa.desc(VFolderRow.created_at)) ) if user_id is not None: @@ -1634,13 +1660,13 @@ async def batch_load_by_id( qfparser = QueryFilterParser(cls._queryfilter_fieldspec) query = qfparser.append_filter(query, filter) async with graph_ctx.db.begin_readonly_session() as db_sess: - return await batch_multiresult( + return await batch_result_in_scalar_stream( graph_ctx, db_sess, query, cls, ids, - lambda row: row["user"], + lambda row: row.id, ) @classmethod diff --git a/tests/manager/conftest.py b/tests/manager/conftest.py index 0d7ed0d7da0..1f9cc21ca6b 100644 --- a/tests/manager/conftest.py +++ b/tests/manager/conftest.py @@ -28,6 +28,7 @@ from unittest.mock import AsyncMock, MagicMock from urllib.parse import quote_plus as urlquote +import aiofiles.os import aiohttp import asyncpg import pytest @@ -419,7 +420,12 @@ async def database_engine(local_config, database): @pytest.fixture() -def database_fixture(local_config, test_db, database) -> Iterator[None]: +def extra_fixtures(): + return {} + + +@pytest.fixture() +def database_fixture(local_config, test_db, database, extra_fixtures) -> Iterator[None]: """ Populate the example data as fixtures to the database and delete them after use. @@ -430,12 +436,20 @@ def database_fixture(local_config, test_db, database) -> Iterator[None]: db_url = f"postgresql+asyncpg://{db_user}:{urlquote(db_pass)}@{db_addr}/{test_db}" build_root = Path(os.environ["BACKEND_BUILD_ROOT"]) + + extra_fixture_file = tempfile.NamedTemporaryFile(delete=False) + extra_fixture_file_path = Path(extra_fixture_file.name) + + with open(extra_fixture_file_path, "w") as f: + json.dump(extra_fixtures, f) + fixture_paths = [ build_root / "fixtures" / "manager" / "example-users.json", build_root / "fixtures" / "manager" / "example-keypairs.json", build_root / "fixtures" / "manager" / "example-set-user-main-access-keys.json", build_root / "fixtures" / "manager" / "example-resource-presets.json", build_root / "fixtures" / "manager" / "example-container-registries-harbor.json", + extra_fixture_file_path, ] async def init_fixture() -> None: @@ -460,6 +474,9 @@ async def init_fixture() -> None: yield async def clean_fixture() -> None: + if extra_fixture_file_path.exists(): + await aiofiles.os.remove(extra_fixture_file_path) + engine: SAEngine = sa.ext.asyncio.create_async_engine( db_url, connect_args=pgsql_connect_opts, diff --git a/tests/manager/models/test_base.py b/tests/manager/models/test_base.py new file mode 100644 index 00000000000..de230e7900a --- /dev/null +++ b/tests/manager/models/test_base.py @@ -0,0 +1,40 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from ai.backend.manager.models.base import batch_result_in_scalar_stream + + +@pytest.mark.asyncio +async def test_batch_result_in_scalar_stream(): + key_list = [1, 2, 3] + + mock_rows = [SimpleNamespace(id=1, data="data1"), SimpleNamespace(id=3, data="data3")] + + async def mock_stream_scalars(query): + for row in mock_rows: + yield row + + mock_db_sess = MagicMock() + mock_db_sess.stream_scalars = AsyncMock(side_effect=mock_stream_scalars) + + def mock_from_row(graph_ctx, row): + return {"id": row.id, "data": row.data} + + mock_obj_type = MagicMock() + mock_obj_type.from_row.side_effect = mock_from_row + + key_getter = lambda row: row.id + graph_ctx = None + result = await batch_result_in_scalar_stream( + graph_ctx, + mock_db_sess, + query=None, # We use mocking instead of using query here + obj_type=mock_obj_type, + key_list=key_list, + key_getter=key_getter, + ) + + expected_result = [{"id": 1, "data": "data1"}, None, {"id": 3, "data": "data3"}] + assert result == expected_result diff --git a/tests/manager/models/test_vfolder.py b/tests/manager/models/test_vfolder.py new file mode 100644 index 00000000000..89e9118eb7a --- /dev/null +++ b/tests/manager/models/test_vfolder.py @@ -0,0 +1,221 @@ +import uuid + +import pytest + +from ai.backend.manager.models.gql import GraphQueryContext +from ai.backend.manager.models.utils import ExtendedAsyncSAEngine +from ai.backend.manager.models.vfolder import VirtualFolder +from ai.backend.manager.server import ( + database_ctx, +) + + +def get_graphquery_context(database_engine: ExtendedAsyncSAEngine) -> GraphQueryContext: + return GraphQueryContext( + schema=None, # type: ignore + dataloader_manager=None, # type: ignore + local_config=None, # type: ignore + shared_config=None, # type: ignore + etcd=None, # type: ignore + user={"domain": "default", "role": "superadmin"}, + access_key="AKIAIOSFODNN7EXAMPLE", + db=database_engine, # type: ignore + redis_stat=None, # type: ignore + redis_image=None, # type: ignore + redis_live=None, # type: ignore + manager_status=None, # type: ignore + known_slot_types=None, # type: ignore + background_task_manager=None, # type: ignore + storage_manager=None, # type: ignore + registry=None, # type: ignore + idle_checker_host=None, # type: ignore + network_plugin_ctx=None, # type: ignore + ) + + +FIXTURES = [ + { + "users": [ + { + "uuid": "00000000-0000-0000-0000-000000000000", + "username": "mock_user", + "email": "", + "password": "", + "need_password_change": False, + "full_name": "", + "description": "", + "status": "active", + "status_info": "admin-requested", + "domain_name": "default", + "resource_policy": "default", + "role": "superadmin", + } + ], + "groups": [ + { + "id": "00000000-0000-0000-0000-000000000000", + "name": "mock_group", + "description": "", + "is_active": True, + "domain_name": "default", + "resource_policy": "default", + "total_resource_slots": {}, + "allowed_vfolder_hosts": {}, + "type": "general", + }, + ], + "vfolders": [ + { + "id": "00000000-0000-0000-0000-000000000001", + "host": "mock", + "domain_name": "default", + "name": "mock_vfolder_1", + "quota_scope_id": "user:00000000-0000-0000-0000-000000000000", + "usage_mode": "general", + "permission": "rw", + "ownership_type": "user", + "status": "ready", + "cloneable": False, + "max_files": 0, + "num_files": 0, + "user": "00000000-0000-0000-0000-000000000000", + "group": None, + }, + { + "id": "00000000-0000-0000-0000-000000000002", + "host": "mock", + "domain_name": "default", + "name": "mock_vfolder_2", + "quota_scope_id": "user:00000000-0000-0000-0000-000000000000", + "usage_mode": "general", + "permission": "rw", + "ownership_type": "user", + "status": "ready", + "cloneable": False, + "max_files": 0, + "num_files": 0, + "user": "00000000-0000-0000-0000-000000000000", + "group": None, + }, + { + "id": "00000000-0000-0000-0000-000000000003", + "host": "mock", + "domain_name": "default", + "name": "mock_vfolder_3", + "quota_scope_id": "project:00000000-0000-0000-0000-000000000000", + "usage_mode": "general", + "permission": "rw", + "ownership_type": "group", + "status": "ready", + "cloneable": False, + "max_files": 0, + "num_files": 0, + "user": None, + "group": "00000000-0000-0000-0000-000000000000", + }, + ], + } +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "extra_fixtures", + FIXTURES, +) +@pytest.mark.parametrize( + "test_case", + [ + { + "vfolder_ids": [uuid.UUID("00000000-0000-0000-0000-000000000001")], + "user_id": None, + "group_id": None, + "domain_name": None, + "expected_result": [uuid.UUID("00000000-0000-0000-0000-000000000001")], + }, + { + "vfolder_ids": [ + uuid.UUID("00000000-0000-0000-0000-000000000001"), + uuid.UUID("00000000-0000-0000-0000-000000000002"), + ], + "user_id": None, + "group_id": None, + "domain_name": None, + "expected_result": [ + uuid.UUID("00000000-0000-0000-0000-000000000001"), + uuid.UUID("00000000-0000-0000-0000-000000000002"), + ], + }, + { + "vfolder_ids": [uuid.UUID("00000000-0000-0000-0000-000000000001")], + "user_id": uuid.UUID("00000000-0000-0000-0000-000000000000"), + "group_id": None, + "domain_name": None, + "expected_result": [uuid.UUID("00000000-0000-0000-0000-000000000001")], + }, + { + "vfolder_ids": [uuid.UUID("00000000-0000-0000-0000-000000000001")], + "user_id": uuid.UUID("00000000-0000-0000-0000-000000000000"), + "group_id": None, + "domain_name": "default", + "expected_result": [uuid.UUID("00000000-0000-0000-0000-000000000001")], + }, + { + "vfolder_ids": [uuid.UUID("00000000-0000-0000-0000-000000000001")], + "user_id": uuid.UUID("00000000-0000-0000-0000-000000000000"), + "group_id": None, + "domain_name": "INVALID", + "expected_result": [None], + }, + { + "vfolder_ids": [uuid.UUID("00000000-0000-0000-0000-000000000003")], + "user_id": None, + "group_id": uuid.UUID("00000000-0000-0000-0000-000000000000"), + "domain_name": None, + "expected_result": [uuid.UUID("00000000-0000-0000-0000-000000000003")], + }, + ], + ids=[ + "Batchload a vfolder by id", + "Batchload multiple vfolders by ids", + "Batchload a vfolder by user_id", + "Batchload a vfolder by user_id and domain_name", + "Batchload a vfolder by user_id and invalid domain_name", + "Batchload a vfolder by group_id", + ], +) +async def test_batch_load_by_id( + test_case, + database_fixture, + create_app_and_client, +): + test_app, _ = await create_app_and_client( + [ + database_ctx, + ], + [], + ) + + root_ctx = test_app["_root.context"] + context = get_graphquery_context(root_ctx.db) + + vfolder_ids = test_case["vfolder_ids"] + user_id = test_case["user_id"] + group_id = test_case["group_id"] + domain_name = test_case["domain_name"] + expected_result = test_case["expected_result"] + + result = await VirtualFolder.batch_load_by_id( + context, + vfolder_ids, + user_id=user_id, + group_id=group_id, + domain_name=domain_name, + ) + + assert len(result) == len(expected_result) + for res, expected_id in zip(result, expected_result): + if expected_id is None: + assert res is None + else: + assert res.id == expected_id