diff --git a/egs/wenetspeech4tts/TTS/valle/valle.py b/egs/wenetspeech4tts/TTS/valle/valle.py index bfe047617..40501736b 100644 --- a/egs/wenetspeech4tts/TTS/valle/valle.py +++ b/egs/wenetspeech4tts/TTS/valle/valle.py @@ -1670,9 +1670,7 @@ class VALLE(nn.Module): text_tokens = batch["text_tokens"].to("cpu").detach().numpy() text_tokens_lens = batch["text_tokens_lens"].to("cpu").detach().numpy() audio_features = batch["audio_features"].to("cpu").detach().numpy() - audio_features_lens = ( - batch["audio_features_lens"].to("cpu").detach().numpy() - ) + audio_features_lens = batch["audio_features_lens"].to("cpu").detach().numpy() assert text_tokens.ndim == 2 utt_ids, texts = batch["utt_id"], batch["text"] @@ -1681,9 +1679,7 @@ class VALLE(nn.Module): decoder_outputs = predicts[1] if isinstance(decoder_outputs, list): decoder_outputs = decoder_outputs[-1] - decoder_outputs = ( - decoder_outputs.to("cpu").type(torch.float32).detach().numpy() - ) + decoder_outputs = decoder_outputs.to("cpu").type(torch.float32).detach().numpy() vmin, vmax = 0, 1024 # Encodec if decoder_outputs.dtype == np.float32: