-
-
Notifications
You must be signed in to change notification settings - Fork 4.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #437 from rwightman/agc
Adaptive Gradient Clipping (AGC) Impl
- Loading branch information
Showing
9 changed files
with
106 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
""" Adaptive Gradient Clipping | ||
An impl of AGC, as per (https://arxiv.org/abs/2102.06171): | ||
@article{brock2021high, | ||
author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan}, | ||
title={High-Performance Large-Scale Image Recognition Without Normalization}, | ||
journal={arXiv preprint arXiv:}, | ||
year={2021} | ||
} | ||
Code references: | ||
* Official JAX impl (paper authors): https://github.com/deepmind/deepmind-research/tree/master/nfnets | ||
* Phil Wang's PyTorch gist: https://gist.github.com/lucidrains/0d6560077edac419ab5d3aa29e674d5c | ||
Hacked together by / Copyright 2021 Ross Wightman | ||
""" | ||
import torch | ||
|
||
|
||
def unitwise_norm(x, norm_type=2.0): | ||
if x.ndim <= 1: | ||
return x.norm(norm_type) | ||
else: | ||
# works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor | ||
# might need special cases for other weights (possibly MHA) where this may not be true | ||
return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True) | ||
|
||
|
||
def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0): | ||
if isinstance(parameters, torch.Tensor): | ||
parameters = [parameters] | ||
for p in parameters: | ||
if p.grad is None: | ||
continue | ||
p_data = p.detach() | ||
g_data = p.grad.detach() | ||
max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor) | ||
grad_norm = unitwise_norm(g_data, norm_type=norm_type) | ||
clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6)) | ||
new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad) | ||
p.grad.detach().copy_(new_grads) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import torch | ||
|
||
from timm.utils.agc import adaptive_clip_grad | ||
|
||
|
||
def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0): | ||
""" Dispatch to gradient clipping method | ||
Args: | ||
parameters (Iterable): model parameters to clip | ||
value (float): clipping value/factor/norm, mode dependant | ||
mode (str): clipping mode, one of 'norm', 'value', 'agc' | ||
norm_type (float): p-norm, default 2.0 | ||
""" | ||
if mode == 'norm': | ||
torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type) | ||
elif mode == 'value': | ||
torch.nn.utils.clip_grad_value_(parameters, value) | ||
elif mode == 'agc': | ||
adaptive_clip_grad(parameters, value, norm_type=norm_type) | ||
else: | ||
assert False, f"Unknown clip mode ({mode})." | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
__version__ = '0.4.3' | ||
__version__ = '0.4.4' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters