Skip to content

Commit

Permalink
batch norm forward over reverse coverage with decomposition
Browse files Browse the repository at this point in the history
  • Loading branch information
samdow committed Jun 15, 2022
1 parent 2b16530 commit 915aecb
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 10 deletions.
100 changes: 93 additions & 7 deletions functorch/_src/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
return min - torch.log1p(z), buffer


def recompute_mean_var(input: Tensor, rstd: Tensor, inner_dim_indices: List[int], keepdim: bool):
# for most norm decompositions, it will be the same as the core version except for here.
# We recompute the mean and variance so that they track gradients through input

mean = torch.mean(input, dim=inner_dim_indices, keepdim=keepdim)
var = torch.var(input, dim=inner_dim_indices, unbiased=False, keepdim=keepdim)
eps = torch.pow(1 / rstd, 2) - var # this makes me so sad inside
eps = eps.detach()
rstd = 1 / torch.sqrt(var + eps)
return mean, rstd

@register_decomposition_for_jvp(aten.native_layer_norm_backward)
def native_layer_norm_backward(
grad_out: Tensor,
Expand Down Expand Up @@ -80,13 +91,7 @@ def native_layer_norm_backward(
input.new_zeros(input_shape[axis:]),
)

# this is exactly the same as the other decomposition except for here. We recompute the mean and variance
# so that they track gradients through input
mean_ = torch.mean(input, dim=inner_dim_indices, keepdim=True)
var = torch.var(input, dim=inner_dim_indices, unbiased=False, keepdim=True)
eps = torch.pow(1 / rstd, 2) - var # this makes me so sad inside
eps = eps.detach()
rstd_ = 1 / torch.sqrt(var + eps)
mean_, rstd_ = recompute_mean_var(input, rstd, inner_dim_indices, keepdim=True)

x_hat = (input - mean_) * rstd_
if weight is not None:
Expand Down Expand Up @@ -128,3 +133,84 @@ def native_layer_norm_backward(
d_bias = torch.zeros(()) # should be None but doesn't work with vjp

return (d_input, d_weight, d_bias)


def prod(x: List[int]):
r = 1
for i in x:
r *= i
return r


@register_decomposition(aten.native_batch_norm_backward) # @register_decomposition_for_jvp after in core
def native_batch_norm_backward(
grad_out: Tensor,
input: Tensor,
weight: Optional[Tensor],
running_mean: Optional[Tensor],
running_var: Optional[Tensor],
save_mean: Optional[Tensor],
save_invstd: Optional[Tensor],
train: bool,
eps: float,
output_mask: List[bool],
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
input_shape = input.shape
input_rank = input.dim()
assert input_rank >= 2, "rank of the input must be at least 2"

axis = 1
num_features = prod(input_shape) / input_shape[axis]
mean = save_mean
invstd = save_invstd
if train:
assert save_mean is not None and save_invstd is not None, "when train=True, save_mean and save_invstd are required"

reduciton_dims = [0] + list(range(2, input.dim()))
assert invstd is not None # for typing
mean, invstd = recompute_mean_var(input, invstd, reduciton_dims, keepdim=False)
else:
assert running_mean is not None and running_var is not None
mean = running_mean
invstd = torch.rsqrt(running_var + eps)

broadcast_mask = [1] * input_rank
broadcast_mask[axis] = input_shape[axis]

reduction_axes: List[int] = []
for i in range(input_rank):
if i != axis:
reduction_axes.append(i)

mean = torch.reshape(mean, broadcast_mask)
norm = 1.0 / num_features
grad_output_sum = torch.sum(grad_out, reduction_axes)
dot_p = torch.sum(grad_out * (input - mean), reduction_axes)

grad_mean = torch.reshape(grad_output_sum * norm, broadcast_mask)
proj_scale = torch.reshape(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask)

if weight is None:
grad_scale = torch.reshape(invstd, broadcast_mask) * 1.0
else:
grad_scale = torch.reshape(invstd * weight, broadcast_mask)

if train:
proj = (input - mean) * proj_scale
grad_input = ((grad_out - proj) - grad_mean) * grad_scale
else:
grad_input = grad_out * grad_scale

if output_mask[1]:
grad_weight = dot_p * invstd
elif weight is not None:
grad_weight = torch.zeros_like(weight) # should be None but doesn't work with vjp
else:
grad_weight = torch.zeros(()) # should be None but doesn't work with vjp

if output_mask[2]:
grad_bias = grad_output_sum
else:
grad_bias = torch.zeros_like(grad_output_sum) # should be None but doesn't work with vjp

return (grad_input, grad_weight, grad_bias)
1 change: 1 addition & 0 deletions functorch/_src/eager_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1339,5 +1339,6 @@ def _register_python_decomposition_vmap(decomp):
_register_jit_decomposition(torch.ops.aten._softmax_backward_data.default)
_register_jit_decomposition(torch.ops.aten.log_sigmoid_forward.default)
_register_jit_decomposition(torch.ops.aten.native_layer_norm_backward.default)
_register_jit_decomposition(torch.ops.aten.native_batch_norm_backward.default, use_python=True)
_register_python_decomposition_vmap(torch.ops.aten.mse_loss_backward.default)
_register_python_decomposition_vmap(torch.ops.aten.addr.default)
1 change: 1 addition & 0 deletions functorch/csrc/DynamicLayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,7 @@ TORCH_LIBRARY_IMPL(aten, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) {
OP_DECOMPOSE(log_sigmoid);
JVP_DECOMP(log_sigmoid_forward);
JVP_DECOMP(native_layer_norm_backward);
JVP_DECOMP(native_batch_norm_backward);
}


Expand Down
5 changes: 2 additions & 3 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,8 +1146,6 @@ def get_vjp(cotangents, *primals):
xfail('logdet', ''),
xfail('nanmean', ''),
xfail('nansum', ''),
xfail('nn.functional.batch_norm', ''),
xfail('nn.functional.batch_norm', 'without_cudnn', device_type='cuda'),
xfail('nn.functional.embedding'),
xfail('nn.functional.embedding', 'functorch'),
xfail('nn.functional.embedding_bag', ''),
Expand Down Expand Up @@ -1246,7 +1244,8 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents):
'softmax',
'log_softmax',
'nn.functional.cross_entropy',
'nn.functional.layer_norm'
'nn.functional.layer_norm',
'nn.functional.batch_norm',
}
if op.name in FUNCTORCH_HAS_FORMULA_BUT_NOT_PYTORCH:
self.assertFalse(op.supports_fwgrad_bwgrad,
Expand Down

0 comments on commit 915aecb

Please sign in to comment.