mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
add blank penalty in decoding script
This commit is contained in:
parent
a8e9dc2488
commit
f35fa8aa8f
@ -370,6 +370,19 @@ def get_parser():
|
|||||||
modified_beam_search_LODR.
|
modified_beam_search_LODR.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--blank-penalty",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="""
|
||||||
|
The penalty applied on blank symbol during decoding.
|
||||||
|
Note: It is a positive value that would be applied to logits like
|
||||||
|
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
||||||
|
[batch_size, vocab] and blank id is 0).
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -457,6 +470,7 @@ def decode_one_batch(
|
|||||||
beam=params.beam,
|
beam=params.beam,
|
||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
|
blank_penalty=params.blank_penalty,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(sp.text2word(hyp))
|
hyps.append(sp.text2word(hyp))
|
||||||
@ -471,6 +485,7 @@ def decode_one_batch(
|
|||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
nbest_scale=params.nbest_scale,
|
nbest_scale=params.nbest_scale,
|
||||||
|
blank_penalty=params.blank_penalty,
|
||||||
)
|
)
|
||||||
for hyp in hyp_tokens:
|
for hyp in hyp_tokens:
|
||||||
hyps.append([word_table[i] for i in hyp])
|
hyps.append([word_table[i] for i in hyp])
|
||||||
@ -485,6 +500,7 @@ def decode_one_batch(
|
|||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
nbest_scale=params.nbest_scale,
|
nbest_scale=params.nbest_scale,
|
||||||
|
blank_penalty=params.blank_penalty,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(sp.text2word(hyp))
|
hyps.append(sp.text2word(hyp))
|
||||||
@ -500,6 +516,7 @@ def decode_one_batch(
|
|||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
ref_texts=sp.encode(supervisions["text"]),
|
ref_texts=sp.encode(supervisions["text"]),
|
||||||
nbest_scale=params.nbest_scale,
|
nbest_scale=params.nbest_scale,
|
||||||
|
blank_penalty=params.blank_penalty,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(sp.text2word(hyp))
|
hyps.append(sp.text2word(hyp))
|
||||||
@ -508,6 +525,7 @@ def decode_one_batch(
|
|||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
blank_penalty=params.blank_penalty,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(sp.text2word(hyp))
|
hyps.append(sp.text2word(hyp))
|
||||||
@ -518,6 +536,7 @@ def decode_one_batch(
|
|||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
context_graph=context_graph,
|
context_graph=context_graph,
|
||||||
|
blank_penalty=params.blank_penalty,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(sp.text2word(hyp))
|
hyps.append(sp.text2word(hyp))
|
||||||
@ -591,6 +610,7 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
hyps.append(sp.text2word(sp.decode(hyp)))
|
hyps.append(sp.text2word(sp.decode(hyp)))
|
||||||
|
|
||||||
|
key = f"blank_penalty_{params.blank_penalty}"
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
return {"greedy_search": hyps}
|
return {"greedy_search": hyps}
|
||||||
elif "fast_beam_search" in params.decoding_method:
|
elif "fast_beam_search" in params.decoding_method:
|
||||||
@ -827,6 +847,8 @@ def main():
|
|||||||
f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
|
f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
params.suffix += f"-blank-penalty-{params.blank_penalty}"
|
||||||
|
|
||||||
if params.use_averaged_model:
|
if params.use_averaged_model:
|
||||||
params.suffix += "-use-averaged-model"
|
params.suffix += "-use-averaged-model"
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user