Update greedy search for modified decoder.

This commit is contained in:
Fangjun Kuang 2021-12-21 11:37:56 +08:00
parent 04977175a3
commit afec6b6cae
3 changed files with 18 additions and 14 deletions

View File

@ -36,13 +36,17 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
# support only batch_size == 1 for now # support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0) assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device device = model.device
sos = torch.tensor([blank_id], device=device).reshape(1, 1) sos = torch.tensor([blank_id] * context_size, device=device).reshape(
decoder_out = model.decoder(sos) 1, context_size
)
decoder_out = model.decoder(sos, need_pad=False)
T = encoder_out.size(1) T = encoder_out.size(1)
t = 0 t = 0
hyp = [] hyp = [blank_id] * context_size
sym_per_frame = 0 sym_per_frame = 0
sym_per_utt = 0 sym_per_utt = 0
@ -57,14 +61,14 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
logits = model.joiner(current_encoder_out, decoder_out) logits = model.joiner(current_encoder_out, decoder_out)
# logits is (1, 1, 1, vocab_size) # logits is (1, 1, 1, vocab_size)
log_prob = logits.log_softmax(dim=-1) y = logits.argmax().item()
# log_prob is (1, 1, 1, vocab_size)
# TODO: Use logits.argmax()
y = log_prob.argmax()
if y != blank_id: if y != blank_id:
hyp.append(y.item()) hyp.append(y)
y = y.reshape(1, 1) decoder_input = torch.tensor(
decoder_out = model.decoder(y) [hyp[-context_size:]], device=device
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
sym_per_utt += 1 sym_per_utt += 1
sym_per_frame += 1 sym_per_frame += 1
@ -72,6 +76,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
if y == blank_id or sym_per_frame > max_sym_per_frame: if y == blank_id or sym_per_frame > max_sym_per_frame:
sym_per_frame = 0 sym_per_frame = 0
t += 1 t += 1
hyp = hyp[context_size:] # remove blanks
return hyp return hyp

View File

@ -70,7 +70,7 @@ class Decoder(nn.Module):
bias=False, bias=False,
) )
def forward(self, y: torch.Tensor) -> torch.Tensor: def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
""" """
Args: Args:
y: y:
@ -81,7 +81,7 @@ class Decoder(nn.Module):
embeding_out = self.embedding(y) embeding_out = self.embedding(y)
if self.context_size > 1: if self.context_size > 1:
embeding_out = embeding_out.permute(0, 2, 1) embeding_out = embeding_out.permute(0, 2, 1)
if self.training is True: if need_pad is True:
embeding_out = F.pad( embeding_out = F.pad(
embeding_out, pad=(self.context_size - 1, 0) embeding_out, pad=(self.context_size - 1, 0)
) )

View File

@ -47,9 +47,8 @@ def test_decoder():
assert y.shape == (N, U, embedding_dim) assert y.shape == (N, U, embedding_dim)
# for inference # for inference
decoder.eval()
x = torch.randint(low=0, high=vocab_size, size=(N, context_size)) x = torch.randint(low=0, high=vocab_size, size=(N, context_size))
y = decoder(x) y = decoder(x, need_pad=False)
assert y.shape == (N, 1, embedding_dim) assert y.shape == (N, 1, embedding_dim)