mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
performed end to end testing to the VALL-E recipe (#1818)
* added the missing ``visualize`` function * minor fixes
This commit is contained in:
parent
bdd0f85704
commit
6e6b022e41
@ -516,9 +516,19 @@ def main():
|
|||||||
for idx, part in enumerate(cut_sets):
|
for idx, part in enumerate(cut_sets):
|
||||||
if args.audio_extractor:
|
if args.audio_extractor:
|
||||||
if args.audio_extractor == "Encodec":
|
if args.audio_extractor == "Encodec":
|
||||||
storage_path = f"{args.output_dir}/{args.prefix}_encodec_{partition}_{idx if split > 1 else ''}"
|
if split > 1:
|
||||||
|
storage_path = f"{args.output_dir}/{args.prefix}_encodec_{partition}_{idx}"
|
||||||
|
else:
|
||||||
|
storage_path = (
|
||||||
|
f"{args.output_dir}/{args.prefix}_encodec_{partition}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
storage_path = f"{args.output_dir}/{args.prefix}_fbank_{partition}_{idx if split > 1 else ''}"
|
if split > 1:
|
||||||
|
storage_path = f"{args.output_dir}/{args.prefix}_fbank_{partition}_{idx}"
|
||||||
|
else:
|
||||||
|
storage_path = (
|
||||||
|
f"{args.output_dir}/{args.prefix}_fbank_{partition}"
|
||||||
|
)
|
||||||
|
|
||||||
if args.prefix.lower() in [
|
if args.prefix.lower() in [
|
||||||
"ljspeech",
|
"ljspeech",
|
||||||
@ -587,9 +597,11 @@ def main():
|
|||||||
].normalized_text, "normalized_text is None"
|
].normalized_text, "normalized_text is None"
|
||||||
|
|
||||||
# Save each part with an index if split > 1
|
# Save each part with an index if split > 1
|
||||||
cuts_filename = (
|
if split > 1:
|
||||||
f"{prefix}cuts_{partition}.{idx if split > 1 else ''}.{args.suffix}"
|
cuts_filename = f"{prefix}cuts_{partition}.{idx}.{args.suffix}"
|
||||||
)
|
else:
|
||||||
|
cuts_filename = f"{prefix}cuts_{partition}.{args.suffix}"
|
||||||
|
|
||||||
part.to_file(f"{args.output_dir}/{cuts_filename}")
|
part.to_file(f"{args.output_dir}/{cuts_filename}")
|
||||||
logging.info(f"Saved {cuts_filename}")
|
logging.info(f"Saved {cuts_filename}")
|
||||||
|
|
||||||
|
@ -86,7 +86,7 @@ def get_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--checkpoint",
|
"--checkpoint",
|
||||||
type=str,
|
type=str,
|
||||||
default="exp/vallf_nano_full/checkpoint-100000.pt",
|
default="./valle/exp/checkpoint-100000.pt",
|
||||||
help="Path to the saved checkpoint.",
|
help="Path to the saved checkpoint.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
2
egs/wenetspeech4tts/TTS/valle/requirements.txt
Normal file
2
egs/wenetspeech4tts/TTS/valle/requirements.txt
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
phonemizer==3.2.1
|
||||||
|
git+https://github.com/facebookresearch/encodec.git
|
@ -4,6 +4,7 @@
|
|||||||
# Mingshuang Luo)
|
# Mingshuang Luo)
|
||||||
# Copyright 2023 (authors: Feiteng Li)
|
# Copyright 2023 (authors: Feiteng Li)
|
||||||
# Copyright 2024 (authors: Yuekai Zhang)
|
# Copyright 2024 (authors: Yuekai Zhang)
|
||||||
|
# Copyright 2024 Tsinghua University (authors: Zengrui Jin,)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -48,10 +49,8 @@ python3 valle/train.py --max-duration 160 --filter-min-duration 0.5 --filter-max
|
|||||||
import argparse
|
import argparse
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import random
|
import random
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import nullcontext
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
@ -216,7 +215,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="exp/valle_dev",
|
default="./valle/exp",
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
@ -686,9 +685,9 @@ def compute_validation_loss(
|
|||||||
output_dir = Path(f"{params.exp_dir}/eval/step-{params.batch_idx_train:06d}")
|
output_dir = Path(f"{params.exp_dir}/eval/step-{params.batch_idx_train:06d}")
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
if isinstance(model, DDP):
|
if isinstance(model, DDP):
|
||||||
model.module.visualize(predicts, batch, output_dir=output_dir)
|
model.module.visualize(predicts, batch, tokenizer, output_dir=output_dir)
|
||||||
else:
|
else:
|
||||||
model.visualize(predicts, batch, output_dir=output_dir)
|
model.visualize(predicts, batch, tokenizer, output_dir=output_dir)
|
||||||
|
|
||||||
return tot_loss
|
return tot_loss
|
||||||
|
|
||||||
|
@ -19,8 +19,11 @@ import random
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from tokenizer import TextTokenCollater
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import Linear, Module
|
from torch.nn import Linear, Module
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
@ -1658,6 +1661,88 @@ class VALLE(nn.Module):
|
|||||||
assert len(codes) == 8
|
assert len(codes) == 8
|
||||||
return torch.stack(codes, dim=-1)
|
return torch.stack(codes, dim=-1)
|
||||||
|
|
||||||
|
def visualize(
|
||||||
|
self,
|
||||||
|
predicts: Tuple[torch.Tensor],
|
||||||
|
batch: Dict[str, Union[List, torch.Tensor]],
|
||||||
|
tokenizer: TextTokenCollater,
|
||||||
|
output_dir: str,
|
||||||
|
limit: int = 4,
|
||||||
|
) -> None:
|
||||||
|
audio_features = batch["features"].to("cpu").detach().numpy()
|
||||||
|
audio_features_lens = batch["features_lens"].to("cpu").detach().numpy()
|
||||||
|
|
||||||
|
tokens = batch["tokens"]
|
||||||
|
text_tokens, text_tokens_lens = tokenizer(tokens)
|
||||||
|
assert text_tokens.ndim == 2
|
||||||
|
|
||||||
|
texts = batch["text"]
|
||||||
|
utt_ids = [cut.id for cut in batch["cut"]]
|
||||||
|
|
||||||
|
encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy()
|
||||||
|
decoder_outputs = predicts[1]
|
||||||
|
if isinstance(decoder_outputs, list):
|
||||||
|
decoder_outputs = decoder_outputs[-1]
|
||||||
|
decoder_outputs = decoder_outputs.to("cpu").type(torch.float32).detach().numpy()
|
||||||
|
|
||||||
|
vmin, vmax = 0, 1024 # Encodec
|
||||||
|
if decoder_outputs.dtype == np.float32:
|
||||||
|
vmin, vmax = -6, 0 # Fbank
|
||||||
|
|
||||||
|
num_figures = 3
|
||||||
|
for b, (utt_id, text) in enumerate(zip(utt_ids[:limit], texts[:limit])):
|
||||||
|
_ = plt.figure(figsize=(14, 8 * num_figures))
|
||||||
|
|
||||||
|
S = text_tokens_lens[b]
|
||||||
|
T = audio_features_lens[b]
|
||||||
|
|
||||||
|
# encoder
|
||||||
|
plt.subplot(num_figures, 1, 1)
|
||||||
|
plt.title(f"Text: {text}")
|
||||||
|
plt.imshow(
|
||||||
|
X=np.transpose(encoder_outputs[b]),
|
||||||
|
cmap=plt.get_cmap("jet"),
|
||||||
|
aspect="auto",
|
||||||
|
interpolation="nearest",
|
||||||
|
)
|
||||||
|
plt.gca().invert_yaxis()
|
||||||
|
plt.axvline(x=S - 0.4, linewidth=2, color="r")
|
||||||
|
plt.xlabel("Encoder Output")
|
||||||
|
plt.colorbar()
|
||||||
|
|
||||||
|
# decoder
|
||||||
|
plt.subplot(num_figures, 1, 2)
|
||||||
|
plt.imshow(
|
||||||
|
X=np.transpose(decoder_outputs[b]),
|
||||||
|
cmap=plt.get_cmap("jet"),
|
||||||
|
aspect="auto",
|
||||||
|
interpolation="nearest",
|
||||||
|
vmin=vmin,
|
||||||
|
vmax=vmax,
|
||||||
|
)
|
||||||
|
plt.gca().invert_yaxis()
|
||||||
|
plt.axvline(x=T - 0.4, linewidth=2, color="r")
|
||||||
|
plt.xlabel("Decoder Output")
|
||||||
|
plt.colorbar()
|
||||||
|
|
||||||
|
# target
|
||||||
|
plt.subplot(num_figures, 1, 3)
|
||||||
|
plt.imshow(
|
||||||
|
X=np.transpose(audio_features[b]),
|
||||||
|
cmap=plt.get_cmap("jet"),
|
||||||
|
aspect="auto",
|
||||||
|
interpolation="nearest",
|
||||||
|
vmin=vmin,
|
||||||
|
vmax=vmax,
|
||||||
|
)
|
||||||
|
plt.gca().invert_yaxis()
|
||||||
|
plt.axvline(x=T - 0.4, linewidth=2, color="r")
|
||||||
|
plt.xlabel("Decoder Target")
|
||||||
|
plt.colorbar()
|
||||||
|
|
||||||
|
plt.savefig(f"{output_dir}/{utt_id}.png")
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
|
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
|
||||||
def top_k_top_p_filtering(
|
def top_k_top_p_filtering(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user