Skip to content

Commit

Permalink
Fixed rnn hidden shape bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
kwangjinoh committed Mar 24, 2018
1 parent e47853e commit 11f238e
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
6 changes: 3 additions & 3 deletions 12_2_hello_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@ def forward(self, hidden, x):

# Propagate input through RNN
# Input: (batch, seq_len, input_size)
# hidden: (batch, num_layers * num_directions, hidden_size)
# hidden: (num_layers * num_directions, batch, hidden_size)
out, hidden = self.rnn(x, hidden)
return hidden, out.view(-1, num_classes)

def init_hidden(self):
# Initialize hidden and cell states
# (batch, num_layers * num_directions, hidden_size) for batch_first=True
return Variable(torch.zeros(batch_size, num_layers, hidden_size))
# (num_layers * num_directions, batch, hidden_size)
return Variable(torch.zeros(num_layers, batch_size, hidden_size))


# Instantiate RNN model
Expand Down
6 changes: 3 additions & 3 deletions 12_3_hello_rnn_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,16 @@ def __init__(self, num_classes, input_size, hidden_size, num_layers):

def forward(self, x):
# Initialize hidden and cell states
# (batch, num_layers * num_directions, hidden_size) for batch_first=True
# (num_layers * num_directions, batch, hidden_size) for batch_first=True
h_0 = Variable(torch.zeros(
x.size(0), self.num_layers, self.hidden_size))
self.num_layers, x.size(0), self.hidden_size))

# Reshape input
x.view(x.size(0), self.sequence_length, self.input_size)

# Propagate input through RNN
# Input: (batch, seq_len, input_size)
# h_0: (batch, num_layers * num_directions, hidden_size)
# h_0: (num_layers * num_directions, batch, hidden_size)

out, _ = self.rnn(x, h_0)
return out.view(-1, num_classes)
Expand Down
6 changes: 3 additions & 3 deletions 12_4_hello_rnn_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,16 @@ def __init__(self):

def forward(self, x):
# Initialize hidden and cell states
# (batch, num_layers * num_directions, hidden_size) for batch_first=True
# (num_layers * num_directions, batch, hidden_size)
h_0 = Variable(torch.zeros(
x.size(0), num_layers, hidden_size))
self.num_layers, x.size(0), self.hidden_size))

emb = self.embedding(x)
emb = emb.view(batch_size, sequence_length, -1)

# Propagate embedding through RNN
# Input: (batch, seq_len, embedding_size)
# h_0: (batch, num_layers * num_directions, hidden_size)
# h_0: (num_layers * num_directions, batch, hidden_size)
out, _ = self.rnn(emb, h_0)
return self.fc(out.view(-1, num_classes))

Expand Down

0 comments on commit 11f238e

Please sign in to comment.