Skip to content

Commit

Permalink
[hk] Improve error message from using with_rng outside of hk.transform
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 314326619
Change-Id: I191612602d9c54f3feaec52a69b353a83164e432
  • Loading branch information
trevorcai authored and copybara-github committed Jun 2, 2020
1 parent a7d498d commit 31a2fc1
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
1 change: 1 addition & 0 deletions haiku/_src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,4 +723,5 @@ def with_rng(key: PRNGKey):
Returns:
Context manager under which the given sequence is active.
"""
assert_context("with_rng")
return current_frame().rng_stack(PRNGSequence(key))
6 changes: 6 additions & 0 deletions haiku/_src/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,12 @@ def test_with_rng(self, seed):
self.assertNotEqual(without_decorator_out, expected_output)
self.assertEqual(with_decorator_out, expected_output)

def test_with_rng_no_transform(self):
with self.assertRaisesRegex(ValueError,
"must be used as part of an `hk.transform`"):
with base.with_rng(jax.random.PRNGKey(428)):
pass

def test_new_context(self):
with base.new_context() as ctx:
pass
Expand Down

0 comments on commit 31a2fc1

Please sign in to comment.