From d5e513528ecee20a5eaf74bafa8dce451f401f80 Mon Sep 17 00:00:00 2001 From: qianhao0713 <475483052@qq.com> Date: Mon, 15 Jul 2024 15:39:24 +0800 Subject: [PATCH 1/3] rename variables --- Llama3-70B.sh | 4 +- Llama3-8B.sh | 4 +- src/llamafactory/data/collator.py | 48 +++++++++---------- src/llamafactory/easy_context/__init__.py | 4 +- .../dist_flash_attn/async_communication.py | 4 +- .../dist_flash_attn/monkey_patch.py | 4 +- src/llamafactory/hparams/finetuning_args.py | 4 +- src/llamafactory/train/sft/trainer.py | 23 ++++----- src/llamafactory/train/sft/workflow.py | 6 +-- 9 files changed, 51 insertions(+), 50 deletions(-) diff --git a/Llama3-70B.sh b/Llama3-70B.sh index 4cb5bb407c..72fe710ea2 100644 --- a/Llama3-70B.sh +++ b/Llama3-70B.sh @@ -5,7 +5,7 @@ NGPUS=${NGPUS:-8} WORLD_SIZE=${WORLD_SIZE:-1} NUM_PROCESSES=$[${NGPUS}*$[WORLD_SIZE]] SEQ_LEN=${SEQ_LEN:-32768} -SP_GROUP_SIZE=${SP_GROUP_SIZE:-1} +SP_SIZE=${SP_SIZE:-1} BATCH_SIZE=${BATCH_SIZE:-1} export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' export WANDB_DISABLED=true @@ -30,7 +30,7 @@ src/train.py \ --do_train \ --finetuning_type full \ --parallel_mode dist_flash_attn \ ---seq_parallel_size ${SP_GROUP_SIZE} \ +--sp_size ${SP_SIZE} \ --deepspeed examples/deepspeed/ds_z3_offload_config.json \ --dataset long_sft_128k \ --template llama3 \ diff --git a/Llama3-8B.sh b/Llama3-8B.sh index 46c081b412..0e67d291e9 100644 --- a/Llama3-8B.sh +++ b/Llama3-8B.sh @@ -5,7 +5,7 @@ NGPUS=${NGPUS:-8} WORLD_SIZE=${WORLD_SIZE:-1} NUM_PROCESSES=$[${NGPUS}*$[WORLD_SIZE]] SEQ_LEN=${SEQ_LEN:-1024} -SP_GROUP_SIZE=${SP_GROUP_SIZE:-1} +SP_SIZE=${SP_SIZE:-1} BATCH_SIZE=${BATCH_SIZE:-1} export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' export WANDB_DISABLED=true @@ -31,7 +31,7 @@ src/train.py \ --finetuning_type full \ --lora_target all \ --parallel_mode dist_flash_attn \ ---seq_parallel_size ${SP_GROUP_SIZE} \ +--sp_size ${SP_SIZE} \ --deepspeed examples/deepspeed/ds_z3_offload_config.json \ --dataset alpaca_en \ --template llama3 \ diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 8a0ae59c23..c38dfee1c9 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -87,7 +87,7 @@ class SeqParallelDataCollator(DataCollatorForSeq2Seq): Data collator for sequence parallel in supervised finetune(sft) stage. """ seq_algo: str = "data_parallel", - seq_parallel_size: int = -1 + sp_size: int = -1 rank: int = 0 world_size: int = 8 device: Optional[Any] = None @@ -99,15 +99,15 @@ def __call__(self, features: Sequence[Dict[str, Any]], return_tensors=None) -> D input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] labels = batch["labels"] - seq_rank = self.rank - seq_worlds_size = self.world_size - if self.seq_parallel_size != -1: - dp_rank = self.rank // self.seq_parallel_size - seq_rank = self.rank % self.seq_parallel_size - seq_worlds_size = self.seq_parallel_size + world_size = self.world_size + sp_rank = self.rank + if self.sp_size != -1: + dp_rank = self.rank // self.sp_size + sp_rank = self.rank % self.sp_size + world_size = self.sp_size bs = len(input_ids) - data_group_size = self.world_size // self.seq_parallel_size - group_bs = bs // data_group_size + dp_size = self.world_size // self.sp_size + group_bs = bs // dp_size input_ids = input_ids[dp_rank * group_bs: (dp_rank + 1) * group_bs] attention_mask = attention_mask[dp_rank * group_bs: (dp_rank + 1) * group_bs] labels = labels[dp_rank * group_bs: (dp_rank + 1) * group_bs] @@ -116,8 +116,8 @@ def __call__(self, features: Sequence[Dict[str, Any]], return_tensors=None) -> D attention_mask=attention_mask, position_ids=None, labels=labels, - rank=seq_rank, - world_size=seq_worlds_size, + rank=sp_rank, + world_size=world_size, device=self.device) return batch @@ -138,27 +138,27 @@ def __call__(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dic batch = super().__call__(examples) if self.seq_algo == "data_parallel": return batch - seq_rank = self.rank - seq_worlds_size = self.world_size - if self.seq_parallel_size != -1: - dp_rank = self.rank // self.seq_parallel_size - seq_rank = self.rank % self.seq_parallel_size - seq_worlds_size = self.seq_parallel_size + input_ids = batch["input_ids"] + attention_mask = batch["attention_mask"] + labels = batch["labels"] + world_size = self.world_size + sp_rank = self.rank + if self.sp_size != -1: + dp_rank = self.rank // self.sp_size + sp_rank = self.rank % self.sp_size + world_size = self.sp_size bs = len(input_ids) - data_group_size = self.world_size // self.seq_parallel_size - group_bs = bs // data_group_size + dp_size = self.world_size // self.sp_size + group_bs = bs // dp_size input_ids = input_ids[dp_rank * group_bs: (dp_rank + 1) * group_bs] attention_mask = attention_mask[dp_rank * group_bs: (dp_rank + 1) * group_bs] labels = labels[dp_rank * group_bs: (dp_rank + 1) * group_bs] - input_ids = batch["input_ids"] - attention_mask = batch["attention_mask"] - labels = batch["labels"] batch = prepare_seq_parallel_sft_inputs(self.seq_algo, input_ids=input_ids, attention_mask=attention_mask, position_ids=None, labels=labels, - rank=seq_rank, - world_size=seq_worlds_size, + rank=sp_rank, + world_size=world_size, device=self.device) return batch diff --git a/src/llamafactory/easy_context/__init__.py b/src/llamafactory/easy_context/__init__.py index bef8ecab74..8f72c15786 100644 --- a/src/llamafactory/easy_context/__init__.py +++ b/src/llamafactory/easy_context/__init__.py @@ -64,7 +64,7 @@ def prepare_seq_parallel_sft_inputs( raise ValueError(f"Invalid seq_algo: {seq_algo}") def apply_seq_parallel_monkey_patch( - seq_algo, model,seq_parallel_size=None + seq_algo, model, sp_size=None ): assert seq_algo in ["zigzag_ring_attn", "dist_flash_attn", "ulysses_attn", "data_parallel"], f"Invalid seq_algo: {seq_algo}" assert model in ["llama", "mistral"], f"Invalid model: {model}" @@ -75,7 +75,7 @@ def apply_seq_parallel_monkey_patch( elif seq_algo == "zigzag_ring_attn" and model == "mistral": apply_zigzag_ring_attn_monkey_patch_mistral() elif seq_algo == "dist_flash_attn" and model == "llama": - apply_dist_flash_attn_monkey_patch_llama(seq_parallel_size=seq_parallel_size) + apply_dist_flash_attn_monkey_patch_llama(sp_size=sp_size) elif seq_algo == "ulysses_attn" and model == "llama": apply_ulysses_attn_monkey_patch_llama() else: diff --git a/src/llamafactory/easy_context/dist_flash_attn/async_communication.py b/src/llamafactory/easy_context/dist_flash_attn/async_communication.py index e67e88a4f7..68b35b5ae6 100644 --- a/src/llamafactory/easy_context/dist_flash_attn/async_communication.py +++ b/src/llamafactory/easy_context/dist_flash_attn/async_communication.py @@ -39,7 +39,7 @@ _bwd_send_volume = 0 _bwd_recv_volume = 0 -def initialize_distributed(sequence_parallel_size=None): +def initialize_distributed(sp_size=None): if dist.is_initialized(): if dist.get_rank() == 0: print( @@ -55,7 +55,7 @@ def initialize_distributed(sequence_parallel_size=None): global_world_size = dist.get_world_size() torch.cuda.set_device(dist.get_rank() % local_world_size) - _initialize_sequence_parallel(sequence_parallel_size=sequence_parallel_size) + _initialize_sequence_parallel(sp_size) # create_nccl_communicators() def _initialize_sequence_parallel(sequence_parallel_size=None): diff --git a/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py b/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py index 63ba6f4973..317b8b2748 100644 --- a/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py +++ b/src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py @@ -602,7 +602,7 @@ def custom_forward(*inputs): ) -def apply_dist_flash_attn_monkey_patch_llama(seq_parallel_size=None): - initialize_distributed(sequence_parallel_size=seq_parallel_size) +def apply_dist_flash_attn_monkey_patch_llama(sp_size=None): + initialize_distributed(sp_size=sp_size) transformers.models.llama.modeling_llama.LlamaModel.forward = forward transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = llama_layer_forward diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index b636afdf37..eddbf3217f 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -316,10 +316,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA default="data_parallel", metadata={"help": "which sequence parallel mode to use."}, ) - seq_parallel_size: int = field( + sp_size: int = field( default=-1, metadata={ - "help": "used for use seq_parallel and data_parallel simultaneously, -1 for all gpus parallels in sequence_length axis, n for n_gpus makes a sequence_parallel group" + "help": "allow using seq_parallel and data_parallel simultaneously, -1 for all gpus parallels in sequence_length axis, n for n_gpus makes a sequence_parallel group" } ) diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 2abc5e8e40..6b1627f93c 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -136,13 +136,14 @@ def save_predictions(self, predict_results: "PredictionOutput") -> None: writer.write("\n".join(res)) class CustomSeqParallelTrainer(CustomSeq2SeqTrainer): - from transformers.trainer import _is_peft_model, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + def compute_loss(self, model, inputs, return_outputs=False): """ How the loss is computed by Trainer. By default, all models return the loss in the first element. Subclass and override for custom behavior. """ + from transformers.trainer import _is_peft_model, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES if self.label_smoother is not None and "labels" in inputs: labels = inputs.pop("labels") else: @@ -229,12 +230,12 @@ def get_train_dataloader(self) -> DataLoader: dataloader_params["worker_init_fn"] = seed_worker dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor if hasattr(data_collator, "seq_algo") and data_collator.seq_algo != "data_parallel": - seq_parallel_size = self.finetuning_args.seq_parallel_size - if seq_parallel_size != -1: + sp_size = self.finetuning_args.sp_size + if sp_size != -1: world_size = int(os.environ['WORLD_SIZE']) - assert seq_parallel_size != 0 and world_size % seq_parallel_size == 0, f"world_size: {world_size} should be devide by seq_parallel_size: {seq_parallel_size}" - data_parallel_size = world_size // seq_parallel_size - dataloader_params["batch_size"] = dataloader_params["batch_size"] * data_parallel_size + assert sp_size != 0 and world_size % sp_size == 0, f"world_size: {world_size} should be devide by seq_parallel_size: {sp_size}" + dp_size = world_size // sp_size + dataloader_params["batch_size"] = dataloader_params["batch_size"] * dp_size return DataLoader(train_dataset, **dataloader_params) return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) @@ -284,11 +285,11 @@ def get_eval_dataloader(self, eval_dataset) -> DataLoader: self._eval_dataloader = eval_dataloader if hasattr(data_collator, "seq_algo") and data_collator.seq_algo != "data_parallel": - seq_parallel_size = self.finetuning_args.seq_parallel_size - if seq_parallel_size != -1: + sp_size = self.finetuning_args.sp_size + if sp_size != -1: world_size = int(os.environ['WORLD_SIZE']) - assert seq_parallel_size != 0 and world_size % seq_parallel_size == 0, f"world_size: {world_size} should be devide by seq_parallel_size: {seq_parallel_size}" - data_parallel_size = world_size // seq_parallel_size - dataloader_params["batch_size"] = dataloader_params["batch_size"] * data_parallel_size + assert sp_size != 0 and world_size % sp_size == 0, f"world_size: {world_size} should be devide by seq_parallel_size: {sp_size}" + dp_size = world_size // sp_size + dataloader_params["batch_size"] = dataloader_params["batch_size"] * dp_size return eval_dataloader return self.accelerator.prepare(eval_dataloader) diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index c3a082ad38..e570d3ef71 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -37,7 +37,7 @@ def run_sft( tokenizer = tokenizer_module["tokenizer"] dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) - apply_seq_parallel_monkey_patch(finetuning_args.parallel_mode, "llama", seq_parallel_size=finetuning_args.seq_parallel_size) + apply_seq_parallel_monkey_patch(finetuning_args.parallel_mode, "llama", sp_size=finetuning_args.sp_size) if training_args.predict_with_generate: tokenizer.padding_side = "left" # use left-padding in generation @@ -53,7 +53,7 @@ def run_sft( pad_to_multiple_of=data_args.cutoff_len if tokenizer.padding_side == "right" else None, label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, seq_algo=finetuning_args.parallel_mode, - seq_parallel_size=finetuning_args.seq_parallel_size, + sp_size=finetuning_args.sp_size, rank=torch.distributed.get_rank(), world_size=world_size, device=torch.device("cuda", local_rank) @@ -87,7 +87,7 @@ def run_sft( trainer.save_model() trainer.log_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics) - # trainer.save_state() + trainer.save_state() if trainer.is_world_process_zero() and finetuning_args.plot_loss: plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) From d31c8f764c7ac1f4e2c26f371b0ea2d4adfd872c Mon Sep 17 00:00:00 2001 From: qianhao0713 <475483052@qq.com> Date: Mon, 15 Jul 2024 15:49:58 +0800 Subject: [PATCH 2/3] fix 70b launch shell --- Llama3-70B.sh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/Llama3-70B.sh b/Llama3-70B.sh index 72fe710ea2..b139933734 100644 --- a/Llama3-70B.sh +++ b/Llama3-70B.sh @@ -9,6 +9,7 @@ SP_SIZE=${SP_SIZE:-1} BATCH_SIZE=${BATCH_SIZE:-1} export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:1024' export WANDB_DISABLED=true +export NCCL_DEBUG=WARN echo ${RANK}/$[WORLD_SIZE] if [ ${MASTER_ADDR} == 'localhost' ]; then export MASTER_ADDR=`hostname -i` @@ -35,7 +36,7 @@ src/train.py \ --dataset long_sft_128k \ --template llama3 \ --cutoff_len ${SEQ_LEN} \ ---max_samples 1000 \ +--max_steps 10 \ --overwrite_cache \ --preprocessing_num_workers 16 \ --output_dir ./output/70B_32K_bs_1M_rope_1M_step_1000_lr_2e-5 \ @@ -46,7 +47,7 @@ src/train.py \ --per_device_train_batch_size ${BATCH_SIZE} \ --gradient_accumulation_steps 4 \ --learning_rate 2e-5 \ ---num_train_epochs 3.0 \ +--num_train_epochs 1.0 \ --lr_scheduler_type cosine \ --warmup_ratio 0.1 \ --bf16 \ From 4a4ea30960a10865c1afe0849b0d3a18a3aec7ff Mon Sep 17 00:00:00 2001 From: qianhao0713 <475483052@qq.com> Date: Mon, 15 Jul 2024 16:08:37 +0800 Subject: [PATCH 3/3] fix bug --- src/llamafactory/train/sft/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 6b1627f93c..37de8b8ea2 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -174,7 +174,7 @@ def compute_loss(self, model, inputs, return_outputs=False): if self.finetuning_args.parallel_mode== "data_parallel": loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] else: - sp_size = self.finetuning_args.seq_parallel_size + sp_size = self.finetuning_args.sp_size loss_fn = CrossEntropyLoss(reduction='sum') labels = inputs.pop("labels") logits = outputs["logits"] if isinstance(outputs, dict) else outputs[1]