Skip to content

Commit

Permalink
Typing keys module
Browse files Browse the repository at this point in the history
  • Loading branch information
zxzxwu committed Jan 17, 2024
1 parent 46ceea7 commit 5a60f24
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 29 deletions.
63 changes: 34 additions & 29 deletions bumble/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import dataclasses
import logging
import os
import json
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Any

from .colors import color
from .hci import Address
Expand All @@ -41,26 +42,27 @@


# -----------------------------------------------------------------------------
@dataclasses.dataclass
class PairingKeys:
@dataclasses.dataclass
class Key:
def __init__(self, value, authenticated=False, ediv=None, rand=None):
self.value = value
self.authenticated = authenticated
self.ediv = ediv
self.rand = rand
value: bytes
authenticated: bool = False
ediv: Optional[int] = None
rand: Optional[bytes] = None

@classmethod
def from_dict(cls, key_dict):
def from_dict(cls, key_dict: Dict[str, Any]) -> PairingKeys.Key:
value = bytes.fromhex(key_dict['value'])
authenticated = key_dict.get('authenticated', False)
ediv = key_dict.get('ediv')
rand = key_dict.get('rand')
if rand is not None:
rand = bytes.fromhex(rand)

return cls(value, authenticated, ediv, rand)
return cls(value=value, authenticated=authenticated, ediv=ediv, rand=rand)

def to_dict(self):
def to_dict(self) -> Dict[str, Any]:
key_dict = {'value': self.value.hex(), 'authenticated': self.authenticated}
if self.ediv is not None:
key_dict['ediv'] = self.ediv
Expand All @@ -69,25 +71,26 @@ def to_dict(self):

return key_dict

def __init__(self):
self.address_type = None
self.ltk = None
self.ltk_central = None
self.ltk_peripheral = None
self.irk = None
self.csrk = None
self.link_key = None # Classic
address_type: Optional[int] = None
ltk: Optional[PairingKeys.Key] = None
ltk_central: Optional[PairingKeys.Key] = None
ltk_peripheral: Optional[PairingKeys.Key] = None
irk: Optional[PairingKeys.Key] = None
csrk: Optional[PairingKeys.Key] = None
link_key: Optional[PairingKeys.Key] = None # Classic

@staticmethod
def key_from_dict(keys_dict, key_name):
def key_from_dict(
keys_dict: Dict[str, Any], key_name: str
) -> Optional[PairingKeys.Key]:
key_dict = keys_dict.get(key_name)
if key_dict is None:
return None

return PairingKeys.Key.from_dict(key_dict)

@staticmethod
def from_dict(keys_dict):
def from_dict(keys_dict: Dict[str, Any]) -> PairingKeys:
keys = PairingKeys()

keys.address_type = keys_dict.get('address_type')
Expand All @@ -100,8 +103,8 @@ def from_dict(keys_dict):

return keys

def to_dict(self):
keys = {}
def to_dict(self) -> Dict[str, Any]:
keys: Dict[str, Any] = {}

if self.address_type is not None:
keys['address_type'] = self.address_type
Expand All @@ -126,20 +129,20 @@ def to_dict(self):

return keys

def print(self, prefix=''):
def print(self, prefix: str = '') -> None:
keys_dict = self.to_dict()
for (container_property, value) in keys_dict.items():
for container_property, value in keys_dict.items():
if isinstance(value, dict):
print(f'{prefix}{color(container_property, "cyan")}:')
for (key_property, key_value) in value.items():
for key_property, key_value in value.items():
print(f'{prefix} {color(key_property, "green")}: {key_value}')
else:
print(f'{prefix}{color(container_property, "cyan")}: {value}')


# -----------------------------------------------------------------------------
class KeyStore:
async def delete(self, name: str):
async def delete(self, name: str) -> None:
pass

async def update(self, name: str, keys: PairingKeys) -> None:
Expand All @@ -158,7 +161,7 @@ async def delete_all(self) -> None:
async def get_resolving_keys(self):
all_keys = await self.get_all()
resolving_keys = []
for (name, keys) in all_keys:
for name, keys in all_keys:
if keys.irk is not None:
if keys.address_type is None:
address_type = Address.RANDOM_DEVICE_ADDRESS
Expand All @@ -168,10 +171,10 @@ async def get_resolving_keys(self):

return resolving_keys

async def print(self, prefix=''):
async def print(self, prefix: str = '') -> None:
entries = await self.get_all()
separator = ''
for (name, keys) in entries:
for name, keys in entries:
print(separator + prefix + color(name, 'yellow'))
keys.print(prefix=prefix + ' ')
separator = '\n'
Expand Down Expand Up @@ -229,7 +232,9 @@ class without a namespace. With the default namespace, reading from a file will
DEFAULT_NAMESPACE = '__DEFAULT__'
DEFAULT_BASE_NAME = "keys"

def __init__(self, namespace, filename=None):
def __init__(
self, namespace: Optional[str], filename: Optional[str] = None
) -> None:
self.namespace = namespace if namespace is not None else self.DEFAULT_NAMESPACE

if filename is None:
Expand Down
3 changes: 3 additions & 0 deletions bumble/smp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1348,6 +1348,9 @@ async def on_pairing(self) -> None:
if self.sc or self.connection.transport == BT_BR_EDR_TRANSPORT:
keys.ltk = PairingKeys.Key(value=self.ltk, authenticated=authenticated)
else:
if not self.peer_ltk:
raise RuntimeError('Peer LTK missing in LE legacy pairing')

our_ltk_key = PairingKeys.Key(
value=self.ltk,
authenticated=authenticated,
Expand Down

0 comments on commit 5a60f24

Please sign in to comment.