from local

This commit is contained in:
dohe0342 2022-12-10 15:07:44 +09:00
parent 87b12f97c0
commit ed7085a03f
3 changed files with 25 additions and 0 deletions

View File

@ -196,3 +196,28 @@ class Transducer(nn.Module):
)
return (simple_loss, pruned_loss, ctc_output)
def decode(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
sp,
):
from beam_search import greedy_search_batch, greedy_search_batch_target_input
assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, x_lens = self.encoder(x, x_lens)
assert torch.all(x_lens > 0)
hyps = []
#hyp_tokens = greedy_search_batch_target_input(self, encoder_out, x_lens, decoder_out)
hyp_tokens = greedy_search_batch(self, encoder_out, x_lens)#, decoder_out)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
return hyps