mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add blank-penalty to other decoding method
This commit is contained in:
parent
961750c0a9
commit
899f858659
@ -49,6 +49,7 @@ def fast_beam_search_one_best(
|
||||
temperature: float = 1.0,
|
||||
subtract_ilme: bool = False,
|
||||
ilme_scale: float = 0.1,
|
||||
blank_penalty: float = 0.0,
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[List[int]], DecodingResults]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
@ -92,6 +93,7 @@ def fast_beam_search_one_best(
|
||||
temperature=temperature,
|
||||
subtract_ilme=subtract_ilme,
|
||||
ilme_scale=ilme_scale,
|
||||
blank_penalty=blank_penalty,
|
||||
)
|
||||
|
||||
best_path = one_best_decoding(lattice)
|
||||
@ -114,6 +116,7 @@ def fast_beam_search_nbest_LG(
|
||||
nbest_scale: float = 0.5,
|
||||
use_double_scores: bool = True,
|
||||
temperature: float = 1.0,
|
||||
blank_penalty: float = 0.0,
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[List[int]], DecodingResults]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
@ -168,6 +171,7 @@ def fast_beam_search_nbest_LG(
|
||||
max_states=max_states,
|
||||
max_contexts=max_contexts,
|
||||
temperature=temperature,
|
||||
blank_penalty=blank_penalty,
|
||||
)
|
||||
|
||||
nbest = Nbest.from_lattice(
|
||||
@ -240,6 +244,7 @@ def fast_beam_search_nbest(
|
||||
nbest_scale: float = 0.5,
|
||||
use_double_scores: bool = True,
|
||||
temperature: float = 1.0,
|
||||
blank_penalty: float = 0.0,
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[List[int]], DecodingResults]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
@ -293,6 +298,7 @@ def fast_beam_search_nbest(
|
||||
beam=beam,
|
||||
max_states=max_states,
|
||||
max_contexts=max_contexts,
|
||||
blank_penalty=blank_penalty,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
@ -331,6 +337,7 @@ def fast_beam_search_nbest_oracle(
|
||||
use_double_scores: bool = True,
|
||||
nbest_scale: float = 0.5,
|
||||
temperature: float = 1.0,
|
||||
blank_penalty: float = 0.0,
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[List[int]], DecodingResults]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
@ -389,6 +396,7 @@ def fast_beam_search_nbest_oracle(
|
||||
max_states=max_states,
|
||||
max_contexts=max_contexts,
|
||||
temperature=temperature,
|
||||
blank_penalty=blank_penalty,
|
||||
)
|
||||
|
||||
nbest = Nbest.from_lattice(
|
||||
@ -434,6 +442,7 @@ def fast_beam_search(
|
||||
temperature: float = 1.0,
|
||||
subtract_ilme: bool = False,
|
||||
ilme_scale: float = 0.1,
|
||||
blank_penalty: float = 0.0,
|
||||
) -> k2.Fsa:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
@ -503,6 +512,8 @@ def fast_beam_search(
|
||||
project_input=False,
|
||||
)
|
||||
logits = logits.squeeze(1).squeeze(1)
|
||||
if blank_penalty != 0:
|
||||
logits[:, 0] -= blank_penalty
|
||||
log_probs = (logits / temperature).log_softmax(dim=-1)
|
||||
if subtract_ilme:
|
||||
ilme_logits = model.joiner(
|
||||
@ -526,6 +537,7 @@ def greedy_search(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
max_sym_per_frame: int,
|
||||
blank_penalty: float = 0.0,
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[int], DecodingResults]:
|
||||
"""Greedy search for a single utterance.
|
||||
@ -595,6 +607,9 @@ def greedy_search(
|
||||
)
|
||||
# logits is (1, 1, 1, vocab_size)
|
||||
|
||||
if blank_penalty != 0:
|
||||
logits[:,:,:,0] -= blank_penalty
|
||||
|
||||
y = logits.argmax().item()
|
||||
if y not in (blank_id, unk_id):
|
||||
hyp.append(y)
|
||||
@ -704,7 +719,10 @@ def greedy_search_batch(
|
||||
|
||||
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
|
||||
assert logits.ndim == 2, logits.shape
|
||||
|
||||
if blank_penalty != 0:
|
||||
logits[:, 0] -= blank_penalty
|
||||
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
for i, v in enumerate(y):
|
||||
@ -921,6 +939,7 @@ def modified_beam_search(
|
||||
encoder_out_lens: torch.Tensor,
|
||||
beam: int = 4,
|
||||
temperature: float = 1.0,
|
||||
blank_penalty: float = 0.0,
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[List[int]], DecodingResults]:
|
||||
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
||||
@ -1024,6 +1043,9 @@ def modified_beam_search(
|
||||
|
||||
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
|
||||
|
||||
if blank_penalty != 0:
|
||||
logits[:, 0] -= blank_penalty
|
||||
|
||||
log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size)
|
||||
|
||||
log_probs.add_(ys_log_probs)
|
||||
@ -1628,6 +1650,7 @@ def beam_search(
|
||||
encoder_out: torch.Tensor,
|
||||
beam: int = 4,
|
||||
temperature: float = 1.0,
|
||||
blank_penalty: float = 0.0,
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[int], DecodingResults]:
|
||||
"""
|
||||
@ -1724,6 +1747,9 @@ def beam_search(
|
||||
project_input=False,
|
||||
)
|
||||
|
||||
if blank_penalty != 0:
|
||||
logits[:,:,:,0] -= blank_penalty
|
||||
|
||||
# TODO(fangjun): Scale the blank posterior
|
||||
log_prob = (logits / temperature).log_softmax(dim=-1)
|
||||
# log_prob is (1, 1, 1, vocab_size)
|
||||
|
||||
@ -307,7 +307,12 @@ def get_parser():
|
||||
"--blank-penalty",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="",
|
||||
help="""
|
||||
The penalty applied on blank symbol during decoding.
|
||||
Note: It is a positive value that would be applied to logits like
|
||||
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
||||
[batch_size, vocab] and blank id is 0).
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
@ -373,6 +378,7 @@ def decode_one_batch(
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
blank_penalty=params.blank_penalty,
|
||||
)
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
@ -387,6 +393,7 @@ def decode_one_batch(
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
blank_penalty=params.blank_penalty,
|
||||
)
|
||||
for hyp in hyp_tokens:
|
||||
sentence = "".join([lexicon.word_table[i] for i in hyp])
|
||||
@ -402,6 +409,7 @@ def decode_one_batch(
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
blank_penalty=params.blank_penalty,
|
||||
)
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
@ -417,6 +425,7 @@ def decode_one_batch(
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=graph_compiler.texts_to_ids(supervisions["text"]),
|
||||
nbest_scale=params.nbest_scale,
|
||||
blank_penalty=params.blank_penalty,
|
||||
)
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
@ -431,6 +440,7 @@ def decode_one_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
blank_penalty=params.blank_penalty,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for i in range(encoder_out.size(0)):
|
||||
@ -447,10 +457,12 @@ def decode_one_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
blank_penalty=params.blank_penalty,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size,
|
||||
blank_penalty=params.blank_penalty,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
@ -458,10 +470,11 @@ def decode_one_batch(
|
||||
)
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp])
|
||||
|
||||
key = f"blank_penalty_{params.blank_penalty}"
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
return {"greedy_search_" + key: hyps}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key = f"beam_{params.beam}_"
|
||||
key += f"_beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
key += f"max_states_{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
@ -472,7 +485,7 @@ def decode_one_batch(
|
||||
|
||||
return {key: hyps}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
return {f"beam_size_{params.beam_size}_" + key: hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
|
||||
@ -260,7 +260,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--lr-epochs",
|
||||
type=float,
|
||||
default=3.5,
|
||||
default=1.5,
|
||||
help="""Number of epochs that affects how rapidly the learning rate decreases.
|
||||
""",
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user