mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
Update greedy search for modified decoder.
This commit is contained in:
parent
04977175a3
commit
afec6b6cae
@ -36,13 +36,17 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
||||
# support only batch_size == 1 for now
|
||||
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
||||
blank_id = model.decoder.blank_id
|
||||
context_size = model.decoder.context_size
|
||||
|
||||
device = model.device
|
||||
|
||||
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
|
||||
decoder_out = model.decoder(sos)
|
||||
sos = torch.tensor([blank_id] * context_size, device=device).reshape(
|
||||
1, context_size
|
||||
)
|
||||
decoder_out = model.decoder(sos, need_pad=False)
|
||||
T = encoder_out.size(1)
|
||||
t = 0
|
||||
hyp = []
|
||||
hyp = [blank_id] * context_size
|
||||
|
||||
sym_per_frame = 0
|
||||
sym_per_utt = 0
|
||||
@ -57,14 +61,14 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
||||
logits = model.joiner(current_encoder_out, decoder_out)
|
||||
# logits is (1, 1, 1, vocab_size)
|
||||
|
||||
log_prob = logits.log_softmax(dim=-1)
|
||||
# log_prob is (1, 1, 1, vocab_size)
|
||||
# TODO: Use logits.argmax()
|
||||
y = log_prob.argmax()
|
||||
y = logits.argmax().item()
|
||||
if y != blank_id:
|
||||
hyp.append(y.item())
|
||||
y = y.reshape(1, 1)
|
||||
decoder_out = model.decoder(y)
|
||||
hyp.append(y)
|
||||
decoder_input = torch.tensor(
|
||||
[hyp[-context_size:]], device=device
|
||||
).reshape(1, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
|
||||
sym_per_utt += 1
|
||||
sym_per_frame += 1
|
||||
@ -72,6 +76,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
||||
if y == blank_id or sym_per_frame > max_sym_per_frame:
|
||||
sym_per_frame = 0
|
||||
t += 1
|
||||
hyp = hyp[context_size:] # remove blanks
|
||||
|
||||
return hyp
|
||||
|
||||
|
@ -70,7 +70,7 @@ class Decoder(nn.Module):
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, y: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
@ -81,7 +81,7 @@ class Decoder(nn.Module):
|
||||
embeding_out = self.embedding(y)
|
||||
if self.context_size > 1:
|
||||
embeding_out = embeding_out.permute(0, 2, 1)
|
||||
if self.training is True:
|
||||
if need_pad is True:
|
||||
embeding_out = F.pad(
|
||||
embeding_out, pad=(self.context_size - 1, 0)
|
||||
)
|
||||
|
@ -47,9 +47,8 @@ def test_decoder():
|
||||
assert y.shape == (N, U, embedding_dim)
|
||||
|
||||
# for inference
|
||||
decoder.eval()
|
||||
x = torch.randint(low=0, high=vocab_size, size=(N, context_size))
|
||||
y = decoder(x)
|
||||
y = decoder(x, need_pad=False)
|
||||
assert y.shape == (N, 1, embedding_dim)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user