From 9d3d5495ce9ec13f49dd559687b2fcd37868a1ff Mon Sep 17 00:00:00 2001 From: Gilles Boccon-Gibod Date: Sat, 23 Nov 2024 15:56:14 -0800 Subject: [PATCH] only use `__bytes__` when not argument is needed. --- apps/hci_bridge.py | 2 +- bumble/att.py | 5 +--- bumble/controller.py | 4 +-- bumble/core.py | 3 -- bumble/device.py | 2 +- bumble/gatt.py | 2 +- bumble/gatt_client.py | 6 ++-- bumble/gatt_server.py | 4 +-- bumble/hci.py | 40 +++++-------------------- bumble/host.py | 2 +- bumble/l2cap.py | 10 ++----- bumble/profiles/bap.py | 68 +++++++++++++++++++----------------------- bumble/sdp.py | 8 +---- bumble/smp.py | 7 ++--- tests/gatt_test.py | 4 +-- tests/hci_test.py | 8 ++--- tests/self_test.py | 2 +- 17 files changed, 62 insertions(+), 115 deletions(-) diff --git a/apps/hci_bridge.py b/apps/hci_bridge.py index 1d1f9a19..00093a0f 100644 --- a/apps/hci_bridge.py +++ b/apps/hci_bridge.py @@ -83,7 +83,7 @@ def host_to_controller_filter(hci_packet): return_parameters=bytes([hci.HCI_SUCCESS]), ) # Return a packet with 'respond to sender' set to True - return (response.to_bytes(), True) + return (bytes(response), True) return None diff --git a/bumble/att.py b/bumble/att.py index 15ad8c69..741c413d 100644 --- a/bumble/att.py +++ b/bumble/att.py @@ -291,9 +291,6 @@ def __init__(self, pdu=None, **kwargs): def init_from_bytes(self, pdu, offset): return HCI_Object.init_from_bytes(self, pdu, offset, self.fields) - def to_bytes(self): - return self.pdu - @property def is_command(self): return ((self.op_code >> 6) & 1) == 1 @@ -303,7 +300,7 @@ def has_authentication_signature(self): return ((self.op_code >> 7) & 1) == 1 def __bytes__(self): - return self.to_bytes() + return self.pdu def __str__(self): result = color(self.name, 'yellow') diff --git a/bumble/controller.py b/bumble/controller.py index 267f3e55..03d3c14d 100644 --- a/bumble/controller.py +++ b/bumble/controller.py @@ -314,7 +314,7 @@ def send_hci_packet(self, packet): f'{color("CONTROLLER -> HOST", "green")}: {packet}' ) if self.host: - self.host.on_packet(packet.to_bytes()) + self.host.on_packet(bytes(packet)) # This method allows the controller to emulate the same API as a transport source async def wait_for_termination(self): @@ -1192,7 +1192,7 @@ def on_hci_read_bd_addr_command(self, _command): See Bluetooth spec Vol 4, Part E - 7.4.6 Read BD_ADDR Command ''' bd_addr = ( - self._public_address.to_bytes() + bytes(self._public_address) if self._public_address is not None else bytes(6) ) diff --git a/bumble/core.py b/bumble/core.py index 5aec826d..f6d42dd5 100644 --- a/bumble/core.py +++ b/bumble/core.py @@ -1624,9 +1624,6 @@ def __bytes__(self): [bytes([len(x[1]) + 1, x[0]]) + x[1] for x in self.ad_structures] ) - def to_bytes(self) -> bytes: - return bytes(self) - def to_string(self, separator=', '): return separator.join( [AdvertisingData.ad_data_to_string(x[0], x[1]) for x in self.ad_structures] diff --git a/bumble/device.py b/bumble/device.py index 908e8a4f..866ef166 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -1986,7 +1986,7 @@ def find_connection_by_bd_addr( check_address_type: bool = False, ) -> Optional[Connection]: for connection in self.connections.values(): - if connection.peer_address.to_bytes() == bd_addr.to_bytes(): + if bytes(connection.peer_address) == bytes(bd_addr): if ( check_address_type and connection.peer_address.address_type != bd_addr.address_type diff --git a/bumble/gatt.py b/bumble/gatt.py index ea65116d..86450fb1 100644 --- a/bumble/gatt.py +++ b/bumble/gatt.py @@ -410,7 +410,7 @@ class IncludedServiceDeclaration(Attribute): def __init__(self, service: Service) -> None: declaration_bytes = struct.pack( - ' None: logger.debug( f'GATT Command from client: [0x{self.connection.handle:04X}] {command}' ) - self.send_gatt_pdu(command.to_bytes()) + self.send_gatt_pdu(bytes(command)) async def send_request(self, request: ATT_PDU): logger.debug( @@ -310,7 +310,7 @@ async def send_request(self, request: ATT_PDU): self.pending_request = request try: - self.send_gatt_pdu(request.to_bytes()) + self.send_gatt_pdu(bytes(request)) response = await asyncio.wait_for( self.pending_response, GATT_REQUEST_TIMEOUT ) @@ -328,7 +328,7 @@ def send_confirmation(self, confirmation: ATT_Handle_Value_Confirmation) -> None f'GATT Confirmation from client: [0x{self.connection.handle:04X}] ' f'{confirmation}' ) - self.send_gatt_pdu(confirmation.to_bytes()) + self.send_gatt_pdu(bytes(confirmation)) async def request_mtu(self, mtu: int) -> int: # Check the range diff --git a/bumble/gatt_server.py b/bumble/gatt_server.py index 0ee673c0..15a2a827 100644 --- a/bumble/gatt_server.py +++ b/bumble/gatt_server.py @@ -353,7 +353,7 @@ def send_response(self, connection: Connection, response: ATT_PDU) -> None: logger.debug( f'GATT Response from server: [0x{connection.handle:04X}] {response}' ) - self.send_gatt_pdu(connection.handle, response.to_bytes()) + self.send_gatt_pdu(connection.handle, bytes(response)) async def notify_subscriber( self, @@ -450,7 +450,7 @@ async def indicate_subscriber( ) try: - self.send_gatt_pdu(connection.handle, indication.to_bytes()) + self.send_gatt_pdu(connection.handle, bytes(indication)) await asyncio.wait_for(pending_confirmation, GATT_REQUEST_TIMEOUT) except asyncio.TimeoutError as error: logger.warning(color('!!! GATT Indicate timeout', 'red')) diff --git a/bumble/hci.py b/bumble/hci.py index ce26596e..1f1aa2ab 100644 --- a/bumble/hci.py +++ b/bumble/hci.py @@ -1496,14 +1496,11 @@ def parse_from_bytes(cls, data: bytes, offset: int) -> tuple[int, CodingFormat]: def from_bytes(cls, data: bytes) -> CodingFormat: return cls.parse_from_bytes(data, 0)[1] - def to_bytes(self) -> bytes: + def __bytes__(self) -> bytes: return struct.pack( ' bytes: - return self.to_bytes() - # ----------------------------------------------------------------------------- class HCI_Constant: @@ -1720,7 +1717,7 @@ def serialize_field(field_value, field_type): field_length = len(field_bytes) field_bytes = bytes([field_length]) + field_bytes elif isinstance(field_value, (bytes, bytearray)) or hasattr( - field_value, 'to_bytes' + field_value, '__bytes__' ): field_bytes = bytes(field_value) if isinstance(field_type, int) and 4 < field_type <= 256: @@ -1765,7 +1762,7 @@ def dict_to_bytes(hci_object, fields): def from_bytes(cls, data, offset, fields): return cls(fields, **cls.dict_from_bytes(data, offset, fields)) - def to_bytes(self): + def __bytes__(self): return HCI_Object.dict_to_bytes(self.__dict__, self.fields) @staticmethod @@ -1860,9 +1857,6 @@ def format_fields(hci_object, fields, indentation='', value_mappers=None): for field_name, field_value in field_strings ) - def __bytes__(self): - return self.to_bytes() - def __init__(self, fields, **kwargs): self.fields = fields self.init_from_fields(self, fields, kwargs) @@ -2037,9 +2031,6 @@ def is_resolvable(self): def is_static(self): return self.is_random and (self.address_bytes[5] >> 6 == 3) - def to_bytes(self): - return self.address_bytes - def to_string(self, with_type_qualifier=True): ''' String representation of the address, MSB first, with an optional type @@ -2051,7 +2042,7 @@ def to_string(self, with_type_qualifier=True): return result + '/P' def __bytes__(self): - return self.to_bytes() + return self.address_bytes def __hash__(self): return hash(self.address_bytes) @@ -2257,16 +2248,13 @@ def __init__(self, op_code=-1, parameters=None, **kwargs): self.op_code = op_code self.parameters = parameters - def to_bytes(self): + def __bytes__(self): parameters = b'' if self.parameters is None else self.parameters return ( struct.pack(' HCI_AclDataPacket: connection_handle, pb_flag, bc_flag, data_total_length, data ) - def to_bytes(self): + def __bytes__(self): h = (self.pb_flag << 12) | (self.bc_flag << 14) | self.connection_handle return ( struct.pack(' HCI_SynchronousDataPacket: connection_handle, packet_status, data_total_length, data ) - def to_bytes(self) -> bytes: + def __bytes__(self) -> bytes: h = (self.packet_status << 12) | self.connection_handle return ( struct.pack(' bytes: - return self.to_bytes() - def __str__(self) -> str: return ( f'{color("SCO", "blue")}: ' @@ -6891,9 +6870,6 @@ def from_bytes(packet: bytes) -> HCI_IsoDataPacket: ) def __bytes__(self) -> bytes: - return self.to_bytes() - - def to_bytes(self) -> bytes: fmt = ' Optional[Connection]: for connection in self.connections.values(): - if connection.peer_address.to_bytes() == bd_addr.to_bytes(): + if bytes(connection.peer_address) == bytes(bd_addr): if ( check_address_type and connection.peer_address.address_type != bd_addr.address_type diff --git a/bumble/l2cap.py b/bumble/l2cap.py index 75e554c8..a7a944d0 100644 --- a/bumble/l2cap.py +++ b/bumble/l2cap.py @@ -225,7 +225,7 @@ def from_bytes(data: bytes) -> L2CAP_PDU: return L2CAP_PDU(l2cap_pdu_cid, l2cap_pdu_payload) - def to_bytes(self) -> bytes: + def __bytes__(self) -> bytes: header = struct.pack(' None: self.cid = cid self.payload = payload - def __bytes__(self) -> bytes: - return self.to_bytes() - def __str__(self) -> str: return f'{color("L2CAP", "green")} [CID={self.cid}]: {self.payload.hex()}' @@ -333,11 +330,8 @@ def __init__(self, pdu=None, **kwargs) -> None: def init_from_bytes(self, pdu, offset): return HCI_Object.init_from_bytes(self, pdu, offset, self.fields) - def to_bytes(self) -> bytes: - return self.pdu - def __bytes__(self) -> bytes: - return self.to_bytes() + return self.pdu def __str__(self) -> str: result = f'{color(self.name, "yellow")} [ID={self.identifier}]' diff --git a/bumble/profiles/bap.py b/bumble/profiles/bap.py index 7bae6578..fb1afa64 100644 --- a/bumble/profiles/bap.py +++ b/bumble/profiles/bap.py @@ -265,7 +265,7 @@ def __bytes__(self) -> bytes: core.AdvertisingData.SERVICE_DATA_16_BIT_UUID, struct.pack( '<2sBIB', - gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE.to_bytes(), + bytes(gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE), self.announcement_type, self.available_audio_contexts, len(self.metadata), @@ -487,24 +487,23 @@ class BroadcastAudioAnnouncement: def from_bytes(cls, data: bytes) -> Self: return cls(int.from_bytes(data[:3], 'little')) - def to_bytes(self) -> bytes: - return self.broadcast_id.to_bytes(3, 'little') - def __bytes__(self) -> bytes: - return self.to_bytes() + return self.broadcast_id.to_bytes(3, 'little') def get_advertising_data(self) -> bytes: - return core.AdvertisingData( - [ - ( - core.AdvertisingData.SERVICE_DATA_16_BIT_UUID, + return bytes( + core.AdvertisingData( + [ ( - gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE.to_bytes() - + self.to_bytes() - ), - ) - ] - ).to_bytes() + core.AdvertisingData.SERVICE_DATA_16_BIT_UUID, + ( + bytes(gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE) + + bytes(self) + ), + ) + ] + ) + ) @dataclasses.dataclass @@ -514,7 +513,7 @@ class BIS: index: int codec_specific_configuration: CodecSpecificConfiguration - def to_bytes(self) -> bytes: + def __bytes__(self) -> bytes: codec_specific_configuration_bytes = bytes( self.codec_specific_configuration ) @@ -523,9 +522,6 @@ def to_bytes(self) -> bytes: + codec_specific_configuration_bytes ) - def __bytes__(self) -> bytes: - return self.to_bytes() - @dataclasses.dataclass class Subgroup: codec_id: hci.CodingFormat @@ -533,14 +529,14 @@ class Subgroup: metadata: le_audio.Metadata bis: List[BasicAudioAnnouncement.BIS] - def to_bytes(self) -> bytes: + def __bytes__(self) -> bytes: metadata_bytes = bytes(self.metadata) codec_specific_configuration_bytes = bytes( self.codec_specific_configuration ) return ( bytes([len(self.bis)]) - + self.codec_id.to_bytes() + + bytes(self.codec_id) + bytes([len(codec_specific_configuration_bytes)]) + codec_specific_configuration_bytes + bytes([len(metadata_bytes)]) @@ -548,9 +544,6 @@ def to_bytes(self) -> bytes: + b''.join(map(bytes, self.bis)) ) - def __bytes__(self) -> bytes: - return self.to_bytes() - presentation_delay: int subgroups: List[BasicAudioAnnouncement.Subgroup] @@ -607,25 +600,24 @@ def from_bytes(cls, data: bytes) -> Self: return cls(presentation_delay, subgroups) - def to_bytes(self) -> bytes: + def __bytes__(self) -> bytes: return ( self.presentation_delay.to_bytes(3, 'little') + bytes([len(self.subgroups)]) + b''.join(map(bytes, self.subgroups)) ) - def __bytes__(self) -> bytes: - return self.to_bytes() - def get_advertising_data(self) -> bytes: - return core.AdvertisingData( - [ - ( - core.AdvertisingData.SERVICE_DATA_16_BIT_UUID, + return bytes( + core.AdvertisingData( + [ ( - gatt.GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE.to_bytes() - + self.to_bytes() - ), - ) - ] - ).to_bytes() + core.AdvertisingData.SERVICE_DATA_16_BIT_UUID, + ( + bytes(gatt.GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE) + + bytes(self) + ), + ) + ] + ) + ) diff --git a/bumble/sdp.py b/bumble/sdp.py index 90ac07c7..826bd598 100644 --- a/bumble/sdp.py +++ b/bumble/sdp.py @@ -344,9 +344,6 @@ def from_bytes(data): ] # Keep a copy so we can re-serialize to an exact replica return result - def to_bytes(self): - return bytes(self) - def __bytes__(self): # Return early if we have a cache if self.bytes: @@ -623,11 +620,8 @@ def __init__(self, pdu=None, transaction_id=0, **kwargs): def init_from_bytes(self, pdu, offset): return HCI_Object.init_from_bytes(self, pdu, offset, self.fields) - def to_bytes(self): - return self.pdu - def __bytes__(self): - return self.to_bytes() + return self.pdu def __str__(self): result = f'{color(self.name, "blue")} [TID={self.transaction_id}]' diff --git a/bumble/smp.py b/bumble/smp.py index 37031048..5d1d4873 100644 --- a/bumble/smp.py +++ b/bumble/smp.py @@ -298,11 +298,8 @@ def __init__(self, pdu: Optional[bytes] = None, **kwargs: Any) -> None: def init_from_bytes(self, pdu: bytes, offset: int) -> None: return HCI_Object.init_from_bytes(self, pdu, offset, self.fields) - def to_bytes(self): - return self.pdu - def __bytes__(self): - return self.to_bytes() + return self.pdu def __str__(self): result = color(self.name, 'yellow') @@ -1949,7 +1946,7 @@ def send_command(self, connection: Connection, command: SMP_Command) -> None: f'{connection.peer_address}: {command}' ) cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID - connection.send_l2cap_pdu(cid, command.to_bytes()) + connection.send_l2cap_pdu(cid, bytes(command)) def on_smp_security_request_command( self, connection: Connection, request: SMP_Security_Request_Command diff --git a/tests/gatt_test.py b/tests/gatt_test.py index 8d73eb3f..bd296517 100644 --- a/tests/gatt_test.py +++ b/tests/gatt_test.py @@ -57,7 +57,7 @@ # ----------------------------------------------------------------------------- def basic_check(x): - pdu = x.to_bytes() + pdu = bytes(x) parsed = ATT_PDU.from_bytes(pdu) x_str = str(x) parsed_str = str(parsed) @@ -74,7 +74,7 @@ def test_UUID(): assert str(u) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6' v = UUID(str(u)) assert str(v) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6' - w = UUID.from_bytes(v.to_bytes()) + w = UUID.from_bytes(bytes(v)) assert str(w) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6' u1 = UUID.from_16_bits(0x1234) diff --git a/tests/hci_test.py b/tests/hci_test.py index 1b69cda3..30ab0d73 100644 --- a/tests/hci_test.py +++ b/tests/hci_test.py @@ -75,13 +75,13 @@ def basic_check(x): - packet = x.to_bytes() + packet = bytes(x) print(packet.hex()) parsed = HCI_Packet.from_bytes(packet) x_str = str(x) parsed_str = str(parsed) print(x_str) - parsed_bytes = parsed.to_bytes() + parsed_bytes = bytes(parsed) assert x_str == parsed_str assert packet == parsed_bytes @@ -188,7 +188,7 @@ def test_HCI_Command_Complete_Event(): return_parameters=bytes([7]), ) basic_check(event) - event = HCI_Packet.from_bytes(event.to_bytes()) + event = HCI_Packet.from_bytes(bytes(event)) assert event.return_parameters == 7 # With a simple status as an integer status @@ -562,7 +562,7 @@ def test_iso_data_packet(): '6281bc77ed6a3206d984bcdabee6be831c699cb50e2' ) - assert packet.to_bytes() == data + assert bytes(packet) == data # ----------------------------------------------------------------------------- diff --git a/tests/self_test.py b/tests/self_test.py index 5c68ea06..6654a546 100644 --- a/tests/self_test.py +++ b/tests/self_test.py @@ -240,7 +240,7 @@ async def test_self_gatt(): result = await peer.discover_included_services(result[0]) assert len(result) == 2 # Service UUID is only present when the UUID is 16-bit Bluetooth UUID - assert result[1].uuid.to_bytes() == s3.uuid.to_bytes() + assert bytes(result[1].uuid) == bytes(s3.uuid) # -----------------------------------------------------------------------------