From aeaf2a063b0b6df2140bd11617b717da6bedab67 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 27 Oct 2024 11:37:08 -0700 Subject: [PATCH] Add --lost-dist-impl argument to pick different distributed loss implementations --- src/open_clip/factory.py | 2 ++ src/open_clip/loss.py | 39 ++++++++++++++++++++--------------- src/open_clip_train/params.py | 6 ++++++ 3 files changed, 30 insertions(+), 17 deletions(-) diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 03ccc4f06..c6a9e9eac 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -448,7 +448,9 @@ def create_loss(args): return SigLipLoss( rank=args.rank, world_size=args.world_size, + dist_impl=args.loss_dist_impl, # siglip has multiple distributed implementations to choose from ) + return ClipLoss( local_loss=args.local_loss, gather_with_grad=args.gather_with_grad, diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index 9b39dbf31..b3e6dd256 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch import torch.nn as nn from torch.nn import functional as F @@ -102,8 +104,14 @@ def get_ground_truth(self, device, num_logits) -> torch.Tensor: def get_logits(self, image_features, text_features, logit_scale): if self.world_size > 1: all_image_features, all_text_features = gather_features( - image_features, text_features, - self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) + image_features, + text_features, + local_loss=self.local_loss, + gather_with_grad=self.gather_with_grad, + rank=self.rank, + world_size=self.world_size, + use_horovod=self.use_horovod, + ) if self.local_loss: logits_per_image = logit_scale * image_features @ all_text_features.T @@ -158,12 +166,11 @@ def __init__( self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id) def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False): - - clip_loss = torch.tensor(0) - if self.clip_loss_weight: clip_loss = super().forward(image_features, text_features, logit_scale) clip_loss = self.clip_loss_weight * clip_loss + else: + clip_loss = torch.tensor(0, device=logits.device) caption_loss = self.caption_loss( logits.permute(0, 2, 1), @@ -316,19 +323,17 @@ class SigLipLoss(nn.Module): """ def __init__( self, - cache_labels=False, - rank=0, - world_size=1, - use_horovod=False, - impl='bidir', + cache_labels: bool = False, + rank: int = 0, + world_size: int = 1, + dist_impl: Optional[str] = None, ): super().__init__() self.cache_labels = cache_labels self.rank = rank self.world_size = world_size - assert not use_horovod # FIXME need to look at hvd ops for ring transfers - self.use_horovod = use_horovod - self.impl = impl + self.dist_impl = dist_impl or 'bidir' # default to bidir exchange for now, this will likely change + assert self.dist_impl in ('bidir', 'shift', 'reduce', 'gather') # cache state FIXME cache not currently used, worthwhile? self.prev_num_logits = 0 @@ -361,7 +366,7 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output loss = self._loss(image_features, text_features, logit_scale, logit_bias) if self.world_size > 1: - if self.impl == 'bidir': + if self.dist_impl == 'bidir': right_rank = (self.rank + 1) % self.world_size left_rank = (self.rank - 1 + self.world_size) % self.world_size text_features_to_right = text_features_to_left = text_features @@ -396,7 +401,7 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output logit_bias, negative_only=True, ) - elif self.impl == "shift": + elif self.dist_impl == "shift": right_rank = (self.rank + 1) % self.world_size left_rank = (self.rank - 1 + self.world_size) % self.world_size text_features_to_right = text_features @@ -414,7 +419,7 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output negative_only=True, ) text_features_to_right = text_features_from_left - elif self.impl == "reduce": + elif self.dist_impl == "reduce": for i in range(self.world_size): text_from_other = torch.distributed.nn.all_reduce( text_features * (self.rank == i), @@ -427,7 +432,7 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output logit_bias, negative_only=True, ) - elif self.impl == "gather": + elif self.dist_impl == "gather": all_text = torch.distributed.nn.all_gather(text_features) for i in range(self.world_size): loss += float(i != self.rank) * self._loss( diff --git a/src/open_clip_train/params.py b/src/open_clip_train/params.py index 2cf5ded30..63e6f6c8a 100644 --- a/src/open_clip_train/params.py +++ b/src/open_clip_train/params.py @@ -469,6 +469,12 @@ def parse_args(args): action="store_true", help='Use SigLip (sigmoid) loss.' ) + parser.add_argument( + "--loss-dist-impl", + default=None, + type=str, + help='A string to specify a specific distributed loss implementation.' + ) args = parser.parse_args(args)