mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Update decoder.py
This commit is contained in:
parent
bcec8465c9
commit
00ed2b7567
@ -79,7 +79,6 @@ class Decoder(nn.Module):
|
|||||||
self,
|
self,
|
||||||
embedding_out: torch.Tensor,
|
embedding_out: torch.Tensor,
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
is_training: bool = True,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Add the repeat parameter to the embedding_out tensor.
|
Add the repeat parameter to the embedding_out tensor.
|
||||||
@ -89,16 +88,13 @@ class Decoder(nn.Module):
|
|||||||
A tensor of shape (N, U, decoder_dim).
|
A tensor of shape (N, U, decoder_dim).
|
||||||
k:
|
k:
|
||||||
A tensor of shape (N, U).
|
A tensor of shape (N, U).
|
||||||
Should be (N, S) during training.
|
Should be (N, S + 1) during training.
|
||||||
Should be (N, 1) during inference.
|
Should be (N, 1) during inference.
|
||||||
is_training:
|
is_training:
|
||||||
Whether it is training.
|
Whether it is training.
|
||||||
Returns:
|
Returns:
|
||||||
Return a tensor of shape (N, U, decoder_dim).
|
Return a tensor of shape (N, U, decoder_dim).
|
||||||
"""
|
"""
|
||||||
if is_training:
|
|
||||||
k = F.pad(k, (1, 0), mode="constant", value=self.blank_id)
|
|
||||||
|
|
||||||
return embedding_out + (k / (1 + k)).unsqueeze(2) * self.repeat_param
|
return embedding_out + (k / (1 + k)).unsqueeze(2) * self.repeat_param
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -140,7 +136,6 @@ class Decoder(nn.Module):
|
|||||||
embedding_out = self._add_repeat_param(
|
embedding_out = self._add_repeat_param(
|
||||||
embedding_out=embedding_out,
|
embedding_out=embedding_out,
|
||||||
k=k,
|
k=k,
|
||||||
is_training=need_pad,
|
|
||||||
)
|
)
|
||||||
embedding_out = F.relu(embedding_out)
|
embedding_out = F.relu(embedding_out)
|
||||||
return embedding_out
|
return embedding_out
|
||||||
|
Loading…
x
Reference in New Issue
Block a user