Skip to content

Commit

Permalink
Merge pull request #368 from google/gbg/driver-load-before-reset
Browse files Browse the repository at this point in the history
support drivers that can't use reset directly.
  • Loading branch information
barbibulle authored Dec 12, 2023
2 parents f0b55a4 + 98ed772 commit a286700
Show file tree
Hide file tree
Showing 10 changed files with 170 additions and 81 deletions.
58 changes: 27 additions & 31 deletions bumble/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,53 +19,49 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import abc
from __future__ import annotations
import logging
import pathlib
import platform
from typing import Dict, Iterable, Optional, Type, TYPE_CHECKING

from . import rtk
from .common import Driver

if TYPE_CHECKING:
from bumble.host import Host

# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)


# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class Driver(abc.ABC):
"""Base class for drivers."""

@staticmethod
async def for_host(_host):
"""Return a driver instance for a host.
Args:
host: Host object for which a driver should be created.
Returns:
A Driver instance if a driver should be instantiated for this host, or
None if no driver instance of this class is needed.
"""
return None

@abc.abstractmethod
async def init_controller(self):
"""Initialize the controller."""


# -----------------------------------------------------------------------------
# Functions
# -----------------------------------------------------------------------------
async def get_driver_for_host(host):
"""Probe all known diver classes until one returns a valid instance for a host,
or none is found.
async def get_driver_for_host(host: Host) -> Optional[Driver]:
"""Probe diver classes until one returns a valid instance for a host, or none is
found.
If a "driver" HCI metadata entry is present, only that driver class will be probed.
"""
if driver := await rtk.Driver.for_host(host):
logger.debug("Instantiated RTK driver")
return driver
driver_classes: Dict[str, Type[Driver]] = {"rtk": rtk.Driver}
probe_list: Iterable[str]
if driver_name := host.hci_metadata.get("driver"):
# Only probe a single driver
probe_list = [driver_name]
else:
# Probe all drivers
probe_list = driver_classes.keys()

for driver_name in probe_list:
if driver_class := driver_classes.get(driver_name):
logger.debug(f"Probing driver class: {driver_name}")
if driver := await driver_class.for_host(host):
logger.debug(f"Instantiated {driver_name} driver")
return driver
else:
logger.debug(f"Skipping unknown driver class: {driver_name}")

return None

Expand Down
45 changes: 45 additions & 0 deletions bumble/drivers/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Common types for drivers.
"""

# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import abc


# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class Driver(abc.ABC):
"""Base class for drivers."""

@staticmethod
async def for_host(_host):
"""Return a driver instance for a host.
Args:
host: Host object for which a driver should be created.
Returns:
A Driver instance if a driver should be instantiated for this host, or
None if no driver instance of this class is needed.
"""
return None

@abc.abstractmethod
async def init_controller(self):
"""Initialize the controller."""
15 changes: 11 additions & 4 deletions bumble/drivers/rtk.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
HCI_Reset_Command,
HCI_Read_Local_Version_Information_Command,
)

from bumble.drivers import common

# -----------------------------------------------------------------------------
# Logging
Expand Down Expand Up @@ -285,7 +285,7 @@ def __init__(self, firmware):
)


class Driver:
class Driver(common.Driver):
@dataclass
class DriverInfo:
rom: int
Expand Down Expand Up @@ -470,8 +470,12 @@ def check(host):
logger.debug("USB metadata not found")
return False

vendor_id = host.hci_metadata.get("vendor_id", None)
product_id = host.hci_metadata.get("product_id", None)
if host.hci_metadata.get('driver') == 'rtk':
# Forced driver
return True

vendor_id = host.hci_metadata.get("vendor_id")
product_id = host.hci_metadata.get("product_id")
if vendor_id is None or product_id is None:
logger.debug("USB metadata not sufficient")
return False
Expand All @@ -486,6 +490,9 @@ def check(host):

@classmethod
async def driver_info_for_host(cls, host):
await host.send_command(HCI_Reset_Command(), check_result=True)
host.ready = True # Needed to let the host know the controller is ready.

response = await host.send_command(
HCI_Read_Local_Version_Information_Command(), check_result=True
)
Expand Down
38 changes: 21 additions & 17 deletions bumble/host.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import logging
import struct

from typing import Optional, TYPE_CHECKING, Dict, Callable, Awaitable, cast
from typing import Any, Awaitable, Callable, Dict, Optional, Union, cast, TYPE_CHECKING

from bumble.colors import color
from bumble.l2cap import L2CAP_PDU
Expand Down Expand Up @@ -124,7 +124,8 @@ def on_acl_pdu(self, pdu: bytes) -> None:
class Host(AbortableEventEmitter):
connections: Dict[int, Connection]
acl_packet_queue: collections.deque[HCI_AclDataPacket]
hci_sink: TransportSink
hci_sink: Optional[TransportSink] = None
hci_metadata: Dict[str, Any]
long_term_key_provider: Optional[
Callable[[int, bytes, int], Awaitable[Optional[bytes]]]
]
Expand All @@ -137,9 +138,8 @@ def __init__(
) -> None:
super().__init__()

self.hci_metadata = None
self.hci_metadata = {}
self.ready = False # True when we can accept incoming packets
self.reset_done = False
self.connections = {} # Connections, by connection handle
self.pending_command = None
self.pending_response = None
Expand All @@ -162,10 +162,7 @@ def __init__(

# Connect to the source and sink if specified
if controller_source:
controller_source.set_packet_sink(self)
self.hci_metadata = getattr(
controller_source, 'metadata', self.hci_metadata
)
self.set_packet_source(controller_source)
if controller_sink:
self.set_packet_sink(controller_sink)

Expand Down Expand Up @@ -200,17 +197,21 @@ async def reset(self, driver_factory=drivers.get_driver_for_host):
self.ready = False
await self.flush()

await self.send_command(HCI_Reset_Command(), check_result=True)
self.ready = True

# Instantiate and init a driver for the host if needed.
# NOTE: we don't keep a reference to the driver here, because we don't
# currently have a need for the driver later on. But if the driver interface
# evolves, it may be required, then, to store a reference to the driver in
# an object property.
reset_needed = True
if driver_factory is not None:
if driver := await driver_factory(self):
await driver.init_controller()
reset_needed = False

# Send a reset command unless a driver has already done so.
if reset_needed:
await self.send_command(HCI_Reset_Command(), check_result=True)
self.ready = True

response = await self.send_command(
HCI_Read_Local_Supported_Commands_Command(), check_result=True
Expand Down Expand Up @@ -313,25 +314,28 @@ async def reset(self, driver_factory=drivers.get_driver_for_host):
)
)

self.reset_done = True

@property
def controller(self) -> TransportSink:
def controller(self) -> Optional[TransportSink]:
return self.hci_sink

@controller.setter
def controller(self, controller):
def controller(self, controller) -> None:
self.set_packet_sink(controller)
if controller:
controller.set_packet_sink(self)

def set_packet_sink(self, sink: TransportSink) -> None:
def set_packet_sink(self, sink: Optional[TransportSink]) -> None:
self.hci_sink = sink

def set_packet_source(self, source: TransportSource) -> None:
source.set_packet_sink(self)
self.hci_metadata = getattr(source, 'metadata', self.hci_metadata)

def send_hci_packet(self, packet: HCI_Packet) -> None:
if self.snooper:
self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER)
self.hci_sink.on_packet(bytes(packet))
if self.hci_sink:
self.hci_sink.on_packet(bytes(packet))

async def send_command(self, command, check_result=False):
logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: {command}')
Expand Down
Loading

0 comments on commit a286700

Please sign in to comment.