Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

[EXAMPLE] Add llama finetune #923

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions alpa/parallel_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,16 +277,21 @@ def get_3d_parallel_method(num_micro_batches: int,
assert num_mesh_devices % num_devices_per_host == 0
physical_mesh_shape = (num_mesh_devices // num_devices_per_host,
num_devices_per_host)
if pipeline_parallel == num_devices:
manual_sharding_option = None

# If no pipeline parallel, degenerate into shard parallel
if pp == 1 and allow_degenerate_into_shard_parallel:
return ShardParallel(num_micro_batches=num_micro_batches,
auto_sharding_option=AutoShardingOption(
prefer_reduce_scatter=True,
force_batch_dim_to_mesh_dim=0),
devices=get_global_physical_mesh(
create_if_not_exist=True).get_logical_mesh(
[data_parallel, operator_parallel]))
return ShardParallel(
num_micro_batches=num_micro_batches,
auto_sharding_option=AutoShardingOption(
enable_auto_sharding=manual_sharding_option is None,
prefer_reduce_scatter=True,
force_batch_dim_to_mesh_dim=0),
devices=get_global_physical_mesh(
create_if_not_exist=True).get_logical_mesh(
[data_parallel, operator_parallel]),
manual_sharding_option=manual_sharding_option)

# Return pipeshard parallel
if manual_layer_num is not None:
Expand Down
16 changes: 9 additions & 7 deletions alpa/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1663,20 +1663,22 @@ def compute_gpt_tflops(batch_size,
num_gpus,
latency,
backward=True,
checkpoint_activations=False):
checkpoint_activations=False,
intermediate_size=None):
"""
Compute the Tera Flop Operations (TFLOP) per second per GPU
for GPT-like models.
"""
factor = 24
factor = 2
if backward:
factor += 48
factor += 4
if checkpoint_activations:
factor += 24
factor += 2
if intermediate_size is None:
intermediate_size = hidden_size * 4

total_flop = (factor * batch_size * seq_len *
(hidden_size**2) * num_layers * (1 + seq_len /
(6 * hidden_size)) +
total_flop = ((factor * num_layers * batch_size * seq_len * hidden_size *
(4 * hidden_size + 2 * intermediate_size + 2 * seq_len)) +
6 * batch_size * seq_len * hidden_size * vocab_size)
# Note: The above formula does not count the first embedding table lookup
# because it is a sparse operation.
Expand Down
61 changes: 61 additions & 0 deletions examples/llama_finetune/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
This script needs some monkey-patches on the original EasyLM's model definition:

##### Fix Import Errors

EasyLM is based on jax 0.4, while this branch is tested on jax 0.3.22. Some import errors needs to be fixed:

```
--- a/EasyLM/jax_utils.py
+++ b/EasyLM/jax_utils.py
@@ -10,8 +10,8 @@ import dill
import flax
import jax
import jax.numpy as jnp
-from jax.sharding import PartitionSpec as PS
-from jax.sharding import Mesh
+from jax.experimental.pjit import PartitionSpec as PS
+from jax.interpreters.pxla import Mesh
from jax.experimental.pjit import with_sharding_constraint as _with_sharding_constraint
from jax.experimental.pjit import pjit
from jax.interpreters import pxla
```

```
--- a/EasyLM/models/llama/llama_model.py
+++ b/EasyLM/models/llama/llama_model.py
@@ -8,7 +8,7 @@ import numpy as np
import jax
import jax.numpy as jnp
from jax import lax
-from jax.sharding import PartitionSpec as PS
+from jax.experimental.pjit import PartitionSpec as PS
import flax.linen as nn
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
```

##### Support mark pipeline boundary
We use manual pipeline boundary, though the auto one works in most cases. So we add a marker at the end of each layer.

Will monkey patch it in the training script later.

```
--- a/EasyLM/models/llama/llama_model.py
+++ b/EasyLM/models/llama/llama_model.py
@@ -31,6 +31,7 @@ from mlxu import function_args_to_config, load_pickle, open_file
from EasyLM.jax_utils import (
with_sharding_constraint, get_jax_mesh, get_gradient_checkpoint_policy
)
+from alpa import mark_pipeline_boundary


LLAMA_STANDARD_CONFIGS = {
@@ -829,6 +830,7 @@ class FlaxLLaMABlockCollection(nn.Module):
output_attentions,
fcm_mask,
)
+ mark_pipeline_boundary()
hidden_states = layer_outputs[0]

if output_attentions:
```
142 changes: 142 additions & 0 deletions examples/llama_finetune/hf_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import json
from typing import Dict

from datasets import Dataset
import numpy as np
import transformers

from fastchat.conversation import get_default_conv_template, SeparatorStyle


def preprocess(sources, tokenizer: transformers.PreTrainedTokenizer,
ignore_token_id) -> Dict:
conv = get_default_conv_template("vicuna").copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]

conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())

# Tokenize conversations
input_ids = tokenizer(
conversations,
return_tensors="np",
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = np.copy(input_ids)

assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO

# Mask targets
sep = conv.sep + conv.roles[1] + ": "
for conversation, target in zip(conversations, targets):
total_len = int((target != tokenizer.pad_token_id).sum())

rounds = conversation.split(conv.sep2)
cur_len = 1
target[:cur_len] = ignore_token_id
for i, rou in enumerate(rounds):
if rou == "":
break

parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 2

target[cur_len:cur_len + instruction_len] = ignore_token_id

cur_len += round_len
target[cur_len:] = ignore_token_id

if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = ignore_token_id
print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)")

return dict(
input_ids=input_ids,
labels=targets,
attention_mask=np.array(input_ids != tokenizer.pad_token_id),
)


class LazySupervisedDataset:
"""Dataset for supervised fine-tuning."""

def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer,
ignore_token_id):
super(LazySupervisedDataset, self).__init__()
print("Formatting inputs...Skip in lazy mode")
self.tokenizer = tokenizer
self.raw_data = raw_data
self.cached_data_dict = {}
self.ignore_token_id = ignore_token_id

def __len__(self):
return len(self.raw_data)

def __getitem__(self, i):
if i in self.cached_data_dict:
return self.cached_data_dict[i]

ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer,
self.ignore_token_id)
ret = dict(
input_ids=ret["input_ids"][0],
labels=ret["labels"][0],
attention_mask=ret["attention_mask"][0],
)
self.cached_data_dict[i] = ret

return ret

def iter(self):

def gen():
for i in range(len(self)):
yield self[i]

return gen


def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
data_path, ignore_token_id) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
print("Loading data...")
raw_data = json.load(open(data_path, "r"))

# Split train/test
perm = np.random.permutation(len(raw_data))
split = int(len(perm) * 0.98)
train_indices = perm[:split]
eval_indices = perm[split:]
train_raw_data = [raw_data[i] for i in train_indices]
eval_raw_data = [raw_data[i] for i in eval_indices]
print(f"#train {len(train_raw_data)}, #eval {len(eval_raw_data)}")

train_dataset = LazySupervisedDataset(train_raw_data,
tokenizer=tokenizer,
ignore_token_id=ignore_token_id)
eval_dataset = LazySupervisedDataset(eval_raw_data,
tokenizer=tokenizer,
ignore_token_id=ignore_token_id)
train_dataset = Dataset.from_generator(train_dataset.iter())
eval_dataset = Dataset.from_generator(eval_dataset.iter())
return dict(train_dataset=train_dataset, eval_dataset=eval_dataset)
45 changes: 45 additions & 0 deletions examples/llama_finetune/hf_jax_conversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import transformers

def import_hf_model(model_name_or_path):
model = transformers.AutoModelForCausalLM.from_pretrained(
model_name_or_path,
)
return model

def hf_to_jax_weight(hf_model):
state_dict = hf_model.state_dict()
num_heads = hf_model.config.num_attention_heads
dim = hf_model.config.hidden_size
# inverse function of EasyLM's convert_easylm_to_hf.write_model.permute
def inv_permute(w):
return w.reshape(num_heads, 2, dim // num_heads // 2, dim).transpose(1, 2).reshape(dim, dim)
jax_weights = {
'transformer': {
'wte': {'embedding': state_dict['model.embed_tokens.weight'].numpy()},
'ln_f': {'kernel': state_dict['model.norm.weight'].numpy()},
'h': {
'%d' % (layer): {
'attention': {
'wq': {'kernel': inv_permute(state_dict['model.layers.%d.self_attn.q_proj.weight' % (layer)]).numpy().transpose()},
'wk': {'kernel': inv_permute(state_dict['model.layers.%d.self_attn.k_proj.weight' % (layer)]).numpy().transpose()},
'wv': {'kernel': state_dict['model.layers.%d.self_attn.v_proj.weight' % (layer)].numpy().transpose()},
'wo': {'kernel': state_dict['model.layers.%d.self_attn.o_proj.weight' % (layer)].numpy().transpose()},
},
'feed_forward': {
'w1': {'kernel': state_dict['model.layers.%d.mlp.gate_proj.weight' % (layer)].numpy().transpose()},
'w2': {'kernel': state_dict['model.layers.%d.mlp.down_proj.weight' % (layer)].numpy().transpose()},
'w3': {'kernel': state_dict['model.layers.%d.mlp.up_proj.weight' % (layer)].numpy().transpose()},
},
'attention_norm': {'kernel': state_dict['model.layers.%d.input_layernorm.weight' % (layer)].numpy()},
'ffn_norm': {'kernel': state_dict['model.layers.%d.post_attention_layernorm.weight' % (layer)].numpy()},
}
for layer in range(hf_model.config.num_hidden_layers)},
},
'lm_head': {'kernel': state_dict["lm_head.weight"].numpy().transpose()},
}
return jax_weights

if __name__ == "__main__":
hf_model = import_hf_model("./llama-7b")
jax_params = hf_to_jax(hf_model)
# EasyLM uses fout.write(flax.serialization.msgpack_serialize(jax_weights, in_place=True)) to store the param
18 changes: 18 additions & 0 deletions examples/llama_finetune/monkey_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from functools import partial

import jax
import jax.numpy as jnp

from EasyLM.models.llama.llama_model import FlaxLLaMAForCausalLMModule


def do_monkey_patch():
# TODO: jax 0.3.22 does not support eval shape with static args well. Remove
# after rebasing to jax 0.4, use the model's _do_init=False then.
def init_dummy(self, *args, **kwargs):
avals = jax.eval_shape(partial(self._backup_init, **kwargs), *args)
return jax.tree_util.tree_map(lambda x: jnp.full(x.shape, 1e-8, x.dtype),
avals)
if not hasattr(FlaxLLaMAForCausalLMModule, "_backup_init"):
FlaxLLaMAForCausalLMModule._backup_init = FlaxLLaMAForCausalLMModule.init
FlaxLLaMAForCausalLMModule.init = init_dummy
Loading