From f3fca0c81b6f88ede10c6597033d920509ebb086 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 25 Dec 2024 02:17:18 +0000 Subject: [PATCH] fix infer error --- egs/wenetspeech4tts/TTS/f5-tts/infer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/egs/wenetspeech4tts/TTS/f5-tts/infer.py b/egs/wenetspeech4tts/TTS/f5-tts/infer.py index 4db628a66..8d38af1ca 100644 --- a/egs/wenetspeech4tts/TTS/f5-tts/infer.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/infer.py @@ -281,9 +281,8 @@ def main(): vocoder = vocoder.eval().to(device) model = get_model(args).eval().to(device) - checkpoint = torch.load(args.model_path, map_location="cpu", weights_only=True) - - if "model_state_dict" or "ema_model_state_dict" in checkpoint: + checkpoint = torch.load(args.model_path, map_location="cpu") + if "ema_model_state_dict" in checkpoint or 'model_state_dict' in checkpoint: model = load_F5_TTS_pretrained_checkpoint(model, args.model_path) else: _ = load_checkpoint(