Skip to content

Commit

Permalink
Merge pull request #10 from danbraunai/dataloader
Browse files Browse the repository at this point in the history
Wrote simple streaming dataloader
  • Loading branch information
lennart-finke authored Oct 17, 2024
2 parents 3fd1e94 + 2cf917b commit a932f89
Show file tree
Hide file tree
Showing 7 changed files with 3,574 additions and 134 deletions.
15 changes: 3 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,9 @@ make test-all # Run all tests
```

## Usage
First, download the TinyShakespeare dataset:
```bash
python simple_stories_train/tinyshakespeare.py
```

Then, train llama on it:
`python simple_stories_train/train_llama.py --model d12 --input_bin
simple_stories_train/tinyshakespeare/tiny_shakespeare_train.bin --input_val_bin
simple_stories_train/tinyshakespeare/tiny_shakespeare_val.bin --wandb_project
simple-stories-training --val_loss_every 3 --sample_every 3`
Training a simple model:
`python simple_stories_train/train_llama.py --model d2 --sequence_length 1024 --total_batch_size=4096`

You may be asked to enter your wandb API key. You can find it in your [wandb account
settings](https://wandb.ai/settings). Alternatively, to avoid entering your API key on program
execution, you can set the environment variable `WANDB_API_KEY` to your API key, or put it in a
You may be asked to enter your wandb API key. You can find it in your [wandb account settings](https://wandb.ai/settings). Alternatively, to avoid entering your API key on program execution, you can set the environment variable `WANDB_API_KEY` to your API key, or put it in a
`.env` file under the root of the repository.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ dependencies = [
"ipykernel",
"jaxtyping",
"tiktoken",
"transformers"
"transformers",
"datasets"
]

[project.optional-dependencies]
Expand Down
202 changes: 202 additions & 0 deletions simple_stories_train/dataloaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
from typing import Any

import numpy as np
import torch
from datasets import Dataset, IterableDataset, load_dataset
from datasets.distributed import split_dataset_by_node
from numpy.typing import NDArray
from pydantic import BaseModel
from tokenizers import Tokenizer
from torch.utils.data import DataLoader

"""
The bulk of this file is copied from https://github.com/ApolloResearch/e2e_sae
licensed under MIT, (c) 2024 ApolloResearch.
"""

class DatasetConfig(BaseModel):
dataset_name: str
is_tokenized: bool = True
tokenizer_file_path: str
streaming: bool = True
split: str
n_ctx: int
seed: int | None = None
column_name: str = "input_ids"
ddp_rank: int = 0
ddp_world_size: int = 1
"""The name of the column in the dataset that contains the data (tokenized or non-tokenized).
Typically 'input_ids' for datasets stored with e2e_sae/scripts/upload_hf_dataset.py, or "tokens"
for datasets tokenized in TransformerLens (e.g. NeelNanda/pile-10k)."""


def _keep_single_column(dataset: Dataset, col_name: str) -> Dataset:
"""
Acts on a HuggingFace dataset to delete all columns apart from a single column name - useful
when we want to tokenize and mix together different strings.
"""
for key in dataset.features: # pyright: ignore[reportAttributeAccessIssue]
if key != col_name:
dataset = dataset.remove_columns(key)
return dataset


def tokenize_and_concatenate(
dataset: Dataset,
tokenizer: Tokenizer,
max_length: int = 1024,
column_name: str = "story",
add_bos_token: bool = False,
num_proc: int = 10,
to_lower: bool = True,
) -> Dataset:
"""Helper function to tokenizer and concatenate a dataset of text. This converts the text to
tokens, concatenates them (separated by EOS tokens) and then reshapes them into a 2D array of
shape (____, sequence_length), dropping the last batch. Tokenizers are much faster if
parallelised, so we chop the string into 20, feed it into the tokenizer, in parallel with
padding, then remove padding at the end.
NOTE: Adapted from
https://github.com/TransformerLensOrg/TransformerLens/blob/main/transformer_lens/utils.py#L267
to handle IterableDataset.
TODO: Fix typing of tokenizer
This tokenization is useful for training language models, as it allows us to efficiently train
on a large corpus of text of varying lengths (without, eg, a lot of truncation or padding).
Further, for models with absolute positional encodings, this avoids privileging early tokens
(eg, news articles often begin with CNN, and models may learn to use early positional
encodings to predict these)
Args:
dataset: The dataset to tokenize, assumed to be a HuggingFace text dataset. Can be a regular
Dataset or an IterableDataset.
tokenizer: The tokenizer. Assumed to have a bos_token_id and an eos_token_id.
max_length: The length of the context window of the sequence. Defaults to 1024.
column_name: The name of the text column in the dataset. Defaults to 'text'.
add_bos_token: Add BOS token at the beginning of each sequence. Defaults to False as this
is not done during training.
Returns:
Dataset or IterableDataset: Returns the tokenized dataset, as a dataset of tensors, with a
single column called "input_ids".
Note: There is a bug when inputting very small datasets (eg, <1 batch per process) where it
just outputs nothing. I'm not super sure why
"""
dataset = _keep_single_column(dataset, column_name)
seq_len = max_length - 1 if add_bos_token else max_length

def tokenize_function(
examples: dict[str, list[str]],
) -> dict[
str,
NDArray[np.signedinteger[Any]],
]:
padding_token_id = 0
tokenizer.enable_padding(pad_id=padding_token_id)

text = examples[column_name]
# Concatenate all the text into a single string, separated by EOS tokens
full_text = "[EOS]".join(text)

# Split the text into chunks for parallel tokenization
num_chunks = 20
chunk_length = (len(full_text) - 1) // num_chunks + 1
chunks = [full_text[i * chunk_length : (i + 1) * chunk_length] for i in range(num_chunks)]

# Tokenize the chunks using the Tokenizer library
if to_lower:
chunks = [chunk.lower().replace("[eos]", "[EOS]") for chunk in chunks]
tokens = [
tokenizer.encode(chunk).ids for chunk in chunks
] # Get token IDs for each chunk
tokens = np.concatenate(tokens) # Flatten the list of token IDs

# Drop padding tokens (if applicable)
tokens = tokens[tokens != padding_token_id]

# Calculate number of batches and adjust the tokens accordingly
num_tokens = len(tokens)
num_batches = num_tokens // seq_len
tokens = tokens[: seq_len * num_batches]

# Reshape tokens into batches
tokens = tokens.reshape((num_batches, seq_len))

# Optionally, add BOS token at the beginning of each sequence
if add_bos_token:
bos_token_id = tokenizer.token_to_id("[BOS]")
prefix = np.full((num_batches, 1), bos_token_id)
tokens = np.concatenate([prefix, tokens], axis=1)

return {"input_ids": tokens}

# Apply the tokenization function to the dataset
if isinstance(dataset, IterableDataset):
tokenized_dataset = dataset.map(
tokenize_function, batched=True, remove_columns=[column_name]
)
else:
tokenized_dataset = dataset.map(
tokenize_function, batched=True, remove_columns=[column_name], num_proc=num_proc
)

tokenized_dataset = tokenized_dataset.with_format("torch")

return tokenized_dataset



def create_data_loader(
dataset_config: DatasetConfig, batch_size: int, buffer_size: int = 1000, global_seed: int = 0
) -> tuple[DataLoader[Any], Tokenizer]:
"""Create a DataLoader for the given dataset.
Args:
dataset_config: The configuration for the dataset.
batch_size: The batch size.
buffer_size: The buffer size for streaming datasets.
global_seed: Used for shuffling if dataset_config.seed is None.
Returns:
A tuple of the DataLoader and the tokenizer.
"""
dataset = load_dataset(
dataset_config.dataset_name, streaming=dataset_config.streaming, split=dataset_config.split
)
seed = dataset_config.seed if dataset_config.seed is not None else global_seed
if dataset_config.streaming:
assert isinstance(dataset, IterableDataset)
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
else:
dataset = dataset.shuffle(seed=seed)
dataset = split_dataset_by_node(dataset, dataset_config.ddp_rank, dataset_config.ddp_world_size) # type: ignore

tokenizer = Tokenizer.from_file(dataset_config.tokenizer_file_path)

torch_dataset: Dataset
if dataset_config.is_tokenized:
torch_dataset = dataset.with_format("torch") # type: ignore
# Get a sample from the dataset and check if it's tokenized and what the n_ctx is
# Note that the dataset may be streamed, so we can't just index into it
sample = next(iter(torch_dataset))[dataset_config.column_name] # type: ignore
assert (
isinstance(sample, torch.Tensor) and sample.ndim == 1
), "Expected the dataset to be tokenized."
assert len(sample) == dataset_config.n_ctx, "n_ctx does not match the tokenized length."

else:
torch_dataset = tokenize_and_concatenate(
dataset, # type: ignore
tokenizer,
max_length=dataset_config.n_ctx,
add_bos_token=False,
)

loader = DataLoader[Any](
torch_dataset, # type: ignore
batch_size=batch_size,
shuffle=False
)
return loader, tokenizer
119 changes: 119 additions & 0 deletions simple_stories_train/tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""
TODO: Use the finalized tokenizer.
This file is copied almost 1:1 from Nix Goldowsky-Dill's adaption of the tokenizer in https://github.com/juand-r/tiny_tokenizer.
"""

from itertools import chain
from pathlib import Path

from datasets import Dataset, DatasetDict, load_dataset
from tokenizers import Tokenizer, pre_tokenizers
from tokenizers.models import WordPiece
from tokenizers.pre_tokenizers import Digits, Punctuation, Whitespace
from tokenizers.trainers import WordPieceTrainer

from dataloaders import DatasetConfig, create_data_loader

OUT_DIR = Path("tokenizers")


def clean_dataset(dataset="lennart-finke/SimpleStories") -> DatasetDict:
dataset = load_dataset(dataset)
trans = str.maketrans(
{"\u201d": '"', "\u201c": '"', "\u2019": "'", "\u2018": "'", "\u2014": "-", "\u2026": "..."}
)
cleaned = [
s.translate(trans).encode("ascii", "ignore").decode("ascii").lower()
for s in dataset["train"]["story"] # pyright: ignore
]

n_train = int(len(cleaned) * 0.9)
train, validation = cleaned[:n_train], cleaned[n_train:]

train_ds = Dataset.from_dict(dict(story=train))
validation_ds = Dataset.from_dict(dict(story=validation))

return DatasetDict({"train": train_ds, "validation": validation_ds})


def train_tokenizer(vocab_size=4096, dataset="lennart-finke/SimpleStories") -> None:
data = clean_dataset(dataset=dataset)["train"]["story"]

# Split almost all types into individual tokens
pre_tokenizer = pre_tokenizers.Sequence(
[Punctuation(), Whitespace(), Digits(individual_digits=True)]
)

# The tokenizer itself, being a WordPiece tokenizer, which is generally smaller than a byte pair encoding
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]")) # pyright: ignore
tokenizer.pre_tokenizer = pre_tokenizer # pyright: ignore

print("training tokenizer...")
# Train the tokenizer
trainer = WordPieceTrainer(vocab_size=vocab_size, special_tokens=["[UNK]", "[EOS]"])
tokenizer.train_from_iterator(data, trainer=trainer, length=len(data))
OUT_DIR.mkdir(parents=True, exist_ok=True)
tokenizer.save(f"{OUT_DIR}/stories-{vocab_size}.json")

def test_tokenizer(filepath : str, dataset : str = "lennart-finke/SimpleStories") -> None:
dataset_name = dataset
split = "train"

context_width = 512
dataset_config = DatasetConfig(
dataset_name=dataset_name,
is_tokenized=False,
tokenizer_file_path=filepath,
streaming=True,
split=split,
n_ctx=context_width,
seed=42,
column_name="story",
ddp_rank=0,
ddp_world_size=1,
)

batch_size = 1
buffer_size = 1000
global_seed = 0

loader, tokenizer = create_data_loader(dataset_config, batch_size, buffer_size, global_seed)
batch = next(iter(loader))
words = tokenizer.decode_batch(batch["input_ids"].tolist(), skip_special_tokens=False)
print(words)

def load_tokenizer(vocab_size=3072) -> Tokenizer:
return Tokenizer.from_file(f"{OUT_DIR}/stories-{vocab_size}.json")


def print_split_words(story_tokens: list[str]) -> None:
for i, token in enumerate(story_tokens):
if token.startswith("##") and (i == 0 or not story_tokens[i - 1].startswith("##")):
word_start = story_tokens[i - 1] if i > 0 else ""
word_parts = [token]

for next_token in story_tokens[i + 1 :]:
if not next_token.startswith("##"):
break
word_parts.append(next_token)

print(f"{word_start} {' '.join(word_parts)}")


def analysis(vocab_size=3072) -> None:
tokenizer = load_tokenizer(vocab_size)
stories = clean_dataset(dataset="lennart-finke/SimpleStories")["validation"]["story"]
tokenized_stories = [tokenizer.encode(story).tokens for story in stories]

all_tokens = list(chain.from_iterable(tokenized_stories))
partial_word_toks = len([token for token in all_tokens if token.startswith("##")])

print(f"Tokens per story: {len(all_tokens) / len(stories):.2f}")
print(f"Proportion of partial word tokens: {partial_word_toks / len(all_tokens):.2%}")

for story_idx, story_tokens in enumerate(tokenized_stories[:25]):
print(f"\nStory {story_idx}")
print_split_words(story_tokens)

if __name__ == "__main__":
test_tokenizer("simple_stories_train/tokenizer/stories-3072.json")
Loading

0 comments on commit a932f89

Please sign in to comment.