Skip to content

Commit

Permalink
Merge pull request #65 from james-certn/global-pre-job-hook
Browse files Browse the repository at this point in the history
Add hook to run generic pre- and post-task logic
  • Loading branch information
j4mie authored Jun 25, 2024
2 parents 9a8facd + 32b1ac9 commit 01a506c
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 14 deletions.
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,35 @@ JOBS = {
}
```

#### Pre & Post Task Hooks
You can also run pre task or post task hooks, which happen in the normal processing of your `Job` instances and are executed inside the worker process.

Both pre and post task hooks receive your `Job` instance as their only argument. Here's an example:

```python
def my_pre_task_hook(job):
... # configure something before running your task
```

To ensure these hooks are run, simply add a `pre_task_hook` or `post_task_hook` key (or both, if needed) to your job config like so:

```python
JOBS = {
"my_job": {
"tasks": ["project.common.jobs.my_task"],
"pre_task_hook": "project.common.jobs.my_pre_task_hook",
"post_task_hook": "project.common.jobs.my_post_task_hook",
},
}
```

Notes:

* If the `pre_task_hook` fails (raises an exception), the task function is not run, and django-db-queue behaves as if the task function itself had failed: the failure hook is called, and the job is goes into the `FAILED` state.
* The `post_task_hook` is always run, even if the job fails. In this case, it runs after the `failure_hook`.
* If the `post_task_hook` raises an exception, this is logged but the the job is **not marked as failed** and the failure hook does not run. This is because the `post_task_hook` might need to perform cleanup that always happens after the task, no matter whether it succeeds or fails.


### Start the worker

In another terminal:
Expand Down
2 changes: 1 addition & 1 deletion django_dbq/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.1.0"
__version__ = "3.2.0"
21 changes: 9 additions & 12 deletions django_dbq/management/commands/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,26 +74,23 @@ def _process_job(self):
self.current_job = job

try:
task_function = import_string(job.next_task)
task_function(job)
job.run_pre_task_hook()
job.run_next_task()
job.update_next_task()

if not job.next_task:
job.state = Job.STATES.COMPLETE
else:
job.state = Job.STATES.READY
except Exception as exception:
logger.exception("Job id=%s failed", job.pk)
job.state = Job.STATES.FAILED

failure_hook_name = job.get_failure_hook_name()
if failure_hook_name:
logger.info(
"Running failure hook %s for job id=%s", failure_hook_name, job.pk
)
failure_hook_function = import_string(failure_hook_name)
failure_hook_function(job, exception)
else:
logger.info("No failure hook for job id=%s", job.pk)
job.run_failure_hook(exception)
finally:
try:
job.run_post_task_hook()
except:
logger.exception("Job id=%s post_task_hook failed", job.pk)

logger.info(
'Updating job: name="%s" id=%s state=%s next_task=%s',
Expand Down
35 changes: 34 additions & 1 deletion django_dbq/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from django.utils.module_loading import import_string
from django_dbq.tasks import (
get_next_task_name,
get_pre_task_hook_name,
get_post_task_hook_name,
get_failure_hook_name,
get_creation_hook_name,
)
Expand Down Expand Up @@ -126,16 +128,47 @@ def save(self, *args, **kwargs):
def update_next_task(self):
self.next_task = get_next_task_name(self.name, self.next_task) or ""

def run_next_task(self):
next_task_function = import_string(self.next_task)
next_task_function(self)

def get_pre_task_hook_name(self):
return get_pre_task_hook_name(self.name)

def get_post_task_hook_name(self):
return get_post_task_hook_name(self.name)

def get_failure_hook_name(self):
return get_failure_hook_name(self.name)

def get_creation_hook_name(self):
return get_creation_hook_name(self.name)

def run_pre_task_hook(self):
pre_task_hook_name = self.get_pre_task_hook_name()
if pre_task_hook_name:
logger.info("Running pre_task hook %s for job", pre_task_hook_name)
pre_task_hook_function = import_string(pre_task_hook_name)
pre_task_hook_function(self)

def run_post_task_hook(self):
post_task_hook_name = self.get_post_task_hook_name()
if post_task_hook_name:
logger.info("Running post_task hook %s for job", post_task_hook_name)
post_task_hook_function = import_string(post_task_hook_name)
post_task_hook_function(self)

def run_failure_hook(self, exception):
failure_hook_name = self.get_failure_hook_name()
if failure_hook_name:
logger.info("Running failure hook %s for job", failure_hook_name)
failure_hook_function = import_string(failure_hook_name)
failure_hook_function(self, exception)

def run_creation_hook(self):
creation_hook_name = self.get_creation_hook_name()
if creation_hook_name:
logger.info("Running creation hook %s for new job", creation_hook_name)
logger.info("Running creation hook %s for job", creation_hook_name)
creation_hook_function = import_string(creation_hook_name)
creation_hook_function(self)

Expand Down
12 changes: 12 additions & 0 deletions django_dbq/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@


TASK_LIST_KEY = "tasks"
PRE_TASK_HOOK_KEY = "pre_task_hook"
POST_TASK_HOOK_KEY = "post_task_hook"
FAILURE_HOOK_KEY = "failure_hook"
CREATION_HOOK_KEY = "creation_hook"

Expand All @@ -24,6 +26,16 @@ def get_next_task_name(job_name, current_task=None):
return None


def get_pre_task_hook_name(job_name):
"""Return the name of the pre task hook for the given job (as a string) or None"""
return settings.JOBS[job_name].get(PRE_TASK_HOOK_KEY)


def get_post_task_hook_name(job_name):
"""Return the name of the post_task hook for the given job (as a string) or None"""
return settings.JOBS[job_name].get(POST_TASK_HOOK_KEY)


def get_failure_hook_name(job_name):
"""Return the name of the failure hook for the given job (as a string) or None"""
return settings.JOBS[job_name].get(FAILURE_HOOK_KEY)
Expand Down
52 changes: 52 additions & 0 deletions django_dbq/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,25 @@ def failing_task(job):
raise Exception("uh oh")


def pre_task_hook(job):
job.workspace["output"] = "pre task hook ran"
job.workspace["job_id"] = str(job.id)


def post_task_hook(job):
job.workspace["output"] = "post task hook ran"
job.workspace["job_id"] = str(job.id)


def failure_hook(job, exception):
job.workspace["output"] = "failure hook ran"
job.workspace["exception"] = str(exception)
job.workspace["job_id"] = str(job.id)


def creation_hook(job):
job.workspace["output"] = "creation hook ran"
job.workspace["job_id"] = str(job.id)


@override_settings(JOBS={"testjob": {"tasks": ["a"]}})
Expand Down Expand Up @@ -316,6 +329,7 @@ def test_creation_hook(self):
job = Job.objects.create(name="testjob")
job = Job.objects.get()
self.assertEqual(job.workspace["output"], "creation hook ran")
self.assertEqual(job.workspace["job_id"], str(job.id))

def test_creation_hook_only_runs_on_create(self):
job = Job.objects.create(name="testjob")
Expand All @@ -326,6 +340,42 @@ def test_creation_hook_only_runs_on_create(self):
self.assertEqual(job.workspace["output"], "creation hook output removed")


@override_settings(
JOBS={
"testjob": {
"tasks": ["django_dbq.tests.test_task"],
"pre_task_hook": "django_dbq.tests.pre_task_hook",
}
}
)
class JobPreTaskHookTestCase(TestCase):
def test_pre_task_hook(self):
job = Job.objects.create(name="testjob")
Worker("default", 1)._process_job()
job = Job.objects.get()
self.assertEqual(job.state, Job.STATES.COMPLETE)
self.assertEqual(job.workspace["output"], "pre task hook ran")
self.assertEqual(job.workspace["job_id"], str(job.id))


@override_settings(
JOBS={
"testjob": {
"tasks": ["django_dbq.tests.test_task"],
"post_task_hook": "django_dbq.tests.post_task_hook",
}
}
)
class JobPostTaskHookTestCase(TestCase):
def test_post_task_hook(self):
job = Job.objects.create(name="testjob")
Worker("default", 1)._process_job()
job = Job.objects.get()
self.assertEqual(job.state, Job.STATES.COMPLETE)
self.assertEqual(job.workspace["output"], "post task hook ran")
self.assertEqual(job.workspace["job_id"], str(job.id))


@override_settings(
JOBS={
"testjob": {
Expand All @@ -341,6 +391,8 @@ def test_failure_hook(self):
job = Job.objects.get()
self.assertEqual(job.state, Job.STATES.FAILED)
self.assertEqual(job.workspace["output"], "failure hook ran")
self.assertIn("uh oh", job.workspace["exception"])
self.assertEqual(job.workspace["job_id"], str(job.id))


@override_settings(JOBS={"testjob": {"tasks": ["a"]}})
Expand Down

0 comments on commit 01a506c

Please sign in to comment.