Patches to make decoding work correctly at utt start, for greedy_search
This commit is contained in:
parent
e25ca74955
commit
daa55d5a3c
@ -279,7 +279,7 @@ def greedy_search(
|
|||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
decoder_input = torch.tensor(
|
decoder_input = torch.tensor(
|
||||||
[blank_id] * context_size, device=device, dtype=torch.int64
|
[-1] * (context_size - 1) + [blank_id], device=device, dtype=torch.int64
|
||||||
).reshape(1, context_size)
|
).reshape(1, context_size)
|
||||||
|
|
||||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
@ -373,7 +373,7 @@ def greedy_search_batch(
|
|||||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||||
assert N == batch_size_list[0], (N, batch_size_list)
|
assert N == batch_size_list[0], (N, batch_size_list)
|
||||||
|
|
||||||
hyps = [[blank_id] * context_size for _ in range(N)]
|
hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(N)]
|
||||||
|
|
||||||
decoder_input = torch.tensor(
|
decoder_input = torch.tensor(
|
||||||
hyps,
|
hyps,
|
||||||
|
|||||||
@ -85,7 +85,9 @@ class Decoder(nn.Module):
|
|||||||
Return a tensor of shape (N, U, decoder_dim).
|
Return a tensor of shape (N, U, decoder_dim).
|
||||||
"""
|
"""
|
||||||
y = y.to(torch.int64)
|
y = y.to(torch.int64)
|
||||||
embedding_out = self.embedding(y)
|
# this stuff about clamp() is a temporary fix for a mismatch
|
||||||
|
# at utterance start, we use negative ids in beam_search.py
|
||||||
|
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
|
||||||
if self.context_size > 1:
|
if self.context_size > 1:
|
||||||
embedding_out = embedding_out.permute(0, 2, 1)
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
if need_pad is True:
|
if need_pad is True:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user