mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
Merge d8a0a4095554e58db0fc8f4e30a6a33932ab37dd into 34fc1fdf0d8ff520e2bb18267d046ca207c78ef9
This commit is contained in:
commit
cfbd4208cc
296
egs/libritts/TTS/vocos/discriminators.py
Normal file
296
egs/libritts/TTS/vocos/discriminators.py
Normal file
@ -0,0 +1,296 @@
|
|||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import Conv2d
|
||||||
|
from torch.nn.utils.parametrizations import weight_norm
|
||||||
|
from torchaudio.transforms import Spectrogram
|
||||||
|
|
||||||
|
|
||||||
|
class MultiPeriodDiscriminator(nn.Module):
|
||||||
|
"""
|
||||||
|
Multi-Period Discriminator module adapted from https://github.com/jik876/hifi-gan.
|
||||||
|
Additionally, it allows incorporating conditional information with a learned embeddings table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
periods (tuple[int]): Tuple of periods for each discriminator.
|
||||||
|
num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
|
||||||
|
Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
periods: Tuple[int, ...] = (2, 3, 5, 7, 11),
|
||||||
|
num_embeddings: Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.discriminators = nn.ModuleList(
|
||||||
|
[DiscriminatorP(period=p, num_embeddings=num_embeddings) for p in periods]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
y: torch.Tensor,
|
||||||
|
y_hat: torch.Tensor,
|
||||||
|
bandwidth_id: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[
|
||||||
|
List[torch.Tensor],
|
||||||
|
List[torch.Tensor],
|
||||||
|
List[List[torch.Tensor]],
|
||||||
|
List[List[torch.Tensor]],
|
||||||
|
]:
|
||||||
|
y_d_rs = []
|
||||||
|
y_d_gs = []
|
||||||
|
fmap_rs = []
|
||||||
|
fmap_gs = []
|
||||||
|
for d in self.discriminators:
|
||||||
|
y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
|
||||||
|
y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
|
||||||
|
y_d_rs.append(y_d_r)
|
||||||
|
fmap_rs.append(fmap_r)
|
||||||
|
y_d_gs.append(y_d_g)
|
||||||
|
fmap_gs.append(fmap_g)
|
||||||
|
|
||||||
|
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||||
|
|
||||||
|
|
||||||
|
class DiscriminatorP(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
period: int,
|
||||||
|
in_channels: int = 1,
|
||||||
|
kernel_size: int = 5,
|
||||||
|
stride: int = 3,
|
||||||
|
lrelu_slope: float = 0.1,
|
||||||
|
num_embeddings: Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.period = period
|
||||||
|
self.convs = nn.ModuleList(
|
||||||
|
[
|
||||||
|
weight_norm(
|
||||||
|
Conv2d(
|
||||||
|
in_channels,
|
||||||
|
32,
|
||||||
|
(kernel_size, 1),
|
||||||
|
(stride, 1),
|
||||||
|
padding=(kernel_size // 2, 0),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
weight_norm(
|
||||||
|
Conv2d(
|
||||||
|
32,
|
||||||
|
128,
|
||||||
|
(kernel_size, 1),
|
||||||
|
(stride, 1),
|
||||||
|
padding=(kernel_size // 2, 0),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
weight_norm(
|
||||||
|
Conv2d(
|
||||||
|
128,
|
||||||
|
512,
|
||||||
|
(kernel_size, 1),
|
||||||
|
(stride, 1),
|
||||||
|
padding=(kernel_size // 2, 0),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
weight_norm(
|
||||||
|
Conv2d(
|
||||||
|
512,
|
||||||
|
1024,
|
||||||
|
(kernel_size, 1),
|
||||||
|
(stride, 1),
|
||||||
|
padding=(kernel_size // 2, 0),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
weight_norm(
|
||||||
|
Conv2d(
|
||||||
|
1024,
|
||||||
|
1024,
|
||||||
|
(kernel_size, 1),
|
||||||
|
(1, 1),
|
||||||
|
padding=(kernel_size // 2, 0),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if num_embeddings is not None:
|
||||||
|
self.emb = torch.nn.Embedding(
|
||||||
|
num_embeddings=num_embeddings, embedding_dim=1024
|
||||||
|
)
|
||||||
|
torch.nn.init.zeros_(self.emb.weight)
|
||||||
|
|
||||||
|
self.conv_post = weight_norm(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
||||||
|
self.lrelu_slope = lrelu_slope
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
|
||||||
|
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||||
|
x = x.unsqueeze(1)
|
||||||
|
fmap = []
|
||||||
|
# 1d to 2d
|
||||||
|
b, c, t = x.shape
|
||||||
|
if t % self.period != 0: # pad first
|
||||||
|
n_pad = self.period - (t % self.period)
|
||||||
|
x = torch.nn.functional.pad(x, (0, n_pad), "reflect")
|
||||||
|
t = t + n_pad
|
||||||
|
x = x.view(b, c, t // self.period, self.period)
|
||||||
|
|
||||||
|
for i, l in enumerate(self.convs):
|
||||||
|
x = l(x)
|
||||||
|
x = torch.nn.functional.leaky_relu(x, self.lrelu_slope)
|
||||||
|
if i > 0:
|
||||||
|
fmap.append(x)
|
||||||
|
if cond_embedding_id is not None:
|
||||||
|
emb = self.emb(cond_embedding_id)
|
||||||
|
h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
|
||||||
|
else:
|
||||||
|
h = 0
|
||||||
|
x = self.conv_post(x)
|
||||||
|
fmap.append(x)
|
||||||
|
x += h
|
||||||
|
x = torch.flatten(x, 1, -1)
|
||||||
|
|
||||||
|
return x, fmap
|
||||||
|
|
||||||
|
|
||||||
|
class MultiResolutionDiscriminator(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
|
||||||
|
num_embeddings: Optional[int] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
|
||||||
|
Additionally, it allows incorporating conditional information with a learned embeddings table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
|
||||||
|
num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
|
||||||
|
Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.discriminators = nn.ModuleList(
|
||||||
|
[
|
||||||
|
DiscriminatorR(window_length=w, num_embeddings=num_embeddings)
|
||||||
|
for w in fft_sizes
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
|
||||||
|
) -> Tuple[
|
||||||
|
List[torch.Tensor],
|
||||||
|
List[torch.Tensor],
|
||||||
|
List[List[torch.Tensor]],
|
||||||
|
List[List[torch.Tensor]],
|
||||||
|
]:
|
||||||
|
y_d_rs = []
|
||||||
|
y_d_gs = []
|
||||||
|
fmap_rs = []
|
||||||
|
fmap_gs = []
|
||||||
|
|
||||||
|
for d in self.discriminators:
|
||||||
|
y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
|
||||||
|
y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
|
||||||
|
y_d_rs.append(y_d_r)
|
||||||
|
fmap_rs.append(fmap_r)
|
||||||
|
y_d_gs.append(y_d_g)
|
||||||
|
fmap_gs.append(fmap_g)
|
||||||
|
|
||||||
|
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||||
|
|
||||||
|
|
||||||
|
class DiscriminatorR(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
window_length: int,
|
||||||
|
num_embeddings: Optional[int] = None,
|
||||||
|
channels: int = 32,
|
||||||
|
hop_factor: float = 0.25,
|
||||||
|
bands: Tuple[Tuple[float, float], ...] = (
|
||||||
|
(0.0, 0.1),
|
||||||
|
(0.1, 0.25),
|
||||||
|
(0.25, 0.5),
|
||||||
|
(0.5, 0.75),
|
||||||
|
(0.75, 1.0),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.window_length = window_length
|
||||||
|
self.hop_factor = hop_factor
|
||||||
|
self.spec_fn = Spectrogram(
|
||||||
|
n_fft=window_length,
|
||||||
|
hop_length=int(window_length * hop_factor),
|
||||||
|
win_length=window_length,
|
||||||
|
power=None,
|
||||||
|
)
|
||||||
|
n_fft = window_length // 2 + 1
|
||||||
|
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
|
||||||
|
self.bands = bands
|
||||||
|
convs = lambda: nn.ModuleList(
|
||||||
|
[
|
||||||
|
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
|
||||||
|
weight_norm(
|
||||||
|
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
||||||
|
),
|
||||||
|
weight_norm(
|
||||||
|
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
||||||
|
),
|
||||||
|
weight_norm(
|
||||||
|
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
||||||
|
),
|
||||||
|
weight_norm(
|
||||||
|
nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
|
||||||
|
|
||||||
|
if num_embeddings is not None:
|
||||||
|
self.emb = torch.nn.Embedding(
|
||||||
|
num_embeddings=num_embeddings, embedding_dim=channels
|
||||||
|
)
|
||||||
|
torch.nn.init.zeros_(self.emb.weight)
|
||||||
|
|
||||||
|
self.conv_post = weight_norm(
|
||||||
|
nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))
|
||||||
|
)
|
||||||
|
|
||||||
|
def spectrogram(self, x):
|
||||||
|
# Remove DC offset
|
||||||
|
x = x - x.mean(dim=-1, keepdims=True)
|
||||||
|
# Peak normalize the volume of input audio
|
||||||
|
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
|
||||||
|
x = self.spec_fn(x)
|
||||||
|
x = torch.view_as_real(x)
|
||||||
|
# x = rearrange(x, "b f t c -> b c t f")
|
||||||
|
x = x.permute(0, 3, 2, 1)
|
||||||
|
# Split into bands
|
||||||
|
x_bands = [x[..., b[0] : b[1]] for b in self.bands]
|
||||||
|
return x_bands
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
|
||||||
|
x_bands = self.spectrogram(x)
|
||||||
|
fmap = []
|
||||||
|
x = []
|
||||||
|
for band, stack in zip(x_bands, self.band_convs):
|
||||||
|
for i, layer in enumerate(stack):
|
||||||
|
band = layer(band)
|
||||||
|
band = torch.nn.functional.leaky_relu(band, 0.1)
|
||||||
|
if i > 0:
|
||||||
|
fmap.append(band)
|
||||||
|
x.append(band)
|
||||||
|
x = torch.cat(x, dim=-1)
|
||||||
|
if cond_embedding_id is not None:
|
||||||
|
emb = self.emb(cond_embedding_id)
|
||||||
|
h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
|
||||||
|
else:
|
||||||
|
h = 0
|
||||||
|
x = self.conv_post(x)
|
||||||
|
fmap.append(x)
|
||||||
|
x += h
|
||||||
|
|
||||||
|
return x, fmap
|
371
egs/libritts/TTS/vocos/export-onnx.py
Executable file
371
egs/libritts/TTS/vocos/export-onnx.py
Executable file
@ -0,0 +1,371 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang)
|
||||||
|
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script exports a transducer model from PyTorch to ONNX.
|
||||||
|
|
||||||
|
We use the pre-trained model from
|
||||||
|
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||||
|
as an example to show how to use this file.
|
||||||
|
|
||||||
|
1. Download the pre-trained model
|
||||||
|
|
||||||
|
cd egs/librispeech/ASR
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
|
||||||
|
cd exp
|
||||||
|
ln -s pretrained.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
2. Export the model to ONNX
|
||||||
|
|
||||||
|
./zipformer/export-onnx.py \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--num-encoder-layers "2,2,3,4,3,2" \
|
||||||
|
--downsampling-factor "1,2,4,8,4,2" \
|
||||||
|
--feedforward-dim "512,768,1024,1536,1024,768" \
|
||||||
|
--num-heads "4,4,4,8,4,4" \
|
||||||
|
--encoder-dim "192,256,384,512,384,256" \
|
||||||
|
--query-head-dim 32 \
|
||||||
|
--value-head-dim 12 \
|
||||||
|
--pos-head-dim 4 \
|
||||||
|
--pos-dim 48 \
|
||||||
|
--encoder-unmasked-dim "192,192,256,256,256,192" \
|
||||||
|
--cnn-module-kernel "31,31,15,15,15,31" \
|
||||||
|
--decoder-dim 512 \
|
||||||
|
--joiner-dim 512 \
|
||||||
|
--causal False \
|
||||||
|
--chunk-size "16,32,64,-1" \
|
||||||
|
--left-context-frames "64,128,256,-1" \
|
||||||
|
--fp16 True
|
||||||
|
It will generate the following 3 files inside $repo/exp:
|
||||||
|
|
||||||
|
- encoder-epoch-99-avg-1.onnx
|
||||||
|
- decoder-epoch-99-avg-1.onnx
|
||||||
|
- joiner-epoch-99-avg-1.onnx
|
||||||
|
|
||||||
|
See ./onnx_pretrained.py and ./onnx_check.py for how to
|
||||||
|
use the exported ONNX models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
|
import onnx
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from onnxconverter_common import float16
|
||||||
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.utils import make_pad_mask, num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sampling-rate",
|
||||||
|
type=int,
|
||||||
|
default=24000,
|
||||||
|
help="The sampleing rate of libritts dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--frame-shift",
|
||||||
|
type=int,
|
||||||
|
default=256,
|
||||||
|
help="Frame shift.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--frame-length",
|
||||||
|
type=int,
|
||||||
|
default=1024,
|
||||||
|
help="Frame shift.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=28,
|
||||||
|
help="""It specifies the checkpoint to use for averaging.
|
||||||
|
Note: Epoch counts from 0.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="zipformer/exp",
|
||||||
|
help="""It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--fp16",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to export models in fp16",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||||
|
"""Add meta data to an ONNX model. It is changed in-place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
Filename of the ONNX model to be changed.
|
||||||
|
meta_data:
|
||||||
|
Key-value pairs.
|
||||||
|
"""
|
||||||
|
model = onnx.load(filename)
|
||||||
|
for key, value in meta_data.items():
|
||||||
|
meta = model.metadata_props.add()
|
||||||
|
meta.key = key
|
||||||
|
meta.value = value
|
||||||
|
|
||||||
|
onnx.save(model, filename)
|
||||||
|
|
||||||
|
|
||||||
|
def export_model_onnx(
|
||||||
|
model: nn.Module,
|
||||||
|
model_filename: str,
|
||||||
|
opset_version: int = 13,
|
||||||
|
) -> None:
|
||||||
|
"""Export the joiner model to ONNX format.
|
||||||
|
The exported joiner model has two inputs:
|
||||||
|
|
||||||
|
- encoder_out: a tensor of shape (N, joiner_dim)
|
||||||
|
- decoder_out: a tensor of shape (N, joiner_dim)
|
||||||
|
|
||||||
|
and produces one output:
|
||||||
|
|
||||||
|
- logit: a tensor of shape (N, vocab_size)
|
||||||
|
"""
|
||||||
|
input_tensor = torch.rand((2, 80, 100), dtype=torch.float32)
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
model,
|
||||||
|
(input_tensor,),
|
||||||
|
model_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=[
|
||||||
|
"features",
|
||||||
|
],
|
||||||
|
output_names=["audio"],
|
||||||
|
dynamic_axes={
|
||||||
|
"features": {0: "N", 2: "F"},
|
||||||
|
"audio": {0: "N", 1: "T"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
meta_data = {
|
||||||
|
"model_type": "Vocos",
|
||||||
|
"version": "1",
|
||||||
|
"model_author": "k2-fsa",
|
||||||
|
"comment": "ConvNext Vocos",
|
||||||
|
}
|
||||||
|
logging.info(f"meta_data: {meta_data}")
|
||||||
|
|
||||||
|
add_meta_data(filename=model_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
params.device = device
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_model(params)
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
vocos = model.generator
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
suffix = f"iter-{params.iter}"
|
||||||
|
else:
|
||||||
|
suffix = f"epoch-{params.epoch}"
|
||||||
|
|
||||||
|
suffix += f"-avg-{params.avg}"
|
||||||
|
|
||||||
|
opset_version = 13
|
||||||
|
|
||||||
|
logging.info("Exporting model")
|
||||||
|
model_filename = params.exp_dir / f"vocos-{suffix}.onnx"
|
||||||
|
export_model_onnx(
|
||||||
|
vocos,
|
||||||
|
model_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
logging.info(f"Exported vocos generator to {model_filename}")
|
||||||
|
|
||||||
|
if params.fp16:
|
||||||
|
logging.info("Generate fp16 models")
|
||||||
|
|
||||||
|
model = onnx.load(model_filename)
|
||||||
|
model_fp16 = float16.convert_float_to_float16(model, keep_io_types=True)
|
||||||
|
model_filename_fp16 = params.exp_dir / f"vocos-{suffix}.fp16.onnx"
|
||||||
|
onnx.save(model_fp16, model_filename_fp16)
|
||||||
|
|
||||||
|
# Generate int8 quantization models
|
||||||
|
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||||
|
|
||||||
|
logging.info("Generate int8 quantization models")
|
||||||
|
|
||||||
|
model_filename_int8 = params.exp_dir / f"vocos-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=model_filename,
|
||||||
|
model_output=model_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
407
egs/libritts/TTS/vocos/export.py
Executable file
407
egs/libritts/TTS/vocos/export.py
Executable file
@ -0,0 +1,407 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2024 Xiaomi Corporation (Author: Wei Kang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# This script converts several saved checkpoints
|
||||||
|
# to a single one using model averaging.
|
||||||
|
"""
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
Note: This is a example for libritts dataset, if you are using different
|
||||||
|
dataset, you should change the argument values according to your dataset.
|
||||||
|
|
||||||
|
(1) Export to torchscript model using torch.jit.script()
|
||||||
|
|
||||||
|
|
||||||
|
./vocos/export.py \
|
||||||
|
--exp-dir ./vocos/exp \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 9 \
|
||||||
|
--jit 1
|
||||||
|
|
||||||
|
It will generate a file `jit_script.pt` in the given `exp_dir`. You can later
|
||||||
|
load it by `torch.jit.load("jit_script.pt")`.
|
||||||
|
|
||||||
|
Check ./jit_pretrained.py for its usage.
|
||||||
|
|
||||||
|
Check https://github.com/k2-fsa/sherpa
|
||||||
|
for how to use the exported models outside of icefall.
|
||||||
|
|
||||||
|
- For streaming model:
|
||||||
|
|
||||||
|
./zipformer/export.py \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--causal 1 \
|
||||||
|
--chunk-size 16 \
|
||||||
|
--left-context-frames 128 \
|
||||||
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 9 \
|
||||||
|
--jit 1
|
||||||
|
|
||||||
|
It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`.
|
||||||
|
You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`.
|
||||||
|
|
||||||
|
Check ./jit_pretrained_streaming.py for its usage.
|
||||||
|
|
||||||
|
Check https://github.com/k2-fsa/sherpa
|
||||||
|
for how to use the exported models outside of icefall.
|
||||||
|
|
||||||
|
(2) Export `model.state_dict()`
|
||||||
|
|
||||||
|
- For non-streaming model:
|
||||||
|
|
||||||
|
./zipformer/export.py \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 9
|
||||||
|
|
||||||
|
- For streaming model:
|
||||||
|
|
||||||
|
./zipformer/export.py \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--causal 1 \
|
||||||
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 9
|
||||||
|
|
||||||
|
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
|
||||||
|
load it by `icefall.checkpoint.load_checkpoint()`.
|
||||||
|
|
||||||
|
- For non-streaming model:
|
||||||
|
|
||||||
|
To use the generated file with `zipformer/decode.py`,
|
||||||
|
you can do:
|
||||||
|
|
||||||
|
cd /path/to/exp_dir
|
||||||
|
ln -s pretrained.pt epoch-9999.pt
|
||||||
|
|
||||||
|
cd /path/to/egs/librispeech/ASR
|
||||||
|
./zipformer/decode.py \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--epoch 9999 \
|
||||||
|
--avg 1 \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method greedy_search \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model
|
||||||
|
|
||||||
|
- For streaming model:
|
||||||
|
|
||||||
|
To use the generated file with `zipformer/decode.py` and `zipformer/streaming_decode.py`, you can do:
|
||||||
|
|
||||||
|
cd /path/to/exp_dir
|
||||||
|
ln -s pretrained.pt epoch-9999.pt
|
||||||
|
|
||||||
|
cd /path/to/egs/librispeech/ASR
|
||||||
|
|
||||||
|
# simulated streaming decoding
|
||||||
|
./zipformer/decode.py \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--epoch 9999 \
|
||||||
|
--avg 1 \
|
||||||
|
--max-duration 600 \
|
||||||
|
--causal 1 \
|
||||||
|
--chunk-size 16 \
|
||||||
|
--left-context-frames 128 \
|
||||||
|
--decoding-method greedy_search \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model
|
||||||
|
|
||||||
|
# chunk-wise streaming decoding
|
||||||
|
./zipformer/streaming_decode.py \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--epoch 9999 \
|
||||||
|
--avg 1 \
|
||||||
|
--max-duration 600 \
|
||||||
|
--causal 1 \
|
||||||
|
--chunk-size 16 \
|
||||||
|
--left-context-frames 128 \
|
||||||
|
--decoding-method greedy_search \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model
|
||||||
|
|
||||||
|
Check ./pretrained.py for its usage.
|
||||||
|
|
||||||
|
Note: If you don't want to train a model from scratch, we have
|
||||||
|
provided one for you. You can get it at
|
||||||
|
|
||||||
|
- non-streaming model:
|
||||||
|
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||||
|
|
||||||
|
- streaming model:
|
||||||
|
https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
|
||||||
|
|
||||||
|
with the following commands:
|
||||||
|
|
||||||
|
sudo apt-get install git-lfs
|
||||||
|
git lfs install
|
||||||
|
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||||
|
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
|
||||||
|
# You will find the pre-trained models in exp dir
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
)
|
||||||
|
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
from utils import load_checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sampling-rate",
|
||||||
|
type=int,
|
||||||
|
default=24000,
|
||||||
|
help="The sampleing rate of libritts dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--frame-shift",
|
||||||
|
type=int,
|
||||||
|
default=256,
|
||||||
|
help="Frame shift.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--frame-length",
|
||||||
|
type=int,
|
||||||
|
default=1024,
|
||||||
|
help="Frame shift.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
|
Note: Epoch counts from 1.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=9,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="vocos/exp",
|
||||||
|
help="""It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--jit",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""True to save a model after applying torch.jit.script.
|
||||||
|
It will generate a file named jit_script.pt.
|
||||||
|
Check ./jit_pretrained.py for how to use it.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderModel(nn.Module):
|
||||||
|
"""A wrapper for encoder and encoder_embed"""
|
||||||
|
|
||||||
|
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = encoder
|
||||||
|
self.encoder_embed = encoder_embed
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, features: Tensor, feature_lengths: Tensor
|
||||||
|
) -> Tuple[Tensor, Tensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
features: (N, T, C)
|
||||||
|
feature_lengths: (N,)
|
||||||
|
"""
|
||||||
|
x, x_lens = self.encoder_embed(features, feature_lengths)
|
||||||
|
|
||||||
|
src_key_padding_mask = make_pad_mask(x_lens)
|
||||||
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
|
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
||||||
|
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
|
||||||
|
return encoder_out, encoder_out_lens
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
params.device = device
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_model(params)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
model = model.generator
|
||||||
|
|
||||||
|
if params.jit is True:
|
||||||
|
model.encoder = EncoderModel(model.encoder, model.encoder_embed)
|
||||||
|
filename = "jit_script.pt"
|
||||||
|
|
||||||
|
logging.info("Using torch.jit.script")
|
||||||
|
model = torch.jit.script(model)
|
||||||
|
model.save(str(params.exp_dir / filename))
|
||||||
|
logging.info(f"Saved to {filename}")
|
||||||
|
else:
|
||||||
|
logging.info("Not using torchscript. Export model.state_dict()")
|
||||||
|
# Save it using a format so that it can be loaded
|
||||||
|
# by :func:`load_checkpoint`
|
||||||
|
filename = params.exp_dir / "generator.pt"
|
||||||
|
torch.save({"model": model.state_dict()}, str(filename))
|
||||||
|
logging.info(f"Saved to {filename}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
264
egs/libritts/TTS/vocos/generator.py
Normal file
264
egs/libritts/TTS/vocos/generator.py
Normal file
@ -0,0 +1,264 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.autograd import Variable
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def window_sumsquare(
|
||||||
|
window: torch.Tensor,
|
||||||
|
n_samples: int,
|
||||||
|
hop_length: int = 256,
|
||||||
|
win_length: int = 1024,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Compute the sum-square envelope of a window function at a given hop length.
|
||||||
|
This is used to estimate modulation effects induced by windowing
|
||||||
|
observations in short-time fourier transforms.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
window : string, tuple, number, callable, or list-like
|
||||||
|
Window specification, as in `get_window`
|
||||||
|
n_samples : int > 0
|
||||||
|
The number of expected samples.
|
||||||
|
hop_length : int > 0
|
||||||
|
The number of samples to advance between frames
|
||||||
|
win_length :
|
||||||
|
The length of the window function.
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
wss : torch.Tensor, The sum-squared envelope of the window function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
n_frames = (n_samples - win_length) // hop_length + 1
|
||||||
|
output_size = (n_frames - 1) * hop_length + win_length
|
||||||
|
device = window.device
|
||||||
|
|
||||||
|
# Window envelope
|
||||||
|
window_sq = window.square().expand(1, n_frames, -1).transpose(1, 2)
|
||||||
|
window_envelope = torch.nn.functional.fold(
|
||||||
|
window_sq,
|
||||||
|
output_size=(1, output_size),
|
||||||
|
kernel_size=(1, win_length),
|
||||||
|
stride=(1, hop_length),
|
||||||
|
).squeeze()
|
||||||
|
window_envelope = torch.nn.functional.pad(
|
||||||
|
window_envelope, (0, n_samples - output_size)
|
||||||
|
)
|
||||||
|
return window_envelope
|
||||||
|
|
||||||
|
|
||||||
|
class ISTFT(torch.nn.Module):
|
||||||
|
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
filter_length: int = 1024,
|
||||||
|
hop_length: int = 256,
|
||||||
|
win_length: int = 1024,
|
||||||
|
padding: str = "none",
|
||||||
|
window_type: str = "povey",
|
||||||
|
max_samples: int = 1440000, # 1440000 / 24000 = 60s
|
||||||
|
):
|
||||||
|
super(ISTFT, self).__init__()
|
||||||
|
self.filter_length = filter_length
|
||||||
|
self.hop_length = hop_length
|
||||||
|
self.win_length = win_length
|
||||||
|
self.padding = padding
|
||||||
|
scale = self.filter_length / self.hop_length
|
||||||
|
fourier_basis = np.fft.fft(np.eye(self.filter_length))
|
||||||
|
|
||||||
|
cutoff = int((self.filter_length / 2 + 1))
|
||||||
|
fourier_basis = np.vstack(
|
||||||
|
[np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
|
||||||
|
)
|
||||||
|
inverse_basis = torch.FloatTensor(
|
||||||
|
np.linalg.pinv(scale * fourier_basis).T[:, None, :]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert filter_length >= win_length
|
||||||
|
# Consistence with lhotse, search "create_frame_window" in https://github.com/lhotse-speech/lhotse
|
||||||
|
assert window_type in [
|
||||||
|
"hanning",
|
||||||
|
"povey",
|
||||||
|
], f"Only 'hanning' and 'povey' windows are supported, given {window_type}."
|
||||||
|
fft_window = torch.hann_window(win_length, periodic=False)
|
||||||
|
if window_type == "povey":
|
||||||
|
fft_window = fft_window.pow(0.85)
|
||||||
|
|
||||||
|
if filter_length > win_length:
|
||||||
|
pad_size = (filter_length - win_length) // 2
|
||||||
|
fft_window = torch.nn.functional.pad(fft_window, (pad_size, pad_size))
|
||||||
|
|
||||||
|
window_sum = window_sumsquare(
|
||||||
|
window=fft_window,
|
||||||
|
n_samples=max_samples,
|
||||||
|
hop_length=hop_length,
|
||||||
|
win_length=filter_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
inverse_basis *= fft_window
|
||||||
|
|
||||||
|
self.register_buffer("inverse_basis", inverse_basis.float())
|
||||||
|
self.register_buffer("fft_window", fft_window)
|
||||||
|
self.register_buffer("window_sum", window_sum)
|
||||||
|
self.tiny = torch.finfo(torch.float16).tiny
|
||||||
|
|
||||||
|
def forward(self, magnitude, phase):
|
||||||
|
magnitude_phase = torch.cat(
|
||||||
|
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
|
||||||
|
)
|
||||||
|
inverse_transform = F.conv_transpose1d(
|
||||||
|
magnitude_phase,
|
||||||
|
Variable(self.inverse_basis, requires_grad=False),
|
||||||
|
stride=self.hop_length,
|
||||||
|
padding=0,
|
||||||
|
)
|
||||||
|
inverse_transform = inverse_transform.squeeze(1)
|
||||||
|
|
||||||
|
window_sum = self.window_sum
|
||||||
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||||
|
if self.window_sum.size(-1) < inverse_transform.size(-1):
|
||||||
|
logging.warning(
|
||||||
|
f"The precomputed `window_sumsquare` is too small, recomputing, "
|
||||||
|
f"from {self.window_sum.size(-1)} to {inverse_transform.size(-1)}"
|
||||||
|
)
|
||||||
|
window_sum = window_sumsquare(
|
||||||
|
window=self.fft_window,
|
||||||
|
n_samples=inverse_transform.size(-1),
|
||||||
|
win_length=self.filter_length,
|
||||||
|
hop_length=self.hop_length,
|
||||||
|
)
|
||||||
|
window_sum = window_sum[: inverse_transform.size(-1)]
|
||||||
|
approx_nonzero_indices = (window_sum > self.tiny).nonzero().squeeze()
|
||||||
|
|
||||||
|
inverse_transform[:, approx_nonzero_indices] /= window_sum[
|
||||||
|
approx_nonzero_indices
|
||||||
|
]
|
||||||
|
|
||||||
|
# scale by hop ratio
|
||||||
|
inverse_transform *= float(self.filter_length) / self.hop_length
|
||||||
|
assert self.padding in ["none", "same", "center"]
|
||||||
|
if self.padding == "center":
|
||||||
|
pad_len = self.filter_length // 2
|
||||||
|
elif self.padding == "same":
|
||||||
|
pad_len = (self.filter_length - self.hop_length) // 2
|
||||||
|
else:
|
||||||
|
return inverse_transform
|
||||||
|
return inverse_transform[:, pad_len:-pad_len]
|
||||||
|
|
||||||
|
|
||||||
|
class ConvNeXtBlock(nn.Module):
|
||||||
|
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim (int): Number of input channels.
|
||||||
|
intermediate_dim (int): Dimensionality of the intermediate layer.
|
||||||
|
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
||||||
|
Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
intermediate_dim: int,
|
||||||
|
layer_scale_init_value: Optional[float] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dwconv = nn.Conv1d(
|
||||||
|
dim, dim, kernel_size=7, padding=3, groups=dim
|
||||||
|
) # depthwise conv
|
||||||
|
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
||||||
|
# pointwise/1x1 convs, implemented with linear layers
|
||||||
|
self.pwconv1 = nn.Linear(dim, intermediate_dim)
|
||||||
|
self.act = nn.GELU()
|
||||||
|
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
||||||
|
self.gamma = (
|
||||||
|
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
||||||
|
if layer_scale_init_value > 0
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
residual = x
|
||||||
|
x = self.dwconv(x)
|
||||||
|
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
||||||
|
x = self.norm(x)
|
||||||
|
x = self.pwconv1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.pwconv2(x)
|
||||||
|
if self.gamma is not None:
|
||||||
|
x = self.gamma * x
|
||||||
|
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
|
||||||
|
|
||||||
|
x = residual + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Generator(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
feature_dim: int = 80,
|
||||||
|
dim: int = 512,
|
||||||
|
n_fft: int = 1024,
|
||||||
|
hop_length: int = 256,
|
||||||
|
intermediate_dim: int = 1536,
|
||||||
|
num_layers: int = 8,
|
||||||
|
padding: str = "none",
|
||||||
|
max_samples: int = 1440000, # 1440000 / 24000 = 60s
|
||||||
|
):
|
||||||
|
super(Generator, self).__init__()
|
||||||
|
self.feature_dim = feature_dim
|
||||||
|
self.embed = nn.Conv1d(feature_dim, dim, kernel_size=7, padding=3)
|
||||||
|
|
||||||
|
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
||||||
|
|
||||||
|
layer_scale_init_value = 1 / num_layers
|
||||||
|
self.convnext = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ConvNeXtBlock(
|
||||||
|
dim=dim,
|
||||||
|
intermediate_dim=intermediate_dim,
|
||||||
|
layer_scale_init_value=layer_scale_init_value,
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
self.out_proj = torch.nn.Linear(dim, n_fft + 2)
|
||||||
|
self.istft = ISTFT(
|
||||||
|
filter_length=n_fft,
|
||||||
|
hop_length=hop_length,
|
||||||
|
win_length=n_fft,
|
||||||
|
padding=padding,
|
||||||
|
max_samples=max_samples,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_weights(self, m):
|
||||||
|
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
||||||
|
nn.init.trunc_normal_(m.weight, std=0.02)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.embed(x)
|
||||||
|
x = self.norm(x.transpose(1, 2))
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
for conv_block in self.convnext:
|
||||||
|
x = conv_block(x)
|
||||||
|
x = self.final_layer_norm(x.transpose(1, 2))
|
||||||
|
x = self.out_proj(x).transpose(1, 2)
|
||||||
|
mag, phase = x.chunk(2, dim=1)
|
||||||
|
mag = torch.exp(mag)
|
||||||
|
# safeguard to prevent excessively large magnitudes
|
||||||
|
mag = torch.clip(mag, max=1e2)
|
||||||
|
audio = self.istft(mag, phase)
|
||||||
|
return audio
|
352
egs/libritts/TTS/vocos/infer.py
Executable file
352
egs/libritts/TTS/vocos/infer.py
Executable file
@ -0,0 +1,352 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2024 Xiaomi Corp. (authors: Wei Kang
|
||||||
|
# Han Zhu)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from lhotse.utils import fix_random_seed
|
||||||
|
from scipy.io.wavfile import write
|
||||||
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
from tts_datamodule import LibriTTSDataModule
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||||
|
|
||||||
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
|
Note: Epoch counts from 1.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=10,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="vocos/exp",
|
||||||
|
help="The experiment dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--generate-dir",
|
||||||
|
type=str,
|
||||||
|
default="generated_wavs",
|
||||||
|
help="Path name of the generated wavs",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def decode_one_batch(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
batch: dict,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
It's the return value of :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The text-to-feature neural model.
|
||||||
|
batch:
|
||||||
|
It is the return value from iterating
|
||||||
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
|
for the format of the `batch`.
|
||||||
|
Returns:
|
||||||
|
Return the decoding result. See above description for the format of
|
||||||
|
the returned dict.
|
||||||
|
"""
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
|
cut_ids = [cut.id for cut in batch["cut"]]
|
||||||
|
|
||||||
|
infer_time = 0
|
||||||
|
audio_time = 0
|
||||||
|
|
||||||
|
features = batch["features"] # (B, T, F)
|
||||||
|
utt_durations = batch["features_lens"]
|
||||||
|
|
||||||
|
x = features.permute(0, 2, 1) # (B, F, T)
|
||||||
|
|
||||||
|
audio_time += torch.sum(utt_durations)
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
|
||||||
|
audios = model(x.to(device)) # (B, T)
|
||||||
|
|
||||||
|
infer_time += time.time() - start
|
||||||
|
|
||||||
|
wav_dir = f"{params.res_dir}/{params.suffix}"
|
||||||
|
os.makedirs(wav_dir, exist_ok=True)
|
||||||
|
|
||||||
|
for i in range(audios.shape[0]):
|
||||||
|
audio = audios[i][: int(utt_durations[i] * 256)]
|
||||||
|
audio = audio.cpu().squeeze().numpy()
|
||||||
|
write(f"{wav_dir}/{cut_ids[i]}.wav", 24000, audio)
|
||||||
|
|
||||||
|
print(f"RTF : {infer_time / (audio_time * (256/24000))}")
|
||||||
|
|
||||||
|
|
||||||
|
def decode_dataset(
|
||||||
|
dl: torch.utils.data.DataLoader,
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
test_set: str,
|
||||||
|
):
|
||||||
|
"""Decode dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dl:
|
||||||
|
PyTorch's dataloader containing the dataset to decode.
|
||||||
|
params:
|
||||||
|
It is returned by :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The text-to-feature neural model.
|
||||||
|
test_set:
|
||||||
|
The name of the test_set
|
||||||
|
"""
|
||||||
|
num_cuts = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
num_batches = len(dl)
|
||||||
|
except TypeError:
|
||||||
|
num_batches = "?"
|
||||||
|
|
||||||
|
with open(f"{params.res_dir}/{test_set}.scp", "w", encoding="utf8") as f:
|
||||||
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
# texts = batch["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["cut"]]
|
||||||
|
|
||||||
|
decode_one_batch(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
batch=batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
# assert len(texts) == len(cut_ids), (len(texts), len(cut_ids))
|
||||||
|
|
||||||
|
# for i in range(len(texts)):
|
||||||
|
# f.write(f"{cut_ids[i]}\t{texts[i]}\n")
|
||||||
|
|
||||||
|
# num_cuts += len(texts)
|
||||||
|
|
||||||
|
if batch_idx % 50 == 0:
|
||||||
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
LibriTTSDataModule.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
params.res_dir = params.exp_dir / params.generate_dir
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||||
|
else:
|
||||||
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
|
if params.use_averaged_model:
|
||||||
|
params.suffix += "-use-averaged-model"
|
||||||
|
|
||||||
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
|
logging.info("Decoding started")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
params.device = device
|
||||||
|
|
||||||
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
fix_random_seed(666)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_model(params)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model = model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
|
libritts = LibriTTSDataModule(args)
|
||||||
|
|
||||||
|
test_cuts = libritts.test_clean_cuts()
|
||||||
|
|
||||||
|
test_dl = libritts.test_dataloaders(test_cuts)
|
||||||
|
|
||||||
|
test_sets = ["test"]
|
||||||
|
test_dls = [test_dl]
|
||||||
|
|
||||||
|
for test_set, test_dl in zip(test_sets, test_dls):
|
||||||
|
decode_dataset(
|
||||||
|
dl=test_dl,
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
test_set=test_set,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
133
egs/libritts/TTS/vocos/loss.py
Normal file
133
egs/libritts/TTS/vocos/loss.py
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from utils import safe_log
|
||||||
|
|
||||||
|
|
||||||
|
class MelSpecReconstructionLoss(nn.Module):
|
||||||
|
"""
|
||||||
|
L1 distance between the mel-scaled magnitude spectrograms of the ground truth sample and the generated sample
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sample_rate: int = 24000,
|
||||||
|
n_fft: int = 1024,
|
||||||
|
hop_length: int = 256,
|
||||||
|
n_mels: int = 100,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.mel_spec = torchaudio.transforms.MelSpectrogram(
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
n_fft=n_fft,
|
||||||
|
hop_length=hop_length,
|
||||||
|
n_mels=n_mels,
|
||||||
|
center=True,
|
||||||
|
power=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, y_hat, y) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
y_hat (Tensor): Predicted audio waveform.
|
||||||
|
y (Tensor): Ground truth audio waveform.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: L1 loss between the mel-scaled magnitude spectrograms.
|
||||||
|
"""
|
||||||
|
mel_hat = safe_log(self.mel_spec(y_hat))
|
||||||
|
mel = safe_log(self.mel_spec(y))
|
||||||
|
|
||||||
|
loss = torch.nn.functional.l1_loss(mel, mel_hat)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class GeneratorLoss(nn.Module):
|
||||||
|
"""
|
||||||
|
Generator Loss module. Calculates the loss for the generator based on discriminator outputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, disc_outputs: List[torch.Tensor]
|
||||||
|
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
disc_outputs (List[Tensor]): List of discriminator outputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from
|
||||||
|
the sub-discriminators
|
||||||
|
"""
|
||||||
|
loss = torch.zeros(
|
||||||
|
1, device=disc_outputs[0].device, dtype=disc_outputs[0].dtype
|
||||||
|
)
|
||||||
|
gen_losses = []
|
||||||
|
for dg in disc_outputs:
|
||||||
|
l = torch.mean(torch.clamp(1 - dg, min=0))
|
||||||
|
gen_losses.append(l)
|
||||||
|
loss += l
|
||||||
|
|
||||||
|
return loss, gen_losses
|
||||||
|
|
||||||
|
|
||||||
|
class DiscriminatorLoss(nn.Module):
|
||||||
|
"""
|
||||||
|
Discriminator Loss module. Calculates the loss for the discriminator based on real and generated outputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
disc_real_outputs: List[torch.Tensor],
|
||||||
|
disc_generated_outputs: List[torch.Tensor],
|
||||||
|
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
disc_real_outputs (List[Tensor]): List of discriminator outputs for real samples.
|
||||||
|
disc_generated_outputs (List[Tensor]): List of discriminator outputs for generated samples.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Tensor, List[Tensor], List[Tensor]]: A tuple containing the total loss, a list of loss values from
|
||||||
|
the sub-discriminators for real outputs, and a list of
|
||||||
|
loss values for generated outputs.
|
||||||
|
"""
|
||||||
|
loss = torch.zeros(
|
||||||
|
1, device=disc_real_outputs[0].device, dtype=disc_real_outputs[0].dtype
|
||||||
|
)
|
||||||
|
r_losses = []
|
||||||
|
g_losses = []
|
||||||
|
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||||
|
r_loss = torch.mean(torch.clamp(1 - dr, min=0))
|
||||||
|
g_loss = torch.mean(torch.clamp(1 + dg, min=0))
|
||||||
|
loss += r_loss + g_loss
|
||||||
|
r_losses.append(r_loss.item())
|
||||||
|
g_losses.append(g_loss.item())
|
||||||
|
|
||||||
|
return loss, r_losses, g_losses
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureMatchingLoss(nn.Module):
|
||||||
|
"""
|
||||||
|
Feature Matching Loss module. Calculates the feature matching loss between feature maps of the sub-discriminators.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
fmap_r (List[List[Tensor]]): List of feature maps from real samples.
|
||||||
|
fmap_g (List[List[Tensor]]): List of feature maps from generated samples.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: The calculated feature matching loss.
|
||||||
|
"""
|
||||||
|
loss = torch.zeros(1, device=fmap_r[0][0].device, dtype=fmap_r[0][0].dtype)
|
||||||
|
for dr, dg in zip(fmap_r, fmap_g):
|
||||||
|
for rl, gl in zip(dr, dg):
|
||||||
|
loss += torch.mean(torch.abs(rl - gl))
|
||||||
|
|
||||||
|
return loss
|
48
egs/libritts/TTS/vocos/model.py
Normal file
48
egs/libritts/TTS/vocos/model.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
from discriminators import MultiPeriodDiscriminator, MultiResolutionDiscriminator
|
||||||
|
from generator import Generator
|
||||||
|
from loss import (
|
||||||
|
DiscriminatorLoss,
|
||||||
|
GeneratorLoss,
|
||||||
|
FeatureMatchingLoss,
|
||||||
|
MelSpecReconstructionLoss,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Vocos(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
feature_dim: int = 80,
|
||||||
|
dim: int = 512,
|
||||||
|
n_fft: int = 1024,
|
||||||
|
hop_length: int = 256,
|
||||||
|
intermediate_dim: int = 1536,
|
||||||
|
num_layers: int = 8,
|
||||||
|
padding: str = "none",
|
||||||
|
sample_rate: int = 24000,
|
||||||
|
max_seconds: int = 60,
|
||||||
|
):
|
||||||
|
super(Vocos, self).__init__()
|
||||||
|
self.generator = Generator(
|
||||||
|
feature_dim=feature_dim,
|
||||||
|
dim=dim,
|
||||||
|
n_fft=n_fft,
|
||||||
|
hop_length=hop_length,
|
||||||
|
num_layers=num_layers,
|
||||||
|
intermediate_dim=intermediate_dim,
|
||||||
|
padding=padding,
|
||||||
|
max_samples=int(sample_rate * max_seconds),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mpd = MultiPeriodDiscriminator()
|
||||||
|
self.mrd = MultiResolutionDiscriminator()
|
||||||
|
|
||||||
|
self.disc_loss = DiscriminatorLoss()
|
||||||
|
self.gen_loss = GeneratorLoss()
|
||||||
|
self.feat_matching_loss = FeatureMatchingLoss()
|
||||||
|
self.melspec_loss = MelSpecReconstructionLoss(sample_rate=sample_rate)
|
||||||
|
|
||||||
|
def forward(self, features: torch.Tensor):
|
||||||
|
audio = self.generator(features)
|
||||||
|
return audio
|
268
egs/libritts/TTS/vocos/onnx_pretrained.py
Executable file
268
egs/libritts/TTS/vocos/onnx_pretrained.py
Executable file
@ -0,0 +1,268 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This script loads ONNX models and uses them to decode waves.
|
||||||
|
You can use the following command to get the exported models:
|
||||||
|
|
||||||
|
We use the pre-trained model from
|
||||||
|
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||||
|
as an example to show how to use this file.
|
||||||
|
|
||||||
|
1. Download the pre-trained model
|
||||||
|
|
||||||
|
cd egs/librispeech/ASR
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
|
||||||
|
cd exp
|
||||||
|
ln -s pretrained.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
2. Export the model to ONNX
|
||||||
|
|
||||||
|
./zipformer/export-onnx.py \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--causal False
|
||||||
|
|
||||||
|
It will generate the following 3 files inside $repo/exp:
|
||||||
|
|
||||||
|
- encoder-epoch-99-avg-1.onnx
|
||||||
|
- decoder-epoch-99-avg-1.onnx
|
||||||
|
- joiner-epoch-99-avg-1.onnx
|
||||||
|
|
||||||
|
3. Run this file
|
||||||
|
|
||||||
|
./zipformer/onnx_pretrained.py \
|
||||||
|
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||||
|
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||||
|
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import onnxruntime as ort
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
from lhotse import Fbank, FbankConfig
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the encoder onnx model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sampling-rate",
|
||||||
|
type=int,
|
||||||
|
default=24000,
|
||||||
|
help="The sampleing rate of libritts dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--frame-shift",
|
||||||
|
type=int,
|
||||||
|
default=256,
|
||||||
|
help="Frame shift.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--frame-length",
|
||||||
|
type=int,
|
||||||
|
default=1024,
|
||||||
|
help="Frame shift.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-fft-mag",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to use magnitude of fbank, false to use power energy.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir",
|
||||||
|
type=str,
|
||||||
|
default="generated_audios",
|
||||||
|
help="The generated will be written to.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"sound_files",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="The input sound file(s) to transcribe. "
|
||||||
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
|
"For example, wav and flac are supported. "
|
||||||
|
"The sample rate has to be 16kHz.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxModel:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_filename: str,
|
||||||
|
):
|
||||||
|
session_opts = ort.SessionOptions()
|
||||||
|
session_opts.inter_op_num_threads = 1
|
||||||
|
session_opts.intra_op_num_threads = 4
|
||||||
|
|
||||||
|
self.session_opts = session_opts
|
||||||
|
|
||||||
|
self.init_model(model_filename)
|
||||||
|
|
||||||
|
def init_model(self, model_filename: str):
|
||||||
|
self.model = ort.InferenceSession(
|
||||||
|
model_filename,
|
||||||
|
sess_options=self.session_opts,
|
||||||
|
providers=["CPUExecutionProvider"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_model(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
A 3-D tensor of shape (N, T, C)
|
||||||
|
x_lens:
|
||||||
|
A 2-D tensor of shape (N,). Its dtype is torch.int64
|
||||||
|
Returns:
|
||||||
|
Return a tuple containing:
|
||||||
|
- encoder_out, its shape is (N, T', joiner_dim)
|
||||||
|
- encoder_out_lens, its shape is (N,)
|
||||||
|
"""
|
||||||
|
out = self.model.run(
|
||||||
|
[
|
||||||
|
self.model.get_outputs()[0].name,
|
||||||
|
],
|
||||||
|
{
|
||||||
|
self.model.get_inputs()[0].name: x.numpy(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return torch.from_numpy(out[0])
|
||||||
|
|
||||||
|
|
||||||
|
def read_sound_files(
|
||||||
|
filenames: List[str], expected_sample_rate: float
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||||
|
Args:
|
||||||
|
filenames:
|
||||||
|
A list of sound filenames.
|
||||||
|
expected_sample_rate:
|
||||||
|
The expected sample rate of the sound files.
|
||||||
|
Returns:
|
||||||
|
Return a list of 1-D float32 torch tensors.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
for f in filenames:
|
||||||
|
wave, sample_rate = torchaudio.load(f)
|
||||||
|
assert (
|
||||||
|
sample_rate == expected_sample_rate
|
||||||
|
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||||
|
# We use only the first channel
|
||||||
|
ans.append(wave[0])
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
output_dir = Path(args.model_filename).parent / args.output_dir
|
||||||
|
output_dir.mkdir(exist_ok=True)
|
||||||
|
args.output_dir = output_dir
|
||||||
|
logging.info(vars(args))
|
||||||
|
|
||||||
|
model = OnnxModel(model_filename=args.model_filename)
|
||||||
|
|
||||||
|
config = FbankConfig(
|
||||||
|
sampling_rate=args.sampling_rate,
|
||||||
|
frame_length=args.frame_length / args.sampling_rate, # (in second),
|
||||||
|
frame_shift=args.frame_shift / args.sampling_rate, # (in second)
|
||||||
|
use_fft_mag=args.use_fft_mag,
|
||||||
|
)
|
||||||
|
fbank = Fbank(config)
|
||||||
|
|
||||||
|
logging.info(f"Reading sound files: {args.sound_files}")
|
||||||
|
|
||||||
|
waves = read_sound_files(
|
||||||
|
filenames=args.sound_files, expected_sample_rate=args.sampling_rate
|
||||||
|
)
|
||||||
|
wave_lengths = [w.size(0) for w in waves]
|
||||||
|
waves = pad_sequence(waves, batch_first=True, padding_value=0)
|
||||||
|
|
||||||
|
logging.info(f"waves : {waves.shape}")
|
||||||
|
|
||||||
|
features = fbank.extract_batch(waves, sampling_rate=args.sampling_rate)
|
||||||
|
|
||||||
|
if features.dim() == 2:
|
||||||
|
features = features.unsqueeze(0)
|
||||||
|
|
||||||
|
features = features.permute(0, 2, 1)
|
||||||
|
|
||||||
|
logging.info(f"features : {features.shape}")
|
||||||
|
|
||||||
|
logging.info("Generating started")
|
||||||
|
|
||||||
|
# model forward
|
||||||
|
audios = model.run_model(features)
|
||||||
|
|
||||||
|
for i, filename in enumerate(args.sound_files):
|
||||||
|
audio = audios[i : i + 1, 0 : wave_lengths[i]]
|
||||||
|
ofilename = args.output_dir / filename.split("/")[-1]
|
||||||
|
logging.info(f"Writting audio : {ofilename}")
|
||||||
|
torchaudio.save(str(ofilename), audio.cpu(), args.sampling_rate)
|
||||||
|
|
||||||
|
logging.info("Generating Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
196
egs/libritts/TTS/vocos/pretrained.py
Executable file
196
egs/libritts/TTS/vocos/pretrained.py
Executable file
@ -0,0 +1,196 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2024 Xiaomi Corp. (authors: Wei Kang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This script loads a checkpoint and uses it to decode waves.
|
||||||
|
You can generate the checkpoint with the following command:
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
from lhotse import Fbank, FbankConfig
|
||||||
|
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the checkpoint. "
|
||||||
|
"The checkpoint is assumed to be saved by "
|
||||||
|
"icefall.checkpoint.save_checkpoint().",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sampling-rate",
|
||||||
|
type=int,
|
||||||
|
default=24000,
|
||||||
|
help="The sampleing rate of libritts dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--frame-shift",
|
||||||
|
type=int,
|
||||||
|
default=256,
|
||||||
|
help="Frame shift.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--frame-length",
|
||||||
|
type=int,
|
||||||
|
default=1024,
|
||||||
|
help="Frame shift.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-fft-mag",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to use magnitude of fbank, false to use power energy.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir",
|
||||||
|
type=str,
|
||||||
|
default="generated_audios",
|
||||||
|
help="The generated will be written to.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"sound_files",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="The input sound file(s) to transcribe. "
|
||||||
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
|
"For example, wav and flac are supported. "
|
||||||
|
"The sample rate has to be 16kHz.",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def read_sound_files(
|
||||||
|
filenames: List[str], expected_sample_rate: float
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||||
|
Args:
|
||||||
|
filenames:
|
||||||
|
A list of sound filenames.
|
||||||
|
expected_sample_rate:
|
||||||
|
The expected sample rate of the sound files.
|
||||||
|
Returns:
|
||||||
|
Return a list of 1-D float32 torch tensors.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
for f in filenames:
|
||||||
|
wave, sample_rate = torchaudio.load(f)
|
||||||
|
assert (
|
||||||
|
sample_rate == expected_sample_rate
|
||||||
|
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||||
|
# We use only the first channel
|
||||||
|
ans.append(wave[0].contiguous())
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
params.device = device
|
||||||
|
|
||||||
|
output_dir = Path(params.checkpoint).parent / params.output_dir
|
||||||
|
output_dir.mkdir(exist_ok=True)
|
||||||
|
params.output_dir = output_dir
|
||||||
|
|
||||||
|
logging.info(f"{params}")
|
||||||
|
|
||||||
|
logging.info("Creating model")
|
||||||
|
model = get_model(params)
|
||||||
|
|
||||||
|
model = model.generator
|
||||||
|
|
||||||
|
checkpoint = torch.load(params.checkpoint, map_location="cpu")
|
||||||
|
model.load_state_dict(checkpoint["model"], strict=False)
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
logging.info("Constructing Fbank computer")
|
||||||
|
|
||||||
|
config = FbankConfig(
|
||||||
|
sampling_rate=params.sampling_rate,
|
||||||
|
frame_length=params.frame_length / params.sampling_rate, # (in second),
|
||||||
|
frame_shift=params.frame_shift / params.sampling_rate, # (in second)
|
||||||
|
use_fft_mag=params.use_fft_mag,
|
||||||
|
)
|
||||||
|
fbank = Fbank(config)
|
||||||
|
|
||||||
|
logging.info(f"Reading sound files: {params.sound_files}")
|
||||||
|
|
||||||
|
waves = read_sound_files(
|
||||||
|
filenames=params.sound_files, expected_sample_rate=params.sampling_rate
|
||||||
|
)
|
||||||
|
wave_lengths = [w.size(0) for w in waves]
|
||||||
|
waves = pad_sequence(waves, batch_first=True, padding_value=0)
|
||||||
|
|
||||||
|
features = (
|
||||||
|
fbank.extract_batch(waves, sampling_rate=params.sampling_rate)
|
||||||
|
.permute(0, 2, 1)
|
||||||
|
.to(device)
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Generating started")
|
||||||
|
|
||||||
|
# model forward
|
||||||
|
audios = model(features)
|
||||||
|
|
||||||
|
for i, filename in enumerate(params.sound_files):
|
||||||
|
audio = audios[i : i + 1, 0 : wave_lengths[i]]
|
||||||
|
ofilename = params.output_dir / filename.split("/")[-1]
|
||||||
|
logging.info(f"Writting audio : {ofilename}")
|
||||||
|
torchaudio.save(str(ofilename), audio.cpu(), params.sampling_rate)
|
||||||
|
|
||||||
|
logging.info("Generating Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
1054
egs/libritts/TTS/vocos/train.py
Executable file
1054
egs/libritts/TTS/vocos/train.py
Executable file
File diff suppressed because it is too large
Load Diff
419
egs/libritts/TTS/vocos/tts_datamodule.py
Normal file
419
egs/libritts/TTS/vocos/tts_datamodule.py
Normal file
@ -0,0 +1,419 @@
|
|||||||
|
# Copyright 2021 Piotr Żelasko
|
||||||
|
# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo,
|
||||||
|
# Zengwei Yao,
|
||||||
|
# Wei Kang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
|
||||||
|
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||||
|
CutConcatenate,
|
||||||
|
CutMix,
|
||||||
|
DynamicBucketingSampler,
|
||||||
|
PrecomputedFeatures,
|
||||||
|
SimpleCutSampler,
|
||||||
|
SpecAugment,
|
||||||
|
SpeechSynthesisDataset,
|
||||||
|
)
|
||||||
|
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||||
|
AudioSamples,
|
||||||
|
OnTheFlyFeatures,
|
||||||
|
)
|
||||||
|
from lhotse.utils import fix_random_seed
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
|
class _SeedWorkers:
|
||||||
|
def __init__(self, seed: int):
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
def __call__(self, worker_id: int):
|
||||||
|
fix_random_seed(self.seed + worker_id)
|
||||||
|
|
||||||
|
|
||||||
|
class LibriTTSDataModule:
|
||||||
|
"""
|
||||||
|
DataModule for tts experiments.
|
||||||
|
It assumes there is always one train and valid dataloader,
|
||||||
|
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||||
|
and test-other).
|
||||||
|
|
||||||
|
It contains all the common data pipeline modules used in ASR
|
||||||
|
experiments, e.g.:
|
||||||
|
- dynamic batch size,
|
||||||
|
- bucketing samplers,
|
||||||
|
- cut concatenation,
|
||||||
|
- on-the-fly feature extraction
|
||||||
|
|
||||||
|
This class should be derived for specific corpora used in ASR tasks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, args: argparse.Namespace):
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||||
|
group = parser.add_argument_group(
|
||||||
|
title="TTS data related options",
|
||||||
|
description="These options are used for the preparation of "
|
||||||
|
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||||
|
"effective batch sizes, sampling strategies, applied data "
|
||||||
|
"augmentations, etc.",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--manifest-dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("data/fbank"),
|
||||||
|
help="Path to directory with train/valid/test cuts.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--max-duration",
|
||||||
|
type=int,
|
||||||
|
default=200.0,
|
||||||
|
help="Maximum pooled recordings duration (seconds) in a "
|
||||||
|
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--bucketing-sampler",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="When enabled, the batches will come from buckets of "
|
||||||
|
"similar duration (saves padding frames).",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--num-buckets",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="The number of buckets for the DynamicBucketingSampler"
|
||||||
|
"(you might want to increase it for larger datasets).",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--on-the-fly-feats",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="When enabled, use on-the-fly cut mixing and feature "
|
||||||
|
"extraction. Will drop existing precomputed feature manifests "
|
||||||
|
"if available.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--shuffle",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="When enabled (=default), the examples will be "
|
||||||
|
"shuffled for each epoch.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--drop-last",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to drop last batch. Used by sampler.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--return-cuts",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="When enabled, each batch will have the "
|
||||||
|
"field: batch['cut'] with the cuts that "
|
||||||
|
"were used to construct it.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--return-text",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to return the text of the audio.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--return-tokens",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether the return the tokens of the text of the audio.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--num-workers",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="The number of training dataloader workers that "
|
||||||
|
"collect the batches.",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--sampling-rate",
|
||||||
|
type=int,
|
||||||
|
default=24000,
|
||||||
|
help="The sampleing rate of libritts dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--frame-shift",
|
||||||
|
type=int,
|
||||||
|
default=256,
|
||||||
|
help="Frame shift.",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--frame-length",
|
||||||
|
type=int,
|
||||||
|
default=1024,
|
||||||
|
help="Frame shift.",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--input-strategy",
|
||||||
|
type=str,
|
||||||
|
default="PrecomputedFeatures",
|
||||||
|
help="AudioSamples or PrecomputedFeatures",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--use-fft-mag",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to use magnitude of fbank, false to use power energy.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def train_dataloaders(
|
||||||
|
self,
|
||||||
|
cuts_train: CutSet,
|
||||||
|
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> DataLoader:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
cuts_train:
|
||||||
|
CutSet for training.
|
||||||
|
sampler_state_dict:
|
||||||
|
The state dict for the training sampler.
|
||||||
|
"""
|
||||||
|
logging.info("About to create train dataset")
|
||||||
|
train = SpeechSynthesisDataset(
|
||||||
|
return_text=self.args.return_text,
|
||||||
|
return_tokens=self.args.return_tokens,
|
||||||
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.on_the_fly_feats:
|
||||||
|
sampling_rate = self.args.sampling_rate
|
||||||
|
config = FbankConfig(
|
||||||
|
sampling_rate=sampling_rate,
|
||||||
|
frame_length=self.args.frame_length / sampling_rate, # (in second),
|
||||||
|
frame_shift=self.args.frame_shift / sampling_rate, # (in second)
|
||||||
|
use_fft_mag=self.args.use_fft_mag,
|
||||||
|
)
|
||||||
|
train = SpeechSynthesisDataset(
|
||||||
|
return_text=self.args.return_text,
|
||||||
|
return_tokens=self.args.return_tokens,
|
||||||
|
feature_input_strategy=OnTheFlyFeatures(Fbank(config)),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.bucketing_sampler:
|
||||||
|
logging.info("Using DynamicBucketingSampler.")
|
||||||
|
train_sampler = DynamicBucketingSampler(
|
||||||
|
cuts_train,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=self.args.shuffle,
|
||||||
|
num_buckets=self.args.num_buckets,
|
||||||
|
buffer_size=self.args.num_buckets * 2000,
|
||||||
|
shuffle_buffer_size=self.args.num_buckets * 5000,
|
||||||
|
drop_last=self.args.drop_last,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.info("Using SimpleCutSampler.")
|
||||||
|
train_sampler = SimpleCutSampler(
|
||||||
|
cuts_train,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=self.args.shuffle,
|
||||||
|
)
|
||||||
|
logging.info("About to create train dataloader")
|
||||||
|
|
||||||
|
if sampler_state_dict is not None:
|
||||||
|
logging.info("Loading sampler state dict")
|
||||||
|
train_sampler.load_state_dict(sampler_state_dict)
|
||||||
|
|
||||||
|
# 'seed' is derived from the current random state, which will have
|
||||||
|
# previously been set in the main process.
|
||||||
|
seed = torch.randint(0, 100000, ()).item()
|
||||||
|
worker_init_fn = _SeedWorkers(seed)
|
||||||
|
|
||||||
|
train_dl = DataLoader(
|
||||||
|
train,
|
||||||
|
sampler=train_sampler,
|
||||||
|
batch_size=None,
|
||||||
|
num_workers=self.args.num_workers,
|
||||||
|
persistent_workers=False,
|
||||||
|
worker_init_fn=worker_init_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_dl
|
||||||
|
|
||||||
|
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||||
|
logging.info("About to create dev dataset")
|
||||||
|
if self.args.on_the_fly_feats:
|
||||||
|
sampling_rate = self.args.sampling_rate
|
||||||
|
config = FbankConfig(
|
||||||
|
sampling_rate=sampling_rate,
|
||||||
|
frame_length=self.args.frame_length / sampling_rate, # (in second),
|
||||||
|
frame_shift=self.args.frame_shift / sampling_rate, # (in second)
|
||||||
|
use_fft_mag=self.args.use_fft_mag,
|
||||||
|
)
|
||||||
|
validate = SpeechSynthesisDataset(
|
||||||
|
return_text=self.args.return_text,
|
||||||
|
return_tokens=self.args.return_tokens,
|
||||||
|
feature_input_strategy=OnTheFlyFeatures(Fbank(config)),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
validate = SpeechSynthesisDataset(
|
||||||
|
return_text=self.args.return_text,
|
||||||
|
return_tokens=self.args.return_tokens,
|
||||||
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
valid_sampler = DynamicBucketingSampler(
|
||||||
|
cuts_valid,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
num_buckets=self.args.num_buckets,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
logging.info("About to create valid dataloader")
|
||||||
|
valid_dl = DataLoader(
|
||||||
|
validate,
|
||||||
|
sampler=valid_sampler,
|
||||||
|
batch_size=None,
|
||||||
|
num_workers=2,
|
||||||
|
persistent_workers=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return valid_dl
|
||||||
|
|
||||||
|
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||||
|
logging.info("About to create test dataset")
|
||||||
|
if self.args.on_the_fly_feats:
|
||||||
|
sampling_rate = self.args.sampling_rate
|
||||||
|
config = FbankConfig(
|
||||||
|
sampling_rate=sampling_rate,
|
||||||
|
frame_length=self.args.frame_length / sampling_rate, # (in second),
|
||||||
|
frame_shift=self.args.frame_shift / sampling_rate, # (in second)
|
||||||
|
use_fft_mag=self.args.use_fft_mag,
|
||||||
|
)
|
||||||
|
test = SpeechSynthesisDataset(
|
||||||
|
return_text=self.args.return_text,
|
||||||
|
return_tokens=self.args.return_tokens,
|
||||||
|
feature_input_strategy=OnTheFlyFeatures(Fbank(config)),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
test = SpeechSynthesisDataset(
|
||||||
|
return_text=self.args.return_text,
|
||||||
|
return_tokens=self.args.return_tokens,
|
||||||
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
test_sampler = DynamicBucketingSampler(
|
||||||
|
cuts,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
num_buckets=self.args.num_buckets,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
logging.info("About to create test dataloader")
|
||||||
|
test_dl = DataLoader(
|
||||||
|
test,
|
||||||
|
batch_size=None,
|
||||||
|
sampler=test_sampler,
|
||||||
|
num_workers=self.args.num_workers,
|
||||||
|
)
|
||||||
|
return test_dl
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def train_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get train cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "libritts_cuts_train-all-shuf.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def train_clean_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get train clean cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "libritts_cuts_train-clean-460.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def train_clean_100_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get train clean 100 cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "libritts_cuts_train-clean-100.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def train_clean_360_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get train clean 360 cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "libritts_cuts_train-clean-360.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def dev_clean_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get dev clean cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "libritts_cuts_dev-clean.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def dev_other_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get dev other cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "libritts_cuts_dev-other.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def test_clean_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get test clean cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "libritts_cuts_test-clean.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def test_other_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get test other cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "libritts_cuts_test-other.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def train_cuts_finetune(self) -> CutSet:
|
||||||
|
logging.info("About to get train cuts finetune")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "libritts_cuts_train_finetune.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def valid_cuts_finetune(self) -> CutSet:
|
||||||
|
logging.info("About to get validation cuts finetune")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "libritts_cuts_valid_finetune.jsonl.gz"
|
||||||
|
)
|
282
egs/libritts/TTS/vocos/utils.py
Normal file
282
egs/libritts/TTS/vocos/utils.py
Normal file
@ -0,0 +1,282 @@
|
|||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import matplotlib
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
from torch.nn.utils import weight_norm
|
||||||
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.cuda.amp import GradScaler
|
||||||
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
|
|
||||||
|
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
import matplotlib.pylab as plt
|
||||||
|
|
||||||
|
|
||||||
|
def plot_spectrogram(spectrogram):
|
||||||
|
fig, ax = plt.subplots(figsize=(10, 2))
|
||||||
|
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
||||||
|
plt.colorbar(im, ax=ax)
|
||||||
|
|
||||||
|
fig.canvas.draw()
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def save_checkpoint_with_global_batch_idx(
|
||||||
|
out_dir: Path,
|
||||||
|
global_batch_idx: int,
|
||||||
|
model: Union[nn.Module, DDP],
|
||||||
|
model_avg: Optional[nn.Module] = None,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
optimizer_g: Optional[Optimizer] = None,
|
||||||
|
optimizer_d: Optional[Optimizer] = None,
|
||||||
|
scheduler_g: Optional[LRScheduler] = None,
|
||||||
|
scheduler_d: Optional[LRScheduler] = None,
|
||||||
|
scaler: Optional[GradScaler] = None,
|
||||||
|
sampler: Optional[CutSampler] = None,
|
||||||
|
rank: int = 0,
|
||||||
|
):
|
||||||
|
"""Save training info after processing given number of batches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
out_dir:
|
||||||
|
The directory to save the checkpoint.
|
||||||
|
global_batch_idx:
|
||||||
|
The number of batches processed so far from the very start of the
|
||||||
|
training. The saved checkpoint will have the following filename:
|
||||||
|
|
||||||
|
f'out_dir / checkpoint-{global_batch_idx}.pt'
|
||||||
|
model:
|
||||||
|
The neural network model whose `state_dict` will be saved in the
|
||||||
|
checkpoint.
|
||||||
|
model_avg:
|
||||||
|
The stored model averaged from the start of training.
|
||||||
|
params:
|
||||||
|
A dict of training configurations to be saved.
|
||||||
|
optimizer:
|
||||||
|
The optimizer used in the training. Its `state_dict` will be saved.
|
||||||
|
scheduler:
|
||||||
|
The learning rate scheduler used in the training. Its `state_dict` will
|
||||||
|
be saved.
|
||||||
|
scaler:
|
||||||
|
The scaler used for mix precision training. Its `state_dict` will
|
||||||
|
be saved.
|
||||||
|
sampler:
|
||||||
|
The sampler used in the training dataset.
|
||||||
|
rank:
|
||||||
|
The rank ID used in DDP training of the current node. Set it to 0
|
||||||
|
if DDP is not used.
|
||||||
|
"""
|
||||||
|
out_dir = Path(out_dir)
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
filename = out_dir / f"checkpoint-{global_batch_idx}.pt"
|
||||||
|
save_checkpoint(
|
||||||
|
filename=filename,
|
||||||
|
model=model,
|
||||||
|
model_avg=model_avg,
|
||||||
|
params=params,
|
||||||
|
optimizer_g=optimizer_g,
|
||||||
|
scheduler_g=scheduler_g,
|
||||||
|
optimizer_d=optimizer_d,
|
||||||
|
scheduler_d=scheduler_d,
|
||||||
|
scaler=scaler,
|
||||||
|
sampler=sampler,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint(
|
||||||
|
filename: Path,
|
||||||
|
model: nn.Module,
|
||||||
|
model_avg: Optional[nn.Module] = None,
|
||||||
|
optimizer_g: Optional[Optimizer] = None,
|
||||||
|
optimizer_d: Optional[Optimizer] = None,
|
||||||
|
scheduler_g: Optional[LRScheduler] = None,
|
||||||
|
scheduler_d: Optional[LRScheduler] = None,
|
||||||
|
scaler: Optional[GradScaler] = None,
|
||||||
|
sampler: Optional[CutSampler] = None,
|
||||||
|
strict: bool = False,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
logging.info(f"Loading checkpoint from {filename}")
|
||||||
|
checkpoint = torch.load(filename, map_location="cpu")
|
||||||
|
|
||||||
|
if next(iter(checkpoint["model"])).startswith("module."):
|
||||||
|
logging.info("Loading checkpoint saved by DDP")
|
||||||
|
|
||||||
|
dst_state_dict = model.state_dict()
|
||||||
|
src_state_dict = checkpoint["model"]
|
||||||
|
for key in dst_state_dict.keys():
|
||||||
|
src_key = "{}.{}".format("module", key)
|
||||||
|
dst_state_dict[key] = src_state_dict.pop(src_key)
|
||||||
|
assert len(src_state_dict) == 0
|
||||||
|
model.load_state_dict(dst_state_dict, strict=strict)
|
||||||
|
else:
|
||||||
|
model.load_state_dict(checkpoint["model"], strict=strict)
|
||||||
|
|
||||||
|
checkpoint.pop("model")
|
||||||
|
|
||||||
|
if model_avg is not None and "model_avg" in checkpoint:
|
||||||
|
logging.info("Loading averaged model")
|
||||||
|
model_avg.load_state_dict(checkpoint["model_avg"], strict=strict)
|
||||||
|
checkpoint.pop("model_avg")
|
||||||
|
|
||||||
|
def load(name, obj):
|
||||||
|
s = checkpoint.get(name, None)
|
||||||
|
if obj and s:
|
||||||
|
obj.load_state_dict(s)
|
||||||
|
checkpoint.pop(name)
|
||||||
|
|
||||||
|
load("optimizer_g", optimizer_g)
|
||||||
|
load("optimizer_d", optimizer_d)
|
||||||
|
load("scheduler_g", scheduler_g)
|
||||||
|
load("scheduler_d", scheduler_d)
|
||||||
|
load("grad_scaler", scaler)
|
||||||
|
load("sampler", sampler)
|
||||||
|
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
def save_checkpoint(
|
||||||
|
filename: Path,
|
||||||
|
model: Union[nn.Module, DDP],
|
||||||
|
model_avg: Optional[nn.Module] = None,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
optimizer_g: Optional[Optimizer] = None,
|
||||||
|
optimizer_d: Optional[Optimizer] = None,
|
||||||
|
scheduler_g: Optional[LRScheduler] = None,
|
||||||
|
scheduler_d: Optional[LRScheduler] = None,
|
||||||
|
scaler: Optional[GradScaler] = None,
|
||||||
|
sampler: Optional[CutSampler] = None,
|
||||||
|
rank: int = 0,
|
||||||
|
) -> None:
|
||||||
|
"""Save training information to a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
The checkpoint filename.
|
||||||
|
model:
|
||||||
|
The model to be saved. We only save its `state_dict()`.
|
||||||
|
model_avg:
|
||||||
|
The stored model averaged from the start of training.
|
||||||
|
params:
|
||||||
|
User defined parameters, e.g., epoch, loss.
|
||||||
|
optimizer:
|
||||||
|
The optimizer to be saved. We only save its `state_dict()`.
|
||||||
|
scheduler:
|
||||||
|
The scheduler to be saved. We only save its `state_dict()`.
|
||||||
|
scalar:
|
||||||
|
The GradScaler to be saved. We only save its `state_dict()`.
|
||||||
|
rank:
|
||||||
|
Used in DDP. We save checkpoint only for the node whose rank is 0.
|
||||||
|
Returns:
|
||||||
|
Return None.
|
||||||
|
"""
|
||||||
|
if rank != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
logging.info(f"Saving checkpoint to {filename}")
|
||||||
|
|
||||||
|
if isinstance(model, DDP):
|
||||||
|
model = model.module
|
||||||
|
|
||||||
|
checkpoint = {
|
||||||
|
"model": model.state_dict(),
|
||||||
|
"optimizer_g": optimizer_g.state_dict() if optimizer_g is not None else None,
|
||||||
|
"optimizer_d": optimizer_d.state_dict() if optimizer_d is not None else None,
|
||||||
|
"scheduler_g": scheduler_g.state_dict() if scheduler_g is not None else None,
|
||||||
|
"scheduler_d": scheduler_d.state_dict() if scheduler_d is not None else None,
|
||||||
|
"grad_scaler": scaler.state_dict() if scaler is not None else None,
|
||||||
|
"sampler": sampler.state_dict() if sampler is not None else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
if model_avg is not None:
|
||||||
|
checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict()
|
||||||
|
|
||||||
|
if params:
|
||||||
|
for k, v in params.items():
|
||||||
|
assert k not in checkpoint
|
||||||
|
checkpoint[k] = v
|
||||||
|
|
||||||
|
torch.save(checkpoint, filename)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_cosine_schedule_with_warmup_lr_lambda(
|
||||||
|
current_step: int,
|
||||||
|
*,
|
||||||
|
num_warmup_steps: int,
|
||||||
|
num_training_steps: int,
|
||||||
|
num_cycles: float,
|
||||||
|
min_lr_rate: float = 0.0,
|
||||||
|
):
|
||||||
|
if current_step < num_warmup_steps:
|
||||||
|
return float(current_step) / float(max(1, num_warmup_steps))
|
||||||
|
progress = float(current_step - num_warmup_steps) / float(
|
||||||
|
max(1, num_training_steps - num_warmup_steps)
|
||||||
|
)
|
||||||
|
factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
|
||||||
|
factor = factor * (1 - min_lr_rate) + min_lr_rate
|
||||||
|
return max(0, factor)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cosine_schedule_with_warmup(
|
||||||
|
optimizer: Optimizer,
|
||||||
|
num_warmup_steps: int,
|
||||||
|
num_training_steps: int,
|
||||||
|
num_cycles: float = 0.5,
|
||||||
|
last_epoch: int = -1,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
||||||
|
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
|
||||||
|
initial lr set in the optimizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer ([`~torch.optim.Optimizer`]):
|
||||||
|
The optimizer for which to schedule the learning rate.
|
||||||
|
num_warmup_steps (`int`):
|
||||||
|
The number of steps for the warmup phase.
|
||||||
|
num_training_steps (`int`):
|
||||||
|
The total number of training steps.
|
||||||
|
num_cycles (`float`, *optional*, defaults to 0.5):
|
||||||
|
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
|
||||||
|
following a half-cosine).
|
||||||
|
last_epoch (`int`, *optional*, defaults to -1):
|
||||||
|
The index of the last epoch when resuming training.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
||||||
|
"""
|
||||||
|
|
||||||
|
lr_lambda = partial(
|
||||||
|
_get_cosine_schedule_with_warmup_lr_lambda,
|
||||||
|
num_warmup_steps=num_warmup_steps,
|
||||||
|
num_training_steps=num_training_steps,
|
||||||
|
num_cycles=num_cycles,
|
||||||
|
)
|
||||||
|
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||||
|
|
||||||
|
|
||||||
|
def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor.
|
||||||
|
clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Element-wise logarithm of the input tensor with clipping applied.
|
||||||
|
"""
|
||||||
|
return torch.log(torch.clip(x, min=clip_val))
|
287
egs/ljspeech/TTS/local/evaluate_fsd.py
Normal file
287
egs/ljspeech/TTS/local/evaluate_fsd.py
Normal file
@ -0,0 +1,287 @@
|
|||||||
|
"""
|
||||||
|
Calculate Frechet Speech Distance betweeen two speech directories.
|
||||||
|
Adapted from: https://github.com/gudgud96/frechet-audio-distance/blob/main/frechet_audio_distance/fad.py
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from multiprocessing.dummy import Pool as ThreadPool
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
import torch
|
||||||
|
from scipy import linalg
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--real-path", type=str, help="path of the real speech directory"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--eval-path", type=str, help="path of the evaluated speech directory"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-path",
|
||||||
|
type=str,
|
||||||
|
default="model/huggingface/wav2vec2_base",
|
||||||
|
help="path of the wav2vec 2.0 model directory",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--real-embds-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="path of the real embedding directory",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--eval-embds-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="path of the evaluated embedding directory",
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
class FrechetSpeechDistance:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_path="resources/wav2vec2_base",
|
||||||
|
pca_dim=128,
|
||||||
|
speech_load_worker=8,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize FSD
|
||||||
|
"""
|
||||||
|
self.sample_rate = 16000
|
||||||
|
self.channels = 1
|
||||||
|
self.device = (
|
||||||
|
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
|
)
|
||||||
|
logging.info("[Frechet Speech Distance] Using device: {}".format(self.device))
|
||||||
|
self.speech_load_worker = speech_load_worker
|
||||||
|
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_path)
|
||||||
|
self.model = Wav2Vec2Model.from_pretrained(model_path)
|
||||||
|
self.model.to(self.device)
|
||||||
|
self.model.eval()
|
||||||
|
self.pca_dim = pca_dim
|
||||||
|
|
||||||
|
def load_speech_files(self, dir, dtype="float32"):
|
||||||
|
def _load_speech_task(fname, sample_rate, channels, dtype="float32"):
|
||||||
|
if dtype not in ["float64", "float32", "int32", "int16"]:
|
||||||
|
raise ValueError(f"dtype not supported: {dtype}")
|
||||||
|
|
||||||
|
wav_data, sr = sf.read(fname, dtype=dtype)
|
||||||
|
# For integer type PCM input, convert to [-1.0, +1.0]
|
||||||
|
if dtype == "int16":
|
||||||
|
wav_data = wav_data / 32768.0
|
||||||
|
elif dtype == "int32":
|
||||||
|
wav_data = wav_data / float(2**31)
|
||||||
|
|
||||||
|
# Convert to mono
|
||||||
|
assert channels in [1, 2], "channels must be 1 or 2"
|
||||||
|
if len(wav_data.shape) > channels:
|
||||||
|
wav_data = np.mean(wav_data, axis=1)
|
||||||
|
|
||||||
|
if sr != sample_rate:
|
||||||
|
wav_data = (
|
||||||
|
librosa.resample(wav_data, orig_sr=sr, target_sr=sample_rate),
|
||||||
|
)
|
||||||
|
|
||||||
|
return wav_data
|
||||||
|
|
||||||
|
task_results = []
|
||||||
|
|
||||||
|
pool = ThreadPool(self.speech_load_worker)
|
||||||
|
|
||||||
|
logging.info("[Frechet Speech Distance] Loading speech from {}...".format(dir))
|
||||||
|
for fname in os.listdir(dir):
|
||||||
|
res = pool.apply_async(
|
||||||
|
_load_speech_task,
|
||||||
|
args=(os.path.join(dir, fname), self.sample_rate, self.channels, dtype),
|
||||||
|
)
|
||||||
|
task_results.append(res)
|
||||||
|
pool.close()
|
||||||
|
pool.join()
|
||||||
|
|
||||||
|
return [k.get() for k in task_results]
|
||||||
|
|
||||||
|
def get_embeddings(self, x):
|
||||||
|
"""
|
||||||
|
Get embeddings
|
||||||
|
Params:
|
||||||
|
-- x : a list of np.ndarray speech samples
|
||||||
|
-- sr : sampling rate.
|
||||||
|
"""
|
||||||
|
embd_lst = []
|
||||||
|
try:
|
||||||
|
for speech in tqdm(x):
|
||||||
|
input_features = self.feature_extractor(
|
||||||
|
speech, sampling_rate=self.sample_rate, return_tensors="pt"
|
||||||
|
).input_values.to(self.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
embd = self.model(input_features).last_hidden_state.mean(1)
|
||||||
|
|
||||||
|
if embd.device != torch.device("cpu"):
|
||||||
|
embd = embd.cpu()
|
||||||
|
|
||||||
|
if torch.is_tensor(embd):
|
||||||
|
embd = embd.detach().numpy()
|
||||||
|
|
||||||
|
embd_lst.append(embd)
|
||||||
|
except Exception as e:
|
||||||
|
print(
|
||||||
|
"[Frechet Speech Distance] get_embeddings throw an exception: {}".format(
|
||||||
|
str(e)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return np.concatenate(embd_lst, axis=0)
|
||||||
|
|
||||||
|
def calculate_embd_statistics(self, embd_lst):
|
||||||
|
if isinstance(embd_lst, list):
|
||||||
|
embd_lst = np.array(embd_lst)
|
||||||
|
mu = np.mean(embd_lst, axis=0)
|
||||||
|
sigma = np.cov(embd_lst, rowvar=False)
|
||||||
|
return mu, sigma
|
||||||
|
|
||||||
|
def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6):
|
||||||
|
"""
|
||||||
|
Adapted from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py
|
||||||
|
|
||||||
|
Numpy implementation of the Frechet Distance.
|
||||||
|
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
||||||
|
and X_2 ~ N(mu_2, C_2) is
|
||||||
|
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
||||||
|
Stable version by Dougal J. Sutherland.
|
||||||
|
Params:
|
||||||
|
-- mu1 : Numpy array containing the activations of a layer of the
|
||||||
|
inception net (like returned by the function 'get_predictions')
|
||||||
|
for generated samples.
|
||||||
|
-- mu2 : The sample mean over activations, precalculated on an
|
||||||
|
representative data set.
|
||||||
|
-- sigma1: The covariance matrix over activations for generated samples.
|
||||||
|
-- sigma2: The covariance matrix over activations, precalculated on an
|
||||||
|
representative data set.
|
||||||
|
Returns:
|
||||||
|
-- : The Frechet Distance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
mu1 = np.atleast_1d(mu1)
|
||||||
|
mu2 = np.atleast_1d(mu2)
|
||||||
|
|
||||||
|
sigma1 = np.atleast_2d(sigma1)
|
||||||
|
sigma2 = np.atleast_2d(sigma2)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
mu1.shape == mu2.shape
|
||||||
|
), "Training and test mean vectors have different lengths"
|
||||||
|
assert (
|
||||||
|
sigma1.shape == sigma2.shape
|
||||||
|
), "Training and test covariances have different dimensions"
|
||||||
|
|
||||||
|
diff = mu1 - mu2
|
||||||
|
|
||||||
|
# Product might be almost singular
|
||||||
|
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2).astype(complex), disp=False)
|
||||||
|
if not np.isfinite(covmean).all():
|
||||||
|
msg = (
|
||||||
|
"fid calculation produces singular product; "
|
||||||
|
"adding %s to diagonal of cov estimates"
|
||||||
|
) % eps
|
||||||
|
logging.info(msg)
|
||||||
|
offset = np.eye(sigma1.shape[0]) * eps
|
||||||
|
covmean = linalg.sqrtm(
|
||||||
|
(sigma1 + offset).dot(sigma2 + offset).astype(complex)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Numerical error might give slight imaginary component
|
||||||
|
if np.iscomplexobj(covmean):
|
||||||
|
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
||||||
|
m = np.max(np.abs(covmean.imag))
|
||||||
|
raise ValueError("Imaginary component {}".format(m))
|
||||||
|
covmean = covmean.real
|
||||||
|
|
||||||
|
tr_covmean = np.trace(covmean)
|
||||||
|
|
||||||
|
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
|
||||||
|
|
||||||
|
def score(
|
||||||
|
self,
|
||||||
|
real_path,
|
||||||
|
eval_path,
|
||||||
|
real_embds_path=None,
|
||||||
|
eval_embds_path=None,
|
||||||
|
dtype="float32",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Computes the Frechet Speech Distance (FSD) between two directories of speech files.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- real_path (str): Path to the directory containing real speech files.
|
||||||
|
- eval_path (str): Path to the directory containing evaluation speech files.
|
||||||
|
- real_embds_path (str, optional): Path to save/load real speech embeddings (e.g., /folder/bkg_embs.npy). If None, embeddings won't be saved.
|
||||||
|
- eval_embds_path (str, optional): Path to save/load evaluation speech embeddings (e.g., /folder/test_embs.npy). If None, embeddings won't be saved.
|
||||||
|
- dtype (str, optional): Data type for loading speech. Default is "float32".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- float: The Frechet Speech Distance (FSD) score between the two directories of speech files.
|
||||||
|
"""
|
||||||
|
# Load or compute real embeddings
|
||||||
|
if real_embds_path is not None and os.path.exists(real_embds_path):
|
||||||
|
logging.info(
|
||||||
|
f"[Frechet Speech Distance] Loading embeddings from {real_embds_path}..."
|
||||||
|
)
|
||||||
|
embds_real = np.load(real_embds_path)
|
||||||
|
else:
|
||||||
|
speech_real = self.load_speech_files(real_path, dtype=dtype)
|
||||||
|
embds_real = self.get_embeddings(speech_real)
|
||||||
|
if real_embds_path:
|
||||||
|
os.makedirs(os.path.dirname(real_embds_path), exist_ok=True)
|
||||||
|
np.save(real_embds_path, embds_real)
|
||||||
|
|
||||||
|
# Load or compute eval embeddings
|
||||||
|
if eval_embds_path is not None and os.path.exists(eval_embds_path):
|
||||||
|
logging.info(
|
||||||
|
f"[Frechet Speech Distance] Loading embeddings from {eval_embds_path}..."
|
||||||
|
)
|
||||||
|
embds_eval = np.load(eval_embds_path)
|
||||||
|
else:
|
||||||
|
speech_eval = self.load_speech_files(eval_path, dtype=dtype)
|
||||||
|
embds_eval = self.get_embeddings(speech_eval)
|
||||||
|
if eval_embds_path:
|
||||||
|
os.makedirs(os.path.dirname(eval_embds_path), exist_ok=True)
|
||||||
|
np.save(eval_embds_path, embds_eval)
|
||||||
|
|
||||||
|
# Check if embeddings are empty
|
||||||
|
if len(embds_real) == 0:
|
||||||
|
logging.info("[Frechet Speech Distance] real set dir is empty, exiting...")
|
||||||
|
return -10.46
|
||||||
|
if len(embds_eval) == 0:
|
||||||
|
logging.info("[Frechet Speech Distance] eval set dir is empty, exiting...")
|
||||||
|
return -1
|
||||||
|
|
||||||
|
# Compute statistics and FSD score
|
||||||
|
mu_real, sigma_real = self.calculate_embd_statistics(embds_real)
|
||||||
|
mu_eval, sigma_eval = self.calculate_embd_statistics(embds_eval)
|
||||||
|
|
||||||
|
fsd_score = self.calculate_frechet_distance(
|
||||||
|
mu_real, sigma_real, mu_eval, sigma_eval
|
||||||
|
)
|
||||||
|
|
||||||
|
return fsd_score
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
FSD = FrechetSpeechDistance(model_path=args.model_path)
|
||||||
|
score = FSD.score(
|
||||||
|
args.real_path, args.eval_path, args.real_embds_path, args.eval_embds_path
|
||||||
|
)
|
||||||
|
logging.info(f"FSD score: {score:.2f}")
|
139
egs/ljspeech/TTS/local/evaluate_wer_whisper.py
Normal file
139
egs/ljspeech/TTS/local/evaluate_wer_whisper.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
"""
|
||||||
|
Calculate WER with Whisper model
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import soundfile as sf
|
||||||
|
import torch
|
||||||
|
from num2words import num2words
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import pipeline
|
||||||
|
|
||||||
|
from icefall.utils import store_transcripts, write_error_stats
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument("--wav-path", type=str, help="path of the speech directory")
|
||||||
|
parser.add_argument("--decode-path", type=str, help="path of the speech directory")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-path",
|
||||||
|
type=str,
|
||||||
|
default="model/huggingface/whisper_medium",
|
||||||
|
help="path of the huggingface whisper model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--transcript-path",
|
||||||
|
type=str,
|
||||||
|
default="data/transcript/test.tsv",
|
||||||
|
help="path of the transcript tsv file",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size", type=int, default=64, help="decoding batch size"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device", type=str, default="cuda:0", help="decoding device, cuda:0 or cpu"
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def post_process(text: str):
|
||||||
|
def convert_numbers(match):
|
||||||
|
return num2words(match.group())
|
||||||
|
|
||||||
|
text = re.sub(r"\b\d{1,2}\b", convert_numbers, text)
|
||||||
|
text = re.sub(r"[^a-zA-Z0-9']", " ", text.lower())
|
||||||
|
text = re.sub(r"\s+", " ", text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def save_results(
|
||||||
|
res_dir: str,
|
||||||
|
results: List[Tuple[str, List[str], List[str]]],
|
||||||
|
):
|
||||||
|
if not os.path.exists(res_dir):
|
||||||
|
os.makedirs(res_dir)
|
||||||
|
recog_path = os.path.join(res_dir, "recogs.txt")
|
||||||
|
results = sorted(results)
|
||||||
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
|
errs_filename = os.path.join(res_dir, "errs.txt")
|
||||||
|
with open(errs_filename, "w") as f:
|
||||||
|
_ = write_error_stats(f, "test", results, enable_log=True)
|
||||||
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
|
|
||||||
|
class SpeechEvalDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, wav_path: str, transcript_path: str):
|
||||||
|
super().__init__()
|
||||||
|
self.audio_name = []
|
||||||
|
self.audio_paths = []
|
||||||
|
self.transcripts = []
|
||||||
|
with Path(transcript_path).open("r", encoding="utf8") as f:
|
||||||
|
meta = [item.split("\t") for item in f.read().rstrip().split("\n")]
|
||||||
|
for item in meta:
|
||||||
|
self.audio_name.append(item[0])
|
||||||
|
self.audio_paths.append(Path(wav_path, item[0] + ".wav"))
|
||||||
|
self.transcripts.append(item[1])
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.audio_paths)
|
||||||
|
|
||||||
|
def __getitem__(self, index: int):
|
||||||
|
audio, sampling_rate = sf.read(self.audio_paths[index])
|
||||||
|
item = {
|
||||||
|
"array": librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000),
|
||||||
|
"sampling_rate": 16000,
|
||||||
|
"reference": self.transcripts[index],
|
||||||
|
"audio_name": self.audio_name[index],
|
||||||
|
}
|
||||||
|
return item
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
|
||||||
|
batch_size = args.batch_size
|
||||||
|
|
||||||
|
pipe = pipeline(
|
||||||
|
"automatic-speech-recognition",
|
||||||
|
model=args.model_path,
|
||||||
|
device=args.device,
|
||||||
|
tokenizer=args.model_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = SpeechEvalDataset(args.wav_path, args.transcript_path)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
bar = tqdm(
|
||||||
|
pipe(
|
||||||
|
dataset,
|
||||||
|
generate_kwargs={"language": "english", "task": "transcribe"},
|
||||||
|
batch_size=batch_size,
|
||||||
|
),
|
||||||
|
total=len(dataset),
|
||||||
|
)
|
||||||
|
for out in bar:
|
||||||
|
results.append(
|
||||||
|
(
|
||||||
|
out["audio_name"][0],
|
||||||
|
post_process(out["reference"][0].strip()).split(),
|
||||||
|
post_process(out["text"].strip()).split(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
save_results(args.decode_path, results)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
1
egs/ljspeech/TTS/vocos/discriminators.py
Symbolic link
1
egs/ljspeech/TTS/vocos/discriminators.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../libritts/TTS/vocos/discriminators.py
|
1
egs/ljspeech/TTS/vocos/export-onnx.py
Symbolic link
1
egs/ljspeech/TTS/vocos/export-onnx.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../libritts/TTS/vocos/export-onnx.py
|
1
egs/ljspeech/TTS/vocos/export.py
Symbolic link
1
egs/ljspeech/TTS/vocos/export.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../libritts/TTS/vocos/export.py
|
1
egs/ljspeech/TTS/vocos/generator.py
Symbolic link
1
egs/ljspeech/TTS/vocos/generator.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../libritts/TTS/vocos/generator.py
|
340
egs/ljspeech/TTS/vocos/infer.py
Executable file
340
egs/ljspeech/TTS/vocos/infer.py
Executable file
@ -0,0 +1,340 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2024 Xiaomi Corp. (authors: Wei Kang
|
||||||
|
# Han Zhu)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from lhotse.utils import fix_random_seed
|
||||||
|
from scipy.io.wavfile import write
|
||||||
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
from tts_datamodule import LJSpeechTtsDataModule
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||||
|
|
||||||
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
|
Note: Epoch counts from 1.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=10,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="flow_match/exp",
|
||||||
|
help="The experiment dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--generate-dir",
|
||||||
|
type=str,
|
||||||
|
default="generated_wavs",
|
||||||
|
help="Path name of the generated wavs",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def decode_one_batch(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
batch: dict,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
It's the return value of :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The text-to-feature neural model.
|
||||||
|
batch:
|
||||||
|
It is the return value from iterating
|
||||||
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
|
for the format of the `batch`.
|
||||||
|
Returns:
|
||||||
|
Return the decoding result. See above description for the format of
|
||||||
|
the returned dict.
|
||||||
|
"""
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
|
cut_ids = [cut.id for cut in batch["cut"]]
|
||||||
|
|
||||||
|
features = batch["features"] # (B, T, F)
|
||||||
|
utt_durations = batch["features_lens"]
|
||||||
|
|
||||||
|
x = features.permute(0, 2, 1) # (B, F, T)
|
||||||
|
|
||||||
|
audios = model(x.to(device)) # (B, T)
|
||||||
|
|
||||||
|
wav_dir = f"{params.res_dir}/{params.suffix}"
|
||||||
|
os.makedirs(wav_dir, exist_ok=True)
|
||||||
|
|
||||||
|
for i in range(audios.shape[0]):
|
||||||
|
audio = audios[i][: (utt_durations[i] - 1) * 256 + 1024]
|
||||||
|
audio = audio.cpu().squeeze().numpy()
|
||||||
|
write(f"{wav_dir}/{cut_ids[i]}.wav", 22050, audio)
|
||||||
|
|
||||||
|
|
||||||
|
def decode_dataset(
|
||||||
|
dl: torch.utils.data.DataLoader,
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
test_set: str,
|
||||||
|
):
|
||||||
|
"""Decode dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dl:
|
||||||
|
PyTorch's dataloader containing the dataset to decode.
|
||||||
|
params:
|
||||||
|
It is returned by :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The text-to-feature neural model.
|
||||||
|
test_set:
|
||||||
|
The name of the test_set
|
||||||
|
"""
|
||||||
|
num_cuts = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
num_batches = len(dl)
|
||||||
|
except TypeError:
|
||||||
|
num_batches = "?"
|
||||||
|
|
||||||
|
with open(f"{params.res_dir}/{test_set}.scp", "w", encoding="utf8") as f:
|
||||||
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
texts = batch["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["cut"]]
|
||||||
|
|
||||||
|
decode_one_batch(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
batch=batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(texts) == len(cut_ids), (len(texts), len(cut_ids))
|
||||||
|
|
||||||
|
for i in range(len(texts)):
|
||||||
|
f.write(f"{cut_ids[i]}\t{texts[i]}\n")
|
||||||
|
|
||||||
|
num_cuts += len(texts)
|
||||||
|
|
||||||
|
if batch_idx % 50 == 0:
|
||||||
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
LJSpeechTtsDataModule.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
params.res_dir = params.exp_dir / params.generate_dir
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||||
|
else:
|
||||||
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
|
if params.use_averaged_model:
|
||||||
|
params.suffix += "-use-averaged-model"
|
||||||
|
|
||||||
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
|
logging.info("Decoding started")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
params.device = device
|
||||||
|
|
||||||
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
fix_random_seed(666)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_model(params)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model = model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
|
ljspeech = LJSpeechTtsDataModule(args)
|
||||||
|
|
||||||
|
test_cuts = ljspeech.test_cuts()
|
||||||
|
|
||||||
|
test_dl = ljspeech.test_dataloaders(test_cuts)
|
||||||
|
|
||||||
|
test_sets = ["test"]
|
||||||
|
test_dls = [test_dl]
|
||||||
|
|
||||||
|
for test_set, test_dl in zip(test_sets, test_dls):
|
||||||
|
decode_dataset(
|
||||||
|
dl=test_dl,
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
test_set=test_set,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
1
egs/ljspeech/TTS/vocos/loss.py
Symbolic link
1
egs/ljspeech/TTS/vocos/loss.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../libritts/TTS/vocos/loss.py
|
1
egs/ljspeech/TTS/vocos/model.py
Symbolic link
1
egs/ljspeech/TTS/vocos/model.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../libritts/TTS/vocos/model.py
|
1
egs/ljspeech/TTS/vocos/onnx_pretrained.py
Symbolic link
1
egs/ljspeech/TTS/vocos/onnx_pretrained.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../libritts/TTS/vocos/onnx_pretrained.py
|
1
egs/ljspeech/TTS/vocos/pretrained.py
Symbolic link
1
egs/ljspeech/TTS/vocos/pretrained.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../libritts/TTS/vocos/pretrained.py
|
1054
egs/ljspeech/TTS/vocos/train.py
Executable file
1054
egs/ljspeech/TTS/vocos/train.py
Executable file
File diff suppressed because it is too large
Load Diff
372
egs/ljspeech/TTS/vocos/tts_datamodule.py
Normal file
372
egs/ljspeech/TTS/vocos/tts_datamodule.py
Normal file
@ -0,0 +1,372 @@
|
|||||||
|
# Copyright 2021 Piotr Żelasko
|
||||||
|
# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo,
|
||||||
|
# Zengwei Yao,
|
||||||
|
# Wei Kang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
|
||||||
|
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||||
|
CutConcatenate,
|
||||||
|
CutMix,
|
||||||
|
DynamicBucketingSampler,
|
||||||
|
PrecomputedFeatures,
|
||||||
|
SimpleCutSampler,
|
||||||
|
SpecAugment,
|
||||||
|
SpeechSynthesisDataset,
|
||||||
|
)
|
||||||
|
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||||
|
AudioSamples,
|
||||||
|
OnTheFlyFeatures,
|
||||||
|
)
|
||||||
|
from lhotse.utils import fix_random_seed
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
|
class _SeedWorkers:
|
||||||
|
def __init__(self, seed: int):
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
def __call__(self, worker_id: int):
|
||||||
|
fix_random_seed(self.seed + worker_id)
|
||||||
|
|
||||||
|
|
||||||
|
class LJSpeechTtsDataModule:
|
||||||
|
"""
|
||||||
|
DataModule for tts experiments.
|
||||||
|
It assumes there is always one train and valid dataloader,
|
||||||
|
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||||
|
and test-other).
|
||||||
|
|
||||||
|
It contains all the common data pipeline modules used in ASR
|
||||||
|
experiments, e.g.:
|
||||||
|
- dynamic batch size,
|
||||||
|
- bucketing samplers,
|
||||||
|
- cut concatenation,
|
||||||
|
- on-the-fly feature extraction
|
||||||
|
|
||||||
|
This class should be derived for specific corpora used in ASR tasks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, args: argparse.Namespace):
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||||
|
group = parser.add_argument_group(
|
||||||
|
title="TTS data related options",
|
||||||
|
description="These options are used for the preparation of "
|
||||||
|
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||||
|
"effective batch sizes, sampling strategies, applied data "
|
||||||
|
"augmentations, etc.",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--manifest-dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("data/fbank"),
|
||||||
|
help="Path to directory with train/valid/test cuts.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--max-duration",
|
||||||
|
type=int,
|
||||||
|
default=200.0,
|
||||||
|
help="Maximum pooled recordings duration (seconds) in a "
|
||||||
|
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--bucketing-sampler",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="When enabled, the batches will come from buckets of "
|
||||||
|
"similar duration (saves padding frames).",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--num-buckets",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="The number of buckets for the DynamicBucketingSampler"
|
||||||
|
"(you might want to increase it for larger datasets).",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--on-the-fly-feats",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="When enabled, use on-the-fly cut mixing and feature "
|
||||||
|
"extraction. Will drop existing precomputed feature manifests "
|
||||||
|
"if available.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--shuffle",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="When enabled (=default), the examples will be "
|
||||||
|
"shuffled for each epoch.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--drop-last",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to drop last batch. Used by sampler.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--return-cuts",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="When enabled, each batch will have the "
|
||||||
|
"field: batch['cut'] with the cuts that "
|
||||||
|
"were used to construct it.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--num-workers",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The number of training dataloader workers that "
|
||||||
|
"collect the batches.",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--sampling-rate",
|
||||||
|
type=int,
|
||||||
|
default=22050,
|
||||||
|
help="The sampleing rate of ljspeech dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--frame-shift",
|
||||||
|
type=int,
|
||||||
|
default=256,
|
||||||
|
help="Frame shift.",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--frame-length",
|
||||||
|
type=int,
|
||||||
|
default=1024,
|
||||||
|
help="Frame shift.",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--input-strategy",
|
||||||
|
type=str,
|
||||||
|
default="PrecomputedFeatures",
|
||||||
|
help="AudioSamples or PrecomputedFeatures",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--use-fft-mag",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to use magnitude of fbank, false to use power energy.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def train_dataloaders(
|
||||||
|
self,
|
||||||
|
cuts_train: CutSet,
|
||||||
|
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> DataLoader:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
cuts_train:
|
||||||
|
CutSet for training.
|
||||||
|
sampler_state_dict:
|
||||||
|
The state dict for the training sampler.
|
||||||
|
"""
|
||||||
|
logging.info("About to create train dataset")
|
||||||
|
train = SpeechSynthesisDataset(
|
||||||
|
return_text=True,
|
||||||
|
return_tokens=False,
|
||||||
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.on_the_fly_feats:
|
||||||
|
sampling_rate = self.args.sampling_rate
|
||||||
|
config = FbankConfig(
|
||||||
|
sampling_rate=sampling_rate,
|
||||||
|
frame_length=self.args.frame_length / sampling_rate, # (in second),
|
||||||
|
frame_shift=self.args.frame_shift / sampling_rate, # (in second)
|
||||||
|
use_fft_mag=self.args.use_fft_mag,
|
||||||
|
)
|
||||||
|
train = SpeechSynthesisDataset(
|
||||||
|
return_text=True,
|
||||||
|
return_tokens=False,
|
||||||
|
feature_input_strategy=OnTheFlyFeatures(Fbank(config)),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.bucketing_sampler:
|
||||||
|
logging.info("Using DynamicBucketingSampler.")
|
||||||
|
train_sampler = DynamicBucketingSampler(
|
||||||
|
cuts_train,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=self.args.shuffle,
|
||||||
|
num_buckets=self.args.num_buckets,
|
||||||
|
buffer_size=self.args.num_buckets * 2000,
|
||||||
|
shuffle_buffer_size=self.args.num_buckets * 5000,
|
||||||
|
drop_last=self.args.drop_last,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.info("Using SimpleCutSampler.")
|
||||||
|
train_sampler = SimpleCutSampler(
|
||||||
|
cuts_train,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=self.args.shuffle,
|
||||||
|
)
|
||||||
|
logging.info("About to create train dataloader")
|
||||||
|
|
||||||
|
if sampler_state_dict is not None:
|
||||||
|
logging.info("Loading sampler state dict")
|
||||||
|
train_sampler.load_state_dict(sampler_state_dict)
|
||||||
|
|
||||||
|
# 'seed' is derived from the current random state, which will have
|
||||||
|
# previously been set in the main process.
|
||||||
|
seed = torch.randint(0, 100000, ()).item()
|
||||||
|
worker_init_fn = _SeedWorkers(seed)
|
||||||
|
|
||||||
|
train_dl = DataLoader(
|
||||||
|
train,
|
||||||
|
sampler=train_sampler,
|
||||||
|
batch_size=None,
|
||||||
|
num_workers=self.args.num_workers,
|
||||||
|
persistent_workers=False,
|
||||||
|
worker_init_fn=worker_init_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_dl
|
||||||
|
|
||||||
|
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||||
|
logging.info("About to create dev dataset")
|
||||||
|
if self.args.on_the_fly_feats:
|
||||||
|
sampling_rate = self.args.sampling_rate
|
||||||
|
config = FbankConfig(
|
||||||
|
sampling_rate=sampling_rate,
|
||||||
|
frame_length=self.args.frame_length / sampling_rate, # (in second),
|
||||||
|
frame_shift=self.args.frame_shift / sampling_rate, # (in second)
|
||||||
|
use_fft_mag=self.args.use_fft_mag,
|
||||||
|
)
|
||||||
|
validate = SpeechSynthesisDataset(
|
||||||
|
return_text=True,
|
||||||
|
return_tokens=False,
|
||||||
|
feature_input_strategy=OnTheFlyFeatures(Fbank(config)),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
validate = SpeechSynthesisDataset(
|
||||||
|
return_text=True,
|
||||||
|
return_tokens=False,
|
||||||
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
valid_sampler = DynamicBucketingSampler(
|
||||||
|
cuts_valid,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
num_buckets=self.args.num_buckets,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
logging.info("About to create valid dataloader")
|
||||||
|
valid_dl = DataLoader(
|
||||||
|
validate,
|
||||||
|
sampler=valid_sampler,
|
||||||
|
batch_size=None,
|
||||||
|
num_workers=2,
|
||||||
|
persistent_workers=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return valid_dl
|
||||||
|
|
||||||
|
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||||
|
logging.info("About to create test dataset")
|
||||||
|
if self.args.on_the_fly_feats:
|
||||||
|
sampling_rate = self.args.sampling_rate
|
||||||
|
config = FbankConfig(
|
||||||
|
sampling_rate=sampling_rate,
|
||||||
|
frame_length=self.args.frame_length / sampling_rate, # (in second),
|
||||||
|
frame_shift=self.args.frame_shift / sampling_rate, # (in second)
|
||||||
|
use_fft_mag=self.args.use_fft_mag,
|
||||||
|
)
|
||||||
|
test = SpeechSynthesisDataset(
|
||||||
|
return_text=True,
|
||||||
|
return_tokens=False,
|
||||||
|
feature_input_strategy=OnTheFlyFeatures(Fbank(config)),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
test = SpeechSynthesisDataset(
|
||||||
|
return_text=True,
|
||||||
|
return_tokens=False,
|
||||||
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
test_sampler = DynamicBucketingSampler(
|
||||||
|
cuts,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
num_buckets=self.args.num_buckets,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
logging.info("About to create test dataloader")
|
||||||
|
test_dl = DataLoader(
|
||||||
|
test,
|
||||||
|
batch_size=None,
|
||||||
|
sampler=test_sampler,
|
||||||
|
num_workers=self.args.num_workers,
|
||||||
|
)
|
||||||
|
return test_dl
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def train_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get train cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "ljspeech_cuts_train.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def valid_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get validation cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "ljspeech_cuts_valid.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def train_cuts_finetune(self) -> CutSet:
|
||||||
|
logging.info("About to get train cuts finetune")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "ljspeech_cuts_train_finetune.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def valid_cuts_finetune(self) -> CutSet:
|
||||||
|
logging.info("About to get validation cuts finetune")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "ljspeech_cuts_valid_finetune.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def test_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get test cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "ljspeech_cuts_test.jsonl.gz"
|
||||||
|
)
|
1
egs/ljspeech/TTS/vocos/utils.py
Symbolic link
1
egs/ljspeech/TTS/vocos/utils.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../libritts/TTS/vocos/utils.py
|
@ -251,18 +251,22 @@ def save_checkpoint_with_global_batch_idx(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
|
def find_checkpoints(
|
||||||
|
out_dir: Path,
|
||||||
|
iteration: int = 0,
|
||||||
|
prefix: str = "checkpoint",
|
||||||
|
) -> List[str]:
|
||||||
"""Find all available checkpoints in a directory.
|
"""Find all available checkpoints in a directory.
|
||||||
|
|
||||||
The checkpoint filenames have the form: `checkpoint-xxx.pt`
|
The checkpoint filenames have the form: `{prefix}-xxx.pt`
|
||||||
where xxx is a numerical value.
|
where xxx is a numerical value.
|
||||||
|
|
||||||
Assume you have the following checkpoints in the folder `foo`:
|
Assume you have the following checkpoints in the folder `foo`:
|
||||||
|
|
||||||
- checkpoint-1.pt
|
- {prefix}-1.pt
|
||||||
- checkpoint-20.pt
|
- {prefix}-20.pt
|
||||||
- checkpoint-300.pt
|
- {prefix}-300.pt
|
||||||
- checkpoint-4000.pt
|
- {prefix}-4000.pt
|
||||||
|
|
||||||
Case 1 (Return all checkpoints)::
|
Case 1 (Return all checkpoints)::
|
||||||
|
|
||||||
@ -291,8 +295,8 @@ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
|
|||||||
Return a list of checkpoint filenames, sorted in descending
|
Return a list of checkpoint filenames, sorted in descending
|
||||||
order by the numerical value in the filename.
|
order by the numerical value in the filename.
|
||||||
"""
|
"""
|
||||||
checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
|
checkpoints = list(glob.glob(f"{out_dir}/{prefix}-[0-9]*.pt"))
|
||||||
pattern = re.compile(r"checkpoint-([0-9]+).pt")
|
pattern = re.compile(rf"{prefix}-([0-9]+).pt")
|
||||||
iter_checkpoints = []
|
iter_checkpoints = []
|
||||||
for c in checkpoints:
|
for c in checkpoints:
|
||||||
result = pattern.search(c)
|
result = pattern.search(c)
|
||||||
@ -317,12 +321,13 @@ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
|
|||||||
def remove_checkpoints(
|
def remove_checkpoints(
|
||||||
out_dir: Path,
|
out_dir: Path,
|
||||||
topk: int,
|
topk: int,
|
||||||
|
prefix: str = "checkpoint",
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
):
|
):
|
||||||
"""Remove checkpoints from the given directory.
|
"""Remove checkpoints from the given directory.
|
||||||
|
|
||||||
We assume that checkpoint filename has the form `checkpoint-xxx.pt`
|
We assume that checkpoint filename has the form `{prefix}-xxx.pt`
|
||||||
where xxx is a number, representing the number of processed batches
|
where xxx is a number, representing the number of processed batches/epochs
|
||||||
when saving that checkpoint. We sort checkpoints by filename and keep
|
when saving that checkpoint. We sort checkpoints by filename and keep
|
||||||
only the `topk` checkpoints with the highest `xxx`.
|
only the `topk` checkpoints with the highest `xxx`.
|
||||||
|
|
||||||
@ -331,6 +336,8 @@ def remove_checkpoints(
|
|||||||
The directory containing checkpoints to be removed.
|
The directory containing checkpoints to be removed.
|
||||||
topk:
|
topk:
|
||||||
Number of checkpoints to keep.
|
Number of checkpoints to keep.
|
||||||
|
prefix:
|
||||||
|
The prefix of the checkpoint filename, normally `epoch`, `checkpoint`.
|
||||||
rank:
|
rank:
|
||||||
If using DDP for training, it is the rank of the current node.
|
If using DDP for training, it is the rank of the current node.
|
||||||
Use 0 if no DDP is used for training.
|
Use 0 if no DDP is used for training.
|
||||||
@ -338,7 +345,7 @@ def remove_checkpoints(
|
|||||||
assert topk >= 1, topk
|
assert topk >= 1, topk
|
||||||
if rank != 0:
|
if rank != 0:
|
||||||
return
|
return
|
||||||
checkpoints = find_checkpoints(out_dir)
|
checkpoints = find_checkpoints(out_dir, prefix=prefix)
|
||||||
|
|
||||||
if len(checkpoints) == 0:
|
if len(checkpoints) == 0:
|
||||||
logging.warn(f"No checkpoints found in {out_dir}")
|
logging.warn(f"No checkpoints found in {out_dir}")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user