Skip to content

Commit

Permalink
fix the mistake of offset and add test for value
Browse files Browse the repository at this point in the history
  • Loading branch information
noobOriented committed Jan 29, 2019
1 parent 691a2a8 commit 81b1986
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
2 changes: 1 addition & 1 deletion talos/layers/positional_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _get_positional_encode_tensor(self, maxlen, dim, dtype):
dim_range = np.arange(dim) # shape [D]
wave_length = np.power(self.base, 2. * dim_range / dim) # shape [D]

offset = -np.pi * ((dim_range % 2) + 1)
offset = (-np.pi / 2.) * ((dim_range + 1) % 2)
# [-pi / 2, 0, ...] for convert sin to cos on even dim, shape [D]

theta = position_range[:, np.newaxis] / wave_length[np.newaxis, :] + offset[np.newaxis, :]
Expand Down
27 changes: 25 additions & 2 deletions talos/layers/tests/test_positional_encode.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,33 @@
import pytest

import numpy as np
import tensorflow as tf

from ..positional_encode import PositionalEncode


def test_output_shape():
layer = PositionalEncode()
@pytest.fixture(scope='module')
def layer():
return PositionalEncode()


def test_output_shape(layer):
inputs = tf.zeros([5, 4, 3])
outputs = layer(inputs)
assert outputs.shape.as_list() == inputs.shape.as_list()


def test_output_val(layer, sess):
inputs = tf.constant([
[[1., 2., 3., 4.]],
[[5., 6., 7., 8.]],
])
outputs = layer(inputs)
inputs_val, outputs_val = sess.run([inputs, outputs])
np.testing.assert_array_almost_equal(
outputs_val[:, 0],
inputs_val[:, 0] + np.array([
[0., 1., 0., 1.] # sin0, cos0, sin0,
for _ in range(len(inputs_val))
]),
)

0 comments on commit 81b1986

Please sign in to comment.