Skip to content

Commit

Permalink
Merge pull request #437 from rwightman/agc
Browse files Browse the repository at this point in the history
Adaptive Gradient Clipping (AGC) Impl
  • Loading branch information
rwightman authored Feb 16, 2021
2 parents 5f9aff3 + 361fd0f commit 4ea5931
Show file tree
Hide file tree
Showing 9 changed files with 106 additions and 17 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@

## What's New

### Feb 16, 2021
* Add Adaptive Gradient Clipping (AGC) as per https://arxiv.org/abs/2102.06171. Integrated w/ PyTorch gradient clipping via mode arg that defaults to prev 'norm' mode. For backward arg compat, clip-grad arg must be specified to enable when using train.py.
* AGC w/ default clipping factor `--clip-grad .01 --clip-mode agc`
* PyTorch global norm of 1.0 (old behaviour, always norm), `--clip-grad 1.0`
* PyTorch value clipping of 10, `--clip-grad 10. --clip-mode value`
* AGC performance is definitely sensitive to the clipping factor. More experimentation needed to determine good values for smaller batch sizes and optimizers besides those in paper. So far I've found .001-.005 is necessary for stable RMSProp training w/ NFNet/NF-ResNet.

### Feb 12, 2021
* Update Normalization-Free nets to include new NFNet-F (https://arxiv.org/abs/2102.06171) model defs

Expand Down Expand Up @@ -238,6 +245,7 @@ Several (less common) features that I often utilize in my projects are included.
* Efficient Channel Attention - ECA (https://arxiv.org/abs/1910.03151)
* Blur Pooling (https://arxiv.org/abs/1904.11486)
* Space-to-Depth by [mrT23](https://github.com/mrT23/TResNet/blob/master/src/models/tresnet/layers/space_to_depth.py) (https://arxiv.org/abs/1801.04590) -- original paper?
* Adaptive Gradient Clipping (https://arxiv.org/abs/2102.06171, https://github.com/deepmind/deepmind-research/tree/master/nfnets)

## Results

Expand Down
2 changes: 1 addition & 1 deletion timm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from .xception_aligned import *

from .factory import create_model
from .helpers import load_checkpoint, resume_checkpoint
from .helpers import load_checkpoint, resume_checkpoint, model_parameters
from .layers import TestTimePoolHead, apply_test_time_pool
from .layers import convert_splitbn_model
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
Expand Down
20 changes: 13 additions & 7 deletions timm/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,9 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_
digits of the SHA256 hash of the contents of the file. The hash is used to
ensure unique names and to verify the contents of the file. Default: False
"""
if cfg is None:
cfg = getattr(model, 'default_cfg')
if cfg is None or 'url' not in cfg or not cfg['url']:
_logger.warning("Pretrained model URL does not exist, using random initialization.")
cfg = cfg or getattr(model, 'default_cfg')
if cfg is None or not cfg.get('url', None):
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
return
url = cfg['url']

Expand Down Expand Up @@ -174,9 +173,8 @@ def adapt_input_conv(in_chans, conv_weight):


def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False):
if cfg is None:
cfg = getattr(model, 'default_cfg')
if cfg is None or 'url' not in cfg or not cfg['url']:
cfg = cfg or getattr(model, 'default_cfg')
if cfg is None or not cfg.get('url', None):
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
return

Expand Down Expand Up @@ -376,3 +374,11 @@ def build_model_with_cfg(
model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg

return model


def model_parameters(model, exclude_head=False):
if exclude_head:
# FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering
return [p for p in model.parameters()][:-2]
else:
return model.parameters()
2 changes: 2 additions & 0 deletions timm/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .agc import adaptive_clip_grad
from .checkpoint_saver import CheckpointSaver
from .clip_grad import dispatch_clip_grad
from .cuda import ApexScaler, NativeScaler
from .distributed import distribute_bn, reduce_tensor
from .jit import set_jit_legacy
Expand Down
42 changes: 42 additions & 0 deletions timm/utils/agc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
""" Adaptive Gradient Clipping
An impl of AGC, as per (https://arxiv.org/abs/2102.06171):
@article{brock2021high,
author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan},
title={High-Performance Large-Scale Image Recognition Without Normalization},
journal={arXiv preprint arXiv:},
year={2021}
}
Code references:
* Official JAX impl (paper authors): https://github.com/deepmind/deepmind-research/tree/master/nfnets
* Phil Wang's PyTorch gist: https://gist.github.com/lucidrains/0d6560077edac419ab5d3aa29e674d5c
Hacked together by / Copyright 2021 Ross Wightman
"""
import torch


def unitwise_norm(x, norm_type=2.0):
if x.ndim <= 1:
return x.norm(norm_type)
else:
# works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor
# might need special cases for other weights (possibly MHA) where this may not be true
return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True)


def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
for p in parameters:
if p.grad is None:
continue
p_data = p.detach()
g_data = p.grad.detach()
max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor)
grad_norm = unitwise_norm(g_data, norm_type=norm_type)
clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6))
new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad)
p.grad.detach().copy_(new_grads)
23 changes: 23 additions & 0 deletions timm/utils/clip_grad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch

from timm.utils.agc import adaptive_clip_grad


def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0):
""" Dispatch to gradient clipping method
Args:
parameters (Iterable): model parameters to clip
value (float): clipping value/factor/norm, mode dependant
mode (str): clipping mode, one of 'norm', 'value', 'agc'
norm_type (float): p-norm, default 2.0
"""
if mode == 'norm':
torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type)
elif mode == 'value':
torch.nn.utils.clip_grad_value_(parameters, value)
elif mode == 'agc':
adaptive_clip_grad(parameters, value, norm_type=norm_type)
else:
assert False, f"Unknown clip mode ({mode})."

10 changes: 6 additions & 4 deletions timm/utils/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,17 @@
amp = None
has_apex = False

from .clip_grad import dispatch_clip_grad


class ApexScaler:
state_dict_key = "amp"

def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False):
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward(create_graph=create_graph)
if clip_grad is not None:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), clip_grad)
dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode)
optimizer.step()

def state_dict(self):
Expand All @@ -37,12 +39,12 @@ class NativeScaler:
def __init__(self):
self._scaler = torch.cuda.amp.GradScaler()

def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False):
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
self._scaler.scale(loss).backward(create_graph=create_graph)
if clip_grad is not None:
assert parameters is not None
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
self._scaler.step(optimizer)
self._scaler.update()

Expand Down
2 changes: 1 addition & 1 deletion timm/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.4.3'
__version__ = '0.4.4'
14 changes: 10 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from torch.nn.parallel import DistributedDataParallel as NativeDDP

from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model
from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model, model_parameters
from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer
Expand Down Expand Up @@ -116,7 +116,8 @@
help='weight decay (default: 0.0001)')
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')

parser.add_argument('--clip-mode', type=str, default='norm',
help='Gradient clipping mode. One of ("norm", "value", "agc")')


# Learning rate schedule parameters
Expand Down Expand Up @@ -637,11 +638,16 @@ def train_one_epoch(
optimizer.zero_grad()
if loss_scaler is not None:
loss_scaler(
loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order)
loss, optimizer,
clip_grad=args.clip_grad, clip_mode=args.clip_mode,
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
create_graph=second_order)
else:
loss.backward(create_graph=second_order)
if args.clip_grad is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
dispatch_clip_grad(
model_parameters(model, exclude_head='agc' in args.clip_mode),
value=args.clip_grad, mode=args.clip_mode)
optimizer.step()

if model_ema is not None:
Expand Down

0 comments on commit 4ea5931

Please sign in to comment.