diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py index 02c992049..7c3901c20 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py @@ -548,7 +548,7 @@ def main(): # torch.save(avg_checkpoint, filename) else: checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}/pytorch_model.bin", weights_only=False + f"{params.exp_dir}/epoch-{params.epoch}/pytorch_model.bin", weights_only=False, map_location="cpu", ) model.load_state_dict(checkpoint, strict=False)