Skip to content

Commit

Permalink
Changes to how signals are caught and handled in awatch (#136)
Browse files Browse the repository at this point in the history
* Don't use a signal handler, but catch CancelledError and set the stop event there

* The exit_on_signal is now also no longer needed

* Better also handle KeyboardInterrupt the same way

* Raise KeyboardInterrupt when asked for

* split 'stop' and 'signal' case; log warning to match test expectation

* no need to mock open_signal_receiver anymore; set exit_code to 'stop', which should be the expected one here

* removed unused import

* tweaks and docs changes for signal handling

* fix coverage

* fix test_watch_dont_raise_interrupt

Co-authored-by: Samuel Colvin <s@muelcolvin.com>
  • Loading branch information
justvanrossum and samuelcolvin authored May 16, 2022
1 parent 86762c9 commit 7ab3cb9
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 53 deletions.
5 changes: 5 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ enum WatcherEnum {
struct RustNotify {
changes: Arc<Mutex<HashSet<(u8, String)>>>,
error: Arc<Mutex<Option<String>>>,
debug: bool,
watcher: WatcherEnum,
}

Expand Down Expand Up @@ -138,6 +139,7 @@ impl RustNotify {
Ok(RustNotify {
changes,
error,
debug,
watcher,
})
}
Expand Down Expand Up @@ -186,6 +188,9 @@ impl RustNotify {

if let Some(is_set) = stop_event_is_set {
if is_set.call0()?.is_true()? {
if self.debug {
eprintln!("stop event set, stopping...");
}
self.clear();
return Ok("stop".to_object(py));
}
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,13 @@ def watch(self, debounce_ms: int, step_ms: int, timeout_ms: int, cancel_event):
from typing import Literal, Protocol

class MockRustType(Protocol):
def __call__(self, changes: ChangesType, *, exit_code: Literal['signal', 'stop', 'timeout'] = 'signal') -> Any:
def __call__(self, changes: ChangesType, *, exit_code: Literal['signal', 'stop', 'timeout'] = 'stop') -> Any:
...


@pytest.fixture
def mock_rust_notify(mocker):
def mock(changes: ChangesType, *, exit_code: str = 'signal'):
def mock(changes: ChangesType, *, exit_code: str = 'stop'):
m = MockRustNotify(changes, exit_code)
mocker.patch('watchfiles.main.RustNotify', return_value=m)
return m
Expand Down
68 changes: 50 additions & 18 deletions tests/test_watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,27 @@ async def test_await_stop_event(tmp_path: Path, write_soon):
stop_event.set()


def test_watch_interrupt(mock_rust_notify: 'MockRustType'):
mock_rust_notify([{(1, 'foo.txt')}])
def test_watch_raise_interrupt(mock_rust_notify: 'MockRustType'):
mock_rust_notify([{(1, 'foo.txt')}], exit_code='signal')

w = watch('.', raise_interrupt=True)
assert next(w) == {(Change.added, 'foo.txt')}
with pytest.raises(KeyboardInterrupt):
next(w)


def test_watch_dont_raise_interrupt(mock_rust_notify: 'MockRustType', caplog):
caplog.set_level('WARNING', 'watchfiles')
mock_rust_notify([{(1, 'foo.txt')}], exit_code='signal')

w = watch('.', raise_interrupt=False)
assert next(w) == {(Change.added, 'foo.txt')}
with pytest.raises(StopIteration):
next(w)

assert caplog.text == 'watchfiles.main WARNING: KeyboardInterrupt caught, stopping watch\n'


@contextmanager
def mock_open_signal_receiver(signal):
async def signals():
Expand All @@ -69,31 +81,26 @@ async def signals():
yield signals()


@pytest.mark.skipif(sys.platform == 'win32', reason='fails on windows')
async def test_awatch_interrupt_raise(mocker, mock_rust_notify: 'MockRustType'):
mocker.patch('watchfiles.main.anyio.open_signal_receiver', side_effect=mock_open_signal_receiver)
mock_rust_notify([{(1, 'foo.txt')}])
async def test_awatch_unexpected_signal(mock_rust_notify: 'MockRustType'):
mock_rust_notify([{(1, 'foo.txt')}], exit_code='signal')

count = 0
with pytest.raises(KeyboardInterrupt):
with pytest.raises(RuntimeError, match='watch thread unexpectedly received a signal'):
async for _ in awatch('.'):
count += 1

assert count == 1


@pytest.mark.skipif(sys.platform == 'win32', reason='fails on windows')
async def test_awatch_interrupt_warning(mocker, mock_rust_notify: 'MockRustType', caplog):
caplog.set_level('INFO', 'watchfiles')
mocker.patch('watchfiles.main.anyio.open_signal_receiver', side_effect=mock_open_signal_receiver)
async def test_awatch_interrupt_warning(mock_rust_notify: 'MockRustType', caplog):
mock_rust_notify([{(1, 'foo.txt')}])

count = 0
async for _ in awatch('.', raise_interrupt=False):
count += 1
with pytest.warns(DeprecationWarning, match='raise_interrupt is deprecated, KeyboardInterrupt will cause this'):
async for _ in awatch('.', raise_interrupt=False):
count += 1

assert count == 1
assert 'WARNING: KeyboardInterrupt caught, stopping awatch' in caplog.text


def test_watch_no_yield(mock_rust_notify: 'MockRustType', caplog):
Expand All @@ -119,7 +126,7 @@ async def test_awatch_no_yield(mock_rust_notify: 'MockRustType', caplog):


def test_watch_timeout(mock_rust_notify: 'MockRustType', caplog):
mock = mock_rust_notify(['timeout', {(1, 'spam.py')}], exit_code='stop')
mock = mock_rust_notify(['timeout', {(1, 'spam.py')}])

caplog.set_level('DEBUG', 'watchfiles')
change_list = []
Expand All @@ -135,7 +142,7 @@ def test_watch_timeout(mock_rust_notify: 'MockRustType', caplog):


def test_watch_yield_on_timeout(mock_rust_notify: 'MockRustType'):
mock = mock_rust_notify(['timeout', {(1, 'spam.py')}], exit_code='stop')
mock = mock_rust_notify(['timeout', {(1, 'spam.py')}])

change_list = []
for changes in watch('.', yield_on_timeout=True):
Expand All @@ -146,7 +153,7 @@ def test_watch_yield_on_timeout(mock_rust_notify: 'MockRustType'):


async def test_awatch_timeout(mock_rust_notify: 'MockRustType', caplog):
mock = mock_rust_notify(['timeout', {(1, 'spam.py')}], exit_code='stop')
mock = mock_rust_notify(['timeout', {(1, 'spam.py')}])

caplog.set_level('DEBUG', 'watchfiles')
change_list = []
Expand All @@ -162,7 +169,7 @@ async def test_awatch_timeout(mock_rust_notify: 'MockRustType', caplog):


async def test_awatch_yield_on_timeout(mock_rust_notify: 'MockRustType'):
mock = mock_rust_notify(['timeout', {(1, 'spam.py')}], exit_code='stop')
mock = mock_rust_notify(['timeout', {(1, 'spam.py')}])

change_list = []
async for changes in awatch('.', yield_on_timeout=True):
Expand All @@ -182,3 +189,28 @@ def test_calc_async_timeout_posix():
def test_calc_async_timeout_win():
assert _calc_async_timeout(123) == 123
assert _calc_async_timeout(None) == 1_000


class MockRustNotifyRaise:
def __init__(self):
self.i = 0

def watch(self, *args):
if self.i == 1:
raise KeyboardInterrupt('test error')
self.i += 1
return {(Change.added, 'spam.py')}


async def test_awatch_interrupt_raise(mocker, caplog):
mocker.patch('watchfiles.main.RustNotify', return_value=MockRustNotifyRaise())

count = 0
stop_event = threading.Event()
with pytest.raises(KeyboardInterrupt, match='test error'):
async for _ in awatch('.', stop_event=stop_event):
count += 1

# event is set because it's set while handling the KeyboardInterrupt
assert stop_event.is_set()
assert count == 1
59 changes: 32 additions & 27 deletions watchfiles/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import signal
import sys
import warnings
from enum import IntEnum
from pathlib import Path
from typing import TYPE_CHECKING, AsyncGenerator, Callable, Generator, Optional, Set, Tuple, Union
Expand Down Expand Up @@ -132,8 +132,7 @@ async def awatch( # noqa C901
rust_timeout: Optional[int] = None,
yield_on_timeout: bool = False,
debug: bool = False,
raise_interrupt: bool = True,
exit_on_signal: bool = True,
raise_interrupt: Optional[bool] = None,
force_polling: bool = False,
poll_delay_ms: int = 30,
) -> AsyncGenerator[Set[FileChange], None]:
Expand All @@ -143,6 +142,9 @@ async def awatch( # noqa C901
All async methods use [anyio](https://anyio.readthedocs.io/en/latest/) to run the event loop.
Unlike [`watch`][watchfiles.watch] `KeyboardInterrupt` cannot be suppressed by `awatch` so they need to be caught
where `asyncio.run` or equivalent is called.
Args:
*paths: filesystem paths to watch.
watch_filter: matches the same argument of [`watch`][watchfiles.watch].
Expand All @@ -154,8 +156,9 @@ async def awatch( # noqa C901
see [#110](https://github.com/samuelcolvin/watchfiles/issues/110).
yield_on_timeout: matches the same argument of [`watch`][watchfiles.watch].
debug: matches the same argument of [`watch`][watchfiles.watch].
raise_interrupt: matches the same argument of [`watch`][watchfiles.watch].
exit_on_signal: whether to watch for, and exit upon `SIGINT`, ignored on Windows where signals don't work.
raise_interrupt: This is deprecated, `KeyboardInterrupt` will cause this coroutine to be cancelled and then
be raised by the top level `asyncio.run` call or equivalent, and should be caught there.
See [#136](https://github.com/samuelcolvin/watchfiles/issues/136)
force_polling: if true, always use polling instead of file system notifications.
poll_delay_ms: delay between polling for changes, only used if `force_polling=True`.
Expand All @@ -170,7 +173,11 @@ async def main():
async for changes in awatch('./first/dir', './second/dir'):
print(changes)
asyncio.run(main())
if __name__ == '__main__':
try:
asyncio.run(main())
except KeyboardInterrupt:
print('stopped via KeyboardInterrupt')
```
```py title="Example of awatch usage with a stop event"
Expand All @@ -195,44 +202,42 @@ async def stop_soon():
asyncio.run(main())
```
"""
if raise_interrupt is not None:
warnings.warn(
'raise_interrupt is deprecated, KeyboardInterrupt will cause this coroutine to be cancelled and then '
'be raised by the top level asyncio.run call or equivalent, and should be caught there. See #136.',
DeprecationWarning,
)

if stop_event is None:
stop_event_: 'AnyEvent' = anyio.Event()
else:
stop_event_ = stop_event
interrupted = False

async def signal_handler() -> None:
nonlocal interrupted

with anyio.open_signal_receiver(signal.SIGINT) as signals:
async for _ in signals:
interrupted = True
stop_event_.set()
break

watcher = RustNotify([str(p) for p in paths], debug, force_polling, poll_delay_ms)
timeout = _calc_async_timeout(rust_timeout)
CancelledError = anyio.get_cancelled_exc_class()

while True:
async with anyio.create_task_group() as tg:
# add_signal_handler is not implemented on Windows repeat ctrl+c should still stops the watcher
if exit_on_signal and sys.platform != 'win32':
tg.start_soon(signal_handler)
raw_changes = await anyio.to_thread.run_sync(watcher.watch, debounce, step, timeout, stop_event_)
try:
raw_changes = await anyio.to_thread.run_sync(watcher.watch, debounce, step, timeout, stop_event_)
except (CancelledError, KeyboardInterrupt):
stop_event_.set()
# suppressing KeyboardInterrupt wouldn't stop it getting raised by the top level asyncio.run call
raise
tg.cancel_scope.cancel()

if raw_changes == 'timeout':
if yield_on_timeout:
yield set()
else:
logger.debug('rust notify timeout, continuing')
elif raw_changes == 'stop' or raw_changes == 'signal':
# cover both cases here although in theory the watch thread should never get a signal
if interrupted:
if raise_interrupt:
raise KeyboardInterrupt
else:
logger.warning('KeyboardInterrupt caught, stopping awatch')
elif raw_changes == 'stop':
return
elif raw_changes == 'signal':
# in theory the watch thread should never get a signal
raise RuntimeError('watch thread unexpectedly received a signal')
else:
changes = _prep_changes(raw_changes, watch_filter)
if changes:
Expand Down
13 changes: 7 additions & 6 deletions watchfiles/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ async def arun_process(
Starting and stopping the process and watching for changes is done in a separate thread.
As with `run_process`, internally `arun_process` uses [`awatch`][watchfiles.awatch] with `raise_interrupt=False`
so the function exits cleanly upon `Ctrl+C`.
As with `run_process`, internally `arun_process` uses [`awatch`][watchfiles.awatch], however `KeyboardInterrupt`
cannot be caught and suppressed in `awatch` so these errors need to be caught separately, see below.
```py title="Example of arun_process usage"
import asyncio
Expand All @@ -171,7 +171,10 @@ async def main():
await arun_process('.', target=foobar, args=(1, 2), callback=callback)
if __name__ == '__main__':
asyncio.run(main())
try:
asyncio.run(main())
except KeyboardInterrupt:
print('stopped via KeyboardInterrupt')
```
"""
import inspect
Expand All @@ -183,9 +186,7 @@ async def main():
process = await anyio.to_thread.run_sync(start_process, target, target_type, args, kwargs)
reloads = 0

async for changes in awatch(
*paths, watch_filter=watch_filter, debounce=debounce, step=step, debug=debug, raise_interrupt=False
):
async for changes in awatch(*paths, watch_filter=watch_filter, debounce=debounce, step=step, debug=debug):
if callback is not None:
r = callback(changes)
if inspect.isawaitable(r):
Expand Down

0 comments on commit 7ab3cb9

Please sign in to comment.