Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] version.py: guard git calls during wheel builds (sans .git) #2590

Merged
merged 2 commits into from
May 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 71 additions & 26 deletions apis/python/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,22 @@ def err(*args, **kwargs):
print(*args, file=sys.stderr, **kwargs)


def lines(*cmd, drop_trailing_newline: bool = True, **kwargs) -> List[str]:
def lines(
*cmd, drop_trailing_newline: bool = True, stderr=DEVNULL, **kwargs
) -> List[str]:
"""Run a command and return its output as a list of lines.

Strip trailing newlines, and drop the last line if it's empty, by default."""
lns = [ln.rstrip("\n") for ln in check_output(cmd, **kwargs).decode().splitlines()]
lns = [
ln.rstrip("\n")
for ln in check_output(cmd, stderr=stderr, **kwargs).decode().splitlines()
]
if lns and drop_trailing_newline and not lns[-1]:
lns.pop()
return lns


def line(*cmd, **kwargs) -> Optional[str]:
def line(*cmd, **kwargs) -> str:
"""Run a command, verify exactly one line of stdout, return it."""
lns = lines(*cmd, **kwargs)
if len(lns) != 1:
Expand All @@ -113,8 +118,11 @@ def line(*cmd, **kwargs) -> Optional[str]:

def get_latest_tag() -> Optional[str]:
"""Return the most recent local Git tag of the form `[0-9].*.*` (or `None` if none exist)."""
tags = lines("git", "tag", "--list", "--sort=v:refname", "[0-9].*.*")
return tags[-1] if tags else None
try:
tags = lines("git", "tag", "--list", "--sort=v:refname", "[0-9].*.*")
return tags[-1] if tags else None
except CalledProcessError:
return None


def get_latest_remote_tag(remote: str) -> str:
Expand All @@ -133,6 +141,40 @@ def get_sha_base10() -> int:
return int(sha, 16)


def get_only_remote() -> Optional[str]:
"""Find the only remote Git repository, if one exists."""
try:
remotes = lines("git", "remote")
if len(remotes) == 1:
return remotes[0]
except CalledProcessError:
pass
return None


def get_default_remote() -> Optional[str]:
"""Find a Git remote to parse a most recent release tag from.

- If the current branch tracks a remote branch, use that remote
- Otherwise, if there's only one remote, use that
- Otherwise, return `None`
"""
try:
tracked_branch = line(
"git", "rev-parse", "--abbrev-ref", "--symbolic-full-name", "@{u}"
)
tracked_remote = tracked_branch.split("/")[0]
err(f"Parsed tracked remote {tracked_remote} from branch {tracked_branch}")
return tracked_remote
except CalledProcessError:
remote = get_only_remote()
if remote:
err(f"Checking tags at default/only remote {remote}")
return remote
else:
return None


def get_git_version() -> Optional[str]:
"""Construct a PEP440-compatible version string that encodes various Git state.

Expand All @@ -149,9 +191,7 @@ def get_git_version() -> Optional[str]:
abbreviated Git SHA, converted to base 10 for PEP440 compliance).
"""
try:
git_version = line(
"git", "describe", "--long", "--tags", "--match", "[0-9]*.*", stderr=DEVNULL
)
git_version = line("git", "describe", "--long", "--tags", "--match", "[0-9]*.*")
except CalledProcessError:
git_version = None

Expand All @@ -174,23 +214,28 @@ def get_git_version() -> Optional[str]:
if latest_tag:
err(f"Git traversal returned {ver}, using latest local tag {latest_tag}")
else:
try:
tracked_branch = line(
"git", "rev-parse", "--abbrev-ref", "--symbolic-full-name", "@{u}"
)
tracked_remote = tracked_branch.split("/")[0]
err(
f"Parsed tracked remote {tracked_remote} from branch {tracked_branch}"
)
except CalledProcessError:
tracked_remote = line("git", "remote")
err(f"Checking tags at default/only remote {tracked_remote}")
latest_tag = get_latest_remote_tag(tracked_remote)
err(
f"Git traversal returned {ver}, using latest tag {latest_tag} from tracked remote {tracked_remote}"
)

return f"{latest_tag}.post0.dev{get_sha_base10()}"
# Didn't find a suitable local tag, look for a tracked/default "remote", and find its
# latest release tag
tracked_remote = get_default_remote()
if tracked_remote:
try:
latest_tag = get_latest_remote_tag(tracked_remote)
err(
f"Git traversal returned {ver}, using latest tag {latest_tag} from tracked remote {tracked_remote}"
)
except CalledProcessError:
err(f"Failed to find tags in remote {tracked_remote}")
return None
else:
err("Failed to find a suitable remote for tag traversal")
return None

try:
sha_base10 = get_sha_base10()
return f"{latest_tag}.post0.dev{sha_base10}"
except CalledProcessError:
err("Failed to find current SHA")
return None
else:
commits = int(m.group("commits"))
if commits:
Expand All @@ -200,7 +245,7 @@ def get_git_version() -> Optional[str]:
return ver


def read_release_version():
def read_release_version() -> Optional[str]:
try:
with open(RELEASE_VERSION_FILE) as fd:
ver = fd.readline().strip()
Expand Down
Loading