Skip to content

Commit

Permalink
Using version guard to fix change that broke backwards compatibility …
Browse files Browse the repository at this point in the history
…with some older versions of JAX.

PiperOrigin-RevId: 694248080
  • Loading branch information
james-martens authored and KfacJaxDev committed Nov 7, 2024
1 parent c284c12 commit ef97d57
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion kfac_jax/_src/utils/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@

from kfac_jax._src.utils import types

jax_version = (
jax.__version_info__ if hasattr(jax, "__version_info__")
else tuple(map(int, jax.__version__.split("."))))


Array = types.Array
Numeric = types.Numeric
PRNGKey = types.PRNGKey
Expand All @@ -35,7 +40,19 @@ def in_pmap(axis_name: str | None) -> bool:
if axis_name is None:
return False

return axis_name in core.unsafe_get_axis_names_DO_NOT_USE()
if jax_version >= (0, 4, 36):
return axis_name in core.unsafe_get_axis_names_DO_NOT_USE()

else:
try:
# The only way to know if we are under `jax.pmap` is to check if the
# function call below raises a `NameError` or not.
core.axis_frame(axis_name)

return True

except NameError:
return False


def wrap_if_pmap(
Expand Down

0 comments on commit ef97d57

Please sign in to comment.