From 5f870430225d2dbca2fc39dd6ea05601c835306f Mon Sep 17 00:00:00 2001 From: Joongi Kim Date: Sun, 12 Nov 2023 23:07:47 -0700 Subject: [PATCH] feat: Per-image rescan with improved multi-arch support (#1712) --- changes/1712.feature.md | 1 + src/ai/backend/manager/cli/context.py | 48 ++++- src/ai/backend/manager/cli/etcd.py | 75 ++----- src/ai/backend/manager/cli/image.py | 8 +- src/ai/backend/manager/cli/image_impl.py | 47 +---- .../manager/container_registry/base.py | 193 +++++++++++------ .../manager/container_registry/docker.py | 13 ++ .../manager/container_registry/harbor.py | 195 ++++++++++++++++-- .../manager/container_registry/local.py | 72 ++++--- src/ai/backend/manager/models/image.py | 51 +++-- 10 files changed, 474 insertions(+), 229 deletions(-) create mode 100644 changes/1712.feature.md diff --git a/changes/1712.feature.md b/changes/1712.feature.md new file mode 100644 index 0000000000..614001654b --- /dev/null +++ b/changes/1712.feature.md @@ -0,0 +1 @@ +Implement per-image metadata sync in the `mgr image rescan` command and deprecate scanning a whole Docker Hub account to avoid the API rate limit diff --git a/src/ai/backend/manager/cli/context.py b/src/ai/backend/manager/cli/context.py index 2c54a674f3..5245e7fdb9 100644 --- a/src/ai/backend/manager/cli/context.py +++ b/src/ai/backend/manager/cli/context.py @@ -12,12 +12,12 @@ from ai.backend.common import redis_helper from ai.backend.common.config import redis_config_iv from ai.backend.common.defs import REDIS_IMAGE_DB, REDIS_LIVE_DB, REDIS_STAT_DB, REDIS_STREAM_DB +from ai.backend.common.etcd import AsyncEtcd, ConfigScopes from ai.backend.common.exception import ConfigurationError from ai.backend.common.logging import AbstractLogger, LocalLogger from ai.backend.common.types import RedisConnectionInfo -from ai.backend.manager.config import SharedConfig -from ..config import LocalConfig +from ..config import LocalConfig, SharedConfig from ..config import load as load_config @@ -61,6 +61,50 @@ def __exit__(self, *exc_info) -> None: self._logger.__exit__() +@contextlib.asynccontextmanager +async def etcd_ctx(cli_ctx: CLIContext) -> AsyncIterator[AsyncEtcd]: + local_config = cli_ctx.local_config + creds = None + if local_config["etcd"]["user"]: + creds = { + "user": local_config["etcd"]["user"], + "password": local_config["etcd"]["password"], + } + scope_prefix_map = { + ConfigScopes.GLOBAL: "", + # TODO: provide a way to specify other scope prefixes + } + etcd = AsyncEtcd( + local_config["etcd"]["addr"], + local_config["etcd"]["namespace"], + scope_prefix_map, + credentials=creds, + ) + try: + yield etcd + finally: + await etcd.close() + + +@contextlib.asynccontextmanager +async def config_ctx(cli_ctx: CLIContext) -> AsyncIterator[SharedConfig]: + local_config = cli_ctx.local_config + # scope_prefix_map is created inside ConfigServer + shared_config = SharedConfig( + local_config["etcd"]["addr"], + local_config["etcd"]["user"], + local_config["etcd"]["password"], + local_config["etcd"]["namespace"], + ) + await shared_config.reload() + raw_redis_config = await shared_config.etcd.get_prefix("config/redis") + local_config["redis"] = redis_config_iv.check(raw_redis_config) + try: + yield shared_config + finally: + await shared_config.close() + + @attrs.define(auto_attribs=True, frozen=True, slots=True) class RedisConnectionSet: live: RedisConnectionInfo diff --git a/src/ai/backend/manager/cli/etcd.py b/src/ai/backend/manager/cli/etcd.py index 897c85c212..36559f5c9e 100644 --- a/src/ai/backend/manager/cli/etcd.py +++ b/src/ai/backend/manager/cli/etcd.py @@ -1,23 +1,21 @@ from __future__ import annotations import asyncio -import contextlib import json import logging import sys -from typing import TYPE_CHECKING, AsyncIterator +from typing import TYPE_CHECKING import click from ai.backend.cli.types import ExitCode from ai.backend.common.cli import EnumChoice, MinMaxRange -from ai.backend.common.config import redis_config_iv -from ai.backend.common.etcd import AsyncEtcd, ConfigScopes +from ai.backend.common.etcd import ConfigScopes from ai.backend.common.etcd import quote as etcd_quote from ai.backend.common.etcd import unquote as etcd_unquote from ai.backend.common.logging import BraceStyleAdapter -from ..config import SharedConfig +from .context import etcd_ctx from .image_impl import alias as alias_impl from .image_impl import dealias as dealias_impl from .image_impl import forget_image as forget_image_impl @@ -37,50 +35,6 @@ def cli() -> None: pass -@contextlib.asynccontextmanager -async def etcd_ctx(cli_ctx: CLIContext) -> AsyncIterator[AsyncEtcd]: - local_config = cli_ctx.local_config - creds = None - if local_config["etcd"]["user"]: - creds = { - "user": local_config["etcd"]["user"], - "password": local_config["etcd"]["password"], - } - scope_prefix_map = { - ConfigScopes.GLOBAL: "", - # TODO: provide a way to specify other scope prefixes - } - etcd = AsyncEtcd( - local_config["etcd"]["addr"], - local_config["etcd"]["namespace"], - scope_prefix_map, - credentials=creds, - ) - try: - yield etcd - finally: - await etcd.close() - - -@contextlib.asynccontextmanager -async def config_ctx(cli_ctx: CLIContext) -> AsyncIterator[SharedConfig]: - local_config = cli_ctx.local_config - # scope_prefix_map is created inside ConfigServer - shared_config = SharedConfig( - local_config["etcd"]["addr"], - local_config["etcd"]["user"], - local_config["etcd"]["password"], - local_config["etcd"]["namespace"], - ) - await shared_config.reload() - raw_redis_config = await shared_config.etcd.get_prefix("config/redis") - local_config["redis"] = redis_config_iv.check(raw_redis_config) - try: - yield shared_config - finally: - await shared_config.close() - - @cli.command() @click.argument("key") @click.argument("value") @@ -300,7 +254,7 @@ def set_image_resource_limit( @cli.command() @click.argument("registry") @click.pass_obj -def rescan_images(cli_ctx: CLIContext, registry) -> None: +def rescan_images(cli_ctx: CLIContext, registry: str) -> None: """ Update the kernel image metadata from all configured docker registries. @@ -315,7 +269,7 @@ def rescan_images(cli_ctx: CLIContext, registry) -> None: @click.argument("target") @click.argument("architecture") @click.pass_obj -def alias(cli_ctx, alias, target, architecture) -> None: +def alias(cli_ctx: CLIContext, alias: str, target: str, architecture: str) -> None: """Add an image alias from the given alias to the target image reference.""" log.warn("etcd alias command is deprecated, use image alias instead") asyncio.run(alias_impl(cli_ctx, alias, target, architecture)) @@ -324,7 +278,7 @@ def alias(cli_ctx, alias, target, architecture) -> None: @cli.command() @click.argument("alias") @click.pass_obj -def dealias(cli_ctx, alias) -> None: +def dealias(cli_ctx: CLIContext, alias: str) -> None: """Remove an alias.""" log.warn("etcd dealias command is deprecated, use image dealias instead") asyncio.run(dealias_impl(cli_ctx, alias)) @@ -333,7 +287,7 @@ def dealias(cli_ctx, alias) -> None: @cli.command() @click.argument("value") @click.pass_obj -def quote(cli_ctx: CLIContext, value) -> None: +def quote(cli_ctx: CLIContext, value: str) -> None: """ Quote the given string for use as a URL piece in etcd keys. Use this to generate argument inputs for aliases and raw image keys. @@ -344,7 +298,7 @@ def quote(cli_ctx: CLIContext, value) -> None: @cli.command() @click.argument("value") @click.pass_obj -def unquote(cli_ctx: CLIContext, value) -> None: +def unquote(cli_ctx: CLIContext, value: str) -> None: """ Unquote the given string used as a URL piece in etcd keys. """ @@ -362,7 +316,12 @@ def unquote(cli_ctx: CLIContext, value) -> None: help="The configuration scope to put the value.", ) @click.pass_obj -def set_storage_sftp_scaling_group(cli_ctx: CLIContext, proxy, scaling_groups, scope) -> None: +def set_storage_sftp_scaling_group( + cli_ctx: CLIContext, + proxy: str, + scaling_groups: str, + scope: ConfigScopes, +) -> None: """ Updates storage proxy node config's SFTP desginated scaling groups. To enter multiple scaling groups concatenate names with comma(,). @@ -392,7 +351,11 @@ async def _impl(): help="The configuration scope to put the value.", ) @click.pass_obj -def remove_storage_sftp_scaling_group(cli_ctx: CLIContext, proxy, scope) -> None: +def remove_storage_sftp_scaling_group( + cli_ctx: CLIContext, + proxy: str, + scope: ConfigScopes, +) -> None: """ Removes storage proxy node config's SFTP desginated scaling groups. """ diff --git a/src/ai/backend/manager/cli/image.py b/src/ai/backend/manager/cli/image.py index 398ab98e1c..f89ee42910 100644 --- a/src/ai/backend/manager/cli/image.py +++ b/src/ai/backend/manager/cli/image.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import logging @@ -75,7 +77,7 @@ def set_resource_limit( @cli.command() -@click.argument("registry", required=False, default="") +@click.argument("registry_or_image", required=False, default="") @click.option( "--local", is_flag=True, @@ -83,13 +85,13 @@ def set_resource_limit( help="Scan the local Docker daemon instead of a registry", ) @click.pass_obj -def rescan(cli_ctx, registry: str, local: bool) -> None: +def rescan(cli_ctx, registry_or_image: str, local: bool) -> None: """ Update the kernel image metadata from all configured docker registries. Pass the name (usually hostname or "lablup") of the Docker registry configured as REGISTRY. """ - asyncio.run(rescan_images_impl(cli_ctx, registry, local)) + asyncio.run(rescan_images_impl(cli_ctx, registry_or_image, local)) @cli.command() diff --git a/src/ai/backend/manager/cli/image_impl.py b/src/ai/backend/manager/cli/image_impl.py index f44af4d957..b6c7c0f64a 100644 --- a/src/ai/backend/manager/cli/image_impl.py +++ b/src/ai/backend/manager/cli/image_impl.py @@ -1,7 +1,7 @@ -import contextlib +from __future__ import annotations + import logging from pprint import pformat, pprint -from typing import AsyncIterator import click import sqlalchemy as sa @@ -10,40 +10,15 @@ from ai.backend.common import redis_helper from ai.backend.common.docker import ImageRef -from ai.backend.common.etcd import AsyncEtcd, ConfigScopes from ai.backend.common.exception import UnknownImageReference from ai.backend.common.logging import BraceStyleAdapter -from ai.backend.manager.cli.context import CLIContext, redis_ctx -from ai.backend.manager.models.image import ImageAliasRow, ImageRow -from ai.backend.manager.models.image import rescan_images as rescan_images_func -from ai.backend.manager.models.utils import connect_database - -log = BraceStyleAdapter(logging.getLogger(__spec__.name)) # type: ignore[name-defined] +from ..models.image import ImageAliasRow, ImageRow +from ..models.image import rescan_images as rescan_images_func +from ..models.utils import connect_database +from .context import CLIContext, etcd_ctx, redis_ctx -@contextlib.asynccontextmanager -async def etcd_ctx(cli_ctx: CLIContext) -> AsyncIterator[AsyncEtcd]: - local_config = cli_ctx.local_config - creds = None - if local_config["etcd"]["user"]: - creds = { - "user": local_config["etcd"]["user"], - "password": local_config["etcd"]["password"], - } - scope_prefix_map = { - ConfigScopes.GLOBAL: "", - # TODO: provide a way to specify other scope prefixes - } - etcd = AsyncEtcd( - local_config["etcd"]["addr"], - local_config["etcd"]["namespace"], - scope_prefix_map, - credentials=creds, - ) - try: - yield etcd - finally: - await etcd.close() +log = BraceStyleAdapter(logging.getLogger(__spec__.name)) # type: ignore[name-defined] async def list_images(cli_ctx, short, installed_only): @@ -175,15 +150,15 @@ async def set_image_resource_limit( log.exception("An error occurred.") -async def rescan_images(cli_ctx: CLIContext, registry: str, local: bool) -> None: - if not registry and not local: - raise click.BadArgumentUsage("Please specify a valid registry name.") +async def rescan_images(cli_ctx: CLIContext, registry_or_image: str, local: bool) -> None: + if not registry_or_image and not local: + raise click.BadArgumentUsage("Please specify a valid registry or full image name.") async with ( connect_database(cli_ctx.local_config) as db, etcd_ctx(cli_ctx) as etcd, ): try: - await rescan_images_func(etcd, db, registry=registry, local=local) + await rescan_images_func(etcd, db, registry_or_image, local=local) except Exception: log.exception("An error occurred.") diff --git a/src/ai/backend/manager/container_registry/base.py b/src/ai/backend/manager/container_registry/base.py index bf878bb09f..ea220c58a7 100644 --- a/src/ai/backend/manager/container_registry/base.py +++ b/src/ai/backend/manager/container_registry/base.py @@ -6,7 +6,7 @@ from abc import ABCMeta, abstractmethod from contextlib import asynccontextmanager as actxmgr from contextvars import ContextVar -from typing import Any, AsyncIterator, Dict, Mapping, Optional, cast +from typing import Any, AsyncIterator, Dict, Final, Mapping, Optional, cast import aiohttp import aiotools @@ -19,10 +19,16 @@ from ai.backend.common.docker import login as registry_login from ai.backend.common.exception import InvalidImageName, InvalidImageTag from ai.backend.common.logging import BraceStyleAdapter -from ai.backend.manager.models.image import ImageRow, ImageType -from ai.backend.manager.models.utils import ExtendedAsyncSAEngine + +from ..models.image import ImageRow, ImageType +from ..models.utils import ExtendedAsyncSAEngine log = BraceStyleAdapter(logging.getLogger(__spec__.name)) # type: ignore[name-defined] +concurrency_sema: ContextVar[asyncio.Semaphore] = ContextVar("concurrency_sema") +progress_reporter: ContextVar[Optional[ProgressReporter]] = ContextVar( + "progress_reporter", default=None +) +all_updates: ContextVar[Dict[ImageRef, Dict[str, Any]]] = ContextVar("all_updates") class BaseContainerRegistry(metaclass=ABCMeta): @@ -35,9 +41,13 @@ class BaseContainerRegistry(metaclass=ABCMeta): credentials: Dict[str, str] ssl_verify: bool - sema: ContextVar[asyncio.Semaphore] - reporter: ContextVar[Optional[ProgressReporter]] - all_updates: ContextVar[Dict[ImageRef, Dict[str, Any]]] + content_type_docker_manifest_list: Final[str] = ( + "application/vnd.docker.distribution.manifest.list.v2+json" + ) + content_type_oci_manifest: Final[str] = "application/vnd.oci.image.index.v1+json" + content_type_docker_manifest: Final[str] = ( + "application/vnd.docker.distribution.manifest.v2+json" + ) def __init__( self, @@ -58,10 +68,8 @@ def __init__( } self.credentials = {} self.ssl_verify = ssl_verify - self.sema = ContextVar("sema") - self.reporter = ContextVar("reporter", default=None) - self.all_updates = ContextVar("all_updates") + @actxmgr async def prepare_client_session(self) -> AsyncIterator[tuple[yarl.URL, aiohttp.ClientSession]]: ssl_ctx = None # default if not self.registry_info["ssl_verify"]: @@ -74,26 +82,31 @@ async def rescan_single_registry( self, reporter: ProgressReporter | None = None, ) -> None: - self.all_updates.set({}) - self.sema.set(asyncio.Semaphore(self.max_concurrency_per_registry)) - self.reporter.set(reporter) - username = self.registry_info["username"] - if username is not None: - self.credentials["username"] = username - password = self.registry_info["password"] - if password is not None: - self.credentials["password"] = password - async with actxmgr(self.prepare_client_session)() as (url, client_session): - self.registry_url = url - async with aiotools.TaskGroup() as tg: - async for image in self.fetch_repositories(client_session): - tg.create_task(self._scan_image(client_session, image)) + all_updates_token = all_updates.set({}) + concurrency_sema.set(asyncio.Semaphore(self.max_concurrency_per_registry)) + progress_reporter.set(reporter) + try: + username = self.registry_info["username"] + if username is not None: + self.credentials["username"] = username + password = self.registry_info["password"] + if password is not None: + self.credentials["password"] = password + async with self.prepare_client_session() as (url, client_session): + self.registry_url = url + async with aiotools.TaskGroup() as tg: + async for image in self.fetch_repositories(client_session): + tg.create_task(self._scan_image(client_session, image)) + await self.commit_rescan_result() + finally: + all_updates.reset(all_updates_token) - all_updates = self.all_updates.get() - if not all_updates: + async def commit_rescan_result(self) -> None: + _all_updates = all_updates.get() + if not _all_updates: log.info("No images found in registry {0}", self.registry_url) else: - image_identifiers = [(k.canonical, k.architecture) for k in all_updates.keys()] + image_identifiers = [(k.canonical, k.architecture) for k in _all_updates.keys()] async with self.db.begin_session() as session: existing_images = await session.scalars( sa.select(ImageRow).where( @@ -104,10 +117,10 @@ async def rescan_single_registry( for image_row in existing_images: key = image_row.image_ref - values = all_updates.get(key) + values = _all_updates.get(key) if values is None: continue - all_updates.pop(key) + _all_updates.pop(key) image_row.config_digest = values["config_digest"] image_row.size_bytes = values["size_bytes"] image_row.accelerators = values.get("accels") @@ -131,11 +144,30 @@ async def rescan_single_registry( labels=v["labels"], resources=v["resources"], ) - for k, v in all_updates.items() + for k, v in _all_updates.items() ] ) await session.flush() + async def scan_single_ref(self, image_ref: str) -> None: + all_updates_token = all_updates.set({}) + sema_token = concurrency_sema.set(asyncio.Semaphore(1)) + try: + async with self.prepare_client_session() as (url, sess): + image, tag = ImageRef._parse_image_tag(image_ref) + rqst_args = await registry_login( + sess, + self.registry_url, + self.credentials, + f"repository:{image}:pull", + ) + rqst_args["headers"].update(**self.base_hdrs) + await self._scan_tag(sess, rqst_args, image, tag) + await self.commit_rescan_result() + finally: + concurrency_sema.reset(sema_token) + all_updates.reset(all_updates_token) + async def _scan_image( self, sess: aiohttp.ClientSession, @@ -166,7 +198,7 @@ async def _scan_image( tag_list_url = self.registry_url.with_path(next_page_url.path).with_query( next_page_url.query ) - if (reporter := self.reporter.get()) is not None: + if (reporter := progress_reporter.get()) is not None: reporter.total_progress += len(tags) async with aiotools.TaskGroup() as tg: for tag in tags: @@ -177,20 +209,40 @@ async def _scan_tag( sess: aiohttp.ClientSession, rqst_args: dict[str, Any], image: str, - digest: str, - tag: Optional[str] = None, + tag: str, ) -> None: - async with self.sema.get(): + manifests = {} + async with concurrency_sema.get(): + rqst_args["headers"]["Accept"] = self.content_type_docker_manifest_list async with sess.get( - self.registry_url / f"v2/{image}/manifests/{digest}", **rqst_args + self.registry_url / f"v2/{image}/manifests/{tag}", **rqst_args ) as resp: if resp.status == 404: # ignore missing tags # (may occur after deleting an image from the docker hub) return + content_type = resp.headers["Content-Type"] resp.raise_for_status() - data = await resp.json() - + if content_type != self.content_type_docker_manifest_list: + raise RuntimeError( + "The registry does not support the standard way of " + "listing multiarch images." + ) + resp_json = await resp.json() + manifest_list = resp_json["manifests"] + rqst_args["headers"]["Accept"] = self.content_type_docker_manifest + for manifest in manifest_list: + platform_arg = ( + f"{manifest['platform']['os']}/{manifest['platform']['architecture']}" + ) + if variant := manifest["platform"].get("variant", None): + platform_arg += f"/{variant}" + architecture = manifest["platform"]["architecture"] + architecture = arch_name_aliases.get(architecture, architecture) + async with sess.get( + self.registry_url / f"v2/{image}/manifests/{manifest['digest']}", **rqst_args + ) as resp: + data = await resp.json() config_digest = data["config"]["digest"] size_bytes = sum(layer["size"] for layer in data["layers"]) + data["config"]["size"] async with sess.get( @@ -198,32 +250,35 @@ async def _scan_tag( ) as resp: resp.raise_for_status() data = json.loads(await resp.read()) - architecture = arch_name_aliases.get(data["architecture"], data["architecture"]) - labels = {} - if "container_config" in data: - raw_labels = data["container_config"].get("Labels") - if raw_labels: - labels.update(raw_labels) - else: - log.warn( - "label not found on image {}:{}/{}", image, digest, architecture - ) + labels = {} + if "container_config" in data: + raw_labels = data["container_config"].get("Labels") + if raw_labels: + labels.update(raw_labels) else: - raw_labels = data["config"].get("Labels") - if raw_labels: - labels.update(raw_labels) - else: - log.warn( - "label not found on image {}:{}/{}", image, digest, architecture - ) - manifest = { - architecture: { - "size": size_bytes, - "labels": labels, - "digest": config_digest, - }, - } - await self._read_manifest(image, tag or digest, manifest) + log.warn( + "label not found on image {}:{}/{}", + image, + tag, + architecture, + ) + else: + raw_labels = data["config"].get("Labels") + if raw_labels: + labels.update(raw_labels) + else: + log.warn( + "label not found on image {}:{}/{}", + image, + tag, + architecture, + ) + manifests[architecture] = { + "size": size_bytes, + "labels": labels, + "digest": config_digest, + } + await self._read_manifest(image, tag, manifests) async def _read_manifest( self, @@ -237,7 +292,7 @@ async def _read_manifest( skip_reason = "missing/deleted" log.warning("Skipped image - {}:{} ({})", image, tag, skip_reason) progress_msg = f"Skipped {image}:{tag} ({skip_reason})" - if (reporter := self.reporter.get()) is not None: + if (reporter := progress_reporter.get()) is not None: await reporter.update(1, message=progress_msg) return @@ -292,7 +347,7 @@ async def _read_manifest( res_key = k[len(res_prefix) :] resources[res_key] = {"min": v} updates["resources"] = ImageRow.resources.type._schema.check(resources) - self.all_updates.get().update( + all_updates.get().update( { update_key: updates, } @@ -306,9 +361,15 @@ async def _read_manifest( ) progress_msg = f"Skipped {image}:{tag}/{architecture} ({skip_reason})" else: - log.info("Updated image - {0}:{1}/{2}", image, tag, architecture) - progress_msg = f"Updated {image}:{tag}/{architecture}" - if (reporter := self.reporter.get()) is not None: + log.info( + "Updated image - {0}:{1}/{2} ({3})", + image, + tag, + architecture, + manifest["digest"], + ) + progress_msg = f"Updated {image}:{tag}/{architecture} ({manifest['digest']})" + if (reporter := progress_reporter.get()) is not None: await reporter.update(1, message=progress_msg) @abstractmethod diff --git a/src/ai/backend/manager/container_registry/docker.py b/src/ai/backend/manager/container_registry/docker.py index dc4a775679..ce130b2f5b 100644 --- a/src/ai/backend/manager/container_registry/docker.py +++ b/src/ai/backend/manager/container_registry/docker.py @@ -3,6 +3,7 @@ from typing import AsyncIterator, Optional, cast import aiohttp +import typing_extensions import yarl from ai.backend.common.docker import login as registry_login @@ -14,11 +15,23 @@ class DockerHubRegistry(BaseContainerRegistry): + @typing_extensions.deprecated( + "Rescanning a whole Docker Hub account is disabled due to the API rate limit." + ) async def fetch_repositories( self, sess: aiohttp.ClientSession, ) -> AsyncIterator[str]: # We need some special treatment for the Docker Hub. + raise DeprecationWarning( + "Rescanning a whole Docker Hub account is disabled due to the API rate limit." + ) + yield "" # dead code to ensure the type of method + + async def fetch_repositories_legacy( + self, + sess: aiohttp.ClientSession, + ) -> AsyncIterator[str]: params = {"page_size": "30"} username = self.registry_info["username"] hub_url = yarl.URL("https://hub.docker.com") diff --git a/src/ai/backend/manager/container_registry/harbor.py b/src/ai/backend/manager/container_registry/harbor.py index c380ed0e48..335ff86b42 100644 --- a/src/ai/backend/manager/container_registry/harbor.py +++ b/src/ai/backend/manager/container_registry/harbor.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +import json import logging import urllib.parse from typing import Any, AsyncIterator, Mapping, Optional, cast @@ -6,10 +9,15 @@ import aiotools import yarl +from ai.backend.common.docker import arch_name_aliases from ai.backend.common.docker import login as registry_login from ai.backend.common.logging import BraceStyleAdapter -from .base import BaseContainerRegistry +from .base import ( + BaseContainerRegistry, + concurrency_sema, + progress_reporter, +) log = BraceStyleAdapter(logging.getLogger(__spec__.name)) # type: ignore[name-defined] @@ -144,18 +152,22 @@ async def _scan_image( return tag = image_info["tags"][0]["name"] match image_info["manifest_media_type"]: - case "application/vnd.oci.image.index.v1+json": + case self.content_type_oci_manifest: await self._process_oci_index( tg, sess, rqst_args, image, image_info ) - case "application/vnd.docker.distribution.manifest.list.v2+json": + case self.content_type_docker_manifest_list: await self._process_docker_v2_multiplatform_image( tg, sess, rqst_args, image, image_info ) - case _: + case self.content_type_docker_manifest: await self._process_docker_v2_image( tg, sess, rqst_args, image, image_info ) + case _ as media_type: + raise RuntimeError( + f"Unsupported artifact media-type: {media_type}" + ) finally: if skip_reason: log.warn("Skipped image - {}:{} ({})", image, tag, skip_reason) @@ -179,21 +191,27 @@ async def _process_oci_index( if not rqst_args.get("headers"): rqst_args["headers"] = {} rqst_args["headers"].update({"Accept": "application/vnd.oci.image.manifest.v1+json"}) - digests: list[str] = [] + digests: list[tuple[str, str]] = [] + tag_name = image_info["tags"][0]["name"] for reference in image_info["references"]: if ( reference["platform"]["os"] == "unknown" or reference["platform"]["architecture"] == "unknown" ): continue - digests.append(reference["child_digest"]) - if (reporter := self.reporter.get()) is not None: + digests.append((reference["child_digest"], reference["platform"]["architecture"])) + if (reporter := progress_reporter.get()) is not None: reporter.total_progress += len(digests) async with aiotools.TaskGroup() as tg: - for digest in digests: + for digest, architecture in digests: tg.create_task( - self._scan_tag( - sess, rqst_args, image, digest, tag=image_info["tags"][0]["name"] + self._harbor_scan_tag_per_arch( + sess, + rqst_args, + image, + digest=digest, + tag=tag_name, + architecture=architecture, ) ) @@ -211,21 +229,27 @@ async def _process_docker_v2_multiplatform_image( rqst_args["headers"].update( {"Accept": "application/vnd.docker.distribution.manifest.v2+json"} ) - digests: list[str] = [] + digests: list[tuple[str, str]] = [] + tag_name = image_info["tags"][0]["name"] for reference in image_info["references"]: if ( reference["platform"]["os"] == "unknown" or reference["platform"]["architecture"] == "unknown" ): continue - digests.append(reference["child_digest"]) - if (reporter := self.reporter.get()) is not None: + digests.append((reference["child_digest"], reference["platform"]["architecture"])) + if (reporter := progress_reporter.get()) is not None: reporter.total_progress += len(digests) async with aiotools.TaskGroup() as tg: - for digest in digests: + for digest, architecture in digests: tg.create_task( - self._scan_tag( - sess, rqst_args, image, digest, tag=image_info["tags"][0]["name"] + self._harbor_scan_tag_per_arch( + sess, + rqst_args, + image, + digest=digest, + tag=tag_name, + architecture=architecture, ) ) @@ -243,7 +267,142 @@ async def _process_docker_v2_image( rqst_args["headers"].update( {"Accept": "application/vnd.docker.distribution.manifest.v2+json"} ) - if (reporter := self.reporter.get()) is not None: + if (reporter := progress_reporter.get()) is not None: reporter.total_progress += 1 + tag_name = image_info["tags"][0]["name"] async with aiotools.TaskGroup() as tg: - tg.create_task(self._scan_tag(sess, rqst_args, image, image_info["tags"][0]["name"])) + tg.create_task( + self._harbor_scan_tag_single_arch( + sess, + rqst_args, + image, + tag=tag_name, + ) + ) + + async def _harbor_scan_tag_per_arch( + self, + sess: aiohttp.ClientSession, + rqst_args: dict[str, Any], + image: str, + *, + digest: str, + tag: str, + architecture: str, + ) -> None: + """ + Scan 'image:tag' when there are explicitly known values of digest and architecture. + """ + manifests = {} + async with concurrency_sema.get(): + async with sess.get( + self.registry_url / f"v2/{image}/manifests/{digest}", **rqst_args + ) as resp: + if resp.status == 404: + # ignore missing tags + # (may occur after deleting an image from the docker hub) + return + resp.raise_for_status() + top_manifest = await resp.json() + architecture = arch_name_aliases.get(architecture, architecture) + config_digest = top_manifest["config"]["digest"] + size_bytes = ( + sum(layer["size"] for layer in top_manifest["layers"]) + + top_manifest["config"]["size"] + ) + async with sess.get( + self.registry_url / f"v2/{image}/blobs/{config_digest}", **rqst_args + ) as resp: + resp.raise_for_status() + data = json.loads(await resp.read()) + labels = {} + if "container_config" in data: + raw_labels = data["container_config"].get("Labels") + if raw_labels: + labels.update(raw_labels) + else: + log.warn( + "label not found on image {}:{}/{}", + image, + tag, + architecture, + ) + else: + raw_labels = data["config"].get("Labels") + if raw_labels: + labels.update(raw_labels) + else: + log.warn( + "label not found on image {}:{}/{}", + image, + tag, + architecture, + ) + manifests[architecture] = { + "size": size_bytes, + "labels": labels, + "digest": config_digest, + } + await self._read_manifest(image, tag, manifests) + + async def _harbor_scan_tag_single_arch( + self, + sess: aiohttp.ClientSession, + rqst_args: dict[str, Any], + image: str, + tag: str, + ) -> None: + """ + Scan 'image:tag' which has been pusehd as a single architecture tag. + In this case, Harbor does not provide explicit methods to determine the architecture. + We infer the architecture from the tag naming patterns ("-arm64" for instance). + """ + manifests = {} + async with concurrency_sema.get(): + rqst_args["headers"]["Accept"] = self.content_type_docker_manifest + # Harbor does not provide architecture information for a single-arch tag reference. + # We heuristically detect the architecture using the tag name pattern. + if tag.endswith("-arm64") or tag.endswith("-aarch64"): + architecture = "aarch64" + else: + architecture = "x86_64" + async with sess.get( + self.registry_url / f"v2/{image}/manifests/{tag}", **rqst_args + ) as resp: + data = await resp.json() + config_digest = data["config"]["digest"] + size_bytes = sum(layer["size"] for layer in data["layers"]) + data["config"]["size"] + async with sess.get( + self.registry_url / f"v2/{image}/blobs/{config_digest}", **rqst_args + ) as resp: + resp.raise_for_status() + data = json.loads(await resp.read()) + labels = {} + if "container_config" in data: + raw_labels = data["container_config"].get("Labels") + if raw_labels: + labels.update(raw_labels) + else: + log.warn( + "label not found on image {}:{}/{}", + image, + tag, + architecture, + ) + else: + raw_labels = data["config"].get("Labels") + if raw_labels: + labels.update(raw_labels) + else: + log.warn( + "label not found on image {}:{}/{}", + image, + tag, + architecture, + ) + manifests[architecture] = { + "size": size_bytes, + "labels": labels, + "digest": config_digest, + } + await self._read_manifest(image, tag, manifests) diff --git a/src/ai/backend/manager/container_registry/local.py b/src/ai/backend/manager/container_registry/local.py index 52fa7fb8b2..333e31dfdc 100644 --- a/src/ai/backend/manager/container_registry/local.py +++ b/src/ai/backend/manager/container_registry/local.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import json import logging +from contextlib import asynccontextmanager as actxmgr from typing import AsyncIterator, Optional import aiohttp @@ -10,12 +13,17 @@ from ai.backend.common.logging import BraceStyleAdapter from ..models.image import ImageRow -from .base import BaseContainerRegistry +from .base import ( + BaseContainerRegistry, + concurrency_sema, + progress_reporter, +) log = BraceStyleAdapter(logging.getLogger(__spec__.name)) # type: ignore[name-defined] class LocalRegistry(BaseContainerRegistry): + @actxmgr async def prepare_client_session(self) -> AsyncIterator[tuple[yarl.URL, aiohttp.ClientSession]]: url, connector = get_docker_connector() async with aiohttp.ClientSession(connector=connector) as sess: @@ -27,7 +35,7 @@ async def fetch_repositories( ) -> AsyncIterator[str]: async with sess.get(self.registry_url / "images" / "json") as response: items = await response.json() - if (reporter := self.reporter.get()) is not None: + if (reporter := progress_reporter.get()) is not None: reporter.total_progress = len(items) for item in items: labels = item["Labels"] @@ -46,9 +54,9 @@ async def _scan_image( image: str, ) -> None: repo, _, tag = image.rpartition(":") - await self._scan_tag(sess, {}, repo, tag) + await self._scan_tag_local(sess, {}, repo, tag) - async def _scan_tag( + async def _scan_tag_local( self, sess: aiohttp.ClientSession, rqst_args: dict[str, str], @@ -63,36 +71,34 @@ async def _read_image_info( self.registry_url / "images" / f"{image}:{digest}" / "json" ) as response: data = await response.json() - architecture = data["Architecture"] - summary = { - "Id": data["Id"], - "RepoDigests": data.get("RepoDigests", []), - "Config.Image": data["Config"]["Image"], - "ContainerConfig.Image": data["ContainerConfig"]["Image"], - "Architecture": data["Architecture"], - } - log.debug( - "scanned image info: {}:{}\n{}", image, digest, json.dumps(summary, indent=2) - ) - already_exists = 0 - config_digest = data["Id"] - async with self.db.begin_readonly_session() as db_session: - already_exists = await db_session.scalar( - sa.select([sa.func.count(ImageRow.id)]).where( - ImageRow.config_digest == config_digest, - ImageRow.is_local == sa.false(), - ) + architecture = data["Architecture"] + summary = { + "Id": data["Id"], + "RepoDigests": data.get("RepoDigests", []), + "Config.Image": data["Config"]["Image"], + "ContainerConfig.Image": data["ContainerConfig"]["Image"], + "Architecture": data["Architecture"], + } + log.debug("scanned image info: {}:{}\n{}", image, digest, json.dumps(summary, indent=2)) + already_exists = 0 + config_digest = data["Id"] + async with self.db.begin_readonly_session() as db_session: + already_exists = await db_session.scalar( + sa.select([sa.func.count(ImageRow.id)]).where( + ImageRow.config_digest == config_digest, + ImageRow.is_local == sa.false(), ) - if already_exists > 0: - return {}, "already synchronized from a remote registry" - return { - architecture: { - "size": data["Size"], - "labels": data["Config"]["Labels"], - "digest": config_digest, - }, - }, None + ) + if already_exists > 0: + return {}, "already synchronized from a remote registry" + return { + architecture: { + "size": data["Size"], + "labels": data["Config"]["Labels"], + "digest": config_digest, + }, + }, None - async with self.sema.get(): + async with concurrency_sema.get(): manifests, skip_reason = await _read_image_info(digest) await self._read_manifest(image, digest, manifests, skip_reason) diff --git a/src/ai/backend/manager/models/image.py b/src/ai/backend/manager/models/image.py index ceb8a8f593..de548d5b51 100644 --- a/src/ai/backend/manager/models/image.py +++ b/src/ai/backend/manager/models/image.py @@ -15,6 +15,7 @@ Sequence, Tuple, Union, + cast, ) import aiotools @@ -32,10 +33,10 @@ from ai.backend.common.exception import UnknownImageReference from ai.backend.common.logging import BraceStyleAdapter from ai.backend.common.types import BinarySize, ImageAlias, ResourceSlot -from ai.backend.manager.api.exceptions import ImageNotFound -from ai.backend.manager.container_registry import get_container_registry_cls -from ai.backend.manager.defs import DEFAULT_IMAGE_ARCH +from ..api.exceptions import ImageNotFound +from ..container_registry import get_container_registry_cls +from ..defs import DEFAULT_IMAGE_ARCH from .base import ( Base, BigInt, @@ -53,8 +54,8 @@ if TYPE_CHECKING: from ai.backend.common.bgtask import ProgressReporter - from ai.backend.manager.config import SharedConfig + from ..config import SharedConfig from .gql import GraphQueryContext log = BraceStyleAdapter(logging.getLogger(__spec__.name)) # type: ignore[name-defined] @@ -78,13 +79,13 @@ async def rescan_images( etcd: AsyncEtcd, db: ExtendedAsyncSAEngine, - registry: str = None, - local: bool = False, + registry_or_image: str | None = None, *, - reporter: ProgressReporter = None, + local: bool | None = False, + reporter: ProgressReporter | None = None, ) -> None: # cannot import ai.backend.manager.config at start due to circular import - from ai.backend.manager.config import container_registry_iv + from ..config import container_registry_iv if local: registries = { @@ -98,18 +99,38 @@ async def rescan_images( } else: registry_config_iv = t.Mapping(t.String, container_registry_iv) - latest_registry_config = registry_config_iv.check( - await etcd.get_prefix("config/docker/registry"), + latest_registry_config = cast( + dict[str, Any], + registry_config_iv.check( + await etcd.get_prefix("config/docker/registry"), + ), ) # TODO: delete images from registries removed from the previous config? - if registry is None: + if registry_or_image is None: # scan all configured registries registries = latest_registry_config else: - try: - registries = {registry: latest_registry_config[registry]} - except KeyError: - raise RuntimeError("It is an unknown registry.", registry) + # find if it's a full image ref of one of configured registries + for registry_name, registry_info in latest_registry_config.items(): + if registry_or_image.startswith(registry_name + "/"): + repo_with_tag = registry_or_image.removeprefix(registry_name + "/") + log.debug( + "running a per-image metadata scan: {}, {}", + registry_name, + repo_with_tag, + ) + scanner_cls = get_container_registry_cls(registry_info) + scanner = scanner_cls(db, registry_name, registry_info) + await scanner.scan_single_ref(repo_with_tag) + return + else: + # treat it as a normal registry name + registry = registry_or_image + try: + registries = {registry: latest_registry_config[registry]} + log.debug("running a per-registry metadata scan") + except KeyError: + raise RuntimeError("It is an unknown registry.", registry) async with aiotools.TaskGroup() as tg: for registry_name, registry_info in registries.items(): log.info('Scanning kernel images from the registry "{0}"', registry_name)