diff --git a/README.md b/README.md index c4f3a588a1..421bced44e 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,13 @@ ## What's New +### Feb 16, 2021 +* Add Adaptive Gradient Clipping (AGC) as per https://arxiv.org/abs/2102.06171. Integrated w/ PyTorch gradient clipping via mode arg that defaults to prev 'norm' mode. For backward arg compat, clip-grad arg must be specified to enable when using train.py. + * AGC w/ default clipping factor `--clip-grad .01 --clip-mode agc` + * PyTorch global norm of 1.0 (old behaviour, always norm), `--clip-grad 1.0` + * PyTorch value clipping of 10, `--clip-grad 10. --clip-mode value` + * AGC performance is definitely sensitive to the clipping factor. More experimentation needed to determine good values for smaller batch sizes and optimizers besides those in paper. So far I've found .001-.005 is necessary for stable RMSProp training w/ NFNet/NF-ResNet. + ### Feb 12, 2021 * Update Normalization-Free nets to include new NFNet-F (https://arxiv.org/abs/2102.06171) model defs @@ -238,6 +245,7 @@ Several (less common) features that I often utilize in my projects are included. * Efficient Channel Attention - ECA (https://arxiv.org/abs/1910.03151) * Blur Pooling (https://arxiv.org/abs/1904.11486) * Space-to-Depth by [mrT23](https://github.com/mrT23/TResNet/blob/master/src/models/tresnet/layers/space_to_depth.py) (https://arxiv.org/abs/1801.04590) -- original paper? +* Adaptive Gradient Clipping (https://arxiv.org/abs/2102.06171, https://github.com/deepmind/deepmind-research/tree/master/nfnets) ## Results diff --git a/timm/models/__init__.py b/timm/models/__init__.py index dc56848e2e..8d99d19bdc 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -31,7 +31,7 @@ from .xception_aligned import * from .factory import create_model -from .helpers import load_checkpoint, resume_checkpoint +from .helpers import load_checkpoint, resume_checkpoint, model_parameters from .layers import TestTimePoolHead, apply_test_time_pool from .layers import convert_splitbn_model from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit diff --git a/timm/models/helpers.py b/timm/models/helpers.py index d9b501dac9..4d9b8a2853 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -113,10 +113,9 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_ digits of the SHA256 hash of the contents of the file. The hash is used to ensure unique names and to verify the contents of the file. Default: False """ - if cfg is None: - cfg = getattr(model, 'default_cfg') - if cfg is None or 'url' not in cfg or not cfg['url']: - _logger.warning("Pretrained model URL does not exist, using random initialization.") + cfg = cfg or getattr(model, 'default_cfg') + if cfg is None or not cfg.get('url', None): + _logger.warning("No pretrained weights exist for this model. Using random initialization.") return url = cfg['url'] @@ -174,9 +173,8 @@ def adapt_input_conv(in_chans, conv_weight): def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False): - if cfg is None: - cfg = getattr(model, 'default_cfg') - if cfg is None or 'url' not in cfg or not cfg['url']: + cfg = cfg or getattr(model, 'default_cfg') + if cfg is None or not cfg.get('url', None): _logger.warning("No pretrained weights exist for this model. Using random initialization.") return @@ -376,3 +374,11 @@ def build_model_with_cfg( model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg return model + + +def model_parameters(model, exclude_head=False): + if exclude_head: + # FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering + return [p for p in model.parameters()][:-2] + else: + return model.parameters() diff --git a/timm/utils/__init__.py b/timm/utils/__init__.py index 0f7c4b0555..1c526e8ce1 100644 --- a/timm/utils/__init__.py +++ b/timm/utils/__init__.py @@ -1,4 +1,6 @@ +from .agc import adaptive_clip_grad from .checkpoint_saver import CheckpointSaver +from .clip_grad import dispatch_clip_grad from .cuda import ApexScaler, NativeScaler from .distributed import distribute_bn, reduce_tensor from .jit import set_jit_legacy diff --git a/timm/utils/agc.py b/timm/utils/agc.py new file mode 100644 index 0000000000..f51401726f --- /dev/null +++ b/timm/utils/agc.py @@ -0,0 +1,42 @@ +""" Adaptive Gradient Clipping + +An impl of AGC, as per (https://arxiv.org/abs/2102.06171): + +@article{brock2021high, + author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan}, + title={High-Performance Large-Scale Image Recognition Without Normalization}, + journal={arXiv preprint arXiv:}, + year={2021} +} + +Code references: + * Official JAX impl (paper authors): https://github.com/deepmind/deepmind-research/tree/master/nfnets + * Phil Wang's PyTorch gist: https://gist.github.com/lucidrains/0d6560077edac419ab5d3aa29e674d5c + +Hacked together by / Copyright 2021 Ross Wightman +""" +import torch + + +def unitwise_norm(x, norm_type=2.0): + if x.ndim <= 1: + return x.norm(norm_type) + else: + # works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor + # might need special cases for other weights (possibly MHA) where this may not be true + return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True) + + +def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + for p in parameters: + if p.grad is None: + continue + p_data = p.detach() + g_data = p.grad.detach() + max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor) + grad_norm = unitwise_norm(g_data, norm_type=norm_type) + clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6)) + new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad) + p.grad.detach().copy_(new_grads) diff --git a/timm/utils/clip_grad.py b/timm/utils/clip_grad.py new file mode 100644 index 0000000000..7eb40697a2 --- /dev/null +++ b/timm/utils/clip_grad.py @@ -0,0 +1,23 @@ +import torch + +from timm.utils.agc import adaptive_clip_grad + + +def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0): + """ Dispatch to gradient clipping method + + Args: + parameters (Iterable): model parameters to clip + value (float): clipping value/factor/norm, mode dependant + mode (str): clipping mode, one of 'norm', 'value', 'agc' + norm_type (float): p-norm, default 2.0 + """ + if mode == 'norm': + torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type) + elif mode == 'value': + torch.nn.utils.clip_grad_value_(parameters, value) + elif mode == 'agc': + adaptive_clip_grad(parameters, value, norm_type=norm_type) + else: + assert False, f"Unknown clip mode ({mode})." + diff --git a/timm/utils/cuda.py b/timm/utils/cuda.py index bcd29f5801..9e7bddf304 100644 --- a/timm/utils/cuda.py +++ b/timm/utils/cuda.py @@ -11,15 +11,17 @@ amp = None has_apex = False +from .clip_grad import dispatch_clip_grad + class ApexScaler: state_dict_key = "amp" - def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False): + def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward(create_graph=create_graph) if clip_grad is not None: - torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), clip_grad) + dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode) optimizer.step() def state_dict(self): @@ -37,12 +39,12 @@ class NativeScaler: def __init__(self): self._scaler = torch.cuda.amp.GradScaler() - def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False): + def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): self._scaler.scale(loss).backward(create_graph=create_graph) if clip_grad is not None: assert parameters is not None self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place - torch.nn.utils.clip_grad_norm_(parameters, clip_grad) + dispatch_clip_grad(parameters, clip_grad, mode=clip_mode) self._scaler.step(optimizer) self._scaler.update() diff --git a/timm/version.py b/timm/version.py index 908c0bb70e..9a8e054a63 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.4.3' +__version__ = '0.4.4' diff --git a/train.py b/train.py index 0333d72ffd..9abcfed3e5 100755 --- a/train.py +++ b/train.py @@ -29,7 +29,7 @@ from torch.nn.parallel import DistributedDataParallel as NativeDDP from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset -from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model +from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model, model_parameters from timm.utils import * from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy from timm.optim import create_optimizer @@ -116,7 +116,8 @@ help='weight decay (default: 0.0001)') parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', help='Clip gradient norm (default: None, no clipping)') - +parser.add_argument('--clip-mode', type=str, default='norm', + help='Gradient clipping mode. One of ("norm", "value", "agc")') # Learning rate schedule parameters @@ -637,11 +638,16 @@ def train_one_epoch( optimizer.zero_grad() if loss_scaler is not None: loss_scaler( - loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order) + loss, optimizer, + clip_grad=args.clip_grad, clip_mode=args.clip_mode, + parameters=model_parameters(model, exclude_head='agc' in args.clip_mode), + create_graph=second_order) else: loss.backward(create_graph=second_order) if args.clip_grad is not None: - torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad) + dispatch_clip_grad( + model_parameters(model, exclude_head='agc' in args.clip_mode), + value=args.clip_grad, mode=args.clip_mode) optimizer.step() if model_ema is not None: