Skip to content

Commit

Permalink
refactor: Revamp ContainerRegistryNode API
Browse files Browse the repository at this point in the history
  • Loading branch information
jopemachine committed Jan 10, 2025
1 parent 978b571 commit 265cecf
Show file tree
Hide file tree
Showing 6 changed files with 262 additions and 326 deletions.
60 changes: 13 additions & 47 deletions src/ai/backend/client/func/container_registry.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,33 @@
from __future__ import annotations

import textwrap
from ai.backend.client.request import Request

from ..session import api_session
from .base import BaseFunction, api_function

__all__ = ("ContainerRegistry",)


class ContainerRegistry(BaseFunction):
"""
Provides a shortcut of :func:`Admin.query()
<ai.backend.client.admin.Admin.query>` that fetches, modifies various container registry
information.
.. note::
All methods in this function class require your API access key to
have the *admin* privilege.
Provides functions to manage container registries.
"""

@api_function
@classmethod
async def associate_group(cls, registry_id: str, group_id: str) -> dict:
# TODO: Implement params type
async def patch_container_registry(cls, registry_id: str, params) -> None:
"""
Associate container_registry with group.
Updates the container registry information, and return the container registry.
:param registry_id: ID of the container registry.
:param group_id: ID of the group.
"""
query = textwrap.dedent(
"""\
mutation($registry_id: String!, $group_id: String!) {
associate_container_registry_with_group(
registry_id: $registry_id, group_id: $group_id) {
ok msg
}
}
"""
)
variables = {"registry_id": registry_id, "group_id": group_id}
data = await api_session.get().Admin._query(query, variables)
return data["associate_container_registry_with_group"]

@api_function
@classmethod
async def disassociate_group(cls, registry_id: str, group_id: str) -> dict:
:param params: Parameters to update the container registry.
"""
Disassociate container_registry with group.

:param registry_id: ID of the container registry.
:param group_id: ID of the group.
"""
query = textwrap.dedent(
"""\
mutation($registry_id: String!, $group_id: String!) {
disassociate_container_registry_with_group(
registry_id: $registry_id, group_id: $group_id) {
ok msg
}
}
"""
request = Request(
"PATCH",
f"/container-registries/{registry_id}",
)
variables = {"registry_id": registry_id, "group_id": group_id}
data = await api_session.get().Admin._query(query, variables)
return data["disassociate_container_registry_with_group"]
request.set_json(params)

async with request.fetch() as resp:
await resp.read()
129 changes: 71 additions & 58 deletions src/ai/backend/manager/api/container_registry.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Iterable, Tuple
import uuid
from typing import TYPE_CHECKING, Iterable, Optional, Tuple

import aiohttp_cors
import sqlalchemy as sa
from aiohttp import web
from pydantic import AliasChoices, BaseModel, Field
from pydantic import BaseModel
from sqlalchemy.exc import IntegrityError

from ai.backend.logging import BraceStyleAdapter
from ai.backend.manager.models.association_container_registries_groups import (
AssociationContainerRegistriesGroupsRow,
)
from ai.backend.manager.models.container_registry import ContainerRegistryRow
from ai.backend.manager.models.utils import ExtendedAsyncSAEngine

from .exceptions import ContainerRegistryNotFound, GenericBadRequest
from .exceptions import GenericBadRequest, InternalServerError

if TYPE_CHECKING:
from .context import RootContext
Expand All @@ -27,75 +30,86 @@
log = BraceStyleAdapter(logging.getLogger(__spec__.name))


class AssociationRequestModel(BaseModel):
registry_id: str = Field(
validation_alias=AliasChoices("registry_id", "registry"),
description="Container registry row's ID",
)
group_id: str = Field(
validation_alias=AliasChoices("group_id", "group"),
description="Group row's ID",
)
class AllowedGroups(BaseModel):
add: list[str] = []
remove: list[str] = []


@server_status_required(READ_ALLOWED)
@superadmin_required
@pydantic_params_api_handler(AssociationRequestModel)
async def associate_with_group(
request: web.Request, params: AssociationRequestModel
) -> web.Response:
log.info("ASSOCIATE_WITH_GROUP (cr:{}, gr:{})", params.registry_id, params.group_id)
root_ctx: RootContext = request.app["_root.context"]
registry_id = params.registry_id
group_id = params.group_id
class PatchContainerRegistryRequestModel(BaseModel):
url: Optional[str] = None
type: Optional[str] = None
registry_name: Optional[str] = None
is_global: Optional[bool] = None
project: Optional[str] = None
username: Optional[str] = None
password: Optional[str] = None
ssl_verify: Optional[bool] = None
extra: Optional[str] = None
allowed_groups: Optional[AllowedGroups] = None

async with root_ctx.db.begin_session() as db_sess:
insert_query = sa.insert(AssociationContainerRegistriesGroupsRow).values({
"registry_id": registry_id,
"group_id": group_id,
})

try:
await db_sess.execute(insert_query)
except IntegrityError:
raise GenericBadRequest("Association already exists.")
# TODO: Add this. ContainerRegistryRow is not compatible with BaseModel
# class PatchContainerRegistryResponseModel(BaseModel):
# container_registry: ContainerRegistryRow

return web.Response(status=204)

async def handle_allowed_groups_update(
db: ExtendedAsyncSAEngine, registry_id: uuid.UUID, allowed_group_updates: AllowedGroups
):
async with db.begin_session() as db_sess:
if allowed_group_updates.add:
insert_values = [
{"registry_id": registry_id, "group_id": group_id}
for group_id in allowed_group_updates.add
]

insert_query = sa.insert(AssociationContainerRegistriesGroupsRow).values(insert_values)
await db_sess.execute(insert_query)

class DisassociationRequestModel(BaseModel):
registry_id: str = Field(
validation_alias=AliasChoices("registry_id", "registry"),
description="Container registry row's ID",
)
group_id: str = Field(
validation_alias=AliasChoices("group_id", "group"),
description="Group row's ID",
)
if allowed_group_updates.remove:
delete_query = (
sa.delete(AssociationContainerRegistriesGroupsRow)
.where(AssociationContainerRegistriesGroupsRow.registry_id == registry_id)
.where(
AssociationContainerRegistriesGroupsRow.group_id.in_(
allowed_group_updates.remove
)
)
)
await db_sess.execute(delete_query)


@server_status_required(READ_ALLOWED)
@superadmin_required
@pydantic_params_api_handler(DisassociationRequestModel)
async def disassociate_with_group(
request: web.Request, params: DisassociationRequestModel
@pydantic_params_api_handler(PatchContainerRegistryRequestModel)
async def patch_container_registry(
request: web.Request, params: PatchContainerRegistryRequestModel
) -> web.Response:
log.info("DISASSOCIATE_WITH_GROUP (cr:{}, gr:{})", params.registry_id, params.group_id)
registry_id = uuid.UUID(request.match_info["registry_id"])
log.info("PATCH_CONTAINER_REGISTRY (cr:{})", registry_id)
root_ctx: RootContext = request.app["_root.context"]
registry_id = params.registry_id
group_id = params.group_id

async with root_ctx.db.begin_session() as db_sess:
delete_query = (
sa.delete(AssociationContainerRegistriesGroupsRow)
.where(AssociationContainerRegistriesGroupsRow.registry_id == registry_id)
.where(AssociationContainerRegistriesGroupsRow.group_id == group_id)
input_config = params.model_dump(exclude={"allowed_groups"}, exclude_none=True)

async with root_ctx.db.begin_session() as db_session:
update_stmt = (
sa.update(ContainerRegistryRow)
.where(ContainerRegistryRow.id == registry_id)
.values(input_config)
)
await db_session.execute(update_stmt)

# select_stmt = sa.select(ContainerRegistryRow).where(ContainerRegistryRow.id == registry_id)
# updated_container_registry = await db_session.execute(select_stmt)

result = await db_sess.execute(delete_query)
if result.rowcount == 0:
raise ContainerRegistryNotFound()
try:
if params.allowed_groups:
await handle_allowed_groups_update(root_ctx.db, registry_id, params.allowed_groups)
except IntegrityError as e:
raise GenericBadRequest(f"Failed to update allowed groups! Details: {str(e)}")
except Exception as e:
raise InternalServerError(f"Failed to update allowed groups! Details: {str(e)}")

# return PatchContainerRegistryResponseModel(container_registry=updated_container_registry)
return web.Response(status=204)


Expand All @@ -106,6 +120,5 @@ def create_app(
app["api_versions"] = (1, 2, 3, 4, 5)
app["prefix"] = "container-registries"
cors = aiohttp_cors.setup(app, defaults=default_cors_options)
cors.add(app.router.add_route("POST", "/associate-with-group", associate_with_group))
cors.add(app.router.add_route("POST", "/disassociate-with-group", disassociate_with_group))
cors.add(app.router.add_route("PATCH", "/{registry_id}", patch_container_registry))
return app, []
Loading

0 comments on commit 265cecf

Please sign in to comment.