From 825f439e344d3c51aad364072f2208f2f78541c0 Mon Sep 17 00:00:00 2001 From: samdow Date: Wed, 15 Jun 2022 11:13:08 -0400 Subject: [PATCH] batch norm forward over reverse coverage with decomposition --- functorch/_src/decompositions.py | 101 +++++++++++++++++++++++++++-- functorch/_src/eager_transforms.py | 2 + functorch/csrc/DynamicLayer.cpp | 2 + test/test_ops.py | 5 +- 4 files changed, 100 insertions(+), 10 deletions(-) diff --git a/functorch/_src/decompositions.py b/functorch/_src/decompositions.py index 02b08009b..8a1cc09ce 100644 --- a/functorch/_src/decompositions.py +++ b/functorch/_src/decompositions.py @@ -47,6 +47,18 @@ def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]: return min - torch.log1p(z), buffer +def recompute_mean_var(input: Tensor, rstd: Tensor, inner_dim_indices: List[int], keepdim: bool): + # for most norm decompositions, it will be the same as the core version except for here. + # We recompute the mean and variance so that they track gradients through input + + mean = torch.mean(input, dim=inner_dim_indices, keepdim=keepdim) + var = torch.var(input, dim=inner_dim_indices, unbiased=False, keepdim=keepdim) + eps = torch.pow(1 / rstd, 2) - var # this makes me so sad inside + eps = eps.detach() + rstd = 1 / torch.sqrt(var + eps) + return mean, rstd + + @register_decomposition_for_jvp(aten.native_layer_norm_backward) def native_layer_norm_backward( grad_out: Tensor, @@ -80,13 +92,7 @@ def native_layer_norm_backward( input.new_zeros(input_shape[axis:]), ) - # this is exactly the same as the other decomposition except for here. We recompute the mean and variance - # so that they track gradients through input - mean_ = torch.mean(input, dim=inner_dim_indices, keepdim=True) - var = torch.var(input, dim=inner_dim_indices, unbiased=False, keepdim=True) - eps = torch.pow(1 / rstd, 2) - var # this makes me so sad inside - eps = eps.detach() - rstd_ = 1 / torch.sqrt(var + eps) + mean_, rstd_ = recompute_mean_var(input, rstd, inner_dim_indices, keepdim=True) x_hat = (input - mean_) * rstd_ if weight is not None: @@ -128,3 +134,84 @@ def native_layer_norm_backward( d_bias = torch.zeros(()) # should be None but doesn't work with vjp return (d_input, d_weight, d_bias) + + +def prod(x: List[int]): + r = 1 + for i in x: + r *= i + return r + + +@register_decomposition(aten.native_batch_norm_backward) # @register_decomposition_for_jvp after in core +def native_batch_norm_backward( + grad_out: Tensor, + input: Tensor, + weight: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + save_mean: Optional[Tensor], + save_invstd: Optional[Tensor], + train: bool, + eps: float, + output_mask: List[bool], +) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + input_shape = input.shape + input_rank = input.dim() + assert input_rank >= 2, "rank of the input must be at least 2" + + axis = 1 + num_features = prod(input_shape) / input_shape[axis] + mean = save_mean + invstd = save_invstd + if train: + assert save_mean is not None and save_invstd is not None, "when train=True, save_mean and save_invstd are required" + + reduciton_dims = [0] + list(range(2, input.dim())) + assert invstd is not None # for typing + mean, invstd = recompute_mean_var(input, invstd, reduciton_dims, keepdim=False) + else: + assert running_mean is not None and running_var is not None + mean = running_mean + invstd = torch.rsqrt(running_var + eps) + + broadcast_mask = [1] * input_rank + broadcast_mask[axis] = input_shape[axis] + + reduction_axes: List[int] = [] + for i in range(input_rank): + if i != axis: + reduction_axes.append(i) + + mean = torch.reshape(mean, broadcast_mask) + norm = 1.0 / num_features + grad_output_sum = torch.sum(grad_out, reduction_axes) + dot_p = torch.sum(grad_out * (input - mean), reduction_axes) + + grad_mean = torch.reshape(grad_output_sum * norm, broadcast_mask) + proj_scale = torch.reshape(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask) + + if weight is None: + grad_scale = torch.reshape(invstd, broadcast_mask) * 1.0 + else: + grad_scale = torch.reshape(invstd * weight, broadcast_mask) + + if train: + proj = (input - mean) * proj_scale + grad_input = ((grad_out - proj) - grad_mean) * grad_scale + else: + grad_input = grad_out * grad_scale + + if output_mask[1]: + grad_weight = dot_p * invstd + elif weight is not None: + grad_weight = torch.zeros_like(weight) # should be None but doesn't work with vjp + else: + grad_weight = torch.zeros(()) # should be None but doesn't work with vjp + + if output_mask[2]: + grad_bias = grad_output_sum + else: + grad_bias = torch.zeros_like(grad_output_sum) # should be None but doesn't work with vjp + + return (grad_input, grad_weight, grad_bias) diff --git a/functorch/_src/eager_transforms.py b/functorch/_src/eager_transforms.py index a1fc68189..76e4874ac 100644 --- a/functorch/_src/eager_transforms.py +++ b/functorch/_src/eager_transforms.py @@ -1339,5 +1339,7 @@ def _register_python_decomposition_vmap(decomp): _register_jit_decomposition(torch.ops.aten._softmax_backward_data.default) _register_jit_decomposition(torch.ops.aten.log_sigmoid_forward.default) _register_jit_decomposition(torch.ops.aten.native_layer_norm_backward.default) +_register_jit_decomposition(torch.ops.aten.native_batch_norm_backward.default) +_register_jit_decomposition(torch.ops.aten.cudnn_batch_norm_backward.default) _register_python_decomposition_vmap(torch.ops.aten.mse_loss_backward.default) _register_python_decomposition_vmap(torch.ops.aten.addr.default) diff --git a/functorch/csrc/DynamicLayer.cpp b/functorch/csrc/DynamicLayer.cpp index 65d1d2477..8bfd38835 100644 --- a/functorch/csrc/DynamicLayer.cpp +++ b/functorch/csrc/DynamicLayer.cpp @@ -502,6 +502,8 @@ TORCH_LIBRARY_IMPL(aten, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) { OP_DECOMPOSE(log_sigmoid); JVP_DECOMP(log_sigmoid_forward); JVP_DECOMP(native_layer_norm_backward); + JVP_DECOMP(native_batch_norm_backward); + JVP_DECOMP(cudnn_batch_norm_backward); } diff --git a/test/test_ops.py b/test/test_ops.py index dda8182ac..cd74b71e9 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1149,8 +1149,6 @@ def get_vjp(cotangents, *primals): xfail('logdet', ''), xfail('nanmean', ''), xfail('nansum', ''), - xfail('nn.functional.batch_norm', ''), - xfail('nn.functional.batch_norm', 'without_cudnn', device_type='cuda'), xfail('nn.functional.embedding'), xfail('nn.functional.embedding', 'functorch'), xfail('nn.functional.embedding_bag', ''), @@ -1249,7 +1247,8 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents): 'softmax', 'log_softmax', 'nn.functional.cross_entropy', - 'nn.functional.layer_norm' + 'nn.functional.layer_norm', + 'nn.functional.batch_norm', } if op.name in FUNCTORCH_HAS_FORMULA_BUT_NOT_PYTORCH: self.assertFalse(op.supports_fwgrad_bwgrad,