diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index 659b586d..ddfe3122 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -580,6 +580,7 @@ async def bluetooth_device_connect( # pylint: disable=too-many-locals, too-many ) timeout_expired = False connect_ok = False + unhandled_exception = False try: await connect_future connect_ok = True @@ -606,11 +607,18 @@ async def bluetooth_device_connect( # pylint: disable=too-many-locals, too-many f"after {timeout}s, disconnect timed out: {disconnect_timed_out}, " f" after {disconnect_timeout}s" ) from err + except BaseException: + unhandled_exception = True + raise finally: - if not connect_ok and not timeout_expired: + if unhandled_exception or (not connect_ok and not timeout_expired): unsub() if not timeout_expired: timeout_handle.cancel() + if unhandled_exception: + # Make sure to disconnect if we had an unhandled exception + # as otherwise the connection will be left open. + self._bluetooth_disconnect_no_wait(address) return unsub @@ -717,6 +725,14 @@ def _raise_for_ble_connection_change( f"({response.error})" ) + def _bluetooth_disconnect_no_wait(self, address: int) -> None: + """Disconnect from a Bluetooth device without waiting for a response.""" + self._get_connection().send_message( + BluetoothDeviceRequest( + address=address, request_type=BluetoothDeviceRequestType.DISCONNECT + ) + ) + async def _bluetooth_device_request( self, address: int, diff --git a/tests/test_client.py b/tests/test_client.py index bba93564..7458e6e8 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -2194,6 +2194,13 @@ def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None await connect_task assert states == [] + # Ensure the disconnect request is written + assert len(transport.writelines.mock_calls) == 2 + req = BluetoothDeviceRequest( + address=1234, request_type=BluetoothDeviceRequestType.DISCONNECT + ).SerializeToString() + assert transport.writelines.mock_calls[-1] == call([b"\x00", b"\x05", b"D", req]) + handlers_after = len(list(itertools.chain(*connection._message_handlers.values()))) # Make sure we do not leak message handlers assert handlers_after == handlers_before