From 87c76a4a0e5cc0088278525e04cfe11f12b31cd0 Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Thu, 14 Dec 2023 23:52:04 +0800 Subject: [PATCH] Complete CSIP and CAP Also add random address generation functions. --- bumble/crypto.py | 10 +++ bumble/gatt.py | 7 +- bumble/profiles/cap.py | 52 +++++++++++++++ bumble/profiles/csip.py | 66 +++++++++++++++++-- examples/run_csis_servers.py | 116 +++++++++++++++++++++++++++++++++ examples/run_unicast_server.py | 53 +++++++++------ tests/cap_test.py | 71 ++++++++++++++++++++ tests/csip_test.py | 37 +++++++++++ 8 files changed, 385 insertions(+), 27 deletions(-) create mode 100644 bumble/profiles/cap.py create mode 100644 examples/run_csis_servers.py create mode 100644 tests/cap_test.py diff --git a/bumble/crypto.py b/bumble/crypto.py index d267c3b7..af95160b 100644 --- a/bumble/crypto.py +++ b/bumble/crypto.py @@ -100,6 +100,16 @@ def dh(self, public_key_x: bytes, public_key_y: bytes) -> bytes: # ----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- +def generate_prand() -> bytes: + '''Generates random 3 bytes, with the 2 most significant bits of 0b01. + + See Bluetooth spec, Vol 6, Part E - Table 1.2. + ''' + prand_bytes = secrets.token_bytes(6) + return prand_bytes[:2] + bytes([(prand_bytes[2] & 0b01111111) | 0b01000000]) + + # ----------------------------------------------------------------------------- def xor(x: bytes, y: bytes) -> bytes: assert len(x) == len(y) diff --git a/bumble/gatt.py b/bumble/gatt.py index da5934c5..5e270244 100644 --- a/bumble/gatt.py +++ b/bumble/gatt.py @@ -368,9 +368,12 @@ class TemplateService(Service): UUID: UUID def __init__( - self, characteristics: List[Characteristic], primary: bool = True + self, + characteristics: List[Characteristic], + primary: bool = True, + included_services: List[Service] = [], ) -> None: - super().__init__(self.UUID, characteristics, primary) + super().__init__(self.UUID, characteristics, primary, included_services) # ----------------------------------------------------------------------------- diff --git a/bumble/profiles/cap.py b/bumble/profiles/cap.py new file mode 100644 index 00000000..476f908f --- /dev/null +++ b/bumble/profiles/cap.py @@ -0,0 +1,52 @@ +# Copyright 2021-2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +from __future__ import annotations + +from bumble import gatt +from bumble import gatt_client +from bumble.profiles import csip + + +# ----------------------------------------------------------------------------- +# Server +# ----------------------------------------------------------------------------- +class CommonAudioServiceService(gatt.TemplateService): + UUID = gatt.GATT_COMMON_AUDIO_SERVICE + + def __init__( + self, + coordinated_set_identification_service: csip.CoordinatedSetIdentificationService, + ) -> None: + self.coordinated_set_identification_service = ( + coordinated_set_identification_service + ) + super().__init__( + characteristics=[], + included_services=[coordinated_set_identification_service], + ) + + +# ----------------------------------------------------------------------------- +# Client +# ----------------------------------------------------------------------------- +class CommonAudioServiceServiceProxy(gatt_client.ProfileServiceProxy): + SERVICE_CLASS = CommonAudioServiceService + + def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None: + self.service_proxy = service_proxy diff --git a/bumble/profiles/csip.py b/bumble/profiles/csip.py index 9657246b..c82b413a 100644 --- a/bumble/profiles/csip.py +++ b/bumble/profiles/csip.py @@ -21,6 +21,9 @@ import struct from typing import Optional +from bumble import core +from bumble import crypto +from bumble import device from bumble import gatt from bumble import gatt_client @@ -43,9 +46,43 @@ class MemberLock(enum.IntEnum): # ----------------------------------------------------------------------------- -# Utils +# Crypto Toolbox # ----------------------------------------------------------------------------- -# TODO: Implement RSI Generator +def s1(m: bytes) -> bytes: + ''' + Coordinated Set Identification Service - 4.3 s1 SALT generation function. + ''' + return crypto.aes_cmac(m[::-1], bytes(16))[::-1] + + +def k1(n: bytes, salt: bytes, p: bytes) -> bytes: + ''' + Coordinated Set Identification Service - 4.4 k1 derivation function. + ''' + t = crypto.aes_cmac(n[::-1], salt[::-1]) + return crypto.aes_cmac(p[::-1], t)[::-1] + + +def sef(k: bytes, r: bytes) -> bytes: + ''' + Coordinated Set Identification Service - 4.5 SIRK encryption function sef. + ''' + return crypto.xor(k1(k, s1(b'SIRKenc'[::-1]), b'csis'[::-1]), r) + + +def sih(k: bytes, r: bytes) -> bytes: + ''' + Coordinated Set Identification Service - 4.7 Resolvable Set Identifier hash function sih. + ''' + return crypto.e(k, r + bytes(13))[:3] + + +def generate_rsi(sirk: bytes) -> bytes: + ''' + Coordinated Set Identification Service - 4.8 Resolvable Set Identifier generation operation. + ''' + prand = crypto.generate_prand() + return sih(sirk, prand) + prand # ----------------------------------------------------------------------------- @@ -54,6 +91,7 @@ class MemberLock(enum.IntEnum): class CoordinatedSetIdentificationService(gatt.TemplateService): UUID = gatt.GATT_COORDINATED_SET_IDENTIFICATION_SERVICE + set_identity_resolving_key: bytes set_identity_resolving_key_characteristic: gatt.Characteristic coordinated_set_size_characteristic: Optional[gatt.Characteristic] = None set_member_lock_characteristic: Optional[gatt.Characteristic] = None @@ -62,19 +100,21 @@ class CoordinatedSetIdentificationService(gatt.TemplateService): def __init__( self, set_identity_resolving_key: bytes, + set_identity_resolving_key_type: SirkType, coordinated_set_size: Optional[int] = None, set_member_lock: Optional[MemberLock] = None, set_member_rank: Optional[int] = None, ) -> None: characteristics = [] + self.set_identity_resolving_key = set_identity_resolving_key + self.set_identity_resolving_key_type = set_identity_resolving_key_type self.set_identity_resolving_key_characteristic = gatt.Characteristic( uuid=gatt.GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC, properties=gatt.Characteristic.Properties.READ | gatt.Characteristic.Properties.NOTIFY, permissions=gatt.Characteristic.Permissions.READABLE, - # TODO: Implement encrypted SIRK reader. - value=struct.pack('B', SirkType.PLAINTEXT) + set_identity_resolving_key, + value=gatt.CharacteristicValue(read=self.on_sirk_read), ) characteristics.append(self.set_identity_resolving_key_characteristic) @@ -112,6 +152,24 @@ def __init__( super().__init__(characteristics) + def on_sirk_read(self, _connection: device.Connection) -> bytes: + if self.set_identity_resolving_key_type == SirkType.PLAINTEXT: + return bytes([SirkType.PLAINTEXT]) + self.set_identity_resolving_key + else: + raise NotImplementedError('TODO: Pending async Characteristic read.') + + def get_advertising_data(self) -> bytes: + return bytes( + core.AdvertisingData( + [ + ( + core.AdvertisingData.RESOLVABLE_SET_IDENTIFIER, + generate_rsi(self.set_identity_resolving_key), + ), + ] + ) + ) + # ----------------------------------------------------------------------------- # Client diff --git a/examples/run_csis_servers.py b/examples/run_csis_servers.py new file mode 100644 index 00000000..88d49a16 --- /dev/null +++ b/examples/run_csis_servers.py @@ -0,0 +1,116 @@ +# Copyright 2021-2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +import asyncio +import logging +import sys +import os +import secrets + +from bumble.core import AdvertisingData +from bumble.device import Device +from bumble.hci import ( + Address, + OwnAddressType, + HCI_LE_Set_Extended_Advertising_Parameters_Command, +) +from bumble.profiles.cap import CommonAudioServiceService +from bumble.profiles.csip import CoordinatedSetIdentificationService, SirkType + +from bumble.transport import open_transport_or_link + + +# ----------------------------------------------------------------------------- +async def main() -> None: + if len(sys.argv) < 3: + print( + 'Usage: run_cig_setup.py ' + ' ' + ) + print( + 'example: run_cig_setup.py device1.json' + 'tcp-client:127.0.0.1:6402 tcp-client:127.0.0.1:6402' + ) + return + + print('<<< connecting to HCI...') + hci_transports = await asyncio.gather( + open_transport_or_link(sys.argv[2]), open_transport_or_link(sys.argv[3]) + ) + print('<<< connected') + + devices = [ + Device.from_config_file_with_hci( + sys.argv[1], hci_transport.source, hci_transport.sink + ) + for hci_transport in hci_transports + ] + + sirk = secrets.token_bytes(16) + + for i, device in enumerate(devices): + device.random_address = Address(secrets.token_bytes(6)) + await device.power_on() + csis = CoordinatedSetIdentificationService( + set_identity_resolving_key=sirk, + set_identity_resolving_key_type=SirkType.PLAINTEXT, + coordinated_set_size=2, + ) + device.add_service(CommonAudioServiceService(csis)) + advertising_data = ( + bytes( + AdvertisingData( + [ + ( + AdvertisingData.COMPLETE_LOCAL_NAME, + bytes(f'Bumble LE Audio-{i}', 'utf-8'), + ), + ( + AdvertisingData.FLAGS, + bytes( + [ + AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG + | AdvertisingData.BR_EDR_HOST_FLAG + | AdvertisingData.BR_EDR_CONTROLLER_FLAG + ] + ), + ), + ( + AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, + bytes(CoordinatedSetIdentificationService.UUID), + ), + ] + ) + ) + + csis.get_advertising_data() + ) + await device.start_extended_advertising( + advertising_properties=( + HCI_LE_Set_Extended_Advertising_Parameters_Command.AdvertisingProperties.CONNECTABLE_ADVERTISING + ), + own_address_type=OwnAddressType.RANDOM, + advertising_data=advertising_data, + ) + + await asyncio.gather( + *[hci_transport.source.terminated for hci_transport in hci_transports] + ) + + +# ----------------------------------------------------------------------------- +logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) +asyncio.run(main()) diff --git a/examples/run_unicast_server.py b/examples/run_unicast_server.py index e71cbeff..35e124d4 100644 --- a/examples/run_unicast_server.py +++ b/examples/run_unicast_server.py @@ -20,6 +20,7 @@ import sys import os import struct +import secrets from bumble.core import AdvertisingData from bumble.device import Device, CisLink from bumble.hci import ( @@ -39,6 +40,8 @@ PublishedAudioCapabilitiesService, AudioStreamControlService, ) +from bumble.profiles.cap import CommonAudioServiceService +from bumble.profiles.csip import CoordinatedSetIdentificationService, SirkType from bumble.transport import open_transport_or_link @@ -60,6 +63,11 @@ async def main() -> None: await device.power_on() + csis = CoordinatedSetIdentificationService( + set_identity_resolving_key=secrets.token_bytes(16), + set_identity_resolving_key_type=SirkType.PLAINTEXT, + ) + device.add_service(CommonAudioServiceService(csis)) device.add_service( PublishedAudioCapabilitiesService( supported_source_context=ContextType.PROHIBITED, @@ -108,29 +116,32 @@ async def main() -> None: device.add_service(AudioStreamControlService(device, sink_ase_id=[1, 2])) - advertising_data = bytes( - AdvertisingData( - [ - ( - AdvertisingData.COMPLETE_LOCAL_NAME, - bytes('Bumble LE Audio', 'utf-8'), - ), - ( - AdvertisingData.FLAGS, - bytes( - [ - AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG - | AdvertisingData.BR_EDR_HOST_FLAG - | AdvertisingData.BR_EDR_CONTROLLER_FLAG - ] + advertising_data = ( + bytes( + AdvertisingData( + [ + ( + AdvertisingData.COMPLETE_LOCAL_NAME, + bytes('Bumble LE Audio', 'utf-8'), ), - ), - ( - AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, - bytes(PublishedAudioCapabilitiesService.UUID), - ), - ] + ( + AdvertisingData.FLAGS, + bytes( + [ + AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG + | AdvertisingData.BR_EDR_HOST_FLAG + | AdvertisingData.BR_EDR_CONTROLLER_FLAG + ] + ), + ), + ( + AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, + bytes(PublishedAudioCapabilitiesService.UUID), + ), + ] + ) ) + + csis.get_advertising_data() ) subprocess = await asyncio.create_subprocess_shell( f'dlc3 | ffplay pipe:0', diff --git a/tests/cap_test.py b/tests/cap_test.py new file mode 100644 index 00000000..ab5ab816 --- /dev/null +++ b/tests/cap_test.py @@ -0,0 +1,71 @@ +# Copyright 2021-2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +import asyncio +import os +import pytest +import logging + +from bumble import device +from bumble import gatt +from bumble.profiles import cap +from bumble.profiles import csip +from .test_utils import TwoDevices + +# ----------------------------------------------------------------------------- +# Logging +# ----------------------------------------------------------------------------- +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_cas(): + SIRK = bytes.fromhex('2f62c8ae41867d1bb619e788a2605faa') + + devices = TwoDevices() + devices[0].add_service( + cap.CommonAudioServiceService( + csip.CoordinatedSetIdentificationService( + set_identity_resolving_key=SIRK, + set_identity_resolving_key_type=csip.SirkType.PLAINTEXT, + ) + ) + ) + + await devices.setup_connection() + peer = device.Peer(devices.connections[1]) + cas_client = await peer.discover_service_and_create_proxy( + cap.CommonAudioServiceServiceProxy + ) + + included_services = await peer.discover_included_services(cas_client.service_proxy) + assert any( + service.uuid == gatt.GATT_COORDINATED_SET_IDENTIFICATION_SERVICE + for service in included_services + ) + + +# ----------------------------------------------------------------------------- +async def run(): + await test_cas() + + +# ----------------------------------------------------------------------------- +if __name__ == '__main__': + logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) + asyncio.run(run()) diff --git a/tests/csip_test.py b/tests/csip_test.py index 6f2c7fda..5899d81f 100644 --- a/tests/csip_test.py +++ b/tests/csip_test.py @@ -31,6 +31,41 @@ logger = logging.getLogger(__name__) +# ----------------------------------------------------------------------------- +def test_s1(): + assert ( + csip.s1(b'SIRKenc'[::-1]) + == bytes.fromhex('6901983f 18149e82 3c7d133a 7d774572')[::-1] + ) + + +# ----------------------------------------------------------------------------- +def test_k1(): + K = bytes.fromhex('676e1b9b d448696f 061ec622 3ce5ced9')[::-1] + SALT = csip.s1(b'SIRKenc'[::-1]) + P = b'csis'[::-1] + assert ( + csip.k1(K, SALT, P) + == bytes.fromhex('5277453c c094d982 b0e8ee53 2f2d1f8b')[::-1] + ) + + +# ----------------------------------------------------------------------------- +def test_sih(): + SIRK = bytes.fromhex('457d7d09 21a1fd22 cecd8c86 dd72cccd')[::-1] + PRAND = bytes.fromhex('69f563')[::-1] + assert csip.sih(SIRK, PRAND) == bytes.fromhex('1948da')[::-1] + + +# ----------------------------------------------------------------------------- +def test_sef(): + SIRK = bytes.fromhex('457d7d09 21a1fd22 cecd8c86 dd72cccd')[::-1] + K = bytes.fromhex('676e1b9b d448696f 061ec622 3ce5ced9')[::-1] + assert ( + csip.sef(K, SIRK) == bytes.fromhex('170a3835 e13524a0 7e2562d5 f25fd346')[::-1] + ) + + # ----------------------------------------------------------------------------- @pytest.mark.asyncio async def test_csis(): @@ -40,6 +75,7 @@ async def test_csis(): devices[0].add_service( csip.CoordinatedSetIdentificationService( set_identity_resolving_key=SIRK, + set_identity_resolving_key_type=csip.SirkType.PLAINTEXT, coordinated_set_size=2, set_member_lock=csip.MemberLock.UNLOCKED, set_member_rank=0, @@ -65,6 +101,7 @@ async def test_csis(): # ----------------------------------------------------------------------------- async def run(): + test_sih() await test_csis()