Skip to content

Commit

Permalink
Merge pull request #28 from Yoctol/fix-sequential-add
Browse files Browse the repository at this point in the history
remove the unexpected part in Sequential.add()
  • Loading branch information
noobOriented authored Nov 16, 2018
2 parents 9ebe1c4 + 42ea4c8 commit 212e979
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 2 deletions.
38 changes: 36 additions & 2 deletions talos/module/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,52 @@
from tensorflow.python.keras.engine.network import Network
from tensorflow.python.keras.engine.sequential import Sequential as keras_Sequential
from tensorflow.python.util import tf_inspect
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.training.checkpointable import base as checkpointable


class Sequential(keras_Sequential):

# override: add **kwargs
# HACK override: remove adding InputLayer part!!
# https://github.com/tensorflow/tensorflow/blob/r1.11/tensorflow/python/keras/engine/sequential.py#L122-L188
# just copy/paste the source code and remove L142-L170
@checkpointable.no_automatic_dependency_tracking
def add(self, layer):
# source code L137
if not isinstance(layer, base_layer.Layer):
raise TypeError(
'The added layer must be '
'an instance of class Layer. '
f'Found: {layer}',
)
self.built = False

# source code L172
if self.outputs:
# If the model is being built continuously on top of an input layer:
# refresh its output.
output_tensor = layer(self.outputs[0])
if isinstance(output_tensor, list):
raise TypeError(
'All layers in a Sequential model '
'should have a single output tensor. '
'For multi-output layers, '
'use the functional API.',
)
self.outputs = [output_tensor]

self._layers.append(layer)
if self._layers:
self._track_layers(self._layers)

# HACK override: add **kwargs
# https://github.com/tensorflow/tensorflow/blob/r1.11/tensorflow/python/keras/engine/sequential.py#L227-L233
def call(self, inputs, training=None, mask=None, **kwargs):
outputs, _ = self._call_and_compute_mask(
inputs, training=training, mask=mask, **kwargs)
return outputs

# override: add **kwargs
# HACK override: add **kwargs
# https://github.com/tensorflow/tensorflow/blob/r1.11/tensorflow/python/keras/engine/sequential.py#L235-L257
def _call_and_compute_mask(self, inputs, training=None, mask=None, **kwargs):
if not self.built:
Expand Down
28 changes: 28 additions & 0 deletions talos/module/tests/test_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,34 @@ def graph():
yield graph


def test_build_sublayers_when_first_called(graph):
sequential = Sequential([
tf.keras.layers.Embedding(20, 10),
tf.keras.layers.LSTM(10, return_sequences=True),
tf.keras.layers.Dense(5),
tf.keras.layers.MaxPooling1D(),
])
assert all(not layer.built for layer in sequential.layers)
inputs = tf.zeros([1, 3], dtype=tf.float32)
sequential(inputs)
assert all(layer.built for layer in sequential.layers)


def test_context_manager_work_when_first_called(graph):
sequential = Sequential([
tf.keras.layers.Embedding(20, 10),
tf.keras.layers.LSTM(10, return_sequences=True),
tf.keras.layers.Dense(5),
tf.keras.layers.MaxPooling1D(),
])
with tf.variable_scope('scope'):
inputs = tf.zeros([1, 3], dtype=tf.float32)
sequential(inputs)
variables = graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
assert all(var.graph is graph for var in variables)
assert all(var.name.startswith('scope') for var in variables)


def test_additional_inputs(graph):

class LayerNeedSeqlen(tf.keras.layers.Layer):
Expand Down

0 comments on commit 212e979

Please sign in to comment.