From 0a99ceb6baaaed381ac598a315e5a003309c8702 Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Thu, 14 Jul 2022 00:01:28 +0800 Subject: [PATCH 1/3] psd algorithm --- egs/librispeech/ASR/local/compile_hlg.py | 30 +++- .../beam_search.py | 150 +++++++++++++++++- .../pruned_transducer_stateless6/decode.py | 48 ++++++ 3 files changed, 219 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index 9a35750e0..1ae7c2fdb 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -47,10 +47,19 @@ def get_args(): """, ) + parser.add_argument( + "--h-graph", + type=str, + help="""one of ["H", "Trivial"] + H: k2.ctc_topo + Trivial: k2.trivial_graph + """, + ) + return parser.parse_args() -def compile_HLG(lang_dir: str) -> k2.Fsa: +def compile_HLG(lang_dir: str, h_graph: str = "H") -> k2.Fsa: """ Args: lang_dir: @@ -62,7 +71,14 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: lexicon = Lexicon(lang_dir) max_token_id = max(lexicon.tokens) logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") - H = k2.ctc_topo(max_token_id) + + if h_graph == "H": + H = k2.ctc_topo(max_token_id) + elif h_graph == "Trivial": + H = k2.trivial_graph(max_token_id - 1) + else: + raise ValueError(f"Unsupported h_graph: {h_graph}") + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) if Path("data/lm/G_3_gram.pt").is_file(): @@ -138,15 +154,17 @@ def main(): args = get_args() lang_dir = Path(args.lang_dir) - if (lang_dir / "HLG.pt").is_file(): - logging.info(f"{lang_dir}/HLG.pt already exists - skipping") + if (lang_dir / f"{args.h_graph}LG.pt").is_file(): + logging.info( + f"{lang_dir}/{args.h_graph}LG.pt already exists - skipping" + ) return logging.info(f"Processing {lang_dir}") HLG = compile_HLG(lang_dir) - logging.info(f"Saving HLG.pt to {lang_dir}") - torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt") + logging.info(f"Saving {args.h_graph}LG.pt to {lang_dir}") + torch.save(HLG.as_dict(), f"{lang_dir}/{args.h_graph}LG.pt") if __name__ == "__main__": diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 6b6190a09..b19fe12f5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -14,16 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import warnings from dataclasses import dataclass from typing import Dict, List, Optional import k2 import torch +from torch.nn.utils.rnn import PackedSequence, pad_packed_sequence from model import Transducer -from icefall.decode import Nbest, one_best_decoding -from icefall.utils import get_texts +from icefall.decode import get_lattice, Nbest, one_best_decoding +from icefall.utils import get_alignments, get_texts def fast_beam_search_one_best( @@ -534,6 +536,8 @@ def greedy_search_batch( model: Transducer, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, + decoding_graph: Optional[k2.Fsa] = None, + ngram_rescoring: bool = False, ) -> List[List[int]]: """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. Args: @@ -583,6 +587,18 @@ def greedy_search_batch( encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + if ngram_rescoring: + vocab_size = model.decoder.vocab_size + total_t = encoder_out.shape[0] + # cached all joiner outputs during greedy search, + # from which non-blank frames are selected before n-gram rescoring. + all_logits = torch.zeros([total_t, vocab_size], device=device) + + # A flag indicating a frame is a blank frame or not. + # 0 for blank frame and 1 for non-blank frame. + # Used to select non-blank frames for n-gram rescoring. + non_blank_flag = torch.zeros([total_t], device=device) + offset = 0 for batch_size in batch_size_list: start = offset @@ -600,7 +616,36 @@ def greedy_search_batch( # logits'shape (batch_size, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) - assert logits.ndim == 2, logits.shape + + if ngram_rescoring: + all_logits[start:end] = logits + + assert logits.ndim == 2, logits.shape + logits_argmax = logits.argmax(dim=1) + logits_softmax = logits.softmax(dim=1) + + # detailed in below fuction verify_non_blank_logits. + selection_verification = True + + # 0 for blank frame and 1 for non-blank frame. + non_blank_flag[start:end] = torch.where( + logits_argmax == blank_id, 0, 1 + ) + + if False: + # In paper: https://arxiv.org/pdf/2101.06856.pdf + # A gama_blank threshold value is used to determinze blank frame. + # Currently, results are worse than baseline greedy_search + # and also very sensitive to gama_blank. + # (TODO): debug this later. + gama_blank = 0.50 + non_blank_flag[start:end] = torch.where( + logits_softmax[:, 0] >= gama_blank, 0, 1 + ) + + # function verify_non_blank_logits only works with logits_argmax == blank_id. + selection_verification = False + y = logits.argmax(dim=1).tolist() emitted = False for i, v in enumerate(y): @@ -624,6 +669,105 @@ def greedy_search_batch( for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) + if not ngram_rescoring: + return ans + + assert decoding_graph is not None + + # Transform logits to shape [N, T, vocab_size] format to make it easier + # to select non-blank frames. + packed_all_logits = PackedSequence( + all_logits, torch.tensor(batch_size_list) + ) + all_logits_unpacked, _ = pad_packed_sequence( + packed_all_logits, batch_first=True + ) + + # Transform non_blank_flag to shape [N, T] + packed_non_blank_flag = PackedSequence( + non_blank_flag, torch.tensor(batch_size_list) + ) + non_blank_flag_unpacked, _ = pad_packed_sequence( + packed_non_blank_flag, batch_first=True + ) + + non_blank_logits_lens = torch.sum(non_blank_flag_unpacked, dim=1) + max_frame_to_rescore = non_blank_logits_lens.max() + + non_blank_logits = torch.zeros( + [N, int(max_frame_to_rescore), vocab_size], device=device + ) + + # torch.index_select only acceptec a single dimension to index from. + # So we need generate non_blank_logits one by one. + # Maybe there is another efficient way to do this. + for i in range(N): + cur_non_blank_index = torch.where(non_blank_flag_unpacked[i, :] != 0)[0] + assert non_blank_logits_lens[i] == cur_non_blank_index.shape[0] + non_blank_logits[ + i, : int(non_blank_logits_lens[i]), : + ] = torch.index_select( + all_logits_unpacked[i, :], 0, cur_non_blank_index + ) + + def verify_non_blank_logits(): + # A way to verify non_blank_logits are selected correctly from all_logits. + hyps_before_rescore = non_blank_logits.argmax(dim=2) + for i in range(N): + usi = unsorted_indices[i] + hyp_to_verify = hyps_before_rescore[usi][ + : int(non_blank_logits_lens[usi]) + ].tolist() + assert ans[i] == hyp_to_verify + logging.info("Verified non-blank logits.") + + # TODO: skip verification after we finally get a workable rescoring method. + if selection_verification: + verify_non_blank_logits() + + # Split log_softmax into two seperate steps, + # so we cound do blank deweight in probability domain if needed. + logits_to_rescore_softmax = non_blank_logits.softmax(dim=2) + logits_to_rescore = logits_to_rescore_softmax.log() + + # In paper: https://arxiv.org/pdf/2101.06856.pdf + # blank deweight is applied before non_blank frames selected. + # However, in current setup, that results in a higher WER. + # So just put this blank deweight before ngram rescoring. + # (TODO): debug this blank deweight issue. + + blank_deweight = 100 + logits_to_rescore[:, :, 0] -= blank_deweight + + supervision_segments = torch.zeros([N, 3], dtype=torch.int32) + supervision_segments[:, 0] = torch.arange(0, N, dtype=torch.int32) + supervision_segments[:, 2] = non_blank_logits_lens.to(torch.int32) + + lattice = get_lattice( + nnet_output=logits_to_rescore, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=20, + output_beam=8, + min_active_states=30, + max_active_states=1000, + subsampling_factor=1, + ) + + lm_weight = 0.3 # (TODO): tuning this. + lattice.scores = lattice.scores - lattice.lm_scores * (1 - lm_weight) + + best_path = one_best_decoding( + lattice=lattice, + use_double_scores=True, + ) + + token_ids = get_alignments(best_path, "labels") + + ans = [] + for i in range(N): + usi = unsorted_indices[i] + ans.append(token_ids[usi][: int(non_blank_logits_lens[usi])]) return ans diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py index 701cad73c..60006c2e2 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py @@ -136,6 +136,28 @@ def get_parser(): "`epoch` are loaded for averaging. ", ) + parser.add_argument( + "--ngram-rescoring", + type=str2bool, + default=False, + help="Whether to use ngram_rescoring.", + ) + + parser.add_argument( + "--decoding-graph", + type=str, + default="trivial_graph", + help="one of [trivial_grpah, HLG, Trival_LG, LG]" + "used by greedy_search_batch with ngram-rescoring=True.", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="./data/lang_bpe_500/", + help="Path to decoding graphs", + ) + parser.add_argument( "--exp-dir", type=str, @@ -293,6 +315,8 @@ def decode_one_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, + decoding_graph=decoding_graph, + ngram_rescoring=params.ngram_rescoring, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -498,6 +522,10 @@ def main(): if params.use_averaged_model: params.suffix += "-use-averaged-model" + if params.ngram_rescoring: + params.suffix += "-ngram-rescoring" + params.suffix += f"-{params.decoding_graph}" + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -605,6 +633,26 @@ def main(): else: decoding_graph = None + if params.ngram_rescoring and params.decoding_method == "greedy_search": + assert params.decoding_graph in [ + "trivial_graph", + "HLG", + "Trivial_LG", + ], f"Unsupported decoding graph {params.decoding_graph}" + if params.decoding_graph == "trivial_graph": + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) + else: + decoding_graph = k2.Fsa.from_dict( + torch.load( + f"data/lang_bpe_500/{params.decoding_graph}.pt", + map_location=device, + ) + ) + + decoding_graph.lm_scores = decoding_graph.scores.clone() + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") From 473efcd5313787a0f78b97e06b2263985f6bf85e Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Sat, 16 Jul 2022 01:19:04 +0800 Subject: [PATCH 2/3] add self loop to L --- .../beam_search.py | 39 ++++--------------- .../pruned_transducer_stateless6/decode.py | 12 +++++- icefall/utils.py | 6 ++- 3 files changed, 22 insertions(+), 35 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index b19fe12f5..783ac5070 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -538,6 +538,7 @@ def greedy_search_batch( encoder_out_lens: torch.Tensor, decoding_graph: Optional[k2.Fsa] = None, ngram_rescoring: bool = False, + gamma_blank: float = 1.0, ) -> List[List[int]]: """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. Args: @@ -624,27 +625,12 @@ def greedy_search_batch( logits_argmax = logits.argmax(dim=1) logits_softmax = logits.softmax(dim=1) - # detailed in below fuction verify_non_blank_logits. - selection_verification = True # 0 for blank frame and 1 for non-blank frame. non_blank_flag[start:end] = torch.where( - logits_argmax == blank_id, 0, 1 + logits_softmax[:, 0] >= gamma_blank, 0, 1 ) - if False: - # In paper: https://arxiv.org/pdf/2101.06856.pdf - # A gama_blank threshold value is used to determinze blank frame. - # Currently, results are worse than baseline greedy_search - # and also very sensitive to gama_blank. - # (TODO): debug this later. - gama_blank = 0.50 - non_blank_flag[start:end] = torch.where( - logits_softmax[:, 0] >= gama_blank, 0, 1 - ) - - # function verify_non_blank_logits only works with logits_argmax == blank_id. - selection_verification = False y = logits.argmax(dim=1).tolist() emitted = False @@ -710,21 +696,10 @@ def greedy_search_batch( all_logits_unpacked[i, :], 0, cur_non_blank_index ) - def verify_non_blank_logits(): - # A way to verify non_blank_logits are selected correctly from all_logits. - hyps_before_rescore = non_blank_logits.argmax(dim=2) - for i in range(N): - usi = unsorted_indices[i] - hyp_to_verify = hyps_before_rescore[usi][ - : int(non_blank_logits_lens[usi]) - ].tolist() - assert ans[i] == hyp_to_verify - logging.info("Verified non-blank logits.") - # TODO: skip verification after we finally get a workable rescoring method. - if selection_verification: - verify_non_blank_logits() + number_selected_frames = non_blank_flag.sum() + logging.info(f"{number_selected_frames} are selected out of {total_t} frames") # Split log_softmax into two seperate steps, # so we cound do blank deweight in probability domain if needed. logits_to_rescore_softmax = non_blank_logits.softmax(dim=2) @@ -736,7 +711,7 @@ def greedy_search_batch( # So just put this blank deweight before ngram rescoring. # (TODO): debug this blank deweight issue. - blank_deweight = 100 + blank_deweight = 0.0 logits_to_rescore[:, :, 0] -= blank_deweight supervision_segments = torch.zeros([N, 3], dtype=torch.int32) @@ -754,7 +729,7 @@ def greedy_search_batch( subsampling_factor=1, ) - lm_weight = 0.3 # (TODO): tuning this. + lm_weight = 0.5 # (TODO): tuning this. lattice.scores = lattice.scores - lattice.lm_scores * (1 - lm_weight) best_path = one_best_decoding( @@ -762,7 +737,7 @@ def greedy_search_batch( use_double_scores=True, ) - token_ids = get_alignments(best_path, "labels") + token_ids = get_alignments(best_path, "labels", remove_zero_blank=True) ans = [] for i in range(N): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py index 60006c2e2..fd12f8f29 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py @@ -241,6 +241,12 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) + parser.add_argument( + "--gamma-blank", + type=int, + default=1.0, + ) + return parser @@ -317,6 +323,7 @@ def decode_one_batch( encoder_out_lens=encoder_out_lens, decoding_graph=decoding_graph, ngram_rescoring=params.ngram_rescoring, + gamma_blank=params.gamma_blank, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -525,6 +532,7 @@ def main(): if params.ngram_rescoring: params.suffix += "-ngram-rescoring" params.suffix += f"-{params.decoding_graph}" + params.suffix += f"-gamma_blank-{params.gamma_blank}" setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -636,8 +644,7 @@ def main(): if params.ngram_rescoring and params.decoding_method == "greedy_search": assert params.decoding_graph in [ "trivial_graph", - "HLG", - "Trivial_LG", + "L", ], f"Unsupported decoding graph {params.decoding_graph}" if params.decoding_graph == "trivial_graph": decoding_graph = k2.trivial_graph( @@ -650,6 +657,7 @@ def main(): map_location=device, ) ) + decoding_graph = k2.add_epsilon_self_loops(decoding_graph) decoding_graph.lm_scores = decoding_graph.scores.clone() diff --git a/icefall/utils.py b/icefall/utils.py index 3bfd5e5b1..bf11eae27 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -236,7 +236,9 @@ def get_texts( return aux_labels.tolist() -def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]: +def get_alignments( + best_paths: k2.Fsa, kind: str, remove_zero_blank: bool = False +) -> List[List[int]]: """Extract labels or aux_labels from the best-path FSAs. Args: @@ -272,6 +274,8 @@ def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]: token_shape, getattr(best_paths, kind).contiguous() ) tokens = tokens.remove_values_eq(-1) + if remove_zero_blank: + tokens = tokens.remove_values_eq(0) return tokens.tolist() From 1ebf714fb758942266ef8a8fdcae54c5061f762c Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Sat, 16 Jul 2022 13:37:31 +0800 Subject: [PATCH 3/3] remove hlg related modifications --- egs/librispeech/ASR/local/compile_hlg.py | 30 ++++--------------- .../beam_search.py | 6 +--- .../pruned_transducer_stateless6/decode.py | 2 -- 3 files changed, 7 insertions(+), 31 deletions(-) diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index 1ae7c2fdb..9a35750e0 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -47,19 +47,10 @@ def get_args(): """, ) - parser.add_argument( - "--h-graph", - type=str, - help="""one of ["H", "Trivial"] - H: k2.ctc_topo - Trivial: k2.trivial_graph - """, - ) - return parser.parse_args() -def compile_HLG(lang_dir: str, h_graph: str = "H") -> k2.Fsa: +def compile_HLG(lang_dir: str) -> k2.Fsa: """ Args: lang_dir: @@ -71,14 +62,7 @@ def compile_HLG(lang_dir: str, h_graph: str = "H") -> k2.Fsa: lexicon = Lexicon(lang_dir) max_token_id = max(lexicon.tokens) logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") - - if h_graph == "H": - H = k2.ctc_topo(max_token_id) - elif h_graph == "Trivial": - H = k2.trivial_graph(max_token_id - 1) - else: - raise ValueError(f"Unsupported h_graph: {h_graph}") - + H = k2.ctc_topo(max_token_id) L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) if Path("data/lm/G_3_gram.pt").is_file(): @@ -154,17 +138,15 @@ def main(): args = get_args() lang_dir = Path(args.lang_dir) - if (lang_dir / f"{args.h_graph}LG.pt").is_file(): - logging.info( - f"{lang_dir}/{args.h_graph}LG.pt already exists - skipping" - ) + if (lang_dir / "HLG.pt").is_file(): + logging.info(f"{lang_dir}/HLG.pt already exists - skipping") return logging.info(f"Processing {lang_dir}") HLG = compile_HLG(lang_dir) - logging.info(f"Saving {args.h_graph}LG.pt to {lang_dir}") - torch.save(HLG.as_dict(), f"{lang_dir}/{args.h_graph}LG.pt") + logging.info(f"Saving HLG.pt to {lang_dir}") + torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt") if __name__ == "__main__": diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 783ac5070..38643c270 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -614,15 +614,14 @@ def greedy_search_batch( logits = model.joiner( current_encoder_out, decoder_out.unsqueeze(1), project_input=False ) - # logits'shape (batch_size, 1, 1, vocab_size) + # logits'shape (batch_size, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) if ngram_rescoring: all_logits[start:end] = logits assert logits.ndim == 2, logits.shape - logits_argmax = logits.argmax(dim=1) logits_softmax = logits.softmax(dim=1) @@ -729,9 +728,6 @@ def greedy_search_batch( subsampling_factor=1, ) - lm_weight = 0.5 # (TODO): tuning this. - lattice.scores = lattice.scores - lattice.lm_scores * (1 - lm_weight) - best_path = one_best_decoding( lattice=lattice, use_double_scores=True, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py index fd12f8f29..7aed94674 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py @@ -659,8 +659,6 @@ def main(): ) decoding_graph = k2.add_epsilon_self_loops(decoding_graph) - decoding_graph.lm_scores = decoding_graph.scores.clone() - num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}")