fix decoder for emformer_rnnt2

This commit is contained in:
Fangjun Kuang 2024-01-26 16:58:44 +08:00
parent 283227c0c5
commit ed68914fe2

View File

@ -91,7 +91,7 @@ class Decoder(nn.Module):
Returns:
Return a tensor of shape (N, U, embedding_dim).
"""
embedding_out = self.embedding(y)
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True: