diff --git a/egs/librispeech/ASR/conformer_ctc2/.decode.py.swp b/egs/librispeech/ASR/conformer_ctc2/.decode.py.swp index 0cffc404f..d9393c931 100644 Binary files a/egs/librispeech/ASR/conformer_ctc2/.decode.py.swp and b/egs/librispeech/ASR/conformer_ctc2/.decode.py.swp differ diff --git a/egs/librispeech/ASR/conformer_ctc2/decode.py b/egs/librispeech/ASR/conformer_ctc2/decode.py index 73cf62b76..1851c057a 100755 --- a/egs/librispeech/ASR/conformer_ctc2/decode.py +++ b/egs/librispeech/ASR/conformer_ctc2/decode.py @@ -225,6 +225,29 @@ def get_params() -> AttributeDict: return params +def greedy_search( + ctc_probs: torch.Tensor, + mask: torch.Tensor, +) -> List[List[int]]: + """Apply CTC greedy search + Args: + ctc_probs (torch.Tensor): (batch, max_len, num_bpe) + mask (torch.Tensor): (batch, max_len) + Returns: + best path result + """ + + _, max_index = ctc_probs.max(2) # (B, maxlen) + max_index = max_index.masked_fill_(mask, 0) # (B, maxlen) + + ret_hyps = [] + for hyp in max_index: + hyp = torch.unique_consecutive(hyp) + hyp = hyp[hyp > 0].tolist() + ret_hyps.append(hyp) + return ret_hyps + + def ctc_greedy_search( ctc_probs: torch.Tensor, mask: torch.Tensor,