added blank penalty

This commit is contained in:
jinzr 2023-11-17 17:06:23 +08:00
parent b6bcd4dcf4
commit 39a02f7c30

View File

@ -310,6 +310,18 @@ def get_parser():
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
parser.add_argument(
"--use-shallow-fusion",
type=str2bool,
@ -460,6 +472,7 @@ def decode_one_batch(
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
blank_penalty=params.blank_penalty,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -474,6 +487,7 @@ def decode_one_batch(
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
blank_penalty=params.blank_penalty,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
@ -488,6 +502,7 @@ def decode_one_batch(
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
blank_penalty=params.blank_penalty,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -503,6 +518,7 @@ def decode_one_batch(
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
blank_penalty=params.blank_penalty,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -511,6 +527,7 @@ def decode_one_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
blank_penalty=params.blank_penalty,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -521,6 +538,7 @@ def decode_one_batch(
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
context_graph=context_graph,
blank_penalty=params.blank_penalty,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -531,6 +549,7 @@ def decode_one_batch(
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
LM=LM,
blank_penalty=params.blank_penalty,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -544,6 +563,7 @@ def decode_one_batch(
LODR_lm_scale=ngram_lm_scale,
LM=LM,
context_graph=context_graph,
blank_penalty=params.blank_penalty,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -556,6 +576,7 @@ def decode_one_batch(
beam=params.beam_size,
LM=LM,
lm_scale_list=lm_scale_list,
blank_penalty=params.blank_penalty,
)
elif params.decoding_method == "modified_beam_search_lm_rescore_LODR":
lm_scale_list = [0.02 * i for i in range(2, 30)]
@ -568,6 +589,7 @@ def decode_one_batch(
LODR_lm=ngram_lm,
sp=sp,
lm_scale_list=lm_scale_list,
blank_penalty=params.blank_penalty,
)
else:
batch_size = encoder_out.size(0)
@ -581,12 +603,14 @@ def decode_one_batch(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
blank_penalty=params.blank_penalty,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
blank_penalty=params.blank_penalty,
)
else:
raise ValueError(