Skip to content

Commit

Permalink
Merge pull request #13 from ryanhoangt/wandb-checkpoint-upload
Browse files Browse the repository at this point in the history
Save model checkpoints to Wandb
  • Loading branch information
lennart-finke authored Nov 2, 2024
2 parents a932f89 + ee532f6 commit 53db0b3
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 7 deletions.
9 changes: 5 additions & 4 deletions simple_stories_train/train_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import time
from contextlib import nullcontext
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path

import numpy as np
Expand All @@ -62,7 +63,6 @@
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.nn import functional as F
from torch.nn.parallel import DistributedDataParallel as DDP

from utils import (
init_wandb,
is_checkpoint_step,
Expand Down Expand Up @@ -860,7 +860,8 @@ def get_lr(it: int) -> float:
checkpoints_dir = None
output_dir = None
if args.output_dir:
output_dir = Path(args.output_dir)
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
output_dir = Path(args.output_dir) / f"{timestamp}"
output_dir.mkdir(parents=True, exist_ok=True)
logfile = output_dir / "main.log"
# create the log file "main.log" inside it, and wipe it clean
Expand All @@ -870,7 +871,7 @@ def get_lr(it: int) -> float:
# set our checkpoints directory and save off the initilized model
checkpoints_dir = output_dir / "checkpoints"
checkpoints_dir.mkdir(parents=True, exist_ok=True)
save_model_and_config(checkpoints_dir, raw_model, step=0)
save_model_and_config(checkpoints_dir, raw_model, args.__dict__, step=0)

if device == "cuda":
torch.cuda.reset_peak_memory_stats()
Expand Down Expand Up @@ -1001,7 +1002,7 @@ def get_lr(it: int) -> float:
f.write("s:%d trl:%f\n" % (step, lossf))

if checkpoints_dir is not None and is_checkpoint_step(step):
save_model_and_config(checkpoints_dir, raw_model, step=step)
save_model_and_config(checkpoints_dir, raw_model, args.__dict__, step=step)

# keep track of smooth timings, last 20 iterations
if step > 1 and step > args.num_iterations - 20:
Expand Down
21 changes: 18 additions & 3 deletions simple_stories_train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
import wandb
import yaml
from torch import nn


Expand All @@ -21,18 +22,32 @@ def is_checkpoint_step(step: int) -> bool:
return (0 < step < 1000 and (step & (step - 1)) == 0) or step % 1000 == 0


def save_model_and_config(save_dir: Path, model: nn.Module, step: int) -> None:
"""Save the model to disk. Also save the config file if it doesn't exist.
def save_model_and_config(
save_dir: Path,
model: nn.Module,
config_dict: dict[str, Any],
step: int,
config_filename: str = "final_config.yaml",
) -> None:
"""Save the model to disk and wandb. Also save the config file if it doesn't exist.
Args:
save_dir: The directory to save the model and config to.
model: The model to save.
step: The current step (used in the model filename).
"""
save_dir.mkdir(parents=True, exist_ok=True)
model_file = save_dir / f"model_step_{step}.pt"
config_file = save_dir / config_filename
if not config_file.exists():
with open(config_file, "w") as f:
yaml.dump(config_dict, f)
model_file_name = f"model_step_{step}.pt"
model_file = save_dir / model_file_name
torch.save(model.state_dict(), model_file)
print0(f"Saved model to {model_file}")
if config_dict.get("wandb_project"):
wandb.save(str(model_file), policy="now", base_path=save_dir)
print0(f"Saved model to wandb: {str(model_file_name)}")


def init_wandb(config: Any, project: str) -> None:
Expand Down

0 comments on commit 53db0b3

Please sign in to comment.