Merge d8a0a4095554e58db0fc8f4e30a6a33932ab37dd into 34fc1fdf0d8ff520e2bb18267d046ca207c78ef9

This commit is contained in:
Wei Kang 2025-07-25 09:16:26 +02:00 committed by GitHub
commit cfbd4208cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 6309 additions and 11 deletions

View File

@ -0,0 +1,296 @@
from typing import List, Optional, Tuple
import torch
from torch import nn
from torch.nn import Conv2d
from torch.nn.utils.parametrizations import weight_norm
from torchaudio.transforms import Spectrogram
class MultiPeriodDiscriminator(nn.Module):
"""
Multi-Period Discriminator module adapted from https://github.com/jik876/hifi-gan.
Additionally, it allows incorporating conditional information with a learned embeddings table.
Args:
periods (tuple[int]): Tuple of periods for each discriminator.
num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
Defaults to None.
"""
def __init__(
self,
periods: Tuple[int, ...] = (2, 3, 5, 7, 11),
num_embeddings: Optional[int] = None,
):
super().__init__()
self.discriminators = nn.ModuleList(
[DiscriminatorP(period=p, num_embeddings=num_embeddings) for p in periods]
)
def forward(
self,
y: torch.Tensor,
y_hat: torch.Tensor,
bandwidth_id: Optional[torch.Tensor] = None,
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
List[List[torch.Tensor]],
]:
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for d in self.discriminators:
y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class DiscriminatorP(nn.Module):
def __init__(
self,
period: int,
in_channels: int = 1,
kernel_size: int = 5,
stride: int = 3,
lrelu_slope: float = 0.1,
num_embeddings: Optional[int] = None,
):
super().__init__()
self.period = period
self.convs = nn.ModuleList(
[
weight_norm(
Conv2d(
in_channels,
32,
(kernel_size, 1),
(stride, 1),
padding=(kernel_size // 2, 0),
)
),
weight_norm(
Conv2d(
32,
128,
(kernel_size, 1),
(stride, 1),
padding=(kernel_size // 2, 0),
)
),
weight_norm(
Conv2d(
128,
512,
(kernel_size, 1),
(stride, 1),
padding=(kernel_size // 2, 0),
)
),
weight_norm(
Conv2d(
512,
1024,
(kernel_size, 1),
(stride, 1),
padding=(kernel_size // 2, 0),
)
),
weight_norm(
Conv2d(
1024,
1024,
(kernel_size, 1),
(1, 1),
padding=(kernel_size // 2, 0),
)
),
]
)
if num_embeddings is not None:
self.emb = torch.nn.Embedding(
num_embeddings=num_embeddings, embedding_dim=1024
)
torch.nn.init.zeros_(self.emb.weight)
self.conv_post = weight_norm(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
self.lrelu_slope = lrelu_slope
def forward(
self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
x = x.unsqueeze(1)
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = torch.nn.functional.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for i, l in enumerate(self.convs):
x = l(x)
x = torch.nn.functional.leaky_relu(x, self.lrelu_slope)
if i > 0:
fmap.append(x)
if cond_embedding_id is not None:
emb = self.emb(cond_embedding_id)
h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
else:
h = 0
x = self.conv_post(x)
fmap.append(x)
x += h
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiResolutionDiscriminator(nn.Module):
def __init__(
self,
fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
num_embeddings: Optional[int] = None,
):
"""
Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
Additionally, it allows incorporating conditional information with a learned embeddings table.
Args:
fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
Defaults to None.
"""
super().__init__()
self.discriminators = nn.ModuleList(
[
DiscriminatorR(window_length=w, num_embeddings=num_embeddings)
for w in fft_sizes
]
)
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
List[List[torch.Tensor]],
]:
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for d in self.discriminators:
y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class DiscriminatorR(nn.Module):
def __init__(
self,
window_length: int,
num_embeddings: Optional[int] = None,
channels: int = 32,
hop_factor: float = 0.25,
bands: Tuple[Tuple[float, float], ...] = (
(0.0, 0.1),
(0.1, 0.25),
(0.25, 0.5),
(0.5, 0.75),
(0.75, 1.0),
),
):
super().__init__()
self.window_length = window_length
self.hop_factor = hop_factor
self.spec_fn = Spectrogram(
n_fft=window_length,
hop_length=int(window_length * hop_factor),
win_length=window_length,
power=None,
)
n_fft = window_length // 2 + 1
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
self.bands = bands
convs = lambda: nn.ModuleList(
[
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
weight_norm(
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
),
weight_norm(
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
),
weight_norm(
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
),
weight_norm(
nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))
),
]
)
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
if num_embeddings is not None:
self.emb = torch.nn.Embedding(
num_embeddings=num_embeddings, embedding_dim=channels
)
torch.nn.init.zeros_(self.emb.weight)
self.conv_post = weight_norm(
nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))
)
def spectrogram(self, x):
# Remove DC offset
x = x - x.mean(dim=-1, keepdims=True)
# Peak normalize the volume of input audio
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
x = self.spec_fn(x)
x = torch.view_as_real(x)
# x = rearrange(x, "b f t c -> b c t f")
x = x.permute(0, 3, 2, 1)
# Split into bands
x_bands = [x[..., b[0] : b[1]] for b in self.bands]
return x_bands
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
x_bands = self.spectrogram(x)
fmap = []
x = []
for band, stack in zip(x_bands, self.band_convs):
for i, layer in enumerate(stack):
band = layer(band)
band = torch.nn.functional.leaky_relu(band, 0.1)
if i > 0:
fmap.append(band)
x.append(band)
x = torch.cat(x, dim=-1)
if cond_embedding_id is not None:
emb = self.emb(cond_embedding_id)
h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
else:
h = 0
x = self.conv_post(x)
fmap.append(x)
x += h
return x, fmap

View File

@ -0,0 +1,371 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang)
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
"""
This script exports a transducer model from PyTorch to ONNX.
We use the pre-trained model from
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
2. Export the model to ONNX
./zipformer/export-onnx.py \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
--num-encoder-layers "2,2,3,4,3,2" \
--downsampling-factor "1,2,4,8,4,2" \
--feedforward-dim "512,768,1024,1536,1024,768" \
--num-heads "4,4,4,8,4,4" \
--encoder-dim "192,256,384,512,384,256" \
--query-head-dim 32 \
--value-head-dim 12 \
--pos-head-dim 4 \
--pos-dim 48 \
--encoder-unmasked-dim "192,192,256,256,256,192" \
--cnn-module-kernel "31,31,15,15,15,31" \
--decoder-dim 512 \
--joiner-dim 512 \
--causal False \
--chunk-size "16,32,64,-1" \
--left-context-frames "64,128,256,-1" \
--fp16 True
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
See ./onnx_pretrained.py and ./onnx_check.py for how to
use the exported ONNX models.
"""
import argparse
import logging
from pathlib import Path
from typing import Dict, Tuple
import onnx
import torch
import torch.nn as nn
from onnxconverter_common import float16
from onnxruntime.quantization import QuantType, quantize_dynamic
from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import make_pad_mask, num_tokens, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--sampling-rate",
type=int,
default=24000,
help="The sampleing rate of libritts dataset",
)
parser.add_argument(
"--frame-shift",
type=int,
default=256,
help="Frame shift.",
)
parser.add_argument(
"--frame-length",
type=int,
default=1024,
help="Frame shift.",
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--fp16",
type=str2bool,
default=False,
help="Whether to export models in fp16",
)
add_model_arguments(parser)
return parser
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = value
onnx.save(model, filename)
def export_model_onnx(
model: nn.Module,
model_filename: str,
opset_version: int = 13,
) -> None:
"""Export the joiner model to ONNX format.
The exported joiner model has two inputs:
- encoder_out: a tensor of shape (N, joiner_dim)
- decoder_out: a tensor of shape (N, joiner_dim)
and produces one output:
- logit: a tensor of shape (N, vocab_size)
"""
input_tensor = torch.rand((2, 80, 100), dtype=torch.float32)
torch.onnx.export(
model,
(input_tensor,),
model_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"features",
],
output_names=["audio"],
dynamic_axes={
"features": {0: "N", 2: "F"},
"audio": {0: "N", 1: "T"},
},
)
meta_data = {
"model_type": "Vocos",
"version": "1",
"model_author": "k2-fsa",
"comment": "ConvNext Vocos",
}
logging.info(f"meta_data: {meta_data}")
add_meta_data(filename=model_filename, meta_data=meta_data)
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
params.device = device
logging.info(params)
logging.info("About to create model")
model = get_model(params)
model.to(device)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.eval()
vocos = model.generator
if params.iter > 0:
suffix = f"iter-{params.iter}"
else:
suffix = f"epoch-{params.epoch}"
suffix += f"-avg-{params.avg}"
opset_version = 13
logging.info("Exporting model")
model_filename = params.exp_dir / f"vocos-{suffix}.onnx"
export_model_onnx(
vocos,
model_filename,
opset_version=opset_version,
)
logging.info(f"Exported vocos generator to {model_filename}")
if params.fp16:
logging.info("Generate fp16 models")
model = onnx.load(model_filename)
model_fp16 = float16.convert_float_to_float16(model, keep_io_types=True)
model_filename_fp16 = params.exp_dir / f"vocos-{suffix}.fp16.onnx"
onnx.save(model_fp16, model_filename_fp16)
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
logging.info("Generate int8 quantization models")
model_filename_int8 = params.exp_dir / f"vocos-{suffix}.int8.onnx"
quantize_dynamic(
model_input=model_filename,
model_output=model_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

407
egs/libritts/TTS/vocos/export.py Executable file
View File

@ -0,0 +1,407 @@
#!/usr/bin/env python3
#
# Copyright 2024 Xiaomi Corporation (Author: Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
Note: This is a example for libritts dataset, if you are using different
dataset, you should change the argument values according to your dataset.
(1) Export to torchscript model using torch.jit.script()
./vocos/export.py \
--exp-dir ./vocos/exp \
--epoch 30 \
--avg 9 \
--jit 1
It will generate a file `jit_script.pt` in the given `exp_dir`. You can later
load it by `torch.jit.load("jit_script.pt")`.
Check ./jit_pretrained.py for its usage.
Check https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall.
- For streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`.
You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`.
Check ./jit_pretrained_streaming.py for its usage.
Check https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall.
(2) Export `model.state_dict()`
- For non-streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9
- For streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--causal 1 \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
load it by `icefall.checkpoint.load_checkpoint()`.
- For non-streaming model:
To use the generated file with `zipformer/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./zipformer/decode.py \
--exp-dir ./zipformer/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
- For streaming model:
To use the generated file with `zipformer/decode.py` and `zipformer/streaming_decode.py`, you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
# simulated streaming decoding
./zipformer/decode.py \
--exp-dir ./zipformer/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
# chunk-wise streaming decoding
./zipformer/streaming_decode.py \
--exp-dir ./zipformer/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
Check ./pretrained.py for its usage.
Note: If you don't want to train a model from scratch, we have
provided one for you. You can get it at
- non-streaming model:
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
- streaming model:
https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
with the following commands:
sudo apt-get install git-lfs
git lfs install
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
# You will find the pre-trained models in exp dir
"""
import argparse
import logging
from pathlib import Path
from typing import List, Tuple
import torch
from torch import Tensor, nn
from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
)
from icefall.utils import str2bool
from utils import load_checkpoint
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--sampling-rate",
type=int,
default=24000,
help="The sampleing rate of libritts dataset",
)
parser.add_argument(
"--frame-shift",
type=int,
default=256,
help="Frame shift.",
)
parser.add_argument(
"--frame-length",
type=int,
default=1024,
help="Frame shift.",
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="vocos/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
It will generate a file named jit_script.pt.
Check ./jit_pretrained.py for how to use it.
""",
)
add_model_arguments(parser)
return parser
class EncoderModel(nn.Module):
"""A wrapper for encoder and encoder_embed"""
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
super().__init__()
self.encoder = encoder
self.encoder_embed = encoder_embed
def forward(
self, features: Tensor, feature_lengths: Tensor
) -> Tuple[Tensor, Tensor]:
"""
Args:
features: (N, T, C)
feature_lengths: (N,)
"""
x, x_lens = self.encoder_embed(features, feature_lengths)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return encoder_out, encoder_out_lens
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
params.device = device
logging.info(f"device: {device}")
logging.info(params)
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.eval()
model = model.generator
if params.jit is True:
model.encoder = EncoderModel(model.encoder, model.encoder_embed)
filename = "jit_script.pt"
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
model.save(str(params.exp_dir / filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torchscript. Export model.state_dict()")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "generator.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,264 @@
import logging
from typing import Optional
import numpy as np
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F
def window_sumsquare(
window: torch.Tensor,
n_samples: int,
hop_length: int = 256,
win_length: int = 1024,
):
"""
Compute the sum-square envelope of a window function at a given hop length.
This is used to estimate modulation effects induced by windowing
observations in short-time fourier transforms.
Parameters
----------
window : string, tuple, number, callable, or list-like
Window specification, as in `get_window`
n_samples : int > 0
The number of expected samples.
hop_length : int > 0
The number of samples to advance between frames
win_length :
The length of the window function.
Returns
-------
wss : torch.Tensor, The sum-squared envelope of the window function.
"""
n_frames = (n_samples - win_length) // hop_length + 1
output_size = (n_frames - 1) * hop_length + win_length
device = window.device
# Window envelope
window_sq = window.square().expand(1, n_frames, -1).transpose(1, 2)
window_envelope = torch.nn.functional.fold(
window_sq,
output_size=(1, output_size),
kernel_size=(1, win_length),
stride=(1, hop_length),
).squeeze()
window_envelope = torch.nn.functional.pad(
window_envelope, (0, n_samples - output_size)
)
return window_envelope
class ISTFT(torch.nn.Module):
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
def __init__(
self,
filter_length: int = 1024,
hop_length: int = 256,
win_length: int = 1024,
padding: str = "none",
window_type: str = "povey",
max_samples: int = 1440000, # 1440000 / 24000 = 60s
):
super(ISTFT, self).__init__()
self.filter_length = filter_length
self.hop_length = hop_length
self.win_length = win_length
self.padding = padding
scale = self.filter_length / self.hop_length
fourier_basis = np.fft.fft(np.eye(self.filter_length))
cutoff = int((self.filter_length / 2 + 1))
fourier_basis = np.vstack(
[np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
)
inverse_basis = torch.FloatTensor(
np.linalg.pinv(scale * fourier_basis).T[:, None, :]
)
assert filter_length >= win_length
# Consistence with lhotse, search "create_frame_window" in https://github.com/lhotse-speech/lhotse
assert window_type in [
"hanning",
"povey",
], f"Only 'hanning' and 'povey' windows are supported, given {window_type}."
fft_window = torch.hann_window(win_length, periodic=False)
if window_type == "povey":
fft_window = fft_window.pow(0.85)
if filter_length > win_length:
pad_size = (filter_length - win_length) // 2
fft_window = torch.nn.functional.pad(fft_window, (pad_size, pad_size))
window_sum = window_sumsquare(
window=fft_window,
n_samples=max_samples,
hop_length=hop_length,
win_length=filter_length,
)
inverse_basis *= fft_window
self.register_buffer("inverse_basis", inverse_basis.float())
self.register_buffer("fft_window", fft_window)
self.register_buffer("window_sum", window_sum)
self.tiny = torch.finfo(torch.float16).tiny
def forward(self, magnitude, phase):
magnitude_phase = torch.cat(
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
)
inverse_transform = F.conv_transpose1d(
magnitude_phase,
Variable(self.inverse_basis, requires_grad=False),
stride=self.hop_length,
padding=0,
)
inverse_transform = inverse_transform.squeeze(1)
window_sum = self.window_sum
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
if self.window_sum.size(-1) < inverse_transform.size(-1):
logging.warning(
f"The precomputed `window_sumsquare` is too small, recomputing, "
f"from {self.window_sum.size(-1)} to {inverse_transform.size(-1)}"
)
window_sum = window_sumsquare(
window=self.fft_window,
n_samples=inverse_transform.size(-1),
win_length=self.filter_length,
hop_length=self.hop_length,
)
window_sum = window_sum[: inverse_transform.size(-1)]
approx_nonzero_indices = (window_sum > self.tiny).nonzero().squeeze()
inverse_transform[:, approx_nonzero_indices] /= window_sum[
approx_nonzero_indices
]
# scale by hop ratio
inverse_transform *= float(self.filter_length) / self.hop_length
assert self.padding in ["none", "same", "center"]
if self.padding == "center":
pad_len = self.filter_length // 2
elif self.padding == "same":
pad_len = (self.filter_length - self.hop_length) // 2
else:
return inverse_transform
return inverse_transform[:, pad_len:-pad_len]
class ConvNeXtBlock(nn.Module):
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
Args:
dim (int): Number of input channels.
intermediate_dim (int): Dimensionality of the intermediate layer.
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
Defaults to None.
"""
def __init__(
self,
dim: int,
intermediate_dim: int,
layer_scale_init_value: Optional[float] = None,
):
super().__init__()
self.dwconv = nn.Conv1d(
dim, dim, kernel_size=7, padding=3, groups=dim
) # depthwise conv
self.norm = nn.LayerNorm(dim, eps=1e-6)
# pointwise/1x1 convs, implemented with linear layers
self.pwconv1 = nn.Linear(dim, intermediate_dim)
self.act = nn.GELU()
self.pwconv2 = nn.Linear(intermediate_dim, dim)
self.gamma = (
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
if layer_scale_init_value > 0
else None
)
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
residual = x
x = self.dwconv(x)
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
x = residual + x
return x
class Generator(torch.nn.Module):
def __init__(
self,
feature_dim: int = 80,
dim: int = 512,
n_fft: int = 1024,
hop_length: int = 256,
intermediate_dim: int = 1536,
num_layers: int = 8,
padding: str = "none",
max_samples: int = 1440000, # 1440000 / 24000 = 60s
):
super(Generator, self).__init__()
self.feature_dim = feature_dim
self.embed = nn.Conv1d(feature_dim, dim, kernel_size=7, padding=3)
self.norm = nn.LayerNorm(dim, eps=1e-6)
layer_scale_init_value = 1 / num_layers
self.convnext = nn.ModuleList(
[
ConvNeXtBlock(
dim=dim,
intermediate_dim=intermediate_dim,
layer_scale_init_value=layer_scale_init_value,
)
for _ in range(num_layers)
]
)
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
self.apply(self._init_weights)
self.out_proj = torch.nn.Linear(dim, n_fft + 2)
self.istft = ISTFT(
filter_length=n_fft,
hop_length=hop_length,
win_length=n_fft,
padding=padding,
max_samples=max_samples,
)
def _init_weights(self, m):
if isinstance(m, (nn.Conv1d, nn.Linear)):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.embed(x)
x = self.norm(x.transpose(1, 2))
x = x.transpose(1, 2)
for conv_block in self.convnext:
x = conv_block(x)
x = self.final_layer_norm(x.transpose(1, 2))
x = self.out_proj(x).transpose(1, 2)
mag, phase = x.chunk(2, dim=1)
mag = torch.exp(mag)
# safeguard to prevent excessively large magnitudes
mag = torch.clip(mag, max=1e2)
audio = self.istft(mag, phase)
return audio

352
egs/libritts/TTS/vocos/infer.py Executable file
View File

@ -0,0 +1,352 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Wei Kang
# Han Zhu)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import logging
import math
import time
import os
from functools import partial
from pathlib import Path
import torch
import torch.nn as nn
from lhotse.utils import fix_random_seed
from scipy.io.wavfile import write
from train import add_model_arguments, get_model, get_params
from tts_datamodule import LibriTTSDataModule
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import AttributeDict, setup_logger, str2bool
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=100,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=10,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=False,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="vocos/exp",
help="The experiment dir",
)
parser.add_argument(
"--generate-dir",
type=str,
default="generated_wavs",
help="Path name of the generated wavs",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
batch: dict,
):
"""
Args:
params:
It's the return value of :func:`get_params`.
model:
The text-to-feature neural model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
cut_ids = [cut.id for cut in batch["cut"]]
infer_time = 0
audio_time = 0
features = batch["features"] # (B, T, F)
utt_durations = batch["features_lens"]
x = features.permute(0, 2, 1) # (B, F, T)
audio_time += torch.sum(utt_durations)
start = time.time()
audios = model(x.to(device)) # (B, T)
infer_time += time.time() - start
wav_dir = f"{params.res_dir}/{params.suffix}"
os.makedirs(wav_dir, exist_ok=True)
for i in range(audios.shape[0]):
audio = audios[i][: int(utt_durations[i] * 256)]
audio = audio.cpu().squeeze().numpy()
write(f"{wav_dir}/{cut_ids[i]}.wav", 24000, audio)
print(f"RTF : {infer_time / (audio_time * (256/24000))}")
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
test_set: str,
):
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The text-to-feature neural model.
test_set:
The name of the test_set
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
with open(f"{params.res_dir}/{test_set}.scp", "w", encoding="utf8") as f:
for batch_idx, batch in enumerate(dl):
# texts = batch["text"]
cut_ids = [cut.id for cut in batch["cut"]]
decode_one_batch(
params=params,
model=model,
batch=batch,
)
# assert len(texts) == len(cut_ids), (len(texts), len(cut_ids))
# for i in range(len(texts)):
# f.write(f"{cut_ids[i]}\t{texts[i]}\n")
# num_cuts += len(texts)
if batch_idx % 50 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
@torch.no_grad()
def main():
parser = get_parser()
LibriTTSDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.res_dir = params.exp_dir / params.generate_dir
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
params.device = device
logging.info(f"Device: {device}")
logging.info(params)
fix_random_seed(666)
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model = model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
libritts = LibriTTSDataModule(args)
test_cuts = libritts.test_clean_cuts()
test_dl = libritts.test_dataloaders(test_cuts)
test_sets = ["test"]
test_dls = [test_dl]
for test_set, test_dl in zip(test_sets, test_dls):
decode_dataset(
dl=test_dl,
params=params,
model=model,
test_set=test_set,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,133 @@
from typing import List, Tuple
import torch
import torchaudio
from torch import nn
from utils import safe_log
class MelSpecReconstructionLoss(nn.Module):
"""
L1 distance between the mel-scaled magnitude spectrograms of the ground truth sample and the generated sample
"""
def __init__(
self,
sample_rate: int = 24000,
n_fft: int = 1024,
hop_length: int = 256,
n_mels: int = 100,
):
super().__init__()
self.mel_spec = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
hop_length=hop_length,
n_mels=n_mels,
center=True,
power=1,
)
def forward(self, y_hat, y) -> torch.Tensor:
"""
Args:
y_hat (Tensor): Predicted audio waveform.
y (Tensor): Ground truth audio waveform.
Returns:
Tensor: L1 loss between the mel-scaled magnitude spectrograms.
"""
mel_hat = safe_log(self.mel_spec(y_hat))
mel = safe_log(self.mel_spec(y))
loss = torch.nn.functional.l1_loss(mel, mel_hat)
return loss
class GeneratorLoss(nn.Module):
"""
Generator Loss module. Calculates the loss for the generator based on discriminator outputs.
"""
def forward(
self, disc_outputs: List[torch.Tensor]
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
Args:
disc_outputs (List[Tensor]): List of discriminator outputs.
Returns:
Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from
the sub-discriminators
"""
loss = torch.zeros(
1, device=disc_outputs[0].device, dtype=disc_outputs[0].dtype
)
gen_losses = []
for dg in disc_outputs:
l = torch.mean(torch.clamp(1 - dg, min=0))
gen_losses.append(l)
loss += l
return loss, gen_losses
class DiscriminatorLoss(nn.Module):
"""
Discriminator Loss module. Calculates the loss for the discriminator based on real and generated outputs.
"""
def forward(
self,
disc_real_outputs: List[torch.Tensor],
disc_generated_outputs: List[torch.Tensor],
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
"""
Args:
disc_real_outputs (List[Tensor]): List of discriminator outputs for real samples.
disc_generated_outputs (List[Tensor]): List of discriminator outputs for generated samples.
Returns:
Tuple[Tensor, List[Tensor], List[Tensor]]: A tuple containing the total loss, a list of loss values from
the sub-discriminators for real outputs, and a list of
loss values for generated outputs.
"""
loss = torch.zeros(
1, device=disc_real_outputs[0].device, dtype=disc_real_outputs[0].dtype
)
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
r_loss = torch.mean(torch.clamp(1 - dr, min=0))
g_loss = torch.mean(torch.clamp(1 + dg, min=0))
loss += r_loss + g_loss
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
return loss, r_losses, g_losses
class FeatureMatchingLoss(nn.Module):
"""
Feature Matching Loss module. Calculates the feature matching loss between feature maps of the sub-discriminators.
"""
def forward(
self, fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]
) -> torch.Tensor:
"""
Args:
fmap_r (List[List[Tensor]]): List of feature maps from real samples.
fmap_g (List[List[Tensor]]): List of feature maps from generated samples.
Returns:
Tensor: The calculated feature matching loss.
"""
loss = torch.zeros(1, device=fmap_r[0][0].device, dtype=fmap_r[0][0].dtype)
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
loss += torch.mean(torch.abs(rl - gl))
return loss

View File

@ -0,0 +1,48 @@
import logging
import torch
from discriminators import MultiPeriodDiscriminator, MultiResolutionDiscriminator
from generator import Generator
from loss import (
DiscriminatorLoss,
GeneratorLoss,
FeatureMatchingLoss,
MelSpecReconstructionLoss,
)
class Vocos(torch.nn.Module):
def __init__(
self,
feature_dim: int = 80,
dim: int = 512,
n_fft: int = 1024,
hop_length: int = 256,
intermediate_dim: int = 1536,
num_layers: int = 8,
padding: str = "none",
sample_rate: int = 24000,
max_seconds: int = 60,
):
super(Vocos, self).__init__()
self.generator = Generator(
feature_dim=feature_dim,
dim=dim,
n_fft=n_fft,
hop_length=hop_length,
num_layers=num_layers,
intermediate_dim=intermediate_dim,
padding=padding,
max_samples=int(sample_rate * max_seconds),
)
self.mpd = MultiPeriodDiscriminator()
self.mrd = MultiResolutionDiscriminator()
self.disc_loss = DiscriminatorLoss()
self.gen_loss = GeneratorLoss()
self.feat_matching_loss = FeatureMatchingLoss()
self.melspec_loss = MelSpecReconstructionLoss(sample_rate=sample_rate)
def forward(self, features: torch.Tensor):
audio = self.generator(features)
return audio

View File

@ -0,0 +1,268 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads ONNX models and uses them to decode waves.
You can use the following command to get the exported models:
We use the pre-trained model from
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
2. Export the model to ONNX
./zipformer/export-onnx.py \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
--causal False
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
3. Run this file
./zipformer/onnx_pretrained.py \
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
"""
import argparse
import logging
import math
from pathlib import Path
from typing import List, Tuple
import onnxruntime as ort
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
from lhotse import Fbank, FbankConfig
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--model-filename",
type=str,
required=True,
help="Path to the encoder onnx model. ",
)
parser.add_argument(
"--sampling-rate",
type=int,
default=24000,
help="The sampleing rate of libritts dataset",
)
parser.add_argument(
"--frame-shift",
type=int,
default=256,
help="Frame shift.",
)
parser.add_argument(
"--frame-length",
type=int,
default=1024,
help="Frame shift.",
)
parser.add_argument(
"--use-fft-mag",
type=str2bool,
default=True,
help="Whether to use magnitude of fbank, false to use power energy.",
)
parser.add_argument(
"--output-dir",
type=str,
default="generated_audios",
help="The generated will be written to.",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
return parser
class OnnxModel:
def __init__(
self,
model_filename: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 4
self.session_opts = session_opts
self.init_model(model_filename)
def init_model(self, model_filename: str):
self.model = ort.InferenceSession(
model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
def run_model(
self,
x: torch.Tensor,
) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C)
x_lens:
A 2-D tensor of shape (N,). Its dtype is torch.int64
Returns:
Return a tuple containing:
- encoder_out, its shape is (N, T', joiner_dim)
- encoder_out_lens, its shape is (N,)
"""
out = self.model.run(
[
self.model.get_outputs()[0].name,
],
{
self.model.get_inputs()[0].name: x.numpy(),
},
)
return torch.from_numpy(out[0])
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
output_dir = Path(args.model_filename).parent / args.output_dir
output_dir.mkdir(exist_ok=True)
args.output_dir = output_dir
logging.info(vars(args))
model = OnnxModel(model_filename=args.model_filename)
config = FbankConfig(
sampling_rate=args.sampling_rate,
frame_length=args.frame_length / args.sampling_rate, # (in second),
frame_shift=args.frame_shift / args.sampling_rate, # (in second)
use_fft_mag=args.use_fft_mag,
)
fbank = Fbank(config)
logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files, expected_sample_rate=args.sampling_rate
)
wave_lengths = [w.size(0) for w in waves]
waves = pad_sequence(waves, batch_first=True, padding_value=0)
logging.info(f"waves : {waves.shape}")
features = fbank.extract_batch(waves, sampling_rate=args.sampling_rate)
if features.dim() == 2:
features = features.unsqueeze(0)
features = features.permute(0, 2, 1)
logging.info(f"features : {features.shape}")
logging.info("Generating started")
# model forward
audios = model.run_model(features)
for i, filename in enumerate(args.sound_files):
audio = audios[i : i + 1, 0 : wave_lengths[i]]
ofilename = args.output_dir / filename.split("/")[-1]
logging.info(f"Writting audio : {ofilename}")
torchaudio.save(str(ofilename), audio.cpu(), args.sampling_rate)
logging.info("Generating Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,196 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads a checkpoint and uses it to decode waves.
You can generate the checkpoint with the following command:
"""
import argparse
import logging
import math
from pathlib import Path
from typing import List
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_model, get_params
from lhotse import Fbank, FbankConfig
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--sampling-rate",
type=int,
default=24000,
help="The sampleing rate of libritts dataset",
)
parser.add_argument(
"--frame-shift",
type=int,
default=256,
help="Frame shift.",
)
parser.add_argument(
"--frame-length",
type=int,
default=1024,
help="Frame shift.",
)
parser.add_argument(
"--use-fft-mag",
type=str2bool,
default=True,
help="Whether to use magnitude of fbank, false to use power energy.",
)
parser.add_argument(
"--output-dir",
type=str,
default="generated_audios",
help="The generated will be written to.",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
add_model_arguments(parser)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0].contiguous())
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
params.device = device
output_dir = Path(params.checkpoint).parent / params.output_dir
output_dir.mkdir(exist_ok=True)
params.output_dir = output_dir
logging.info(f"{params}")
logging.info("Creating model")
model = get_model(params)
model = model.generator
checkpoint = torch.load(params.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
logging.info("Constructing Fbank computer")
config = FbankConfig(
sampling_rate=params.sampling_rate,
frame_length=params.frame_length / params.sampling_rate, # (in second),
frame_shift=params.frame_shift / params.sampling_rate, # (in second)
use_fft_mag=params.use_fft_mag,
)
fbank = Fbank(config)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sampling_rate
)
wave_lengths = [w.size(0) for w in waves]
waves = pad_sequence(waves, batch_first=True, padding_value=0)
features = (
fbank.extract_batch(waves, sampling_rate=params.sampling_rate)
.permute(0, 2, 1)
.to(device)
)
logging.info("Generating started")
# model forward
audios = model(features)
for i, filename in enumerate(params.sound_files):
audio = audios[i : i + 1, 0 : wave_lengths[i]]
ofilename = params.output_dir / filename.split("/")[-1]
logging.info(f"Writting audio : {ofilename}")
torchaudio.save(str(ofilename), audio.cpu(), params.sampling_rate)
logging.info("Generating Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

1054
egs/libritts/TTS/vocos/train.py Executable file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,419 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo,
# Zengwei Yao,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
SpeechSynthesisDataset,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class LibriTTSDataModule:
"""
DataModule for tts experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="TTS data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--return-text",
type=str2bool,
default=True,
help="Whether to return the text of the audio.",
)
group.add_argument(
"--return-tokens",
type=str2bool,
default=False,
help="Whether the return the tokens of the text of the audio.",
)
group.add_argument(
"--num-workers",
type=int,
default=4,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--sampling-rate",
type=int,
default=24000,
help="The sampleing rate of libritts dataset",
)
group.add_argument(
"--frame-shift",
type=int,
default=256,
help="Frame shift.",
)
group.add_argument(
"--frame-length",
type=int,
default=1024,
help="Frame shift.",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
group.add_argument(
"--use-fft-mag",
type=str2bool,
default=True,
help="Whether to use magnitude of fbank, false to use power energy.",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
logging.info("About to create train dataset")
train = SpeechSynthesisDataset(
return_text=self.args.return_text,
return_tokens=self.args.return_tokens,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
sampling_rate = self.args.sampling_rate
config = FbankConfig(
sampling_rate=sampling_rate,
frame_length=self.args.frame_length / sampling_rate, # (in second),
frame_shift=self.args.frame_shift / sampling_rate, # (in second)
use_fft_mag=self.args.use_fft_mag,
)
train = SpeechSynthesisDataset(
return_text=self.args.return_text,
return_tokens=self.args.return_tokens,
feature_input_strategy=OnTheFlyFeatures(Fbank(config)),
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
sampling_rate = self.args.sampling_rate
config = FbankConfig(
sampling_rate=sampling_rate,
frame_length=self.args.frame_length / sampling_rate, # (in second),
frame_shift=self.args.frame_shift / sampling_rate, # (in second)
use_fft_mag=self.args.use_fft_mag,
)
validate = SpeechSynthesisDataset(
return_text=self.args.return_text,
return_tokens=self.args.return_tokens,
feature_input_strategy=OnTheFlyFeatures(Fbank(config)),
return_cuts=self.args.return_cuts,
)
else:
validate = SpeechSynthesisDataset(
return_text=self.args.return_text,
return_tokens=self.args.return_tokens,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
num_buckets=self.args.num_buckets,
shuffle=False,
)
logging.info("About to create valid dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.info("About to create test dataset")
if self.args.on_the_fly_feats:
sampling_rate = self.args.sampling_rate
config = FbankConfig(
sampling_rate=sampling_rate,
frame_length=self.args.frame_length / sampling_rate, # (in second),
frame_shift=self.args.frame_shift / sampling_rate, # (in second)
use_fft_mag=self.args.use_fft_mag,
)
test = SpeechSynthesisDataset(
return_text=self.args.return_text,
return_tokens=self.args.return_tokens,
feature_input_strategy=OnTheFlyFeatures(Fbank(config)),
return_cuts=self.args.return_cuts,
)
else:
test = SpeechSynthesisDataset(
return_text=self.args.return_text,
return_tokens=self.args.return_tokens,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
test_sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
num_buckets=self.args.num_buckets,
shuffle=False,
)
logging.info("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=test_sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_train-all-shuf.jsonl.gz"
)
@lru_cache()
def train_clean_cuts(self) -> CutSet:
logging.info("About to get train clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_train-clean-460.jsonl.gz"
)
@lru_cache()
def train_clean_100_cuts(self) -> CutSet:
logging.info("About to get train clean 100 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_train-clean-100.jsonl.gz"
)
@lru_cache()
def train_clean_360_cuts(self) -> CutSet:
logging.info("About to get train clean 360 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_train-clean-360.jsonl.gz"
)
@lru_cache()
def dev_clean_cuts(self) -> CutSet:
logging.info("About to get dev clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_dev-clean.jsonl.gz"
)
@lru_cache()
def dev_other_cuts(self) -> CutSet:
logging.info("About to get dev other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_dev-other.jsonl.gz"
)
@lru_cache()
def test_clean_cuts(self) -> CutSet:
logging.info("About to get test clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_test-clean.jsonl.gz"
)
@lru_cache()
def test_other_cuts(self) -> CutSet:
logging.info("About to get test other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_test-other.jsonl.gz"
)
@lru_cache()
def train_cuts_finetune(self) -> CutSet:
logging.info("About to get train cuts finetune")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_train_finetune.jsonl.gz"
)
@lru_cache()
def valid_cuts_finetune(self) -> CutSet:
logging.info("About to get validation cuts finetune")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_valid_finetune.jsonl.gz"
)

View File

@ -0,0 +1,282 @@
import glob
import os
import logging
import matplotlib
import math
import torch
import torch.nn as nn
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from torch.nn.utils import weight_norm
from torch.optim.lr_scheduler import LRScheduler
from torch.optim import Optimizer
from torch.cuda.amp import GradScaler
from lhotse.dataset.sampling.base import CutSampler
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR
matplotlib.use("Agg")
import matplotlib.pylab as plt
def plot_spectrogram(spectrogram):
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
fig.canvas.draw()
plt.close()
return fig
def save_checkpoint_with_global_batch_idx(
out_dir: Path,
global_batch_idx: int,
model: Union[nn.Module, DDP],
model_avg: Optional[nn.Module] = None,
params: Optional[Dict[str, Any]] = None,
optimizer_g: Optional[Optimizer] = None,
optimizer_d: Optional[Optimizer] = None,
scheduler_g: Optional[LRScheduler] = None,
scheduler_d: Optional[LRScheduler] = None,
scaler: Optional[GradScaler] = None,
sampler: Optional[CutSampler] = None,
rank: int = 0,
):
"""Save training info after processing given number of batches.
Args:
out_dir:
The directory to save the checkpoint.
global_batch_idx:
The number of batches processed so far from the very start of the
training. The saved checkpoint will have the following filename:
f'out_dir / checkpoint-{global_batch_idx}.pt'
model:
The neural network model whose `state_dict` will be saved in the
checkpoint.
model_avg:
The stored model averaged from the start of training.
params:
A dict of training configurations to be saved.
optimizer:
The optimizer used in the training. Its `state_dict` will be saved.
scheduler:
The learning rate scheduler used in the training. Its `state_dict` will
be saved.
scaler:
The scaler used for mix precision training. Its `state_dict` will
be saved.
sampler:
The sampler used in the training dataset.
rank:
The rank ID used in DDP training of the current node. Set it to 0
if DDP is not used.
"""
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
filename = out_dir / f"checkpoint-{global_batch_idx}.pt"
save_checkpoint(
filename=filename,
model=model,
model_avg=model_avg,
params=params,
optimizer_g=optimizer_g,
scheduler_g=scheduler_g,
optimizer_d=optimizer_d,
scheduler_d=scheduler_d,
scaler=scaler,
sampler=sampler,
rank=rank,
)
def load_checkpoint(
filename: Path,
model: nn.Module,
model_avg: Optional[nn.Module] = None,
optimizer_g: Optional[Optimizer] = None,
optimizer_d: Optional[Optimizer] = None,
scheduler_g: Optional[LRScheduler] = None,
scheduler_d: Optional[LRScheduler] = None,
scaler: Optional[GradScaler] = None,
sampler: Optional[CutSampler] = None,
strict: bool = False,
) -> Dict[str, Any]:
logging.info(f"Loading checkpoint from {filename}")
checkpoint = torch.load(filename, map_location="cpu")
if next(iter(checkpoint["model"])).startswith("module."):
logging.info("Loading checkpoint saved by DDP")
dst_state_dict = model.state_dict()
src_state_dict = checkpoint["model"]
for key in dst_state_dict.keys():
src_key = "{}.{}".format("module", key)
dst_state_dict[key] = src_state_dict.pop(src_key)
assert len(src_state_dict) == 0
model.load_state_dict(dst_state_dict, strict=strict)
else:
model.load_state_dict(checkpoint["model"], strict=strict)
checkpoint.pop("model")
if model_avg is not None and "model_avg" in checkpoint:
logging.info("Loading averaged model")
model_avg.load_state_dict(checkpoint["model_avg"], strict=strict)
checkpoint.pop("model_avg")
def load(name, obj):
s = checkpoint.get(name, None)
if obj and s:
obj.load_state_dict(s)
checkpoint.pop(name)
load("optimizer_g", optimizer_g)
load("optimizer_d", optimizer_d)
load("scheduler_g", scheduler_g)
load("scheduler_d", scheduler_d)
load("grad_scaler", scaler)
load("sampler", sampler)
return checkpoint
def save_checkpoint(
filename: Path,
model: Union[nn.Module, DDP],
model_avg: Optional[nn.Module] = None,
params: Optional[Dict[str, Any]] = None,
optimizer_g: Optional[Optimizer] = None,
optimizer_d: Optional[Optimizer] = None,
scheduler_g: Optional[LRScheduler] = None,
scheduler_d: Optional[LRScheduler] = None,
scaler: Optional[GradScaler] = None,
sampler: Optional[CutSampler] = None,
rank: int = 0,
) -> None:
"""Save training information to a file.
Args:
filename:
The checkpoint filename.
model:
The model to be saved. We only save its `state_dict()`.
model_avg:
The stored model averaged from the start of training.
params:
User defined parameters, e.g., epoch, loss.
optimizer:
The optimizer to be saved. We only save its `state_dict()`.
scheduler:
The scheduler to be saved. We only save its `state_dict()`.
scalar:
The GradScaler to be saved. We only save its `state_dict()`.
rank:
Used in DDP. We save checkpoint only for the node whose rank is 0.
Returns:
Return None.
"""
if rank != 0:
return
logging.info(f"Saving checkpoint to {filename}")
if isinstance(model, DDP):
model = model.module
checkpoint = {
"model": model.state_dict(),
"optimizer_g": optimizer_g.state_dict() if optimizer_g is not None else None,
"optimizer_d": optimizer_d.state_dict() if optimizer_d is not None else None,
"scheduler_g": scheduler_g.state_dict() if scheduler_g is not None else None,
"scheduler_d": scheduler_d.state_dict() if scheduler_d is not None else None,
"grad_scaler": scaler.state_dict() if scaler is not None else None,
"sampler": sampler.state_dict() if sampler is not None else None,
}
if model_avg is not None:
checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict()
if params:
for k, v in params.items():
assert k not in checkpoint
checkpoint[k] = v
torch.save(checkpoint, filename)
def _get_cosine_schedule_with_warmup_lr_lambda(
current_step: int,
*,
num_warmup_steps: int,
num_training_steps: int,
num_cycles: float,
min_lr_rate: float = 0.0,
):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps)
)
factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
factor = factor * (1 - min_lr_rate) + min_lr_rate
return max(0, factor)
def get_cosine_schedule_with_warmup(
optimizer: Optimizer,
num_warmup_steps: int,
num_training_steps: int,
num_cycles: float = 0.5,
last_epoch: int = -1,
):
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
num_cycles (`float`, *optional*, defaults to 0.5):
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
following a half-cosine).
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
lr_lambda = partial(
_get_cosine_schedule_with_warmup_lr_lambda,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_cycles=num_cycles,
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
"""
Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
Args:
x (Tensor): Input tensor.
clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
Returns:
Tensor: Element-wise logarithm of the input tensor with clipping applied.
"""
return torch.log(torch.clip(x, min=clip_val))

View File

@ -0,0 +1,287 @@
"""
Calculate Frechet Speech Distance betweeen two speech directories.
Adapted from: https://github.com/gudgud96/frechet-audio-distance/blob/main/frechet_audio_distance/fad.py
"""
import argparse
import logging
import os
from multiprocessing.dummy import Pool as ThreadPool
import librosa
import numpy as np
import soundfile as sf
import torch
from scipy import linalg
from tqdm import tqdm
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model
logging.basicConfig(level=logging.INFO)
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--real-path", type=str, help="path of the real speech directory"
)
parser.add_argument(
"--eval-path", type=str, help="path of the evaluated speech directory"
)
parser.add_argument(
"--model-path",
type=str,
default="model/huggingface/wav2vec2_base",
help="path of the wav2vec 2.0 model directory",
)
parser.add_argument(
"--real-embds-path",
type=str,
default=None,
help="path of the real embedding directory",
)
parser.add_argument(
"--eval-embds-path",
type=str,
default=None,
help="path of the evaluated embedding directory",
)
return parser
class FrechetSpeechDistance:
def __init__(
self,
model_path="resources/wav2vec2_base",
pca_dim=128,
speech_load_worker=8,
):
"""
Initialize FSD
"""
self.sample_rate = 16000
self.channels = 1
self.device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
logging.info("[Frechet Speech Distance] Using device: {}".format(self.device))
self.speech_load_worker = speech_load_worker
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_path)
self.model = Wav2Vec2Model.from_pretrained(model_path)
self.model.to(self.device)
self.model.eval()
self.pca_dim = pca_dim
def load_speech_files(self, dir, dtype="float32"):
def _load_speech_task(fname, sample_rate, channels, dtype="float32"):
if dtype not in ["float64", "float32", "int32", "int16"]:
raise ValueError(f"dtype not supported: {dtype}")
wav_data, sr = sf.read(fname, dtype=dtype)
# For integer type PCM input, convert to [-1.0, +1.0]
if dtype == "int16":
wav_data = wav_data / 32768.0
elif dtype == "int32":
wav_data = wav_data / float(2**31)
# Convert to mono
assert channels in [1, 2], "channels must be 1 or 2"
if len(wav_data.shape) > channels:
wav_data = np.mean(wav_data, axis=1)
if sr != sample_rate:
wav_data = (
librosa.resample(wav_data, orig_sr=sr, target_sr=sample_rate),
)
return wav_data
task_results = []
pool = ThreadPool(self.speech_load_worker)
logging.info("[Frechet Speech Distance] Loading speech from {}...".format(dir))
for fname in os.listdir(dir):
res = pool.apply_async(
_load_speech_task,
args=(os.path.join(dir, fname), self.sample_rate, self.channels, dtype),
)
task_results.append(res)
pool.close()
pool.join()
return [k.get() for k in task_results]
def get_embeddings(self, x):
"""
Get embeddings
Params:
-- x : a list of np.ndarray speech samples
-- sr : sampling rate.
"""
embd_lst = []
try:
for speech in tqdm(x):
input_features = self.feature_extractor(
speech, sampling_rate=self.sample_rate, return_tensors="pt"
).input_values.to(self.device)
with torch.no_grad():
embd = self.model(input_features).last_hidden_state.mean(1)
if embd.device != torch.device("cpu"):
embd = embd.cpu()
if torch.is_tensor(embd):
embd = embd.detach().numpy()
embd_lst.append(embd)
except Exception as e:
print(
"[Frechet Speech Distance] get_embeddings throw an exception: {}".format(
str(e)
)
)
return np.concatenate(embd_lst, axis=0)
def calculate_embd_statistics(self, embd_lst):
if isinstance(embd_lst, list):
embd_lst = np.array(embd_lst)
mu = np.mean(embd_lst, axis=0)
sigma = np.cov(embd_lst, rowvar=False)
return mu, sigma
def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6):
"""
Adapted from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py
Numpy implementation of the Frechet Distance.
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
and X_2 ~ N(mu_2, C_2) is
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
Stable version by Dougal J. Sutherland.
Params:
-- mu1 : Numpy array containing the activations of a layer of the
inception net (like returned by the function 'get_predictions')
for generated samples.
-- mu2 : The sample mean over activations, precalculated on an
representative data set.
-- sigma1: The covariance matrix over activations for generated samples.
-- sigma2: The covariance matrix over activations, precalculated on an
representative data set.
Returns:
-- : The Frechet Distance.
"""
mu1 = np.atleast_1d(mu1)
mu2 = np.atleast_1d(mu2)
sigma1 = np.atleast_2d(sigma1)
sigma2 = np.atleast_2d(sigma2)
assert (
mu1.shape == mu2.shape
), "Training and test mean vectors have different lengths"
assert (
sigma1.shape == sigma2.shape
), "Training and test covariances have different dimensions"
diff = mu1 - mu2
# Product might be almost singular
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2).astype(complex), disp=False)
if not np.isfinite(covmean).all():
msg = (
"fid calculation produces singular product; "
"adding %s to diagonal of cov estimates"
) % eps
logging.info(msg)
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm(
(sigma1 + offset).dot(sigma2 + offset).astype(complex)
)
# Numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
m = np.max(np.abs(covmean.imag))
raise ValueError("Imaginary component {}".format(m))
covmean = covmean.real
tr_covmean = np.trace(covmean)
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
def score(
self,
real_path,
eval_path,
real_embds_path=None,
eval_embds_path=None,
dtype="float32",
):
"""
Computes the Frechet Speech Distance (FSD) between two directories of speech files.
Parameters:
- real_path (str): Path to the directory containing real speech files.
- eval_path (str): Path to the directory containing evaluation speech files.
- real_embds_path (str, optional): Path to save/load real speech embeddings (e.g., /folder/bkg_embs.npy). If None, embeddings won't be saved.
- eval_embds_path (str, optional): Path to save/load evaluation speech embeddings (e.g., /folder/test_embs.npy). If None, embeddings won't be saved.
- dtype (str, optional): Data type for loading speech. Default is "float32".
Returns:
- float: The Frechet Speech Distance (FSD) score between the two directories of speech files.
"""
# Load or compute real embeddings
if real_embds_path is not None and os.path.exists(real_embds_path):
logging.info(
f"[Frechet Speech Distance] Loading embeddings from {real_embds_path}..."
)
embds_real = np.load(real_embds_path)
else:
speech_real = self.load_speech_files(real_path, dtype=dtype)
embds_real = self.get_embeddings(speech_real)
if real_embds_path:
os.makedirs(os.path.dirname(real_embds_path), exist_ok=True)
np.save(real_embds_path, embds_real)
# Load or compute eval embeddings
if eval_embds_path is not None and os.path.exists(eval_embds_path):
logging.info(
f"[Frechet Speech Distance] Loading embeddings from {eval_embds_path}..."
)
embds_eval = np.load(eval_embds_path)
else:
speech_eval = self.load_speech_files(eval_path, dtype=dtype)
embds_eval = self.get_embeddings(speech_eval)
if eval_embds_path:
os.makedirs(os.path.dirname(eval_embds_path), exist_ok=True)
np.save(eval_embds_path, embds_eval)
# Check if embeddings are empty
if len(embds_real) == 0:
logging.info("[Frechet Speech Distance] real set dir is empty, exiting...")
return -10.46
if len(embds_eval) == 0:
logging.info("[Frechet Speech Distance] eval set dir is empty, exiting...")
return -1
# Compute statistics and FSD score
mu_real, sigma_real = self.calculate_embd_statistics(embds_real)
mu_eval, sigma_eval = self.calculate_embd_statistics(embds_eval)
fsd_score = self.calculate_frechet_distance(
mu_real, sigma_real, mu_eval, sigma_eval
)
return fsd_score
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
FSD = FrechetSpeechDistance(model_path=args.model_path)
score = FSD.score(
args.real_path, args.eval_path, args.real_embds_path, args.eval_embds_path
)
logging.info(f"FSD score: {score:.2f}")

View File

@ -0,0 +1,139 @@
"""
Calculate WER with Whisper model
"""
import argparse
import logging
import os
import re
from pathlib import Path
from typing import List, Tuple
import librosa
import soundfile as sf
import torch
from num2words import num2words
from tqdm import tqdm
from transformers import pipeline
from icefall.utils import store_transcripts, write_error_stats
logging.basicConfig(level=logging.INFO)
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--wav-path", type=str, help="path of the speech directory")
parser.add_argument("--decode-path", type=str, help="path of the speech directory")
parser.add_argument(
"--model-path",
type=str,
default="model/huggingface/whisper_medium",
help="path of the huggingface whisper model",
)
parser.add_argument(
"--transcript-path",
type=str,
default="data/transcript/test.tsv",
help="path of the transcript tsv file",
)
parser.add_argument(
"--batch-size", type=int, default=64, help="decoding batch size"
)
parser.add_argument(
"--device", type=str, default="cuda:0", help="decoding device, cuda:0 or cpu"
)
return parser
def post_process(text: str):
def convert_numbers(match):
return num2words(match.group())
text = re.sub(r"\b\d{1,2}\b", convert_numbers, text)
text = re.sub(r"[^a-zA-Z0-9']", " ", text.lower())
text = re.sub(r"\s+", " ", text)
return text
def save_results(
res_dir: str,
results: List[Tuple[str, List[str], List[str]]],
):
if not os.path.exists(res_dir):
os.makedirs(res_dir)
recog_path = os.path.join(res_dir, "recogs.txt")
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
errs_filename = os.path.join(res_dir, "errs.txt")
with open(errs_filename, "w") as f:
_ = write_error_stats(f, "test", results, enable_log=True)
logging.info("Wrote detailed error stats to {}".format(errs_filename))
class SpeechEvalDataset(torch.utils.data.Dataset):
def __init__(self, wav_path: str, transcript_path: str):
super().__init__()
self.audio_name = []
self.audio_paths = []
self.transcripts = []
with Path(transcript_path).open("r", encoding="utf8") as f:
meta = [item.split("\t") for item in f.read().rstrip().split("\n")]
for item in meta:
self.audio_name.append(item[0])
self.audio_paths.append(Path(wav_path, item[0] + ".wav"))
self.transcripts.append(item[1])
def __len__(self):
return len(self.audio_paths)
def __getitem__(self, index: int):
audio, sampling_rate = sf.read(self.audio_paths[index])
item = {
"array": librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000),
"sampling_rate": 16000,
"reference": self.transcripts[index],
"audio_name": self.audio_name[index],
}
return item
def main(args):
batch_size = args.batch_size
pipe = pipeline(
"automatic-speech-recognition",
model=args.model_path,
device=args.device,
tokenizer=args.model_path,
)
dataset = SpeechEvalDataset(args.wav_path, args.transcript_path)
results = []
bar = tqdm(
pipe(
dataset,
generate_kwargs={"language": "english", "task": "transcribe"},
batch_size=batch_size,
),
total=len(dataset),
)
for out in bar:
results.append(
(
out["audio_name"][0],
post_process(out["reference"][0].strip()).split(),
post_process(out["text"].strip()).split(),
)
)
save_results(args.decode_path, results)
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
main(args)

View File

@ -0,0 +1 @@
../../../libritts/TTS/vocos/discriminators.py

View File

@ -0,0 +1 @@
../../../libritts/TTS/vocos/export-onnx.py

View File

@ -0,0 +1 @@
../../../libritts/TTS/vocos/export.py

View File

@ -0,0 +1 @@
../../../libritts/TTS/vocos/generator.py

340
egs/ljspeech/TTS/vocos/infer.py Executable file
View File

@ -0,0 +1,340 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Wei Kang
# Han Zhu)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import logging
import math
import os
from functools import partial
from pathlib import Path
import torch
import torch.nn as nn
from lhotse.utils import fix_random_seed
from scipy.io.wavfile import write
from train import add_model_arguments, get_model, get_params
from tts_datamodule import LJSpeechTtsDataModule
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import AttributeDict, setup_logger, str2bool
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=100,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=10,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=False,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="flow_match/exp",
help="The experiment dir",
)
parser.add_argument(
"--generate-dir",
type=str,
default="generated_wavs",
help="Path name of the generated wavs",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
batch: dict,
):
"""
Args:
params:
It's the return value of :func:`get_params`.
model:
The text-to-feature neural model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
cut_ids = [cut.id for cut in batch["cut"]]
features = batch["features"] # (B, T, F)
utt_durations = batch["features_lens"]
x = features.permute(0, 2, 1) # (B, F, T)
audios = model(x.to(device)) # (B, T)
wav_dir = f"{params.res_dir}/{params.suffix}"
os.makedirs(wav_dir, exist_ok=True)
for i in range(audios.shape[0]):
audio = audios[i][: (utt_durations[i] - 1) * 256 + 1024]
audio = audio.cpu().squeeze().numpy()
write(f"{wav_dir}/{cut_ids[i]}.wav", 22050, audio)
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
test_set: str,
):
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The text-to-feature neural model.
test_set:
The name of the test_set
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
with open(f"{params.res_dir}/{test_set}.scp", "w", encoding="utf8") as f:
for batch_idx, batch in enumerate(dl):
texts = batch["text"]
cut_ids = [cut.id for cut in batch["cut"]]
decode_one_batch(
params=params,
model=model,
batch=batch,
)
assert len(texts) == len(cut_ids), (len(texts), len(cut_ids))
for i in range(len(texts)):
f.write(f"{cut_ids[i]}\t{texts[i]}\n")
num_cuts += len(texts)
if batch_idx % 50 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
@torch.no_grad()
def main():
parser = get_parser()
LJSpeechTtsDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.res_dir = params.exp_dir / params.generate_dir
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
params.device = device
logging.info(f"Device: {device}")
logging.info(params)
fix_random_seed(666)
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model = model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
ljspeech = LJSpeechTtsDataModule(args)
test_cuts = ljspeech.test_cuts()
test_dl = ljspeech.test_dataloaders(test_cuts)
test_sets = ["test"]
test_dls = [test_dl]
for test_set, test_dl in zip(test_sets, test_dls):
decode_dataset(
dl=test_dl,
params=params,
model=model,
test_set=test_set,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../libritts/TTS/vocos/loss.py

View File

@ -0,0 +1 @@
../../../libritts/TTS/vocos/model.py

View File

@ -0,0 +1 @@
../../../libritts/TTS/vocos/onnx_pretrained.py

View File

@ -0,0 +1 @@
../../../libritts/TTS/vocos/pretrained.py

1054
egs/ljspeech/TTS/vocos/train.py Executable file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,372 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo,
# Zengwei Yao,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
SpeechSynthesisDataset,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class LJSpeechTtsDataModule:
"""
DataModule for tts experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="TTS data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--sampling-rate",
type=int,
default=22050,
help="The sampleing rate of ljspeech dataset",
)
group.add_argument(
"--frame-shift",
type=int,
default=256,
help="Frame shift.",
)
group.add_argument(
"--frame-length",
type=int,
default=1024,
help="Frame shift.",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
group.add_argument(
"--use-fft-mag",
type=str2bool,
default=True,
help="Whether to use magnitude of fbank, false to use power energy.",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
logging.info("About to create train dataset")
train = SpeechSynthesisDataset(
return_text=True,
return_tokens=False,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
sampling_rate = self.args.sampling_rate
config = FbankConfig(
sampling_rate=sampling_rate,
frame_length=self.args.frame_length / sampling_rate, # (in second),
frame_shift=self.args.frame_shift / sampling_rate, # (in second)
use_fft_mag=self.args.use_fft_mag,
)
train = SpeechSynthesisDataset(
return_text=True,
return_tokens=False,
feature_input_strategy=OnTheFlyFeatures(Fbank(config)),
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
sampling_rate = self.args.sampling_rate
config = FbankConfig(
sampling_rate=sampling_rate,
frame_length=self.args.frame_length / sampling_rate, # (in second),
frame_shift=self.args.frame_shift / sampling_rate, # (in second)
use_fft_mag=self.args.use_fft_mag,
)
validate = SpeechSynthesisDataset(
return_text=True,
return_tokens=False,
feature_input_strategy=OnTheFlyFeatures(Fbank(config)),
return_cuts=self.args.return_cuts,
)
else:
validate = SpeechSynthesisDataset(
return_text=True,
return_tokens=False,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
num_buckets=self.args.num_buckets,
shuffle=False,
)
logging.info("About to create valid dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.info("About to create test dataset")
if self.args.on_the_fly_feats:
sampling_rate = self.args.sampling_rate
config = FbankConfig(
sampling_rate=sampling_rate,
frame_length=self.args.frame_length / sampling_rate, # (in second),
frame_shift=self.args.frame_shift / sampling_rate, # (in second)
use_fft_mag=self.args.use_fft_mag,
)
test = SpeechSynthesisDataset(
return_text=True,
return_tokens=False,
feature_input_strategy=OnTheFlyFeatures(Fbank(config)),
return_cuts=self.args.return_cuts,
)
else:
test = SpeechSynthesisDataset(
return_text=True,
return_tokens=False,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
test_sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
num_buckets=self.args.num_buckets,
shuffle=False,
)
logging.info("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=test_sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "ljspeech_cuts_train.jsonl.gz"
)
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get validation cuts")
return load_manifest_lazy(
self.args.manifest_dir / "ljspeech_cuts_valid.jsonl.gz"
)
@lru_cache()
def train_cuts_finetune(self) -> CutSet:
logging.info("About to get train cuts finetune")
return load_manifest_lazy(
self.args.manifest_dir / "ljspeech_cuts_train_finetune.jsonl.gz"
)
@lru_cache()
def valid_cuts_finetune(self) -> CutSet:
logging.info("About to get validation cuts finetune")
return load_manifest_lazy(
self.args.manifest_dir / "ljspeech_cuts_valid_finetune.jsonl.gz"
)
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "ljspeech_cuts_test.jsonl.gz"
)

View File

@ -0,0 +1 @@
../../../libritts/TTS/vocos/utils.py

View File

@ -251,18 +251,22 @@ def save_checkpoint_with_global_batch_idx(
)
def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
def find_checkpoints(
out_dir: Path,
iteration: int = 0,
prefix: str = "checkpoint",
) -> List[str]:
"""Find all available checkpoints in a directory.
The checkpoint filenames have the form: `checkpoint-xxx.pt`
The checkpoint filenames have the form: `{prefix}-xxx.pt`
where xxx is a numerical value.
Assume you have the following checkpoints in the folder `foo`:
- checkpoint-1.pt
- checkpoint-20.pt
- checkpoint-300.pt
- checkpoint-4000.pt
- {prefix}-1.pt
- {prefix}-20.pt
- {prefix}-300.pt
- {prefix}-4000.pt
Case 1 (Return all checkpoints)::
@ -291,8 +295,8 @@ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
Return a list of checkpoint filenames, sorted in descending
order by the numerical value in the filename.
"""
checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
pattern = re.compile(r"checkpoint-([0-9]+).pt")
checkpoints = list(glob.glob(f"{out_dir}/{prefix}-[0-9]*.pt"))
pattern = re.compile(rf"{prefix}-([0-9]+).pt")
iter_checkpoints = []
for c in checkpoints:
result = pattern.search(c)
@ -317,12 +321,13 @@ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
def remove_checkpoints(
out_dir: Path,
topk: int,
prefix: str = "checkpoint",
rank: int = 0,
):
"""Remove checkpoints from the given directory.
We assume that checkpoint filename has the form `checkpoint-xxx.pt`
where xxx is a number, representing the number of processed batches
We assume that checkpoint filename has the form `{prefix}-xxx.pt`
where xxx is a number, representing the number of processed batches/epochs
when saving that checkpoint. We sort checkpoints by filename and keep
only the `topk` checkpoints with the highest `xxx`.
@ -331,6 +336,8 @@ def remove_checkpoints(
The directory containing checkpoints to be removed.
topk:
Number of checkpoints to keep.
prefix:
The prefix of the checkpoint filename, normally `epoch`, `checkpoint`.
rank:
If using DDP for training, it is the rank of the current node.
Use 0 if no DDP is used for training.
@ -338,7 +345,7 @@ def remove_checkpoints(
assert topk >= 1, topk
if rank != 0:
return
checkpoints = find_checkpoints(out_dir)
checkpoints = find_checkpoints(out_dir, prefix=prefix)
if len(checkpoints) == 0:
logging.warn(f"No checkpoints found in {out_dir}")