Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

None of the inputs have requires_grad=True. Gradients will be None #573

Open
maxjaritz opened this issue Jul 20, 2023 · 2 comments
Open

Comments

@maxjaritz
Copy link

maxjaritz commented Jul 20, 2023

I am fine-tuning a model on a custom dataset. At training start, I get the warning "None of the inputs have requires_grad=True. Gradients will be None". I made this warning disappear by adding use_reentrant=False in the checkpoint function in the following three lines in transformer.py:

Interestingly, this also increased performance in the train/val loss and cross-modal retrieval, simply by setting use_reentrant=False!
image

My training command is:

torchrun --nproc_per_node 8 -m training.main \
--train-data 'mydata/{00000..04089}.tar' \
--val-data 'mydata/{04090..04095}.tar' \
--train-num-samples 16115354 \
--val-num-samples 70965 \
--dataset-type webdataset \
--epochs 10 \
--batch-size 1650 \
--precision amp \
--local-loss \
--gather-with-grad \
--grad-checkpointing \
--ddp-static-graph \
--workers 8 \
--seed 0 \
--lr 0.3e-3 \
--warmup 1220 \
--report-to tensorboard \
--resume "latest" \
--zeroshot-frequency 1 \
--model ViT-B-32 \
--name ... \
--pretrained laion2B-s34B-b79K \
--lock-image \
--lock-image-unlocked-groups 9

The problem is not occurring when removing the following arguments from the training command

--lock-image \
--lock-image-unlocked-groups 9

It might be related to the following warning from the PyTorch docs (https://pytorch.org/docs/stable/checkpoint.html):

If use_reentrant=True is specified, at least one of the inputs needs to have requires_grad=True if grads are needed for model inputs, otherwise the checkpointed part of the model won’t have gradients. At least one of the outputs needs to have requires_grad=True as well. Note that this does not apply if use_reentrant=False is specified.

Do you know what the underlying issue is?

@rwightman
Copy link
Collaborator

hmm, I would have thought this works as long as you don't lock the full image or text towers... but perhaps not, it may not be good idea to checkpoint the parts of the model that have gradients disabled.

Should probably set use_reentrant=False but it's never been clear to me what the downside to that is, the PT docs mention many pluses of =False, but why was =True the default, hohumm

@maxjaritz
Copy link
Author

In the pytorch doc, I also saw:

Note that future versions of PyTorch will default to use_reentrant=False. Default: True

@maxjaritz maxjaritz reopened this Sep 16, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants