Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Improve docker context support #1724

Merged
merged 13 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/1724.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix the installer to use the refactored `common.docker.get_docker_connector()` for system docker detection which now also detects the active docker context if configured
4 changes: 2 additions & 2 deletions src/ai/backend/agent/vendor/linux.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async def get_available_cores() -> set[int]:

async def read_cgroup_cpuset() -> tuple[set[int], str] | None:
try:
docker_host, connector = get_docker_connector()
_, docker_host, connector = get_docker_connector()
async with aiohttp.ClientSession(connector=connector) as sess:
async with sess.get(docker_host / "info") as resp:
data = await resp.json()
Expand Down Expand Up @@ -151,7 +151,7 @@ async def read_os_cpus() -> tuple[set[int], str] | None:
case "darwin" | "win32":
try:
cpuset_source = "the cpus accessible by the docker service"
docker_host, connector = get_docker_connector()
_, docker_host, connector = get_docker_connector()
async with aiohttp.ClientSession(connector=connector) as sess:
async with sess.get(docker_host / "info") as resp:
data = await resp.json()
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/common/cgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class CgroupVersion:


async def get_docker_cgroup_version() -> CgroupVersion:
docker_host, connector = get_docker_connector()
_, docker_host, connector = get_docker_connector()
async with aiohttp.ClientSession(connector=connector) as sess:
async with sess.get(docker_host / "info") as resp:
data = await resp.json()
Expand Down
123 changes: 83 additions & 40 deletions src/ai/backend/common/docker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from __future__ import annotations

import functools
import ipaddress
import itertools
import json
Expand All @@ -8,15 +11,12 @@
from pathlib import Path
from typing import (
Any,
Dict,
Final,
Iterable,
Mapping,
MutableMapping,
Optional,
Sequence,
Tuple,
Type,
Union,
)

Expand Down Expand Up @@ -98,46 +98,86 @@
).ignore_extra("*")


def get_docker_connector() -> tuple[yarl.URL, aiohttp.BaseConnector]:
connector_cls: Type[aiohttp.UnixConnector] | Type[aiohttp.NamedPipeConnector]
if raw_docker_host := os.environ.get("DOCKER_HOST", None):
docker_host = yarl.URL(raw_docker_host)
match docker_host.scheme:
case "http" | "https":
return docker_host, aiohttp.TCPConnector()
case "unix":
search_paths = [Path(docker_host.path)]
connector_cls = aiohttp.UnixConnector
case "npipe":
search_paths = [Path(docker_host.path.replace("/", "\\"))]
connector_cls = aiohttp.NamedPipeConnector
case _ as unknown_scheme:
raise RuntimeError("unsupported connection scheme", unknown_scheme)
else:
match sys.platform:
case "linux" | "darwin":
search_paths = [
Path("/run/docker.sock"),
Path("/var/run/docker.sock"),
Path.home() / ".docker/run/docker.sock",
]
connector_cls = aiohttp.UnixConnector
case "win32":
search_paths = [
Path(r"\\.\pipe\docker_engine"),
]
connector_cls = aiohttp.NamedPipeConnector
case _ as platform_name:
raise RuntimeError("unsupported platform", platform_name)
@functools.lru_cache()
def get_docker_context_host() -> str | None:
try:
docker_config_path = Path.home() / ".docker" / "config.json"
docker_config = json.loads(docker_config_path.read_bytes())
except IOError:
return None
current_context_name = docker_config.get("currentContext", "default")
for meta_path in (Path.home() / ".docker" / "contexts" / "meta").glob("*/meta.json"):
context_data = json.loads(meta_path.read_bytes())
if context_data["Name"] == current_context_name:
return context_data["Endpoints"]["docker"]["Host"]
return None


def parse_docker_host_url(
docker_host: yarl.URL,
) -> tuple[Path | None, yarl.URL, aiohttp.BaseConnector]:
connector_cls: type[aiohttp.UnixConnector] | type[aiohttp.NamedPipeConnector]
match docker_host.scheme:
case "http" | "https":
return None, docker_host, aiohttp.TCPConnector()
case "unix":
path = Path(docker_host.path)
if not path.exists() or not path.is_socket():
raise RuntimeError(f"DOCKER_HOST {path} is not a valid socket file.")
decoded_path = os.fsdecode(path)
connector_cls = aiohttp.UnixConnector
case "npipe":
path = Path(docker_host.path.replace("/", "\\"))
if not path.exists() or not path.is_fifo():
raise RuntimeError(f"DOCKER_HOST {path} is not a valid named pipe.")
decoded_path = os.fsdecode(path)
connector_cls = aiohttp.NamedPipeConnector
case _ as unknown_scheme:
raise RuntimeError("unsupported connection scheme", unknown_scheme)
return (
path,
yarl.URL("http://localhost"),
connector_cls(decoded_path),
)


@functools.lru_cache()
def search_docker_socket_files() -> tuple[Path | None, yarl.URL, aiohttp.BaseConnector]:
connector_cls: type[aiohttp.UnixConnector] | type[aiohttp.NamedPipeConnector]
match sys.platform:
case "linux" | "darwin":
search_paths = [
Path("/run/docker.sock"),
Path("/var/run/docker.sock"),
Path.home() / ".docker/run/docker.sock",
]
connector_cls = aiohttp.UnixConnector
case "win32":
search_paths = [
Path(r"\\.\pipe\docker_engine"),
]
connector_cls = aiohttp.NamedPipeConnector
case _ as platform_name:
raise RuntimeError(f"unsupported platform: {platform_name}")
for p in search_paths:
if p.exists() and (p.is_socket() or p.is_fifo()):
decoded_path = os.fsdecode(p)
return (
p,
yarl.URL("http://localhost"),
connector_cls(decoded_path),
)
else:
raise RuntimeError("could not find the docker socket")
searched_paths = ", ".join(map(os.fsdecode, search_paths))
raise RuntimeError(f"could not find the docker socket; tried: {searched_paths}")


def get_docker_connector() -> tuple[Path | None, yarl.URL, aiohttp.BaseConnector]:
if raw_docker_host := os.environ.get("DOCKER_HOST", None):
return parse_docker_host_url(yarl.URL(raw_docker_host))
if raw_docker_host := get_docker_context_host():
return parse_docker_host_url(yarl.URL(raw_docker_host))
return search_docker_socket_files()


async def login(
Expand Down Expand Up @@ -210,7 +250,10 @@ async def get_known_registries(etcd: AsyncEtcd) -> Mapping[str, yarl.URL]:
return results


def is_known_registry(val: str, known_registries: Union[Mapping[str, Any], Sequence[str]] = None):
def is_known_registry(
val: str,
known_registries: Union[Mapping[str, Any], Sequence[str]] | None = None,
):
if val == default_registry:
return True
if known_registries is not None and val in known_registries:
Expand All @@ -224,7 +267,7 @@ def is_known_registry(val: str, known_registries: Union[Mapping[str, Any], Seque
return False


async def get_registry_info(etcd: AsyncEtcd, name: str) -> Tuple[yarl.URL, dict]:
async def get_registry_info(etcd: AsyncEtcd, name: str) -> tuple[yarl.URL, dict]:
reg_path = f"config/docker/registry/{etcd_quote(name)}"
item = await etcd.get_prefix(reg_path)
if not item:
Expand Down Expand Up @@ -272,7 +315,7 @@ def validate_image_labels(labels: dict[str, str]) -> dict[str, str]:

class PlatformTagSet(Mapping):
__slots__ = ("_data",)
_data: Dict[str, str]
_data: dict[str, str]
_rx_ver = re.compile(r"^(?P<tag>[a-zA-Z]+)(?P<version>\d+(?:\.\d+)*[a-z0-9]*)?$")

def __init__(self, tags: Iterable[str]):
Expand Down Expand Up @@ -357,7 +400,7 @@ def __init__(
self._update_tag_set()

@staticmethod
def _parse_image_tag(s: str, using_default_registry: bool = False) -> Tuple[str, str]:
def _parse_image_tag(s: str, using_default_registry: bool = False) -> tuple[str, str]:
image_tag = s.rsplit(":", maxsplit=1)
if len(image_tag) == 1:
image = image_tag[0]
Expand Down Expand Up @@ -450,7 +493,7 @@ def architecture(self) -> str:
return self._arch

@property
def tag_set(self) -> Tuple[str, PlatformTagSet]:
def tag_set(self) -> tuple[str, PlatformTagSet]:
# e.g., '3.6', {'ubuntu', 'cuda', ...}
return self._tag_set

Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/install/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ async def check_prerequisites(self) -> None:
)
)
await self.log.wait_continue()
if determine_docker_sudo():
if await determine_docker_sudo():
self.docker_sudo = ["sudo"]
self.log.write(
Text.from_markup(
Expand Down
110 changes: 42 additions & 68 deletions src/ai/backend/install/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import asyncio
import base64
import hashlib
import json
import os
import re
from pathlib import Path
from typing import TYPE_CHECKING

import aiohttp
from rich.text import Text

from ai.backend.common.docker import get_docker_connector
from ai.backend.install.types import PrerequisiteError

from .http import request_unix
Expand Down Expand Up @@ -66,78 +67,57 @@ async def detect_snap_docker():
return pkg_data["version"]


async def detect_system_docker(ctx: Context):
# Well-known docker socket paths
sock_paths = [
Path("/run/docker.sock"), # Linux default
Path("/var/run/docker.sock"), # macOS default
]
async def detect_system_docker(ctx: Context) -> str:
if ctx.docker_sudo:
ctx.log.write(
Text.from_markup("[yellow]Docker commands require sudo. We will use sudo.[/]")
)
try:
sock_path, docker_host, connector = get_docker_connector()
except RuntimeError as e:
raise PrerequisiteError(f"Could not find the docker socket ({e})") from e
ctx.log.write(Text.from_markup(f"[cyan]{docker_host=} {sock_path=}[/]"))

# Read from context
# Test a docker command to ensure passwordless sudo.
proc = await asyncio.create_subprocess_exec(
*(*ctx.docker_sudo, "docker", "context", "show"),
*(*ctx.docker_sudo, "docker", "version"),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
assert proc.stdout is not None
stdout = ""
try:
async with asyncio.timeout(0.5):
stdout = (await proc.stdout.read()).decode().strip()
await proc.wait()
await proc.communicate()
except asyncio.TimeoutError:
proc.kill()
await proc.wait()
raise PrerequisiteError(
"sudo requires prompt.",
instruction="Please make sudo available without password prompts.",
)
context_name = stdout
proc = await asyncio.create_subprocess_exec(
*(*ctx.docker_sudo, "docker", "context", "inspect", context_name),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.DEVNULL,
)
assert proc.stdout is not None
stdout = (await proc.stdout.read()).decode()
await proc.wait()
context_info = json.loads(stdout)
context_sock_path = context_info[0]["Endpoints"]["docker"]["Host"].removeprefix("unix://")
sock_paths.insert(0, Path(context_sock_path))

# Read from environment variable
if env_sock_path := os.environ.get("DOCKER_HOST", None):
# Some special setups like OrbStack may have a custom DOCKER_HOST.
env_sock_path = env_sock_path.removeprefix("unix://")
sock_paths.insert(0, Path(env_sock_path))

for sock_path in sock_paths:
if sock_path.is_socket():
break
else:
raise RuntimeError(
"Failed to find Docker daemon socket ("
+ ", ".join(str(sock_path) for sock_path in sock_paths)
+ ")"
)
ctx.log.write(Text.from_markup(f"[yellow]{sock_path=}[/]"))

if ctx.docker_sudo:
# change the docker socket permission (temporarily)
# Change the docker socket permission (temporarily)
# so that we could access the docker daemon API directly.
proc = await asyncio.create_subprocess_exec(
*["sudo", "chmod", "666", str(sock_path)],
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
assert proc.stdout is not None
stdout = (await proc.stdout.read()).decode()
if (await proc.wait()) != 0:
raise RuntimeError("Failed to set the docker socket permission", stdout)
async with request_unix("GET", str(sock_path), "http://localhost/version") as r:
if r.status != 200:
raise RuntimeError("Failed to query the Docker daemon API")
response_data = await r.json()
return response_data["Version"]
# NOTE: For TCP URLs (e.g., remote Docker), we don't have the socket file.
if sock_path is not None and not sock_path.resolve().is_relative_to(Path.home()):
proc = await asyncio.create_subprocess_exec(
*["sudo", "chmod", "666", str(sock_path)],
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
assert proc.stdout is not None
stdout = (await proc.stdout.read()).decode()
if (await proc.wait()) != 0:
raise RuntimeError("Failed to set the docker socket permission", stdout)

async with aiohttp.ClientSession(connector=connector) as sess:
async with sess.get(docker_host / "version") as r:
if r.status != 200:
raise RuntimeError("Failed to query the Docker daemon API")
response_data = await r.json()
return response_data["Version"]


def fail_with_snap_docker_refresh_request() -> None:
Expand Down Expand Up @@ -175,19 +155,13 @@ async def get_preferred_pants_local_exec_root(ctx: Context) -> str:


async def determine_docker_sudo() -> bool:
proc = await asyncio.create_subprocess_exec(
*("docker", "version"),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
assert proc.stdout is not None
stdout = (await proc.stdout.read()).decode()
if (await proc.wait()) != 0:
if "permission denied" in stdout.lower():
# installed, requires sudo
return True
raise RuntimeError("Docker client command is not available in the host.")
# installed, does not require sudo
sock_path, docker_host, connector = get_docker_connector()
try:
async with aiohttp.ClientSession(connector=connector) as sess:
async with sess.get(docker_host / "version") as r:
await r.json()
except PermissionError:
return True
return False


Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/manager/container_registry/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
class LocalRegistry(BaseContainerRegistry):
@actxmgr
async def prepare_client_session(self) -> AsyncIterator[tuple[yarl.URL, aiohttp.ClientSession]]:
url, connector = get_docker_connector()
_, url, connector = get_docker_connector()
async with aiohttp.ClientSession(connector=connector) as sess:
yield url, sess

Expand Down
Loading
Loading