From e6f1947dad9cbafd38b15862a4004255286a46ce Mon Sep 17 00:00:00 2001 From: qianhao0713 <475483052@qq.com> Date: Wed, 11 Dec 2024 11:17:19 +0800 Subject: [PATCH] remove per_instance_loss --- src/llamafactory/easy_context/__init__.py | 3 - .../zigzag_ring_attn/monkey_patch.py | 20 ----- src/llamafactory/hparams/finetuning_args.py | 4 - src/llamafactory/train/sft/trainer.py | 81 +------------------ 4 files changed, 3 insertions(+), 105 deletions(-) diff --git a/src/llamafactory/easy_context/__init__.py b/src/llamafactory/easy_context/__init__.py index bffbbba570..a1d1b02a79 100644 --- a/src/llamafactory/easy_context/__init__.py +++ b/src/llamafactory/easy_context/__init__.py @@ -2,7 +2,6 @@ from .dist_flash_attn.monkey_patch import apply_dist_flash_attn_monkey_patch_llama from .zigzag_ring_attn.prepare_inputs import prepare_zigzag_ring_attn_inputs, prepare_zigzag_ring_attn_sft_inputs from .zigzag_ring_attn.monkey_patch import apply_zigzag_ring_attn_monkey_patch_llama -from .zigzag_ring_attn.monkey_patch import apply_zigzag_ring_attn_monkey_patch_mistral from .unsloth_offloaded_gradient_checkpoint.monkey_patch import apply_unsloth_offloaded_gradient_checkpoint_monkey_patch from .ulysses_attn.prepare_inputs import prepare_ulysses_attn_inputs, prepare_ulysses_attn_sft_inputs from .ulysses_attn.monkey_patch import apply_ulysses_attn_monkey_patch_llama @@ -72,8 +71,6 @@ def apply_seq_parallel_monkey_patch( return elif seq_algo == "zigzag_ring_attn" and model == "llama": apply_zigzag_ring_attn_monkey_patch_llama(sp_size=sp_size) - elif seq_algo == "zigzag_ring_attn" and model == "mistral": - apply_zigzag_ring_attn_monkey_patch_mistral(sp_size=sp_size) elif seq_algo == "dist_flash_attn" and model == "llama": apply_dist_flash_attn_monkey_patch_llama(sp_size=sp_size, enable_offload=enable_offload, offload_percent=offload_percent) elif seq_algo == "ulysses_attn" and model == "llama": diff --git a/src/llamafactory/easy_context/zigzag_ring_attn/monkey_patch.py b/src/llamafactory/easy_context/zigzag_ring_attn/monkey_patch.py index d27bcf0de1..250c1ea6f1 100644 --- a/src/llamafactory/easy_context/zigzag_ring_attn/monkey_patch.py +++ b/src/llamafactory/easy_context/zigzag_ring_attn/monkey_patch.py @@ -227,23 +227,3 @@ def apply_zigzag_ring_attn_monkey_patch_llama(sp_size=None): transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = ( new_decoder_forward ) - - -def apply_zigzag_ring_attn_monkey_patch_mistral(sp_size=None): - sp_group = get_sp_process_group(sp_size) - if hasattr(transformers.models.llama.modeling_llama.LlamaFlashAttention2, '_flash_attention_forward'): - transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = ( - partialmethod(new_flash_attn_forward, group=sp_group) - ) - else: - transformers.models.llama.modeling_llama._flash_attention_forward = ( - partial(new_flash_attn_forward_v2, group=sp_group) - ) - if "position_embeddings" in inspect.getfullargspec(transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward).args: - transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = ( - new_decoder_forward_v2 - ) - else: - transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = ( - new_decoder_forward - ) diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index c525502694..5f9889c82b 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -330,10 +330,6 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA default=0.0, metadata={"help": "0 for remain all activation memory in gpu, 1 for offload all activation memory in cpu"} ) - per_instance_loss: bool = field( - default=False, - metadata={"help": "if update transformers to 4.46.3, the loss will be calculated in a global batch by default, enable per_instance_loss will calculate loss in each instance"} - ) def __post_init__(self): def split_arg(arg): diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index d00acdf76d..ca30bcc7a2 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -11,7 +11,7 @@ from ...extras.logging import get_logger from ..trainer_utils import create_custom_optimzer, create_custom_scheduler from torch.utils.data import DataLoader -from transformers.utils import is_datasets_available, is_sagemaker_mp_enabled +from transformers.utils import is_datasets_available from transformers.trainer_utils import seed_worker import datasets from torch.nn import CrossEntropyLoss @@ -137,77 +137,6 @@ def save_predictions(self, predict_results: "PredictionOutput") -> None: class CustomSeqParallelTrainer(CustomSeq2SeqTrainer): - def training_step( - self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None - ) -> torch.Tensor: - """ - Perform a training step on a batch of inputs. - - Subclass and override to inject custom behavior. - - Args: - model (`nn.Module`): - The model to train. - inputs (`Dict[str, Union[torch.Tensor, Any]]`): - The inputs and targets of the model. - - The dictionary will be unpacked before being fed to the model. Most models expect the targets under the - argument `labels`. Check your model's documentation for all accepted arguments. - - Return: - `torch.Tensor`: The tensor with training loss on this batch. - """ - model.train() - if hasattr(self.optimizer, "train") and callable(self.optimizer.train): - self.optimizer.train() - - use_per_instance_loss = self.finetuning_args.per_instance_loss - - inputs = self._prepare_inputs(inputs) - if is_sagemaker_mp_enabled(): - loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) - return loss_mb.reduce_mean().detach().to(self.args.device) - - with self.compute_loss_context_manager(): - loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) - - del inputs - if (hasattr(self.args, "torch_empty_cache_steps") - and self.args.torch_empty_cache_steps is not None - and self.state.global_step % self.args.torch_empty_cache_steps == 0 - ): - if is_torch_xpu_available(): - torch.xpu.empty_cache() - elif is_torch_mlu_available(): - torch.mlu.empty_cache() - elif is_torch_musa_available(): - torch.musa.empty_cache() - elif is_torch_npu_available(): - torch.npu.empty_cache() - elif is_torch_mps_available(min_version="2.0"): - torch.mps.empty_cache() - else: - torch.cuda.empty_cache() - - kwargs = {} - - # For LOMO optimizers you need to explicitly use the learnign rate - # if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: - # kwargs["learning_rate"] = self._get_learning_rate() - - if self.args.n_gpu > 1: - loss = loss.mean() # mean() to average on multi-gpu parallel training - - if self.use_apex: - with amp.scale_loss(loss, self.optimizer) as scaled_loss: - scaled_loss.backward() - else: - self.accelerator.backward(loss, **kwargs) - # Finally we need to normalize the loss for reporting - if num_items_in_batch is None or self.finetuning_args.per_instance_loss: - return loss.detach() / self.args.gradient_accumulation_steps - return loss.detach() - def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): """ How the loss is computed by Trainer. By default, all models return the loss in the first element. @@ -258,14 +187,10 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N self.args.average_tokens_across_devices = None if not hasattr(self, 'model_accepts_loss_kwargs'): self.model_accepts_loss_kwargs= None - if self.finetuning_args.per_instance_loss and num_items_in_batch is not None: - labels = inputs.pop("labels") - valid_label_cnt = (labels!=-100).sum() - loss *= num_items_in_batch / valid_label_cnt - elif self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs: + if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs: loss *= self.accelerator.num_processes else: - if num_items_in_batch is None or self.finetuning_args.per_instance_loss: + if num_items_in_batch is None: sp_size = self.finetuning_args.sp_size loss_fn = CrossEntropyLoss(reduction='sum') labels = inputs.pop("labels")