Skip to content

Commit

Permalink
refactor: Implement fetch_repositories() per project
Browse files Browse the repository at this point in the history
  • Loading branch information
jopemachine authored and kyujin-cho committed Oct 31, 2024
1 parent f36c70a commit df37043
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 93 deletions.
4 changes: 3 additions & 1 deletion src/ai/backend/manager/container_registry/aws_ecr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import AsyncIterator
from typing import AsyncIterator, override

import aiohttp
import boto3
Expand All @@ -14,9 +14,11 @@


class AWSElasticContainerRegistry(BaseContainerRegistry):
@override
async def fetch_repositories(
self,
sess: aiohttp.ClientSession,
project: str | None,
) -> AsyncIterator[str]:
access_key, secret_access_key, region, type_ = (
self.registry_info.extra.get("access_key"),
Expand Down
5 changes: 4 additions & 1 deletion src/ai/backend/manager/container_registry/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ async def rescan_single_registry(
async with self.prepare_client_session() as (url, client_session):
self.registry_url = url
async with aiotools.TaskGroup() as tg:
async for image in self.fetch_repositories(client_session):
async for image in self.fetch_repositories(
client_session, self.registry_info.project
):
tg.create_task(self._scan_image(client_session, image))
await self.commit_rescan_result()
finally:
Expand Down Expand Up @@ -554,5 +556,6 @@ async def _read_manifest(
async def fetch_repositories(
self,
sess: aiohttp.ClientSession,
project: str | None,
) -> AsyncIterator[str]:
yield ""
6 changes: 5 additions & 1 deletion src/ai/backend/manager/container_registry/docker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import logging
from typing import AsyncIterator, Optional, cast
from typing import AsyncIterator, Optional, cast, override

import aiohttp
import typing_extensions
Expand All @@ -18,9 +18,11 @@ class DockerHubRegistry(BaseContainerRegistry):
@typing_extensions.deprecated(
"Rescanning a whole Docker Hub account is disabled due to the API rate limit."
)
@override
async def fetch_repositories(
self,
sess: aiohttp.ClientSession,
project: str | None,
) -> AsyncIterator[str]:
# We need some special treatment for the Docker Hub.
raise DeprecationWarning(
Expand Down Expand Up @@ -63,9 +65,11 @@ async def fetch_repositories_legacy(


class DockerRegistry_v2(BaseContainerRegistry):
@override
async def fetch_repositories(
self,
sess: aiohttp.ClientSession,
project: str | None,
) -> AsyncIterator[str]:
# The credential should have the catalog search privilege.
rqst_args = await registry_login(
Expand Down
6 changes: 3 additions & 3 deletions src/ai/backend/manager/container_registry/github.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import AsyncIterator
from typing import AsyncIterator, override

import aiohttp

Expand All @@ -13,9 +13,9 @@


class GitHubRegistry(BaseContainerRegistry):
@override
async def fetch_repositories(
self,
sess: aiohttp.ClientSession,
self, sess: aiohttp.ClientSession, project: str | None
) -> AsyncIterator[str]:
username = self.registry_info.username
access_token = self.registry_info.password
Expand Down
77 changes: 35 additions & 42 deletions src/ai/backend/manager/container_registry/gitlab.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import logging
import urllib.parse
from typing import AsyncIterator, cast
from typing import AsyncIterator, override

import aiohttp
import sqlalchemy as sa

from ai.backend.logging import BraceStyleAdapter
from ai.backend.manager.models.container_registry import ContainerRegistryRow

from .base import (
BaseContainerRegistry,
Expand All @@ -16,49 +14,44 @@


class GitLabRegistry(BaseContainerRegistry):
async def fetch_repositories(self, sess: aiohttp.ClientSession) -> AsyncIterator[str]:
@override
async def fetch_repositories(
self, sess: aiohttp.ClientSession, project: str | None
) -> AsyncIterator[str]:
access_token = self.registry_info.password
api_endpoint = self.registry_info.extra.get("api_endpoint", None)

if api_endpoint is None:
raise RuntimeError('"api_endpoint" is not provided for GitLab registry!')

async with self.db.begin_readonly_session() as db_sess:
result = await db_sess.execute(
sa.select(ContainerRegistryRow.project).where(
ContainerRegistryRow.registry_name == self.registry_info.registry_name
)
)
projects = cast(list[str], result.scalars().all())

for project in projects:
encoded_project_id = urllib.parse.quote(project, safe="")
repo_list_url = (
f"{api_endpoint}/api/v4/projects/{encoded_project_id}/registry/repositories"
)

headers = {
"Accept": "application/json",
"PRIVATE-TOKEN": access_token,
}
page = 1

while True:
async with sess.get(
repo_list_url,
headers=headers,
params={"per_page": 30, "page": page},
) as response:
if response.status == 200:
data = await response.json()

for repo in data:
yield repo["path"]
if "next" in response.headers.get("Link", ""):
page += 1
else:
break
if project is None:
raise RuntimeError("Project should be provided for GitLab registry!")

encoded_project_id = urllib.parse.quote(project, safe="")
repo_list_url = f"{api_endpoint}/api/v4/projects/{encoded_project_id}/registry/repositories"

headers = {
"Accept": "application/json",
"PRIVATE-TOKEN": access_token,
}
page = 1

while True:
async with sess.get(
repo_list_url,
headers=headers,
params={"per_page": 30, "page": page},
) as response:
if response.status == 200:
data = await response.json()

for repo in data:
yield repo["path"]
if "next" in response.headers.get("Link", ""):
page += 1
else:
raise RuntimeError(
f"Failed to fetch repositories for project {project}! {response.status} error occurred."
)
break
else:
raise RuntimeError(
f"Failed to fetch repositories for project {project}! {response.status} error occurred."
)
82 changes: 38 additions & 44 deletions src/ai/backend/manager/container_registry/harbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,16 @@
import json
import logging
import urllib.parse
from typing import Any, AsyncIterator, Mapping, Optional, cast
from typing import Any, AsyncIterator, Mapping, Optional, cast, override

import aiohttp
import aiohttp.client_exceptions
import aiotools
import sqlalchemy as sa
import yarl

from ai.backend.common.docker import ImageRef, arch_name_aliases
from ai.backend.common.docker import login as registry_login
from ai.backend.logging import BraceStyleAdapter
from ai.backend.manager.models.container_registry import ContainerRegistryRow

from .base import (
BaseContainerRegistry,
Expand All @@ -26,20 +24,14 @@


class HarborRegistry_v1(BaseContainerRegistry):
@override
async def fetch_repositories(
self,
sess: aiohttp.ClientSession,
project: str | None,
) -> AsyncIterator[str]:
api_url = self.registry_url / "api"

async with self.db.begin_readonly_session() as db_sess:
result = await db_sess.execute(
sa.select(ContainerRegistryRow.project).where(
ContainerRegistryRow.registry_name == self.registry_info.registry_name
)
)
registry_projects = cast(list[str | None], result.scalars().all())

rqst_args: dict[str, Any] = {}
if self.credentials:
rqst_args["auth"] = aiohttp.BasicAuth(
Expand All @@ -55,7 +47,7 @@ async def fetch_repositories(
async with sess.get(project_list_url, allow_redirects=False, **rqst_args) as resp:
projects = await resp.json()
for item in projects:
if item["name"] in registry_projects:
if item["name"] == project:
project_ids.append(item["project_id"])
project_list_url = None
next_page_link = resp.links.get("next")
Expand Down Expand Up @@ -86,6 +78,7 @@ async def fetch_repositories(
next_page_url.query
)

@override
async def _scan_tag(
self,
sess: aiohttp.ClientSession,
Expand Down Expand Up @@ -178,53 +171,50 @@ async def untag(
): # 404 means image is already removed from harbor so we can just safely ignore the exception
raise RuntimeError(f"Failed to untag {image}: {e.message}") from e

@override
async def fetch_repositories(
self,
sess: aiohttp.ClientSession,
project: str | None,
) -> AsyncIterator[str]:
api_url = self.registry_url / "api" / "v2.0"

async with self.db.begin_readonly_session() as db_sess:
result = await db_sess.execute(
sa.select(ContainerRegistryRow.project).where(
ContainerRegistryRow.registry_name == self.registry_info.registry_name
)
)
registry_projects = cast(list[str | None], result.scalars().all())

rqst_args: dict[str, Any] = {}
if self.credentials:
rqst_args["auth"] = aiohttp.BasicAuth(
self.credentials["username"],
self.credentials["password"],
)

repo_list_url: Optional[yarl.URL]
for project_name in registry_projects:
assert project_name is not None

repo_list_url = (api_url / "projects" / project_name / "repositories").with_query(
{"page_size": "30"},
)
while repo_list_url is not None:
async with sess.get(repo_list_url, allow_redirects=False, **rqst_args) as resp:
items = await resp.json()
if isinstance(items, dict) and (errors := items.get("errors", [])):
raise RuntimeError(
f"failed to fetch repositories in project {project_name}",
errors[0]["code"],
errors[0]["message"],
)
repos = [item["name"] for item in items]
for item in repos:
yield item
repo_list_url = None
next_page_link = resp.links.get("next")
if next_page_link:
next_page_url = cast(yarl.URL, next_page_link["url"])
repo_list_url = self.registry_url.with_path(next_page_url.path).with_query(
next_page_url.query
)
if project is None:
raise RuntimeError("Project should be provided for Harbor registry!")

repo_list_url = (api_url / "projects" / project / "repositories").with_query(
{"page_size": "30"},
)
while repo_list_url is not None:
async with sess.get(repo_list_url, allow_redirects=False, **rqst_args) as resp:
items = await resp.json()
if isinstance(items, dict) and (errors := items.get("errors", [])):
raise RuntimeError(
f"failed to fetch repositories in project {project}",
errors[0]["code"],
errors[0]["message"],
)
repos = [item["name"] for item in items]
for item in repos:
yield item
repo_list_url = None
next_page_link = resp.links.get("next")
if next_page_link:
next_page_url = cast(yarl.URL, next_page_link["url"])
repo_list_url = self.registry_url.with_path(next_page_url.path).with_query(
next_page_url.query
)

@override
async def _scan_image(
self,
sess: aiohttp.ClientSession,
Expand Down Expand Up @@ -290,6 +280,7 @@ async def _scan_image(
next_page_url.query
)

@override
async def _scan_tag(
self,
sess: aiohttp.ClientSession,
Expand Down Expand Up @@ -330,6 +321,7 @@ async def _scan_tag(
case _ as media_type:
raise RuntimeError(f"Unsupported artifact media-type: {media_type}")

@override
async def _process_oci_index(
self,
tg: aiotools.TaskGroup,
Expand Down Expand Up @@ -366,6 +358,7 @@ async def _process_oci_index(
)
)

@override
async def _process_docker_v2_multiplatform_image(
self,
tg: aiotools.TaskGroup,
Expand Down Expand Up @@ -404,6 +397,7 @@ async def _process_docker_v2_multiplatform_image(
)
)

@override
async def _process_docker_v2_image(
self,
tg: aiotools.TaskGroup,
Expand Down
5 changes: 4 additions & 1 deletion src/ai/backend/manager/container_registry/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import logging
from contextlib import asynccontextmanager as actxmgr
from typing import AsyncIterator, Optional
from typing import AsyncIterator, Optional, override

import aiohttp
import sqlalchemy as sa
Expand All @@ -29,9 +29,11 @@ async def prepare_client_session(self) -> AsyncIterator[tuple[yarl.URL, aiohttp.
async with aiohttp.ClientSession(connector=connector.connector) as sess:
yield connector.docker_host, sess

@override
async def fetch_repositories(
self,
sess: aiohttp.ClientSession,
project: str | None,
) -> AsyncIterator[str]:
async with sess.get(self.registry_url / "images" / "json") as response:
items = await response.json()
Expand All @@ -48,6 +50,7 @@ async def fetch_repositories(
continue
yield image_ref_str # this includes the tag part

@override
async def _scan_image(
self,
sess: aiohttp.ClientSession,
Expand Down

0 comments on commit df37043

Please sign in to comment.