Skip to content

Commit

Permalink
Merge pull request #4 from ZJLab-DataHub-Security/qianhao_dev
Browse files Browse the repository at this point in the history
rename variables
  • Loading branch information
qianhao0713 authored Jul 15, 2024
2 parents 37d097a + 4a4ea30 commit 92554c2
Show file tree
Hide file tree
Showing 9 changed files with 55 additions and 53 deletions.
9 changes: 5 additions & 4 deletions Llama3-70B.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ NGPUS=${NGPUS:-8}
WORLD_SIZE=${WORLD_SIZE:-1}
NUM_PROCESSES=$[${NGPUS}*$[WORLD_SIZE]]
SEQ_LEN=${SEQ_LEN:-32768}
SP_GROUP_SIZE=${SP_GROUP_SIZE:-1}
SP_SIZE=${SP_SIZE:-1}
BATCH_SIZE=${BATCH_SIZE:-1}
export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024'
export WANDB_DISABLED=true
export NCCL_DEBUG=WARN
echo ${RANK}/$[WORLD_SIZE]
if [ ${MASTER_ADDR} == 'localhost' ]; then
export MASTER_ADDR=`hostname -i`
Expand All @@ -30,12 +31,12 @@ src/train.py \
--do_train \
--finetuning_type full \
--parallel_mode dist_flash_attn \
--seq_parallel_size ${SP_GROUP_SIZE} \
--sp_size ${SP_SIZE} \
--deepspeed examples/deepspeed/ds_z3_offload_config.json \
--dataset long_sft_128k \
--template llama3 \
--cutoff_len ${SEQ_LEN} \
--max_samples 1000 \
--max_steps 10 \
--overwrite_cache \
--preprocessing_num_workers 16 \
--output_dir ./output/70B_32K_bs_1M_rope_1M_step_1000_lr_2e-5 \
Expand All @@ -46,7 +47,7 @@ src/train.py \
--per_device_train_batch_size ${BATCH_SIZE} \
--gradient_accumulation_steps 4 \
--learning_rate 2e-5 \
--num_train_epochs 3.0 \
--num_train_epochs 1.0 \
--lr_scheduler_type cosine \
--warmup_ratio 0.1 \
--bf16 \
Expand Down
4 changes: 2 additions & 2 deletions Llama3-8B.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ NGPUS=${NGPUS:-8}
WORLD_SIZE=${WORLD_SIZE:-1}
NUM_PROCESSES=$[${NGPUS}*$[WORLD_SIZE]]
SEQ_LEN=${SEQ_LEN:-1024}
SP_GROUP_SIZE=${SP_GROUP_SIZE:-1}
SP_SIZE=${SP_SIZE:-1}
BATCH_SIZE=${BATCH_SIZE:-1}
export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024'
export WANDB_DISABLED=true
Expand All @@ -31,7 +31,7 @@ src/train.py \
--finetuning_type full \
--lora_target all \
--parallel_mode dist_flash_attn \
--seq_parallel_size ${SP_GROUP_SIZE} \
--sp_size ${SP_SIZE} \
--deepspeed examples/deepspeed/ds_z3_offload_config.json \
--dataset alpaca_en \
--template llama3 \
Expand Down
48 changes: 24 additions & 24 deletions src/llamafactory/data/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class SeqParallelDataCollator(DataCollatorForSeq2Seq):
Data collator for sequence parallel in supervised finetune(sft) stage.
"""
seq_algo: str = "data_parallel",
seq_parallel_size: int = -1
sp_size: int = -1
rank: int = 0
world_size: int = 8
device: Optional[Any] = None
Expand All @@ -99,15 +99,15 @@ def __call__(self, features: Sequence[Dict[str, Any]], return_tensors=None) -> D
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
labels = batch["labels"]
seq_rank = self.rank
seq_worlds_size = self.world_size
if self.seq_parallel_size != -1:
dp_rank = self.rank // self.seq_parallel_size
seq_rank = self.rank % self.seq_parallel_size
seq_worlds_size = self.seq_parallel_size
world_size = self.world_size
sp_rank = self.rank
if self.sp_size != -1:
dp_rank = self.rank // self.sp_size
sp_rank = self.rank % self.sp_size
world_size = self.sp_size
bs = len(input_ids)
data_group_size = self.world_size // self.seq_parallel_size
group_bs = bs // data_group_size
dp_size = self.world_size // self.sp_size
group_bs = bs // dp_size
input_ids = input_ids[dp_rank * group_bs: (dp_rank + 1) * group_bs]
attention_mask = attention_mask[dp_rank * group_bs: (dp_rank + 1) * group_bs]
labels = labels[dp_rank * group_bs: (dp_rank + 1) * group_bs]
Expand All @@ -116,8 +116,8 @@ def __call__(self, features: Sequence[Dict[str, Any]], return_tensors=None) -> D
attention_mask=attention_mask,
position_ids=None,
labels=labels,
rank=seq_rank,
world_size=seq_worlds_size,
rank=sp_rank,
world_size=world_size,
device=self.device)
return batch

Expand All @@ -138,27 +138,27 @@ def __call__(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dic
batch = super().__call__(examples)
if self.seq_algo == "data_parallel":
return batch
seq_rank = self.rank
seq_worlds_size = self.world_size
if self.seq_parallel_size != -1:
dp_rank = self.rank // self.seq_parallel_size
seq_rank = self.rank % self.seq_parallel_size
seq_worlds_size = self.seq_parallel_size
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
labels = batch["labels"]
world_size = self.world_size
sp_rank = self.rank
if self.sp_size != -1:
dp_rank = self.rank // self.sp_size
sp_rank = self.rank % self.sp_size
world_size = self.sp_size
bs = len(input_ids)
data_group_size = self.world_size // self.seq_parallel_size
group_bs = bs // data_group_size
dp_size = self.world_size // self.sp_size
group_bs = bs // dp_size
input_ids = input_ids[dp_rank * group_bs: (dp_rank + 1) * group_bs]
attention_mask = attention_mask[dp_rank * group_bs: (dp_rank + 1) * group_bs]
labels = labels[dp_rank * group_bs: (dp_rank + 1) * group_bs]
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
labels = batch["labels"]
batch = prepare_seq_parallel_sft_inputs(self.seq_algo,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=None,
labels=labels,
rank=seq_rank,
world_size=seq_worlds_size,
rank=sp_rank,
world_size=world_size,
device=self.device)
return batch
4 changes: 2 additions & 2 deletions src/llamafactory/easy_context/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def prepare_seq_parallel_sft_inputs(
raise ValueError(f"Invalid seq_algo: {seq_algo}")

def apply_seq_parallel_monkey_patch(
seq_algo, model,seq_parallel_size=None
seq_algo, model, sp_size=None
):
assert seq_algo in ["zigzag_ring_attn", "dist_flash_attn", "ulysses_attn", "data_parallel"], f"Invalid seq_algo: {seq_algo}"
assert model in ["llama", "mistral"], f"Invalid model: {model}"
Expand All @@ -75,7 +75,7 @@ def apply_seq_parallel_monkey_patch(
elif seq_algo == "zigzag_ring_attn" and model == "mistral":
apply_zigzag_ring_attn_monkey_patch_mistral()
elif seq_algo == "dist_flash_attn" and model == "llama":
apply_dist_flash_attn_monkey_patch_llama(seq_parallel_size=seq_parallel_size)
apply_dist_flash_attn_monkey_patch_llama(sp_size=sp_size)
elif seq_algo == "ulysses_attn" and model == "llama":
apply_ulysses_attn_monkey_patch_llama()
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
_bwd_send_volume = 0
_bwd_recv_volume = 0

def initialize_distributed(sequence_parallel_size=None):
def initialize_distributed(sp_size=None):
if dist.is_initialized():
if dist.get_rank() == 0:
print(
Expand All @@ -55,7 +55,7 @@ def initialize_distributed(sequence_parallel_size=None):
global_world_size = dist.get_world_size()
torch.cuda.set_device(dist.get_rank() % local_world_size)

_initialize_sequence_parallel(sequence_parallel_size=sequence_parallel_size)
_initialize_sequence_parallel(sp_size)
# create_nccl_communicators()

def _initialize_sequence_parallel(sequence_parallel_size=None):
Expand Down
4 changes: 2 additions & 2 deletions src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ def custom_forward(*inputs):
)


def apply_dist_flash_attn_monkey_patch_llama(seq_parallel_size=None):
initialize_distributed(sequence_parallel_size=seq_parallel_size)
def apply_dist_flash_attn_monkey_patch_llama(sp_size=None):
initialize_distributed(sp_size=sp_size)
transformers.models.llama.modeling_llama.LlamaModel.forward = forward
transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = llama_layer_forward
4 changes: 2 additions & 2 deletions src/llamafactory/hparams/finetuning_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,10 +316,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default="data_parallel",
metadata={"help": "which sequence parallel mode to use."},
)
seq_parallel_size: int = field(
sp_size: int = field(
default=-1,
metadata={
"help": "used for use seq_parallel and data_parallel simultaneously, -1 for all gpus parallels in sequence_length axis, n for n_gpus makes a sequence_parallel group"
"help": "allow using seq_parallel and data_parallel simultaneously, -1 for all gpus parallels in sequence_length axis, n for n_gpus makes a sequence_parallel group"
}
)

Expand Down
25 changes: 13 additions & 12 deletions src/llamafactory/train/sft/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,14 @@ def save_predictions(self, predict_results: "PredictionOutput") -> None:
writer.write("\n".join(res))

class CustomSeqParallelTrainer(CustomSeq2SeqTrainer):
from transformers.trainer import _is_peft_model, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES

def compute_loss(self, model, inputs, return_outputs=False):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Subclass and override for custom behavior.
"""
from transformers.trainer import _is_peft_model, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
if self.label_smoother is not None and "labels" in inputs:
labels = inputs.pop("labels")
else:
Expand Down Expand Up @@ -173,7 +174,7 @@ def compute_loss(self, model, inputs, return_outputs=False):
if self.finetuning_args.parallel_mode== "data_parallel":
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
else:
sp_size = self.finetuning_args.seq_parallel_size
sp_size = self.finetuning_args.sp_size
loss_fn = CrossEntropyLoss(reduction='sum')
labels = inputs.pop("labels")
logits = outputs["logits"] if isinstance(outputs, dict) else outputs[1]
Expand Down Expand Up @@ -229,12 +230,12 @@ def get_train_dataloader(self) -> DataLoader:
dataloader_params["worker_init_fn"] = seed_worker
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
if hasattr(data_collator, "seq_algo") and data_collator.seq_algo != "data_parallel":
seq_parallel_size = self.finetuning_args.seq_parallel_size
if seq_parallel_size != -1:
sp_size = self.finetuning_args.sp_size
if sp_size != -1:
world_size = int(os.environ['WORLD_SIZE'])
assert seq_parallel_size != 0 and world_size % seq_parallel_size == 0, f"world_size: {world_size} should be devide by seq_parallel_size: {seq_parallel_size}"
data_parallel_size = world_size // seq_parallel_size
dataloader_params["batch_size"] = dataloader_params["batch_size"] * data_parallel_size
assert sp_size != 0 and world_size % sp_size == 0, f"world_size: {world_size} should be devide by seq_parallel_size: {sp_size}"
dp_size = world_size // sp_size
dataloader_params["batch_size"] = dataloader_params["batch_size"] * dp_size
return DataLoader(train_dataset, **dataloader_params)
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))

Expand Down Expand Up @@ -284,11 +285,11 @@ def get_eval_dataloader(self, eval_dataset) -> DataLoader:
self._eval_dataloader = eval_dataloader

if hasattr(data_collator, "seq_algo") and data_collator.seq_algo != "data_parallel":
seq_parallel_size = self.finetuning_args.seq_parallel_size
if seq_parallel_size != -1:
sp_size = self.finetuning_args.sp_size
if sp_size != -1:
world_size = int(os.environ['WORLD_SIZE'])
assert seq_parallel_size != 0 and world_size % seq_parallel_size == 0, f"world_size: {world_size} should be devide by seq_parallel_size: {seq_parallel_size}"
data_parallel_size = world_size // seq_parallel_size
dataloader_params["batch_size"] = dataloader_params["batch_size"] * data_parallel_size
assert sp_size != 0 and world_size % sp_size == 0, f"world_size: {world_size} should be devide by seq_parallel_size: {sp_size}"
dp_size = world_size // sp_size
dataloader_params["batch_size"] = dataloader_params["batch_size"] * dp_size
return eval_dataloader
return self.accelerator.prepare(eval_dataloader)
6 changes: 3 additions & 3 deletions src/llamafactory/train/sft/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def run_sft(
tokenizer = tokenizer_module["tokenizer"]
dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
apply_seq_parallel_monkey_patch(finetuning_args.parallel_mode, "llama", seq_parallel_size=finetuning_args.seq_parallel_size)
apply_seq_parallel_monkey_patch(finetuning_args.parallel_mode, "llama", sp_size=finetuning_args.sp_size)

if training_args.predict_with_generate:
tokenizer.padding_side = "left" # use left-padding in generation
Expand All @@ -53,7 +53,7 @@ def run_sft(
pad_to_multiple_of=data_args.cutoff_len if tokenizer.padding_side == "right" else None,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
seq_algo=finetuning_args.parallel_mode,
seq_parallel_size=finetuning_args.seq_parallel_size,
sp_size=finetuning_args.sp_size,
rank=torch.distributed.get_rank(),
world_size=world_size,
device=torch.device("cuda", local_rank)
Expand Down Expand Up @@ -87,7 +87,7 @@ def run_sft(
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
# trainer.save_state()
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])

Expand Down

0 comments on commit 92554c2

Please sign in to comment.