Add blank-penalty to other decoding method

This commit is contained in:
pkufool 2023-05-25 12:20:29 +08:00
parent 961750c0a9
commit 899f858659
3 changed files with 45 additions and 6 deletions

View File

@ -49,6 +49,7 @@ def fast_beam_search_one_best(
temperature: float = 1.0, temperature: float = 1.0,
subtract_ilme: bool = False, subtract_ilme: bool = False,
ilme_scale: float = 0.1, ilme_scale: float = 0.1,
blank_penalty: float = 0.0,
return_timestamps: bool = False, return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]: ) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
@ -92,6 +93,7 @@ def fast_beam_search_one_best(
temperature=temperature, temperature=temperature,
subtract_ilme=subtract_ilme, subtract_ilme=subtract_ilme,
ilme_scale=ilme_scale, ilme_scale=ilme_scale,
blank_penalty=blank_penalty,
) )
best_path = one_best_decoding(lattice) best_path = one_best_decoding(lattice)
@ -114,6 +116,7 @@ def fast_beam_search_nbest_LG(
nbest_scale: float = 0.5, nbest_scale: float = 0.5,
use_double_scores: bool = True, use_double_scores: bool = True,
temperature: float = 1.0, temperature: float = 1.0,
blank_penalty: float = 0.0,
return_timestamps: bool = False, return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]: ) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
@ -168,6 +171,7 @@ def fast_beam_search_nbest_LG(
max_states=max_states, max_states=max_states,
max_contexts=max_contexts, max_contexts=max_contexts,
temperature=temperature, temperature=temperature,
blank_penalty=blank_penalty,
) )
nbest = Nbest.from_lattice( nbest = Nbest.from_lattice(
@ -240,6 +244,7 @@ def fast_beam_search_nbest(
nbest_scale: float = 0.5, nbest_scale: float = 0.5,
use_double_scores: bool = True, use_double_scores: bool = True,
temperature: float = 1.0, temperature: float = 1.0,
blank_penalty: float = 0.0,
return_timestamps: bool = False, return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]: ) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
@ -293,6 +298,7 @@ def fast_beam_search_nbest(
beam=beam, beam=beam,
max_states=max_states, max_states=max_states,
max_contexts=max_contexts, max_contexts=max_contexts,
blank_penalty=blank_penalty,
temperature=temperature, temperature=temperature,
) )
@ -331,6 +337,7 @@ def fast_beam_search_nbest_oracle(
use_double_scores: bool = True, use_double_scores: bool = True,
nbest_scale: float = 0.5, nbest_scale: float = 0.5,
temperature: float = 1.0, temperature: float = 1.0,
blank_penalty: float = 0.0,
return_timestamps: bool = False, return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]: ) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
@ -389,6 +396,7 @@ def fast_beam_search_nbest_oracle(
max_states=max_states, max_states=max_states,
max_contexts=max_contexts, max_contexts=max_contexts,
temperature=temperature, temperature=temperature,
blank_penalty=blank_penalty,
) )
nbest = Nbest.from_lattice( nbest = Nbest.from_lattice(
@ -434,6 +442,7 @@ def fast_beam_search(
temperature: float = 1.0, temperature: float = 1.0,
subtract_ilme: bool = False, subtract_ilme: bool = False,
ilme_scale: float = 0.1, ilme_scale: float = 0.1,
blank_penalty: float = 0.0,
) -> k2.Fsa: ) -> k2.Fsa:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
@ -503,6 +512,8 @@ def fast_beam_search(
project_input=False, project_input=False,
) )
logits = logits.squeeze(1).squeeze(1) logits = logits.squeeze(1).squeeze(1)
if blank_penalty != 0:
logits[:, 0] -= blank_penalty
log_probs = (logits / temperature).log_softmax(dim=-1) log_probs = (logits / temperature).log_softmax(dim=-1)
if subtract_ilme: if subtract_ilme:
ilme_logits = model.joiner( ilme_logits = model.joiner(
@ -526,6 +537,7 @@ def greedy_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
max_sym_per_frame: int, max_sym_per_frame: int,
blank_penalty: float = 0.0,
return_timestamps: bool = False, return_timestamps: bool = False,
) -> Union[List[int], DecodingResults]: ) -> Union[List[int], DecodingResults]:
"""Greedy search for a single utterance. """Greedy search for a single utterance.
@ -595,6 +607,9 @@ def greedy_search(
) )
# logits is (1, 1, 1, vocab_size) # logits is (1, 1, 1, vocab_size)
if blank_penalty != 0:
logits[:,:,:,0] -= blank_penalty
y = logits.argmax().item() y = logits.argmax().item()
if y not in (blank_id, unk_id): if y not in (blank_id, unk_id):
hyp.append(y) hyp.append(y)
@ -704,7 +719,10 @@ def greedy_search_batch(
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape assert logits.ndim == 2, logits.shape
logits[:, 0] -= blank_penalty
if blank_penalty != 0:
logits[:, 0] -= blank_penalty
y = logits.argmax(dim=1).tolist() y = logits.argmax(dim=1).tolist()
emitted = False emitted = False
for i, v in enumerate(y): for i, v in enumerate(y):
@ -921,6 +939,7 @@ def modified_beam_search(
encoder_out_lens: torch.Tensor, encoder_out_lens: torch.Tensor,
beam: int = 4, beam: int = 4,
temperature: float = 1.0, temperature: float = 1.0,
blank_penalty: float = 0.0,
return_timestamps: bool = False, return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]: ) -> Union[List[List[int]], DecodingResults]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
@ -1024,6 +1043,9 @@ def modified_beam_search(
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
if blank_penalty != 0:
logits[:, 0] -= blank_penalty
log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs) log_probs.add_(ys_log_probs)
@ -1628,6 +1650,7 @@ def beam_search(
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
beam: int = 4, beam: int = 4,
temperature: float = 1.0, temperature: float = 1.0,
blank_penalty: float = 0.0,
return_timestamps: bool = False, return_timestamps: bool = False,
) -> Union[List[int], DecodingResults]: ) -> Union[List[int], DecodingResults]:
""" """
@ -1724,6 +1747,9 @@ def beam_search(
project_input=False, project_input=False,
) )
if blank_penalty != 0:
logits[:,:,:,0] -= blank_penalty
# TODO(fangjun): Scale the blank posterior # TODO(fangjun): Scale the blank posterior
log_prob = (logits / temperature).log_softmax(dim=-1) log_prob = (logits / temperature).log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size) # log_prob is (1, 1, 1, vocab_size)

View File

@ -307,7 +307,12 @@ def get_parser():
"--blank-penalty", "--blank-penalty",
type=float, type=float,
default=0.0, default=0.0,
help="", help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
) )
add_model_arguments(parser) add_model_arguments(parser)
@ -373,6 +378,7 @@ def decode_one_batch(
beam=params.beam, beam=params.beam,
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
blank_penalty=params.blank_penalty,
) )
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
@ -387,6 +393,7 @@ def decode_one_batch(
max_states=params.max_states, max_states=params.max_states,
num_paths=params.num_paths, num_paths=params.num_paths,
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
blank_penalty=params.blank_penalty,
) )
for hyp in hyp_tokens: for hyp in hyp_tokens:
sentence = "".join([lexicon.word_table[i] for i in hyp]) sentence = "".join([lexicon.word_table[i] for i in hyp])
@ -402,6 +409,7 @@ def decode_one_batch(
max_states=params.max_states, max_states=params.max_states,
num_paths=params.num_paths, num_paths=params.num_paths,
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
blank_penalty=params.blank_penalty,
) )
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
@ -417,6 +425,7 @@ def decode_one_batch(
num_paths=params.num_paths, num_paths=params.num_paths,
ref_texts=graph_compiler.texts_to_ids(supervisions["text"]), ref_texts=graph_compiler.texts_to_ids(supervisions["text"]),
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
blank_penalty=params.blank_penalty,
) )
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
@ -431,6 +440,7 @@ def decode_one_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
blank_penalty=params.blank_penalty,
beam=params.beam_size, beam=params.beam_size,
) )
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
@ -447,10 +457,12 @@ def decode_one_batch(
model=model, model=model,
encoder_out=encoder_out_i, encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame, max_sym_per_frame=params.max_sym_per_frame,
blank_penalty=params.blank_penalty,
) )
elif params.decoding_method == "beam_search": elif params.decoding_method == "beam_search":
hyp = beam_search( hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size, model=model, encoder_out=encoder_out_i, beam=params.beam_size,
blank_penalty=params.blank_penalty,
) )
else: else:
raise ValueError( raise ValueError(
@ -458,10 +470,11 @@ def decode_one_batch(
) )
hyps.append([lexicon.token_table[idx] for idx in hyp]) hyps.append([lexicon.token_table[idx] for idx in hyp])
key = f"blank_penalty_{params.blank_penalty}"
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search_" + key: hyps}
elif "fast_beam_search" in params.decoding_method: elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_" key += f"_beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_" key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}" key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method: if "nbest" in params.decoding_method:
@ -472,7 +485,7 @@ def decode_one_batch(
return {key: hyps} return {key: hyps}
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}_" + key: hyps}
def decode_dataset( def decode_dataset(

View File

@ -260,7 +260,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--lr-epochs", "--lr-epochs",
type=float, type=float,
default=3.5, default=1.5,
help="""Number of epochs that affects how rapidly the learning rate decreases. help="""Number of epochs that affects how rapidly the learning rate decreases.
""", """,
) )