mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Merge d8a0a4095554e58db0fc8f4e30a6a33932ab37dd into 34fc1fdf0d8ff520e2bb18267d046ca207c78ef9
This commit is contained in:
commit
cfbd4208cc
296
egs/libritts/TTS/vocos/discriminators.py
Normal file
296
egs/libritts/TTS/vocos/discriminators.py
Normal file
@ -0,0 +1,296 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Conv2d
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torchaudio.transforms import Spectrogram
|
||||
|
||||
|
||||
class MultiPeriodDiscriminator(nn.Module):
|
||||
"""
|
||||
Multi-Period Discriminator module adapted from https://github.com/jik876/hifi-gan.
|
||||
Additionally, it allows incorporating conditional information with a learned embeddings table.
|
||||
|
||||
Args:
|
||||
periods (tuple[int]): Tuple of periods for each discriminator.
|
||||
num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
periods: Tuple[int, ...] = (2, 3, 5, 7, 11),
|
||||
num_embeddings: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.discriminators = nn.ModuleList(
|
||||
[DiscriminatorP(period=p, num_embeddings=num_embeddings) for p in periods]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
y_hat: torch.Tensor,
|
||||
bandwidth_id: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]],
|
||||
List[List[torch.Tensor]],
|
||||
]:
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
for d in self.discriminators:
|
||||
y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
|
||||
y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
class DiscriminatorP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
period: int,
|
||||
in_channels: int = 1,
|
||||
kernel_size: int = 5,
|
||||
stride: int = 3,
|
||||
lrelu_slope: float = 0.1,
|
||||
num_embeddings: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.period = period
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv2d(
|
||||
in_channels,
|
||||
32,
|
||||
(kernel_size, 1),
|
||||
(stride, 1),
|
||||
padding=(kernel_size // 2, 0),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv2d(
|
||||
32,
|
||||
128,
|
||||
(kernel_size, 1),
|
||||
(stride, 1),
|
||||
padding=(kernel_size // 2, 0),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv2d(
|
||||
128,
|
||||
512,
|
||||
(kernel_size, 1),
|
||||
(stride, 1),
|
||||
padding=(kernel_size // 2, 0),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv2d(
|
||||
512,
|
||||
1024,
|
||||
(kernel_size, 1),
|
||||
(stride, 1),
|
||||
padding=(kernel_size // 2, 0),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv2d(
|
||||
1024,
|
||||
1024,
|
||||
(kernel_size, 1),
|
||||
(1, 1),
|
||||
padding=(kernel_size // 2, 0),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
if num_embeddings is not None:
|
||||
self.emb = torch.nn.Embedding(
|
||||
num_embeddings=num_embeddings, embedding_dim=1024
|
||||
)
|
||||
torch.nn.init.zeros_(self.emb.weight)
|
||||
|
||||
self.conv_post = weight_norm(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
||||
self.lrelu_slope = lrelu_slope
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
x = x.unsqueeze(1)
|
||||
fmap = []
|
||||
# 1d to 2d
|
||||
b, c, t = x.shape
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x = torch.nn.functional.pad(x, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
x = x.view(b, c, t // self.period, self.period)
|
||||
|
||||
for i, l in enumerate(self.convs):
|
||||
x = l(x)
|
||||
x = torch.nn.functional.leaky_relu(x, self.lrelu_slope)
|
||||
if i > 0:
|
||||
fmap.append(x)
|
||||
if cond_embedding_id is not None:
|
||||
emb = self.emb(cond_embedding_id)
|
||||
h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
|
||||
else:
|
||||
h = 0
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x += h
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class MultiResolutionDiscriminator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
|
||||
num_embeddings: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
|
||||
Additionally, it allows incorporating conditional information with a learned embeddings table.
|
||||
|
||||
Args:
|
||||
fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
|
||||
num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
DiscriminatorR(window_length=w, num_embeddings=num_embeddings)
|
||||
for w in fft_sizes
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
|
||||
) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]],
|
||||
List[List[torch.Tensor]],
|
||||
]:
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
|
||||
for d in self.discriminators:
|
||||
y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
|
||||
y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
class DiscriminatorR(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
window_length: int,
|
||||
num_embeddings: Optional[int] = None,
|
||||
channels: int = 32,
|
||||
hop_factor: float = 0.25,
|
||||
bands: Tuple[Tuple[float, float], ...] = (
|
||||
(0.0, 0.1),
|
||||
(0.1, 0.25),
|
||||
(0.25, 0.5),
|
||||
(0.5, 0.75),
|
||||
(0.75, 1.0),
|
||||
),
|
||||
):
|
||||
super().__init__()
|
||||
self.window_length = window_length
|
||||
self.hop_factor = hop_factor
|
||||
self.spec_fn = Spectrogram(
|
||||
n_fft=window_length,
|
||||
hop_length=int(window_length * hop_factor),
|
||||
win_length=window_length,
|
||||
power=None,
|
||||
)
|
||||
n_fft = window_length // 2 + 1
|
||||
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
|
||||
self.bands = bands
|
||||
convs = lambda: nn.ModuleList(
|
||||
[
|
||||
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
|
||||
weight_norm(
|
||||
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
||||
),
|
||||
weight_norm(
|
||||
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
||||
),
|
||||
weight_norm(
|
||||
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
||||
),
|
||||
weight_norm(
|
||||
nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))
|
||||
),
|
||||
]
|
||||
)
|
||||
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
|
||||
|
||||
if num_embeddings is not None:
|
||||
self.emb = torch.nn.Embedding(
|
||||
num_embeddings=num_embeddings, embedding_dim=channels
|
||||
)
|
||||
torch.nn.init.zeros_(self.emb.weight)
|
||||
|
||||
self.conv_post = weight_norm(
|
||||
nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))
|
||||
)
|
||||
|
||||
def spectrogram(self, x):
|
||||
# Remove DC offset
|
||||
x = x - x.mean(dim=-1, keepdims=True)
|
||||
# Peak normalize the volume of input audio
|
||||
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
|
||||
x = self.spec_fn(x)
|
||||
x = torch.view_as_real(x)
|
||||
# x = rearrange(x, "b f t c -> b c t f")
|
||||
x = x.permute(0, 3, 2, 1)
|
||||
# Split into bands
|
||||
x_bands = [x[..., b[0] : b[1]] for b in self.bands]
|
||||
return x_bands
|
||||
|
||||
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
|
||||
x_bands = self.spectrogram(x)
|
||||
fmap = []
|
||||
x = []
|
||||
for band, stack in zip(x_bands, self.band_convs):
|
||||
for i, layer in enumerate(stack):
|
||||
band = layer(band)
|
||||
band = torch.nn.functional.leaky_relu(band, 0.1)
|
||||
if i > 0:
|
||||
fmap.append(band)
|
||||
x.append(band)
|
||||
x = torch.cat(x, dim=-1)
|
||||
if cond_embedding_id is not None:
|
||||
emb = self.emb(cond_embedding_id)
|
||||
h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
|
||||
else:
|
||||
h = 0
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x += h
|
||||
|
||||
return x, fmap
|
371
egs/libritts/TTS/vocos/export-onnx.py
Executable file
371
egs/libritts/TTS/vocos/export-onnx.py
Executable file
@ -0,0 +1,371 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang)
|
||||
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
|
||||
|
||||
"""
|
||||
This script exports a transducer model from PyTorch to ONNX.
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
2. Export the model to ONNX
|
||||
|
||||
./zipformer/export-onnx.py \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--exp-dir $repo/exp \
|
||||
--num-encoder-layers "2,2,3,4,3,2" \
|
||||
--downsampling-factor "1,2,4,8,4,2" \
|
||||
--feedforward-dim "512,768,1024,1536,1024,768" \
|
||||
--num-heads "4,4,4,8,4,4" \
|
||||
--encoder-dim "192,256,384,512,384,256" \
|
||||
--query-head-dim 32 \
|
||||
--value-head-dim 12 \
|
||||
--pos-head-dim 4 \
|
||||
--pos-dim 48 \
|
||||
--encoder-unmasked-dim "192,192,256,256,256,192" \
|
||||
--cnn-module-kernel "31,31,15,15,15,31" \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512 \
|
||||
--causal False \
|
||||
--chunk-size "16,32,64,-1" \
|
||||
--left-context-frames "64,128,256,-1" \
|
||||
--fp16 True
|
||||
It will generate the following 3 files inside $repo/exp:
|
||||
|
||||
- encoder-epoch-99-avg-1.onnx
|
||||
- decoder-epoch-99-avg-1.onnx
|
||||
- joiner-epoch-99-avg-1.onnx
|
||||
|
||||
See ./onnx_pretrained.py and ./onnx_check.py for how to
|
||||
use the exported ONNX models.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import onnx
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from onnxconverter_common import float16
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import make_pad_mask, num_tokens, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sampling-rate",
|
||||
type=int,
|
||||
default=24000,
|
||||
help="The sampleing rate of libritts dataset",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--frame-shift",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Frame shift.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--frame-length",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Frame shift.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="zipformer/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to export models in fp16",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||
"""Add meta data to an ONNX model. It is changed in-place.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename of the ONNX model to be changed.
|
||||
meta_data:
|
||||
Key-value pairs.
|
||||
"""
|
||||
model = onnx.load(filename)
|
||||
for key, value in meta_data.items():
|
||||
meta = model.metadata_props.add()
|
||||
meta.key = key
|
||||
meta.value = value
|
||||
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
def export_model_onnx(
|
||||
model: nn.Module,
|
||||
model_filename: str,
|
||||
opset_version: int = 13,
|
||||
) -> None:
|
||||
"""Export the joiner model to ONNX format.
|
||||
The exported joiner model has two inputs:
|
||||
|
||||
- encoder_out: a tensor of shape (N, joiner_dim)
|
||||
- decoder_out: a tensor of shape (N, joiner_dim)
|
||||
|
||||
and produces one output:
|
||||
|
||||
- logit: a tensor of shape (N, vocab_size)
|
||||
"""
|
||||
input_tensor = torch.rand((2, 80, 100), dtype=torch.float32)
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(input_tensor,),
|
||||
model_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=[
|
||||
"features",
|
||||
],
|
||||
output_names=["audio"],
|
||||
dynamic_axes={
|
||||
"features": {0: "N", 2: "F"},
|
||||
"audio": {0: "N", 1: "T"},
|
||||
},
|
||||
)
|
||||
|
||||
meta_data = {
|
||||
"model_type": "Vocos",
|
||||
"version": "1",
|
||||
"model_author": "k2-fsa",
|
||||
"comment": "ConvNext Vocos",
|
||||
}
|
||||
logging.info(f"meta_data: {meta_data}")
|
||||
|
||||
add_meta_data(filename=model_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
params.device = device
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
model.to(device)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.eval()
|
||||
vocos = model.generator
|
||||
|
||||
if params.iter > 0:
|
||||
suffix = f"iter-{params.iter}"
|
||||
else:
|
||||
suffix = f"epoch-{params.epoch}"
|
||||
|
||||
suffix += f"-avg-{params.avg}"
|
||||
|
||||
opset_version = 13
|
||||
|
||||
logging.info("Exporting model")
|
||||
model_filename = params.exp_dir / f"vocos-{suffix}.onnx"
|
||||
export_model_onnx(
|
||||
vocos,
|
||||
model_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported vocos generator to {model_filename}")
|
||||
|
||||
if params.fp16:
|
||||
logging.info("Generate fp16 models")
|
||||
|
||||
model = onnx.load(model_filename)
|
||||
model_fp16 = float16.convert_float_to_float16(model, keep_io_types=True)
|
||||
model_filename_fp16 = params.exp_dir / f"vocos-{suffix}.fp16.onnx"
|
||||
onnx.save(model_fp16, model_filename_fp16)
|
||||
|
||||
# Generate int8 quantization models
|
||||
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||
|
||||
logging.info("Generate int8 quantization models")
|
||||
|
||||
model_filename_int8 = params.exp_dir / f"vocos-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=model_filename,
|
||||
model_output=model_filename_int8,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
407
egs/libritts/TTS/vocos/export.py
Executable file
407
egs/libritts/TTS/vocos/export.py
Executable file
@ -0,0 +1,407 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2024 Xiaomi Corporation (Author: Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This script converts several saved checkpoints
|
||||
# to a single one using model averaging.
|
||||
"""
|
||||
|
||||
Usage:
|
||||
|
||||
Note: This is a example for libritts dataset, if you are using different
|
||||
dataset, you should change the argument values according to your dataset.
|
||||
|
||||
(1) Export to torchscript model using torch.jit.script()
|
||||
|
||||
|
||||
./vocos/export.py \
|
||||
--exp-dir ./vocos/exp \
|
||||
--epoch 30 \
|
||||
--avg 9 \
|
||||
--jit 1
|
||||
|
||||
It will generate a file `jit_script.pt` in the given `exp_dir`. You can later
|
||||
load it by `torch.jit.load("jit_script.pt")`.
|
||||
|
||||
Check ./jit_pretrained.py for its usage.
|
||||
|
||||
Check https://github.com/k2-fsa/sherpa
|
||||
for how to use the exported models outside of icefall.
|
||||
|
||||
- For streaming model:
|
||||
|
||||
./zipformer/export.py \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--causal 1 \
|
||||
--chunk-size 16 \
|
||||
--left-context-frames 128 \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch 30 \
|
||||
--avg 9 \
|
||||
--jit 1
|
||||
|
||||
It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`.
|
||||
You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`.
|
||||
|
||||
Check ./jit_pretrained_streaming.py for its usage.
|
||||
|
||||
Check https://github.com/k2-fsa/sherpa
|
||||
for how to use the exported models outside of icefall.
|
||||
|
||||
(2) Export `model.state_dict()`
|
||||
|
||||
- For non-streaming model:
|
||||
|
||||
./zipformer/export.py \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch 30 \
|
||||
--avg 9
|
||||
|
||||
- For streaming model:
|
||||
|
||||
./zipformer/export.py \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--causal 1 \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch 30 \
|
||||
--avg 9
|
||||
|
||||
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
|
||||
load it by `icefall.checkpoint.load_checkpoint()`.
|
||||
|
||||
- For non-streaming model:
|
||||
|
||||
To use the generated file with `zipformer/decode.py`,
|
||||
you can do:
|
||||
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
|
||||
cd /path/to/egs/librispeech/ASR
|
||||
./zipformer/decode.py \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search \
|
||||
--bpe-model data/lang_bpe_500/bpe.model
|
||||
|
||||
- For streaming model:
|
||||
|
||||
To use the generated file with `zipformer/decode.py` and `zipformer/streaming_decode.py`, you can do:
|
||||
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
|
||||
cd /path/to/egs/librispeech/ASR
|
||||
|
||||
# simulated streaming decoding
|
||||
./zipformer/decode.py \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 600 \
|
||||
--causal 1 \
|
||||
--chunk-size 16 \
|
||||
--left-context-frames 128 \
|
||||
--decoding-method greedy_search \
|
||||
--bpe-model data/lang_bpe_500/bpe.model
|
||||
|
||||
# chunk-wise streaming decoding
|
||||
./zipformer/streaming_decode.py \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 600 \
|
||||
--causal 1 \
|
||||
--chunk-size 16 \
|
||||
--left-context-frames 128 \
|
||||
--decoding-method greedy_search \
|
||||
--bpe-model data/lang_bpe_500/bpe.model
|
||||
|
||||
Check ./pretrained.py for its usage.
|
||||
|
||||
Note: If you don't want to train a model from scratch, we have
|
||||
provided one for you. You can get it at
|
||||
|
||||
- non-streaming model:
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||
|
||||
- streaming model:
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
|
||||
|
||||
with the following commands:
|
||||
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
|
||||
# You will find the pre-trained models in exp dir
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
)
|
||||
|
||||
from icefall.utils import str2bool
|
||||
from utils import load_checkpoint
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sampling-rate",
|
||||
type=int,
|
||||
default=24000,
|
||||
help="The sampleing rate of libritts dataset",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--frame-shift",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Frame shift.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--frame-length",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Frame shift.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="vocos/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to save a model after applying torch.jit.script.
|
||||
It will generate a file named jit_script.pt.
|
||||
Check ./jit_pretrained.py for how to use it.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class EncoderModel(nn.Module):
|
||||
"""A wrapper for encoder and encoder_embed"""
|
||||
|
||||
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.encoder_embed = encoder_embed
|
||||
|
||||
def forward(
|
||||
self, features: Tensor, feature_lengths: Tensor
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Args:
|
||||
features: (N, T, C)
|
||||
feature_lengths: (N,)
|
||||
"""
|
||||
x, x_lens = self.encoder_embed(features, feature_lengths)
|
||||
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
params.device = device
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.eval()
|
||||
|
||||
model = model.generator
|
||||
|
||||
if params.jit is True:
|
||||
model.encoder = EncoderModel(model.encoder, model.encoder_embed)
|
||||
filename = "jit_script.pt"
|
||||
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
model.save(str(params.exp_dir / filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
else:
|
||||
logging.info("Not using torchscript. Export model.state_dict()")
|
||||
# Save it using a format so that it can be loaded
|
||||
# by :func:`load_checkpoint`
|
||||
filename = params.exp_dir / "generator.pt"
|
||||
torch.save({"model": model.state_dict()}, str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
264
egs/libritts/TTS/vocos/generator.py
Normal file
264
egs/libritts/TTS/vocos/generator.py
Normal file
@ -0,0 +1,264 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Variable
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def window_sumsquare(
|
||||
window: torch.Tensor,
|
||||
n_samples: int,
|
||||
hop_length: int = 256,
|
||||
win_length: int = 1024,
|
||||
):
|
||||
"""
|
||||
Compute the sum-square envelope of a window function at a given hop length.
|
||||
This is used to estimate modulation effects induced by windowing
|
||||
observations in short-time fourier transforms.
|
||||
Parameters
|
||||
----------
|
||||
window : string, tuple, number, callable, or list-like
|
||||
Window specification, as in `get_window`
|
||||
n_samples : int > 0
|
||||
The number of expected samples.
|
||||
hop_length : int > 0
|
||||
The number of samples to advance between frames
|
||||
win_length :
|
||||
The length of the window function.
|
||||
Returns
|
||||
-------
|
||||
wss : torch.Tensor, The sum-squared envelope of the window function.
|
||||
"""
|
||||
|
||||
n_frames = (n_samples - win_length) // hop_length + 1
|
||||
output_size = (n_frames - 1) * hop_length + win_length
|
||||
device = window.device
|
||||
|
||||
# Window envelope
|
||||
window_sq = window.square().expand(1, n_frames, -1).transpose(1, 2)
|
||||
window_envelope = torch.nn.functional.fold(
|
||||
window_sq,
|
||||
output_size=(1, output_size),
|
||||
kernel_size=(1, win_length),
|
||||
stride=(1, hop_length),
|
||||
).squeeze()
|
||||
window_envelope = torch.nn.functional.pad(
|
||||
window_envelope, (0, n_samples - output_size)
|
||||
)
|
||||
return window_envelope
|
||||
|
||||
|
||||
class ISTFT(torch.nn.Module):
|
||||
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter_length: int = 1024,
|
||||
hop_length: int = 256,
|
||||
win_length: int = 1024,
|
||||
padding: str = "none",
|
||||
window_type: str = "povey",
|
||||
max_samples: int = 1440000, # 1440000 / 24000 = 60s
|
||||
):
|
||||
super(ISTFT, self).__init__()
|
||||
self.filter_length = filter_length
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.padding = padding
|
||||
scale = self.filter_length / self.hop_length
|
||||
fourier_basis = np.fft.fft(np.eye(self.filter_length))
|
||||
|
||||
cutoff = int((self.filter_length / 2 + 1))
|
||||
fourier_basis = np.vstack(
|
||||
[np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
|
||||
)
|
||||
inverse_basis = torch.FloatTensor(
|
||||
np.linalg.pinv(scale * fourier_basis).T[:, None, :]
|
||||
)
|
||||
|
||||
assert filter_length >= win_length
|
||||
# Consistence with lhotse, search "create_frame_window" in https://github.com/lhotse-speech/lhotse
|
||||
assert window_type in [
|
||||
"hanning",
|
||||
"povey",
|
||||
], f"Only 'hanning' and 'povey' windows are supported, given {window_type}."
|
||||
fft_window = torch.hann_window(win_length, periodic=False)
|
||||
if window_type == "povey":
|
||||
fft_window = fft_window.pow(0.85)
|
||||
|
||||
if filter_length > win_length:
|
||||
pad_size = (filter_length - win_length) // 2
|
||||
fft_window = torch.nn.functional.pad(fft_window, (pad_size, pad_size))
|
||||
|
||||
window_sum = window_sumsquare(
|
||||
window=fft_window,
|
||||
n_samples=max_samples,
|
||||
hop_length=hop_length,
|
||||
win_length=filter_length,
|
||||
)
|
||||
|
||||
inverse_basis *= fft_window
|
||||
|
||||
self.register_buffer("inverse_basis", inverse_basis.float())
|
||||
self.register_buffer("fft_window", fft_window)
|
||||
self.register_buffer("window_sum", window_sum)
|
||||
self.tiny = torch.finfo(torch.float16).tiny
|
||||
|
||||
def forward(self, magnitude, phase):
|
||||
magnitude_phase = torch.cat(
|
||||
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
|
||||
)
|
||||
inverse_transform = F.conv_transpose1d(
|
||||
magnitude_phase,
|
||||
Variable(self.inverse_basis, requires_grad=False),
|
||||
stride=self.hop_length,
|
||||
padding=0,
|
||||
)
|
||||
inverse_transform = inverse_transform.squeeze(1)
|
||||
|
||||
window_sum = self.window_sum
|
||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||
if self.window_sum.size(-1) < inverse_transform.size(-1):
|
||||
logging.warning(
|
||||
f"The precomputed `window_sumsquare` is too small, recomputing, "
|
||||
f"from {self.window_sum.size(-1)} to {inverse_transform.size(-1)}"
|
||||
)
|
||||
window_sum = window_sumsquare(
|
||||
window=self.fft_window,
|
||||
n_samples=inverse_transform.size(-1),
|
||||
win_length=self.filter_length,
|
||||
hop_length=self.hop_length,
|
||||
)
|
||||
window_sum = window_sum[: inverse_transform.size(-1)]
|
||||
approx_nonzero_indices = (window_sum > self.tiny).nonzero().squeeze()
|
||||
|
||||
inverse_transform[:, approx_nonzero_indices] /= window_sum[
|
||||
approx_nonzero_indices
|
||||
]
|
||||
|
||||
# scale by hop ratio
|
||||
inverse_transform *= float(self.filter_length) / self.hop_length
|
||||
assert self.padding in ["none", "same", "center"]
|
||||
if self.padding == "center":
|
||||
pad_len = self.filter_length // 2
|
||||
elif self.padding == "same":
|
||||
pad_len = (self.filter_length - self.hop_length) // 2
|
||||
else:
|
||||
return inverse_transform
|
||||
return inverse_transform[:, pad_len:-pad_len]
|
||||
|
||||
|
||||
class ConvNeXtBlock(nn.Module):
|
||||
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
intermediate_dim (int): Dimensionality of the intermediate layer.
|
||||
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
intermediate_dim: int,
|
||||
layer_scale_init_value: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dwconv = nn.Conv1d(
|
||||
dim, dim, kernel_size=7, padding=3, groups=dim
|
||||
) # depthwise conv
|
||||
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
||||
# pointwise/1x1 convs, implemented with linear layers
|
||||
self.pwconv1 = nn.Linear(dim, intermediate_dim)
|
||||
self.act = nn.GELU()
|
||||
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
||||
self.gamma = (
|
||||
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
||||
if layer_scale_init_value > 0
|
||||
else None
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
residual = x
|
||||
x = self.dwconv(x)
|
||||
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
||||
x = self.norm(x)
|
||||
x = self.pwconv1(x)
|
||||
x = self.act(x)
|
||||
x = self.pwconv2(x)
|
||||
if self.gamma is not None:
|
||||
x = self.gamma * x
|
||||
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
|
||||
|
||||
x = residual + x
|
||||
return x
|
||||
|
||||
|
||||
class Generator(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
feature_dim: int = 80,
|
||||
dim: int = 512,
|
||||
n_fft: int = 1024,
|
||||
hop_length: int = 256,
|
||||
intermediate_dim: int = 1536,
|
||||
num_layers: int = 8,
|
||||
padding: str = "none",
|
||||
max_samples: int = 1440000, # 1440000 / 24000 = 60s
|
||||
):
|
||||
super(Generator, self).__init__()
|
||||
self.feature_dim = feature_dim
|
||||
self.embed = nn.Conv1d(feature_dim, dim, kernel_size=7, padding=3)
|
||||
|
||||
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
||||
|
||||
layer_scale_init_value = 1 / num_layers
|
||||
self.convnext = nn.ModuleList(
|
||||
[
|
||||
ConvNeXtBlock(
|
||||
dim=dim,
|
||||
intermediate_dim=intermediate_dim,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
self.out_proj = torch.nn.Linear(dim, n_fft + 2)
|
||||
self.istft = ISTFT(
|
||||
filter_length=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=n_fft,
|
||||
padding=padding,
|
||||
max_samples=max_samples,
|
||||
)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
||||
nn.init.trunc_normal_(m.weight, std=0.02)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.embed(x)
|
||||
x = self.norm(x.transpose(1, 2))
|
||||
x = x.transpose(1, 2)
|
||||
for conv_block in self.convnext:
|
||||
x = conv_block(x)
|
||||
x = self.final_layer_norm(x.transpose(1, 2))
|
||||
x = self.out_proj(x).transpose(1, 2)
|
||||
mag, phase = x.chunk(2, dim=1)
|
||||
mag = torch.exp(mag)
|
||||
# safeguard to prevent excessively large magnitudes
|
||||
mag = torch.clip(mag, max=1e2)
|
||||
audio = self.istft(mag, phase)
|
||||
return audio
|
352
egs/libritts/TTS/vocos/infer.py
Executable file
352
egs/libritts/TTS/vocos/infer.py
Executable file
@ -0,0 +1,352 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Wei Kang
|
||||
# Han Zhu)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from lhotse.utils import fix_random_seed
|
||||
from scipy.io.wavfile import write
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
from tts_datamodule import LibriTTSDataModule
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=100,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="vocos/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--generate-dir",
|
||||
type=str,
|
||||
default="generated_wavs",
|
||||
help="Path name of the generated wavs",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
batch: dict,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The text-to-feature neural model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
|
||||
cut_ids = [cut.id for cut in batch["cut"]]
|
||||
|
||||
infer_time = 0
|
||||
audio_time = 0
|
||||
|
||||
features = batch["features"] # (B, T, F)
|
||||
utt_durations = batch["features_lens"]
|
||||
|
||||
x = features.permute(0, 2, 1) # (B, F, T)
|
||||
|
||||
audio_time += torch.sum(utt_durations)
|
||||
|
||||
start = time.time()
|
||||
|
||||
audios = model(x.to(device)) # (B, T)
|
||||
|
||||
infer_time += time.time() - start
|
||||
|
||||
wav_dir = f"{params.res_dir}/{params.suffix}"
|
||||
os.makedirs(wav_dir, exist_ok=True)
|
||||
|
||||
for i in range(audios.shape[0]):
|
||||
audio = audios[i][: int(utt_durations[i] * 256)]
|
||||
audio = audio.cpu().squeeze().numpy()
|
||||
write(f"{wav_dir}/{cut_ids[i]}.wav", 24000, audio)
|
||||
|
||||
print(f"RTF : {infer_time / (audio_time * (256/24000))}")
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
test_set: str,
|
||||
):
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The text-to-feature neural model.
|
||||
test_set:
|
||||
The name of the test_set
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
with open(f"{params.res_dir}/{test_set}.scp", "w", encoding="utf8") as f:
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
# texts = batch["text"]
|
||||
cut_ids = [cut.id for cut in batch["cut"]]
|
||||
|
||||
decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
# assert len(texts) == len(cut_ids), (len(texts), len(cut_ids))
|
||||
|
||||
# for i in range(len(texts)):
|
||||
# f.write(f"{cut_ids[i]}\t{texts[i]}\n")
|
||||
|
||||
# num_cuts += len(texts)
|
||||
|
||||
if batch_idx % 50 == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriTTSDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
params.res_dir = params.exp_dir / params.generate_dir
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
params.device = device
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
logging.info(params)
|
||||
fix_random_seed(666)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
libritts = LibriTTSDataModule(args)
|
||||
|
||||
test_cuts = libritts.test_clean_cuts()
|
||||
|
||||
test_dl = libritts.test_dataloaders(test_cuts)
|
||||
|
||||
test_sets = ["test"]
|
||||
test_dls = [test_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dls):
|
||||
decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
test_set=test_set,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
133
egs/libritts/TTS/vocos/loss.py
Normal file
133
egs/libritts/TTS/vocos/loss.py
Normal file
@ -0,0 +1,133 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch import nn
|
||||
|
||||
from utils import safe_log
|
||||
|
||||
|
||||
class MelSpecReconstructionLoss(nn.Module):
|
||||
"""
|
||||
L1 distance between the mel-scaled magnitude spectrograms of the ground truth sample and the generated sample
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate: int = 24000,
|
||||
n_fft: int = 1024,
|
||||
hop_length: int = 256,
|
||||
n_mels: int = 100,
|
||||
):
|
||||
super().__init__()
|
||||
self.mel_spec = torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=sample_rate,
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
n_mels=n_mels,
|
||||
center=True,
|
||||
power=1,
|
||||
)
|
||||
|
||||
def forward(self, y_hat, y) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
y_hat (Tensor): Predicted audio waveform.
|
||||
y (Tensor): Ground truth audio waveform.
|
||||
|
||||
Returns:
|
||||
Tensor: L1 loss between the mel-scaled magnitude spectrograms.
|
||||
"""
|
||||
mel_hat = safe_log(self.mel_spec(y_hat))
|
||||
mel = safe_log(self.mel_spec(y))
|
||||
|
||||
loss = torch.nn.functional.l1_loss(mel, mel_hat)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class GeneratorLoss(nn.Module):
|
||||
"""
|
||||
Generator Loss module. Calculates the loss for the generator based on discriminator outputs.
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self, disc_outputs: List[torch.Tensor]
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
disc_outputs (List[Tensor]): List of discriminator outputs.
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from
|
||||
the sub-discriminators
|
||||
"""
|
||||
loss = torch.zeros(
|
||||
1, device=disc_outputs[0].device, dtype=disc_outputs[0].dtype
|
||||
)
|
||||
gen_losses = []
|
||||
for dg in disc_outputs:
|
||||
l = torch.mean(torch.clamp(1 - dg, min=0))
|
||||
gen_losses.append(l)
|
||||
loss += l
|
||||
|
||||
return loss, gen_losses
|
||||
|
||||
|
||||
class DiscriminatorLoss(nn.Module):
|
||||
"""
|
||||
Discriminator Loss module. Calculates the loss for the discriminator based on real and generated outputs.
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
disc_real_outputs: List[torch.Tensor],
|
||||
disc_generated_outputs: List[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
disc_real_outputs (List[Tensor]): List of discriminator outputs for real samples.
|
||||
disc_generated_outputs (List[Tensor]): List of discriminator outputs for generated samples.
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor, List[Tensor], List[Tensor]]: A tuple containing the total loss, a list of loss values from
|
||||
the sub-discriminators for real outputs, and a list of
|
||||
loss values for generated outputs.
|
||||
"""
|
||||
loss = torch.zeros(
|
||||
1, device=disc_real_outputs[0].device, dtype=disc_real_outputs[0].dtype
|
||||
)
|
||||
r_losses = []
|
||||
g_losses = []
|
||||
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||
r_loss = torch.mean(torch.clamp(1 - dr, min=0))
|
||||
g_loss = torch.mean(torch.clamp(1 + dg, min=0))
|
||||
loss += r_loss + g_loss
|
||||
r_losses.append(r_loss.item())
|
||||
g_losses.append(g_loss.item())
|
||||
|
||||
return loss, r_losses, g_losses
|
||||
|
||||
|
||||
class FeatureMatchingLoss(nn.Module):
|
||||
"""
|
||||
Feature Matching Loss module. Calculates the feature matching loss between feature maps of the sub-discriminators.
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self, fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
fmap_r (List[List[Tensor]]): List of feature maps from real samples.
|
||||
fmap_g (List[List[Tensor]]): List of feature maps from generated samples.
|
||||
|
||||
Returns:
|
||||
Tensor: The calculated feature matching loss.
|
||||
"""
|
||||
loss = torch.zeros(1, device=fmap_r[0][0].device, dtype=fmap_r[0][0].dtype)
|
||||
for dr, dg in zip(fmap_r, fmap_g):
|
||||
for rl, gl in zip(dr, dg):
|
||||
loss += torch.mean(torch.abs(rl - gl))
|
||||
|
||||
return loss
|
48
egs/libritts/TTS/vocos/model.py
Normal file
48
egs/libritts/TTS/vocos/model.py
Normal file
@ -0,0 +1,48 @@
|
||||
import logging
|
||||
import torch
|
||||
from discriminators import MultiPeriodDiscriminator, MultiResolutionDiscriminator
|
||||
from generator import Generator
|
||||
from loss import (
|
||||
DiscriminatorLoss,
|
||||
GeneratorLoss,
|
||||
FeatureMatchingLoss,
|
||||
MelSpecReconstructionLoss,
|
||||
)
|
||||
|
||||
|
||||
class Vocos(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
feature_dim: int = 80,
|
||||
dim: int = 512,
|
||||
n_fft: int = 1024,
|
||||
hop_length: int = 256,
|
||||
intermediate_dim: int = 1536,
|
||||
num_layers: int = 8,
|
||||
padding: str = "none",
|
||||
sample_rate: int = 24000,
|
||||
max_seconds: int = 60,
|
||||
):
|
||||
super(Vocos, self).__init__()
|
||||
self.generator = Generator(
|
||||
feature_dim=feature_dim,
|
||||
dim=dim,
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
num_layers=num_layers,
|
||||
intermediate_dim=intermediate_dim,
|
||||
padding=padding,
|
||||
max_samples=int(sample_rate * max_seconds),
|
||||
)
|
||||
|
||||
self.mpd = MultiPeriodDiscriminator()
|
||||
self.mrd = MultiResolutionDiscriminator()
|
||||
|
||||
self.disc_loss = DiscriminatorLoss()
|
||||
self.gen_loss = GeneratorLoss()
|
||||
self.feat_matching_loss = FeatureMatchingLoss()
|
||||
self.melspec_loss = MelSpecReconstructionLoss(sample_rate=sample_rate)
|
||||
|
||||
def forward(self, features: torch.Tensor):
|
||||
audio = self.generator(features)
|
||||
return audio
|
268
egs/libritts/TTS/vocos/onnx_pretrained.py
Executable file
268
egs/libritts/TTS/vocos/onnx_pretrained.py
Executable file
@ -0,0 +1,268 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads ONNX models and uses them to decode waves.
|
||||
You can use the following command to get the exported models:
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
2. Export the model to ONNX
|
||||
|
||||
./zipformer/export-onnx.py \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--exp-dir $repo/exp \
|
||||
--causal False
|
||||
|
||||
It will generate the following 3 files inside $repo/exp:
|
||||
|
||||
- encoder-epoch-99-avg-1.onnx
|
||||
- decoder-epoch-99-avg-1.onnx
|
||||
- joiner-epoch-99-avg-1.onnx
|
||||
|
||||
3. Run this file
|
||||
|
||||
./zipformer/onnx_pretrained.py \
|
||||
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from lhotse import Fbank, FbankConfig
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the encoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sampling-rate",
|
||||
type=int,
|
||||
default=24000,
|
||||
help="The sampleing rate of libritts dataset",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--frame-shift",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Frame shift.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--frame-length",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Frame shift.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-fft-mag",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to use magnitude of fbank, false to use power energy.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="generated_audios",
|
||||
help="The generated will be written to.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(
|
||||
self,
|
||||
model_filename: str,
|
||||
):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 4
|
||||
|
||||
self.session_opts = session_opts
|
||||
|
||||
self.init_model(model_filename)
|
||||
|
||||
def init_model(self, model_filename: str):
|
||||
self.model = ort.InferenceSession(
|
||||
model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def run_model(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
x_lens:
|
||||
A 2-D tensor of shape (N,). Its dtype is torch.int64
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- encoder_out, its shape is (N, T', joiner_dim)
|
||||
- encoder_out_lens, its shape is (N,)
|
||||
"""
|
||||
out = self.model.run(
|
||||
[
|
||||
self.model.get_outputs()[0].name,
|
||||
],
|
||||
{
|
||||
self.model.get_inputs()[0].name: x.numpy(),
|
||||
},
|
||||
)
|
||||
return torch.from_numpy(out[0])
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
output_dir = Path(args.model_filename).parent / args.output_dir
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
args.output_dir = output_dir
|
||||
logging.info(vars(args))
|
||||
|
||||
model = OnnxModel(model_filename=args.model_filename)
|
||||
|
||||
config = FbankConfig(
|
||||
sampling_rate=args.sampling_rate,
|
||||
frame_length=args.frame_length / args.sampling_rate, # (in second),
|
||||
frame_shift=args.frame_shift / args.sampling_rate, # (in second)
|
||||
use_fft_mag=args.use_fft_mag,
|
||||
)
|
||||
fbank = Fbank(config)
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_files}")
|
||||
|
||||
waves = read_sound_files(
|
||||
filenames=args.sound_files, expected_sample_rate=args.sampling_rate
|
||||
)
|
||||
wave_lengths = [w.size(0) for w in waves]
|
||||
waves = pad_sequence(waves, batch_first=True, padding_value=0)
|
||||
|
||||
logging.info(f"waves : {waves.shape}")
|
||||
|
||||
features = fbank.extract_batch(waves, sampling_rate=args.sampling_rate)
|
||||
|
||||
if features.dim() == 2:
|
||||
features = features.unsqueeze(0)
|
||||
|
||||
features = features.permute(0, 2, 1)
|
||||
|
||||
logging.info(f"features : {features.shape}")
|
||||
|
||||
logging.info("Generating started")
|
||||
|
||||
# model forward
|
||||
audios = model.run_model(features)
|
||||
|
||||
for i, filename in enumerate(args.sound_files):
|
||||
audio = audios[i : i + 1, 0 : wave_lengths[i]]
|
||||
ofilename = args.output_dir / filename.split("/")[-1]
|
||||
logging.info(f"Writting audio : {ofilename}")
|
||||
torchaudio.save(str(ofilename), audio.cpu(), args.sampling_rate)
|
||||
|
||||
logging.info("Generating Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
196
egs/libritts/TTS/vocos/pretrained.py
Executable file
196
egs/libritts/TTS/vocos/pretrained.py
Executable file
@ -0,0 +1,196 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads a checkpoint and uses it to decode waves.
|
||||
You can generate the checkpoint with the following command:
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
from lhotse import Fbank, FbankConfig
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the checkpoint. "
|
||||
"The checkpoint is assumed to be saved by "
|
||||
"icefall.checkpoint.save_checkpoint().",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sampling-rate",
|
||||
type=int,
|
||||
default=24000,
|
||||
help="The sampleing rate of libritts dataset",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--frame-shift",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Frame shift.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--frame-length",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Frame shift.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-fft-mag",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to use magnitude of fbank, false to use power energy.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="generated_audios",
|
||||
help="The generated will be written to.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0].contiguous())
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
params.device = device
|
||||
|
||||
output_dir = Path(params.checkpoint).parent / params.output_dir
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
params.output_dir = output_dir
|
||||
|
||||
logging.info(f"{params}")
|
||||
|
||||
logging.info("Creating model")
|
||||
model = get_model(params)
|
||||
|
||||
model = model.generator
|
||||
|
||||
checkpoint = torch.load(params.checkpoint, map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
|
||||
config = FbankConfig(
|
||||
sampling_rate=params.sampling_rate,
|
||||
frame_length=params.frame_length / params.sampling_rate, # (in second),
|
||||
frame_shift=params.frame_shift / params.sampling_rate, # (in second)
|
||||
use_fft_mag=params.use_fft_mag,
|
||||
)
|
||||
fbank = Fbank(config)
|
||||
|
||||
logging.info(f"Reading sound files: {params.sound_files}")
|
||||
|
||||
waves = read_sound_files(
|
||||
filenames=params.sound_files, expected_sample_rate=params.sampling_rate
|
||||
)
|
||||
wave_lengths = [w.size(0) for w in waves]
|
||||
waves = pad_sequence(waves, batch_first=True, padding_value=0)
|
||||
|
||||
features = (
|
||||
fbank.extract_batch(waves, sampling_rate=params.sampling_rate)
|
||||
.permute(0, 2, 1)
|
||||
.to(device)
|
||||
)
|
||||
|
||||
logging.info("Generating started")
|
||||
|
||||
# model forward
|
||||
audios = model(features)
|
||||
|
||||
for i, filename in enumerate(params.sound_files):
|
||||
audio = audios[i : i + 1, 0 : wave_lengths[i]]
|
||||
ofilename = params.output_dir / filename.split("/")[-1]
|
||||
logging.info(f"Writting audio : {ofilename}")
|
||||
torchaudio.save(str(ofilename), audio.cpu(), params.sampling_rate)
|
||||
|
||||
logging.info("Generating Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1054
egs/libritts/TTS/vocos/train.py
Executable file
1054
egs/libritts/TTS/vocos/train.py
Executable file
File diff suppressed because it is too large
Load Diff
419
egs/libritts/TTS/vocos/tts_datamodule.py
Normal file
419
egs/libritts/TTS/vocos/tts_datamodule.py
Normal file
@ -0,0 +1,419 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo,
|
||||
# Zengwei Yao,
|
||||
# Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
PrecomputedFeatures,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
SpeechSynthesisDataset,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
AudioSamples,
|
||||
OnTheFlyFeatures,
|
||||
)
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class LibriTTSDataModule:
|
||||
"""
|
||||
DataModule for tts experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- cut concatenation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="TTS data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/fbank"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-text",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to return the text of the audio.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-tokens",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether the return the tokens of the text of the audio.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--sampling-rate",
|
||||
type=int,
|
||||
default=24000,
|
||||
help="The sampleing rate of libritts dataset",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--frame-shift",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Frame shift.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--frame-length",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Frame shift.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--input-strategy",
|
||||
type=str,
|
||||
default="PrecomputedFeatures",
|
||||
help="AudioSamples or PrecomputedFeatures",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--use-fft-mag",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to use magnitude of fbank, false to use power energy.",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
logging.info("About to create train dataset")
|
||||
train = SpeechSynthesisDataset(
|
||||
return_text=self.args.return_text,
|
||||
return_tokens=self.args.return_tokens,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = self.args.sampling_rate
|
||||
config = FbankConfig(
|
||||
sampling_rate=sampling_rate,
|
||||
frame_length=self.args.frame_length / sampling_rate, # (in second),
|
||||
frame_shift=self.args.frame_shift / sampling_rate, # (in second)
|
||||
use_fft_mag=self.args.use_fft_mag,
|
||||
)
|
||||
train = SpeechSynthesisDataset(
|
||||
return_text=self.args.return_text,
|
||||
return_tokens=self.args.return_tokens,
|
||||
feature_input_strategy=OnTheFlyFeatures(Fbank(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
buffer_size=self.args.num_buckets * 2000,
|
||||
shuffle_buffer_size=self.args.num_buckets * 5000,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = self.args.sampling_rate
|
||||
config = FbankConfig(
|
||||
sampling_rate=sampling_rate,
|
||||
frame_length=self.args.frame_length / sampling_rate, # (in second),
|
||||
frame_shift=self.args.frame_shift / sampling_rate, # (in second)
|
||||
use_fft_mag=self.args.use_fft_mag,
|
||||
)
|
||||
validate = SpeechSynthesisDataset(
|
||||
return_text=self.args.return_text,
|
||||
return_tokens=self.args.return_tokens,
|
||||
feature_input_strategy=OnTheFlyFeatures(Fbank(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
validate = SpeechSynthesisDataset(
|
||||
return_text=self.args.return_text,
|
||||
return_tokens=self.args.return_tokens,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
num_buckets=self.args.num_buckets,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create valid dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.info("About to create test dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = self.args.sampling_rate
|
||||
config = FbankConfig(
|
||||
sampling_rate=sampling_rate,
|
||||
frame_length=self.args.frame_length / sampling_rate, # (in second),
|
||||
frame_shift=self.args.frame_shift / sampling_rate, # (in second)
|
||||
use_fft_mag=self.args.use_fft_mag,
|
||||
)
|
||||
test = SpeechSynthesisDataset(
|
||||
return_text=self.args.return_text,
|
||||
return_tokens=self.args.return_tokens,
|
||||
feature_input_strategy=OnTheFlyFeatures(Fbank(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
test = SpeechSynthesisDataset(
|
||||
return_text=self.args.return_text,
|
||||
return_tokens=self.args.return_tokens,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
test_sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration,
|
||||
num_buckets=self.args.num_buckets,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=test_sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_train-all-shuf.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get train clean cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_train-clean-460.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_100_cuts(self) -> CutSet:
|
||||
logging.info("About to get train clean 100 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_train-clean-100.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_360_cuts(self) -> CutSet:
|
||||
logging.info("About to get train clean 360 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_train-clean-360.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev clean cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_dev-clean.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev other cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_dev-other.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get test clean cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_test-clean.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get test other cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_test-other.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts_finetune(self) -> CutSet:
|
||||
logging.info("About to get train cuts finetune")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_train_finetune.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def valid_cuts_finetune(self) -> CutSet:
|
||||
logging.info("About to get validation cuts finetune")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_valid_finetune.jsonl.gz"
|
||||
)
|
282
egs/libritts/TTS/vocos/utils.py
Normal file
282
egs/libritts/TTS/vocos/utils.py
Normal file
@ -0,0 +1,282 @@
|
||||
import glob
|
||||
import os
|
||||
import logging
|
||||
import matplotlib
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from torch.nn.utils import weight_norm
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from torch.optim import Optimizer
|
||||
from torch.cuda.amp import GradScaler
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from torch import Tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pylab as plt
|
||||
|
||||
|
||||
def plot_spectrogram(spectrogram):
|
||||
fig, ax = plt.subplots(figsize=(10, 2))
|
||||
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
||||
plt.colorbar(im, ax=ax)
|
||||
|
||||
fig.canvas.draw()
|
||||
plt.close()
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def save_checkpoint_with_global_batch_idx(
|
||||
out_dir: Path,
|
||||
global_batch_idx: int,
|
||||
model: Union[nn.Module, DDP],
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
optimizer_g: Optional[Optimizer] = None,
|
||||
optimizer_d: Optional[Optimizer] = None,
|
||||
scheduler_g: Optional[LRScheduler] = None,
|
||||
scheduler_d: Optional[LRScheduler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
rank: int = 0,
|
||||
):
|
||||
"""Save training info after processing given number of batches.
|
||||
|
||||
Args:
|
||||
out_dir:
|
||||
The directory to save the checkpoint.
|
||||
global_batch_idx:
|
||||
The number of batches processed so far from the very start of the
|
||||
training. The saved checkpoint will have the following filename:
|
||||
|
||||
f'out_dir / checkpoint-{global_batch_idx}.pt'
|
||||
model:
|
||||
The neural network model whose `state_dict` will be saved in the
|
||||
checkpoint.
|
||||
model_avg:
|
||||
The stored model averaged from the start of training.
|
||||
params:
|
||||
A dict of training configurations to be saved.
|
||||
optimizer:
|
||||
The optimizer used in the training. Its `state_dict` will be saved.
|
||||
scheduler:
|
||||
The learning rate scheduler used in the training. Its `state_dict` will
|
||||
be saved.
|
||||
scaler:
|
||||
The scaler used for mix precision training. Its `state_dict` will
|
||||
be saved.
|
||||
sampler:
|
||||
The sampler used in the training dataset.
|
||||
rank:
|
||||
The rank ID used in DDP training of the current node. Set it to 0
|
||||
if DDP is not used.
|
||||
"""
|
||||
out_dir = Path(out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
filename = out_dir / f"checkpoint-{global_batch_idx}.pt"
|
||||
save_checkpoint(
|
||||
filename=filename,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
params=params,
|
||||
optimizer_g=optimizer_g,
|
||||
scheduler_g=scheduler_g,
|
||||
optimizer_d=optimizer_d,
|
||||
scheduler_d=scheduler_d,
|
||||
scaler=scaler,
|
||||
sampler=sampler,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
filename: Path,
|
||||
model: nn.Module,
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
optimizer_g: Optional[Optimizer] = None,
|
||||
optimizer_d: Optional[Optimizer] = None,
|
||||
scheduler_g: Optional[LRScheduler] = None,
|
||||
scheduler_d: Optional[LRScheduler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
strict: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
logging.info(f"Loading checkpoint from {filename}")
|
||||
checkpoint = torch.load(filename, map_location="cpu")
|
||||
|
||||
if next(iter(checkpoint["model"])).startswith("module."):
|
||||
logging.info("Loading checkpoint saved by DDP")
|
||||
|
||||
dst_state_dict = model.state_dict()
|
||||
src_state_dict = checkpoint["model"]
|
||||
for key in dst_state_dict.keys():
|
||||
src_key = "{}.{}".format("module", key)
|
||||
dst_state_dict[key] = src_state_dict.pop(src_key)
|
||||
assert len(src_state_dict) == 0
|
||||
model.load_state_dict(dst_state_dict, strict=strict)
|
||||
else:
|
||||
model.load_state_dict(checkpoint["model"], strict=strict)
|
||||
|
||||
checkpoint.pop("model")
|
||||
|
||||
if model_avg is not None and "model_avg" in checkpoint:
|
||||
logging.info("Loading averaged model")
|
||||
model_avg.load_state_dict(checkpoint["model_avg"], strict=strict)
|
||||
checkpoint.pop("model_avg")
|
||||
|
||||
def load(name, obj):
|
||||
s = checkpoint.get(name, None)
|
||||
if obj and s:
|
||||
obj.load_state_dict(s)
|
||||
checkpoint.pop(name)
|
||||
|
||||
load("optimizer_g", optimizer_g)
|
||||
load("optimizer_d", optimizer_d)
|
||||
load("scheduler_g", scheduler_g)
|
||||
load("scheduler_d", scheduler_d)
|
||||
load("grad_scaler", scaler)
|
||||
load("sampler", sampler)
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
filename: Path,
|
||||
model: Union[nn.Module, DDP],
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
optimizer_g: Optional[Optimizer] = None,
|
||||
optimizer_d: Optional[Optimizer] = None,
|
||||
scheduler_g: Optional[LRScheduler] = None,
|
||||
scheduler_d: Optional[LRScheduler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save training information to a file.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
The checkpoint filename.
|
||||
model:
|
||||
The model to be saved. We only save its `state_dict()`.
|
||||
model_avg:
|
||||
The stored model averaged from the start of training.
|
||||
params:
|
||||
User defined parameters, e.g., epoch, loss.
|
||||
optimizer:
|
||||
The optimizer to be saved. We only save its `state_dict()`.
|
||||
scheduler:
|
||||
The scheduler to be saved. We only save its `state_dict()`.
|
||||
scalar:
|
||||
The GradScaler to be saved. We only save its `state_dict()`.
|
||||
rank:
|
||||
Used in DDP. We save checkpoint only for the node whose rank is 0.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
if rank != 0:
|
||||
return
|
||||
|
||||
logging.info(f"Saving checkpoint to {filename}")
|
||||
|
||||
if isinstance(model, DDP):
|
||||
model = model.module
|
||||
|
||||
checkpoint = {
|
||||
"model": model.state_dict(),
|
||||
"optimizer_g": optimizer_g.state_dict() if optimizer_g is not None else None,
|
||||
"optimizer_d": optimizer_d.state_dict() if optimizer_d is not None else None,
|
||||
"scheduler_g": scheduler_g.state_dict() if scheduler_g is not None else None,
|
||||
"scheduler_d": scheduler_d.state_dict() if scheduler_d is not None else None,
|
||||
"grad_scaler": scaler.state_dict() if scaler is not None else None,
|
||||
"sampler": sampler.state_dict() if sampler is not None else None,
|
||||
}
|
||||
|
||||
if model_avg is not None:
|
||||
checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict()
|
||||
|
||||
if params:
|
||||
for k, v in params.items():
|
||||
assert k not in checkpoint
|
||||
checkpoint[k] = v
|
||||
|
||||
torch.save(checkpoint, filename)
|
||||
|
||||
|
||||
def _get_cosine_schedule_with_warmup_lr_lambda(
|
||||
current_step: int,
|
||||
*,
|
||||
num_warmup_steps: int,
|
||||
num_training_steps: int,
|
||||
num_cycles: float,
|
||||
min_lr_rate: float = 0.0,
|
||||
):
|
||||
if current_step < num_warmup_steps:
|
||||
return float(current_step) / float(max(1, num_warmup_steps))
|
||||
progress = float(current_step - num_warmup_steps) / float(
|
||||
max(1, num_training_steps - num_warmup_steps)
|
||||
)
|
||||
factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
|
||||
factor = factor * (1 - min_lr_rate) + min_lr_rate
|
||||
return max(0, factor)
|
||||
|
||||
|
||||
def get_cosine_schedule_with_warmup(
|
||||
optimizer: Optimizer,
|
||||
num_warmup_steps: int,
|
||||
num_training_steps: int,
|
||||
num_cycles: float = 0.5,
|
||||
last_epoch: int = -1,
|
||||
):
|
||||
"""
|
||||
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
||||
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
|
||||
initial lr set in the optimizer.
|
||||
|
||||
Args:
|
||||
optimizer ([`~torch.optim.Optimizer`]):
|
||||
The optimizer for which to schedule the learning rate.
|
||||
num_warmup_steps (`int`):
|
||||
The number of steps for the warmup phase.
|
||||
num_training_steps (`int`):
|
||||
The total number of training steps.
|
||||
num_cycles (`float`, *optional*, defaults to 0.5):
|
||||
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
|
||||
following a half-cosine).
|
||||
last_epoch (`int`, *optional*, defaults to -1):
|
||||
The index of the last epoch when resuming training.
|
||||
|
||||
Return:
|
||||
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
||||
"""
|
||||
|
||||
lr_lambda = partial(
|
||||
_get_cosine_schedule_with_warmup_lr_lambda,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=num_training_steps,
|
||||
num_cycles=num_cycles,
|
||||
)
|
||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||
|
||||
|
||||
def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
|
||||
"""
|
||||
Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor.
|
||||
clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
|
||||
|
||||
Returns:
|
||||
Tensor: Element-wise logarithm of the input tensor with clipping applied.
|
||||
"""
|
||||
return torch.log(torch.clip(x, min=clip_val))
|
287
egs/ljspeech/TTS/local/evaluate_fsd.py
Normal file
287
egs/ljspeech/TTS/local/evaluate_fsd.py
Normal file
@ -0,0 +1,287 @@
|
||||
"""
|
||||
Calculate Frechet Speech Distance betweeen two speech directories.
|
||||
Adapted from: https://github.com/gudgud96/frechet-audio-distance/blob/main/frechet_audio_distance/fad.py
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from multiprocessing.dummy import Pool as ThreadPool
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from scipy import linalg
|
||||
from tqdm import tqdm
|
||||
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--real-path", type=str, help="path of the real speech directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval-path", type=str, help="path of the evaluated speech directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
type=str,
|
||||
default="model/huggingface/wav2vec2_base",
|
||||
help="path of the wav2vec 2.0 model directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--real-embds-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path of the real embedding directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval-embds-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path of the evaluated embedding directory",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
class FrechetSpeechDistance:
|
||||
def __init__(
|
||||
self,
|
||||
model_path="resources/wav2vec2_base",
|
||||
pca_dim=128,
|
||||
speech_load_worker=8,
|
||||
):
|
||||
"""
|
||||
Initialize FSD
|
||||
"""
|
||||
self.sample_rate = 16000
|
||||
self.channels = 1
|
||||
self.device = (
|
||||
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
)
|
||||
logging.info("[Frechet Speech Distance] Using device: {}".format(self.device))
|
||||
self.speech_load_worker = speech_load_worker
|
||||
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_path)
|
||||
self.model = Wav2Vec2Model.from_pretrained(model_path)
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
self.pca_dim = pca_dim
|
||||
|
||||
def load_speech_files(self, dir, dtype="float32"):
|
||||
def _load_speech_task(fname, sample_rate, channels, dtype="float32"):
|
||||
if dtype not in ["float64", "float32", "int32", "int16"]:
|
||||
raise ValueError(f"dtype not supported: {dtype}")
|
||||
|
||||
wav_data, sr = sf.read(fname, dtype=dtype)
|
||||
# For integer type PCM input, convert to [-1.0, +1.0]
|
||||
if dtype == "int16":
|
||||
wav_data = wav_data / 32768.0
|
||||
elif dtype == "int32":
|
||||
wav_data = wav_data / float(2**31)
|
||||
|
||||
# Convert to mono
|
||||
assert channels in [1, 2], "channels must be 1 or 2"
|
||||
if len(wav_data.shape) > channels:
|
||||
wav_data = np.mean(wav_data, axis=1)
|
||||
|
||||
if sr != sample_rate:
|
||||
wav_data = (
|
||||
librosa.resample(wav_data, orig_sr=sr, target_sr=sample_rate),
|
||||
)
|
||||
|
||||
return wav_data
|
||||
|
||||
task_results = []
|
||||
|
||||
pool = ThreadPool(self.speech_load_worker)
|
||||
|
||||
logging.info("[Frechet Speech Distance] Loading speech from {}...".format(dir))
|
||||
for fname in os.listdir(dir):
|
||||
res = pool.apply_async(
|
||||
_load_speech_task,
|
||||
args=(os.path.join(dir, fname), self.sample_rate, self.channels, dtype),
|
||||
)
|
||||
task_results.append(res)
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
return [k.get() for k in task_results]
|
||||
|
||||
def get_embeddings(self, x):
|
||||
"""
|
||||
Get embeddings
|
||||
Params:
|
||||
-- x : a list of np.ndarray speech samples
|
||||
-- sr : sampling rate.
|
||||
"""
|
||||
embd_lst = []
|
||||
try:
|
||||
for speech in tqdm(x):
|
||||
input_features = self.feature_extractor(
|
||||
speech, sampling_rate=self.sample_rate, return_tensors="pt"
|
||||
).input_values.to(self.device)
|
||||
with torch.no_grad():
|
||||
embd = self.model(input_features).last_hidden_state.mean(1)
|
||||
|
||||
if embd.device != torch.device("cpu"):
|
||||
embd = embd.cpu()
|
||||
|
||||
if torch.is_tensor(embd):
|
||||
embd = embd.detach().numpy()
|
||||
|
||||
embd_lst.append(embd)
|
||||
except Exception as e:
|
||||
print(
|
||||
"[Frechet Speech Distance] get_embeddings throw an exception: {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
|
||||
return np.concatenate(embd_lst, axis=0)
|
||||
|
||||
def calculate_embd_statistics(self, embd_lst):
|
||||
if isinstance(embd_lst, list):
|
||||
embd_lst = np.array(embd_lst)
|
||||
mu = np.mean(embd_lst, axis=0)
|
||||
sigma = np.cov(embd_lst, rowvar=False)
|
||||
return mu, sigma
|
||||
|
||||
def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6):
|
||||
"""
|
||||
Adapted from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py
|
||||
|
||||
Numpy implementation of the Frechet Distance.
|
||||
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
||||
and X_2 ~ N(mu_2, C_2) is
|
||||
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
||||
Stable version by Dougal J. Sutherland.
|
||||
Params:
|
||||
-- mu1 : Numpy array containing the activations of a layer of the
|
||||
inception net (like returned by the function 'get_predictions')
|
||||
for generated samples.
|
||||
-- mu2 : The sample mean over activations, precalculated on an
|
||||
representative data set.
|
||||
-- sigma1: The covariance matrix over activations for generated samples.
|
||||
-- sigma2: The covariance matrix over activations, precalculated on an
|
||||
representative data set.
|
||||
Returns:
|
||||
-- : The Frechet Distance.
|
||||
"""
|
||||
|
||||
mu1 = np.atleast_1d(mu1)
|
||||
mu2 = np.atleast_1d(mu2)
|
||||
|
||||
sigma1 = np.atleast_2d(sigma1)
|
||||
sigma2 = np.atleast_2d(sigma2)
|
||||
|
||||
assert (
|
||||
mu1.shape == mu2.shape
|
||||
), "Training and test mean vectors have different lengths"
|
||||
assert (
|
||||
sigma1.shape == sigma2.shape
|
||||
), "Training and test covariances have different dimensions"
|
||||
|
||||
diff = mu1 - mu2
|
||||
|
||||
# Product might be almost singular
|
||||
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2).astype(complex), disp=False)
|
||||
if not np.isfinite(covmean).all():
|
||||
msg = (
|
||||
"fid calculation produces singular product; "
|
||||
"adding %s to diagonal of cov estimates"
|
||||
) % eps
|
||||
logging.info(msg)
|
||||
offset = np.eye(sigma1.shape[0]) * eps
|
||||
covmean = linalg.sqrtm(
|
||||
(sigma1 + offset).dot(sigma2 + offset).astype(complex)
|
||||
)
|
||||
|
||||
# Numerical error might give slight imaginary component
|
||||
if np.iscomplexobj(covmean):
|
||||
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
||||
m = np.max(np.abs(covmean.imag))
|
||||
raise ValueError("Imaginary component {}".format(m))
|
||||
covmean = covmean.real
|
||||
|
||||
tr_covmean = np.trace(covmean)
|
||||
|
||||
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
|
||||
|
||||
def score(
|
||||
self,
|
||||
real_path,
|
||||
eval_path,
|
||||
real_embds_path=None,
|
||||
eval_embds_path=None,
|
||||
dtype="float32",
|
||||
):
|
||||
"""
|
||||
Computes the Frechet Speech Distance (FSD) between two directories of speech files.
|
||||
|
||||
Parameters:
|
||||
- real_path (str): Path to the directory containing real speech files.
|
||||
- eval_path (str): Path to the directory containing evaluation speech files.
|
||||
- real_embds_path (str, optional): Path to save/load real speech embeddings (e.g., /folder/bkg_embs.npy). If None, embeddings won't be saved.
|
||||
- eval_embds_path (str, optional): Path to save/load evaluation speech embeddings (e.g., /folder/test_embs.npy). If None, embeddings won't be saved.
|
||||
- dtype (str, optional): Data type for loading speech. Default is "float32".
|
||||
|
||||
Returns:
|
||||
- float: The Frechet Speech Distance (FSD) score between the two directories of speech files.
|
||||
"""
|
||||
# Load or compute real embeddings
|
||||
if real_embds_path is not None and os.path.exists(real_embds_path):
|
||||
logging.info(
|
||||
f"[Frechet Speech Distance] Loading embeddings from {real_embds_path}..."
|
||||
)
|
||||
embds_real = np.load(real_embds_path)
|
||||
else:
|
||||
speech_real = self.load_speech_files(real_path, dtype=dtype)
|
||||
embds_real = self.get_embeddings(speech_real)
|
||||
if real_embds_path:
|
||||
os.makedirs(os.path.dirname(real_embds_path), exist_ok=True)
|
||||
np.save(real_embds_path, embds_real)
|
||||
|
||||
# Load or compute eval embeddings
|
||||
if eval_embds_path is not None and os.path.exists(eval_embds_path):
|
||||
logging.info(
|
||||
f"[Frechet Speech Distance] Loading embeddings from {eval_embds_path}..."
|
||||
)
|
||||
embds_eval = np.load(eval_embds_path)
|
||||
else:
|
||||
speech_eval = self.load_speech_files(eval_path, dtype=dtype)
|
||||
embds_eval = self.get_embeddings(speech_eval)
|
||||
if eval_embds_path:
|
||||
os.makedirs(os.path.dirname(eval_embds_path), exist_ok=True)
|
||||
np.save(eval_embds_path, embds_eval)
|
||||
|
||||
# Check if embeddings are empty
|
||||
if len(embds_real) == 0:
|
||||
logging.info("[Frechet Speech Distance] real set dir is empty, exiting...")
|
||||
return -10.46
|
||||
if len(embds_eval) == 0:
|
||||
logging.info("[Frechet Speech Distance] eval set dir is empty, exiting...")
|
||||
return -1
|
||||
|
||||
# Compute statistics and FSD score
|
||||
mu_real, sigma_real = self.calculate_embd_statistics(embds_real)
|
||||
mu_eval, sigma_eval = self.calculate_embd_statistics(embds_eval)
|
||||
|
||||
fsd_score = self.calculate_frechet_distance(
|
||||
mu_real, sigma_real, mu_eval, sigma_eval
|
||||
)
|
||||
|
||||
return fsd_score
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
FSD = FrechetSpeechDistance(model_path=args.model_path)
|
||||
score = FSD.score(
|
||||
args.real_path, args.eval_path, args.real_embds_path, args.eval_embds_path
|
||||
)
|
||||
logging.info(f"FSD score: {score:.2f}")
|
139
egs/ljspeech/TTS/local/evaluate_wer_whisper.py
Normal file
139
egs/ljspeech/TTS/local/evaluate_wer_whisper.py
Normal file
@ -0,0 +1,139 @@
|
||||
"""
|
||||
Calculate WER with Whisper model
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import librosa
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from num2words import num2words
|
||||
from tqdm import tqdm
|
||||
from transformers import pipeline
|
||||
|
||||
from icefall.utils import store_transcripts, write_error_stats
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--wav-path", type=str, help="path of the speech directory")
|
||||
parser.add_argument("--decode-path", type=str, help="path of the speech directory")
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
type=str,
|
||||
default="model/huggingface/whisper_medium",
|
||||
help="path of the huggingface whisper model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--transcript-path",
|
||||
type=str,
|
||||
default="data/transcript/test.tsv",
|
||||
help="path of the transcript tsv file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size", type=int, default=64, help="decoding batch size"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="cuda:0", help="decoding device, cuda:0 or cpu"
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def post_process(text: str):
|
||||
def convert_numbers(match):
|
||||
return num2words(match.group())
|
||||
|
||||
text = re.sub(r"\b\d{1,2}\b", convert_numbers, text)
|
||||
text = re.sub(r"[^a-zA-Z0-9']", " ", text.lower())
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
return text
|
||||
|
||||
|
||||
def save_results(
|
||||
res_dir: str,
|
||||
results: List[Tuple[str, List[str], List[str]]],
|
||||
):
|
||||
if not os.path.exists(res_dir):
|
||||
os.makedirs(res_dir)
|
||||
recog_path = os.path.join(res_dir, "recogs.txt")
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
errs_filename = os.path.join(res_dir, "errs.txt")
|
||||
with open(errs_filename, "w") as f:
|
||||
_ = write_error_stats(f, "test", results, enable_log=True)
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
|
||||
class SpeechEvalDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, wav_path: str, transcript_path: str):
|
||||
super().__init__()
|
||||
self.audio_name = []
|
||||
self.audio_paths = []
|
||||
self.transcripts = []
|
||||
with Path(transcript_path).open("r", encoding="utf8") as f:
|
||||
meta = [item.split("\t") for item in f.read().rstrip().split("\n")]
|
||||
for item in meta:
|
||||
self.audio_name.append(item[0])
|
||||
self.audio_paths.append(Path(wav_path, item[0] + ".wav"))
|
||||
self.transcripts.append(item[1])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.audio_paths)
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
audio, sampling_rate = sf.read(self.audio_paths[index])
|
||||
item = {
|
||||
"array": librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000),
|
||||
"sampling_rate": 16000,
|
||||
"reference": self.transcripts[index],
|
||||
"audio_name": self.audio_name[index],
|
||||
}
|
||||
return item
|
||||
|
||||
|
||||
def main(args):
|
||||
|
||||
batch_size = args.batch_size
|
||||
|
||||
pipe = pipeline(
|
||||
"automatic-speech-recognition",
|
||||
model=args.model_path,
|
||||
device=args.device,
|
||||
tokenizer=args.model_path,
|
||||
)
|
||||
|
||||
dataset = SpeechEvalDataset(args.wav_path, args.transcript_path)
|
||||
|
||||
results = []
|
||||
bar = tqdm(
|
||||
pipe(
|
||||
dataset,
|
||||
generate_kwargs={"language": "english", "task": "transcribe"},
|
||||
batch_size=batch_size,
|
||||
),
|
||||
total=len(dataset),
|
||||
)
|
||||
for out in bar:
|
||||
results.append(
|
||||
(
|
||||
out["audio_name"][0],
|
||||
post_process(out["reference"][0].strip()).split(),
|
||||
post_process(out["text"].strip()).split(),
|
||||
)
|
||||
)
|
||||
save_results(args.decode_path, results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
main(args)
|
1
egs/ljspeech/TTS/vocos/discriminators.py
Symbolic link
1
egs/ljspeech/TTS/vocos/discriminators.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libritts/TTS/vocos/discriminators.py
|
1
egs/ljspeech/TTS/vocos/export-onnx.py
Symbolic link
1
egs/ljspeech/TTS/vocos/export-onnx.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libritts/TTS/vocos/export-onnx.py
|
1
egs/ljspeech/TTS/vocos/export.py
Symbolic link
1
egs/ljspeech/TTS/vocos/export.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libritts/TTS/vocos/export.py
|
1
egs/ljspeech/TTS/vocos/generator.py
Symbolic link
1
egs/ljspeech/TTS/vocos/generator.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libritts/TTS/vocos/generator.py
|
340
egs/ljspeech/TTS/vocos/infer.py
Executable file
340
egs/ljspeech/TTS/vocos/infer.py
Executable file
@ -0,0 +1,340 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Wei Kang
|
||||
# Han Zhu)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from lhotse.utils import fix_random_seed
|
||||
from scipy.io.wavfile import write
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
from tts_datamodule import LJSpeechTtsDataModule
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=100,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="flow_match/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--generate-dir",
|
||||
type=str,
|
||||
default="generated_wavs",
|
||||
help="Path name of the generated wavs",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
batch: dict,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The text-to-feature neural model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
|
||||
cut_ids = [cut.id for cut in batch["cut"]]
|
||||
|
||||
features = batch["features"] # (B, T, F)
|
||||
utt_durations = batch["features_lens"]
|
||||
|
||||
x = features.permute(0, 2, 1) # (B, F, T)
|
||||
|
||||
audios = model(x.to(device)) # (B, T)
|
||||
|
||||
wav_dir = f"{params.res_dir}/{params.suffix}"
|
||||
os.makedirs(wav_dir, exist_ok=True)
|
||||
|
||||
for i in range(audios.shape[0]):
|
||||
audio = audios[i][: (utt_durations[i] - 1) * 256 + 1024]
|
||||
audio = audio.cpu().squeeze().numpy()
|
||||
write(f"{wav_dir}/{cut_ids[i]}.wav", 22050, audio)
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
test_set: str,
|
||||
):
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The text-to-feature neural model.
|
||||
test_set:
|
||||
The name of the test_set
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
with open(f"{params.res_dir}/{test_set}.scp", "w", encoding="utf8") as f:
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["text"]
|
||||
cut_ids = [cut.id for cut in batch["cut"]]
|
||||
|
||||
decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
assert len(texts) == len(cut_ids), (len(texts), len(cut_ids))
|
||||
|
||||
for i in range(len(texts)):
|
||||
f.write(f"{cut_ids[i]}\t{texts[i]}\n")
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % 50 == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LJSpeechTtsDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
params.res_dir = params.exp_dir / params.generate_dir
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
params.device = device
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
logging.info(params)
|
||||
fix_random_seed(666)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
ljspeech = LJSpeechTtsDataModule(args)
|
||||
|
||||
test_cuts = ljspeech.test_cuts()
|
||||
|
||||
test_dl = ljspeech.test_dataloaders(test_cuts)
|
||||
|
||||
test_sets = ["test"]
|
||||
test_dls = [test_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dls):
|
||||
decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
test_set=test_set,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
egs/ljspeech/TTS/vocos/loss.py
Symbolic link
1
egs/ljspeech/TTS/vocos/loss.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libritts/TTS/vocos/loss.py
|
1
egs/ljspeech/TTS/vocos/model.py
Symbolic link
1
egs/ljspeech/TTS/vocos/model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libritts/TTS/vocos/model.py
|
1
egs/ljspeech/TTS/vocos/onnx_pretrained.py
Symbolic link
1
egs/ljspeech/TTS/vocos/onnx_pretrained.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libritts/TTS/vocos/onnx_pretrained.py
|
1
egs/ljspeech/TTS/vocos/pretrained.py
Symbolic link
1
egs/ljspeech/TTS/vocos/pretrained.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libritts/TTS/vocos/pretrained.py
|
1054
egs/ljspeech/TTS/vocos/train.py
Executable file
1054
egs/ljspeech/TTS/vocos/train.py
Executable file
File diff suppressed because it is too large
Load Diff
372
egs/ljspeech/TTS/vocos/tts_datamodule.py
Normal file
372
egs/ljspeech/TTS/vocos/tts_datamodule.py
Normal file
@ -0,0 +1,372 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo,
|
||||
# Zengwei Yao,
|
||||
# Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
PrecomputedFeatures,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
SpeechSynthesisDataset,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
AudioSamples,
|
||||
OnTheFlyFeatures,
|
||||
)
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class LJSpeechTtsDataModule:
|
||||
"""
|
||||
DataModule for tts experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- cut concatenation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="TTS data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/fbank"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--sampling-rate",
|
||||
type=int,
|
||||
default=22050,
|
||||
help="The sampleing rate of ljspeech dataset",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--frame-shift",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Frame shift.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--frame-length",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Frame shift.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--input-strategy",
|
||||
type=str,
|
||||
default="PrecomputedFeatures",
|
||||
help="AudioSamples or PrecomputedFeatures",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--use-fft-mag",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to use magnitude of fbank, false to use power energy.",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
logging.info("About to create train dataset")
|
||||
train = SpeechSynthesisDataset(
|
||||
return_text=True,
|
||||
return_tokens=False,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = self.args.sampling_rate
|
||||
config = FbankConfig(
|
||||
sampling_rate=sampling_rate,
|
||||
frame_length=self.args.frame_length / sampling_rate, # (in second),
|
||||
frame_shift=self.args.frame_shift / sampling_rate, # (in second)
|
||||
use_fft_mag=self.args.use_fft_mag,
|
||||
)
|
||||
train = SpeechSynthesisDataset(
|
||||
return_text=True,
|
||||
return_tokens=False,
|
||||
feature_input_strategy=OnTheFlyFeatures(Fbank(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
buffer_size=self.args.num_buckets * 2000,
|
||||
shuffle_buffer_size=self.args.num_buckets * 5000,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = self.args.sampling_rate
|
||||
config = FbankConfig(
|
||||
sampling_rate=sampling_rate,
|
||||
frame_length=self.args.frame_length / sampling_rate, # (in second),
|
||||
frame_shift=self.args.frame_shift / sampling_rate, # (in second)
|
||||
use_fft_mag=self.args.use_fft_mag,
|
||||
)
|
||||
validate = SpeechSynthesisDataset(
|
||||
return_text=True,
|
||||
return_tokens=False,
|
||||
feature_input_strategy=OnTheFlyFeatures(Fbank(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
validate = SpeechSynthesisDataset(
|
||||
return_text=True,
|
||||
return_tokens=False,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
num_buckets=self.args.num_buckets,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create valid dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.info("About to create test dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = self.args.sampling_rate
|
||||
config = FbankConfig(
|
||||
sampling_rate=sampling_rate,
|
||||
frame_length=self.args.frame_length / sampling_rate, # (in second),
|
||||
frame_shift=self.args.frame_shift / sampling_rate, # (in second)
|
||||
use_fft_mag=self.args.use_fft_mag,
|
||||
)
|
||||
test = SpeechSynthesisDataset(
|
||||
return_text=True,
|
||||
return_tokens=False,
|
||||
feature_input_strategy=OnTheFlyFeatures(Fbank(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
test = SpeechSynthesisDataset(
|
||||
return_text=True,
|
||||
return_tokens=False,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
test_sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration,
|
||||
num_buckets=self.args.num_buckets,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=test_sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "ljspeech_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def valid_cuts(self) -> CutSet:
|
||||
logging.info("About to get validation cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "ljspeech_cuts_valid.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts_finetune(self) -> CutSet:
|
||||
logging.info("About to get train cuts finetune")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "ljspeech_cuts_train_finetune.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def valid_cuts_finetune(self) -> CutSet:
|
||||
logging.info("About to get validation cuts finetune")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "ljspeech_cuts_valid_finetune.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts(self) -> CutSet:
|
||||
logging.info("About to get test cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "ljspeech_cuts_test.jsonl.gz"
|
||||
)
|
1
egs/ljspeech/TTS/vocos/utils.py
Symbolic link
1
egs/ljspeech/TTS/vocos/utils.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libritts/TTS/vocos/utils.py
|
@ -251,18 +251,22 @@ def save_checkpoint_with_global_batch_idx(
|
||||
)
|
||||
|
||||
|
||||
def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
|
||||
def find_checkpoints(
|
||||
out_dir: Path,
|
||||
iteration: int = 0,
|
||||
prefix: str = "checkpoint",
|
||||
) -> List[str]:
|
||||
"""Find all available checkpoints in a directory.
|
||||
|
||||
The checkpoint filenames have the form: `checkpoint-xxx.pt`
|
||||
The checkpoint filenames have the form: `{prefix}-xxx.pt`
|
||||
where xxx is a numerical value.
|
||||
|
||||
Assume you have the following checkpoints in the folder `foo`:
|
||||
|
||||
- checkpoint-1.pt
|
||||
- checkpoint-20.pt
|
||||
- checkpoint-300.pt
|
||||
- checkpoint-4000.pt
|
||||
- {prefix}-1.pt
|
||||
- {prefix}-20.pt
|
||||
- {prefix}-300.pt
|
||||
- {prefix}-4000.pt
|
||||
|
||||
Case 1 (Return all checkpoints)::
|
||||
|
||||
@ -291,8 +295,8 @@ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
|
||||
Return a list of checkpoint filenames, sorted in descending
|
||||
order by the numerical value in the filename.
|
||||
"""
|
||||
checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
|
||||
pattern = re.compile(r"checkpoint-([0-9]+).pt")
|
||||
checkpoints = list(glob.glob(f"{out_dir}/{prefix}-[0-9]*.pt"))
|
||||
pattern = re.compile(rf"{prefix}-([0-9]+).pt")
|
||||
iter_checkpoints = []
|
||||
for c in checkpoints:
|
||||
result = pattern.search(c)
|
||||
@ -317,12 +321,13 @@ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
|
||||
def remove_checkpoints(
|
||||
out_dir: Path,
|
||||
topk: int,
|
||||
prefix: str = "checkpoint",
|
||||
rank: int = 0,
|
||||
):
|
||||
"""Remove checkpoints from the given directory.
|
||||
|
||||
We assume that checkpoint filename has the form `checkpoint-xxx.pt`
|
||||
where xxx is a number, representing the number of processed batches
|
||||
We assume that checkpoint filename has the form `{prefix}-xxx.pt`
|
||||
where xxx is a number, representing the number of processed batches/epochs
|
||||
when saving that checkpoint. We sort checkpoints by filename and keep
|
||||
only the `topk` checkpoints with the highest `xxx`.
|
||||
|
||||
@ -331,6 +336,8 @@ def remove_checkpoints(
|
||||
The directory containing checkpoints to be removed.
|
||||
topk:
|
||||
Number of checkpoints to keep.
|
||||
prefix:
|
||||
The prefix of the checkpoint filename, normally `epoch`, `checkpoint`.
|
||||
rank:
|
||||
If using DDP for training, it is the rank of the current node.
|
||||
Use 0 if no DDP is used for training.
|
||||
@ -338,7 +345,7 @@ def remove_checkpoints(
|
||||
assert topk >= 1, topk
|
||||
if rank != 0:
|
||||
return
|
||||
checkpoints = find_checkpoints(out_dir)
|
||||
checkpoints = find_checkpoints(out_dir, prefix=prefix)
|
||||
|
||||
if len(checkpoints) == 0:
|
||||
logging.warn(f"No checkpoints found in {out_dir}")
|
||||
|
Loading…
x
Reference in New Issue
Block a user