Skip to content

Commit

Permalink
Fix pyright failures from numpy 2.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Dec 19, 2024
1 parent 613dd0e commit 3553ae1
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion diffrax/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def static_select(pred: BoolScalarLike, a: ArrayLike, b: ArrayLike) -> ArrayLike
# This in turn allows us to perform some trace-time optimisations that XLA isn't
# smart enough to do on its own.
if isinstance(pred, (np.ndarray, np.generic)) and pred.shape == ():
pred = pred.item()
pred = cast(BoolScalarLike, pred.item())
if pred is True:
return a
elif pred is False:
Expand Down
5 changes: 3 additions & 2 deletions diffrax/_progress_meter.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,9 @@ def _step_bar(bar: list[float], progress: FloatScalarLike) -> None:
if eqx.is_array(progress):
# May not be an array when called with `JAX_DISABLE_JIT=1`
progress = cast(Union[Array, np.ndarray], progress)
progress = progress.item()
progress = cast(float, progress)
progress = cast(float, progress.item())
else:
progress = cast(float, progress)
bar[0] = progress
print(f"{100 * progress:.2f}%")

Expand Down

0 comments on commit 3553ae1

Please sign in to comment.