From d6d8ddc01bfd5b538d5ba481cadac76e309463ac Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Fri, 15 Dec 2023 22:33:33 +0800 Subject: [PATCH] Manage lifecycle of CIS and SCO links in host --- bumble/device.py | 46 ++++++++++++++++++++----------------------- bumble/host.py | 47 +++++++++++++++++++++++++++++++++++++++----- tests/device_test.py | 5 ++--- tests/hfp_test.py | 25 +++++++++++++---------- 4 files changed, 80 insertions(+), 43 deletions(-) diff --git a/bumble/device.py b/bumble/device.py index 6d5c00c80..37173ea5f 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -3078,34 +3078,30 @@ async def create_cis(self, cis_acl_pairs: List[Tuple[int, int]]) -> List[CisLink cig_id=cig_id, ) - result = await self.send_command( - HCI_LE_Create_CIS_Command( - cis_connection_handle=[p[0] for p in cis_acl_pairs], - acl_connection_handle=[p[1] for p in cis_acl_pairs], - ), - ) - if result.status != HCI_COMMAND_STATUS_PENDING: - logger.warning( - 'HCI_LE_Create_CIS_Command failed: ' - f'{HCI_Constant.error_name(result.status)}' - ) - raise HCI_StatusError(result) - - pending_cis_establishments: Dict[int, asyncio.Future[CisLink]] = {} - for cis_handle, _ in cis_acl_pairs: - pending_cis_establishments[ - cis_handle - ] = asyncio.get_running_loop().create_future() - with closing(EventWatcher()) as watcher: + pending_cis_establishments = { + cis_handle: asyncio.get_running_loop().create_future() + for cis_handle, _ in cis_acl_pairs + } @watcher.on(self, 'cis_establishment') def on_cis_establishment(cis_link: CisLink) -> None: - if pending_future := pending_cis_establishments.get( - cis_link.handle, None - ): + if pending_future := pending_cis_establishments.get(cis_link.handle): pending_future.set_result(cis_link) + result = await self.send_command( + HCI_LE_Create_CIS_Command( + cis_connection_handle=[p[0] for p in cis_acl_pairs], + acl_connection_handle=[p[1] for p in cis_acl_pairs], + ), + ) + if result.status != HCI_COMMAND_STATUS_PENDING: + logger.warning( + 'HCI_LE_Create_CIS_Command failed: ' + f'{HCI_Constant.error_name(result.status)}' + ) + raise HCI_StatusError(result) + return await asyncio.gather(*pending_cis_establishments.values()) # [LE only] @@ -3753,7 +3749,7 @@ def on_sco_connection_failure( @host_event_handler @experimental('Only for testing') def on_sco_packet(self, sco_handle: int, packet: HCI_SynchronousDataPacket) -> None: - if sco_link := self.sco_links.get(sco_handle, None): + if sco_link := self.sco_links.get(sco_handle): sco_link.emit('pdu', packet) # [LE only] @@ -3833,7 +3829,7 @@ def on_cis_establishment(self, cis_handle: int) -> None: @experimental('Only for testing') def on_cis_establishment_failure(self, cis_handle: int, status: int) -> None: logger.debug(f'*** CIS Establishment Failure: cis=[0x{cis_handle:04X}] ***') - if cis_link := self.cis_links.pop(cis_handle, None): + if cis_link := self.cis_links.pop(cis_handle): cis_link.emit('establishment_failure') self.emit('cis_establishment_failure', cis_handle, status) @@ -3841,7 +3837,7 @@ def on_cis_establishment_failure(self, cis_handle: int, status: int) -> None: @host_event_handler @experimental('Only for testing') def on_iso_packet(self, handle: int, packet: HCI_IsoDataPacket) -> None: - if cis_link := self.cis_links.get(handle, None): + if cis_link := self.cis_links.get(handle): cis_link.emit('pdu', packet) @host_event_handler diff --git a/bumble/host.py b/bumble/host.py index f28d6fc68..2f198e7dd 100644 --- a/bumble/host.py +++ b/bumble/host.py @@ -18,6 +18,7 @@ from __future__ import annotations import asyncio import collections +import dataclasses import logging import struct @@ -161,9 +162,25 @@ def on_acl_pdu(self, pdu: bytes) -> None: self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload) +# ----------------------------------------------------------------------------- +@dataclasses.dataclass +class ScoLink: + peer_address: Address + handle: int + + +# ----------------------------------------------------------------------------- +@dataclasses.dataclass +class CisLink: + peer_address: Address + handle: int + + # ----------------------------------------------------------------------------- class Host(AbortableEventEmitter): connections: Dict[int, Connection] + cis_links: Dict[int, CisLink] + sco_links: Dict[int, ScoLink] acl_packet_queue: Optional[AclPacketQueue] = None le_acl_packet_queue: Optional[AclPacketQueue] = None hci_sink: Optional[TransportSink] = None @@ -183,6 +200,8 @@ def __init__( self.hci_metadata = {} self.ready = False # True when we can accept incoming packets self.connections = {} # Connections, by connection handle + self.cis_links = {} # CIS links, by connection handle + self.sco_links = {} # SCO links, by connection handle self.pending_command = None self.pending_response = None self.local_version = None @@ -696,25 +715,34 @@ def on_hci_connection_complete_event(self, event): def on_hci_disconnection_complete_event(self, event): # Find the connection - if (connection := self.connections.get(event.connection_handle)) is None: + handle = event.connection_handle + if ( + connection := ( + self.connections.get(handle) + or self.cis_links.get(handle) + or self.sco_links.get(handle) + ) + ) is None: logger.warning('!!! DISCONNECTION COMPLETE: unknown handle') return if event.status == HCI_SUCCESS: logger.debug( - f'### DISCONNECTION: [0x{event.connection_handle:04X}] ' + f'### DISCONNECTION: [0x{handle:04X}] ' f'{connection.peer_address} ' f'reason={event.reason}' ) - del self.connections[event.connection_handle] # Notify the listeners - self.emit('disconnection', event.connection_handle, event.reason) + self.emit('disconnection', handle, event.reason) + del self.connections[handle] + del self.cis_links[handle] + del self.sco_links[handle] else: logger.debug(f'### DISCONNECTION FAILED: {event.status}') # Notify the listeners - self.emit('disconnection_failure', event.connection_handle, event.status) + self.emit('disconnection_failure', handle, event.status) def on_hci_le_connection_update_complete_event(self, event): if (connection := self.connections.get(event.connection_handle)) is None: @@ -775,6 +803,10 @@ def on_hci_le_cis_request_event(self, event): def on_hci_le_cis_established_event(self, event): # The remaining parameters are unused for now. if event.status == HCI_SUCCESS: + self.cis_links[event.connection_handle] = CisLink( + handle=event.connection_handle, + peer_address=Address.ANY, + ) self.emit('cis_establishment', event.connection_handle) else: self.emit( @@ -841,6 +873,11 @@ def on_hci_synchronous_connection_complete_event(self, event): f'{event.bd_addr}' ) + self.sco_links[event.connection_handle] = ScoLink( + peer_address=event.bd_addr, + handle=event.connection_handle, + ) + # Notify the client self.emit( 'sco_connection', diff --git a/tests/device_test.py b/tests/device_test.py index d2b51d86a..8af1e9cd5 100644 --- a/tests/device_test.py +++ b/tests/device_test.py @@ -467,9 +467,8 @@ def on_cis_request( await asyncio.gather(*peripheral_cis_futures.values()) assert len(cis_links) == 2 - # TODO: Fix Host CIS support. - # await cis_links[0].disconnect() - # await cis_links[1].disconnect() + await cis_links[0].disconnect() + await cis_links[1].disconnect() # ----------------------------------------------------------------------------- diff --git a/tests/hfp_test.py b/tests/hfp_test.py index d94488aa5..dc281805e 100644 --- a/tests/hfp_test.py +++ b/tests/hfp_test.py @@ -24,11 +24,11 @@ from .test_utils import TwoDevices from bumble import core -from bumble import device from bumble import hfp from bumble import rfcomm from bumble import hci + # ----------------------------------------------------------------------------- # Logging # ----------------------------------------------------------------------------- @@ -109,7 +109,7 @@ async def test_sco_setup(): devices[1].accept(devices[0].public_address), ) - def on_sco_request(_connection: device.Connection, _link_type: int): + def on_sco_request(_connection, _link_type: int): connections[1].abort_on( 'disconnection', devices[1].send_command( @@ -124,17 +124,13 @@ def on_sco_request(_connection: device.Connection, _link_type: int): devices[1].on('sco_request', on_sco_request) - sco_connections = [ + sco_connection_futures = [ asyncio.get_running_loop().create_future(), asyncio.get_running_loop().create_future(), ] - devices[0].on( - 'sco_connection', lambda sco_link: sco_connections[0].set_result(sco_link) - ) - devices[1].on( - 'sco_connection', lambda sco_link: sco_connections[1].set_result(sco_link) - ) + for device, future in zip(devices, sco_connection_futures): + device.on('sco_connection', future.set_result) await devices[0].send_command( hci.HCI_Enhanced_Setup_Synchronous_Connection_Command( @@ -142,8 +138,17 @@ def on_sco_request(_connection: device.Connection, _link_type: int): **hfp.ESCO_PARAMETERS[hfp.DefaultCodecParameters.ESCO_CVSD_S1].asdict(), ) ) + sco_connections = await asyncio.gather(*sco_connection_futures) + + sco_disconnection_futures = [ + asyncio.get_running_loop().create_future(), + asyncio.get_running_loop().create_future(), + ] + for future, sco_connection in zip(sco_disconnection_futures, sco_connections): + sco_connection.on('disconnection', future.set_result) - await asyncio.gather(*sco_connections) + await sco_connections[0].disconnect() + await asyncio.gather(*sco_disconnection_futures) # -----------------------------------------------------------------------------