added blank penalty

This commit is contained in:
jinzr 2023-11-17 17:06:23 +08:00
parent b6bcd4dcf4
commit 39a02f7c30

View File

@ -310,6 +310,18 @@ def get_parser():
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
) )
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
parser.add_argument( parser.add_argument(
"--use-shallow-fusion", "--use-shallow-fusion",
type=str2bool, type=str2bool,
@ -460,6 +472,7 @@ def decode_one_batch(
beam=params.beam, beam=params.beam,
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
blank_penalty=params.blank_penalty,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -474,6 +487,7 @@ def decode_one_batch(
max_states=params.max_states, max_states=params.max_states,
num_paths=params.num_paths, num_paths=params.num_paths,
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
blank_penalty=params.blank_penalty,
) )
for hyp in hyp_tokens: for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp]) hyps.append([word_table[i] for i in hyp])
@ -488,6 +502,7 @@ def decode_one_batch(
max_states=params.max_states, max_states=params.max_states,
num_paths=params.num_paths, num_paths=params.num_paths,
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
blank_penalty=params.blank_penalty,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -503,6 +518,7 @@ def decode_one_batch(
num_paths=params.num_paths, num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]), ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
blank_penalty=params.blank_penalty,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -511,6 +527,7 @@ def decode_one_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
blank_penalty=params.blank_penalty,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -521,6 +538,7 @@ def decode_one_batch(
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
context_graph=context_graph, context_graph=context_graph,
blank_penalty=params.blank_penalty,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -531,6 +549,7 @@ def decode_one_batch(
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
LM=LM, LM=LM,
blank_penalty=params.blank_penalty,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -544,6 +563,7 @@ def decode_one_batch(
LODR_lm_scale=ngram_lm_scale, LODR_lm_scale=ngram_lm_scale,
LM=LM, LM=LM,
context_graph=context_graph, context_graph=context_graph,
blank_penalty=params.blank_penalty,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -556,6 +576,7 @@ def decode_one_batch(
beam=params.beam_size, beam=params.beam_size,
LM=LM, LM=LM,
lm_scale_list=lm_scale_list, lm_scale_list=lm_scale_list,
blank_penalty=params.blank_penalty,
) )
elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": elif params.decoding_method == "modified_beam_search_lm_rescore_LODR":
lm_scale_list = [0.02 * i for i in range(2, 30)] lm_scale_list = [0.02 * i for i in range(2, 30)]
@ -568,6 +589,7 @@ def decode_one_batch(
LODR_lm=ngram_lm, LODR_lm=ngram_lm,
sp=sp, sp=sp,
lm_scale_list=lm_scale_list, lm_scale_list=lm_scale_list,
blank_penalty=params.blank_penalty,
) )
else: else:
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
@ -581,12 +603,14 @@ def decode_one_batch(
model=model, model=model,
encoder_out=encoder_out_i, encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame, max_sym_per_frame=params.max_sym_per_frame,
blank_penalty=params.blank_penalty,
) )
elif params.decoding_method == "beam_search": elif params.decoding_method == "beam_search":
hyp = beam_search( hyp = beam_search(
model=model, model=model,
encoder_out=encoder_out_i, encoder_out=encoder_out_i,
beam=params.beam_size, beam=params.beam_size,
blank_penalty=params.blank_penalty,
) )
else: else:
raise ValueError( raise ValueError(