Skip to content

Commit

Permalink
Merge pull request #14 from danbraunai/dataloader
Browse files Browse the repository at this point in the history
Fixed issues with DDP
  • Loading branch information
lennart-finke authored Nov 2, 2024
2 parents 53db0b3 + 28fa529 commit 026ae35
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 13 deletions.
Binary file added .DS_Store
Binary file not shown.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,8 @@ make test-all # Run all tests
Training a simple model:
`python simple_stories_train/train_llama.py --model d2 --sequence_length 1024 --total_batch_size=4096`

For a final model, we currently (intend to) run:
`torchrun --standalone --nproc_per_node=8 simple_stories_train/train_llama.py --model d24 --sequence_length 1024 --total_batch_size=16448 --compile 1 --tensorcores=1 --dtype=bfloat16 --wandb 1`

You may be asked to enter your wandb API key. You can find it in your [wandb account settings](https://wandb.ai/settings). Alternatively, to avoid entering your API key on program execution, you can set the environment variable `WANDB_API_KEY` to your API key, or put it in a
`.env` file under the root of the repository.
31 changes: 18 additions & 13 deletions simple_stories_train/train_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@
save_model_and_config,
)


# using a global to toggle flash-attention
FLASH = 0

Expand Down Expand Up @@ -388,6 +387,7 @@ def forward(
if targets is not None:
# if we are given some desired targets also calculate the loss
logits = self.lm_head(x)
targets = targets.long()
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1
)
Expand Down Expand Up @@ -642,6 +642,7 @@ def generate(
# python -> C bridge
parser.add_argument("--write_tensors", type=int, default=1, help="write tensors to disk")
# wandb settings
parser.add_argument("--wandb", type=int, default=0, help="use wandb?")
parser.add_argument("--wandb_project", type=str, default="", help="wandb project name")
args = parser.parse_args()

Expand Down Expand Up @@ -824,7 +825,8 @@ def generate(

# -------------------------------------------------------------------------
# main training loop
init_wandb(args, args.wandb_project)
if args.wandb:
init_wandb(args, args.wandb_project)

# here we wrap model into DDP container
if ddp:
Expand Down Expand Up @@ -897,7 +899,8 @@ def get_lr(it: int) -> float:
val_loss += loss.item()
val_loss /= args.val_max_steps
# log to wandb
log_metrics(step, {"val_loss": val_loss})
if args.wandb:
log_metrics(step, {"val_loss": val_loss})
# log to console and to file
print0(f"val loss {val_loss}")
if master_process and logfile is not None:
Expand All @@ -921,7 +924,8 @@ def get_lr(it: int) -> float:
print0(enc.decode(yg[0].tolist()))
print0("---------------")
# log to wandb
generations.append([step, enc.decode(yg[0].tolist())])
if args.wandb:
generations.append([step, enc.decode(yg[0].tolist())])
log_generations(step, generations)

# bit confusing: we want to make sure to eval and sample on 0th iteration
Expand All @@ -938,7 +942,7 @@ def get_lr(it: int) -> float:
# micro-batch loop where we do gradient accumulation to reach desired total batch size
lossf = Tensor(
[0.0]
) # for getting the mean loss (as simple float) over the accumulation steps
).to(device) # for getting the mean loss (as simple float) over the accumulation steps
for micro_step in range(grad_accum_steps):
# fetch a batch
bat = next(train_loader)["input_ids"].to(torch.int)
Expand All @@ -958,7 +962,7 @@ def get_lr(it: int) -> float:
# addition of gradients corresponds to a SUM in the objective, but
# instead of a SUM we want MEAN, so we scale the loss here
loss = loss / grad_accum_steps
lossf += loss.item() # keep track of the mean loss
lossf += loss.detach() # keep track of the mean loss

# backward pass
if not args.inference_only:
Expand Down Expand Up @@ -989,13 +993,14 @@ def get_lr(it: int) -> float:
f"step {step:4d}/{args.num_iterations} | train loss {lossf:.6f} | norm {norm:.4f} | lr {lr:.2e} | ({(t1-t0)*1000:.2f} ms | {tokens_per_second:.0f} tok/s)"
)
# log to wandb
log_metrics(
step,
{
"train_loss": lossf,
"lr": lr,
},
)
if args.wandb:
log_metrics(
step,
{
"train_loss": lossf,
"lr": lr,
},
)
# log to logile
if master_process and logfile is not None:
with open(logfile, "a") as f:
Expand Down

0 comments on commit 026ae35

Please sign in to comment.