mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 11:32:19 +00:00
added the missing `visualize
` function
This commit is contained in:
parent
58f7875c7e
commit
60c5a1d539
@ -519,12 +519,16 @@ def main():
|
||||
if split > 1:
|
||||
storage_path = f"{args.output_dir}/{args.prefix}_encodec_{partition}_{idx}"
|
||||
else:
|
||||
storage_path = f"{args.output_dir}/{args.prefix}_encodec_{partition}"
|
||||
storage_path = (
|
||||
f"{args.output_dir}/{args.prefix}_encodec_{partition}"
|
||||
)
|
||||
else:
|
||||
if split > 1:
|
||||
storage_path = f"{args.output_dir}/{args.prefix}_fbank_{partition}_{idx}"
|
||||
else:
|
||||
storage_path = f"{args.output_dir}/{args.prefix}_fbank_{partition}"
|
||||
storage_path = (
|
||||
f"{args.output_dir}/{args.prefix}_fbank_{partition}"
|
||||
)
|
||||
|
||||
if args.prefix.lower() in [
|
||||
"ljspeech",
|
||||
|
@ -19,6 +19,8 @@ import random
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
@ -1658,6 +1660,89 @@ class VALLE(nn.Module):
|
||||
assert len(codes) == 8
|
||||
return torch.stack(codes, dim=-1)
|
||||
|
||||
def visualize(
|
||||
self,
|
||||
predicts: Tuple[torch.Tensor],
|
||||
batch: Dict[str, Union[List, torch.Tensor]],
|
||||
output_dir: str,
|
||||
limit: int = 4,
|
||||
) -> None:
|
||||
text_tokens = batch["text_tokens"].to("cpu").detach().numpy()
|
||||
text_tokens_lens = batch["text_tokens_lens"].to("cpu").detach().numpy()
|
||||
audio_features = batch["audio_features"].to("cpu").detach().numpy()
|
||||
audio_features_lens = (
|
||||
batch["audio_features_lens"].to("cpu").detach().numpy()
|
||||
)
|
||||
assert text_tokens.ndim == 2
|
||||
|
||||
utt_ids, texts = batch["utt_id"], batch["text"]
|
||||
|
||||
encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy()
|
||||
decoder_outputs = predicts[1]
|
||||
if isinstance(decoder_outputs, list):
|
||||
decoder_outputs = decoder_outputs[-1]
|
||||
decoder_outputs = (
|
||||
decoder_outputs.to("cpu").type(torch.float32).detach().numpy()
|
||||
)
|
||||
|
||||
vmin, vmax = 0, 1024 # Encodec
|
||||
if decoder_outputs.dtype == np.float32:
|
||||
vmin, vmax = -6, 0 # Fbank
|
||||
|
||||
num_figures = 3
|
||||
for b, (utt_id, text) in enumerate(zip(utt_ids[:limit], texts[:limit])):
|
||||
_ = plt.figure(figsize=(14, 8 * num_figures))
|
||||
|
||||
S = text_tokens_lens[b]
|
||||
T = audio_features_lens[b]
|
||||
|
||||
# encoder
|
||||
plt.subplot(num_figures, 1, 1)
|
||||
plt.title(f"Text: {text}")
|
||||
plt.imshow(
|
||||
X=np.transpose(encoder_outputs[b]),
|
||||
cmap=plt.get_cmap("jet"),
|
||||
aspect="auto",
|
||||
interpolation="nearest",
|
||||
)
|
||||
plt.gca().invert_yaxis()
|
||||
plt.axvline(x=S - 0.4, linewidth=2, color="r")
|
||||
plt.xlabel("Encoder Output")
|
||||
plt.colorbar()
|
||||
|
||||
# decoder
|
||||
plt.subplot(num_figures, 1, 2)
|
||||
plt.imshow(
|
||||
X=np.transpose(decoder_outputs[b]),
|
||||
cmap=plt.get_cmap("jet"),
|
||||
aspect="auto",
|
||||
interpolation="nearest",
|
||||
vmin=vmin,
|
||||
vmax=vmax,
|
||||
)
|
||||
plt.gca().invert_yaxis()
|
||||
plt.axvline(x=T - 0.4, linewidth=2, color="r")
|
||||
plt.xlabel("Decoder Output")
|
||||
plt.colorbar()
|
||||
|
||||
# target
|
||||
plt.subplot(num_figures, 1, 3)
|
||||
plt.imshow(
|
||||
X=np.transpose(audio_features[b]),
|
||||
cmap=plt.get_cmap("jet"),
|
||||
aspect="auto",
|
||||
interpolation="nearest",
|
||||
vmin=vmin,
|
||||
vmax=vmax,
|
||||
)
|
||||
plt.gca().invert_yaxis()
|
||||
plt.axvline(x=T - 0.4, linewidth=2, color="r")
|
||||
plt.xlabel("Decoder Target")
|
||||
plt.colorbar()
|
||||
|
||||
plt.savefig(f"{output_dir}/{utt_id}.png")
|
||||
plt.close()
|
||||
|
||||
|
||||
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
|
||||
def top_k_top_p_filtering(
|
||||
|
Loading…
x
Reference in New Issue
Block a user