From 1ebf714fb758942266ef8a8fdcae54c5061f762c Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Sat, 16 Jul 2022 13:37:31 +0800 Subject: [PATCH] remove hlg related modifications --- egs/librispeech/ASR/local/compile_hlg.py | 30 ++++--------------- .../beam_search.py | 6 +--- .../pruned_transducer_stateless6/decode.py | 2 -- 3 files changed, 7 insertions(+), 31 deletions(-) diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index 1ae7c2fdb..9a35750e0 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -47,19 +47,10 @@ def get_args(): """, ) - parser.add_argument( - "--h-graph", - type=str, - help="""one of ["H", "Trivial"] - H: k2.ctc_topo - Trivial: k2.trivial_graph - """, - ) - return parser.parse_args() -def compile_HLG(lang_dir: str, h_graph: str = "H") -> k2.Fsa: +def compile_HLG(lang_dir: str) -> k2.Fsa: """ Args: lang_dir: @@ -71,14 +62,7 @@ def compile_HLG(lang_dir: str, h_graph: str = "H") -> k2.Fsa: lexicon = Lexicon(lang_dir) max_token_id = max(lexicon.tokens) logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") - - if h_graph == "H": - H = k2.ctc_topo(max_token_id) - elif h_graph == "Trivial": - H = k2.trivial_graph(max_token_id - 1) - else: - raise ValueError(f"Unsupported h_graph: {h_graph}") - + H = k2.ctc_topo(max_token_id) L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) if Path("data/lm/G_3_gram.pt").is_file(): @@ -154,17 +138,15 @@ def main(): args = get_args() lang_dir = Path(args.lang_dir) - if (lang_dir / f"{args.h_graph}LG.pt").is_file(): - logging.info( - f"{lang_dir}/{args.h_graph}LG.pt already exists - skipping" - ) + if (lang_dir / "HLG.pt").is_file(): + logging.info(f"{lang_dir}/HLG.pt already exists - skipping") return logging.info(f"Processing {lang_dir}") HLG = compile_HLG(lang_dir) - logging.info(f"Saving {args.h_graph}LG.pt to {lang_dir}") - torch.save(HLG.as_dict(), f"{lang_dir}/{args.h_graph}LG.pt") + logging.info(f"Saving HLG.pt to {lang_dir}") + torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt") if __name__ == "__main__": diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 783ac5070..38643c270 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -614,15 +614,14 @@ def greedy_search_batch( logits = model.joiner( current_encoder_out, decoder_out.unsqueeze(1), project_input=False ) - # logits'shape (batch_size, 1, 1, vocab_size) + # logits'shape (batch_size, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) if ngram_rescoring: all_logits[start:end] = logits assert logits.ndim == 2, logits.shape - logits_argmax = logits.argmax(dim=1) logits_softmax = logits.softmax(dim=1) @@ -729,9 +728,6 @@ def greedy_search_batch( subsampling_factor=1, ) - lm_weight = 0.5 # (TODO): tuning this. - lattice.scores = lattice.scores - lattice.lm_scores * (1 - lm_weight) - best_path = one_best_decoding( lattice=lattice, use_double_scores=True, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py index fd12f8f29..7aed94674 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py @@ -659,8 +659,6 @@ def main(): ) decoding_graph = k2.add_epsilon_self_loops(decoding_graph) - decoding_graph.lm_scores = decoding_graph.scores.clone() - num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}")