From a499843fe5f2638c556ae035f0b6c6d1b6ee59f0 Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Sat, 20 Jan 2024 00:13:18 +0800 Subject: [PATCH] Convert HCI role and transport into enums As they are widely used as parameters, making them enums can significantly improve the static check. --- bumble/controller.py | 31 ++++++-------- bumble/core.py | 17 ++++++-- bumble/device.py | 100 +++++++++++++++++++++---------------------- bumble/host.py | 33 +++++++++----- 4 files changed, 97 insertions(+), 84 deletions(-) diff --git a/bumble/controller.py b/bumble/controller.py index 374f138d..ce3710a6 100644 --- a/bumble/controller.py +++ b/bumble/controller.py @@ -24,12 +24,7 @@ import random import struct from bumble.colors import color -from bumble.core import ( - BT_CENTRAL_ROLE, - BT_PERIPHERAL_ROLE, - BT_LE_TRANSPORT, - BT_BR_EDR_TRANSPORT, -) +from bumble.core import PhysicalTransport, Role from bumble.hci import ( HCI_ACL_DATA_PACKET, @@ -98,10 +93,10 @@ class CisLink: class Connection: controller: Controller handle: int - role: int + role: Role peer_address: Address link: Any - transport: int + transport: PhysicalTransport link_type: int def __post_init__(self): @@ -388,10 +383,10 @@ def on_link_central_connected(self, central_address): connection = Connection( controller=self, handle=connection_handle, - role=BT_PERIPHERAL_ROLE, + role=Role.PERIPHERAL, peer_address=peer_address, link=self.link, - transport=BT_LE_TRANSPORT, + transport=PhysicalTransport.LE, link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE, ) self.peripheral_connections[peer_address] = connection @@ -448,10 +443,10 @@ def on_link_peripheral_connection_complete( connection = Connection( controller=self, handle=connection_handle, - role=BT_CENTRAL_ROLE, + role=Role.CENTRAL, peer_address=peer_address, link=self.link, - transport=BT_LE_TRANSPORT, + transport=PhysicalTransport.LE, link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE, ) self.central_connections[peer_address] = connection @@ -467,7 +462,7 @@ def on_link_peripheral_connection_complete( HCI_LE_Connection_Complete_Event( status=status, connection_handle=connection.handle if connection else 0, - role=BT_CENTRAL_ROLE, + role=Role.CENTRAL, peer_address_type=le_create_connection_command.peer_address_type, peer_address=le_create_connection_command.peer_address, connection_interval=le_create_connection_command.connection_interval_min, @@ -529,7 +524,7 @@ def on_link_encrypted(self, peer_address, _rand, _ediv, _ltk): def on_link_acl_data(self, sender_address, transport, data): # Look for the connection to which this data belongs - if transport == BT_LE_TRANSPORT: + if transport == PhysicalTransport.LE: connection = self.find_le_connection_by_address(sender_address) else: connection = self.find_classic_connection_by_address(sender_address) @@ -691,10 +686,10 @@ def on_classic_connection_complete(self, peer_address, status): controller=self, handle=connection_handle, # Role doesn't matter in Classic because they are managed by HCI_Role_Change and HCI_Role_Discovery - role=BT_CENTRAL_ROLE, + role=Role.CENTRAL, peer_address=peer_address, link=self.link, - transport=BT_BR_EDR_TRANSPORT, + transport=PhysicalTransport.BR_EDR, link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE, ) self.classic_connections[peer_address] = connection @@ -759,10 +754,10 @@ def on_classic_sco_connection_complete( controller=self, handle=connection_handle, # Role doesn't matter in SCO. - role=BT_CENTRAL_ROLE, + role=Role.CENTRAL, peer_address=peer_address, link=self.link, - transport=BT_BR_EDR_TRANSPORT, + transport=PhysicalTransport.BR_EDR, link_type=link_type, ) self.classic_connections[peer_address] = connection diff --git a/bumble/core.py b/bumble/core.py index 2722b870..d35a7f3b 100644 --- a/bumble/core.py +++ b/bumble/core.py @@ -28,11 +28,20 @@ # ----------------------------------------------------------------------------- # fmt: off -BT_CENTRAL_ROLE = 0 -BT_PERIPHERAL_ROLE = 1 +class PhysicalTransport(enum.IntEnum): + BR_EDR = 0 + LE = 1 -BT_BR_EDR_TRANSPORT = 0 -BT_LE_TRANSPORT = 1 +class Role(enum.IntEnum): + CENTRAL = 0 + PERIPHERAL = 1 + +# For backward compatibility. +BT_CENTRAL_ROLE = Role.CENTRAL +BT_PERIPHERAL_ROLE = Role.PERIPHERAL + +BT_BR_EDR_TRANSPORT = PhysicalTransport.BR_EDR +BT_LE_TRANSPORT = PhysicalTransport.LE # fmt: on diff --git a/bumble/device.py b/bumble/device.py index 6d5c00c8..0c77dcce 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -155,10 +155,8 @@ from .host import Host from .gap import GenericAccessService from .core import ( - BT_BR_EDR_TRANSPORT, - BT_CENTRAL_ROLE, - BT_LE_TRANSPORT, - BT_PERIPHERAL_ROLE, + PhysicalTransport, + Role, AdvertisingData, ConnectionParameterUpdateError, CommandTimeoutError, @@ -676,12 +674,12 @@ async def disconnect( class Connection(CompositeEventEmitter): device: Device handle: int - transport: int + transport: PhysicalTransport self_address: Address peer_address: Address peer_resolvable_address: Optional[Address] peer_le_features: Optional[LeFeatureMask] - role: int + role: Role encryption: int authenticated: bool sc: bool @@ -770,7 +768,7 @@ def incomplete(cls, device, peer_address, role): return cls( device, None, - BT_BR_EDR_TRANSPORT, + PhysicalTransport.BR_EDR, device.public_address, peer_address, None, @@ -785,7 +783,7 @@ def complete(self, handle, parameters): Finish an incomplete connection upon completion. """ assert self.handle is None - assert self.transport == BT_BR_EDR_TRANSPORT + assert self.transport == PhysicalTransport.BR_EDR self.handle = handle self.parameters = parameters @@ -793,9 +791,9 @@ def complete(self, handle, parameters): def role_name(self): if self.role is None: return 'NOT-SET' - if self.role == BT_CENTRAL_ROLE: + if self.role == Role.CENTRAL: return 'CENTRAL' - if self.role == BT_PERIPHERAL_ROLE: + if self.role == Role.PERIPHERAL: return 'PERIPHERAL' return f'UNKNOWN[{self.role}]' @@ -855,7 +853,7 @@ async def authenticate(self) -> None: async def encrypt(self, enable: bool = True) -> None: return await self.device.encrypt(self, enable) - async def switch_role(self, role: int) -> None: + async def switch_role(self, role: Role) -> None: return await self.device.switch_role(self, role) async def sustain(self, timeout: Optional[float] = None) -> None: @@ -1344,7 +1342,7 @@ def lookup_connection(self, connection_handle: int) -> Optional[Connection]: def find_connection_by_bd_addr( self, bd_addr: Address, - transport: Optional[int] = None, + transport: Optional[PhysicalTransport] = None, check_address_type: bool = False, ) -> Optional[Connection]: for connection in self.connections.values(): @@ -2123,7 +2121,7 @@ async def set_connectable(self, connectable: bool = True) -> None: async def connect( self, peer_address: Union[Address, str], - transport: int = BT_LE_TRANSPORT, + transport: PhysicalTransport = PhysicalTransport.LE, connection_parameters_preferences: Optional[ Dict[int, ConnectionParametersPreferences] ] = None, @@ -2144,17 +2142,17 @@ async def connect( ''' # Check parameters - if transport not in (BT_LE_TRANSPORT, BT_BR_EDR_TRANSPORT): + if transport not in (PhysicalTransport.LE, PhysicalTransport.BR_EDR): raise ValueError('invalid transport') # Adjust the transport automatically if we need to - if transport == BT_LE_TRANSPORT and not self.le_enabled: - transport = BT_BR_EDR_TRANSPORT - elif transport == BT_BR_EDR_TRANSPORT and not self.classic_enabled: - transport = BT_LE_TRANSPORT + if transport == PhysicalTransport.LE and not self.le_enabled: + transport = PhysicalTransport.BR_EDR + elif transport == PhysicalTransport.BR_EDR and not self.classic_enabled: + transport = PhysicalTransport.LE # Check that there isn't already a pending connection - if transport == BT_LE_TRANSPORT and self.is_le_connecting: + if transport == PhysicalTransport.LE and self.is_le_connecting: raise InvalidStateError('connection already pending') if isinstance(peer_address, str): @@ -2171,7 +2169,7 @@ async def connect( else: # All BR/EDR addresses should be public addresses if ( - transport == BT_BR_EDR_TRANSPORT + transport == PhysicalTransport.BR_EDR and peer_address.address_type != Address.PUBLIC_DEVICE_ADDRESS ): raise ValueError('BR/EDR addresses must be PUBLIC') @@ -2179,7 +2177,7 @@ async def connect( assert isinstance(peer_address, Address) def on_connection(connection): - if transport == BT_LE_TRANSPORT or ( + if transport == PhysicalTransport.LE or ( # match BR/EDR connection event against peer address connection.transport == transport and connection.peer_address == peer_address @@ -2187,7 +2185,7 @@ def on_connection(connection): pending_connection.set_result(connection) def on_connection_failure(error): - if transport == BT_LE_TRANSPORT or ( + if transport == PhysicalTransport.LE or ( # match BR/EDR connection failure event against peer address error.transport == transport and error.peer_address == peer_address @@ -2201,7 +2199,7 @@ def on_connection_failure(error): try: # Tell the controller to connect - if transport == BT_LE_TRANSPORT: + if transport == PhysicalTransport.LE: if connection_parameters_preferences is None: if connection_parameters_preferences is None: connection_parameters_preferences = { @@ -2327,7 +2325,7 @@ def on_connection_failure(error): else: # Save pending connection self.pending_connections[peer_address] = Connection.incomplete( - self, peer_address, BT_CENTRAL_ROLE + self, peer_address, Role.CENTRAL ) # TODO: allow passing other settings @@ -2346,7 +2344,7 @@ def on_connection_failure(error): raise HCI_StatusError(result) # Wait for the connection process to complete - if transport == BT_LE_TRANSPORT: + if transport == PhysicalTransport.LE: self.le_connecting = True if timeout is None: @@ -2357,7 +2355,7 @@ def on_connection_failure(error): asyncio.shield(pending_connection), timeout ) except asyncio.TimeoutError: - if transport == BT_LE_TRANSPORT: + if transport == PhysicalTransport.LE: await self.send_command(HCI_LE_Create_Connection_Cancel_Command()) else: await self.send_command( @@ -2371,7 +2369,7 @@ def on_connection_failure(error): finally: self.remove_listener('connection', on_connection) self.remove_listener('connection_failure', on_connection_failure) - if transport == BT_LE_TRANSPORT: + if transport == PhysicalTransport.LE: self.le_connecting = False self.connect_own_address_type = None else: @@ -2380,7 +2378,7 @@ def on_connection_failure(error): async def accept( self, peer_address: Union[Address, str] = Address.ANY, - role: int = BT_PERIPHERAL_ROLE, + role: Role = Role.PERIPHERAL, timeout: Optional[float] = DEVICE_DEFAULT_CONNECT_TIMEOUT, ) -> Connection: ''' @@ -2401,7 +2399,7 @@ async def accept( # If the address is not parsable, assume it is a name instead logger.debug('looking for peer by name') peer_address = await self.find_peer_by_name( - peer_address, BT_BR_EDR_TRANSPORT + peer_address, PhysicalTransport.BR_EDR ) # TODO: timeout assert isinstance(peer_address, Address) @@ -2449,14 +2447,14 @@ async def accept( def on_connection(connection): if ( - connection.transport == BT_BR_EDR_TRANSPORT + connection.transport == PhysicalTransport.BR_EDR and connection.peer_address == peer_address ): pending_connection.set_result(connection) def on_connection_failure(error): if ( - error.transport == BT_BR_EDR_TRANSPORT + error.transport == PhysicalTransport.BR_EDR and error.peer_address == peer_address ): pending_connection.set_exception(error) @@ -2469,7 +2467,7 @@ def on_connection_failure(error): # command, this connection is still considered Peripheral until an eventual # role change event. self.pending_connections[peer_address] = Connection.incomplete( - self, peer_address, BT_PERIPHERAL_ROLE + self, peer_address, Role.PERIPHERAL ) try: @@ -2524,7 +2522,7 @@ async def cancel_connection(self, peer_address=None): # If the address is not parsable, assume it is a name instead logger.debug('looking for peer by name') peer_address = await self.find_peer_by_name( - peer_address, BT_BR_EDR_TRANSPORT + peer_address, PhysicalTransport.BR_EDR ) # TODO: timeout await self.send_command( @@ -2594,7 +2592,7 @@ async def update_connection_parameters( ''' if use_l2cap: - if connection.role != BT_PERIPHERAL_ROLE: + if connection.role != Role.PERIPHERAL: raise InvalidStateError( 'only peripheral can update connection parameters with l2cap' ) @@ -2679,7 +2677,7 @@ async def set_default_phy(self, tx_phys=None, rx_phys=None): check_result=True, ) - async def find_peer_by_name(self, name, transport=BT_LE_TRANSPORT): + async def find_peer_by_name(self, name, transport=PhysicalTransport.LE): """ Scan for a peer with a give name and return its address and transport """ @@ -2700,7 +2698,7 @@ def on_peer_found(address, ad_data): was_scanning = self.scanning was_discovering = self.discovering try: - if transport == BT_LE_TRANSPORT: + if transport == PhysicalTransport.LE: event_name = 'advertisement' handler = self.on( event_name, @@ -2712,7 +2710,7 @@ def on_peer_found(address, ad_data): if not self.scanning: await self.start_scanning(filter_duplicates=True) - elif transport == BT_BR_EDR_TRANSPORT: + elif transport == PhysicalTransport.BR_EDR: event_name = 'inquiry_result' handler = self.on( event_name, @@ -2731,9 +2729,9 @@ def on_peer_found(address, ad_data): if handler is not None: self.remove_listener(event_name, handler) - if transport == BT_LE_TRANSPORT and not was_scanning: + if transport == PhysicalTransport.LE and not was_scanning: await self.stop_scanning() - elif transport == BT_BR_EDR_TRANSPORT and not was_discovering: + elif transport == PhysicalTransport.BR_EDR and not was_discovering: await self.stop_discovery() @property @@ -2779,10 +2777,10 @@ async def get_long_term_key( if keys.ltk: return keys.ltk.value - if connection.role == BT_CENTRAL_ROLE and keys.ltk_central: + if connection.role == Role.CENTRAL and keys.ltk_central: return keys.ltk_central.value - if connection.role == BT_PERIPHERAL_ROLE and keys.ltk_peripheral: + if connection.role == Role.PERIPHERAL and keys.ltk_peripheral: return keys.ltk_peripheral.value return None @@ -2840,7 +2838,7 @@ def on_authentication_failure(error_code): ) async def encrypt(self, connection, enable=True): - if not enable and connection.transport == BT_LE_TRANSPORT: + if not enable and connection.transport == PhysicalTransport.LE: raise ValueError('`enable` parameter is classic only.') # Set up event handlers @@ -2857,7 +2855,7 @@ def on_encryption_failure(error_code): # Request the encryption try: - if connection.transport == BT_LE_TRANSPORT: + if connection.transport == PhysicalTransport.LE: # Look for a key in the key store if self.keystore is None: raise RuntimeError('no key store') @@ -2933,7 +2931,7 @@ async def update_keys(self, address: str, keys: PairingKeys) -> None: self.emit('key_store_update') # [Classic only] - async def switch_role(self, connection: Connection, role: int): + async def switch_role(self, connection: Connection, role: Role): pending_role_change = asyncio.get_running_loop().create_future() def on_role_change(new_role): @@ -3199,7 +3197,7 @@ def on_link_key(self, bd_addr, link_key, key_type): self.abort_on('flush', self.update_keys(str(bd_addr), pairing_keys)) if connection := self.find_connection_by_bd_addr( - bd_addr, transport=BT_BR_EDR_TRANSPORT + bd_addr, transport=PhysicalTransport.BR_EDR ): connection.link_key_type = key_type @@ -3246,7 +3244,7 @@ def on_connection( peer_resolvable_address = None - if transport == BT_BR_EDR_TRANSPORT: + if transport == PhysicalTransport.BR_EDR: # Create a new connection connection = self.pending_connections.pop(peer_address) connection.complete(connection_handle, connection_parameters) @@ -3331,7 +3329,7 @@ def on_connection_failure(self, transport, peer_address, error_code): # For directed advertising, this means a timeout if ( - transport == BT_LE_TRANSPORT + transport == PhysicalTransport.LE and self.legacy_advertiser and self.legacy_advertiser.advertising_type.is_directed ): @@ -3358,7 +3356,7 @@ def on_connection_request(self, bd_addr, class_of_device, link_type): HCI_Connection_Complete_Event.ESCO_LINK_TYPE, ): if connection := self.find_connection_by_bd_addr( - bd_addr, transport=BT_BR_EDR_TRANSPORT + bd_addr, transport=PhysicalTransport.BR_EDR ): self.emit('sco_request', connection, link_type) else: @@ -3379,7 +3377,7 @@ def on_connection_request(self, bd_addr, class_of_device, link_type): elif self.classic_accept_any: # Save pending connection self.pending_connections[bd_addr] = Connection.incomplete( - self, bd_addr, BT_PERIPHERAL_ROLE + self, bd_addr, Role.PERIPHERAL ) self.host.send_command_sync( @@ -3855,14 +3853,14 @@ def on_connection_encryption_change(self, connection, encryption): connection.encryption = encryption if ( not connection.authenticated - and connection.transport == BT_BR_EDR_TRANSPORT + and connection.transport == PhysicalTransport.BR_EDR and encryption == HCI_Encryption_Change_Event.AES_CCM ): connection.authenticated = True connection.sc = True if ( not connection.authenticated - and connection.transport == BT_LE_TRANSPORT + and connection.transport == PhysicalTransport.LE and encryption == HCI_Encryption_Change_Event.E0_OR_AES_CCM ): connection.authenticated = True diff --git a/bumble/host.py b/bumble/host.py index f28d6fc6..c5d384f7 100644 --- a/bumble/host.py +++ b/bumble/host.py @@ -73,8 +73,7 @@ LeFeatureMask, ) from .core import ( - BT_BR_EDR_TRANSPORT, - BT_LE_TRANSPORT, + PhysicalTransport, ConnectionPHY, ConnectionParameters, ) @@ -139,7 +138,13 @@ def on_packets_completed(self, packet_count: int) -> None: # ----------------------------------------------------------------------------- class Connection: - def __init__(self, host: Host, handle: int, peer_address: Address, transport: int): + def __init__( + self, + host: Host, + handle: int, + peer_address: Address, + transport: PhysicalTransport, + ): self.host = host self.handle = handle self.peer_address = peer_address @@ -147,7 +152,7 @@ def __init__(self, host: Host, handle: int, peer_address: Address, transport: in self.transport = transport acl_packet_queue: Optional[AclPacketQueue] = ( host.le_acl_packet_queue - if transport == BT_LE_TRANSPORT + if transport == PhysicalTransport.LE else host.acl_packet_queue ) assert acl_packet_queue @@ -205,7 +210,7 @@ def __init__( def find_connection_by_bd_addr( self, bd_addr: Address, - transport: Optional[int] = None, + transport: Optional[PhysicalTransport] = None, check_address_type: bool = False, ) -> Optional[Connection]: for connection in self.connections.values(): @@ -629,7 +634,7 @@ def on_hci_le_connection_complete_event(self, event): self, event.connection_handle, event.peer_address, - BT_LE_TRANSPORT, + PhysicalTransport.LE, ) self.connections[event.connection_handle] = connection @@ -642,7 +647,7 @@ def on_hci_le_connection_complete_event(self, event): self.emit( 'connection', event.connection_handle, - BT_LE_TRANSPORT, + PhysicalTransport.LE, event.peer_address, event.role, connection_parameters, @@ -652,7 +657,10 @@ def on_hci_le_connection_complete_event(self, event): # Notify the listeners self.emit( - 'connection_failure', BT_LE_TRANSPORT, event.peer_address, event.status + 'connection_failure', + PhysicalTransport.LE, + event.peer_address, + event.status, ) def on_hci_le_enhanced_connection_complete_event(self, event): @@ -673,7 +681,7 @@ def on_hci_connection_complete_event(self, event): self, event.connection_handle, event.bd_addr, - BT_BR_EDR_TRANSPORT, + PhysicalTransport.BR_EDR, ) self.connections[event.connection_handle] = connection @@ -681,7 +689,7 @@ def on_hci_connection_complete_event(self, event): self.emit( 'connection', event.connection_handle, - BT_BR_EDR_TRANSPORT, + PhysicalTransport.BR_EDR, event.bd_addr, None, None, @@ -691,7 +699,10 @@ def on_hci_connection_complete_event(self, event): # Notify the client self.emit( - 'connection_failure', BT_BR_EDR_TRANSPORT, event.bd_addr, event.status + 'connection_failure', + PhysicalTransport.BR_EDR, + event.bd_addr, + event.status, ) def on_hci_disconnection_complete_event(self, event):