Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] support signal handling in AWS batch #17

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions pyfgaws/batch/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import copy
import enum
import logging
import signal
import sys
from typing import Any
from typing import Dict
Expand All @@ -14,9 +15,9 @@
from typing import Union

import botocore
import mypy_boto3_batch as batch
import namegenerator
from botocore.waiter import Waiter as BotoWaiter
import mypy_boto3_batch as batch
from mypy_boto3_batch.type_defs import ArrayPropertiesTypeDef # noqa
from mypy_boto3_batch.type_defs import ContainerDetailTypeDef # noqa
from mypy_boto3_batch.type_defs import ContainerOverridesTypeDef # noqa
Expand Down Expand Up @@ -126,6 +127,9 @@ class BatchJob:
retry_strategy: the retry strategy to use for failed jobs from the `submit_job` operation.
timeout: the timeout configuration
logger: logger to write status messages
cancel_on: cancel a submitted batch job if one of the given signals are encountered.
terminate_on: terminate a submitted batch job if one of the given signals are
encountered; this will take precedence over cancel.
"""

def __init__(
Expand All @@ -148,6 +152,8 @@ def __init__(
retry_strategy: Optional[RetryStrategyTypeDef] = None,
timeout: Optional[JobTimeoutTypeDef] = None,
logger: Optional[logging.Logger] = None,
cancel_on: Optional[List[int]] = None,
terminate_on: Optional[List[int]] = None,
) -> None:

self.client: batch.Client = client
Expand Down Expand Up @@ -192,6 +198,9 @@ def __init__(

self.job_id: Optional[str] = None

self.cancel_on: Optional[List[int]] = cancel_on
self.terminate_on: Optional[List[int]] = terminate_on

def _add_to_container_overrides(
self, key: _ContainerOverridesTypes, value: Optional[Any]
) -> None:
Expand All @@ -202,14 +211,23 @@ def _add_to_container_overrides(
self.container_overrides[key] = value

@classmethod
def from_id(cls, client: batch.Client, job_id: str) -> "BatchJob":
def from_id(
cls,
client: batch.Client,
job_id: str,
cancel_on: Optional[List[int]] = None,
terminate_on: Optional[List[int]] = None,
) -> "BatchJob":
""""Builds a batch job from the given ID.

Will lookup the job to retrieve job information.

Args:
client: the AWS batch client
job_id: the job identifier
cancel_on: cancel a submitted batch job if one of the given signals are encountered.
terminate_on: terminate a submitted batch job if one of the given signals are
encountered; this will take precedence over cancel.
"""
jobs_response = client.describe_jobs(jobs=[job_id])
jobs = jobs_response["jobs"]
Expand Down Expand Up @@ -247,6 +265,8 @@ def add_to_container_overrides(key: _ContainerOverridesTypes) -> None:
container_overrides=container_overrides,
retry_strategy=job_info.get("retryStrategy"),
timeout=job_info.get("timeout"),
cancel_on=cancel_on,
terminate_on=terminate_on,
)

job.job_id = job_id
Expand Down Expand Up @@ -342,6 +362,7 @@ def wait_on(
max_attempts: Optional[int] = None,
delay: Optional[int] = None,
after_success: bool = False,
terminate_on_signal: bool = False,
) -> batch.type_defs.JobDetailTypeDef:
"""Waits for the given states with associated success or failure.

Expand All @@ -352,6 +373,8 @@ def wait_on(
status_to_state: mapping of status to success (true) or failure (false) state
max_attempts: the maximum # of attempts until reaching the given state.
delay: the delay before waiting
after_success: true to treat all status after the last successful input status are
treated as success, otherwise failure.
"""
assert len(status_to_state) > 0, "No statuses given"
assert any(value for value in status_to_state.values()), "No statuses with success set."
Expand Down Expand Up @@ -391,6 +414,19 @@ def wait_on(
model: botocore.waiter.WaiterModel = botocore.waiter.WaiterModel(config)
waiter: BotoWaiter = botocore.waiter.create_waiter_with_client(name, model, self.client)
waiter.wait(jobs=[self.job_id])

if self.cancel_on is not None:
for code in self.cancel_on:
signal.signal(
code, lambda signum, frame: self.cancel_job(reason=f"Interrupted: {code}")
)

if self.terminate_on is not None:
for code in self.terminate_on:
signal.signal(
code, lambda signum, frame: self.terminate_job(reason=f"Interrupted: {code}")
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs a unit test probably. Make a sub-process, then check that it terminates?


return self.describe_job()

def wait_on_running(
Expand Down
21 changes: 21 additions & 0 deletions pyfgaws/batch/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import logging
import signal
import threading
import time
from typing import Any
Expand Down Expand Up @@ -83,6 +84,8 @@ def run_job(
environment: Optional[KeyValuePairTypeDef] = None,
watch_until: List[Status] = [],
after_success: bool = False,
cancel_on_interrupt: bool = False,
terminate_on_interrupt: bool = False,
) -> None:
"""Submits a batch job and optionally waits for it to reach one of the given states.

Expand All @@ -105,13 +108,29 @@ def run_job(
See the `--after-success` option to control this behavior.
after_success: true to treat states after the `watch_until` states as success, otherwise
failure.
cancel_on_interrupt: true to cancel the job if SIGINT or SIGTERM is encountered,
false otherwise. Requires `--watch-until` to be specified.
terminate_on_interrupt: true to cancel the job if SIGINT or SIGTERM is encountered,
false otherwise. Requires `--watch-until` to be specified.
"""
logger = logging.getLogger(__name__)

assert watch_until or not cancel_on_interrupt, "--cancel-on-interrupt requires --watch-until"
assert (
watch_until or not terminate_on_interrupt
), "--terminate-on-interrupt requires --watch-until"

batch_client: batch.Client = boto3.client(
service_name="batch", region_name=region_name # type: ignore
)

cancel_on: Optional[List[int]] = (
[signal.SIGINT, signal.SIGTERM] if cancel_on_interrupt and watch_until else None
)
terminate_on: Optional[List[int]] = (
[signal.SIGINT, signal.SIGTERM] if terminate_on_interrupt and watch_until else None
)

job = BatchJob(
client=batch_client,
queue=queue,
Expand All @@ -123,6 +142,8 @@ def run_job(
environment=None if environment is None else [environment],
parameters=parameters,
logger=logger,
cancel_on=cancel_on,
terminate_on=terminate_on,
)

# Submit the job
Expand Down