diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py index 30f8ba76d..f3bdb452c 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py @@ -588,7 +588,7 @@ def main(): test_sets_cuts = multi_dataset.aishell_test_cuts() elif params.dataset == "speechio": test_sets_cuts = multi_dataset.speechio_test_cuts() - elif params.dataaset == "wenetspeech_test_meeting": + elif params.dataset == "wenetspeech_test_meeting": test_sets_cuts = multi_dataset.wenetspeech_test_meeting_cuts() else: test_sets_cuts = multi_dataset.test_cuts() diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py index ddaf6078c..2ddfdf09d 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py @@ -190,6 +190,14 @@ def get_parser(): """, ) + parser.add_argument( + "--sampler-state-dict-path", + type=str, + default=None, + help="""The path to the sampler state dict if it is not None. Training will start from this sampler state dict. + """, + ) + parser.add_argument( "--base-lr", type=float, default=1e-5, help="The base learning rate." ) @@ -813,6 +821,8 @@ def run(rank, world_size, args): # else: # sampler_state_dict = None sampler_state_dict = None + if params.sampler_state_dict_path: + sampler_state_dict = torch.load(params.sampler_state_dict_path) # TODO: load sampler state dict train_dl = data_module.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict