-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SigLip memory consumption increases as we scale number of GPUs #942
Comments
@khalidsaifullaah yeah, it's not working quite as efficiently as it should. I feel my current isend/irecv impl, while in theory should be reasonable, it appears it may not a well optimized approach. Looking at the big_vision codebase where SigLIP authors have the original ver of the models, there's no chunked sigmoid impl in the current code, but there's an impl in a deprecated file, interestingly there's a comment that states the 'ppermute' version (which should be equivalent to send/recv neighbour exchange, and should itself be more optimal being one high level op instead of multiple) used more memory than doing the 'hot potatoe' with all_reduce instead. Hmmm https://github.com/google-research/big_vision/blob/46b2456f54b9d4f829d1925b78943372b376153d/big_vision/trainers/proj/image_text/_deprecated_contrastive.py#L168-L200 I was thinking of trying an all_reduce impl... all_reduce and all_gather should be among the most optimized collective ops through software stack and network as they are most heavily used. |
@khalidsaifullaah I'm experimenting with diff impl of the loss to see if any scale better in #971 ... feel free to try, feedback would be welcome |
Oh awesome, I had moved on to implementing a different dist loss. However in my quick test of the new commit, i still seem to get OOM when horizontal scale to 128 gpus (mbsz=2, impl="reduce"). I'll do more tests on the other |
awesome! From my observations, when training with SigLIP loss using 100+ GPUs, I noticed it was considerably slower compared to CLIP loss. It would be really helpful for me if you could also report the 'train/batch_time' metrics for each implementation type @khalidsaifullaah @rwightman |
@khalidsaifullaah @long8v FWIW I wouldn't necessarily say no extra overhead as the world size increases is the passing criteria, I feel with gradient buffers, allocator behaviour, etc there's still likely to be some impact from the world size. It should be more efficient than CLIP loss though. |
From the SigLip paper my understanding is that it doesn't require any
all_gather
and it's always performing localb x b
computation iteratively, where b ismicro_batch_size
(see this section from the paper).So if I can fit let's say
micro_batch_size
10 (in 8 GPUs), and then I increase the number of GPUs to 16, 32, 64, 128, ... my memory consumption should (more or less) remain the same (just like doing normal DDP). Or simply put, we should be able to scaleworld_batch_size
or the number of nodes by keeping themicro_batch_size
constant (in theory) right?But what I've observed is that the memory consumption spikes as i increase
world_batch_size
(num of nodes) and I need to lower mymicro_batch_size
(even to as low as2
for 128 devices).micro_batch_size
constant it allows you to scaleworld_batch_size
? It could also be the case that they do some sort of TPU trick (i don't have much insights re that)?I could be totally wrong on both of these, so I'd be glad to know if anyone tried scaling
world_batch_size
and have had similar results, so i could validate my hypothesisThe text was updated successfully, but these errors were encountered: