Skip to content

Commit

Permalink
Avoid cancelling futures
Browse files Browse the repository at this point in the history
  • Loading branch information
zxzxwu committed Dec 30, 2024
1 parent 80d60aa commit d10bd7d
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 28 deletions.
10 changes: 9 additions & 1 deletion bumble/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
import asyncio
import enum
import struct
from typing import List, Optional, Tuple, Union, cast, Dict
Expand Down Expand Up @@ -160,6 +160,14 @@ class UnreachableError(BaseBumbleError):
"""The code path raising this error should be unreachable."""


class CancelledError(BaseBumbleError, asyncio.CancelledError):
"""Operation has been cancelled or aborted."""


class BearerLostError(BaseBumbleError):
"""Bearer transport (ACL, L2CAP Channel) has been terminated or lost."""


class ConnectionError(BaseError): # pylint: disable=redefined-builtin
"""Connection Error"""

Expand Down
6 changes: 4 additions & 2 deletions bumble/gatt_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,9 +1060,11 @@ async def write_value(
)
)

def on_disconnection(self, _) -> None:
def on_disconnection(self, reason: int) -> None:
if self.pending_response and not self.pending_response.done():
self.pending_response.cancel()
self.pending_response.set_exception(
core.BearerLostError(f"Connection terminated, reason={reason}")
)

def on_gatt_pdu(self, att_pdu: ATT_PDU) -> None:
logger.debug(
Expand Down
10 changes: 6 additions & 4 deletions bumble/rfcomm.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,10 +727,10 @@ async def drain(self) -> None:
def abort(self) -> None:
logger.debug(f'aborting DLC: {self}')
if self.connection_result:
self.connection_result.cancel()
self.connection_result.set_exception(core.CancelledError("Aborted"))
self.connection_result = None
if self.disconnection_result:
self.disconnection_result.cancel()
self.disconnection_result.set_exception(core.CancelledError("Aborted"))
self.disconnection_result = None
self.change_state(DLC.State.RESET)
self.emit('close')
Expand Down Expand Up @@ -1011,10 +1011,12 @@ def on_dlc_disconnection(self, dlc: DLC) -> None:
def on_l2cap_channel_close(self) -> None:
logger.debug('L2CAP channel closed, cleaning up')
if self.open_result:
self.open_result.cancel()
self.open_result.set_exception(core.BearerLostError("L2CAP channel closed"))
self.open_result = None
if self.disconnection_result:
self.disconnection_result.cancel()
self.disconnection_result.set_exception(
core.BearerLostError("L2CAP channel closed")
)
self.disconnection_result = None
for dlc in self.dlcs.values():
dlc.abort()
Expand Down
43 changes: 26 additions & 17 deletions bumble/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import enum
import functools
import logging
import sys
import warnings
from typing import (
Any,
Expand Down Expand Up @@ -185,33 +184,43 @@ def close(self) -> None:


class AbortableEventEmitter(EventEmitter):
def abort_on(self, event: str, awaitable: Awaitable[_T]) -> Awaitable[_T]:
def abort_on(self, event: str, awaitable: Awaitable[_T]) -> asyncio.Future[_T]:
"""
Set a coroutine or future to abort when an event occur.
"""
future = asyncio.ensure_future(awaitable)
if future.done():
return future
inner_future = asyncio.ensure_future(awaitable)
if inner_future.done():
return inner_future

exposed_future: asyncio.Future[_T]
if isinstance(inner_future, asyncio.Task):
exposed_future = asyncio.get_running_loop().create_future()
else:
exposed_future = inner_future

def on_event(*_):
if future.done():
if exposed_future.done():
return
msg = f'abort: {event} event occurred.'
if isinstance(future, asyncio.Task):
# python < 3.9 does not support passing a message on `Task.cancel`
if sys.version_info < (3, 9, 0):
future.cancel()
else:
future.cancel(msg)
else:
future.set_exception(asyncio.CancelledError(msg))
if isinstance(inner_future, asyncio.Task):
inner_future.cancel()

from bumble.core import CancelledError

exposed_future.set_exception(
CancelledError(f'abort: {event} event occurred.')
)

def on_done(_):
self.remove_listener(event, on_event)
if exposed_future is not inner_future:
try:
exposed_future.set_result(inner_future.result())
except BaseException as e:
exposed_future.set_exception(e)

self.on(event, on_event)
future.add_done_callback(on_done)
return future
inner_future.add_done_callback(on_done)
return exposed_future


# -----------------------------------------------------------------------------
Expand Down
6 changes: 2 additions & 4 deletions tests/device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
BT_BR_EDR_TRANSPORT,
BT_LE_TRANSPORT,
BT_PERIPHERAL_ROLE,
CancelledError,
ConnectionParameters,
)
from bumble.device import (
Expand Down Expand Up @@ -259,11 +260,8 @@ async def test_flush():
d0 = Device(host=Host(None, None))
task = d0.abort_on('flush', asyncio.sleep(10000))
await d0.host.flush()
try:
with pytest.raises(CancelledError):
await task
assert False
except asyncio.CancelledError:
pass


# -----------------------------------------------------------------------------
Expand Down
74 changes: 74 additions & 0 deletions tests/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import contextlib
import logging
import os
import pytest
from unittest.mock import MagicMock

from pyee import EventEmitter

from bumble import core
from bumble import utils


Expand Down Expand Up @@ -95,6 +98,77 @@ class Foo(utils.OpenIntEnum):
print(list(Foo))


# -----------------------------------------------------------------------------
async def test_abort_on_coroutine_aborted():
ee = utils.AbortableEventEmitter()

future = ee.abort_on('e', asyncio.Event().wait())
ee.emit('e')

with pytest.raises(core.CancelledError):
await future


# -----------------------------------------------------------------------------
async def test_abort_on_coroutine_non_aborted():
ee = utils.AbortableEventEmitter()
event = asyncio.Event()

future = ee.abort_on('e', event.wait())
event.set()

await future


# -----------------------------------------------------------------------------
async def test_abort_on_coroutine_exception():
ee = utils.AbortableEventEmitter()

async def foo():
raise Exception("test")

future = ee.abort_on('e', foo())
with pytest.raises(Exception) as e:
await future
assert e.value.args == ("test",)


# -----------------------------------------------------------------------------
async def test_abort_on_future_aborted():
ee = utils.AbortableEventEmitter()
real_future = asyncio.get_running_loop().create_future()

future = ee.abort_on('e', real_future)
ee.emit('e')

with pytest.raises(core.CancelledError):
await future


# -----------------------------------------------------------------------------
async def test_abort_on_future_non_aborted():
ee = utils.AbortableEventEmitter()
real_future = asyncio.get_running_loop().create_future()

future = ee.abort_on('e', real_future)
real_future.set_result(None)

await future


# -----------------------------------------------------------------------------
async def test_abort_on_future_exception():
ee = utils.AbortableEventEmitter()
real_future = asyncio.get_running_loop().create_future()

future = ee.abort_on('e', real_future)
real_future.set_exception(Exception("test"))

with pytest.raises(Exception) as e:
await future
assert e.value.args == ("test",)


# -----------------------------------------------------------------------------
def run_tests():
test_on()
Expand Down

0 comments on commit d10bd7d

Please sign in to comment.