diff --git a/egs/wenetspeech4tts/TTS/valle/valle.py b/egs/wenetspeech4tts/TTS/valle/valle.py index 206b843ba..0cf73926b 100644 --- a/egs/wenetspeech4tts/TTS/valle/valle.py +++ b/egs/wenetspeech4tts/TTS/valle/valle.py @@ -1669,8 +1669,8 @@ class VALLE(nn.Module): output_dir: str, limit: int = 4, ) -> None: - audio_features = batch["audio_features"].to("cpu").detach().numpy() - audio_features_lens = batch["audio_features_lens"].to("cpu").detach().numpy() + audio_features = batch["features"].to("cpu").detach().numpy() + audio_features_lens = batch["features_lens"].to("cpu").detach().numpy() tokens = batch["tokens"] text_tokens, text_tokens_lens = tokenizer(tokens)