diff --git a/changes/1732.fix.md b/changes/1732.fix.md new file mode 100644 index 0000000000..96088f0b50 --- /dev/null +++ b/changes/1732.fix.md @@ -0,0 +1 @@ +Fix an installer regression in #1724 to inappropriately cache an aiohttp connector instance used to access the local Docker API diff --git a/src/ai/backend/agent/vendor/linux.py b/src/ai/backend/agent/vendor/linux.py index 9c0473f785..be3af5fcc4 100644 --- a/src/ai/backend/agent/vendor/linux.py +++ b/src/ai/backend/agent/vendor/linux.py @@ -67,9 +67,9 @@ async def get_available_cores() -> set[int]: async def read_cgroup_cpuset() -> tuple[set[int], str] | None: try: - _, docker_host, connector = get_docker_connector() - async with aiohttp.ClientSession(connector=connector) as sess: - async with sess.get(docker_host / "info") as resp: + connector = get_docker_connector() + async with aiohttp.ClientSession(connector=connector.connector) as sess: + async with sess.get(connector.docker_host / "info") as resp: data = await resp.json() except (RuntimeError, aiohttp.ClientError): return None @@ -151,9 +151,9 @@ async def read_os_cpus() -> tuple[set[int], str] | None: case "darwin" | "win32": try: cpuset_source = "the cpus accessible by the docker service" - _, docker_host, connector = get_docker_connector() - async with aiohttp.ClientSession(connector=connector) as sess: - async with sess.get(docker_host / "info") as resp: + connector = get_docker_connector() + async with aiohttp.ClientSession(connector=connector.connector) as sess: + async with sess.get(connector.docker_host / "info") as resp: data = await resp.json() return {idx for idx in range(data["NCPU"])} except (RuntimeError, aiohttp.ClientError): diff --git a/src/ai/backend/common/cgroup.py b/src/ai/backend/common/cgroup.py index ea444a0ab3..6e52204cfa 100644 --- a/src/ai/backend/common/cgroup.py +++ b/src/ai/backend/common/cgroup.py @@ -31,9 +31,9 @@ class CgroupVersion: async def get_docker_cgroup_version() -> CgroupVersion: - _, docker_host, connector = get_docker_connector() - async with aiohttp.ClientSession(connector=connector) as sess: - async with sess.get(docker_host / "info") as resp: + connector = get_docker_connector() + async with aiohttp.ClientSession(connector=connector.connector) as sess: + async with sess.get(connector.docker_host / "info") as resp: data = await resp.json() return CgroupVersion(data["CgroupVersion"], data["CgroupDriver"]) diff --git a/src/ai/backend/common/docker.py b/src/ai/backend/common/docker.py index 8743afba7c..0dfea27eb4 100644 --- a/src/ai/backend/common/docker.py +++ b/src/ai/backend/common/docker.py @@ -1,5 +1,6 @@ from __future__ import annotations +import enum import functools import ipaddress import itertools @@ -8,6 +9,7 @@ import os import re import sys +from dataclasses import dataclass from pathlib import Path from typing import ( Any, @@ -98,6 +100,20 @@ ).ignore_extra("*") +class DockerConnectorSource(enum.Enum): + ENV_VAR = enum.auto() + USER_CONTEXT = enum.auto() + KNOWN_LOCATION = enum.auto() + + +@dataclass() +class DockerConnector: + sock_path: Path | None + docker_host: yarl.URL + connector: aiohttp.BaseConnector + source: DockerConnectorSource + + @functools.lru_cache() def get_docker_context_host() -> str | None: try: @@ -136,13 +152,16 @@ def parse_docker_host_url( raise RuntimeError("unsupported connection scheme", unknown_scheme) return ( path, - yarl.URL("http://localhost"), - connector_cls(decoded_path), + yarl.URL("http://docker"), # a fake hostname to construct a valid URL + connector_cls(decoded_path, force_close=True), ) +# We may cache the connector type but not connector instances! @functools.lru_cache() -def search_docker_socket_files() -> tuple[Path | None, yarl.URL, aiohttp.BaseConnector]: +def _search_docker_socket_files_impl() -> ( + tuple[Path, yarl.URL, type[aiohttp.UnixConnector] | type[aiohttp.NamedPipeConnector]] +): connector_cls: type[aiohttp.UnixConnector] | type[aiohttp.NamedPipeConnector] match sys.platform: case "linux" | "darwin": @@ -161,23 +180,50 @@ def search_docker_socket_files() -> tuple[Path | None, yarl.URL, aiohttp.BaseCon raise RuntimeError(f"unsupported platform: {platform_name}") for p in search_paths: if p.exists() and (p.is_socket() or p.is_fifo()): - decoded_path = os.fsdecode(p) return ( p, - yarl.URL("http://localhost"), - connector_cls(decoded_path), + yarl.URL("http://docker"), # a fake hostname to construct a valid URL + connector_cls, ) else: searched_paths = ", ".join(map(os.fsdecode, search_paths)) raise RuntimeError(f"could not find the docker socket; tried: {searched_paths}") -def get_docker_connector() -> tuple[Path | None, yarl.URL, aiohttp.BaseConnector]: +def search_docker_socket_files() -> tuple[Path | None, yarl.URL, aiohttp.BaseConnector]: + connector_cls: type[aiohttp.UnixConnector] | type[aiohttp.NamedPipeConnector] + sock_path, docker_host, connector_cls = _search_docker_socket_files_impl() + return ( + sock_path, + docker_host, + connector_cls(os.fsdecode(sock_path), force_close=True), + ) + + +def get_docker_connector() -> DockerConnector: if raw_docker_host := os.environ.get("DOCKER_HOST", None): - return parse_docker_host_url(yarl.URL(raw_docker_host)) + sock_path, docker_host, connector = parse_docker_host_url(yarl.URL(raw_docker_host)) + return DockerConnector( + sock_path, + docker_host, + connector, + DockerConnectorSource.ENV_VAR, + ) if raw_docker_host := get_docker_context_host(): - return parse_docker_host_url(yarl.URL(raw_docker_host)) - return search_docker_socket_files() + sock_path, docker_host, connector = parse_docker_host_url(yarl.URL(raw_docker_host)) + return DockerConnector( + sock_path, + docker_host, + connector, + DockerConnectorSource.USER_CONTEXT, + ) + sock_path, docker_host, connector = search_docker_socket_files() + return DockerConnector( + sock_path, + docker_host, + connector, + DockerConnectorSource.KNOWN_LOCATION, + ) async def login( diff --git a/src/ai/backend/install/docker.py b/src/ai/backend/install/docker.py index b7d83d88f9..758fabe4f8 100644 --- a/src/ai/backend/install/docker.py +++ b/src/ai/backend/install/docker.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING import aiohttp +from aiohttp.client_exceptions import ClientConnectorError from rich.text import Text from ai.backend.common.docker import get_docker_connector @@ -73,10 +74,10 @@ async def detect_system_docker(ctx: Context) -> str: Text.from_markup("[yellow]Docker commands require sudo. We will use sudo.[/]") ) try: - sock_path, docker_host, connector = get_docker_connector() + connector = get_docker_connector() except RuntimeError as e: raise PrerequisiteError(f"Could not find the docker socket ({e})") from e - ctx.log.write(Text.from_markup(f"[cyan]{docker_host=} {sock_path=}[/]")) + ctx.log.write(Text.from_markup(f"[cyan]{connector=}[/]")) # Test a docker command to ensure passwordless sudo. proc = await asyncio.create_subprocess_exec( @@ -101,9 +102,11 @@ async def detect_system_docker(ctx: Context) -> str: # Change the docker socket permission (temporarily) # so that we could access the docker daemon API directly. # NOTE: For TCP URLs (e.g., remote Docker), we don't have the socket file. - if sock_path is not None and not sock_path.resolve().is_relative_to(Path.home()): + if connector.sock_path is not None and not connector.sock_path.resolve().is_relative_to( + Path.home() + ): proc = await asyncio.create_subprocess_exec( - *["sudo", "chmod", "666", str(sock_path)], + *["sudo", "chmod", "666", str(connector.sock_path)], stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT, ) @@ -112,10 +115,13 @@ async def detect_system_docker(ctx: Context) -> str: if (await proc.wait()) != 0: raise RuntimeError("Failed to set the docker socket permission", stdout) - async with aiohttp.ClientSession(connector=connector) as sess: - async with sess.get(docker_host / "version") as r: + async with aiohttp.ClientSession(connector=connector.connector) as sess: + async with sess.get(connector.docker_host / "version") as r: if r.status != 200: - raise RuntimeError("Failed to query the Docker daemon API") + raise RuntimeError( + "The Docker daemon API responded with unexpected response:" + f" {r.status} {r.reason}" + ) response_data = await r.json() return response_data["Version"] @@ -155,11 +161,15 @@ async def get_preferred_pants_local_exec_root(ctx: Context) -> str: async def determine_docker_sudo() -> bool: - sock_path, docker_host, connector = get_docker_connector() + connector = get_docker_connector() try: - async with aiohttp.ClientSession(connector=connector) as sess: - async with sess.get(docker_host / "version") as r: + async with aiohttp.ClientSession(connector=connector.connector) as sess: + async with sess.get(connector.docker_host / "version") as r: await r.json() + except ClientConnectorError as e: + if isinstance(e.os_error, PermissionError): + return True + raise except PermissionError: return True return False @@ -180,6 +190,8 @@ async def check_docker(ctx: Context) -> None: else: fail_with_system_docker_install_request() + # Compose is not a part of the docker API but a client-side plugin. + # We need to execute the client command to get information about it. proc = await asyncio.create_subprocess_exec( *ctx.docker_sudo, "docker", "compose", "version", stdout=asyncio.subprocess.PIPE ) diff --git a/src/ai/backend/manager/container_registry/local.py b/src/ai/backend/manager/container_registry/local.py index 11dba3903e..a5f930cee7 100644 --- a/src/ai/backend/manager/container_registry/local.py +++ b/src/ai/backend/manager/container_registry/local.py @@ -25,9 +25,9 @@ class LocalRegistry(BaseContainerRegistry): @actxmgr async def prepare_client_session(self) -> AsyncIterator[tuple[yarl.URL, aiohttp.ClientSession]]: - _, url, connector = get_docker_connector() - async with aiohttp.ClientSession(connector=connector) as sess: - yield url, sess + connector = get_docker_connector() + async with aiohttp.ClientSession(connector=connector.connector) as sess: + yield connector.docker_host, sess async def fetch_repositories( self, diff --git a/tests/common/test_docker.py b/tests/common/test_docker.py index 5078d3b860..7e6104d778 100644 --- a/tests/common/test_docker.py +++ b/tests/common/test_docker.py @@ -2,6 +2,7 @@ import functools import itertools import typing +from pathlib import PosixPath from unittest.mock import MagicMock, call import aiohttp @@ -10,46 +11,46 @@ from ai.backend.common.docker import ( ImageRef, PlatformTagSet, + _search_docker_socket_files_impl, default_registry, default_repository, get_docker_connector, get_docker_context_host, - search_docker_socket_files, ) @pytest.mark.asyncio async def test_get_docker_connector(monkeypatch): get_docker_context_host.cache_clear() - search_docker_socket_files.cache_clear() + _search_docker_socket_files_impl.cache_clear() with monkeypatch.context() as m: m.setenv("DOCKER_HOST", "http://localhost:2375") - _, url, connector = get_docker_connector() - assert str(url) == "http://localhost:2375" - assert isinstance(connector, aiohttp.TCPConnector) + connector = get_docker_connector() + assert str(connector.docker_host) == "http://localhost:2375" + assert isinstance(connector.connector, aiohttp.TCPConnector) get_docker_context_host.cache_clear() - search_docker_socket_files.cache_clear() + _search_docker_socket_files_impl.cache_clear() with monkeypatch.context() as m: m.setenv("DOCKER_HOST", "https://example.com:2375") - _, url, connector = get_docker_connector() - assert str(url) == "https://example.com:2375" - assert isinstance(connector, aiohttp.TCPConnector) + connector = get_docker_connector() + assert str(connector.docker_host) == "https://example.com:2375" + assert isinstance(connector.connector, aiohttp.TCPConnector) get_docker_context_host.cache_clear() - search_docker_socket_files.cache_clear() + _search_docker_socket_files_impl.cache_clear() with monkeypatch.context() as m: m.setenv("DOCKER_HOST", "unix:///run/docker.sock") m.setattr("pathlib.Path.exists", lambda self: True) m.setattr("pathlib.Path.is_socket", lambda self: True) m.setattr("pathlib.Path.is_fifo", lambda self: False) - _, url, connector = get_docker_connector() - assert str(url) == "http://localhost" - assert isinstance(connector, aiohttp.UnixConnector) - assert connector.path == "/run/docker.sock" + connector = get_docker_connector() + assert str(connector.docker_host) == "http://docker" + assert isinstance(connector.connector, aiohttp.UnixConnector) + assert connector.sock_path == PosixPath("/run/docker.sock") get_docker_context_host.cache_clear() - search_docker_socket_files.cache_clear() + _search_docker_socket_files_impl.cache_clear() with monkeypatch.context() as m: m.setenv("DOCKER_HOST", "unix:///run/docker.sock") m.setattr("pathlib.Path.exists", lambda self: False) @@ -59,7 +60,7 @@ async def test_get_docker_connector(monkeypatch): get_docker_connector() get_docker_context_host.cache_clear() - search_docker_socket_files.cache_clear() + _search_docker_socket_files_impl.cache_clear() with monkeypatch.context() as m: m.setenv("DOCKER_HOST", "npipe:////./pipe/docker_engine") m.setattr("pathlib.Path.exists", lambda self: True) @@ -67,17 +68,17 @@ async def test_get_docker_connector(monkeypatch): m.setattr("pathlib.Path.is_fifo", lambda self: True) mock_connector = MagicMock() m.setattr("aiohttp.NamedPipeConnector", mock_connector) - _, url, connector = get_docker_connector() - assert str(url) == "http://localhost" - mock_connector.assert_called_once_with(r"\\.\pipe\docker_engine") + connector = get_docker_connector() + assert str(connector.docker_host) == "http://docker" + mock_connector.assert_called_once_with(r"\\.\pipe\docker_engine", force_close=True) - search_docker_socket_files.cache_clear() + _search_docker_socket_files_impl.cache_clear() with monkeypatch.context() as m: m.setenv("DOCKER_HOST", "unknown://dockerhost") with pytest.raises(RuntimeError, match="unsupported connection scheme"): get_docker_connector() - search_docker_socket_files.cache_clear() + _search_docker_socket_files_impl.cache_clear() with monkeypatch.context() as m: m.delenv("DOCKER_HOST", raising=False) m.setattr("ai.backend.common.docker.get_docker_context_host", lambda: None) @@ -88,12 +89,12 @@ async def test_get_docker_connector(monkeypatch): m.setattr("pathlib.Path.exists", lambda self: True) m.setattr("pathlib.Path.is_socket", lambda self: True) m.setattr("pathlib.Path.is_fifo", lambda self: False) - _, url, connector = get_docker_connector() + connector = get_docker_connector() mock_path.assert_has_calls([call("/run/docker.sock"), call("/var/run/docker.sock")]) - assert str(url) == "http://localhost" - assert isinstance(connector, aiohttp.UnixConnector) + assert str(connector.docker_host) == "http://docker" + assert isinstance(connector.connector, aiohttp.UnixConnector) - search_docker_socket_files.cache_clear() + _search_docker_socket_files_impl.cache_clear() with monkeypatch.context() as m: m.delenv("DOCKER_HOST", raising=False) m.setattr("ai.backend.common.docker.get_docker_context_host", lambda: None) @@ -105,12 +106,12 @@ async def test_get_docker_connector(monkeypatch): m.setattr("pathlib.Path.is_fifo", lambda self: True) mock_connector = MagicMock() m.setattr("aiohttp.NamedPipeConnector", mock_connector) - _, url, connector = get_docker_connector() + connector = get_docker_connector() mock_path.assert_has_calls([call(r"\\.\pipe\docker_engine")]) - assert str(url) == "http://localhost" + assert str(connector.docker_host) == "http://docker" get_docker_context_host.cache_clear() - search_docker_socket_files.cache_clear() + _search_docker_socket_files_impl.cache_clear() with monkeypatch.context() as m: m.delenv("DOCKER_HOST", raising=False) m.setattr("ai.backend.common.docker.get_docker_context_host", lambda: None)