From 067559724e6754c823fd3ffdab80b791daf9224f Mon Sep 17 00:00:00 2001 From: Gyubong Lee Date: Mon, 11 Nov 2024 02:47:05 +0000 Subject: [PATCH] feat: Implement `build_ctx_in_project_scope` using `AssociationContainerRegistriesGroups` --- src/ai/backend/manager/models/image.py | 46 +++++++++++++++++--------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/src/ai/backend/manager/models/image.py b/src/ai/backend/manager/models/image.py index af01897b648..d66fa677d2d 100644 --- a/src/ai/backend/manager/models/image.py +++ b/src/ai/backend/manager/models/image.py @@ -37,6 +37,9 @@ ResourceSlot, ) from ai.backend.logging import BraceStyleAdapter +from ai.backend.manager.models.association_container_registries_groups import ( + AssociationContainerRegistriesGroupsRow, +) from ai.backend.manager.models.container_registry import ContainerRegistryRow from ..api.exceptions import ImageNotFound @@ -767,32 +770,43 @@ async def _in_project_scope( ctx: ClientContext, scope: ProjectScope, ) -> ImagePermissionContext: - from .domain import DomainRow from .group import GroupRow permissions = await self.calculate_permission(ctx, scope) image_id_permission_map: dict[UUID, frozenset[ImagePermission]] = {} - _group_query_stmt = ( - sa.select(GroupRow) - .where(GroupRow.id == scope.project_id) - .options(joinedload(GroupRow.domain)) - ) - group_row = cast(Optional[GroupRow], await self.db_session.scalar(_group_query_stmt)) + group_query_stmt = sa.select(GroupRow).where(GroupRow.id == scope.project_id) + group_row = cast(Optional[GroupRow], await self.db_session.scalar(group_query_stmt)) if group_row is None: raise InvalidScope(f"Project not found (n:{scope.project_id})") - _domain_query_stmt = sa.select(DomainRow).where(DomainRow.name == group_row.domain.name) - domain_row = cast(Optional[DomainRow], await self.db_session.scalar(_domain_query_stmt)) - if domain_row is None: - raise InvalidScope(f"Domain not found (n:{scope.domain_name})") + image_select_stmt = ( + sa.select(ImageRow) + .select_from( + sa.join( + ImageRow, ContainerRegistryRow, ImageRow.registry_id == ContainerRegistryRow.id + ) + ) + .where( + sa.or_( + ContainerRegistryRow.is_global, + sa.and_( + not ContainerRegistryRow.is_global, + sa.exists().where( + (AssociationContainerRegistriesGroupsRow.group_id == scope.project_id) + & ( + AssociationContainerRegistriesGroupsRow.container_registry_id + == ImageRow.registry_id + ) + ), + ), + ) + ) + ) - allowed_registries: set[str] = set(domain_row.allowed_docker_registries) - _img_query_stmt = sa.select(ImageRow).options(load_only(ImageRow.id, ImageRow.registry)) - for row in await self.db_session.scalars(_img_query_stmt): + for row in await self.db_session.scalars(image_select_stmt): _row = cast(ImageRow, row) - if _row.registry in allowed_registries: - image_id_permission_map[_row.id] = permissions + image_id_permission_map[_row.id] = permissions return ImagePermissionContext( object_id_to_additional_permission_map=image_id_permission_map