Skip to content

Commit

Permalink
Merge pull request #39 from Yoctol/positional-encode
Browse files Browse the repository at this point in the history
Positional encode.
  • Loading branch information
GBLin5566 authored Jan 29, 2019
2 parents a6fc265 + 81b1986 commit fb38a9d
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 1 deletion.
2 changes: 1 addition & 1 deletion talos/layers/conv1d_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs
**kwargs,
):
super().__init__(
filters=filters,
Expand Down
41 changes: 41 additions & 0 deletions talos/layers/positional_encode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import numpy as np
import tensorflow as tf


class PositionalEncode(tf.keras.layers.Layer):

def __init__(
self,
base: float = 1e4,
amplitude: float = 1.,
**kwargs,
):
super().__init__(**kwargs)
self.base = base
self.amplitude = amplitude
self.supports_masking = True
self.input_spec = tf.keras.layers.InputSpec(ndim=3)

def call(self, inputs: tf.Tensor) -> tf.Tensor:
dtype = inputs.dtype
maxlen, dim = inputs.shape.as_list()[1:]
pe = self._get_positional_encode_tensor(maxlen, dim, dtype)
return inputs + pe[tf.newaxis, :, :]

def _get_positional_encode_tensor(self, maxlen, dim, dtype):
position_range = np.arange(maxlen) # shape [L]
dim_range = np.arange(dim) # shape [D]
wave_length = np.power(self.base, 2. * dim_range / dim) # shape [D]

offset = (-np.pi / 2.) * ((dim_range + 1) % 2)
# [-pi / 2, 0, ...] for convert sin to cos on even dim, shape [D]

theta = position_range[:, np.newaxis] / wave_length[np.newaxis, :] + offset[np.newaxis, :]
outputs_np = self.amplitude * np.cos(theta)
return tf.constant(outputs_np, dtype=dtype) # shape [L, D]

def compute_output_shape(self, input_shape):
return input_shape

def compute_mask(self, inputs, mask):
return mask
33 changes: 33 additions & 0 deletions talos/layers/tests/test_positional_encode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytest

import numpy as np
import tensorflow as tf

from ..positional_encode import PositionalEncode


@pytest.fixture(scope='module')
def layer():
return PositionalEncode()


def test_output_shape(layer):
inputs = tf.zeros([5, 4, 3])
outputs = layer(inputs)
assert outputs.shape.as_list() == inputs.shape.as_list()


def test_output_val(layer, sess):
inputs = tf.constant([
[[1., 2., 3., 4.]],
[[5., 6., 7., 8.]],
])
outputs = layer(inputs)
inputs_val, outputs_val = sess.run([inputs, outputs])
np.testing.assert_array_almost_equal(
outputs_val[:, 0],
inputs_val[:, 0] + np.array([
[0., 1., 0., 1.] # sin0, cos0, sin0,
for _ in range(len(inputs_val))
]),
)

0 comments on commit fb38a9d

Please sign in to comment.