Skip to content

Commit

Permalink
Removed dotenv dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
PC committed Oct 14, 2024
1 parent 873b84f commit 3fd1e94
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 6 deletions.
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ dependencies = [
"ipykernel",
"jaxtyping",
"tiktoken",
"transformers",
"python-dotenv"
"transformers"
]

[project.optional-dependencies]
Expand Down
4 changes: 2 additions & 2 deletions simple_stories_train/train_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
from torch.nn import functional as F
from torch.nn.parallel import DistributedDataParallel as DDP

from simple_stories_train.utils import (
from utils import (
init_wandb,
is_checkpoint_step,
log_generations,
Expand Down Expand Up @@ -1036,7 +1036,7 @@ def get_lr(it: int) -> float:
# addition of gradients corresponds to a SUM in the objective, but
# instead of a SUM we want MEAN, so we scale the loss here
loss = loss / grad_accum_steps
lossf += loss.detach() # keep track of the mean loss
lossf += loss.item() # keep track of the mean loss
# backward pass
if not args.inference_only:
loss.backward()
Expand Down
2 changes: 0 additions & 2 deletions simple_stories_train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import torch
import wandb
from dotenv import load_dotenv
from torch import nn


Expand Down Expand Up @@ -37,7 +36,6 @@ def save_model_and_config(save_dir: Path, model: nn.Module, step: int) -> None:


def init_wandb(config: Any, project: str) -> None:
load_dotenv(override=True)
wandb.init(
project=project,
config=config,
Expand Down

0 comments on commit 3fd1e94

Please sign in to comment.