diff --git a/src/apscheduler/datastores/mongodb.py b/src/apscheduler/datastores/mongodb.py index 44e2b901..95b54825 100644 --- a/src/apscheduler/datastores/mongodb.py +++ b/src/apscheduler/datastores/mongodb.py @@ -2,7 +2,6 @@ import operator import sys -from collections import defaultdict from collections.abc import AsyncGenerator, AsyncIterator, Sequence from contextlib import AsyncExitStack, asynccontextmanager from datetime import datetime, timedelta, timezone @@ -611,28 +610,8 @@ async def acquire_jobs( doc for doc in documents if doc["_id"] in acquired_job_ids ] - # Get the number of available job slots per task - task_ids: set[str] = {doc["task_id"] for doc in documents} - async with await AsyncCursor.create( - lambda: self._tasks.find( - { - "_id": {"$in": list(task_ids)}, - "max_running_jobs": {"$ne": None}, - }, - projection=["_id", "max_running_jobs", "running_jobs"], - session=session, - ) - ) as cursor: - task_job_slots_left: dict[str, float] = defaultdict( - lambda: float("inf") - ) - async for doc in cursor: - task_max_running_jobs = doc["max_running_jobs"] - task_job_slots_left[doc["_id"]] = doc["max_running_jobs"] - acquired_jobs: list[Job] = [] skipped_job_ids: list[UUID] = [] - running_job_count_increments: dict[str, int] = defaultdict(lambda: 0) for doc in documents: # Deserialize the job doc["id"] = doc.pop("_id") @@ -683,35 +662,50 @@ async def acquire_jobs( ) continue - # Skip and un-acquire the job if no more slots are available - if not task_job_slots_left.get(job.task_id, float("inf")): + # Try to increment the task's running jobs count + update_task_result = await to_thread.run_sync( + lambda: self._tasks.update_one( + { + "_id": job.task_id, + "$or": [ + {"max_running_jobs": None}, + { + "$expr": { + "$gt": [ + "$max_running_jobs", + "$running_jobs", + ] + } + }, + ], + }, + {"$inc": {"running_jobs": 1}}, + session=session, + ) + ) + if not update_task_result.matched_count: self._logger.debug( - "Skipping job %s because task %r has the maximum " - "number of %d jobs already running", - task_max_running_jobs, + "Skipping job %s because task %r has the maximum number of " + "jobs already running", + job.id, + job.task_id, ) skipped_job_ids.append(job.id) continue - task_job_slots_left[job.task_id] -= 1 - running_job_count_increments[job.task_id] += 1 job.acquired_by = scheduler_id job.acquired_until = now + lease_duration acquired_jobs.append(job) events.append(JobAcquired.from_job(job, scheduler_id=scheduler_id)) - # Increment the running_jobs field on the tasks of the acquired jobs - if writes := [ - UpdateOne({"_id": task_id}, {"$inc": {"running_jobs": amount}}) - for task_id, amount in running_job_count_increments.items() - ]: - await to_thread.run_sync(self._tasks.bulk_write, writes) - # Release jobs skipped due to max job slots being reached if skipped_job_ids: await to_thread.run_sync( lambda: self._jobs.update_many( - {"_id": {"$in": skipped_job_ids}}, + { + "_id": {"$in": skipped_job_ids}, + "acquired_by": scheduler_id, + }, { "$unset": { "acquired_by": True, diff --git a/tests/test_datastores.py b/tests/test_datastores.py index 32d75559..8e10b39b 100644 --- a/tests/test_datastores.py +++ b/tests/test_datastores.py @@ -624,6 +624,9 @@ async def test_acquire_jobs_max_number_exceeded(datastore: DataStore) -> None: assert job.acquired_by == "worker1" assert job.acquired_until + # Check that no jobs are acquired now that the task is at capacity + assert not await datastore.acquire_jobs("worker1", timedelta(seconds=30), 3) + # Release one job, and the worker should be able to acquire the third job await datastore.release_job( "worker1",