Skip to content

Commit

Permalink
Fix relationship between optimizer and LR scheduler.
Browse files Browse the repository at this point in the history
- Always initialize the LR scheduler before the first step.
- Correctly call LR scheduler only once per iteration.

See pytorch/pytorch#20124

Note: there is still a problem if you set update_interval as
a number of Epochs.
  • Loading branch information
catwell committed Dec 20, 2024
1 parent 2e6e1ed commit 7740ba1
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/refiners/training_utils/clock.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(
self,
training_duration: TimeValue,
gradient_accumulation: Step,
lr_scheduler_interval: TimeValue,
lr_scheduler_interval: Iteration | Epoch,
verbose: bool = True,
) -> None:
self.training_duration = training_duration
Expand Down
12 changes: 8 additions & 4 deletions src/refiners/training_utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ def __init__(self, config: ConfigType) -> None:
self._load_models()
self._call_callbacks(event_name="on_init_end")

# Ensure the lr_scheduler is initialized before calling `step` on the optimizer.
# See `patch_track_step_called` in LRScheduler constructor.
assert self.lr_scheduler

@register_callback()
def clock(self, config: ClockConfig) -> TrainingClock:
return TrainingClock(
Expand Down Expand Up @@ -299,10 +303,10 @@ def backward(self) -> None:
self.optimizer.step()
self.optimizer.zero_grad()
self._call_callbacks(event_name="on_optimizer_step_end")
if self.clock.is_due(self.config.lr_scheduler.update_interval):
self._call_callbacks(event_name="on_lr_scheduler_step_begin")
self.lr_scheduler.step()
self._call_callbacks(event_name="on_lr_scheduler_step_end")
if self.clock.is_due(self.config.lr_scheduler.update_interval):
self._call_callbacks(event_name="on_lr_scheduler_step_begin")
self.lr_scheduler.step()
self._call_callbacks(event_name="on_lr_scheduler_step_end")

def step(self, batch: Batch) -> None:
"""Perform a single training step."""
Expand Down

0 comments on commit 7740ba1

Please sign in to comment.