Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jvpvjp] Batch norm coverage with decomposition #877

Merged
merged 1 commit into from
Jun 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 94 additions & 7 deletions functorch/_src/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,18 @@ def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
return min - torch.log1p(z), buffer


def recompute_mean_var(input: Tensor, rstd: Tensor, inner_dim_indices: List[int], keepdim: bool):
# for most norm decompositions, it will be the same as the core version except for here.
# We recompute the mean and variance so that they track gradients through input

mean = torch.mean(input, dim=inner_dim_indices, keepdim=keepdim)
var = torch.var(input, dim=inner_dim_indices, unbiased=False, keepdim=keepdim)
eps = torch.pow(1 / rstd, 2) - var # this makes me so sad inside
eps = eps.detach()
rstd = 1 / torch.sqrt(var + eps)
return mean, rstd


@register_decomposition_for_jvp(aten.native_layer_norm_backward)
def native_layer_norm_backward(
grad_out: Tensor,
Expand Down Expand Up @@ -80,13 +92,7 @@ def native_layer_norm_backward(
input.new_zeros(input_shape[axis:]),
)

# this is exactly the same as the other decomposition except for here. We recompute the mean and variance
# so that they track gradients through input
mean_ = torch.mean(input, dim=inner_dim_indices, keepdim=True)
var = torch.var(input, dim=inner_dim_indices, unbiased=False, keepdim=True)
eps = torch.pow(1 / rstd, 2) - var # this makes me so sad inside
eps = eps.detach()
rstd_ = 1 / torch.sqrt(var + eps)
mean_, rstd_ = recompute_mean_var(input, rstd, inner_dim_indices, keepdim=True)

x_hat = (input - mean_) * rstd_
if weight is not None:
Expand Down Expand Up @@ -128,3 +134,84 @@ def native_layer_norm_backward(
d_bias = torch.zeros(()) # should be None but doesn't work with vjp

return (d_input, d_weight, d_bias)


def prod(x: List[int]):
r = 1
for i in x:
r *= i
return r


@register_decomposition(aten.native_batch_norm_backward) # @register_decomposition_for_jvp after in core
def native_batch_norm_backward(
grad_out: Tensor,
input: Tensor,
weight: Optional[Tensor],
running_mean: Optional[Tensor],
running_var: Optional[Tensor],
save_mean: Optional[Tensor],
save_invstd: Optional[Tensor],
train: bool,
eps: float,
output_mask: List[bool],
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
input_shape = input.shape
input_rank = input.dim()
assert input_rank >= 2, "rank of the input must be at least 2"

axis = 1
num_features = prod(input_shape) / input_shape[axis]
mean = save_mean
invstd = save_invstd
if train:
assert save_mean is not None and save_invstd is not None, "when train=True, save_mean and save_invstd are required"

reduciton_dims = [0] + list(range(2, input.dim()))
assert invstd is not None # for typing
mean, invstd = recompute_mean_var(input, invstd, reduciton_dims, keepdim=False)
else:
assert running_mean is not None and running_var is not None
mean = running_mean
invstd = torch.rsqrt(running_var + eps)

broadcast_mask = [1] * input_rank
broadcast_mask[axis] = input_shape[axis]

reduction_axes: List[int] = []
for i in range(input_rank):
if i != axis:
reduction_axes.append(i)

mean = torch.reshape(mean, broadcast_mask)
norm = 1.0 / num_features
grad_output_sum = torch.sum(grad_out, reduction_axes)
dot_p = torch.sum(grad_out * (input - mean), reduction_axes)

grad_mean = torch.reshape(grad_output_sum * norm, broadcast_mask)
proj_scale = torch.reshape(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask)

if weight is None:
grad_scale = torch.reshape(invstd, broadcast_mask) * 1.0
else:
grad_scale = torch.reshape(invstd * weight, broadcast_mask)

if train:
proj = (input - mean) * proj_scale
grad_input = ((grad_out - proj) - grad_mean) * grad_scale
else:
grad_input = grad_out * grad_scale

if output_mask[1]:
grad_weight = dot_p * invstd
elif weight is not None:
grad_weight = torch.zeros_like(weight) # should be None but doesn't work with vjp
else:
grad_weight = torch.zeros(()) # should be None but doesn't work with vjp

if output_mask[2]:
grad_bias = grad_output_sum
else:
grad_bias = torch.zeros_like(grad_output_sum) # should be None but doesn't work with vjp
Comment on lines +208 to +215
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sad, but is it what it is


return (grad_input, grad_weight, grad_bias)
2 changes: 2 additions & 0 deletions functorch/_src/eager_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1339,5 +1339,7 @@ def _register_python_decomposition_vmap(decomp):
_register_jit_decomposition(torch.ops.aten._softmax_backward_data.default)
_register_jit_decomposition(torch.ops.aten.log_sigmoid_forward.default)
_register_jit_decomposition(torch.ops.aten.native_layer_norm_backward.default)
_register_jit_decomposition(torch.ops.aten.native_batch_norm_backward.default)
_register_jit_decomposition(torch.ops.aten.cudnn_batch_norm_backward.default)
_register_python_decomposition_vmap(torch.ops.aten.mse_loss_backward.default)
_register_python_decomposition_vmap(torch.ops.aten.addr.default)
2 changes: 2 additions & 0 deletions functorch/csrc/DynamicLayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,8 @@ TORCH_LIBRARY_IMPL(aten, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) {
OP_DECOMPOSE(log_sigmoid);
JVP_DECOMP(log_sigmoid_forward);
JVP_DECOMP(native_layer_norm_backward);
JVP_DECOMP(native_batch_norm_backward);
JVP_DECOMP(cudnn_batch_norm_backward);
}


Expand Down
5 changes: 2 additions & 3 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,8 +1149,6 @@ def get_vjp(cotangents, *primals):
xfail('logdet', ''),
xfail('nanmean', ''),
xfail('nansum', ''),
xfail('nn.functional.batch_norm', ''),
xfail('nn.functional.batch_norm', 'without_cudnn', device_type='cuda'),
xfail('nn.functional.embedding'),
xfail('nn.functional.embedding', 'functorch'),
xfail('nn.functional.embedding_bag', ''),
Expand Down Expand Up @@ -1249,7 +1247,8 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents):
'softmax',
'log_softmax',
'nn.functional.cross_entropy',
'nn.functional.layer_norm'
'nn.functional.layer_norm',
'nn.functional.batch_norm',
}
if op.name in FUNCTORCH_HAS_FORMULA_BUT_NOT_PYTORCH:
self.assertFalse(op.supports_fwgrad_bwgrad,
Expand Down