Skip to content

Commit

Permalink
fix bug when parallel_mode=data_parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
qianhao0713 committed Dec 11, 2024
1 parent 358cbad commit 1eead84
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ def llama_model_forward(
logits = logits.float()

loss = None
if labels is not None:
if labels is not None and hasattr(self, 'loss_function'):
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **loss_kwargs)
if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
98 changes: 90 additions & 8 deletions src/llamafactory/train/sft/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ...extras.logging import get_logger
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
from torch.utils.data import DataLoader
from transformers.utils import is_datasets_available
from transformers.utils import is_datasets_available, is_sagemaker_mp_enabled
from transformers.trainer_utils import seed_worker
import datasets
from torch.nn import CrossEntropyLoss
Expand Down Expand Up @@ -137,18 +137,91 @@ def save_predictions(self, predict_results: "PredictionOutput") -> None:

class CustomSeqParallelTrainer(CustomSeq2SeqTrainer):

def training_step(
self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None
) -> torch.Tensor:
"""
Perform a training step on a batch of inputs.
Subclass and override to inject custom behavior.
Args:
model (`nn.Module`):
The model to train.
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument `labels`. Check your model's documentation for all accepted arguments.
Return:
`torch.Tensor`: The tensor with training loss on this batch.
"""
model.train()
if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
self.optimizer.train()

use_per_instance_loss = self.finetuning_args.per_instance_loss

inputs = self._prepare_inputs(inputs)
if is_sagemaker_mp_enabled():
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
return loss_mb.reduce_mean().detach().to(self.args.device)

with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)

del inputs
if (hasattr(self.args, "torch_empty_cache_steps")
and self.args.torch_empty_cache_steps is not None
and self.state.global_step % self.args.torch_empty_cache_steps == 0
):
if is_torch_xpu_available():
torch.xpu.empty_cache()
elif is_torch_mlu_available():
torch.mlu.empty_cache()
elif is_torch_musa_available():
torch.musa.empty_cache()
elif is_torch_npu_available():
torch.npu.empty_cache()
elif is_torch_mps_available(min_version="2.0"):
torch.mps.empty_cache()
else:
torch.cuda.empty_cache()

kwargs = {}

# For LOMO optimizers you need to explicitly use the learnign rate
# if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
# kwargs["learning_rate"] = self._get_learning_rate()

if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training

if self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
self.accelerator.backward(loss, **kwargs)
# Finally we need to normalize the loss for reporting
if num_items_in_batch is None or self.finetuning_args.per_instance_loss:
return loss.detach() / self.args.gradient_accumulation_steps
return loss.detach()

def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Subclass and override for custom behavior.
"""
from transformers.trainer import _is_peft_model, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
if not hasattr(self, 'compute_loss_func'):
self.compute_loss_func = None
if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs:
labels = inputs.pop("labels")
else:
labels = None
if self.model_accepts_loss_kwargs:
if hasattr(self, 'model_accepts_loss_kwargs') and self.model_accepts_loss_kwargs:
loss_kwargs = {}
if num_items_in_batch is not None:
loss_kwargs["num_items_in_batch"] = num_items_in_batch
Expand All @@ -173,15 +246,23 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
else:
loss = self.label_smoother(outputs, labels)
else:
if isinstance(outputs, dict) and "loss" not in outputs:
raise ValueError(
"The model did not return a loss from the inputs, only the following keys: "
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
)
# We don't use .loss here since the model may return tuples instead of ModelOutput.
if self.finetuning_args.parallel_mode== "data_parallel":
if isinstance(outputs, dict) and "loss" not in outputs:
raise ValueError(
"The model did not return a loss from the inputs, only the following keys: "
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
)
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
if not hasattr(self.args, 'average_tokens_across_devices'):
self.args.average_tokens_across_devices = None
if not hasattr(self, 'model_accepts_loss_kwargs'):
self.model_accepts_loss_kwargs= None
if self.finetuning_args.per_instance_loss and num_items_in_batch is not None:
labels = inputs.pop("labels")
valid_label_cnt = (labels!=-100).sum()
loss *= num_items_in_batch / valid_label_cnt
elif self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
loss *= self.accelerator.num_processes
else:
if num_items_in_batch is None or self.finetuning_args.per_instance_loss:
Expand All @@ -205,6 +286,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
loss[b]=loss_fn(shift_logits[b], shift_labels[b])/normalizer
loss = loss.mean()*sp_size
else:
assert self.args.average_tokens_across_devices is True, "must ensure average_tokens_across_devices if parallel_mode is not data_parallel"
loss_fn = CrossEntropyLoss(reduction='sum')
labels = inputs.pop("labels")
logits = outputs["logits"] if isinstance(outputs, dict) else outputs[1]
Expand Down

0 comments on commit 1eead84

Please sign in to comment.