-
Notifications
You must be signed in to change notification settings - Fork 346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
An issue on loss function #10
Comments
I was wondering if the multiplication of T square is really helpful? Because if T=20, the soft loss will dominate the total loss. And there is no need to add extra softmax for the hard target as it is already embedded in nn.functional.cross_entropy. @lhyfst |
As @erichhhhho pointed out, it's indeed no need to manually add extra softmax. From the reference paper, it looks like T^2 is only required when using BOTH hard/soft targets. |
Thank you, everybody! So, why does the first part of the KD loss function in distill_mnist.py multiply 2? |
As per distiller KD_Loss is effectively the following equation: α * kl_divergence + β * cross_entropy And Hinton et al. 2015 originally used a weighted average, i.e. |
I suggest both training loss function without KD and with KD should add a softmax function, because the outputs of models are without softmax. Just like this.
https://github.com/peterliht/knowledge-distillation-pytorch/blob/e4c40132fed5a45e39a6ef7a77b15e5d389186f8/model/net.py#L100-L114
==>
KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1), F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \ F.cross_entropy(F.softmax(outputs,dim=1), labels) * (1. - alpha)
&
https://github.com/peterliht/knowledge-distillation-pytorch/blob/e4c40132fed5a45e39a6ef7a77b15e5d389186f8/model/net.py#L83-L97
==>
return nn.CrossEntropyLoss()(F.softmax(outputs,dim=1), labels)
For another thing, why does the first part of the KD loss function in distill_mnist.py multiply 2?
https://github.com/peterliht/knowledge-distillation-pytorch/blob/e4c40132fed5a45e39a6ef7a77b15e5d389186f8/mnist/distill_mnist.py#L96-L97
One more thing, it is not necessary to multiply T*T if we distill only using soft targets.
https://github.com/peterliht/knowledge-distillation-pytorch/blob/e4c40132fed5a45e39a6ef7a77b15e5d389186f8/mnist/distill_mnist_unlabeled.py#L96-L97
reference
Distilling the Knowledge in a Neural Network
The text was updated successfully, but these errors were encountered: