added the missing `visualize` function

This commit is contained in:
zr_jin 2024-12-06 10:12:58 +08:00
parent 58f7875c7e
commit 60c5a1d539
2 changed files with 91 additions and 2 deletions

View File

@ -519,12 +519,16 @@ def main():
if split > 1:
storage_path = f"{args.output_dir}/{args.prefix}_encodec_{partition}_{idx}"
else:
storage_path = f"{args.output_dir}/{args.prefix}_encodec_{partition}"
storage_path = (
f"{args.output_dir}/{args.prefix}_encodec_{partition}"
)
else:
if split > 1:
storage_path = f"{args.output_dir}/{args.prefix}_fbank_{partition}_{idx}"
else:
storage_path = f"{args.output_dir}/{args.prefix}_fbank_{partition}"
storage_path = (
f"{args.output_dir}/{args.prefix}_fbank_{partition}"
)
if args.prefix.lower() in [
"ljspeech",

View File

@ -19,6 +19,8 @@ import random
from functools import partial
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
@ -1658,6 +1660,89 @@ class VALLE(nn.Module):
assert len(codes) == 8
return torch.stack(codes, dim=-1)
def visualize(
self,
predicts: Tuple[torch.Tensor],
batch: Dict[str, Union[List, torch.Tensor]],
output_dir: str,
limit: int = 4,
) -> None:
text_tokens = batch["text_tokens"].to("cpu").detach().numpy()
text_tokens_lens = batch["text_tokens_lens"].to("cpu").detach().numpy()
audio_features = batch["audio_features"].to("cpu").detach().numpy()
audio_features_lens = (
batch["audio_features_lens"].to("cpu").detach().numpy()
)
assert text_tokens.ndim == 2
utt_ids, texts = batch["utt_id"], batch["text"]
encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy()
decoder_outputs = predicts[1]
if isinstance(decoder_outputs, list):
decoder_outputs = decoder_outputs[-1]
decoder_outputs = (
decoder_outputs.to("cpu").type(torch.float32).detach().numpy()
)
vmin, vmax = 0, 1024 # Encodec
if decoder_outputs.dtype == np.float32:
vmin, vmax = -6, 0 # Fbank
num_figures = 3
for b, (utt_id, text) in enumerate(zip(utt_ids[:limit], texts[:limit])):
_ = plt.figure(figsize=(14, 8 * num_figures))
S = text_tokens_lens[b]
T = audio_features_lens[b]
# encoder
plt.subplot(num_figures, 1, 1)
plt.title(f"Text: {text}")
plt.imshow(
X=np.transpose(encoder_outputs[b]),
cmap=plt.get_cmap("jet"),
aspect="auto",
interpolation="nearest",
)
plt.gca().invert_yaxis()
plt.axvline(x=S - 0.4, linewidth=2, color="r")
plt.xlabel("Encoder Output")
plt.colorbar()
# decoder
plt.subplot(num_figures, 1, 2)
plt.imshow(
X=np.transpose(decoder_outputs[b]),
cmap=plt.get_cmap("jet"),
aspect="auto",
interpolation="nearest",
vmin=vmin,
vmax=vmax,
)
plt.gca().invert_yaxis()
plt.axvline(x=T - 0.4, linewidth=2, color="r")
plt.xlabel("Decoder Output")
plt.colorbar()
# target
plt.subplot(num_figures, 1, 3)
plt.imshow(
X=np.transpose(audio_features[b]),
cmap=plt.get_cmap("jet"),
aspect="auto",
interpolation="nearest",
vmin=vmin,
vmax=vmax,
)
plt.gca().invert_yaxis()
plt.axvline(x=T - 0.4, linewidth=2, color="r")
plt.xlabel("Decoder Target")
plt.colorbar()
plt.savefig(f"{output_dir}/{utt_id}.png")
plt.close()
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
def top_k_top_p_filtering(