mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
black formatted
This commit is contained in:
parent
60c5a1d539
commit
ce73643af6
@ -1670,9 +1670,7 @@ class VALLE(nn.Module):
|
|||||||
text_tokens = batch["text_tokens"].to("cpu").detach().numpy()
|
text_tokens = batch["text_tokens"].to("cpu").detach().numpy()
|
||||||
text_tokens_lens = batch["text_tokens_lens"].to("cpu").detach().numpy()
|
text_tokens_lens = batch["text_tokens_lens"].to("cpu").detach().numpy()
|
||||||
audio_features = batch["audio_features"].to("cpu").detach().numpy()
|
audio_features = batch["audio_features"].to("cpu").detach().numpy()
|
||||||
audio_features_lens = (
|
audio_features_lens = batch["audio_features_lens"].to("cpu").detach().numpy()
|
||||||
batch["audio_features_lens"].to("cpu").detach().numpy()
|
|
||||||
)
|
|
||||||
assert text_tokens.ndim == 2
|
assert text_tokens.ndim == 2
|
||||||
|
|
||||||
utt_ids, texts = batch["utt_id"], batch["text"]
|
utt_ids, texts = batch["utt_id"], batch["text"]
|
||||||
@ -1681,9 +1679,7 @@ class VALLE(nn.Module):
|
|||||||
decoder_outputs = predicts[1]
|
decoder_outputs = predicts[1]
|
||||||
if isinstance(decoder_outputs, list):
|
if isinstance(decoder_outputs, list):
|
||||||
decoder_outputs = decoder_outputs[-1]
|
decoder_outputs = decoder_outputs[-1]
|
||||||
decoder_outputs = (
|
decoder_outputs = decoder_outputs.to("cpu").type(torch.float32).detach().numpy()
|
||||||
decoder_outputs.to("cpu").type(torch.float32).detach().numpy()
|
|
||||||
)
|
|
||||||
|
|
||||||
vmin, vmax = 0, 1024 # Encodec
|
vmin, vmax = 0, 1024 # Encodec
|
||||||
if decoder_outputs.dtype == np.float32:
|
if decoder_outputs.dtype == np.float32:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user