Skip to content

Commit

Permalink
Merge pull request #20 from kwangjinoh/master
Browse files Browse the repository at this point in the history
Fixed 12_X RNN hidden shape bugs
  • Loading branch information
hunkim authored Mar 25, 2018
2 parents 59c86cc + 11f238e commit d7f74ae
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 12 deletions.
8 changes: 5 additions & 3 deletions 12_1_rnn_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
# One cell RNN input_dim (4) -> output_dim (2). sequence: 5
cell = nn.RNN(input_size=4, hidden_size=2, batch_first=True)

# (num_layers * num_directions, batch, hidden_size)
# (batch, num_layers * num_directions, hidden_size) for batch_first=True
# (num_layers * num_directions, batch, hidden_size) whether batch_first=True or False
hidden = (Variable(torch.randn(1, 1, 2)))

# Propagate input through RNN
Expand All @@ -32,6 +31,9 @@
print("sequence input size", inputs.size(), "out size", out.size())


# hidden : (num_layers * num_directions, batch, hidden_size) whether batch_first=True or False
hidden = Variable(torch.randn(1, 3, 2))

# One cell RNN input_dim (4) -> output_dim (2). sequence: 5, batch 3
# 3 batches 'hello', 'eolll', 'lleel'
# rank = (3, 5, 4)
Expand All @@ -50,7 +52,7 @@
cell = nn.RNN(input_size=4, hidden_size=2)

# The given dimensions dim0 and dim1 are swapped.
inputs = inputs.transpose(3, dim1=1, dim2=2)
inputs = inputs.transpose(dim0=0, dim1=1)
# Propagate input through RNN
# Input: (seq_len, batch_size, input_size) when batch_first=False (default)
# S x B x I
Expand Down
6 changes: 3 additions & 3 deletions 12_2_hello_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@ def forward(self, hidden, x):

# Propagate input through RNN
# Input: (batch, seq_len, input_size)
# hidden: (batch, num_layers * num_directions, hidden_size)
# hidden: (num_layers * num_directions, batch, hidden_size)
out, hidden = self.rnn(x, hidden)
return hidden, out.view(-1, num_classes)

def init_hidden(self):
# Initialize hidden and cell states
# (batch, num_layers * num_directions, hidden_size) for batch_first=True
return Variable(torch.zeros(batch_size, num_layers, hidden_size))
# (num_layers * num_directions, batch, hidden_size)
return Variable(torch.zeros(num_layers, batch_size, hidden_size))


# Instantiate RNN model
Expand Down
6 changes: 3 additions & 3 deletions 12_3_hello_rnn_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,16 @@ def __init__(self, num_classes, input_size, hidden_size, num_layers):

def forward(self, x):
# Initialize hidden and cell states
# (batch, num_layers * num_directions, hidden_size) for batch_first=True
# (num_layers * num_directions, batch, hidden_size) for batch_first=True
h_0 = Variable(torch.zeros(
x.size(0), self.num_layers, self.hidden_size))
self.num_layers, x.size(0), self.hidden_size))

# Reshape input
x.view(x.size(0), self.sequence_length, self.input_size)

# Propagate input through RNN
# Input: (batch, seq_len, input_size)
# h_0: (batch, num_layers * num_directions, hidden_size)
# h_0: (num_layers * num_directions, batch, hidden_size)

out, _ = self.rnn(x, h_0)
return out.view(-1, num_classes)
Expand Down
6 changes: 3 additions & 3 deletions 12_4_hello_rnn_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,16 @@ def __init__(self):

def forward(self, x):
# Initialize hidden and cell states
# (batch, num_layers * num_directions, hidden_size) for batch_first=True
# (num_layers * num_directions, batch, hidden_size)
h_0 = Variable(torch.zeros(
x.size(0), num_layers, hidden_size))
self.num_layers, x.size(0), self.hidden_size))

emb = self.embedding(x)
emb = emb.view(batch_size, sequence_length, -1)

# Propagate embedding through RNN
# Input: (batch, seq_len, embedding_size)
# h_0: (batch, num_layers * num_directions, hidden_size)
# h_0: (num_layers * num_directions, batch, hidden_size)
out, _ = self.rnn(emb, h_0)
return self.fc(out.view(-1, num_classes))

Expand Down

0 comments on commit d7f74ae

Please sign in to comment.