Skip to content

Commit

Permalink
Make DeviceConfiguration dataclass
Browse files Browse the repository at this point in the history
  • Loading branch information
zxzxwu committed May 5, 2024
1 parent 1b33c9e commit a5ac5f2
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 72 deletions.
140 changes: 72 additions & 68 deletions bumble/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
from enum import IntEnum
import copy
import functools
import json
import asyncio
Expand All @@ -40,6 +41,7 @@
overload,
TYPE_CHECKING,
)
from typing_extensions import Self

from pyee import EventEmitter

Expand Down Expand Up @@ -1252,75 +1254,47 @@ def __str__(self):


# -----------------------------------------------------------------------------
@dataclass
class DeviceConfiguration:
def __init__(self) -> None:
# Setup defaults
self.name = DEVICE_DEFAULT_NAME
self.address = Address(DEVICE_DEFAULT_ADDRESS)
self.class_of_device = DEVICE_DEFAULT_CLASS_OF_DEVICE
self.scan_response_data = DEVICE_DEFAULT_SCAN_RESPONSE_DATA
self.advertising_interval_min = DEVICE_DEFAULT_ADVERTISING_INTERVAL
self.advertising_interval_max = DEVICE_DEFAULT_ADVERTISING_INTERVAL
self.le_enabled = True
# LE host enable 2nd parameter
self.le_simultaneous_enabled = False
self.classic_enabled = False
self.classic_sc_enabled = True
self.classic_ssp_enabled = True
self.classic_smp_enabled = True
self.classic_accept_any = True
self.connectable = True
self.discoverable = True
self.advertising_data = bytes(
AdvertisingData(
[(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(self.name, 'utf-8'))]
)
# Setup defaults
name: str = DEVICE_DEFAULT_NAME
address: Address = Address(DEVICE_DEFAULT_ADDRESS)
class_of_device: int = DEVICE_DEFAULT_CLASS_OF_DEVICE
scan_response_data: bytes = DEVICE_DEFAULT_SCAN_RESPONSE_DATA
advertising_interval_min: int = DEVICE_DEFAULT_ADVERTISING_INTERVAL
advertising_interval_max: int = DEVICE_DEFAULT_ADVERTISING_INTERVAL
le_enabled: bool = True
# LE host enable 2nd parameter
le_simultaneous_enabled: bool = False
classic_enabled: bool = False
classic_sc_enabled: bool = True
classic_ssp_enabled: bool = True
classic_smp_enabled: bool = True
classic_accept_any: bool = True
connectable: bool = True
discoverable: bool = True
advertising_data: bytes = bytes(
AdvertisingData(
[(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(DEVICE_DEFAULT_NAME, 'utf-8'))]
)
self.irk = bytes(16) # This really must be changed for any level of security
self.keystore = None
)
irk: bytes = bytes(16) # This really must be changed for any level of security
keystore: Optional[str] = None
address_resolution_offload: bool = False
cis_enabled: bool = False

def __post_init__(self) -> None:
self.gatt_services: List[Dict[str, Any]] = []
self.address_resolution_offload = False
self.cis_enabled = False

def load_from_dict(self, config: Dict[str, Any]) -> None:
config = copy.deepcopy(config)

# Load simple properties
self.name = config.get('name', self.name)
if address := config.get('address', None):
if address := config.pop('address', None):
self.address = Address(address)
self.class_of_device = config.get('class_of_device', self.class_of_device)
self.advertising_interval_min = config.get(
'advertising_interval', self.advertising_interval_min
)
self.advertising_interval_max = self.advertising_interval_min
self.keystore = config.get('keystore')
self.le_enabled = config.get('le_enabled', self.le_enabled)
self.le_simultaneous_enabled = config.get(
'le_simultaneous_enabled', self.le_simultaneous_enabled
)
self.classic_enabled = config.get('classic_enabled', self.classic_enabled)
self.classic_sc_enabled = config.get(
'classic_sc_enabled', self.classic_sc_enabled
)
self.classic_ssp_enabled = config.get(
'classic_ssp_enabled', self.classic_ssp_enabled
)
self.classic_smp_enabled = config.get(
'classic_smp_enabled', self.classic_smp_enabled
)
self.classic_accept_any = config.get(
'classic_accept_any', self.classic_accept_any
)
self.connectable = config.get('connectable', self.connectable)
self.discoverable = config.get('discoverable', self.discoverable)
self.gatt_services = config.get('gatt_services', self.gatt_services)
self.address_resolution_offload = config.get(
'address_resolution_offload', self.address_resolution_offload
)
self.cis_enabled = config.get('cis_enabled', self.cis_enabled)

# Load or synthesize an IRK
irk = config.get('irk')
if irk:
if irk := config.pop('irk', None):
self.irk = bytes.fromhex(irk)
elif self.address != Address(DEVICE_DEFAULT_ADDRESS):
# Construct an IRK from the address bytes
Expand All @@ -1332,21 +1306,53 @@ def load_from_dict(self, config: Dict[str, Any]) -> None:
# Fallback - when both IRK and address are not set, randomly generate an IRK.
self.irk = secrets.token_bytes(16)

if (name := config.pop('name', None)) is not None:
self.name = name

# Load advertising data
advertising_data = config.get('advertising_data')
if advertising_data:
if advertising_data := config.pop('advertising_data', None):
self.advertising_data = bytes.fromhex(advertising_data)
elif config.get('name') is not None:
elif name is not None:
self.advertising_data = bytes(
AdvertisingData(
[(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(self.name, 'utf-8'))]
)
)

def load_from_file(self, filename):
# Load advertising interval (for backward compatibility)
if advertising_interval := config.pop('advertising_interval', None):
self.advertising_interval_min = advertising_interval
self.advertising_interval_max = advertising_interval
if (
'advertising_interval_max' in config
or 'advertising_interval_min' in config
):
logger.warning(
'Trying to set both advertising_interval and '
'advertising_interval_min/max, advertising_interval will be'
'ignored.'
)

# Load data in primitive types.
for key, value in config.items():
setattr(self, key, value)

def load_from_file(self, filename: str) -> None:
with open(filename, 'r', encoding='utf-8') as file:
self.load_from_dict(json.load(file))

@classmethod
def from_file(cls: Type[Self], filename: str) -> Self:
config = cls()
config.load_from_file(filename)
return config

@classmethod
def from_dict(cls: Type[Self], config: Dict[str, Any]) -> Self:
device_config = cls()
device_config.load_from_dict(config)
return device_config


# -----------------------------------------------------------------------------
# Decorators used with the following Device class
Expand Down Expand Up @@ -1470,8 +1476,7 @@ def with_hci(

@classmethod
def from_config_file(cls, filename: str) -> Device:
config = DeviceConfiguration()
config.load_from_file(filename)
config = DeviceConfiguration.from_file(filename)
return cls(config=config)

@classmethod
Expand All @@ -1488,8 +1493,7 @@ def from_config_with_hci(
def from_config_file_with_hci(
cls, filename: str, hci_source: TransportSource, hci_sink: TransportSink
) -> Device:
config = DeviceConfiguration()
config.load_from_file(filename)
config = DeviceConfiguration.from_file(filename)
return cls.from_config_with_hci(config, hci_source, hci_sink)

def __init__(
Expand Down
11 changes: 7 additions & 4 deletions bumble/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
import logging
import os
import json
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type
from typing_extensions import Self

from .colors import color
from .hci import Address
Expand Down Expand Up @@ -253,8 +254,10 @@ def __init__(self, namespace, filename=None):

logger.debug(f'JSON keystore: {self.filename}')

@staticmethod
def from_device(device: Device, filename=None) -> Optional[JsonKeyStore]:
@classmethod
def from_device(
cls: Type[Self], device: Device, filename: Optional[str] = None
) -> Self:
if not filename:
# Extract the filename from the config if there is one
if device.config.keystore is not None:
Expand All @@ -270,7 +273,7 @@ def from_device(device: Device, filename=None) -> Optional[JsonKeyStore]:
else:
namespace = JsonKeyStore.DEFAULT_NAMESPACE

return JsonKeyStore(namespace, filename)
return cls(namespace, filename)

async def load(self):
# Try to open the file, without failing. If the file does not exist, it
Expand Down

0 comments on commit a5ac5f2

Please sign in to comment.