Skip to content

Commit

Permalink
remove per_instance_loss
Browse files Browse the repository at this point in the history
  • Loading branch information
qianhao0713 committed Dec 11, 2024
1 parent 1eead84 commit e6f1947
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 105 deletions.
3 changes: 0 additions & 3 deletions src/llamafactory/easy_context/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from .dist_flash_attn.monkey_patch import apply_dist_flash_attn_monkey_patch_llama
from .zigzag_ring_attn.prepare_inputs import prepare_zigzag_ring_attn_inputs, prepare_zigzag_ring_attn_sft_inputs
from .zigzag_ring_attn.monkey_patch import apply_zigzag_ring_attn_monkey_patch_llama
from .zigzag_ring_attn.monkey_patch import apply_zigzag_ring_attn_monkey_patch_mistral
from .unsloth_offloaded_gradient_checkpoint.monkey_patch import apply_unsloth_offloaded_gradient_checkpoint_monkey_patch
from .ulysses_attn.prepare_inputs import prepare_ulysses_attn_inputs, prepare_ulysses_attn_sft_inputs
from .ulysses_attn.monkey_patch import apply_ulysses_attn_monkey_patch_llama
Expand Down Expand Up @@ -72,8 +71,6 @@ def apply_seq_parallel_monkey_patch(
return
elif seq_algo == "zigzag_ring_attn" and model == "llama":
apply_zigzag_ring_attn_monkey_patch_llama(sp_size=sp_size)
elif seq_algo == "zigzag_ring_attn" and model == "mistral":
apply_zigzag_ring_attn_monkey_patch_mistral(sp_size=sp_size)
elif seq_algo == "dist_flash_attn" and model == "llama":
apply_dist_flash_attn_monkey_patch_llama(sp_size=sp_size, enable_offload=enable_offload, offload_percent=offload_percent)
elif seq_algo == "ulysses_attn" and model == "llama":
Expand Down
20 changes: 0 additions & 20 deletions src/llamafactory/easy_context/zigzag_ring_attn/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,23 +227,3 @@ def apply_zigzag_ring_attn_monkey_patch_llama(sp_size=None):
transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = (
new_decoder_forward
)


def apply_zigzag_ring_attn_monkey_patch_mistral(sp_size=None):
sp_group = get_sp_process_group(sp_size)
if hasattr(transformers.models.llama.modeling_llama.LlamaFlashAttention2, '_flash_attention_forward'):
transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = (
partialmethod(new_flash_attn_forward, group=sp_group)
)
else:
transformers.models.llama.modeling_llama._flash_attention_forward = (
partial(new_flash_attn_forward_v2, group=sp_group)
)
if "position_embeddings" in inspect.getfullargspec(transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward).args:
transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = (
new_decoder_forward_v2
)
else:
transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = (
new_decoder_forward
)
4 changes: 0 additions & 4 deletions src/llamafactory/hparams/finetuning_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,10 +330,6 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default=0.0,
metadata={"help": "0 for remain all activation memory in gpu, 1 for offload all activation memory in cpu"}
)
per_instance_loss: bool = field(
default=False,
metadata={"help": "if update transformers to 4.46.3, the loss will be calculated in a global batch by default, enable per_instance_loss will calculate loss in each instance"}
)

def __post_init__(self):
def split_arg(arg):
Expand Down
81 changes: 3 additions & 78 deletions src/llamafactory/train/sft/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ...extras.logging import get_logger
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
from torch.utils.data import DataLoader
from transformers.utils import is_datasets_available, is_sagemaker_mp_enabled
from transformers.utils import is_datasets_available
from transformers.trainer_utils import seed_worker
import datasets
from torch.nn import CrossEntropyLoss
Expand Down Expand Up @@ -137,77 +137,6 @@ def save_predictions(self, predict_results: "PredictionOutput") -> None:

class CustomSeqParallelTrainer(CustomSeq2SeqTrainer):

def training_step(
self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None
) -> torch.Tensor:
"""
Perform a training step on a batch of inputs.
Subclass and override to inject custom behavior.
Args:
model (`nn.Module`):
The model to train.
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument `labels`. Check your model's documentation for all accepted arguments.
Return:
`torch.Tensor`: The tensor with training loss on this batch.
"""
model.train()
if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
self.optimizer.train()

use_per_instance_loss = self.finetuning_args.per_instance_loss

inputs = self._prepare_inputs(inputs)
if is_sagemaker_mp_enabled():
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
return loss_mb.reduce_mean().detach().to(self.args.device)

with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)

del inputs
if (hasattr(self.args, "torch_empty_cache_steps")
and self.args.torch_empty_cache_steps is not None
and self.state.global_step % self.args.torch_empty_cache_steps == 0
):
if is_torch_xpu_available():
torch.xpu.empty_cache()
elif is_torch_mlu_available():
torch.mlu.empty_cache()
elif is_torch_musa_available():
torch.musa.empty_cache()
elif is_torch_npu_available():
torch.npu.empty_cache()
elif is_torch_mps_available(min_version="2.0"):
torch.mps.empty_cache()
else:
torch.cuda.empty_cache()

kwargs = {}

# For LOMO optimizers you need to explicitly use the learnign rate
# if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
# kwargs["learning_rate"] = self._get_learning_rate()

if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training

if self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
self.accelerator.backward(loss, **kwargs)
# Finally we need to normalize the loss for reporting
if num_items_in_batch is None or self.finetuning_args.per_instance_loss:
return loss.detach() / self.args.gradient_accumulation_steps
return loss.detach()

def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Expand Down Expand Up @@ -258,14 +187,10 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
self.args.average_tokens_across_devices = None
if not hasattr(self, 'model_accepts_loss_kwargs'):
self.model_accepts_loss_kwargs= None
if self.finetuning_args.per_instance_loss and num_items_in_batch is not None:
labels = inputs.pop("labels")
valid_label_cnt = (labels!=-100).sum()
loss *= num_items_in_batch / valid_label_cnt
elif self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
loss *= self.accelerator.num_processes
else:
if num_items_in_batch is None or self.finetuning_args.per_instance_loss:
if num_items_in_batch is None:
sp_size = self.finetuning_args.sp_size
loss_fn = CrossEntropyLoss(reduction='sum')
labels = inputs.pop("labels")
Expand Down

0 comments on commit e6f1947

Please sign in to comment.