Skip to content

Commit

Permalink
fix: Improve docker context support (#1724)
Browse files Browse the repository at this point in the history
  • Loading branch information
achimnol authored Nov 22, 2023
1 parent 2143f7c commit 717079a
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 120 deletions.
1 change: 1 addition & 0 deletions changes/1724.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix the installer to use the refactored `common.docker.get_docker_connector()` for system docker detection which now also detects the active docker context if configured
4 changes: 2 additions & 2 deletions src/ai/backend/agent/vendor/linux.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async def get_available_cores() -> set[int]:

async def read_cgroup_cpuset() -> tuple[set[int], str] | None:
try:
docker_host, connector = get_docker_connector()
_, docker_host, connector = get_docker_connector()
async with aiohttp.ClientSession(connector=connector) as sess:
async with sess.get(docker_host / "info") as resp:
data = await resp.json()
Expand Down Expand Up @@ -151,7 +151,7 @@ async def read_os_cpus() -> tuple[set[int], str] | None:
case "darwin" | "win32":
try:
cpuset_source = "the cpus accessible by the docker service"
docker_host, connector = get_docker_connector()
_, docker_host, connector = get_docker_connector()
async with aiohttp.ClientSession(connector=connector) as sess:
async with sess.get(docker_host / "info") as resp:
data = await resp.json()
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/common/cgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class CgroupVersion:


async def get_docker_cgroup_version() -> CgroupVersion:
docker_host, connector = get_docker_connector()
_, docker_host, connector = get_docker_connector()
async with aiohttp.ClientSession(connector=connector) as sess:
async with sess.get(docker_host / "info") as resp:
data = await resp.json()
Expand Down
123 changes: 83 additions & 40 deletions src/ai/backend/common/docker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from __future__ import annotations

import functools
import ipaddress
import itertools
import json
Expand All @@ -8,15 +11,12 @@
from pathlib import Path
from typing import (
Any,
Dict,
Final,
Iterable,
Mapping,
MutableMapping,
Optional,
Sequence,
Tuple,
Type,
Union,
)

Expand Down Expand Up @@ -98,46 +98,86 @@
).ignore_extra("*")


def get_docker_connector() -> tuple[yarl.URL, aiohttp.BaseConnector]:
connector_cls: Type[aiohttp.UnixConnector] | Type[aiohttp.NamedPipeConnector]
if raw_docker_host := os.environ.get("DOCKER_HOST", None):
docker_host = yarl.URL(raw_docker_host)
match docker_host.scheme:
case "http" | "https":
return docker_host, aiohttp.TCPConnector()
case "unix":
search_paths = [Path(docker_host.path)]
connector_cls = aiohttp.UnixConnector
case "npipe":
search_paths = [Path(docker_host.path.replace("/", "\\"))]
connector_cls = aiohttp.NamedPipeConnector
case _ as unknown_scheme:
raise RuntimeError("unsupported connection scheme", unknown_scheme)
else:
match sys.platform:
case "linux" | "darwin":
search_paths = [
Path("/run/docker.sock"),
Path("/var/run/docker.sock"),
Path.home() / ".docker/run/docker.sock",
]
connector_cls = aiohttp.UnixConnector
case "win32":
search_paths = [
Path(r"\\.\pipe\docker_engine"),
]
connector_cls = aiohttp.NamedPipeConnector
case _ as platform_name:
raise RuntimeError("unsupported platform", platform_name)
@functools.lru_cache()
def get_docker_context_host() -> str | None:
try:
docker_config_path = Path.home() / ".docker" / "config.json"
docker_config = json.loads(docker_config_path.read_bytes())
except IOError:
return None
current_context_name = docker_config.get("currentContext", "default")
for meta_path in (Path.home() / ".docker" / "contexts" / "meta").glob("*/meta.json"):
context_data = json.loads(meta_path.read_bytes())
if context_data["Name"] == current_context_name:
return context_data["Endpoints"]["docker"]["Host"]
return None


def parse_docker_host_url(
docker_host: yarl.URL,
) -> tuple[Path | None, yarl.URL, aiohttp.BaseConnector]:
connector_cls: type[aiohttp.UnixConnector] | type[aiohttp.NamedPipeConnector]
match docker_host.scheme:
case "http" | "https":
return None, docker_host, aiohttp.TCPConnector()
case "unix":
path = Path(docker_host.path)
if not path.exists() or not path.is_socket():
raise RuntimeError(f"DOCKER_HOST {path} is not a valid socket file.")
decoded_path = os.fsdecode(path)
connector_cls = aiohttp.UnixConnector
case "npipe":
path = Path(docker_host.path.replace("/", "\\"))
if not path.exists() or not path.is_fifo():
raise RuntimeError(f"DOCKER_HOST {path} is not a valid named pipe.")
decoded_path = os.fsdecode(path)
connector_cls = aiohttp.NamedPipeConnector
case _ as unknown_scheme:
raise RuntimeError("unsupported connection scheme", unknown_scheme)
return (
path,
yarl.URL("http://localhost"),
connector_cls(decoded_path),
)


@functools.lru_cache()
def search_docker_socket_files() -> tuple[Path | None, yarl.URL, aiohttp.BaseConnector]:
connector_cls: type[aiohttp.UnixConnector] | type[aiohttp.NamedPipeConnector]
match sys.platform:
case "linux" | "darwin":
search_paths = [
Path("/run/docker.sock"),
Path("/var/run/docker.sock"),
Path.home() / ".docker/run/docker.sock",
]
connector_cls = aiohttp.UnixConnector
case "win32":
search_paths = [
Path(r"\\.\pipe\docker_engine"),
]
connector_cls = aiohttp.NamedPipeConnector
case _ as platform_name:
raise RuntimeError(f"unsupported platform: {platform_name}")
for p in search_paths:
if p.exists() and (p.is_socket() or p.is_fifo()):
decoded_path = os.fsdecode(p)
return (
p,
yarl.URL("http://localhost"),
connector_cls(decoded_path),
)
else:
raise RuntimeError("could not find the docker socket")
searched_paths = ", ".join(map(os.fsdecode, search_paths))
raise RuntimeError(f"could not find the docker socket; tried: {searched_paths}")


def get_docker_connector() -> tuple[Path | None, yarl.URL, aiohttp.BaseConnector]:
if raw_docker_host := os.environ.get("DOCKER_HOST", None):
return parse_docker_host_url(yarl.URL(raw_docker_host))
if raw_docker_host := get_docker_context_host():
return parse_docker_host_url(yarl.URL(raw_docker_host))
return search_docker_socket_files()


async def login(
Expand Down Expand Up @@ -210,7 +250,10 @@ async def get_known_registries(etcd: AsyncEtcd) -> Mapping[str, yarl.URL]:
return results


def is_known_registry(val: str, known_registries: Union[Mapping[str, Any], Sequence[str]] = None):
def is_known_registry(
val: str,
known_registries: Union[Mapping[str, Any], Sequence[str]] | None = None,
):
if val == default_registry:
return True
if known_registries is not None and val in known_registries:
Expand All @@ -224,7 +267,7 @@ def is_known_registry(val: str, known_registries: Union[Mapping[str, Any], Seque
return False


async def get_registry_info(etcd: AsyncEtcd, name: str) -> Tuple[yarl.URL, dict]:
async def get_registry_info(etcd: AsyncEtcd, name: str) -> tuple[yarl.URL, dict]:
reg_path = f"config/docker/registry/{etcd_quote(name)}"
item = await etcd.get_prefix(reg_path)
if not item:
Expand Down Expand Up @@ -272,7 +315,7 @@ def validate_image_labels(labels: dict[str, str]) -> dict[str, str]:

class PlatformTagSet(Mapping):
__slots__ = ("_data",)
_data: Dict[str, str]
_data: dict[str, str]
_rx_ver = re.compile(r"^(?P<tag>[a-zA-Z]+)(?P<version>\d+(?:\.\d+)*[a-z0-9]*)?$")

def __init__(self, tags: Iterable[str]):
Expand Down Expand Up @@ -357,7 +400,7 @@ def __init__(
self._update_tag_set()

@staticmethod
def _parse_image_tag(s: str, using_default_registry: bool = False) -> Tuple[str, str]:
def _parse_image_tag(s: str, using_default_registry: bool = False) -> tuple[str, str]:
image_tag = s.rsplit(":", maxsplit=1)
if len(image_tag) == 1:
image = image_tag[0]
Expand Down Expand Up @@ -450,7 +493,7 @@ def architecture(self) -> str:
return self._arch

@property
def tag_set(self) -> Tuple[str, PlatformTagSet]:
def tag_set(self) -> tuple[str, PlatformTagSet]:
# e.g., '3.6', {'ubuntu', 'cuda', ...}
return self._tag_set

Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/install/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ async def check_prerequisites(self) -> None:
)
)
await self.log.wait_continue()
if determine_docker_sudo():
if await determine_docker_sudo():
self.docker_sudo = ["sudo"]
self.log.write(
Text.from_markup(
Expand Down
110 changes: 42 additions & 68 deletions src/ai/backend/install/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import asyncio
import base64
import hashlib
import json
import os
import re
from pathlib import Path
from typing import TYPE_CHECKING

import aiohttp
from rich.text import Text

from ai.backend.common.docker import get_docker_connector
from ai.backend.install.types import PrerequisiteError

from .http import request_unix
Expand Down Expand Up @@ -66,78 +67,57 @@ async def detect_snap_docker():
return pkg_data["version"]


async def detect_system_docker(ctx: Context):
# Well-known docker socket paths
sock_paths = [
Path("/run/docker.sock"), # Linux default
Path("/var/run/docker.sock"), # macOS default
]
async def detect_system_docker(ctx: Context) -> str:
if ctx.docker_sudo:
ctx.log.write(
Text.from_markup("[yellow]Docker commands require sudo. We will use sudo.[/]")
)
try:
sock_path, docker_host, connector = get_docker_connector()
except RuntimeError as e:
raise PrerequisiteError(f"Could not find the docker socket ({e})") from e
ctx.log.write(Text.from_markup(f"[cyan]{docker_host=} {sock_path=}[/]"))

# Read from context
# Test a docker command to ensure passwordless sudo.
proc = await asyncio.create_subprocess_exec(
*(*ctx.docker_sudo, "docker", "context", "show"),
*(*ctx.docker_sudo, "docker", "version"),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
assert proc.stdout is not None
stdout = ""
try:
async with asyncio.timeout(0.5):
stdout = (await proc.stdout.read()).decode().strip()
await proc.wait()
await proc.communicate()
except asyncio.TimeoutError:
proc.kill()
await proc.wait()
raise PrerequisiteError(
"sudo requires prompt.",
instruction="Please make sudo available without password prompts.",
)
context_name = stdout
proc = await asyncio.create_subprocess_exec(
*(*ctx.docker_sudo, "docker", "context", "inspect", context_name),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.DEVNULL,
)
assert proc.stdout is not None
stdout = (await proc.stdout.read()).decode()
await proc.wait()
context_info = json.loads(stdout)
context_sock_path = context_info[0]["Endpoints"]["docker"]["Host"].removeprefix("unix://")
sock_paths.insert(0, Path(context_sock_path))

# Read from environment variable
if env_sock_path := os.environ.get("DOCKER_HOST", None):
# Some special setups like OrbStack may have a custom DOCKER_HOST.
env_sock_path = env_sock_path.removeprefix("unix://")
sock_paths.insert(0, Path(env_sock_path))

for sock_path in sock_paths:
if sock_path.is_socket():
break
else:
raise RuntimeError(
"Failed to find Docker daemon socket ("
+ ", ".join(str(sock_path) for sock_path in sock_paths)
+ ")"
)
ctx.log.write(Text.from_markup(f"[yellow]{sock_path=}[/]"))

if ctx.docker_sudo:
# change the docker socket permission (temporarily)
# Change the docker socket permission (temporarily)
# so that we could access the docker daemon API directly.
proc = await asyncio.create_subprocess_exec(
*["sudo", "chmod", "666", str(sock_path)],
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
assert proc.stdout is not None
stdout = (await proc.stdout.read()).decode()
if (await proc.wait()) != 0:
raise RuntimeError("Failed to set the docker socket permission", stdout)
async with request_unix("GET", str(sock_path), "http://localhost/version") as r:
if r.status != 200:
raise RuntimeError("Failed to query the Docker daemon API")
response_data = await r.json()
return response_data["Version"]
# NOTE: For TCP URLs (e.g., remote Docker), we don't have the socket file.
if sock_path is not None and not sock_path.resolve().is_relative_to(Path.home()):
proc = await asyncio.create_subprocess_exec(
*["sudo", "chmod", "666", str(sock_path)],
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
assert proc.stdout is not None
stdout = (await proc.stdout.read()).decode()
if (await proc.wait()) != 0:
raise RuntimeError("Failed to set the docker socket permission", stdout)

async with aiohttp.ClientSession(connector=connector) as sess:
async with sess.get(docker_host / "version") as r:
if r.status != 200:
raise RuntimeError("Failed to query the Docker daemon API")
response_data = await r.json()
return response_data["Version"]


def fail_with_snap_docker_refresh_request() -> None:
Expand Down Expand Up @@ -175,19 +155,13 @@ async def get_preferred_pants_local_exec_root(ctx: Context) -> str:


async def determine_docker_sudo() -> bool:
proc = await asyncio.create_subprocess_exec(
*("docker", "version"),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
assert proc.stdout is not None
stdout = (await proc.stdout.read()).decode()
if (await proc.wait()) != 0:
if "permission denied" in stdout.lower():
# installed, requires sudo
return True
raise RuntimeError("Docker client command is not available in the host.")
# installed, does not require sudo
sock_path, docker_host, connector = get_docker_connector()
try:
async with aiohttp.ClientSession(connector=connector) as sess:
async with sess.get(docker_host / "version") as r:
await r.json()
except PermissionError:
return True
return False


Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/manager/container_registry/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
class LocalRegistry(BaseContainerRegistry):
@actxmgr
async def prepare_client_session(self) -> AsyncIterator[tuple[yarl.URL, aiohttp.ClientSession]]:
url, connector = get_docker_connector()
_, url, connector = get_docker_connector()
async with aiohttp.ClientSession(connector=connector) as sess:
yield url, sess

Expand Down
Loading

0 comments on commit 717079a

Please sign in to comment.