diff --git a/bumble/device.py b/bumble/device.py index bcf256fa..fb5f7258 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -528,11 +528,12 @@ def __int__(self) -> int: return int(properties) - @staticmethod + @classmethod def from_advertising_type( + cls: Type[AdvertisingEventProperties], advertising_type: AdvertisingType, ) -> AdvertisingEventProperties: - return AdvertisingEventProperties( + return cls( is_connectable=advertising_type.is_connectable, is_scannable=advertising_type.is_scannable, is_directed=advertising_type.is_directed, @@ -711,6 +712,16 @@ async def set_random_address(self, random_address: Address) -> None: async def start( self, duration: float = 0.0, max_advertising_events: int = 0 ) -> None: + """ + Start advertising. + + Args: + duration: How long to advertise for, in seconds. Use 0 (the default) for + an unlimited duration, unless this advertising set is a High Duty Cycle + Directed Advertisement type. + max_advertising_events: Maximum number of events to advertise for. Use 0 + (the default) for an unlimited number of advertisements. + """ await self.device.send_command( HCI_LE_Set_Extended_Advertising_Enable_Command( enable=1, @@ -2154,11 +2165,11 @@ async def create_advertising_set( if periodic_advertising_parameters: # TODO: call LE Set Periodic Advertising Parameters command - pass + raise NotImplementedError('periodic advertising not yet supported') if periodic_advertising_data: # TODO: call LE Set Periodic Advertising Data command - pass + raise NotImplementedError('periodic advertising not yet supported') except HCI_Error as error: # Remove the advertising set so that it doesn't stay dangling in the @@ -2199,11 +2210,10 @@ def is_advertising(self): if self.legacy_advertising_set and self.legacy_advertising_set.enabled: return True - for advertising_set in self.extended_advertising_sets.values(): - if advertising_set.enabled: - return True - - return False + return any( + advertising_set.enabled + for advertising_set in self.extended_advertising_sets.values() + ) async def start_scanning( self, @@ -3532,7 +3542,10 @@ def on_advertising_set_termination( number_of_completed_extended_advertising_events, ): if not ( - advertising_set := self.extended_advertising_sets.get(advertising_handle) + advertising_set := ( + self.extended_advertising_sets.get(advertising_handle) + or self.legacy_advertising_set + ) ): logger.warning(f'advertising set {advertising_handle} not found') return @@ -3565,9 +3578,9 @@ def on_advertising_set_termination( lambda _: self.abort_on('flush', advertising_set.start()), ) - self.emit_le_connection(connection) + self._emit_le_connection(connection) - def emit_le_connection(self, connection: Connection) -> None: + def _emit_le_connection(self, connection: Connection) -> None: # If supported, read which PHY we're connected with before # notifying listeners of the new connection. if self.host.supports_command(HCI_LE_READ_PHY_COMMAND): @@ -3642,6 +3655,7 @@ def on_connection( # We were connected via a legacy advertisement. if self.legacy_advertiser: own_address_type = self.legacy_advertiser.own_address_type + self.legacy_advertiser = None else: # This should not happen, but just in case, pick a default. logger.warning("connection without an advertiser") @@ -3684,7 +3698,7 @@ def on_connection( if role == HCI_CENTRAL_ROLE or not self.supports_le_extended_advertising: # We can emit now, we have all the info we need - self.emit_le_connection(connection) + self._emit_le_connection(connection) @host_event_handler def on_connection_failure(self, transport, peer_address, error_code): diff --git a/bumble/host.py b/bumble/host.py index d0bac9c3..52a9a73a 100644 --- a/bumble/host.py +++ b/bumble/host.py @@ -28,145 +28,15 @@ from bumble.l2cap import L2CAP_PDU from bumble.snoop import Snooper from bumble import drivers - -from .hci import ( - Address, - HCI_ACL_DATA_PACKET, - HCI_COMMAND_PACKET, - HCI_EVENT_PACKET, - HCI_ISO_DATA_PACKET, - HCI_LE_READ_BUFFER_SIZE_COMMAND, - HCI_LE_READ_LOCAL_SUPPORTED_FEATURES_COMMAND, - HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND, - HCI_LE_WRITE_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND, - HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND, - HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND, - HCI_READ_BUFFER_SIZE_COMMAND, - HCI_READ_LOCAL_VERSION_INFORMATION_COMMAND, - HCI_RESET_COMMAND, - HCI_SUCCESS, - HCI_SUPPORTED_COMMANDS_FLAGS, - HCI_SYNCHRONOUS_DATA_PACKET, - HCI_VERSION_BLUETOOTH_CORE_4_0, - HCI_INQUIRY_COMPLETE_EVENT, - HCI_INQUIRY_RESULT_EVENT, - HCI_CONNECTION_COMPLETE_EVENT, - HCI_CONNECTION_REQUEST_EVENT, - HCI_DISCONNECTION_COMPLETE_EVENT, - HCI_AUTHENTICATION_COMPLETE_EVENT, - HCI_REMOTE_NAME_REQUEST_COMPLETE_EVENT, - HCI_ENCRYPTION_CHANGE_EVENT, - HCI_CHANGE_CONNECTION_LINK_KEY_COMPLETE_EVENT, - HCI_LINK_KEY_TYPE_CHANGED_EVENT, - HCI_READ_REMOTE_SUPPORTED_FEATURES_COMPLETE_EVENT, - HCI_READ_REMOTE_VERSION_INFORMATION_COMPLETE_EVENT, - HCI_QOS_SETUP_COMPLETE_EVENT, - HCI_HARDWARE_ERROR_EVENT, - HCI_FLUSH_OCCURRED_EVENT, - HCI_ROLE_CHANGE_EVENT, - HCI_MODE_CHANGE_EVENT, - HCI_RETURN_LINK_KEYS_EVENT, - HCI_PIN_CODE_REQUEST_EVENT, - HCI_LINK_KEY_REQUEST_EVENT, - HCI_LINK_KEY_NOTIFICATION_EVENT, - HCI_LOOPBACK_COMMAND_EVENT, - HCI_DATA_BUFFER_OVERFLOW_EVENT, - HCI_MAX_SLOTS_CHANGE_EVENT, - HCI_READ_CLOCK_OFFSET_COMPLETE_EVENT, - HCI_CONNECTION_PACKET_TYPE_CHANGED_EVENT, - HCI_QOS_VIOLATION_EVENT, - HCI_PAGE_SCAN_REPETITION_MODE_CHANGE_EVENT, - HCI_FLOW_SPECIFICATION_COMPLETE_EVENT, - HCI_INQUIRY_RESULT_WITH_RSSI_EVENT, - HCI_READ_REMOTE_EXTENDED_FEATURES_COMPLETE_EVENT, - HCI_SYNCHRONOUS_CONNECTION_COMPLETE_EVENT, - HCI_SYNCHRONOUS_CONNECTION_CHANGED_EVENT, - HCI_SNIFF_SUBRATING_EVENT, - HCI_EXTENDED_INQUIRY_RESULT_EVENT, - HCI_ENCRYPTION_KEY_REFRESH_COMPLETE_EVENT, - HCI_IO_CAPABILITY_REQUEST_EVENT, - HCI_IO_CAPABILITY_RESPONSE_EVENT, - HCI_USER_CONFIRMATION_REQUEST_EVENT, - HCI_USER_PASSKEY_REQUEST_EVENT, - HCI_REMOTE_OOB_DATA_REQUEST_EVENT, - HCI_SIMPLE_PAIRING_COMPLETE_EVENT, - HCI_LINK_SUPERVISION_TIMEOUT_CHANGED_EVENT, - HCI_ENHANCED_FLUSH_COMPLETE_EVENT, - HCI_USER_PASSKEY_NOTIFICATION_EVENT, - HCI_KEYPRESS_NOTIFICATION_EVENT, - HCI_REMOTE_HOST_SUPPORTED_FEATURES_NOTIFICATION_EVENT, - HCI_LE_META_EVENT, - HCI_LE_CONNECTION_COMPLETE_EVENT, - HCI_LE_ADVERTISING_REPORT_EVENT, - HCI_LE_CONNECTION_UPDATE_COMPLETE_EVENT, - HCI_LE_READ_REMOTE_FEATURES_COMPLETE_EVENT, - HCI_LE_LONG_TERM_KEY_REQUEST_EVENT, - HCI_LE_REMOTE_CONNECTION_PARAMETER_REQUEST_EVENT, - HCI_LE_DATA_LENGTH_CHANGE_EVENT, - HCI_LE_READ_LOCAL_P_256_PUBLIC_KEY_COMPLETE_EVENT, - HCI_LE_GENERATE_DHKEY_COMPLETE_EVENT, - HCI_LE_ENHANCED_CONNECTION_COMPLETE_EVENT, - HCI_LE_DIRECTED_ADVERTISING_REPORT_EVENT, - HCI_LE_PHY_UPDATE_COMPLETE_EVENT, - HCI_LE_EXTENDED_ADVERTISING_REPORT_EVENT, - HCI_LE_PERIODIC_ADVERTISING_SYNC_ESTABLISHED_EVENT, - HCI_LE_PERIODIC_ADVERTISING_REPORT_EVENT, - HCI_LE_PERIODIC_ADVERTISING_SYNC_LOST_EVENT, - HCI_LE_SCAN_TIMEOUT_EVENT, - HCI_LE_ADVERTISING_SET_TERMINATED_EVENT, - HCI_LE_SCAN_REQUEST_RECEIVED_EVENT, - HCI_LE_CONNECTIONLESS_IQ_REPORT_EVENT, - HCI_LE_CONNECTION_IQ_REPORT_EVENT, - HCI_LE_CTE_REQUEST_FAILED_EVENT, - HCI_LE_PERIODIC_ADVERTISING_SYNC_TRANSFER_RECEIVED_EVENT, - HCI_LE_CIS_ESTABLISHED_EVENT, - HCI_LE_CIS_REQUEST_EVENT, - HCI_LE_CREATE_BIG_COMPLETE_EVENT, - HCI_LE_TERMINATE_BIG_COMPLETE_EVENT, - HCI_LE_BIG_SYNC_ESTABLISHED_EVENT, - HCI_LE_BIG_SYNC_LOST_EVENT, - HCI_LE_REQUEST_PEER_SCA_COMPLETE_EVENT, - HCI_LE_PATH_LOSS_THRESHOLD_EVENT, - HCI_LE_TRANSMIT_POWER_REPORTING_EVENT, - HCI_LE_BIGINFO_ADVERTISING_REPORT_EVENT, - HCI_LE_SUBRATE_CHANGE_EVENT, - HCI_AclDataPacket, - HCI_AclDataPacketAssembler, - HCI_Command, - HCI_Command_Complete_Event, - HCI_Constant, - HCI_Error, - HCI_Event, - HCI_IsoDataPacket, - HCI_LE_Long_Term_Key_Request_Negative_Reply_Command, - HCI_LE_Long_Term_Key_Request_Reply_Command, - HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command, - HCI_LE_Read_Buffer_Size_Command, - HCI_LE_Read_Local_Supported_Features_Command, - HCI_LE_Read_Suggested_Default_Data_Length_Command, - HCI_LE_Remote_Connection_Parameter_Request_Reply_Command, - HCI_LE_Set_Event_Mask_Command, - HCI_LE_Write_Suggested_Default_Data_Length_Command, - HCI_LE_Read_Maximum_Advertising_Data_Length_Command, - HCI_Link_Key_Request_Negative_Reply_Command, - HCI_Link_Key_Request_Reply_Command, - HCI_Packet, - HCI_Read_Buffer_Size_Command, - HCI_Read_Local_Supported_Commands_Command, - HCI_Read_Local_Version_Information_Command, - HCI_Reset_Command, - HCI_Set_Event_Mask_Command, - HCI_SynchronousDataPacket, - LeFeatureMask, -) -from .core import ( +from bumble import hci +from bumble.core import ( BT_BR_EDR_TRANSPORT, BT_LE_TRANSPORT, ConnectionPHY, ConnectionParameters, ) -from .utils import AbortableEventEmitter -from .transport.common import TransportLostError +from bumble.utils import AbortableEventEmitter +from bumble.transport.common import TransportLostError if TYPE_CHECKING: from .transport.common import TransportSink, TransportSource @@ -186,15 +56,15 @@ def __init__( self, max_packet_size: int, max_in_flight: int, - send: Callable[[HCI_Packet], None], + send: Callable[[hci.HCI_Packet], None], ) -> None: self.max_packet_size = max_packet_size self.max_in_flight = max_in_flight self.in_flight = 0 self.send = send - self.packets: Deque[HCI_AclDataPacket] = collections.deque() + self.packets: Deque[hci.HCI_AclDataPacket] = collections.deque() - def enqueue(self, packet: HCI_AclDataPacket) -> None: + def enqueue(self, packet: hci.HCI_AclDataPacket) -> None: self.packets.appendleft(packet) self.check_queue() @@ -226,11 +96,13 @@ def on_packets_completed(self, packet_count: int) -> None: # ----------------------------------------------------------------------------- class Connection: - def __init__(self, host: Host, handle: int, peer_address: Address, transport: int): + def __init__( + self, host: Host, handle: int, peer_address: hci.Address, transport: int + ): self.host = host self.handle = handle self.peer_address = peer_address - self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu) + self.assembler = hci.HCI_AclDataPacketAssembler(self.on_acl_pdu) self.transport = transport acl_packet_queue: Optional[AclPacketQueue] = ( host.le_acl_packet_queue @@ -240,7 +112,7 @@ def __init__(self, host: Host, handle: int, peer_address: Address, transport: in assert acl_packet_queue self.acl_packet_queue = acl_packet_queue - def on_hci_acl_data_packet(self, packet: HCI_AclDataPacket) -> None: + def on_hci_acl_data_packet(self, packet: hci.HCI_AclDataPacket) -> None: self.assembler.feed_packet(packet) def on_acl_pdu(self, pdu: bytes) -> None: @@ -251,14 +123,14 @@ def on_acl_pdu(self, pdu: bytes) -> None: # ----------------------------------------------------------------------------- @dataclasses.dataclass class ScoLink: - peer_address: Address + peer_address: hci.Address handle: int # ----------------------------------------------------------------------------- @dataclasses.dataclass class CisLink: - peer_address: Address + peer_address: hci.Address handle: int @@ -274,7 +146,7 @@ class Host(AbortableEventEmitter): long_term_key_provider: Optional[ Callable[[int, bytes, int], Awaitable[Optional[bytes]]] ] - link_key_provider: Optional[Callable[[Address], Awaitable[Optional[bytes]]]] + link_key_provider: Optional[Callable[[hci.Address], Awaitable[Optional[bytes]]]] def __init__( self, @@ -311,7 +183,7 @@ def __init__( def find_connection_by_bd_addr( self, - bd_addr: Address, + bd_addr: hci.Address, transport: Optional[int] = None, check_address_type: bool = False, ) -> Optional[Connection]: @@ -353,80 +225,80 @@ async def reset(self, driver_factory=drivers.get_driver_for_host): # Send a reset command unless a driver has already done so. if reset_needed: - await self.send_command(HCI_Reset_Command(), check_result=True) + await self.send_command(hci.HCI_Reset_Command(), check_result=True) self.ready = True response = await self.send_command( - HCI_Read_Local_Supported_Commands_Command(), check_result=True + hci.HCI_Read_Local_Supported_Commands_Command(), check_result=True ) self.local_supported_commands = response.return_parameters.supported_commands - if self.supports_command(HCI_LE_READ_LOCAL_SUPPORTED_FEATURES_COMMAND): + if self.supports_command(hci.HCI_LE_READ_LOCAL_SUPPORTED_FEATURES_COMMAND): response = await self.send_command( - HCI_LE_Read_Local_Supported_Features_Command(), check_result=True + hci.HCI_LE_Read_Local_Supported_Features_Command(), check_result=True ) self.local_le_features = struct.unpack( ' None: source.set_packet_sink(self) self.hci_metadata = getattr(source, 'metadata', self.hci_metadata) - def send_hci_packet(self, packet: HCI_Packet) -> None: + def send_hci_packet(self, packet: hci.HCI_Packet) -> None: logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: {packet}') if self.snooper: self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER) @@ -622,11 +499,12 @@ async def send_command(self, command, check_result=False): else: status = response.return_parameters.status - if status != HCI_SUCCESS: + if status != hci.HCI_SUCCESS: logger.warning( - f'{command.name} failed ({HCI_Constant.error_name(status)})' + f'{command.name} failed ' + f'({hci.HCI_Constant.error_name(status)})' ) - raise HCI_Error(status) + raise hci.HCI_Error(status) return response except Exception as error: @@ -639,8 +517,8 @@ async def send_command(self, command, check_result=False): self.pending_response = None # Use this method to send a command from a task - def send_command_sync(self, command: HCI_Command) -> None: - async def send_command(command: HCI_Command) -> None: + def send_command_sync(self, command: hci.HCI_Command) -> None: + async def send_command(command: hci.HCI_Command) -> None: await self.send_command(command) asyncio.create_task(send_command(command)) @@ -665,7 +543,7 @@ def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None: pb_flag = 0 while bytes_remaining: data_total_length = min(bytes_remaining, packet_queue.max_packet_size) - acl_packet = HCI_AclDataPacket( + acl_packet = hci.HCI_AclDataPacket( connection_handle=connection_handle, pb_flag=pb_flag, bc_flag=0, @@ -680,7 +558,7 @@ def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None: def supports_command(self, command): # Find the support flag position for this command - for octet, flags in enumerate(HCI_SUPPORTED_COMMANDS_FLAGS): + for octet, flags in enumerate(hci.HCI_SUPPORTED_COMMANDS_FLAGS): for flag_position, value in enumerate(flags): if value == command: # Check if the flag is set @@ -695,16 +573,16 @@ def supports_command(self, command): def supported_commands(self): commands = [] for octet, flags in enumerate(self.local_supported_commands): - if octet < len(HCI_SUPPORTED_COMMANDS_FLAGS): + if octet < len(hci.HCI_SUPPORTED_COMMANDS_FLAGS): for flag in range(8): if flags & (1 << flag) != 0: - command = HCI_SUPPORTED_COMMANDS_FLAGS[octet][flag] + command = hci.HCI_SUPPORTED_COMMANDS_FLAGS[octet][flag] if command is not None: commands.append(command) return commands - def supports_le_features(self, feature: LeFeatureMask) -> bool: + def supports_le_features(self, feature: hci.LeFeatureMask) -> bool: return (self.local_le_features & feature) == feature @property @@ -715,10 +593,10 @@ def supported_le_features(self): # Packet Sink protocol (packets coming from the controller via HCI) def on_packet(self, packet: bytes) -> None: - hci_packet = HCI_Packet.from_bytes(packet) + hci_packet = hci.HCI_Packet.from_bytes(packet) if self.ready or ( - isinstance(hci_packet, HCI_Command_Complete_Event) - and hci_packet.command_opcode == HCI_RESET_COMMAND + isinstance(hci_packet, hci.HCI_Command_Complete_Event) + and hci_packet.command_opcode == hci.HCI_RESET_COMMAND ): self.on_hci_packet(hci_packet) else: @@ -731,44 +609,44 @@ def on_transport_lost(self): self.emit('flush') - def on_hci_packet(self, packet: HCI_Packet) -> None: + def on_hci_packet(self, packet: hci.HCI_Packet) -> None: logger.debug(f'{color("### CONTROLLER -> HOST", "green")}: {packet}') if self.snooper: self.snooper.snoop(bytes(packet), Snooper.Direction.CONTROLLER_TO_HOST) # If the packet is a command, invoke the handler for this packet - if packet.hci_packet_type == HCI_COMMAND_PACKET: - self.on_hci_command_packet(cast(HCI_Command, packet)) - elif packet.hci_packet_type == HCI_EVENT_PACKET: - self.on_hci_event_packet(cast(HCI_Event, packet)) - elif packet.hci_packet_type == HCI_ACL_DATA_PACKET: - self.on_hci_acl_data_packet(cast(HCI_AclDataPacket, packet)) - elif packet.hci_packet_type == HCI_SYNCHRONOUS_DATA_PACKET: - self.on_hci_sco_data_packet(cast(HCI_SynchronousDataPacket, packet)) - elif packet.hci_packet_type == HCI_ISO_DATA_PACKET: - self.on_hci_iso_data_packet(cast(HCI_IsoDataPacket, packet)) + if packet.hci_packet_type == hci.HCI_COMMAND_PACKET: + self.on_hci_command_packet(cast(hci.HCI_Command, packet)) + elif packet.hci_packet_type == hci.HCI_EVENT_PACKET: + self.on_hci_event_packet(cast(hci.HCI_Event, packet)) + elif packet.hci_packet_type == hci.HCI_ACL_DATA_PACKET: + self.on_hci_acl_data_packet(cast(hci.HCI_AclDataPacket, packet)) + elif packet.hci_packet_type == hci.HCI_SYNCHRONOUS_DATA_PACKET: + self.on_hci_sco_data_packet(cast(hci.HCI_SynchronousDataPacket, packet)) + elif packet.hci_packet_type == hci.HCI_ISO_DATA_PACKET: + self.on_hci_iso_data_packet(cast(hci.HCI_IsoDataPacket, packet)) else: logger.warning(f'!!! unknown packet type {packet.hci_packet_type}') - def on_hci_command_packet(self, command: HCI_Command) -> None: + def on_hci_command_packet(self, command: hci.HCI_Command) -> None: logger.warning(f'!!! unexpected command packet: {command}') - def on_hci_event_packet(self, event: HCI_Event) -> None: + def on_hci_event_packet(self, event: hci.HCI_Event) -> None: handler_name = f'on_{event.name.lower()}' handler = getattr(self, handler_name, self.on_hci_event) handler(event) - def on_hci_acl_data_packet(self, packet: HCI_AclDataPacket) -> None: + def on_hci_acl_data_packet(self, packet: hci.HCI_AclDataPacket) -> None: # Look for the connection to which this data belongs if connection := self.connections.get(packet.connection_handle): connection.on_hci_acl_data_packet(packet) - def on_hci_sco_data_packet(self, packet: HCI_SynchronousDataPacket) -> None: + def on_hci_sco_data_packet(self, packet: hci.HCI_SynchronousDataPacket) -> None: # Experimental self.emit('sco_packet', packet.connection_handle, packet) - def on_hci_iso_data_packet(self, packet: HCI_IsoDataPacket) -> None: + def on_hci_iso_data_packet(self, packet: hci.HCI_IsoDataPacket) -> None: # Experimental self.emit('iso_packet', packet.connection_handle, packet) @@ -832,11 +710,11 @@ def on_hci_connection_request_event(self, event): def on_hci_le_connection_complete_event(self, event): # Check if this is a cancellation - if event.status == HCI_SUCCESS: + if event.status == hci.HCI_SUCCESS: # Create/update the connection logger.debug( f'### LE CONNECTION: [0x{event.connection_handle:04X}] ' - f'{event.peer_address} as {HCI_Constant.role_name(event.role)}' + f'{event.peer_address} as {hci.HCI_Constant.role_name(event.role)}' ) connection = self.connections.get(event.connection_handle) @@ -876,7 +754,7 @@ def on_hci_le_enhanced_connection_complete_event(self, event): self.on_hci_le_connection_complete_event(event) def on_hci_connection_complete_event(self, event): - if event.status == HCI_SUCCESS: + if event.status == hci.HCI_SUCCESS: # Create/update the connection logger.debug( f'### BR/EDR CONNECTION: [0x{event.connection_handle:04X}] ' @@ -923,7 +801,7 @@ def on_hci_disconnection_complete_event(self, event): logger.warning('!!! DISCONNECTION COMPLETE: unknown handle') return - if event.status == HCI_SUCCESS: + if event.status == hci.HCI_SUCCESS: logger.debug( f'### DISCONNECTION: [0x{handle:04X}] ' f'{connection.peer_address} ' @@ -932,7 +810,9 @@ def on_hci_disconnection_complete_event(self, event): # Notify the listeners self.emit('disconnection', handle, event.reason) - ( + + # Remove the handle reference + _ = ( self.connections.pop(handle, 0) or self.cis_links.pop(handle, 0) or self.sco_links.pop(handle, 0) @@ -949,7 +829,7 @@ def on_hci_le_connection_update_complete_event(self, event): return # Notify the client - if event.status == HCI_SUCCESS: + if event.status == hci.HCI_SUCCESS: connection_parameters = ConnectionParameters( event.connection_interval, event.peripheral_latency, @@ -969,7 +849,7 @@ def on_hci_le_phy_update_complete_event(self, event): return # Notify the client - if event.status == HCI_SUCCESS: + if event.status == hci.HCI_SUCCESS: connection_phy = ConnectionPHY(event.tx_phy, event.rx_phy) self.emit('connection_phy_update', connection.handle, connection_phy) else: @@ -1002,10 +882,10 @@ def on_hci_le_cis_request_event(self, event): def on_hci_le_cis_established_event(self, event): # The remaining parameters are unused for now. - if event.status == HCI_SUCCESS: + if event.status == hci.HCI_SUCCESS: self.cis_links[event.connection_handle] = CisLink( handle=event.connection_handle, - peer_address=Address.ANY, + peer_address=hci.Address.ANY, ) self.emit('cis_establishment', event.connection_handle) else: @@ -1021,7 +901,7 @@ def on_hci_le_remote_connection_parameter_request_event(self, event): # For now, just accept everything # TODO: delegate the decision self.send_command_sync( - HCI_LE_Remote_Connection_Parameter_Request_Reply_Command( + hci.HCI_LE_Remote_Connection_Parameter_Request_Reply_Command( connection_handle=event.connection_handle, interval_min=event.interval_min, interval_max=event.interval_max, @@ -1052,12 +932,12 @@ async def send_long_term_key(): ), ) if long_term_key: - response = HCI_LE_Long_Term_Key_Request_Reply_Command( + response = hci.HCI_LE_Long_Term_Key_Request_Reply_Command( connection_handle=event.connection_handle, long_term_key=long_term_key, ) else: - response = HCI_LE_Long_Term_Key_Request_Negative_Reply_Command( + response = hci.HCI_LE_Long_Term_Key_Request_Negative_Reply_Command( connection_handle=event.connection_handle ) @@ -1066,7 +946,7 @@ async def send_long_term_key(): asyncio.create_task(send_long_term_key()) def on_hci_synchronous_connection_complete_event(self, event): - if event.status == HCI_SUCCESS: + if event.status == hci.HCI_SUCCESS: # Create/update the connection logger.debug( f'### SCO CONNECTION: [0x{event.connection_handle:04X}] ' @@ -1095,16 +975,16 @@ def on_hci_synchronous_connection_changed_event(self, event): pass def on_hci_role_change_event(self, event): - if event.status == HCI_SUCCESS: + if event.status == hci.HCI_SUCCESS: logger.debug( f'role change for {event.bd_addr}: ' - f'{HCI_Constant.role_name(event.new_role)}' + f'{hci.HCI_Constant.role_name(event.new_role)}' ) self.emit('role_change', event.bd_addr, event.new_role) else: logger.debug( f'role change for {event.bd_addr} failed: ' - f'{HCI_Constant.error_name(event.status)}' + f'{hci.HCI_Constant.error_name(event.status)}' ) self.emit('role_change_failure', event.bd_addr, event.status) @@ -1120,7 +1000,7 @@ def on_hci_le_data_length_change_event(self, event): def on_hci_authentication_complete_event(self, event): # Notify the client - if event.status == HCI_SUCCESS: + if event.status == hci.HCI_SUCCESS: self.emit('connection_authentication', event.connection_handle) else: self.emit( @@ -1131,7 +1011,7 @@ def on_hci_authentication_complete_event(self, event): def on_hci_encryption_change_event(self, event): # Notify the client - if event.status == HCI_SUCCESS: + if event.status == hci.HCI_SUCCESS: self.emit( 'connection_encryption_change', event.connection_handle, @@ -1144,7 +1024,7 @@ def on_hci_encryption_change_event(self, event): def on_hci_encryption_key_refresh_complete_event(self, event): # Notify the client - if event.status == HCI_SUCCESS: + if event.status == hci.HCI_SUCCESS: self.emit('connection_encryption_key_refresh', event.connection_handle) else: self.emit( @@ -1165,16 +1045,16 @@ def on_hci_page_scan_repetition_mode_change_event(self, event): def on_hci_link_key_notification_event(self, event): logger.debug( f'link key for {event.bd_addr}: {event.link_key.hex()}, ' - f'type={HCI_Constant.link_key_type_name(event.key_type)}' + f'type={hci.HCI_Constant.link_key_type_name(event.key_type)}' ) self.emit('link_key', event.bd_addr, event.link_key, event.key_type) def on_hci_simple_pairing_complete_event(self, event): logger.debug( f'simple pairing complete for {event.bd_addr}: ' - f'status={HCI_Constant.status_name(event.status)}' + f'status={hci.HCI_Constant.status_name(event.status)}' ) - if event.status == HCI_SUCCESS: + if event.status == hci.HCI_SUCCESS: self.emit('classic_pairing', event.bd_addr) else: self.emit('classic_pairing_failure', event.bd_addr, event.status) @@ -1194,11 +1074,11 @@ async def send_link_key(): self.link_key_provider(event.bd_addr), ) if link_key: - response = HCI_Link_Key_Request_Reply_Command( + response = hci.HCI_Link_Key_Request_Reply_Command( bd_addr=event.bd_addr, link_key=link_key ) else: - response = HCI_Link_Key_Request_Negative_Reply_Command( + response = hci.HCI_Link_Key_Request_Negative_Reply_Command( bd_addr=event.bd_addr ) @@ -1255,7 +1135,7 @@ def on_hci_extended_inquiry_result_event(self, event): ) def on_hci_remote_name_request_complete_event(self, event): - if event.status != HCI_SUCCESS: + if event.status != hci.HCI_SUCCESS: self.emit('remote_name_failure', event.bd_addr, event.status) else: utf8_name = event.remote_name @@ -1273,7 +1153,7 @@ def on_hci_remote_host_supported_features_notification_event(self, event): ) def on_hci_le_read_remote_features_complete_event(self, event): - if event.status != HCI_SUCCESS: + if event.status != hci.HCI_SUCCESS: self.emit( 'le_remote_features_failure', event.connection_handle, event.status ) diff --git a/examples/run_extended_advertiser_2.py b/examples/run_extended_advertiser_2.py index db4d6199..735e1c5f 100644 --- a/examples/run_extended_advertiser_2.py +++ b/examples/run_extended_advertiser_2.py @@ -1,4 +1,4 @@ -# Copyright 2021-2022 Google LLC +# Copyright 2021-2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/device_test.py b/tests/device_test.py index 89aa567f..5d872826 100644 --- a/tests/device_test.py +++ b/tests/device_test.py @@ -50,7 +50,8 @@ GATT_APPEARANCE_CHARACTERISTIC, ) -from .test_utils import TwoDevices +from .test_utils import TwoDevices, async_barrier + # ----------------------------------------------------------------------------- # Logging @@ -310,14 +311,18 @@ async def test_legacy_advertising_disconnection(auto_restart): ConnectionParameters(0, 0, 0), ) - device.start_advertising = mock.AsyncMock() + device.on_advertising_set_termination( + HCI_SUCCESS, device.legacy_advertising_set.advertising_handle, 0x0001, 0 + ) device.on_disconnection(0x0001, 0) + await async_barrier() + await async_barrier() if auto_restart: assert device.is_advertising else: - not device.is_advertising + assert not device.is_advertising # ----------------------------------------------------------------------------- diff --git a/tests/gatt_test.py b/tests/gatt_test.py index 19dff2f2..e3c92097 100644 --- a/tests/gatt_test.py +++ b/tests/gatt_test.py @@ -50,6 +50,7 @@ ATT_Error_Response, ATT_Read_By_Group_Type_Request, ) +from .test_utils import async_barrier # ----------------------------------------------------------------------------- @@ -456,13 +457,6 @@ def __init__(self): self.paired = [None, None, None] -# ----------------------------------------------------------------------------- -async def async_barrier(): - ready = asyncio.get_running_loop().create_future() - asyncio.get_running_loop().call_soon(ready.set_result, None) - await ready - - # ----------------------------------------------------------------------------- @pytest.mark.asyncio async def test_read_write(): diff --git a/tests/test_utils.py b/tests/test_utils.py index 331b1860..d193d6e5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +import asyncio from typing import List, Optional from bumble.controller import Controller @@ -22,6 +26,7 @@ from bumble.hci import Address +# ----------------------------------------------------------------------------- class TwoDevices: connections: List[Optional[Connection]] @@ -75,3 +80,10 @@ async def setup_connection(self) -> None: def __getitem__(self, index: int) -> Device: return self.devices[index] + + +# ----------------------------------------------------------------------------- +async def async_barrier(): + ready = asyncio.get_running_loop().create_future() + asyncio.get_running_loop().call_soon(ready.set_result, None) + await ready