Skip to content

Commit

Permalink
Fix inference tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
danbraunai committed Nov 30, 2024
1 parent a088f43 commit 981657b
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 8 deletions.
3 changes: 2 additions & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
"version": "0.2.0",
"configurations": [
{
"name": "train llama",
"name": "train llama d12",
"type": "debugpy",
"request": "launch",
"program": "${workspaceFolder}/simple_stories_train/train_llama.py",
"args": [
"${workspaceFolder}/simple_stories_train/d12_config.yaml"
],
"console": "integratedTerminal",
"justMyCode": true,
Expand Down
2 changes: 2 additions & 0 deletions simple_stories_train/d12_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ train_dataset_config:
is_tokenized: false
tokenizer_file_path: simple_stories_train/tokenizer/stories-3072.json
split: train
streaming: false
n_ctx: 1024
seed: 0
column_name: story
Expand All @@ -12,6 +13,7 @@ val_dataset_config:
is_tokenized: false
tokenizer_file_path: simple_stories_train/tokenizer/stories-3072.json
split: test
streaming: false
n_ctx: 1024
seed: 0
column_name: story
Expand Down
13 changes: 6 additions & 7 deletions simple_stories_train/train_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

import fire
import numpy as np
import tiktoken
import torch
import torch._inductor.config as torch_inductor_config
import torch.distributed as dist
Expand Down Expand Up @@ -217,7 +216,7 @@ def main(config_path_or_obj: Path | str | Config | None = None, **kwargs: Any) -
torch.set_float32_matmul_precision("high")

# init (and write) the tokenizer
enc: tiktoken.core.Encoding = tiktoken.get_encoding("gpt2")
# enc: tiktoken.core.Encoding = tiktoken.get_encoding("gpt2")

model_config = MODEL_CONFIGS[config.model_name]
model = Llama(model_config)
Expand All @@ -232,7 +231,7 @@ def main(config_path_or_obj: Path | str | Config | None = None, **kwargs: Any) -
print0("compiling the model...")
model: nn.Module = torch.compile(model) # type: ignore[reportArgumentType]

train_loader, _ = create_data_loader(
train_loader, train_tokenizer = create_data_loader(
dataset_config=config.train_dataset_config,
batch_size=B,
buffer_size=1000,
Expand Down Expand Up @@ -345,19 +344,19 @@ def get_lr(it: int) -> float:
# before we end, let's also do one round of inference
# we'll kick off the generation with "<|endoftext|>", which designates the start of a
# new sequence
start_ids = [enc.eot_token]
start_ids = [train_tokenizer.token_to_id("[EOS]")]
xg = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]
max_new_tokens = 32
temperature = 1.0
top_k = 40
yg = raw_model.generate(xg, max_new_tokens, temperature=temperature, top_k=top_k)
print0("---------------")
print0(enc.decode(yg[0].tolist()))
print0(train_tokenizer.decode(yg[0].tolist()))
print0("---------------")
# log to wandb
if config.wandb_project is not None and master_process:
generations.append([step, enc.decode(yg[0].tolist())])
log_generations(step, generations)
generations.append([step, train_tokenizer.decode(yg[0].tolist())])
log_generations(step, generations)

# bit confusing: we want to make sure to eval and sample on 0th iteration
# but also after the very last iteration. so we loop for step <= num_iterations
Expand Down

0 comments on commit 981657b

Please sign in to comment.