Skip to content

Commit

Permalink
fix the bugs of AvgPool can't compute mask.
Browse files Browse the repository at this point in the history
  • Loading branch information
noobOriented committed Nov 22, 2018
1 parent 7c59e84 commit a548d71
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 6 deletions.
21 changes: 15 additions & 6 deletions talos/layers/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,29 @@ class GlobalPooling1D(tf.keras.layers.Layer, abc.ABC):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.input_spec = tf.keras.layers.InputSpec(ndim=3)
self.support_masking = True

def compute_output_shape(self, input_shape):
input_shape = tf.TensorShape(input_shape).as_list()
return tf.TensorShape([input_shape[0], input_shape[2]])

def compute_mask(self, inputs, mask):
return None

@abc.abstractmethod
def call(self, inputs, seqlen=None, mask=None):
pass

def _get_mask(self, inputs, seqlen, mask):
# if there's a mask, use mask first
if mask is not None:
return tf.cast(mask, inputs.dtype)
return tf.expand_dims(tf.cast(mask, inputs.dtype), axis=2)
elif seqlen is not None:
maxlen = inputs.shape[1].value
return tf.sequence_mask(seqlen, maxlen=maxlen, dtype=inputs.dtype) # shape (N, T)
return tf.expand_dims(
tf.sequence_mask(seqlen, maxlen=maxlen, dtype=inputs.dtype),
axis=2,
) # shape (N, T, 1)
else:
return None

Expand All @@ -37,10 +44,9 @@ def call(self, inputs, seqlen=None, mask=None):
return tf.reduce_mean(inputs, axis=1)

# compute mean on True part
casted_mask = tf.expand_dims(casted_mask, axis=2)
# if there's a mask, use mask first
if mask is not None:
true_count = tf.reduce_sum(mask, axis=1)
true_count = tf.reduce_sum(casted_mask, axis=1)
else:
true_count = tf.expand_dims(tf.cast(seqlen, inputs.dtype), axis=1)
return tf.reduce_sum(inputs * casted_mask, axis=1) / true_count
Expand Down Expand Up @@ -122,6 +128,10 @@ def call(
mask: tf.Tensor = None,
) -> tf.Tensor:
# shape (N, T, units)
mask = self._get_mask(inputs, seqlen, mask)
if mask is not None:
inputs *= mask

hidden_outputs = tf.tensordot(inputs, self.candidate_kernel, axes=[[2], [0]])
if self.use_bias:
hidden_outputs = tf.nn.bias_add(hidden_outputs, self.bias)
Expand All @@ -131,10 +141,9 @@ def call(
logits = tf.tensordot(hidden_outputs, self.softmax_kernel, axes=[[2], [0]])
weights = tf.nn.softmax(logits, axis=1)

mask = self._get_mask(inputs, seqlen, mask)
if mask is not None:
# Renormalize for lower seqlen
weights *= tf.expand_dims(mask, axis=2)
weights *= mask
weights /= tf.reduce_sum(weights, axis=1, keepdims=True)

if self.reg_coeff > 0:
Expand Down
10 changes: 10 additions & 0 deletions talos/layers/tests/test_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,13 @@ def test_average_pooling_mask_value(graph):

expected_outputs_val = np.array([[0.5], [3.5]], dtype=np.float32)
np.testing.assert_array_almost_equal(outputs_val, expected_outputs_val)


def test_given_mask(graph):
att_pool = GlobalAttentionPooling1D(units=3, heads=4)
pool = GlobalAveragePooling1D()
inputs = tf.placeholder(tf.float32, [None, 5, 1])
masked_inputs = tf.keras.layers.Masking()(inputs)

att_pool(masked_inputs)
pool(masked_inputs)

0 comments on commit a548d71

Please sign in to comment.