Patches to make decoding work correctly at utt start, for greedy_search

This commit is contained in:
Daniel Povey 2022-07-27 09:32:51 +08:00
parent e25ca74955
commit daa55d5a3c
2 changed files with 5 additions and 3 deletions

View File

@ -279,7 +279,7 @@ def greedy_search(
device = next(model.parameters()).device
decoder_input = torch.tensor(
[blank_id] * context_size, device=device, dtype=torch.int64
[-1] * (context_size - 1) + [blank_id], device=device, dtype=torch.int64
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
@ -373,7 +373,7 @@ def greedy_search_batch(
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
hyps = [[blank_id] * context_size for _ in range(N)]
hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(N)]
decoder_input = torch.tensor(
hyps,

View File

@ -85,7 +85,9 @@ class Decoder(nn.Module):
Return a tensor of shape (N, U, decoder_dim).
"""
y = y.to(torch.int64)
embedding_out = self.embedding(y)
# this stuff about clamp() is a temporary fix for a mismatch
# at utterance start, we use negative ids in beam_search.py
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True: