Skip to content

Commit

Permalink
Merge pull request #115 from Yoctol/selectable-weight-decay
Browse files Browse the repository at this point in the history
Selectable weight decay
  • Loading branch information
noobOriented authored Oct 10, 2019
2 parents 7a0bf69 + c1667a4 commit f22957d
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 3 deletions.
2 changes: 1 addition & 1 deletion talos/__version__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__title__ = 'talos'
__version__ = '1.6.0'
__version__ = '1.6.1'
__description__ = 'Powerful Neural Network Builder'
__author__ = 'Jsaon'
31 changes: 30 additions & 1 deletion talos/optimizers/tests/test_weight_decay.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

import numpy as np
import tensorflow as tf

Expand All @@ -14,7 +16,7 @@ def test_weight_decay(sess):
x = tf.Variable(x_val)
z = tf.Variable(z_val)
y = tf.pow(x, 3) # dy/dx = 3x^2
train_op = optimizer.minimize(y, var_list=[x])
train_op = optimizer.minimize(y, var_list=[x, z])

sess.run(tf.variables_initializer([x, z]))
sess.run(train_op)
Expand All @@ -23,3 +25,30 @@ def test_weight_decay(sess):
x_val * (1. - decay_rate) - lr * 3 * (x_val ** 2),
)
np.testing.assert_almost_equal(sess.run(z), z_val) # keep since it's not updated


@pytest.mark.parametrize('var_filter', ['collection', 'callable'])
def test_weight_decay_with_filter(var_filter, sess):
lr, decay_rate = 0.2, 0.1
x_val, z_val = 2., 1.
x = tf.Variable(x_val, name='x')
z = tf.Variable(z_val, name='z')

optimizer = WeightDecay(
tf.train.GradientDescentOptimizer(lr),
decay_rate=decay_rate,
variable_filter={x} if var_filter == 'collection' else lambda v: 'x' in v.name,
)
y = tf.pow(x, 3) + z # dy/dx = 3x^2, dy/dz = 1
train_op = optimizer.minimize(y, var_list=[x, z])

sess.run(tf.variables_initializer([x, z]))
sess.run(train_op)
np.testing.assert_almost_equal(
sess.run(x),
x_val * (1. - decay_rate) - lr * 3 * (x_val ** 2),
)
np.testing.assert_almost_equal(
sess.run(z),
z_val - lr,
) # doesn't decay since it's not in filter
15 changes: 14 additions & 1 deletion talos/optimizers/weight_decay.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Callable, Container, Union

import tensorflow as tf


Expand All @@ -11,14 +13,25 @@ def __init__(
decay_rate: float,
use_locking: bool = False,
name: str = 'WeightDecay',
variable_filter: Union[Container[tf.Variable], Callable[[tf.Variable], bool]] = None,
):
super().__init__(use_locking, name)
self.optimizer = optimizer
self.decay_rate = decay_rate
self.decay_rate_tensor = tf.convert_to_tensor(decay_rate)
self.variable_filter = variable_filter

def apply_gradients(self, grads_and_vars, global_step=None, name=None):
var_list = [v for g, v in grads_and_vars if g is not None]
if self.variable_filter is None:
def need_decay(var):
return True
elif hasattr(self.variable_filter, '__contains__'):
def need_decay(var):
return var in self.variable_filter
else:
need_decay = self.variable_filter

var_list = [v for g, v in grads_and_vars if g is not None and need_decay(v)]

decay_value = [
tf.cast(self.decay_rate_tensor, dtype=v.dtype.base_dtype) * v
Expand Down

0 comments on commit f22957d

Please sign in to comment.