From aba1ac0cea08e0badc70a4e48637cbfeb4100c71 Mon Sep 17 00:00:00 2001 From: Gilles Boccon-Gibod Date: Mon, 7 Aug 2023 11:34:58 -0700 Subject: [PATCH] use a dict instead of a series of ifs (+6 squashed commits) Squashed commits: [90f2024] fix import order [0edd321] add a few docstrings [77a0ac0] wip [adcf159] wip [96cbd67] wip [d8bfbab] wip (+1 squashed commit) Squashed commits: [43b4d66] wip (+2 squashed commits) Squashed commits: [3dafaa8] wip [5844026] wip (+1 squashed commit) Squashed commits: [4cbb35a] wip (+1 squashed commit) Squashed commits: [4d2b6d3] wip (+4 squashed commits) Squashed commits: [f2da510] wip [318c119] wip [923b4eb] wip [9d46365] wip use a dict instead of a series of ifs (+6 squashed commits) Squashed commits: [90f2024] fix import order [0edd321] add a few docstrings [77a0ac0] wip [adcf159] wip [96cbd67] wip [d8bfbab] wip --- .vscode/settings.json | 2 + bumble/a2dp.py | 16 +- bumble/avc.py | 520 ++++++++++ bumble/avctp.py | 291 ++++++ bumble/avdtp.py | 8 +- bumble/avrcp.py | 1916 +++++++++++++++++++++++++++++++++++ bumble/core.py | 18 +- bumble/helpers.py | 79 +- bumble/sdp.py | 6 +- bumble/utils.py | 65 +- examples/avrcp_as_sink.html | 274 +++++ examples/run_avrcp.py | 408 ++++++++ tests/avrcp_test.py | 246 +++++ tests/utils_test.py | 36 +- 14 files changed, 3831 insertions(+), 54 deletions(-) create mode 100644 bumble/avc.py create mode 100644 bumble/avctp.py create mode 100644 bumble/avrcp.py create mode 100644 examples/avrcp_as_sink.html create mode 100644 examples/run_avrcp.py create mode 100644 tests/avrcp_test.py diff --git a/.vscode/settings.json b/.vscode/settings.json index 93e9ece3..054260ed 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -12,7 +12,9 @@ "ASHA", "asyncio", "ATRAC", + "avctp", "avdtp", + "avrcp", "bitpool", "bitstruct", "BSCP", diff --git a/bumble/a2dp.py b/bumble/a2dp.py index 6d8fc478..653a0426 100644 --- a/bumble/a2dp.py +++ b/bumble/a2dp.py @@ -184,8 +184,12 @@ def make_audio_source_service_sdp_records(service_record_handle, version=(1, 3)) SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence( [ - DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE), - DataElement.unsigned_integer_16(version_int), + DataElement.sequence( + [ + DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE), + DataElement.unsigned_integer_16(version_int), + ] + ) ] ), ), @@ -234,8 +238,12 @@ def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)): SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence( [ - DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE), - DataElement.unsigned_integer_16(version_int), + DataElement.sequence( + [ + DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE), + DataElement.unsigned_integer_16(version_int), + ] + ) ] ), ), diff --git a/bumble/avc.py b/bumble/avc.py new file mode 100644 index 00000000..1d0a7dc9 --- /dev/null +++ b/bumble/avc.py @@ -0,0 +1,520 @@ +# Copyright 2021-2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +from __future__ import annotations +import enum +import struct +from typing import Dict, Type, Union, Tuple + +from bumble.utils import OpenIntEnum + + +# ----------------------------------------------------------------------------- +class Frame: + class SubunitType(enum.IntEnum): + # AV/C Digital Interface Command Set General Specification Version 4.1 + # Table 7.4 + MONITOR = 0x00 + AUDIO = 0x01 + PRINTER = 0x02 + DISC = 0x03 + TAPE_RECORDER_OR_PLAYER = 0x04 + TUNER = 0x05 + CA = 0x06 + CAMERA = 0x07 + PANEL = 0x09 + BULLETIN_BOARD = 0x0A + VENDOR_UNIQUE = 0x1C + EXTENDED = 0x1E + UNIT = 0x1F + + class OperationCode(OpenIntEnum): + # 0x00 - 0x0F: Unit and subunit commands + VENDOR_DEPENDENT = 0x00 + RESERVE = 0x01 + PLUG_INFO = 0x02 + + # 0x10 - 0x3F: Unit commands + DIGITAL_OUTPUT = 0x10 + DIGITAL_INPUT = 0x11 + CHANNEL_USAGE = 0x12 + OUTPUT_PLUG_SIGNAL_FORMAT = 0x18 + INPUT_PLUG_SIGNAL_FORMAT = 0x19 + GENERAL_BUS_SETUP = 0x1F + CONNECT_AV = 0x20 + DISCONNECT_AV = 0x21 + CONNECTIONS = 0x22 + CONNECT = 0x24 + DISCONNECT = 0x25 + UNIT_INFO = 0x30 + SUBUNIT_INFO = 0x31 + + # 0x40 - 0x7F: Subunit commands + PASS_THROUGH = 0x7C + GUI_UPDATE = 0x7D + PUSH_GUI_DATA = 0x7E + USER_ACTION = 0x7F + + # 0xA0 - 0xBF: Unit and subunit commands + VERSION = 0xB0 + POWER = 0xB2 + + subunit_type: SubunitType + subunit_id: int + opcode: OperationCode + operands: bytes + + @staticmethod + def subclass(subclass): + # Infer the opcode from the class name + if subclass.__name__.endswith("CommandFrame"): + short_name = subclass.__name__.replace("CommandFrame", "") + category_class = CommandFrame + elif subclass.__name__.endswith("ResponseFrame"): + short_name = subclass.__name__.replace("ResponseFrame", "") + category_class = ResponseFrame + else: + raise ValueError(f"invalid subclass name {subclass.__name__}") + + uppercase_indexes = [ + i for i in range(len(short_name)) if short_name[i].isupper() + ] + uppercase_indexes.append(len(short_name)) + words = [ + short_name[uppercase_indexes[i] : uppercase_indexes[i + 1]].upper() + for i in range(len(uppercase_indexes) - 1) + ] + opcode_name = "_".join(words) + opcode = Frame.OperationCode[opcode_name] + category_class.subclasses[opcode] = subclass + return subclass + + @staticmethod + def from_bytes(data: bytes) -> Frame: + if data[0] >> 4 != 0: + raise ValueError("first 4 bits must be 0s") + + ctype_or_response = data[0] & 0xF + subunit_type = Frame.SubunitType(data[1] >> 3) + subunit_id = data[1] & 7 + + if subunit_type == Frame.SubunitType.EXTENDED: + # Not supported + raise NotImplementedError("extended subunit types not supported") + + if subunit_id < 5: + opcode_offset = 2 + elif subunit_id == 5: + # Extended to the next byte + extension = data[2] + if extension == 0: + raise ValueError("extended subunit ID value reserved") + if extension == 0xFF: + subunit_id = 5 + 254 + data[3] + opcode_offset = 4 + else: + subunit_id = 5 + extension + opcode_offset = 3 + + elif subunit_id == 6: + raise ValueError("reserved subunit ID") + + opcode = Frame.OperationCode(data[opcode_offset]) + operands = data[opcode_offset + 1 :] + + # Look for a registered subclass + if ctype_or_response < 8: + # Command + ctype = CommandFrame.CommandType(ctype_or_response) + if c_subclass := CommandFrame.subclasses.get(opcode): + return c_subclass( + ctype, + subunit_type, + subunit_id, + *c_subclass.parse_operands(operands), + ) + return CommandFrame(ctype, subunit_type, subunit_id, opcode, operands) + else: + # Response + response = ResponseFrame.ResponseCode(ctype_or_response) + if r_subclass := ResponseFrame.subclasses.get(opcode): + return r_subclass( + response, + subunit_type, + subunit_id, + *r_subclass.parse_operands(operands), + ) + return ResponseFrame(response, subunit_type, subunit_id, opcode, operands) + + def to_bytes( + self, + ctype_or_response: Union[CommandFrame.CommandType, ResponseFrame.ResponseCode], + ) -> bytes: + # TODO: support extended subunit types and ids. + return ( + bytes( + [ + ctype_or_response, + self.subunit_type << 3 | self.subunit_id, + self.opcode, + ] + ) + + self.operands + ) + + def to_string(self, extra: str) -> str: + return ( + f"{self.__class__.__name__}({extra}" + f"subunit_type={self.subunit_type.name}, " + f"subunit_id=0x{self.subunit_id:02X}, " + f"opcode={self.opcode.name}, " + f"operands={self.operands.hex()})" + ) + + def __init__( + self, + subunit_type: SubunitType, + subunit_id: int, + opcode: OperationCode, + operands: bytes, + ) -> None: + self.subunit_type = subunit_type + self.subunit_id = subunit_id + self.opcode = opcode + self.operands = operands + + +# ----------------------------------------------------------------------------- +class CommandFrame(Frame): + class CommandType(OpenIntEnum): + # AV/C Digital Interface Command Set General Specification Version 4.1 + # Table 7.1 + CONTROL = 0x00 + STATUS = 0x01 + SPECIFIC_INQUIRY = 0x02 + NOTIFY = 0x03 + GENERAL_INQUIRY = 0x04 + + subclasses: Dict[Frame.OperationCode, Type[CommandFrame]] = {} + ctype: CommandType + + @staticmethod + def parse_operands(operands: bytes) -> Tuple: + raise NotImplementedError + + def __init__( + self, + ctype: CommandType, + subunit_type: Frame.SubunitType, + subunit_id: int, + opcode: Frame.OperationCode, + operands: bytes, + ) -> None: + super().__init__(subunit_type, subunit_id, opcode, operands) + self.ctype = ctype + + def __bytes__(self): + return self.to_bytes(self.ctype) + + def __str__(self): + return self.to_string(f"ctype={self.ctype.name}, ") + + +# ----------------------------------------------------------------------------- +class ResponseFrame(Frame): + class ResponseCode(OpenIntEnum): + # AV/C Digital Interface Command Set General Specification Version 4.1 + # Table 7.2 + NOT_IMPLEMENTED = 0x08 + ACCEPTED = 0x09 + REJECTED = 0x0A + IN_TRANSITION = 0x0B + IMPLEMENTED_OR_STABLE = 0x0C + CHANGED = 0x0D + INTERIM = 0x0F + + subclasses: Dict[Frame.OperationCode, Type[ResponseFrame]] = {} + response: ResponseCode + + @staticmethod + def parse_operands(operands: bytes) -> Tuple: + raise NotImplementedError + + def __init__( + self, + response: ResponseCode, + subunit_type: Frame.SubunitType, + subunit_id: int, + opcode: Frame.OperationCode, + operands: bytes, + ) -> None: + super().__init__(subunit_type, subunit_id, opcode, operands) + self.response = response + + def __bytes__(self): + return self.to_bytes(self.response) + + def __str__(self): + return self.to_string(f"response={self.response.name}, ") + + +# ----------------------------------------------------------------------------- +class VendorDependentFrame: + company_id: int + vendor_dependent_data: bytes + + @staticmethod + def parse_operands(operands: bytes) -> Tuple: + return ( + struct.unpack(">I", b"\x00" + operands[:3])[0], + operands[3:], + ) + + def make_operands(self) -> bytes: + return struct.pack(">I", self.company_id)[1:] + self.vendor_dependent_data + + def __init__(self, company_id: int, vendor_dependent_data: bytes): + self.company_id = company_id + self.vendor_dependent_data = vendor_dependent_data + + +# ----------------------------------------------------------------------------- +@Frame.subclass +class VendorDependentCommandFrame(VendorDependentFrame, CommandFrame): + def __init__( + self, + ctype: CommandFrame.CommandType, + subunit_type: Frame.SubunitType, + subunit_id: int, + company_id: int, + vendor_dependent_data: bytes, + ) -> None: + VendorDependentFrame.__init__(self, company_id, vendor_dependent_data) + CommandFrame.__init__( + self, + ctype, + subunit_type, + subunit_id, + Frame.OperationCode.VENDOR_DEPENDENT, + self.make_operands(), + ) + + def __str__(self): + return ( + f"VendorDependentCommandFrame(ctype={self.ctype.name}, " + f"subunit_type={self.subunit_type.name}, " + f"subunit_id=0x{self.subunit_id:02X}, " + f"company_id=0x{self.company_id:06X}, " + f"vendor_dependent_data={self.vendor_dependent_data.hex()})" + ) + + +# ----------------------------------------------------------------------------- +@Frame.subclass +class VendorDependentResponseFrame(VendorDependentFrame, ResponseFrame): + def __init__( + self, + response: ResponseFrame.ResponseCode, + subunit_type: Frame.SubunitType, + subunit_id: int, + company_id: int, + vendor_dependent_data: bytes, + ) -> None: + VendorDependentFrame.__init__(self, company_id, vendor_dependent_data) + ResponseFrame.__init__( + self, + response, + subunit_type, + subunit_id, + Frame.OperationCode.VENDOR_DEPENDENT, + self.make_operands(), + ) + + def __str__(self): + return ( + f"VendorDependentResponseFrame(response={self.response.name}, " + f"subunit_type={self.subunit_type.name}, " + f"subunit_id=0x{self.subunit_id:02X}, " + f"company_id=0x{self.company_id:06X}, " + f"vendor_dependent_data={self.vendor_dependent_data.hex()})" + ) + + +# ----------------------------------------------------------------------------- +class PassThroughFrame: + """ + See AV/C Panel Subunit Specification 1.1 - 9.4 PASS THROUGH control command + """ + + class StateFlag(enum.IntEnum): + PRESSED = 0 + RELEASED = 1 + + class OperationId(OpenIntEnum): + SELECT = 0x00 + UP = 0x01 + DOWN = 0x01 + LEFT = 0x03 + RIGHT = 0x04 + RIGHT_UP = 0x05 + RIGHT_DOWN = 0x06 + LEFT_UP = 0x07 + LEFT_DOWN = 0x08 + ROOT_MENU = 0x09 + SETUP_MENU = 0x0A + CONTENTS_MENU = 0x0B + FAVORITE_MENU = 0x0C + EXIT = 0x0D + NUMBER_0 = 0x20 + NUMBER_1 = 0x21 + NUMBER_2 = 0x22 + NUMBER_3 = 0x23 + NUMBER_4 = 0x24 + NUMBER_5 = 0x25 + NUMBER_6 = 0x26 + NUMBER_7 = 0x27 + NUMBER_8 = 0x28 + NUMBER_9 = 0x29 + DOT = 0x2A + ENTER = 0x2B + CLEAR = 0x2C + CHANNEL_UP = 0x30 + CHANNEL_DOWN = 0x31 + PREVIOUS_CHANNEL = 0x32 + SOUND_SELECT = 0x33 + INPUT_SELECT = 0x34 + DISPLAY_INFORMATION = 0x35 + HELP = 0x36 + PAGE_UP = 0x37 + PAGE_DOWN = 0x38 + POWER = 0x40 + VOLUME_UP = 0x41 + VOLUME_DOWN = 0x42 + MUTE = 0x43 + PLAY = 0x44 + STOP = 0x45 + PAUSE = 0x46 + RECORD = 0x47 + REWIND = 0x48 + FAST_FORWARD = 0x49 + EJECT = 0x4A + FORWARD = 0x4B + BACKWARD = 0x4C + ANGLE = 0x50 + SUBPICTURE = 0x51 + F1 = 0x71 + F2 = 0x72 + F3 = 0x73 + F4 = 0x74 + F5 = 0x75 + VENDOR_UNIQUE = 0x7E + + state_flag: StateFlag + operation_id: OperationId + operation_data: bytes + + @staticmethod + def parse_operands(operands: bytes) -> Tuple: + return ( + PassThroughFrame.StateFlag(operands[0] >> 7), + PassThroughFrame.OperationId(operands[0] & 0x7F), + operands[1 : 1 + operands[1]], + ) + + def make_operands(self): + return ( + bytes([self.state_flag << 7 | self.operation_id, len(self.operation_data)]) + + self.operation_data + ) + + def __init__( + self, + state_flag: StateFlag, + operation_id: OperationId, + operation_data: bytes, + ) -> None: + if len(operation_data) > 255: + raise ValueError("operation data must be <= 255 bytes") + self.state_flag = state_flag + self.operation_id = operation_id + self.operation_data = operation_data + + +# ----------------------------------------------------------------------------- +@Frame.subclass +class PassThroughCommandFrame(PassThroughFrame, CommandFrame): + def __init__( + self, + ctype: CommandFrame.CommandType, + subunit_type: Frame.SubunitType, + subunit_id: int, + state_flag: PassThroughFrame.StateFlag, + operation_id: PassThroughFrame.OperationId, + operation_data: bytes, + ) -> None: + PassThroughFrame.__init__(self, state_flag, operation_id, operation_data) + CommandFrame.__init__( + self, + ctype, + subunit_type, + subunit_id, + Frame.OperationCode.PASS_THROUGH, + self.make_operands(), + ) + + def __str__(self): + return ( + f"PassThroughCommandFrame(ctype={self.ctype.name}, " + f"subunit_type={self.subunit_type.name}, " + f"subunit_id=0x{self.subunit_id:02X}, " + f"state_flag={self.state_flag.name}, " + f"operation_id={self.operation_id.name}, " + f"operation_data={self.operation_data.hex()})" + ) + + +# ----------------------------------------------------------------------------- +@Frame.subclass +class PassThroughResponseFrame(PassThroughFrame, ResponseFrame): + def __init__( + self, + response: ResponseFrame.ResponseCode, + subunit_type: Frame.SubunitType, + subunit_id: int, + state_flag: PassThroughFrame.StateFlag, + operation_id: PassThroughFrame.OperationId, + operation_data: bytes, + ) -> None: + PassThroughFrame.__init__(self, state_flag, operation_id, operation_data) + ResponseFrame.__init__( + self, + response, + subunit_type, + subunit_id, + Frame.OperationCode.PASS_THROUGH, + self.make_operands(), + ) + + def __str__(self): + return ( + f"PassThroughResponseFrame(response={self.response.name}, " + f"subunit_type={self.subunit_type.name}, " + f"subunit_id=0x{self.subunit_id:02X}, " + f"state_flag={self.state_flag.name}, " + f"operation_id={self.operation_id.name}, " + f"operation_data={self.operation_data.hex()})" + ) diff --git a/bumble/avctp.py b/bumble/avctp.py new file mode 100644 index 00000000..22713249 --- /dev/null +++ b/bumble/avctp.py @@ -0,0 +1,291 @@ +# Copyright 2021-2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +from __future__ import annotations +from enum import IntEnum +import logging +import struct +from typing import Callable, cast, Dict, Optional + +from bumble.colors import color +from bumble import avc +from bumble import l2cap + +# ----------------------------------------------------------------------------- +# Logging +# ----------------------------------------------------------------------------- +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +# Constants +# ----------------------------------------------------------------------------- +AVCTP_PSM = 0x0017 +AVCTP_BROWSING_PSM = 0x001B + + +# ----------------------------------------------------------------------------- +class MessageAssembler: + Callback = Callable[[int, bool, bool, int, bytes], None] + + transaction_label: int + pid: int + c_r: int + ipid: int + payload: bytes + number_of_packets: int + packets_received: int + + def __init__(self, callback: Callback) -> None: + self.callback = callback + self.reset() + + def reset(self) -> None: + self.packets_received = 0 + self.transaction_label = -1 + self.pid = -1 + self.c_r = -1 + self.ipid = -1 + self.payload = b'' + self.number_of_packets = 0 + self.packet_count = 0 + + def on_pdu(self, pdu: bytes) -> None: + self.packets_received += 1 + + transaction_label = pdu[0] >> 4 + packet_type = Protocol.PacketType((pdu[0] >> 2) & 3) + c_r = (pdu[0] >> 1) & 1 + ipid = pdu[0] & 1 + + if c_r == 0 and ipid != 0: + logger.warning("invalid IPID in command frame") + self.reset() + return + + pid_offset = 1 + if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.START): + if self.transaction_label >= 0: + # We are already in a transaction + logger.warning("received START or SINGLE fragment while in transaction") + self.reset() + self.packets_received = 1 + + if packet_type == Protocol.PacketType.START: + self.number_of_packets = pdu[1] + pid_offset = 2 + + pid = struct.unpack_from(">H", pdu, pid_offset)[0] + self.payload += pdu[pid_offset + 2 :] + + if packet_type in (Protocol.PacketType.CONTINUE, Protocol.PacketType.END): + if transaction_label != self.transaction_label: + logger.warning("transaction label does not match") + self.reset() + return + + if pid != self.pid: + logger.warning("PID does not match") + self.reset() + return + + if c_r != self.c_r: + logger.warning("C/R does not match") + self.reset() + return + + if self.packets_received > self.number_of_packets: + logger.warning("too many fragments in transaction") + self.reset() + return + + if packet_type == Protocol.PacketType.END: + if self.packets_received != self.number_of_packets: + logger.warning("premature END") + self.reset() + return + else: + self.transaction_label = transaction_label + self.c_r = c_r + self.ipid = ipid + self.pid = pid + + if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.END): + self.on_message_complete() + + def on_message_complete(self): + try: + self.callback( + self.transaction_label, + self.c_r == 0, + self.ipid != 0, + self.pid, + self.payload, + ) + except Exception as error: + logger.exception(color(f"!!! exception in callback: {error}", "red")) + + self.reset() + + +# ----------------------------------------------------------------------------- +class Protocol: + CommandHandler = Callable[[int, avc.CommandFrame], None] + command_handlers: Dict[int, CommandHandler] # Command handlers, by PID + ResponseHandler = Callable[[int, Optional[avc.ResponseFrame]], None] + response_handlers: Dict[int, ResponseHandler] # Response handlers, by PID + next_transaction_label: int + message_assembler: MessageAssembler + + class PacketType(IntEnum): + SINGLE = 0b00 + START = 0b01 + CONTINUE = 0b10 + END = 0b11 + + def __init__(self, l2cap_channel: l2cap.ClassicChannel) -> None: + self.command_handlers = {} + self.response_handlers = {} + self.l2cap_channel = l2cap_channel + self.message_assembler = MessageAssembler(self.on_message) + + # Register to receive PDUs from the channel + l2cap_channel.sink = self.on_pdu + l2cap_channel.on("open", self.on_l2cap_channel_open) + l2cap_channel.on("close", self.on_l2cap_channel_close) + + def on_l2cap_channel_open(self): + logger.debug(color("<<< AVCTP channel open", "magenta")) + + def on_l2cap_channel_close(self): + logger.debug(color("<<< AVCTP channel closed", "magenta")) + + def on_pdu(self, pdu: bytes) -> None: + self.message_assembler.on_pdu(pdu) + + def on_message( + self, + transaction_label: int, + is_command: bool, + ipid: bool, + pid: int, + payload: bytes, + ) -> None: + logger.debug( + f"<<< AVCTP Message: pid={pid}, " + f"transaction_label={transaction_label}, " + f"is_command={is_command}, " + f"ipid={ipid}, " + f"payload={payload.hex()}" + ) + + # Check for invalid PID responses. + if ipid: + logger.debug(f"received IPID for PID={pid}") + + # Find the appropriate handler. + if is_command: + if pid not in self.command_handlers: + logger.warning(f"no command handler for PID {pid}") + self.send_ipid(transaction_label, pid) + return + + command_frame = cast(avc.CommandFrame, avc.Frame.from_bytes(payload)) + self.command_handlers[pid](transaction_label, command_frame) + else: + if pid not in self.response_handlers: + logger.warning(f"no response handler for PID {pid}") + return + + # By convention, for an ipid, send a None payload to the response handler. + if ipid: + response_frame = None + else: + response_frame = cast(avc.ResponseFrame, avc.Frame.from_bytes(payload)) + + self.response_handlers[pid](transaction_label, response_frame) + + def send_message( + self, + transaction_label: int, + is_command: bool, + ipid: bool, + pid: int, + payload: bytes, + ): + # TODO: fragment large messages + packet_type = Protocol.PacketType.SINGLE + pdu = ( + struct.pack( + ">BH", + transaction_label << 4 + | packet_type << 2 + | (0 if is_command else 1) << 1 + | (1 if ipid else 0), + pid, + ) + + payload + ) + self.l2cap_channel.send_pdu(pdu) + + def send_command(self, transaction_label: int, pid: int, payload: bytes) -> None: + logger.debug( + ">>> AVCTP command: " + f"transaction_label={transaction_label}, " + f"pid={pid}, " + f"payload={payload.hex()}" + ) + self.send_message(transaction_label, True, False, pid, payload) + + def send_response(self, transaction_label: int, pid: int, payload: bytes): + logger.debug( + ">>> AVCTP response: " + f"transaction_label={transaction_label}, " + f"pid={pid}, " + f"payload={payload.hex()}" + ) + self.send_message(transaction_label, False, False, pid, payload) + + def send_ipid(self, transaction_label: int, pid: int) -> None: + logger.debug( + ">>> AVCTP ipid: " f"transaction_label={transaction_label}, " f"pid={pid}" + ) + self.send_message(transaction_label, False, True, pid, b'') + + def register_command_handler( + self, pid: int, handler: Protocol.CommandHandler + ) -> None: + self.command_handlers[pid] = handler + + def unregister_command_handler( + self, pid: int, handler: Protocol.CommandHandler + ) -> None: + if pid not in self.command_handlers or self.command_handlers[pid] != handler: + raise ValueError("command handler not registered") + del self.command_handlers[pid] + + def register_response_handler( + self, pid: int, handler: Protocol.ResponseHandler + ) -> None: + self.response_handlers[pid] = handler + + def unregister_response_handler( + self, pid: int, handler: Protocol.ResponseHandler + ) -> None: + if pid not in self.response_handlers or self.response_handlers[pid] != handler: + raise ValueError("response handler not registered") + del self.response_handlers[pid] diff --git a/bumble/avdtp.py b/bumble/avdtp.py index 103597f9..3be1e157 100644 --- a/bumble/avdtp.py +++ b/bumble/avdtp.py @@ -241,7 +241,10 @@ async def find_avdtp_service_with_sdp_client( ) if profile_descriptor_list: for profile_descriptor in profile_descriptor_list.value: - if len(profile_descriptor.value) >= 2: + if ( + profile_descriptor.type == sdp.DataElement.SEQUENCE + and len(profile_descriptor.value) >= 2 + ): avdtp_version_major = profile_descriptor.value[1].value >> 8 avdtp_version_minor = profile_descriptor.value[1].value & 0xFF return (avdtp_version_major, avdtp_version_minor) @@ -511,7 +514,8 @@ def on_message_complete(self) -> None: try: self.callback(self.transaction_label, message) except Exception as error: - logger.warning(color(f'!!! exception in callback: {error}')) + logger.exception(color(f'!!! exception in callback: {error}', 'red')) + self.reset() diff --git a/bumble/avrcp.py b/bumble/avrcp.py new file mode 100644 index 00000000..aef6dd55 --- /dev/null +++ b/bumble/avrcp.py @@ -0,0 +1,1916 @@ +# Copyright 2021-2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +from __future__ import annotations +import asyncio +from dataclasses import dataclass +import enum +import logging +import struct +from typing import ( + AsyncIterator, + Awaitable, + Callable, + cast, + Dict, + Iterable, + List, + Optional, + Sequence, + SupportsBytes, + Tuple, + Type, + TypeVar, + Union, +) + +import pyee + +from bumble.colors import color +from bumble.device import Device, Connection +from bumble.sdp import ( + SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, + SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, + SDP_PUBLIC_BROWSE_ROOT, + SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, + SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, + SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, + SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID, + DataElement, + ServiceAttribute, +) +from bumble.utils import AsyncRunner, OpenIntEnum +from bumble.core import ( + ProtocolError, + BT_L2CAP_PROTOCOL_ID, + BT_AVCTP_PROTOCOL_ID, + BT_AV_REMOTE_CONTROL_SERVICE, + BT_AV_REMOTE_CONTROL_CONTROLLER_SERVICE, + BT_AV_REMOTE_CONTROL_TARGET_SERVICE, +) +from bumble import l2cap +from bumble import avc +from bumble import avctp +from bumble import utils + + +# ----------------------------------------------------------------------------- +# Logging +# ----------------------------------------------------------------------------- +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +# Constants +# ----------------------------------------------------------------------------- +AVRCP_PID = 0x110E +AVRCP_BLUETOOTH_SIG_COMPANY_ID = 0x001958 + + +# ----------------------------------------------------------------------------- +def make_controller_service_sdp_records( + service_record_handle: int, + avctp_version: Tuple[int, int] = (1, 4), + avrcp_version: Tuple[int, int] = (1, 6), + supported_features: int = 1, +) -> List[ServiceAttribute]: + # TODO: support a way to compute the supported features from a feature list + avctp_version_int = avctp_version[0] << 8 | avctp_version[1] + avrcp_version_int = avrcp_version[0] << 8 | avrcp_version[1] + + return [ + ServiceAttribute( + SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, + DataElement.unsigned_integer_32(service_record_handle), + ), + ServiceAttribute( + SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, + DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]), + ), + ServiceAttribute( + SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, + DataElement.sequence( + [ + DataElement.uuid(BT_AV_REMOTE_CONTROL_SERVICE), + DataElement.uuid(BT_AV_REMOTE_CONTROL_CONTROLLER_SERVICE), + ] + ), + ), + ServiceAttribute( + SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, + DataElement.sequence( + [ + DataElement.sequence( + [ + DataElement.uuid(BT_L2CAP_PROTOCOL_ID), + DataElement.unsigned_integer_16(avctp.AVCTP_PSM), + ] + ), + DataElement.sequence( + [ + DataElement.uuid(BT_AVCTP_PROTOCOL_ID), + DataElement.unsigned_integer_16(avctp_version_int), + ] + ), + ] + ), + ), + ServiceAttribute( + SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, + DataElement.sequence( + [ + DataElement.sequence( + [ + DataElement.uuid(BT_AV_REMOTE_CONTROL_SERVICE), + DataElement.unsigned_integer_16(avrcp_version_int), + ] + ), + ] + ), + ), + ServiceAttribute( + SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID, + DataElement.unsigned_integer_16(supported_features), + ), + ] + + +# ----------------------------------------------------------------------------- +def make_target_service_sdp_records( + service_record_handle: int, + avctp_version: Tuple[int, int] = (1, 4), + avrcp_version: Tuple[int, int] = (1, 6), + supported_features: int = 0x23, +) -> List[ServiceAttribute]: + # TODO: support a way to compute the supported features from a feature list + avctp_version_int = avctp_version[0] << 8 | avctp_version[1] + avrcp_version_int = avrcp_version[0] << 8 | avrcp_version[1] + + return [ + ServiceAttribute( + SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, + DataElement.unsigned_integer_32(service_record_handle), + ), + ServiceAttribute( + SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, + DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]), + ), + ServiceAttribute( + SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, + DataElement.sequence( + [ + DataElement.uuid(BT_AV_REMOTE_CONTROL_TARGET_SERVICE), + ] + ), + ), + ServiceAttribute( + SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, + DataElement.sequence( + [ + DataElement.sequence( + [ + DataElement.uuid(BT_L2CAP_PROTOCOL_ID), + DataElement.unsigned_integer_16(avctp.AVCTP_PSM), + ] + ), + DataElement.sequence( + [ + DataElement.uuid(BT_AVCTP_PROTOCOL_ID), + DataElement.unsigned_integer_16(avctp_version_int), + ] + ), + ] + ), + ), + ServiceAttribute( + SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, + DataElement.sequence( + [ + DataElement.sequence( + [ + DataElement.uuid(BT_AV_REMOTE_CONTROL_SERVICE), + DataElement.unsigned_integer_16(avrcp_version_int), + ] + ), + ] + ), + ), + ServiceAttribute( + SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID, + DataElement.unsigned_integer_16(supported_features), + ), + ] + + +# ----------------------------------------------------------------------------- +def _decode_attribute_value(value: bytes, character_set: CharacterSetId) -> str: + try: + if character_set == CharacterSetId.UTF_8: + return value.decode("utf-8") + return value.decode("ascii") + except UnicodeDecodeError: + logger.warning(f"cannot decode string with bytes: {value.hex()}") + return "" + + +# ----------------------------------------------------------------------------- +class PduAssembler: + """ + PDU Assembler to support fragmented PDUs are defined in: + Audio/Video Remote Control / Profile Specification + 6.3.1 AVRCP specific AV//C commands + """ + + pdu_id: Optional[Protocol.PduId] + payload: bytes + + def __init__(self, callback: Callable[[Protocol.PduId, bytes], None]) -> None: + self.callback = callback + self.reset() + + def reset(self) -> None: + self.pdu_id = None + self.parameter = b'' + + def on_pdu(self, pdu: bytes) -> None: + pdu_id = Protocol.PduId(pdu[0]) + packet_type = Protocol.PacketType(pdu[1] & 3) + parameter_length = struct.unpack_from('>H', pdu, 2)[0] + parameter = pdu[4 : 4 + parameter_length] + if len(parameter) != parameter_length: + logger.warning("parameter length exceeds pdu size") + self.reset() + return + + if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.START): + if self.pdu_id is not None: + # We are already in a PDU + logger.warning("received START or SINGLE fragment while in pdu") + self.reset() + + if packet_type in (Protocol.PacketType.CONTINUE, Protocol.PacketType.END): + if pdu_id != self.pdu_id: + logger.warning("PID does not match") + self.reset() + return + else: + self.pdu_id = pdu_id + + self.parameter += parameter + + if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.END): + self.on_pdu_complete() + + def on_pdu_complete(self) -> None: + assert self.pdu_id is not None + try: + self.callback(self.pdu_id, self.parameter) + except Exception as error: + logger.exception(color(f'!!! exception in callback: {error}', 'red')) + + self.reset() + + +# ----------------------------------------------------------------------------- +@dataclass +class Command: + pdu_id: Protocol.PduId + parameter: bytes + + def to_string(self, properties: Dict[str, str]) -> str: + properties_str = ",".join( + [f"{name}={value}" for name, value in properties.items()] + ) + return f"Command[{self.pdu_id.name}]({properties_str})" + + def __str__(self) -> str: + return self.to_string({"parameters": self.parameter.hex()}) + + def __repr__(self) -> str: + return str(self) + + +# ----------------------------------------------------------------------------- +class GetCapabilitiesCommand(Command): + class CapabilityId(OpenIntEnum): + COMPANY_ID = 0x02 + EVENTS_SUPPORTED = 0x03 + + capability_id: CapabilityId + + @classmethod + def from_bytes(cls, pdu: bytes) -> GetCapabilitiesCommand: + return cls(cls.CapabilityId(pdu[0])) + + def __init__(self, capability_id: CapabilityId) -> None: + super().__init__(Protocol.PduId.GET_CAPABILITIES, bytes([capability_id])) + self.capability_id = capability_id + + def __str__(self) -> str: + return self.to_string({"capability_id": self.capability_id.name}) + + +# ----------------------------------------------------------------------------- +class GetPlayStatusCommand(Command): + @classmethod + def from_bytes(cls, _: bytes) -> GetPlayStatusCommand: + return cls() + + def __init__(self) -> None: + super().__init__(Protocol.PduId.GET_PLAY_STATUS, b'') + + +# ----------------------------------------------------------------------------- +class GetElementAttributesCommand(Command): + identifier: int + attribute_ids: List[MediaAttributeId] + + @classmethod + def from_bytes(cls, pdu: bytes) -> GetElementAttributesCommand: + identifier = struct.unpack_from(">Q", pdu)[0] + num_attributes = pdu[8] + attribute_ids = [MediaAttributeId(pdu[9 + i]) for i in range(num_attributes)] + return cls(identifier, attribute_ids) + + def __init__( + self, identifier: int, attribute_ids: Sequence[MediaAttributeId] + ) -> None: + parameter = struct.pack(">QB", identifier, len(attribute_ids)) + b''.join( + [struct.pack(">I", int(attribute_id)) for attribute_id in attribute_ids] + ) + super().__init__(Protocol.PduId.GET_ELEMENT_ATTRIBUTES, parameter) + self.identifier = identifier + self.attribute_ids = list(attribute_ids) + + +# ----------------------------------------------------------------------------- +class SetAbsoluteVolumeCommand(Command): + MAXIMUM_VOLUME = 0x7F + + volume: int + + @classmethod + def from_bytes(cls, pdu: bytes) -> SetAbsoluteVolumeCommand: + return cls(pdu[0]) + + def __init__(self, volume: int) -> None: + super().__init__(Protocol.PduId.SET_ABSOLUTE_VOLUME, bytes([volume])) + self.volume = volume + + def __str__(self) -> str: + return self.to_string({"volume": str(self.volume)}) + + +# ----------------------------------------------------------------------------- +class RegisterNotificationCommand(Command): + event_id: EventId + playback_interval: int + + @classmethod + def from_bytes(cls, pdu: bytes) -> RegisterNotificationCommand: + event_id = EventId(pdu[0]) + playback_interval = struct.unpack_from(">I", pdu, 1)[0] + return cls(event_id, playback_interval) + + def __init__(self, event_id: EventId, playback_interval: int) -> None: + super().__init__( + Protocol.PduId.REGISTER_NOTIFICATION, + struct.pack(">BI", int(event_id), playback_interval), + ) + self.event_id = event_id + self.playback_interval = playback_interval + + def __str__(self) -> str: + return self.to_string( + { + "event_id": self.event_id.name, + "playback_interval": str(self.playback_interval), + } + ) + + +# ----------------------------------------------------------------------------- +@dataclass +class Response: + pdu_id: Protocol.PduId + parameter: bytes + + def to_string(self, properties: Dict[str, str]) -> str: + properties_str = ",".join( + [f"{name}={value}" for name, value in properties.items()] + ) + return f"Response[{self.pdu_id.name}]({properties_str})" + + def __str__(self) -> str: + return self.to_string({"parameter": self.parameter.hex()}) + + def __repr__(self) -> str: + return str(self) + + +# ----------------------------------------------------------------------------- +class RejectedResponse(Response): + status_code: Protocol.StatusCode + + @classmethod + def from_bytes(cls, pdu_id: Protocol.PduId, pdu: bytes) -> RejectedResponse: + return cls(pdu_id, Protocol.StatusCode(pdu[0])) + + def __init__( + self, pdu_id: Protocol.PduId, status_code: Protocol.StatusCode + ) -> None: + super().__init__(pdu_id, bytes([int(status_code)])) + self.status_code = status_code + + def __str__(self) -> str: + return self.to_string( + { + "status_code": self.status_code.name, + } + ) + + +# ----------------------------------------------------------------------------- +class NotImplementedResponse(Response): + @classmethod + def from_bytes(cls, pdu_id: Protocol.PduId, pdu: bytes) -> NotImplementedResponse: + return cls(pdu_id, pdu[1:]) + + +# ----------------------------------------------------------------------------- +class GetCapabilitiesResponse(Response): + capability_id: GetCapabilitiesCommand.CapabilityId + capabilities: List[SupportsBytes] + + @classmethod + def from_bytes(cls, pdu: bytes) -> GetCapabilitiesResponse: + if len(pdu) < 2: + # Possibly a reject response. + return cls(GetCapabilitiesCommand.CapabilityId(0), []) + + # Assume that the payloads all follow the same pattern: + # + capability_id = GetCapabilitiesCommand.CapabilityId(pdu[0]) + capability_count = pdu[1] + + capabilities: List[SupportsBytes] + if capability_id == GetCapabilitiesCommand.CapabilityId.EVENTS_SUPPORTED: + capabilities = [EventId(pdu[2 + x]) for x in range(capability_count)] + else: + capability_size = (len(pdu) - 2) // capability_count + capabilities = [ + pdu[x : x + capability_size] + for x in range(2, len(pdu), capability_size) + ] + + return cls(capability_id, capabilities) + + def __init__( + self, + capability_id: GetCapabilitiesCommand.CapabilityId, + capabilities: Sequence[SupportsBytes], + ) -> None: + super().__init__( + Protocol.PduId.GET_CAPABILITIES, + bytes([capability_id, len(capabilities)]) + + b''.join(bytes(capability) for capability in capabilities), + ) + self.capability_id = capability_id + self.capabilities = list(capabilities) + + def __str__(self) -> str: + return self.to_string( + { + "capability_id": self.capability_id.name, + "capabilities": str(self.capabilities), + } + ) + + +# ----------------------------------------------------------------------------- +class GetPlayStatusResponse(Response): + song_length: int + song_position: int + play_status: PlayStatus + + @classmethod + def from_bytes(cls, pdu: bytes) -> GetPlayStatusResponse: + (song_length, song_position) = struct.unpack_from(">II", pdu, 0) + play_status = PlayStatus(pdu[8]) + + return cls(song_length, song_position, play_status) + + def __init__( + self, + song_length: int, + song_position: int, + play_status: PlayStatus, + ) -> None: + super().__init__( + Protocol.PduId.GET_PLAY_STATUS, + struct.pack(">IIB", song_length, song_position, int(play_status)), + ) + self.song_length = song_length + self.song_position = song_position + self.play_status = play_status + + def __str__(self) -> str: + return self.to_string( + { + "song_length": str(self.song_length), + "song_position": str(self.song_position), + "play_status": self.play_status.name, + } + ) + + +# ----------------------------------------------------------------------------- +class GetElementAttributesResponse(Response): + attributes: List[MediaAttribute] + + @classmethod + def from_bytes(cls, pdu: bytes) -> GetElementAttributesResponse: + num_attributes = pdu[0] + offset = 1 + attributes: List[MediaAttribute] = [] + for _ in range(num_attributes): + ( + attribute_id_int, + character_set_id_int, + attribute_value_length, + ) = struct.unpack_from(">IHH", pdu, offset) + attribute_value_bytes = pdu[ + offset + 8 : offset + 8 + attribute_value_length + ] + attribute_id = MediaAttributeId(attribute_id_int) + character_set_id = CharacterSetId(character_set_id_int) + attribute_value = _decode_attribute_value( + attribute_value_bytes, character_set_id + ) + attributes.append( + MediaAttribute(attribute_id, character_set_id, attribute_value) + ) + offset += 8 + attribute_value_length + + return cls(attributes) + + def __init__(self, attributes: Sequence[MediaAttribute]) -> None: + parameter = bytes([len(attributes)]) + for attribute in attributes: + attribute_value_bytes = attribute.attribute_value.encode("utf-8") + parameter += ( + struct.pack( + ">IHH", + int(attribute.attribute_id), + int(CharacterSetId.UTF_8), + len(attribute_value_bytes), + ) + + attribute_value_bytes + ) + super().__init__( + Protocol.PduId.GET_ELEMENT_ATTRIBUTES, + parameter, + ) + self.attributes = list(attributes) + + def __str__(self) -> str: + attribute_strs = [str(attribute) for attribute in self.attributes] + return self.to_string( + { + "attributes": f"[{', '.join(attribute_strs)}]", + } + ) + + +# ----------------------------------------------------------------------------- +class SetAbsoluteVolumeResponse(Response): + volume: int + + @classmethod + def from_bytes(cls, pdu: bytes) -> SetAbsoluteVolumeResponse: + return cls(pdu[0]) + + def __init__(self, volume: int) -> None: + super().__init__(Protocol.PduId.SET_ABSOLUTE_VOLUME, bytes([volume])) + self.volume = volume + + def __str__(self) -> str: + return self.to_string({"volume": str(self.volume)}) + + +# ----------------------------------------------------------------------------- +class RegisterNotificationResponse(Response): + event: Event + + @classmethod + def from_bytes(cls, pdu: bytes) -> RegisterNotificationResponse: + return cls(Event.from_bytes(pdu)) + + def __init__(self, event: Event) -> None: + super().__init__( + Protocol.PduId.REGISTER_NOTIFICATION, + bytes(event), + ) + self.event = event + + def __str__(self) -> str: + return self.to_string( + { + "event": str(self.event), + } + ) + + +# ----------------------------------------------------------------------------- +class EventId(OpenIntEnum): + PLAYBACK_STATUS_CHANGED = 0x01 + TRACK_CHANGED = 0x02 + TRACK_REACHED_END = 0x03 + TRACK_REACHED_START = 0x04 + PLAYBACK_POS_CHANGED = 0x05 + BATT_STATUS_CHANGED = 0x06 + SYSTEM_STATUS_CHANGED = 0x07 + PLAYER_APPLICATION_SETTING_CHANGED = 0x08 + NOW_PLAYING_CONTENT_CHANGED = 0x09 + AVAILABLE_PLAYERS_CHANGED = 0x0A + ADDRESSED_PLAYER_CHANGED = 0x0B + UIDS_CHANGED = 0x0C + VOLUME_CHANGED = 0x0D + + def __bytes__(self) -> bytes: + return bytes([int(self)]) + + +# ----------------------------------------------------------------------------- +class CharacterSetId(OpenIntEnum): + UTF_8 = 0x06 + + +# ----------------------------------------------------------------------------- +class MediaAttributeId(OpenIntEnum): + TITLE = 0x01 + ARTIST_NAME = 0x02 + ALBUM_NAME = 0x03 + TRACK_NUMBER = 0x04 + TOTAL_NUMBER_OF_TRACKS = 0x05 + GENRE = 0x06 + PLAYING_TIME = 0x07 + DEFAULT_COVER_ART = 0x08 + + +# ----------------------------------------------------------------------------- +@dataclass +class MediaAttribute: + attribute_id: MediaAttributeId + character_set_id: CharacterSetId + attribute_value: str + + +# ----------------------------------------------------------------------------- +class PlayStatus(OpenIntEnum): + STOPPED = 0x00 + PLAYING = 0x01 + PAUSED = 0x02 + FWD_SEEK = 0x03 + REV_SEEK = 0x04 + ERROR = 0xFF + + +# ----------------------------------------------------------------------------- +@dataclass +class SongAndPlayStatus: + song_length: int + song_position: int + play_status: PlayStatus + + +# ----------------------------------------------------------------------------- +class ApplicationSetting: + class AttributeId(OpenIntEnum): + EQUALIZER_ON_OFF = 0x01 + REPEAT_MODE = 0x02 + SHUFFLE_ON_OFF = 0x03 + SCAN_ON_OFF = 0x04 + + class EqualizerOnOffStatus(OpenIntEnum): + OFF = 0x01 + ON = 0x02 + + class RepeatModeStatus(OpenIntEnum): + OFF = 0x01 + SINGLE_TRACK_REPEAT = 0x02 + ALL_TRACK_REPEAT = 0x03 + GROUP_REPEAT = 0x04 + + class ShuffleOnOffStatus(OpenIntEnum): + OFF = 0x01 + ALL_TRACKS_SHUFFLE = 0x02 + GROUP_SHUFFLE = 0x03 + + class ScanOnOffStatus(OpenIntEnum): + OFF = 0x01 + ALL_TRACKS_SCAN = 0x02 + GROUP_SCAN = 0x03 + + class GenericValue(OpenIntEnum): + pass + + +# ----------------------------------------------------------------------------- +@dataclass +class Event: + event_id: EventId + + @classmethod + def from_bytes(cls, pdu: bytes) -> Event: + event_id = EventId(pdu[0]) + subclass = EVENT_SUBCLASSES.get(event_id, GenericEvent) + return subclass.from_bytes(pdu) + + def __bytes__(self) -> bytes: + return bytes([self.event_id]) + + +# ----------------------------------------------------------------------------- +@dataclass +class GenericEvent(Event): + data: bytes + + @classmethod + def from_bytes(cls, pdu: bytes) -> GenericEvent: + return cls(event_id=EventId(pdu[0]), data=pdu[1:]) + + def __bytes__(self) -> bytes: + return bytes([self.event_id]) + self.data + + +# ----------------------------------------------------------------------------- +@dataclass +class PlaybackStatusChangedEvent(Event): + play_status: PlayStatus + + @classmethod + def from_bytes(cls, pdu: bytes) -> PlaybackStatusChangedEvent: + return cls(play_status=PlayStatus(pdu[1])) + + def __init__(self, play_status: PlayStatus) -> None: + super().__init__(EventId.PLAYBACK_STATUS_CHANGED) + self.play_status = play_status + + def __bytes__(self) -> bytes: + return bytes([self.event_id]) + bytes([self.play_status]) + + +# ----------------------------------------------------------------------------- +@dataclass +class PlaybackPositionChangedEvent(Event): + playback_position: int + + @classmethod + def from_bytes(cls, pdu: bytes) -> PlaybackPositionChangedEvent: + return cls(playback_position=struct.unpack_from(">I", pdu, 1)[0]) + + def __init__(self, playback_position: int) -> None: + super().__init__(EventId.PLAYBACK_POS_CHANGED) + self.playback_position = playback_position + + def __bytes__(self) -> bytes: + return bytes([self.event_id]) + struct.pack(">I", self.playback_position) + + +# ----------------------------------------------------------------------------- +@dataclass +class TrackChangedEvent(Event): + identifier: bytes + + @classmethod + def from_bytes(cls, pdu: bytes) -> TrackChangedEvent: + return cls(identifier=pdu[1:]) + + def __init__(self, identifier: bytes) -> None: + super().__init__(EventId.TRACK_CHANGED) + self.identifier = identifier + + def __bytes__(self) -> bytes: + return bytes([self.event_id]) + self.identifier + + +# ----------------------------------------------------------------------------- +@dataclass +class PlayerApplicationSettingChangedEvent(Event): + @dataclass + class Setting: + attribute_id: ApplicationSetting.AttributeId + value_id: OpenIntEnum + + player_application_settings: List[Setting] + + @classmethod + def from_bytes(cls, pdu: bytes) -> PlayerApplicationSettingChangedEvent: + def setting(attribute_id_int: int, value_id_int: int): + attribute_id = ApplicationSetting.AttributeId(attribute_id_int) + value_id: OpenIntEnum + if attribute_id == ApplicationSetting.AttributeId.EQUALIZER_ON_OFF: + value_id = ApplicationSetting.EqualizerOnOffStatus(value_id_int) + elif attribute_id == ApplicationSetting.AttributeId.REPEAT_MODE: + value_id = ApplicationSetting.RepeatModeStatus(value_id_int) + elif attribute_id == ApplicationSetting.AttributeId.SHUFFLE_ON_OFF: + value_id = ApplicationSetting.ShuffleOnOffStatus(value_id_int) + elif attribute_id == ApplicationSetting.AttributeId.SCAN_ON_OFF: + value_id = ApplicationSetting.ScanOnOffStatus(value_id_int) + else: + value_id = ApplicationSetting.GenericValue(value_id_int) + + return cls.Setting(attribute_id, value_id) + + settings = [ + setting(pdu[2 + (i * 2)], pdu[2 + (i * 2) + 1]) for i in range(pdu[1]) + ] + return cls(player_application_settings=settings) + + def __init__(self, player_application_settings: Sequence[Setting]) -> None: + super().__init__(EventId.PLAYER_APPLICATION_SETTING_CHANGED) + self.player_application_settings = list(player_application_settings) + + def __bytes__(self) -> bytes: + return ( + bytes([self.event_id]) + + bytes([len(self.player_application_settings)]) + + b''.join( + [ + bytes([setting.attribute_id, setting.value_id]) + for setting in self.player_application_settings + ] + ) + ) + + +# ----------------------------------------------------------------------------- +@dataclass +class NowPlayingContentChangedEvent(Event): + @classmethod + def from_bytes(cls, pdu: bytes) -> NowPlayingContentChangedEvent: + return cls() + + def __init__(self) -> None: + super().__init__(EventId.NOW_PLAYING_CONTENT_CHANGED) + + +# ----------------------------------------------------------------------------- +@dataclass +class AvailablePlayersChangedEvent(Event): + @classmethod + def from_bytes(cls, pdu: bytes) -> AvailablePlayersChangedEvent: + return cls() + + def __init__(self) -> None: + super().__init__(EventId.AVAILABLE_PLAYERS_CHANGED) + + +# ----------------------------------------------------------------------------- +@dataclass +class AddressedPlayerChangedEvent(Event): + @dataclass + class Player: + player_id: int + uid_counter: int + + @classmethod + def from_bytes(cls, pdu: bytes) -> AddressedPlayerChangedEvent: + player_id, uid_counter = struct.unpack_from(" None: + super().__init__(EventId.ADDRESSED_PLAYER_CHANGED) + self.player = player + + def __bytes__(self) -> bytes: + return bytes([self.event_id]) + struct.pack( + ">HH", self.player.player_id, self.player.uid_counter + ) + + +# ----------------------------------------------------------------------------- +@dataclass +class UidsChangedEvent(Event): + uid_counter: int + + @classmethod + def from_bytes(cls, pdu: bytes) -> UidsChangedEvent: + return cls(uid_counter=struct.unpack_from(">H", pdu, 1)[0]) + + def __init__(self, uid_counter: int) -> None: + super().__init__(EventId.UIDS_CHANGED) + self.uid_counter = uid_counter + + def __bytes__(self) -> bytes: + return bytes([self.event_id]) + struct.pack(">H", self.uid_counter) + + +# ----------------------------------------------------------------------------- +@dataclass +class VolumeChangedEvent(Event): + volume: int + + @classmethod + def from_bytes(cls, pdu: bytes) -> VolumeChangedEvent: + return cls(volume=pdu[1]) + + def __init__(self, volume: int) -> None: + super().__init__(EventId.VOLUME_CHANGED) + self.volume = volume + + def __bytes__(self) -> bytes: + return bytes([self.event_id]) + bytes([self.volume]) + + +# ----------------------------------------------------------------------------- +EVENT_SUBCLASSES: Dict[EventId, Type[Event]] = { + EventId.PLAYBACK_STATUS_CHANGED: PlaybackStatusChangedEvent, + EventId.PLAYBACK_POS_CHANGED: PlaybackPositionChangedEvent, + EventId.TRACK_CHANGED: TrackChangedEvent, + EventId.PLAYER_APPLICATION_SETTING_CHANGED: PlayerApplicationSettingChangedEvent, + EventId.NOW_PLAYING_CONTENT_CHANGED: NowPlayingContentChangedEvent, + EventId.AVAILABLE_PLAYERS_CHANGED: AvailablePlayersChangedEvent, + EventId.ADDRESSED_PLAYER_CHANGED: AddressedPlayerChangedEvent, + EventId.UIDS_CHANGED: UidsChangedEvent, + EventId.VOLUME_CHANGED: VolumeChangedEvent, +} + + +# ----------------------------------------------------------------------------- +class Delegate: + """ + Base class for AVRCP delegates. + + All the methods are async, even if they don't always need to be, so that + delegates that do need to wait for an async result may do so. + """ + + class Error(Exception): + """The delegate method failed, with a specified status code.""" + + def __init__(self, status_code: Protocol.StatusCode) -> None: + self.status_code = status_code + + supported_events: List[EventId] + volume: int + + def __init__(self, supported_events: Iterable[EventId] = ()) -> None: + self.supported_events = list(supported_events) + self.volume = 0 + + async def get_supported_events(self) -> List[EventId]: + return self.supported_events + + async def set_absolute_volume(self, volume: int) -> None: + """ + Set the absolute volume. + + Returns: the effective volume that was set. + """ + logger.debug(f"@@@ set_absolute_volume: volume={volume}") + self.volume = volume + + async def get_absolute_volume(self) -> int: + return self.volume + + # TODO add other delegate methods + + +# ----------------------------------------------------------------------------- +class Protocol(pyee.EventEmitter): + """AVRCP Controller and Target protocol.""" + + class PacketType(enum.IntEnum): + SINGLE = 0b00 + START = 0b01 + CONTINUE = 0b10 + END = 0b11 + + class PduId(OpenIntEnum): + GET_CAPABILITIES = 0x10 + LIST_PLAYER_APPLICATION_SETTING_ATTRIBUTES = 0x11 + LIST_PLAYER_APPLICATION_SETTING_VALUES = 0x12 + GET_CURRENT_PLAYER_APPLICATION_SETTING_VALUE = 0x13 + SET_PLAYER_APPLICATION_SETTING_VALUE = 0x14 + GET_PLAYER_APPLICATION_SETTING_ATTRIBUTE_TEXT = 0x15 + GET_PLAYER_APPLICATION_SETTING_VALUE_TEXT = 0x16 + INFORM_DISPLAYABLE_CHARACTER_SET = 0x17 + INFORM_BATTERY_STATUS_OF_CT = 0x18 + GET_ELEMENT_ATTRIBUTES = 0x20 + GET_PLAY_STATUS = 0x30 + REGISTER_NOTIFICATION = 0x31 + REQUEST_CONTINUING_RESPONSE = 0x40 + ABORT_CONTINUING_RESPONSE = 0x41 + SET_ABSOLUTE_VOLUME = 0x50 + SET_ADDRESSED_PLAYER = 0x60 + SET_BROWSED_PLAYER = 0x70 + GET_FOLDER_ITEMS = 0x71 + GET_TOTAL_NUMBER_OF_ITEMS = 0x75 + + class StatusCode(OpenIntEnum): + INVALID_COMMAND = 0x00 + INVALID_PARAMETER = 0x01 + PARAMETER_CONTENT_ERROR = 0x02 + INTERNAL_ERROR = 0x03 + OPERATION_COMPLETED = 0x04 + UID_CHANGED = 0x05 + INVALID_DIRECTION = 0x07 + NOT_A_DIRECTORY = 0x08 + DOES_NOT_EXIST = 0x09 + INVALID_SCOPE = 0x0A + RANGE_OUT_OF_BOUNDS = 0x0B + FOLDER_ITEM_IS_NOT_PLAYABLE = 0x0C + MEDIA_IN_USE = 0x0D + NOW_PLAYING_LIST_FULL = 0x0E + SEARCH_NOT_SUPPORTED = 0x0F + SEARCH_IN_PROGRESS = 0x10 + INVALID_PLAYER_ID = 0x11 + PLAYER_NOT_BROWSABLE = 0x12 + PLAYER_NOT_ADDRESSED = 0x13 + NO_VALID_SEARCH_RESULTS = 0x14 + NO_AVAILABLE_PLAYERS = 0x15 + ADDRESSED_PLAYER_CHANGED = 0x16 + + class InvalidPidError(Exception): + """A response frame with ipid==1 was received.""" + + class NotPendingError(Exception): + """There is no pending command for a transaction label.""" + + class MismatchedResponseError(Exception): + """The response type does not corresponding to the request type.""" + + def __init__(self, response: Response) -> None: + self.response = response + + class UnexpectedResponseTypeError(Exception): + """The response type is not the expected one.""" + + def __init__(self, response: Protocol.ResponseContext) -> None: + self.response = response + + class UnexpectedResponseCodeError(Exception): + """The response code was not the expected one.""" + + def __init__( + self, response_code: avc.ResponseFrame.ResponseCode, response: Response + ) -> None: + self.response_code = response_code + self.response = response + + class PendingCommand: + response: asyncio.Future + + def __init__(self, transaction_label: int) -> None: + self.transaction_label = transaction_label + self.reset() + + def reset(self): + self.response = asyncio.get_running_loop().create_future() + + @dataclass + class ReceiveCommandState: + transaction_label: int + command_type: avc.CommandFrame.CommandType + + @dataclass + class ReceiveResponseState: + transaction_label: int + response_code: avc.ResponseFrame.ResponseCode + + @dataclass + class ResponseContext: + transaction_label: int + response: Response + + @dataclass + class FinalResponse(ResponseContext): + response_code: avc.ResponseFrame.ResponseCode + + @dataclass + class InterimResponse(ResponseContext): + final: Awaitable[Protocol.FinalResponse] + + @dataclass + class NotificationListener: + transaction_label: int + register_notification_command: RegisterNotificationCommand + + delegate: Delegate + send_transaction_label: int + command_pdu_assembler: PduAssembler + receive_command_state: Optional[ReceiveCommandState] + response_pdu_assembler: PduAssembler + receive_response_state: Optional[ReceiveResponseState] + avctp_protocol: Optional[avctp.Protocol] + free_commands: asyncio.Queue + pending_commands: Dict[int, PendingCommand] # Pending commands, by label + notification_listeners: Dict[EventId, NotificationListener] + + @staticmethod + def _check_vendor_dependent_frame( + frame: Union[avc.VendorDependentCommandFrame, avc.VendorDependentResponseFrame] + ) -> bool: + if frame.company_id != AVRCP_BLUETOOTH_SIG_COMPANY_ID: + logger.debug("unsupported company id, ignoring") + return False + + if frame.subunit_type != avc.Frame.SubunitType.PANEL or frame.subunit_id != 0: + logger.debug("unsupported subunit") + return False + + return True + + def __init__(self, delegate: Optional[Delegate] = None) -> None: + super().__init__() + self.delegate = delegate if delegate else Delegate() + self.command_pdu_assembler = PduAssembler(self._on_command_pdu) + self.receive_command_state = None + self.response_pdu_assembler = PduAssembler(self._on_response_pdu) + self.receive_response_state = None + self.avctp_protocol = None + self.notification_listeners = {} + + # Create an initial pool of free commands + self.pending_commands = {} + self.free_commands = asyncio.Queue() + for transaction_label in range(16): + self.free_commands.put_nowait(self.PendingCommand(transaction_label)) + + def listen(self, device: Device) -> None: + """ + Listen for incoming connections. + + A 'connection' event will be emitted when a connection is made, and a 'start' + event will be emitted when the protocol is ready to be used on that connection. + """ + device.register_l2cap_server(avctp.AVCTP_PSM, self._on_avctp_connection) + + async def connect(self, connection: Connection) -> None: + """ + Connect to a peer. + """ + avctp_channel = await connection.create_l2cap_channel( + l2cap.ClassicChannelSpec(psm=avctp.AVCTP_PSM) + ) + self._on_avctp_channel_open(avctp_channel) + + async def _obtain_pending_command(self) -> PendingCommand: + pending_command = await self.free_commands.get() + self.pending_commands[pending_command.transaction_label] = pending_command + return pending_command + + def recycle_pending_command(self, pending_command: PendingCommand) -> None: + pending_command.reset() + del self.pending_commands[pending_command.transaction_label] + self.free_commands.put_nowait(pending_command) + logger.debug(f"recycled pending command, {self.free_commands.qsize()} free") + + _R = TypeVar('_R') + + @staticmethod + def _check_response( + response_context: ResponseContext, expected_type: Type[_R] + ) -> _R: + if isinstance(response_context, Protocol.FinalResponse): + if ( + response_context.response_code + != avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE + ): + raise Protocol.UnexpectedResponseCodeError( + response_context.response_code, response_context.response + ) + + if not (isinstance(response_context.response, expected_type)): + raise Protocol.MismatchedResponseError(response_context.response) + + return response_context.response + + raise Protocol.UnexpectedResponseTypeError(response_context) + + def _delegate_command( + self, transaction_label: int, command: Command, method: Awaitable + ) -> None: + async def call(): + try: + await method + except Delegate.Error as error: + self.send_rejected_avrcp_response( + transaction_label, + command.pdu_id, + error.status_code, + ) + except Exception: + logger.exception("delegate method raised exception") + self.send_rejected_avrcp_response( + transaction_label, + command.pdu_id, + Protocol.StatusCode.INTERNAL_ERROR, + ) + + utils.AsyncRunner.spawn(call()) + + async def get_supported_events(self) -> List[EventId]: + """Get the list of events supported by the connected peer.""" + response_context = await self.send_avrcp_command( + avc.CommandFrame.CommandType.STATUS, + GetCapabilitiesCommand( + GetCapabilitiesCommand.CapabilityId.EVENTS_SUPPORTED + ), + ) + response = self._check_response(response_context, GetCapabilitiesResponse) + return cast(List[EventId], response.capabilities) + + async def get_play_status(self) -> SongAndPlayStatus: + """Get the play status of the connected peer.""" + response_context = await self.send_avrcp_command( + avc.CommandFrame.CommandType.STATUS, GetPlayStatusCommand() + ) + response = self._check_response(response_context, GetPlayStatusResponse) + return SongAndPlayStatus( + response.song_length, response.song_position, response.play_status + ) + + async def get_element_attributes( + self, element_identifier: int, attribute_ids: Sequence[MediaAttributeId] + ) -> List[MediaAttribute]: + """Get element attributes from the connected peer.""" + response_context = await self.send_avrcp_command( + avc.CommandFrame.CommandType.STATUS, + GetElementAttributesCommand(element_identifier, attribute_ids), + ) + response = self._check_response(response_context, GetElementAttributesResponse) + return response.attributes + + async def monitor_events( + self, event_id: EventId, playback_interval: int = 0 + ) -> AsyncIterator[Event]: + """ + Monitor events emitted from a peer. + + This generator yields Event objects. + """ + + def check_response(response) -> Event: + if not isinstance(response, RegisterNotificationResponse): + raise self.MismatchedResponseError(response) + + return response.event + + while True: + response = await self.send_avrcp_command( + avc.CommandFrame.CommandType.NOTIFY, + RegisterNotificationCommand(event_id, playback_interval), + ) + + if isinstance(response, self.InterimResponse): + logger.debug(f"interim: {response}") + yield check_response(response.response) + + logger.debug("waiting for final response") + response = await response.final + + if not isinstance(response, self.FinalResponse): + raise self.UnexpectedResponseTypeError(response) + + logger.debug(f"final: {response}") + if response.response_code != avc.ResponseFrame.ResponseCode.CHANGED: + raise self.UnexpectedResponseCodeError( + response.response_code, response.response + ) + + yield check_response(response.response) + + async def monitor_playback_status( + self, + ) -> AsyncIterator[PlayStatus]: + """Monitor Playback Status changes from the connected peer.""" + async for event in self.monitor_events(EventId.PLAYBACK_STATUS_CHANGED, 0): + if not isinstance(event, PlaybackStatusChangedEvent): + logger.warning("unexpected event class") + continue + yield event.play_status + + async def monitor_track_changed( + self, + ) -> AsyncIterator[bytes]: + """Monitor Track changes from the connected peer.""" + async for event in self.monitor_events(EventId.TRACK_CHANGED, 0): + if not isinstance(event, TrackChangedEvent): + logger.warning("unexpected event class") + continue + yield event.identifier + + async def monitor_playback_position( + self, playback_interval: int + ) -> AsyncIterator[int]: + """Monitor Playback Position changes from the connected peer.""" + async for event in self.monitor_events( + EventId.PLAYBACK_POS_CHANGED, playback_interval + ): + if not isinstance(event, PlaybackPositionChangedEvent): + logger.warning("unexpected event class") + continue + yield event.playback_position + + async def monitor_player_application_settings( + self, + ) -> AsyncIterator[List[PlayerApplicationSettingChangedEvent.Setting]]: + """Monitor Player Application Setting changes from the connected peer.""" + async for event in self.monitor_events( + EventId.PLAYER_APPLICATION_SETTING_CHANGED, 0 + ): + if not isinstance(event, PlayerApplicationSettingChangedEvent): + logger.warning("unexpected event class") + continue + yield event.player_application_settings + + async def monitor_now_playing_content(self) -> AsyncIterator[None]: + """Monitor Now Playing changes from the connected peer.""" + async for event in self.monitor_events(EventId.NOW_PLAYING_CONTENT_CHANGED, 0): + if not isinstance(event, NowPlayingContentChangedEvent): + logger.warning("unexpected event class") + continue + yield None + + async def monitor_available_players(self) -> AsyncIterator[None]: + """Monitor Available Players changes from the connected peer.""" + async for event in self.monitor_events(EventId.AVAILABLE_PLAYERS_CHANGED, 0): + if not isinstance(event, AvailablePlayersChangedEvent): + logger.warning("unexpected event class") + continue + yield None + + async def monitor_addressed_player( + self, + ) -> AsyncIterator[AddressedPlayerChangedEvent.Player]: + """Monitor Addressed Player changes from the connected peer.""" + async for event in self.monitor_events(EventId.ADDRESSED_PLAYER_CHANGED, 0): + if not isinstance(event, AddressedPlayerChangedEvent): + logger.warning("unexpected event class") + continue + yield event.player + + async def monitor_uids( + self, + ) -> AsyncIterator[int]: + """Monitor UID changes from the connected peer.""" + async for event in self.monitor_events(EventId.UIDS_CHANGED, 0): + if not isinstance(event, UidsChangedEvent): + logger.warning("unexpected event class") + continue + yield event.uid_counter + + async def monitor_volume( + self, + ) -> AsyncIterator[int]: + """Monitor Volume changes from the connected peer.""" + async for event in self.monitor_events(EventId.VOLUME_CHANGED, 0): + if not isinstance(event, VolumeChangedEvent): + logger.warning("unexpected event class") + continue + yield event.volume + + def notify_event(self, event: Event): + """Notify an event to the connected peer.""" + if (listener := self.notification_listeners.get(event.event_id)) is None: + logger.debug(f"no listener for {event.event_id.name}") + return + + # Emit the notification. + notification = RegisterNotificationResponse(event) + self.send_avrcp_response( + listener.transaction_label, + avc.ResponseFrame.ResponseCode.CHANGED, + notification, + ) + + # Remove the listener (they will need to re-register). + del self.notification_listeners[event.event_id] + + def notify_playback_status_changed(self, status: PlayStatus) -> None: + """Notify the connected peer of a Playback Status change.""" + self.notify_event(PlaybackStatusChangedEvent(status)) + + def notify_track_changed(self, identifier: bytes) -> None: + """Notify the connected peer of a Track change.""" + if len(identifier) != 8: + raise ValueError("identifier must be 8 bytes") + self.notify_event(TrackChangedEvent(identifier)) + + def notify_playback_position_changed(self, position: int) -> None: + """Notify the connected peer of a Position change.""" + self.notify_event(PlaybackPositionChangedEvent(position)) + + def notify_player_application_settings_changed( + self, settings: Sequence[PlayerApplicationSettingChangedEvent.Setting] + ) -> None: + """Notify the connected peer of an Player Application Setting change.""" + self.notify_event( + PlayerApplicationSettingChangedEvent(settings), + ) + + def notify_now_playing_content_changed(self) -> None: + """Notify the connected peer of a Now Playing change.""" + self.notify_event(NowPlayingContentChangedEvent()) + + def notify_available_players_changed(self) -> None: + """Notify the connected peer of an Available Players change.""" + self.notify_event(AvailablePlayersChangedEvent()) + + def notify_addressed_player_changed( + self, player: AddressedPlayerChangedEvent.Player + ) -> None: + """Notify the connected peer of an Addressed Player change.""" + self.notify_event(AddressedPlayerChangedEvent(player)) + + def notify_uids_changed(self, uid_counter: int) -> None: + """Notify the connected peer of a UID change.""" + self.notify_event(UidsChangedEvent(uid_counter)) + + def notify_volume_changed(self, volume: int) -> None: + """Notify the connected peer of a Volume change.""" + self.notify_event(VolumeChangedEvent(volume)) + + def _register_notification_listener( + self, transaction_label: int, command: RegisterNotificationCommand + ) -> None: + listener = self.NotificationListener(transaction_label, command) + self.notification_listeners[command.event_id] = listener + + def _on_avctp_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None: + logger.debug("AVCTP connection established") + l2cap_channel.on("open", lambda: self._on_avctp_channel_open(l2cap_channel)) + + self.emit("connection") + + def _on_avctp_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None: + logger.debug("AVCTP channel open") + if self.avctp_protocol is not None: + # TODO: find a better strategy instead of just closing + logger.warning("AVCTP protocol already active, closing connection") + AsyncRunner.spawn(l2cap_channel.disconnect()) + return + + self.avctp_protocol = avctp.Protocol(l2cap_channel) + self.avctp_protocol.register_command_handler(AVRCP_PID, self._on_avctp_command) + self.avctp_protocol.register_response_handler( + AVRCP_PID, self._on_avctp_response + ) + l2cap_channel.on("close", self._on_avctp_channel_close) + + self.emit("start") + + def _on_avctp_channel_close(self) -> None: + logger.debug("AVCTP channel closed") + self.avctp_protocol = None + + self.emit("stop") + + def _on_avctp_command( + self, transaction_label: int, command: avc.CommandFrame + ) -> None: + logger.debug( + f"<<< AVCTP Command, transaction_label={transaction_label}: " f"{command}" + ) + + # Only the PANEL subunit type with subunit ID 0 is supported in this profile. + if ( + command.subunit_type != avc.Frame.SubunitType.PANEL + or command.subunit_id != 0 + ): + logger.debug("subunit not supported") + self.send_not_implemented_response(transaction_label, command) + return + + if isinstance(command, avc.VendorDependentCommandFrame): + if not self._check_vendor_dependent_frame(command): + return + + if self.receive_command_state is None: + self.receive_command_state = self.ReceiveCommandState( + transaction_label=transaction_label, command_type=command.ctype + ) + elif ( + self.receive_command_state.transaction_label != transaction_label + or self.receive_command_state.command_type != command.ctype + ): + # We're in the middle of some other PDU + logger.warning("received interleaved PDU, resetting state") + self.command_pdu_assembler.reset() + self.receive_command_state = None + return + else: + self.receive_command_state.command_type = command.ctype + self.receive_command_state.transaction_label = transaction_label + + self.command_pdu_assembler.on_pdu(command.vendor_dependent_data) + return + + if isinstance(command, avc.PassThroughCommandFrame): + # TODO: delegate + response = avc.PassThroughResponseFrame( + avc.ResponseFrame.ResponseCode.ACCEPTED, + avc.Frame.SubunitType.PANEL, + 0, + command.state_flag, + command.operation_id, + command.operation_data, + ) + self.send_response(transaction_label, response) + return + + # TODO handle other types + self.send_not_implemented_response(transaction_label, command) + + def _on_avctp_response( + self, transaction_label: int, response: Optional[avc.ResponseFrame] + ) -> None: + logger.debug( + f"<<< AVCTP Response, transaction_label={transaction_label}: {response}" + ) + + # Check that we have a pending command that matches this response. + if not (pending_command := self.pending_commands.get(transaction_label)): + logger.warning("no pending command with this transaction label") + return + + # A None response means an invalid PID was used in the request. + if response is None: + pending_command.response.set_exception(self.InvalidPidError()) + + if isinstance(response, avc.VendorDependentResponseFrame): + if not self._check_vendor_dependent_frame(response): + return + + if self.receive_response_state is None: + self.receive_response_state = self.ReceiveResponseState( + transaction_label=transaction_label, response_code=response.response + ) + elif ( + self.receive_response_state.transaction_label != transaction_label + or self.receive_response_state.response_code != response.response + ): + # We're in the middle of some other PDU + logger.warning("received interleaved PDU, resetting state") + self.response_pdu_assembler.reset() + self.receive_response_state = None + return + else: + self.receive_response_state.response_code = response.response + self.receive_response_state.transaction_label = transaction_label + + self.response_pdu_assembler.on_pdu(response.vendor_dependent_data) + return + + if isinstance(response, avc.PassThroughResponseFrame): + pending_command.response.set_result(response) + + # TODO handle other types + + self.recycle_pending_command(pending_command) + + def _on_command_pdu(self, pdu_id: PduId, pdu: bytes) -> None: + logger.debug(f"<<< AVRCP command PDU [pdu_id={pdu_id.name}]: {pdu.hex()}") + + assert self.receive_command_state is not None + transaction_label = self.receive_command_state.transaction_label + + # Dispatch the command. + # NOTE: with a small number of supported commands, a manual dispatch like this + # is Ok, but if/when more commands are supported, a lookup dispatch mechanism + # would be more appropriate. + # TODO: switch on ctype + if self.receive_command_state.command_type in ( + avc.CommandFrame.CommandType.CONTROL, + avc.CommandFrame.CommandType.STATUS, + avc.CommandFrame.CommandType.NOTIFY, + ): + # TODO: catch exceptions from delegates + if pdu_id == self.PduId.GET_CAPABILITIES: + self._on_get_capabilities_command( + transaction_label, GetCapabilitiesCommand.from_bytes(pdu) + ) + elif pdu_id == self.PduId.SET_ABSOLUTE_VOLUME: + self._on_set_absolute_volume_command( + transaction_label, SetAbsoluteVolumeCommand.from_bytes(pdu) + ) + elif pdu_id == self.PduId.REGISTER_NOTIFICATION: + self._on_register_notification_command( + transaction_label, RegisterNotificationCommand.from_bytes(pdu) + ) + else: + # Not supported. + # TODO: check that this is the right way to respond in this case. + logger.debug("unsupported PDU ID") + self.send_rejected_avrcp_response( + transaction_label, pdu_id, self.StatusCode.INVALID_PARAMETER + ) + else: + logger.debug("unsupported command type") + self.send_rejected_avrcp_response( + transaction_label, pdu_id, self.StatusCode.INVALID_COMMAND + ) + + self.receive_command_state = None + + def _on_response_pdu(self, pdu_id: PduId, pdu: bytes) -> None: + logger.debug(f"<<< AVRCP response PDU [pdu_id={pdu_id.name}]: {pdu.hex()}") + + assert self.receive_response_state is not None + + transaction_label = self.receive_response_state.transaction_label + response_code = self.receive_response_state.response_code + self.receive_response_state = None + + # Check that we have a pending command that matches this response. + if not (pending_command := self.pending_commands.get(transaction_label)): + logger.warning("no pending command with this transaction label") + return + + # Convert the PDU bytes into a response object. + # NOTE: with a small number of supported responses, a manual switch like this + # is Ok, but if/when more responses are supported, a lookup mechanism would be + # more appropriate. + response: Optional[Response] = None + if response_code == avc.ResponseFrame.ResponseCode.REJECTED: + response = RejectedResponse.from_bytes(pdu_id, pdu) + elif response_code == avc.ResponseFrame.ResponseCode.NOT_IMPLEMENTED: + response = NotImplementedResponse.from_bytes(pdu_id, pdu) + elif response_code in ( + avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE, + avc.ResponseFrame.ResponseCode.INTERIM, + avc.ResponseFrame.ResponseCode.CHANGED, + avc.ResponseFrame.ResponseCode.ACCEPTED, + ): + if pdu_id == self.PduId.GET_CAPABILITIES: + response = GetCapabilitiesResponse.from_bytes(pdu) + elif pdu_id == self.PduId.GET_PLAY_STATUS: + response = GetPlayStatusResponse.from_bytes(pdu) + elif pdu_id == self.PduId.GET_ELEMENT_ATTRIBUTES: + response = GetElementAttributesResponse.from_bytes(pdu) + elif pdu_id == self.PduId.SET_ABSOLUTE_VOLUME: + response = SetAbsoluteVolumeResponse.from_bytes(pdu) + elif pdu_id == self.PduId.REGISTER_NOTIFICATION: + response = RegisterNotificationResponse.from_bytes(pdu) + else: + logger.debug("unexpected PDU ID") + pending_command.response.set_exception( + ProtocolError( + error_code=None, + error_namespace="avrcp", + details="unexpected PDU ID", + ) + ) + else: + logger.debug("unexpected response code") + pending_command.response.set_exception( + ProtocolError( + error_code=None, + error_namespace="avrcp", + details="unexpected response code", + ) + ) + + if response is None: + self.recycle_pending_command(pending_command) + return + + logger.debug(f"<<< AVRCP response: {response}") + + # Make the response available to the waiter. + if response_code == avc.ResponseFrame.ResponseCode.INTERIM: + pending_interim_response = pending_command.response + pending_command.reset() + pending_interim_response.set_result( + self.InterimResponse( + pending_command.transaction_label, + response, + pending_command.response, + ) + ) + else: + pending_command.response.set_result( + self.FinalResponse( + pending_command.transaction_label, + response, + response_code, + ) + ) + self.recycle_pending_command(pending_command) + + def send_command(self, transaction_label: int, command: avc.CommandFrame) -> None: + logger.debug(f">>> AVRCP command: {command}") + + if self.avctp_protocol is None: + logger.warning("trying to send command while avctp_protocol is None") + return + + self.avctp_protocol.send_command(transaction_label, AVRCP_PID, bytes(command)) + + async def send_passthrough_command( + self, command: avc.PassThroughCommandFrame + ) -> avc.PassThroughResponseFrame: + # Wait for a free command slot. + pending_command = await self._obtain_pending_command() + + # Send the command. + self.send_command(pending_command.transaction_label, command) + + # Wait for the response. + return await pending_command.response + + async def send_key_event( + self, key: avc.PassThroughCommandFrame.OperationId, pressed: bool + ) -> avc.PassThroughResponseFrame: + """Send a key event to the connected peer.""" + return await self.send_passthrough_command( + avc.PassThroughCommandFrame( + avc.CommandFrame.CommandType.CONTROL, + avc.Frame.SubunitType.PANEL, + 0, + avc.PassThroughFrame.StateFlag.PRESSED + if pressed + else avc.PassThroughFrame.StateFlag.RELEASED, + key, + b'', + ) + ) + + async def send_avrcp_command( + self, command_type: avc.CommandFrame.CommandType, command: Command + ) -> ResponseContext: + # Wait for a free command slot. + pending_command = await self._obtain_pending_command() + + # TODO: fragmentation + # Send the command. + logger.debug(f">>> AVRCP command PDU: {command}") + pdu = ( + struct.pack(">BBH", command.pdu_id, 0, len(command.parameter)) + + command.parameter + ) + command_frame = avc.VendorDependentCommandFrame( + command_type, + avc.Frame.SubunitType.PANEL, + 0, + AVRCP_BLUETOOTH_SIG_COMPANY_ID, + pdu, + ) + self.send_command(pending_command.transaction_label, command_frame) + + # Wait for the response. + return await pending_command.response + + def send_response( + self, transaction_label: int, response: avc.ResponseFrame + ) -> None: + assert self.avctp_protocol is not None + logger.debug(f">>> AVRCP response: {response}") + self.avctp_protocol.send_response(transaction_label, AVRCP_PID, bytes(response)) + + def send_passthrough_response( + self, + transaction_label: int, + command: avc.PassThroughCommandFrame, + response_code: avc.ResponseFrame.ResponseCode, + ): + response = avc.PassThroughResponseFrame( + response_code, + avc.Frame.SubunitType.PANEL, + 0, + command.state_flag, + command.operation_id, + command.operation_data, + ) + self.send_response(transaction_label, response) + + def send_avrcp_response( + self, + transaction_label: int, + response_code: avc.ResponseFrame.ResponseCode, + response: Response, + ) -> None: + # TODO: fragmentation + logger.debug(f">>> AVRCP response PDU: {response}") + pdu = ( + struct.pack(">BBH", response.pdu_id, 0, len(response.parameter)) + + response.parameter + ) + response_frame = avc.VendorDependentResponseFrame( + response_code, + avc.Frame.SubunitType.PANEL, + 0, + AVRCP_BLUETOOTH_SIG_COMPANY_ID, + pdu, + ) + self.send_response(transaction_label, response_frame) + + def send_not_implemented_response( + self, transaction_label: int, command: avc.CommandFrame + ) -> None: + response = avc.ResponseFrame( + avc.ResponseFrame.ResponseCode.NOT_IMPLEMENTED, + command.subunit_type, + command.subunit_id, + command.opcode, + command.operands, + ) + self.send_response(transaction_label, response) + + def send_rejected_avrcp_response( + self, transaction_label: int, pdu_id: Protocol.PduId, status_code: StatusCode + ) -> None: + self.send_avrcp_response( + transaction_label, + avc.ResponseFrame.ResponseCode.REJECTED, + RejectedResponse(pdu_id, status_code), + ) + + def _on_get_capabilities_command( + self, transaction_label: int, command: GetCapabilitiesCommand + ) -> None: + logger.debug(f"<<< AVRCP command PDU: {command}") + + async def get_supported_events(): + if ( + command.capability_id + != GetCapabilitiesCommand.CapabilityId.EVENTS_SUPPORTED + ): + raise Protocol.InvalidParameterError + + supported_events = await self.delegate.get_supported_events() + self.send_avrcp_response( + transaction_label, + avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE, + GetCapabilitiesResponse(command.capability_id, supported_events), + ) + + self._delegate_command(transaction_label, command, get_supported_events()) + + def _on_set_absolute_volume_command( + self, transaction_label: int, command: SetAbsoluteVolumeCommand + ) -> None: + logger.debug(f"<<< AVRCP command PDU: {command}") + + async def set_absolute_volume(): + await self.delegate.set_absolute_volume(command.volume) + effective_volume = await self.delegate.get_absolute_volume() + self.send_avrcp_response( + transaction_label, + avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE, + SetAbsoluteVolumeResponse(effective_volume), + ) + + self._delegate_command(transaction_label, command, set_absolute_volume()) + + def _on_register_notification_command( + self, transaction_label: int, command: RegisterNotificationCommand + ) -> None: + logger.debug(f"<<< AVRCP command PDU: {command}") + + async def register_notification(): + # Check if the event is supported. + supported_events = await self.delegate.get_supported_events() + if command.event_id in supported_events: + if command.event_id == EventId.VOLUME_CHANGED: + volume = await self.delegate.get_absolute_volume() + response = RegisterNotificationResponse(VolumeChangedEvent(volume)) + self.send_avrcp_response( + transaction_label, + avc.ResponseFrame.ResponseCode.INTERIM, + response, + ) + self._register_notification_listener(transaction_label, command) + return + + if command.event_id == EventId.PLAYBACK_STATUS_CHANGED: + # TODO: testing only, use delegate + response = RegisterNotificationResponse( + PlaybackStatusChangedEvent(play_status=PlayStatus.PLAYING) + ) + self.send_avrcp_response( + transaction_label, + avc.ResponseFrame.ResponseCode.INTERIM, + response, + ) + self._register_notification_listener(transaction_label, command) + return + + self._delegate_command(transaction_label, command, register_notification()) diff --git a/bumble/core.py b/bumble/core.py index 5bbddd0b..dce721a4 100644 --- a/bumble/core.py +++ b/bumble/core.py @@ -97,12 +97,16 @@ def __str__(self): namespace = f'{self.error_namespace}/' else: namespace = '' - error_text = { - (True, True): f'{self.error_name} [0x{self.error_code:X}]', - (True, False): self.error_name, - (False, True): f'0x{self.error_code:X}', - (False, False): '', - }[(self.error_name != '', self.error_code is not None)] + have_name = self.error_name != '' + have_code = self.error_code is not None + if have_name and have_code: + error_text = f'{self.error_name} [0x{self.error_code:X}]' + elif have_name and not have_code: + error_text = self.error_name + elif not have_name and have_code: + error_text = f'0x{self.error_code:X}' + else: + error_text = '' return f'{type(self).__name__}({namespace}{error_text})' @@ -319,7 +323,7 @@ def __str__(self) -> str: BT_HARDCOPY_CONTROL_CHANNEL_PROTOCOL_ID = UUID.from_16_bits(0x0012, 'HardcopyControlChannel') BT_HARDCOPY_DATA_CHANNEL_PROTOCOL_ID = UUID.from_16_bits(0x0014, 'HardcopyDataChannel') BT_HARDCOPY_NOTIFICATION_PROTOCOL_ID = UUID.from_16_bits(0x0016, 'HardcopyNotification') -BT_AVTCP_PROTOCOL_ID = UUID.from_16_bits(0x0017, 'AVCTP') +BT_AVCTP_PROTOCOL_ID = UUID.from_16_bits(0x0017, 'AVCTP') BT_AVDTP_PROTOCOL_ID = UUID.from_16_bits(0x0019, 'AVDTP') BT_CMTP_PROTOCOL_ID = UUID.from_16_bits(0x001B, 'CMTP') BT_MCAP_CONTROL_CHANNEL_PROTOCOL_ID = UUID.from_16_bits(0x001E, 'MCAPControlChannel') diff --git a/bumble/helpers.py b/bumble/helpers.py index 93bbc24e..a3633f1b 100644 --- a/bumble/helpers.py +++ b/bumble/helpers.py @@ -50,6 +50,11 @@ from bumble.rfcomm import RFCOMM_Frame, RFCOMM_PSM from bumble.sdp import SDP_PDU, SDP_PSM from bumble import crypto +from bumble.avdtp import MessageAssembler as AVDTP_MessageAssembler, AVDTP_PSM +from bumble.avctp import MessageAssembler as AVCTP_MessageAssembler, AVCTP_PSM +from bumble.avrcp import AVRCP_PID +from bumble.avc import Frame as AVC_Frame + # ----------------------------------------------------------------------------- # Logging @@ -61,9 +66,12 @@ PSM_NAMES = { RFCOMM_PSM: 'RFCOMM', SDP_PSM: 'SDP', - avdtp.AVDTP_PSM: 'AVDTP', + AVDTP_PSM: 'AVDTP', + AVCTP_PSM: 'AVCTP' + # TODO: add more PSM values } +AVCTP_PID_NAMES = {AVRCP_PID: 'AVRCP'} # ----------------------------------------------------------------------------- class PacketTracer: @@ -76,12 +84,15 @@ def __init__(self, analyzer: PacketTracer.Analyzer) -> None: self.analyzer = analyzer self.packet_assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu) self.avdtp_assemblers = {} # AVDTP assemblers, by source_cid + self.avctp_assemblers = {} # AVCTP assemblers, by source_cid + self.avrcp_assemblers = {} # AVRCP assemblers, by source_cid self.psms = {} # PSM, by source_cid self.peer = None # pylint: disable=too-many-nested-blocks def on_acl_pdu(self, pdu: bytes) -> None: l2cap_pdu = L2CAP_PDU.from_bytes(pdu) + self.analyzer.emit(l2cap_pdu) if l2cap_pdu.cid == ATT_CID: att_pdu = ATT_PDU.from_bytes(l2cap_pdu.payload) @@ -103,25 +114,32 @@ def on_acl_pdu(self, pdu: bytes) -> None: connection_response.result == L2CAP_Connection_Response.CONNECTION_SUCCESSFUL ): - if self.peer: - if psm := self.peer.psms.get( + if self.peer and ( + psm := self.peer.psms.get( connection_response.source_cid - ): - # Found a pending connection - self.psms[connection_response.destination_cid] = psm - - # For AVDTP connections, create a packet assembler for - # each direction - if psm == avdtp.AVDTP_PSM: - self.avdtp_assemblers[ - connection_response.source_cid - ] = avdtp.MessageAssembler(self.on_avdtp_message) - self.peer.avdtp_assemblers[ - connection_response.destination_cid - ] = avdtp.MessageAssembler( - self.peer.on_avdtp_message - ) - + ) + ): + # Found a pending connection + self.psms[connection_response.destination_cid] = psm + + # For AVDTP connections, create a packet assembler for + # each direction + if psm == avdtp.AVDTP_PSM: + self.avdtp_assemblers[ + connection_response.source_cid + ] = avdtp.MessageAssembler(self.on_avdtp_message) + self.peer.avdtp_assemblers[ + connection_response.destination_cid + ] = avdtp.MessageAssembler( + self.peer.on_avdtp_message + ) + elif psm == AVCTP_PSM: + self.avctp_assemblers[ + connection_response.source_cid + ] = AVCTP_MessageAssembler(self.on_avctp_message) + self.peer.avctp_assemblers[ + connection_response.destination_cid + ] = AVCTP_MessageAssembler(self.peer.on_avctp_message) else: # Try to find the PSM associated with this PDU if self.peer and (psm := self.peer.psms.get(l2cap_pdu.cid)): @@ -139,6 +157,14 @@ def on_acl_pdu(self, pdu: bytes) -> None: assembler = self.avdtp_assemblers.get(l2cap_pdu.cid) if assembler: assembler.on_pdu(l2cap_pdu.payload) + elif psm == AVCTP_PSM: + self.analyzer.emit( + f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, ' + f'PSM=AVCTP]: {l2cap_pdu.payload.hex()}' + ) + assembler = self.avctp_assemblers.get(l2cap_pdu.cid) + if assembler: + assembler.on_pdu(l2cap_pdu.payload) else: psm_string = name_or_number(PSM_NAMES, psm) self.analyzer.emit( @@ -155,6 +181,21 @@ def on_avdtp_message( f'{color("AVDTP", "green")} [{transaction_label}] {message}' ) + def on_avctp_message(self, transaction_label: int, is_command: bool, ipid: bool, pid: int, payload: bytes): + if pid == AVRCP_PID: + avc_frame = AVC_Frame.from_bytes(payload) + details = str(avc_frame) + else: + details = payload.hex() + + c_r = 'Command' if is_command else 'Response' + self.analyzer.emit( + f'{color("AVCTP", "green")} ' + f'{c_r}[{transaction_label}][{name_or_number(AVCTP_PID_NAMES, pid)}] ' + f'{"#" if ipid else ""}' + f'{details}' + ) + def feed_packet(self, packet: HCI_AclDataPacket) -> None: self.packet_assembler.feed_packet(packet) diff --git a/bumble/sdp.py b/bumble/sdp.py index 099efabb..749e2956 100644 --- a/bumble/sdp.py +++ b/bumble/sdp.py @@ -97,7 +97,8 @@ SDP_ICON_URL_ATTRIBUTE_ID = 0X000C SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID = 0X000D -# Attribute Identifier (cf. Assigned Numbers for Service Discovery) + +# Profile-specific Attribute Identifiers (cf. Assigned Numbers for Service Discovery) # used by AVRCP, HFP and A2DP SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID = 0x0311 @@ -115,7 +116,8 @@ SDP_DOCUMENTATION_URL_ATTRIBUTE_ID: 'SDP_DOCUMENTATION_URL_ATTRIBUTE_ID', SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID: 'SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID', SDP_ICON_URL_ATTRIBUTE_ID: 'SDP_ICON_URL_ATTRIBUTE_ID', - SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID: 'SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID' + SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID: 'SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID', + SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID: 'SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID', } SDP_PUBLIC_BROWSE_ROOT = core.UUID.from_16_bits(0x1002, 'PublicBrowseRoot') diff --git a/bumble/utils.py b/bumble/utils.py index 552140b1..1bb84c62 100644 --- a/bumble/utils.py +++ b/bumble/utils.py @@ -17,9 +17,10 @@ # ----------------------------------------------------------------------------- from __future__ import annotations import asyncio -import logging -import traceback import collections +import enum +import functools +import logging import sys import warnings from typing import ( @@ -34,7 +35,8 @@ Union, overload, ) -from functools import wraps, partial +import traceback + from pyee import EventEmitter from .colors import color @@ -131,13 +133,14 @@ def on( Args: emitter: EventEmitter to watch event: Event name - handler: (Optional) Event handler. When nothing is passed, this method works as a decorator. + handler: (Optional) Event handler. When nothing is passed, this method + works as a decorator. ''' - def wrapper(f: _Handler) -> _Handler: - self.handlers.append((emitter, event, f)) - emitter.on(event, f) - return f + def wrapper(wrapped: _Handler) -> _Handler: + self.handlers.append((emitter, event, wrapped)) + emitter.on(event, wrapped) + return wrapped return wrapper if handler is None else wrapper(handler) @@ -157,13 +160,14 @@ def once( Args: emitter: EventEmitter to watch event: Event name - handler: (Optional) Event handler. When nothing passed, this method works as a decorator. + handler: (Optional) Event handler. When nothing passed, this method works + as a decorator. ''' - def wrapper(f: _Handler) -> _Handler: - self.handlers.append((emitter, event, f)) - emitter.once(event, f) - return f + def wrapper(wrapped: _Handler) -> _Handler: + self.handlers.append((emitter, event, wrapped)) + emitter.once(event, wrapped) + return wrapped return wrapper if handler is None else wrapper(handler) @@ -276,7 +280,7 @@ def run_in_task(queue=None): """ def decorator(func): - @wraps(func) + @functools.wraps(func) def wrapper(*args, **kwargs): coroutine = func(*args, **kwargs) if queue is None: @@ -410,30 +414,35 @@ async def pump(self): self.check_pump() +# ----------------------------------------------------------------------------- async def async_call(function, *args, **kwargs): """ - Immediately calls the function with provided args and kwargs, wrapping it in an async function. - Rust's `pyo3_asyncio` library needs functions to be marked async to properly inject a running loop. + Immediately calls the function with provided args and kwargs, wrapping it in an + async function. + Rust's `pyo3_asyncio` library needs functions to be marked async to properly inject + a running loop. result = await async_call(some_function, ...) """ return function(*args, **kwargs) +# ----------------------------------------------------------------------------- def wrap_async(function): """ Wraps the provided function in an async function. """ - return partial(async_call, function) + return functools.partial(async_call, function) +# ----------------------------------------------------------------------------- def deprecated(msg: str): """ Throw deprecation warning before execution. """ def wrapper(function): - @wraps(function) + @functools.wraps(function) def inner(*args, **kwargs): warnings.warn(msg, DeprecationWarning) return function(*args, **kwargs) @@ -443,6 +452,7 @@ def inner(*args, **kwargs): return wrapper +# ----------------------------------------------------------------------------- def experimental(msg: str): """ Throws a future warning before execution. @@ -457,3 +467,22 @@ def inner(*args, **kwargs): return inner return wrapper + + +# ----------------------------------------------------------------------------- +class OpenIntEnum(enum.IntEnum): + """ + Subclass of enum.IntEnum that can hold integer values outside the set of + predefined values. This is convenient for implementing protocols where some + integer constants may be added over time. + """ + + @classmethod + def _missing_(cls, value): + if not isinstance(value, int): + return None + + obj = int.__new__(cls, value) + obj._value_ = value + obj._name_ = f"{cls.__name__}[{value}]" + return obj diff --git a/examples/avrcp_as_sink.html b/examples/avrcp_as_sink.html new file mode 100644 index 00000000..7c967442 --- /dev/null +++ b/examples/avrcp_as_sink.html @@ -0,0 +1,274 @@ + + + + + + + Server Port
+
+
+

+
+
+
+
+
+
+
+ + + VOLUME: + +   +
+ + + + + + + + + + + + + + + + + + + + + + +
PLAYBACK STATUS
POSITION
TRACK
ADDRESSED PLAYER
UID COUNTER
SUPPORTED EVENTS
PLAYER SETTINGS
+ + + + \ No newline at end of file diff --git a/examples/run_avrcp.py b/examples/run_avrcp.py new file mode 100644 index 00000000..4bb41437 --- /dev/null +++ b/examples/run_avrcp.py @@ -0,0 +1,408 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +from __future__ import annotations +import asyncio +import json +import sys +import os +import logging +import websockets + +from bumble.device import Device +from bumble.transport import open_transport_or_link +from bumble.core import BT_BR_EDR_TRANSPORT +from bumble import avc +from bumble import avrcp +from bumble import avdtp +from bumble import a2dp +from bumble import utils + + +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +def sdp_records(): + a2dp_sink_service_record_handle = 0x00010001 + avrcp_controller_service_record_handle = 0x00010002 + avrcp_target_service_record_handle = 0x00010003 + # pylint: disable=line-too-long + return { + a2dp_sink_service_record_handle: a2dp.make_audio_sink_service_sdp_records( + a2dp_sink_service_record_handle + ), + avrcp_controller_service_record_handle: avrcp.make_controller_service_sdp_records( + avrcp_controller_service_record_handle + ), + avrcp_target_service_record_handle: avrcp.make_target_service_sdp_records( + avrcp_controller_service_record_handle + ), + } + + +# ----------------------------------------------------------------------------- +def codec_capabilities(): + return avdtp.MediaCodecCapabilities( + media_type=avdtp.AVDTP_AUDIO_MEDIA_TYPE, + media_codec_type=a2dp.A2DP_SBC_CODEC_TYPE, + media_codec_information=a2dp.SbcMediaCodecInformation.from_lists( + sampling_frequencies=[48000, 44100, 32000, 16000], + channel_modes=[ + a2dp.SBC_MONO_CHANNEL_MODE, + a2dp.SBC_DUAL_CHANNEL_MODE, + a2dp.SBC_STEREO_CHANNEL_MODE, + a2dp.SBC_JOINT_STEREO_CHANNEL_MODE, + ], + block_lengths=[4, 8, 12, 16], + subbands=[4, 8], + allocation_methods=[ + a2dp.SBC_LOUDNESS_ALLOCATION_METHOD, + a2dp.SBC_SNR_ALLOCATION_METHOD, + ], + minimum_bitpool_value=2, + maximum_bitpool_value=53, + ), + ) + + +# ----------------------------------------------------------------------------- +def on_avdtp_connection(server): + # Add a sink endpoint to the server + sink = server.add_sink(codec_capabilities()) + sink.on('rtp_packet', on_rtp_packet) + + +# ----------------------------------------------------------------------------- +def on_rtp_packet(packet): + print(f'RTP: {packet}') + + +# ----------------------------------------------------------------------------- +def on_avrcp_start(avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketServer): + async def get_supported_events(): + events = await avrcp_protocol.get_supported_events() + print("SUPPORTED EVENTS:", events) + websocket_server.send_message( + { + "type": "supported-events", + "params": {"events": [event.name for event in events]}, + } + ) + + if avrcp.EventId.TRACK_CHANGED in events: + utils.AsyncRunner.spawn(monitor_track_changed()) + + if avrcp.EventId.PLAYBACK_STATUS_CHANGED in events: + utils.AsyncRunner.spawn(monitor_playback_status()) + + if avrcp.EventId.PLAYBACK_POS_CHANGED in events: + utils.AsyncRunner.spawn(monitor_playback_position()) + + if avrcp.EventId.PLAYER_APPLICATION_SETTING_CHANGED in events: + utils.AsyncRunner.spawn(monitor_player_application_settings()) + + if avrcp.EventId.AVAILABLE_PLAYERS_CHANGED in events: + utils.AsyncRunner.spawn(monitor_available_players()) + + if avrcp.EventId.ADDRESSED_PLAYER_CHANGED in events: + utils.AsyncRunner.spawn(monitor_addressed_player()) + + if avrcp.EventId.UIDS_CHANGED in events: + utils.AsyncRunner.spawn(monitor_uids()) + + if avrcp.EventId.VOLUME_CHANGED in events: + utils.AsyncRunner.spawn(monitor_volume()) + + utils.AsyncRunner.spawn(get_supported_events()) + + async def monitor_track_changed(): + async for identifier in avrcp_protocol.monitor_track_changed(): + print("TRACK CHANGED:", identifier.hex()) + websocket_server.send_message( + {"type": "track-changed", "params": {"identifier": identifier.hex()}} + ) + + async def monitor_playback_status(): + async for playback_status in avrcp_protocol.monitor_playback_status(): + print("PLAYBACK STATUS CHANGED:", playback_status.name) + websocket_server.send_message( + { + "type": "playback-status-changed", + "params": {"status": playback_status.name}, + } + ) + + async def monitor_playback_position(): + async for playback_position in avrcp_protocol.monitor_playback_position( + playback_interval=1 + ): + print("PLAYBACK POSITION CHANGED:", playback_position) + websocket_server.send_message( + { + "type": "playback-position-changed", + "params": {"position": playback_position}, + } + ) + + async def monitor_player_application_settings(): + async for settings in avrcp_protocol.monitor_player_application_settings(): + print("PLAYER APPLICATION SETTINGS:", settings) + settings_as_dict = [ + {"attribute": setting.attribute_id.name, "value": setting.value_id.name} + for setting in settings + ] + websocket_server.send_message( + { + "type": "player-settings-changed", + "params": {"settings": settings_as_dict}, + } + ) + + async def monitor_available_players(): + async for _ in avrcp_protocol.monitor_available_players(): + print("AVAILABLE PLAYERS CHANGED") + websocket_server.send_message( + {"type": "available-players-changed", "params": {}} + ) + + async def monitor_addressed_player(): + async for player in avrcp_protocol.monitor_addressed_player(): + print("ADDRESSED PLAYER CHANGED") + websocket_server.send_message( + { + "type": "addressed-player-changed", + "params": { + "player": { + "player_id": player.player_id, + "uid_counter": player.uid_counter, + } + }, + } + ) + + async def monitor_uids(): + async for uid_counter in avrcp_protocol.monitor_uids(): + print("UIDS CHANGED") + websocket_server.send_message( + { + "type": "uids-changed", + "params": { + "uid_counter": uid_counter, + }, + } + ) + + async def monitor_volume(): + async for volume in avrcp_protocol.monitor_volume(): + print("VOLUME CHANGED:", volume) + websocket_server.send_message( + {"type": "volume-changed", "params": {"volume": volume}} + ) + + +# ----------------------------------------------------------------------------- +class WebSocketServer: + def __init__( + self, avrcp_protocol: avrcp.Protocol, avrcp_delegate: Delegate + ) -> None: + self.socket = None + self.delegate = None + self.avrcp_protocol = avrcp_protocol + self.avrcp_delegate = avrcp_delegate + + async def start(self) -> None: + # pylint: disable-next=no-member + await websockets.serve(self.serve, 'localhost', 8989) # type: ignore + + async def serve(self, socket, _path) -> None: + print('### WebSocket connected') + self.socket = socket + while True: + try: + message = await socket.recv() + print('Received: ', str(message)) + + parsed = json.loads(message) + message_type = parsed['type'] + if message_type == 'send-key-down': + await self.on_send_key_down(parsed) + elif message_type == 'send-key-up': + await self.on_send_key_up(parsed) + elif message_type == 'set-volume': + await self.on_set_volume(parsed) + elif message_type == 'get-play-status': + await self.on_get_play_status() + elif message_type == 'get-element-attributes': + await self.on_get_element_attributes() + except websockets.exceptions.ConnectionClosedOK: + self.socket = None + break + + async def on_send_key_down(self, message: dict) -> None: + key = avc.PassThroughFrame.OperationId[message["key"]] + await self.avrcp_protocol.send_key_event(key, True) + + async def on_send_key_up(self, message: dict) -> None: + key = avc.PassThroughFrame.OperationId[message["key"]] + await self.avrcp_protocol.send_key_event(key, False) + + async def on_set_volume(self, message: dict) -> None: + volume = message["volume"] + self.avrcp_delegate.volume = volume + self.avrcp_protocol.notify_volume_changed(volume) + + async def on_get_play_status(self) -> None: + play_status = await self.avrcp_protocol.get_play_status() + self.send_message( + { + "type": "get-play-status-response", + "params": { + "song_length": play_status.song_length, + "song_position": play_status.song_position, + "play_status": play_status.play_status.name, + }, + } + ) + + async def on_get_element_attributes(self) -> None: + attributes = await self.avrcp_protocol.get_element_attributes( + 0, + [ + avrcp.MediaAttributeId.TITLE, + avrcp.MediaAttributeId.ARTIST_NAME, + avrcp.MediaAttributeId.ALBUM_NAME, + avrcp.MediaAttributeId.TRACK_NUMBER, + avrcp.MediaAttributeId.TOTAL_NUMBER_OF_TRACKS, + avrcp.MediaAttributeId.GENRE, + avrcp.MediaAttributeId.PLAYING_TIME, + avrcp.MediaAttributeId.DEFAULT_COVER_ART, + ], + ) + self.send_message( + { + "type": "get-element-attributes-response", + "params": [ + { + "attribute_id": attribute.attribute_id.name, + "attribute_value": attribute.attribute_value, + } + for attribute in attributes + ], + } + ) + + def send_message(self, message: dict) -> None: + if self.socket is None: + print("no socket, dropping message") + return + serialized = json.dumps(message) + utils.AsyncRunner.spawn(self.socket.send(serialized)) + + +# ----------------------------------------------------------------------------- +class Delegate(avrcp.Delegate): + def __init__(self): + super().__init__( + [avrcp.EventId.VOLUME_CHANGED, avrcp.EventId.PLAYBACK_STATUS_CHANGED] + ) + self.websocket_server = None + + async def set_absolute_volume(self, volume: int) -> None: + await super().set_absolute_volume(volume) + if self.websocket_server is not None: + self.websocket_server.send_message( + {"type": "set-volume", "params": {"volume": volume}} + ) + + +# ----------------------------------------------------------------------------- +async def main(): + if len(sys.argv) < 3: + print( + 'Usage: run_avrcp_controller.py ' + ' []' + ) + print('example: run_avrcp_controller.py classic1.json usb:0') + return + + print('<<< connecting to HCI...') + async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink): + print('<<< connected') + + # Create a device + device = Device.from_config_file_with_hci(sys.argv[1], hci_source, hci_sink) + device.classic_enabled = True + + # Setup the SDP to expose the sink service + device.sdp_service_records = sdp_records() + + # Start the controller + await device.power_on() + + # Create a listener to wait for AVDTP connections + listener = avdtp.Listener(avdtp.Listener.create_registrar(device)) + listener.on('connection', on_avdtp_connection) + + avrcp_delegate = Delegate() + avrcp_protocol = avrcp.Protocol(avrcp_delegate) + avrcp_protocol.listen(device) + + websocket_server = WebSocketServer(avrcp_protocol, avrcp_delegate) + avrcp_delegate.websocket_server = websocket_server + avrcp_protocol.on( + "start", lambda: on_avrcp_start(avrcp_protocol, websocket_server) + ) + await websocket_server.start() + + if len(sys.argv) >= 5: + # Connect to the peer + target_address = sys.argv[4] + print(f'=== Connecting to {target_address}...') + connection = await device.connect( + target_address, transport=BT_BR_EDR_TRANSPORT + ) + print(f'=== Connected to {connection.peer_address}!') + + # Request authentication + print('*** Authenticating...') + await connection.authenticate() + print('*** Authenticated') + + # Enable encryption + print('*** Enabling encryption...') + await connection.encrypt() + print('*** Encryption on') + + server = await avdtp.Protocol.connect(connection) + listener.set_server(connection, server) + sink = server.add_sink(codec_capabilities()) + sink.on('rtp_packet', on_rtp_packet) + + await avrcp_protocol.connect(connection) + + else: + # Start being discoverable and connectable + await device.set_discoverable(True) + await device.set_connectable(True) + + await asyncio.get_event_loop().create_future() + + +# ----------------------------------------------------------------------------- +logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) +asyncio.run(main()) diff --git a/tests/avrcp_test.py b/tests/avrcp_test.py new file mode 100644 index 00000000..103f3608 --- /dev/null +++ b/tests/avrcp_test.py @@ -0,0 +1,246 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +import asyncio +import struct + +import pytest + +from bumble import core +from bumble import device +from bumble import host +from bumble import controller +from bumble import link +from bumble import avc +from bumble import avrcp +from bumble import avctp +from bumble.transport import common + + +# ----------------------------------------------------------------------------- +class TwoDevices: + def __init__(self): + self.connections = [None, None] + + addresses = ['F0:F1:F2:F3:F4:F5', 'F5:F4:F3:F2:F1:F0'] + self.link = link.LocalLink() + self.controllers = [ + controller.Controller('C1', link=self.link, public_address=addresses[0]), + controller.Controller('C2', link=self.link, public_address=addresses[1]), + ] + self.devices = [ + device.Device( + address=addresses[0], + host=host.Host( + self.controllers[0], common.AsyncPipeSink(self.controllers[0]) + ), + ), + device.Device( + address=addresses[1], + host=host.Host( + self.controllers[1], common.AsyncPipeSink(self.controllers[1]) + ), + ), + ] + self.devices[0].classic_enabled = True + self.devices[1].classic_enabled = True + self.connections = [None, None] + self.protocols = [None, None] + + def on_connection(self, which, connection): + self.connections[which] = connection + + async def setup_connections(self): + await self.devices[0].power_on() + await self.devices[1].power_on() + + self.connections = await asyncio.gather( + self.devices[0].connect( + self.devices[1].public_address, core.BT_BR_EDR_TRANSPORT + ), + self.devices[1].accept(self.devices[0].public_address), + ) + + self.protocols = [avrcp.Protocol(), avrcp.Protocol()] + self.protocols[0].listen(self.devices[1]) + await self.protocols[1].connect(self.connections[0]) + + +# ----------------------------------------------------------------------------- +def test_frame_parser(): + with pytest.raises(ValueError) as error: + avc.Frame.from_bytes(bytes.fromhex("11480000")) + + x = bytes.fromhex("014D0208") + frame = avc.Frame.from_bytes(x) + assert frame.subunit_type == avc.Frame.SubunitType.PANEL + assert frame.subunit_id == 7 + assert frame.opcode == 8 + + x = bytes.fromhex("014DFF0108") + frame = avc.Frame.from_bytes(x) + assert frame.subunit_type == avc.Frame.SubunitType.PANEL + assert frame.subunit_id == 260 + assert frame.opcode == 8 + + x = bytes.fromhex("0148000019581000000103") + + frame = avc.Frame.from_bytes(x) + + assert isinstance(frame, avc.CommandFrame) + assert frame.ctype == avc.CommandFrame.CommandType.STATUS + assert frame.subunit_type == avc.Frame.SubunitType.PANEL + assert frame.subunit_id == 0 + assert frame.opcode == 0 + + +# ----------------------------------------------------------------------------- +def test_vendor_dependent_command(): + x = bytes.fromhex("0148000019581000000103") + frame = avc.Frame.from_bytes(x) + assert isinstance(frame, avc.VendorDependentCommandFrame) + assert frame.company_id == 0x1958 + assert frame.vendor_dependent_data == bytes.fromhex("1000000103") + + frame = avc.VendorDependentCommandFrame( + avc.CommandFrame.CommandType.STATUS, + avc.Frame.SubunitType.PANEL, + 0, + 0x1958, + bytes.fromhex("1000000103"), + ) + assert bytes(frame) == x + + +# ----------------------------------------------------------------------------- +def test_avctp_message_assembler(): + received_message = [] + + def on_message(transaction_label, is_response, ipid, pid, payload): + received_message.append((transaction_label, is_response, ipid, pid, payload)) + + assembler = avctp.MessageAssembler(on_message) + + payload = bytes.fromhex("01") + assembler.on_pdu(bytes([1 << 4 | 0b00 << 2 | 1 << 1 | 0, 0x11, 0x22]) + payload) + assert received_message + assert received_message[0] == (1, False, False, 0x1122, payload) + + received_message = [] + payload = bytes.fromhex("010203") + assembler.on_pdu(bytes([1 << 4 | 0b01 << 2 | 1 << 1 | 0, 0x11, 0x22]) + payload) + assert len(received_message) == 0 + assembler.on_pdu(bytes([1 << 4 | 0b00 << 2 | 1 << 1 | 0, 0x11, 0x22]) + payload) + assert received_message + assert received_message[0] == (1, False, False, 0x1122, payload) + + received_message = [] + payload = bytes.fromhex("010203") + assembler.on_pdu( + bytes([1 << 4 | 0b01 << 2 | 1 << 1 | 0, 3, 0x11, 0x22]) + payload[0:1] + ) + assembler.on_pdu( + bytes([1 << 4 | 0b10 << 2 | 1 << 1 | 0, 0x11, 0x22]) + payload[1:2] + ) + assembler.on_pdu( + bytes([1 << 4 | 0b11 << 2 | 1 << 1 | 0, 0x11, 0x22]) + payload[2:3] + ) + assert received_message + assert received_message[0] == (1, False, False, 0x1122, payload) + + # received_message = [] + # parameter = bytes.fromhex("010203") + # assembler.on_pdu(struct.pack(">BBH", 0x10, 0b11, len(parameter)) + parameter) + # assert len(received_message) == 0 + + +# ----------------------------------------------------------------------------- +def test_avrcp_pdu_assembler(): + received_pdus = [] + + def on_pdu(pdu_id, parameter): + received_pdus.append((pdu_id, parameter)) + + assembler = avrcp.PduAssembler(on_pdu) + + parameter = bytes.fromhex("01") + assembler.on_pdu(struct.pack(">BBH", 0x10, 0b00, len(parameter)) + parameter) + assert received_pdus + assert received_pdus[0] == (0x10, parameter) + + received_pdus = [] + parameter = bytes.fromhex("010203") + assembler.on_pdu(struct.pack(">BBH", 0x10, 0b01, len(parameter)) + parameter) + assert len(received_pdus) == 0 + assembler.on_pdu(struct.pack(">BBH", 0x10, 0b00, len(parameter)) + parameter) + assert received_pdus + assert received_pdus[0] == (0x10, parameter) + + received_pdus = [] + parameter = bytes.fromhex("010203") + assembler.on_pdu(struct.pack(">BBH", 0x10, 0b01, 1) + parameter[0:1]) + assembler.on_pdu(struct.pack(">BBH", 0x10, 0b10, 1) + parameter[1:2]) + assembler.on_pdu(struct.pack(">BBH", 0x10, 0b11, 1) + parameter[2:3]) + assert received_pdus + assert received_pdus[0] == (0x10, parameter) + + received_pdus = [] + parameter = bytes.fromhex("010203") + assembler.on_pdu(struct.pack(">BBH", 0x10, 0b11, len(parameter)) + parameter) + assert len(received_pdus) == 0 + + +def test_passthrough_commands(): + play_pressed = avc.PassThroughCommandFrame( + avc.CommandFrame.CommandType.CONTROL, + avc.CommandFrame.SubunitType.PANEL, + 0, + avc.PassThroughCommandFrame.StateFlag.PRESSED, + avc.PassThroughCommandFrame.OperationId.PLAY, + b'', + ) + + play_pressed_bytes = bytes(play_pressed) + parsed = avc.Frame.from_bytes(play_pressed_bytes) + assert isinstance(parsed, avc.PassThroughCommandFrame) + assert parsed.operation_id == avc.PassThroughCommandFrame.OperationId.PLAY + assert bytes(parsed) == play_pressed_bytes + + +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_get_supported_events(): + two_devices = TwoDevices() + await two_devices.setup_connections() + + supported_events = await two_devices.protocols[0].get_supported_events() + assert supported_events == [] + + delegate1 = avrcp.Delegate([avrcp.EventId.VOLUME_CHANGED]) + two_devices.protocols[0].delegate = delegate1 + supported_events = await two_devices.protocols[1].get_supported_events() + assert supported_events == [avrcp.EventId.VOLUME_CHANGED] + + +# ----------------------------------------------------------------------------- +if __name__ == '__main__': + test_frame_parser() + test_vendor_dependent_command() + test_avctp_message_assembler() + test_avrcp_pdu_assembler() + test_passthrough_commands() + test_get_supported_events() diff --git a/tests/utils_test.py b/tests/utils_test.py index d6f57807..6266f9ef 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -12,15 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- import contextlib import logging import os +from unittest.mock import MagicMock -from bumble import utils from pyee import EventEmitter -from unittest.mock import MagicMock + +from bumble import utils +# ----------------------------------------------------------------------------- def test_on() -> None: emitter = EventEmitter() with contextlib.closing(utils.EventWatcher()) as context: @@ -33,6 +38,7 @@ def test_on() -> None: assert mock.call_count == 1 +# ----------------------------------------------------------------------------- def test_on_decorator() -> None: emitter = EventEmitter() with contextlib.closing(utils.EventWatcher()) as context: @@ -48,6 +54,7 @@ def on_event(*_) -> None: assert mock.call_count == 1 +# ----------------------------------------------------------------------------- def test_multiple_handlers() -> None: emitter = EventEmitter() with contextlib.closing(utils.EventWatcher()) as context: @@ -64,6 +71,30 @@ def test_multiple_handlers() -> None: mock.assert_called_once_with('b') +# ----------------------------------------------------------------------------- +def test_open_int_enums(): + class Foo(utils.OpenIntEnum): + FOO = 1 + BAR = 2 + BLA = 3 + + x = Foo(1) + assert x.name == "FOO" + assert x.value == 1 + assert int(x) == 1 + assert x == 1 + assert x + 1 == 2 + + x = Foo(4) + assert x.name == "Foo[4]" + assert x.value == 4 + assert int(x) == 4 + assert x == 4 + assert x + 1 == 5 + + print(list(Foo)) + + # ----------------------------------------------------------------------------- def run_tests(): test_on() @@ -75,3 +106,4 @@ def run_tests(): if __name__ == '__main__': logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) run_tests() + test_open_int_enums()