From f35fa8aa8fe3557199002de887c9f9d6079f4ffb Mon Sep 17 00:00:00 2001 From: Triplecq Date: Tue, 23 Jan 2024 17:10:10 -0500 Subject: [PATCH] add blank penalty in decoding script --- egs/reazonspeech/ASR/zipformer/decode.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/egs/reazonspeech/ASR/zipformer/decode.py b/egs/reazonspeech/ASR/zipformer/decode.py index bbe363f41..757d1323f 100755 --- a/egs/reazonspeech/ASR/zipformer/decode.py +++ b/egs/reazonspeech/ASR/zipformer/decode.py @@ -370,6 +370,19 @@ def get_parser(): modified_beam_search_LODR. """, ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + add_model_arguments(parser) return parser @@ -457,6 +470,7 @@ def decode_one_batch( beam=params.beam, max_contexts=params.max_contexts, max_states=params.max_states, + blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): hyps.append(sp.text2word(hyp)) @@ -471,6 +485,7 @@ def decode_one_batch( max_states=params.max_states, num_paths=params.num_paths, nbest_scale=params.nbest_scale, + blank_penalty=params.blank_penalty, ) for hyp in hyp_tokens: hyps.append([word_table[i] for i in hyp]) @@ -485,6 +500,7 @@ def decode_one_batch( max_states=params.max_states, num_paths=params.num_paths, nbest_scale=params.nbest_scale, + blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): hyps.append(sp.text2word(hyp)) @@ -500,6 +516,7 @@ def decode_one_batch( num_paths=params.num_paths, ref_texts=sp.encode(supervisions["text"]), nbest_scale=params.nbest_scale, + blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): hyps.append(sp.text2word(hyp)) @@ -508,6 +525,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, + blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): hyps.append(sp.text2word(hyp)) @@ -518,6 +536,7 @@ def decode_one_batch( encoder_out_lens=encoder_out_lens, beam=params.beam_size, context_graph=context_graph, + blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): hyps.append(sp.text2word(hyp)) @@ -591,6 +610,7 @@ def decode_one_batch( ) hyps.append(sp.text2word(sp.decode(hyp))) + key = f"blank_penalty_{params.blank_penalty}" if params.decoding_method == "greedy_search": return {"greedy_search": hyps} elif "fast_beam_search" in params.decoding_method: @@ -827,6 +847,8 @@ def main(): f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" ) + params.suffix += f"-blank-penalty-{params.blank_penalty}" + if params.use_averaged_model: params.suffix += "-use-averaged-model"