From df370438aec0c1b428f17e95dd528c77b47582f8 Mon Sep 17 00:00:00 2001 From: Gyubong Lee Date: Tue, 29 Oct 2024 05:12:07 +0000 Subject: [PATCH] refactor: Implement `fetch_repositories()` per project --- .../manager/container_registry/aws_ecr.py | 4 +- .../manager/container_registry/base.py | 5 +- .../manager/container_registry/docker.py | 6 +- .../manager/container_registry/github.py | 6 +- .../manager/container_registry/gitlab.py | 77 ++++++++--------- .../manager/container_registry/harbor.py | 82 +++++++++---------- .../manager/container_registry/local.py | 5 +- 7 files changed, 92 insertions(+), 93 deletions(-) diff --git a/src/ai/backend/manager/container_registry/aws_ecr.py b/src/ai/backend/manager/container_registry/aws_ecr.py index a29d5aefaf1..0f375c20a7b 100644 --- a/src/ai/backend/manager/container_registry/aws_ecr.py +++ b/src/ai/backend/manager/container_registry/aws_ecr.py @@ -1,5 +1,5 @@ import logging -from typing import AsyncIterator +from typing import AsyncIterator, override import aiohttp import boto3 @@ -14,9 +14,11 @@ class AWSElasticContainerRegistry(BaseContainerRegistry): + @override async def fetch_repositories( self, sess: aiohttp.ClientSession, + project: str | None, ) -> AsyncIterator[str]: access_key, secret_access_key, region, type_ = ( self.registry_info.extra.get("access_key"), diff --git a/src/ai/backend/manager/container_registry/base.py b/src/ai/backend/manager/container_registry/base.py index c2a0240ae02..daa49f7b0a6 100644 --- a/src/ai/backend/manager/container_registry/base.py +++ b/src/ai/backend/manager/container_registry/base.py @@ -115,7 +115,9 @@ async def rescan_single_registry( async with self.prepare_client_session() as (url, client_session): self.registry_url = url async with aiotools.TaskGroup() as tg: - async for image in self.fetch_repositories(client_session): + async for image in self.fetch_repositories( + client_session, self.registry_info.project + ): tg.create_task(self._scan_image(client_session, image)) await self.commit_rescan_result() finally: @@ -554,5 +556,6 @@ async def _read_manifest( async def fetch_repositories( self, sess: aiohttp.ClientSession, + project: str | None, ) -> AsyncIterator[str]: yield "" diff --git a/src/ai/backend/manager/container_registry/docker.py b/src/ai/backend/manager/container_registry/docker.py index e64536b6b41..3bd4971cfc1 100644 --- a/src/ai/backend/manager/container_registry/docker.py +++ b/src/ai/backend/manager/container_registry/docker.py @@ -1,6 +1,6 @@ import json import logging -from typing import AsyncIterator, Optional, cast +from typing import AsyncIterator, Optional, cast, override import aiohttp import typing_extensions @@ -18,9 +18,11 @@ class DockerHubRegistry(BaseContainerRegistry): @typing_extensions.deprecated( "Rescanning a whole Docker Hub account is disabled due to the API rate limit." ) + @override async def fetch_repositories( self, sess: aiohttp.ClientSession, + project: str | None, ) -> AsyncIterator[str]: # We need some special treatment for the Docker Hub. raise DeprecationWarning( @@ -63,9 +65,11 @@ async def fetch_repositories_legacy( class DockerRegistry_v2(BaseContainerRegistry): + @override async def fetch_repositories( self, sess: aiohttp.ClientSession, + project: str | None, ) -> AsyncIterator[str]: # The credential should have the catalog search privilege. rqst_args = await registry_login( diff --git a/src/ai/backend/manager/container_registry/github.py b/src/ai/backend/manager/container_registry/github.py index 08b4216972d..d467495884d 100644 --- a/src/ai/backend/manager/container_registry/github.py +++ b/src/ai/backend/manager/container_registry/github.py @@ -1,5 +1,5 @@ import logging -from typing import AsyncIterator +from typing import AsyncIterator, override import aiohttp @@ -13,9 +13,9 @@ class GitHubRegistry(BaseContainerRegistry): + @override async def fetch_repositories( - self, - sess: aiohttp.ClientSession, + self, sess: aiohttp.ClientSession, project: str | None ) -> AsyncIterator[str]: username = self.registry_info.username access_token = self.registry_info.password diff --git a/src/ai/backend/manager/container_registry/gitlab.py b/src/ai/backend/manager/container_registry/gitlab.py index 02875a608b9..38fd236a5b1 100644 --- a/src/ai/backend/manager/container_registry/gitlab.py +++ b/src/ai/backend/manager/container_registry/gitlab.py @@ -1,12 +1,10 @@ import logging import urllib.parse -from typing import AsyncIterator, cast +from typing import AsyncIterator, override import aiohttp -import sqlalchemy as sa from ai.backend.logging import BraceStyleAdapter -from ai.backend.manager.models.container_registry import ContainerRegistryRow from .base import ( BaseContainerRegistry, @@ -16,49 +14,44 @@ class GitLabRegistry(BaseContainerRegistry): - async def fetch_repositories(self, sess: aiohttp.ClientSession) -> AsyncIterator[str]: + @override + async def fetch_repositories( + self, sess: aiohttp.ClientSession, project: str | None + ) -> AsyncIterator[str]: access_token = self.registry_info.password api_endpoint = self.registry_info.extra.get("api_endpoint", None) if api_endpoint is None: raise RuntimeError('"api_endpoint" is not provided for GitLab registry!') - async with self.db.begin_readonly_session() as db_sess: - result = await db_sess.execute( - sa.select(ContainerRegistryRow.project).where( - ContainerRegistryRow.registry_name == self.registry_info.registry_name - ) - ) - projects = cast(list[str], result.scalars().all()) - - for project in projects: - encoded_project_id = urllib.parse.quote(project, safe="") - repo_list_url = ( - f"{api_endpoint}/api/v4/projects/{encoded_project_id}/registry/repositories" - ) - - headers = { - "Accept": "application/json", - "PRIVATE-TOKEN": access_token, - } - page = 1 - - while True: - async with sess.get( - repo_list_url, - headers=headers, - params={"per_page": 30, "page": page}, - ) as response: - if response.status == 200: - data = await response.json() - - for repo in data: - yield repo["path"] - if "next" in response.headers.get("Link", ""): - page += 1 - else: - break + if project is None: + raise RuntimeError("Project should be provided for GitLab registry!") + + encoded_project_id = urllib.parse.quote(project, safe="") + repo_list_url = f"{api_endpoint}/api/v4/projects/{encoded_project_id}/registry/repositories" + + headers = { + "Accept": "application/json", + "PRIVATE-TOKEN": access_token, + } + page = 1 + + while True: + async with sess.get( + repo_list_url, + headers=headers, + params={"per_page": 30, "page": page}, + ) as response: + if response.status == 200: + data = await response.json() + + for repo in data: + yield repo["path"] + if "next" in response.headers.get("Link", ""): + page += 1 else: - raise RuntimeError( - f"Failed to fetch repositories for project {project}! {response.status} error occurred." - ) + break + else: + raise RuntimeError( + f"Failed to fetch repositories for project {project}! {response.status} error occurred." + ) diff --git a/src/ai/backend/manager/container_registry/harbor.py b/src/ai/backend/manager/container_registry/harbor.py index 054a200c20a..6c3444430e2 100644 --- a/src/ai/backend/manager/container_registry/harbor.py +++ b/src/ai/backend/manager/container_registry/harbor.py @@ -3,18 +3,16 @@ import json import logging import urllib.parse -from typing import Any, AsyncIterator, Mapping, Optional, cast +from typing import Any, AsyncIterator, Mapping, Optional, cast, override import aiohttp import aiohttp.client_exceptions import aiotools -import sqlalchemy as sa import yarl from ai.backend.common.docker import ImageRef, arch_name_aliases from ai.backend.common.docker import login as registry_login from ai.backend.logging import BraceStyleAdapter -from ai.backend.manager.models.container_registry import ContainerRegistryRow from .base import ( BaseContainerRegistry, @@ -26,20 +24,14 @@ class HarborRegistry_v1(BaseContainerRegistry): + @override async def fetch_repositories( self, sess: aiohttp.ClientSession, + project: str | None, ) -> AsyncIterator[str]: api_url = self.registry_url / "api" - async with self.db.begin_readonly_session() as db_sess: - result = await db_sess.execute( - sa.select(ContainerRegistryRow.project).where( - ContainerRegistryRow.registry_name == self.registry_info.registry_name - ) - ) - registry_projects = cast(list[str | None], result.scalars().all()) - rqst_args: dict[str, Any] = {} if self.credentials: rqst_args["auth"] = aiohttp.BasicAuth( @@ -55,7 +47,7 @@ async def fetch_repositories( async with sess.get(project_list_url, allow_redirects=False, **rqst_args) as resp: projects = await resp.json() for item in projects: - if item["name"] in registry_projects: + if item["name"] == project: project_ids.append(item["project_id"]) project_list_url = None next_page_link = resp.links.get("next") @@ -86,6 +78,7 @@ async def fetch_repositories( next_page_url.query ) + @override async def _scan_tag( self, sess: aiohttp.ClientSession, @@ -178,53 +171,50 @@ async def untag( ): # 404 means image is already removed from harbor so we can just safely ignore the exception raise RuntimeError(f"Failed to untag {image}: {e.message}") from e + @override async def fetch_repositories( self, sess: aiohttp.ClientSession, + project: str | None, ) -> AsyncIterator[str]: api_url = self.registry_url / "api" / "v2.0" - async with self.db.begin_readonly_session() as db_sess: - result = await db_sess.execute( - sa.select(ContainerRegistryRow.project).where( - ContainerRegistryRow.registry_name == self.registry_info.registry_name - ) - ) - registry_projects = cast(list[str | None], result.scalars().all()) - rqst_args: dict[str, Any] = {} if self.credentials: rqst_args["auth"] = aiohttp.BasicAuth( self.credentials["username"], self.credentials["password"], ) + repo_list_url: Optional[yarl.URL] - for project_name in registry_projects: - assert project_name is not None - repo_list_url = (api_url / "projects" / project_name / "repositories").with_query( - {"page_size": "30"}, - ) - while repo_list_url is not None: - async with sess.get(repo_list_url, allow_redirects=False, **rqst_args) as resp: - items = await resp.json() - if isinstance(items, dict) and (errors := items.get("errors", [])): - raise RuntimeError( - f"failed to fetch repositories in project {project_name}", - errors[0]["code"], - errors[0]["message"], - ) - repos = [item["name"] for item in items] - for item in repos: - yield item - repo_list_url = None - next_page_link = resp.links.get("next") - if next_page_link: - next_page_url = cast(yarl.URL, next_page_link["url"]) - repo_list_url = self.registry_url.with_path(next_page_url.path).with_query( - next_page_url.query - ) + if project is None: + raise RuntimeError("Project should be provided for Harbor registry!") + + repo_list_url = (api_url / "projects" / project / "repositories").with_query( + {"page_size": "30"}, + ) + while repo_list_url is not None: + async with sess.get(repo_list_url, allow_redirects=False, **rqst_args) as resp: + items = await resp.json() + if isinstance(items, dict) and (errors := items.get("errors", [])): + raise RuntimeError( + f"failed to fetch repositories in project {project}", + errors[0]["code"], + errors[0]["message"], + ) + repos = [item["name"] for item in items] + for item in repos: + yield item + repo_list_url = None + next_page_link = resp.links.get("next") + if next_page_link: + next_page_url = cast(yarl.URL, next_page_link["url"]) + repo_list_url = self.registry_url.with_path(next_page_url.path).with_query( + next_page_url.query + ) + @override async def _scan_image( self, sess: aiohttp.ClientSession, @@ -290,6 +280,7 @@ async def _scan_image( next_page_url.query ) + @override async def _scan_tag( self, sess: aiohttp.ClientSession, @@ -330,6 +321,7 @@ async def _scan_tag( case _ as media_type: raise RuntimeError(f"Unsupported artifact media-type: {media_type}") + @override async def _process_oci_index( self, tg: aiotools.TaskGroup, @@ -366,6 +358,7 @@ async def _process_oci_index( ) ) + @override async def _process_docker_v2_multiplatform_image( self, tg: aiotools.TaskGroup, @@ -404,6 +397,7 @@ async def _process_docker_v2_multiplatform_image( ) ) + @override async def _process_docker_v2_image( self, tg: aiotools.TaskGroup, diff --git a/src/ai/backend/manager/container_registry/local.py b/src/ai/backend/manager/container_registry/local.py index 9aab29987b0..6ab3d5e3a2f 100644 --- a/src/ai/backend/manager/container_registry/local.py +++ b/src/ai/backend/manager/container_registry/local.py @@ -3,7 +3,7 @@ import json import logging from contextlib import asynccontextmanager as actxmgr -from typing import AsyncIterator, Optional +from typing import AsyncIterator, Optional, override import aiohttp import sqlalchemy as sa @@ -29,9 +29,11 @@ async def prepare_client_session(self) -> AsyncIterator[tuple[yarl.URL, aiohttp. async with aiohttp.ClientSession(connector=connector.connector) as sess: yield connector.docker_host, sess + @override async def fetch_repositories( self, sess: aiohttp.ClientSession, + project: str | None, ) -> AsyncIterator[str]: async with sess.get(self.registry_url / "images" / "json") as response: items = await response.json() @@ -48,6 +50,7 @@ async def fetch_repositories( continue yield image_ref_str # this includes the tag part + @override async def _scan_image( self, sess: aiohttp.ClientSession,