Skip to content

Commit

Permalink
Adding step rejection feature
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 695090781
  • Loading branch information
james-martens authored and KfacJaxDev committed Nov 12, 2024
1 parent d3cf2cd commit 5217df8
Showing 1 changed file with 48 additions and 5 deletions.
53 changes: 48 additions & 5 deletions kfac_jax/_src/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def __init__(
damping_upper_threshold: Numeric = 0.75,
always_use_exact_qmodel_for_damping_adjustment: bool = False,
precon_damping_mult: Numeric = 1.0,
use_step_rejection: bool = False,
reject_damping_increase_factor: float = 1.0,
norm_constraint: Numeric | None = None,
num_burnin_steps: int = 10,
estimation_mode: str | None = None,
Expand Down Expand Up @@ -163,7 +165,7 @@ def __init__(
scale of the objective, so that if you multiply your loss by some factor you
should do the same for the damping. Roughly speaking, larger damping values
constrain the update vector to a smaller region around zero, which is needed
in general since the second-order approximations that underly second-order
in general since the second-order approximations that underlie second-order
methods can break down for large updates. (In gradient descent the learning
rate plays an analogous role.) The relationship between the damping
parameter and the radius of this region is complicated and depends on the
Expand Down Expand Up @@ -283,6 +285,12 @@ def __init__(
precon_damping_mult: Scalar. Multiplies the damping used in the
preconditioner (vs the exact quadratic model) by this value.
(Default: 1.0)
use_step_rejection: Whether or not to reject the step whenever the loss
on the current batch goes up after the update. This option offers
robustness at the cost of doing more work per step (unless adaptive
damping with Levenberg-Marquardt is used). (Default: ``False``)
reject_damping_increase_factor: The damping parameter is increased by this
factor if the step is rejected. (Default: ``1.0``)
norm_constraint: Scalar. If specified, the update is scaled down so that
its approximate squared Fisher norm ``v^T F v`` is at most the specified
value. (Note that here ``F`` is the approximate curvature matrix, not
Expand Down Expand Up @@ -448,6 +456,9 @@ def __init__(
always_use_exact_qmodel_for_damping_adjustment)
self._precon_damping_mult = precon_damping_mult

self._use_step_rejection = use_step_rejection
self._reject_damping_increase_factor = reject_damping_increase_factor

self._norm_constraint = norm_constraint
self._num_burnin_steps = num_burnin_steps
self._curvature_ema = curvature_ema
Expand Down Expand Up @@ -1163,12 +1174,11 @@ def _step(
damping=damping,
func_args=func_args)

# Compute delta and update velocities
# Compute the parameter update (delta)
delta = self.weighted_sum_of_objects(vectors, coefficients)
state.velocities = delta

# Update parameters
params = jax.tree_util.tree_map(jnp.add, params, delta)
new_params = jax.tree_util.tree_map(jnp.add, params, delta)

# Optionally compute the reduction ratio and update the damping
if self._use_adaptive_damping:
Expand All @@ -1179,12 +1189,42 @@ def _step(
lambda args: (args[0], self._invalid_metric_value,
self._invalid_metric_value),
operand=(state.damping, loss, quad_model_change,
(params,) + func_args[1:])
(new_params,) + func_args[1:])
)

new_loss_is_valid = self.should_update_damping(state)

else:
# If not adjusting the damping we don't compute these here and just set
# them to self._invalid_metric_value.
new_loss, rho = self._invalid_metric_value, self._invalid_metric_value
new_loss_is_valid = False

if self._use_step_rejection:

# Don't recompute it if we already have it from before.
new_loss = lax.cond(
jnp.logical_not(new_loss_is_valid), # static eval when possible?
lambda: self.compute_loss_value((new_params,) + func_args[1:],
state=state),
lambda: new_loss,
)

# Sync (possibly redundant)
new_loss = utils.pmean_if_pmap(new_loss, self.pmap_axis_name)

reject_step = jnp.logical_or(jnp.isnan(new_loss), new_loss > loss)

params, state.velocities, state.damping = lax.cond(
reject_step,
lambda: (params, state.velocities, state.damping),
lambda: (new_params, delta,
self._reject_damping_increase_factor * state.damping))

else:
# stop the linter from complaining about uninitialized variable
reject_step = False
params, state.velocities = new_params, delta

# Compute per-device and total batch size
batch_size = self._batch_size_extractor(func_args[-1])
Expand Down Expand Up @@ -1217,6 +1257,9 @@ def _step(
scaled_grad_norm_sq=scaled_grad_norm_sq,
)

if self._use_step_rejection:
stats["step_rejected"] = reject_step

if aux is not None:
aux = utils.pmean_if_pmap(aux, self.pmap_axis_name)
stats["aux"] = aux
Expand Down

0 comments on commit 5217df8

Please sign in to comment.