diff --git a/egs/wenetspeech4tts/TTS/f5-tts/infer.py b/egs/wenetspeech4tts/TTS/f5-tts/infer.py index 4db628a66..8d38af1ca 100644 --- a/egs/wenetspeech4tts/TTS/f5-tts/infer.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/infer.py @@ -281,9 +281,8 @@ def main(): vocoder = vocoder.eval().to(device) model = get_model(args).eval().to(device) - checkpoint = torch.load(args.model_path, map_location="cpu", weights_only=True) - - if "model_state_dict" or "ema_model_state_dict" in checkpoint: + checkpoint = torch.load(args.model_path, map_location="cpu") + if "ema_model_state_dict" in checkpoint or 'model_state_dict' in checkpoint: model = load_F5_TTS_pretrained_checkpoint(model, args.model_path) else: _ = load_checkpoint(