From 689354fcdec441d7c51a145c90e56c51adc3ff7c Mon Sep 17 00:00:00 2001 From: Florian Mounier Date: Tue, 4 Jun 2024 17:02:34 +0200 Subject: [PATCH 1/3] [IMP] queue_job: Add split method --- queue_job/delay.py | 44 ++++++++++++ queue_job/tests/__init__.py | 1 + queue_job/tests/test_delayable_split.py | 94 +++++++++++++++++++++++++ 3 files changed, 139 insertions(+) create mode 100644 queue_job/tests/test_delayable_split.py diff --git a/queue_job/delay.py b/queue_job/delay.py index 4b2ed5c001..1836ce8550 100644 --- a/queue_job/delay.py +++ b/queue_job/delay.py @@ -534,6 +534,50 @@ def delay(self): """Delay the whole graph""" self._graph.delay() + def split(self, size): + """Split the Delayable into a DelayableGroup containing batches + of size `size` + """ + if not self._job_method: + raise ValueError("No method set on the Delayable") + + total_records = len(self.recordset) + + delayables = [] + for index in range(0, total_records, size): + recordset = self.recordset[index : index + size] + delayable = Delayable( + recordset, + priority=self.priority, + eta=self.eta, + max_retries=self.max_retries, + description=self.description, + channel=self.channel, + identity_key=self.identity_key, + ) + # Update the __self__ + delayable._job_method = getattr(recordset, self._job_method.__name__) + delayable._job_args = self._job_args + delayable._job_kwargs = self._job_kwargs + + delayables.append(delayable) + + description = self.description or ( + self._job_method.__doc__.splitlines()[0].strip() + if self._job_method.__doc__ + else "{}.{}".format(self.recordset._name, self._job_method.__name__) + ) + for index, delayable in enumerate(delayables): + delayable.set( + description="%s (split %s/%s)" + % (description, index + 1, len(delayables)) + ) + + # Prevent warning on deletion + self._generated_job = True + + return DelayableGroup(*delayables) + def _build_job(self): if self._generated_job: return self._generated_job diff --git a/queue_job/tests/__init__.py b/queue_job/tests/__init__.py index e0ff9576a5..db53ac3a60 100644 --- a/queue_job/tests/__init__.py +++ b/queue_job/tests/__init__.py @@ -1,6 +1,7 @@ from . import test_runner_channels from . import test_runner_runner from . import test_delayable +from . import test_delayable_split from . import test_json_field from . import test_model_job_channel from . import test_model_job_function diff --git a/queue_job/tests/test_delayable_split.py b/queue_job/tests/test_delayable_split.py new file mode 100644 index 0000000000..469983aab6 --- /dev/null +++ b/queue_job/tests/test_delayable_split.py @@ -0,0 +1,94 @@ +# Copyright 2024 Akretion (http://www.akretion.com). +# @author Florian Mounier +# License AGPL-3.0 or later (http://www.gnu.org/licenses/agpl). + +import unittest + +# pylint: disable=odoo-addons-relative-import +from odoo.addons.queue_job.delay import Delayable + + +class TestDelayableSplit(unittest.TestCase): + def setUp(self): + super().setUp() + + class FakeRecordSet(list): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._name = "recordset" + + def __getitem__(self, key): + if isinstance(key, slice): + return FakeRecordSet(super().__getitem__(key)) + return super().__getitem__(key) + + def method(self, arg, kwarg=None): + """Method to be called""" + return arg, kwarg + + self.FakeRecordSet = FakeRecordSet + + def test_delayable_split_no_method_call_beforehand(self): + dl = Delayable(self.FakeRecordSet(range(20))) + with self.assertRaises(ValueError): + dl.split(3) + + def test_delayable_split_10_3(self): + dl = Delayable(self.FakeRecordSet(range(10))) + dl.method("arg", kwarg="kwarg") + group = dl.split(3) + self.assertEqual(len(group._delayables), 4) + delayables = sorted(list(group._delayables), key=lambda x: x.description) + self.assertEqual(delayables[0].recordset, self.FakeRecordSet([0, 1, 2])) + self.assertEqual(delayables[1].recordset, self.FakeRecordSet([3, 4, 5])) + self.assertEqual(delayables[2].recordset, self.FakeRecordSet([6, 7, 8])) + self.assertEqual(delayables[3].recordset, self.FakeRecordSet([9])) + self.assertEqual(delayables[0].description, "Method to be called (split 1/4)") + self.assertEqual(delayables[1].description, "Method to be called (split 2/4)") + self.assertEqual(delayables[2].description, "Method to be called (split 3/4)") + self.assertEqual(delayables[3].description, "Method to be called (split 4/4)") + self.assertNotEqual(delayables[0]._job_method, dl._job_method) + self.assertNotEqual(delayables[1]._job_method, dl._job_method) + self.assertNotEqual(delayables[2]._job_method, dl._job_method) + self.assertNotEqual(delayables[3]._job_method, dl._job_method) + self.assertEqual(delayables[0]._job_method.__name__, dl._job_method.__name__) + self.assertEqual(delayables[1]._job_method.__name__, dl._job_method.__name__) + self.assertEqual(delayables[2]._job_method.__name__, dl._job_method.__name__) + self.assertEqual(delayables[3]._job_method.__name__, dl._job_method.__name__) + self.assertEqual(delayables[0]._job_args, ("arg",)) + self.assertEqual(delayables[1]._job_args, ("arg",)) + self.assertEqual(delayables[2]._job_args, ("arg",)) + self.assertEqual(delayables[3]._job_args, ("arg",)) + self.assertEqual(delayables[0]._job_kwargs, {"kwarg": "kwarg"}) + self.assertEqual(delayables[1]._job_kwargs, {"kwarg": "kwarg"}) + self.assertEqual(delayables[2]._job_kwargs, {"kwarg": "kwarg"}) + self.assertEqual(delayables[3]._job_kwargs, {"kwarg": "kwarg"}) + + def test_delayable_split_10_5(self): + dl = Delayable(self.FakeRecordSet(range(10))) + dl.method("arg", kwarg="kwarg") + group = dl.split(5) + self.assertEqual(len(group._delayables), 2) + delayables = sorted(list(group._delayables), key=lambda x: x.description) + self.assertEqual(delayables[0].recordset, self.FakeRecordSet([0, 1, 2, 3, 4])) + self.assertEqual(delayables[1].recordset, self.FakeRecordSet([5, 6, 7, 8, 9])) + self.assertEqual(delayables[0].description, "Method to be called (split 1/2)") + self.assertEqual(delayables[1].description, "Method to be called (split 2/2)") + + def test_delayable_split_10_10(self): + dl = Delayable(self.FakeRecordSet(range(10))) + dl.method("arg", kwarg="kwarg") + group = dl.split(10) + self.assertEqual(len(group._delayables), 1) + delayables = sorted(list(group._delayables), key=lambda x: x.description) + self.assertEqual(delayables[0].recordset, self.FakeRecordSet(range(10))) + self.assertEqual(delayables[0].description, "Method to be called (split 1/1)") + + def test_delayable_split_10_20(self): + dl = Delayable(self.FakeRecordSet(range(10))) + dl.method("arg", kwarg="kwarg") + group = dl.split(20) + self.assertEqual(len(group._delayables), 1) + delayables = sorted(list(group._delayables), key=lambda x: x.description) + self.assertEqual(delayables[0].recordset, self.FakeRecordSet(range(10))) + self.assertEqual(delayables[0].description, "Method to be called (split 1/1)") From 5459b92994306b08046bcbeaf86686d9e424e72e Mon Sep 17 00:00:00 2001 From: Florian Mounier Date: Tue, 4 Jun 2024 17:12:25 +0200 Subject: [PATCH 2/3] [FIX] queue_job: Migrate unittest.TestCase tests to common.TransactionCase --- queue_job/tests/test_delayable.py | 6 +++--- queue_job/tests/test_delayable_split.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/queue_job/tests/test_delayable.py b/queue_job/tests/test_delayable.py index c7295ea2b1..1495a9e05b 100644 --- a/queue_job/tests/test_delayable.py +++ b/queue_job/tests/test_delayable.py @@ -1,15 +1,15 @@ # copyright 2019 Camptocamp # license agpl-3.0 or later (http://www.gnu.org/licenses/agpl.html) -import unittest - import mock +from odoo.tests import common + # pylint: disable=odoo-addons-relative-import from odoo.addons.queue_job.delay import Delayable, DelayableGraph -class TestDelayable(unittest.TestCase): +class TestDelayable(common.TransactionCase): def setUp(self): super().setUp() self.recordset = mock.MagicMock(name="recordset") diff --git a/queue_job/tests/test_delayable_split.py b/queue_job/tests/test_delayable_split.py index 469983aab6..8047141723 100644 --- a/queue_job/tests/test_delayable_split.py +++ b/queue_job/tests/test_delayable_split.py @@ -2,13 +2,13 @@ # @author Florian Mounier # License AGPL-3.0 or later (http://www.gnu.org/licenses/agpl). -import unittest +from odoo.tests import common # pylint: disable=odoo-addons-relative-import from odoo.addons.queue_job.delay import Delayable -class TestDelayableSplit(unittest.TestCase): +class TestDelayableSplit(common.TransactionCase): def setUp(self): super().setUp() From 3865f141cd377040887856c98a074b7810ecdbb1 Mon Sep 17 00:00:00 2001 From: Florian Mounier Date: Tue, 2 Jul 2024 12:42:29 +0200 Subject: [PATCH 3/3] [IMP] queue_job: Add chain parameter on split method --- queue_job/delay.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/queue_job/delay.py b/queue_job/delay.py index 1836ce8550..00603b6840 100644 --- a/queue_job/delay.py +++ b/queue_job/delay.py @@ -534,9 +534,9 @@ def delay(self): """Delay the whole graph""" self._graph.delay() - def split(self, size): - """Split the Delayable into a DelayableGroup containing batches - of size `size` + def split(self, size, chain=False): + """Split the Delayable into a DelayableGroup or DelayableChain + if `chain` is True containing batches of size `size` """ if not self._job_method: raise ValueError("No method set on the Delayable") @@ -576,7 +576,7 @@ def split(self, size): # Prevent warning on deletion self._generated_job = True - return DelayableGroup(*delayables) + return (DelayableChain if chain else DelayableGroup)(*delayables) def _build_job(self): if self._generated_job: