Skip to content

Commit

Permalink
- Using version guard to fix change that broke backwards compatibilit…
Browse files Browse the repository at this point in the history
…y with some older versions of JAX.

- Adding test of in_pmap.

PiperOrigin-RevId: 695740774
  • Loading branch information
james-martens authored and KfacJaxDev committed Nov 12, 2024
1 parent 8a610fc commit d3cf2cd
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 4 deletions.
19 changes: 17 additions & 2 deletions kfac_jax/_src/utils/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,35 @@

from kfac_jax._src.utils import types

jax_version = (
jax.__version_info__ if hasattr(jax, "__version_info__")
else tuple(map(int, jax.__version__.split("."))))


Array = types.Array
Numeric = types.Numeric
PRNGKey = types.PRNGKey
TArrayTree = types.TArrayTree


# TODO(jamesmartens,botev): add a test for this function?
def in_pmap(axis_name: str | None) -> bool:
"""Returns whether we are in a pmap with the given axis name."""

if axis_name is None:
return False

return axis_name in core.unsafe_get_axis_names_DO_NOT_USE()
if jax_version >= (0, 4, 36):
return axis_name in core.unsafe_get_axis_names_DO_NOT_USE()

try:
# The only way to know if we are under `jax.pmap` is to check if the
# function call below raises a `NameError` or not.
core.axis_frame(axis_name)

return True

except NameError:
return False


def wrap_if_pmap(
Expand Down
27 changes: 25 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class TestStableSqrt(parameterized.TestCase):

def test_stable_sqrt(self):
"""Tests calculation of the stable square root."""

x = jnp.asarray([1.0, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 0.0])
expected_y = jnp.sqrt(x)
expected_dx = jnp.minimum(1 / (2 * expected_y), 1000.0)
Expand Down Expand Up @@ -87,14 +88,15 @@ class TestRearrage(parameterized.TestCase):
output_shape=[32, 3600],
),
)
def test_stable_sqrt(
def test_rearrange(
self,
shape: list[int],
spec: str,
transpose_order: list[int],
output_shape: list[int],
):
"""Tests calculation of the stable square root."""
"""Tests rearrange function."""

rng = jax.random.PRNGKey(0)
x = jax.random.normal(rng, shape)
y_target = jnp.transpose(x, transpose_order).reshape(output_shape)
Expand All @@ -103,5 +105,26 @@ def test_stable_sqrt(
np.testing.assert_array_equal(y, y_target)


class InPmapTest(absltest.TestCase):
"""Test class for the in_pmap function."""

def test_in_pmap_outside_pmap(self):
self.assertFalse(kfac_jax.utils.in_pmap("my_axis"))

def test_in_pmap_inside_pmap(self):
def f(x):
self.assertTrue(kfac_jax.utils.in_pmap("my_axis"))
return x + 1

jax.pmap(f, axis_name="my_axis")(jnp.ones([jax.local_device_count()]))

def test_in_pmap_inside_pmap_wrong_axis(self):
def f(x):
self.assertFalse(kfac_jax.utils.in_pmap("my_axis"))
return x + 1

jax.pmap(f, axis_name="their_axis")(jnp.ones([jax.local_device_count()]))


if __name__ == "__main__":
absltest.main()

0 comments on commit d3cf2cd

Please sign in to comment.