diff --git a/.github/workflows/nightly.yaml b/.github/workflows/nightly.yaml index 4052b9a..4041ba5 100644 --- a/.github/workflows/nightly.yaml +++ b/.github/workflows/nightly.yaml @@ -190,19 +190,12 @@ jobs: $VENV_PYTHON -m pip install --require-hashes -r requirements.txt - name: Set Artichoke Rust toolchain version - shell: python + shell: bash id: rust_toolchain - working-directory: artichoke run: | - import os - import tomllib - - with open("rust-toolchain.toml", "rb") as f: - data = tomllib.load(f) - toolchain = data["toolchain"]["channel"] - print(f"Rust toolchain version: {toolchain}") - with open(os.environ["GITHUB_OUTPUT"], "a") as f: - print(f"version={toolchain}", file=f) + $VENV_PYTHON -m artichoke_nightly.rust_toolchain_version \ + --file artichoke/rust-toolchain.toml \ + --format github - name: Install Rust toolchain uses: artichoke/setup-rust/build-and-test@v1.12.1 diff --git a/artichoke_nightly/github_actions.py b/artichoke_nightly/github_actions.py index 38c9317..d2d4019 100644 --- a/artichoke_nightly/github_actions.py +++ b/artichoke_nightly/github_actions.py @@ -8,8 +8,11 @@ def set_output(*, name: str, value: str) -> None: """ Set an output for a GitHub Actions job. - https://docs.github.com/en/actions/using-jobs/defining-outputs-for-jobs - https://github.blog/changelog/2022-10-11-github-actions-deprecating-save-state-and-set-output-commands/ + See the GitHub Actions documentation for [defining output for jobs] and + changes to [deprecate the set-output command]. + + [defining output for jobs]: https://docs.github.com/en/actions/using-jobs/defining-outputs-for-jobs + [deprecate the set-output command]: https://github.blog/changelog/2022-10-11-github-actions-deprecating-save-state-and-set-output-commands/ """ if github_output := os.getenv("GITHUB_OUTPUT"): @@ -22,8 +25,18 @@ def log_group(group: str) -> Iterator[None]: """ Create an expandable log group in GitHub Actions job logs. - https://docs.github.com/en/actions/using-workflows/workflow-commands-for-github-actions#grouping-log-lines + Only prints log group markers when running in GitHub Actions CI. See the GitHub + Actions documentation for [grouping log lines]. + + Args: + group (str): The name of the log group. + + [grouping log lines]: https://docs.github.com/en/actions/using-workflows/workflow-commands-for-github-actions#grouping-log-lines """ + if os.getenv("CI") != "true" or os.getenv("GITHUB_ACTIONS") != "true": + # Do nothing if not running in GitHub Actions + yield + return print(f"::group::{group}") try: @@ -33,7 +46,7 @@ def log_group(group: str) -> Iterator[None]: def emit_metadata() -> None: - if os.getenv("CI") != "true": + if os.getenv("CI") != "true" or os.getenv("GITHUB_ACTIONS") != "true": return with log_group("Workflow metadata"): if repository := os.getenv("GITHUB_REPOSITORY"): @@ -55,6 +68,24 @@ def emit_metadata() -> None: def runner_tempdir() -> Path | None: + """ + Get the temporary directory used by the GitHub Actions runner. + + This function retrieves the path to the temporary directory used by the GitHub + Actions runner during job execution. The directory path is taken from the + `RUNNER_TEMP` environment variable, which is set by GitHub Actions. This + directory is used for storing temporary files generated during the job run. + + Returns: + Optional[Path]: A Path object pointing to the runner's temporary directory if + the `RUNNER_TEMP` environment variable is set; otherwise, returns None. + + Example: + >>> temp_dir = runner_tempdir() + >>> if temp_dir: + ... print(f"Temporary directory: {temp_dir}") + """ + if temp := os.getenv("RUNNER_TEMP"): return Path(temp) return None diff --git a/artichoke_nightly/rust_toolchain_version.py b/artichoke_nightly/rust_toolchain_version.py new file mode 100755 index 0000000..dbe34b8 --- /dev/null +++ b/artichoke_nightly/rust_toolchain_version.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 + +import argparse +import sys +import tomllib +from dataclasses import dataclass, field +from enum import StrEnum +from pathlib import Path +from typing import assert_never + +from .github_actions import emit_metadata, log_group, set_output + + +class OutputFormat(StrEnum): + """Enum for output format options.""" + + PLAIN = "plain" + GITHUB = "github" + + +@dataclass(frozen=True, kw_only=True) +class Args: + """Dataclass to store command line arguments.""" + + file: Path = field(metadata={"help": "Path to the rust-toolchain.toml file."}) + format: OutputFormat = field( + metadata={"help": "Output format: either 'plain' or 'github'."} + ) + + +def parse_args() -> Args: + """Parse command line arguments into an Args dataclass.""" + parser = argparse.ArgumentParser(description="Set Rust toolchain version.") + parser.add_argument( + "-f", + "--file", + type=Path, + required=True, + help="Path to the rust-toolchain.toml file.", + ) + parser.add_argument( + "--format", + type=OutputFormat, + choices=list(OutputFormat), + default=OutputFormat.PLAIN, + help="Output format: either 'plain' or 'github'.", + ) + args = parser.parse_args() + return Args(file=args.file, format=args.format) + + +def read_toolchain_version(file_path: Path) -> str: + """ + Read the Rust toolchain version from the rust-toolchain.toml file. + + Args: + file_path (Path): Path to the rust-toolchain.toml file. + + Returns: + str: The Rust toolchain version specified in the TOML file. + + Raises: + FileNotFoundError: If the file does not exist or cannot be accessed. + ValueError: If the TOML file is malformed. + TypeError: If the TOML file does not contain the expected structure. + """ + try: + with file_path.open("rb") as f: + data = tomllib.load(f) + except tomllib.TOMLDecodeError as e: + raise ValueError(f"Failed to parse rust-toolchain.toml file: {e}") from e + + # Validate the structure and type of the expected keys + toolchain = data.get("toolchain") + if toolchain is None: + raise TypeError("Malformed rust-toolchain.toml: 'toolchain' stanza is missing.") + if not isinstance(toolchain, dict): + raise TypeError( + "Malformed rust-toolchain.toml: 'toolchain' should be a dictionary." + ) + + channel = toolchain.get("channel") + if channel is None: + raise TypeError("Malformed rust-toolchain.toml: 'channel' field is missing.") + if not isinstance(channel, str): + raise TypeError("Malformed rust-toolchain.toml: 'channel' should be a string.") + if not channel: + raise ValueError("Malformed rust-toolchain.toml: 'channel' is empty.") + + return channel + + +def format_output(toolchain_version: str, output_format: OutputFormat) -> None: + """ + Format the output based on the selected format. + + Args: + toolchain_version (str): The Rust toolchain version. + output_format (OutputFormat): The desired output format, either 'plain' or + 'github'. + """ + match output_format: + case OutputFormat.PLAIN: + print(toolchain_version) + case OutputFormat.GITHUB: + set_output(name="version", value=toolchain_version) + case _: + assert_never(output_format) + + +def main() -> int: + """Main function to set Rust toolchain version.""" + args = parse_args() + emit_metadata() + + with log_group("Setting Rust toolchain version"): + try: + toolchain_version = read_toolchain_version(args.file) + format_output(toolchain_version, args.format) + except (FileNotFoundError, OSError, ValueError, TypeError) as e: + print(e, file=sys.stderr) + return 1 + + return 0 + + +if __name__ == "__main__": + sys.exit(main())