Skip to content

Commit

Permalink
handle aborting sync jobs on shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
onlyann committed Oct 13, 2024
1 parent a8c7fc3 commit 266bef2
Show file tree
Hide file tree
Showing 12 changed files with 217 additions and 51 deletions.
26 changes: 18 additions & 8 deletions docs/howto/advanced/shutdown.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,23 @@ A worker will keep running until:
- [task.cancel](https://docs.python.org/3/library/asyncio-task.html#asyncio.Task.cancel) is called on the task created from `app.run_worker_async`

When a worker is requested to stop, it will attempt to gracefully shut down by waiting for all running jobs to complete.
If a `shutdown_timeout` option is specified, the worker will attempt to abort all jobs that have not completed by that time. Cancelling the `run_worker_async` task a second time also results in the worker aborting running jobs.
If a `shutdown_graceful_timeout` option is specified, the worker will attempt to abort all jobs that have not completed by that time. Cancelling the `run_worker_async` task a second time also results in the worker aborting running jobs.

The worker will then wait for all jobs to complete.


:::{note}
The worker aborts its remaining jobs by calling [task.cancel](https://docs.python.org/3/library/asyncio-task.html#asyncio.Task.cancel) on the underlying asyncio task that runs the job.
The worker aborts its remaining jobs by:
- setting the context so that `JobContext.should_abort` returns `AbortReason.SHUTDOWN`
- calling [task.cancel](https://docs.python.org/3/library/asyncio-task.html#asyncio.Task.cancel) on the underlying asyncio task that runs the job when the job is asynchronous

Jobs that do not respect the request to abort will prevent the worker from shutting down until they complete. In a way, it will remain a graceful shutdown for those jobs even after `shutdown_graceful_timeout`.

It is possible for that task to handle `asyncio.CancelledError` and even suppress the cancellation.
For more information, see {doc}`./cancellation`.

Currently, Procrastinate does not provide a built-in method to forcefully terminate a worker. This is something you would want to do with your process manager (e.g. systemd, Docker, Kubernetes), which typically offers options to control process termination. In that case, your jobs will be considered stale, see {doc}`../production/retry_stalled_jobs`.
:::

## Examples

### Run a worker until no job is left
Expand All @@ -29,15 +39,15 @@ async with app.open_async():
async with app.open_async():
# give jobs up to 10 seconds to complete when a stop signal is received
# all jobs still running after 10 seconds are aborted
# In the absence of shutdown_timeout, the task will complete when all jobs have completed.
await app.run_worker_async(shutdown_timeout=10)
# In the absence of shutdown_graceful_timeout, the task will complete when all jobs have completed.
await app.run_worker_async(shutdown_graceful_timeout=10)
```

### Run a worker until its Task is cancelled

```python
async with app.open_async():
worker = asyncio.create_task(app run_worker_async())
worker = asyncio.create_task(app.run_worker_async())
# eventually
worker.cancel()
try:
Expand All @@ -51,7 +61,7 @@ async with app.open_async():

```python
async with app.open_async():
worker = asyncio.create_task(app.run_worker_async(shutdown_timeout=10))
worker = asyncio.create_task(app.run_worker_async(shutdown_graceful_timeout=10))
# eventually
worker.cancel()
try:
Expand All @@ -66,7 +76,7 @@ async with app.open_async():

```python
async with app.open_async():
# Notice that shutdown_timeout is not specified
# Notice that shutdown_graceful_timeout is not specified
worker = asyncio.create_task(app.run_worker_async())

# eventually
Expand Down
7 changes: 4 additions & 3 deletions procrastinate/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class WorkerOptions(TypedDict):
wait: NotRequired[bool]
fetch_job_polling_interval: NotRequired[float]
abort_job_polling_interval: NotRequired[float]
shutdown_timeout: NotRequired[float]
shutdown_graceful_timeout: NotRequired[float]
listen_notify: NotRequired[bool]
delete_jobs: NotRequired[str | jobs.DeleteJobCondition]
additional_context: NotRequired[dict[str, Any]]
Expand Down Expand Up @@ -289,10 +289,11 @@ async def run_worker_async(self, **kwargs: Unpack[WorkerOptions]) -> None:
mechanism and can reasonably be set to a higher value.
(defaults to 5.0)
shutdown_timeout: ``float``
shutdown_graceful_timeout: ``float``
Indicates the maximum duration (in seconds) the worker waits for jobs to
complete when requested stop. Jobs that have not been completed by that time
complete when requested to stop. Jobs that have not been completed by that time
are aborted. A value of None corresponds to no timeout.
(defaults to None)
listen_notify : ``bool``
If ``True``, allocates a connection from the pool to
Expand Down
7 changes: 7 additions & 0 deletions procrastinate/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,13 @@ def configure_worker_parser(subparsers: argparse._SubParsersAction):
help="How often to polling for abort requests",
envvar="WORKER_ABORT_JOB_POLLING_INTERVAL",
)
add_argument(
worker_parser,
"--shutdown-graceful-timeout",
type=float,
help="How long to wait for jobs to complete when shutting down before aborting them",
envvar="WORKER_SHUTDOWN_GRACEFUL_TIMEOUT",
)
add_argument(
worker_parser,
"-w",
Expand Down
14 changes: 13 additions & 1 deletion procrastinate/job_context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import time
from enum import Enum
from typing import Any, Callable, Iterable

import attr
Expand Down Expand Up @@ -32,6 +33,17 @@ def as_dict(self):
return result


class AbortReason(Enum):
"""
An enumeration of reasons a job is being aborted
"""

USER_REQUEST = "user_request" #: The user requested to abort the job
SHUTDOWN = (
"shutdown" #: The job is being aborted as part of shutting down the worker
)


@attr.dataclass(frozen=True, kw_only=True)
class JobContext:
"""
Expand All @@ -51,7 +63,7 @@ class JobContext:

additional_context: dict = attr.ib(factory=dict)

should_abort: Callable[[], bool]
should_abort: Callable[[], AbortReason | None]

def evolve(self, **update: Any) -> JobContext:
return attr.evolve(self, **update)
Expand Down
2 changes: 1 addition & 1 deletion procrastinate/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def list_jobs(
status: str | None = None,
lock: str | None = None,
queueing_lock: str | None = None,
) -> Iterable[jobs.Job]:
) -> list[jobs.Job]:
"""
Sync version of `list_jobs_async`
"""
Expand Down
80 changes: 55 additions & 25 deletions procrastinate/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(
wait: bool = True,
fetch_job_polling_interval: float = FETCH_JOB_POLLING_INTERVAL,
abort_job_polling_interval: float = ABORT_JOB_POLLING_INTERVAL,
shutdown_timeout: float | None = None,
shutdown_graceful_timeout: float | None = None,
listen_notify: bool = True,
delete_jobs: str | jobs.DeleteJobCondition | None = None,
additional_context: dict[str, Any] | None = None,
Expand Down Expand Up @@ -71,8 +71,8 @@ def __init__(
self._running_jobs: dict[asyncio.Task, job_context.JobContext] = {}
self._job_semaphore = asyncio.Semaphore(self.concurrency)
self._stop_event = asyncio.Event()
self.shutdown_timeout = shutdown_timeout
self._job_ids_to_abort = set()
self.shutdown_graceful_timeout = shutdown_graceful_timeout
self._job_ids_to_abort: dict[int, job_context.AbortReason] = dict()

def stop(self):
if self._stop_event.is_set():
Expand Down Expand Up @@ -149,7 +149,8 @@ async def _persist_job_status(
job=job, status=status, delete_job=delete_job
)

self._job_ids_to_abort.discard(job.id)
assert job.id
self._job_ids_to_abort.pop(job.id, None)

self.logger.debug(
f"Acknowledged job completion {job.call_string}",
Expand All @@ -171,8 +172,10 @@ def _log_job_outcome(
):
if status == jobs.Status.SUCCEEDED:
log_action, log_title = "job_success", "Success"
elif status == jobs.Status.ABORTED:
elif status == jobs.Status.ABORTED and not job_retry:
log_action, log_title = "job_aborted", "Aborted"
elif status == jobs.Status.ABORTED and job_retry:
log_action, log_title = "job_aborted_retry", "Aborted, to retry"
elif job_retry:
log_action, log_title = "job_error_retry", "Error, to retry"
else:
Expand Down Expand Up @@ -252,7 +255,10 @@ async def ensure_async() -> Callable[..., Awaitable]:
except BaseException as e:
exc_info = e

if not isinstance(e, exceptions.JobAborted):
# aborted job can be retried if it is caused by a shutdown.
if not (isinstance(e, exceptions.JobAborted)) or (
context.should_abort() == job_context.AbortReason.SHUTDOWN
):
job_retry = (
task.get_retry_exception(exception=e, job=job) if task else None
)
Expand Down Expand Up @@ -321,7 +327,6 @@ async def _fetch_and_process_jobs(self):
break

job_id = job.id
assert job_id

context = job_context.JobContext(
app=self.app,
Expand All @@ -331,7 +336,9 @@ async def _fetch_and_process_jobs(self):
if self.additional_context
else {},
job=job,
should_abort=lambda: job_id in self._job_ids_to_abort,
should_abort=lambda: self._job_ids_to_abort.get(job_id)
if job_id
else None,
start_timestamp=time.time(),
)
job_task = asyncio.create_task(
Expand All @@ -357,19 +364,11 @@ async def run(self):
try:
# shield the loop task from cancellation
# instead, a stop event is set to enable graceful shutdown
await utils.wait_any(asyncio.shield(loop_task), self._stop_event.wait())
if self._stop_event.is_set():
try:
await asyncio.wait_for(loop_task, timeout=self.shutdown_timeout)
except asyncio.TimeoutError:
pass
await asyncio.shield(loop_task)
except asyncio.CancelledError:
# worker.run is cancelled, usually by cancelling app.run_worker_async
self.stop()
try:
await asyncio.wait_for(loop_task, timeout=self.shutdown_timeout)
except asyncio.TimeoutError:
pass
await loop_task
raise

async def _handle_notification(
Expand Down Expand Up @@ -404,16 +403,24 @@ async def _poll_jobs_to_abort(self):

def _handle_abort_jobs_requested(self, job_ids: Iterable[int]):
running_job_ids = {c.job.id for c in self._running_jobs.values() if c.job.id}
new_job_ids_to_abort = (running_job_ids & set(job_ids)) - self._job_ids_to_abort
new_job_ids_to_abort = (running_job_ids & set(job_ids)) - set(
self._job_ids_to_abort
)

for process_job_task, context in self._running_jobs.items():
if context.job.id in new_job_ids_to_abort:
self._abort_job(process_job_task, context)
self._abort_job(
process_job_task, context, job_context.AbortReason.USER_REQUEST
)

def _abort_job(
self, process_job_task: asyncio.Task, context: job_context.JobContext
self,
process_job_task: asyncio.Task,
context: job_context.JobContext,
reason: job_context.AbortReason,
):
self._job_ids_to_abort.add(context.job.id)
assert context.job.id
self._job_ids_to_abort[context.job.id] = reason

log_message: str
task = self.app.tasks.get(context.job.task_name)
Expand Down Expand Up @@ -447,16 +454,39 @@ async def _shutdown(self, side_tasks: list[asyncio.Task]):
),
)

# wait for any in progress job to complete processing
# use return_exceptions to not cancel other job tasks if one was to fail
await asyncio.gather(*self._running_jobs, return_exceptions=True)
if self._running_jobs:
await asyncio.wait(
self._running_jobs, timeout=self.shutdown_graceful_timeout
)

# As a reminder, tasks have a done callback that
# removes them from the self._running_jobs dict,
# so as the tasks stop, this dict will shrink.
if self._running_jobs:
self.logger.info(
f"{len(self._running_jobs)} jobs still running after graceful timeout. Aborting them",
extra=self._log_extra(
action="stop_worker",
queues=self.queues,
context=None,
job_result=None,
),
)
await self._abort_running_jobs()

self.logger.info(
f"Stopped worker on {utils.queues_display(self.queues)}",
extra=self._log_extra(
action="stop_worker", queues=self.queues, context=None, job_result=None
),
)

async def _abort_running_jobs(self):
for task, context in self._running_jobs.items():
self._abort_job(task, context, job_context.AbortReason.SHUTDOWN)

await asyncio.gather(*self._running_jobs, return_exceptions=True)

def _start_side_tasks(self) -> list[asyncio.Task]:
"""Start side tasks such as periodic deferrer and notification listener"""
side_tasks = [
Expand Down
Loading

0 comments on commit 266bef2

Please sign in to comment.