Skip to content

Commit

Permalink
Avoid cancelling futures
Browse files Browse the repository at this point in the history
  • Loading branch information
zxzxwu committed Dec 31, 2024
1 parent 80d60aa commit 03c1f7b
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 29 deletions.
10 changes: 9 additions & 1 deletion bumble/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
import asyncio
import enum
import struct
from typing import List, Optional, Tuple, Union, cast, Dict
Expand Down Expand Up @@ -160,6 +160,14 @@ class UnreachableError(BaseBumbleError):
"""The code path raising this error should be unreachable."""


class CancelledError(BaseBumbleError, asyncio.CancelledError):
"""Operation has been cancelled or aborted."""


class BearerLostError(BaseBumbleError):
"""Bearer transport (ACL, L2CAP Channel) has been terminated or lost."""


class ConnectionError(BaseError): # pylint: disable=redefined-builtin
"""Connection Error"""

Expand Down
6 changes: 4 additions & 2 deletions bumble/gatt_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,9 +1060,11 @@ async def write_value(
)
)

def on_disconnection(self, _) -> None:
def on_disconnection(self, reason: int) -> None:
if self.pending_response and not self.pending_response.done():
self.pending_response.cancel()
self.pending_response.set_exception(
core.BearerLostError(f"Connection terminated, reason={reason}")
)

def on_gatt_pdu(self, att_pdu: ATT_PDU) -> None:
logger.debug(
Expand Down
10 changes: 6 additions & 4 deletions bumble/rfcomm.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,10 +727,10 @@ async def drain(self) -> None:
def abort(self) -> None:
logger.debug(f'aborting DLC: {self}')
if self.connection_result:
self.connection_result.cancel()
self.connection_result.set_exception(core.CancelledError("Aborted"))
self.connection_result = None
if self.disconnection_result:
self.disconnection_result.cancel()
self.disconnection_result.set_exception(core.CancelledError("Aborted"))
self.disconnection_result = None
self.change_state(DLC.State.RESET)
self.emit('close')
Expand Down Expand Up @@ -1011,10 +1011,12 @@ def on_dlc_disconnection(self, dlc: DLC) -> None:
def on_l2cap_channel_close(self) -> None:
logger.debug('L2CAP channel closed, cleaning up')
if self.open_result:
self.open_result.cancel()
self.open_result.set_exception(core.BearerLostError("L2CAP channel closed"))
self.open_result = None
if self.disconnection_result:
self.disconnection_result.cancel()
self.disconnection_result.set_exception(
core.BearerLostError("L2CAP channel closed")
)
self.disconnection_result = None
for dlc in self.dlcs.values():
dlc.abort()
Expand Down
43 changes: 26 additions & 17 deletions bumble/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import enum
import functools
import logging
import sys
import warnings
from typing import (
Any,
Expand Down Expand Up @@ -185,33 +184,43 @@ def close(self) -> None:


class AbortableEventEmitter(EventEmitter):
def abort_on(self, event: str, awaitable: Awaitable[_T]) -> Awaitable[_T]:
def abort_on(self, event: str, awaitable: Awaitable[_T]) -> asyncio.Future[_T]:
"""
Set a coroutine or future to abort when an event occur.
"""
future = asyncio.ensure_future(awaitable)
if future.done():
return future
inner_future = asyncio.ensure_future(awaitable)
if inner_future.done():
return inner_future

exposed_future: asyncio.Future[_T]
if isinstance(inner_future, asyncio.Task):
exposed_future = asyncio.get_running_loop().create_future()
else:
exposed_future = inner_future

def on_event(*_):
if future.done():
if exposed_future.done():
return
msg = f'abort: {event} event occurred.'
if isinstance(future, asyncio.Task):
# python < 3.9 does not support passing a message on `Task.cancel`
if sys.version_info < (3, 9, 0):
future.cancel()
else:
future.cancel(msg)
else:
future.set_exception(asyncio.CancelledError(msg))
if isinstance(inner_future, asyncio.Task):
inner_future.cancel()

from bumble.core import CancelledError

exposed_future.set_exception(
CancelledError(f'abort: {event} event occurred.')
)

def on_done(_):
self.remove_listener(event, on_event)
if exposed_future is not inner_future:
try:
exposed_future.set_result(inner_future.result())
except BaseException as e:
exposed_future.set_exception(e)

self.on(event, on_event)
future.add_done_callback(on_done)
return future
inner_future.add_done_callback(on_done)
return exposed_future


# -----------------------------------------------------------------------------
Expand Down
6 changes: 2 additions & 4 deletions tests/device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
BT_BR_EDR_TRANSPORT,
BT_LE_TRANSPORT,
BT_PERIPHERAL_ROLE,
CancelledError,
ConnectionParameters,
)
from bumble.device import (
Expand Down Expand Up @@ -259,11 +260,8 @@ async def test_flush():
d0 = Device(host=Host(None, None))
task = d0.abort_on('flush', asyncio.sleep(10000))
await d0.host.flush()
try:
with pytest.raises(CancelledError):
await task
assert False
except asyncio.CancelledError:
pass


# -----------------------------------------------------------------------------
Expand Down
74 changes: 73 additions & 1 deletion tests/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import contextlib
import logging
import os
from unittest.mock import MagicMock
import pytest
from unittest.mock import MagicMock, AsyncMock

from pyee import EventEmitter

from bumble import core
from bumble import utils


Expand Down Expand Up @@ -95,6 +98,75 @@ class Foo(utils.OpenIntEnum):
print(list(Foo))


# -----------------------------------------------------------------------------
async def test_abort_on_coroutine_aborted():
ee = utils.AbortableEventEmitter()

future = ee.abort_on('e', asyncio.Event().wait())
ee.emit('e')

with pytest.raises(core.CancelledError):
await future


# -----------------------------------------------------------------------------
async def test_abort_on_coroutine_non_aborted():
ee = utils.AbortableEventEmitter()
event = asyncio.Event()

future = ee.abort_on('e', event.wait())
event.set()

await future


# -----------------------------------------------------------------------------
async def test_abort_on_coroutine_exception():
ee = utils.AbortableEventEmitter()
coroutine_factory = AsyncMock(side_effect=Exception("test"))

future = ee.abort_on('e', coroutine_factory())
with pytest.raises(Exception) as e:
await future
assert e.value.args == ("test",)


# -----------------------------------------------------------------------------
async def test_abort_on_future_aborted():
ee = utils.AbortableEventEmitter()
real_future = asyncio.get_running_loop().create_future()

future = ee.abort_on('e', real_future)
ee.emit('e')

with pytest.raises(core.CancelledError):
await future


# -----------------------------------------------------------------------------
async def test_abort_on_future_non_aborted():
ee = utils.AbortableEventEmitter()
real_future = asyncio.get_running_loop().create_future()

future = ee.abort_on('e', real_future)
real_future.set_result(None)

await future


# -----------------------------------------------------------------------------
async def test_abort_on_future_exception():
ee = utils.AbortableEventEmitter()
real_future = asyncio.get_running_loop().create_future()

future = ee.abort_on('e', real_future)
real_future.set_exception(Exception("test"))

with pytest.raises(Exception) as e:
await future
assert e.value.args == ("test",)


# -----------------------------------------------------------------------------
def run_tests():
test_on()
Expand Down

0 comments on commit 03c1f7b

Please sign in to comment.