diff --git a/egs/librispeech/ASR/conformer_ctc3/decode.py b/egs/librispeech/ASR/conformer_ctc3/decode.py index 3b24ad597..4f96fb7eb 100755 --- a/egs/librispeech/ASR/conformer_ctc3/decode.py +++ b/egs/librispeech/ASR/conformer_ctc3/decode.py @@ -95,7 +95,10 @@ from icefall.decode import ( from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, + convert_timestamp, get_texts, + make_pad_mask, + parse_bpe_start_end_pairs, parse_fsa_timestamps_and_texts, setup_logger, store_transcripts_and_timestamps, @@ -170,21 +173,24 @@ def get_parser(): default="ctc-decoding", help="""Decoding method. Supported values are: - - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece + - (0) ctc-greedy-search. It uses a sentence piece model, + i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (1) ctc-decoding. Use CTC decoding. It uses a sentence piece model, i.e., lang_dir/bpe.model, to convert word pieces to words. It needs neither a lexicon nor an n-gram LM. - - (1) 1best. Extract the best path from the decoding lattice as the + - (2) 1best. Extract the best path from the decoding lattice as the decoding result. - - (2) nbest. Extract n paths from the decoding lattice; the path + - (3) nbest. Extract n paths from the decoding lattice; the path with the highest score is the decoding result. - - (3) nbest-rescoring. Extract n paths from the decoding lattice, + - (4) nbest-rescoring. Extract n paths from the decoding lattice, rescore them with an n-gram LM (e.g., a 4-gram LM), the path with the highest score is the decoding result. - - (4) whole-lattice-rescoring. Rescore the decoding lattice with an + - (5) whole-lattice-rescoring. Rescore the decoding lattice with an n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice is the decoding result. you have trained an RNN LM using ./rnn_lm/train.py - - (5) nbest-oracle. Its WER is the lower bound of any n-best + - (6) nbest-oracle. Its WER is the lower bound of any n-best rescoring method can achieve. Useful for debugging n-best rescoring method. """, @@ -272,6 +278,101 @@ def get_decoding_params() -> AttributeDict: return params +def ctc_greedy_search( + ctc_probs: torch.Tensor, + nnet_output_lens: torch.Tensor, + sp: spm.SentencePieceProcessor, + subsampling_factor: int = 4, + frame_shift_ms: float = 10, +) -> Tuple[List[Tuple[float, float]], List[List[str]]]: + """Apply CTC greedy search + Args: + ctc_probs (torch.Tensor): + (batch, max_len, feat_dim) + nnet_output_lens (torch.Tensor): + (batch, ) + sp: + The BPE model. + subsampling_factor: + The subsampling factor of the model. + frame_shift_ms: + Frame shift in milliseconds between two contiguous frames. + + Returns: + utt_time_pairs: + A list of pair list. utt_time_pairs[i] is a list of + (start-time, end-time) pairs for each word in + utterance-i. + utt_words: + A list of str list. utt_words[i] is a word list of utterence-i. + """ + topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1) + topk_index = topk_index.squeeze(2) # (B, maxlen) + mask = make_pad_mask(nnet_output_lens) + topk_index = topk_index.masked_fill_(mask, 0) # (B, maxlen) + hyps = [hyp.tolist() for hyp in topk_index] + + def get_first_tokens(tokens: List[int]) -> List[bool]: + is_first_token = [] + first_tokens = [] + for t in range(len(tokens)): + if tokens[t] != 0 and (t == 0 or tokens[t - 1] != tokens[t]): + is_first_token.append(True) + first_tokens.append(tokens[t]) + else: + is_first_token.append(False) + return first_tokens, is_first_token + + utt_time_pairs = [] + utt_words = [] + for utt in range(len(hyps)): + first_tokens, is_first_token = get_first_tokens(hyps[utt]) + all_tokens = sp.id_to_piece(hyps[utt]) + index_pairs = parse_bpe_start_end_pairs(all_tokens, is_first_token) + words = sp.decode(first_tokens).split() + assert len(index_pairs) == len(words), ( + len(index_pairs), + len(words), + all_tokens, + ) + start = convert_timestamp( + frames=[i[0] for i in index_pairs], + subsampling_factor=subsampling_factor, + frame_shift_ms=frame_shift_ms, + ) + end = convert_timestamp( + # The duration in frames is (end_frame_index - start_frame_index + 1) + frames=[i[1] + 1 for i in index_pairs], + subsampling_factor=subsampling_factor, + frame_shift_ms=frame_shift_ms, + ) + utt_time_pairs.append(list(zip(start, end))) + utt_words.append(words) + + return utt_time_pairs, utt_words + + +def remove_duplicates_and_blank(hyp: List[int]) -> Tuple[List[int], List[int]]: + # modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py + new_hyp: List[int] = [] + time: List[Tuple[int, int]] = [] + cur = 0 + start, end = -1, -1 + while cur < len(hyp): + if hyp[cur] != 0: + new_hyp.append(hyp[cur]) + start = cur + prev = cur + while cur < len(hyp) and hyp[cur] == hyp[prev]: + if start != -1: + end = cur + cur += 1 + if start != -1 and end != -1: + time.append((start, end)) + start, end = -1, -1 + return new_hyp, time + + def decode_one_batch( params: AttributeDict, model: nn.Module, @@ -363,6 +464,17 @@ def decode_one_batch( nnet_output = model.get_ctc_output(encoder_out) # nnet_output is (N, T, C) + if params.decoding_method == "ctc-greedy-search": + timestamps, hyps = ctc_greedy_search( + ctc_probs=nnet_output, + nnet_output_lens=encoder_out_lens, + sp=bpe_model, + subsampling_factor=params.subsampling_factor, + frame_shift_ms=params.frame_shift_ms, + ) + key = "ctc-greedy-search" + return {key: (hyps, timestamps)} + supervision_segments = torch.stack( ( supervisions["sequence_idx"], @@ -699,6 +811,7 @@ def main(): params.update(vars(args)) assert params.decoding_method in ( + "ctc-greedy-search", "ctc-decoding", "1best", "nbest", @@ -752,7 +865,7 @@ def main(): params.sos_id = sos_id params.eos_id = eos_id - if params.decoding_method == "ctc-decoding": + if params.decoding_method in ["ctc-decoding", "ctc-greedy-search"]: HLG = None H = k2.ctc_topo( max_token=max_token_id,