Skip to content

Commit

Permalink
get_initial_loop_state code factorization
Browse files Browse the repository at this point in the history
  • Loading branch information
varisd authored and jlibovicky committed Mar 26, 2019
1 parent 4a11fcd commit 3e5ab5a
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 29 deletions.
26 changes: 14 additions & 12 deletions neuralmonkey/decoders/autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,14 @@ def runtime_logprobs(self) -> tf.Tensor:
def output_dimension(self) -> int:
raise NotImplementedError("Abstract property")

def get_initial_loop_state(self) -> LoopState:
def get_initial_feedables(self) -> DecoderFeedables:
return DecoderFeedables(
step=tf.constant(0, tf.int32),
finished=tf.zeros([self.batch_size], dtype=tf.bool),
embedded_input=self.embed_input_symbols(self.go_symbols),
other=None)

def get_initial_histories(self) -> DecoderHistories:
output_states = tf.zeros(
shape=[0, self.batch_size, self.embedding_size],
dtype=tf.float32,
Expand All @@ -400,25 +406,21 @@ def get_initial_loop_state(self) -> LoopState:
dtype=tf.float32,
name="hist_logits")

feedables = DecoderFeedables(
step=tf.constant(0, tf.int32),
finished=tf.zeros([self.batch_size], dtype=tf.bool),
embedded_input=self.embed_input_symbols(self.go_symbols),
other=None)

histories = DecoderHistories(
return DecoderHistories(
logits=logits,
output_states=output_states,
output_mask=output_mask,
output_symbols=output_symbols,
other=None)

constants = DecoderConstants(train_inputs=self.train_inputs)
def get_initial_constants(self) -> DecoderConstants:
return DecoderConstants(train_inputs=self.train_inputs)

def get_initial_loop_state(self) -> LoopState:
return LoopState(
histories=histories,
constants=constants,
feedables=feedables)
feedables=self.get_initial_feedables(),
histories=self.get_initial_histories(),
constants=self.get_initial_constants())

def loop_continue_criterion(self, *args) -> tf.Tensor:
"""Decide whether to break out of the while loop.
Expand Down
18 changes: 9 additions & 9 deletions neuralmonkey/decoders/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typeguard import check_argument_types

from neuralmonkey.decoders.autoregressive import (
AutoregressiveDecoder, LoopState)
AutoregressiveDecoder, DecoderFeedables, DecoderHistories, LoopState)
from neuralmonkey.attention.base_attention import BaseAttention
from neuralmonkey.vocabulary import Vocabulary
from neuralmonkey.model.sequence import EmbeddedSequence
Expand Down Expand Up @@ -357,17 +357,20 @@ def next_state(self, loop_state: LoopState) -> Tuple[tf.Tensor, Any, Any]:

return (output, new_feedables, new_histories)

def get_initial_loop_state(self) -> LoopState:
default_ls = AutoregressiveDecoder.get_initial_loop_state(self)
feedables = default_ls.feedables
histories = default_ls.histories
def get_initial_feedables(self) -> DecoderFeedables:
feedables = AutoregressiveDecoder.get_initial_feedables(self)

rnn_feedables = RNNFeedables(
prev_contexts=[tf.zeros([self.batch_size, a.context_vector_size])
for a in self.attentions],
prev_rnn_state=self.initial_state,
prev_rnn_output=self.initial_state)

return feedables._replace(other=rnn_feedables)

def get_initial_histories(self) -> DecoderHistories:
histories = AutoregressiveDecoder.get_initial_histories(self)

rnn_histories = RNNHistories(
rnn_outputs=tf.zeros(
shape=[0, self.batch_size, self.rnn_size],
Expand All @@ -376,10 +379,7 @@ def get_initial_loop_state(self) -> LoopState:
attention_histories=[a.initial_loop_state()
for a in self.attentions if a is not None])

return LoopState(
histories=histories._replace(other=rnn_histories),
constants=default_ls.constants,
feedables=feedables._replace(other=rnn_feedables))
return histories._replace(other=rnn_histories)

def finalize_loop(self, final_loop_state: LoopState,
train_mode: bool) -> None:
Expand Down
16 changes: 8 additions & 8 deletions neuralmonkey/decoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,10 +452,8 @@ def train_loop_result(self) -> LoopState:
histories=histories,
constants=decoder_ls.constants)

def get_initial_loop_state(self) -> LoopState:
default_ls = AutoregressiveDecoder.get_initial_loop_state(self)
feedables = default_ls.feedables
histories = default_ls.histories
def get_initial_feedables(self) -> DecoderFeedables:
feedables = AutoregressiveDecoder.get_initial_feedables(self)

tr_feedables = TransformerFeedables(
input_sequence=tf.zeros(
Expand All @@ -467,6 +465,11 @@ def get_initial_loop_state(self) -> LoopState:
dtype=tf.float32,
name="input_mask"))

return feedables._replace(other=tr_feedables)

def get_initial_histories(self) -> DecoderHistories:
histories = AutoregressiveDecoder.get_initial_histories(self)

# TODO: record histories properly
tr_histories = tf.zeros([])
# tr_histories = TransformerHistories(
Expand All @@ -479,10 +482,7 @@ def get_initial_loop_state(self) -> LoopState:
# self.n_heads_enc)
# for a in range(self.depth)])

return LoopState(
histories=histories._replace(other=tr_histories),
constants=default_ls.constants,
feedables=feedables._replace(other=tr_feedables))
return histories._replace(other=tr_histories)

def next_state(self, loop_state: LoopState) -> Tuple[tf.Tensor, Any, Any]:
feedables = loop_state.feedables
Expand Down

0 comments on commit 3e5ab5a

Please sign in to comment.