mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
performed end to end testing to the VALL-E recipe (#1818)
* added the missing ``visualize`` function * minor fixes
This commit is contained in:
parent
bdd0f85704
commit
6e6b022e41
@ -516,9 +516,19 @@ def main():
|
||||
for idx, part in enumerate(cut_sets):
|
||||
if args.audio_extractor:
|
||||
if args.audio_extractor == "Encodec":
|
||||
storage_path = f"{args.output_dir}/{args.prefix}_encodec_{partition}_{idx if split > 1 else ''}"
|
||||
if split > 1:
|
||||
storage_path = f"{args.output_dir}/{args.prefix}_encodec_{partition}_{idx}"
|
||||
else:
|
||||
storage_path = (
|
||||
f"{args.output_dir}/{args.prefix}_encodec_{partition}"
|
||||
)
|
||||
else:
|
||||
storage_path = f"{args.output_dir}/{args.prefix}_fbank_{partition}_{idx if split > 1 else ''}"
|
||||
if split > 1:
|
||||
storage_path = f"{args.output_dir}/{args.prefix}_fbank_{partition}_{idx}"
|
||||
else:
|
||||
storage_path = (
|
||||
f"{args.output_dir}/{args.prefix}_fbank_{partition}"
|
||||
)
|
||||
|
||||
if args.prefix.lower() in [
|
||||
"ljspeech",
|
||||
@ -587,9 +597,11 @@ def main():
|
||||
].normalized_text, "normalized_text is None"
|
||||
|
||||
# Save each part with an index if split > 1
|
||||
cuts_filename = (
|
||||
f"{prefix}cuts_{partition}.{idx if split > 1 else ''}.{args.suffix}"
|
||||
)
|
||||
if split > 1:
|
||||
cuts_filename = f"{prefix}cuts_{partition}.{idx}.{args.suffix}"
|
||||
else:
|
||||
cuts_filename = f"{prefix}cuts_{partition}.{args.suffix}"
|
||||
|
||||
part.to_file(f"{args.output_dir}/{cuts_filename}")
|
||||
logging.info(f"Saved {cuts_filename}")
|
||||
|
||||
|
@ -86,7 +86,7 @@ def get_args():
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
default="exp/vallf_nano_full/checkpoint-100000.pt",
|
||||
default="./valle/exp/checkpoint-100000.pt",
|
||||
help="Path to the saved checkpoint.",
|
||||
)
|
||||
|
||||
|
2
egs/wenetspeech4tts/TTS/valle/requirements.txt
Normal file
2
egs/wenetspeech4tts/TTS/valle/requirements.txt
Normal file
@ -0,0 +1,2 @@
|
||||
phonemizer==3.2.1
|
||||
git+https://github.com/facebookresearch/encodec.git
|
@ -4,6 +4,7 @@
|
||||
# Mingshuang Luo)
|
||||
# Copyright 2023 (authors: Feiteng Li)
|
||||
# Copyright 2024 (authors: Yuekai Zhang)
|
||||
# Copyright 2024 Tsinghua University (authors: Zengrui Jin,)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -48,10 +49,8 @@ python3 valle/train.py --max-duration 160 --filter-min-duration 0.5 --filter-max
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
@ -216,7 +215,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="exp/valle_dev",
|
||||
default="./valle/exp",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
@ -686,9 +685,9 @@ def compute_validation_loss(
|
||||
output_dir = Path(f"{params.exp_dir}/eval/step-{params.batch_idx_train:06d}")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
if isinstance(model, DDP):
|
||||
model.module.visualize(predicts, batch, output_dir=output_dir)
|
||||
model.module.visualize(predicts, batch, tokenizer, output_dir=output_dir)
|
||||
else:
|
||||
model.visualize(predicts, batch, output_dir=output_dir)
|
||||
model.visualize(predicts, batch, tokenizer, output_dir=output_dir)
|
||||
|
||||
return tot_loss
|
||||
|
||||
|
@ -19,8 +19,11 @@ import random
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tokenizer import TextTokenCollater
|
||||
from torch import Tensor
|
||||
from torch.nn import Linear, Module
|
||||
from torch.nn import functional as F
|
||||
@ -1658,6 +1661,88 @@ class VALLE(nn.Module):
|
||||
assert len(codes) == 8
|
||||
return torch.stack(codes, dim=-1)
|
||||
|
||||
def visualize(
|
||||
self,
|
||||
predicts: Tuple[torch.Tensor],
|
||||
batch: Dict[str, Union[List, torch.Tensor]],
|
||||
tokenizer: TextTokenCollater,
|
||||
output_dir: str,
|
||||
limit: int = 4,
|
||||
) -> None:
|
||||
audio_features = batch["features"].to("cpu").detach().numpy()
|
||||
audio_features_lens = batch["features_lens"].to("cpu").detach().numpy()
|
||||
|
||||
tokens = batch["tokens"]
|
||||
text_tokens, text_tokens_lens = tokenizer(tokens)
|
||||
assert text_tokens.ndim == 2
|
||||
|
||||
texts = batch["text"]
|
||||
utt_ids = [cut.id for cut in batch["cut"]]
|
||||
|
||||
encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy()
|
||||
decoder_outputs = predicts[1]
|
||||
if isinstance(decoder_outputs, list):
|
||||
decoder_outputs = decoder_outputs[-1]
|
||||
decoder_outputs = decoder_outputs.to("cpu").type(torch.float32).detach().numpy()
|
||||
|
||||
vmin, vmax = 0, 1024 # Encodec
|
||||
if decoder_outputs.dtype == np.float32:
|
||||
vmin, vmax = -6, 0 # Fbank
|
||||
|
||||
num_figures = 3
|
||||
for b, (utt_id, text) in enumerate(zip(utt_ids[:limit], texts[:limit])):
|
||||
_ = plt.figure(figsize=(14, 8 * num_figures))
|
||||
|
||||
S = text_tokens_lens[b]
|
||||
T = audio_features_lens[b]
|
||||
|
||||
# encoder
|
||||
plt.subplot(num_figures, 1, 1)
|
||||
plt.title(f"Text: {text}")
|
||||
plt.imshow(
|
||||
X=np.transpose(encoder_outputs[b]),
|
||||
cmap=plt.get_cmap("jet"),
|
||||
aspect="auto",
|
||||
interpolation="nearest",
|
||||
)
|
||||
plt.gca().invert_yaxis()
|
||||
plt.axvline(x=S - 0.4, linewidth=2, color="r")
|
||||
plt.xlabel("Encoder Output")
|
||||
plt.colorbar()
|
||||
|
||||
# decoder
|
||||
plt.subplot(num_figures, 1, 2)
|
||||
plt.imshow(
|
||||
X=np.transpose(decoder_outputs[b]),
|
||||
cmap=plt.get_cmap("jet"),
|
||||
aspect="auto",
|
||||
interpolation="nearest",
|
||||
vmin=vmin,
|
||||
vmax=vmax,
|
||||
)
|
||||
plt.gca().invert_yaxis()
|
||||
plt.axvline(x=T - 0.4, linewidth=2, color="r")
|
||||
plt.xlabel("Decoder Output")
|
||||
plt.colorbar()
|
||||
|
||||
# target
|
||||
plt.subplot(num_figures, 1, 3)
|
||||
plt.imshow(
|
||||
X=np.transpose(audio_features[b]),
|
||||
cmap=plt.get_cmap("jet"),
|
||||
aspect="auto",
|
||||
interpolation="nearest",
|
||||
vmin=vmin,
|
||||
vmax=vmax,
|
||||
)
|
||||
plt.gca().invert_yaxis()
|
||||
plt.axvline(x=T - 0.4, linewidth=2, color="r")
|
||||
plt.xlabel("Decoder Target")
|
||||
plt.colorbar()
|
||||
|
||||
plt.savefig(f"{output_dir}/{utt_id}.png")
|
||||
plt.close()
|
||||
|
||||
|
||||
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
|
||||
def top_k_top_p_filtering(
|
||||
|
Loading…
x
Reference in New Issue
Block a user