diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py index 2ddfdf09d..f4d30d28a 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py @@ -823,6 +823,7 @@ def run(rank, world_size, args): sampler_state_dict = None if params.sampler_state_dict_path: sampler_state_dict = torch.load(params.sampler_state_dict_path) + sampler_state_dict["max_duration"] = params.max_duration # TODO: load sampler state dict train_dl = data_module.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict