Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supporting aimv2 encoders #2379

Merged
merged 10 commits into from
Dec 31, 2024
10 changes: 5 additions & 5 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2'
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*'
]

# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*', 'aimv2*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*', 'sam_hiera*',
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*', 'test_vit*',
]
Expand All @@ -72,11 +72,11 @@
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*',
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', '*huge*', '*giant*', '*gigantic*',
'*enormous*', 'maxvit_xlarge*', 'regnet*1280', 'regnet*2560']
NON_STD_EXCLUDE_FILTERS = ['*huge*', '*giant*', '*gigantic*', '*enormous*']
'*enormous*', 'maxvit_xlarge*', 'regnet*1280', 'regnet*2560', '*_1b_*', '*_3b_*']
NON_STD_EXCLUDE_FILTERS = ['*huge*', '*giant*', '*gigantic*', '*enormous*', '*_1b_*', '*_3b_*']
else:
EXCLUDE_FILTERS = ['*enormous*']
NON_STD_EXCLUDE_FILTERS = ['*gigantic*', '*enormous*']
NON_STD_EXCLUDE_FILTERS = ['*gigantic*', '*enormous*', '*_3b_*']

EXCLUDE_JIT_FILTERS = ['hiera_*']

Expand Down
2 changes: 1 addition & 1 deletion timm/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d, SimpleNorm, SimpleNorm2d
from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
from .padding import get_padding, get_same_padding, pad_same
Expand Down
4 changes: 3 additions & 1 deletion timm/layers/create_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch.nn as nn

from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d, SimpleNorm, SimpleNorm2d
from torchvision.ops.misc import FrozenBatchNorm2d

_NORM_MAP = dict(
Expand All @@ -23,6 +23,8 @@
layernorm2d=LayerNorm2d,
rmsnorm=RmsNorm,
rmsnorm2d=RmsNorm2d,
simplenorm=SimpleNorm,
simplenorm2d=SimpleNorm2d,
frozenbatchnorm2d=FrozenBatchNorm2d,
)
_NORM_TYPES = {m for n, m in _NORM_MAP.items()}
Expand Down
68 changes: 62 additions & 6 deletions timm/layers/fast_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
has_apex_rmsnorm = False


has_torch_rms_norm = hasattr(F, 'rms_norm')

# fast (ie lower precision LN) can be disabled with this flag if issues crop up
_USE_FAST_NORM = False # defaulting to False for now

Expand Down Expand Up @@ -75,7 +77,6 @@ def fast_group_norm(
if is_autocast_enabled(x.device.type):
# normally native AMP casts GN inputs to float32
# here we use the low precision autocast dtype
# FIXME what to do re CPU autocast?
dt = get_autocast_dtype(x.device.type)
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None

Expand All @@ -101,7 +102,6 @@ def fast_layer_norm(
# normally native AMP casts LN inputs to float32
# apex LN does not, this is behaving like Apex
dt = get_autocast_dtype(x.device.type)
# FIXME what to do re CPU autocast?
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None

with torch.amp.autocast(device_type=x.device.type, enabled=False):
Expand All @@ -115,15 +115,16 @@ def rms_norm(
eps: float = 1e-5,
):
norm_ndim = len(normalized_shape)
v = x.pow(2)
if torch.jit.is_scripting():
# ndim = len(x.shape)
# dims = list(range(ndim - norm_ndim, ndim)) # this doesn't work on pytorch <= 1.13.x
# NOTE -ve dims cause torchscript to crash in some cases, out of options to work around
assert norm_ndim == 1
v = torch.var(x, dim=-1).unsqueeze(-1) # ts crashes with -ve dim + keepdim=True
v = torch.mean(v, dim=-1).unsqueeze(-1) # ts crashes with -ve dim + keepdim=True
else:
dims = tuple(range(-1, -norm_ndim - 1, -1))
v = torch.var(x, dim=dims, keepdim=True)
v = torch.mean(v, dim=dims, keepdim=True)
x = x * torch.rsqrt(v + eps)
if weight is not None:
x = x * weight
Expand All @@ -146,5 +147,60 @@ def fast_rms_norm(
else:
return fused_rms_norm_affine(x, weight, normalized_shape, eps)

# fallback
return rms_norm(x, normalized_shape, weight, eps)
if is_autocast_enabled(x.device.type):
# normally native AMP casts LN inputs to float32
# apex LN does not, this is behaving like Apex
dt = get_autocast_dtype(x.device.type)
x, weight = x.to(dt), weight.to(dt)

with torch.amp.autocast(device_type=x.device.type, enabled=False):
if has_torch_rms_norm:
x = F.rms_norm(x, normalized_shape, weight, eps)
else:
x = rms_norm(x, normalized_shape, weight, eps)

return x


def simple_norm(
x: torch.Tensor,
normalized_shape: List[int],
weight: Optional[torch.Tensor] = None,
eps: float = 1e-5,
):
norm_ndim = len(normalized_shape)
if torch.jit.is_scripting():
# ndim = len(x.shape)
# dims = list(range(ndim - norm_ndim, ndim)) # this doesn't work on pytorch <= 1.13.x
# NOTE -ve dims cause torchscript to crash in some cases, out of options to work around
assert norm_ndim == 1
v = torch.var(x, dim=-1).unsqueeze(-1) # ts crashes with -ve dim + keepdim=True
else:
dims = tuple(range(-1, -norm_ndim - 1, -1))
v = torch.var(x, dim=dims, keepdim=True)
x = x * torch.rsqrt(v + eps)
if weight is not None:
x = x * weight
return x


def fast_simple_norm(
x: torch.Tensor,
normalized_shape: List[int],
weight: Optional[torch.Tensor] = None,
eps: float = 1e-5,
) -> torch.Tensor:
if torch.jit.is_scripting():
# this must be by itself, cannot merge with has_apex_rmsnorm
return simple_norm(x, normalized_shape, weight, eps)

if is_autocast_enabled(x.device.type):
# normally native AMP casts LN inputs to float32
# apex LN does not, this is behaving like Apex
dt = get_autocast_dtype(x.device.type)
x, weight = x.to(dt), weight.to(dt)

with torch.amp.autocast(device_type=x.device.type, enabled=False):
x = simple_norm(x, normalized_shape, weight, eps)
return x

9 changes: 5 additions & 4 deletions timm/layers/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ def __init__(

def init_weights(self):
# override init of fc1 w/ gate portion set to weight near zero, bias=1
fc1_mid = self.fc1.bias.shape[0] // 2
nn.init.ones_(self.fc1.bias[fc1_mid:])
nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6)
if self.fc1.bias is not None:
nn.init.ones_(self.fc1.bias[self.fc1.bias.shape[0] // 2:])
nn.init.normal_(self.fc1.weight[self.fc1.weight.shape[0] // 2:], std=1e-6)

def forward(self, x):
x = self.fc1(x)
Expand Down Expand Up @@ -132,7 +132,8 @@ def __init__(

def init_weights(self):
# override init of fc1 w/ gate portion set to weight near zero, bias=1
nn.init.ones_(self.fc1_g.bias)
if self.fc1_g.bias is not None:
nn.init.ones_(self.fc1_g.bias)
nn.init.normal_(self.fc1_g.weight, std=1e-6)

def forward(self, x):
Expand Down
124 changes: 115 additions & 9 deletions timm/layers/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,24 @@
import torch.nn as nn
import torch.nn.functional as F

from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm, fast_simple_norm, simple_norm

try:
from torch.nn.functional import rms_norm
except ImportError:
from .fast_norm import rms_norm


class GroupNorm(nn.GroupNorm):
_fast_norm: torch.jit.Final[bool]

def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True):
# NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN
super().__init__(num_groups, num_channels, eps=eps, affine=affine)
self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)

def forward(self, x):
if self.fast_norm:
if self._fast_norm:
return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
else:
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
Expand All @@ -31,13 +38,14 @@ class GroupNorm1(nn.GroupNorm):
""" Group Normalization with 1 group.
Input: tensor in shape [B, C, *]
"""
_fast_norm: torch.jit.Final[bool]

def __init__(self, num_channels, **kwargs):
super().__init__(1, num_channels, **kwargs)
self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.fast_norm:
if self._fast_norm:
return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
else:
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
Expand All @@ -46,6 +54,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
class LayerNorm(nn.LayerNorm):
""" LayerNorm w/ fast norm option
"""
_fast_norm: torch.jit.Final[bool]

def __init__(self, num_channels, eps=1e-6, affine=True):
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
Expand All @@ -60,6 +70,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

class LayerNorm2d(nn.LayerNorm):
""" LayerNorm for channels of '2D' spatial NCHW tensors """
_fast_norm: torch.jit.Final[bool]

def __init__(self, num_channels, eps=1e-6, affine=True):
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
Expand Down Expand Up @@ -121,10 +133,11 @@ def forward(self, x) -> torch.Tensor:
class RmsNorm(nn.Module):
""" RmsNorm w/ fast (apex) norm if available
"""
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
normalized_shape: Tuple[int, ...]
eps: float
elementwise_affine: bool
_fast_norm: bool

def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
Expand All @@ -136,6 +149,8 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) ->
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = affine
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)

if self.elementwise_affine:
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
Expand All @@ -150,17 +165,21 @@ def reset_parameters(self) -> None:
def forward(self, x: torch.Tensor) -> torch.Tensor:
# NOTE fast norm fallback needs our rms norm impl, so both paths through here.
# Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
if self._fast_norm:
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
else:
x = rms_norm(x, self.normalized_shape, self.weight, self.eps)
return x


class RmsNorm2d(nn.Module):
""" RmsNorm w/ fast (apex) norm if available
"""
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
normalized_shape: Tuple[int, ...]
eps: float
elementwise_affine: bool
_fast_norm: bool

def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
Expand All @@ -172,6 +191,8 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) ->
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = affine
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)

if self.elementwise_affine:
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
Expand All @@ -187,6 +208,91 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.permute(0, 2, 3, 1)
# NOTE fast norm fallback needs our rms norm impl, so both paths through here.
# Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
if self._fast_norm:
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
else:
x = rms_norm(x, self.normalized_shape, self.weight, self.eps)
x = x.permute(0, 3, 1, 2)
return x


class SimpleNorm(nn.Module):
""" SimpleNorm (x / std(x))
"""
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
normalized_shape: Tuple[int, ...]
eps: float
elementwise_affine: bool
_fast_norm: bool

def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
normalized_shape = channels
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = affine
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)

if self.elementwise_affine:
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
self.register_parameter('weight', None)

self.reset_parameters()

def reset_parameters(self) -> None:
if self.elementwise_affine:
nn.init.ones_(self.weight)

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self._fast_norm:
x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps)
else:
x = simple_norm(x, self.normalized_shape, self.weight, self.eps)
return x


class SimpleNorm2d(nn.Module):
""" SimpleNorm for NCHW tensors
"""
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
normalized_shape: Tuple[int, ...]
eps: float
elementwise_affine: bool
_fast_norm: bool

def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
normalized_shape = channels
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = affine
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)

if self.elementwise_affine:
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
self.register_parameter('weight', None)

self.reset_parameters()

def reset_parameters(self) -> None:
if self.elementwise_affine:
nn.init.ones_(self.weight)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.permute(0, 2, 3, 1)
if self._fast_norm:
x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps)
else:
x = simple_norm(x, self.normalized_shape, self.weight, self.eps)
x = x.permute(0, 3, 1, 2)
return x
Loading
Loading