diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index 9a35750e0..1ae7c2fdb 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -47,10 +47,19 @@ def get_args(): """, ) + parser.add_argument( + "--h-graph", + type=str, + help="""one of ["H", "Trivial"] + H: k2.ctc_topo + Trivial: k2.trivial_graph + """, + ) + return parser.parse_args() -def compile_HLG(lang_dir: str) -> k2.Fsa: +def compile_HLG(lang_dir: str, h_graph: str = "H") -> k2.Fsa: """ Args: lang_dir: @@ -62,7 +71,14 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: lexicon = Lexicon(lang_dir) max_token_id = max(lexicon.tokens) logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") - H = k2.ctc_topo(max_token_id) + + if h_graph == "H": + H = k2.ctc_topo(max_token_id) + elif h_graph == "Trivial": + H = k2.trivial_graph(max_token_id - 1) + else: + raise ValueError(f"Unsupported h_graph: {h_graph}") + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) if Path("data/lm/G_3_gram.pt").is_file(): @@ -138,15 +154,17 @@ def main(): args = get_args() lang_dir = Path(args.lang_dir) - if (lang_dir / "HLG.pt").is_file(): - logging.info(f"{lang_dir}/HLG.pt already exists - skipping") + if (lang_dir / f"{args.h_graph}LG.pt").is_file(): + logging.info( + f"{lang_dir}/{args.h_graph}LG.pt already exists - skipping" + ) return logging.info(f"Processing {lang_dir}") HLG = compile_HLG(lang_dir) - logging.info(f"Saving HLG.pt to {lang_dir}") - torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt") + logging.info(f"Saving {args.h_graph}LG.pt to {lang_dir}") + torch.save(HLG.as_dict(), f"{lang_dir}/{args.h_graph}LG.pt") if __name__ == "__main__": diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 6b6190a09..b19fe12f5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -14,16 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import warnings from dataclasses import dataclass from typing import Dict, List, Optional import k2 import torch +from torch.nn.utils.rnn import PackedSequence, pad_packed_sequence from model import Transducer -from icefall.decode import Nbest, one_best_decoding -from icefall.utils import get_texts +from icefall.decode import get_lattice, Nbest, one_best_decoding +from icefall.utils import get_alignments, get_texts def fast_beam_search_one_best( @@ -534,6 +536,8 @@ def greedy_search_batch( model: Transducer, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, + decoding_graph: Optional[k2.Fsa] = None, + ngram_rescoring: bool = False, ) -> List[List[int]]: """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. Args: @@ -583,6 +587,18 @@ def greedy_search_batch( encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + if ngram_rescoring: + vocab_size = model.decoder.vocab_size + total_t = encoder_out.shape[0] + # cached all joiner outputs during greedy search, + # from which non-blank frames are selected before n-gram rescoring. + all_logits = torch.zeros([total_t, vocab_size], device=device) + + # A flag indicating a frame is a blank frame or not. + # 0 for blank frame and 1 for non-blank frame. + # Used to select non-blank frames for n-gram rescoring. + non_blank_flag = torch.zeros([total_t], device=device) + offset = 0 for batch_size in batch_size_list: start = offset @@ -600,7 +616,36 @@ def greedy_search_batch( # logits'shape (batch_size, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) - assert logits.ndim == 2, logits.shape + + if ngram_rescoring: + all_logits[start:end] = logits + + assert logits.ndim == 2, logits.shape + logits_argmax = logits.argmax(dim=1) + logits_softmax = logits.softmax(dim=1) + + # detailed in below fuction verify_non_blank_logits. + selection_verification = True + + # 0 for blank frame and 1 for non-blank frame. + non_blank_flag[start:end] = torch.where( + logits_argmax == blank_id, 0, 1 + ) + + if False: + # In paper: https://arxiv.org/pdf/2101.06856.pdf + # A gama_blank threshold value is used to determinze blank frame. + # Currently, results are worse than baseline greedy_search + # and also very sensitive to gama_blank. + # (TODO): debug this later. + gama_blank = 0.50 + non_blank_flag[start:end] = torch.where( + logits_softmax[:, 0] >= gama_blank, 0, 1 + ) + + # function verify_non_blank_logits only works with logits_argmax == blank_id. + selection_verification = False + y = logits.argmax(dim=1).tolist() emitted = False for i, v in enumerate(y): @@ -624,6 +669,105 @@ def greedy_search_batch( for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) + if not ngram_rescoring: + return ans + + assert decoding_graph is not None + + # Transform logits to shape [N, T, vocab_size] format to make it easier + # to select non-blank frames. + packed_all_logits = PackedSequence( + all_logits, torch.tensor(batch_size_list) + ) + all_logits_unpacked, _ = pad_packed_sequence( + packed_all_logits, batch_first=True + ) + + # Transform non_blank_flag to shape [N, T] + packed_non_blank_flag = PackedSequence( + non_blank_flag, torch.tensor(batch_size_list) + ) + non_blank_flag_unpacked, _ = pad_packed_sequence( + packed_non_blank_flag, batch_first=True + ) + + non_blank_logits_lens = torch.sum(non_blank_flag_unpacked, dim=1) + max_frame_to_rescore = non_blank_logits_lens.max() + + non_blank_logits = torch.zeros( + [N, int(max_frame_to_rescore), vocab_size], device=device + ) + + # torch.index_select only acceptec a single dimension to index from. + # So we need generate non_blank_logits one by one. + # Maybe there is another efficient way to do this. + for i in range(N): + cur_non_blank_index = torch.where(non_blank_flag_unpacked[i, :] != 0)[0] + assert non_blank_logits_lens[i] == cur_non_blank_index.shape[0] + non_blank_logits[ + i, : int(non_blank_logits_lens[i]), : + ] = torch.index_select( + all_logits_unpacked[i, :], 0, cur_non_blank_index + ) + + def verify_non_blank_logits(): + # A way to verify non_blank_logits are selected correctly from all_logits. + hyps_before_rescore = non_blank_logits.argmax(dim=2) + for i in range(N): + usi = unsorted_indices[i] + hyp_to_verify = hyps_before_rescore[usi][ + : int(non_blank_logits_lens[usi]) + ].tolist() + assert ans[i] == hyp_to_verify + logging.info("Verified non-blank logits.") + + # TODO: skip verification after we finally get a workable rescoring method. + if selection_verification: + verify_non_blank_logits() + + # Split log_softmax into two seperate steps, + # so we cound do blank deweight in probability domain if needed. + logits_to_rescore_softmax = non_blank_logits.softmax(dim=2) + logits_to_rescore = logits_to_rescore_softmax.log() + + # In paper: https://arxiv.org/pdf/2101.06856.pdf + # blank deweight is applied before non_blank frames selected. + # However, in current setup, that results in a higher WER. + # So just put this blank deweight before ngram rescoring. + # (TODO): debug this blank deweight issue. + + blank_deweight = 100 + logits_to_rescore[:, :, 0] -= blank_deweight + + supervision_segments = torch.zeros([N, 3], dtype=torch.int32) + supervision_segments[:, 0] = torch.arange(0, N, dtype=torch.int32) + supervision_segments[:, 2] = non_blank_logits_lens.to(torch.int32) + + lattice = get_lattice( + nnet_output=logits_to_rescore, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=20, + output_beam=8, + min_active_states=30, + max_active_states=1000, + subsampling_factor=1, + ) + + lm_weight = 0.3 # (TODO): tuning this. + lattice.scores = lattice.scores - lattice.lm_scores * (1 - lm_weight) + + best_path = one_best_decoding( + lattice=lattice, + use_double_scores=True, + ) + + token_ids = get_alignments(best_path, "labels") + + ans = [] + for i in range(N): + usi = unsorted_indices[i] + ans.append(token_ids[usi][: int(non_blank_logits_lens[usi])]) return ans diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py index 701cad73c..60006c2e2 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py @@ -136,6 +136,28 @@ def get_parser(): "`epoch` are loaded for averaging. ", ) + parser.add_argument( + "--ngram-rescoring", + type=str2bool, + default=False, + help="Whether to use ngram_rescoring.", + ) + + parser.add_argument( + "--decoding-graph", + type=str, + default="trivial_graph", + help="one of [trivial_grpah, HLG, Trival_LG, LG]" + "used by greedy_search_batch with ngram-rescoring=True.", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="./data/lang_bpe_500/", + help="Path to decoding graphs", + ) + parser.add_argument( "--exp-dir", type=str, @@ -293,6 +315,8 @@ def decode_one_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, + decoding_graph=decoding_graph, + ngram_rescoring=params.ngram_rescoring, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -498,6 +522,10 @@ def main(): if params.use_averaged_model: params.suffix += "-use-averaged-model" + if params.ngram_rescoring: + params.suffix += "-ngram-rescoring" + params.suffix += f"-{params.decoding_graph}" + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -605,6 +633,26 @@ def main(): else: decoding_graph = None + if params.ngram_rescoring and params.decoding_method == "greedy_search": + assert params.decoding_graph in [ + "trivial_graph", + "HLG", + "Trivial_LG", + ], f"Unsupported decoding graph {params.decoding_graph}" + if params.decoding_graph == "trivial_graph": + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) + else: + decoding_graph = k2.Fsa.from_dict( + torch.load( + f"data/lang_bpe_500/{params.decoding_graph}.pt", + map_location=device, + ) + ) + + decoding_graph.lm_scores = decoding_graph.scores.clone() + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}")