diff --git a/aioredis/pool.py b/aioredis/pool.py index 78174b8c5..6850f06d5 100644 --- a/aioredis/pool.py +++ b/aioredis/pool.py @@ -1,4 +1,5 @@ import asyncio +import collections from .commands import create_redis, Redis from .log import logger @@ -22,7 +23,7 @@ def create_pool(address, *, db=0, password=None, encoding=None, minsize=minsize, maxsize=maxsize, commands_factory=commands_factory, loop=loop) - yield from pool._fill_free() + yield from pool._fill_free(override_min=False) return pool @@ -41,9 +42,10 @@ def __init__(self, address, db=0, password=None, encoding=None, self._minsize = minsize self._factory = commands_factory self._loop = loop - self._pool = asyncio.Queue(maxsize, loop=loop) + self._pool = collections.deque(maxlen=maxsize) self._used = set() - self._need_wait = None + self._acquiring = 0 + self._cond = asyncio.Condition(loop=loop) @property def minsize(self): @@ -53,17 +55,17 @@ def minsize(self): @property def maxsize(self): """Maximum pool size.""" - return self._pool.maxsize + return self._pool.maxlen @property def size(self): """Current pool size.""" - return self.freesize + len(self._used) + return self.freesize + len(self._used) + self._acquiring @property def freesize(self): """Current number of free connections.""" - return self._pool.qsize() + return len(self._pool) @asyncio.coroutine def clear(self): @@ -71,10 +73,13 @@ def clear(self): Close and remove all free connections. """ - while not self._pool.empty(): - conn = yield from self._pool.get() - conn.close() - yield from conn.wait_closed() + with (yield from self._cond): + waiters = [] + while self._pool: + conn = self._pool.popleft() + conn.close() + waiters.append(conn.wait_closed()) + yield from asyncio.gather(*waiters, loop=self._loop) @property def db(self): @@ -92,24 +97,11 @@ def select(self, db): All previously acquired connections will be closed when released. """ - self._need_wait = fut = asyncio.Future(loop=self._loop) - try: - for _ in range(self.freesize): - conn = yield from self._pool.get() - try: - yield from conn.select(db) - finally: - yield from self._pool.put(conn) + with (yield from self._cond): + for i in range(self.freesize): + yield from self._pool[i].select(db) else: self._db = db - finally: - self._need_wait = None - fut.set_result(None) - - def _wait_select(self): - if self._need_wait is None: - return () - return self._need_wait @asyncio.coroutine def acquire(self): @@ -117,16 +109,17 @@ def acquire(self): Creates new connection if needed. """ - yield from self._wait_select() - yield from self._fill_free() - if self.minsize > 0 or not self._pool.empty(): - conn = yield from self._pool.get() - else: - conn = yield from self._create_new_connection() - assert not conn.closed, conn - assert conn not in self._used, (conn, self._used) - self._used.add(conn) - return conn + with (yield from self._cond): + while True: + yield from self._fill_free(override_min=True) + if self.freesize: + conn = self._pool.popleft() + assert not conn.closed, conn + assert conn not in self._used, (conn, self._used) + self._used.add(conn) + return conn + else: + yield from self._cond.wait() def release(self, conn): """Returns used connection back into pool. @@ -143,19 +136,34 @@ def release(self, conn): conn) conn.close() elif conn.db == self.db: - try: - self._pool.put_nowait(conn) - except asyncio.QueueFull: + if self.maxsize and self.freesize < self.maxsize: + self._pool.append(conn) + else: # consider this connection as old and close it. conn.close() else: conn.close() + # FIXME: check event loop is not closed + asyncio.async(self._wakeup(), loop=self._loop) @asyncio.coroutine - def _fill_free(self): - while self.freesize < self.minsize and self.size < self.maxsize: - conn = yield from self._create_new_connection() - yield from self._pool.put(conn) + def _fill_free(self, *, override_min): + while self.size < self.minsize: + self._acquiring += 1 + try: + conn = yield from self._create_new_connection() + self._pool.append(conn) + finally: + self._acquiring -= 1 + if self.freesize: + return + if override_min and self.size < self.maxsize: + self._acquiring += 1 + try: + conn = yield from self._create_new_connection() + self._pool.append(conn) + finally: + self._acquiring -= 1 @asyncio.coroutine def _create_new_connection(self): @@ -167,6 +175,13 @@ def _create_new_connection(self): loop=self._loop) return conn + @asyncio.coroutine + def _wakeup(self, closing_conn=None): + with (yield from self._cond): + self._cond.notify() + if closing_conn is not None: + yield from closing_conn.wait_closed() + def __enter__(self): raise RuntimeError( "'yield from' should be used as a context manager expression") diff --git a/tests/pool_test.py b/tests/pool_test.py index d14f4e413..ef92ab2c9 100644 --- a/tests/pool_test.py +++ b/tests/pool_test.py @@ -1,5 +1,4 @@ import asyncio -import unittest from ._testutil import BaseTest, run_until_complete from aioredis import RedisPool, ReplyError @@ -106,9 +105,10 @@ def test_create_no_minsize(self): self.assertEqual(pool.size, 1) self.assertEqual(pool.freesize, 0) - with (yield from pool): - self.assertEqual(pool.size, 2) - self.assertEqual(pool.freesize, 0) + with self.assertRaises(asyncio.TimeoutError): + yield from asyncio.wait_for(pool.acquire(), + timeout=0.2, + loop=self.loop) self.assertEqual(pool.size, 1) self.assertEqual(pool.freesize, 1) @@ -219,15 +219,14 @@ def test(): db = 0 while True: db = (db + 1) & 1 - res = yield from asyncio.gather(pool.select(db), - pool.acquire(), - loop=self.loop) - conn = res[1] + _, conn = yield from asyncio.gather(pool.select(db), + pool.acquire(), + loop=self.loop) self.assertEqual(pool.db, db) pool.release(conn) if conn.db == db: break - yield from asyncio.wait_for(test(), 10, loop=self.loop) + yield from asyncio.wait_for(test(), 1, loop=self.loop) @run_until_complete def test_response_decoding(self): @@ -273,30 +272,33 @@ def test_crappy_multiexec(self): value = yield from redis.get('abc') self.assertEquals(value, 'def') - @unittest.expectedFailure + # @unittest.expectedFailure @run_until_complete def test_pool_size_growth(self): pool = yield from self.create_pool( ('localhost', self.redis_port), loop=self.loop, - minsize=1, maxsize=2) + minsize=1, maxsize=1) + + done = set() + tasks = [] @asyncio.coroutine - def task1(): + def task1(i): with (yield from pool): self.assertLessEqual(pool.size, pool.maxsize) self.assertEqual(pool.freesize, 0) - yield from asyncio.sleep(1, loop=self.loop) + yield from asyncio.sleep(0.2, loop=self.loop) + done.add(i) @asyncio.coroutine def task2(): with (yield from pool): self.assertLessEqual(pool.size, pool.maxsize) - self.assertEqual(pool.freesize, 0) - yield from asyncio.sleep(1, loop=self.loop) + self.assertGreaterEqual(pool.freesize, 0) + self.assertEqual(done, {0, 1}) - tasks = [] for _ in range(2): - tasks.append(asyncio.async(task1(), loop=self.loop)) + tasks.append(asyncio.async(task1(_), loop=self.loop)) tasks.append(asyncio.async(task2(), loop=self.loop)) yield from asyncio.gather(*tasks, loop=self.loop)