Skip to content

Commit

Permalink
fix: Installer regression by #1724 and refactor docker connector for …
Browse files Browse the repository at this point in the history
…richer debug info (#1732)

Backported-from: main
Backported-to: 23.09
  • Loading branch information
achimnol committed Nov 22, 2023
1 parent 76bd784 commit 36e98cc
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 60 deletions.
1 change: 1 addition & 0 deletions changes/1732.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix an installer regression in #1724 to inappropriately cache an aiohttp connector instance used to access the local Docker API
12 changes: 6 additions & 6 deletions src/ai/backend/agent/vendor/linux.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ async def get_available_cores() -> set[int]:

async def read_cgroup_cpuset() -> tuple[set[int], str] | None:
try:
_, docker_host, connector = get_docker_connector()
async with aiohttp.ClientSession(connector=connector) as sess:
async with sess.get(docker_host / "info") as resp:
connector = get_docker_connector()
async with aiohttp.ClientSession(connector=connector.connector) as sess:
async with sess.get(connector.docker_host / "info") as resp:
data = await resp.json()
except (RuntimeError, aiohttp.ClientError):
return None
Expand Down Expand Up @@ -151,9 +151,9 @@ async def read_os_cpus() -> tuple[set[int], str] | None:
case "darwin" | "win32":
try:
cpuset_source = "the cpus accessible by the docker service"
_, docker_host, connector = get_docker_connector()
async with aiohttp.ClientSession(connector=connector) as sess:
async with sess.get(docker_host / "info") as resp:
connector = get_docker_connector()
async with aiohttp.ClientSession(connector=connector.connector) as sess:
async with sess.get(connector.docker_host / "info") as resp:
data = await resp.json()
return {idx for idx in range(data["NCPU"])}
except (RuntimeError, aiohttp.ClientError):
Expand Down
6 changes: 3 additions & 3 deletions src/ai/backend/common/cgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ class CgroupVersion:


async def get_docker_cgroup_version() -> CgroupVersion:
_, docker_host, connector = get_docker_connector()
async with aiohttp.ClientSession(connector=connector) as sess:
async with sess.get(docker_host / "info") as resp:
connector = get_docker_connector()
async with aiohttp.ClientSession(connector=connector.connector) as sess:
async with sess.get(connector.docker_host / "info") as resp:
data = await resp.json()
return CgroupVersion(data["CgroupVersion"], data["CgroupDriver"])

Expand Down
66 changes: 56 additions & 10 deletions src/ai/backend/common/docker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import enum
import functools
import ipaddress
import itertools
Expand All @@ -8,6 +9,7 @@
import os
import re
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import (
Any,
Expand Down Expand Up @@ -98,6 +100,20 @@
).ignore_extra("*")


class DockerConnectorSource(enum.Enum):
ENV_VAR = enum.auto()
USER_CONTEXT = enum.auto()
KNOWN_LOCATION = enum.auto()


@dataclass()
class DockerConnector:
sock_path: Path | None
docker_host: yarl.URL
connector: aiohttp.BaseConnector
source: DockerConnectorSource


@functools.lru_cache()
def get_docker_context_host() -> str | None:
try:
Expand Down Expand Up @@ -136,13 +152,16 @@ def parse_docker_host_url(
raise RuntimeError("unsupported connection scheme", unknown_scheme)
return (
path,
yarl.URL("http://localhost"),
connector_cls(decoded_path),
yarl.URL("http://docker"), # a fake hostname to construct a valid URL
connector_cls(decoded_path, force_close=True),
)


# We may cache the connector type but not connector instances!
@functools.lru_cache()
def search_docker_socket_files() -> tuple[Path | None, yarl.URL, aiohttp.BaseConnector]:
def _search_docker_socket_files_impl() -> (
tuple[Path, yarl.URL, type[aiohttp.UnixConnector] | type[aiohttp.NamedPipeConnector]]
):
connector_cls: type[aiohttp.UnixConnector] | type[aiohttp.NamedPipeConnector]
match sys.platform:
case "linux" | "darwin":
Expand All @@ -161,23 +180,50 @@ def search_docker_socket_files() -> tuple[Path | None, yarl.URL, aiohttp.BaseCon
raise RuntimeError(f"unsupported platform: {platform_name}")
for p in search_paths:
if p.exists() and (p.is_socket() or p.is_fifo()):
decoded_path = os.fsdecode(p)
return (
p,
yarl.URL("http://localhost"),
connector_cls(decoded_path),
yarl.URL("http://docker"), # a fake hostname to construct a valid URL
connector_cls,
)
else:
searched_paths = ", ".join(map(os.fsdecode, search_paths))
raise RuntimeError(f"could not find the docker socket; tried: {searched_paths}")


def get_docker_connector() -> tuple[Path | None, yarl.URL, aiohttp.BaseConnector]:
def search_docker_socket_files() -> tuple[Path | None, yarl.URL, aiohttp.BaseConnector]:
connector_cls: type[aiohttp.UnixConnector] | type[aiohttp.NamedPipeConnector]
sock_path, docker_host, connector_cls = _search_docker_socket_files_impl()
return (
sock_path,
docker_host,
connector_cls(os.fsdecode(sock_path), force_close=True),
)


def get_docker_connector() -> DockerConnector:
if raw_docker_host := os.environ.get("DOCKER_HOST", None):
return parse_docker_host_url(yarl.URL(raw_docker_host))
sock_path, docker_host, connector = parse_docker_host_url(yarl.URL(raw_docker_host))
return DockerConnector(
sock_path,
docker_host,
connector,
DockerConnectorSource.ENV_VAR,
)
if raw_docker_host := get_docker_context_host():
return parse_docker_host_url(yarl.URL(raw_docker_host))
return search_docker_socket_files()
sock_path, docker_host, connector = parse_docker_host_url(yarl.URL(raw_docker_host))
return DockerConnector(
sock_path,
docker_host,
connector,
DockerConnectorSource.USER_CONTEXT,
)
sock_path, docker_host, connector = search_docker_socket_files()
return DockerConnector(
sock_path,
docker_host,
connector,
DockerConnectorSource.KNOWN_LOCATION,
)


async def login(
Expand Down
32 changes: 22 additions & 10 deletions src/ai/backend/install/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import TYPE_CHECKING

import aiohttp
from aiohttp.client_exceptions import ClientConnectorError
from rich.text import Text

from ai.backend.common.docker import get_docker_connector
Expand Down Expand Up @@ -73,10 +74,10 @@ async def detect_system_docker(ctx: Context) -> str:
Text.from_markup("[yellow]Docker commands require sudo. We will use sudo.[/]")
)
try:
sock_path, docker_host, connector = get_docker_connector()
connector = get_docker_connector()
except RuntimeError as e:
raise PrerequisiteError(f"Could not find the docker socket ({e})") from e
ctx.log.write(Text.from_markup(f"[cyan]{docker_host=} {sock_path=}[/]"))
ctx.log.write(Text.from_markup(f"[cyan]{connector=}[/]"))

# Test a docker command to ensure passwordless sudo.
proc = await asyncio.create_subprocess_exec(
Expand All @@ -101,9 +102,11 @@ async def detect_system_docker(ctx: Context) -> str:
# Change the docker socket permission (temporarily)
# so that we could access the docker daemon API directly.
# NOTE: For TCP URLs (e.g., remote Docker), we don't have the socket file.
if sock_path is not None and not sock_path.resolve().is_relative_to(Path.home()):
if connector.sock_path is not None and not connector.sock_path.resolve().is_relative_to(
Path.home()
):
proc = await asyncio.create_subprocess_exec(
*["sudo", "chmod", "666", str(sock_path)],
*["sudo", "chmod", "666", str(connector.sock_path)],
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
Expand All @@ -112,10 +115,13 @@ async def detect_system_docker(ctx: Context) -> str:
if (await proc.wait()) != 0:
raise RuntimeError("Failed to set the docker socket permission", stdout)

async with aiohttp.ClientSession(connector=connector) as sess:
async with sess.get(docker_host / "version") as r:
async with aiohttp.ClientSession(connector=connector.connector) as sess:
async with sess.get(connector.docker_host / "version") as r:
if r.status != 200:
raise RuntimeError("Failed to query the Docker daemon API")
raise RuntimeError(
"The Docker daemon API responded with unexpected response:"
f" {r.status} {r.reason}"
)
response_data = await r.json()
return response_data["Version"]

Expand Down Expand Up @@ -155,11 +161,15 @@ async def get_preferred_pants_local_exec_root(ctx: Context) -> str:


async def determine_docker_sudo() -> bool:
sock_path, docker_host, connector = get_docker_connector()
connector = get_docker_connector()
try:
async with aiohttp.ClientSession(connector=connector) as sess:
async with sess.get(docker_host / "version") as r:
async with aiohttp.ClientSession(connector=connector.connector) as sess:
async with sess.get(connector.docker_host / "version") as r:
await r.json()
except ClientConnectorError as e:
if isinstance(e.os_error, PermissionError):
return True
raise
except PermissionError:
return True
return False
Expand All @@ -180,6 +190,8 @@ async def check_docker(ctx: Context) -> None:
else:
fail_with_system_docker_install_request()

# Compose is not a part of the docker API but a client-side plugin.
# We need to execute the client command to get information about it.
proc = await asyncio.create_subprocess_exec(
*ctx.docker_sudo, "docker", "compose", "version", stdout=asyncio.subprocess.PIPE
)
Expand Down
6 changes: 3 additions & 3 deletions src/ai/backend/manager/container_registry/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
class LocalRegistry(BaseContainerRegistry):
@actxmgr
async def prepare_client_session(self) -> AsyncIterator[tuple[yarl.URL, aiohttp.ClientSession]]:
_, url, connector = get_docker_connector()
async with aiohttp.ClientSession(connector=connector) as sess:
yield url, sess
connector = get_docker_connector()
async with aiohttp.ClientSession(connector=connector.connector) as sess:
yield connector.docker_host, sess

async def fetch_repositories(
self,
Expand Down
57 changes: 29 additions & 28 deletions tests/common/test_docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import functools
import itertools
import typing
from pathlib import PosixPath
from unittest.mock import MagicMock, call

import aiohttp
Expand All @@ -10,46 +11,46 @@
from ai.backend.common.docker import (
ImageRef,
PlatformTagSet,
_search_docker_socket_files_impl,
default_registry,
default_repository,
get_docker_connector,
get_docker_context_host,
search_docker_socket_files,
)


@pytest.mark.asyncio
async def test_get_docker_connector(monkeypatch):
get_docker_context_host.cache_clear()
search_docker_socket_files.cache_clear()
_search_docker_socket_files_impl.cache_clear()
with monkeypatch.context() as m:
m.setenv("DOCKER_HOST", "http://localhost:2375")
_, url, connector = get_docker_connector()
assert str(url) == "http://localhost:2375"
assert isinstance(connector, aiohttp.TCPConnector)
connector = get_docker_connector()
assert str(connector.docker_host) == "http://localhost:2375"
assert isinstance(connector.connector, aiohttp.TCPConnector)

get_docker_context_host.cache_clear()
search_docker_socket_files.cache_clear()
_search_docker_socket_files_impl.cache_clear()
with monkeypatch.context() as m:
m.setenv("DOCKER_HOST", "https://example.com:2375")
_, url, connector = get_docker_connector()
assert str(url) == "https://example.com:2375"
assert isinstance(connector, aiohttp.TCPConnector)
connector = get_docker_connector()
assert str(connector.docker_host) == "https://example.com:2375"
assert isinstance(connector.connector, aiohttp.TCPConnector)

get_docker_context_host.cache_clear()
search_docker_socket_files.cache_clear()
_search_docker_socket_files_impl.cache_clear()
with monkeypatch.context() as m:
m.setenv("DOCKER_HOST", "unix:///run/docker.sock")
m.setattr("pathlib.Path.exists", lambda self: True)
m.setattr("pathlib.Path.is_socket", lambda self: True)
m.setattr("pathlib.Path.is_fifo", lambda self: False)
_, url, connector = get_docker_connector()
assert str(url) == "http://localhost"
assert isinstance(connector, aiohttp.UnixConnector)
assert connector.path == "/run/docker.sock"
connector = get_docker_connector()
assert str(connector.docker_host) == "http://docker"
assert isinstance(connector.connector, aiohttp.UnixConnector)
assert connector.sock_path == PosixPath("/run/docker.sock")

get_docker_context_host.cache_clear()
search_docker_socket_files.cache_clear()
_search_docker_socket_files_impl.cache_clear()
with monkeypatch.context() as m:
m.setenv("DOCKER_HOST", "unix:///run/docker.sock")
m.setattr("pathlib.Path.exists", lambda self: False)
Expand All @@ -59,25 +60,25 @@ async def test_get_docker_connector(monkeypatch):
get_docker_connector()

get_docker_context_host.cache_clear()
search_docker_socket_files.cache_clear()
_search_docker_socket_files_impl.cache_clear()
with monkeypatch.context() as m:
m.setenv("DOCKER_HOST", "npipe:////./pipe/docker_engine")
m.setattr("pathlib.Path.exists", lambda self: True)
m.setattr("pathlib.Path.is_socket", lambda self: False)
m.setattr("pathlib.Path.is_fifo", lambda self: True)
mock_connector = MagicMock()
m.setattr("aiohttp.NamedPipeConnector", mock_connector)
_, url, connector = get_docker_connector()
assert str(url) == "http://localhost"
mock_connector.assert_called_once_with(r"\\.\pipe\docker_engine")
connector = get_docker_connector()
assert str(connector.docker_host) == "http://docker"
mock_connector.assert_called_once_with(r"\\.\pipe\docker_engine", force_close=True)

search_docker_socket_files.cache_clear()
_search_docker_socket_files_impl.cache_clear()
with monkeypatch.context() as m:
m.setenv("DOCKER_HOST", "unknown://dockerhost")
with pytest.raises(RuntimeError, match="unsupported connection scheme"):
get_docker_connector()

search_docker_socket_files.cache_clear()
_search_docker_socket_files_impl.cache_clear()
with monkeypatch.context() as m:
m.delenv("DOCKER_HOST", raising=False)
m.setattr("ai.backend.common.docker.get_docker_context_host", lambda: None)
Expand All @@ -88,12 +89,12 @@ async def test_get_docker_connector(monkeypatch):
m.setattr("pathlib.Path.exists", lambda self: True)
m.setattr("pathlib.Path.is_socket", lambda self: True)
m.setattr("pathlib.Path.is_fifo", lambda self: False)
_, url, connector = get_docker_connector()
connector = get_docker_connector()
mock_path.assert_has_calls([call("/run/docker.sock"), call("/var/run/docker.sock")])
assert str(url) == "http://localhost"
assert isinstance(connector, aiohttp.UnixConnector)
assert str(connector.docker_host) == "http://docker"
assert isinstance(connector.connector, aiohttp.UnixConnector)

search_docker_socket_files.cache_clear()
_search_docker_socket_files_impl.cache_clear()
with monkeypatch.context() as m:
m.delenv("DOCKER_HOST", raising=False)
m.setattr("ai.backend.common.docker.get_docker_context_host", lambda: None)
Expand All @@ -105,12 +106,12 @@ async def test_get_docker_connector(monkeypatch):
m.setattr("pathlib.Path.is_fifo", lambda self: True)
mock_connector = MagicMock()
m.setattr("aiohttp.NamedPipeConnector", mock_connector)
_, url, connector = get_docker_connector()
connector = get_docker_connector()
mock_path.assert_has_calls([call(r"\\.\pipe\docker_engine")])
assert str(url) == "http://localhost"
assert str(connector.docker_host) == "http://docker"

get_docker_context_host.cache_clear()
search_docker_socket_files.cache_clear()
_search_docker_socket_files_impl.cache_clear()
with monkeypatch.context() as m:
m.delenv("DOCKER_HOST", raising=False)
m.setattr("ai.backend.common.docker.get_docker_context_host", lambda: None)
Expand Down

0 comments on commit 36e98cc

Please sign in to comment.