Update decoder.py

This commit is contained in:
Yifan Yang 2023-02-09 18:21:52 +08:00 committed by GitHub
parent bcec8465c9
commit 00ed2b7567
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -79,7 +79,6 @@ class Decoder(nn.Module):
self, self,
embedding_out: torch.Tensor, embedding_out: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
is_training: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Add the repeat parameter to the embedding_out tensor. Add the repeat parameter to the embedding_out tensor.
@ -89,16 +88,13 @@ class Decoder(nn.Module):
A tensor of shape (N, U, decoder_dim). A tensor of shape (N, U, decoder_dim).
k: k:
A tensor of shape (N, U). A tensor of shape (N, U).
Should be (N, S) during training. Should be (N, S + 1) during training.
Should be (N, 1) during inference. Should be (N, 1) during inference.
is_training: is_training:
Whether it is training. Whether it is training.
Returns: Returns:
Return a tensor of shape (N, U, decoder_dim). Return a tensor of shape (N, U, decoder_dim).
""" """
if is_training:
k = F.pad(k, (1, 0), mode="constant", value=self.blank_id)
return embedding_out + (k / (1 + k)).unsqueeze(2) * self.repeat_param return embedding_out + (k / (1 + k)).unsqueeze(2) * self.repeat_param
def forward( def forward(
@ -140,7 +136,6 @@ class Decoder(nn.Module):
embedding_out = self._add_repeat_param( embedding_out = self._add_repeat_param(
embedding_out=embedding_out, embedding_out=embedding_out,
k=k, k=k,
is_training=need_pad,
) )
embedding_out = F.relu(embedding_out) embedding_out = F.relu(embedding_out)
return embedding_out return embedding_out