Skip to content

Commit

Permalink
Merge branch 'main' into fix/enhance-container-registry-mutations
Browse files Browse the repository at this point in the history
  • Loading branch information
rapsealk committed Oct 31, 2023
2 parents 2ceacde + b850b20 commit 2884748
Show file tree
Hide file tree
Showing 26 changed files with 264 additions and 254 deletions.
1 change: 1 addition & 0 deletions changes/1632.enhance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Upgrade Graphene and GraphQL core (v2 -> v3) for better support of Relay, security rules, and other improvements
1 change: 1 addition & 0 deletions changes/1664.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add a `allow_app_download_panel` config to webserver to show/hide the webui app download panel on the summary page.
1 change: 1 addition & 0 deletions changes/1665.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix symbolic link loop error of vfolder
1 change: 1 addition & 0 deletions changes/1666.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add a `allow_custom_resource_allocation` config to webserver to show/hide the custom allocation on the session launcher.
1 change: 1 addition & 0 deletions changes/1668.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Update the parameter of session-template update API to follow-up change of session-template create API.
1 change: 1 addition & 0 deletions changes/1674.enhance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Use the explicit `graphql.Undefined` value to fill the unspecified fields of GraphQL mutation input objects
1 change: 1 addition & 0 deletions changes/1676.misc.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Include `HOME` env-var when running tests via pants
4 changes: 4 additions & 0 deletions configs/webserver/sample.conf
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ mask_user_info = false
# hide_agents = true
# URL to download the webui electron app. If blank, https://github.com/lablup/backend.ai-webui/releases/download will be used.
# app_download_url = ""
# Allow users to see the panel downloading the webui app from the summary page.
# allow_app_download_panel = true
# Enable/disable 2-Factor-Authentication (TOTP).
enable_2FA = false
# Force enable 2-Factor-Authentication (TOTP).
Expand All @@ -61,6 +63,8 @@ force_2FA = false
# system_SSH_image = ""
# If true, display the amount of usage per directory such as folder capacity, and number of files and directories.
directory_based_usage = false
# If true, display the custom allocation on the session launcher.
# allow_custom_resource_allocation = true

[resources]
# Display "Open port to public" checkbox in the app launcher.
Expand Down
2 changes: 1 addition & 1 deletion pants.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ root_patterns = [
]

[test]
extra_env_vars = ["BACKEND_BUILD_ROOT=%(buildroot)s"]
extra_env_vars = ["BACKEND_BUILD_ROOT=%(buildroot)s", "HOME"]

[python]
enable_resolves = true
Expand Down
198 changes: 81 additions & 117 deletions python.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ cryptography>=2.8
dataclasses-json~=0.5.7
etcetra==0.1.17
faker~=13.12.0
graphene~=2.1.9
graphene~=3.3.0
humanize>=3.1.0
ifaddr~=0.2
inquirer~=2.9.2
Expand Down
18 changes: 17 additions & 1 deletion src/ai/backend/client/cli/pretty.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
import textwrap
import traceback
from typing import Sequence

from click import echo, style
from tqdm import tqdm
Expand Down Expand Up @@ -105,6 +106,17 @@ def print_pretty(msg, *, status=PrintStatus.NONE, file=None):
print_warn = functools.partial(print_pretty, status=PrintStatus.WARNING)


def _format_gql_path(items: Sequence[str | int]) -> str:
pieces = []
for item in items:
match item:
case int():
pieces.append(f"[{item}]")
case _:
pieces.append(f".{str(item)}")
return "".join(pieces)[1:] # strip first dot


def format_error(exc: Exception):
if isinstance(exc, BackendAPIError):
yield "{0}: {1} {2}\n".format(exc.__class__.__name__, exc.status, exc.reason)
Expand Down Expand Up @@ -132,7 +144,11 @@ def format_error(exc: Exception):
else:
if exc.data["type"].endswith("/graphql-error"):
yield "\n\u279c Message:\n"
yield from (f"{err_item['message']}\n" for err_item in exc.data.get("data", []))
for err_item in exc.data.get("data", []):
yield f"{err_item['message']}"
if err_path := err_item.get("path"):
yield f" (path: {_format_gql_path(err_path)})"
yield "\n"
else:
other_details = exc.data.get("msg", None)
if other_details:
Expand Down
48 changes: 23 additions & 25 deletions src/ai/backend/manager/api/admin.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from __future__ import annotations

import inspect
import logging
import re
from typing import TYPE_CHECKING, Any, Iterable, Tuple

import aiohttp_cors
import attrs
import graphene
import trafaret as t
from aiohttp import web
from graphql.error import GraphQLError, format_error # pants: no-infer-dep
from graphene.types.inputobjecttype import set_input_object_type_default_value
from graphene.validation import DisableIntrospection
from graphql import Undefined, parse, validate
from graphql.error import GraphQLError # pants: no-infer-dep
from graphql.execution import ExecutionResult # pants: no-infer-dep
from graphql.execution.executors.asyncio import AsyncioExecutor # pants: no-infer-dep

from ai.backend.common import validators as tx
from ai.backend.common.logging import BraceStyleAdapter
Expand All @@ -30,8 +30,6 @@

log = BraceStyleAdapter(logging.getLogger(__spec__.name)) # type: ignore[name-defined]

_rx_mutation_hdr = re.compile(r"^mutation(\s+\w+)?\s*(\(|{|@)", re.M)


class GQLLoggingMiddleware:
def resolve(self, next, root, info: graphene.ResolveInfo, **args) -> Any:
Expand All @@ -52,7 +50,14 @@ async def _handle_gql_common(request: web.Request, params: Any) -> ExecutionResu
app_ctx: PrivateContext = request.app["admin.context"]
manager_status = await root_ctx.shared_config.get_manager_status()
known_slot_types = await root_ctx.shared_config.get_resource_slots()

if not root_ctx.shared_config["api"]["allow-graphql-schema-introspection"]:
validate_errors = validate(
schema=app_ctx.gql_schema.graphql_schema,
document_ast=parse(params["query"]),
rules=(DisableIntrospection,),
)
if validate_errors:
return ExecutionResult(None, errors=validate_errors)
gql_ctx = GraphQueryContext(
schema=app_ctx.gql_schema,
dataloader_manager=DataLoaderManager(),
Expand All @@ -72,9 +77,9 @@ async def _handle_gql_common(request: web.Request, params: Any) -> ExecutionResu
registry=root_ctx.registry,
idle_checker_host=root_ctx.idle_checker_host,
)
result = app_ctx.gql_schema.execute(
result = await app_ctx.gql_schema.execute_async(
params["query"],
app_ctx.gql_executor,
None, # root
variable_values=params["variables"],
operation_name=params["operation_name"],
context_value=gql_ctx,
Expand All @@ -83,10 +88,7 @@ async def _handle_gql_common(request: web.Request, params: Any) -> ExecutionResu
GQLMutationUnfrozenRequiredMiddleware(),
GQLMutationPrivilegeCheckMiddleware(),
],
return_promise=True,
)
if inspect.isawaitable(result):
result = await result
return result


Expand All @@ -102,7 +104,7 @@ async def _handle_gql_common(request: web.Request, params: Any) -> ExecutionResu
)
async def handle_gql(request: web.Request, params: Any) -> web.Response:
result = await _handle_gql_common(request, params)
return web.json_response(result.to_dict(), status=200)
return web.json_response(result.formatted, status=200)


@auth_required
Expand All @@ -122,7 +124,7 @@ async def handle_gql_legacy(request: web.Request, params: Any) -> web.Response:
errors = []
for e in result.errors:
if isinstance(e, GraphQLError):
errmsg = format_error(e)
errmsg = e.formatted
errors.append(errmsg)
else:
errmsg = {"message": str(e)}
Expand All @@ -134,18 +136,22 @@ async def handle_gql_legacy(request: web.Request, params: Any) -> web.Response:

@attrs.define(auto_attribs=True, slots=True, init=False)
class PrivateContext:
gql_executor: AsyncioExecutor
gql_schema: graphene.Schema


async def init(app: web.Application) -> None:
app_ctx: PrivateContext = app["admin.context"]
app_ctx.gql_executor = AsyncioExecutor()
app_ctx.gql_schema = graphene.Schema(
query=Queries,
mutation=Mutations,
auto_camelcase=False,
)
root_ctx: RootContext = app["_root.context"]
if root_ctx.shared_config["api"]["allow-graphql-schema-introspection"]:
log.warning(
"GraphQL schema introspection is enabled. "
"It is strongly advised to disable this in production setups."
)


async def shutdown(app: web.Application) -> None:
Expand All @@ -159,16 +165,8 @@ def create_app(
app.on_startup.append(init)
app.on_shutdown.append(shutdown)
app["admin.context"] = PrivateContext()
set_input_object_type_default_value(Undefined)
cors = aiohttp_cors.setup(app, defaults=default_cors_options)
cors.add(app.router.add_route("POST", r"/graphql", handle_gql_legacy))
cors.add(app.router.add_route("POST", r"/gql", handle_gql))
return app, []


if __name__ == "__main__":
# If executed as a main program, print all GraphQL schemas.
# (graphene transforms our object model into a textual representation)
# This is useful for writing documentation!
schema = graphene.Schema(query=Queries, mutation=Mutations, auto_camelcase=False)
print("======== GraphQL API Schema ========")
print(str(schema))
2 changes: 1 addition & 1 deletion src/ai/backend/manager/api/session_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ async def put(request: web.Request, params: Any) -> web.Response:
body = yaml.safe_load(params["payload"])
except (yaml.YAMLError, yaml.MarkedYAMLError):
raise InvalidAPIParameters("Malformed payload")
for st in body["session_templates"]:
for st in body:
template_data = check_task_template(st["template"])
name = st["name"] if "name" in st else template_data["metadata"]["name"]
if "group_id" in st:
Expand Down
6 changes: 6 additions & 0 deletions src/ai/backend/manager/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
- timezone: "UTC" # pytz-compatible timezone names (e.g., "Asia/Seoul")
+ api
- allow-origins: "*"
- allow-graphql-schema-introspection: "yes" | "no" # (default: no)
+ resources
- group_resource_visibility: "true" # return group resource status in check-presets
# (default: false)
Expand Down Expand Up @@ -320,6 +321,7 @@
},
"api": {
"allow-origins": "*",
"allow-graphql-schema-introspection": False,
},
"redis": config.redis_default_config,
"docker": {
Expand Down Expand Up @@ -408,6 +410,10 @@ def container_registry_serialize(v: dict[str, Any]) -> dict[str, str]:
t.Key("api", default=_config_defaults["api"]): t.Dict(
{
t.Key("allow-origins", default=_config_defaults["api"]["allow-origins"]): t.String,
t.Key(
"allow-graphql-schema-introspection",
default=_config_defaults["api"]["allow-graphql-schema-introspection"],
): t.ToBool,
}
).allow_extra("*"),
t.Key("redis", default=_config_defaults["redis"]): config.redis_config_iv,
Expand Down
21 changes: 13 additions & 8 deletions src/ai/backend/manager/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from aiotools import apartial
from graphene.types import Scalar
from graphene.types.scalars import MAX_INT, MIN_INT
from graphql import Undefined
from graphql.language import ast # pants: no-infer-dep
from sqlalchemy.dialects.postgresql import ARRAY, CIDR, ENUM, JSONB, UUID
from sqlalchemy.engine.result import Result
Expand Down Expand Up @@ -69,8 +70,6 @@
from ..api.exceptions import GenericForbidden, InvalidAPIParameters

if TYPE_CHECKING:
from graphql.execution.executors.asyncio import AsyncioExecutor # pants: no-infer-dep

from .gql import GraphQueryContext
from .user import UserRole

Expand Down Expand Up @@ -759,14 +758,17 @@ def privileged_query(required_role: UserRole):
def wrap(func):
@functools.wraps(func)
async def wrapped(
executor: AsyncioExecutor, info: graphene.ResolveInfo, *args, **kwargs
root: Any,
info: graphene.ResolveInfo,
*args,
**kwargs,
) -> Any:
from .user import UserRole

ctx: GraphQueryContext = info.context
if ctx.user["role"] != UserRole.SUPERADMIN:
raise GenericForbidden("superadmin privilege required")
return await func(executor, info, *args, **kwargs)
return await func(root, info, *args, **kwargs)

return wrapped

Expand All @@ -792,7 +794,10 @@ def scoped_query(
def wrap(resolve_func):
@functools.wraps(resolve_func)
async def wrapped(
executor: AsyncioExecutor, info: graphene.ResolveInfo, *args, **kwargs
root: Any,
info: graphene.ResolveInfo,
*args,
**kwargs,
) -> Any:
from .user import UserRole

Expand Down Expand Up @@ -841,7 +846,7 @@ async def wrapped(
if kwargs.get("project", None) is not None:
kwargs["project"] = group_id
kwargs[user_key] = user_id
return await resolve_func(executor, info, *args, **kwargs)
return await resolve_func(root, info, *args, **kwargs)

return wrapped

Expand Down Expand Up @@ -1032,8 +1037,8 @@ def set_if_set(
target_key: Optional[str] = None,
) -> None:
v = getattr(src, name)
# NOTE: unset optional fields are passed as null.
if v is not None:
# NOTE: unset optional fields are passed as graphql.Undefined.
if v is not Undefined:
if callable(clean_func):
target[target_key or name] = clean_func(v)
else:
Expand Down
Loading

0 comments on commit 2884748

Please sign in to comment.