From 1b44875efd77abae8905342a0d8f3a2c0597b2e6 Mon Sep 17 00:00:00 2001 From: Anibal Angulo Date: Sat, 4 May 2024 08:26:56 -0600 Subject: [PATCH 1/3] Added redis stream option for job delivery --- arq/cli.py | 8 +++++++- arq/connections.py | 6 +++++- arq/constants.py | 3 +++ arq/worker.py | 49 +++++++++++++++++++++++++++++++++++----------- 4 files changed, 53 insertions(+), 13 deletions(-) diff --git a/arq/cli.py b/arq/cli.py index 3d3aa300..58919f83 100644 --- a/arq/cli.py +++ b/arq/cli.py @@ -16,6 +16,7 @@ from .typing import WorkerSettingsType burst_help = 'Batch mode: exit once no jobs are found in any queue.' +stream_help = 'Stream mode: use redis streams for job delivery. Does not support batch mode.' health_check_help = 'Health Check: run a health check and exit.' watch_help = 'Watch a directory and reload the worker upon changes.' verbose_help = 'Enable verbose output.' @@ -26,11 +27,14 @@ @click.version_option(VERSION, '-V', '--version', prog_name='arq') @click.argument('worker-settings', type=str, required=True) @click.option('--burst/--no-burst', default=None, help=burst_help) +@click.option('--stream/--no-stream', default=None, help=stream_help) @click.option('--check', is_flag=True, help=health_check_help) @click.option('--watch', type=click.Path(exists=True, dir_okay=True, file_okay=False), help=watch_help) @click.option('-v', '--verbose', is_flag=True, help=verbose_help) @click.option('--custom-log-dict', type=str, help=logdict_help) -def cli(*, worker_settings: str, burst: bool, check: bool, watch: str, verbose: bool, custom_log_dict: str) -> None: +def cli( + *, worker_settings: str, burst: bool, stream: bool, check: bool, watch: str, verbose: bool, custom_log_dict: str +) -> None: """ Job queues in python with asyncio and redis. @@ -48,6 +52,8 @@ def cli(*, worker_settings: str, burst: bool, check: bool, watch: str, verbose: exit(check_health(worker_settings_)) else: kwargs = {} if burst is None else {'burst': burst} + if stream: + kwargs['stream'] = stream if watch: asyncio.run(watch_reload(watch, worker_settings_)) else: diff --git a/arq/connections.py b/arq/connections.py index c1058890..c3194cc1 100644 --- a/arq/connections.py +++ b/arq/connections.py @@ -13,7 +13,7 @@ from redis.asyncio.sentinel import Sentinel from redis.exceptions import RedisError, WatchError -from .constants import default_queue_name, expires_extra_ms, job_key_prefix, result_key_prefix +from .constants import default_queue_name, expires_extra_ms, job_key_prefix, result_key_prefix, stream_prefix from .jobs import Deserializer, Job, JobDef, JobResult, Serializer, deserialize_job, serialize_job from .utils import timestamp_ms, to_ms, to_unix_ms @@ -120,6 +120,7 @@ async def enqueue_job( self, function: str, *args: Any, + _use_stream: bool = False, _job_id: Optional[str] = None, _queue_name: Optional[str] = None, _defer_until: Optional[datetime] = None, @@ -133,6 +134,7 @@ async def enqueue_job( :param function: Name of the function to call :param args: args to pass to the function + :param _use_stream: queue the job through redis streams. Stream mode must be enabled in worker. :param _job_id: ID of the job, can be used to enforce job uniqueness :param _queue_name: queue of the job, can be used to create job in different queue :param _defer_until: datetime at which to run the job @@ -171,6 +173,8 @@ async def enqueue_job( job = serialize_job(function, args, kwargs, _job_try, enqueue_time_ms, serializer=self.job_serializer) pipe.multi() + if _use_stream: + pipe.xadd(stream_prefix + _queue_name, {job_key_prefix: job}) pipe.psetex(job_key, expires_ms, job) pipe.zadd(_queue_name, {job_id: score}) try: diff --git a/arq/constants.py b/arq/constants.py index 84c009aa..31bd19ca 100644 --- a/arq/constants.py +++ b/arq/constants.py @@ -1,4 +1,7 @@ default_queue_name = 'arq:queue' +default_worker_name = 'arq:worker' +default_worker_group = 'arq:workers' +stream_prefix = 'arq:stream:' job_key_prefix = 'arq:job:' in_progress_key_prefix = 'arq:in-progress:' result_key_prefix = 'arq:result:' diff --git a/arq/worker.py b/arq/worker.py index 4c33b677..72db408c 100644 --- a/arq/worker.py +++ b/arq/worker.py @@ -20,6 +20,8 @@ abort_job_max_age, abort_jobs_ss, default_queue_name, + default_worker_group, + default_worker_name, expires_extra_ms, health_check_key_suffix, in_progress_key_prefix, @@ -27,6 +29,7 @@ keep_cronjob_progress, result_key_prefix, retry_key_prefix, + stream_prefix, ) from .utils import ( args_to_string, @@ -144,10 +147,13 @@ class Worker: :param functions: list of functions to register, can either be raw coroutine functions or the result of :func:`arq.worker.func`. :param queue_name: queue name to get jobs from + :param worker_name: unique name to identify this worker + :param worker_group: worker group that this worker belongs to :param cron_jobs: list of cron jobs to run, use :func:`arq.cron.cron` to create them :param redis_settings: settings for creating a redis connection :param redis_pool: existing redis pool, generally None :param burst: whether to stop the worker once all jobs have been run + :param stream: whether to constantly listen for new jobs from a redis stream :param on_startup: coroutine function to run at startup :param on_shutdown: coroutine function to run at shutdown :param on_job_start: coroutine function to run on job start @@ -188,10 +194,13 @@ def __init__( functions: Sequence[Union[Function, 'WorkerCoroutine']] = (), *, queue_name: Optional[str] = default_queue_name, + worker_name: Optional[str] = None, + worker_group: Optional[str] = None, cron_jobs: Optional[Sequence[CronJob]] = None, redis_settings: Optional[RedisSettings] = None, redis_pool: Optional[ArqRedis] = None, burst: bool = False, + stream: bool = False, on_startup: Optional['StartupShutdown'] = None, on_shutdown: Optional['StartupShutdown'] = None, on_job_start: Optional['StartupShutdown'] = None, @@ -234,6 +243,10 @@ def __init__( if len(self.functions) == 0: raise RuntimeError('at least one function or cron_job must be registered') self.burst = burst + self.stream = stream + if stream is True: + self.worker_name = worker_name if worker_name is not None else default_worker_name + self.worker_group = worker_group if worker_group is not None else default_worker_group self.on_startup = on_startup self.on_shutdown = on_shutdown self.on_job_start = on_job_start @@ -357,17 +370,31 @@ async def main(self) -> None: if self.on_startup: await self.on_startup(self.ctx) - async for _ in poll(self.poll_delay_s): - await self._poll_iteration() - - if self.burst: - if 0 <= self.max_burst_jobs <= self._jobs_started(): - await asyncio.gather(*self.tasks.values()) - return None - queued_jobs = await self.pool.zcard(self.queue_name) - if queued_jobs == 0: - await asyncio.gather(*self.tasks.values()) - return None + if self.stream is False: + async for _ in poll(self.poll_delay_s): + await self._poll_iteration() + + if self.burst: + if 0 <= self.max_burst_jobs <= self._jobs_started(): + await asyncio.gather(*self.tasks.values()) + return None + queued_jobs = await self.pool.zcard(self.queue_name) + if queued_jobs == 0: + await asyncio.gather(*self.tasks.values()) + return None + else: + stream_name = stream_prefix + self.queue_name + + with contextlib.suppress(ResponseError): + await self.pool.xgroup_create(stream_name, self.worker_group, '$', mkstream=True) + logger.info('Stream consumer group created with name: %s', self.worker_group) + + while True: + if event := await self.pool.xreadgroup( + consumername=self.worker_name, groupname=self.worker_group, streams={stream_name: '>'}, block=0 + ): + await self._poll_iteration() + await self.pool.xack(stream_name, self.worker_group, event[0][1][0][0]) # type: ignore[no-untyped-call] async def _poll_iteration(self) -> None: """ From 77de228ab38294ca6095d8ff46b2a5d6b42252cb Mon Sep 17 00:00:00 2001 From: ajac-zero Date: Wed, 12 Jun 2024 19:50:31 -0600 Subject: [PATCH 2/3] Added Stream Worker tests --- arq/worker.py | 19 ++- tests/conftest.py | 26 ++++ tests/test_worker.py | 345 +++++++++++++++++++++++++++---------------- 3 files changed, 260 insertions(+), 130 deletions(-) diff --git a/arq/worker.py b/arq/worker.py index 72db408c..81ebc1c8 100644 --- a/arq/worker.py +++ b/arq/worker.py @@ -8,7 +8,7 @@ from functools import partial from signal import Signals from time import time -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Sequence, Set, Tuple, Union, cast from redis.exceptions import ResponseError, WatchError @@ -389,12 +389,23 @@ async def main(self) -> None: await self.pool.xgroup_create(stream_name, self.worker_group, '$', mkstream=True) logger.info('Stream consumer group created with name: %s', self.worker_group) - while True: + async def read_messages(name: Literal['>', '0']): # type: ignore[no-untyped-def] if event := await self.pool.xreadgroup( - consumername=self.worker_name, groupname=self.worker_group, streams={stream_name: '>'}, block=0 + consumername=self.worker_name, groupname=self.worker_group, streams={stream_name: name}, block=0 ): await self._poll_iteration() - await self.pool.xack(stream_name, self.worker_group, event[0][1][0][0]) # type: ignore[no-untyped-call] + + for message in event[0][1]: + await self.pool.xack(stream_name, self.worker_group, message[0]) # type: ignore[no-untyped-call] + + # Heartbeat before blocking, or health check will fail previous to receiving first message + await self.heart_beat() + + if self.burst: + await read_messages('0') + else: + while True: + await read_messages('>') async def _poll_iteration(self) -> None: """ diff --git a/tests/conftest.py b/tests/conftest.py index b9332eed..350e1741 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -79,6 +79,32 @@ def create(functions=[], burst=True, poll_delay=0, max_jobs=10, arq_redis=arq_re await worker_.close() +poll_worker = worker + + +@pytest.fixture +async def stream_worker(arq_redis): + worker_: Worker = None + + def create(functions=[], burst=True, stream=True, poll_delay=0, max_jobs=10, arq_redis=arq_redis, **kwargs): + nonlocal worker_ + worker_ = Worker( + functions=functions, + redis_pool=arq_redis, + burst=burst, + stream=stream, + poll_delay=poll_delay, + max_jobs=max_jobs, + **kwargs, + ) + return worker_ + + yield create + + if worker_: + await worker_.close() + + @pytest.fixture async def worker_retry(arq_redis_retry): worker_retry_: Worker = None diff --git a/tests/test_worker.py b/tests/test_worker.py index a25f0f1d..9e7c6803 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -35,28 +35,38 @@ async def fails(ctx): raise TypeError('my type error') -def test_no_jobs(arq_redis: ArqRedis, loop, mocker): +def test_no_jobs(arq_redis: ArqRedis, loop, mocker, _stream: bool = False): class Settings: functions = [func(foobar, name='foobar')] burst = True poll_delay = 0 queue_read_limit = 10 + stream = _stream - loop.run_until_complete(arq_redis.enqueue_job('foobar')) + loop.run_until_complete(arq_redis.enqueue_job('foobar', _use_stream=_stream)) mocker.patch('asyncio.get_event_loop', lambda: loop) worker = run_worker(Settings) assert worker.jobs_complete == 1 assert str(worker) == '' -def test_health_check_direct(loop): +def test_no_job_stream(arq_redis: ArqRedis, loop, mocker): + assert test_no_jobs(arq_redis, loop, mocker, _stream=True) is None + + +def test_health_check_direct(loop, _stream: bool = False): class Settings: + stream = _stream pass asyncio.set_event_loop(loop) assert check_health(Settings) == 1 +def test_health_check_direct_stream(loop): + assert test_health_check_direct(loop, _stream=True) is None + + async def test_health_check_fails(): assert 1 == await async_check_health(None) @@ -66,11 +76,15 @@ async def test_health_check_pass(arq_redis): assert 0 == await async_check_health(None) -async def test_set_health_check_key(arq_redis: ArqRedis, worker): - await arq_redis.enqueue_job('foobar', _job_id='testing') +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_set_health_check_key(arq_redis: ArqRedis, worker, stream): + await arq_redis.enqueue_job('foobar', _job_id='testing', _use_stream=stream) worker: Worker = worker(functions=[func(foobar, keep_result=0)], health_check_key='arq:test:health-check') await worker.main() - assert sorted(await arq_redis.keys('*')) == [b'arq:test:health-check'] + expected_keys = [b'arq:test:health-check'] + if stream: + expected_keys.insert(0, b'arq:stream:arq:queue') + assert sorted(await arq_redis.keys('*')) == expected_keys async def test_handle_sig(caplog, arq_redis: ArqRedis): @@ -111,14 +125,15 @@ async def test_handle_no_sig(caplog): assert worker.tasks[1].cancel.call_count == 1 -async def test_worker_signal_completes_job_before_shutting_down(caplog, arq_redis: ArqRedis, worker): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_worker_signal_completes_job_before_shutting_down(caplog, arq_redis: ArqRedis, worker, stream): caplog.set_level(logging.INFO) async def sleep_job(ctx, time): await asyncio.sleep(time) - await arq_redis.enqueue_job('sleep_job', 0.2, _job_id='short_sleep') # should be completed - await arq_redis.enqueue_job('sleep_job', 5, _job_id='long_sleep') # should be cancelled + await arq_redis.enqueue_job('sleep_job', 0.2, _job_id='short_sleep', _use_stream=stream) # should be completed + await arq_redis.enqueue_job('sleep_job', 5, _job_id='long_sleep', _use_stream=stream) # should be cancelled worker = worker( functions=[func(sleep_job, name='sleep_job', max_tries=1)], job_completion_wait=0.5, @@ -143,9 +158,10 @@ async def sleep_job(ctx, time): assert worker.jobs_failed == 0 -async def test_job_successful(arq_redis: ArqRedis, worker, caplog): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_job_successful(arq_redis: ArqRedis, worker, stream, caplog): caplog.set_level(logging.INFO) - await arq_redis.enqueue_job('foobar', _job_id='testing') + await arq_redis.enqueue_job('foobar', _job_id='testing', _use_stream=stream) worker: Worker = worker(functions=[foobar]) assert worker.jobs_complete == 0 assert worker.jobs_failed == 0 @@ -159,9 +175,10 @@ async def test_job_successful(arq_redis: ArqRedis, worker, caplog): assert 'X.XXs → testing:foobar()\n X.XXs ← testing:foobar ● 42' in log -async def test_job_successful_no_result_logging(arq_redis: ArqRedis, worker, caplog): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_job_successful_no_result_logging(arq_redis: ArqRedis, worker, stream, caplog): caplog.set_level(logging.INFO) - await arq_redis.enqueue_job('foobar', _job_id='testing') + await arq_redis.enqueue_job('foobar', _job_id='testing', _use_stream=stream) worker: Worker = worker(functions=[foobar], log_results=False) await worker.main() @@ -170,13 +187,14 @@ async def test_job_successful_no_result_logging(arq_redis: ArqRedis, worker, cap assert '42' not in log -async def test_job_retry(arq_redis: ArqRedis, worker, caplog): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_job_retry(arq_redis: ArqRedis, worker, stream, caplog): async def retry(ctx): if ctx['job_try'] <= 2: raise Retry(defer=0.01) caplog.set_level(logging.INFO) - await arq_redis.enqueue_job('retry', _job_id='testing') + await arq_redis.enqueue_job('retry', _job_id='testing', _use_stream=stream) worker: Worker = worker(functions=[func(retry, name='retry')]) await worker.main() assert worker.jobs_complete == 1 @@ -189,12 +207,13 @@ async def retry(ctx): assert '0.XXs ← testing:retry ●' in log -async def test_job_retry_dont_retry(arq_redis: ArqRedis, worker, caplog): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_job_retry_dont_retry(arq_redis: ArqRedis, worker, stream, caplog): async def retry(ctx): raise Retry(defer=0.01) caplog.set_level(logging.INFO) - await arq_redis.enqueue_job('retry', _job_id='testing') + await arq_redis.enqueue_job('retry', _job_id='testing', _use_stream=stream) worker: Worker = worker(functions=[func(retry, name='retry')]) with pytest.raises(FailedJobs) as exc_info: await worker.run_check(retry_jobs=False) @@ -204,12 +223,13 @@ async def retry(ctx): assert '! testing:retry failed, Retry: \n' in caplog.text -async def test_job_retry_max_jobs(arq_redis: ArqRedis, worker, caplog): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_job_retry_max_jobs(arq_redis: ArqRedis, worker, stream, caplog): async def retry(ctx): raise Retry(defer=0.01) caplog.set_level(logging.INFO) - await arq_redis.enqueue_job('retry', _job_id='testing') + await arq_redis.enqueue_job('retry', _job_id='testing', _use_stream=stream) worker: Worker = worker(functions=[func(retry, name='retry')]) assert await worker.run_check(max_burst_jobs=1) == 0 assert worker.jobs_complete == 0 @@ -221,9 +241,10 @@ async def retry(ctx): assert '0.XXs → testing:retry() try=2\n' not in log -async def test_job_job_not_found(arq_redis: ArqRedis, worker, caplog): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_job_job_not_found(arq_redis: ArqRedis, worker, stream, caplog): caplog.set_level(logging.INFO) - await arq_redis.enqueue_job('missing', _job_id='testing') + await arq_redis.enqueue_job('missing', _job_id='testing', _use_stream=stream) worker: Worker = worker(functions=[foobar]) await worker.main() assert worker.jobs_complete == 0 @@ -234,9 +255,10 @@ async def test_job_job_not_found(arq_redis: ArqRedis, worker, caplog): assert "job testing, function 'missing' not found" in log -async def test_job_job_not_found_run_check(arq_redis: ArqRedis, worker, caplog): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_job_job_not_found_run_check(arq_redis: ArqRedis, worker, stream, caplog): caplog.set_level(logging.INFO) - await arq_redis.enqueue_job('missing', _job_id='testing') + await arq_redis.enqueue_job('missing', _job_id='testing', _use_stream=stream) worker: Worker = worker(functions=[foobar]) with pytest.raises(FailedJobs) as exc_info: await worker.run_check() @@ -248,12 +270,13 @@ async def test_job_job_not_found_run_check(arq_redis: ArqRedis, worker, caplog): assert failure != 123 # check the __eq__ method of JobExecutionFailed -async def test_retry_lots(arq_redis: ArqRedis, worker, caplog): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_retry_lots(arq_redis: ArqRedis, worker, stream, caplog): async def retry(ctx): raise Retry() caplog.set_level(logging.INFO) - await arq_redis.enqueue_job('retry', _job_id='testing') + await arq_redis.enqueue_job('retry', _job_id='testing', _use_stream=stream) worker: Worker = worker(functions=[func(retry, name='retry')]) await worker.main() assert worker.jobs_complete == 0 @@ -264,34 +287,37 @@ async def retry(ctx): assert ' X.XXs ! testing:retry max retries 5 exceeded' in log -async def test_retry_lots_without_keep_result(arq_redis: ArqRedis, worker): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_retry_lots_without_keep_result(arq_redis: ArqRedis, worker, stream): async def retry(ctx): raise Retry() - await arq_redis.enqueue_job('retry', _job_id='testing') + await arq_redis.enqueue_job('retry', _job_id='testing', _use_stream=stream) worker: Worker = worker(functions=[func(retry, name='retry')], keep_result=0) await worker.main() # Should not raise MultiExecError -async def test_retry_lots_check(arq_redis: ArqRedis, worker, caplog): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_retry_lots_check(arq_redis: ArqRedis, worker, stream, caplog): async def retry(ctx): raise Retry() caplog.set_level(logging.INFO) - await arq_redis.enqueue_job('retry', _job_id='testing') + await arq_redis.enqueue_job('retry', _job_id='testing', _use_stream=stream) worker: Worker = worker(functions=[func(retry, name='retry')]) with pytest.raises(FailedJobs, match='max 5 retries exceeded'): await worker.run_check() +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) @pytest.mark.skipif(sys.version_info >= (3, 8), reason='3.8 deals with CancelledError differently') -async def test_cancel_error(arq_redis: ArqRedis, worker, caplog): +async def test_cancel_error(arq_redis: ArqRedis, worker, stream, caplog): async def retry(ctx): if ctx['job_try'] == 1: raise asyncio.CancelledError() caplog.set_level(logging.INFO) - await arq_redis.enqueue_job('retry', _job_id='testing') + await arq_redis.enqueue_job('retry', _job_id='testing', _use_stream=stream) worker: Worker = worker(functions=[func(retry, name='retry')]) await worker.main() assert worker.jobs_complete == 1 @@ -302,13 +328,14 @@ async def retry(ctx): assert 'X.XXs ↻ testing:retry cancelled, will be run again' in log -async def test_retry_job_error(arq_redis: ArqRedis, worker, caplog): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_retry_job_error(arq_redis: ArqRedis, worker, stream, caplog): async def retry(ctx): if ctx['job_try'] == 1: raise RetryJob() caplog.set_level(logging.INFO) - await arq_redis.enqueue_job('retry', _job_id='testing') + await arq_redis.enqueue_job('retry', _job_id='testing', _use_stream=stream) worker: Worker = worker(functions=[func(retry, name='retry')]) await worker.main() assert worker.jobs_complete == 1 @@ -319,9 +346,10 @@ async def retry(ctx): assert 'X.XXs ↻ testing:retry cancelled, will be run again' in log -async def test_job_expired(arq_redis: ArqRedis, worker, caplog): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_job_expired(arq_redis: ArqRedis, worker, stream, caplog): caplog.set_level(logging.INFO) - await arq_redis.enqueue_job('foobar', _job_id='testing') + await arq_redis.enqueue_job('foobar', _job_id='testing', _use_stream=stream) await arq_redis.delete(job_key_prefix + 'testing') worker: Worker = worker(functions=[foobar]) assert worker.jobs_complete == 0 @@ -336,9 +364,10 @@ async def test_job_expired(arq_redis: ArqRedis, worker, caplog): assert 'job testing expired' in log -async def test_job_expired_run_check(arq_redis: ArqRedis, worker, caplog): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_job_expired_run_check(arq_redis: ArqRedis, worker, stream, caplog): caplog.set_level(logging.INFO) - await arq_redis.enqueue_job('foobar', _job_id='testing') + await arq_redis.enqueue_job('foobar', _job_id='testing', _use_stream=stream) await arq_redis.delete(job_key_prefix + 'testing') worker: Worker = worker(functions=[foobar]) with pytest.raises(FailedJobs) as exc_info: @@ -373,9 +402,10 @@ async def test_default_job_expiry(arq_redis: ArqRedis, worker, caplog, extra_job assert time_to_live_ms == pytest.approx(wait_time) -async def test_job_old(arq_redis: ArqRedis, worker, caplog): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_job_old(arq_redis: ArqRedis, worker, stream, caplog): caplog.set_level(logging.INFO) - await arq_redis.enqueue_job('foobar', _job_id='testing', _defer_by=-2) + await arq_redis.enqueue_job('foobar', _job_id='testing', _defer_by=-2, _use_stream=stream) worker: Worker = worker(functions=[foobar]) assert worker.jobs_complete == 0 assert worker.jobs_failed == 0 @@ -393,9 +423,10 @@ async def test_retry_repr(): assert str(Retry(123)) == '' -async def test_str_function(arq_redis: ArqRedis, worker, caplog): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_str_function(arq_redis: ArqRedis, worker, stream, caplog): caplog.set_level(logging.INFO) - await arq_redis.enqueue_job('asyncio.sleep', _job_id='testing') + await arq_redis.enqueue_job('asyncio.sleep', _job_id='testing', _use_stream=stream) worker: Worker = worker(functions=['asyncio.sleep']) assert worker.jobs_complete == 0 assert worker.jobs_failed == 0 @@ -409,7 +440,8 @@ async def test_str_function(arq_redis: ArqRedis, worker, caplog): assert '0.XXs ! testing:asyncio.sleep failed, TypeError' in log -async def test_startup_shutdown(arq_redis: ArqRedis, worker): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_startup_shutdown(arq_redis: ArqRedis, worker, stream): calls = [] async def startup(ctx): @@ -418,7 +450,7 @@ async def startup(ctx): async def shutdown(ctx): calls.append('shutdown') - await arq_redis.enqueue_job('foobar', _job_id='testing') + await arq_redis.enqueue_job('foobar', _job_id='testing', _use_stream=stream) worker: Worker = worker(functions=[foobar], on_startup=startup, on_shutdown=shutdown) await worker.main() await worker.close() @@ -435,9 +467,10 @@ async def error_function(ctx): raise CustomError('this is the error') -async def test_exc_extra(arq_redis: ArqRedis, worker, caplog): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_exc_extra(arq_redis: ArqRedis, worker, stream, caplog): caplog.set_level(logging.INFO) - await arq_redis.enqueue_job('error_function', _job_id='testing') + await arq_redis.enqueue_job('error_function', _job_id='testing', _use_stream=stream) worker: Worker = worker(functions=[error_function]) await worker.main() assert worker.jobs_failed == 1 @@ -448,7 +481,8 @@ async def test_exc_extra(arq_redis: ArqRedis, worker, caplog): assert error.extra == {'x': 'y'} -async def test_unpickleable(arq_redis: ArqRedis, worker, caplog): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_unpickleable(arq_redis: ArqRedis, worker, stream, caplog): caplog.set_level(logging.INFO) class Foo: @@ -457,7 +491,7 @@ class Foo: async def example(ctx): return Foo() - await arq_redis.enqueue_job('example', _job_id='testing') + await arq_redis.enqueue_job('example', _job_id='testing', _use_stream=stream) worker: Worker = worker(functions=[func(example, name='example')]) await worker.main() @@ -465,9 +499,10 @@ async def example(ctx): assert 'error serializing result of testing:example' in log -async def test_log_health_check(arq_redis: ArqRedis, worker, caplog): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_log_health_check(arq_redis: ArqRedis, worker, stream, caplog): caplog.set_level(logging.INFO) - await arq_redis.enqueue_job('foobar', _job_id='testing') + await arq_redis.enqueue_job('foobar', _job_id='testing', _use_stream=stream) worker: Worker = worker(functions=[foobar], health_check_interval=0) await worker.main() await worker.main() @@ -479,62 +514,96 @@ async def test_log_health_check(arq_redis: ArqRedis, worker, caplog): assert 'recording health' in caplog.text -async def test_remain_keys(arq_redis: ArqRedis, worker, create_pool): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_remain_keys(arq_redis: ArqRedis, worker, stream, create_pool): redis2 = await create_pool(RedisSettings()) - await arq_redis.enqueue_job('foobar', _job_id='testing') - assert sorted(await redis2.keys('*')) == [b'arq:job:testing', b'arq:queue'] + await arq_redis.enqueue_job('foobar', _job_id='testing', _use_stream=stream) + expected_keys = [b'arq:job:testing', b'arq:queue'] + if stream: + expected_keys.insert(2, b'arq:stream:arq:queue') + assert sorted(await redis2.keys('*')) == expected_keys worker: Worker = worker(functions=[foobar]) await worker.main() - assert sorted(await redis2.keys('*')) == [b'arq:queue:health-check', b'arq:result:testing'] + expected_keys = [b'arq:queue:health-check', b'arq:result:testing'] + if stream: + expected_keys.insert(2, b'arq:stream:arq:queue') + assert sorted(await redis2.keys('*')) == expected_keys await worker.close() - assert sorted(await redis2.keys('*')) == [b'arq:result:testing'] - - -async def test_remain_keys_no_results(arq_redis: ArqRedis, worker): - await arq_redis.enqueue_job('foobar', _job_id='testing') - assert sorted(await arq_redis.keys('*')) == [b'arq:job:testing', b'arq:queue'] + expected_keys = [b'arq:result:testing'] + if stream: + expected_keys.insert(1, b'arq:stream:arq:queue') + assert sorted(await redis2.keys('*')) == expected_keys + + +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_remain_keys_no_results(arq_redis: ArqRedis, worker, stream): + await arq_redis.enqueue_job('foobar', _job_id='testing', _use_stream=stream) + expected_keys = [b'arq:job:testing', b'arq:queue'] + if stream: + expected_keys.insert(2, b'arq:stream:arq:queue') + assert sorted(await arq_redis.keys('*')) == expected_keys worker: Worker = worker(functions=[func(foobar, keep_result=0)]) await worker.main() - assert sorted(await arq_redis.keys('*')) == [b'arq:queue:health-check'] - - -async def test_remain_keys_keep_results_forever_in_function(arq_redis: ArqRedis, worker): - await arq_redis.enqueue_job('foobar', _job_id='testing') - assert sorted(await arq_redis.keys('*')) == [b'arq:job:testing', b'arq:queue'] + expected_keys = [b'arq:queue:health-check'] + if stream: + expected_keys.insert(1, b'arq:stream:arq:queue') + assert sorted(await arq_redis.keys('*')) == expected_keys + + +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_remain_keys_keep_results_forever_in_function(arq_redis: ArqRedis, worker, stream): + await arq_redis.enqueue_job('foobar', _job_id='testing', _use_stream=stream) + expected_keys = [b'arq:job:testing', b'arq:queue'] + if stream: + expected_keys.insert(2, b'arq:stream:arq:queue') + assert sorted(await arq_redis.keys('*')) == expected_keys worker: Worker = worker(functions=[func(foobar, keep_result_forever=True)]) await worker.main() - assert sorted(await arq_redis.keys('*')) == [b'arq:queue:health-check', b'arq:result:testing'] + expected_keys = [b'arq:queue:health-check', b'arq:result:testing'] + if stream: + expected_keys.insert(2, b'arq:stream:arq:queue') + assert sorted(await arq_redis.keys('*')) == expected_keys ttl_result = await arq_redis.ttl('arq:result:testing') assert ttl_result == -1 -async def test_remain_keys_keep_results_forever(arq_redis: ArqRedis, worker): - await arq_redis.enqueue_job('foobar', _job_id='testing') - assert sorted(await arq_redis.keys('*')) == [b'arq:job:testing', b'arq:queue'] +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_remain_keys_keep_results_forever(arq_redis: ArqRedis, worker, stream): + await arq_redis.enqueue_job('foobar', _job_id='testing', _use_stream=stream) + expected_keys = [b'arq:job:testing', b'arq:queue'] + if stream: + expected_keys.insert(2, b'arq:stream:arq:queue') + assert sorted(await arq_redis.keys('*')) == expected_keys worker: Worker = worker(functions=[func(foobar)], keep_result_forever=True) await worker.main() - assert sorted(await arq_redis.keys('*')) == [b'arq:queue:health-check', b'arq:result:testing'] + expected_keys = [b'arq:queue:health-check', b'arq:result:testing'] + if stream: + expected_keys.insert(2, b'arq:stream:arq:queue') + assert sorted(await arq_redis.keys('*')) == expected_keys ttl_result = await arq_redis.ttl('arq:result:testing') assert ttl_result == -1 -async def test_run_check_passes(arq_redis: ArqRedis, worker): - await arq_redis.enqueue_job('foobar') - await arq_redis.enqueue_job('foobar') +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_run_check_passes(arq_redis: ArqRedis, worker, stream): + await arq_redis.enqueue_job('foobar', _use_stream=stream) + await arq_redis.enqueue_job('foobar', _use_stream=stream) worker: Worker = worker(functions=[func(foobar, name='foobar')]) assert 2 == await worker.run_check() -async def test_run_check_error(arq_redis: ArqRedis, worker): - await arq_redis.enqueue_job('fails') +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_run_check_error(arq_redis: ArqRedis, worker, stream): + await arq_redis.enqueue_job('fails', _use_stream=stream) worker: Worker = worker(functions=[func(fails, name='fails')]) with pytest.raises(FailedJobs, match=r"1 job failed TypeError\('my type error'"): await worker.run_check() -async def test_run_check_error2(arq_redis: ArqRedis, worker): - await arq_redis.enqueue_job('fails') - await arq_redis.enqueue_job('fails') +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_run_check_error2(arq_redis: ArqRedis, worker, stream): + await arq_redis.enqueue_job('fails', _use_stream=stream) + await arq_redis.enqueue_job('fails', _use_stream=stream) worker: Worker = worker(functions=[func(fails, name='fails')]) with pytest.raises(FailedJobs, match='2 jobs failed:\n') as exc_info: await worker.run_check() @@ -551,11 +620,12 @@ async def return_something(ctx): assert (worker.jobs_complete, worker.jobs_failed, worker.jobs_retried) == (1, 0, 0) -async def test_return_exception(arq_redis: ArqRedis, worker): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_return_exception(arq_redis: ArqRedis, worker, stream): async def return_error(ctx): return TypeError('xxx') - j = await arq_redis.enqueue_job('return_error') + j = await arq_redis.enqueue_job('return_error', _use_stream=stream) worker: Worker = worker(functions=[func(return_error, name='return_error')]) await worker.main() assert (worker.jobs_complete, worker.jobs_failed, worker.jobs_retried) == (1, 0, 0) @@ -565,8 +635,9 @@ async def return_error(ctx): assert info.success is True -async def test_error_success(arq_redis: ArqRedis, worker): - j = await arq_redis.enqueue_job('fails') +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_error_success(arq_redis: ArqRedis, worker, stream): + j = await arq_redis.enqueue_job('fails', _use_stream=stream) worker: Worker = worker(functions=[func(fails, name='fails')]) await worker.main() assert (worker.jobs_complete, worker.jobs_failed, worker.jobs_retried) == (0, 1, 0) @@ -574,9 +645,10 @@ async def test_error_success(arq_redis: ArqRedis, worker): assert info.success is False -async def test_many_jobs_expire(arq_redis: ArqRedis, worker, caplog): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_many_jobs_expire(arq_redis: ArqRedis, worker, stream, caplog): caplog.set_level(logging.INFO) - await arq_redis.enqueue_job('foobar') + await arq_redis.enqueue_job('foobar', _use_stream=stream) await asyncio.gather(*[arq_redis.zadd(default_queue_name, {f'testing-{i}': 1}) for i in range(100)]) worker: Worker = worker(functions=[foobar]) assert worker.jobs_complete == 0 @@ -592,8 +664,9 @@ async def test_many_jobs_expire(arq_redis: ArqRedis, worker, caplog): assert log.count(' expired') == 100 -async def test_repeat_job_result(arq_redis: ArqRedis, worker): - j1 = await arq_redis.enqueue_job('foobar', _job_id='job_id') +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_repeat_job_result(arq_redis: ArqRedis, worker, stream): + j1 = await arq_redis.enqueue_job('foobar', _job_id='job_id', _use_stream=stream) assert isinstance(j1, Job) assert await j1.status() == JobStatus.queued @@ -605,9 +678,10 @@ async def test_repeat_job_result(arq_redis: ArqRedis, worker): assert await arq_redis.enqueue_job('foobar', _job_id='job_id') is None -async def test_queue_read_limit_equals_max_jobs(arq_redis: ArqRedis, worker): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_queue_read_limit_equals_max_jobs(arq_redis: ArqRedis, worker, stream): for _ in range(4): - await arq_redis.enqueue_job('foobar') + await arq_redis.enqueue_job('foobar', _use_stream=stream) assert await arq_redis.zcard(default_queue_name) == 4 worker: Worker = worker(functions=[foobar], queue_read_limit=2) @@ -663,8 +737,9 @@ async def test_custom_queue_read_limit(arq_redis: ArqRedis, worker): assert worker.jobs_retried == 0 -async def test_custom_serializers(arq_redis_msgpack: ArqRedis, worker): - j = await arq_redis_msgpack.enqueue_job('foobar', _job_id='job_id') +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_custom_serializers(arq_redis_msgpack: ArqRedis, worker, stream): + j = await arq_redis_msgpack.enqueue_job('foobar', _job_id='job_id', _use_stream=stream) worker: Worker = worker( functions=[foobar], job_serializer=msgpack.packb, job_deserializer=functools.partial(msgpack.unpackb, raw=False) ) @@ -685,16 +760,18 @@ def __setstate__(self, state): @pytest.mark.skipif(sys.version_info < (3, 7), reason='repr(exc) is ugly in 3.6') -async def test_deserialization_error(arq_redis: ArqRedis, worker): - await arq_redis.enqueue_job('foobar', UnpickleFails('hello'), _job_id='job_id') +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_deserialization_error(arq_redis: ArqRedis, worker, stream): + await arq_redis.enqueue_job('foobar', UnpickleFails('hello'), _job_id='job_id', _use_stream=stream) worker: Worker = worker(functions=[foobar]) with pytest.raises(FailedJobs) as exc_info: await worker.run_check() assert str(exc_info.value) == "1 job failed DeserializationError('unable to deserialize job')" -async def test_incompatible_serializers_1(arq_redis_msgpack: ArqRedis, worker): - await arq_redis_msgpack.enqueue_job('foobar', _job_id='job_id') +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_incompatible_serializers_1(arq_redis_msgpack: ArqRedis, worker, stream): + await arq_redis_msgpack.enqueue_job('foobar', _job_id='job_id', _use_stream=stream) worker: Worker = worker(functions=[foobar]) await worker.main() assert worker.jobs_complete == 0 @@ -702,8 +779,9 @@ async def test_incompatible_serializers_1(arq_redis_msgpack: ArqRedis, worker): assert worker.jobs_retried == 0 -async def test_incompatible_serializers_2(arq_redis: ArqRedis, worker): - await arq_redis.enqueue_job('foobar', _job_id='job_id') +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_incompatible_serializers_2(arq_redis: ArqRedis, worker, stream): + await arq_redis.enqueue_job('foobar', _job_id='job_id', _use_stream=stream) worker: Worker = worker( functions=[foobar], job_serializer=msgpack.packb, job_deserializer=functools.partial(msgpack.unpackb, raw=False) ) @@ -713,7 +791,8 @@ async def test_incompatible_serializers_2(arq_redis: ArqRedis, worker): assert worker.jobs_retried == 0 -async def test_max_jobs_completes(arq_redis: ArqRedis, worker): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_max_jobs_completes(arq_redis: ArqRedis, worker, stream): v = 0 async def raise_second_time(ctx): @@ -722,21 +801,22 @@ async def raise_second_time(ctx): if v > 1: raise ValueError('xxx') - await arq_redis.enqueue_job('raise_second_time') - await arq_redis.enqueue_job('raise_second_time') - await arq_redis.enqueue_job('raise_second_time') + await arq_redis.enqueue_job('raise_second_time', _use_stream=stream) + await arq_redis.enqueue_job('raise_second_time', _use_stream=stream) + await arq_redis.enqueue_job('raise_second_time', _use_stream=stream) worker: Worker = worker(functions=[func(raise_second_time, name='raise_second_time')]) with pytest.raises(FailedJobs) as exc_info: await worker.run_check(max_burst_jobs=3) assert repr(exc_info.value).startswith('<2 jobs failed:') -async def test_max_bursts_sub_call(arq_redis: ArqRedis, worker, caplog): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_max_bursts_sub_call(arq_redis: ArqRedis, worker, stream, caplog): async def foo(ctx, v): return v + 1 async def bar(ctx, v): - await ctx['redis'].enqueue_job('foo', v + 1) + await ctx['redis'].enqueue_job('foo', v + 1, _use_stream=stream) caplog.set_level(logging.INFO) await arq_redis.enqueue_job('bar', 10) @@ -749,13 +829,14 @@ async def bar(ctx, v): assert 'foo' in caplog.text -async def test_max_bursts_multiple(arq_redis: ArqRedis, worker, caplog): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_max_bursts_multiple(arq_redis: ArqRedis, worker, stream, caplog): async def foo(ctx, v): return v + 1 caplog.set_level(logging.INFO) - await arq_redis.enqueue_job('foo', 1) - await arq_redis.enqueue_job('foo', 2) + await arq_redis.enqueue_job('foo', 1, _use_stream=stream) + await arq_redis.enqueue_job('foo', 2, _use_stream=stream) worker: Worker = worker(functions=[func(foo, name='foo')]) assert await worker.run_check(max_burst_jobs=1) == 1 assert worker.jobs_complete == 1 @@ -769,12 +850,13 @@ async def foo(ctx, v): assert 'foo(1)' not in caplog.text -async def test_max_bursts_dont_get(arq_redis: ArqRedis, worker): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_max_bursts_dont_get(arq_redis: ArqRedis, worker, stream): async def foo(ctx, v): return v + 1 - await arq_redis.enqueue_job('foo', 1) - await arq_redis.enqueue_job('foo', 2) + await arq_redis.enqueue_job('foo', 1, _use_stream=stream) + await arq_redis.enqueue_job('foo', 2, _use_stream=stream) worker: Worker = worker(functions=[func(foo, name='foo')]) worker.max_burst_jobs = 0 @@ -783,12 +865,13 @@ async def foo(ctx, v): assert len(worker.tasks) == 0 -async def test_non_burst(arq_redis: ArqRedis, worker, caplog, loop): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_non_burst(arq_redis: ArqRedis, worker, stream, caplog, loop): async def foo(ctx, v): return v + 1 caplog.set_level(logging.INFO) - await arq_redis.enqueue_job('foo', 1, _job_id='testing') + await arq_redis.enqueue_job('foo', 1, _job_id='testing', _use_stream=stream) worker: Worker = worker(functions=[func(foo, name='foo')]) worker.burst = False t = loop.create_task(worker.main()) @@ -800,7 +883,8 @@ async def foo(ctx, v): assert '← testing:foo ● 2' in caplog.text -async def test_multi_exec(arq_redis: ArqRedis, worker, caplog): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_multi_exec(arq_redis: ArqRedis, worker, stream, caplog): c = 0 async def foo(ctx, v): @@ -809,7 +893,7 @@ async def foo(ctx, v): return v + 1 caplog.set_level(logging.DEBUG, logger='arq.worker') - await arq_redis.enqueue_job('foo', 1, _job_id='testing') + await arq_redis.enqueue_job('foo', 1, _job_id='testing', _use_stream=stream) worker: Worker = worker(functions=[func(foo, name='foo')]) await asyncio.gather(*[worker.start_jobs([b'testing']) for _ in range(5)]) # debug(caplog.text) @@ -819,7 +903,8 @@ async def foo(ctx, v): # assert 'WatchVariableError' not in caplog.text -async def test_abort_job(arq_redis: ArqRedis, worker, caplog, loop): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_abort_job(arq_redis: ArqRedis, worker, stream, caplog, loop): async def longfunc(ctx): await asyncio.sleep(3600) @@ -829,7 +914,7 @@ async def wait_and_abort(job, delay=0.1): caplog.set_level(logging.INFO) await arq_redis.zadd(abort_jobs_ss, {b'foobar': int(1e9)}) - job = await arq_redis.enqueue_job('longfunc', _job_id='testing') + job = await arq_redis.enqueue_job('longfunc', _job_id='testing', _use_stream=stream) worker: Worker = worker(functions=[func(longfunc, name='longfunc')], allow_abort_jobs=True, poll_delay=0.1) assert worker.jobs_complete == 0 @@ -852,13 +937,14 @@ async def test_abort_job_which_is_not_in_queue(arq_redis: ArqRedis): assert await job.abort() is False -async def test_abort_job_before(arq_redis: ArqRedis, worker, caplog, loop): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_abort_job_before(arq_redis: ArqRedis, worker, stream, caplog, loop): async def longfunc(ctx): await asyncio.sleep(3600) caplog.set_level(logging.INFO) - job = await arq_redis.enqueue_job('longfunc', _job_id='testing') + job = await arq_redis.enqueue_job('longfunc', _job_id='testing', _use_stream=stream) worker: Worker = worker(functions=[func(longfunc, name='longfunc')], allow_abort_jobs=True, poll_delay=0.1) assert worker.jobs_complete == 0 @@ -878,14 +964,15 @@ async def longfunc(ctx): assert worker.tasks == {} -async def test_abort_deferred_job_before(arq_redis: ArqRedis, worker, caplog, loop): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_abort_deferred_job_before(arq_redis: ArqRedis, worker, stream, caplog, loop): async def longfunc(ctx): await asyncio.sleep(3600) caplog.set_level(logging.INFO) job = await arq_redis.enqueue_job( - 'longfunc', _job_id='testing', _defer_until=datetime.now(timezone.utc) + timedelta(days=1) + 'longfunc', _job_id='testing', _defer_until=datetime.now(timezone.utc) + timedelta(days=1), _use_stream=stream ) worker: Worker = worker(functions=[func(longfunc, name='longfunc')], allow_abort_jobs=True, poll_delay=0.1) @@ -908,7 +995,8 @@ async def longfunc(ctx): assert worker.tasks == {} -async def test_not_abort_job(arq_redis: ArqRedis, worker, caplog, loop): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_not_abort_job(arq_redis: ArqRedis, worker, stream, caplog, loop): async def shortfunc(ctx): await asyncio.sleep(0.2) @@ -917,7 +1005,7 @@ async def wait_and_abort(job, delay=0.1): assert await job.abort() is False caplog.set_level(logging.INFO) - job = await arq_redis.enqueue_job('shortfunc', _job_id='testing') + job = await arq_redis.enqueue_job('shortfunc', _job_id='testing', _use_stream=stream) worker: Worker = worker(functions=[func(shortfunc, name='shortfunc')], poll_delay=0.1) assert worker.jobs_complete == 0 @@ -935,12 +1023,13 @@ async def wait_and_abort(job, delay=0.1): assert worker.job_tasks == {} -async def test_job_timeout(arq_redis: ArqRedis, worker, caplog): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_job_timeout(arq_redis: ArqRedis, worker, stream, caplog): async def longfunc(ctx): await asyncio.sleep(0.3) caplog.set_level(logging.ERROR) - await arq_redis.enqueue_job('longfunc', _job_id='testing') + await arq_redis.enqueue_job('longfunc', _job_id='testing', _use_stream=stream) worker: Worker = worker(functions=[func(longfunc, name='longfunc')], job_timeout=0.2, poll_delay=0.1) assert worker.jobs_complete == 0 assert worker.jobs_failed == 0 @@ -953,7 +1042,8 @@ async def longfunc(ctx): assert 'X.XXs ! testing:longfunc failed, TimeoutError:' in log -async def test_on_job(arq_redis: ArqRedis, worker): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_on_job(arq_redis: ArqRedis, worker, stream): result = {'called': 0} async def on_start(ctx): @@ -971,7 +1061,7 @@ async def after_end(ctx): async def test(ctx): return - await arq_redis.enqueue_job('func', _job_id='testing') + await arq_redis.enqueue_job('func', _job_id='testing', _use_stream=stream) worker: Worker = worker( functions=[func(test, name='func')], on_job_start=on_start, @@ -991,7 +1081,8 @@ async def test(ctx): assert result['called'] == 4 -async def test_job_cancel_on_max_jobs(arq_redis: ArqRedis, worker, caplog): +@pytest.mark.parametrize('worker, stream', [('poll_worker', False), ('stream_worker', True)], indirect=['worker']) +async def test_job_cancel_on_max_jobs(arq_redis: ArqRedis, worker, stream, caplog): async def longfunc(ctx): await asyncio.sleep(3600) @@ -1001,7 +1092,7 @@ async def wait_and_abort(job, delay=0.1): caplog.set_level(logging.INFO) await arq_redis.zadd(abort_jobs_ss, {b'foobar': int(1e9)}) - job = await arq_redis.enqueue_job('longfunc', _job_id='testing') + job = await arq_redis.enqueue_job('longfunc', _job_id='testing', _use_stream=stream) worker: Worker = worker( functions=[func(longfunc, name='longfunc')], allow_abort_jobs=True, poll_delay=0.1, max_jobs=1 @@ -1021,6 +1112,7 @@ async def wait_and_abort(job, delay=0.1): assert worker.job_tasks == {} +@pytest.mark.parametrize('worker', ['poll_worker', 'stream_worker'], indirect=['worker']) async def test_worker_timezone_defaults_to_system_timezone(worker): worker = worker(functions=[func(foobar)]) assert worker.timezone is not None @@ -1072,6 +1164,7 @@ async def test_worker_retry(mocker, worker_retry, exception_thrown): redis.exceptions.TimeoutError('Timeout reading from host'), ], ) +@pytest.mark.parametrize('worker', ['poll_worker', 'stream_worker'], indirect=['worker']) async def test_worker_crash(mocker, worker, exception_thrown): # Testing redis exceptions, no retry settings specified worker = worker(functions=[func(foobar)]) From 5747d484a63b120338f17d49f0462fc06230209b Mon Sep 17 00:00:00 2001 From: ajac-zero Date: Wed, 12 Jun 2024 21:15:05 -0600 Subject: [PATCH 3/3] Fixed mypy typings in stream implementation --- arq/worker.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/arq/worker.py b/arq/worker.py index 81ebc1c8..2dd4f184 100644 --- a/arq/worker.py +++ b/arq/worker.py @@ -389,14 +389,15 @@ async def main(self) -> None: await self.pool.xgroup_create(stream_name, self.worker_group, '$', mkstream=True) logger.info('Stream consumer group created with name: %s', self.worker_group) - async def read_messages(name: Literal['>', '0']): # type: ignore[no-untyped-def] + async def read_messages(name: Literal['>', '0']) -> None: if event := await self.pool.xreadgroup( consumername=self.worker_name, groupname=self.worker_group, streams={stream_name: name}, block=0 ): await self._poll_iteration() + acknowledge = cast(Callable[..., Any], self.pool.xack) for message in event[0][1]: - await self.pool.xack(stream_name, self.worker_group, message[0]) # type: ignore[no-untyped-call] + await acknowledge(stream_name, self.worker_group, message[0]) # Heartbeat before blocking, or health check will fail previous to receiving first message await self.heart_beat()