diff --git a/.vscode/settings.json b/.vscode/settings.json index b535ada8..777c47b4 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,6 +1,7 @@ { "cSpell.words": [ "Abortable", + "aiohttp", "altsetting", "ansiblue", "ansicyan", @@ -9,6 +10,7 @@ "ansired", "ansiyellow", "appendleft", + "ascs", "ASHA", "asyncio", "ATRAC", @@ -43,6 +45,7 @@ "keyup", "levelname", "libc", + "liblc", "libusb", "MITM", "MSBC", @@ -78,6 +81,7 @@ "unmuted", "usbmodem", "vhci", + "wasmtime", "websockets", "xcursor", "ycursor" diff --git a/apps/lea_unicast/app.py b/apps/lea_unicast/app.py new file mode 100644 index 00000000..cf0de3da --- /dev/null +++ b/apps/lea_unicast/app.py @@ -0,0 +1,586 @@ +# Copyright 2021-2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +from __future__ import annotations +import asyncio +import datetime +import enum +import functools +from importlib import resources +import json +import os +import logging +import pathlib +from typing import Optional, List, cast +import weakref +import struct + +import ctypes +import wasmtime +import wasmtime.loader +import liblc3 # type: ignore +import logging + +import click +import aiohttp.web + +import bumble +from bumble.core import AdvertisingData +from bumble.colors import color +from bumble.device import Device, DeviceConfiguration, AdvertisingParameters +from bumble.transport import open_transport +from bumble.profiles import bap +from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket + +# ----------------------------------------------------------------------------- +# Logging +# ----------------------------------------------------------------------------- +logger = logging.getLogger(__name__) + +# ----------------------------------------------------------------------------- +# Constants +# ----------------------------------------------------------------------------- +DEFAULT_UI_PORT = 7654 + + +def _sink_pac_record() -> bap.PacRecord: + return bap.PacRecord( + coding_format=CodingFormat(CodecID.LC3), + codec_specific_capabilities=bap.CodecSpecificCapabilities( + supported_sampling_frequencies=( + bap.SupportedSamplingFrequency.FREQ_8000 + | bap.SupportedSamplingFrequency.FREQ_16000 + | bap.SupportedSamplingFrequency.FREQ_24000 + | bap.SupportedSamplingFrequency.FREQ_32000 + | bap.SupportedSamplingFrequency.FREQ_48000 + ), + supported_frame_durations=( + bap.SupportedFrameDuration.DURATION_10000_US_SUPPORTED + ), + supported_audio_channel_counts=[1, 2], + min_octets_per_codec_frame=26, + max_octets_per_codec_frame=240, + supported_max_codec_frames_per_sdu=2, + ), + ) + + +def _source_pac_record() -> bap.PacRecord: + return bap.PacRecord( + coding_format=CodingFormat(CodecID.LC3), + codec_specific_capabilities=bap.CodecSpecificCapabilities( + supported_sampling_frequencies=( + bap.SupportedSamplingFrequency.FREQ_8000 + | bap.SupportedSamplingFrequency.FREQ_16000 + | bap.SupportedSamplingFrequency.FREQ_24000 + | bap.SupportedSamplingFrequency.FREQ_32000 + | bap.SupportedSamplingFrequency.FREQ_48000 + ), + supported_frame_durations=( + bap.SupportedFrameDuration.DURATION_10000_US_SUPPORTED + ), + supported_audio_channel_counts=[1], + min_octets_per_codec_frame=30, + max_octets_per_codec_frame=100, + supported_max_codec_frames_per_sdu=1, + ), + ) + + +# ----------------------------------------------------------------------------- +# WASM - liblc3 +# ----------------------------------------------------------------------------- +store = wasmtime.loader.store +_memory = cast(wasmtime.Memory, liblc3.memory) +STACK_POINTER = _memory.data_len(store) +_memory.grow(store, 1) +# Mapping wasmtime memory to linear address +memory = (ctypes.c_ubyte * _memory.data_len(store)).from_address( + ctypes.addressof(_memory.data_ptr(store).contents) # type: ignore +) + + +class Liblc3PcmFormat(enum.IntEnum): + S16 = 0 + S24 = 1 + S24_3LE = 2 + FLOAT = 3 + + +MAX_DECODER_SIZE = liblc3.lc3_decoder_size(10000, 48000) +MAX_ENCODER_SIZE = liblc3.lc3_encoder_size(10000, 48000) + +DECODER_STACK_POINTER = STACK_POINTER +ENCODER_STACK_POINTER = DECODER_STACK_POINTER + MAX_DECODER_SIZE * 2 +DECODE_BUFFER_STACK_POINTER = ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * 2 +ENCODE_BUFFER_STACK_POINTER = DECODE_BUFFER_STACK_POINTER + 8192 +DEFAULT_PCM_SAMPLE_RATE = 48000 +DEFAULT_PCM_FORMAT = Liblc3PcmFormat.S16 +DEFAULT_PCM_BYTES_PER_SAMPLE = 2 + + +encoders: List[int] = [] +decoders: List[int] = [] + + +def setup_encoders( + sample_rate_hz: int, frame_duration_us: int, num_channels: int +) -> None: + logger.info( + f"setup_encoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels" + ) + encoders[:num_channels] = [ + liblc3.lc3_setup_encoder( + frame_duration_us, + sample_rate_hz, + DEFAULT_PCM_SAMPLE_RATE, # Input sample rate + ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * i, + ) + for i in range(num_channels) + ] + + +def setup_decoders( + sample_rate_hz: int, frame_duration_us: int, num_channels: int +) -> None: + logger.info( + f"setup_decoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels" + ) + decoders[:num_channels] = [ + liblc3.lc3_setup_decoder( + frame_duration_us, + sample_rate_hz, + DEFAULT_PCM_SAMPLE_RATE, # Output sample rate + DECODER_STACK_POINTER + MAX_DECODER_SIZE * i, + ) + for i in range(num_channels) + ] + + +def decode( + frame_duration_us: int, + num_channels: int, + input_bytes: bytes, +) -> bytes: + if not input_bytes: + return b'' + + input_buffer_offset = DECODE_BUFFER_STACK_POINTER + input_buffer_size = len(input_bytes) + input_bytes_per_frame = input_buffer_size // num_channels + + # Copy into wasm + memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore + + output_buffer_offset = input_buffer_offset + input_buffer_size + output_buffer_size = ( + liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE) + * DEFAULT_PCM_BYTES_PER_SAMPLE + * num_channels + ) + + for i in range(num_channels): + res = liblc3.lc3_decode( + decoders[i], + input_buffer_offset + input_bytes_per_frame * i, + input_bytes_per_frame, + DEFAULT_PCM_FORMAT, + output_buffer_offset + i * DEFAULT_PCM_BYTES_PER_SAMPLE, + num_channels, # Stride + ) + + if res != 0: + logging.error(f"Parsing failed, res={res}") + + # Extract decoded data from the output buffer + return bytes( + memory[output_buffer_offset : output_buffer_offset + output_buffer_size] + ) + + +def encode( + sdu_length: int, + num_channels: int, + stride: int, + input_bytes: bytes, +) -> bytes: + if not input_bytes: + return b'' + + input_buffer_offset = ENCODE_BUFFER_STACK_POINTER + input_buffer_size = len(input_bytes) + + # Copy into wasm + memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore + + output_buffer_offset = input_buffer_offset + input_buffer_size + output_buffer_size = sdu_length + output_frame_size = output_buffer_size // num_channels + + for i in range(num_channels): + res = liblc3.lc3_encode( + encoders[i], + DEFAULT_PCM_FORMAT, + input_buffer_offset + DEFAULT_PCM_BYTES_PER_SAMPLE * i, + stride, + output_frame_size, + output_buffer_offset + output_frame_size * i, + ) + + if res != 0: + logging.error(f"Parsing failed, res={res}") + + # Extract decoded data from the output buffer + return bytes( + memory[output_buffer_offset : output_buffer_offset + output_buffer_size] + ) + + +async def lc3_source_task( + filename: str, + sdu_length: int, + frame_duration_us: int, + device: Device, + cis_handle: int, +) -> None: + with open(filename, 'rb') as f: + header = f.read(44) + assert header[8:12] == b'WAVE' + + pcm_num_channel, pcm_sample_rate, _byte_rate, _block_align, bits_per_sample = ( + struct.unpack(" None: + self.speaker = weakref.ref(speaker) + self.port = port + self.channel_socket = None + + async def start_http(self) -> None: + """Start the UI HTTP server.""" + + app = aiohttp.web.Application() + app.add_routes( + [ + aiohttp.web.get('/', self.get_static), + aiohttp.web.get('/index.html', self.get_static), + aiohttp.web.get('/channel', self.get_channel), + ] + ) + + runner = aiohttp.web.AppRunner(app) + await runner.setup() + site = aiohttp.web.TCPSite(runner, 'localhost', self.port) + print('UI HTTP server at ' + color(f'http://127.0.0.1:{self.port}', 'green')) + await site.start() + + async def get_static(self, request): + path = request.path + if path == '/': + path = '/index.html' + if path.endswith('.html'): + content_type = 'text/html' + elif path.endswith('.js'): + content_type = 'text/javascript' + elif path.endswith('.css'): + content_type = 'text/css' + elif path.endswith('.svg'): + content_type = 'image/svg+xml' + else: + content_type = 'text/plain' + text = ( + resources.files("bumble.apps.lea_unicast") + .joinpath(pathlib.Path(path).relative_to('/')) + .read_text(encoding="utf-8") + ) + return aiohttp.web.Response(text=text, content_type=content_type) + + async def get_channel(self, request): + ws = aiohttp.web.WebSocketResponse() + await ws.prepare(request) + + # Process messages until the socket is closed. + self.channel_socket = ws + async for message in ws: + if message.type == aiohttp.WSMsgType.TEXT: + logger.debug(f'<<< received message: {message.data}') + await self.on_message(message.data) + elif message.type == aiohttp.WSMsgType.ERROR: + logger.debug( + f'channel connection closed with exception {ws.exception()}' + ) + + self.channel_socket = None + logger.debug('--- channel connection closed') + + return ws + + async def on_message(self, message_str: str): + # Parse the message as JSON + message = json.loads(message_str) + + # Dispatch the message + message_type = message['type'] + message_params = message.get('params', {}) + handler = getattr(self, f'on_{message_type}_message') + if handler: + await handler(**message_params) + + async def on_hello_message(self): + await self.send_message( + 'hello', + bumble_version=bumble.__version__, + codec=self.speaker().codec, + streamState=self.speaker().stream_state.name, + ) + if connection := self.speaker().connection: + await self.send_message( + 'connection', + peer_address=connection.peer_address.to_string(False), + peer_name=connection.peer_name, + ) + + async def send_message(self, message_type: str, **kwargs) -> None: + if self.channel_socket is None: + return + + message = {'type': message_type, 'params': kwargs} + await self.channel_socket.send_json(message) + + async def send_audio(self, data: bytes) -> None: + if self.channel_socket is None: + return + + try: + await self.channel_socket.send_bytes(data) + except Exception as error: + logger.warning(f'exception while sending audio packet: {error}') + + +# ----------------------------------------------------------------------------- +class Speaker: + + def __init__( + self, + device_config_path: Optional[str], + ui_port: int, + transport: str, + lc3_input_file_path: str, + ): + self.device_config_path = device_config_path + self.transport = transport + self.lc3_input_file_path = lc3_input_file_path + + # Create an HTTP server for the UI + self.ui_server = UiServer(speaker=self, port=ui_port) + + async def run(self) -> None: + await self.ui_server.start_http() + + async with await open_transport(self.transport) as hci_transport: + # Create a device + if self.device_config_path: + device_config = DeviceConfiguration.from_file(self.device_config_path) + else: + device_config = DeviceConfiguration( + name="Bumble LE Headphone", + class_of_device=0x244418, + keystore="JsonKeyStore", + advertising_interval_min=25, + advertising_interval_max=25, + address=Address('F1:F2:F3:F4:F5:F6'), + ) + + device_config.le_enabled = True + device_config.cis_enabled = True + self.device = Device.from_config_with_hci( + device_config, hci_transport.source, hci_transport.sink + ) + + self.device.add_service( + bap.PublishedAudioCapabilitiesService( + supported_source_context=bap.ContextType(0xFFFF), + available_source_context=bap.ContextType(0xFFFF), + supported_sink_context=bap.ContextType(0xFFFF), # All context types + available_sink_context=bap.ContextType(0xFFFF), # All context types + sink_audio_locations=( + bap.AudioLocation.FRONT_LEFT | bap.AudioLocation.FRONT_RIGHT + ), + sink_pac=[_sink_pac_record()], + source_audio_locations=bap.AudioLocation.FRONT_LEFT, + source_pac=[_source_pac_record()], + ) + ) + + ascs = bap.AudioStreamControlService( + self.device, sink_ase_id=[1], source_ase_id=[2] + ) + self.device.add_service(ascs) + + advertising_data = bytes( + AdvertisingData( + [ + ( + AdvertisingData.COMPLETE_LOCAL_NAME, + bytes(device_config.name, 'utf-8'), + ), + ( + AdvertisingData.FLAGS, + bytes( + [ + AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG + | AdvertisingData.BR_EDR_HOST_FLAG + | AdvertisingData.BR_EDR_CONTROLLER_FLAG + ] + ), + ), + ( + AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, + bytes(bap.PublishedAudioCapabilitiesService.UUID), + ), + ] + ) + ) + bytes(bap.UnicastServerAdvertisingData()) + + def on_pdu(pdu: HCI_IsoDataPacket, ase: bap.AseStateMachine): + codec_config = ase.codec_specific_configuration + assert isinstance(codec_config, bap.CodecSpecificConfiguration) + pcm = decode( + codec_config.frame_duration.us, + codec_config.audio_channel_allocation.num_channels, + pdu.iso_sdu_fragment, + ) + self.device.abort_on('disconnection', self.ui_server.send_audio(pcm)) + + def on_ase_state_change( + state: bap.AseStateMachine.State, + ase: bap.AseStateMachine, + ) -> None: + if state == bap.AseStateMachine.State.STREAMING: + codec_config = ase.codec_specific_configuration + assert isinstance(codec_config, bap.CodecSpecificConfiguration) + assert ase.cis_link + if ase.role == bap.AudioRole.SOURCE: + ase.cis_link.abort_on( + 'disconnection', + lc3_source_task( + filename=self.lc3_input_file_path, + sdu_length=( + codec_config.codec_frames_per_sdu + * codec_config.octets_per_codec_frame + ), + frame_duration_us=codec_config.frame_duration.us, + device=self.device, + cis_handle=ase.cis_link.handle, + ), + ) + else: + ase.cis_link.sink = functools.partial(on_pdu, ase=ase) + elif state == bap.AseStateMachine.State.CODEC_CONFIGURED: + codec_config = ase.codec_specific_configuration + assert isinstance(codec_config, bap.CodecSpecificConfiguration) + if ase.role == bap.AudioRole.SOURCE: + setup_encoders( + codec_config.sampling_frequency.hz, + codec_config.frame_duration.us, + codec_config.audio_channel_allocation.num_channels, + ) + else: + setup_decoders( + codec_config.sampling_frequency.hz, + codec_config.frame_duration.us, + codec_config.audio_channel_allocation.num_channels, + ) + + for ase in ascs.ase_state_machines.values(): + ase.on('state_change', functools.partial(on_ase_state_change, ase=ase)) + + await self.device.power_on() + await self.device.create_advertising_set( + advertising_data=advertising_data, + auto_restart=True, + advertising_parameters=AdvertisingParameters( + primary_advertising_interval_min=100, + primary_advertising_interval_max=100, + ), + ) + + await hci_transport.source.terminated + + +@click.command() +@click.option( + '--ui-port', + 'ui_port', + metavar='HTTP_PORT', + default=DEFAULT_UI_PORT, + show_default=True, + help='HTTP port for the UI server', +) +@click.option('--device-config', metavar='FILENAME', help='Device configuration file') +@click.argument('transport') +@click.argument('lc3_file') +def speaker(ui_port: int, device_config: str, transport: str, lc3_file: str) -> None: + """Run the speaker.""" + + asyncio.run(Speaker(device_config, ui_port, transport, lc3_file).run()) + + +# ----------------------------------------------------------------------------- +def main(): + logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper()) + speaker() + + +# ----------------------------------------------------------------------------- +if __name__ == "__main__": + main() # pylint: disable=no-value-for-parameter diff --git a/apps/lea_unicast/index.html b/apps/lea_unicast/index.html new file mode 100644 index 00000000..fb1e61c9 --- /dev/null +++ b/apps/lea_unicast/index.html @@ -0,0 +1,68 @@ + + + + + + + + + +
+ +
+ + +
+ + + + + + + \ No newline at end of file diff --git a/apps/lea_unicast/liblc3.wasm b/apps/lea_unicast/liblc3.wasm new file mode 100755 index 00000000..e9051058 Binary files /dev/null and b/apps/lea_unicast/liblc3.wasm differ diff --git a/bumble/device.py b/bumble/device.py index bbb2f967..f858c0f8 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -961,8 +961,9 @@ class ScoLink(CompositeEventEmitter): acl_connection: Connection handle: int link_type: int + sink: Optional[Callable[[HCI_SynchronousDataPacket], Any]] = None - def __post_init__(self): + def __post_init__(self) -> None: super().__init__() async def disconnect( @@ -984,8 +985,9 @@ class State(IntEnum): cis_id: int # CIS ID assigned by Central device cig_id: int # CIG ID assigned by Central device state: State = State.PENDING + sink: Optional[Callable[[HCI_IsoDataPacket], Any]] = None - def __post_init__(self): + def __post_init__(self) -> None: super().__init__() async def disconnect( @@ -1518,6 +1520,7 @@ def __init__( self.classic_pending_accepts = { Address.ANY: [] } # Futures, by BD address OR [Futures] for Address.ANY + self.cis_lock = asyncio.Lock() # Own address type cache self.connect_own_address_type = None @@ -3415,26 +3418,44 @@ def on_cis_establishment(cis_link: CisLink) -> None: # [LE only] @experimental('Only for testing.') async def accept_cis_request(self, handle: int) -> CisLink: - result = await self.send_command( - HCI_LE_Accept_CIS_Request_Command(connection_handle=handle), - ) - if result.status != HCI_COMMAND_STATUS_PENDING: - logger.warning( - 'HCI_LE_Accept_CIS_Request_Command failed: ' - f'{HCI_Constant.error_name(result.status)}' - ) - raise HCI_StatusError(result) + """[LE Only] Accepts an incoming CIS request. - pending_cis_establishment = asyncio.get_running_loop().create_future() + When the specified CIS handle is already created, this method returns the + existed CIS link object immediately. - with closing(EventWatcher()) as watcher: + Args: + handle: CIS handle to accept. - @watcher.on(self, 'cis_establishment') - def on_cis_establishment(cis_link: CisLink) -> None: - if cis_link.handle == handle: - pending_cis_establishment.set_result(cis_link) + Returns: + CIS link object on the given handle. + """ + if not (cis_link := self.cis_links.get(handle)): + raise InvalidStateError(f'No pending CIS request of handle {handle}') + + if cis_link.state == CisLink.State.ESTABLISHED: + return cis_link + + async with self.cis_lock: + with closing(EventWatcher()) as watcher: + pending_establishment = asyncio.get_running_loop().create_future() + watcher.on( + cis_link, + 'establishment', + lambda: pending_establishment.set_result(None), + ) + + result = await self.send_command( + HCI_LE_Accept_CIS_Request_Command(connection_handle=handle), + ) + if result.status != HCI_COMMAND_STATUS_PENDING: + logger.warning( + 'HCI_LE_Accept_CIS_Request_Command failed: ' + f'{HCI_Constant.error_name(result.status)}' + ) + raise HCI_StatusError(result) - return await pending_cis_establishment + await pending_establishment + return cis_link # [LE only] @experimental('Only for testing.') @@ -4097,8 +4118,8 @@ def on_sco_connection_failure( @host_event_handler @experimental('Only for testing') def on_sco_packet(self, sco_handle: int, packet: HCI_SynchronousDataPacket) -> None: - if sco_link := self.sco_links.get(sco_handle): - sco_link.emit('pdu', packet) + if (sco_link := self.sco_links.get(sco_handle)) and sco_link.sink: + sco_link.sink(packet) # [LE only] @host_event_handler @@ -4161,8 +4182,8 @@ def on_cis_establishment_failure(self, cis_handle: int, status: int) -> None: @host_event_handler @experimental('Only for testing') def on_iso_packet(self, handle: int, packet: HCI_IsoDataPacket) -> None: - if cis_link := self.cis_links.get(handle): - cis_link.emit('pdu', packet) + if (cis_link := self.cis_links.get(handle)) and cis_link.sink: + cis_link.sink(packet) @host_event_handler @with_connection_from_handle diff --git a/bumble/hci.py b/bumble/hci.py index fba89515..9ef40bf2 100644 --- a/bumble/hci.py +++ b/bumble/hci.py @@ -23,7 +23,7 @@ import logging import secrets import struct -from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union, ClassVar from bumble import crypto from .colors import color @@ -2003,7 +2003,7 @@ class HCI_Packet: Abstract Base class for HCI packets ''' - hci_packet_type: int + hci_packet_type: ClassVar[int] @staticmethod def from_bytes(packet: bytes) -> HCI_Packet: @@ -6192,12 +6192,23 @@ def __str__(self) -> str: # ----------------------------------------------------------------------------- +@dataclasses.dataclass class HCI_IsoDataPacket(HCI_Packet): ''' See Bluetooth spec @ 5.4.5 HCI ISO Data Packets ''' - hci_packet_type = HCI_ISO_DATA_PACKET + hci_packet_type: ClassVar[int] = HCI_ISO_DATA_PACKET + + connection_handle: int + data_total_length: int + iso_sdu_fragment: bytes + pb_flag: int + ts_flag: int = 0 + time_stamp: Optional[int] = None + packet_sequence_number: Optional[int] = None + iso_sdu_length: Optional[int] = None + packet_status_flag: Optional[int] = None @staticmethod def from_bytes(packet: bytes) -> HCI_IsoDataPacket: @@ -6241,28 +6252,6 @@ def from_bytes(packet: bytes) -> HCI_IsoDataPacket: iso_sdu_fragment=iso_sdu_fragment, ) - def __init__( - self, - connection_handle: int, - pb_flag: int, - ts_flag: int, - data_total_length: int, - time_stamp: Optional[int], - packet_sequence_number: Optional[int], - iso_sdu_length: Optional[int], - packet_status_flag: Optional[int], - iso_sdu_fragment: bytes, - ) -> None: - self.connection_handle = connection_handle - self.pb_flag = pb_flag - self.ts_flag = ts_flag - self.data_total_length = data_total_length - self.time_stamp = time_stamp - self.packet_sequence_number = packet_sequence_number - self.iso_sdu_length = iso_sdu_length - self.packet_status_flag = packet_status_flag - self.iso_sdu_fragment = iso_sdu_fragment - def __bytes__(self) -> bytes: return self.to_bytes() diff --git a/bumble/host.py b/bumble/host.py index 609012ac..f9fe7e63 100644 --- a/bumble/host.py +++ b/bumble/host.py @@ -719,14 +719,16 @@ def on_hci_number_of_completed_packets_event(self, event): for connection_handle, num_completed_packets in zip( event.connection_handles, event.num_completed_packets ): - if not (connection := self.connections.get(connection_handle)): + if connection := self.connections.get(connection_handle): + connection.acl_packet_queue.on_packets_completed(num_completed_packets) + elif not ( + self.cis_links.get(connection_handle) + or self.sco_links.get(connection_handle) + ): logger.warning( 'received packet completion event for unknown handle ' f'0x{connection_handle:04X}' ) - continue - - connection.acl_packet_queue.on_packets_completed(num_completed_packets) # Classic only def on_hci_connection_request_event(self, event): diff --git a/bumble/profiles/bap.py b/bumble/profiles/bap.py index b54ad1dc..d3b40cbe 100644 --- a/bumble/profiles/bap.py +++ b/bumble/profiles/bap.py @@ -78,6 +78,10 @@ class AudioLocation(enum.IntFlag): LEFT_SURROUND = 0x04000000 RIGHT_SURROUND = 0x08000000 + @property + def num_channels(self) -> int: + return bin(self.value).count('1') + class AudioInputType(enum.IntEnum): '''Bluetooth Assigned Numbers, Section 6.12.2 - Audio Input Type''' @@ -218,6 +222,13 @@ class FrameDuration(enum.IntEnum): DURATION_7500_US = 0x00 DURATION_10000_US = 0x01 + @property + def us(self) -> int: + return { + FrameDuration.DURATION_7500_US: 7500, + FrameDuration.DURATION_10000_US: 10000, + }[self] + class SupportedFrameDuration(enum.IntFlag): '''Bluetooth Assigned Numbers, Section 6.12.4.2 - Frame Duration''' @@ -870,15 +881,22 @@ def on_cis_request( cig_id: int, cis_id: int, ) -> None: - if cis_id == self.cis_id and self.state == self.State.ENABLING: + if ( + cig_id == self.cig_id + and cis_id == self.cis_id + and self.state == self.State.ENABLING + ): acl_connection.abort_on( 'flush', self.service.device.accept_cis_request(cis_handle) ) def on_cis_establishment(self, cis_link: device.CisLink) -> None: - if cis_link.cis_id == self.cis_id and self.state == self.State.ENABLING: - self.state = self.State.STREAMING - self.cis_link = cis_link + if ( + cis_link.cig_id == self.cig_id + and cis_link.cis_id == self.cis_id + and self.state == self.State.ENABLING + ): + cis_link.on('disconnection', self.on_cis_disconnection) async def post_cis_established(): await self.service.device.send_command( @@ -891,9 +909,15 @@ async def post_cis_established(): codec_configuration=b'', ) ) + if self.role == AudioRole.SINK: + self.state = self.State.STREAMING await self.service.device.notify_subscribers(self, self.value) cis_link.acl_connection.abort_on('flush', post_cis_established()) + self.cis_link = cis_link + + def on_cis_disconnection(self, _reason) -> None: + self.cis_link = None def on_config_codec( self, @@ -991,11 +1015,17 @@ def on_disable(self) -> Tuple[AseResponseCode, AseReasonCode]: AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, AseReasonCode.NONE, ) - self.state = self.State.DISABLING + if self.role == AudioRole.SINK: + self.state = self.State.QOS_CONFIGURED + else: + self.state = self.State.DISABLING return (AseResponseCode.SUCCESS, AseReasonCode.NONE) def on_receiver_stop_ready(self) -> Tuple[AseResponseCode, AseReasonCode]: - if self.state != AseStateMachine.State.DISABLING: + if ( + self.role != AudioRole.SOURCE + or self.state != AseStateMachine.State.DISABLING + ): return ( AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, AseReasonCode.NONE, @@ -1044,8 +1074,9 @@ def state(self) -> State: @state.setter def state(self, new_state: State) -> None: - logger.debug(f'{self} state change -> {colors.color(new_state.name, "cyan")}') + logger.info(f'{self} state change -> {colors.color(new_state.name, "cyan")}') self._state = new_state + self.emit('state_change', new_state) @property def value(self): @@ -1118,6 +1149,7 @@ class AudioStreamControlService(gatt.TemplateService): ase_state_machines: Dict[int, AseStateMachine] ase_control_point: gatt.Characteristic + _active_client: Optional[device.Connection] = None def __init__( self, @@ -1155,7 +1187,16 @@ def on_operation(self, opcode: ASE_Operation.Opcode, ase_id: int, args): else: return (ase_id, AseResponseCode.INVALID_ASE_ID, AseReasonCode.NONE) + def _on_client_disconnected(self, _reason: int) -> None: + for ase in self.ase_state_machines.values(): + ase.state = AseStateMachine.State.IDLE + self._active_client = None + def on_write_ase_control_point(self, connection, data): + if not self._active_client and connection: + self._active_client = connection + connection.once('disconnection', self._on_client_disconnected) + operation = ASE_Operation.from_bytes(data) responses = [] logger.debug(f'*** ASCS Write {operation} ***') diff --git a/examples/run_unicast_server.py b/examples/run_unicast_server.py index 60d2f4ae..1e262905 100644 --- a/examples/run_unicast_server.py +++ b/examples/run_unicast_server.py @@ -16,20 +16,28 @@ # Imports # ----------------------------------------------------------------------------- import asyncio +import datetime +import functools import logging import sys import os +import io import struct import secrets + +from typing import Dict + from bumble.core import AdvertisingData -from bumble.device import Device, CisLink +from bumble.device import Device from bumble.hci import ( CodecID, CodingFormat, HCI_IsoDataPacket, ) from bumble.profiles.bap import ( + AseStateMachine, UnicastServerAdvertisingData, + CodecSpecificConfiguration, CodecSpecificCapabilities, ContextType, AudioLocation, @@ -45,6 +53,32 @@ from bumble.transport import open_transport_or_link +def _sink_pac_record() -> PacRecord: + return PacRecord( + coding_format=CodingFormat(CodecID.LC3), + codec_specific_capabilities=CodecSpecificCapabilities( + supported_sampling_frequencies=( + SupportedSamplingFrequency.FREQ_8000 + | SupportedSamplingFrequency.FREQ_16000 + | SupportedSamplingFrequency.FREQ_24000 + | SupportedSamplingFrequency.FREQ_32000 + | SupportedSamplingFrequency.FREQ_48000 + ), + supported_frame_durations=( + SupportedFrameDuration.DURATION_7500_US_SUPPORTED + | SupportedFrameDuration.DURATION_10000_US_SUPPORTED + ), + supported_audio_channel_counts=[1, 2], + min_octets_per_codec_frame=26, + max_octets_per_codec_frame=240, + supported_max_codec_frames_per_sdu=2, + ), + ) + + +file_outputs: Dict[AseStateMachine, io.BufferedWriter] = {} + + # ----------------------------------------------------------------------------- async def main() -> None: if len(sys.argv) < 3: @@ -71,49 +105,17 @@ async def main() -> None: PublishedAudioCapabilitiesService( supported_source_context=ContextType.PROHIBITED, available_source_context=ContextType.PROHIBITED, - supported_sink_context=ContextType.MEDIA, - available_sink_context=ContextType.MEDIA, + supported_sink_context=ContextType(0xFF), # All context types + available_sink_context=ContextType(0xFF), # All context types sink_audio_locations=( AudioLocation.FRONT_LEFT | AudioLocation.FRONT_RIGHT ), - sink_pac=[ - # Codec Capability Setting 16_2 - PacRecord( - coding_format=CodingFormat(CodecID.LC3), - codec_specific_capabilities=CodecSpecificCapabilities( - supported_sampling_frequencies=( - SupportedSamplingFrequency.FREQ_16000 - ), - supported_frame_durations=( - SupportedFrameDuration.DURATION_10000_US_SUPPORTED - ), - supported_audio_channel_counts=[1], - min_octets_per_codec_frame=40, - max_octets_per_codec_frame=40, - supported_max_codec_frames_per_sdu=1, - ), - ), - # Codec Capability Setting 24_2 - PacRecord( - coding_format=CodingFormat(CodecID.LC3), - codec_specific_capabilities=CodecSpecificCapabilities( - supported_sampling_frequencies=( - SupportedSamplingFrequency.FREQ_48000 - ), - supported_frame_durations=( - SupportedFrameDuration.DURATION_10000_US_SUPPORTED - ), - supported_audio_channel_counts=[1], - min_octets_per_codec_frame=120, - max_octets_per_codec_frame=120, - supported_max_codec_frames_per_sdu=1, - ), - ), - ], + sink_pac=[_sink_pac_record()], ) ) - device.add_service(AudioStreamControlService(device, sink_ase_id=[1, 2])) + ascs = AudioStreamControlService(device, sink_ase_id=[1], source_ase_id=[2]) + device.add_service(ascs) advertising_data = ( bytes( @@ -143,44 +145,57 @@ async def main() -> None: + csis.get_advertising_data() + bytes(UnicastServerAdvertisingData()) ) - subprocess = await asyncio.create_subprocess_shell( - f'dlc3 | ffplay pipe:0', - stdin=asyncio.subprocess.PIPE, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ) - - stdin = subprocess.stdin - assert stdin - - # Write a fake LC3 header to dlc3. - stdin.write( - bytes([0x1C, 0xCC]) # Header. - + struct.pack( - ' None: + if state != AseStateMachine.State.STREAMING: + if file_output := file_outputs.pop(ase): + file_output.close() + else: + file_output = open(f'{datetime.datetime.now().isoformat()}.lc3', 'wb') + codec_configuration = ase.codec_specific_configuration + assert isinstance(codec_configuration, CodecSpecificConfiguration) + # Write a LC3 header. + file_output.write( + bytes([0x1C, 0xCC]) # Header. + + struct.pack( + '= 1.4.3 types-invoke >= 1.7.3 types-protobuf >= 4.21.0 + wasmtime == 1.0.0 avatar = pandora-avatar == 0.0.9 rootcanal == 1.10.0 ; python_version>='3.10'