diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v/.model.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_d2v/.model.py.swp index 5c46a7f58..26cb0bd9c 100644 Binary files a/egs/librispeech/ASR/pruned_transducer_stateless_d2v/.model.py.swp and b/egs/librispeech/ASR/pruned_transducer_stateless_d2v/.model.py.swp differ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.model.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.model.py.swp index 1f5c2a1bf..2fd391735 100644 Binary files a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.model.py.swp and b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.model.py.swp differ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/model.py index 9f2305393..dc250e218 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/model.py @@ -196,3 +196,28 @@ class Transducer(nn.Module): ) return (simple_loss, pruned_loss, ctc_output) + + def decode( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + sp, + ): + from beam_search import greedy_search_batch, greedy_search_batch_target_input + + assert x.size(0) == x_lens.size(0) == y.dim0 + + encoder_out, x_lens = self.encoder(x, x_lens) + + assert torch.all(x_lens > 0) + + hyps = [] + #hyp_tokens = greedy_search_batch_target_input(self, encoder_out, x_lens, decoder_out) + hyp_tokens = greedy_search_batch(self, encoder_out, x_lens)#, decoder_out) + + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + + return hyps +