mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
Update greedy search for modified decoder.
This commit is contained in:
parent
04977175a3
commit
afec6b6cae
@ -36,13 +36,17 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
|||||||
# support only batch_size == 1 for now
|
# support only batch_size == 1 for now
|
||||||
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
||||||
blank_id = model.decoder.blank_id
|
blank_id = model.decoder.blank_id
|
||||||
|
context_size = model.decoder.context_size
|
||||||
|
|
||||||
device = model.device
|
device = model.device
|
||||||
|
|
||||||
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
|
sos = torch.tensor([blank_id] * context_size, device=device).reshape(
|
||||||
decoder_out = model.decoder(sos)
|
1, context_size
|
||||||
|
)
|
||||||
|
decoder_out = model.decoder(sos, need_pad=False)
|
||||||
T = encoder_out.size(1)
|
T = encoder_out.size(1)
|
||||||
t = 0
|
t = 0
|
||||||
hyp = []
|
hyp = [blank_id] * context_size
|
||||||
|
|
||||||
sym_per_frame = 0
|
sym_per_frame = 0
|
||||||
sym_per_utt = 0
|
sym_per_utt = 0
|
||||||
@ -57,14 +61,14 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
|||||||
logits = model.joiner(current_encoder_out, decoder_out)
|
logits = model.joiner(current_encoder_out, decoder_out)
|
||||||
# logits is (1, 1, 1, vocab_size)
|
# logits is (1, 1, 1, vocab_size)
|
||||||
|
|
||||||
log_prob = logits.log_softmax(dim=-1)
|
y = logits.argmax().item()
|
||||||
# log_prob is (1, 1, 1, vocab_size)
|
|
||||||
# TODO: Use logits.argmax()
|
|
||||||
y = log_prob.argmax()
|
|
||||||
if y != blank_id:
|
if y != blank_id:
|
||||||
hyp.append(y.item())
|
hyp.append(y)
|
||||||
y = y.reshape(1, 1)
|
decoder_input = torch.tensor(
|
||||||
decoder_out = model.decoder(y)
|
[hyp[-context_size:]], device=device
|
||||||
|
).reshape(1, context_size)
|
||||||
|
|
||||||
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
|
|
||||||
sym_per_utt += 1
|
sym_per_utt += 1
|
||||||
sym_per_frame += 1
|
sym_per_frame += 1
|
||||||
@ -72,6 +76,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
|||||||
if y == blank_id or sym_per_frame > max_sym_per_frame:
|
if y == blank_id or sym_per_frame > max_sym_per_frame:
|
||||||
sym_per_frame = 0
|
sym_per_frame = 0
|
||||||
t += 1
|
t += 1
|
||||||
|
hyp = hyp[context_size:] # remove blanks
|
||||||
|
|
||||||
return hyp
|
return hyp
|
||||||
|
|
||||||
|
@ -70,7 +70,7 @@ class Decoder(nn.Module):
|
|||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, y: torch.Tensor) -> torch.Tensor:
|
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
y:
|
y:
|
||||||
@ -81,7 +81,7 @@ class Decoder(nn.Module):
|
|||||||
embeding_out = self.embedding(y)
|
embeding_out = self.embedding(y)
|
||||||
if self.context_size > 1:
|
if self.context_size > 1:
|
||||||
embeding_out = embeding_out.permute(0, 2, 1)
|
embeding_out = embeding_out.permute(0, 2, 1)
|
||||||
if self.training is True:
|
if need_pad is True:
|
||||||
embeding_out = F.pad(
|
embeding_out = F.pad(
|
||||||
embeding_out, pad=(self.context_size - 1, 0)
|
embeding_out, pad=(self.context_size - 1, 0)
|
||||||
)
|
)
|
||||||
|
@ -47,9 +47,8 @@ def test_decoder():
|
|||||||
assert y.shape == (N, U, embedding_dim)
|
assert y.shape == (N, U, embedding_dim)
|
||||||
|
|
||||||
# for inference
|
# for inference
|
||||||
decoder.eval()
|
|
||||||
x = torch.randint(low=0, high=vocab_size, size=(N, context_size))
|
x = torch.randint(low=0, high=vocab_size, size=(N, context_size))
|
||||||
y = decoder(x)
|
y = decoder(x, need_pad=False)
|
||||||
assert y.shape == (N, 1, embedding_dim)
|
assert y.shape == (N, 1, embedding_dim)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user