diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 91b9a29b..53f3708f 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -692,7 +692,7 @@ def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, Ten lang_losses_list = list(lang_losses.keys()) # Compute losses - if isinstance(outputs[0], torch.Tensor): + if len(outputs) > 0 and isinstance(outputs[0], torch.Tensor): # Multilingual losses for loss, lang_code in zip(outputs, lang_codes): lang_losses[lang_losses_list[lang_code]].append(loss) @@ -703,7 +703,7 @@ def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, Ten if not lang_losses[ lang ]: # If the list is empty --> Set local language loss to -1 to exclude it from the global computation - lang_losses[lang] = torch.tensor(-1, dtype=torch.float32) + lang_losses[lang] = torch.tensor(-1, dtype=torch.float32, device="cuda") else: # If we have at least 1 loss from a given language --> compute local language loss mean lang_losses[lang] = torch.mean(torch.stack(lang_losses[lang]))