mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Patches to make decoding work correctly at utt start, for greedy_search
This commit is contained in:
parent
e25ca74955
commit
daa55d5a3c
@ -279,7 +279,7 @@ def greedy_search(
|
||||
device = next(model.parameters()).device
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[blank_id] * context_size, device=device, dtype=torch.int64
|
||||
[-1] * (context_size - 1) + [blank_id], device=device, dtype=torch.int64
|
||||
).reshape(1, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
@ -373,7 +373,7 @@ def greedy_search_batch(
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
assert N == batch_size_list[0], (N, batch_size_list)
|
||||
|
||||
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||
hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(N)]
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
hyps,
|
||||
|
||||
@ -85,7 +85,9 @@ class Decoder(nn.Module):
|
||||
Return a tensor of shape (N, U, decoder_dim).
|
||||
"""
|
||||
y = y.to(torch.int64)
|
||||
embedding_out = self.embedding(y)
|
||||
# this stuff about clamp() is a temporary fix for a mismatch
|
||||
# at utterance start, we use negative ids in beam_search.py
|
||||
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
|
||||
if self.context_size > 1:
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
if need_pad is True:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user