diff --git a/models.py b/models.py index 2ebd036..f4042f5 100644 --- a/models.py +++ b/models.py @@ -156,7 +156,7 @@ def __init__(self, if self.n_vocab!=0: self.emb = nn.Embedding(n_vocab, hidden_channels) if emotion_embedding: - self.emotion_emb = nn.Linear(1024, hidden_channels) + self.emo_proj = nn.Linear(1024, hidden_channels) nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) self.encoder = attentions.Encoder( @@ -172,7 +172,7 @@ def forward(self, x, x_lengths, emotion_embedding=None): if self.n_vocab!=0: x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] if emotion_embedding is not None: - x = x + self.emotion_emb(emotion_embedding.unsqueeze(1)) + x = x + self.emo_proj(emotion_embedding.unsqueeze(1)) x = torch.transpose(x, 1, -1) # [b, h, t] x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)