mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Update decoder.py
This commit is contained in:
parent
bcec8465c9
commit
00ed2b7567
@ -79,7 +79,6 @@ class Decoder(nn.Module):
|
||||
self,
|
||||
embedding_out: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
is_training: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Add the repeat parameter to the embedding_out tensor.
|
||||
@ -89,16 +88,13 @@ class Decoder(nn.Module):
|
||||
A tensor of shape (N, U, decoder_dim).
|
||||
k:
|
||||
A tensor of shape (N, U).
|
||||
Should be (N, S) during training.
|
||||
Should be (N, S + 1) during training.
|
||||
Should be (N, 1) during inference.
|
||||
is_training:
|
||||
Whether it is training.
|
||||
Returns:
|
||||
Return a tensor of shape (N, U, decoder_dim).
|
||||
"""
|
||||
if is_training:
|
||||
k = F.pad(k, (1, 0), mode="constant", value=self.blank_id)
|
||||
|
||||
return embedding_out + (k / (1 + k)).unsqueeze(2) * self.repeat_param
|
||||
|
||||
def forward(
|
||||
@ -140,7 +136,6 @@ class Decoder(nn.Module):
|
||||
embedding_out = self._add_repeat_param(
|
||||
embedding_out=embedding_out,
|
||||
k=k,
|
||||
is_training=need_pad,
|
||||
)
|
||||
embedding_out = F.relu(embedding_out)
|
||||
return embedding_out
|
||||
|
Loading…
x
Reference in New Issue
Block a user