Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend RedisSettings to include redis Retry Helper settings #387

Merged
merged 14 commits into from
Apr 1, 2024
8 changes: 8 additions & 0 deletions arq/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from uuid import uuid4

from redis.asyncio import ConnectionPool, Redis
from redis.asyncio.retry import Retry
from redis.asyncio.sentinel import Sentinel
from redis.exceptions import RedisError, WatchError

Expand Down Expand Up @@ -47,6 +48,10 @@ class RedisSettings:
sentinel: bool = False
sentinel_master: str = 'mymaster'

retry_on_timeout: bool = False
retry_on_error: Optional[List[Exception]] = None
retry: Optional[Retry] = None

@classmethod
def from_dsn(cls, dsn: str) -> 'RedisSettings':
conf = urlparse(dsn)
Expand Down Expand Up @@ -253,6 +258,9 @@ def pool_factory(*args: Any, **kwargs: Any) -> ArqRedis:
ssl_ca_certs=settings.ssl_ca_certs,
ssl_ca_data=settings.ssl_ca_data,
ssl_check_hostname=settings.ssl_check_hostname,
retry=settings.retry,
retry_on_timeout=settings.retry_on_timeout,
retry_on_error=settings.retry_on_error,
)

while True:
Expand Down
40 changes: 40 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

import msgpack
import pytest
import redis.exceptions
from redis.asyncio.retry import Retry
from redis.backoff import NoBackoff
from redislite import Redis

from arq.connections import ArqRedis, create_pool
Expand Down Expand Up @@ -52,6 +55,21 @@ async def arq_redis_msgpack(loop):
await redis_.close(close_connection_pool=True)


@pytest.fixture
async def arq_redis_retry(loop):
redis_ = ArqRedis(
host='localhost',
port=6379,
encoding='utf-8',
retry=Retry(backoff=NoBackoff(), retries=3),
retry_on_timeout=True,
retry_on_error=[redis.exceptions.ConnectionError],
)
await redis_.flushall()
yield redis_
await redis_.close(close_connection_pool=True)


@pytest.fixture
async def worker(arq_redis):
worker_: Worker = None
Expand All @@ -69,6 +87,28 @@ def create(functions=[], burst=True, poll_delay=0, max_jobs=10, arq_redis=arq_re
await worker_.close()


@pytest.fixture
async def worker_retry(arq_redis_retry):
worker_retry_: Worker = None

def create(functions=[], burst=True, poll_delay=0, max_jobs=10, arq_redis=arq_redis_retry, **kwargs):
nonlocal worker_retry_
worker_retry_ = Worker(
functions=functions,
redis_pool=arq_redis,
burst=burst,
poll_delay=poll_delay,
max_jobs=max_jobs,
**kwargs,
)
return worker_retry_

yield create

if worker_retry_:
await worker_retry_.close()


@pytest.fixture(name='create_pool')
async def fix_create_pool(loop):
pools = []
Expand Down
5 changes: 4 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_settings_changed():
"RedisSettings(host='localhost', port=123, unix_socket_path=None, database=0, username=None, password=None, "
"ssl=False, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs='required', ssl_ca_certs=None, "
'ssl_ca_data=None, ssl_check_hostname=False, conn_timeout=1, conn_retries=5, conn_retry_delay=1, '
"sentinel=False, sentinel_master='mymaster')"
"sentinel=False, sentinel_master='mymaster', retry_on_timeout=False, retry_on_error=None, retry=None)"
) == str(settings)


Expand Down Expand Up @@ -112,6 +112,9 @@ def test_redis_settings_validation():
class Settings(BaseModel):
redis_settings: RedisSettings

class Config:
arbitrary_types_allowed = True

@validator('redis_settings', always=True, pre=True)
def parse_redis_settings(cls, v):
if isinstance(v, str):
Expand Down
79 changes: 78 additions & 1 deletion tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import signal
import sys
from datetime import datetime, timedelta
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import msgpack
import pytest
import redis.exceptions

from arq.connections import ArqRedis, RedisSettings
from arq.constants import abort_jobs_ss, default_queue_name, expires_extra_ms, health_check_key_suffix, job_key_prefix
Expand Down Expand Up @@ -1018,3 +1019,79 @@ async def test_worker_timezone_defaults_to_system_timezone(worker):
worker = worker(functions=[func(foobar)])
assert worker.timezone is not None
assert worker.timezone == datetime.now().astimezone().tzinfo


@pytest.mark.parametrize(
'exception_thrown',
[
redis.exceptions.ConnectionError('Error while reading from host'),
redis.exceptions.TimeoutError('Timeout reading from host'),
],
)
async def test_worker_retry(mocker, worker_retry, exception_thrown):
# Testing redis exceptions, with retry settings specified
worker = worker_retry(functions=[func(foobar)])

# patch db read_response to mimic connection exceptions
p = patch.object(worker.pool.connection_pool.connection_class, 'read_response', side_effect=exception_thrown)

# baseline
await worker.main()
await worker._poll_iteration()

# spy method handling call_with_retry failure
spy = mocker.spy(worker.pool, '_disconnect_raise')

try:
# start patch
p.start()

# assert exception thrown
with pytest.raises(type(exception_thrown)):
await worker._poll_iteration()

# assert retry counts and no exception thrown during '_disconnect_raise'
assert spy.call_count == 4 # retries setting + 1
assert spy.spy_exception is None

finally:
# stop patch to allow worker cleanup
p.stop()


@pytest.mark.parametrize(
'exception_thrown',
[
redis.exceptions.ConnectionError('Error while reading from host'),
redis.exceptions.TimeoutError('Timeout reading from host'),
],
)
async def test_worker_crash(mocker, worker, exception_thrown):
# Testing redis exceptions, no retry settings specified
worker = worker(functions=[func(foobar)])

# patch db read_response to mimic connection exceptions
p = patch.object(worker.pool.connection_pool.connection_class, 'read_response', side_effect=exception_thrown)

# baseline
await worker.main()
await worker._poll_iteration()

# spy method handling call_with_retry failure
spy = mocker.spy(worker.pool, '_disconnect_raise')

try:
# start patch
p.start()

# assert exception thrown
with pytest.raises(type(exception_thrown)):
await worker._poll_iteration()

# assert no retry counts and exception thrown during '_disconnect_raise'
assert spy.call_count == 1
assert spy.spy_exception == exception_thrown

finally:
# stop patch to allow worker cleanup
p.stop()
Loading