diff --git a/bumble/keys.py b/bumble/keys.py index 198d5c49..753575e2 100644 --- a/bumble/keys.py +++ b/bumble/keys.py @@ -22,10 +22,11 @@ # ----------------------------------------------------------------------------- from __future__ import annotations import asyncio +import dataclasses import logging import os import json -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Any from .colors import color from .hci import Address @@ -41,16 +42,17 @@ # ----------------------------------------------------------------------------- +@dataclasses.dataclass class PairingKeys: + @dataclasses.dataclass class Key: - def __init__(self, value, authenticated=False, ediv=None, rand=None): - self.value = value - self.authenticated = authenticated - self.ediv = ediv - self.rand = rand + value: bytes + authenticated: bool = False + ediv: Optional[int] = None + rand: Optional[bytes] = None @classmethod - def from_dict(cls, key_dict): + def from_dict(cls, key_dict: Dict[str, Any]) -> PairingKeys.Key: value = bytes.fromhex(key_dict['value']) authenticated = key_dict.get('authenticated', False) ediv = key_dict.get('ediv') @@ -58,9 +60,9 @@ def from_dict(cls, key_dict): if rand is not None: rand = bytes.fromhex(rand) - return cls(value, authenticated, ediv, rand) + return cls(value=value, authenticated=authenticated, ediv=ediv, rand=rand) - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: key_dict = {'value': self.value.hex(), 'authenticated': self.authenticated} if self.ediv is not None: key_dict['ediv'] = self.ediv @@ -69,17 +71,18 @@ def to_dict(self): return key_dict - def __init__(self): - self.address_type = None - self.ltk = None - self.ltk_central = None - self.ltk_peripheral = None - self.irk = None - self.csrk = None - self.link_key = None # Classic + address_type: Optional[int] = None + ltk: Optional[PairingKeys.Key] = None + ltk_central: Optional[PairingKeys.Key] = None + ltk_peripheral: Optional[PairingKeys.Key] = None + irk: Optional[PairingKeys.Key] = None + csrk: Optional[PairingKeys.Key] = None + link_key: Optional[PairingKeys.Key] = None # Classic @staticmethod - def key_from_dict(keys_dict, key_name): + def key_from_dict( + keys_dict: Dict[str, Any], key_name: str + ) -> Optional[PairingKeys.Key]: key_dict = keys_dict.get(key_name) if key_dict is None: return None @@ -87,7 +90,7 @@ def key_from_dict(keys_dict, key_name): return PairingKeys.Key.from_dict(key_dict) @staticmethod - def from_dict(keys_dict): + def from_dict(keys_dict: Dict[str, Any]) -> PairingKeys: keys = PairingKeys() keys.address_type = keys_dict.get('address_type') @@ -100,8 +103,8 @@ def from_dict(keys_dict): return keys - def to_dict(self): - keys = {} + def to_dict(self) -> Dict[str, Any]: + keys: Dict[str, Any] = {} if self.address_type is not None: keys['address_type'] = self.address_type @@ -126,12 +129,12 @@ def to_dict(self): return keys - def print(self, prefix=''): + def print(self, prefix: str = '') -> None: keys_dict = self.to_dict() - for (container_property, value) in keys_dict.items(): + for container_property, value in keys_dict.items(): if isinstance(value, dict): print(f'{prefix}{color(container_property, "cyan")}:') - for (key_property, key_value) in value.items(): + for key_property, key_value in value.items(): print(f'{prefix} {color(key_property, "green")}: {key_value}') else: print(f'{prefix}{color(container_property, "cyan")}: {value}') @@ -139,7 +142,7 @@ def print(self, prefix=''): # ----------------------------------------------------------------------------- class KeyStore: - async def delete(self, name: str): + async def delete(self, name: str) -> None: pass async def update(self, name: str, keys: PairingKeys) -> None: @@ -158,7 +161,7 @@ async def delete_all(self) -> None: async def get_resolving_keys(self): all_keys = await self.get_all() resolving_keys = [] - for (name, keys) in all_keys: + for name, keys in all_keys: if keys.irk is not None: if keys.address_type is None: address_type = Address.RANDOM_DEVICE_ADDRESS @@ -168,10 +171,10 @@ async def get_resolving_keys(self): return resolving_keys - async def print(self, prefix=''): + async def print(self, prefix: str = '') -> None: entries = await self.get_all() separator = '' - for (name, keys) in entries: + for name, keys in entries: print(separator + prefix + color(name, 'yellow')) keys.print(prefix=prefix + ' ') separator = '\n' @@ -229,7 +232,9 @@ class without a namespace. With the default namespace, reading from a file will DEFAULT_NAMESPACE = '__DEFAULT__' DEFAULT_BASE_NAME = "keys" - def __init__(self, namespace, filename=None): + def __init__( + self, namespace: Optional[str], filename: Optional[str] = None + ) -> None: self.namespace = namespace if namespace is not None else self.DEFAULT_NAMESPACE if filename is None: diff --git a/bumble/smp.py b/bumble/smp.py index f8879c6a..4e7dd0b2 100644 --- a/bumble/smp.py +++ b/bumble/smp.py @@ -1348,6 +1348,9 @@ async def on_pairing(self) -> None: if self.sc or self.connection.transport == BT_BR_EDR_TRANSPORT: keys.ltk = PairingKeys.Key(value=self.ltk, authenticated=authenticated) else: + if not self.peer_ltk: + raise RuntimeError('Peer LTK missing in LE legacy pairing') + our_ltk_key = PairingKeys.Key( value=self.ltk, authenticated=authenticated,