From 60f304774eee3853f14f5e47fb8707e88ad58be2 Mon Sep 17 00:00:00 2001 From: Negar Foroutan Eghlidi Date: Thu, 22 Aug 2024 15:17:58 +0200 Subject: [PATCH] Fix a device issue. --- src/nanotron/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 25c4d315..47d83fbe 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -695,7 +695,7 @@ def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, Ten if not lang_losses[ lang ]: # If the list is empty --> Set local language loss to -1 to exclude it from the global computation - lang_losses[lang] = torch.tensor(-1, dtype=torch.float32) + lang_losses[lang] = torch.tensor(-1, dtype=torch.float32, device="cuda") else: # If we have at least 1 loss from a given language --> compute local language loss mean lang_losses[lang] = torch.mean(torch.stack(lang_losses[lang]))