black formatted

This commit is contained in:
zr_jin 2024-12-06 10:44:14 +08:00
parent 60c5a1d539
commit ce73643af6

View File

@ -1670,9 +1670,7 @@ class VALLE(nn.Module):
text_tokens = batch["text_tokens"].to("cpu").detach().numpy() text_tokens = batch["text_tokens"].to("cpu").detach().numpy()
text_tokens_lens = batch["text_tokens_lens"].to("cpu").detach().numpy() text_tokens_lens = batch["text_tokens_lens"].to("cpu").detach().numpy()
audio_features = batch["audio_features"].to("cpu").detach().numpy() audio_features = batch["audio_features"].to("cpu").detach().numpy()
audio_features_lens = ( audio_features_lens = batch["audio_features_lens"].to("cpu").detach().numpy()
batch["audio_features_lens"].to("cpu").detach().numpy()
)
assert text_tokens.ndim == 2 assert text_tokens.ndim == 2
utt_ids, texts = batch["utt_id"], batch["text"] utt_ids, texts = batch["utt_id"], batch["text"]
@ -1681,9 +1679,7 @@ class VALLE(nn.Module):
decoder_outputs = predicts[1] decoder_outputs = predicts[1]
if isinstance(decoder_outputs, list): if isinstance(decoder_outputs, list):
decoder_outputs = decoder_outputs[-1] decoder_outputs = decoder_outputs[-1]
decoder_outputs = ( decoder_outputs = decoder_outputs.to("cpu").type(torch.float32).detach().numpy()
decoder_outputs.to("cpu").type(torch.float32).detach().numpy()
)
vmin, vmax = 0, 1024 # Encodec vmin, vmax = 0, 1024 # Encodec
if decoder_outputs.dtype == np.float32: if decoder_outputs.dtype == np.float32: