Skip to content

Commit

Permalink
Merge pull request #120 from Yoctol/spectral-weight-decay
Browse files Browse the repository at this point in the history
Spectral weight decay
  • Loading branch information
noobOriented authored Nov 11, 2019
2 parents a375107 + 24d9b1b commit b88c9b0
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 57 deletions.
2 changes: 1 addition & 1 deletion talos/__version__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__title__ = 'talos'
__version__ = '1.6.2'
__version__ = '1.6.3'
__description__ = 'Powerful Neural Network Builder'
__author__ = 'Jsaon'
99 changes: 99 additions & 0 deletions talos/optimizers/spectral_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import Callable, Container, Union

import tensorflow as tf


class SpectralWeightDecay(tf.train.Optimizer):
'''
References:
1. Decouple Weight Decay https://arxiv.org/abs/1711.05101
2. Spectral Regularization https://arxiv.org/abs/1705.10941
'''

def __init__(
self,
optimizer,
decay_rate: float,
use_locking: bool = False,
name: str = 'SpectralWeightDecay',
variable_filter: Union[Container[tf.Variable], Callable[[tf.Variable], bool]] = None,
):
super().__init__(use_locking, name)
self.optimizer = optimizer
self.decay_rate = decay_rate
self.decay_rate_tensor = tf.convert_to_tensor(decay_rate)
self.variable_filter = variable_filter

def apply_gradients(self, grads_and_vars, global_step=None, name=None):
var_list, decay_value, update_list = self._get_decay_trips(grads_and_vars)
with tf.control_dependencies(decay_value): # cache the value before descent
grad_descent_op = self.optimizer.apply_gradients(
grads_and_vars,
global_step=global_step,
)

# guarantee compute before decay.
with tf.control_dependencies([grad_descent_op]):
decay_op = tf.group(
*[
v.assign_sub(d_v, use_locking=self._use_locking)
for v, d_v in zip(var_list, decay_value)
],
*update_list,
name=name,
)

return decay_op

def _get_decay_trips(self, grads_and_vars):
if self.variable_filter is None:
def need_decay(var):
return 'kernel' in v.name and v.shape.ndims >= 2
elif hasattr(self.variable_filter, '__contains__'):
def need_decay(var):
return var in self.variable_filter
else:
need_decay = self.variable_filter

var_list, decay_list, update_list = [], [], []
for g, v in grads_and_vars:
if g is None or not need_decay(v):
continue
if v.shape.ndims < 2:
raise ValueError("Can't apply spectral norm on variable with rank < 2!")
decay_value, update_u = self._build_spectral_norm_variables(v)
rate = tf.cast(self.decay_rate_tensor, dtype=v.dtype.base_dtype)
var_list.append(v)
decay_list.append(rate * decay_value)
update_list.append(update_u)

return var_list, decay_list, update_list

def _build_spectral_norm_variables(self, kernel):
kernel_matrix = to_rank2(kernel) # shape (U, V)
u = self._get_or_make_slot_with_initializer(
kernel,
initializer=tf.keras.initializers.lecun_normal(), # unit vector
shape=kernel_matrix.shape[:1],
dtype=kernel_matrix.dtype,
slot_name='u',
op_name=self._name,
) # shape (U)
v = tf.nn.l2_normalize(tf.linalg.matvec(kernel_matrix, u, transpose_a=True)) # shape (V)
Wv = tf.linalg.matvec(kernel_matrix, v) # shape (U)
# NOTE
# sigma = u^T W v -> dsigma / dW = uv^T
# 0.5 dsigma^2 / dW = sigma u v^T = (sigma u) v^T = Wv v^T
decay_value = Wv[:, tf.newaxis] * v # shape (U, V)
if kernel.shape.ndims > 2:
decay_value = tf.reshape(decay_value, kernel.shape)

new_u = tf.nn.l2_normalize(Wv) # shape (U)
update_u = tf.assign(u, new_u)
return decay_value, update_u


def to_rank2(tensor: tf.Tensor):
if tensor.shape.ndims > 2:
return tf.reshape(tensor, [-1, tensor.shape[-1].value])
return tensor
56 changes: 56 additions & 0 deletions talos/optimizers/tests/test_spectral_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import pytest

import numpy as np
import tensorflow as tf

from ..spectral_norm import SpectralWeightDecay


def test_spectral_weight_decay_apply_low_rank_by_default(sess):
lr, decay_rate = 0.2, 0.1
x_val = 2.
optimizer = SpectralWeightDecay(
tf.train.GradientDescentOptimizer(lr),
decay_rate=decay_rate,
)
x = tf.Variable(x_val, name='x') # rank 0
y = tf.pow(x, 3) # dy/dx = 3x^2
train_op = optimizer.minimize(y, var_list=[x])

sess.run(tf.variables_initializer([x]))
sess.run(train_op)
np.testing.assert_almost_equal(
sess.run(x),
x_val - lr * 3 * (x_val ** 2),
)


@pytest.mark.parametrize('shape', [
[3, 4],
[3, 4, 5],
])
def test_spectral_weight_decay(shape, sess):
lr, decay_rate = 0.2, 0.1
optimizer = SpectralWeightDecay(
tf.train.GradientDescentOptimizer(lr),
decay_rate=decay_rate,
)

W = tf.Variable(np.random.rand(*shape), name='kernel')
y = tf.reduce_sum(W) # dy/dx = 1
train_op = optimizer.minimize(y, var_list=[W])
u = optimizer.get_slot(W, 'u')

assert u.shape.as_list() == [np.prod(shape[:-1])]

sess.run(tf.variables_initializer([W, u]))
W_val, u_val = sess.run([W, u])
v_val = W_val.reshape([-1, shape[-1]]).T @ u_val
v_val /= np.linalg.norm(v_val)
decay_val = decay_rate * np.expand_dims(W_val @ v_val, -1) * v_val

sess.run(train_op)
np.testing.assert_almost_equal(
sess.run(W),
W_val - decay_val - lr * 1.,
)
2 changes: 1 addition & 1 deletion talos/optimizers/weight_decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(
use_locking: bool = False,
name: str = 'WeightDecay',
variable_filter: Union[Container[tf.Variable], Callable[[tf.Variable], bool]] = None,
sparse_update: bool = True,
sparse_update: bool = False,
):
super().__init__(use_locking, name)
self.optimizer = optimizer
Expand Down
67 changes: 13 additions & 54 deletions talos/spectral_norm/spectral_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,7 @@
from typing import Set

import tensorflow as tf
from tensorflow.python.keras.layers.cudnn_recurrent import (
_CuDNNRNN,
CuDNNGRU,
CuDNNLSTM,
)
from tensorflow.python.keras.layers.cudnn_recurrent import _CuDNNRNN

_WEIGHTS_VARIABLE_NAME = "kernel"

Expand Down Expand Up @@ -36,21 +32,7 @@ def add_spectral_norm(layer: tf.layers.Layer):
def add_spectral_norm_for_layer(
layer: tf.layers.Layer,
kernel_name: Set[str] = None,
):
if isinstance(layer, (CuDNNGRU, tf.keras.layers.GRUCell)):
weight_split = 3
elif isinstance(layer, (CuDNNLSTM, tf.keras.layers.LSTMCell)):
weight_split = 4
else:
weight_split = 1

_add_spectral_norm_for_layer(layer, kernel_name, weight_split=weight_split)


def _add_spectral_norm_for_layer(
layer: tf.layers.Layer,
kernel_name: Set[str] = None,
weight_split: int = 1,
lipschitz: float = 1.,
):
if layer.built:
raise ValueError("Can't add spectral norm on built layer!")
Expand All @@ -72,41 +54,18 @@ def new_add_weight(self, name=None, shape=None, **kwargs):
if len(shape) < 2:
raise ValueError("Can't apply spectral norm on variable rank < 2!")

kernel_matrix = to_rank2(kernel) # shape (U, V)
if weight_split > 1:
assert shape[1] % weight_split == 0
split_kernel = tf.split(kernel_matrix, weight_split, axis=1)
sn_list = []
for i, sub_kernel in enumerate(split_kernel):
sn_val, update_u = _build_spectral_norm_variables(
f"{name}_{i}",
sub_kernel,
original_add_weight,
)
sn_list.append(
tf.fill([sub_kernel.shape[1].value], value=sn_val),
) # shape (V // split)
self.add_update(update_u)

spectral_norm = tf.concat(sn_list, axis=0) # shape (V)
else:
spectral_norm, update_u = _build_spectral_norm_variables(
name, kernel_matrix, original_add_weight,
) # shape ()
self.add_update(update_u)

normed_kernel = tf.truediv(
kernel,
spectral_norm + tf.keras.backend.epsilon(),
name=f'{name}_sn',
)
spectral_norm, update_u = _build_spectral_norm_variables(name, kernel, original_add_weight)
self.add_update(update_u)

scale = lipschitz / (spectral_norm + tf.keras.backend.epsilon())
normed_kernel = tf.multiply(kernel, scale, name=f'{name}_sn')
return normed_kernel

layer.add_weight = types.MethodType(new_add_weight, layer)


def _build_spectral_norm_variables(name, kernel, add_weight_func):
assert kernel.shape.ndims == 2
def _build_spectral_norm_variables(name, kernel, add_weight_func=tf.get_variable):
kernel = to_rank2(kernel) # shape (U, V)
u_vector = add_weight_func(
name=f'{name}/left_singular_vector',
shape=(kernel.shape[0].value, ),
Expand All @@ -119,12 +78,12 @@ def _build_spectral_norm_variables(name, kernel, add_weight_func):
tf.nn.l2_normalize(tf.linalg.matvec(kernel, u_vector, transpose_a=True)),
name=f'{name}/new_right_singular_vector',
) # shape (V)
unnormed_new_u = tf.linalg.matvec(kernel, new_v) # shape (U)
Wv = tf.linalg.matvec(kernel, new_v) # shape (U)
new_u = tf.stop_gradient(
tf.nn.l2_normalize(unnormed_new_u),
tf.nn.l2_normalize(Wv),
name=f'{name}/new_left_singular_vector',
)
spectral_norm = tf.reduce_sum(new_u * unnormed_new_u, name=f'{name}/singular_value')
) # shape (U)
spectral_norm = tf.tensordot(new_u, Wv, axes=1, name=f'{name}/singular_value')
update_u = tf.assign(u_vector, new_u, name=f'{name}/power_iter')
return spectral_norm, update_u

Expand Down
2 changes: 1 addition & 1 deletion talos/spectral_norm/tests/test_spectral_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_add_spectral_norm(layer, inputs, sess):
u_vector_list = layer.non_trainable_variables

# Since norm come from division
assert all([kernel.op.type == 'RealDiv' for kernel in kernel_list])
assert all([kernel.op.type == 'Mul' for kernel in kernel_list])

sess.run(tf.variables_initializer(layer.variables))

Expand Down

0 comments on commit b88c9b0

Please sign in to comment.