From 473efcd5313787a0f78b97e06b2263985f6bf85e Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Sat, 16 Jul 2022 01:19:04 +0800 Subject: [PATCH] add self loop to L --- .../beam_search.py | 39 ++++--------------- .../pruned_transducer_stateless6/decode.py | 12 +++++- icefall/utils.py | 6 ++- 3 files changed, 22 insertions(+), 35 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index b19fe12f5..783ac5070 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -538,6 +538,7 @@ def greedy_search_batch( encoder_out_lens: torch.Tensor, decoding_graph: Optional[k2.Fsa] = None, ngram_rescoring: bool = False, + gamma_blank: float = 1.0, ) -> List[List[int]]: """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. Args: @@ -624,27 +625,12 @@ def greedy_search_batch( logits_argmax = logits.argmax(dim=1) logits_softmax = logits.softmax(dim=1) - # detailed in below fuction verify_non_blank_logits. - selection_verification = True # 0 for blank frame and 1 for non-blank frame. non_blank_flag[start:end] = torch.where( - logits_argmax == blank_id, 0, 1 + logits_softmax[:, 0] >= gamma_blank, 0, 1 ) - if False: - # In paper: https://arxiv.org/pdf/2101.06856.pdf - # A gama_blank threshold value is used to determinze blank frame. - # Currently, results are worse than baseline greedy_search - # and also very sensitive to gama_blank. - # (TODO): debug this later. - gama_blank = 0.50 - non_blank_flag[start:end] = torch.where( - logits_softmax[:, 0] >= gama_blank, 0, 1 - ) - - # function verify_non_blank_logits only works with logits_argmax == blank_id. - selection_verification = False y = logits.argmax(dim=1).tolist() emitted = False @@ -710,21 +696,10 @@ def greedy_search_batch( all_logits_unpacked[i, :], 0, cur_non_blank_index ) - def verify_non_blank_logits(): - # A way to verify non_blank_logits are selected correctly from all_logits. - hyps_before_rescore = non_blank_logits.argmax(dim=2) - for i in range(N): - usi = unsorted_indices[i] - hyp_to_verify = hyps_before_rescore[usi][ - : int(non_blank_logits_lens[usi]) - ].tolist() - assert ans[i] == hyp_to_verify - logging.info("Verified non-blank logits.") - # TODO: skip verification after we finally get a workable rescoring method. - if selection_verification: - verify_non_blank_logits() + number_selected_frames = non_blank_flag.sum() + logging.info(f"{number_selected_frames} are selected out of {total_t} frames") # Split log_softmax into two seperate steps, # so we cound do blank deweight in probability domain if needed. logits_to_rescore_softmax = non_blank_logits.softmax(dim=2) @@ -736,7 +711,7 @@ def greedy_search_batch( # So just put this blank deweight before ngram rescoring. # (TODO): debug this blank deweight issue. - blank_deweight = 100 + blank_deweight = 0.0 logits_to_rescore[:, :, 0] -= blank_deweight supervision_segments = torch.zeros([N, 3], dtype=torch.int32) @@ -754,7 +729,7 @@ def greedy_search_batch( subsampling_factor=1, ) - lm_weight = 0.3 # (TODO): tuning this. + lm_weight = 0.5 # (TODO): tuning this. lattice.scores = lattice.scores - lattice.lm_scores * (1 - lm_weight) best_path = one_best_decoding( @@ -762,7 +737,7 @@ def greedy_search_batch( use_double_scores=True, ) - token_ids = get_alignments(best_path, "labels") + token_ids = get_alignments(best_path, "labels", remove_zero_blank=True) ans = [] for i in range(N): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py index 60006c2e2..fd12f8f29 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py @@ -241,6 +241,12 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) + parser.add_argument( + "--gamma-blank", + type=int, + default=1.0, + ) + return parser @@ -317,6 +323,7 @@ def decode_one_batch( encoder_out_lens=encoder_out_lens, decoding_graph=decoding_graph, ngram_rescoring=params.ngram_rescoring, + gamma_blank=params.gamma_blank, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -525,6 +532,7 @@ def main(): if params.ngram_rescoring: params.suffix += "-ngram-rescoring" params.suffix += f"-{params.decoding_graph}" + params.suffix += f"-gamma_blank-{params.gamma_blank}" setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -636,8 +644,7 @@ def main(): if params.ngram_rescoring and params.decoding_method == "greedy_search": assert params.decoding_graph in [ "trivial_graph", - "HLG", - "Trivial_LG", + "L", ], f"Unsupported decoding graph {params.decoding_graph}" if params.decoding_graph == "trivial_graph": decoding_graph = k2.trivial_graph( @@ -650,6 +657,7 @@ def main(): map_location=device, ) ) + decoding_graph = k2.add_epsilon_self_loops(decoding_graph) decoding_graph.lm_scores = decoding_graph.scores.clone() diff --git a/icefall/utils.py b/icefall/utils.py index 3bfd5e5b1..bf11eae27 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -236,7 +236,9 @@ def get_texts( return aux_labels.tolist() -def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]: +def get_alignments( + best_paths: k2.Fsa, kind: str, remove_zero_blank: bool = False +) -> List[List[int]]: """Extract labels or aux_labels from the best-path FSAs. Args: @@ -272,6 +274,8 @@ def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]: token_shape, getattr(best_paths, kind).contiguous() ) tokens = tokens.remove_values_eq(-1) + if remove_zero_blank: + tokens = tokens.remove_values_eq(0) return tokens.tolist()