Update greedy search for modified decoder.

This commit is contained in:
Fangjun Kuang 2021-12-21 11:37:56 +08:00
parent 04977175a3
commit afec6b6cae
3 changed files with 18 additions and 14 deletions

View File

@ -36,13 +36,17 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
decoder_out = model.decoder(sos)
sos = torch.tensor([blank_id] * context_size, device=device).reshape(
1, context_size
)
decoder_out = model.decoder(sos, need_pad=False)
T = encoder_out.size(1)
t = 0
hyp = []
hyp = [blank_id] * context_size
sym_per_frame = 0
sym_per_utt = 0
@ -57,14 +61,14 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
logits = model.joiner(current_encoder_out, decoder_out)
# logits is (1, 1, 1, vocab_size)
log_prob = logits.log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size)
# TODO: Use logits.argmax()
y = log_prob.argmax()
y = logits.argmax().item()
if y != blank_id:
hyp.append(y.item())
y = y.reshape(1, 1)
decoder_out = model.decoder(y)
hyp.append(y)
decoder_input = torch.tensor(
[hyp[-context_size:]], device=device
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
sym_per_utt += 1
sym_per_frame += 1
@ -72,6 +76,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
if y == blank_id or sym_per_frame > max_sym_per_frame:
sym_per_frame = 0
t += 1
hyp = hyp[context_size:] # remove blanks
return hyp

View File

@ -70,7 +70,7 @@ class Decoder(nn.Module):
bias=False,
)
def forward(self, y: torch.Tensor) -> torch.Tensor:
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
"""
Args:
y:
@ -81,7 +81,7 @@ class Decoder(nn.Module):
embeding_out = self.embedding(y)
if self.context_size > 1:
embeding_out = embeding_out.permute(0, 2, 1)
if self.training is True:
if need_pad is True:
embeding_out = F.pad(
embeding_out, pad=(self.context_size - 1, 0)
)

View File

@ -47,9 +47,8 @@ def test_decoder():
assert y.shape == (N, U, embedding_dim)
# for inference
decoder.eval()
x = torch.randint(low=0, high=vocab_size, size=(N, context_size))
y = decoder(x)
y = decoder(x, need_pad=False)
assert y.shape == (N, 1, embedding_dim)