Merge "Merge remote-tracking branch 'aosp/upstream-main' into master" am: 127794c6a5 am: 5761825687 am: 5bd8a87b1b am: 84b96d9a9a
Original change: https://android-review.googlesource.com/c/platform/external/python/bumble/+/2514427
Change-Id: Ib1306d451cfbd7abb83024c60addda0046f68b83
Signed-off-by: Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>
diff --git a/apps/bench.py b/apps/bench.py
new file mode 100644
index 0000000..19cdcfa
--- /dev/null
+++ b/apps/bench.py
@@ -0,0 +1,1207 @@
+# Copyright 2021-2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# -----------------------------------------------------------------------------
+# Imports
+# -----------------------------------------------------------------------------
+import asyncio
+import enum
+import logging
+import os
+import struct
+import time
+
+import click
+
+from bumble.core import (
+ BT_BR_EDR_TRANSPORT,
+ BT_LE_TRANSPORT,
+ BT_L2CAP_PROTOCOL_ID,
+ BT_RFCOMM_PROTOCOL_ID,
+ UUID,
+ CommandTimeoutError,
+)
+from bumble.colors import color
+from bumble.device import Connection, ConnectionParametersPreferences, Device, Peer
+from bumble.gatt import Characteristic, CharacteristicValue, Service
+from bumble.hci import (
+ HCI_LE_1M_PHY,
+ HCI_LE_2M_PHY,
+ HCI_LE_CODED_PHY,
+ HCI_Constant,
+ HCI_Error,
+ HCI_StatusError,
+)
+from bumble.sdp import (
+ SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
+ SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
+ SDP_PUBLIC_BROWSE_ROOT,
+ SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
+ SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
+ DataElement,
+ ServiceAttribute,
+)
+from bumble.transport import open_transport_or_link
+import bumble.rfcomm
+import bumble.core
+from bumble.utils import AsyncRunner
+
+
+# -----------------------------------------------------------------------------
+# Logging
+# -----------------------------------------------------------------------------
+logger = logging.getLogger(__name__)
+
+
+# -----------------------------------------------------------------------------
+# Constants
+# -----------------------------------------------------------------------------
+DEFAULT_CENTRAL_ADDRESS = 'F0:F0:F0:F0:F0:F0'
+DEFAULT_CENTRAL_NAME = 'Speed Central'
+DEFAULT_PERIPHERAL_ADDRESS = 'F1:F1:F1:F1:F1:F1'
+DEFAULT_PERIPHERAL_NAME = 'Speed Peripheral'
+
+SPEED_SERVICE_UUID = '50DB505C-8AC4-4738-8448-3B1D9CC09CC5'
+SPEED_TX_UUID = 'E789C754-41A1-45F4-A948-A0A1A90DBA53'
+SPEED_RX_UUID = '016A2CC7-E14B-4819-935F-1F56EAE4098D'
+
+DEFAULT_L2CAP_PSM = 1234
+DEFAULT_L2CAP_MAX_CREDITS = 128
+DEFAULT_L2CAP_MTU = 1022
+DEFAULT_L2CAP_MPS = 1024
+
+DEFAULT_LINGER_TIME = 1.0
+
+DEFAULT_RFCOMM_CHANNEL = 8
+
+# -----------------------------------------------------------------------------
+# Utils
+# -----------------------------------------------------------------------------
+def parse_packet(packet):
+ if len(packet) < 1:
+ print(
+ color(f'!!! Packet too short (got {len(packet)} bytes, need >= 1)', 'red')
+ )
+ raise ValueError('packet too short')
+
+ try:
+ packet_type = PacketType(packet[0])
+ except ValueError:
+ print(color(f'!!! Invalid packet type 0x{packet[0]:02X}', 'red'))
+ raise
+
+ return (packet_type, packet[1:])
+
+
+def parse_packet_sequence(packet_data):
+ if len(packet_data) < 5:
+ print(
+ color(
+ f'!!!Packet too short (got {len(packet_data)} bytes, need >= 5)',
+ 'red',
+ )
+ )
+ raise ValueError('packet too short')
+ return struct.unpack_from('>bI', packet_data, 0)
+
+
+def le_phy_name(phy_id):
+ return {HCI_LE_1M_PHY: '1M', HCI_LE_2M_PHY: '2M', HCI_LE_CODED_PHY: 'CODED'}.get(
+ phy_id, HCI_Constant.le_phy_name(phy_id)
+ )
+
+
+def print_connection(connection):
+ if connection.transport == BT_LE_TRANSPORT:
+ phy_state = (
+ 'PHY='
+ f'RX:{le_phy_name(connection.phy.rx_phy)}/'
+ f'TX:{le_phy_name(connection.phy.tx_phy)}'
+ )
+
+ data_length = f'DL={connection.data_length}'
+ connection_parameters = (
+ 'Parameters='
+ f'{connection.parameters.connection_interval * 1.25:.2f}/'
+ f'{connection.parameters.peripheral_latency}/'
+ f'{connection.parameters.supervision_timeout * 10} '
+ )
+
+ else:
+ phy_state = ''
+ data_length = ''
+ connection_parameters = ''
+
+ mtu = connection.att_mtu
+
+ print(
+ f'{color("@@@ Connection:", "yellow")} '
+ f'{connection_parameters} '
+ f'{data_length} '
+ f'{phy_state} '
+ f'MTU={mtu}'
+ )
+
+
+def make_sdp_records(channel):
+ return {
+ 0x00010001: [
+ ServiceAttribute(
+ SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
+ DataElement.unsigned_integer_32(0x00010001),
+ ),
+ ServiceAttribute(
+ SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
+ DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
+ ),
+ ServiceAttribute(
+ SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
+ DataElement.sequence(
+ [DataElement.uuid(UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'))]
+ ),
+ ),
+ ServiceAttribute(
+ SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
+ DataElement.sequence(
+ [
+ DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
+ DataElement.sequence(
+ [
+ DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
+ DataElement.unsigned_integer_8(channel),
+ ]
+ ),
+ ]
+ ),
+ ),
+ ]
+ }
+
+
+class PacketType(enum.IntEnum):
+ RESET = 0
+ SEQUENCE = 1
+ ACK = 2
+
+
+PACKET_FLAG_LAST = 1
+
+# -----------------------------------------------------------------------------
+# Sender
+# -----------------------------------------------------------------------------
+class Sender:
+ def __init__(self, packet_io, start_delay, packet_size, packet_count):
+ self.tx_start_delay = start_delay
+ self.tx_packet_size = packet_size
+ self.tx_packet_count = packet_count
+ self.packet_io = packet_io
+ self.packet_io.packet_listener = self
+ self.start_time = 0
+ self.bytes_sent = 0
+ self.done = asyncio.Event()
+
+ def reset(self):
+ pass
+
+ async def run(self):
+ print(color('--- Waiting for I/O to be ready...', 'blue'))
+ await self.packet_io.ready.wait()
+ print(color('--- Go!', 'blue'))
+
+ if self.tx_start_delay:
+ print(color(f'*** Startup delay: {self.tx_start_delay}', 'blue'))
+ await asyncio.sleep(self.tx_start_delay) # FIXME
+
+ print(color('=== Sending RESET', 'magenta'))
+ await self.packet_io.send_packet(bytes([PacketType.RESET]))
+ self.start_time = time.time()
+ for tx_i in range(self.tx_packet_count):
+ packet_flags = PACKET_FLAG_LAST if tx_i == self.tx_packet_count - 1 else 0
+ packet = struct.pack(
+ '>bbI',
+ PacketType.SEQUENCE,
+ packet_flags,
+ tx_i,
+ ) + bytes(self.tx_packet_size - 6)
+ print(color(f'Sending packet {tx_i}: {len(packet)} bytes', 'yellow'))
+ self.bytes_sent += len(packet)
+ await self.packet_io.send_packet(packet)
+
+ await self.done.wait()
+ print(color('=== Done!', 'magenta'))
+
+ def on_packet_received(self, packet):
+ try:
+ packet_type, _ = parse_packet(packet)
+ except ValueError:
+ return
+
+ if packet_type == PacketType.ACK:
+ elapsed = time.time() - self.start_time
+ average_tx_speed = self.bytes_sent / elapsed
+ print(
+ color(
+ f'@@@ Received ACK. Speed: average={average_tx_speed:.4f}'
+ f' ({self.bytes_sent} bytes in {elapsed:.2f} seconds)',
+ 'green',
+ )
+ )
+ self.done.set()
+
+
+# -----------------------------------------------------------------------------
+# Receiver
+# -----------------------------------------------------------------------------
+class Receiver:
+ def __init__(self, packet_io):
+ self.reset()
+ self.packet_io = packet_io
+ self.packet_io.packet_listener = self
+ self.done = asyncio.Event()
+
+ def reset(self):
+ self.expected_packet_index = 0
+ self.start_timestamp = 0.0
+ self.last_timestamp = 0.0
+ self.bytes_received = 0
+
+ def on_packet_received(self, packet):
+ try:
+ packet_type, packet_data = parse_packet(packet)
+ except ValueError:
+ return
+
+ now = time.time()
+
+ if packet_type == PacketType.RESET:
+ print(color('=== Received RESET', 'magenta'))
+ self.reset()
+ self.start_timestamp = now
+ return
+
+ try:
+ packet_flags, packet_index = parse_packet_sequence(packet_data)
+ except ValueError:
+ return
+ print(
+ f'<<< Received packet {packet_index}: '
+ f'flags=0x{packet_flags:02X}, {len(packet)} bytes'
+ )
+
+ if packet_index != self.expected_packet_index:
+ print(
+ color(
+ f'!!! Unexpected packet, expected {self.expected_packet_index} '
+ f'but received {packet_index}'
+ )
+ )
+
+ elapsed_since_start = now - self.start_timestamp
+ elapsed_since_last = now - self.last_timestamp
+ self.bytes_received += len(packet)
+ instant_rx_speed = len(packet) / elapsed_since_last
+ average_rx_speed = self.bytes_received / elapsed_since_start
+ print(
+ color(
+ f'Speed: instant={instant_rx_speed:.4f}, average={average_rx_speed:.4f}',
+ 'yellow',
+ )
+ )
+
+ self.last_timestamp = now
+ self.expected_packet_index = packet_index + 1
+
+ if packet_flags & PACKET_FLAG_LAST:
+ AsyncRunner.spawn(
+ self.packet_io.send_packet(
+ struct.pack('>bbI', PacketType.ACK, packet_flags, packet_index)
+ )
+ )
+ print(color('@@@ Received last packet', 'green'))
+ self.done.set()
+
+ async def run(self):
+ await self.done.wait()
+ print(color('=== Done!', 'magenta'))
+
+
+# -----------------------------------------------------------------------------
+# Ping
+# -----------------------------------------------------------------------------
+class Ping:
+ def __init__(self, packet_io, start_delay, packet_size, packet_count):
+ self.tx_start_delay = start_delay
+ self.tx_packet_size = packet_size
+ self.tx_packet_count = packet_count
+ self.packet_io = packet_io
+ self.packet_io.packet_listener = self
+ self.done = asyncio.Event()
+ self.current_packet_index = 0
+ self.ping_sent_time = 0.0
+ self.latencies = []
+
+ def reset(self):
+ pass
+
+ async def run(self):
+ print(color('--- Waiting for I/O to be ready...', 'blue'))
+ await self.packet_io.ready.wait()
+ print(color('--- Go!', 'blue'))
+
+ if self.tx_start_delay:
+ print(color(f'*** Startup delay: {self.tx_start_delay}', 'blue'))
+ await asyncio.sleep(self.tx_start_delay) # FIXME
+
+ print(color('=== Sending RESET', 'magenta'))
+ await self.packet_io.send_packet(bytes([PacketType.RESET]))
+
+ await self.send_next_ping()
+
+ await self.done.wait()
+ average_latency = sum(self.latencies) / len(self.latencies)
+ print(color(f'@@@ Average latency: {average_latency:.2f}'))
+ print(color('=== Done!', 'magenta'))
+
+ async def send_next_ping(self):
+ packet = struct.pack(
+ '>bbI',
+ PacketType.SEQUENCE,
+ PACKET_FLAG_LAST
+ if self.current_packet_index == self.tx_packet_count - 1
+ else 0,
+ self.current_packet_index,
+ ) + bytes(self.tx_packet_size - 6)
+ print(color(f'Sending packet {self.current_packet_index}', 'yellow'))
+ self.ping_sent_time = time.time()
+ await self.packet_io.send_packet(packet)
+
+ def on_packet_received(self, packet):
+ elapsed = time.time() - self.ping_sent_time
+
+ try:
+ packet_type, packet_data = parse_packet(packet)
+ except ValueError:
+ return
+
+ try:
+ packet_flags, packet_index = parse_packet_sequence(packet_data)
+ except ValueError:
+ return
+
+ if packet_type == PacketType.ACK:
+ latency = elapsed * 1000
+ self.latencies.append(latency)
+ print(
+ color(
+ f'<<< Received ACK [{packet_index}], latency={latency:.2f}ms',
+ 'green',
+ )
+ )
+
+ if packet_index == self.current_packet_index:
+ self.current_packet_index += 1
+ else:
+ print(
+ color(
+ f'!!! Unexpected packet, expected {self.current_packet_index} '
+ f'but received {packet_index}'
+ )
+ )
+
+ if packet_flags & PACKET_FLAG_LAST:
+ self.done.set()
+ return
+
+ AsyncRunner.spawn(self.send_next_ping())
+
+
+# -----------------------------------------------------------------------------
+# Pong
+# -----------------------------------------------------------------------------
+class Pong:
+ def __init__(self, packet_io):
+ self.reset()
+ self.packet_io = packet_io
+ self.packet_io.packet_listener = self
+ self.done = asyncio.Event()
+
+ def reset(self):
+ self.expected_packet_index = 0
+
+ def on_packet_received(self, packet):
+ try:
+ packet_type, packet_data = parse_packet(packet)
+ except ValueError:
+ return
+
+ if packet_type == PacketType.RESET:
+ print(color('=== Received RESET', 'magenta'))
+ self.reset()
+ return
+
+ try:
+ packet_flags, packet_index = parse_packet_sequence(packet_data)
+ except ValueError:
+ return
+ print(
+ color(
+ f'<<< Received packet {packet_index}: '
+ f'flags=0x{packet_flags:02X}, {len(packet)} bytes',
+ 'green',
+ )
+ )
+
+ if packet_index != self.expected_packet_index:
+ print(
+ color(
+ f'!!! Unexpected packet, expected {self.expected_packet_index} '
+ f'but received {packet_index}'
+ )
+ )
+
+ self.expected_packet_index = packet_index + 1
+
+ AsyncRunner.spawn(
+ self.packet_io.send_packet(
+ struct.pack('>bbI', PacketType.ACK, packet_flags, packet_index)
+ )
+ )
+
+ if packet_flags & PACKET_FLAG_LAST:
+ self.done.set()
+
+ async def run(self):
+ await self.done.wait()
+ print(color('=== Done!', 'magenta'))
+
+
+# -----------------------------------------------------------------------------
+# GattClient
+# -----------------------------------------------------------------------------
+class GattClient:
+ def __init__(self, _device, att_mtu=None):
+ self.att_mtu = att_mtu
+ self.speed_rx = None
+ self.speed_tx = None
+ self.packet_listener = None
+ self.ready = asyncio.Event()
+
+ async def on_connection(self, connection):
+ peer = Peer(connection)
+
+ if self.att_mtu:
+ print(color(f'*** Requesting MTU update: {self.att_mtu}', 'blue'))
+ await peer.request_mtu(self.att_mtu)
+
+ print(color('*** Discovering services...', 'blue'))
+ await peer.discover_services()
+
+ speed_services = peer.get_services_by_uuid(SPEED_SERVICE_UUID)
+ if not speed_services:
+ print(color('!!! Speed Service not found', 'red'))
+ return
+ speed_service = speed_services[0]
+ print(color('*** Discovering characteristics...', 'blue'))
+ await speed_service.discover_characteristics()
+
+ speed_txs = speed_service.get_characteristics_by_uuid(SPEED_TX_UUID)
+ if not speed_txs:
+ print(color('!!! Speed TX not found', 'red'))
+ return
+ self.speed_tx = speed_txs[0]
+
+ speed_rxs = speed_service.get_characteristics_by_uuid(SPEED_RX_UUID)
+ if not speed_rxs:
+ print(color('!!! Speed RX not found', 'red'))
+ return
+ self.speed_rx = speed_rxs[0]
+
+ print(color('*** Subscribing to RX', 'blue'))
+ await self.speed_rx.subscribe(self.on_packet_received)
+
+ print(color('*** Discovery complete', 'blue'))
+
+ connection.on('disconnection', self.on_disconnection)
+ self.ready.set()
+
+ def on_disconnection(self, _):
+ self.ready.clear()
+
+ def on_packet_received(self, packet):
+ if self.packet_listener:
+ self.packet_listener.on_packet_received(packet)
+
+ async def send_packet(self, packet):
+ await self.speed_tx.write_value(packet)
+
+
+# -----------------------------------------------------------------------------
+# GattServer
+# -----------------------------------------------------------------------------
+class GattServer:
+ def __init__(self, device):
+ self.device = device
+ self.packet_listener = None
+ self.ready = asyncio.Event()
+
+ # Setup the GATT service
+ self.speed_tx = Characteristic(
+ SPEED_TX_UUID,
+ Characteristic.WRITE,
+ Characteristic.WRITEABLE,
+ CharacteristicValue(write=self.on_tx_write),
+ )
+ self.speed_rx = Characteristic(SPEED_RX_UUID, Characteristic.NOTIFY, 0)
+
+ speed_service = Service(
+ SPEED_SERVICE_UUID,
+ [self.speed_tx, self.speed_rx],
+ )
+ device.add_services([speed_service])
+
+ self.speed_rx.on('subscription', self.on_rx_subscription)
+
+ async def on_connection(self, connection):
+ connection.on('disconnection', self.on_disconnection)
+
+ def on_disconnection(self, _):
+ self.ready.clear()
+
+ def on_rx_subscription(self, _connection, notify_enabled, _indicate_enabled):
+ if notify_enabled:
+ print(color('*** RX subscription', 'blue'))
+ self.ready.set()
+ else:
+ print(color('*** RX un-subscription', 'blue'))
+ self.ready.clear()
+
+ def on_tx_write(self, _, value):
+ if self.packet_listener:
+ self.packet_listener.on_packet_received(value)
+
+ async def send_packet(self, packet):
+ await self.device.notify_subscribers(self.speed_rx, packet)
+
+
+# -----------------------------------------------------------------------------
+# StreamedPacketIO
+# -----------------------------------------------------------------------------
+class StreamedPacketIO:
+ def __init__(self):
+ self.packet_listener = None
+ self.io_sink = None
+ self.rx_packet = b''
+ self.rx_packet_header = b''
+ self.rx_packet_need = 0
+
+ def on_packet(self, packet):
+ while packet:
+ if self.rx_packet_need:
+ chunk = packet[: self.rx_packet_need]
+ self.rx_packet += chunk
+ packet = packet[len(chunk) :]
+ self.rx_packet_need -= len(chunk)
+ if not self.rx_packet_need:
+ # Packet completed
+ if self.packet_listener:
+ self.packet_listener.on_packet_received(self.rx_packet)
+
+ self.rx_packet = b''
+ self.rx_packet_header = b''
+ else:
+ # Expect the next packet
+ header_bytes_needed = 2 - len(self.rx_packet_header)
+ header_bytes = packet[:header_bytes_needed]
+ self.rx_packet_header += header_bytes
+ if len(self.rx_packet_header) != 2:
+ return
+ packet = packet[len(header_bytes) :]
+ self.rx_packet_need = struct.unpack('>H', self.rx_packet_header)[0]
+
+ async def send_packet(self, packet):
+ if not self.io_sink:
+ print(color('!!! No sink, dropping packet', 'red'))
+ return
+
+ # pylint: disable-next=not-callable
+ self.io_sink(struct.pack('>H', len(packet)) + packet)
+
+
+# -----------------------------------------------------------------------------
+# L2capClient
+# -----------------------------------------------------------------------------
+class L2capClient(StreamedPacketIO):
+ def __init__(
+ self,
+ _device,
+ psm=DEFAULT_L2CAP_PSM,
+ max_credits=DEFAULT_L2CAP_MAX_CREDITS,
+ mtu=DEFAULT_L2CAP_MTU,
+ mps=DEFAULT_L2CAP_MPS,
+ ):
+ super().__init__()
+ self.psm = psm
+ self.max_credits = max_credits
+ self.mtu = mtu
+ self.mps = mps
+ self.ready = asyncio.Event()
+
+ async def on_connection(self, connection):
+ connection.on('disconnection', self.on_disconnection)
+
+ # Connect a new L2CAP channel
+ print(color(f'>>> Opening L2CAP channel on PSM = {self.psm}', 'yellow'))
+ try:
+ l2cap_channel = await connection.open_l2cap_channel(
+ psm=self.psm,
+ max_credits=self.max_credits,
+ mtu=self.mtu,
+ mps=self.mps,
+ )
+ print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
+ except Exception as error:
+ print(color(f'!!! Connection failed: {error}', 'red'))
+ return
+
+ l2cap_channel.sink = self.on_packet
+ l2cap_channel.on('close', self.on_l2cap_close)
+ self.io_sink = l2cap_channel.write
+
+ self.ready.set()
+
+ def on_disconnection(self, _):
+ pass
+
+ def on_l2cap_close(self):
+ print(color('*** L2CAP channel closed', 'red'))
+
+
+# -----------------------------------------------------------------------------
+# L2capServer
+# -----------------------------------------------------------------------------
+class L2capServer(StreamedPacketIO):
+ def __init__(
+ self,
+ device,
+ psm=DEFAULT_L2CAP_PSM,
+ max_credits=DEFAULT_L2CAP_MAX_CREDITS,
+ mtu=DEFAULT_L2CAP_MTU,
+ mps=DEFAULT_L2CAP_MPS,
+ ):
+ super().__init__()
+ self.l2cap_channel = None
+ self.ready = asyncio.Event()
+
+ # Listen for incoming L2CAP CoC connections
+ device.register_l2cap_channel_server(
+ psm=psm,
+ server=self.on_l2cap_channel,
+ max_credits=max_credits,
+ mtu=mtu,
+ mps=mps,
+ )
+ print(color(f'### Listening for CoC connection on PSM {psm}', 'yellow'))
+
+ async def on_connection(self, connection):
+ connection.on('disconnection', self.on_disconnection)
+
+ def on_disconnection(self, _):
+ pass
+
+ def on_l2cap_channel(self, l2cap_channel):
+ print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
+
+ self.io_sink = l2cap_channel.write
+ l2cap_channel.on('close', self.on_l2cap_close)
+ l2cap_channel.sink = self.on_packet
+
+ self.ready.set()
+
+ def on_l2cap_close(self):
+ print(color('*** L2CAP channel closed', 'red'))
+ self.l2cap_channel = None
+
+
+# -----------------------------------------------------------------------------
+# RfcommClient
+# -----------------------------------------------------------------------------
+class RfcommClient(StreamedPacketIO):
+ def __init__(self, device):
+ super().__init__()
+ self.device = device
+ self.ready = asyncio.Event()
+
+ async def on_connection(self, connection):
+ connection.on('disconnection', self.on_disconnection)
+
+ # Create a client and start it
+ print(color('*** Starting RFCOMM client...', 'blue'))
+ rfcomm_client = bumble.rfcomm.Client(self.device, connection)
+ rfcomm_mux = await rfcomm_client.start()
+ print(color('*** Started', 'blue'))
+
+ channel = DEFAULT_RFCOMM_CHANNEL
+ print(color(f'### Opening session for channel {channel}...', 'yellow'))
+ try:
+ rfcomm_session = await rfcomm_mux.open_dlc(channel)
+ print(color('### Session open', 'yellow'), rfcomm_session)
+ except bumble.core.ConnectionError as error:
+ print(color(f'!!! Session open failed: {error}', 'red'))
+ await rfcomm_mux.disconnect()
+ return
+
+ rfcomm_session.sink = self.on_packet
+ self.io_sink = rfcomm_session.write
+
+ self.ready.set()
+
+ def on_disconnection(self, _):
+ pass
+
+
+# -----------------------------------------------------------------------------
+# RfcommServer
+# -----------------------------------------------------------------------------
+class RfcommServer(StreamedPacketIO):
+ def __init__(self, device):
+ super().__init__()
+ self.ready = asyncio.Event()
+
+ # Create and register a server
+ rfcomm_server = bumble.rfcomm.Server(device)
+
+ # Listen for incoming DLC connections
+ channel_number = rfcomm_server.listen(self.on_dlc, DEFAULT_RFCOMM_CHANNEL)
+
+ # Setup the SDP to advertise this channel
+ device.sdp_service_records = make_sdp_records(channel_number)
+
+ print(
+ color(
+ f'### Listening for RFComm connection on channel {channel_number}',
+ 'yellow',
+ )
+ )
+
+ async def on_connection(self, connection):
+ connection.on('disconnection', self.on_disconnection)
+
+ def on_disconnection(self, _):
+ pass
+
+ def on_dlc(self, dlc):
+ print(color('*** DLC connected:', 'blue'), dlc)
+ dlc.sink = self.on_packet
+ self.io_sink = dlc.write
+
+
+# -----------------------------------------------------------------------------
+# Central
+# -----------------------------------------------------------------------------
+class Central(Connection.Listener):
+ def __init__(
+ self,
+ transport,
+ peripheral_address,
+ classic,
+ role_factory,
+ mode_factory,
+ connection_interval,
+ phy,
+ ):
+ super().__init__()
+ self.transport = transport
+ self.peripheral_address = peripheral_address
+ self.classic = classic
+ self.role_factory = role_factory
+ self.mode_factory = mode_factory
+ self.device = None
+ self.connection = None
+
+ if phy:
+ self.phy = {
+ '1m': HCI_LE_1M_PHY,
+ '2m': HCI_LE_2M_PHY,
+ 'coded': HCI_LE_CODED_PHY,
+ }[phy]
+ else:
+ self.phy = None
+
+ if connection_interval:
+ connection_parameter_preferences = ConnectionParametersPreferences()
+ connection_parameter_preferences.connection_interval_min = (
+ connection_interval
+ )
+ connection_parameter_preferences.connection_interval_max = (
+ connection_interval
+ )
+
+ # Preferences for the 1M PHY are always set.
+ self.connection_parameter_preferences = {
+ HCI_LE_1M_PHY: connection_parameter_preferences,
+ }
+
+ if self.phy not in (None, HCI_LE_1M_PHY):
+ # Add an connections parameters entry for this PHY.
+ self.connection_parameter_preferences[
+ self.phy
+ ] = connection_parameter_preferences
+ else:
+ self.connection_parameter_preferences = None
+
+ async def run(self):
+ print(color('>>> Connecting to HCI...', 'green'))
+ async with await open_transport_or_link(self.transport) as (
+ hci_source,
+ hci_sink,
+ ):
+ print(color('>>> Connected', 'green'))
+
+ central_address = DEFAULT_CENTRAL_ADDRESS
+ self.device = Device.with_hci(
+ DEFAULT_CENTRAL_NAME, central_address, hci_source, hci_sink
+ )
+ mode = self.mode_factory(self.device)
+ role = self.role_factory(mode)
+ self.device.classic_enabled = self.classic
+
+ await self.device.power_on()
+
+ print(color(f'### Connecting to {self.peripheral_address}...', 'cyan'))
+ try:
+ self.connection = await self.device.connect(
+ self.peripheral_address,
+ connection_parameters_preferences=self.connection_parameter_preferences,
+ transport=BT_BR_EDR_TRANSPORT if self.classic else BT_LE_TRANSPORT,
+ )
+ except CommandTimeoutError:
+ print(color('!!! Connection timed out', 'red'))
+ return
+ except bumble.core.ConnectionError as error:
+ print(color(f'!!! Connection error: {error}', 'red'))
+ return
+ except HCI_StatusError as error:
+ print(color(f'!!! Connection failed: {error.error_name}'))
+ return
+ print(color('### Connected', 'cyan'))
+ self.connection.listener = self
+ print_connection(self.connection)
+
+ await mode.on_connection(self.connection)
+
+ # Set the PHY if requested
+ if self.phy is not None:
+ try:
+ await self.connection.set_phy(
+ tx_phys=[self.phy], rx_phys=[self.phy]
+ )
+ except HCI_Error as error:
+ print(
+ color(
+ f'!!! Unable to set the PHY: {error.error_name}', 'yellow'
+ )
+ )
+
+ await role.run()
+ await asyncio.sleep(DEFAULT_LINGER_TIME)
+
+ def on_disconnection(self, reason):
+ print(color(f'!!! Disconnection: reason={reason}', 'red'))
+ self.connection = None
+
+ def on_connection_parameters_update(self):
+ print_connection(self.connection)
+
+ def on_connection_phy_update(self):
+ print_connection(self.connection)
+
+ def on_connection_att_mtu_update(self):
+ print_connection(self.connection)
+
+ def on_connection_data_length_change(self):
+ print_connection(self.connection)
+
+
+# -----------------------------------------------------------------------------
+# Peripheral
+# -----------------------------------------------------------------------------
+class Peripheral(Device.Listener, Connection.Listener):
+ def __init__(self, transport, classic, role_factory, mode_factory):
+ self.transport = transport
+ self.classic = classic
+ self.role_factory = role_factory
+ self.role = None
+ self.mode_factory = mode_factory
+ self.mode = None
+ self.device = None
+ self.connection = None
+ self.connected = asyncio.Event()
+
+ async def run(self):
+ print(color('>>> Connecting to HCI...', 'green'))
+ async with await open_transport_or_link(self.transport) as (
+ hci_source,
+ hci_sink,
+ ):
+ print(color('>>> Connected', 'green'))
+
+ peripheral_address = DEFAULT_PERIPHERAL_ADDRESS
+ self.device = Device.with_hci(
+ DEFAULT_PERIPHERAL_NAME, peripheral_address, hci_source, hci_sink
+ )
+ self.device.listener = self
+ self.mode = self.mode_factory(self.device)
+ self.role = self.role_factory(self.mode)
+ self.device.classic_enabled = self.classic
+
+ await self.device.power_on()
+
+ if self.classic:
+ await self.device.set_discoverable(True)
+ await self.device.set_connectable(True)
+ else:
+ await self.device.start_advertising(auto_restart=True)
+
+ if self.classic:
+ print(
+ color(
+ '### Waiting for connection on'
+ f' {self.device.public_address}...',
+ 'cyan',
+ )
+ )
+ else:
+ print(
+ color(
+ f'### Waiting for connection on {peripheral_address}...',
+ 'cyan',
+ )
+ )
+ await self.connected.wait()
+ print(color('### Connected', 'cyan'))
+
+ await self.mode.on_connection(self.connection)
+ await self.role.run()
+ await asyncio.sleep(DEFAULT_LINGER_TIME)
+
+ def on_connection(self, connection):
+ connection.listener = self
+ self.connection = connection
+ self.connected.set()
+
+ def on_disconnection(self, reason):
+ print(color(f'!!! Disconnection: reason={reason}', 'red'))
+ self.connection = None
+ self.role.reset()
+
+ def on_connection_parameters_update(self):
+ print_connection(self.connection)
+
+ def on_connection_phy_update(self):
+ print_connection(self.connection)
+
+ def on_connection_att_mtu_update(self):
+ print_connection(self.connection)
+
+ def on_connection_data_length_change(self):
+ print_connection(self.connection)
+
+
+# -----------------------------------------------------------------------------
+def create_mode_factory(ctx, default_mode):
+ mode = ctx.obj['mode']
+ if mode is None:
+ mode = default_mode
+
+ def create_mode(device):
+ if mode == 'gatt-client':
+ return GattClient(device, att_mtu=ctx.obj['att_mtu'])
+
+ if mode == 'gatt-server':
+ return GattServer(device)
+
+ if mode == 'l2cap-client':
+ return L2capClient(device)
+
+ if mode == 'l2cap-server':
+ return L2capServer(device)
+
+ if mode == 'rfcomm-client':
+ return RfcommClient(device)
+
+ if mode == 'rfcomm-server':
+ return RfcommServer(device)
+
+ raise ValueError('invalid mode')
+
+ return create_mode
+
+
+# -----------------------------------------------------------------------------
+def create_role_factory(ctx, default_role):
+ role = ctx.obj['role']
+ if role is None:
+ role = default_role
+
+ def create_role(packet_io):
+ if role == 'sender':
+ return Sender(
+ packet_io,
+ start_delay=ctx.obj['start_delay'],
+ packet_size=ctx.obj['packet_size'],
+ packet_count=ctx.obj['packet_count'],
+ )
+
+ if role == 'receiver':
+ return Receiver(packet_io)
+
+ if role == 'ping':
+ return Ping(
+ packet_io,
+ start_delay=ctx.obj['start_delay'],
+ packet_size=ctx.obj['packet_size'],
+ packet_count=ctx.obj['packet_count'],
+ )
+
+ if role == 'pong':
+ return Pong(packet_io)
+
+ raise ValueError('invalid role')
+
+ return create_role
+
+
+# -----------------------------------------------------------------------------
+# Main
+# -----------------------------------------------------------------------------
+@click.group()
+@click.option('--device-config', metavar='FILENAME', help='Device configuration file')
+@click.option('--role', type=click.Choice(['sender', 'receiver', 'ping', 'pong']))
+@click.option(
+ '--mode',
+ type=click.Choice(
+ [
+ 'gatt-client',
+ 'gatt-server',
+ 'l2cap-client',
+ 'l2cap-server',
+ 'rfcomm-client',
+ 'rfcomm-server',
+ ]
+ ),
+)
+@click.option(
+ '--att-mtu',
+ metavar='MTU',
+ type=click.IntRange(23, 517),
+ help='GATT MTU (gatt-client mode)',
+)
+@click.option(
+ '--packet-size',
+ '-s',
+ metavar='SIZE',
+ type=click.IntRange(8, 4096),
+ default=500,
+ help='Packet size (server role)',
+)
+@click.option(
+ '--packet-count',
+ '-c',
+ metavar='COUNT',
+ type=int,
+ default=10,
+ help='Packet count (server role)',
+)
+@click.option(
+ '--start-delay',
+ '-sd',
+ metavar='SECONDS',
+ type=int,
+ default=1,
+ help='Start delay (server role)',
+)
+@click.pass_context
+def bench(
+ ctx, device_config, role, mode, att_mtu, packet_size, packet_count, start_delay
+):
+ ctx.ensure_object(dict)
+ ctx.obj['device_config'] = device_config
+ ctx.obj['role'] = role
+ ctx.obj['mode'] = mode
+ ctx.obj['att_mtu'] = att_mtu
+ ctx.obj['packet_size'] = packet_size
+ ctx.obj['packet_count'] = packet_count
+ ctx.obj['start_delay'] = start_delay
+
+ ctx.obj['classic'] = mode in ('rfcomm-client', 'rfcomm-server')
+
+
+@bench.command()
+@click.argument('transport')
+@click.option(
+ '--peripheral',
+ 'peripheral_address',
+ metavar='ADDRESS_OR_NAME',
+ default=DEFAULT_PERIPHERAL_ADDRESS,
+ help='Address or name to connect to',
+)
+@click.option(
+ '--connection-interval',
+ '--ci',
+ metavar='CONNECTION_INTERVAL',
+ type=int,
+ help='Connection interval (in ms)',
+)
+@click.option('--phy', type=click.Choice(['1m', '2m', 'coded']), help='PHY to use')
+@click.pass_context
+def central(ctx, transport, peripheral_address, connection_interval, phy):
+ """Run as a central (initiates the connection)"""
+ role_factory = create_role_factory(ctx, 'sender')
+ mode_factory = create_mode_factory(ctx, 'gatt-client')
+ classic = ctx.obj['classic']
+
+ asyncio.run(
+ Central(
+ transport,
+ peripheral_address,
+ classic,
+ role_factory,
+ mode_factory,
+ connection_interval,
+ phy,
+ ).run()
+ )
+
+
+@bench.command()
+@click.argument('transport')
+@click.pass_context
+def peripheral(ctx, transport):
+ """Run as a peripheral (waits for a connection)"""
+ role_factory = create_role_factory(ctx, 'receiver')
+ mode_factory = create_mode_factory(ctx, 'gatt-server')
+
+ asyncio.run(
+ Peripheral(transport, ctx.obj['classic'], role_factory, mode_factory).run()
+ )
+
+
+def main():
+ logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
+ bench()
+
+
+# -----------------------------------------------------------------------------
+if __name__ == "__main__":
+ main() # pylint: disable=no-value-for-parameter
diff --git a/apps/console.py b/apps/console.py
index b7c30c7..26223d7 100644
--- a/apps/console.py
+++ b/apps/console.py
@@ -24,6 +24,7 @@
import os
import random
import re
+from typing import Optional
from collections import OrderedDict
import click
@@ -58,6 +59,7 @@
from bumble.utils import AsyncRunner
from bumble.transport import open_transport_or_link
from bumble.gatt import Characteristic, Service, CharacteristicDeclaration, Descriptor
+from bumble.gatt_client import CharacteristicProxy
from bumble.hci import (
HCI_Constant,
HCI_LE_1M_PHY,
@@ -119,6 +121,8 @@
# Console App
# -----------------------------------------------------------------------------
class ConsoleApp:
+ connected_peer: Optional[Peer]
+
def __init__(self):
self.known_addresses = set()
self.known_attributes = []
@@ -218,7 +222,7 @@
filter=Condition(lambda: self.top_tab == 'local-services'),
),
ConditionalContainer(
- Frame(Window(self.remote_services_text), title='Remove Services'),
+ Frame(Window(self.remote_services_text), title='Remote Services'),
filter=Condition(lambda: self.top_tab == 'remote-services'),
),
ConditionalContainer(
@@ -490,7 +494,9 @@
self.show_attributes(attributes)
- def find_characteristic(self, param):
+ def find_characteristic(self, param) -> Optional[CharacteristicProxy]:
+ if not self.connected_peer:
+ return None
parts = param.split('.')
if len(parts) == 2:
service_uuid = UUID(parts[0]) if parts[0] != '*' else None
diff --git a/apps/controller_info.py b/apps/controller_info.py
index 9c9345e..4707983 100644
--- a/apps/controller_info.py
+++ b/apps/controller_info.py
@@ -30,6 +30,8 @@
HCI_VERSION_NAMES,
LMP_VERSION_NAMES,
HCI_Command,
+ HCI_Command_Complete_Event,
+ HCI_Command_Status_Event,
HCI_READ_BD_ADDR_COMMAND,
HCI_Read_BD_ADDR_Command,
HCI_READ_LOCAL_NAME_COMMAND,
@@ -46,10 +48,19 @@
# -----------------------------------------------------------------------------
+def command_succeeded(response):
+ if isinstance(response, HCI_Command_Status_Event):
+ return response.status == HCI_SUCCESS
+ if isinstance(response, HCI_Command_Complete_Event):
+ return response.return_parameters.status == HCI_SUCCESS
+ return False
+
+
+# -----------------------------------------------------------------------------
async def get_classic_info(host):
if host.supports_command(HCI_READ_BD_ADDR_COMMAND):
response = await host.send_command(HCI_Read_BD_ADDR_Command())
- if response.return_parameters.status == HCI_SUCCESS:
+ if command_succeeded(response):
print()
print(
color('Classic Address:', 'yellow'), response.return_parameters.bd_addr
@@ -57,7 +68,7 @@
if host.supports_command(HCI_READ_LOCAL_NAME_COMMAND):
response = await host.send_command(HCI_Read_Local_Name_Command())
- if response.return_parameters.status == HCI_SUCCESS:
+ if command_succeeded(response):
print()
print(
color('Local Name:', 'yellow'),
@@ -73,7 +84,7 @@
response = await host.send_command(
HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command()
)
- if response.return_parameters.status == HCI_SUCCESS:
+ if command_succeeded(response):
print(
color('LE Number Of Supported Advertising Sets:', 'yellow'),
response.return_parameters.num_supported_advertising_sets,
@@ -84,7 +95,7 @@
response = await host.send_command(
HCI_LE_Read_Maximum_Advertising_Data_Length_Command()
)
- if response.return_parameters.status == HCI_SUCCESS:
+ if command_succeeded(response):
print(
color('LE Maximum Advertising Data Length:', 'yellow'),
response.return_parameters.max_advertising_data_length,
@@ -93,7 +104,7 @@
if host.supports_command(HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND):
response = await host.send_command(HCI_LE_Read_Maximum_Data_Length_Command())
- if response.return_parameters.status == HCI_SUCCESS:
+ if command_succeeded(response):
print(
color('Maximum Data Length:', 'yellow'),
(
diff --git a/bumble/att.py b/bumble/att.py
index 2a15f00..8311d18 100644
--- a/bumble/att.py
+++ b/bumble/att.py
@@ -23,12 +23,13 @@
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
+import functools
import struct
from pyee import EventEmitter
from typing import Dict, Type, TYPE_CHECKING
-from bumble.core import UUID, name_or_number
-from bumble.hci import HCI_Object, key_with_value
+from bumble.core import UUID, name_or_number, get_dict_key_by_value, ProtocolError
+from bumble.hci import HCI_Object, key_with_value, HCI_Constant
from bumble.colors import color
if TYPE_CHECKING:
@@ -184,13 +185,18 @@
# -----------------------------------------------------------------------------
# Exceptions
# -----------------------------------------------------------------------------
-class ATT_Error(Exception):
- def __init__(self, error_code, att_handle=0x0000):
- self.error_code = error_code
+class ATT_Error(ProtocolError):
+ def __init__(self, error_code, att_handle=0x0000, message=''):
+ super().__init__(
+ error_code,
+ error_namespace='att',
+ error_name=ATT_PDU.error_name(error_code),
+ )
self.att_handle = att_handle
+ self.message = message
def __str__(self):
- return f'ATT_Error({ATT_PDU.error_name(self.error_code)})'
+ return f'ATT_Error(error={self.error_name}, handle={self.att_handle:04X}): {self.message}'
# -----------------------------------------------------------------------------
@@ -725,11 +731,38 @@
READ_REQUIRES_AUTHORIZATION = 0x40
WRITE_REQUIRES_AUTHORIZATION = 0x80
+ PERMISSION_NAMES = {
+ READABLE: 'READABLE',
+ WRITEABLE: 'WRITEABLE',
+ READ_REQUIRES_ENCRYPTION: 'READ_REQUIRES_ENCRYPTION',
+ WRITE_REQUIRES_ENCRYPTION: 'WRITE_REQUIRES_ENCRYPTION',
+ READ_REQUIRES_AUTHENTICATION: 'READ_REQUIRES_AUTHENTICATION',
+ WRITE_REQUIRES_AUTHENTICATION: 'WRITE_REQUIRES_AUTHENTICATION',
+ READ_REQUIRES_AUTHORIZATION: 'READ_REQUIRES_AUTHORIZATION',
+ WRITE_REQUIRES_AUTHORIZATION: 'WRITE_REQUIRES_AUTHORIZATION',
+ }
+
+ @staticmethod
+ def string_to_permissions(permissions_str: str):
+ try:
+ return functools.reduce(
+ lambda x, y: x | get_dict_key_by_value(Attribute.PERMISSION_NAMES, y),
+ permissions_str.split(","),
+ 0,
+ )
+ except TypeError:
+ raise TypeError(
+ f"Attribute::permissions error:\nExpected a string containing any of the keys, seperated by commas: {','.join(Attribute.PERMISSION_NAMES.values())}\nGot: {permissions_str}"
+ )
+
def __init__(self, attribute_type, permissions, value=b''):
EventEmitter.__init__(self)
self.handle = 0
self.end_group_handle = 0
- self.permissions = permissions
+ if isinstance(permissions, str):
+ self.permissions = self.string_to_permissions(permissions)
+ else:
+ self.permissions = permissions
# Convert the type to a UUID object if it isn't already
if isinstance(attribute_type, str):
diff --git a/bumble/controller.py b/bumble/controller.py
index da5d4cf..cd7de3d 100644
--- a/bumble/controller.py
+++ b/bumble/controller.py
@@ -21,7 +21,12 @@
import random
import struct
from bumble.colors import color
-from bumble.core import BT_CENTRAL_ROLE, BT_PERIPHERAL_ROLE
+from bumble.core import (
+ BT_CENTRAL_ROLE,
+ BT_PERIPHERAL_ROLE,
+ BT_LE_TRANSPORT,
+ BT_BR_EDR_TRANSPORT,
+)
from bumble.hci import (
HCI_ACL_DATA_PACKET,
@@ -29,17 +34,21 @@
HCI_COMMAND_PACKET,
HCI_COMMAND_STATUS_PENDING,
HCI_CONNECTION_TIMEOUT_ERROR,
+ HCI_CONTROLLER_BUSY_ERROR,
HCI_EVENT_PACKET,
HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR,
HCI_LE_1M_PHY,
HCI_SUCCESS,
HCI_UNKNOWN_HCI_COMMAND_ERROR,
+ HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
HCI_VERSION_BLUETOOTH_CORE_5_0,
Address,
HCI_AclDataPacket,
HCI_AclDataPacketAssembler,
HCI_Command_Complete_Event,
HCI_Command_Status_Event,
+ HCI_Connection_Complete_Event,
+ HCI_Connection_Request_Event,
HCI_Disconnection_Complete_Event,
HCI_Encryption_Change_Event,
HCI_LE_Advertising_Report_Event,
@@ -47,7 +56,9 @@
HCI_LE_Read_Remote_Features_Complete_Event,
HCI_Number_Of_Completed_Packets_Event,
HCI_Packet,
+ HCI_Role_Change_Event,
)
+from typing import Optional, Union, Dict
# -----------------------------------------------------------------------------
@@ -65,13 +76,14 @@
# -----------------------------------------------------------------------------
class Connection:
- def __init__(self, controller, handle, role, peer_address, link):
+ def __init__(self, controller, handle, role, peer_address, link, transport):
self.controller = controller
self.handle = handle
self.role = role
self.peer_address = peer_address
self.link = link
self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
+ self.transport = transport
def on_hci_acl_data_packet(self, packet):
self.assembler.feed_packet(packet)
@@ -82,23 +94,33 @@
def on_acl_pdu(self, data):
if self.link:
self.link.send_acl_data(
- self.controller.random_address, self.peer_address, data
+ self.controller, self.peer_address, self.transport, data
)
# -----------------------------------------------------------------------------
class Controller:
- def __init__(self, name, host_source=None, host_sink=None, link=None):
+ def __init__(
+ self,
+ name,
+ host_source=None,
+ host_sink=None,
+ link=None,
+ public_address: Optional[Union[bytes, str, Address]] = None,
+ ):
self.name = name
self.hci_sink = None
self.link = link
- self.central_connections = (
- {}
- ) # Connections where this controller is the central
- self.peripheral_connections = (
- {}
- ) # Connections where this controller is the peripheral
+ self.central_connections: Dict[
+ Address, Connection
+ ] = {} # Connections where this controller is the central
+ self.peripheral_connections: Dict[
+ Address, Connection
+ ] = {} # Connections where this controller is the peripheral
+ self.classic_connections: Dict[
+ Address, Connection
+ ] = {} # Connections in BR/EDR
self.hci_version = HCI_VERSION_BLUETOOTH_CORE_5_0
self.hci_revision = 0
@@ -148,7 +170,14 @@
self.advertising_timer_handle = None
self._random_address = Address('00:00:00:00:00:00')
- self._public_address = None
+ if isinstance(public_address, Address):
+ self._public_address = public_address
+ elif public_address is not None:
+ self._public_address = Address(
+ public_address, Address.PUBLIC_DEVICE_ADDRESS
+ )
+ else:
+ self._public_address = Address('00:00:00:00:00:00')
# Set the source and sink interfaces
if host_source:
@@ -271,7 +300,9 @@
handle = 0
max_handle = 0
for connection in itertools.chain(
- self.central_connections.values(), self.peripheral_connections.values()
+ self.central_connections.values(),
+ self.peripheral_connections.values(),
+ self.classic_connections.values(),
):
max_handle = max(max_handle, connection.handle)
if connection.handle == handle:
@@ -279,14 +310,19 @@
handle = max_handle + 1
return handle
- def find_connection_by_address(self, address):
+ def find_le_connection_by_address(self, address):
return self.central_connections.get(address) or self.peripheral_connections.get(
address
)
+ def find_classic_connection_by_address(self, address):
+ return self.classic_connections.get(address)
+
def find_connection_by_handle(self, handle):
for connection in itertools.chain(
- self.central_connections.values(), self.peripheral_connections.values()
+ self.central_connections.values(),
+ self.peripheral_connections.values(),
+ self.classic_connections.values(),
):
if connection.handle == handle:
return connection
@@ -298,6 +334,12 @@
return connection
return None
+ def find_classic_connection_by_handle(self, handle):
+ for connection in self.classic_connections.values():
+ if connection.handle == handle:
+ return connection
+ return None
+
def on_link_central_connected(self, central_address):
'''
Called when an incoming connection occurs from a central on the link
@@ -310,7 +352,12 @@
if connection is None:
connection_handle = self.allocate_connection_handle()
connection = Connection(
- self, connection_handle, BT_PERIPHERAL_ROLE, peer_address, self.link
+ self,
+ connection_handle,
+ BT_PERIPHERAL_ROLE,
+ peer_address,
+ self.link,
+ BT_LE_TRANSPORT,
)
self.peripheral_connections[peer_address] = connection
logger.debug(f'New PERIPHERAL connection handle: 0x{connection_handle:04X}')
@@ -364,7 +411,12 @@
if connection is None:
connection_handle = self.allocate_connection_handle()
connection = Connection(
- self, connection_handle, BT_CENTRAL_ROLE, peer_address, self.link
+ self,
+ connection_handle,
+ BT_CENTRAL_ROLE,
+ peer_address,
+ self.link,
+ BT_LE_TRANSPORT,
)
self.central_connections[peer_address] = connection
logger.debug(
@@ -432,16 +484,19 @@
def on_link_encrypted(self, peer_address, _rand, _ediv, _ltk):
# For now, just setup the encryption without asking the host
- if connection := self.find_connection_by_address(peer_address):
+ if connection := self.find_le_connection_by_address(peer_address):
self.send_hci_packet(
HCI_Encryption_Change_Event(
status=0, connection_handle=connection.handle, encryption_enabled=1
)
)
- def on_link_acl_data(self, sender_address, data):
+ def on_link_acl_data(self, sender_address, transport, data):
# Look for the connection to which this data belongs
- connection = self.find_connection_by_address(sender_address)
+ if transport == BT_LE_TRANSPORT:
+ connection = self.find_le_connection_by_address(sender_address)
+ else:
+ connection = self.find_classic_connection_by_address(sender_address)
if connection is None:
logger.warning(f'!!! no connection for {sender_address}')
return
@@ -479,6 +534,87 @@
self.send_hci_packet(HCI_LE_Advertising_Report_Event([report]))
############################################################
+ # Classic link connections
+ ############################################################
+
+ def on_classic_connection_request(self, peer_address, link_type):
+ self.send_hci_packet(
+ HCI_Connection_Request_Event(
+ bd_addr=peer_address,
+ class_of_device=0,
+ link_type=link_type,
+ )
+ )
+
+ def on_classic_connection_complete(self, peer_address, status):
+ if status == HCI_SUCCESS:
+ # Allocate (or reuse) a connection handle
+ peer_address = peer_address
+ connection = self.classic_connections.get(peer_address)
+ if connection is None:
+ connection_handle = self.allocate_connection_handle()
+ connection = Connection(
+ controller=self,
+ handle=connection_handle,
+ # Role doesn't matter in Classic because they are managed by HCI_Role_Change and HCI_Role_Discovery
+ role=BT_CENTRAL_ROLE,
+ peer_address=peer_address,
+ link=self.link,
+ transport=BT_BR_EDR_TRANSPORT,
+ )
+ self.classic_connections[peer_address] = connection
+ logger.debug(
+ f'New CLASSIC connection handle: 0x{connection_handle:04X}'
+ )
+ else:
+ connection_handle = connection.handle
+ self.send_hci_packet(
+ HCI_Connection_Complete_Event(
+ status=status,
+ connection_handle=connection_handle,
+ bd_addr=peer_address,
+ encryption_enabled=False,
+ link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
+ )
+ )
+ else:
+ connection = None
+ self.send_hci_packet(
+ HCI_Connection_Complete_Event(
+ status=status,
+ connection_handle=0,
+ bd_addr=peer_address,
+ encryption_enabled=False,
+ link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
+ )
+ )
+
+ def on_classic_disconnected(self, peer_address, reason):
+ # Send a disconnection complete event
+ if connection := self.classic_connections.get(peer_address):
+ self.send_hci_packet(
+ HCI_Disconnection_Complete_Event(
+ status=HCI_SUCCESS,
+ connection_handle=connection.handle,
+ reason=reason,
+ )
+ )
+
+ # Remove the connection
+ del self.classic_connections[peer_address]
+ else:
+ logger.warning(f'!!! No classic connection found for {peer_address}')
+
+ def on_classic_role_change(self, peer_address, new_role):
+ self.send_hci_packet(
+ HCI_Role_Change_Event(
+ status=HCI_SUCCESS,
+ bd_addr=peer_address,
+ new_role=new_role,
+ )
+ )
+
+ ############################################################
# Advertising support
############################################################
def on_advertising_timer_fired(self):
@@ -521,7 +657,31 @@
See Bluetooth spec Vol 2, Part E - 7.1.5 Create Connection command
'''
- # TODO: classic mode not supported yet
+ if self.link is None:
+ return
+ logger.debug(f'Connection request to {command.bd_addr}')
+
+ # Check that we don't already have a pending connection
+ if self.link.get_pending_connection():
+ self.send_hci_packet(
+ HCI_Command_Status_Event(
+ status=HCI_CONTROLLER_BUSY_ERROR,
+ num_hci_command_packets=1,
+ command_opcode=command.op_code,
+ )
+ )
+ return
+
+ self.link.classic_connect(self, command.bd_addr)
+
+ # Say that the connection is pending
+ self.send_hci_packet(
+ HCI_Command_Status_Event(
+ status=HCI_COMMAND_STATUS_PENDING,
+ num_hci_command_packets=1,
+ command_opcode=command.op_code,
+ )
+ )
def on_hci_disconnect_command(self, command):
'''
@@ -537,19 +697,57 @@
)
# Notify the link of the disconnection
- if not (
- connection := self.find_central_connection_by_handle(
- command.connection_handle
- )
- ):
- logger.warning('connection not found')
- return
+ handle = command.connection_handle
+ if connection := self.find_central_connection_by_handle(handle):
+ if self.link:
+ self.link.disconnect(
+ self.random_address, connection.peer_address, command
+ )
+ else:
+ # Remove the connection
+ del self.central_connections[connection.peer_address]
+ elif connection := self.find_classic_connection_by_handle(handle):
+ if self.link:
+ self.link.classic_disconnect(
+ self,
+ connection.peer_address,
+ HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
+ )
+ else:
+ # Remove the connection
+ del self.classic_connections[connection.peer_address]
- if self.link:
- self.link.disconnect(self.random_address, connection.peer_address, command)
- else:
- # Remove the connection
- del self.central_connections[connection.peer_address]
+ def on_hci_accept_connection_request_command(self, command):
+ '''
+ See Bluetooth spec Vol 2, Part E - 7.1.8 Accept Connection Request command
+ '''
+
+ if self.link is None:
+ return
+ self.send_hci_packet(
+ HCI_Command_Status_Event(
+ status=HCI_SUCCESS,
+ num_hci_command_packets=1,
+ command_opcode=command.op_code,
+ )
+ )
+ self.link.classic_accept_connection(self, command.bd_addr, command.role)
+
+ def on_hci_switch_role_command(self, command):
+ '''
+ See Bluetooth spec Vol 2, Part E - 7.2.8 Switch Role command
+ '''
+
+ if self.link is None:
+ return
+ self.send_hci_packet(
+ HCI_Command_Status_Event(
+ status=HCI_SUCCESS,
+ num_hci_command_packets=1,
+ command_opcode=command.op_code,
+ )
+ )
+ self.link.classic_switch_role(self, command.bd_addr, command.role)
def on_hci_set_event_mask_command(self, command):
'''
@@ -627,6 +825,12 @@
ret = HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR
return bytes([ret])
+ def on_hci_write_extended_inquiry_response_command(self, _command):
+ '''
+ See Bluetooth spec Vol 2, Part E - 7.3.59 Write Simple Pairing Mode Command
+ '''
+ return bytes([HCI_SUCCESS])
+
def on_hci_write_simple_pairing_mode_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.59 Write Simple Pairing Mode Command
diff --git a/bumble/decoder.py b/bumble/decoder.py
new file mode 100644
index 0000000..2eb70bc
--- /dev/null
+++ b/bumble/decoder.py
@@ -0,0 +1,416 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# -----------------------------------------------------------------------------
+# Constants
+# -----------------------------------------------------------------------------
+# fmt: off
+
+WL = [-60, -30, 58, 172, 334, 538, 1198, 3042]
+RL42 = [0, 7, 6, 5, 4, 3, 2, 1, 7, 6, 5, 4, 3, 2, 1, 0]
+ILB = [
+ 2048,
+ 2093,
+ 2139,
+ 2186,
+ 2233,
+ 2282,
+ 2332,
+ 2383,
+ 2435,
+ 2489,
+ 2543,
+ 2599,
+ 2656,
+ 2714,
+ 2774,
+ 2834,
+ 2896,
+ 2960,
+ 3025,
+ 3091,
+ 3158,
+ 3228,
+ 3298,
+ 3371,
+ 3444,
+ 3520,
+ 3597,
+ 3676,
+ 3756,
+ 3838,
+ 3922,
+ 4008,
+]
+WH = [0, -214, 798]
+RH2 = [2, 1, 2, 1]
+# Values in QM2/QM4/QM6 left shift three bits than original g722 specification.
+QM2 = [-7408, -1616, 7408, 1616]
+QM4 = [
+ 0,
+ -20456,
+ -12896,
+ -8968,
+ -6288,
+ -4240,
+ -2584,
+ -1200,
+ 20456,
+ 12896,
+ 8968,
+ 6288,
+ 4240,
+ 2584,
+ 1200,
+ 0,
+]
+QM6 = [
+ -136,
+ -136,
+ -136,
+ -136,
+ -24808,
+ -21904,
+ -19008,
+ -16704,
+ -14984,
+ -13512,
+ -12280,
+ -11192,
+ -10232,
+ -9360,
+ -8576,
+ -7856,
+ -7192,
+ -6576,
+ -6000,
+ -5456,
+ -4944,
+ -4464,
+ -4008,
+ -3576,
+ -3168,
+ -2776,
+ -2400,
+ -2032,
+ -1688,
+ -1360,
+ -1040,
+ -728,
+ 24808,
+ 21904,
+ 19008,
+ 16704,
+ 14984,
+ 13512,
+ 12280,
+ 11192,
+ 10232,
+ 9360,
+ 8576,
+ 7856,
+ 7192,
+ 6576,
+ 6000,
+ 5456,
+ 4944,
+ 4464,
+ 4008,
+ 3576,
+ 3168,
+ 2776,
+ 2400,
+ 2032,
+ 1688,
+ 1360,
+ 1040,
+ 728,
+ 432,
+ 136,
+ -432,
+ -136,
+]
+QMF_COEFFS = [3, -11, 12, 32, -210, 951, 3876, -805, 362, -156, 53, -11]
+
+# fmt: on
+
+
+# -----------------------------------------------------------------------------
+# Classes
+# -----------------------------------------------------------------------------
+class G722Decoder(object):
+ """G.722 decoder with bitrate 64kbit/s.
+
+ For the Blocks in the sub-band decoders, please refer to the G.722
+ specification for the required information. G722 specification:
+ https://www.itu.int/rec/T-REC-G.722-201209-I
+ """
+
+ def __init__(self):
+ self._x = [0] * 24
+ self._band = [Band(), Band()]
+ # The initial value in BLOCK 3L
+ self._band[0].det = 32
+ # The initial value in BLOCK 3H
+ self._band[1].det = 8
+
+ def decode_frame(self, encoded_data) -> bytearray:
+ result_array = bytearray(len(encoded_data) * 4)
+ self.g722_decode(result_array, encoded_data)
+ return result_array
+
+ def g722_decode(self, result_array, encoded_data) -> int:
+ """Decode the data frame using g722 decoder."""
+ result_length = 0
+
+ for code in encoded_data:
+ higher_bits = (code >> 6) & 0x03
+ lower_bits = code & 0x3F
+
+ rlow = self.lower_sub_band_decoder(lower_bits)
+ rhigh = self.higher_sub_band_decoder(higher_bits)
+
+ # Apply the receive QMF
+ self._x[:22] = self._x[2:]
+ self._x[22] = rlow + rhigh
+ self._x[23] = rlow - rhigh
+
+ xout2 = sum(self._x[2 * i] * QMF_COEFFS[i] for i in range(12))
+ xout1 = sum(self._x[2 * i + 1] * QMF_COEFFS[11 - i] for i in range(12))
+
+ result_length = self.update_decoded_result(
+ xout1, result_length, result_array
+ )
+ result_length = self.update_decoded_result(
+ xout2, result_length, result_array
+ )
+
+ return result_length
+
+ def update_decoded_result(self, xout, byte_length, byte_array) -> int:
+ result = (int)(xout >> 11)
+ bytes_result = result.to_bytes(2, 'little', signed=True)
+ byte_array[byte_length] = bytes_result[0]
+ byte_array[byte_length + 1] = bytes_result[1]
+ return byte_length + 2
+
+ def lower_sub_band_decoder(self, lower_bits) -> int:
+ """Lower sub-band decoder for last six bits."""
+
+ # Block 5L
+ # INVQBL
+ wd1 = lower_bits
+ wd2 = QM6[wd1]
+ wd1 >>= 2
+ wd2 = (self._band[0].det * wd2) >> 15
+ # RECONS
+ rlow = self._band[0].s + wd2
+
+ # Block 6L
+ # LIMIT
+ if rlow > 16383:
+ rlow = 16383
+ elif rlow < -16384:
+ rlow = -16384
+
+ # Block 2L
+ # INVQAL
+ wd2 = QM4[wd1]
+ dlowt = (self._band[0].det * wd2) >> 15
+
+ # Block 3L
+ # LOGSCL
+ wd2 = RL42[wd1]
+ wd1 = (self._band[0].nb * 127) >> 7
+ wd1 += WL[wd2]
+
+ if wd1 < 0:
+ wd1 = 0
+ elif wd1 > 18432:
+ wd1 = 18432
+
+ self._band[0].nb = wd1
+
+ # SCALEL
+ wd1 = (self._band[0].nb >> 6) & 31
+ wd2 = 8 - (self._band[0].nb >> 11)
+
+ if wd2 < 0:
+ wd3 = ILB[wd1] << -wd2
+ else:
+ wd3 = ILB[wd1] >> wd2
+
+ self._band[0].det = wd3 << 2
+
+ # Block 4L
+ self._band[0].block4(dlowt)
+
+ return rlow
+
+ def higher_sub_band_decoder(self, higher_bits) -> int:
+ """Higher sub-band decoder for first two bits."""
+
+ # Block 2H
+ # INVQAH
+ wd2 = QM2[higher_bits]
+ dhigh = (self._band[1].det * wd2) >> 15
+
+ # Block 5H
+ # RECONS
+ rhigh = dhigh + self._band[1].s
+
+ # Block 6H
+ # LIMIT
+ if rhigh > 16383:
+ rhigh = 16383
+ elif rhigh < -16384:
+ rhigh = -16384
+
+ # Block 3H
+ # LOGSCH
+ wd2 = RH2[higher_bits]
+ wd1 = (self._band[1].nb * 127) >> 7
+ wd1 += WH[wd2]
+
+ if wd1 < 0:
+ wd1 = 0
+ elif wd1 > 22528:
+ wd1 = 22528
+ self._band[1].nb = wd1
+
+ # SCALEH
+ wd1 = (self._band[1].nb >> 6) & 31
+ wd2 = 10 - (self._band[1].nb >> 11)
+
+ if wd2 < 0:
+ wd3 = ILB[wd1] << -wd2
+ else:
+ wd3 = ILB[wd1] >> wd2
+ self._band[1].det = wd3 << 2
+
+ # Block 4H
+ self._band[1].block4(dhigh)
+
+ return rhigh
+
+
+# -----------------------------------------------------------------------------
+class Band(object):
+ """Structure for G722 decode proccessing."""
+
+ s: int = 0
+ nb: int = 0
+ det: int = 0
+
+ def __init__(self):
+ self._sp = 0
+ self._sz = 0
+ self._r = [0] * 3
+ self._a = [0] * 3
+ self._ap = [0] * 3
+ self._p = [0] * 3
+ self._d = [0] * 7
+ self._b = [0] * 7
+ self._bp = [0] * 7
+ self._sg = [0] * 7
+
+ def saturate(self, amp: int) -> int:
+ if amp > 32767:
+ return 32767
+ elif amp < -32768:
+ return -32768
+ else:
+ return amp
+
+ def block4(self, d: int) -> None:
+ """Block4 for both lower and higher sub-band decoder."""
+ wd1 = 0
+ wd2 = 0
+ wd3 = 0
+
+ # RECONS
+ self._d[0] = d
+ self._r[0] = self.saturate(self.s + d)
+
+ # PARREC
+ self._p[0] = self.saturate(self._sz + d)
+
+ # UPPOL2
+ for i in range(3):
+ self._sg[i] = (self._p[i]) >> 15
+ wd1 = self.saturate((self._a[1]) << 2)
+ wd2 = -wd1 if self._sg[0] == self._sg[1] else wd1
+
+ if wd2 > 32767:
+ wd2 = 32767
+
+ wd3 = 128 if self._sg[0] == self._sg[2] else -128
+ wd3 += wd2 >> 7
+ wd3 += (self._a[2] * 32512) >> 15
+
+ if wd3 > 12288:
+ wd3 = 12288
+ elif wd3 < -12288:
+ wd3 = -12288
+ self._ap[2] = wd3
+
+ # UPPOL1
+ self._sg[0] = (self._p[0]) >> 15
+ self._sg[1] = (self._p[1]) >> 15
+ wd1 = 192 if self._sg[0] == self._sg[1] else -192
+ wd2 = (self._a[1] * 32640) >> 15
+
+ self._ap[1] = self.saturate(wd1 + wd2)
+ wd3 = self.saturate(15360 - self._ap[2])
+
+ if self._ap[1] > wd3:
+ self._ap[1] = wd3
+ elif self._ap[1] < -wd3:
+ self._ap[1] = -wd3
+
+ # UPZERO
+ wd1 = 0 if d == 0 else 128
+ self._sg[0] = d >> 15
+ for i in range(1, 7):
+ self._sg[i] = (self._d[i]) >> 15
+ wd2 = wd1 if self._sg[i] == self._sg[0] else -wd1
+ wd3 = (self._b[i] * 32640) >> 15
+ self._bp[i] = self.saturate(wd2 + wd3)
+
+ # DELAYA
+ for i in range(6, 0, -1):
+ self._d[i] = self._d[i - 1]
+ self._b[i] = self._bp[i]
+
+ for i in range(2, 0, -1):
+ self._r[i] = self._r[i - 1]
+ self._p[i] = self._p[i - 1]
+ self._a[i] = self._ap[i]
+
+ # FILTEP
+ self._sp = 0
+ for i in range(1, 3):
+ wd1 = self.saturate(self._r[i] + self._r[i])
+ self._sp += (self._a[i] * wd1) >> 15
+ self._sp = self.saturate(self._sp)
+
+ # FILTEZ
+ self._sz = 0
+ for i in range(6, 0, -1):
+ wd1 = self.saturate(self._d[i] + self._d[i])
+ self._sz += (self._b[i] * wd1) >> 15
+ self._sz = self.saturate(self._sz)
+
+ # PREDIC
+ self.s = self.saturate(self._sp + self._sz)
diff --git a/bumble/device.py b/bumble/device.py
index 512bb1d..25fa099 100644
--- a/bumble/device.py
+++ b/bumble/device.py
@@ -50,6 +50,7 @@
HCI_LE_EXTENDED_CREATE_CONNECTION_COMMAND,
HCI_LE_RAND_COMMAND,
HCI_LE_READ_PHY_COMMAND,
+ HCI_LE_SET_PHY_COMMAND,
HCI_MITM_NOT_REQUIRED_GENERAL_BONDING_AUTHENTICATION_REQUIREMENTS,
HCI_MITM_NOT_REQUIRED_NO_BONDING_AUTHENTICATION_REQUIREMENTS,
HCI_MITM_REQUIRED_GENERAL_BONDING_AUTHENTICATION_REQUIREMENTS,
@@ -94,10 +95,13 @@
HCI_LE_Set_Scan_Enable_Command,
HCI_LE_Set_Scan_Parameters_Command,
HCI_LE_Set_Scan_Response_Data_Command,
+ HCI_PIN_Code_Request_Reply_Command,
+ HCI_PIN_Code_Request_Negative_Reply_Command,
HCI_Read_BD_ADDR_Command,
HCI_Read_RSSI_Command,
HCI_Reject_Connection_Request_Command,
HCI_Remote_Name_Request_Command,
+ HCI_Switch_Role_Command,
HCI_Set_Connection_Encryption_Command,
HCI_StatusError,
HCI_User_Confirmation_Request_Negative_Reply_Command,
@@ -310,6 +314,9 @@
def update(self, report):
advertisement = Advertisement.from_advertising_report(report)
+ if advertisement is None:
+ return None
+
result = None
if advertisement.is_scan_response:
@@ -615,7 +622,9 @@
assert self.transport == BT_BR_EDR_TRANSPORT
self.handle = handle
self.peer_resolvable_address = peer_resolvable_address
- self.role = role
+ # Quirk: role might be known before complete
+ if self.role is None:
+ self.role = role
self.parameters = parameters
@property
@@ -663,6 +672,9 @@
async def encrypt(self, enable: bool = True) -> None:
return await self.device.encrypt(self, enable)
+ async def switch_role(self, role: int) -> None:
+ return await self.device.switch_role(self, role)
+
async def sustain(self, timeout=None):
"""Idles the current task waiting for a disconnect or timeout"""
@@ -739,6 +751,7 @@
self.le_enabled = True
# LE host enable 2nd parameter
self.le_simultaneous_enabled = True
+ self.classic_enabled = False
self.classic_sc_enabled = True
self.classic_ssp_enabled = True
self.classic_accept_any = True
@@ -768,6 +781,7 @@
self.le_simultaneous_enabled = config.get(
'le_simultaneous_enabled', self.le_simultaneous_enabled
)
+ self.classic_enabled = config.get('classic_enabled', self.classic_enabled)
self.classic_sc_enabled = config.get(
'classic_sc_enabled', self.classic_sc_enabled
)
@@ -979,6 +993,7 @@
self.keystore = KeyStore.create_for_device(config)
self.irk = config.irk
self.le_enabled = config.le_enabled
+ self.classic_enabled = config.classic_enabled
self.le_simultaneous_enabled = config.le_simultaneous_enabled
self.classic_ssp_enabled = config.classic_ssp_enabled
self.classic_sc_enabled = config.classic_sc_enabled
@@ -991,15 +1006,20 @@
for characteristic in service.get("characteristics", []):
descriptors = []
for descriptor in characteristic.get("descriptors", []):
+ # Leave this check until 5/25/2023
+ if descriptor.get("permission", False):
+ raise Exception(
+ "Error parsing Device Config's GATT Services. The key 'permission' must be renamed to 'permissions'"
+ )
new_descriptor = Descriptor(
attribute_type=descriptor["descriptor_type"],
- permissions=descriptor["permission"],
+ permissions=descriptor["permissions"],
)
descriptors.append(new_descriptor)
new_characteristic = Characteristic(
uuid=characteristic["uuid"],
properties=characteristic["properties"],
- permissions=int(characteristic["permissions"], 0),
+ permissions=characteristic["permissions"],
descriptors=descriptors,
)
characteristics.append(new_characteristic)
@@ -1242,6 +1262,11 @@
# Done
self.powered_on = True
+ async def power_off(self) -> None:
+ if self.powered_on:
+ await self.host.flush()
+ self.powered_on = False
+
def supports_le_feature(self, feature):
return self.host.supports_le_feature(feature)
@@ -1666,7 +1691,7 @@
)
)
if not phys:
- raise ValueError('least one supported PHY needed')
+ raise ValueError('at least one supported PHY needed')
phy_count = len(phys)
initiating_phys = phy_list_to_bits(phys)
@@ -1807,7 +1832,7 @@
try:
return await self.abort_on('flush', pending_connection)
- except ConnectionError as error:
+ except core.ConnectionError as error:
raise core.TimeoutError() from error
finally:
self.remove_listener('connection', on_connection)
@@ -2041,21 +2066,31 @@
async def set_connection_phy(
self, connection, tx_phys=None, rx_phys=None, phy_options=None
):
+ if not self.host.supports_command(HCI_LE_SET_PHY_COMMAND):
+ logger.warning('ignoring request, command not supported')
+ return
+
all_phys_bits = (1 if tx_phys is None else 0) | (
(1 if rx_phys is None else 0) << 1
)
- return await self.send_command(
+ result = await self.send_command(
HCI_LE_Set_PHY_Command(
connection_handle=connection.handle,
all_phys=all_phys_bits,
tx_phys=phy_list_to_bits(tx_phys),
rx_phys=phy_list_to_bits(rx_phys),
phy_options=0 if phy_options is None else int(phy_options),
- ),
- check_result=True,
+ )
)
+ if result.status != HCI_COMMAND_STATUS_PENDING:
+ logger.warning(
+ 'HCI_LE_Set_PHY_Command failed: '
+ f'{HCI_Constant.error_name(result.status)}'
+ )
+ raise HCI_StatusError(result)
+
async def set_default_phy(self, tx_phys=None, rx_phys=None):
all_phys_bits = (1 if tx_phys is None else 0) | (
(1 if rx_phys is None else 0) << 1
@@ -2290,6 +2325,34 @@
)
# [Classic only]
+ async def switch_role(self, connection: Connection, role: int):
+ pending_role_change = asyncio.get_running_loop().create_future()
+
+ def on_role_change(new_role):
+ pending_role_change.set_result(new_role)
+
+ def on_role_change_failure(error_code):
+ pending_role_change.set_exception(HCI_Error(error_code))
+
+ connection.on('role_change', on_role_change)
+ connection.on('role_change_failure', on_role_change_failure)
+
+ try:
+ result = await self.send_command(
+ HCI_Switch_Role_Command(bd_addr=connection.peer_address, role=role) # type: ignore[call-arg]
+ )
+ if result.status != HCI_COMMAND_STATUS_PENDING:
+ logger.warning(
+ 'HCI_Switch_Role_Command failed: '
+ f'{HCI_Constant.error_name(result.status)}'
+ )
+ raise HCI_StatusError(result)
+ await connection.abort_on('disconnection', pending_role_change)
+ finally:
+ connection.remove_listener('role_change', on_role_change)
+ connection.remove_listener('role_change_failure', on_role_change_failure)
+
+ # [Classic only]
async def request_remote_name(self, remote: Union[Address, Connection]) -> str:
# Set up event handlers
pending_name = asyncio.get_running_loop().create_future()
@@ -2494,7 +2557,7 @@
self.advertising = False
# Notify listeners
- error = ConnectionError(
+ error = core.ConnectionError(
error_code,
transport,
peer_address,
@@ -2567,7 +2630,7 @@
@with_connection_from_handle
def on_disconnection_failure(self, connection, error_code):
logger.debug(f'*** Disconnection failed: {error_code}')
- error = ConnectionError(
+ error = core.ConnectionError(
error_code,
connection.transport,
connection.peer_address,
@@ -2765,6 +2828,54 @@
# [Classic only]
@host_event_handler
@with_connection_from_address
+ def on_pin_code_request(self, connection):
+ # classic legacy pairing
+ # Ask what the pairing config should be for this connection
+ pairing_config = self.pairing_config_factory(connection)
+
+ can_input = pairing_config.delegate.io_capability in (
+ smp.SMP_KEYBOARD_ONLY_IO_CAPABILITY,
+ smp.SMP_KEYBOARD_DISPLAY_IO_CAPABILITY,
+ )
+
+ # respond the pin code
+ if can_input:
+
+ async def get_pin_code():
+ pin_code = await connection.abort_on(
+ 'disconnection', pairing_config.delegate.get_string(16)
+ )
+
+ if pin_code is not None:
+ pin_code = bytes(pin_code, encoding='utf-8')
+ pin_code_len = len(pin_code)
+ assert 0 < pin_code_len <= 16, "pin_code should be 1-16 bytes"
+ await self.host.send_command(
+ HCI_PIN_Code_Request_Reply_Command(
+ bd_addr=connection.peer_address,
+ pin_code_length=pin_code_len,
+ pin_code=pin_code,
+ )
+ )
+ else:
+ logger.debug("delegate.get_string() returned None")
+ await self.host.send_command(
+ HCI_PIN_Code_Request_Negative_Reply_Command(
+ bd_addr=connection.peer_address
+ )
+ )
+
+ asyncio.create_task(get_pin_code())
+ else:
+ self.host.send_command_sync(
+ HCI_PIN_Code_Request_Negative_Reply_Command(
+ bd_addr=connection.peer_address
+ )
+ )
+
+ # [Classic only]
+ @host_event_handler
+ @with_connection_from_address
def on_authentication_user_passkey_notification(self, connection, passkey):
# Ask what the pairing config should be for this connection
pairing_config = self.pairing_config_factory(connection)
@@ -2776,7 +2887,7 @@
# [Classic only]
@host_event_handler
@try_with_connection_from_address
- def on_remote_name(self, connection, address, remote_name):
+ def on_remote_name(self, connection: Connection, address, remote_name):
# Try to decode the name
try:
remote_name = remote_name.decode('utf-8')
@@ -2794,7 +2905,7 @@
# [Classic only]
@host_event_handler
@try_with_connection_from_address
- def on_remote_name_failure(self, connection, address, error):
+ def on_remote_name_failure(self, connection: Connection, address, error):
if connection:
connection.emit('remote_name_failure', error)
self.emit('remote_name_failure', address, error)
@@ -2905,6 +3016,21 @@
)
connection.emit('connection_data_length_change')
+ # [Classic only]
+ @host_event_handler
+ @with_connection_from_address
+ def on_role_change(self, connection, new_role):
+ connection.role = new_role
+ connection.emit('role_change', new_role)
+
+ # [Classic only]
+ @host_event_handler
+ @try_with_connection_from_address
+ def on_role_change_failure(self, connection, address, error):
+ if connection:
+ connection.emit('role_change_failure', error)
+ self.emit('role_change_failure', address, error)
+
@with_connection_from_handle
def on_pairing_start(self, connection):
connection.emit('pairing_start')
diff --git a/bumble/gatt_client.py b/bumble/gatt_client.py
index 2fd7573..25add18 100644
--- a/bumble/gatt_client.py
+++ b/bumble/gatt_client.py
@@ -23,9 +23,11 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
+from __future__ import annotations
import asyncio
import logging
import struct
+from typing import List, Optional
from pyee import EventEmitter
@@ -50,6 +52,7 @@
ATT_Read_Request,
ATT_Write_Command,
ATT_Write_Request,
+ ATT_Error,
)
from . import core
from .core import UUID, InvalidStateError, ProtocolError
@@ -59,6 +62,7 @@
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
GATT_REQUEST_TIMEOUT,
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
+ Service,
Characteristic,
ClientCharacteristicConfigurationBits,
)
@@ -73,6 +77,8 @@
# Proxies
# -----------------------------------------------------------------------------
class AttributeProxy(EventEmitter):
+ client: Client
+
def __init__(self, client, handle, end_group_handle, attribute_type):
EventEmitter.__init__(self)
self.client = client
@@ -101,6 +107,9 @@
class ServiceProxy(AttributeProxy):
+ uuid: UUID
+ characteristics: List[CharacteristicProxy]
+
@staticmethod
def from_client(service_class, client, service_uuid):
# The service and its characteristics are considered to have already been
@@ -130,6 +139,8 @@
class CharacteristicProxy(AttributeProxy):
+ descriptors: List[DescriptorProxy]
+
def __init__(self, client, handle, end_group_handle, uuid, properties):
super().__init__(client, handle, end_group_handle, uuid)
self.uuid = uuid
@@ -201,6 +212,8 @@
# GATT Client
# -----------------------------------------------------------------------------
class Client:
+ services: List[ServiceProxy]
+
def __init__(self, connection):
self.connection = connection
self.mtu_exchange_done = False
@@ -306,7 +319,7 @@
if not already_known:
self.services.append(service)
- async def discover_services(self, uuids=None):
+ async def discover_services(self, uuids=None) -> List[ServiceProxy]:
'''
See Vol 3, Part G - 4.4.1 Discover All Primary Services
'''
@@ -332,8 +345,10 @@
'!!! unexpected error while discovering services: '
f'{HCI_Constant.error_name(response.error_code)}'
)
- # TODO raise appropriate exception
- return
+ raise ATT_Error(
+ error_code=response.error_code,
+ message='Unexpected error while discovering services',
+ )
break
for (
@@ -349,7 +364,7 @@
logger.warning(
f'bogus handle values: {attribute_handle} {end_group_handle}'
)
- return
+ return []
# Create a service proxy for this service
service = ServiceProxy(
@@ -452,7 +467,9 @@
# TODO
return []
- async def discover_characteristics(self, uuids, service):
+ async def discover_characteristics(
+ self, uuids, service: Optional[ServiceProxy]
+ ) -> List[CharacteristicProxy]:
'''
See Vol 3, Part G - 4.6.1 Discover All Characteristics of a Service and 4.6.2
Discover Characteristics by UUID
@@ -465,12 +482,12 @@
services = [service] if service else self.services
# Perform characteristic discovery for each service
- discovered_characteristics = []
+ discovered_characteristics: List[CharacteristicProxy] = []
for service in services:
starting_handle = service.handle
ending_handle = service.end_group_handle
- characteristics = []
+ characteristics: List[CharacteristicProxy] = []
while starting_handle <= ending_handle:
response = await self.send_request(
ATT_Read_By_Type_Request(
@@ -491,8 +508,10 @@
'!!! unexpected error while discovering characteristics: '
f'{HCI_Constant.error_name(response.error_code)}'
)
- # TODO raise appropriate exception
- return
+ raise ATT_Error(
+ error_code=response.error_code,
+ message='Unexpected error while discovering characteristics',
+ )
break
# Stop if for some reason the list was empty
@@ -535,8 +554,11 @@
return discovered_characteristics
async def discover_descriptors(
- self, characteristic=None, start_handle=None, end_handle=None
- ):
+ self,
+ characteristic: Optional[CharacteristicProxy] = None,
+ start_handle=None,
+ end_handle=None,
+ ) -> List[DescriptorProxy]:
'''
See Vol 3, Part G - 4.7.1 Discover All Characteristic Descriptors
'''
@@ -549,7 +571,7 @@
else:
return []
- descriptors = []
+ descriptors: List[DescriptorProxy] = []
while starting_handle <= ending_handle:
response = await self.send_request(
ATT_Find_Information_Request(
diff --git a/bumble/gatt_server.py b/bumble/gatt_server.py
index d82f273..3a5953a 100644
--- a/bumble/gatt_server.py
+++ b/bumble/gatt_server.py
@@ -691,7 +691,7 @@
length=entry_size, attribute_data_list=b''.join(attribute_data_list)
)
else:
- logging.warning(f"not found {request}")
+ logging.debug(f"not found {request}")
self.send_response(connection, response)
diff --git a/bumble/hci.py b/bumble/hci.py
index bcab7a7..9b5793d 100644
--- a/bumble/hci.py
+++ b/bumble/hci.py
@@ -1491,7 +1491,7 @@
elif field_type == -2:
# 16-bit signed
field_value = struct.unpack_from('<h', data, offset)[0]
- offset += 1
+ offset += 2
elif field_type == 3:
# 24-bit unsigned
padded = data[offset : offset + 3] + bytes([0])
@@ -2099,6 +2099,24 @@
# -----------------------------------------------------------------------------
@HCI_Command.command(
+ fields=[
+ ('bd_addr', Address.parse_address),
+ ('pin_code_length', 1),
+ ('pin_code', 16),
+ ],
+ return_parameters_fields=[
+ ('status', STATUS_SPEC),
+ ('bd_addr', Address.parse_address),
+ ],
+)
+class HCI_PIN_Code_Request_Reply_Command(HCI_Command):
+ '''
+ See Bluetooth spec @ 7.1.12 PIN Code Request Reply Command
+ '''
+
+
+# -----------------------------------------------------------------------------
+@HCI_Command.command(
fields=[('bd_addr', Address.parse_address)],
return_parameters_fields=[
('status', STATUS_SPEC),
diff --git a/bumble/host.py b/bumble/host.py
index 9f667a1..9e05c8c 100644
--- a/bumble/host.py
+++ b/bumble/host.py
@@ -24,7 +24,10 @@
from bumble.l2cap import L2CAP_PDU
from bumble.snoop import Snooper
+from typing import Optional
+
from .hci import (
+ Address,
HCI_ACL_DATA_PACKET,
HCI_COMMAND_COMPLETE_EVENT,
HCI_COMMAND_PACKET,
@@ -53,7 +56,6 @@
HCI_LE_Write_Suggested_Default_Data_Length_Command,
HCI_Link_Key_Request_Negative_Reply_Command,
HCI_Link_Key_Request_Reply_Command,
- HCI_PIN_Code_Request_Negative_Reply_Command,
HCI_Packet,
HCI_Read_Buffer_Size_Command,
HCI_Read_Local_Supported_Commands_Command,
@@ -142,6 +144,24 @@
if controller_sink:
self.set_packet_sink(controller_sink)
+ def find_connection_by_bd_addr(
+ self,
+ bd_addr: Address,
+ transport: Optional[int] = None,
+ check_address_type: bool = False,
+ ) -> Optional[Connection]:
+ for connection in self.connections.values():
+ if connection.peer_address.to_bytes() == bd_addr.to_bytes():
+ if (
+ check_address_type
+ and connection.peer_address.address_type != bd_addr.address_type
+ ):
+ continue
+ if transport is None or connection.transport == transport:
+ return connection
+
+ return None
+
async def flush(self) -> None:
# Make sure no command is pending
await self.command_semaphore.acquire()
@@ -719,12 +739,17 @@
f'role change for {event.bd_addr}: '
f'{HCI_Constant.role_name(event.new_role)}'
)
- # TODO: lookup the connection and update the role
+ if connection := self.find_connection_by_bd_addr(
+ event.bd_addr, BT_BR_EDR_TRANSPORT
+ ):
+ connection.role = event.new_role
+ self.emit('role_change', event.bd_addr, event.new_role)
else:
logger.debug(
f'role change for {event.bd_addr} failed: '
f'{HCI_Constant.error_name(event.status)}'
)
+ self.emit('role_change_failure', event.bd_addr, event.status)
def on_hci_le_data_length_change_event(self, event):
self.emit(
@@ -794,11 +819,7 @@
)
def on_hci_pin_code_request_event(self, event):
- # For now, just refuse all requests
- # TODO: delegate the decision
- self.send_command_sync(
- HCI_PIN_Code_Request_Negative_Reply_Command(bd_addr=event.bd_addr)
- )
+ self.emit('pin_code_request', event.bd_addr)
def on_hci_link_key_request_event(self, event):
async def send_link_key():
diff --git a/bumble/l2cap.py b/bumble/l2cap.py
index 2610adc..ef7fdab 100644
--- a/bumble/l2cap.py
+++ b/bumble/l2cap.py
@@ -796,6 +796,11 @@
self.disconnection_result = asyncio.get_running_loop().create_future()
return await self.disconnection_result
+ def abort(self):
+ if self.state == self.OPEN:
+ self.change_state(self.CLOSED)
+ self.emit('close')
+
def send_configure_request(self):
options = L2CAP_Control_Frame.encode_configuration_options(
[
@@ -1105,6 +1110,10 @@
self.disconnection_result = asyncio.get_running_loop().create_future()
return await self.disconnection_result
+ def abort(self):
+ if self.state == self.CONNECTED:
+ self.change_state(self.DISCONNECTED)
+
def on_pdu(self, pdu):
if self.sink is None:
logger.warning('received pdu without a sink')
@@ -1492,8 +1501,12 @@
def on_disconnection(self, connection_handle, _reason):
logger.debug(f'disconnection from {connection_handle}, cleaning up channels')
if connection_handle in self.channels:
+ for _, channel in self.channels[connection_handle].items():
+ channel.abort()
del self.channels[connection_handle]
if connection_handle in self.le_coc_channels:
+ for _, channel in self.le_coc_channels[connection_handle].items():
+ channel.abort()
del self.le_coc_channels[connection_handle]
if connection_handle in self.identifiers:
del self.identifiers[connection_handle]
diff --git a/bumble/link.py b/bumble/link.py
index 82dd9db..85ad96e 100644
--- a/bumble/link.py
+++ b/bumble/link.py
@@ -19,12 +19,15 @@
import asyncio
from functools import partial
+from bumble.core import BT_PERIPHERAL_ROLE, BT_BR_EDR_TRANSPORT, BT_LE_TRANSPORT
from bumble.colors import color
from bumble.hci import (
Address,
HCI_SUCCESS,
HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR,
HCI_CONNECTION_TIMEOUT_ERROR,
+ HCI_PAGE_TIMEOUT_ERROR,
+ HCI_Connection_Complete_Event,
)
# -----------------------------------------------------------------------------
@@ -57,6 +60,11 @@
def __init__(self):
self.controllers = set()
self.pending_connection = None
+ self.pending_classic_connection = None
+
+ ############################################################
+ # Common utils
+ ############################################################
def add_controller(self, controller):
logger.debug(f'new controller: {controller}')
@@ -71,22 +79,39 @@
return controller
return None
- def on_address_changed(self, controller):
- pass
+ def find_classic_controller(self, address):
+ for controller in self.controllers:
+ if controller.public_address == address:
+ return controller
+ return None
def get_pending_connection(self):
return self.pending_connection
+ ############################################################
+ # LE handlers
+ ############################################################
+
+ def on_address_changed(self, controller):
+ pass
+
def send_advertising_data(self, sender_address, data):
# Send the advertising data to all controllers, except the sender
for controller in self.controllers:
if controller.random_address != sender_address:
controller.on_link_advertising_data(sender_address, data)
- def send_acl_data(self, sender_address, destination_address, data):
+ def send_acl_data(self, sender_controller, destination_address, transport, data):
# Send the data to the first controller with a matching address
- if controller := self.find_controller(destination_address):
- controller.on_link_acl_data(sender_address, data)
+ if transport == BT_LE_TRANSPORT:
+ destination_controller = self.find_controller(destination_address)
+ source_address = sender_controller.random_address
+ elif transport == BT_BR_EDR_TRANSPORT:
+ destination_controller = self.find_classic_controller(destination_address)
+ source_address = sender_controller.public_address
+
+ if destination_controller is not None:
+ destination_controller.on_link_acl_data(source_address, transport, data)
def on_connection_complete(self):
# Check that we expect this call
@@ -163,6 +188,89 @@
if peripheral_controller := self.find_controller(peripheral_address):
peripheral_controller.on_link_encrypted(central_address, rand, ediv, ltk)
+ ############################################################
+ # Classic handlers
+ ############################################################
+
+ def classic_connect(self, initiator_controller, responder_address):
+ logger.debug(
+ f'[Classic] {initiator_controller.public_address} connects to {responder_address}'
+ )
+ responder_controller = self.find_classic_controller(responder_address)
+ if responder_controller is None:
+ initiator_controller.on_classic_connection_complete(
+ responder_address, HCI_PAGE_TIMEOUT_ERROR
+ )
+ return
+ self.pending_classic_connection = (initiator_controller, responder_controller)
+
+ responder_controller.on_classic_connection_request(
+ initiator_controller.public_address,
+ HCI_Connection_Complete_Event.ACL_LINK_TYPE,
+ )
+
+ def classic_accept_connection(
+ self, responder_controller, initiator_address, responder_role
+ ):
+ logger.debug(
+ f'[Classic] {responder_controller.public_address} accepts to connect {initiator_address}'
+ )
+ initiator_controller = self.find_classic_controller(initiator_address)
+ if initiator_controller is None:
+ responder_controller.on_classic_connection_complete(
+ responder_controller.public_address, HCI_PAGE_TIMEOUT_ERROR
+ )
+ return
+
+ async def task():
+ if responder_role != BT_PERIPHERAL_ROLE:
+ initiator_controller.on_classic_role_change(
+ responder_controller.public_address, int(not (responder_role))
+ )
+ initiator_controller.on_classic_connection_complete(
+ responder_controller.public_address, HCI_SUCCESS
+ )
+
+ asyncio.create_task(task())
+ responder_controller.on_classic_role_change(
+ initiator_controller.public_address, responder_role
+ )
+ responder_controller.on_classic_connection_complete(
+ initiator_controller.public_address, HCI_SUCCESS
+ )
+ self.pending_classic_connection = None
+
+ def classic_disconnect(self, initiator_controller, responder_address, reason):
+ logger.debug(
+ f'[Classic] {initiator_controller.public_address} disconnects {responder_address}'
+ )
+ responder_controller = self.find_classic_controller(responder_address)
+
+ async def task():
+ initiator_controller.on_classic_disconnected(responder_address, reason)
+
+ asyncio.create_task(task())
+ responder_controller.on_classic_disconnected(
+ initiator_controller.public_address, reason
+ )
+
+ def classic_switch_role(
+ self, initiator_controller, responder_address, initiator_new_role
+ ):
+ responder_controller = self.find_classic_controller(responder_address)
+ if responder_controller is None:
+ return
+
+ async def task():
+ initiator_controller.on_classic_role_change(
+ responder_address, initiator_new_role
+ )
+
+ asyncio.create_task(task())
+ responder_controller.on_classic_role_change(
+ initiator_controller.public_address, int(not (initiator_new_role))
+ )
+
# -----------------------------------------------------------------------------
class RemoteLink:
@@ -200,6 +308,9 @@
def get_pending_connection(self):
return self.pending_connection
+ def get_pending_classic_connection(self):
+ return self.pending_classic_connection
+
async def wait_until_connected(self):
await self.websocket
@@ -366,7 +477,8 @@
async def send_acl_data_to_relay(self, peer_address, data):
await self.send_targeted_message(peer_address, f'acl:{data.hex()}')
- def send_acl_data(self, _, peer_address, data):
+ def send_acl_data(self, _, peer_address, _transport, data):
+ # TODO: handle different transport
self.execute(partial(self.send_acl_data_to_relay, peer_address, data))
async def send_connection_request_to_relay(self, peer_address):
diff --git a/bumble/profiles/asha_service.py b/bumble/profiles/asha_service.py
index fabaa28..1b1e93a 100644
--- a/bumble/profiles/asha_service.py
+++ b/bumble/profiles/asha_service.py
@@ -20,6 +20,7 @@
import logging
from typing import List
from ..core import AdvertisingData
+from ..device import Device, Connection
from ..gatt import (
GATT_ASHA_SERVICE,
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
@@ -31,7 +32,7 @@
Characteristic,
CharacteristicValue,
)
-from ..device import Device
+from ..utils import AsyncRunner
# -----------------------------------------------------------------------------
# Logging
@@ -55,16 +56,16 @@
self.hisyncid = hisyncid
self.capability = capability # Device Capabilities [Left, Monaural]
self.device = device
- self.emitted_data_name = 'ASHA_data_' + str(self.capability)
self.audio_out_data = b''
self.psm = psm # a non-zero psm is mainly for testing purpose
# Handler for volume control
- def on_volume_write(_connection, value):
+ def on_volume_write(connection, value):
logger.info(f'--- VOLUME Write:{value[0]}')
+ self.emit('volume', connection, value[0])
# Handler for audio control commands
- def on_audio_control_point_write(_connection, value):
+ def on_audio_control_point_write(connection: Connection, value):
logger.info(f'--- AUDIO CONTROL POINT Write:{value.hex()}')
opcode = value[0]
if opcode == AshaService.OPCODE_START:
@@ -76,14 +77,29 @@
f'volume={value[3]}, '
f'otherstate={value[4]}'
)
+ self.emit(
+ 'start',
+ connection,
+ {
+ 'codec': value[1],
+ 'audiotype': value[2],
+ 'volume': value[3],
+ 'otherstate': value[4],
+ },
+ )
elif opcode == AshaService.OPCODE_STOP:
logger.info('### STOP')
+ self.emit('stop', connection)
elif opcode == AshaService.OPCODE_STATUS:
logger.info(f'### STATUS: connected={value[1]}')
- # TODO Respond with a status
- # asyncio.create_task(device.notify_subscribers(audio_status_characteristic,
- # force=True))
+ # OPCODE_STATUS does not need audio status point update
+ if opcode != AshaService.OPCODE_STATUS:
+ AsyncRunner.spawn(
+ device.notify_subscribers(
+ self.audio_status_characteristic, force=True
+ )
+ )
self.read_only_properties_characteristic = Characteristic(
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
@@ -126,7 +142,7 @@
def on_data(data):
logging.debug(f'<<< data received:{data}')
- self.emit(self.emitted_data_name, data)
+ self.emit('data', channel.connection, data)
self.audio_out_data += data
channel.sink = on_data
diff --git a/bumble/rfcomm.py b/bumble/rfcomm.py
index a6b02ba..dbb2795 100644
--- a/bumble/rfcomm.py
+++ b/bumble/rfcomm.py
@@ -852,17 +852,27 @@
# Register ourselves with the L2CAP channel manager
device.register_l2cap_server(RFCOMM_PSM, self.on_connection)
- def listen(self, acceptor):
- # Find a free channel number
- for channel in range(
- RFCOMM_DYNAMIC_CHANNEL_NUMBER_START, RFCOMM_DYNAMIC_CHANNEL_NUMBER_END + 1
- ):
- if channel not in self.acceptors:
- self.acceptors[channel] = acceptor
- return channel
+ def listen(self, acceptor, channel=0):
+ if channel:
+ if channel in self.acceptors:
+ # Busy
+ return 0
+ else:
+ # Find a free channel number
+ for candidate in range(
+ RFCOMM_DYNAMIC_CHANNEL_NUMBER_START,
+ RFCOMM_DYNAMIC_CHANNEL_NUMBER_END + 1,
+ ):
+ if candidate not in self.acceptors:
+ channel = candidate
+ break
- # All channels used...
- return 0
+ if channel == 0:
+ # All channels used...
+ return 0
+
+ self.acceptors[channel] = acceptor
+ return channel
def on_connection(self, l2cap_channel):
logger.debug(f'+++ new L2CAP connection: {l2cap_channel}')
diff --git a/bumble/smp.py b/bumble/smp.py
index 9a72cb0..1714743 100644
--- a/bumble/smp.py
+++ b/bumble/smp.py
@@ -522,9 +522,19 @@
async def compare_numbers(self, number: int, digits: int) -> bool:
return True
- async def get_number(self) -> int:
+ async def get_number(self) -> Optional[int]:
+ '''
+ Returns an optional number as an answer to a passkey request.
+ Returning `None` will result in a negative reply.
+ '''
return 0
+ async def get_string(self, max_length) -> Optional[str]:
+ '''
+ Returns a string whose utf-8 encoding is up to max_length bytes.
+ '''
+ return None
+
# pylint: disable-next=unused-argument
async def display_number(self, number: int, digits: int) -> None:
pass
diff --git a/bumble/utils.py b/bumble/utils.py
index 474fff2..8a55684 100644
--- a/bumble/utils.py
+++ b/bumble/utils.py
@@ -20,7 +20,7 @@
import traceback
import collections
import sys
-from typing import Awaitable, TypeVar
+from typing import Awaitable, Set, TypeVar
from functools import wraps
from pyee import EventEmitter
@@ -157,6 +157,9 @@
# Shared default queue
default_queue = WorkQueue()
+ # Shared set of running tasks
+ running_tasks: Set[Awaitable] = set()
+
@staticmethod
def run_in_task(queue=None):
"""
@@ -187,6 +190,19 @@
return decorator
+ @staticmethod
+ def spawn(coroutine):
+ """
+ Spawn a task to run a coroutine in a "fire and forget" mode.
+
+ Using this method instead of just calling `asyncio.create_task(coroutine)`
+ is necessary when you don't keep a reference to the task, because `asyncio`
+ only keeps weak references to alive tasks.
+ """
+ task = asyncio.create_task(coroutine)
+ AsyncRunner.running_tasks.add(task)
+ task.add_done_callback(AsyncRunner.running_tasks.remove)
+
# -----------------------------------------------------------------------------
class FlowControlAsyncPipe:
diff --git a/docs/mkdocs/mkdocs.yml b/docs/mkdocs/mkdocs.yml
index 8cabd0c..0ddc982 100644
--- a/docs/mkdocs/mkdocs.yml
+++ b/docs/mkdocs/mkdocs.yml
@@ -43,7 +43,7 @@
- Apps & Tools:
- Overview: apps_and_tools/index.md
- Console: apps_and_tools/console.md
- - Link Relay: apps_and_tools/link_relay.md
+ - Bench: apps_and_tools/bench.md
- HCI Bridge: apps_and_tools/hci_bridge.md
- Golden Gate Bridge: apps_and_tools/gg_bridge.md
- Show: apps_and_tools/show.md
@@ -51,6 +51,7 @@
- Pair: apps_and_tools/pair.md
- Unbond: apps_and_tools/unbond.md
- USB Probe: apps_and_tools/usb_probe.md
+ - Link Relay: apps_and_tools/link_relay.md
- Hardware:
- Overview: hardware/index.md
- Platforms:
@@ -62,7 +63,7 @@
- Examples:
- Overview: examples/index.md
-copyright: Copyright 2021-2022 Google LLC
+copyright: Copyright 2021-2023 Google LLC
theme:
name: 'material'
diff --git a/docs/mkdocs/src/apps_and_tools/bench.md b/docs/mkdocs/src/apps_and_tools/bench.md
new file mode 100644
index 0000000..db785d6
--- /dev/null
+++ b/docs/mkdocs/src/apps_and_tools/bench.md
@@ -0,0 +1,158 @@
+BENCH TOOL
+==========
+
+The "bench" tool implements a number of different ways of measuring the
+throughput and/or latency between two devices.
+
+# General Usage
+
+```
+Usage: bench.py [OPTIONS] COMMAND [ARGS]...
+
+Options:
+ --device-config FILENAME Device configuration file
+ --role [sender|receiver|ping|pong]
+ --mode [gatt-client|gatt-server|l2cap-client|l2cap-server|rfcomm-client|rfcomm-server]
+ --att-mtu MTU GATT MTU (gatt-client mode) [23<=x<=517]
+ -s, --packet-size SIZE Packet size (server role) [8<=x<=4096]
+ -c, --packet-count COUNT Packet count (server role)
+ -sd, --start-delay SECONDS Start delay (server role)
+ --help Show this message and exit.
+
+Commands:
+ central Run as a central (initiates the connection)
+ peripheral Run as a peripheral (waits for a connection)
+```
+
+## Options for the ``central`` Command
+```
+Usage: bumble-bench central [OPTIONS] TRANSPORT
+
+ Run as a central (initiates the connection)
+
+Options:
+ --peripheral ADDRESS_OR_NAME Address or name to connect to
+ --connection-interval, --ci CONNECTION_INTERVAL
+ Connection interval (in ms)
+ --phy [1m|2m|coded] PHY to use
+ --help Show this message and exit.
+```
+
+
+To test once device against another, one of the two devices must be running
+the ``peripheral`` command and the other the ``central`` command. The device
+running the ``peripheral`` command will accept connections from the device
+running the ``central`` command.
+When using Bluetooth LE (all modes except for ``rfcomm-server`` and ``rfcomm-client``utils),
+the default addresses configured in the tool should be sufficient. But when using
+Bluetooth Classic, the address of the Peripheral must be specified on the Central
+using the ``--peripheral`` option. The address will be printed by the Peripheral when
+it starts.
+
+Independently of whether the device is the Central or Peripheral, each device selects a
+``mode`` and and ``role`` to run as. The ``mode`` and ``role`` of the Central and Peripheral
+must be compatible.
+
+Device 1 mode | Device 2 mode
+------------------|------------------
+``gatt-client`` | ``gatt-server``
+``l2cap-client`` | ``l2cap-server``
+``rfcomm-client`` | ``rfcomm-server``
+
+Device 1 role | Device 2 role
+--------------|--------------
+``sender`` | ``receiver``
+``ping`` | ``pong``
+
+
+# Examples
+
+In the following examples, we have two USB Bluetooth controllers, one on `usb:0` and
+the other on `usb:1`, and two consoles/terminals. We will run a command in each.
+
+!!! example "GATT Throughput"
+ Using the default mode and role for the Central and Peripheral.
+
+ In the first console/terminal:
+ ```
+ $ bumble-bench peripheral usb:0
+ ```
+
+ In the second console/terminal:
+ ```
+ $ bumble-bench central usb:1
+ ```
+
+ In this default configuration, the Central runs a Sender, as a GATT client,
+ connecting to the Peripheral running a Receiver, as a GATT server.
+
+!!! example "L2CAP Throughput"
+ In the first console/terminal:
+ ```
+ $ bumble-bench --mode l2cap-server peripheral usb:0
+ ```
+
+ In the second console/terminal:
+ ```
+ $ bumble-bench --mode l2cap-client central usb:1
+ ```
+
+!!! example "RFComm Throughput"
+ In the first console/terminal:
+ ```
+ $ bumble-bench --mode rfcomm-server peripheral usb:0
+ ```
+
+ NOTE: the BT address of the Peripheral will be printed out, use it with the
+ ``--peripheral`` option for the Central.
+
+ In this example, we use a larger packet size and packet count than the default.
+
+ In the second console/terminal:
+ ```
+ $ bumble-bench --mode rfcomm-client --packet-size 2000 --packet-count 100 central --peripheral 00:16:A4:5A:40:F2 usb:1
+ ```
+
+!!! example "Ping/Pong Latency"
+ In the first console/terminal:
+ ```
+ $ bumble-bench --role pong peripheral usb:0
+ ```
+
+ In the second console/terminal:
+ ```
+ $ bumble-bench --role ping central usb:1
+ ```
+
+!!! example "Reversed modes with GATT and custom connection interval"
+ In the first console/terminal:
+ ```
+ $ bumble-bench --mode gatt-client peripheral usb:0
+ ```
+
+ In the second console/terminal:
+ ```
+ $ bumble-bench --mode gatt-server central --ci 10 usb:1
+ ```
+
+!!! example "Reversed modes with L2CAP and custom PHY"
+ In the first console/terminal:
+ ```
+ $ bumble-bench --mode l2cap-client peripheral usb:0
+ ```
+
+ In the second console/terminal:
+ ```
+ $ bumble-bench --mode l2cap-server central --phy 2m usb:1
+ ```
+
+!!! example "Reversed roles with L2CAP"
+ In the first console/terminal:
+ ```
+ $ bumble-bench --mode l2cap-client --role sender peripheral usb:0
+ ```
+
+ In the second console/terminal:
+ ```
+ $ bumble-bench --mode l2cap-server --role receiver central usb:1
+ ```
diff --git a/docs/mkdocs/src/apps_and_tools/index.md b/docs/mkdocs/src/apps_and_tools/index.md
index f588738..fe7af56 100644
--- a/docs/mkdocs/src/apps_and_tools/index.md
+++ b/docs/mkdocs/src/apps_and_tools/index.md
@@ -5,6 +5,7 @@
These include:
* [Console](console.md) - an interactive text-based console
+ * [Bench](bench.md) - Speed and Latency benchmarking between two devices (LE and Classic)
* [Pair](pair.md) - Pair/bond two devices (LE and Classic)
* [Unbond](unbond.md) - Remove a previously established bond
* [HCI Bridge](hci_bridge.md) - a HCI transport bridge to connect two HCI transports and filter/snoop the HCI packets
diff --git a/docs/mkdocs/src/index.md b/docs/mkdocs/src/index.md
index fb1e155..c81f7ff 100644
--- a/docs/mkdocs/src/index.md
+++ b/docs/mkdocs/src/index.md
@@ -8,8 +8,7 @@
eventually added. Support for BLE is therefore currently somewhat more advanced than for Classic.
!!! warning
- This project is still very much experimental and in an alpha state where a lot of things are still missing or broken, and what's there changes frequently.
- Also, there are still a few hardcoded values/parameters in some of the examples and apps which need to be changed (those will eventually be command line arguments, as appropriate)
+ This project is still in an early state of development where some things are still missing or broken, and what's implemented may change and evolve frequently.
Overview
--------
diff --git a/setup.cfg b/setup.cfg
index 662dd5c..0a4aae3 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -57,6 +57,7 @@
bumble-unbond = bumble.apps.unbond:main
bumble-usb-probe = bumble.apps.usb_probe:main
bumble-link-relay = bumble.apps.link_relay.link_relay:main
+ bumble-bench = bumble.apps.bench:main
[options.package_data]
* = py.typed, *.pyi
diff --git a/tests/a2dp_test.py b/tests/a2dp_test.py
index e499531..92f7915 100644
--- a/tests/a2dp_test.py
+++ b/tests/a2dp_test.py
@@ -21,6 +21,7 @@
import pytest
from bumble.controller import Controller
+from bumble.core import BT_BR_EDR_TRANSPORT
from bumble.link import LocalLink
from bumble.device import Device
from bumble.host import Host
@@ -58,18 +59,19 @@
def __init__(self):
self.connections = [None, None]
+ addresses = ['F0:F1:F2:F3:F4:F5', 'F5:F4:F3:F2:F1:F0']
self.link = LocalLink()
self.controllers = [
- Controller('C1', link=self.link),
- Controller('C2', link=self.link),
+ Controller('C1', link=self.link, public_address=addresses[0]),
+ Controller('C2', link=self.link, public_address=addresses[1]),
]
self.devices = [
Device(
- address='F0:F1:F2:F3:F4:F5',
+ address=addresses[0],
host=Host(self.controllers[0], AsyncPipeSink(self.controllers[0])),
),
Device(
- address='F5:F4:F3:F2:F1:F0',
+ address=addresses[1],
host=Host(self.controllers[1], AsyncPipeSink(self.controllers[1])),
),
]
@@ -79,6 +81,9 @@
def on_connection(self, which, connection):
self.connections[which] = connection
+ def on_paired(self, which, keys):
+ self.paired[which] = keys
+
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
@@ -94,12 +99,21 @@
'connection', lambda connection: two_devices.on_connection(1, connection)
)
+ # Enable Classic connections
+ two_devices.devices[0].classic_enabled = True
+ two_devices.devices[1].classic_enabled = True
+
# Start
await two_devices.devices[0].power_on()
await two_devices.devices[1].power_on()
# Connect the two devices
- await two_devices.devices[0].connect(two_devices.devices[1].random_address)
+ await asyncio.gather(
+ two_devices.devices[0].connect(
+ two_devices.devices[1].public_address, transport=BT_BR_EDR_TRANSPORT
+ ),
+ two_devices.devices[1].accept(two_devices.devices[0].public_address),
+ )
# Check the post conditions
assert two_devices.connections[0] is not None
@@ -152,6 +166,9 @@
@pytest.mark.asyncio
async def test_source_sink_1():
two_devices = TwoDevices()
+ # Enable Classic connections
+ two_devices.devices[0].classic_enabled = True
+ two_devices.devices[1].classic_enabled = True
await two_devices.devices[0].power_on()
await two_devices.devices[1].power_on()
@@ -171,9 +188,16 @@
listener = Listener(Listener.create_registrar(two_devices.devices[1]))
listener.on('connection', on_avdtp_connection)
- connection = await two_devices.devices[0].connect(
- two_devices.devices[1].random_address
- )
+ async def make_connection():
+ connections = await asyncio.gather(
+ two_devices.devices[0].connect(
+ two_devices.devices[1].public_address, BT_BR_EDR_TRANSPORT
+ ),
+ two_devices.devices[1].accept(two_devices.devices[0].public_address),
+ )
+ return connections[0]
+
+ connection = await make_connection()
client = await Protocol.connect(connection)
endpoints = await client.discover_remote_endpoints()
assert len(endpoints) == 1
diff --git a/tests/decoder_test.py b/tests/decoder_test.py
new file mode 100644
index 0000000..3be2c63
--- /dev/null
+++ b/tests/decoder_test.py
@@ -0,0 +1,47 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# -----------------------------------------------------------------------------
+# Imports
+# -----------------------------------------------------------------------------
+import hashlib
+import os
+from bumble.decoder import G722Decoder
+
+
+# -----------------------------------------------------------------------------
+def test_decode_file():
+ decoder = G722Decoder()
+ output_bytes = bytearray()
+
+ with open(
+ os.path.join(os.path.dirname(__file__), 'g722_sample.g722'), 'rb'
+ ) as file:
+ file_content = file.read()
+ frame_length = 80
+ data_length = int(len(file_content) / frame_length)
+
+ for i in range(0, data_length):
+ decoded_data = decoder.decode_frame(
+ file_content[i * frame_length : i * frame_length + frame_length]
+ )
+ output_bytes.extend(decoded_data)
+
+ result = hashlib.md5(output_bytes).hexdigest()
+ assert result == 'b58e0cdd012d12f5633fc796c3b0fbd4'
+
+
+# -----------------------------------------------------------------------------
+if __name__ == '__main__':
+ test_decode_file()
diff --git a/tests/g722_sample.g722 b/tests/g722_sample.g722
new file mode 100644
index 0000000..432cd6c
--- /dev/null
+++ b/tests/g722_sample.g722
Binary files differ
diff --git a/tests/gatt_test.py b/tests/gatt_test.py
index 2473623..70bbdb8 100644
--- a/tests/gatt_test.py
+++ b/tests/gatt_test.py
@@ -37,10 +37,12 @@
Service,
Characteristic,
CharacteristicValue,
+ Descriptor,
)
from bumble.transport import AsyncPipeSink
from bumble.core import UUID
from bumble.att import (
+ Attribute,
ATT_EXCHANGE_MTU_REQUEST,
ATT_ATTRIBUTE_NOT_FOUND_ERROR,
ATT_PDU,
@@ -862,6 +864,29 @@
# -----------------------------------------------------------------------------
+def test_attribute_string_to_permissions():
+ assert Attribute.string_to_permissions('READABLE') == 1
+ assert Attribute.string_to_permissions('WRITEABLE') == 2
+ assert Attribute.string_to_permissions('READABLE,WRITEABLE') == 3
+
+
+# -----------------------------------------------------------------------------
+def test_charracteristic_permissions():
+ characteristic = Characteristic(
+ 'FDB159DB-036C-49E3-B3DB-6325AC750806',
+ Characteristic.READ | Characteristic.WRITE | Characteristic.NOTIFY,
+ 'READABLE,WRITEABLE',
+ )
+ assert characteristic.permissions == 3
+
+
+# -----------------------------------------------------------------------------
+def test_descriptor_permissions():
+ descriptor = Descriptor('2902', 'READABLE,WRITEABLE')
+ assert descriptor.permissions == 3
+
+
+# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
test_UUID()
diff --git a/tests/hci_test.py b/tests/hci_test.py
index 14e5182..af68e86 100644
--- a/tests/hci_test.py
+++ b/tests/hci_test.py
@@ -52,6 +52,7 @@
HCI_LE_Set_Scan_Parameters_Command,
HCI_Number_Of_Completed_Packets_Event,
HCI_Packet,
+ HCI_PIN_Code_Request_Reply_Command,
HCI_Read_Local_Supported_Commands_Command,
HCI_Read_Local_Supported_Features_Command,
HCI_Read_Local_Version_Information_Command,
@@ -213,6 +214,23 @@
# -----------------------------------------------------------------------------
+def test_HCI_PIN_Code_Request_Reply_Command():
+ pin_code = b'1234'
+ pin_code_length = len(pin_code)
+ # here to make the test pass, we need to
+ # pad pin_code, as HCI_Object.format_fields
+ # does not do it for us
+ padded_pin_code = pin_code + bytes(16 - pin_code_length)
+ command = HCI_PIN_Code_Request_Reply_Command(
+ bd_addr=Address(
+ '00:11:22:33:44:55', address_type=Address.PUBLIC_DEVICE_ADDRESS
+ ),
+ pin_code_length=pin_code_length,
+ pin_code=padded_pin_code,
+ )
+ basic_check(command)
+
+
def test_HCI_Reset_Command():
command = HCI_Reset_Command()
basic_check(command)
@@ -440,6 +458,7 @@
def run_test_commands():
test_HCI_Command()
test_HCI_Reset_Command()
+ test_HCI_PIN_Code_Request_Reply_Command()
test_HCI_Read_Local_Version_Information_Command()
test_HCI_Read_Local_Supported_Commands_Command()
test_HCI_Read_Local_Supported_Features_Command()
diff --git a/tests/self_test.py b/tests/self_test.py
index 751825f..d6b16ec 100644
--- a/tests/self_test.py
+++ b/tests/self_test.py
@@ -22,6 +22,7 @@
import pytest
from bumble.controller import Controller
+from bumble.core import BT_BR_EDR_TRANSPORT, BT_PERIPHERAL_ROLE, BT_CENTRAL_ROLE
from bumble.link import LocalLink
from bumble.device import Device, Peer
from bumble.host import Host
@@ -47,18 +48,19 @@
def __init__(self):
self.connections = [None, None]
+ addresses = ['F0:F1:F2:F3:F4:F5', 'F5:F4:F3:F2:F1:F0']
self.link = LocalLink()
self.controllers = [
- Controller('C1', link=self.link),
- Controller('C2', link=self.link),
+ Controller('C1', link=self.link, public_address=addresses[0]),
+ Controller('C2', link=self.link, public_address=addresses[1]),
]
self.devices = [
Device(
- address='F0:F1:F2:F3:F4:F5',
+ address=addresses[0],
host=Host(self.controllers[0], AsyncPipeSink(self.controllers[0])),
),
Device(
- address='F5:F4:F3:F2:F1:F0',
+ address=addresses[1],
host=Host(self.controllers[1], AsyncPipeSink(self.controllers[1])),
),
]
@@ -100,6 +102,60 @@
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ 'responder_role,',
+ (BT_CENTRAL_ROLE, BT_PERIPHERAL_ROLE),
+)
+async def test_self_classic_connection(responder_role):
+ # Create two devices, each with a controller, attached to the same link
+ two_devices = TwoDevices()
+
+ # Attach listeners
+ two_devices.devices[0].on(
+ 'connection', lambda connection: two_devices.on_connection(0, connection)
+ )
+ two_devices.devices[1].on(
+ 'connection', lambda connection: two_devices.on_connection(1, connection)
+ )
+
+ # Enable Classic connections
+ two_devices.devices[0].classic_enabled = True
+ two_devices.devices[1].classic_enabled = True
+
+ # Start
+ await two_devices.devices[0].power_on()
+ await two_devices.devices[1].power_on()
+
+ # Connect the two devices
+ await asyncio.gather(
+ two_devices.devices[0].connect(
+ two_devices.devices[1].public_address, transport=BT_BR_EDR_TRANSPORT
+ ),
+ two_devices.devices[1].accept(
+ two_devices.devices[0].public_address, responder_role
+ ),
+ )
+
+ # Check the post conditions
+ assert two_devices.connections[0] is not None
+ assert two_devices.connections[1] is not None
+
+ # Check the role
+ assert two_devices.connections[0].role != responder_role
+ assert two_devices.connections[1].role == responder_role
+
+ # Role switch
+ await two_devices.connections[0].switch_role(responder_role)
+
+ # Check the role
+ assert two_devices.connections[0].role == responder_role
+ assert two_devices.connections[1].role != responder_role
+
+ await two_devices.connections[0].disconnect()
+
+
+# -----------------------------------------------------------------------------
+@pytest.mark.asyncio
async def test_self_gatt():
# Create two devices, each with a controller, attached to the same link
two_devices = TwoDevices()