Skip to content

Commit

Permalink
Add drop_connect impl to try during training, fix a few comments
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed May 30, 2019
1 parent 0fc4cca commit 4efecfd
Showing 1 changed file with 38 additions and 14 deletions.
52 changes: 38 additions & 14 deletions models/genmobilenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,12 +276,13 @@ class _BlockBuilder:
"""

def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
act_fn=None, se_gate_fn=torch.sigmoid, se_reduce_mid=False,
drop_connect_rate=0., act_fn=None, se_gate_fn=torch.sigmoid, se_reduce_mid=False,
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
folded_bn=False, padding_same=False, verbose=False):
self.channel_multiplier = channel_multiplier
self.channel_divisor = channel_divisor
self.channel_min = channel_min
self.drop_connect_rate = drop_connect_rate
self.act_fn = act_fn
self.se_gate_fn = se_gate_fn
self.se_reduce_mid = se_reduce_mid
Expand Down Expand Up @@ -310,10 +311,12 @@ def _make_block(self, ba):
print('args:', ba)
# could replace this if with lambdas or functools binding if variety increases
if bt == 'ir':
ba['drop_connect_rate'] = self.drop_connect_rate
ba['se_gate_fn'] = self.se_gate_fn
ba['se_reduce_mid'] = self.se_reduce_mid
block = InvertedResidual(**ba)
elif bt == 'ds' or bt == 'dsa':
ba['drop_connect_rate'] = self.drop_connect_rate
block = DepthwiseSeparableConv(**ba)
elif bt == 'ca':
block = CascadeConv(**ba)
Expand Down Expand Up @@ -402,6 +405,19 @@ def hard_sigmoid(x):
return F.relu6(x + 3.) / 6.


def drop_connect(inputs, training=False, drop_connect_rate=0.):
"""Apply drop connect."""
if not training:
return inputs

keep_prob = 1 - drop_connect_rate
random_tensor = keep_prob + torch.rand(
(inputs.size()[0], 1, 1, 1), dtype=inputs.dtype, device=inputs.device)
random_tensor.floor_() # binarize
output = inputs.div(keep_prob) * random_tensor
return output


class ChannelShuffle(nn.Module):
# FIXME haven't used yet
def __init__(self, groups):
Expand Down Expand Up @@ -474,13 +490,14 @@ def __init__(self, in_chs, out_chs, kernel_size,
stride=1, act_fn=F.relu, noskip=False, pw_act=False,
se_ratio=0., se_gate_fn=torch.sigmoid,
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
folded_bn=False, padding_same=False):
folded_bn=False, padding_same=False, drop_connect_rate=0.):
super(DepthwiseSeparableConv, self).__init__()
assert stride in [1, 2]
self.has_se = se_ratio is not None and se_ratio > 0.
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
self.has_pw_act = pw_act # activation after point-wise conv
self.act_fn = act_fn
self.drop_connect_rate = drop_connect_rate
dw_padding = _padding_arg(kernel_size // 2, padding_same)
pw_padding = _padding_arg(0, padding_same)

Expand Down Expand Up @@ -515,7 +532,9 @@ def forward(self, x):
x = self.act_fn(x)

if self.has_residual:
x += residual # FIXME add drop-connect
if self.drop_connect_rate > 0.:
x = drop_connect(x, self.training, self.drop_connect_rate)
x += residual
return x


Expand Down Expand Up @@ -557,12 +576,13 @@ def __init__(self, in_chs, out_chs, kernel_size,
se_ratio=0., se_reduce_mid=False, se_gate_fn=torch.sigmoid,
shuffle_type=None, pw_group=1,
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
folded_bn=False, padding_same=False):
folded_bn=False, padding_same=False, drop_connect_rate=0.):
super(InvertedResidual, self).__init__()
mid_chs = int(in_chs * exp_ratio)
self.has_se = se_ratio is not None and se_ratio > 0.
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.act_fn = act_fn
self.drop_connect_rate = drop_connect_rate
dw_padding = _padding_arg(kernel_size // 2, padding_same)
pw_padding = _padding_arg(0, padding_same)

Expand Down Expand Up @@ -619,7 +639,9 @@ def forward(self, x):
x = self.bn3(x)

if self.has_residual:
x += residual # FIXME add drop-connect
if self.drop_connect_rate > 0.:
x = drop_connect(x, self.training, self.drop_connect_rate)
x += residual

# NOTE maskrcnn_benchmark building blocks have an SE module defined here for some variants

Expand All @@ -643,12 +665,14 @@ class GenMobileNet(nn.Module):
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280,
channel_multiplier=1.0, channel_divisor=8, channel_min=None,
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
drop_rate=0., act_fn=F.relu, se_gate_fn=torch.sigmoid, se_reduce_mid=False,
drop_rate=0., drop_connect_rate=0., act_fn=F.relu,
se_gate_fn=torch.sigmoid, se_reduce_mid=False,
global_pool='avg', head_conv='default', weight_init='goog',
folded_bn=False, padding_same=False):
folded_bn=False, padding_same=False,):
super(GenMobileNet, self).__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
self.drop_connect_rate = drop_connect_rate
self.act_fn = act_fn
self.num_features = num_features

Expand All @@ -661,7 +685,7 @@ def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_f

builder = _BlockBuilder(
channel_multiplier, channel_divisor, channel_min,
act_fn, se_gate_fn, se_reduce_mid,
drop_connect_rate, act_fn, se_gate_fn, se_reduce_mid,
bn_momentum, bn_eps, folded_bn, padding_same, verbose=_DEBUG)
self.blocks = nn.Sequential(*builder(in_chs, block_args))
in_chs = builder.in_chs
Expand Down Expand Up @@ -1090,7 +1114,7 @@ def _gen_spnasnet(channel_multiplier, num_classes=1000, **kwargs):


def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=1000, **kwargs):
"""Creates a MobileNet-V3 model.
"""Creates an EfficientNet model.
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
Paper: https://arxiv.org/abs/1905.11946
Expand Down Expand Up @@ -1347,7 +1371,7 @@ def spnasnet_100(num_classes, in_chans=3, pretrained=False, **kwargs):
def efficientnet_b0(num_classes, in_chans=3, pretrained=False, **kwargs):
""" EfficientNet """
default_cfg = default_cfgs['efficientnet_b0']
# NOTE dropout should be 0.2 for train
# NOTE for train, drop_rate should be 0.2
model = _gen_efficientnet(
channel_multiplier=1.0, depth_multiplier=1.0,
num_classes=num_classes, in_chans=in_chans, **kwargs)
Expand All @@ -1360,7 +1384,7 @@ def efficientnet_b0(num_classes, in_chans=3, pretrained=False, **kwargs):
def efficientnet_b1(num_classes, in_chans=3, pretrained=False, **kwargs):
""" EfficientNet """
default_cfg = default_cfgs['efficientnet_b1']
# NOTE dropout should be 0.2 for train
# NOTE for train, drop_rate should be 0.2
model = _gen_efficientnet(
channel_multiplier=1.0, depth_multiplier=1.1,
num_classes=num_classes, in_chans=in_chans, **kwargs)
Expand All @@ -1373,7 +1397,7 @@ def efficientnet_b1(num_classes, in_chans=3, pretrained=False, **kwargs):
def efficientnet_b2(num_classes, in_chans=3, pretrained=False, **kwargs):
""" EfficientNet """
default_cfg = default_cfgs['efficientnet_b2']
# NOTE dropout should be 0.3 for train
# NOTE for train, drop_rate should be 0.3
model = _gen_efficientnet(
channel_multiplier=1.1, depth_multiplier=1.2,
num_classes=num_classes, in_chans=in_chans, **kwargs)
Expand All @@ -1386,7 +1410,7 @@ def efficientnet_b2(num_classes, in_chans=3, pretrained=False, **kwargs):
def efficientnet_b3(num_classes, in_chans=3, pretrained=False, **kwargs):
""" EfficientNet """
default_cfg = default_cfgs['efficientnet_b3']
# NOTE dropout should be 0.3 for train
# NOTE for train, drop_rate should be 0.3
model = _gen_efficientnet(
channel_multiplier=1.2, depth_multiplier=1.4,
num_classes=num_classes, in_chans=in_chans, **kwargs)
Expand All @@ -1399,7 +1423,7 @@ def efficientnet_b3(num_classes, in_chans=3, pretrained=False, **kwargs):
def efficientnet_b4(num_classes, in_chans=3, pretrained=False, **kwargs):
""" EfficientNet """
default_cfg = default_cfgs['efficientnet_b4']
# NOTE dropout should be 0.4 for train
# NOTE for train, drop_rate should be 0.4
model = _gen_efficientnet(
channel_multiplier=1.4, depth_multiplier=1.8,
num_classes=num_classes, in_chans=in_chans, **kwargs)
Expand Down

0 comments on commit 4efecfd

Please sign in to comment.