diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py index f386bcdd0..88b831f2f 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py @@ -561,7 +561,8 @@ def main(): return True # test_sets_cuts = multi_dataset.test_cuts() - test_sets_cuts = multi_dataset.aishell_test_cuts() + # test_sets_cuts = multi_dataset.aishell_test_cuts() + test_sets_cuts = multi_dataset.wenetspeech_test_meeting_cuts() test_sets = test_sets_cuts.keys() test_dls = [ diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/ds_config_zero1.json b/egs/speech_llm/ASR_LLM/whisper_llm_zh/ds_config_zero1.json index 9d2cef08f..730937a21 100644 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/ds_config_zero1.json +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/ds_config_zero1.json @@ -19,14 +19,14 @@ "optimizer": { "type": "Adam", "params": { - "lr": 5e-4 + "lr": 1e-4 } }, "scheduler": { "type": "WarmupLR", "params": { "warmup_min_lr": 0, - "warmup_max_lr": 5e-4, + "warmup_max_lr": 1e-4, "warmup_num_steps": 100 } }, diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py index 52bbf8c64..9b3ef6e69 100644 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py @@ -230,20 +230,20 @@ class MultiDataset: return { "wenetspeech-meeting_test": wenetspeech_test_meeting_cuts, - # "aishell_test": aishell_test_cuts, - # "aishell_dev": aishell_dev_cuts, - # "ali-meeting_test": alimeeting_test_cuts, - # "ali-meeting_eval": alimeeting_eval_cuts, - # "aishell-4_test": aishell4_test_cuts, - # "aishell-2_test": aishell2_test_cuts, - # "aishell-2_dev": aishell2_dev_cuts, - # "magicdata_test": magicdata_test_cuts, - # "magicdata_dev": magicdata_dev_cuts, - # "kespeech-asr_test": kespeech_test_cuts, - # "kespeech-asr_dev_phase1": kespeech_dev_phase1_cuts, - # "kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts, - # "wenetspeech-net_test": wenetspeech_test_net_cuts, - # "wenetspeech_dev": wenetspeech_dev_cuts, + "aishell_test": aishell_test_cuts, + "aishell_dev": aishell_dev_cuts, + "ali-meeting_test": alimeeting_test_cuts, + "ali-meeting_eval": alimeeting_eval_cuts, + "aishell-4_test": aishell4_test_cuts, + "aishell-2_test": aishell2_test_cuts, + "aishell-2_dev": aishell2_dev_cuts, + "magicdata_test": magicdata_test_cuts, + "magicdata_dev": magicdata_dev_cuts, + "kespeech-asr_test": kespeech_test_cuts, + "kespeech-asr_dev_phase1": kespeech_dev_phase1_cuts, + "kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts, + "wenetspeech-net_test": wenetspeech_test_net_cuts, + "wenetspeech_dev": wenetspeech_dev_cuts, } def aishell_train_cuts(self) -> CutSet: @@ -317,4 +317,17 @@ class MultiDataset: return { "aishell2_test": aishell2_test_cuts, + } + + def wenetspeech_test_meeting_cuts(self) -> CutSet: + logging.info("About to get multidataset test cuts") + + # WeNetSpeech + logging.info("Loading WeNetSpeech set in lazy mode") + wenetspeech_test_meeting_cuts = load_manifest_lazy( + self.fbank_dir / "wenetspeech" / "cuts_TEST_MEETING.jsonl.gz" + ) + + return { + "wenetspeech-meeting_test": wenetspeech_test_meeting_cuts, } \ No newline at end of file diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py index 10023ec9a..21b615930 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py @@ -823,14 +823,14 @@ def run(rank, world_size, args): # an utterance duration distribution for your dataset to select # the threshold if c.duration < 1.0 or c.duration > 20.0: - logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - ) + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) return False return True - # train_cuts = multi_dataset.train_cuts() - train_cuts = multi_dataset.aishell_train_cuts() + train_cuts = multi_dataset.train_cuts() + # train_cuts = multi_dataset.aishell_train_cuts() # train_cuts = multi_dataset.aishell2_train_cuts() train_cuts = train_cuts.filter(remove_short_and_long_utt)