Skip to content

Commit

Permalink
feat: Add session tag in key-value format
Browse files Browse the repository at this point in the history
  • Loading branch information
Yaminyam committed Jan 19, 2024
1 parent 91ced47 commit 003bb7a
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 5 deletions.
9 changes: 8 additions & 1 deletion src/ai/backend/client/cli/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ai.backend.cli.types import ExitCode
from ai.backend.client.cli.session.execute import prepare_env_arg, prepare_resource_arg
from ai.backend.client.session import Session
from ai.backend.client.utils import validate_key_value
from ai.backend.common.arch import DEFAULT_IMAGE_ARCH

from ..output.fields import routing_fields, service_fields
Expand Down Expand Up @@ -194,7 +195,13 @@ def info(ctx: CLIContext, service_name_or_id: str):
help="A user-defined script to execute on startup.",
)
# extra options
@click.option("--tag", type=str, default=None, help="User-defined tag string to annotate sessions.")
@click.option(
"--tag",
type=str,
callback=validate_key_value,
default=None,
help="User-defined tag string to annotate sessions.",
)
@click.option(
"--arch",
"--architecture",
Expand Down
8 changes: 7 additions & 1 deletion src/ai/backend/client/cli/session/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import click

from ai.backend.client.utils import validate_key_value

START_OPTION = [
click.option(
"-t",
Expand Down Expand Up @@ -58,7 +60,11 @@
),
# extra options
click.option(
"--tag", type=str, default=None, help="User-defined tag string to annotate sessions."
"--tag",
type=str,
callback=validate_key_value,
default=None,
help="User-defined tag string to annotate sessions.",
),
# resource spec
click.option(
Expand Down
7 changes: 6 additions & 1 deletion src/ai/backend/client/cli/session/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ai.backend.cli.main import main
from ai.backend.cli.params import CommaSeparatedListType, OptionalType
from ai.backend.cli.types import ExitCode, Undefined, undefined
from ai.backend.client.utils import validate_key_value
from ai.backend.common.arch import DEFAULT_IMAGE_ARCH

from ...compat import asyncio_run
Expand Down Expand Up @@ -82,7 +83,11 @@ def _create_cmd(docs: str = None):
help="A user-defined script to execute on startup.",
)
@click.option(
"--tag", type=str, default=None, help="User-defined tag string to annotate sessions."
"--tag",
type=str,
callback=validate_key_value,
default=None,
help="User-defined tag string to annotate sessions.",
)
@click.option(
"--arch",
Expand Down
11 changes: 11 additions & 0 deletions src/ai/backend/client/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
import io
import os

import click
from tqdm import tqdm


def validate_key_value(ctx, param, value):
key_value_pairs = value.split(",")
for pair in key_value_pairs:
if "=" not in pair:
raise click.BadParameter(
'Invalid format. Each key-value pair should be in the format "key=value".'
)
return value


class ProgressReportingReader(io.BufferedReader):
def __init__(self, file_path, *, tqdm_instance=None):
super().__init__(open(file_path, "rb"))
Expand Down
20 changes: 18 additions & 2 deletions src/ai/backend/manager/models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,23 @@ class SessionRow(Base):

# `image` column is identical to kernels `image` column.
images = sa.Column("images", sa.ARRAY(sa.String), nullable=True)
tag = sa.Column("tag", sa.String(length=64), nullable=True)
tag = sa.Column(
"tag",
pgsql.JSONB(),
sa.CheckConstraint(
"""
(
SELECT jsonb_object_agg(key, value)
FROM jsonb_each_text(tag)
WHERE length(KEY) <= 128
AND length(value) <= 256) is NOT NULL
AND jsonb_array_length(jsonb_object_keys(json_data)) <= 50
)
"""
),
nullable=True,
default={},
)

# Resource occupation
# occupied_slots = sa.Column('occupied_slots', ResourceSlotColumn(), nullable=False)
Expand Down Expand Up @@ -1141,7 +1157,7 @@ class Meta:
# identity
session_id = graphene.UUID() # identical to `id`
main_kernel_id = graphene.UUID()
tag = graphene.String()
tag = graphene.JSONString()
name = graphene.String()
type = graphene.String()
main_kernel_role = graphene.String()
Expand Down

0 comments on commit 003bb7a

Please sign in to comment.