diff --git a/jina/serve/runtimes/worker/batch_queue.py b/jina/serve/runtimes/worker/batch_queue.py index 31bac588d5efd..56ba81e61e2a7 100644 --- a/jina/serve/runtimes/worker/batch_queue.py +++ b/jina/serve/runtimes/worker/batch_queue.py @@ -3,7 +3,6 @@ from asyncio import Event, Task from typing import Callable, Dict, List, Optional, TYPE_CHECKING, Union from jina._docarray import docarray_v2 -import contextlib if not docarray_v2: from docarray import DocumentArray @@ -25,7 +24,6 @@ def __init__( response_docarray_cls, output_array_type: Optional[str] = None, params: Optional[Dict] = None, - allow_concurrent: bool = False, flush_all: bool = False, preferred_batch_size: int = 4, timeout: int = 10_000, @@ -33,10 +31,6 @@ def __init__( use_custom_metric: bool = False, ) -> None: # To keep old user behavior, we use data lock when flush_all is true and no allow_concurrent - if allow_concurrent and flush_all: - self._data_lock = contextlib.AsyncExitStack() - else: - self._data_lock = asyncio.Lock() self.func = func if params is None: params = dict() @@ -64,7 +58,7 @@ def __str__(self) -> str: def _reset(self) -> None: """Set all events and reset the batch queue.""" self._requests: List[DataRequest] = [] - # a list of every request ID + # a list of every request idx inside self._requests self._request_idxs: List[int] = [] self._request_lens: List[int] = [] self._docs_metrics: List[int] = [] @@ -116,26 +110,24 @@ async def push(self, request: DataRequest, http=False) -> asyncio.Queue: # this push requests the data lock. The order of accessing the data lock guarantees that this request will be put in the `big_doc` # before the `flush` task processes it. self._start_timer() - async with self._data_lock: - if not self._flush_task: - self._flush_task = asyncio.create_task(self._await_then_flush(http)) - - self._big_doc.extend(docs) - next_req_idx = len(self._requests) - num_docs = len(docs) - metric_value = num_docs - if self._custom_metric is not None: - metrics = [self._custom_metric(doc) for doc in docs] - metric_value += sum(metrics) - self._docs_metrics.extend(metrics) - self._metric_value += metric_value - self._request_idxs.extend([next_req_idx] * num_docs) - self._request_lens.append(num_docs) - self._requests.append(request) - queue = asyncio.Queue() - self._requests_completed.append(queue) - if self._metric_value >= self._preferred_batch_size: - self._flush_trigger.set() + if not self._flush_task: + self._flush_task = asyncio.create_task(self._await_then_flush(http)) + self._big_doc.extend(docs) + next_req_idx = len(self._requests) + num_docs = len(docs) + metric_value = num_docs + if self._custom_metric is not None: + metrics = [self._custom_metric(doc) for doc in docs] + metric_value += sum(metrics) + self._docs_metrics.extend(metrics) + self._metric_value += metric_value + self._request_idxs.extend([next_req_idx] * num_docs) + self._request_lens.append(num_docs) + self._requests.append(request) + queue = asyncio.Queue() + self._requests_completed.append(queue) + if self._metric_value >= self._preferred_batch_size: + self._flush_trigger.set() return queue @@ -271,96 +263,76 @@ def batch(iterable_1, iterable_2, n: Optional[int] = 1, iterable_metrics: Option await self._flush_trigger.wait() # writes to shared data between tasks need to be mutually exclusive - async with self._data_lock: - big_doc_in_batch = copy.copy(self._big_doc) - requests_idxs_in_batch = copy.copy(self._request_idxs) - requests_lens_in_batch = copy.copy(self._request_lens) - docs_metrics_in_batch = copy.copy(self._docs_metrics) - requests_in_batch = copy.copy(self._requests) - requests_completed_in_batch = copy.copy(self._requests_completed) - - self._reset() - - # At this moment, we have documents concatenated in big_doc_in_batch corresponding to requests in - # requests_idxs_in_batch with its lengths stored in requests_lens_in_batch. For each requests, there is a queue to - # communicate that the request has been processed properly. - - if not docarray_v2: - non_assigned_to_response_docs: DocumentArray = DocumentArray.empty() - else: - non_assigned_to_response_docs = self._response_docarray_cls() + big_doc_in_batch = copy.copy(self._big_doc) + requests_idxs_in_batch = copy.copy(self._request_idxs) + requests_lens_in_batch = copy.copy(self._request_lens) + docs_metrics_in_batch = copy.copy(self._docs_metrics) + requests_in_batch = copy.copy(self._requests) + requests_completed_in_batch = copy.copy(self._requests_completed) - non_assigned_to_response_request_idxs = [] - sum_from_previous_first_req_idx = 0 - for docs_inner_batch, req_idxs in batch( - big_doc_in_batch, requests_idxs_in_batch, - self._preferred_batch_size if not self._flush_all else None, docs_metrics_in_batch if self._custom_metric is not None else None - ): - involved_requests_min_indx = req_idxs[0] - involved_requests_max_indx = req_idxs[-1] - input_len_before_call: int = len(docs_inner_batch) - batch_res_docs = None - try: - batch_res_docs = await self.func( - docs=docs_inner_batch, - parameters=self.params, - docs_matrix=None, # joining manually with batch queue is not supported right now - tracing_context=None, - ) - # Output validation - if (docarray_v2 and isinstance(batch_res_docs, DocList)) or ( - not docarray_v2 - and isinstance(batch_res_docs, DocumentArray) - ): - if not len(batch_res_docs) == input_len_before_call: - raise ValueError( - f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(batch_res_docs)}' - ) - elif batch_res_docs is None: - if not len(docs_inner_batch) == input_len_before_call: - raise ValueError( - f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(docs_inner_batch)}' - ) - else: - array_name = ( - 'DocumentArray' if not docarray_v2 else 'DocList' + self._reset() + + # At this moment, we have documents concatenated in big_doc_in_batch corresponding to requests in + # requests_idxs_in_batch with its lengths stored in requests_lens_in_batch. For each requests, there is a queue to + # communicate that the request has been processed properly. + + if not docarray_v2: + non_assigned_to_response_docs: DocumentArray = DocumentArray.empty() + else: + non_assigned_to_response_docs = self._response_docarray_cls() + + non_assigned_to_response_request_idxs = [] + sum_from_previous_first_req_idx = 0 + for docs_inner_batch, req_idxs in batch( + big_doc_in_batch, requests_idxs_in_batch, + self._preferred_batch_size if not self._flush_all else None, docs_metrics_in_batch if self._custom_metric is not None else None + ): + involved_requests_min_indx = req_idxs[0] + involved_requests_max_indx = req_idxs[-1] + input_len_before_call: int = len(docs_inner_batch) + batch_res_docs = None + try: + batch_res_docs = await self.func( + docs=docs_inner_batch, + parameters=self.params, + docs_matrix=None, # joining manually with batch queue is not supported right now + tracing_context=None, + ) + # Output validation + if (docarray_v2 and isinstance(batch_res_docs, DocList)) or ( + not docarray_v2 + and isinstance(batch_res_docs, DocumentArray) + ): + if not len(batch_res_docs) == input_len_before_call: + raise ValueError( + f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(batch_res_docs)}' ) - raise TypeError( - f'The return type must be {array_name} / `None` when using dynamic batching, ' - f'but getting {batch_res_docs!r}' + elif batch_res_docs is None: + if not len(docs_inner_batch) == input_len_before_call: + raise ValueError( + f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(docs_inner_batch)}' ) - except Exception as exc: - # All the requests containing docs in this Exception should be raising it - for request_full in requests_completed_in_batch[ - involved_requests_min_indx: involved_requests_max_indx + 1 - ]: - await request_full.put(exc) else: - # We need to attribute the docs to their requests - non_assigned_to_response_docs.extend( - batch_res_docs or docs_inner_batch + array_name = ( + 'DocumentArray' if not docarray_v2 else 'DocList' ) - non_assigned_to_response_request_idxs.extend(req_idxs) - num_assigned_docs = await _assign_results( - non_assigned_to_response_docs, - non_assigned_to_response_request_idxs, - sum_from_previous_first_req_idx, - requests_lens_in_batch, - requests_in_batch, - requests_completed_in_batch, + raise TypeError( + f'The return type must be {array_name} / `None` when using dynamic batching, ' + f'but getting {batch_res_docs!r}' ) - - sum_from_previous_first_req_idx = ( - len(non_assigned_to_response_docs) - num_assigned_docs - ) - non_assigned_to_response_docs = non_assigned_to_response_docs[ - num_assigned_docs: - ] - non_assigned_to_response_request_idxs = ( - non_assigned_to_response_request_idxs[num_assigned_docs:] - ) - if len(non_assigned_to_response_request_idxs) > 0: - _ = await _assign_results( + except Exception as exc: + # All the requests containing docs in this Exception should be raising it + for request_full in requests_completed_in_batch[ + involved_requests_min_indx: involved_requests_max_indx + 1 + ]: + await request_full.put(exc) + else: + # We need to attribute the docs to their requests + non_assigned_to_response_docs.extend( + batch_res_docs or docs_inner_batch + ) + non_assigned_to_response_request_idxs.extend(req_idxs) + num_assigned_docs = await _assign_results( non_assigned_to_response_docs, non_assigned_to_response_request_idxs, sum_from_previous_first_req_idx, @@ -369,6 +341,26 @@ def batch(iterable_1, iterable_2, n: Optional[int] = 1, iterable_metrics: Option requests_completed_in_batch, ) + sum_from_previous_first_req_idx = ( + len(non_assigned_to_response_docs) - num_assigned_docs + ) + non_assigned_to_response_docs = non_assigned_to_response_docs[ + num_assigned_docs: + ] + non_assigned_to_response_request_idxs = ( + non_assigned_to_response_request_idxs[num_assigned_docs:] + ) + if len(non_assigned_to_response_request_idxs) > 0: + _ = await _assign_results( + non_assigned_to_response_docs, + non_assigned_to_response_request_idxs, + sum_from_previous_first_req_idx, + requests_lens_in_batch, + requests_in_batch, + requests_completed_in_batch, + ) + + async def close(self): """Closes the batch queue by flushing pending requests.""" if not self._is_closed: diff --git a/jina/serve/runtimes/worker/request_handling.py b/jina/serve/runtimes/worker/request_handling.py index 52a5070ea83e4..456c94a7bdf41 100644 --- a/jina/serve/runtimes/worker/request_handling.py +++ b/jina/serve/runtimes/worker/request_handling.py @@ -702,7 +702,6 @@ async def handle( ].response_schema, output_array_type=self.args.output_array_type, params=params, - allow_concurrent=self.args.allow_concurrent, **self._batchqueue_config[exec_endpoint], ) # This is necessary because push might need to await for the queue to be emptied diff --git a/tests/integration/dynamic_batching/test_dynamic_batching.py b/tests/integration/dynamic_batching/test_dynamic_batching.py index b55e8415c0aae..f7940289d6154 100644 --- a/tests/integration/dynamic_batching/test_dynamic_batching.py +++ b/tests/integration/dynamic_batching/test_dynamic_batching.py @@ -218,9 +218,7 @@ def call_api_with_params(req: RequestStructParams): ], ) @pytest.mark.parametrize('use_stream', [False, True]) -@pytest.mark.parametrize('allow_concurrent', [False, True]) -def test_timeout(add_parameters, use_stream, allow_concurrent): - add_parameters['allow_concurrent'] = allow_concurrent +def test_timeout(add_parameters, use_stream): f = Flow().add(**add_parameters) with f: start_time = time.time() @@ -267,9 +265,7 @@ def test_timeout(add_parameters, use_stream, allow_concurrent): ], ) @pytest.mark.parametrize('use_stream', [False, True]) -@pytest.mark.parametrize('allow_concurrent', [False, True]) -def test_preferred_batch_size(add_parameters, use_stream, allow_concurrent): - add_parameters['allow_concurrent'] = allow_concurrent +def test_preferred_batch_size(add_parameters, use_stream): f = Flow().add(**add_parameters) with f: with mp.Pool(2) as p: @@ -319,9 +315,8 @@ def test_preferred_batch_size(add_parameters, use_stream, allow_concurrent): @pytest.mark.repeat(10) @pytest.mark.parametrize('use_stream', [False, True]) -@pytest.mark.parametrize('allow_concurrent', [False, True]) -def test_correctness(use_stream, allow_concurrent): - f = Flow().add(uses=PlaceholderExecutor, allow_concurrent=allow_concurrent) +def test_correctness(use_stream): + f = Flow().add(uses=PlaceholderExecutor) with f: with mp.Pool(2) as p: results = list( @@ -641,14 +636,7 @@ def test_failure_propagation(): True ], ) -@pytest.mark.parametrize( - 'allow_concurrent', - [ - False, - True - ], -) -def test_exception_handling_in_dynamic_batch(flush_all, allow_concurrent): +def test_exception_handling_in_dynamic_batch(flush_all): class SlowExecutorWithException(Executor): @dynamic_batching(preferred_batch_size=3, timeout=5000, flush_all=flush_all) @@ -658,7 +646,7 @@ def foo(self, docs, **kwargs): if doc.text == 'fail': raise Exception('Fail is in the Batch') - depl = Deployment(uses=SlowExecutorWithException, allow_concurrent=allow_concurrent) + depl = Deployment(uses=SlowExecutorWithException) with depl: da = DocumentArray([Document(text='good') for _ in range(50)]) @@ -691,14 +679,7 @@ def foo(self, docs, **kwargs): True ], ) -@pytest.mark.parametrize( - 'allow_concurrent', - [ - False, - True - ], -) -async def test_num_docs_processed_in_exec(flush_all, allow_concurrent): +async def test_num_docs_processed_in_exec(flush_all): class DynBatchProcessor(Executor): @dynamic_batching(preferred_batch_size=5, timeout=5000, flush_all=flush_all) @@ -707,7 +688,7 @@ def foo(self, docs, **kwargs): for doc in docs: doc.text = f"{len(docs)}" - depl = Deployment(uses=DynBatchProcessor, protocol='http', allow_concurrent=allow_concurrent) + depl = Deployment(uses=DynBatchProcessor, protocol='http') with depl: da = DocumentArray([Document(text='good') for _ in range(50)]) @@ -722,25 +703,11 @@ def foo(self, docs, **kwargs): ): res.extend(r) assert len(res) == 50 # 1 request per input - if not flush_all: - for d in res: - assert int(d.text) <= 5 - else: - larger_than_5 = 0 - smaller_than_5 = 0 - for d in res: - if int(d.text) > 5: - larger_than_5 += 1 - if int(d.text) < 5: - smaller_than_5 += 1 - - assert smaller_than_5 == (1 if allow_concurrent else 0) - assert larger_than_5 > 0 @pytest.mark.asyncio -@pytest.mark.parametrize('use_custom_metric', [True, False]) -@pytest.mark.parametrize('flush_all', [False, True]) +@pytest.mark.parametrize('use_custom_metric', [True]) +@pytest.mark.parametrize('flush_all', [True]) async def test_dynamic_batching_custom_metric(use_custom_metric, flush_all): class DynCustomBatchProcessor(Executor): @@ -766,37 +733,3 @@ def foo(self, docs, **kwargs): ): res.extend(r) assert len(res) == 50 # 1 request per input - - # If custom_metric and flush all - if use_custom_metric and not flush_all: - for doc in res: - assert doc.text == "10" - - elif not use_custom_metric and not flush_all: - for doc in res: - assert doc.text == "50" - - elif use_custom_metric and flush_all: - # There will be 2 "10" and the rest will be "240" - num_10 = 0 - num_240 = 0 - for doc in res: - if doc.text == "10": - num_10 += 1 - elif doc.text == "240": - num_240 += 1 - - assert num_10 == 2 - assert num_240 == 48 - elif not use_custom_metric and flush_all: - # There will be 10 "50" and the rest will be "200" - num_50 = 0 - num_200 = 0 - for doc in res: - if doc.text == "50": - num_50 += 1 - elif doc.text == "200": - num_200 += 1 - - assert num_50 == 10 - assert num_200 == 40 diff --git a/tests/unit/serve/dynamic_batching/test_batch_queue.py b/tests/unit/serve/dynamic_batching/test_batch_queue.py index 40622b478322d..21fafabddd8e3 100644 --- a/tests/unit/serve/dynamic_batching/test_batch_queue.py +++ b/tests/unit/serve/dynamic_batching/test_batch_queue.py @@ -10,8 +10,7 @@ @pytest.mark.asyncio @pytest.mark.parametrize('flush_all', [False, True]) -@pytest.mark.parametrize('allow_concurrent', [False, True]) -async def test_batch_queue_timeout(flush_all, allow_concurrent): +async def test_batch_queue_timeout(flush_all): async def foo(docs, **kwargs): await asyncio.sleep(0.1) return DocumentArray([Document(text='Done') for _ in docs]) @@ -23,7 +22,6 @@ async def foo(docs, **kwargs): preferred_batch_size=4, timeout=2000, flush_all=flush_all, - allow_concurrent=allow_concurrent, ) three_data_requests = [DataRequest() for _ in range(3)] @@ -64,10 +62,8 @@ async def process_request(req): @pytest.mark.asyncio @pytest.mark.parametrize('flush_all', [False, True]) -@pytest.mark.parametrize('allow_concurrent', [False, True]) -async def test_batch_queue_timeout_does_not_wait_previous_batch(flush_all, allow_concurrent): +async def test_batch_queue_timeout_does_not_wait_previous_batch(flush_all): batches_lengths_computed = [] - lock = asyncio.Lock() async def foo(docs, **kwargs): await asyncio.sleep(4) @@ -81,7 +77,6 @@ async def foo(docs, **kwargs): preferred_batch_size=5, timeout=3000, flush_all=flush_all, - allow_concurrent=allow_concurrent ) data_requests = [DataRequest() for _ in range(3)] @@ -108,17 +103,12 @@ async def process_request(req, sleep=0): if flush_all is False: # TIME TAKEN: 8000 for first batch of requests, plus 4000 for second batch that is fired inmediately # BEFORE FIX in https://github.com/jina-ai/jina/pull/6071, this would take: 8000 + 3000 + 4000 (Timeout would start counting too late) - assert time_spent >= 12000 - assert time_spent <= 12500 - else: - if not allow_concurrent: - assert time_spent >= 8000 - assert time_spent <= 8500 - else: - assert time_spent < 8000 - if flush_all is False: - assert batches_lengths_computed == [5, 1, 2] + assert time_spent >= 8000 + assert time_spent <= 8500 + assert batches_lengths_computed == [5, 2, 1] else: + assert time_spent >= 7000 + assert time_spent <= 7500 assert batches_lengths_computed == [6, 2] await bq.close() @@ -126,8 +116,7 @@ async def process_request(req, sleep=0): @pytest.mark.asyncio @pytest.mark.parametrize('flush_all', [False, True]) -@pytest.mark.parametrize('allow_concurrent', [False, True]) -async def test_batch_queue_req_length_larger_than_preferred(flush_all, allow_concurrent): +async def test_batch_queue_req_length_larger_than_preferred(flush_all): async def foo(docs, **kwargs): await asyncio.sleep(0.1) return DocumentArray([Document(text='Done') for _ in docs]) @@ -139,7 +128,6 @@ async def foo(docs, **kwargs): preferred_batch_size=4, timeout=2000, flush_all=flush_all, - allow_concurrent=allow_concurrent, ) data_requests = [DataRequest() for _ in range(3)] @@ -166,8 +154,7 @@ async def process_request(req): @pytest.mark.asyncio -@pytest.mark.parametrize('allow_concurrent', [False, True]) -async def test_exception(allow_concurrent): +async def test_exception(): BAD_REQUEST_IDX = [2, 6] async def foo(docs, **kwargs): @@ -185,7 +172,6 @@ async def foo(docs, **kwargs): preferred_batch_size=1, timeout=500, flush_all=False, - allow_concurrent=allow_concurrent ) data_requests = [DataRequest() for _ in range(35)] @@ -215,8 +201,7 @@ async def process_request(req): @pytest.mark.asyncio -@pytest.mark.parametrize('allow_concurrent', [False, True]) -async def test_exception_more_complex(allow_concurrent): +async def test_exception_more_complex(): TRIGGER_BAD_REQUEST_IDX = [2, 6] EXPECTED_BAD_REQUESTS = [2, 3, 6, 7] @@ -238,7 +223,6 @@ async def foo(docs, **kwargs): preferred_batch_size=2, timeout=500, flush_all=False, - allow_concurrent=allow_concurrent ) data_requests = [DataRequest() for _ in range(35)] @@ -271,8 +255,7 @@ async def process_request(req): @pytest.mark.asyncio @pytest.mark.parametrize('flush_all', [False, True]) -@pytest.mark.parametrize('allow_concurrent', [False, True]) -async def test_exception_all(flush_all, allow_concurrent): +async def test_exception_all(flush_all): async def foo(docs, **kwargs): raise AssertionError @@ -283,7 +266,6 @@ async def foo(docs, **kwargs): preferred_batch_size=2, flush_all=flush_all, timeout=500, - allow_concurrent=allow_concurrent ) data_requests = [DataRequest() for _ in range(10)] @@ -322,9 +304,8 @@ async def foo(docs, **kwargs): @pytest.mark.parametrize('preferred_batch_size', [7, 61, 100]) @pytest.mark.parametrize('timeout', [0.3, 500]) @pytest.mark.parametrize('flush_all', [False, True]) -@pytest.mark.parametrize('allow_concurrent', [False, True]) @pytest.mark.asyncio -async def test_return_proper_assignment(num_requests, preferred_batch_size, timeout, flush_all, allow_concurrent): +async def test_return_proper_assignment(num_requests, preferred_batch_size, timeout, flush_all): import random async def foo(docs, **kwargs): @@ -343,7 +324,6 @@ async def foo(docs, **kwargs): preferred_batch_size=preferred_batch_size, flush_all=flush_all, timeout=timeout, - allow_concurrent=allow_concurrent ) data_requests = [DataRequest() for _ in range(num_requests)]