Skip to content

Commit

Permalink
Merge pull request #103 from Yoctol/extend_embedding_dims
Browse files Browse the repository at this point in the history
embedding can expand dims
  • Loading branch information
noobOriented authored Jun 12, 2019
2 parents 90920dc + 0b32e0c commit 0002e5f
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 34 deletions.
4 changes: 3 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
.DEFAULT_GOAL := all

TARGET = talos

.PHONY: install
install:
pip install -U pip wheel setuptools
Expand All @@ -12,7 +14,7 @@ lint:

.PHONY: test
test:
pytest --cov=talos/ --cov-fail-under=80
pytest ${TARGET} --cov=talos/ --cov-fail-under=80

.PHONY: test-report
test-report:
Expand Down
29 changes: 25 additions & 4 deletions talos/layers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
self.dropout = dropout

self.auxiliary_tokens = 0
self.extend_dims = 0
self._constant = False

@tf_utils.shape_type_conversion
Expand All @@ -60,6 +61,9 @@ def build(self, input_shape):
constraint=self.embeddings_constraint,
trainable=self.trainable,
)

self.total_embeddings = self.embeddings

if self.auxiliary_tokens > 0:
# HACK, since Layer.add_weight will take
# the intersection of trainable (in arg) and self.trainable
Expand All @@ -74,12 +78,27 @@ def build(self, input_shape):
)
self.trainable = original_trainable
self.total_embeddings = tf.concat(
[self.embeddings, self.auxiliary_embeddings],
[self.total_embeddings, self.auxiliary_embeddings],
axis=0,
name='total_embeddings',
name='embeddings_with_auxiliary_tokens',
)
else:
self.total_embeddings = self.embeddings

if self.extend_dims > 0:
original_trainable = self.trainable
self.trainable = True
vocab_size, embeddings_dim = self.total_embeddings.shape.as_list()
self.extend_embeddings = self.add_weight(
shape=(vocab_size, embeddings_dim + self.extend_dims),
name='extend_embeddings_dims',
trainable=True,
)
self.trainable = original_trainable
self.total_embeddings = tf.concat(
[self.total_embeddings, self.extend_embeddings],
axis=1,
name='embeddings_with_extended_dims',
)
self.total_embeddings = tf.identity(self.total_embeddings, name='total_embeddings')
self.built = True

@property
Expand All @@ -101,6 +120,7 @@ def from_weights(
mask_index: Union[int, Sequence[int]] = None,
constant: bool = False,
auxiliary_tokens: int = 0,
extend_dims: int = 0,
dropout: float = None,
**kwargs,
):
Expand Down Expand Up @@ -142,6 +162,7 @@ def from_weights(
layer.trainable = False
layer._constant = True
layer.auxiliary_tokens = auxiliary_tokens
layer.extend_dims = extend_dims
return layer

def _standardize_mask_index(self, mask_index):
Expand Down
77 changes: 48 additions & 29 deletions talos/layers/tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,43 @@ def test_auxiliary_tokens_partially_trainable(inputs, sess, constant):
)


@pytest.mark.parametrize('constant', [False, True])
def test_extend_dims_partially_trainable(inputs, sess, constant):
maxlen = inputs.shape[1].value
vocab_size = 5
original_embedding_size = 3
embed_layer = Embedding.from_weights(
np.random.uniform(size=[vocab_size, original_embedding_size]).astype(np.float32),
constant=constant,
trainable=False,
extend_dims=2,
)
word_vec = embed_layer(inputs)
assert len(embed_layer.trainable_variables) == 1
assert len(embed_layer.non_trainable_variables) == (0 if constant else 1)
assert len(embed_layer.variables) == (1 if constant else 2)

update_op = tf.train.GradientDescentOptimizer(0.1).minimize(tf.reduce_sum(word_vec))

sess.run(tf.variables_initializer(var_list=embed_layer.variables))

original_weights_val = sess.run(embed_layer.total_embeddings)
sess.run(update_op, feed_dict={inputs: np.random.choice(vocab_size, size=[10, maxlen])})
new_weights_val = sess.run(embed_layer.total_embeddings)

# after update:
np.testing.assert_array_almost_equal(
original_weights_val[:, : original_embedding_size],
new_weights_val[:, : original_embedding_size],
)
# others (extend dims) should change.
with pytest.raises(AssertionError):
np.testing.assert_array_almost_equal(
original_weights_val[:, original_embedding_size:], # extend dims
new_weights_val[:, original_embedding_size:],
)


@pytest.mark.parametrize('invalid_weights', [
np.zeros([5]),
np.zeros([1, 2, 3]),
Expand All @@ -122,17 +159,23 @@ def test_construct_from_invalid_weights_raise(invalid_weights):
Embedding.from_weights(invalid_weights)


@pytest.mark.parametrize('constant,auxiliary_tokens', [
(True, 0),
(True, 2),
(False, 2),
@pytest.mark.parametrize('constant,auxiliary_tokens,extend_dims', [
(True, 0, 0),
(True, 2, 0),
(True, 0, 2),
(True, 2, 10),
(False, 0, 0),
(False, 2, 0),
(False, 0, 2),
(False, 2, 10),
])
def test_freeze_success(inputs, sess, constant, auxiliary_tokens):
def test_freeze_success(inputs, sess, constant, auxiliary_tokens, extend_dims):
# build graph with constant embedding layer
embed_layer = Embedding.from_weights(
np.random.rand(5, 10).astype(np.float32),
constant=constant,
auxiliary_tokens=auxiliary_tokens,
extend_dims=extend_dims,
)
outputs = embed_layer(inputs)
sess.run(tf.variables_initializer(var_list=embed_layer.variables))
Expand All @@ -157,30 +200,6 @@ def test_freeze_success(inputs, sess, constant, auxiliary_tokens):
np.testing.assert_array_almost_equal(outputs_val, new_outputs_val)


@pytest.mark.parametrize('constant,auxiliary_tokens', [
(False, 0), # NOTE only fail in this case
])
def test_freeze_fail(inputs, sess, constant, auxiliary_tokens):
# build graph with variable embedding layer
embed_layer = Embedding.from_weights(
np.random.rand(5, 10).astype(np.float32),
constant=constant,
auxiliary_tokens=auxiliary_tokens,
)
outputs = embed_layer(inputs)

sess.run(tf.variables_initializer(var_list=embed_layer.variables))

frozen_graph_def = graph_util.convert_variables_to_constants(
sess=sess,
input_graph_def=sess.graph_def,
output_node_names=[outputs.op.name], # node name == op name
)

with pytest.raises(ValueError):
create_session_from_graphdef(frozen_graph_def)


def create_session_from_graphdef(graph_def):
"""
Create new session from given tf.GraphDef object
Expand Down

0 comments on commit 0002e5f

Please sign in to comment.