fix infer error

This commit is contained in:
root 2024-12-25 02:17:18 +00:00
parent 03d500a414
commit f3fca0c81b

View File

@ -281,9 +281,8 @@ def main():
vocoder = vocoder.eval().to(device)
model = get_model(args).eval().to(device)
checkpoint = torch.load(args.model_path, map_location="cpu", weights_only=True)
if "model_state_dict" or "ema_model_state_dict" in checkpoint:
checkpoint = torch.load(args.model_path, map_location="cpu")
if "ema_model_state_dict" in checkpoint or 'model_state_dict' in checkpoint:
model = load_F5_TTS_pretrained_checkpoint(model, args.model_path)
else:
_ = load_checkpoint(