Skip to content

Commit

Permalink
Manage lifecycle of CIS and SCO links in host
Browse files Browse the repository at this point in the history
  • Loading branch information
zxzxwu committed Jan 18, 2024
1 parent 45c4c4f commit 92eae70
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 40 deletions.
46 changes: 21 additions & 25 deletions bumble/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -3078,34 +3078,30 @@ async def create_cis(self, cis_acl_pairs: List[Tuple[int, int]]) -> List[CisLink
cig_id=cig_id,
)

result = await self.send_command(
HCI_LE_Create_CIS_Command(
cis_connection_handle=[p[0] for p in cis_acl_pairs],
acl_connection_handle=[p[1] for p in cis_acl_pairs],
),
)
if result.status != HCI_COMMAND_STATUS_PENDING:
logger.warning(
'HCI_LE_Create_CIS_Command failed: '
f'{HCI_Constant.error_name(result.status)}'
)
raise HCI_StatusError(result)

pending_cis_establishments: Dict[int, asyncio.Future[CisLink]] = {}
for cis_handle, _ in cis_acl_pairs:
pending_cis_establishments[
cis_handle
] = asyncio.get_running_loop().create_future()

with closing(EventWatcher()) as watcher:
pending_cis_establishments = {
cis_handle: asyncio.get_running_loop().create_future()
for cis_handle, _ in cis_acl_pairs
}

@watcher.on(self, 'cis_establishment')
def on_cis_establishment(cis_link: CisLink) -> None:
if pending_future := pending_cis_establishments.get(
cis_link.handle, None
):
if pending_future := pending_cis_establishments.get(cis_link.handle):
pending_future.set_result(cis_link)

result = await self.send_command(
HCI_LE_Create_CIS_Command(
cis_connection_handle=[p[0] for p in cis_acl_pairs],
acl_connection_handle=[p[1] for p in cis_acl_pairs],
),
)
if result.status != HCI_COMMAND_STATUS_PENDING:
logger.warning(
'HCI_LE_Create_CIS_Command failed: '
f'{HCI_Constant.error_name(result.status)}'
)
raise HCI_StatusError(result)

return await asyncio.gather(*pending_cis_establishments.values())

# [LE only]
Expand Down Expand Up @@ -3753,7 +3749,7 @@ def on_sco_connection_failure(
@host_event_handler
@experimental('Only for testing')
def on_sco_packet(self, sco_handle: int, packet: HCI_SynchronousDataPacket) -> None:
if sco_link := self.sco_links.get(sco_handle, None):
if sco_link := self.sco_links.get(sco_handle):
sco_link.emit('pdu', packet)

# [LE only]
Expand Down Expand Up @@ -3833,15 +3829,15 @@ def on_cis_establishment(self, cis_handle: int) -> None:
@experimental('Only for testing')
def on_cis_establishment_failure(self, cis_handle: int, status: int) -> None:
logger.debug(f'*** CIS Establishment Failure: cis=[0x{cis_handle:04X}] ***')
if cis_link := self.cis_links.pop(cis_handle, None):
if cis_link := self.cis_links.pop(cis_handle):
cis_link.emit('establishment_failure')
self.emit('cis_establishment_failure', cis_handle, status)

# [LE only]
@host_event_handler
@experimental('Only for testing')
def on_iso_packet(self, handle: int, packet: HCI_IsoDataPacket) -> None:
if cis_link := self.cis_links.get(handle, None):
if cis_link := self.cis_links.get(handle):
cis_link.emit('pdu', packet)

@host_event_handler
Expand Down
40 changes: 38 additions & 2 deletions bumble/host.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations
import asyncio
import collections
import dataclasses
import logging
import struct

Expand Down Expand Up @@ -161,9 +162,25 @@ def on_acl_pdu(self, pdu: bytes) -> None:
self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload)


# -----------------------------------------------------------------------------
@dataclasses.dataclass
class ScoLink:
peer_address: Address
handle: int


# -----------------------------------------------------------------------------
@dataclasses.dataclass
class CisLink:
peer_address: Address
handle: int


# -----------------------------------------------------------------------------
class Host(AbortableEventEmitter):
connections: Dict[int, Connection]
cis_links: Dict[int, CisLink]
sco_links: Dict[int, ScoLink]
acl_packet_queue: Optional[AclPacketQueue] = None
le_acl_packet_queue: Optional[AclPacketQueue] = None
hci_sink: Optional[TransportSink] = None
Expand All @@ -183,6 +200,8 @@ def __init__(
self.hci_metadata = {}
self.ready = False # True when we can accept incoming packets
self.connections = {} # Connections, by connection handle
self.cis_links = {} # CIS links, by connection handle
self.sco_links = {} # SCO links, by connection handle
self.pending_command = None
self.pending_response = None
self.local_version = None
Expand Down Expand Up @@ -696,7 +715,13 @@ def on_hci_connection_complete_event(self, event):

def on_hci_disconnection_complete_event(self, event):
# Find the connection
if (connection := self.connections.get(event.connection_handle)) is None:
if (
connection := (
self.connections.get(event.connection_handle)
or self.cis_links.get(event.connection_handle)
or self.sco_links.get(event.connection_handle)
)
) is None:
logger.warning('!!! DISCONNECTION COMPLETE: unknown handle')
return

Expand All @@ -706,10 +731,12 @@ def on_hci_disconnection_complete_event(self, event):
f'{connection.peer_address} '
f'reason={event.reason}'
)
del self.connections[event.connection_handle]

# Notify the listeners
self.emit('disconnection', event.connection_handle, event.reason)
del self.connections[event.connection_handle]
del self.cis_links[event.connection_handle]
del self.sco_links[event.connection_handle]
else:
logger.debug(f'### DISCONNECTION FAILED: {event.status}')

Expand Down Expand Up @@ -775,6 +802,10 @@ def on_hci_le_cis_request_event(self, event):
def on_hci_le_cis_established_event(self, event):
# The remaining parameters are unused for now.
if event.status == HCI_SUCCESS:
self.cis_links[event.connection_handle] = CisLink(
handle=event.connection_handle,
peer_address=Address.ANY,
)
self.emit('cis_establishment', event.connection_handle)
else:
self.emit(
Expand Down Expand Up @@ -841,6 +872,11 @@ def on_hci_synchronous_connection_complete_event(self, event):
f'{event.bd_addr}'
)

self.sco_links[event.connection_handle] = ScoLink(
peer_address=event.bd_addr,
handle=event.connection_handle,
)

# Notify the client
self.emit(
'sco_connection',
Expand Down
5 changes: 2 additions & 3 deletions tests/device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,9 +467,8 @@ def on_cis_request(
await asyncio.gather(*peripheral_cis_futures.values())
assert len(cis_links) == 2

# TODO: Fix Host CIS support.
# await cis_links[0].disconnect()
# await cis_links[1].disconnect()
await cis_links[0].disconnect()
await cis_links[1].disconnect()


# -----------------------------------------------------------------------------
Expand Down
25 changes: 15 additions & 10 deletions tests/hfp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@

from .test_utils import TwoDevices
from bumble import core
from bumble import device
from bumble import hfp
from bumble import rfcomm
from bumble import hci


# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -109,7 +109,7 @@ async def test_sco_setup():
devices[1].accept(devices[0].public_address),
)

def on_sco_request(_connection: device.Connection, _link_type: int):
def on_sco_request(_connection, _link_type: int):
connections[1].abort_on(
'disconnection',
devices[1].send_command(
Expand All @@ -124,26 +124,31 @@ def on_sco_request(_connection: device.Connection, _link_type: int):

devices[1].on('sco_request', on_sco_request)

sco_connections = [
sco_connection_futures = [
asyncio.get_running_loop().create_future(),
asyncio.get_running_loop().create_future(),
]

devices[0].on(
'sco_connection', lambda sco_link: sco_connections[0].set_result(sco_link)
)
devices[1].on(
'sco_connection', lambda sco_link: sco_connections[1].set_result(sco_link)
)
for device, future in zip(devices, sco_connection_futures):
device.on('sco_connection', future.set_result)

await devices[0].send_command(
hci.HCI_Enhanced_Setup_Synchronous_Connection_Command(
connection_handle=connections[0].handle,
**hfp.ESCO_PARAMETERS[hfp.DefaultCodecParameters.ESCO_CVSD_S1].asdict(),
)
)
sco_connections = await asyncio.gather(*sco_connection_futures)

sco_disconnection_futures = [
asyncio.get_running_loop().create_future(),
asyncio.get_running_loop().create_future(),
]
for future, sco_connection in zip(sco_disconnection_futures, sco_connections):
sco_connection.on('disconnection', future.set_result)

await asyncio.gather(*sco_connections)
await sco_connections[0].disconnect()
await asyncio.gather(*sco_disconnection_futures)


# -----------------------------------------------------------------------------
Expand Down

0 comments on commit 92eae70

Please sign in to comment.