diff --git a/12_1_rnn_basics.py b/12_1_rnn_basics.py index fc39d3f..d26ceb0 100644 --- a/12_1_rnn_basics.py +++ b/12_1_rnn_basics.py @@ -12,7 +12,7 @@ cell = nn.RNN(input_size=4, hidden_size=2, batch_first=True) # (num_layers * num_directions, batch, hidden_size) whether batch_first=True or False -hidden = (Variable(torch.randn(1, 1, 2))) +hidden = Variable(torch.randn(1, 1, 2)) # Propagate input through RNN # Input: (batch, seq_len, input_size) when batch_first=True