diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 0280193ca..0c1fd98dc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -626,6 +626,7 @@ def greedy_search_batch( model: Transducer, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, + blank_penalty: float = 0, return_timestamps: bool = False, ) -> Union[List[List[int]], DecodingResults]: """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. @@ -701,6 +702,7 @@ def greedy_search_batch( logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) assert logits.ndim == 2, logits.shape + logits[:, 0] -= blank_penalty y = logits.argmax(dim=1).tolist() emitted = False for i, v in enumerate(y): diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index c9e30e737..6e3d04d85 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -292,7 +292,7 @@ class WenetSpeechAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, - buffer_size=30000, + buffer_size=300000, drop_last=True, ) else: diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless7/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless7/decode.py index 30c66ac39..231cbdccd 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless7/decode.py @@ -125,6 +125,7 @@ from beam_search import ( ) from train import add_model_arguments, get_params, get_transducer_model +from lhotse.cut import Cut from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.checkpoint import ( average_checkpoints, @@ -302,6 +303,13 @@ def get_parser(): fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help="", + ) + add_model_arguments(parser) return parser @@ -414,9 +422,7 @@ def decode_one_batch( hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, + model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, blank_penalty=params.blank_penalty, ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) @@ -444,9 +450,7 @@ def decode_one_batch( ) elif params.decoding_method == "beam_search": hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, + model=model, encoder_out=encoder_out_i, beam=params.beam_size, ) else: raise ValueError( @@ -625,6 +629,7 @@ def main(): else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + params.suffix += f"-blank-penalty-{params.blank_penalty}" if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -751,17 +756,30 @@ def main(): args.return_cuts = True wenetspeech = WenetSpeechAsrDataModule(args) + def remove_short_utt(c: Cut): + T = ((c.num_frames - 7) // 2 + 1) // 2 + if T <= 0: + logging.warning( + f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}." + ) + return T > 0 + dev_cuts = wenetspeech.valid_cuts() + dev_cuts = dev_cuts.filter(remove_short_utt) dev_dl = wenetspeech.valid_dataloaders(dev_cuts) test_net_cuts = wenetspeech.test_net_cuts() + test_net_cuts = test_net_cuts.filter(remove_short_utt) test_net_dl = wenetspeech.test_dataloaders(test_net_cuts) test_meeting_cuts = wenetspeech.test_meeting_cuts() + test_meeting_cuts = test_meeting_cuts.filter(remove_short_utt) test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts) test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] test_dl = [dev_dl, test_net_dl, test_meeting_dl] + # test_sets = ["TEST_MEETING"] + # test_dl = [test_meeting_dl] for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( @@ -774,9 +792,7 @@ def main(): ) save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, + params=params, test_set_name=test_set, results_dict=results_dict, ) logging.info("Done!")