add blank penalty in decoding script

This commit is contained in:
Triplecq 2024-01-23 17:10:10 -05:00
parent a8e9dc2488
commit f35fa8aa8f

View File

@ -370,6 +370,19 @@ def get_parser():
modified_beam_search_LODR.
""",
)
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
add_model_arguments(parser)
return parser
@ -457,6 +470,7 @@ def decode_one_batch(
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
blank_penalty=params.blank_penalty,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(sp.text2word(hyp))
@ -471,6 +485,7 @@ def decode_one_batch(
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
blank_penalty=params.blank_penalty,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
@ -485,6 +500,7 @@ def decode_one_batch(
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
blank_penalty=params.blank_penalty,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(sp.text2word(hyp))
@ -500,6 +516,7 @@ def decode_one_batch(
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
blank_penalty=params.blank_penalty,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(sp.text2word(hyp))
@ -508,6 +525,7 @@ def decode_one_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
blank_penalty=params.blank_penalty,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(sp.text2word(hyp))
@ -518,6 +536,7 @@ def decode_one_batch(
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
context_graph=context_graph,
blank_penalty=params.blank_penalty,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(sp.text2word(hyp))
@ -591,6 +610,7 @@ def decode_one_batch(
)
hyps.append(sp.text2word(sp.decode(hyp)))
key = f"blank_penalty_{params.blank_penalty}"
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif "fast_beam_search" in params.decoding_method:
@ -827,6 +847,8 @@ def main():
f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
)
params.suffix += f"-blank-penalty-{params.blank_penalty}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"