Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: refactor all asserts into raise <exception>, close #371 #379

Merged
merged 1 commit into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions arq/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ class RedisSettings:
@classmethod
def from_dsn(cls, dsn: str) -> 'RedisSettings':
conf = urlparse(dsn)
assert conf.scheme in {'redis', 'rediss', 'unix'}, 'invalid DSN scheme'
if conf.scheme not in {'redis', 'rediss', 'unix'}:
raise RuntimeError('invalid DSN scheme')
query_db = parse_qs(conf.query).get('db')
if query_db:
# e.g. redis://localhost:6379?db=1
Expand Down Expand Up @@ -138,7 +139,8 @@ async def enqueue_job(
_queue_name = self.default_queue_name
job_id = _job_id or uuid4().hex
job_key = job_key_prefix + job_id
assert not (_defer_until and _defer_by), "use either 'defer_until' or 'defer_by' or neither, not both"
if _defer_until and _defer_by:
raise RuntimeError("use either 'defer_until' or 'defer_by' or neither, not both")

defer_by_ms = to_ms(_defer_by)
expires_ms = to_ms(_expires)
Expand Down Expand Up @@ -190,7 +192,8 @@ async def all_job_results(self) -> List[JobResult]:
async def _get_job_def(self, job_id: bytes, score: int) -> JobDef:
key = job_key_prefix + job_id.decode()
v = await self.get(key)
assert v is not None, f'job "{key}" not found'
if v is None:
raise RuntimeError(f'job "{key}" not found')
jd = deserialize_job(v, deserializer=self.job_deserializer)
jd.score = score
return jd
Expand Down Expand Up @@ -221,9 +224,8 @@ async def create_pool(
"""
settings: RedisSettings = RedisSettings() if settings_ is None else settings_

assert not (
type(settings.host) is str and settings.sentinel
), "str provided for 'host' but 'sentinel' is true; list of sentinels expected"
if isinstance(settings.host, str) and settings.sentinel:
raise RuntimeError("str provided for 'host' but 'sentinel' is true; list of sentinels expected")

if settings.sentinel:

Expand Down
11 changes: 7 additions & 4 deletions arq/cron.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,10 @@ def _get_next_dt(dt_: datetime, options: Options) -> Optional[datetime]: # noqa
next_v = getattr(dt_, field)
if isinstance(v, int):
mismatch = next_v != v
else:
assert isinstance(v, (set, list, tuple)), v
elif isinstance(v, (set, list, tuple)):
mismatch = next_v not in v
else:
raise RuntimeError(v)
# print(field, v, next_v, mismatch)
if mismatch:
micro = max(dt_.microsecond - options.microsecond, 0)
Expand All @@ -82,7 +83,8 @@ def _get_next_dt(dt_: datetime, options: Options) -> Optional[datetime]: # noqa
elif field == 'second':
return dt_ + timedelta(seconds=1) - timedelta(microseconds=micro)
else:
assert field == 'microsecond', field
if field != 'microsecond':
raise RuntimeError(field)
return dt_ + timedelta(microseconds=options.microsecond - dt_.microsecond)
return None

Expand Down Expand Up @@ -173,7 +175,8 @@ def cron(
else:
coroutine_ = coroutine

assert asyncio.iscoroutinefunction(coroutine_), f'{coroutine_} is not a coroutine function'
if not asyncio.iscoroutinefunction(coroutine_):
raise RuntimeError(f'{coroutine_} is not a coroutine function')
timeout = to_seconds(timeout)
keep_result = to_seconds(keep_result)

Expand Down
9 changes: 6 additions & 3 deletions arq/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def func(
else:
coroutine_ = coroutine

assert asyncio.iscoroutinefunction(coroutine_), f'{coroutine_} is not a coroutine function'
if not asyncio.iscoroutinefunction(coroutine_):
raise RuntimeError(f'{coroutine_} is not a coroutine function')
timeout = to_seconds(timeout)
keep_result = to_seconds(keep_result)

Expand Down Expand Up @@ -226,10 +227,12 @@ def __init__(
self.queue_name = queue_name
self.cron_jobs: List[CronJob] = []
if cron_jobs is not None:
assert all(isinstance(cj, CronJob) for cj in cron_jobs), 'cron_jobs, must be instances of CronJob'
if not all(isinstance(cj, CronJob) for cj in cron_jobs):
raise RuntimeError('cron_jobs, must be instances of CronJob')
self.cron_jobs = list(cron_jobs)
self.functions.update({cj.name: cj for cj in self.cron_jobs})
assert len(self.functions) > 0, 'at least one function or cron_job must be registered'
if len(self.functions) == 0:
raise RuntimeError('at least one function or cron_job must be registered')
self.burst = burst
self.on_startup = on_startup
self.on_shutdown = on_shutdown
Expand Down