Skip to content

Commit

Permalink
Remove longtime deprecated functions.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707025703
  • Loading branch information
mtthss authored and t5-copybara committed Dec 17, 2024
1 parent fb95318 commit 5f03619
Showing 1 changed file with 19 additions and 22 deletions.
41 changes: 19 additions & 22 deletions t5x/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,13 @@ def restore_state(self, state):
# Optax Elementwise Wrapper


def _scale_by_schedule_ctor(state, params_axes):
del state, params_axes
return optax.ScaleByScheduleState( # pytype: disable=wrong-arg-types # numpy-scalars
count=None
)


class OptaxStatePartitionRules:
"""Collection of rules to partition optax states.
Expand Down Expand Up @@ -218,16 +225,10 @@ class OptaxStatePartitionRules:
mu=OptaxStatePartitionRules.derive_params_axes(state.mu, params_axes),
nu=OptaxStatePartitionRules.derive_params_axes(state.nu, params_axes),
),
optax.ScaleByBeliefState: (
lambda state, params_axes: optax.ScaleByBeliefState( # pytype: disable=wrong-arg-types # numpy-scalars
count=None,
mu=OptaxStatePartitionRules.derive_params_axes(
state.mu, params_axes
),
nu=OptaxStatePartitionRules.derive_params_axes(
state.nu, params_axes
),
)
optax.ScaleByBeliefState: lambda state, params_axes: optax.ScaleByBeliefState( # pytype: disable=wrong-arg-types # numpy-scalars
count=None,
mu=OptaxStatePartitionRules.derive_params_axes(state.mu, params_axes),
nu=OptaxStatePartitionRules.derive_params_axes(state.nu, params_axes),
),
optax.ScaleByLionState: lambda state, params_axes: optax.ScaleByLionState( # pytype: disable=wrong-arg-types # numpy-scalars
count=None,
Expand Down Expand Up @@ -258,9 +259,7 @@ class OptaxStatePartitionRules:
optax.ScaleByTrustRatioState: (
lambda state, params_axes: optax.ScaleByTrustRatioState()
),
optax.ScaleByScheduleState: (
lambda state, params_axes: optax.ScaleByScheduleState(count=None) # pytype: disable=wrong-arg-types # numpy-scalars
),
optax.ScaleByScheduleState: _scale_by_schedule_ctor,
optax.ZeroNansState: lambda state, params_axes: optax.ZeroNansState(
found_nan=None
),
Expand All @@ -272,14 +271,12 @@ class OptaxStatePartitionRules:
state.inner_state, params_axes
)
),
optax.InjectHyperparamsState: (
lambda state, params_axes: optax.InjectHyperparamsState( # pytype: disable=wrong-arg-types # jax-ndarray
count=None,
hyperparams=jax.tree.map(lambda x: None, state.hyperparams),
inner_state=OptaxStatePartitionRules.derive_optax_logical_axes(
state.inner_state, params_axes
),
)
optax.InjectHyperparamsState: lambda state, params_axes: optax.InjectHyperparamsState( # pytype: disable=wrong-arg-types # jax-ndarray
count=None,
hyperparams=jax.tree.map(lambda x: None, state.hyperparams),
inner_state=OptaxStatePartitionRules.derive_optax_logical_axes(
state.inner_state, params_axes
),
),
optax.MultiStepsState: lambda state, params_axes: optax.MultiStepsState( # pytype: disable=wrong-arg-types # jax-ndarray
mini_step=None,
Expand All @@ -299,7 +296,7 @@ class OptaxStatePartitionRules:
),
)
),
optax.MaybeUpdateState: lambda state, params_axes: optax.MaybeUpdateState( # pytype: disable=wrong-arg-types # jax-ndarray
optax.ConditionallyTransformState: lambda state, params_axes: optax.ConditionallyTransformState( # pytype: disable=wrong-arg-types # jax-ndarray
inner_state=OptaxStatePartitionRules.derive_optax_logical_axes(
state.inner_state, params_axes
),
Expand Down

0 comments on commit 5f03619

Please sign in to comment.