From 39a02f7c30e82bdb16452dc3ab1e686830cd84c0 Mon Sep 17 00:00:00 2001 From: jinzr Date: Fri, 17 Nov 2023 17:06:23 +0800 Subject: [PATCH] added blank penalty --- egs/multi_zh-hans/ASR/zipformer/decode.py | 24 +++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/egs/multi_zh-hans/ASR/zipformer/decode.py b/egs/multi_zh-hans/ASR/zipformer/decode.py index 2d3510fc1..89e3dfa98 100755 --- a/egs/multi_zh-hans/ASR/zipformer/decode.py +++ b/egs/multi_zh-hans/ASR/zipformer/decode.py @@ -310,6 +310,18 @@ def get_parser(): fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + parser.add_argument( "--use-shallow-fusion", type=str2bool, @@ -460,6 +472,7 @@ def decode_one_batch( beam=params.beam, max_contexts=params.max_contexts, max_states=params.max_states, + blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -474,6 +487,7 @@ def decode_one_batch( max_states=params.max_states, num_paths=params.num_paths, nbest_scale=params.nbest_scale, + blank_penalty=params.blank_penalty, ) for hyp in hyp_tokens: hyps.append([word_table[i] for i in hyp]) @@ -488,6 +502,7 @@ def decode_one_batch( max_states=params.max_states, num_paths=params.num_paths, nbest_scale=params.nbest_scale, + blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -503,6 +518,7 @@ def decode_one_batch( num_paths=params.num_paths, ref_texts=sp.encode(supervisions["text"]), nbest_scale=params.nbest_scale, + blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -511,6 +527,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, + blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -521,6 +538,7 @@ def decode_one_batch( encoder_out_lens=encoder_out_lens, beam=params.beam_size, context_graph=context_graph, + blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -531,6 +549,7 @@ def decode_one_batch( encoder_out_lens=encoder_out_lens, beam=params.beam_size, LM=LM, + blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -544,6 +563,7 @@ def decode_one_batch( LODR_lm_scale=ngram_lm_scale, LM=LM, context_graph=context_graph, + blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -556,6 +576,7 @@ def decode_one_batch( beam=params.beam_size, LM=LM, lm_scale_list=lm_scale_list, + blank_penalty=params.blank_penalty, ) elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": lm_scale_list = [0.02 * i for i in range(2, 30)] @@ -568,6 +589,7 @@ def decode_one_batch( LODR_lm=ngram_lm, sp=sp, lm_scale_list=lm_scale_list, + blank_penalty=params.blank_penalty, ) else: batch_size = encoder_out.size(0) @@ -581,12 +603,14 @@ def decode_one_batch( model=model, encoder_out=encoder_out_i, max_sym_per_frame=params.max_sym_per_frame, + blank_penalty=params.blank_penalty, ) elif params.decoding_method == "beam_search": hyp = beam_search( model=model, encoder_out=encoder_out_i, beam=params.beam_size, + blank_penalty=params.blank_penalty, ) else: raise ValueError(