diff --git a/bumble/core.py b/bumble/core.py index f6d42dd5..4cfa8ad9 100644 --- a/bumble/core.py +++ b/bumble/core.py @@ -16,7 +16,7 @@ # Imports # ----------------------------------------------------------------------------- from __future__ import annotations -import dataclasses +import asyncio import enum import struct from typing import List, Optional, Tuple, Union, cast, Dict @@ -160,6 +160,14 @@ class UnreachableError(BaseBumbleError): """The code path raising this error should be unreachable.""" +class CancelledError(BaseBumbleError, asyncio.CancelledError): + """Operation has been cancelled or aborted.""" + + +class BearerLostError(BaseBumbleError): + """Bearer transport (ACL, L2CAP Channel) has been terminated or lost.""" + + class ConnectionError(BaseError): # pylint: disable=redefined-builtin """Connection Error""" diff --git a/bumble/gatt_client.py b/bumble/gatt_client.py index 1362b1ed..bcc7dd6c 100644 --- a/bumble/gatt_client.py +++ b/bumble/gatt_client.py @@ -1060,9 +1060,11 @@ async def write_value( ) ) - def on_disconnection(self, _) -> None: + def on_disconnection(self, reason: int) -> None: if self.pending_response and not self.pending_response.done(): - self.pending_response.cancel() + self.pending_response.set_exception( + core.BearerLostError(f"Connection terminated, reason={reason}") + ) def on_gatt_pdu(self, att_pdu: ATT_PDU) -> None: logger.debug( diff --git a/bumble/rfcomm.py b/bumble/rfcomm.py index 2de7374c..9343eb55 100644 --- a/bumble/rfcomm.py +++ b/bumble/rfcomm.py @@ -727,10 +727,10 @@ async def drain(self) -> None: def abort(self) -> None: logger.debug(f'aborting DLC: {self}') if self.connection_result: - self.connection_result.cancel() + self.connection_result.set_exception(core.CancelledError("Aborted")) self.connection_result = None if self.disconnection_result: - self.disconnection_result.cancel() + self.disconnection_result.set_exception(core.CancelledError("Aborted")) self.disconnection_result = None self.change_state(DLC.State.RESET) self.emit('close') @@ -1011,10 +1011,12 @@ def on_dlc_disconnection(self, dlc: DLC) -> None: def on_l2cap_channel_close(self) -> None: logger.debug('L2CAP channel closed, cleaning up') if self.open_result: - self.open_result.cancel() + self.open_result.set_exception(core.BearerLostError("L2CAP channel closed")) self.open_result = None if self.disconnection_result: - self.disconnection_result.cancel() + self.disconnection_result.set_exception( + core.BearerLostError("L2CAP channel closed") + ) self.disconnection_result = None for dlc in self.dlcs.values(): dlc.abort() diff --git a/bumble/utils.py b/bumble/utils.py index d8864bb1..3cef0990 100644 --- a/bumble/utils.py +++ b/bumble/utils.py @@ -21,7 +21,6 @@ import enum import functools import logging -import sys import warnings from typing import ( Any, @@ -185,33 +184,43 @@ def close(self) -> None: class AbortableEventEmitter(EventEmitter): - def abort_on(self, event: str, awaitable: Awaitable[_T]) -> Awaitable[_T]: + def abort_on(self, event: str, awaitable: Awaitable[_T]) -> asyncio.Future[_T]: """ Set a coroutine or future to abort when an event occur. """ - future = asyncio.ensure_future(awaitable) - if future.done(): - return future + inner_future = asyncio.ensure_future(awaitable) + if inner_future.done(): + return inner_future + + exposed_future: asyncio.Future[_T] + if isinstance(inner_future, asyncio.Task): + exposed_future = asyncio.get_running_loop().create_future() + else: + exposed_future = inner_future def on_event(*_): - if future.done(): + if exposed_future.done(): return - msg = f'abort: {event} event occurred.' - if isinstance(future, asyncio.Task): - # python < 3.9 does not support passing a message on `Task.cancel` - if sys.version_info < (3, 9, 0): - future.cancel() - else: - future.cancel(msg) - else: - future.set_exception(asyncio.CancelledError(msg)) + if isinstance(inner_future, asyncio.Task): + inner_future.cancel() + + from bumble.core import CancelledError + + exposed_future.set_exception( + CancelledError(f'abort: {event} event occurred.') + ) def on_done(_): self.remove_listener(event, on_event) + if exposed_future is not inner_future: + try: + exposed_future.set_result(inner_future.result()) + except BaseException as e: + exposed_future.set_exception(e) self.on(event, on_event) - future.add_done_callback(on_done) - return future + inner_future.add_done_callback(on_done) + return exposed_future # ----------------------------------------------------------------------------- diff --git a/tests/device_test.py b/tests/device_test.py index 1f6175ab..4e880dfe 100644 --- a/tests/device_test.py +++ b/tests/device_test.py @@ -25,6 +25,7 @@ BT_BR_EDR_TRANSPORT, BT_LE_TRANSPORT, BT_PERIPHERAL_ROLE, + CancelledError, ConnectionParameters, ) from bumble.device import ( @@ -259,11 +260,8 @@ async def test_flush(): d0 = Device(host=Host(None, None)) task = d0.abort_on('flush', asyncio.sleep(10000)) await d0.host.flush() - try: + with pytest.raises(CancelledError): await task - assert False - except asyncio.CancelledError: - pass # ----------------------------------------------------------------------------- diff --git a/tests/utils_test.py b/tests/utils_test.py index 6266f9ef..955acf4a 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -15,13 +15,16 @@ # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- +import asyncio import contextlib import logging import os -from unittest.mock import MagicMock +import pytest +from unittest.mock import MagicMock, AsyncMock from pyee import EventEmitter +from bumble import core from bumble import utils @@ -95,6 +98,75 @@ class Foo(utils.OpenIntEnum): print(list(Foo)) +# ----------------------------------------------------------------------------- +async def test_abort_on_coroutine_aborted(): + ee = utils.AbortableEventEmitter() + + future = ee.abort_on('e', asyncio.Event().wait()) + ee.emit('e') + + with pytest.raises(core.CancelledError): + await future + + +# ----------------------------------------------------------------------------- +async def test_abort_on_coroutine_non_aborted(): + ee = utils.AbortableEventEmitter() + event = asyncio.Event() + + future = ee.abort_on('e', event.wait()) + event.set() + + await future + + +# ----------------------------------------------------------------------------- +async def test_abort_on_coroutine_exception(): + ee = utils.AbortableEventEmitter() + coroutine_factory = AsyncMock(side_effect=Exception("test")) + + future = ee.abort_on('e', coroutine_factory()) + with pytest.raises(Exception) as e: + await future + assert e.value.args == ("test",) + + +# ----------------------------------------------------------------------------- +async def test_abort_on_future_aborted(): + ee = utils.AbortableEventEmitter() + real_future = asyncio.get_running_loop().create_future() + + future = ee.abort_on('e', real_future) + ee.emit('e') + + with pytest.raises(core.CancelledError): + await future + + +# ----------------------------------------------------------------------------- +async def test_abort_on_future_non_aborted(): + ee = utils.AbortableEventEmitter() + real_future = asyncio.get_running_loop().create_future() + + future = ee.abort_on('e', real_future) + real_future.set_result(None) + + await future + + +# ----------------------------------------------------------------------------- +async def test_abort_on_future_exception(): + ee = utils.AbortableEventEmitter() + real_future = asyncio.get_running_loop().create_future() + + future = ee.abort_on('e', real_future) + real_future.set_exception(Exception("test")) + + with pytest.raises(Exception) as e: + await future + assert e.value.args == ("test",) + + # ----------------------------------------------------------------------------- def run_tests(): test_on()