Skip to content

Commit

Permalink
Fix host resolution when local dns does not resolve mdns (#636)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Nov 11, 2023
1 parent c1a0500 commit 634c739
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 8 deletions.
15 changes: 10 additions & 5 deletions aioesphomeapi/host_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import contextlib
import logging
import socket
from dataclasses import dataclass
from ipaddress import IPv4Address, IPv6Address
Expand All @@ -12,8 +13,12 @@

from .core import APIConnectionError, ResolveAPIError

_LOGGER = logging.getLogger(__name__)

ZeroconfInstanceType = Union[Zeroconf, AsyncZeroconf, None]

SERVICE_TYPE = "_esphomelib._tcp.local."


@dataclass(frozen=True)
class Sockaddr:
Expand Down Expand Up @@ -89,11 +94,11 @@ async def _async_resolve_host_zeroconf(
timeout: float = 3.0,
zeroconf_instance: ZeroconfInstanceType = None,
) -> list[AddrInfo]:
service_type = "_esphomelib._tcp.local."
service_name = f"{host}.{service_type}"
service_name = f"{host}.{SERVICE_TYPE}"

_LOGGER.debug("Resolving host %s via mDNS", service_name)
info = await _async_zeroconf_get_service_info(
zeroconf_instance, service_type, service_name, timeout
zeroconf_instance, SERVICE_TYPE, service_name, timeout
)

if info is None:
Expand Down Expand Up @@ -197,8 +202,8 @@ async def async_resolve_host(
addrs: list[AddrInfo] = []

zc_error = None
if host.endswith(".local"):
name = host[: -len(".local")]
if "." not in host or host.endswith(".local"):
name = host.partition(".")[0]
try:
addrs.extend(
await _async_resolve_host_zeroconf(
Expand Down
8 changes: 7 additions & 1 deletion aioesphomeapi/log_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import sys
from datetime import datetime

from zeroconf.asyncio import AsyncZeroconf

from .api_pb2 import SubscribeLogsResponse # type: ignore
from .client import APIClient
from .log_runner import async_run
Expand All @@ -27,11 +29,14 @@ async def main(argv: list[str]) -> None:
datefmt="%Y-%m-%d %H:%M:%S",
)

aiozc = AsyncZeroconf()

cli = APIClient(
args.address,
args.port,
args.password or "",
noise_psk=args.noise_psk,
zeroconf_instance=aiozc.zeroconf,
keepalive=10,
)

Expand All @@ -41,11 +46,12 @@ def on_log(msg: SubscribeLogsResponse) -> None:
text = message.decode("utf8", "backslashreplace")
print(f"[{time_.hour:02}:{time_.minute:02}:{time_.second:02}]{text}")

stop = await async_run(cli, on_log)
stop = await async_run(cli, on_log, aio_zeroconf_instance=aiozc)
try:
while True:
await asyncio.sleep(60)
finally:
await aiozc.async_close()
await stop()


Expand Down
5 changes: 3 additions & 2 deletions tests/test_log_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,15 @@ async def test_log_runner(event_loop: asyncio.AbstractEventLoop, conn: APIConnec
class PatchableAPIClient(APIClient):
pass

async_zeroconf = get_mock_async_zeroconf()

cli = PatchableAPIClient(
address=Estr("1.2.3.4"),
port=6052,
password=None,
noise_psk=None,
expected_name=Estr("fake"),
zeroconf_instance=async_zeroconf.zeroconf,
)
messages = []

Expand All @@ -60,8 +63,6 @@ async def _wait_subscribe_cli(*args, **kwargs):
await original_subscribe_logs(*args, **kwargs)
subscribed.set()

async_zeroconf = get_mock_async_zeroconf()

with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
), patch.object(cli, "subscribe_logs", _wait_subscribe_cli):
Expand Down

0 comments on commit 634c739

Please sign in to comment.