add blank penalty

This commit is contained in:
pkufool 2023-05-24 22:15:20 +08:00
parent 42513d2e98
commit 4f28e15a1d
3 changed files with 28 additions and 10 deletions

View File

@ -626,6 +626,7 @@ def greedy_search_batch(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor, encoder_out_lens: torch.Tensor,
blank_penalty: float = 0,
return_timestamps: bool = False, return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]: ) -> Union[List[List[int]], DecodingResults]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
@ -701,6 +702,7 @@ def greedy_search_batch(
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape assert logits.ndim == 2, logits.shape
logits[:, 0] -= blank_penalty
y = logits.argmax(dim=1).tolist() y = logits.argmax(dim=1).tolist()
emitted = False emitted = False
for i, v in enumerate(y): for i, v in enumerate(y):

View File

@ -292,7 +292,7 @@ class WenetSpeechAsrDataModule:
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=self.args.shuffle, shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets, num_buckets=self.args.num_buckets,
buffer_size=30000, buffer_size=300000,
drop_last=True, drop_last=True,
) )
else: else:

View File

@ -125,6 +125,7 @@ from beam_search import (
) )
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from lhotse.cut import Cut
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
@ -302,6 +303,13 @@ def get_parser():
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
) )
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -414,9 +422,7 @@ def decode_one_batch(
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, blank_penalty=params.blank_penalty,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
) )
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
@ -444,9 +450,7 @@ def decode_one_batch(
) )
elif params.decoding_method == "beam_search": elif params.decoding_method == "beam_search":
hyp = beam_search( hyp = beam_search(
model=model, model=model, encoder_out=encoder_out_i, beam=params.beam_size,
encoder_out=encoder_out_i,
beam=params.beam_size,
) )
else: else:
raise ValueError( raise ValueError(
@ -625,6 +629,7 @@ def main():
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
params.suffix += f"-blank-penalty-{params.blank_penalty}"
if params.use_averaged_model: if params.use_averaged_model:
params.suffix += "-use-averaged-model" params.suffix += "-use-averaged-model"
@ -751,17 +756,30 @@ def main():
args.return_cuts = True args.return_cuts = True
wenetspeech = WenetSpeechAsrDataModule(args) wenetspeech = WenetSpeechAsrDataModule(args)
def remove_short_utt(c: Cut):
T = ((c.num_frames - 7) // 2 + 1) // 2
if T <= 0:
logging.warning(
f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}."
)
return T > 0
dev_cuts = wenetspeech.valid_cuts() dev_cuts = wenetspeech.valid_cuts()
dev_cuts = dev_cuts.filter(remove_short_utt)
dev_dl = wenetspeech.valid_dataloaders(dev_cuts) dev_dl = wenetspeech.valid_dataloaders(dev_cuts)
test_net_cuts = wenetspeech.test_net_cuts() test_net_cuts = wenetspeech.test_net_cuts()
test_net_cuts = test_net_cuts.filter(remove_short_utt)
test_net_dl = wenetspeech.test_dataloaders(test_net_cuts) test_net_dl = wenetspeech.test_dataloaders(test_net_cuts)
test_meeting_cuts = wenetspeech.test_meeting_cuts() test_meeting_cuts = wenetspeech.test_meeting_cuts()
test_meeting_cuts = test_meeting_cuts.filter(remove_short_utt)
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts) test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
test_dl = [dev_dl, test_net_dl, test_meeting_dl] test_dl = [dev_dl, test_net_dl, test_meeting_dl]
# test_sets = ["TEST_MEETING"]
# test_dl = [test_meeting_dl]
for test_set, test_dl in zip(test_sets, test_dl): for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset( results_dict = decode_dataset(
@ -774,9 +792,7 @@ def main():
) )
save_results( save_results(
params=params, params=params, test_set_name=test_set, results_dict=results_dict,
test_set_name=test_set,
results_dict=results_dict,
) )
logging.info("Done!") logging.info("Done!")