From 34f70cec659d6a508d2d9172442247116be860bd Mon Sep 17 00:00:00 2001 From: qianhao0713 <475483052@qq.com> Date: Mon, 15 Jul 2024 19:08:29 +0800 Subject: [PATCH] add dp&sp hybrid for cpt --- src/llamafactory/data/collator.py | 2 +- src/llamafactory/train/pt/trainer.py | 14 ++++++++++++++ src/llamafactory/train/pt/workflow.py | 3 ++- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index c38dfee1c9..29023f25f9 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -129,7 +129,7 @@ class SeqParallelDataCollatorForLanguageModeling(DataCollatorForLanguageModeling Reuse the sequence parallel distributing function for sft stage. """ seq_algo: str = "data_parallel" - seq_parallel_size: int = -1 + sp_size: int = -1 rank: int = 0 world_size: int = 8 device: Optional[Any] = None diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index 18f73d8512..6a34b27238 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -11,6 +11,7 @@ from transformers.trainer_utils import seed_worker import datasets from torch.nn import CrossEntropyLoss +import os if TYPE_CHECKING: import torch @@ -62,6 +63,7 @@ def compute_loss(self, model, inputs, return_outputs=False): Subclass and override for custom behavior. """ + from transformers.trainer import _is_peft_model, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES if self.label_smoother is not None and "labels" in inputs: labels = inputs.pop("labels") else: @@ -148,6 +150,12 @@ def get_train_dataloader(self) -> DataLoader: dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor if hasattr(data_collator, "seq_algo") and data_collator.seq_algo != "data_parallel": + sp_size = self.finetuning_args.sp_size + if sp_size != -1: + world_size = int(os.environ['WORLD_SIZE']) + assert sp_size != 0 and world_size % sp_size == 0, f"world_size: {world_size} should be devide by seq_parallel_size: {sp_size}" + dp_size = world_size // sp_size + dataloader_params["batch_size"] = dataloader_params["batch_size"] * dp_size return DataLoader(train_dataset, **dataloader_params) return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) @@ -197,5 +205,11 @@ def get_eval_dataloader(self, eval_dataset) -> DataLoader: self._eval_dataloader = eval_dataloader if hasattr(data_collator, "seq_algo") and data_collator.seq_algo != "data_parallel": + sp_size = self.finetuning_args.sp_size + if sp_size != -1: + world_size = int(os.environ['WORLD_SIZE']) + assert sp_size != 0 and world_size % sp_size == 0, f"world_size: {world_size} should be devide by seq_parallel_size: {sp_size}" + dp_size = world_size // sp_size + dataloader_params["batch_size"] = dataloader_params["batch_size"] * dp_size return eval_dataloader return self.accelerator.prepare(eval_dataloader) \ No newline at end of file diff --git a/src/llamafactory/train/pt/workflow.py b/src/llamafactory/train/pt/workflow.py index 1d9a36f673..fc33acffae 100644 --- a/src/llamafactory/train/pt/workflow.py +++ b/src/llamafactory/train/pt/workflow.py @@ -32,7 +32,7 @@ def run_pt( tokenizer = tokenizer_module["tokenizer"] dataset = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) - apply_seq_parallel_monkey_patch(finetuning_args.parallel_mode, "llama") + apply_seq_parallel_monkey_patch(finetuning_args.parallel_mode, "llama", sp_size=finetuning_args.sp_size) # data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) local_rank = int(os.getenv("LOCAL_RANK")) @@ -42,6 +42,7 @@ def run_pt( tokenizer=tokenizer, mlm=False, seq_algo=finetuning_args.parallel_mode, + sp_size=finetuning_args.sp_size, rank=torch.distributed.get_rank(), world_size=torch.distributed.get_world_size(), device=torch.device("cuda", local_rank)