Skip to content

Commit

Permalink
Major refactor (inc adding Pydantic) (#16)
Browse files Browse the repository at this point in the history
* Ignore ruff forward annotation warning

* Use pydantic and split up train script

* Add .env to .gitignore

* Add default d12 config

* Fix validation

* Add dotenv for wandb

* Allow for cli override args and fix ddp

* Add warning about cpu runs

---------

Co-authored-by: Dan Braun <dan@apolloresearch.ai>
  • Loading branch information
danbraunai and danbraunai-apollo authored Nov 25, 2024
1 parent 7bd136e commit a088f43
Show file tree
Hide file tree
Showing 15 changed files with 920 additions and 859 deletions.
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
WANDB_ENTITY=
WANDB_API_KEY=
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
.env
**/out/
tinyshakespeare/
gpt2_tokenizer.bin

Expand Down
3 changes: 0 additions & 3 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@
"request": "launch",
"program": "${workspaceFolder}/simple_stories_train/train_llama.py",
"args": [
"--model", "d12",
"--input_bin", "${workspaceFolder}/simple_stories_train/tinyshakespeare/tiny_shakespeare_val.bin",
"--device", "cpu"
],
"console": "integratedTerminal",
"justMyCode": true,
Expand Down
34 changes: 27 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
# simple_stories_train

Project for training small LMs. Designed for training on SimpleStories, an extension of [TinyStories](https://arxiv.org/abs/2305.07759).
Project for training small LMs. Designed for training on SimpleStories, an extension of
[TinyStories](https://arxiv.org/abs/2305.07759).


- Training script is based on the efficeint [train_gpt2.py](https://github.com/karpathy/llm.c/blob/master/train_gpt2.py) in [llm.c](https://github.com/karpathy/llm.c) (licensed
under MIT ((c) 2024 Andrei Karpathy))
- Some model architecture implementations are based on
[TransformerLens](https://github.com/TransformerLensOrg/TransformerLens) (licensed under
MIT ((c) 2022 TransformerLensOrg)).

## Installation

Expand Down Expand Up @@ -28,11 +36,23 @@ make test-all # Run all tests

## Usage

Training a simple model:
`python simple_stories_train/train_llama.py --model d2 --sequence_length 1024 --total_batch_size=4096`
### Training a model
```
python train_llama.py [PATH/TO/CONFIG.yaml] [--key1 value1 --key2 value2 ...]
```
where
- `PATH/TO/CONFIG.yaml` contains the training config. If no path is provided, a default config will be used.
- `--key1 value1 --key2 value2 ...` override values in the config. Note that if you wish to update a
nested value, you must use dotted notation (e.g. `--train_dataset_config.name my_dataset`).

For a final model, we currently (intend to) run:
`torchrun --standalone --nproc_per_node=8 simple_stories_train/train_llama.py --model d24 --sequence_length 1024 --total_batch_size=16448 --compile 1 --tensorcores=1 --dtype=bfloat16 --wandb 1`
If running on CPU, you may need to set `--compile=False`.

To run on multiple GPUs, use
```
torchrun --standalone --nproc_per_node=N train_llama.py ...
```
where `N` is the number of GPUs to use.

You may be asked to enter your wandb API key. You can find it in your [wandb account settings](https://wandb.ai/settings). Alternatively, to avoid entering your API key on program execution, you can set the environment variable `WANDB_API_KEY` to your API key, or put it in a
`.env` file under the root of the repository.
### Logging with Weights & Biases
To track training with Weights & Biases, you can set the WANDB_PROJECT and WANDB_API_KEY variables in
`.env`. API keys can be obtained from your [Weights & Biases account settings](https://wandb.ai/settings).
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ dependencies = [
"jaxtyping",
"tiktoken",
"transformers",
"datasets"
"datasets",
"python-dotenv",
]

[project.optional-dependencies]
Expand All @@ -36,6 +37,9 @@ packages = ["simple_stories_train"]
[tool.ruff]
line-length = 100
fix = true
ignore = [
"F722" # Incompatible with jaxtyping
]

[tool.ruff.lint]
select = [
Expand Down
29 changes: 29 additions & 0 deletions simple_stories_train/d12_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
wandb_project: simple-stories
train_dataset_config:
name: lennart-finke/SimpleStories
is_tokenized: false
tokenizer_file_path: simple_stories_train/tokenizer/stories-3072.json
split: train
n_ctx: 1024
seed: 0
column_name: story
val_dataset_config:
name: lennart-finke/SimpleStories
is_tokenized: false
tokenizer_file_path: simple_stories_train/tokenizer/stories-3072.json
split: test
n_ctx: 1024
seed: 0
column_name: story
model_name: d12
batch_size: 32
total_batch_size: 32768
num_iterations: 1000
learning_rate: 1e-4
warmup_iters: 10
learning_rate_decay_frac: 1.0
weight_decay: 0.1
grad_clip: 1.0
val_loss_every: 0
val_max_steps: 20
sample_every: 100
34 changes: 18 additions & 16 deletions simple_stories_train/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from datasets import Dataset, IterableDataset, load_dataset
from datasets.distributed import split_dataset_by_node
from numpy.typing import NDArray
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from tokenizers import Tokenizer
from torch.utils.data import DataLoader

Expand All @@ -14,17 +14,17 @@
licensed under MIT, (c) 2024 ApolloResearch.
"""


class DatasetConfig(BaseModel):
dataset_name: str
model_config = ConfigDict(extra="forbid", frozen=True)
name: str = "lennart-finke/SimpleStories"
is_tokenized: bool = True
tokenizer_file_path: str
tokenizer_file_path: str = "simple_stories_train/tokenizer/stories-3072.json"
streaming: bool = True
split: str
n_ctx: int
split: str = "train"
n_ctx: int = 1024
seed: int | None = None
column_name: str = "input_ids"
ddp_rank: int = 0
ddp_world_size: int = 1
"""The name of the column in the dataset that contains the data (tokenized or non-tokenized).
Typically 'input_ids' for datasets stored with e2e_sae/scripts/upload_hf_dataset.py, or "tokens"
for datasets tokenized in TransformerLens (e.g. NeelNanda/pile-10k)."""
Expand Down Expand Up @@ -108,9 +108,7 @@ def tokenize_function(
# Tokenize the chunks using the Tokenizer library
if to_lower:
chunks = [chunk.lower().replace("[eos]", "[EOS]") for chunk in chunks]
tokens = [
tokenizer.encode(chunk).ids for chunk in chunks
] # Get token IDs for each chunk
tokens = [tokenizer.encode(chunk).ids for chunk in chunks] # Get token IDs for each chunk
tokens = np.concatenate(tokens) # Flatten the list of token IDs

# Drop padding tokens (if applicable)
Expand Down Expand Up @@ -147,9 +145,13 @@ def tokenize_function(
return tokenized_dataset



def create_data_loader(
dataset_config: DatasetConfig, batch_size: int, buffer_size: int = 1000, global_seed: int = 0
dataset_config: DatasetConfig,
batch_size: int,
buffer_size: int = 1000,
global_seed: int = 0,
ddp_rank: int = 0,
ddp_world_size: int = 1,
) -> tuple[DataLoader[Any], Tokenizer]:
"""Create a DataLoader for the given dataset.
Expand All @@ -163,16 +165,16 @@ def create_data_loader(
A tuple of the DataLoader and the tokenizer.
"""
dataset = load_dataset(
dataset_config.dataset_name, streaming=dataset_config.streaming, split=dataset_config.split
dataset_config.name, streaming=dataset_config.streaming, split=dataset_config.split
)
seed = dataset_config.seed if dataset_config.seed is not None else global_seed
if dataset_config.streaming:
assert isinstance(dataset, IterableDataset)
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
else:
dataset = dataset.shuffle(seed=seed)
dataset = split_dataset_by_node(dataset, dataset_config.ddp_rank, dataset_config.ddp_world_size) # type: ignore
dataset = split_dataset_by_node(dataset, ddp_rank, ddp_world_size) # type: ignore

tokenizer = Tokenizer.from_file(dataset_config.tokenizer_file_path)

torch_dataset: Dataset
Expand All @@ -197,6 +199,6 @@ def create_data_loader(
loader = DataLoader[Any](
torch_dataset, # type: ignore
batch_size=batch_size,
shuffle=False
shuffle=False,
)
return loader, tokenizer
Loading

0 comments on commit a088f43

Please sign in to comment.