Skip to content

Commit

Permalink
Use select rather than _cond.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 323757367
Change-Id: Id512245c608a6c2709d7748f99308d6502d7a2d9
  • Loading branch information
tomhennigan authored and copybara-github committed Jul 29, 2020
1 parent 95a719d commit 8079063
Showing 1 changed file with 1 addition and 6 deletions.
7 changes: 1 addition & 6 deletions haiku/_src/moving_averages.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,6 @@ def __init__(self, decay, zero_debias=True, warmup_length=0, name=None):
"average to an initial value. Set zero_debias=False if setting "
"warmup_length to a non-zero value.")

def _cond(self, cond, t, f, dtype):
"""Internal, implements jax.lax.cond without control flow."""
c = cond.astype(dtype)
return c * t + (1. - c) * f

def initialize(self, value):
"""If uninitialized sets the average to ``zeros_like`` the given value."""
base.get_state("hidden", value.shape, value.dtype, init=jnp.zeros)
Expand Down Expand Up @@ -91,7 +86,7 @@ def __call__(self, value, update_stats=True):

decay = jax.lax.convert_element_type(self._decay, value.dtype)
if self._warmup_length > 0:
decay = self._cond(counter <= 0, 0.0, decay, value.dtype)
decay = jax.lax.select(counter <= 0, 0.0, decay)

one = jnp.ones([], value.dtype)
hidden = base.get_state("hidden", value.shape, value.dtype, init=jnp.zeros)
Expand Down

0 comments on commit 8079063

Please sign in to comment.