Skip to content

Commit

Permalink
Ensure frame_helper is always closed before the underlying socket (#602)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Oct 24, 2023
1 parent 9ecf2fc commit e1c42e9
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 5 deletions.
7 changes: 5 additions & 2 deletions aioesphomeapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ async def _connect_socket_connect(self, addr: hr.AddrInfo) -> None:
"""Step 2 in connect process: connect the socket."""
debug_enable = self._debug_enabled()
sock = socket.socket(family=addr.family, type=addr.type, proto=addr.proto)
self._socket = sock
sock.setblocking(False)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
# Try to reduce the pressure on esphome device as it measures
Expand Down Expand Up @@ -319,7 +320,6 @@ async def _connect_socket_connect(self, addr: hr.AddrInfo) -> None:
except OSError as err:
raise SocketAPIError(f"Error connecting to {sockaddr}: {err}") from err

self._socket = sock
if debug_enable is True:
_LOGGER.debug(
"%s: Opened socket to %s:%s (%s)",
Expand Down Expand Up @@ -359,14 +359,17 @@ async def _connect_init_frame_helper(self) -> None:
sock=self._socket,
)

# Set the frame helper right away to ensure
# the socket gets closed if we fail to handshake
self._frame_helper = fh

try:
await fh.perform_handshake(HANDSHAKE_TIMEOUT)
except asyncio_TimeoutError as err:
raise TimeoutAPIError("Handshake timed out") from err
except OSError as err:
raise HandshakeAPIError(f"Handshake failed: {err}") from err
self._set_connection_state(ConnectionState.HANDSHAKE_COMPLETE)
self._frame_helper = fh

async def _connect_hello(self) -> None:
"""Step 4 in connect process: send hello and get api version."""
Expand Down
113 changes: 110 additions & 3 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,21 @@
import asyncio
import socket
from datetime import timedelta
from typing import Optional
from typing import Any, Coroutine, Optional
from unittest.mock import AsyncMock

import pytest
from mock import MagicMock, patch

from aioesphomeapi import APIConnectionError
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
from aioesphomeapi.api_pb2 import DeviceInfoResponse, HelloResponse
from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState
from aioesphomeapi.core import RequiresEncryptionAPIError
from aioesphomeapi.core import (
APIConnectionError,
HandshakeAPIError,
RequiresEncryptionAPIError,
TimeoutAPIError,
)
from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr

from .common import async_fire_time_changed, utcnow
Expand Down Expand Up @@ -335,3 +340,105 @@ def on_msg(msg):
remove()
await conn.force_disconnect()
await asyncio.sleep(0)


@pytest.mark.parametrize(
("exception_map"),
[
(OSError("Socket error"), HandshakeAPIError),
(asyncio.TimeoutError, TimeoutAPIError),
(asyncio.CancelledError, APIConnectionError),
],
)
@pytest.mark.asyncio
async def test_plaintext_connection_fails_handshake(
conn: APIConnection,
resolve_host: AsyncMock,
socket_socket: MagicMock,
exception_map: tuple[Exception, Exception],
) -> None:
"""Test that the frame helper is closed before the underlying socket.
If we don't do this, asyncio will get confused and not release the socket.
"""
loop = asyncio.get_event_loop()
exception, raised_exception = exception_map
protocol = _get_mock_protocol(conn)
messages = []
protocol: Optional[APIPlaintextFrameHelper] = None
transport = MagicMock()
connected = asyncio.Event()

class APIPlaintextFrameHelperHandshakeException(APIPlaintextFrameHelper):
"""Plaintext frame helper that raises exception on handshake."""

def perform_handshake(self, timeout: float) -> Coroutine[Any, Any, None]:
raise exception

def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol

def on_msg(msg):
messages.append(msg)

remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse))
transport = MagicMock()

with patch(
"aioesphomeapi.connection.APIPlaintextFrameHelper",
APIPlaintextFrameHelperHandshakeException,
), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
):
connect_task = asyncio.create_task(connect(conn, login=False))
await connected.wait()

assert conn._socket is not None
assert conn._frame_helper is not None

protocol.data_received(
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m'
)
protocol.data_received(b"5stackatomproxy")
protocol.data_received(b"\x00\x00$")
protocol.data_received(b"\x00\x00\x04")
protocol.data_received(
b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d'
)
protocol.data_received(
b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif"
)

call_order = []

def _socket_close_call():
call_order.append("socket_close")

def _frame_helper_close_call():
call_order.append("frame_helper_close")

with patch.object(
conn._socket, "close", side_effect=_socket_close_call
), patch.object(
conn._frame_helper, "close", side_effect=_frame_helper_close_call
), pytest.raises(
raised_exception
):
await asyncio.sleep(0)
await connect_task

# Ensure the frame helper is closed before the socket
# so asyncio releases the socket
assert call_order == ["frame_helper_close", "socket_close"]
assert not conn.is_connected
assert len(messages) == 2
assert isinstance(messages[0], HelloResponse)
assert isinstance(messages[1], DeviceInfoResponse)
assert messages[1].name == "m5stackatomproxy"
remove()
await conn.force_disconnect()
await asyncio.sleep(0)

0 comments on commit e1c42e9

Please sign in to comment.