Skip to content
This repository has been archived by the owner on Feb 21, 2023. It is now read-only.

Commit

Permalink
Merge pull request #77 from aio-libs/pool_growth_over_maxsize_fix
Browse files Browse the repository at this point in the history
fixed pool growth over maxsize
  • Loading branch information
popravich committed Aug 4, 2015
2 parents 792b1b6 + 2c09c4b commit b19047f
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 60 deletions.
101 changes: 58 additions & 43 deletions aioredis/pool.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import collections

from .commands import create_redis, Redis
from .log import logger
Expand All @@ -22,7 +23,7 @@ def create_pool(address, *, db=0, password=None, encoding=None,
minsize=minsize, maxsize=maxsize,
commands_factory=commands_factory,
loop=loop)
yield from pool._fill_free()
yield from pool._fill_free(override_min=False)
return pool


Expand All @@ -41,9 +42,10 @@ def __init__(self, address, db=0, password=None, encoding=None,
self._minsize = minsize
self._factory = commands_factory
self._loop = loop
self._pool = asyncio.Queue(maxsize, loop=loop)
self._pool = collections.deque(maxlen=maxsize)
self._used = set()
self._need_wait = None
self._acquiring = 0
self._cond = asyncio.Condition(loop=loop)

@property
def minsize(self):
Expand All @@ -53,28 +55,31 @@ def minsize(self):
@property
def maxsize(self):
"""Maximum pool size."""
return self._pool.maxsize
return self._pool.maxlen

@property
def size(self):
"""Current pool size."""
return self.freesize + len(self._used)
return self.freesize + len(self._used) + self._acquiring

@property
def freesize(self):
"""Current number of free connections."""
return self._pool.qsize()
return len(self._pool)

@asyncio.coroutine
def clear(self):
"""Clear pool connections.
Close and remove all free connections.
"""
while not self._pool.empty():
conn = yield from self._pool.get()
conn.close()
yield from conn.wait_closed()
with (yield from self._cond):
waiters = []
while self._pool:
conn = self._pool.popleft()
conn.close()
waiters.append(conn.wait_closed())
yield from asyncio.gather(*waiters, loop=self._loop)

@property
def db(self):
Expand All @@ -92,41 +97,29 @@ def select(self, db):
All previously acquired connections will be closed when released.
"""
self._need_wait = fut = asyncio.Future(loop=self._loop)
try:
for _ in range(self.freesize):
conn = yield from self._pool.get()
try:
yield from conn.select(db)
finally:
yield from self._pool.put(conn)
with (yield from self._cond):
for i in range(self.freesize):
yield from self._pool[i].select(db)
else:
self._db = db
finally:
self._need_wait = None
fut.set_result(None)

def _wait_select(self):
if self._need_wait is None:
return ()
return self._need_wait

@asyncio.coroutine
def acquire(self):
"""Acquires a connection from free pool.
Creates new connection if needed.
"""
yield from self._wait_select()
yield from self._fill_free()
if self.minsize > 0 or not self._pool.empty():
conn = yield from self._pool.get()
else:
conn = yield from self._create_new_connection()
assert not conn.closed, conn
assert conn not in self._used, (conn, self._used)
self._used.add(conn)
return conn
with (yield from self._cond):
while True:
yield from self._fill_free(override_min=True)
if self.freesize:
conn = self._pool.popleft()
assert not conn.closed, conn
assert conn not in self._used, (conn, self._used)
self._used.add(conn)
return conn
else:
yield from self._cond.wait()

def release(self, conn):
"""Returns used connection back into pool.
Expand All @@ -143,19 +136,34 @@ def release(self, conn):
conn)
conn.close()
elif conn.db == self.db:
try:
self._pool.put_nowait(conn)
except asyncio.QueueFull:
if self.maxsize and self.freesize < self.maxsize:
self._pool.append(conn)
else:
# consider this connection as old and close it.
conn.close()
else:
conn.close()
# FIXME: check event loop is not closed
asyncio.async(self._wakeup(), loop=self._loop)

@asyncio.coroutine
def _fill_free(self):
while self.freesize < self.minsize and self.size < self.maxsize:
conn = yield from self._create_new_connection()
yield from self._pool.put(conn)
def _fill_free(self, *, override_min):
while self.size < self.minsize:
self._acquiring += 1
try:
conn = yield from self._create_new_connection()
self._pool.append(conn)
finally:
self._acquiring -= 1
if self.freesize:
return
if override_min and self.size < self.maxsize:
self._acquiring += 1
try:
conn = yield from self._create_new_connection()
self._pool.append(conn)
finally:
self._acquiring -= 1

@asyncio.coroutine
def _create_new_connection(self):
Expand All @@ -167,6 +175,13 @@ def _create_new_connection(self):
loop=self._loop)
return conn

@asyncio.coroutine
def _wakeup(self, closing_conn=None):
with (yield from self._cond):
self._cond.notify()
if closing_conn is not None:
yield from closing_conn.wait_closed()

def __enter__(self):
raise RuntimeError(
"'yield from' should be used as a context manager expression")
Expand Down
36 changes: 19 additions & 17 deletions tests/pool_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import unittest

from ._testutil import BaseTest, run_until_complete
from aioredis import RedisPool, ReplyError
Expand Down Expand Up @@ -106,9 +105,10 @@ def test_create_no_minsize(self):
self.assertEqual(pool.size, 1)
self.assertEqual(pool.freesize, 0)

with (yield from pool):
self.assertEqual(pool.size, 2)
self.assertEqual(pool.freesize, 0)
with self.assertRaises(asyncio.TimeoutError):
yield from asyncio.wait_for(pool.acquire(),
timeout=0.2,
loop=self.loop)
self.assertEqual(pool.size, 1)
self.assertEqual(pool.freesize, 1)

Expand Down Expand Up @@ -219,15 +219,14 @@ def test():
db = 0
while True:
db = (db + 1) & 1
res = yield from asyncio.gather(pool.select(db),
pool.acquire(),
loop=self.loop)
conn = res[1]
_, conn = yield from asyncio.gather(pool.select(db),
pool.acquire(),
loop=self.loop)
self.assertEqual(pool.db, db)
pool.release(conn)
if conn.db == db:
break
yield from asyncio.wait_for(test(), 10, loop=self.loop)
yield from asyncio.wait_for(test(), 1, loop=self.loop)

@run_until_complete
def test_response_decoding(self):
Expand Down Expand Up @@ -273,30 +272,33 @@ def test_crappy_multiexec(self):
value = yield from redis.get('abc')
self.assertEquals(value, 'def')

@unittest.expectedFailure
# @unittest.expectedFailure
@run_until_complete
def test_pool_size_growth(self):
pool = yield from self.create_pool(
('localhost', self.redis_port),
loop=self.loop,
minsize=1, maxsize=2)
minsize=1, maxsize=1)

done = set()
tasks = []

@asyncio.coroutine
def task1():
def task1(i):
with (yield from pool):
self.assertLessEqual(pool.size, pool.maxsize)
self.assertEqual(pool.freesize, 0)
yield from asyncio.sleep(1, loop=self.loop)
yield from asyncio.sleep(0.2, loop=self.loop)
done.add(i)

@asyncio.coroutine
def task2():
with (yield from pool):
self.assertLessEqual(pool.size, pool.maxsize)
self.assertEqual(pool.freesize, 0)
yield from asyncio.sleep(1, loop=self.loop)
self.assertGreaterEqual(pool.freesize, 0)
self.assertEqual(done, {0, 1})

tasks = []
for _ in range(2):
tasks.append(asyncio.async(task1(), loop=self.loop))
tasks.append(asyncio.async(task1(_), loop=self.loop))
tasks.append(asyncio.async(task2(), loop=self.loop))
yield from asyncio.gather(*tasks, loop=self.loop)

0 comments on commit b19047f

Please sign in to comment.