mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 16:14:17 +00:00
added blank penalty
This commit is contained in:
parent
b6bcd4dcf4
commit
39a02f7c30
@ -310,6 +310,18 @@ def get_parser():
|
|||||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--blank-penalty",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="""
|
||||||
|
The penalty applied on blank symbol during decoding.
|
||||||
|
Note: It is a positive value that would be applied to logits like
|
||||||
|
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
||||||
|
[batch_size, vocab] and blank id is 0).
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-shallow-fusion",
|
"--use-shallow-fusion",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -460,6 +472,7 @@ def decode_one_batch(
|
|||||||
beam=params.beam,
|
beam=params.beam,
|
||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
|
blank_penalty=params.blank_penalty,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
@ -474,6 +487,7 @@ def decode_one_batch(
|
|||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
nbest_scale=params.nbest_scale,
|
nbest_scale=params.nbest_scale,
|
||||||
|
blank_penalty=params.blank_penalty,
|
||||||
)
|
)
|
||||||
for hyp in hyp_tokens:
|
for hyp in hyp_tokens:
|
||||||
hyps.append([word_table[i] for i in hyp])
|
hyps.append([word_table[i] for i in hyp])
|
||||||
@ -488,6 +502,7 @@ def decode_one_batch(
|
|||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
nbest_scale=params.nbest_scale,
|
nbest_scale=params.nbest_scale,
|
||||||
|
blank_penalty=params.blank_penalty,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
@ -503,6 +518,7 @@ def decode_one_batch(
|
|||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
ref_texts=sp.encode(supervisions["text"]),
|
ref_texts=sp.encode(supervisions["text"]),
|
||||||
nbest_scale=params.nbest_scale,
|
nbest_scale=params.nbest_scale,
|
||||||
|
blank_penalty=params.blank_penalty,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
@ -511,6 +527,7 @@ def decode_one_batch(
|
|||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
blank_penalty=params.blank_penalty,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
@ -521,6 +538,7 @@ def decode_one_batch(
|
|||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
context_graph=context_graph,
|
context_graph=context_graph,
|
||||||
|
blank_penalty=params.blank_penalty,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
@ -531,6 +549,7 @@ def decode_one_batch(
|
|||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
LM=LM,
|
LM=LM,
|
||||||
|
blank_penalty=params.blank_penalty,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
@ -544,6 +563,7 @@ def decode_one_batch(
|
|||||||
LODR_lm_scale=ngram_lm_scale,
|
LODR_lm_scale=ngram_lm_scale,
|
||||||
LM=LM,
|
LM=LM,
|
||||||
context_graph=context_graph,
|
context_graph=context_graph,
|
||||||
|
blank_penalty=params.blank_penalty,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
@ -556,6 +576,7 @@ def decode_one_batch(
|
|||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
LM=LM,
|
LM=LM,
|
||||||
lm_scale_list=lm_scale_list,
|
lm_scale_list=lm_scale_list,
|
||||||
|
blank_penalty=params.blank_penalty,
|
||||||
)
|
)
|
||||||
elif params.decoding_method == "modified_beam_search_lm_rescore_LODR":
|
elif params.decoding_method == "modified_beam_search_lm_rescore_LODR":
|
||||||
lm_scale_list = [0.02 * i for i in range(2, 30)]
|
lm_scale_list = [0.02 * i for i in range(2, 30)]
|
||||||
@ -568,6 +589,7 @@ def decode_one_batch(
|
|||||||
LODR_lm=ngram_lm,
|
LODR_lm=ngram_lm,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
lm_scale_list=lm_scale_list,
|
lm_scale_list=lm_scale_list,
|
||||||
|
blank_penalty=params.blank_penalty,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
batch_size = encoder_out.size(0)
|
batch_size = encoder_out.size(0)
|
||||||
@ -581,12 +603,14 @@ def decode_one_batch(
|
|||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out_i,
|
encoder_out=encoder_out_i,
|
||||||
max_sym_per_frame=params.max_sym_per_frame,
|
max_sym_per_frame=params.max_sym_per_frame,
|
||||||
|
blank_penalty=params.blank_penalty,
|
||||||
)
|
)
|
||||||
elif params.decoding_method == "beam_search":
|
elif params.decoding_method == "beam_search":
|
||||||
hyp = beam_search(
|
hyp = beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out_i,
|
encoder_out=encoder_out_i,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
|
blank_penalty=params.blank_penalty,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user