Update valle.py

This commit is contained in:
zr_jin 2024-12-06 13:57:59 +08:00
parent 2504036f5b
commit 94126e7f38

View File

@ -1669,8 +1669,8 @@ class VALLE(nn.Module):
output_dir: str,
limit: int = 4,
) -> None:
audio_features = batch["audio_features"].to("cpu").detach().numpy()
audio_features_lens = batch["audio_features_lens"].to("cpu").detach().numpy()
audio_features = batch["features"].to("cpu").detach().numpy()
audio_features_lens = batch["features_lens"].to("cpu").detach().numpy()
tokens = batch["tokens"]
text_tokens, text_tokens_lens = tokenizer(tokens)