diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..7ed5fea --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,28 @@ +name: CI + +on: + pull_request: + push: + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: ">=3.10" + cache: "pip" + - name: Install dependencies + run: | + pip install . + pip install pylint + + - name: pre-commit + uses: pre-commit/actions@v3.0.1 + + - name: pyright + uses: jakebailey/pyright-action@v2 + + - name: pylint + run: pylint zmk diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 94572a8..5015d59 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,19 +1,24 @@ fail_fast: false repos: - - repo: https://github.com/psf/black - rev: "23.9.1" + - repo: https://github.com/asottile/pyupgrade + rev: v3.17.0 hooks: - - id: black + - id: pyupgrade + args: [--py310-plus] - repo: https://github.com/pycqa/isort rev: "5.13.2" hooks: - id: isort + - repo: https://github.com/psf/black + rev: "24.8.0" + hooks: + - id: black - repo: https://github.com/Lucas-C/pre-commit-hooks - rev: v1.5.2 + rev: v1.5.5 hooks: - id: remove-tabs - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.6.0 hooks: - id: trailing-whitespace - id: check-yaml diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 0000000..d664ad0 --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,7 @@ +{ + "recommendations": [ + "ms-python.black-formatter", + "ms-python.isort", + "ms-python.python" + ] +} diff --git a/.vscode/settings.json b/.vscode/settings.json index 787772a..7e4cc36 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -21,5 +21,7 @@ "**/templates/common/**": "mako" }, "isort.check": true, - "isort.args": ["--profile", "black"] + "isort.args": ["--settings-path", "${workspaceFolder}"], + "python.analysis.importFormat": "relative", + "python.analysis.typeCheckingMode": "standard" } diff --git a/README.md b/README.md index aba4ae2..9cebed6 100644 --- a/README.md +++ b/README.md @@ -239,3 +239,23 @@ For example, to point ZMK CLI to an existing repo at `~/Documents/zmk-config`, r ```sh zmk config user.home ~/Documents/zmk-config ``` + +# Development + +If you would like to help improve ZMK CLI, you can clone this repo and install it in editable mode so your changes to the code apply when you run `zmk`. Open a terminal to the root directory of the repository and run: + +```sh +pip install -e ".[dev]" +pre-commit install +``` + +You may optionally run these commands inside a [virtual environment](https://docs.python.org/3/library/venv.html) if you don't want to install ZMK CLI's dependencies globally or if your OS disallows doing this. + +After running `pre-commit install`, your code will be checked when you make a commit, but there are some slower checks that do not run automatically. To run these additional checks, run these commands: + +```sh +pyright . +pylint zmk +``` + +Alternatively, you can just create a pull request and GitHub will run the checks and report any errors. diff --git a/pyproject.toml b/pyproject.toml index 5c5f34a..4eeb3d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,9 @@ dependencies = [ ] dynamic = ["version"] +[project.optional-dependencies] +dev = ["pre-commit", "pylint", "pyright"] + [project.urls] Documentation = "https://zmk.dev/docs" "Source Code" = "https://github.com/zmkfirmware/zmk-cli/" @@ -42,6 +45,20 @@ build-backend = "setuptools.build_meta" [tool.isort] profile = "black" +[tool.pylint.MAIN] +ignore = "_version.py" + +[tool.pylint."MESSAGES CONTROL"] +disable = [ + "arguments-differ", # Covered by pyright + "fixme", + "too-few-public-methods", + "too-many-arguments", + "too-many-branches", + "too-many-instance-attributes", + "too-many-locals", +] + [tool.setuptools] packages = ["zmk"] diff --git a/zmk/backports.py b/zmk/backports.py index b967a35..ebfc6c0 100644 --- a/zmk/backports.py +++ b/zmk/backports.py @@ -2,6 +2,8 @@ Backports from Python > 3.10. """ +# pyright: reportMissingImports = false + try: # pylint: disable=unused-import from enum import StrEnum diff --git a/zmk/build.py b/zmk/build.py index c0ef6ee..0d22d9c 100644 --- a/zmk/build.py +++ b/zmk/build.py @@ -2,28 +2,30 @@ Build matrix processing. """ -import collections.abc +from collections.abc import Iterable, Mapping, Sequence from dataclasses import asdict, dataclass, field from pathlib import Path -from typing import Any, Iterable, Optional +from typing import Any, Self, TypeVar, cast, overload import dacite from .repo import Repo from .yaml import YAML +T = TypeVar("T") + @dataclass class BuildItem: """An item in the build matrix""" board: str - shield: Optional[str] = None - snippet: Optional[str] = None - cmake_args: Optional[str] = None - artifact_name: Optional[str] = None + shield: str | None = None + snippet: str | None = None + cmake_args: str | None = None + artifact_name: str | None = None - def __rich__(self): + def __rich__(self) -> str: parts = [] parts.append(self.board) @@ -52,28 +54,28 @@ class BuildMatrix: _path: Path _yaml: YAML - _data: Any + _data: dict[str, Any] | None @classmethod - def from_repo(cls, repo: Repo): + def from_repo(cls, repo: Repo) -> Self: """Get the build matrix for a repo""" return cls(repo.build_matrix_path) - def __init__(self, path: Path) -> None: + def __init__(self, path: Path): self._path = path self._yaml = YAML(typ="rt") self._yaml.indent(mapping=2, sequence=4, offset=2) try: - self._data = self._yaml.load(self._path) + self._data = cast(dict[str, Any], self._yaml.load(self._path)) except FileNotFoundError: self._data = None - def write(self): + def write(self) -> None: """Updated the YAML file, creating it if necessary""" self._yaml.dump(self._data, self._path) @property - def path(self): + def path(self) -> Path: """Path to the matrix's YAML file""" return self._path @@ -87,7 +89,7 @@ def include(self) -> list[BuildItem]: wrapper = dacite.from_dict(_BuildMatrixWrapper, normalized) return wrapper.include - def has_item(self, item: BuildItem): + def has_item(self, item: BuildItem) -> bool: """Get whether the matrix has a build item""" return item in self.include @@ -106,7 +108,7 @@ def append(self, items: BuildItem | Iterable[BuildItem]) -> list[BuildItem]: return [] if not self._data: - self._data = self._yaml.map() + self._data = cast(dict[str, Any], self._yaml.map()) if "include" not in self._data: self._data["include"] = self._yaml.seq() @@ -121,7 +123,7 @@ def remove(self, items: BuildItem | Iterable[BuildItem]) -> list[BuildItem]: :return: the items that were removed. """ if not self._data or "include" not in self._data: - return False + return [] removed = [] items = [items] if isinstance(items, BuildItem) else items @@ -138,7 +140,25 @@ def remove(self, items: BuildItem | Iterable[BuildItem]) -> list[BuildItem]: return removed -def _keys_to_python(data: Any): +@overload +def _keys_to_python(data: str) -> str: ... + + +@overload +def _keys_to_python( + data: Sequence[T], +) -> Sequence[T]: ... + + +@overload +def _keys_to_python(data: Mapping[str, T]) -> Mapping[str, T]: ... + + +@overload +def _keys_to_python(data: T) -> T: ... + + +def _keys_to_python(data: Any) -> Any: """ Fix any keys with hyphens to underscores so that dacite.from_dict() will work correctly. @@ -151,10 +171,10 @@ def fix_key(key: str): case str(): return data - case collections.abc.Sequence(): + case Sequence(): return [_keys_to_python(i) for i in data] - case collections.abc.Mapping(): + case Mapping(): return {fix_key(k): _keys_to_python(v) for k, v in data.items()} case _: diff --git a/zmk/commands/__init__.py b/zmk/commands/__init__.py index 02f0ddb..0227ab1 100644 --- a/zmk/commands/__init__.py +++ b/zmk/commands/__init__.py @@ -7,7 +7,7 @@ from . import cd, code, config, download, init, keyboard, module, west -def register(app: typer.Typer): +def register(app: typer.Typer) -> None: """Register all commands with the app""" app.command()(cd.cd) app.command()(code.code) diff --git a/zmk/commands/cd.py b/zmk/commands/cd.py index c755653..f824a4a 100644 --- a/zmk/commands/cd.py +++ b/zmk/commands/cd.py @@ -10,11 +10,11 @@ import shellingham import typer -from ..config import Config +from ..config import get_config from ..exceptions import FatalError, FatalHomeMissing, FatalHomeNotSet -def cd(ctx: typer.Context): +def cd(ctx: typer.Context) -> None: """Go to the ZMK config repo.""" if not sys.stdout.isatty(): raise FatalError( @@ -22,7 +22,7 @@ def cd(ctx: typer.Context): 'Use "cd $(zmk config user.home)" instead.' ) - cfg = ctx.find_object(Config) + cfg = get_config(ctx) home = cfg.home_path if home is None: diff --git a/zmk/commands/code.py b/zmk/commands/code.py index 7cc9ac5..0bf561b 100644 --- a/zmk/commands/code.py +++ b/zmk/commands/code.py @@ -6,16 +6,17 @@ import shlex import shutil import subprocess +from collections.abc import Callable from configparser import NoOptionError from dataclasses import dataclass, field from enum import Flag, auto -from typing import Annotated, Callable, Optional +from typing import Annotated import rich import typer from rich.markdown import Markdown -from ..config import Config, Settings +from ..config import Config, Settings, get_config from ..exceptions import FatalError from ..menu import show_menu from ..repo import Repo @@ -24,7 +25,7 @@ def code( ctx: typer.Context, keyboard: Annotated[ - Optional[str], + str | None, typer.Argument( help="Name of the keyboard to edit. If omitted, opens the repo directory.", ), @@ -39,10 +40,10 @@ def code( "--build", "-b", help="Open the build matrix instead of a keymap." ), ] = False, -): +) -> None: """Open the repo or a .keymap or .conf file in a text editor.""" - cfg = ctx.find_object(Config) + cfg = get_config(ctx) repo = cfg.get_repo() if open_build_matrix: @@ -56,7 +57,7 @@ def code( subprocess.call(cmd, shell=True) -def _get_file(repo: Repo, keyboard: str, open_conf: bool): +def _get_file(repo: Repo, keyboard: str | None, open_conf: bool): if not keyboard: return repo.path @@ -98,7 +99,7 @@ class Editor: "Executable name or command line to execute this tool" support: Support = Support.FILE "Types of files this tool supports editing" - test: Callable[[], bool] = None + test: Callable[[], bool] | None = None """ Function that returns true if the tool is installed. Defaults to `which {self.cmd}`. @@ -109,7 +110,7 @@ class Editor: def __rich__(self): return self.name - def get_command(self): + def get_command(self) -> str | None: """Get the command to execute the tool, or None if it is not installed""" if self.test and self.test(): return self.cmd @@ -169,7 +170,11 @@ def _select_editor(cfg: Config): ) editor = show_menu("Select a text editor:", file_editors, filter_func=_filter) - cfg.set(Settings.CORE_EDITOR, editor.get_command()) + editor_command = editor.get_command() + if not editor_command: + raise TypeError(f"Invalid editor {editor.name}") + + cfg.set(Settings.CORE_EDITOR, editor_command) explorer = None if editor.support & Support.DIR: @@ -181,7 +186,11 @@ def _select_editor(cfg: Config): dir_editors, filter_func=_filter, ) - cfg.set(Settings.CORE_EXPLORER, explorer.get_command()) + explorer_command = explorer.get_command() + if not explorer_command: + raise TypeError(f"Invalid explorer {editor.name}") + + cfg.set(Settings.CORE_EXPLORER, explorer_command) cfg.write() diff --git a/zmk/commands/config.py b/zmk/commands/config.py index 4d62476..bf26ce5 100644 --- a/zmk/commands/config.py +++ b/zmk/commands/config.py @@ -2,13 +2,13 @@ "zmk config" command. """ -from typing import Annotated, Optional +from typing import Annotated import typer from rich.console import Console from .. import styles -from ..config import Config +from ..config import Config, get_config console = Console( highlighter=styles.KeyValueHighlighter(), theme=styles.KEY_VALUE_THEME @@ -17,7 +17,7 @@ def _path_callback(ctx: typer.Context, value: bool): if value: - cfg = ctx.find_object(Config) + cfg = get_config(ctx) print(cfg.path) raise typer.Exit() @@ -25,13 +25,13 @@ def _path_callback(ctx: typer.Context, value: bool): def config( ctx: typer.Context, name: Annotated[ - Optional[str], + str | None, typer.Argument( help="Setting name. Prints all setting values if omitted.", ), ] = None, value: Annotated[ - Optional[str], + str | None, typer.Argument(help="New setting value. Prints the current value if omitted."), ] = None, unset: Annotated[ @@ -39,7 +39,7 @@ def config( typer.Option("--unset", "-u", help="Remove the setting with the given name."), ] = False, _: Annotated[ - Optional[bool], + bool | None, typer.Option( "--path", "-p", @@ -48,10 +48,10 @@ def config( callback=_path_callback, ), ] = False, -): +) -> None: """Get and set ZMK CLI settings.""" - cfg = ctx.find_object(Config) + cfg = get_config(ctx) if name is None: _list_settings(cfg) diff --git a/zmk/commands/download.py b/zmk/commands/download.py index a54cf42..7d0e58e 100644 --- a/zmk/commands/download.py +++ b/zmk/commands/download.py @@ -4,14 +4,14 @@ import typer -from ..config import Config +from ..config import get_config from ..repo import Repo -def download(ctx: typer.Context): +def download(ctx: typer.Context) -> None: """Open the web page to download firmware from GitHub.""" - cfg = ctx.find_object(Config) + cfg = get_config(ctx) repo = cfg.get_repo() actions_url = _get_actions_url(repo) diff --git a/zmk/commands/init.py b/zmk/commands/init.py index d84b572..9f87b67 100644 --- a/zmk/commands/init.py +++ b/zmk/commands/init.py @@ -13,7 +13,7 @@ from rich.prompt import Confirm, Prompt from rich.table import Table -from ..config import Config +from ..config import Config, get_config from ..exceptions import FatalError from ..prompt import UrlPrompt from ..repo import Repo, find_containing_repo, is_repo @@ -25,11 +25,11 @@ TEXT_WIDTH = 80 -def init(ctx: typer.Context): +def init(ctx: typer.Context) -> None: """Create a new ZMK config repo or clone an existing one.""" console = rich.get_console() - cfg = ctx.find_object(Config) + cfg = get_config(ctx) _check_dependencies() _check_for_existing_repo(cfg) diff --git a/zmk/commands/keyboard/__init__.py b/zmk/commands/keyboard/__init__.py index c361443..b580465 100644 --- a/zmk/commands/keyboard/__init__.py +++ b/zmk/commands/keyboard/__init__.py @@ -17,5 +17,5 @@ @app.callback() -def keyboard(): +def keyboard() -> None: """Add or remove keyboards from the build.""" diff --git a/zmk/commands/keyboard/add.py b/zmk/commands/keyboard/add.py index ae65816..a4f7408 100644 --- a/zmk/commands/keyboard/add.py +++ b/zmk/commands/keyboard/add.py @@ -5,24 +5,24 @@ import itertools import shutil from pathlib import Path -from typing import Annotated, Optional +from typing import Annotated import rich import typer from ...build import BuildItem, BuildMatrix +from ...config import get_config from ...exceptions import FatalError from ...hardware import Board, Keyboard, Shield, get_hardware, is_compatible from ...menu import show_menu from ...repo import Repo from ...util import spinner -from ..config import Config def keyboard_add( ctx: typer.Context, controller_id: Annotated[ - Optional[str], + str | None, typer.Option( "--controller", "-c", @@ -31,7 +31,7 @@ def keyboard_add( ), ] = None, keyboard_id: Annotated[ - Optional[str], + str | None, typer.Option( "--keyboard", "--kb", @@ -40,12 +40,12 @@ def keyboard_add( help="ID of the keyboard board/shield to add.", ), ] = None, -): +) -> None: """Add configuration for a keyboard and add it to the build.""" console = rich.get_console() - cfg = ctx.find_object(Config) + cfg = get_config(ctx) repo = cfg.get_repo() with spinner("Finding hardware..."): @@ -56,7 +56,8 @@ def keyboard_add( if keyboard_id: keyboard = hardware.find_keyboard(keyboard_id) - _check_keyboard_found(keyboard, keyboard_id) + if keyboard is None: + raise KeyboardNotFound(keyboard_id) if controller_id: if not isinstance(keyboard, Shield): @@ -66,13 +67,15 @@ def keyboard_add( ) controller = hardware.find_controller(controller_id) - _check_controller_found(controller, controller_id) + if controller is None: + raise ControllerNotFound(controller_id) elif controller_id: # User specified a controller but not a keyboard. Filter the keyboard # list to just those compatible with the controller. controller = hardware.find_controller(controller_id) - _check_controller_found(controller, controller_id) + if controller is None: + raise ControllerNotFound(controller_id) hardware.keyboards = [ kb @@ -86,19 +89,20 @@ def keyboard_add( "Select a keyboard:", hardware.keyboards, filter_func=_filter ) - if isinstance(keyboard, Shield) and controller is None: - hardware.controllers = [ - c for c in hardware.controllers if is_compatible(c, keyboard) - ] - controller = show_menu( - "Select a controller:", hardware.controllers, filter_func=_filter - ) - - # Sanity check that everything is compatible - if keyboard and controller and not is_compatible(controller, keyboard): - raise FatalError( - f'Keyboard "{keyboard.id}" is not compatible with controller "{controller.id}"' - ) + if isinstance(keyboard, Shield): + if controller is None: + hardware.controllers = [ + c for c in hardware.controllers if is_compatible(c, keyboard) + ] + controller = show_menu( + "Select a controller:", hardware.controllers, filter_func=_filter + ) + + # Sanity check that everything is compatible + if not is_compatible(controller, keyboard): + raise FatalError( + f'Keyboard "{keyboard.id}" is not compatible with controller "{controller.id}"' + ) name = keyboard.id if controller: @@ -112,19 +116,23 @@ def keyboard_add( console.print(f'Run "zmk code {keyboard.id}" to edit the keymap.') -def _filter(item: Board | Shield, text: str): +def _filter(item: Keyboard, text: str): text = text.casefold().strip() return text in item.id.casefold() or text in item.name.casefold() -def _check_keyboard_found(keyboard: Optional[Keyboard], keyboard_id: str): - if keyboard is None: - raise FatalError(f'Could not find a keyboard with ID "{keyboard_id}"') +class KeyboardNotFound(FatalError): + """Fatal error for an invalid keyboard ID""" + + def __init__(self, keyboard_id: str): + super().__init__(f'Could not find a keyboard with ID "{keyboard_id}"') -def _check_controller_found(controller: Optional[Board], controller_id: str): - if controller is None: - raise FatalError(f'Could not find a controller board with ID "{controller_id}"') +class ControllerNotFound(FatalError): + """Fatal error for an invalid controller ID""" + + def __init__(self, controller_id: str): + super().__init__(f'Could not find a controller board with ID "{controller_id}"') def _copy_keyboard_file(repo: Repo, path: Path): @@ -133,12 +141,15 @@ def _copy_keyboard_file(repo: Repo, path: Path): shutil.copy2(path, dest_path) -def _get_build_items(keyboard: Keyboard, controller: Optional[Board]): +def _get_build_items(keyboard: Keyboard, controller: Board | None): boards = [] shields = [] match keyboard: case Shield(id=shield_id, siblings=siblings): + if controller is None: + raise ValueError("controller may not be None if keyboard is a shield") + shields = siblings or [shield_id] boards = [controller.id] @@ -151,7 +162,7 @@ def _get_build_items(keyboard: Keyboard, controller: Optional[Board]): return [BuildItem(board=b, shield=s) for b, s in itertools.product(boards, shields)] -def _add_keyboard(repo: Repo, keyboard: Keyboard, controller: Optional[Board]): +def _add_keyboard(repo: Repo, keyboard: Keyboard, controller: Board | None): _copy_keyboard_file(repo, keyboard.keymap_path) _copy_keyboard_file(repo, keyboard.config_path) diff --git a/zmk/commands/keyboard/list.py b/zmk/commands/keyboard/list.py index 12ec284..ce72d04 100644 --- a/zmk/commands/keyboard/list.py +++ b/zmk/commands/keyboard/list.py @@ -2,7 +2,8 @@ "zmk keyboard list" command. """ -from typing import Annotated, Iterable, Optional +from collections.abc import Iterable +from typing import Annotated import rich import typer @@ -12,10 +13,10 @@ from ...backports import StrEnum from ...build import BuildItem, BuildMatrix +from ...config import get_config from ...exceptions import FatalError from ...hardware import Board, Hardware, Shield, get_hardware, is_compatible from ...util import spinner -from ..config import Config # TODO: allow output as unformatted list # TODO: allow output as more detailed metadata @@ -37,7 +38,7 @@ def _list_build_matrix(ctx: typer.Context, value: bool): console = rich.get_console() - cfg = ctx.find_object(Config) + cfg = get_config(ctx) repo = cfg.get_repo() matrix = BuildMatrix.from_repo(repo) @@ -79,7 +80,7 @@ def add_row(item: BuildItem): def keyboard_list( ctx: typer.Context, _: Annotated[ - Optional[bool], + bool | None, typer.Option( "--build", help="Show the build matrix.", @@ -94,9 +95,9 @@ def keyboard_list( "-t", help="List only items of this type.", ), - ] = "all", + ] = ListType.ALL, board: Annotated[ - Optional[str], + str | None, typer.Option( "--board", "-b", @@ -105,7 +106,7 @@ def keyboard_list( ), ] = None, shield: Annotated[ - Optional[str], + str | None, typer.Option( "--shield", "-s", @@ -114,7 +115,7 @@ def keyboard_list( ), ] = None, interconnect: Annotated[ - Optional[str], + str | None, typer.Option( "--interconnect", "-i", @@ -128,12 +129,12 @@ def keyboard_list( "--standalone", help="List only keyboards with onboard controllers." ), ] = False, -): +) -> None: """List supported keyboards or keyboards in the build matrix.""" console = rich.get_console() - cfg = ctx.find_object(Config) + cfg = get_config(ctx) repo = cfg.get_repo() with spinner("Finding hardware..."): @@ -145,7 +146,11 @@ def keyboard_list( if item is None: raise FatalError(f'Could not find controller board "{board}".') - groups.keyboards = [kb for kb in groups.keyboards if is_compatible(item, kb)] + groups.keyboards = [ + kb + for kb in groups.keyboards + if isinstance(kb, Shield) and is_compatible(item, kb) + ] list_type = ListType.KEYBOARD elif shield: @@ -166,11 +171,13 @@ def keyboard_list( if item is None: raise FatalError(f'Could not find interconnect "{interconnect}".') - groups.controllers = [c for c in groups.controllers if item.id in c.exposes] + groups.controllers = [ + c for c in groups.controllers if c.exposes and item.id in c.exposes + ] groups.keyboards = [ kb for kb in groups.keyboards - if isinstance(kb, Shield) and item.id in kb.requires + if isinstance(kb, Shield) and kb.requires and item.id in kb.requires ] groups.interconnects = [] diff --git a/zmk/commands/keyboard/new.py b/zmk/commands/keyboard/new.py index e31507e..6319c86 100644 --- a/zmk/commands/keyboard/new.py +++ b/zmk/commands/keyboard/new.py @@ -4,17 +4,17 @@ import re from dataclasses import dataclass, field -from typing import Annotated, Optional +from typing import Annotated import rich import typer from rich.prompt import Confirm, InvalidResponse, PromptBase from ...backports import StrEnum +from ...config import get_config from ...exceptions import FatalError from ...menu import detail_list, show_menu from ...templates import get_template_files -from ..config import Config class KeyboardType(StrEnum): @@ -76,19 +76,19 @@ def _validate_short_name(name: str): raise typer.BadParameter(f"Name must be <= {MAX_NAME_LENGTH} characters.") -def _id_callback(value: Optional[str]): +def _id_callback(value: str | None): if value is not None: _validate_id(value) return value -def _name_callback(name: Optional[str]): +def _name_callback(name: str | None): if name is not None: _validate_name(name) return name -def _short_name_callback(name: Optional[str]): +def _short_name_callback(name: str | None): if name is not None: _validate_short_name(name) return name @@ -97,15 +97,15 @@ def _short_name_callback(name: Optional[str]): def keyboard_new( ctx: typer.Context, keyboard_id: Annotated[ - Optional[str], + str | None, typer.Option("--id", "-i", help="Board/shield ID.", callback=_id_callback), ] = None, keyboard_name: Annotated[ - Optional[str], + str | None, typer.Option("--name", "-n", help="Keyboard name.", callback=_name_callback), ] = None, short_name: Annotated[ - Optional[str], + str | None, typer.Option( "--shortname", "-s", @@ -114,7 +114,7 @@ def keyboard_new( ), ] = None, keyboard_type: Annotated[ - Optional[KeyboardType], + KeyboardType | None, typer.Option( "--type", "-t", @@ -122,7 +122,7 @@ def keyboard_new( ), ] = None, keyboard_platform: Annotated[ - Optional[KeyboardPlatform], + KeyboardPlatform | None, typer.Option( "--platform", "-p", @@ -130,15 +130,15 @@ def keyboard_new( ), ] = None, keyboard_layout: Annotated[ - Optional[KeyboardLayout], + KeyboardLayout | None, typer.Option("--layout", "-l", help="Keyboard hardware layout."), ] = None, force: Annotated[ bool, typer.Option("--force", "-f", help="Overwrite existing files.") - ] = None, -): + ] = False, +) -> None: """Create a new keyboard from a template.""" - cfg = ctx.find_object(Config) + cfg = get_config(ctx) repo = cfg.get_repo() board_root = repo.board_root @@ -155,7 +155,7 @@ def keyboard_new( short_name = ShortNamePrompt.ask() if not keyboard_id: - keyboard_id = IdPrompt.ask(name=short_name) + keyboard_id = IdPrompt.ask(prompt=short_name) if not keyboard_type: keyboard_type = _prompt_keyboard_type() @@ -257,12 +257,11 @@ class NamePrompt(NamePromptBase): """Prompt for a keyboard name.""" @classmethod - def validate(cls, value: str): + def validate(cls, value: str) -> None: _validate_name(value) - # pylint: disable=arguments-differ @classmethod - def ask(cls): + def ask(cls) -> str: # pyright: ignore[reportIncompatibleMethodOverride] return super().ask("Enter the name of the keyboard") @@ -270,12 +269,11 @@ class ShortNamePrompt(NamePromptBase): """Prompt for an abbreviated keyboard name.""" @classmethod - def validate(cls, value: str): + def validate(cls, value: str) -> None: _validate_short_name(value) - # pylint: disable=arguments-differ @classmethod - def ask(cls): + def ask(cls) -> str: # pyright: ignore[reportIncompatibleMethodOverride] return super().ask( f"Enter an abbreviated name [dim](<= {MAX_NAME_LENGTH} chars)" ) @@ -285,16 +283,24 @@ class IdPrompt(NamePromptBase): """Prompt for a keyboard identifier.""" @classmethod - def validate(cls, value: str): + def validate(cls, value: str) -> None: _validate_id(value) - # pylint: disable=arguments-differ @classmethod - def ask(cls, name: str): - return super().ask( - "Enter an ID for the keyboard", default=_get_default_id(name) + def ask( # pyright: ignore[reportIncompatibleMethodOverride] + cls, prompt: str + ) -> str: + result = super().ask( + "Enter an ID for the keyboard", default=_get_default_id(prompt) ) + # rich uses ... to indicate no default, but passing ... to the "default" + # parameter causes it to add EllipsisType to the possible return types. + if result == ...: + raise TypeError("ask() returned ...") + + return result + _DEFAULT_ARCH = "arm" _PLATFORM_ARCH: dict[KeyboardPlatform, str] = { @@ -315,7 +321,7 @@ def _get_template( template.data["name"] = keyboard_name template.data["shortname"] = short_name template.data["keyboard_type"] = str(keyboard_type) - template.data["arch"] = None + template.data["arch"] = "" match keyboard_type: case KeyboardType.SHIELD: diff --git a/zmk/commands/keyboard/remove.py b/zmk/commands/keyboard/remove.py index 7218737..e0a2808 100644 --- a/zmk/commands/keyboard/remove.py +++ b/zmk/commands/keyboard/remove.py @@ -6,14 +6,14 @@ import typer from ...build import BuildMatrix +from ...config import get_config from ...menu import show_menu -from ..config import Config # TODO: add options to select items from command line -def keyboard_remove(ctx: typer.Context): +def keyboard_remove(ctx: typer.Context) -> None: """Remove a keyboard from the build.""" - cfg = ctx.find_object(Config) + cfg = get_config(ctx) repo = cfg.get_repo() matrix = BuildMatrix.from_repo(repo) diff --git a/zmk/commands/module/__init__.py b/zmk/commands/module/__init__.py index 6361f4f..2bd7827 100644 --- a/zmk/commands/module/__init__.py +++ b/zmk/commands/module/__init__.py @@ -15,5 +15,5 @@ @app.callback() -def keyboard(): +def keyboard() -> None: """Add or remove Zephyr modules from the build.""" diff --git a/zmk/commands/module/add.py b/zmk/commands/module/add.py index 555454f..4b6f27c 100644 --- a/zmk/commands/module/add.py +++ b/zmk/commands/module/add.py @@ -3,7 +3,7 @@ """ import subprocess -from typing import Annotated, Optional +from typing import Annotated import rich import typer @@ -11,7 +11,7 @@ from rich.prompt import InvalidResponse, Prompt, PromptBase from west.manifest import ImportFlag, Manifest -from ...config import Config +from ...config import get_config from ...exceptions import FatalError from ...prompt import UrlPrompt from ...util import spinner @@ -21,20 +21,20 @@ def module_add( ctx: typer.Context, url: Annotated[ - Optional[str], + str | None, typer.Argument(help="URL of the Git repository to add.", show_default=False), ] = None, revision: Annotated[ - Optional[str], + str | None, typer.Argument(help="Revision to track.", show_default="main"), ] = None, name: Annotated[ - Optional[str], + str | None, typer.Option("--name", "-n", help="Name of the module.", show_default=False), ] = None, -): +) -> None: """Add a Zephyr module to the build.""" - cfg = ctx.find_object(Config) + cfg = get_config(ctx) repo = cfg.get_repo() manifest = Manifest.from_topdir( @@ -88,20 +88,21 @@ def _get_name_from_url(repo_url: str): return repo_url.split("/")[-1].removesuffix(".git") -class NamePrompt(PromptBase): +class NamePrompt(PromptBase[str]): """Prompt for a module name.""" _manifest: Manifest - def __init__(self, manifest: Manifest, *, console: Optional[Console] = None): + def __init__(self, manifest: Manifest, *, console: Console | None = None): super().__init__("Enter a new name", console=console) self._manifest = manifest - # pylint: disable=arguments-renamed, arguments-differ @classmethod - def ask(cls, manifest: Manifest, *, console: Optional[Console] = None): - prompt = cls(manifest, console=console) - return prompt() + def ask( # pyright: ignore[reportIncompatibleMethodOverride] + cls, prompt: Manifest, *, console: Console | None = None + ): + subprompt = cls(prompt, console=console) + return subprompt() def process_response(self, value: str) -> str: value = value.strip() diff --git a/zmk/commands/module/list.py b/zmk/commands/module/list.py index 9a1dc47..8cbf0f0 100644 --- a/zmk/commands/module/list.py +++ b/zmk/commands/module/list.py @@ -8,17 +8,15 @@ from rich.table import Table from west.manifest import ImportFlag, Manifest -from ...config import Config +from ...config import get_config -def module_list( - ctx: typer.Context, -): +def module_list(ctx: typer.Context) -> None: """Print a list of installed Zephyr modules.""" console = rich.get_console() - cfg = ctx.find_object(Config) + cfg = get_config(ctx) repo = cfg.get_repo() manifest = Manifest.from_topdir( diff --git a/zmk/commands/module/remove.py b/zmk/commands/module/remove.py index 739b31c..16e52e9 100644 --- a/zmk/commands/module/remove.py +++ b/zmk/commands/module/remove.py @@ -7,13 +7,13 @@ import stat import subprocess from dataclasses import dataclass -from typing import Annotated, Any, Optional +from typing import Annotated, Any import rich import typer from west.manifest import ImportFlag, Manifest, Project -from ...config import Config +from ...config import get_config from ...exceptions import FatalError from ...menu import Detail, detail_list, show_menu from ...repo import Repo @@ -24,12 +24,12 @@ def module_remove( ctx: typer.Context, module: Annotated[ - Optional[str], + str | None, typer.Argument(help="Name or URL of the module to remove.", show_default=False), ] = None, -): +) -> None: """Remove a Zephyr module from the build.""" - cfg = ctx.find_object(Config) + cfg = get_config(ctx) repo = cfg.get_repo() manifest = Manifest.from_topdir( diff --git a/zmk/commands/west.py b/zmk/commands/west.py index ce86728..96ffab5 100644 --- a/zmk/commands/west.py +++ b/zmk/commands/west.py @@ -2,20 +2,20 @@ "zmk west" command. """ -from typing import Annotated, Optional +from typing import Annotated import typer -from ..config import Config +from ..config import get_config -def west(ctx: typer.Context): +def west(ctx: typer.Context) -> None: # pylint: disable=line-too-long """ Run [link=https://docs.zephyrproject.org/latest/develop/west/index.html]west[/link] in the config repo. """ - cfg = ctx.find_object(Config) + cfg = get_config(ctx) repo = cfg.get_repo() # TODO: detect this better @@ -29,15 +29,15 @@ def west(ctx: typer.Context): def update( ctx: typer.Context, modules: Annotated[ - Optional[list[str]], + list[str] | None, typer.Argument( help="Names of modules to update. Updates all modules if omitted." ), ] = None, -): +) -> None: """Fetch the latest keyboard data.""" - cfg = ctx.find_object(Config) + cfg = get_config(ctx) repo = cfg.get_repo() modules = modules or [] diff --git a/zmk/config.py b/zmk/config.py index 62b457a..8ac5f91 100644 --- a/zmk/config.py +++ b/zmk/config.py @@ -3,10 +3,10 @@ """ from collections import defaultdict +from collections.abc import Generator from configparser import ConfigParser from itertools import chain from pathlib import Path -from typing import Optional import typer @@ -30,7 +30,7 @@ class Config: path: Path force_home: bool - def __init__(self, path: Path, force_home=False) -> None: + def __init__(self, path: Path | None, force_home=False): self.path = path or _default_config_path() self.force_home = force_home @@ -38,7 +38,7 @@ def __init__(self, path: Path, force_home=False) -> None: self._parser = ConfigParser() self._parser.read(self.path, encoding="utf-8") - def write(self): + def write(self) -> None: """Write back to the same file used when calling read()""" self.path.parent.mkdir(parents=True, exist_ok=True) @@ -55,17 +55,17 @@ def getboolean(self, name: str, **kwargs) -> bool: section, option = self._split_option(name) return self._parser.getboolean(section, option, **kwargs) - def set(self, name: str, value: str): + def set(self, name: str, value: str) -> None: """Set a setting""" section, option = self._split_option(name) self._parser.set(section, option, value) - def remove(self, name: str): + def remove(self, name: str) -> None: """Remove a setting""" section, option = self._split_option(name) self._parser.remove_option(section, option) - def items(self): + def items(self) -> Generator[tuple[str, str], None, None]: """Yields ('section.option', 'value') tuples for all settings""" sections = set(chain(self._overrides.keys(), self._parser.sections())) @@ -86,7 +86,7 @@ def _split_option(self, name: str): # Shortcuts for commonly-used settings: @property - def home_path(self) -> Optional[Path]: + def home_path(self) -> Path | None: """ Path to ZMK config repo. """ @@ -94,7 +94,7 @@ def home_path(self) -> Optional[Path]: return Path(home) if home else None @home_path.setter - def home_path(self, value: Path): + def home_path(self, value: Path) -> None: self.set(Settings.USER_HOME, str(value.resolve())) def get_repo(self) -> Repo: @@ -119,5 +119,14 @@ def get_repo(self) -> Repo: return Repo(home) +def get_config(ctx: typer.Context) -> Config: + """Get the Config object for the given context""" + + cfg = ctx.find_object(Config) + if cfg is None: + raise RuntimeError("Could not find Config for current context") + return cfg + + def _default_config_path(): return Path(typer.get_app_dir("zmk", roaming=False)) / "zmk.ini" diff --git a/zmk/exceptions.py b/zmk/exceptions.py index 534af30..963aaa4 100644 --- a/zmk/exceptions.py +++ b/zmk/exceptions.py @@ -3,6 +3,7 @@ """ from pathlib import Path +from typing import cast from click import ClickException from rich.highlighter import Highlighter, ReprHighlighter @@ -18,12 +19,14 @@ class FatalError(ClickException): highlighter: Highlighter - def __init__(self, message: str, highlighter: Highlighter = None) -> None: + def __init__(self, message: str, highlighter: Highlighter | None = None): self.highlighter = highlighter or ReprHighlighter() super().__init__(message) def format_message(self) -> str: - return self.highlighter(Text.from_markup(self.message)) + # format_message() expects a str, but if we convert to a string here, + # we will lose any formatting. + return cast(str, self.highlighter(Text.from_markup(self.message))) class FatalHomeNotSet(FatalError): diff --git a/zmk/hardware.py b/zmk/hardware.py index 3a88c0a..6d2e07e 100644 --- a/zmk/hardware.py +++ b/zmk/hardware.py @@ -2,10 +2,11 @@ Hardware metadata discovery and processing. """ +from collections.abc import Generator, Iterable from dataclasses import dataclass, field from functools import reduce from pathlib import Path -from typing import Any, Generator, Iterable, Literal, Optional, TypeAlias, TypeGuard +from typing import Any, Literal, Self, TypeAlias, TypeGuard import dacite @@ -33,11 +34,11 @@ class Hardware: id: str name: str - file_format: Optional[str] = None - url: Optional[str] = None - description: Optional[str] = None - manufacturer: Optional[str] = None - version: Optional[str] = None + file_format: str | None = None + url: str | None = None + description: str | None = None + manufacturer: str | None = None + version: str | None = None def __str__(self) -> str: return self.id @@ -46,7 +47,7 @@ def __rich__(self) -> Any: return f"{self.id} [dim]{self.name}" @classmethod - def from_dict(cls, data): + def from_dict(cls, data) -> Self: """Read a hardware description from a dict""" return dacite.from_dict(cls, data) @@ -56,20 +57,20 @@ class Interconnect(Hardware): """Description of the connection between two pieces of hardware""" node_labels: dict[str, str] = field(default_factory=dict) - design_guideline: Optional[str] = None + design_guideline: str | None = None @dataclass class Keyboard(Hardware): """Base class for hardware that forms a keyboard""" - siblings: Optional[list[str]] = field(default_factory=list) + siblings: list[str] | None = field(default_factory=list) """List of board/shield IDs for a split keyboard""" - exposes: Optional[list[str]] = field(default_factory=list) + exposes: list[str] | None = field(default_factory=list) """List of interconnect IDs this board/shield provides""" - features: Optional[list[Feature]] = field(default_factory=list) + features: list[Feature] | None = field(default_factory=list) """List of features this board/shield supports""" - variants: Optional[list[Variant]] = field(default_factory=list) + variants: list[Variant] | None = field(default_factory=list) def __post_init__(self): self.siblings = self.siblings or [] @@ -78,12 +79,12 @@ def __post_init__(self): self.variants = self.variants or [] @property - def config_path(self): + def config_path(self) -> Path: """Path to the .conf file for this keyboard""" return self.directory / f"{self.id}.conf" @property - def keymap_path(self): + def keymap_path(self) -> Path: """Path to the .keymap file for this keyboard""" return self.directory / f"{self.id}.keymap" @@ -92,7 +93,7 @@ def keymap_path(self): class Board(Keyboard): """Hardware with a processor. May be a keyboard or a controller.""" - arch: Optional[str] = None + arch: str | None = None outputs: list[Output] = field(default_factory=list) """List of methods by which this board supports sending HID data""" @@ -105,7 +106,7 @@ def __post_init__(self): class Shield(Keyboard): """Hardware that attaches to a board. May be a keyboard or a peripheral.""" - requires: Optional[list[str]] = field(default_factory=list) + requires: list[str] | None = field(default_factory=list) """List of interconnects to which this shield attaches""" def __post_init__(self): @@ -126,17 +127,17 @@ class GroupedHardware: # TODO: add displays and other peripherals? - def find_keyboard(self, item_id: str): + def find_keyboard(self, item_id: str) -> Keyboard | None: """Find a keyboard by ID""" item_id = item_id.casefold() return next((i for i in self.keyboards if i.id.casefold() == item_id), None) - def find_controller(self, item_id: str): + def find_controller(self, item_id: str) -> Board | None: """Find a controller by ID""" item_id = item_id.casefold() return next((i for i in self.controllers if i.id.casefold() == item_id), None) - def find_interconnect(self, item_id: str): + def find_interconnect(self, item_id: str) -> Interconnect | None: """Find an interconnect by ID""" item_id = item_id.casefold() return next( @@ -148,7 +149,7 @@ def find_interconnect(self, item_id: str): def is_keyboard(hardware: Hardware) -> TypeGuard[Keyboard]: """Test whether an item is a keyboard (board or shield supporting keys)""" match hardware: - case Keyboard(features=feat) if "keys" in feat: + case Keyboard(features=feat) if feat and "keys" in feat: return True case _: @@ -165,7 +166,9 @@ def is_interconnect(hardware: Hardware) -> TypeGuard[Interconnect]: return isinstance(hardware, Interconnect) -def is_compatible(base: Board | Shield | Iterable[Board | Shield], shield: Shield): +def is_compatible( + base: Board | Shield | Iterable[Board | Shield], shield: Shield +) -> bool: """ Get whether a shield can be attached to the given hardware. @@ -175,6 +178,9 @@ def is_compatible(base: Board | Shield | Iterable[Board | Shield], shield: Shiel an interconnect provided by another item. """ + if not shield.requires: + return True + base = [base] if isinstance(base, Keyboard) else base exposed = flatten(b.exposes for b in base) diff --git a/zmk/main.py b/zmk/main.py index 7059d07..918e0d6 100644 --- a/zmk/main.py +++ b/zmk/main.py @@ -4,7 +4,7 @@ from importlib import metadata from pathlib import Path -from typing import Annotated, Optional +from typing import Annotated import typer @@ -25,7 +25,7 @@ def _version_callback(version: bool): def main( ctx: typer.Context, config_file: Annotated[ - Optional[Path], + Path | None, typer.Option( envvar="ZMK_CLI_CONFIG", help="Path to the ZMK CLI configuration file." ), @@ -45,8 +45,8 @@ def main( callback=_version_callback, is_eager=True, ), - ] = None, -): + ] = False, +) -> None: """ ZMK Firmware command line tool diff --git a/zmk/menu.py b/zmk/menu.py index b414c7b..9a2ccd7 100644 --- a/zmk/menu.py +++ b/zmk/menu.py @@ -2,8 +2,9 @@ Terminal menus """ +from collections.abc import Callable, Iterable from contextlib import contextmanager -from typing import Any, Callable, Generic, Iterable, Optional, TypeVar +from typing import Any, Generic, TypeVar import rich from rich.console import Console @@ -53,7 +54,7 @@ class TerminalMenu(Generic[T], Highlighter): items: list[T] default_index: int - _filter_func: Optional[Callable[[T, str], bool]] + _filter_func: Callable[[T, str], bool] | None _filter_text: str _filter_items: list[T] _cursor_index: int @@ -70,9 +71,9 @@ def __init__( items: Iterable[T], *, default_index=0, - filter_func: Optional[Callable[[T, str], bool]] = None, - console: Optional[Console] = None, - theme: Optional[Theme] = None, + filter_func: Callable[[T, str], bool] | None = None, + console: Console | None = None, + theme: Theme | None = None, ): """ An interactive terminal menu. @@ -112,7 +113,7 @@ def __init__( self._apply_filter() - def show(self): + def show(self) -> T: """ Displays the menu. @@ -146,11 +147,11 @@ def show(self): self._erase_controls() @property - def has_filter(self): + def has_filter(self) -> bool: """Get whether a filter function is set""" return bool(self._filter_func) - def highlight(self, text: Text): + def highlight(self, text: Text) -> None: normfilter = self._filter_text.casefold().strip() if not normfilter: return @@ -180,7 +181,7 @@ def _context(self): self.console.highlighter = old_highlighter def _apply_filter(self): - if self.has_filter: + if self._filter_func: try: old_focus = self._filter_items[self._focus_index] except IndexError: @@ -190,10 +191,11 @@ def _apply_filter(self): i for i in self.items if self._filter_func(i, self._filter_text) ] - try: - self._focus_index = self._filter_items.index(old_focus) - except ValueError: - pass + if old_focus is not None: + try: + self._focus_index = self._filter_items.index(old_focus) + except ValueError: + pass else: self._filter_items = self.items @@ -246,7 +248,7 @@ def _print_menu(self): overflow="crop", ) - def _print_item(self, item: T, focused: bool, show_more: bool): + def _print_item(self, item: T | str, focused: bool, show_more: bool): style = "ellipsis" if show_more else "focus" if focused else "unfocus" indent = "> " if focused else " " @@ -406,10 +408,10 @@ def show_menu( items: Iterable[T], *, default_index=0, - filter_func: Optional[Callable[[T, str], bool]] = None, - console: Optional[Console] = None, - theme: Optional[Theme] = None, -): + filter_func: Callable[[T, str], bool] | None = None, + console: Console | None = None, + theme: Theme | None = None, +) -> T: """ Displays an interactive menu. @@ -457,7 +459,9 @@ def __rich__(self): # pylint: disable=protected-access @classmethod - def align(cls, items: Iterable["Detail[T]"], console: Optional[Console] = None): + def align( + cls, items: Iterable["Detail[T]"], console: Console | None = None + ) -> list["Detail[T]"]: """Set the padding for each item in the list to align the detail strings.""" items = list(items) console = console or rich.get_console() @@ -474,7 +478,7 @@ def align(cls, items: Iterable["Detail[T]"], console: Optional[Console] = None): def detail_list( - items: Iterable[tuple[T, str]], console: Optional[Console] = None + items: Iterable[tuple[T, str]], console: Console | None = None ) -> list[Detail[T]]: """ Create a list of menu items with a description appended to each item. diff --git a/zmk/prompt.py b/zmk/prompt.py index d734caf..ee55b32 100644 --- a/zmk/prompt.py +++ b/zmk/prompt.py @@ -5,7 +5,7 @@ from rich.prompt import InvalidResponse, PromptBase -class UrlPrompt(PromptBase): +class UrlPrompt(PromptBase[str]): """Prompt for a URL.""" def process_response(self, value: str) -> str: diff --git a/zmk/repo.py b/zmk/repo.py index 3226d2a..f479cca 100644 --- a/zmk/repo.py +++ b/zmk/repo.py @@ -4,10 +4,11 @@ import shutil import subprocess +from collections.abc import Generator from contextlib import redirect_stdout from io import StringIO from pathlib import Path -from typing import Any, Generator, Optional +from typing import Any, Literal, overload from west.app.main import main as west_main @@ -28,7 +29,7 @@ def is_repo(path: Path) -> bool: return (path / _PROJECT_MANIFEST_PATH).is_file() -def find_containing_repo(path: Optional[Path] = None) -> Optional[Path]: +def find_containing_repo(path: Path | None = None) -> Path | None: """Search upwards from the given path for a ZMK config repo.""" path = path or Path() path = path.absolute() @@ -52,7 +53,7 @@ def module_manifest_path(self) -> Path: return self.path / _MODULE_MANIFEST_PATH @property - def board_root(self) -> Optional[Path]: + def board_root(self) -> Path | None: """Path to the "boards" folder.""" # Check for board_root from module manifest @@ -93,7 +94,7 @@ def project_manifest_path(self) -> Path: return self.path / _PROJECT_MANIFEST_PATH @property - def board_root(self) -> Optional[Path]: + def board_root(self) -> Path | None: if root := super().board_root: return root @@ -127,6 +128,15 @@ def west_path(self) -> Path: """Path to the west staging folder.""" return self.path / _WEST_STAGING_PATH + @overload + def git(self, *args: str, capture_output: Literal[False] = False) -> None: ... + + @overload + def git(self, *args: str, capture_output: Literal[True]) -> str: ... + + @overload + def git(self, *args: str, capture_output: bool) -> str | None: ... + def git(self, *args: str, capture_output: bool = False) -> str | None: """ Run Git in the repo. @@ -134,7 +144,7 @@ def git(self, *args: str, capture_output: bool = False) -> str | None: If capture_output is True, the command is run silently in the background and this returns the output as a string. """ - args = ["git", *args] + args = ("git", *args) if capture_output: return subprocess.check_output( @@ -144,7 +154,16 @@ def git(self, *args: str, capture_output: bool = False) -> str | None: subprocess.check_call(args, encoding="utf-8") return None - def run_west(self, *args: str, capture_output: bool = False) -> str | None: + @overload + def run_west(self, *args: str, capture_output: Literal[False] = False) -> None: ... + + @overload + def run_west(self, *args: str, capture_output: Literal[True]) -> str: ... + + @overload + def run_west(self, *args: str, capture_output: bool) -> str | None: ... + + def run_west(self, *args: str, capture_output: bool = False): """ Run west in the west staging folder. @@ -156,7 +175,7 @@ def run_west(self, *args: str, capture_output: bool = False) -> str | None: self.ensure_west_ready() return self._run_west(*args, capture_output=capture_output) - def ensure_west_ready(self): + def ensure_west_ready(self) -> None: """ Ensures the west application is correctly initialized. """ @@ -172,6 +191,15 @@ def ensure_west_ready(self): self._west_ready = True + @overload + def _run_west(self, *args: str, capture_output: Literal[False] = False) -> None: ... + + @overload + def _run_west(self, *args: str, capture_output: Literal[True]) -> str: ... + + @overload + def _run_west(self, *args: str, capture_output: bool) -> str | None: ... + def _run_west(self, *args: str, capture_output=False): if capture_output: with redirect_stdout(StringIO()) as output: diff --git a/zmk/templates/__init__.py b/zmk/templates/__init__.py index df0cf67..a4ceeca 100644 --- a/zmk/templates/__init__.py +++ b/zmk/templates/__init__.py @@ -12,9 +12,9 @@ """ import re -from itertools import pairwise +from collections.abc import Generator from pathlib import Path -from typing import Any, Generator +from typing import Any, cast from mako.lookup import TemplateLookup from mako.template import Template @@ -43,8 +43,8 @@ def get_template_files( file_name = Template(text=file.name, strict_undefined=True) yield ( - file_name.render_unicode(**data), - _ensure_trailing_newline(template.render_unicode(**data)), + cast(str, file_name.render_unicode(**data)), + _ensure_trailing_newline(cast(str, template.render_unicode(**data))), ) diff --git a/zmk/terminal.py b/zmk/terminal.py index 44f710c..06cd015 100644 --- a/zmk/terminal.py +++ b/zmk/terminal.py @@ -2,18 +2,25 @@ Terminal utilities """ +# Ignore missing attributes for platform-specific modules +# pyright: reportAttributeAccessIssue = false + +# Ignore alternative declarations of the same functions +# pyright: reportRedeclaration = false + import os import sys +from collections.abc import Generator from contextlib import contextmanager -def hide_cursor(): +def hide_cursor() -> None: """Hides the terminal cursor.""" sys.stdout.write("\x1b[?25l") sys.stdout.flush() -def show_cursor(): +def show_cursor() -> None: """Unhides the terminal cursor.""" sys.stdout.write("\x1b[?25h") sys.stdout.flush() @@ -62,7 +69,7 @@ def show_cursor(): 83: DELETE, } - def read_key(): + def read_key() -> bytes: """ Waits for a key to be pressed and returns it. @@ -83,7 +90,7 @@ def read_key(): return key @contextmanager - def enable_vt_mode(): + def enable_vt_mode() -> Generator[None, None, None]: """ Context manager which enables virtual terminal processing. """ @@ -102,7 +109,7 @@ def enable_vt_mode(): kernel32.SetConsoleMode(stdout_handle, old_stdout_mode) @contextmanager - def disable_echo(): + def disable_echo() -> Generator[None, None, None]: """ Context manager which disables console echo """ @@ -122,7 +129,7 @@ def disable_echo(): import termios @contextmanager - def enable_vt_mode(): + def enable_vt_mode() -> Generator[None, None, None]: """ Context manager which enables virtual terminal processing. """ @@ -130,7 +137,7 @@ def enable_vt_mode(): yield @contextmanager - def disable_echo(): + def disable_echo() -> Generator[None, None, None]: """ Context manager which disables console echo """ @@ -144,7 +151,7 @@ def disable_echo(): finally: termios.tcsetattr(sys.stdin, termios.TCSAFLUSH, oldattr) - def read_key(): + def read_key() -> bytes: """ Waits for a key to be pressed and returns it. @@ -170,7 +177,7 @@ def get_cursor_pos() -> tuple[int, int]: return (int(row) - 1, int(col) - 1) -def set_cursor_pos(row=0, col=0): +def set_cursor_pos(row=0, col=0) -> None: """ Sets the cursor to the given row and column. Positions are 0-based. """ diff --git a/zmk/util.py b/zmk/util.py index a59ea48..6daafa6 100644 --- a/zmk/util.py +++ b/zmk/util.py @@ -5,9 +5,10 @@ import functools import operator import os +from collections.abc import Generator, Iterable from contextlib import contextmanager from pathlib import Path -from typing import Iterable, Optional, TypeVar +from typing import TypeVar from rich.console import Console from rich.progress import Progress, SpinnerColumn, TextColumn @@ -29,7 +30,7 @@ def splice(text: str, index: int, count: int = 0, insert_text: str = ""): @contextmanager -def set_directory(path: Path): +def set_directory(path: Path) -> Generator[None, None, None]: """Context manager to temporarily change the working directory""" original = Path().absolute() @@ -41,7 +42,7 @@ def set_directory(path: Path): @contextmanager -def spinner(message: str, console: Optional[Console] = None, transient: bool = True): +def spinner(message: str, console: Console | None = None, transient: bool = True): """Context manager which displays a loading spinner for its duration""" with Progress( SpinnerColumn(), diff --git a/zmk/yaml.py b/zmk/yaml.py index 66ddf5a..35bf618 100644 --- a/zmk/yaml.py +++ b/zmk/yaml.py @@ -4,7 +4,7 @@ from io import UnsupportedOperation from pathlib import Path -from typing import IO +from typing import IO, Any import ruamel.yaml @@ -18,7 +18,10 @@ class YAML(ruamel.yaml.YAML): readable or seekable, a leading comment will be overwritten. """ - def dump(self, data, stream: Path | IO = None, *, transform=None): + def dump(self, data, stream: Path | IO | None = None, *, transform=None) -> None: + if stream is None: + raise TypeError("Dumping from a context manager is not supported.") + if isinstance(stream, Path): with stream.open("a+", encoding="utf-8") as f: f.seek(0) @@ -59,6 +62,6 @@ def _seek_to_document_start(stream: IO): return -def read_yaml(path: Path): +def read_yaml(path: Path) -> Any: """Parse a YAML file""" return YAML().load(path)