From 7225965a4ad199baf890ee3ef213f95ed4ffc202 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 12 Jul 2024 00:01:04 +0000 Subject: [PATCH] add blank penalty --- egs/multi_zh-hans/ASR/zipformer/decode.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/egs/multi_zh-hans/ASR/zipformer/decode.py b/egs/multi_zh-hans/ASR/zipformer/decode.py index 5993f243f..a1d018cd2 100755 --- a/egs/multi_zh-hans/ASR/zipformer/decode.py +++ b/egs/multi_zh-hans/ASR/zipformer/decode.py @@ -303,6 +303,17 @@ def get_parser(): fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) add_model_arguments(parser) return parser @@ -431,6 +442,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, + blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -455,6 +467,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out_i, max_sym_per_frame=params.max_sym_per_frame, + blank_penalty=params.blank_penalty, ) elif params.decoding_method == "beam_search": hyp = beam_search( @@ -468,8 +481,9 @@ def decode_one_batch( ) hyps.append(sp.decode(hyp).split()) + key = f"blank_penalty_{params.blank_penalty}" if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} + return {"greedy_search_" + key: hyps} elif "fast_beam_search" in params.decoding_method: key = f"beam_{params.beam}_" key += f"max_contexts_{params.max_contexts}_" @@ -657,6 +671,7 @@ def main(): params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + params.suffix += f"-blank-penalty-{params.blank_penalty}" if params.use_averaged_model: params.suffix += "-use-averaged-model"