Refactor, add libritts

This commit is contained in:
pkufool 2024-12-13 19:39:55 +08:00
parent 6e07cb91e3
commit d8a0a40955
19 changed files with 2476 additions and 266 deletions

View File

@ -3,7 +3,7 @@ from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from torch.nn import Conv2d from torch.nn import Conv2d
from torch.nn.utils import weight_norm from torch.nn.utils.parametrizations import weight_norm
from torchaudio.transforms import Spectrogram from torchaudio.transforms import Spectrogram

View File

@ -0,0 +1,371 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang)
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
"""
This script exports a transducer model from PyTorch to ONNX.
We use the pre-trained model from
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
2. Export the model to ONNX
./zipformer/export-onnx.py \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
--num-encoder-layers "2,2,3,4,3,2" \
--downsampling-factor "1,2,4,8,4,2" \
--feedforward-dim "512,768,1024,1536,1024,768" \
--num-heads "4,4,4,8,4,4" \
--encoder-dim "192,256,384,512,384,256" \
--query-head-dim 32 \
--value-head-dim 12 \
--pos-head-dim 4 \
--pos-dim 48 \
--encoder-unmasked-dim "192,192,256,256,256,192" \
--cnn-module-kernel "31,31,15,15,15,31" \
--decoder-dim 512 \
--joiner-dim 512 \
--causal False \
--chunk-size "16,32,64,-1" \
--left-context-frames "64,128,256,-1" \
--fp16 True
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
See ./onnx_pretrained.py and ./onnx_check.py for how to
use the exported ONNX models.
"""
import argparse
import logging
from pathlib import Path
from typing import Dict, Tuple
import onnx
import torch
import torch.nn as nn
from onnxconverter_common import float16
from onnxruntime.quantization import QuantType, quantize_dynamic
from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import make_pad_mask, num_tokens, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--sampling-rate",
type=int,
default=24000,
help="The sampleing rate of libritts dataset",
)
parser.add_argument(
"--frame-shift",
type=int,
default=256,
help="Frame shift.",
)
parser.add_argument(
"--frame-length",
type=int,
default=1024,
help="Frame shift.",
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--fp16",
type=str2bool,
default=False,
help="Whether to export models in fp16",
)
add_model_arguments(parser)
return parser
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = value
onnx.save(model, filename)
def export_model_onnx(
model: nn.Module,
model_filename: str,
opset_version: int = 13,
) -> None:
"""Export the joiner model to ONNX format.
The exported joiner model has two inputs:
- encoder_out: a tensor of shape (N, joiner_dim)
- decoder_out: a tensor of shape (N, joiner_dim)
and produces one output:
- logit: a tensor of shape (N, vocab_size)
"""
input_tensor = torch.rand((2, 80, 100), dtype=torch.float32)
torch.onnx.export(
model,
(input_tensor,),
model_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"features",
],
output_names=["audio"],
dynamic_axes={
"features": {0: "N", 2: "F"},
"audio": {0: "N", 1: "T"},
},
)
meta_data = {
"model_type": "Vocos",
"version": "1",
"model_author": "k2-fsa",
"comment": "ConvNext Vocos",
}
logging.info(f"meta_data: {meta_data}")
add_meta_data(filename=model_filename, meta_data=meta_data)
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
params.device = device
logging.info(params)
logging.info("About to create model")
model = get_model(params)
model.to(device)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.eval()
vocos = model.generator
if params.iter > 0:
suffix = f"iter-{params.iter}"
else:
suffix = f"epoch-{params.epoch}"
suffix += f"-avg-{params.avg}"
opset_version = 13
logging.info("Exporting model")
model_filename = params.exp_dir / f"vocos-{suffix}.onnx"
export_model_onnx(
vocos,
model_filename,
opset_version=opset_version,
)
logging.info(f"Exported vocos generator to {model_filename}")
if params.fp16:
logging.info("Generate fp16 models")
model = onnx.load(model_filename)
model_fp16 = float16.convert_float_to_float16(model, keep_io_types=True)
model_filename_fp16 = params.exp_dir / f"vocos-{suffix}.fp16.onnx"
onnx.save(model_fp16, model_filename_fp16)
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
logging.info("Generate int8 quantization models")
model_filename_int8 = params.exp_dir / f"vocos-{suffix}.int8.onnx"
quantize_dynamic(
model_input=model_filename,
model_output=model_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

407
egs/libritts/TTS/vocos/export.py Executable file
View File

@ -0,0 +1,407 @@
#!/usr/bin/env python3
#
# Copyright 2024 Xiaomi Corporation (Author: Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
Note: This is a example for libritts dataset, if you are using different
dataset, you should change the argument values according to your dataset.
(1) Export to torchscript model using torch.jit.script()
./vocos/export.py \
--exp-dir ./vocos/exp \
--epoch 30 \
--avg 9 \
--jit 1
It will generate a file `jit_script.pt` in the given `exp_dir`. You can later
load it by `torch.jit.load("jit_script.pt")`.
Check ./jit_pretrained.py for its usage.
Check https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall.
- For streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`.
You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`.
Check ./jit_pretrained_streaming.py for its usage.
Check https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall.
(2) Export `model.state_dict()`
- For non-streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9
- For streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--causal 1 \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
load it by `icefall.checkpoint.load_checkpoint()`.
- For non-streaming model:
To use the generated file with `zipformer/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./zipformer/decode.py \
--exp-dir ./zipformer/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
- For streaming model:
To use the generated file with `zipformer/decode.py` and `zipformer/streaming_decode.py`, you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
# simulated streaming decoding
./zipformer/decode.py \
--exp-dir ./zipformer/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
# chunk-wise streaming decoding
./zipformer/streaming_decode.py \
--exp-dir ./zipformer/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
Check ./pretrained.py for its usage.
Note: If you don't want to train a model from scratch, we have
provided one for you. You can get it at
- non-streaming model:
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
- streaming model:
https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
with the following commands:
sudo apt-get install git-lfs
git lfs install
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
# You will find the pre-trained models in exp dir
"""
import argparse
import logging
from pathlib import Path
from typing import List, Tuple
import torch
from torch import Tensor, nn
from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
)
from icefall.utils import str2bool
from utils import load_checkpoint
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--sampling-rate",
type=int,
default=24000,
help="The sampleing rate of libritts dataset",
)
parser.add_argument(
"--frame-shift",
type=int,
default=256,
help="Frame shift.",
)
parser.add_argument(
"--frame-length",
type=int,
default=1024,
help="Frame shift.",
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="vocos/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
It will generate a file named jit_script.pt.
Check ./jit_pretrained.py for how to use it.
""",
)
add_model_arguments(parser)
return parser
class EncoderModel(nn.Module):
"""A wrapper for encoder and encoder_embed"""
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
super().__init__()
self.encoder = encoder
self.encoder_embed = encoder_embed
def forward(
self, features: Tensor, feature_lengths: Tensor
) -> Tuple[Tensor, Tensor]:
"""
Args:
features: (N, T, C)
feature_lengths: (N,)
"""
x, x_lens = self.encoder_embed(features, feature_lengths)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return encoder_out, encoder_out_lens
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
params.device = device
logging.info(f"device: {device}")
logging.info(params)
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.eval()
model = model.generator
if params.jit is True:
model.encoder = EncoderModel(model.encoder, model.encoder_embed)
filename = "jit_script.pt"
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
model.save(str(params.exp_dir / filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torchscript. Export model.state_dict()")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "generator.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -1,122 +1,154 @@
import torch import logging
from torch import nn
from typing import Optional from typing import Optional
import numpy as np
class AdaLayerNorm(nn.Module): import torch
""" from torch import nn
Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes from torch.autograd import Variable
from torch.nn import functional as F
Args:
num_embeddings (int): Number of embeddings.
embedding_dim (int): Dimension of the embeddings.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.dim = embedding_dim
self.scale = nn.Embedding(
num_embeddings=num_embeddings, embedding_dim=embedding_dim
)
self.shift = nn.Embedding(
num_embeddings=num_embeddings, embedding_dim=embedding_dim
)
torch.nn.init.ones_(self.scale.weight)
torch.nn.init.zeros_(self.shift.weight)
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
scale = self.scale(cond_embedding_id)
shift = self.shift(cond_embedding_id)
x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
x = x * scale + shift
return x
class ISTFT(nn.Module): def window_sumsquare(
""" window: torch.Tensor,
Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with n_samples: int,
windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges. hop_length: int = 256,
See issue: https://github.com/pytorch/pytorch/issues/62323 win_length: int = 1024,
Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
Args:
n_fft (int): Size of Fourier transform.
hop_length (int): The distance between neighboring sliding window frames.
win_length (int): The size of window frame and STFT filter.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def __init__(
self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
): ):
super().__init__()
if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.")
self.padding = padding
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
window = torch.hann_window(win_length)
self.register_buffer("window", window)
def forward(self, spec: torch.Tensor) -> torch.Tensor:
""" """
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. Compute the sum-square envelope of a window function at a given hop length.
This is used to estimate modulation effects induced by windowing
Args: observations in short-time fourier transforms.
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, Parameters
N is the number of frequency bins, and T is the number of time frames. ----------
window : string, tuple, number, callable, or list-like
Returns: Window specification, as in `get_window`
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. n_samples : int > 0
The number of expected samples.
hop_length : int > 0
The number of samples to advance between frames
win_length :
The length of the window function.
Returns
-------
wss : torch.Tensor, The sum-squared envelope of the window function.
""" """
if self.padding == "center":
# Fallback to pytorch native implementation
return torch.istft(
spec,
self.n_fft,
self.hop_length,
self.win_length,
self.window,
center=True,
)
elif self.padding == "same":
pad = (self.win_length - self.hop_length) // 2
else:
raise ValueError("Padding must be 'center' or 'same'.")
assert spec.dim() == 3, "Expected a 3D tensor as input" n_frames = (n_samples - win_length) // hop_length + 1
B, N, T = spec.shape output_size = (n_frames - 1) * hop_length + win_length
device = window.device
# Inverse FFT
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
ifft = ifft * self.window[None, :, None]
# Overlap and Add
output_size = (T - 1) * self.hop_length + self.win_length
y = torch.nn.functional.fold(
ifft,
output_size=(1, output_size),
kernel_size=(1, self.win_length),
stride=(1, self.hop_length),
)[:, 0, 0, :]
# Window envelope # Window envelope
window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) window_sq = window.square().expand(1, n_frames, -1).transpose(1, 2)
window_envelope = torch.nn.functional.fold( window_envelope = torch.nn.functional.fold(
window_sq, window_sq,
output_size=(1, output_size), output_size=(1, output_size),
kernel_size=(1, self.win_length), kernel_size=(1, win_length),
stride=(1, self.hop_length), stride=(1, hop_length),
).squeeze() ).squeeze()
window_envelope = torch.nn.functional.pad(
window_envelope, (0, n_samples - output_size)
)
return window_envelope
# Normalize
norm_indexes = window_envelope > 1e-11
y[:, norm_indexes] = y[:, norm_indexes] / window_envelope[norm_indexes]
return y class ISTFT(torch.nn.Module):
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
def __init__(
self,
filter_length: int = 1024,
hop_length: int = 256,
win_length: int = 1024,
padding: str = "none",
window_type: str = "povey",
max_samples: int = 1440000, # 1440000 / 24000 = 60s
):
super(ISTFT, self).__init__()
self.filter_length = filter_length
self.hop_length = hop_length
self.win_length = win_length
self.padding = padding
scale = self.filter_length / self.hop_length
fourier_basis = np.fft.fft(np.eye(self.filter_length))
cutoff = int((self.filter_length / 2 + 1))
fourier_basis = np.vstack(
[np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
)
inverse_basis = torch.FloatTensor(
np.linalg.pinv(scale * fourier_basis).T[:, None, :]
)
assert filter_length >= win_length
# Consistence with lhotse, search "create_frame_window" in https://github.com/lhotse-speech/lhotse
assert window_type in [
"hanning",
"povey",
], f"Only 'hanning' and 'povey' windows are supported, given {window_type}."
fft_window = torch.hann_window(win_length, periodic=False)
if window_type == "povey":
fft_window = fft_window.pow(0.85)
if filter_length > win_length:
pad_size = (filter_length - win_length) // 2
fft_window = torch.nn.functional.pad(fft_window, (pad_size, pad_size))
window_sum = window_sumsquare(
window=fft_window,
n_samples=max_samples,
hop_length=hop_length,
win_length=filter_length,
)
inverse_basis *= fft_window
self.register_buffer("inverse_basis", inverse_basis.float())
self.register_buffer("fft_window", fft_window)
self.register_buffer("window_sum", window_sum)
self.tiny = torch.finfo(torch.float16).tiny
def forward(self, magnitude, phase):
magnitude_phase = torch.cat(
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
)
inverse_transform = F.conv_transpose1d(
magnitude_phase,
Variable(self.inverse_basis, requires_grad=False),
stride=self.hop_length,
padding=0,
)
inverse_transform = inverse_transform.squeeze(1)
window_sum = self.window_sum
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
if self.window_sum.size(-1) < inverse_transform.size(-1):
logging.warning(
f"The precomputed `window_sumsquare` is too small, recomputing, "
f"from {self.window_sum.size(-1)} to {inverse_transform.size(-1)}"
)
window_sum = window_sumsquare(
window=self.fft_window,
n_samples=inverse_transform.size(-1),
win_length=self.filter_length,
hop_length=self.hop_length,
)
window_sum = window_sum[: inverse_transform.size(-1)]
approx_nonzero_indices = (window_sum > self.tiny).nonzero().squeeze()
inverse_transform[:, approx_nonzero_indices] /= window_sum[
approx_nonzero_indices
]
# scale by hop ratio
inverse_transform *= float(self.filter_length) / self.hop_length
assert self.padding in ["none", "same", "center"]
if self.padding == "center":
pad_len = self.filter_length // 2
elif self.padding == "same":
pad_len = (self.filter_length - self.hop_length) // 2
else:
return inverse_transform
return inverse_transform[:, pad_len:-pad_len]
class ConvNeXtBlock(nn.Module): class ConvNeXtBlock(nn.Module):
@ -127,8 +159,6 @@ class ConvNeXtBlock(nn.Module):
intermediate_dim (int): Dimensionality of the intermediate layer. intermediate_dim (int): Dimensionality of the intermediate layer.
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
Defaults to None. Defaults to None.
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
None means non-conditional LayerNorm. Defaults to None.
""" """
def __init__( def __init__(
@ -136,20 +166,14 @@ class ConvNeXtBlock(nn.Module):
dim: int, dim: int,
intermediate_dim: int, intermediate_dim: int,
layer_scale_init_value: Optional[float] = None, layer_scale_init_value: Optional[float] = None,
adanorm_num_embeddings: Optional[int] = None,
): ):
super().__init__() super().__init__()
self.dwconv = nn.Conv1d( self.dwconv = nn.Conv1d(
dim, dim, kernel_size=7, padding=3, groups=dim dim, dim, kernel_size=7, padding=3, groups=dim
) # depthwise conv ) # depthwise conv
self.adanorm = adanorm_num_embeddings is not None
if adanorm_num_embeddings:
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
else:
self.norm = nn.LayerNorm(dim, eps=1e-6) self.norm = nn.LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear( # pointwise/1x1 convs, implemented with linear layers
dim, intermediate_dim self.pwconv1 = nn.Linear(dim, intermediate_dim)
) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU() self.act = nn.GELU()
self.pwconv2 = nn.Linear(intermediate_dim, dim) self.pwconv2 = nn.Linear(intermediate_dim, dim)
self.gamma = ( self.gamma = (
@ -159,15 +183,12 @@ class ConvNeXtBlock(nn.Module):
) )
def forward( def forward(
self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None self,
x: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
residual = x residual = x
x = self.dwconv(x) x = self.dwconv(x)
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
if self.adanorm:
assert cond_embedding_id is not None
x = self.norm(x, cond_embedding_id)
else:
x = self.norm(x) x = self.norm(x)
x = self.pwconv1(x) x = self.pwconv1(x)
x = self.act(x) x = self.act(x)
@ -189,28 +210,22 @@ class Generator(torch.nn.Module):
hop_length: int = 256, hop_length: int = 256,
intermediate_dim: int = 1536, intermediate_dim: int = 1536,
num_layers: int = 8, num_layers: int = 8,
padding: str = "same", padding: str = "none",
layer_scale_init_value: Optional[float] = None, max_samples: int = 1440000, # 1440000 / 24000 = 60s
adanorm_num_embeddings: Optional[int] = None,
): ):
super(Generator, self).__init__() super(Generator, self).__init__()
self.feature_dim = feature_dim self.feature_dim = feature_dim
self.embed = nn.Conv1d(feature_dim, dim, kernel_size=7, padding=3) self.embed = nn.Conv1d(feature_dim, dim, kernel_size=7, padding=3)
self.adanorm = adanorm_num_embeddings is not None
if adanorm_num_embeddings:
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
else:
self.norm = nn.LayerNorm(dim, eps=1e-6) self.norm = nn.LayerNorm(dim, eps=1e-6)
layer_scale_init_value = layer_scale_init_value or 1 / num_layers layer_scale_init_value = 1 / num_layers
self.convnext = nn.ModuleList( self.convnext = nn.ModuleList(
[ [
ConvNeXtBlock( ConvNeXtBlock(
dim=dim, dim=dim,
intermediate_dim=intermediate_dim, intermediate_dim=intermediate_dim,
layer_scale_init_value=layer_scale_init_value, layer_scale_init_value=layer_scale_init_value,
adanorm_num_embeddings=adanorm_num_embeddings,
) )
for _ in range(num_layers) for _ in range(num_layers)
] ]
@ -221,7 +236,11 @@ class Generator(torch.nn.Module):
self.out_proj = torch.nn.Linear(dim, n_fft + 2) self.out_proj = torch.nn.Linear(dim, n_fft + 2)
self.istft = ISTFT( self.istft = ISTFT(
n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding filter_length=n_fft,
hop_length=hop_length,
win_length=n_fft,
padding=padding,
max_samples=max_samples,
) )
def _init_weights(self, m): def _init_weights(self, m):
@ -229,29 +248,17 @@ class Generator(torch.nn.Module):
nn.init.trunc_normal_(m.weight, std=0.02) nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
bandwidth_id = kwargs.get("bandwidth_id", None)
x = self.embed(x) x = self.embed(x)
if self.adanorm:
assert bandwidth_id is not None
x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
else:
x = self.norm(x.transpose(1, 2)) x = self.norm(x.transpose(1, 2))
x = x.transpose(1, 2) x = x.transpose(1, 2)
for conv_block in self.convnext: for conv_block in self.convnext:
x = conv_block(x, cond_embedding_id=bandwidth_id) x = conv_block(x)
x = self.final_layer_norm(x.transpose(1, 2)) x = self.final_layer_norm(x.transpose(1, 2))
x = self.out_proj(x).transpose(1, 2) x = self.out_proj(x).transpose(1, 2)
mag, p = x.chunk(2, dim=1) mag, phase = x.chunk(2, dim=1)
mag = torch.exp(mag) mag = torch.exp(mag)
mag = torch.clip( # safeguard to prevent excessively large magnitudes
mag, max=1e2 mag = torch.clip(mag, max=1e2)
) # safeguard to prevent excessively large magnitudes audio = self.istft(mag, phase)
x = torch.cos(p)
y = torch.sin(p)
S = mag * (x + 1j * y)
audio = self.istft(S)
return audio return audio

40
egs/libritts/TTS/vocos/infer.py Normal file → Executable file
View File

@ -20,6 +20,7 @@ import argparse
import json import json
import logging import logging
import math import math
import time
import os import os
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
@ -29,7 +30,7 @@ import torch.nn as nn
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from scipy.io.wavfile import write from scipy.io.wavfile import write
from train import add_model_arguments, get_model, get_params from train import add_model_arguments, get_model, get_params
from tts_datamodule import LJSpeechTtsDataModule from tts_datamodule import LibriTTSDataModule
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
@ -89,7 +90,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="flow_match/exp", default="vocos/exp",
help="The experiment dir", help="The experiment dir",
) )
@ -128,22 +129,31 @@ def decode_one_batch(
cut_ids = [cut.id for cut in batch["cut"]] cut_ids = [cut.id for cut in batch["cut"]]
infer_time = 0
audio_time = 0
features = batch["features"] # (B, T, F) features = batch["features"] # (B, T, F)
utt_durations = batch["features_lens"] utt_durations = batch["features_lens"]
x = features.permute(0, 2, 1) # (B, F, T) x = features.permute(0, 2, 1) # (B, F, T)
audio_time += torch.sum(utt_durations)
start = time.time()
audios = model(x.to(device)) # (B, T) audios = model(x.to(device)) # (B, T)
infer_time += time.time() - start
wav_dir = f"{params.res_dir}/{params.suffix}" wav_dir = f"{params.res_dir}/{params.suffix}"
os.makedirs(wav_dir, exist_ok=True) os.makedirs(wav_dir, exist_ok=True)
for i in range(audios.shape[0]): for i in range(audios.shape[0]):
audio = audios[i][ audio = audios[i][: int(utt_durations[i] * 256)]
: int(utt_durations[i] * params.frame_shift_ms / 1000 * 22050)
]
audio = audio.cpu().squeeze().numpy() audio = audio.cpu().squeeze().numpy()
write(f"{wav_dir}/{cut_ids[i]}.wav", 22050, audio) write(f"{wav_dir}/{cut_ids[i]}.wav", 24000, audio)
print(f"RTF : {infer_time / (audio_time * (256/24000))}")
def decode_dataset( def decode_dataset(
@ -173,7 +183,7 @@ def decode_dataset(
with open(f"{params.res_dir}/{test_set}.scp", "w", encoding="utf8") as f: with open(f"{params.res_dir}/{test_set}.scp", "w", encoding="utf8") as f:
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["text"] # texts = batch["text"]
cut_ids = [cut.id for cut in batch["cut"]] cut_ids = [cut.id for cut in batch["cut"]]
decode_one_batch( decode_one_batch(
@ -182,12 +192,12 @@ def decode_dataset(
batch=batch, batch=batch,
) )
assert len(texts) == len(cut_ids), (len(texts), len(cut_ids)) # assert len(texts) == len(cut_ids), (len(texts), len(cut_ids))
for i in range(len(texts)): # for i in range(len(texts)):
f.write(f"{cut_ids[i]}\t{texts[i]}\n") # f.write(f"{cut_ids[i]}\t{texts[i]}\n")
num_cuts += len(texts) # num_cuts += len(texts)
if batch_idx % 50 == 0: if batch_idx % 50 == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
@ -200,7 +210,7 @@ def decode_dataset(
@torch.no_grad() @torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
LJSpeechTtsDataModule.add_arguments(parser) LibriTTSDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
@ -318,11 +328,11 @@ def main():
# we need cut ids to display recognition results. # we need cut ids to display recognition results.
args.return_cuts = True args.return_cuts = True
ljspeech = LJSpeechTtsDataModule(args) libritts = LibriTTSDataModule(args)
test_cuts = ljspeech.test_cuts() test_cuts = libritts.test_clean_cuts()
test_dl = ljspeech.test_dataloaders(test_cuts) test_dl = libritts.test_dataloaders(test_cuts)
test_sets = ["test"] test_sets = ["test"]
test_dls = [test_dl] test_dls = [test_dl]

View File

@ -19,8 +19,9 @@ class Vocos(torch.nn.Module):
hop_length: int = 256, hop_length: int = 256,
intermediate_dim: int = 1536, intermediate_dim: int = 1536,
num_layers: int = 8, num_layers: int = 8,
padding: str = "same", padding: str = "none",
sample_rate: int = 24000, sample_rate: int = 24000,
max_seconds: int = 60,
): ):
super(Vocos, self).__init__() super(Vocos, self).__init__()
self.generator = Generator( self.generator = Generator(
@ -31,6 +32,7 @@ class Vocos(torch.nn.Module):
num_layers=num_layers, num_layers=num_layers,
intermediate_dim=intermediate_dim, intermediate_dim=intermediate_dim,
padding=padding, padding=padding,
max_samples=int(sample_rate * max_seconds),
) )
self.mpd = MultiPeriodDiscriminator() self.mpd = MultiPeriodDiscriminator()

View File

@ -0,0 +1,268 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads ONNX models and uses them to decode waves.
You can use the following command to get the exported models:
We use the pre-trained model from
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
2. Export the model to ONNX
./zipformer/export-onnx.py \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
--causal False
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
3. Run this file
./zipformer/onnx_pretrained.py \
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
"""
import argparse
import logging
import math
from pathlib import Path
from typing import List, Tuple
import onnxruntime as ort
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
from lhotse import Fbank, FbankConfig
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--model-filename",
type=str,
required=True,
help="Path to the encoder onnx model. ",
)
parser.add_argument(
"--sampling-rate",
type=int,
default=24000,
help="The sampleing rate of libritts dataset",
)
parser.add_argument(
"--frame-shift",
type=int,
default=256,
help="Frame shift.",
)
parser.add_argument(
"--frame-length",
type=int,
default=1024,
help="Frame shift.",
)
parser.add_argument(
"--use-fft-mag",
type=str2bool,
default=True,
help="Whether to use magnitude of fbank, false to use power energy.",
)
parser.add_argument(
"--output-dir",
type=str,
default="generated_audios",
help="The generated will be written to.",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
return parser
class OnnxModel:
def __init__(
self,
model_filename: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 4
self.session_opts = session_opts
self.init_model(model_filename)
def init_model(self, model_filename: str):
self.model = ort.InferenceSession(
model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
def run_model(
self,
x: torch.Tensor,
) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C)
x_lens:
A 2-D tensor of shape (N,). Its dtype is torch.int64
Returns:
Return a tuple containing:
- encoder_out, its shape is (N, T', joiner_dim)
- encoder_out_lens, its shape is (N,)
"""
out = self.model.run(
[
self.model.get_outputs()[0].name,
],
{
self.model.get_inputs()[0].name: x.numpy(),
},
)
return torch.from_numpy(out[0])
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
output_dir = Path(args.model_filename).parent / args.output_dir
output_dir.mkdir(exist_ok=True)
args.output_dir = output_dir
logging.info(vars(args))
model = OnnxModel(model_filename=args.model_filename)
config = FbankConfig(
sampling_rate=args.sampling_rate,
frame_length=args.frame_length / args.sampling_rate, # (in second),
frame_shift=args.frame_shift / args.sampling_rate, # (in second)
use_fft_mag=args.use_fft_mag,
)
fbank = Fbank(config)
logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files, expected_sample_rate=args.sampling_rate
)
wave_lengths = [w.size(0) for w in waves]
waves = pad_sequence(waves, batch_first=True, padding_value=0)
logging.info(f"waves : {waves.shape}")
features = fbank.extract_batch(waves, sampling_rate=args.sampling_rate)
if features.dim() == 2:
features = features.unsqueeze(0)
features = features.permute(0, 2, 1)
logging.info(f"features : {features.shape}")
logging.info("Generating started")
# model forward
audios = model.run_model(features)
for i, filename in enumerate(args.sound_files):
audio = audios[i : i + 1, 0 : wave_lengths[i]]
ofilename = args.output_dir / filename.split("/")[-1]
logging.info(f"Writting audio : {ofilename}")
torchaudio.save(str(ofilename), audio.cpu(), args.sampling_rate)
logging.info("Generating Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,196 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads a checkpoint and uses it to decode waves.
You can generate the checkpoint with the following command:
"""
import argparse
import logging
import math
from pathlib import Path
from typing import List
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_model, get_params
from lhotse import Fbank, FbankConfig
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--sampling-rate",
type=int,
default=24000,
help="The sampleing rate of libritts dataset",
)
parser.add_argument(
"--frame-shift",
type=int,
default=256,
help="Frame shift.",
)
parser.add_argument(
"--frame-length",
type=int,
default=1024,
help="Frame shift.",
)
parser.add_argument(
"--use-fft-mag",
type=str2bool,
default=True,
help="Whether to use magnitude of fbank, false to use power energy.",
)
parser.add_argument(
"--output-dir",
type=str,
default="generated_audios",
help="The generated will be written to.",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
add_model_arguments(parser)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0].contiguous())
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
params.device = device
output_dir = Path(params.checkpoint).parent / params.output_dir
output_dir.mkdir(exist_ok=True)
params.output_dir = output_dir
logging.info(f"{params}")
logging.info("Creating model")
model = get_model(params)
model = model.generator
checkpoint = torch.load(params.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
logging.info("Constructing Fbank computer")
config = FbankConfig(
sampling_rate=params.sampling_rate,
frame_length=params.frame_length / params.sampling_rate, # (in second),
frame_shift=params.frame_shift / params.sampling_rate, # (in second)
use_fft_mag=params.use_fft_mag,
)
fbank = Fbank(config)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sampling_rate
)
wave_lengths = [w.size(0) for w in waves]
waves = pad_sequence(waves, batch_first=True, padding_value=0)
features = (
fbank.extract_batch(waves, sampling_rate=params.sampling_rate)
.permute(0, 2, 1)
.to(device)
)
logging.info("Generating started")
# model forward
audios = model(features)
for i, filename in enumerate(params.sound_files):
audio = audios[i : i + 1, 0 : wave_lengths[i]]
ofilename = params.output_dir / filename.split("/")[-1]
logging.info(f"Writting audio : {ofilename}")
torchaudio.save(str(ofilename), audio.cpu(), params.sampling_rate)
logging.info("Generating Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -52,9 +52,11 @@ from utils import (
save_checkpoint, save_checkpoint,
plot_spectrogram, plot_spectrogram,
get_cosine_schedule_with_warmup, get_cosine_schedule_with_warmup,
save_checkpoint_with_global_batch_idx,
) )
from icefall import diagnostics from icefall import diagnostics
from icefall.checkpoint import remove_checkpoints, update_averaged_model
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
@ -91,6 +93,20 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Intermediate dim of ConvNeXt module.", help="Intermediate dim of ConvNeXt module.",
) )
parser.add_argument(
"--max-seconds",
type=int,
default=60,
help="""
The length of the precomputed normalization window sum square
(required by istft). This argument is only for onnx export, it determines
the max length of the audio that be properly normalized.
Note, you can generate audios longer than this value with the exported onnx model,
the part longer than this value will not be normalized yet.
The larger this value is the bigger the exported onnx model will be.
""",
)
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -203,6 +219,16 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--keep-last-epoch-k",
type=int,
default=50,
help="""Only keep this number of checkpoints on disk.
For instance, if it is 3, there are only 3 checkpoints
in the exp-dir with filenames `epoch-xxx.pt`.
""",
)
parser.add_argument( parser.add_argument(
"--average-period", "--average-period",
type=int, type=int,
@ -290,8 +316,8 @@ def get_params() -> AttributeDict:
"valid_interval": 500, "valid_interval": 500,
"feature_dim": 80, "feature_dim": 80,
"segment_size": 16384, "segment_size": 16384,
"adam_b1": 0.8, "adam_b1": 0.9,
"adam_b2": 0.9, "adam_b2": 0.99,
"warmup_steps": 0, "warmup_steps": 0,
"max_steps": 2000000, "max_steps": 2000000,
"env_info": get_env_info(), "env_info": get_env_info(),
@ -311,6 +337,7 @@ def get_model(params: AttributeDict) -> nn.Module:
intermediate_dim=params.intermediate_dim, intermediate_dim=params.intermediate_dim,
num_layers=params.num_layers, num_layers=params.num_layers,
sample_rate=params.sampling_rate, sample_rate=params.sampling_rate,
max_seconds=params.max_seconds,
).to(device) ).to(device)
num_param_gen = sum([p.numel() for p in model.generator.parameters()]) num_param_gen = sum([p.numel() for p in model.generator.parameters()])
@ -479,11 +506,6 @@ def compute_discriminator_loss(
info["loss_disc_mrd"] = loss_mrd.detach().cpu().item() info["loss_disc_mrd"] = loss_mrd.detach().cpu().item()
info["loss_disc_mpd"] = loss_mpd.detach().cpu().item() info["loss_disc_mpd"] = loss_mpd.detach().cpu().item()
for i in range(len(loss_mpd_real)):
info[f"loss_disc_mpd_period_{i+1}"] = loss_mpd_real[i] + loss_mpd_gen[i]
for i in range(len(loss_mrd_real)):
info[f"loss_disc_mrd_resolution_{i+1}"] = loss_mrd_real[i] + loss_mrd_gen[i]
return loss_disc_all, info return loss_disc_all, info
@ -497,6 +519,7 @@ def train_one_epoch(
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler, scaler: GradScaler,
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
rank: int = 0, rank: int = 0,
@ -542,6 +565,7 @@ def train_one_epoch(
save_checkpoint( save_checkpoint(
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
model=model, model=model,
model_avg=model_avg,
params=params, params=params,
optimizer_g=optimizer_g, optimizer_g=optimizer_g,
optimizer_d=optimizer_d, optimizer_d=optimizer_d,
@ -588,6 +612,7 @@ def train_one_epoch(
loss_disc.backward() loss_disc.backward()
optimizer_d.step() optimizer_d.step()
scheduler_d.step()
optimizer_g.zero_grad() optimizer_g.zero_grad()
loss_gen, loss_gen_info = compute_generator_loss( loss_gen, loss_gen_info = compute_generator_loss(
@ -599,6 +624,7 @@ def train_one_epoch(
loss_gen.backward() loss_gen.backward()
optimizer_g.step() optimizer_g.step()
scheduler_g.step()
loss_info = loss_gen_info + loss_disc_info loss_info = loss_gen_info + loss_disc_info
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_gen_info tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_gen_info
@ -611,6 +637,39 @@ def train_one_epoch(
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 5:
return return
if (
rank == 0
and params.batch_idx_train > 0
and params.batch_idx_train % params.average_period == 0
):
update_averaged_model(
params=params,
model_cur=model,
model_avg=model_avg,
)
if (
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
model=model,
model_avg=model_avg,
params=params,
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
rank=rank,
)
if params.batch_idx_train % 100 == 0 and params.use_fp16: if params.batch_idx_train % 100 == 0 and params.use_fp16:
# If the grad scale was less than 1, try increasing it. The _growth_interval # If the grad scale was less than 1, try increasing it. The _growth_interval
# of the grad scaler is configurable, but we can't configure it to have different # of the grad scaler is configurable, but we can't configure it to have different
@ -641,8 +700,8 @@ def train_one_epoch(
f"Epoch {params.cur_epoch}, batch {batch_idx}, " f"Epoch {params.cur_epoch}, batch {batch_idx}, "
f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, "
f"loss[{loss_info}], tot_loss[{tot_loss}], " f"loss[{loss_info}], tot_loss[{tot_loss}], "
f"cur_lr_g: {cur_lr_g:.2e}, " f"cur_lr_g: {cur_lr_g:.4e}, "
f"cur_lr_d: {cur_lr_d:.2e}, " f"cur_lr_d: {cur_lr_d:.4e}, "
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
) )
@ -685,8 +744,6 @@ def train_one_epoch(
tb_writer, "train/valid_", params.batch_idx_train tb_writer, "train/valid_", params.batch_idx_train
) )
scheduler_g.step()
scheduler_d.step()
loss_value = tot_loss["loss_gen"] loss_value = tot_loss["loss_gen"]
params.train_loss = loss_value params.train_loss = loss_value
if params.train_loss < params.best_train_loss: if params.train_loss < params.best_train_loss:
@ -766,7 +823,7 @@ def compute_validation_loss(
params.sampling_rate, params.sampling_rate,
) )
logging.info(f"RTF : {infer_time / (audio_time * 10 / 1000)}") logging.info(f"Validation RTF : {infer_time / (audio_time * 10 / 1000)}")
if world_size > 1: if world_size > 1:
tot_loss.reduce(device) tot_loss.reduce(device)
@ -811,15 +868,22 @@ def run(rank, world_size, args):
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", rank) device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
params.device = device params.device = device
logging.info(params) logging.info(params)
logging.info("About to create model")
logging.info("About to create model")
model = get_model(params) model = get_model(params)
assert params.save_every_n >= params.average_period
model_avg: Optional[nn.Module] = None
if rank == 0:
# model_avg is only used with rank 0
model_avg = copy.deepcopy(model).to(torch.float64)
assert params.start_epoch > 0, params.start_epoch assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(params=params, model=model) checkpoints = load_checkpoint_if_available(
params=params, model=model, model_avg=model_avg
)
model = model.to(device) model = model.to(device)
generator = model.generator generator = model.generator
@ -915,6 +979,7 @@ def run(rank, world_size, args):
train_one_epoch( train_one_epoch(
params=params, params=params,
model=model, model=model,
model_avg=model_avg,
optimizer_g=optimizer_g, optimizer_g=optimizer_g,
optimizer_d=optimizer_d, optimizer_d=optimizer_d,
scheduler_g=scheduler_g, scheduler_g=scheduler_g,
@ -936,6 +1001,7 @@ def run(rank, world_size, args):
filename=filename, filename=filename,
params=params, params=params,
model=model, model=model,
model_avg=model_avg,
optimizer_g=optimizer_g, optimizer_g=optimizer_g,
optimizer_d=optimizer_d, optimizer_d=optimizer_d,
scheduler_g=scheduler_g, scheduler_g=scheduler_g,
@ -945,21 +1011,6 @@ def run(rank, world_size, args):
rank=rank, rank=rank,
) )
if params.batch_idx_train % params.save_every_n == 0:
filename = params.exp_dir / f"checkpoint-{params.batch_idx_train}.pt"
save_checkpoint(
filename=filename,
params=params,
model=model,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
scheduler_g=scheduler_g,
scheduler_d=scheduler_d,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
if rank == 0:
if params.best_train_epoch == params.cur_epoch: if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt" best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename) copyfile(src=filename, dst=best_train_filename)
@ -968,6 +1019,13 @@ def run(rank, world_size, args):
best_valid_filename = params.exp_dir / "best-valid-loss.pt" best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename) copyfile(src=filename, dst=best_valid_filename)
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_epoch_k,
prefix="epoch",
rank=rank,
)
logging.info("Done!") logging.info("Done!")
if world_size > 1: if world_size > 1:

View File

@ -34,6 +34,69 @@ def plot_spectrogram(spectrogram):
return fig return fig
def save_checkpoint_with_global_batch_idx(
out_dir: Path,
global_batch_idx: int,
model: Union[nn.Module, DDP],
model_avg: Optional[nn.Module] = None,
params: Optional[Dict[str, Any]] = None,
optimizer_g: Optional[Optimizer] = None,
optimizer_d: Optional[Optimizer] = None,
scheduler_g: Optional[LRScheduler] = None,
scheduler_d: Optional[LRScheduler] = None,
scaler: Optional[GradScaler] = None,
sampler: Optional[CutSampler] = None,
rank: int = 0,
):
"""Save training info after processing given number of batches.
Args:
out_dir:
The directory to save the checkpoint.
global_batch_idx:
The number of batches processed so far from the very start of the
training. The saved checkpoint will have the following filename:
f'out_dir / checkpoint-{global_batch_idx}.pt'
model:
The neural network model whose `state_dict` will be saved in the
checkpoint.
model_avg:
The stored model averaged from the start of training.
params:
A dict of training configurations to be saved.
optimizer:
The optimizer used in the training. Its `state_dict` will be saved.
scheduler:
The learning rate scheduler used in the training. Its `state_dict` will
be saved.
scaler:
The scaler used for mix precision training. Its `state_dict` will
be saved.
sampler:
The sampler used in the training dataset.
rank:
The rank ID used in DDP training of the current node. Set it to 0
if DDP is not used.
"""
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
filename = out_dir / f"checkpoint-{global_batch_idx}.pt"
save_checkpoint(
filename=filename,
model=model,
model_avg=model_avg,
params=params,
optimizer_g=optimizer_g,
scheduler_g=scheduler_g,
optimizer_d=optimizer_d,
scheduler_d=scheduler_d,
scaler=scaler,
sampler=sampler,
rank=rank,
)
def load_checkpoint( def load_checkpoint(
filename: Path, filename: Path,
model: nn.Module, model: nn.Module,

View File

@ -0,0 +1,287 @@
"""
Calculate Frechet Speech Distance betweeen two speech directories.
Adapted from: https://github.com/gudgud96/frechet-audio-distance/blob/main/frechet_audio_distance/fad.py
"""
import argparse
import logging
import os
from multiprocessing.dummy import Pool as ThreadPool
import librosa
import numpy as np
import soundfile as sf
import torch
from scipy import linalg
from tqdm import tqdm
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model
logging.basicConfig(level=logging.INFO)
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--real-path", type=str, help="path of the real speech directory"
)
parser.add_argument(
"--eval-path", type=str, help="path of the evaluated speech directory"
)
parser.add_argument(
"--model-path",
type=str,
default="model/huggingface/wav2vec2_base",
help="path of the wav2vec 2.0 model directory",
)
parser.add_argument(
"--real-embds-path",
type=str,
default=None,
help="path of the real embedding directory",
)
parser.add_argument(
"--eval-embds-path",
type=str,
default=None,
help="path of the evaluated embedding directory",
)
return parser
class FrechetSpeechDistance:
def __init__(
self,
model_path="resources/wav2vec2_base",
pca_dim=128,
speech_load_worker=8,
):
"""
Initialize FSD
"""
self.sample_rate = 16000
self.channels = 1
self.device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
logging.info("[Frechet Speech Distance] Using device: {}".format(self.device))
self.speech_load_worker = speech_load_worker
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_path)
self.model = Wav2Vec2Model.from_pretrained(model_path)
self.model.to(self.device)
self.model.eval()
self.pca_dim = pca_dim
def load_speech_files(self, dir, dtype="float32"):
def _load_speech_task(fname, sample_rate, channels, dtype="float32"):
if dtype not in ["float64", "float32", "int32", "int16"]:
raise ValueError(f"dtype not supported: {dtype}")
wav_data, sr = sf.read(fname, dtype=dtype)
# For integer type PCM input, convert to [-1.0, +1.0]
if dtype == "int16":
wav_data = wav_data / 32768.0
elif dtype == "int32":
wav_data = wav_data / float(2**31)
# Convert to mono
assert channels in [1, 2], "channels must be 1 or 2"
if len(wav_data.shape) > channels:
wav_data = np.mean(wav_data, axis=1)
if sr != sample_rate:
wav_data = (
librosa.resample(wav_data, orig_sr=sr, target_sr=sample_rate),
)
return wav_data
task_results = []
pool = ThreadPool(self.speech_load_worker)
logging.info("[Frechet Speech Distance] Loading speech from {}...".format(dir))
for fname in os.listdir(dir):
res = pool.apply_async(
_load_speech_task,
args=(os.path.join(dir, fname), self.sample_rate, self.channels, dtype),
)
task_results.append(res)
pool.close()
pool.join()
return [k.get() for k in task_results]
def get_embeddings(self, x):
"""
Get embeddings
Params:
-- x : a list of np.ndarray speech samples
-- sr : sampling rate.
"""
embd_lst = []
try:
for speech in tqdm(x):
input_features = self.feature_extractor(
speech, sampling_rate=self.sample_rate, return_tensors="pt"
).input_values.to(self.device)
with torch.no_grad():
embd = self.model(input_features).last_hidden_state.mean(1)
if embd.device != torch.device("cpu"):
embd = embd.cpu()
if torch.is_tensor(embd):
embd = embd.detach().numpy()
embd_lst.append(embd)
except Exception as e:
print(
"[Frechet Speech Distance] get_embeddings throw an exception: {}".format(
str(e)
)
)
return np.concatenate(embd_lst, axis=0)
def calculate_embd_statistics(self, embd_lst):
if isinstance(embd_lst, list):
embd_lst = np.array(embd_lst)
mu = np.mean(embd_lst, axis=0)
sigma = np.cov(embd_lst, rowvar=False)
return mu, sigma
def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6):
"""
Adapted from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py
Numpy implementation of the Frechet Distance.
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
and X_2 ~ N(mu_2, C_2) is
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
Stable version by Dougal J. Sutherland.
Params:
-- mu1 : Numpy array containing the activations of a layer of the
inception net (like returned by the function 'get_predictions')
for generated samples.
-- mu2 : The sample mean over activations, precalculated on an
representative data set.
-- sigma1: The covariance matrix over activations for generated samples.
-- sigma2: The covariance matrix over activations, precalculated on an
representative data set.
Returns:
-- : The Frechet Distance.
"""
mu1 = np.atleast_1d(mu1)
mu2 = np.atleast_1d(mu2)
sigma1 = np.atleast_2d(sigma1)
sigma2 = np.atleast_2d(sigma2)
assert (
mu1.shape == mu2.shape
), "Training and test mean vectors have different lengths"
assert (
sigma1.shape == sigma2.shape
), "Training and test covariances have different dimensions"
diff = mu1 - mu2
# Product might be almost singular
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2).astype(complex), disp=False)
if not np.isfinite(covmean).all():
msg = (
"fid calculation produces singular product; "
"adding %s to diagonal of cov estimates"
) % eps
logging.info(msg)
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm(
(sigma1 + offset).dot(sigma2 + offset).astype(complex)
)
# Numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
m = np.max(np.abs(covmean.imag))
raise ValueError("Imaginary component {}".format(m))
covmean = covmean.real
tr_covmean = np.trace(covmean)
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
def score(
self,
real_path,
eval_path,
real_embds_path=None,
eval_embds_path=None,
dtype="float32",
):
"""
Computes the Frechet Speech Distance (FSD) between two directories of speech files.
Parameters:
- real_path (str): Path to the directory containing real speech files.
- eval_path (str): Path to the directory containing evaluation speech files.
- real_embds_path (str, optional): Path to save/load real speech embeddings (e.g., /folder/bkg_embs.npy). If None, embeddings won't be saved.
- eval_embds_path (str, optional): Path to save/load evaluation speech embeddings (e.g., /folder/test_embs.npy). If None, embeddings won't be saved.
- dtype (str, optional): Data type for loading speech. Default is "float32".
Returns:
- float: The Frechet Speech Distance (FSD) score between the two directories of speech files.
"""
# Load or compute real embeddings
if real_embds_path is not None and os.path.exists(real_embds_path):
logging.info(
f"[Frechet Speech Distance] Loading embeddings from {real_embds_path}..."
)
embds_real = np.load(real_embds_path)
else:
speech_real = self.load_speech_files(real_path, dtype=dtype)
embds_real = self.get_embeddings(speech_real)
if real_embds_path:
os.makedirs(os.path.dirname(real_embds_path), exist_ok=True)
np.save(real_embds_path, embds_real)
# Load or compute eval embeddings
if eval_embds_path is not None and os.path.exists(eval_embds_path):
logging.info(
f"[Frechet Speech Distance] Loading embeddings from {eval_embds_path}..."
)
embds_eval = np.load(eval_embds_path)
else:
speech_eval = self.load_speech_files(eval_path, dtype=dtype)
embds_eval = self.get_embeddings(speech_eval)
if eval_embds_path:
os.makedirs(os.path.dirname(eval_embds_path), exist_ok=True)
np.save(eval_embds_path, embds_eval)
# Check if embeddings are empty
if len(embds_real) == 0:
logging.info("[Frechet Speech Distance] real set dir is empty, exiting...")
return -10.46
if len(embds_eval) == 0:
logging.info("[Frechet Speech Distance] eval set dir is empty, exiting...")
return -1
# Compute statistics and FSD score
mu_real, sigma_real = self.calculate_embd_statistics(embds_real)
mu_eval, sigma_eval = self.calculate_embd_statistics(embds_eval)
fsd_score = self.calculate_frechet_distance(
mu_real, sigma_real, mu_eval, sigma_eval
)
return fsd_score
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
FSD = FrechetSpeechDistance(model_path=args.model_path)
score = FSD.score(
args.real_path, args.eval_path, args.real_embds_path, args.eval_embds_path
)
logging.info(f"FSD score: {score:.2f}")

View File

@ -0,0 +1,139 @@
"""
Calculate WER with Whisper model
"""
import argparse
import logging
import os
import re
from pathlib import Path
from typing import List, Tuple
import librosa
import soundfile as sf
import torch
from num2words import num2words
from tqdm import tqdm
from transformers import pipeline
from icefall.utils import store_transcripts, write_error_stats
logging.basicConfig(level=logging.INFO)
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--wav-path", type=str, help="path of the speech directory")
parser.add_argument("--decode-path", type=str, help="path of the speech directory")
parser.add_argument(
"--model-path",
type=str,
default="model/huggingface/whisper_medium",
help="path of the huggingface whisper model",
)
parser.add_argument(
"--transcript-path",
type=str,
default="data/transcript/test.tsv",
help="path of the transcript tsv file",
)
parser.add_argument(
"--batch-size", type=int, default=64, help="decoding batch size"
)
parser.add_argument(
"--device", type=str, default="cuda:0", help="decoding device, cuda:0 or cpu"
)
return parser
def post_process(text: str):
def convert_numbers(match):
return num2words(match.group())
text = re.sub(r"\b\d{1,2}\b", convert_numbers, text)
text = re.sub(r"[^a-zA-Z0-9']", " ", text.lower())
text = re.sub(r"\s+", " ", text)
return text
def save_results(
res_dir: str,
results: List[Tuple[str, List[str], List[str]]],
):
if not os.path.exists(res_dir):
os.makedirs(res_dir)
recog_path = os.path.join(res_dir, "recogs.txt")
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
errs_filename = os.path.join(res_dir, "errs.txt")
with open(errs_filename, "w") as f:
_ = write_error_stats(f, "test", results, enable_log=True)
logging.info("Wrote detailed error stats to {}".format(errs_filename))
class SpeechEvalDataset(torch.utils.data.Dataset):
def __init__(self, wav_path: str, transcript_path: str):
super().__init__()
self.audio_name = []
self.audio_paths = []
self.transcripts = []
with Path(transcript_path).open("r", encoding="utf8") as f:
meta = [item.split("\t") for item in f.read().rstrip().split("\n")]
for item in meta:
self.audio_name.append(item[0])
self.audio_paths.append(Path(wav_path, item[0] + ".wav"))
self.transcripts.append(item[1])
def __len__(self):
return len(self.audio_paths)
def __getitem__(self, index: int):
audio, sampling_rate = sf.read(self.audio_paths[index])
item = {
"array": librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000),
"sampling_rate": 16000,
"reference": self.transcripts[index],
"audio_name": self.audio_name[index],
}
return item
def main(args):
batch_size = args.batch_size
pipe = pipeline(
"automatic-speech-recognition",
model=args.model_path,
device=args.device,
tokenizer=args.model_path,
)
dataset = SpeechEvalDataset(args.wav_path, args.transcript_path)
results = []
bar = tqdm(
pipe(
dataset,
generate_kwargs={"language": "english", "task": "transcribe"},
batch_size=batch_size,
),
total=len(dataset),
)
for out in bar:
results.append(
(
out["audio_name"][0],
post_process(out["reference"][0].strip()).split(),
post_process(out["text"].strip()).split(),
)
)
save_results(args.decode_path, results)
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
main(args)

View File

@ -0,0 +1 @@
../../../libritts/TTS/vocos/export-onnx.py

View File

@ -0,0 +1 @@
../../../libritts/TTS/vocos/export.py

340
egs/ljspeech/TTS/vocos/infer.py Executable file
View File

@ -0,0 +1,340 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Wei Kang
# Han Zhu)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import logging
import math
import os
from functools import partial
from pathlib import Path
import torch
import torch.nn as nn
from lhotse.utils import fix_random_seed
from scipy.io.wavfile import write
from train import add_model_arguments, get_model, get_params
from tts_datamodule import LJSpeechTtsDataModule
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import AttributeDict, setup_logger, str2bool
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=100,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=10,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=False,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="flow_match/exp",
help="The experiment dir",
)
parser.add_argument(
"--generate-dir",
type=str,
default="generated_wavs",
help="Path name of the generated wavs",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
batch: dict,
):
"""
Args:
params:
It's the return value of :func:`get_params`.
model:
The text-to-feature neural model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
cut_ids = [cut.id for cut in batch["cut"]]
features = batch["features"] # (B, T, F)
utt_durations = batch["features_lens"]
x = features.permute(0, 2, 1) # (B, F, T)
audios = model(x.to(device)) # (B, T)
wav_dir = f"{params.res_dir}/{params.suffix}"
os.makedirs(wav_dir, exist_ok=True)
for i in range(audios.shape[0]):
audio = audios[i][: (utt_durations[i] - 1) * 256 + 1024]
audio = audio.cpu().squeeze().numpy()
write(f"{wav_dir}/{cut_ids[i]}.wav", 22050, audio)
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
test_set: str,
):
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The text-to-feature neural model.
test_set:
The name of the test_set
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
with open(f"{params.res_dir}/{test_set}.scp", "w", encoding="utf8") as f:
for batch_idx, batch in enumerate(dl):
texts = batch["text"]
cut_ids = [cut.id for cut in batch["cut"]]
decode_one_batch(
params=params,
model=model,
batch=batch,
)
assert len(texts) == len(cut_ids), (len(texts), len(cut_ids))
for i in range(len(texts)):
f.write(f"{cut_ids[i]}\t{texts[i]}\n")
num_cuts += len(texts)
if batch_idx % 50 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
@torch.no_grad()
def main():
parser = get_parser()
LJSpeechTtsDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.res_dir = params.exp_dir / params.generate_dir
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
params.device = device
logging.info(f"Device: {device}")
logging.info(params)
fix_random_seed(666)
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model = model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
ljspeech = LJSpeechTtsDataModule(args)
test_cuts = ljspeech.test_cuts()
test_dl = ljspeech.test_dataloaders(test_cuts)
test_sets = ["test"]
test_dls = [test_dl]
for test_set, test_dl in zip(test_sets, test_dls):
decode_dataset(
dl=test_dl,
params=params,
model=model,
test_set=test_set,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../libritts/TTS/vocos/onnx_pretrained.py

View File

@ -0,0 +1 @@
../../../libritts/TTS/vocos/pretrained.py

View File

@ -52,9 +52,11 @@ from utils import (
save_checkpoint, save_checkpoint,
plot_spectrogram, plot_spectrogram,
get_cosine_schedule_with_warmup, get_cosine_schedule_with_warmup,
save_checkpoint_with_global_batch_idx,
) )
from icefall import diagnostics from icefall import diagnostics
from icefall.checkpoint import remove_checkpoints, update_averaged_model
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
@ -65,7 +67,7 @@ from icefall.utils import (
str2bool, str2bool,
get_parameter_groups_with_lrs, get_parameter_groups_with_lrs,
) )
from models import Vocos from model import Vocos
from lhotse import Fbank, FbankConfig from lhotse import Fbank, FbankConfig
@ -91,6 +93,20 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Intermediate dim of ConvNeXt module.", help="Intermediate dim of ConvNeXt module.",
) )
parser.add_argument(
"--max-seconds",
type=int,
default=60,
help="""
The length of the precomputed normalization window sum square
(required by istft). This argument is only for onnx export, it determines
the max length of the audio that be properly normalized.
Note, you can generate audios longer than this value with the exported onnx model,
the part longer than this value will not be normalized yet.
The larger this value is the bigger the exported onnx model will be.
""",
)
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -203,6 +219,16 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--keep-last-epoch-k",
type=int,
default=50,
help="""Only keep this number of checkpoints on disk.
For instance, if it is 3, there are only 3 checkpoints
in the exp-dir with filenames `epoch-xxx.pt`.
""",
)
parser.add_argument( parser.add_argument(
"--average-period", "--average-period",
type=int, type=int,
@ -290,8 +316,8 @@ def get_params() -> AttributeDict:
"valid_interval": 500, "valid_interval": 500,
"feature_dim": 80, "feature_dim": 80,
"segment_size": 16384, "segment_size": 16384,
"adam_b1": 0.8, "adam_b1": 0.9,
"adam_b2": 0.9, "adam_b2": 0.99,
"warmup_steps": 0, "warmup_steps": 0,
"max_steps": 2000000, "max_steps": 2000000,
"env_info": get_env_info(), "env_info": get_env_info(),
@ -311,18 +337,17 @@ def get_model(params: AttributeDict) -> nn.Module:
intermediate_dim=params.intermediate_dim, intermediate_dim=params.intermediate_dim,
num_layers=params.num_layers, num_layers=params.num_layers,
sample_rate=params.sampling_rate, sample_rate=params.sampling_rate,
max_seconds=params.max_seconds,
).to(device) ).to(device)
num_param_head = sum([p.numel() for p in model.head.parameters()]) num_param_gen = sum([p.numel() for p in model.generator.parameters()])
logging.info(f"Number of Head parameters : {num_param_head}") logging.info(f"Number of Generator parameters : {num_param_gen}")
num_param_bone = sum([p.numel() for p in model.backbone.parameters()])
logging.info(f"Number of Generator parameters : {num_param_bone}")
num_param_mpd = sum([p.numel() for p in model.mpd.parameters()]) num_param_mpd = sum([p.numel() for p in model.mpd.parameters()])
logging.info(f"Number of MultiPeriodDiscriminator parameters : {num_param_mpd}") logging.info(f"Number of MultiPeriodDiscriminator parameters : {num_param_mpd}")
num_param_mrd = sum([p.numel() for p in model.mrd.parameters()]) num_param_mrd = sum([p.numel() for p in model.mrd.parameters()])
logging.info(f"Number of MultiResolutionDiscriminator parameters : {num_param_mrd}") logging.info(f"Number of MultiResolutionDiscriminator parameters : {num_param_mrd}")
logging.info( logging.info(
f"Number of model parameters : {num_param_head + num_param_bone + num_param_mpd + num_param_mrd}" f"Number of model parameters : {num_param_gen + num_param_mpd + num_param_mrd}"
) )
return model return model
@ -481,11 +506,6 @@ def compute_discriminator_loss(
info["loss_disc_mrd"] = loss_mrd.detach().cpu().item() info["loss_disc_mrd"] = loss_mrd.detach().cpu().item()
info["loss_disc_mpd"] = loss_mpd.detach().cpu().item() info["loss_disc_mpd"] = loss_mpd.detach().cpu().item()
for i in range(len(loss_mpd_real)):
info[f"loss_disc_mpd_period_{i+1}"] = loss_mpd_real[i] + loss_mpd_gen[i]
for i in range(len(loss_mrd_real)):
info[f"loss_disc_mrd_resolution_{i+1}"] = loss_mrd_real[i] + loss_mrd_gen[i]
return loss_disc_all, info return loss_disc_all, info
@ -499,6 +519,7 @@ def train_one_epoch(
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler, scaler: GradScaler,
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
rank: int = 0, rank: int = 0,
@ -544,6 +565,7 @@ def train_one_epoch(
save_checkpoint( save_checkpoint(
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
model=model, model=model,
model_avg=model_avg,
params=params, params=params,
optimizer_g=optimizer_g, optimizer_g=optimizer_g,
optimizer_d=optimizer_d, optimizer_d=optimizer_d,
@ -566,10 +588,6 @@ def train_one_epoch(
params.segment_size - params.frame_length params.segment_size - params.frame_length
) // params.frame_shift + 1 ) // params.frame_shift + 1
# segment_frames = (
# params.segment_size + params.frame_shift // 2
# ) // params.frame_shift
start_p = random.randint(0, features_lens.min() - (segment_frames + 1)) start_p = random.randint(0, features_lens.min() - (segment_frames + 1))
features = features[:, start_p : start_p + segment_frames, :].permute( features = features[:, start_p : start_p + segment_frames, :].permute(
@ -594,6 +612,7 @@ def train_one_epoch(
loss_disc.backward() loss_disc.backward()
optimizer_d.step() optimizer_d.step()
scheduler_d.step()
optimizer_g.zero_grad() optimizer_g.zero_grad()
loss_gen, loss_gen_info = compute_generator_loss( loss_gen, loss_gen_info = compute_generator_loss(
@ -605,6 +624,7 @@ def train_one_epoch(
loss_gen.backward() loss_gen.backward()
optimizer_g.step() optimizer_g.step()
scheduler_g.step()
loss_info = loss_gen_info + loss_disc_info loss_info = loss_gen_info + loss_disc_info
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_gen_info tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_gen_info
@ -617,6 +637,39 @@ def train_one_epoch(
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 5:
return return
if (
rank == 0
and params.batch_idx_train > 0
and params.batch_idx_train % params.average_period == 0
):
update_averaged_model(
params=params,
model_cur=model,
model_avg=model_avg,
)
if (
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
model=model,
model_avg=model_avg,
params=params,
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
rank=rank,
)
if params.batch_idx_train % 100 == 0 and params.use_fp16: if params.batch_idx_train % 100 == 0 and params.use_fp16:
# If the grad scale was less than 1, try increasing it. The _growth_interval # If the grad scale was less than 1, try increasing it. The _growth_interval
# of the grad scaler is configurable, but we can't configure it to have different # of the grad scaler is configurable, but we can't configure it to have different
@ -647,8 +700,8 @@ def train_one_epoch(
f"Epoch {params.cur_epoch}, batch {batch_idx}, " f"Epoch {params.cur_epoch}, batch {batch_idx}, "
f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, "
f"loss[{loss_info}], tot_loss[{tot_loss}], " f"loss[{loss_info}], tot_loss[{tot_loss}], "
f"cur_lr_g: {cur_lr_g:.2e}, " f"cur_lr_g: {cur_lr_g:.4e}, "
f"cur_lr_d: {cur_lr_d:.2e}, " f"cur_lr_d: {cur_lr_d:.4e}, "
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
) )
@ -668,11 +721,10 @@ def train_one_epoch(
"train/grad_scale", cur_grad_scale, params.batch_idx_train "train/grad_scale", cur_grad_scale, params.batch_idx_train
) )
# if ( if (
# params.batch_idx_train % params.valid_interval == 0 params.batch_idx_train % params.valid_interval == 0
# and not params.print_diagnostics and not params.print_diagnostics
# ): ):
if True:
logging.info("Computing validation loss") logging.info("Computing validation loss")
valid_info = compute_validation_loss( valid_info = compute_validation_loss(
params=params, params=params,
@ -692,8 +744,6 @@ def train_one_epoch(
tb_writer, "train/valid_", params.batch_idx_train tb_writer, "train/valid_", params.batch_idx_train
) )
scheduler_g.step()
scheduler_d.step()
loss_value = tot_loss["loss_gen"] loss_value = tot_loss["loss_gen"]
params.train_loss = loss_value params.train_loss = loss_value
if params.train_loss < params.best_train_loss: if params.train_loss < params.best_train_loss:
@ -773,7 +823,7 @@ def compute_validation_loss(
params.sampling_rate, params.sampling_rate,
) )
logging.info(f"RTF : {infer_time / (audio_time * 10 / 1000)}") logging.info(f"Validation RTF : {infer_time / (audio_time * 10 / 1000)}")
if world_size > 1: if world_size > 1:
tot_loss.reduce(device) tot_loss.reduce(device)
@ -818,19 +868,25 @@ def run(rank, world_size, args):
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", rank) device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
params.device = device params.device = device
logging.info(params) logging.info(params)
logging.info("About to create model")
logging.info("About to create model")
model = get_model(params) model = get_model(params)
assert params.save_every_n >= params.average_period
model_avg: Optional[nn.Module] = None
if rank == 0:
# model_avg is only used with rank 0
model_avg = copy.deepcopy(model).to(torch.float64)
assert params.start_epoch > 0, params.start_epoch assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(params=params, model=model) checkpoints = load_checkpoint_if_available(
params=params, model=model, model_avg=model_avg
)
model = model.to(device) model = model.to(device)
head = model.head generator = model.generator
backbone = model.backbone
mrd = model.mrd mrd = model.mrd
mpd = model.mpd mpd = model.mpd
if world_size > 1: if world_size > 1:
@ -838,7 +894,7 @@ def run(rank, world_size, args):
model = DDP(model, device_ids=[rank], find_unused_parameters=True) model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer_g = torch.optim.AdamW( optimizer_g = torch.optim.AdamW(
itertools.chain(head.parameters(), backbone.parameters()), generator.parameters(),
params.learning_rate, params.learning_rate,
betas=[params.adam_b1, params.adam_b2], betas=[params.adam_b1, params.adam_b2],
) )
@ -923,6 +979,7 @@ def run(rank, world_size, args):
train_one_epoch( train_one_epoch(
params=params, params=params,
model=model, model=model,
model_avg=model_avg,
optimizer_g=optimizer_g, optimizer_g=optimizer_g,
optimizer_d=optimizer_d, optimizer_d=optimizer_d,
scheduler_g=scheduler_g, scheduler_g=scheduler_g,
@ -944,6 +1001,7 @@ def run(rank, world_size, args):
filename=filename, filename=filename,
params=params, params=params,
model=model, model=model,
model_avg=model_avg,
optimizer_g=optimizer_g, optimizer_g=optimizer_g,
optimizer_d=optimizer_d, optimizer_d=optimizer_d,
scheduler_g=scheduler_g, scheduler_g=scheduler_g,
@ -953,21 +1011,6 @@ def run(rank, world_size, args):
rank=rank, rank=rank,
) )
if params.batch_idx_train % params.save_every_n == 0:
filename = params.exp_dir / f"checkpoint-{params.batch_idx_train}.pt"
save_checkpoint(
filename=filename,
params=params,
model=model,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
scheduler_g=scheduler_g,
scheduler_d=scheduler_d,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
if rank == 0:
if params.best_train_epoch == params.cur_epoch: if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt" best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename) copyfile(src=filename, dst=best_train_filename)
@ -976,6 +1019,13 @@ def run(rank, world_size, args):
best_valid_filename = params.exp_dir / "best-valid-loss.pt" best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename) copyfile(src=filename, dst=best_valid_filename)
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_epoch_k,
prefix="epoch",
rank=rank,
)
logging.info("Done!") logging.info("Done!")
if world_size > 1: if world_size > 1:
@ -997,7 +1047,8 @@ def main():
run(rank=0, world_size=1, args=args) run(rank=0, world_size=1, args=args)
if __name__ == "__main__":
torch.set_num_threads(1) torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
if __name__ == "__main__":
main() main()

View File

@ -250,18 +250,22 @@ def save_checkpoint_with_global_batch_idx(
) )
def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]: def find_checkpoints(
out_dir: Path,
iteration: int = 0,
prefix: str = "checkpoint",
) -> List[str]:
"""Find all available checkpoints in a directory. """Find all available checkpoints in a directory.
The checkpoint filenames have the form: `checkpoint-xxx.pt` The checkpoint filenames have the form: `{prefix}-xxx.pt`
where xxx is a numerical value. where xxx is a numerical value.
Assume you have the following checkpoints in the folder `foo`: Assume you have the following checkpoints in the folder `foo`:
- checkpoint-1.pt - {prefix}-1.pt
- checkpoint-20.pt - {prefix}-20.pt
- checkpoint-300.pt - {prefix}-300.pt
- checkpoint-4000.pt - {prefix}-4000.pt
Case 1 (Return all checkpoints):: Case 1 (Return all checkpoints)::
@ -290,8 +294,8 @@ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
Return a list of checkpoint filenames, sorted in descending Return a list of checkpoint filenames, sorted in descending
order by the numerical value in the filename. order by the numerical value in the filename.
""" """
checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt")) checkpoints = list(glob.glob(f"{out_dir}/{prefix}-[0-9]*.pt"))
pattern = re.compile(r"checkpoint-([0-9]+).pt") pattern = re.compile(rf"{prefix}-([0-9]+).pt")
iter_checkpoints = [] iter_checkpoints = []
for c in checkpoints: for c in checkpoints:
result = pattern.search(c) result = pattern.search(c)
@ -316,12 +320,13 @@ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
def remove_checkpoints( def remove_checkpoints(
out_dir: Path, out_dir: Path,
topk: int, topk: int,
prefix: str = "checkpoint",
rank: int = 0, rank: int = 0,
): ):
"""Remove checkpoints from the given directory. """Remove checkpoints from the given directory.
We assume that checkpoint filename has the form `checkpoint-xxx.pt` We assume that checkpoint filename has the form `{prefix}-xxx.pt`
where xxx is a number, representing the number of processed batches where xxx is a number, representing the number of processed batches/epochs
when saving that checkpoint. We sort checkpoints by filename and keep when saving that checkpoint. We sort checkpoints by filename and keep
only the `topk` checkpoints with the highest `xxx`. only the `topk` checkpoints with the highest `xxx`.
@ -330,6 +335,8 @@ def remove_checkpoints(
The directory containing checkpoints to be removed. The directory containing checkpoints to be removed.
topk: topk:
Number of checkpoints to keep. Number of checkpoints to keep.
prefix:
The prefix of the checkpoint filename, normally `epoch`, `checkpoint`.
rank: rank:
If using DDP for training, it is the rank of the current node. If using DDP for training, it is the rank of the current node.
Use 0 if no DDP is used for training. Use 0 if no DDP is used for training.
@ -337,7 +344,7 @@ def remove_checkpoints(
assert topk >= 1, topk assert topk >= 1, topk
if rank != 0: if rank != 0:
return return
checkpoints = find_checkpoints(out_dir) checkpoints = find_checkpoints(out_dir, prefix=prefix)
if len(checkpoints) == 0: if len(checkpoints) == 0:
logging.warn(f"No checkpoints found in {out_dir}") logging.warn(f"No checkpoints found in {out_dir}")