mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Refactor, add libritts
This commit is contained in:
parent
6e07cb91e3
commit
d8a0a40955
@ -3,7 +3,7 @@ from typing import List, Optional, Tuple
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Conv2d
|
||||
from torch.nn.utils import weight_norm
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torchaudio.transforms import Spectrogram
|
||||
|
||||
|
||||
|
371
egs/libritts/TTS/vocos/export-onnx.py
Executable file
371
egs/libritts/TTS/vocos/export-onnx.py
Executable file
@ -0,0 +1,371 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang)
|
||||
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
|
||||
|
||||
"""
|
||||
This script exports a transducer model from PyTorch to ONNX.
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
2. Export the model to ONNX
|
||||
|
||||
./zipformer/export-onnx.py \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--exp-dir $repo/exp \
|
||||
--num-encoder-layers "2,2,3,4,3,2" \
|
||||
--downsampling-factor "1,2,4,8,4,2" \
|
||||
--feedforward-dim "512,768,1024,1536,1024,768" \
|
||||
--num-heads "4,4,4,8,4,4" \
|
||||
--encoder-dim "192,256,384,512,384,256" \
|
||||
--query-head-dim 32 \
|
||||
--value-head-dim 12 \
|
||||
--pos-head-dim 4 \
|
||||
--pos-dim 48 \
|
||||
--encoder-unmasked-dim "192,192,256,256,256,192" \
|
||||
--cnn-module-kernel "31,31,15,15,15,31" \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512 \
|
||||
--causal False \
|
||||
--chunk-size "16,32,64,-1" \
|
||||
--left-context-frames "64,128,256,-1" \
|
||||
--fp16 True
|
||||
It will generate the following 3 files inside $repo/exp:
|
||||
|
||||
- encoder-epoch-99-avg-1.onnx
|
||||
- decoder-epoch-99-avg-1.onnx
|
||||
- joiner-epoch-99-avg-1.onnx
|
||||
|
||||
See ./onnx_pretrained.py and ./onnx_check.py for how to
|
||||
use the exported ONNX models.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import onnx
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from onnxconverter_common import float16
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import make_pad_mask, num_tokens, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sampling-rate",
|
||||
type=int,
|
||||
default=24000,
|
||||
help="The sampleing rate of libritts dataset",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--frame-shift",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Frame shift.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--frame-length",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Frame shift.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="zipformer/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to export models in fp16",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||
"""Add meta data to an ONNX model. It is changed in-place.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename of the ONNX model to be changed.
|
||||
meta_data:
|
||||
Key-value pairs.
|
||||
"""
|
||||
model = onnx.load(filename)
|
||||
for key, value in meta_data.items():
|
||||
meta = model.metadata_props.add()
|
||||
meta.key = key
|
||||
meta.value = value
|
||||
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
def export_model_onnx(
|
||||
model: nn.Module,
|
||||
model_filename: str,
|
||||
opset_version: int = 13,
|
||||
) -> None:
|
||||
"""Export the joiner model to ONNX format.
|
||||
The exported joiner model has two inputs:
|
||||
|
||||
- encoder_out: a tensor of shape (N, joiner_dim)
|
||||
- decoder_out: a tensor of shape (N, joiner_dim)
|
||||
|
||||
and produces one output:
|
||||
|
||||
- logit: a tensor of shape (N, vocab_size)
|
||||
"""
|
||||
input_tensor = torch.rand((2, 80, 100), dtype=torch.float32)
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(input_tensor,),
|
||||
model_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=[
|
||||
"features",
|
||||
],
|
||||
output_names=["audio"],
|
||||
dynamic_axes={
|
||||
"features": {0: "N", 2: "F"},
|
||||
"audio": {0: "N", 1: "T"},
|
||||
},
|
||||
)
|
||||
|
||||
meta_data = {
|
||||
"model_type": "Vocos",
|
||||
"version": "1",
|
||||
"model_author": "k2-fsa",
|
||||
"comment": "ConvNext Vocos",
|
||||
}
|
||||
logging.info(f"meta_data: {meta_data}")
|
||||
|
||||
add_meta_data(filename=model_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
params.device = device
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
model.to(device)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.eval()
|
||||
vocos = model.generator
|
||||
|
||||
if params.iter > 0:
|
||||
suffix = f"iter-{params.iter}"
|
||||
else:
|
||||
suffix = f"epoch-{params.epoch}"
|
||||
|
||||
suffix += f"-avg-{params.avg}"
|
||||
|
||||
opset_version = 13
|
||||
|
||||
logging.info("Exporting model")
|
||||
model_filename = params.exp_dir / f"vocos-{suffix}.onnx"
|
||||
export_model_onnx(
|
||||
vocos,
|
||||
model_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported vocos generator to {model_filename}")
|
||||
|
||||
if params.fp16:
|
||||
logging.info("Generate fp16 models")
|
||||
|
||||
model = onnx.load(model_filename)
|
||||
model_fp16 = float16.convert_float_to_float16(model, keep_io_types=True)
|
||||
model_filename_fp16 = params.exp_dir / f"vocos-{suffix}.fp16.onnx"
|
||||
onnx.save(model_fp16, model_filename_fp16)
|
||||
|
||||
# Generate int8 quantization models
|
||||
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||
|
||||
logging.info("Generate int8 quantization models")
|
||||
|
||||
model_filename_int8 = params.exp_dir / f"vocos-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=model_filename,
|
||||
model_output=model_filename_int8,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
407
egs/libritts/TTS/vocos/export.py
Executable file
407
egs/libritts/TTS/vocos/export.py
Executable file
@ -0,0 +1,407 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2024 Xiaomi Corporation (Author: Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This script converts several saved checkpoints
|
||||
# to a single one using model averaging.
|
||||
"""
|
||||
|
||||
Usage:
|
||||
|
||||
Note: This is a example for libritts dataset, if you are using different
|
||||
dataset, you should change the argument values according to your dataset.
|
||||
|
||||
(1) Export to torchscript model using torch.jit.script()
|
||||
|
||||
|
||||
./vocos/export.py \
|
||||
--exp-dir ./vocos/exp \
|
||||
--epoch 30 \
|
||||
--avg 9 \
|
||||
--jit 1
|
||||
|
||||
It will generate a file `jit_script.pt` in the given `exp_dir`. You can later
|
||||
load it by `torch.jit.load("jit_script.pt")`.
|
||||
|
||||
Check ./jit_pretrained.py for its usage.
|
||||
|
||||
Check https://github.com/k2-fsa/sherpa
|
||||
for how to use the exported models outside of icefall.
|
||||
|
||||
- For streaming model:
|
||||
|
||||
./zipformer/export.py \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--causal 1 \
|
||||
--chunk-size 16 \
|
||||
--left-context-frames 128 \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch 30 \
|
||||
--avg 9 \
|
||||
--jit 1
|
||||
|
||||
It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`.
|
||||
You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`.
|
||||
|
||||
Check ./jit_pretrained_streaming.py for its usage.
|
||||
|
||||
Check https://github.com/k2-fsa/sherpa
|
||||
for how to use the exported models outside of icefall.
|
||||
|
||||
(2) Export `model.state_dict()`
|
||||
|
||||
- For non-streaming model:
|
||||
|
||||
./zipformer/export.py \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch 30 \
|
||||
--avg 9
|
||||
|
||||
- For streaming model:
|
||||
|
||||
./zipformer/export.py \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--causal 1 \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch 30 \
|
||||
--avg 9
|
||||
|
||||
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
|
||||
load it by `icefall.checkpoint.load_checkpoint()`.
|
||||
|
||||
- For non-streaming model:
|
||||
|
||||
To use the generated file with `zipformer/decode.py`,
|
||||
you can do:
|
||||
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
|
||||
cd /path/to/egs/librispeech/ASR
|
||||
./zipformer/decode.py \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search \
|
||||
--bpe-model data/lang_bpe_500/bpe.model
|
||||
|
||||
- For streaming model:
|
||||
|
||||
To use the generated file with `zipformer/decode.py` and `zipformer/streaming_decode.py`, you can do:
|
||||
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
|
||||
cd /path/to/egs/librispeech/ASR
|
||||
|
||||
# simulated streaming decoding
|
||||
./zipformer/decode.py \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 600 \
|
||||
--causal 1 \
|
||||
--chunk-size 16 \
|
||||
--left-context-frames 128 \
|
||||
--decoding-method greedy_search \
|
||||
--bpe-model data/lang_bpe_500/bpe.model
|
||||
|
||||
# chunk-wise streaming decoding
|
||||
./zipformer/streaming_decode.py \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 600 \
|
||||
--causal 1 \
|
||||
--chunk-size 16 \
|
||||
--left-context-frames 128 \
|
||||
--decoding-method greedy_search \
|
||||
--bpe-model data/lang_bpe_500/bpe.model
|
||||
|
||||
Check ./pretrained.py for its usage.
|
||||
|
||||
Note: If you don't want to train a model from scratch, we have
|
||||
provided one for you. You can get it at
|
||||
|
||||
- non-streaming model:
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||
|
||||
- streaming model:
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
|
||||
|
||||
with the following commands:
|
||||
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
|
||||
# You will find the pre-trained models in exp dir
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
)
|
||||
|
||||
from icefall.utils import str2bool
|
||||
from utils import load_checkpoint
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sampling-rate",
|
||||
type=int,
|
||||
default=24000,
|
||||
help="The sampleing rate of libritts dataset",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--frame-shift",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Frame shift.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--frame-length",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Frame shift.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="vocos/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to save a model after applying torch.jit.script.
|
||||
It will generate a file named jit_script.pt.
|
||||
Check ./jit_pretrained.py for how to use it.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class EncoderModel(nn.Module):
|
||||
"""A wrapper for encoder and encoder_embed"""
|
||||
|
||||
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.encoder_embed = encoder_embed
|
||||
|
||||
def forward(
|
||||
self, features: Tensor, feature_lengths: Tensor
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Args:
|
||||
features: (N, T, C)
|
||||
feature_lengths: (N,)
|
||||
"""
|
||||
x, x_lens = self.encoder_embed(features, feature_lengths)
|
||||
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
params.device = device
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.eval()
|
||||
|
||||
model = model.generator
|
||||
|
||||
if params.jit is True:
|
||||
model.encoder = EncoderModel(model.encoder, model.encoder_embed)
|
||||
filename = "jit_script.pt"
|
||||
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
model.save(str(params.exp_dir / filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
else:
|
||||
logging.info("Not using torchscript. Export model.state_dict()")
|
||||
# Save it using a format so that it can be loaded
|
||||
# by :func:`load_checkpoint`
|
||||
filename = params.exp_dir / "generator.pt"
|
||||
torch.save({"model": model.state_dict()}, str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -1,122 +1,154 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Variable
|
||||
from torch.nn import functional as F
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
|
||||
def window_sumsquare(
|
||||
window: torch.Tensor,
|
||||
n_samples: int,
|
||||
hop_length: int = 256,
|
||||
win_length: int = 1024,
|
||||
):
|
||||
"""
|
||||
Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
|
||||
|
||||
Args:
|
||||
num_embeddings (int): Number of embeddings.
|
||||
embedding_dim (int): Dimension of the embeddings.
|
||||
Compute the sum-square envelope of a window function at a given hop length.
|
||||
This is used to estimate modulation effects induced by windowing
|
||||
observations in short-time fourier transforms.
|
||||
Parameters
|
||||
----------
|
||||
window : string, tuple, number, callable, or list-like
|
||||
Window specification, as in `get_window`
|
||||
n_samples : int > 0
|
||||
The number of expected samples.
|
||||
hop_length : int > 0
|
||||
The number of samples to advance between frames
|
||||
win_length :
|
||||
The length of the window function.
|
||||
Returns
|
||||
-------
|
||||
wss : torch.Tensor, The sum-squared envelope of the window function.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.dim = embedding_dim
|
||||
self.scale = nn.Embedding(
|
||||
num_embeddings=num_embeddings, embedding_dim=embedding_dim
|
||||
)
|
||||
self.shift = nn.Embedding(
|
||||
num_embeddings=num_embeddings, embedding_dim=embedding_dim
|
||||
)
|
||||
torch.nn.init.ones_(self.scale.weight)
|
||||
torch.nn.init.zeros_(self.shift.weight)
|
||||
n_frames = (n_samples - win_length) // hop_length + 1
|
||||
output_size = (n_frames - 1) * hop_length + win_length
|
||||
device = window.device
|
||||
|
||||
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
|
||||
scale = self.scale(cond_embedding_id)
|
||||
shift = self.shift(cond_embedding_id)
|
||||
x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
|
||||
x = x * scale + shift
|
||||
return x
|
||||
# Window envelope
|
||||
window_sq = window.square().expand(1, n_frames, -1).transpose(1, 2)
|
||||
window_envelope = torch.nn.functional.fold(
|
||||
window_sq,
|
||||
output_size=(1, output_size),
|
||||
kernel_size=(1, win_length),
|
||||
stride=(1, hop_length),
|
||||
).squeeze()
|
||||
window_envelope = torch.nn.functional.pad(
|
||||
window_envelope, (0, n_samples - output_size)
|
||||
)
|
||||
return window_envelope
|
||||
|
||||
|
||||
class ISTFT(nn.Module):
|
||||
"""
|
||||
Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
|
||||
windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
|
||||
See issue: https://github.com/pytorch/pytorch/issues/62323
|
||||
Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
|
||||
|
||||
Args:
|
||||
n_fft (int): Size of Fourier transform.
|
||||
hop_length (int): The distance between neighboring sliding window frames.
|
||||
win_length (int): The size of window frame and STFT filter.
|
||||
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
||||
"""
|
||||
class ISTFT(torch.nn.Module):
|
||||
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
||||
|
||||
def __init__(
|
||||
self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
|
||||
self,
|
||||
filter_length: int = 1024,
|
||||
hop_length: int = 256,
|
||||
win_length: int = 1024,
|
||||
padding: str = "none",
|
||||
window_type: str = "povey",
|
||||
max_samples: int = 1440000, # 1440000 / 24000 = 60s
|
||||
):
|
||||
super().__init__()
|
||||
if padding not in ["center", "same"]:
|
||||
raise ValueError("Padding must be 'center' or 'same'.")
|
||||
self.padding = padding
|
||||
self.n_fft = n_fft
|
||||
super(ISTFT, self).__init__()
|
||||
self.filter_length = filter_length
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
window = torch.hann_window(win_length)
|
||||
self.register_buffer("window", window)
|
||||
self.padding = padding
|
||||
scale = self.filter_length / self.hop_length
|
||||
fourier_basis = np.fft.fft(np.eye(self.filter_length))
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
|
||||
cutoff = int((self.filter_length / 2 + 1))
|
||||
fourier_basis = np.vstack(
|
||||
[np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
|
||||
)
|
||||
inverse_basis = torch.FloatTensor(
|
||||
np.linalg.pinv(scale * fourier_basis).T[:, None, :]
|
||||
)
|
||||
|
||||
Args:
|
||||
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
|
||||
N is the number of frequency bins, and T is the number of time frames.
|
||||
assert filter_length >= win_length
|
||||
# Consistence with lhotse, search "create_frame_window" in https://github.com/lhotse-speech/lhotse
|
||||
assert window_type in [
|
||||
"hanning",
|
||||
"povey",
|
||||
], f"Only 'hanning' and 'povey' windows are supported, given {window_type}."
|
||||
fft_window = torch.hann_window(win_length, periodic=False)
|
||||
if window_type == "povey":
|
||||
fft_window = fft_window.pow(0.85)
|
||||
|
||||
Returns:
|
||||
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
|
||||
"""
|
||||
if filter_length > win_length:
|
||||
pad_size = (filter_length - win_length) // 2
|
||||
fft_window = torch.nn.functional.pad(fft_window, (pad_size, pad_size))
|
||||
|
||||
window_sum = window_sumsquare(
|
||||
window=fft_window,
|
||||
n_samples=max_samples,
|
||||
hop_length=hop_length,
|
||||
win_length=filter_length,
|
||||
)
|
||||
|
||||
inverse_basis *= fft_window
|
||||
|
||||
self.register_buffer("inverse_basis", inverse_basis.float())
|
||||
self.register_buffer("fft_window", fft_window)
|
||||
self.register_buffer("window_sum", window_sum)
|
||||
self.tiny = torch.finfo(torch.float16).tiny
|
||||
|
||||
def forward(self, magnitude, phase):
|
||||
magnitude_phase = torch.cat(
|
||||
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
|
||||
)
|
||||
inverse_transform = F.conv_transpose1d(
|
||||
magnitude_phase,
|
||||
Variable(self.inverse_basis, requires_grad=False),
|
||||
stride=self.hop_length,
|
||||
padding=0,
|
||||
)
|
||||
inverse_transform = inverse_transform.squeeze(1)
|
||||
|
||||
window_sum = self.window_sum
|
||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||
if self.window_sum.size(-1) < inverse_transform.size(-1):
|
||||
logging.warning(
|
||||
f"The precomputed `window_sumsquare` is too small, recomputing, "
|
||||
f"from {self.window_sum.size(-1)} to {inverse_transform.size(-1)}"
|
||||
)
|
||||
window_sum = window_sumsquare(
|
||||
window=self.fft_window,
|
||||
n_samples=inverse_transform.size(-1),
|
||||
win_length=self.filter_length,
|
||||
hop_length=self.hop_length,
|
||||
)
|
||||
window_sum = window_sum[: inverse_transform.size(-1)]
|
||||
approx_nonzero_indices = (window_sum > self.tiny).nonzero().squeeze()
|
||||
|
||||
inverse_transform[:, approx_nonzero_indices] /= window_sum[
|
||||
approx_nonzero_indices
|
||||
]
|
||||
|
||||
# scale by hop ratio
|
||||
inverse_transform *= float(self.filter_length) / self.hop_length
|
||||
assert self.padding in ["none", "same", "center"]
|
||||
if self.padding == "center":
|
||||
# Fallback to pytorch native implementation
|
||||
return torch.istft(
|
||||
spec,
|
||||
self.n_fft,
|
||||
self.hop_length,
|
||||
self.win_length,
|
||||
self.window,
|
||||
center=True,
|
||||
)
|
||||
pad_len = self.filter_length // 2
|
||||
elif self.padding == "same":
|
||||
pad = (self.win_length - self.hop_length) // 2
|
||||
pad_len = (self.filter_length - self.hop_length) // 2
|
||||
else:
|
||||
raise ValueError("Padding must be 'center' or 'same'.")
|
||||
|
||||
assert spec.dim() == 3, "Expected a 3D tensor as input"
|
||||
B, N, T = spec.shape
|
||||
|
||||
# Inverse FFT
|
||||
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
|
||||
ifft = ifft * self.window[None, :, None]
|
||||
|
||||
# Overlap and Add
|
||||
output_size = (T - 1) * self.hop_length + self.win_length
|
||||
y = torch.nn.functional.fold(
|
||||
ifft,
|
||||
output_size=(1, output_size),
|
||||
kernel_size=(1, self.win_length),
|
||||
stride=(1, self.hop_length),
|
||||
)[:, 0, 0, :]
|
||||
|
||||
# Window envelope
|
||||
window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
|
||||
window_envelope = torch.nn.functional.fold(
|
||||
window_sq,
|
||||
output_size=(1, output_size),
|
||||
kernel_size=(1, self.win_length),
|
||||
stride=(1, self.hop_length),
|
||||
).squeeze()
|
||||
|
||||
# Normalize
|
||||
norm_indexes = window_envelope > 1e-11
|
||||
y[:, norm_indexes] = y[:, norm_indexes] / window_envelope[norm_indexes]
|
||||
|
||||
return y
|
||||
return inverse_transform
|
||||
return inverse_transform[:, pad_len:-pad_len]
|
||||
|
||||
|
||||
class ConvNeXtBlock(nn.Module):
|
||||
@ -127,8 +159,6 @@ class ConvNeXtBlock(nn.Module):
|
||||
intermediate_dim (int): Dimensionality of the intermediate layer.
|
||||
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
||||
Defaults to None.
|
||||
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
||||
None means non-conditional LayerNorm. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -136,20 +166,14 @@ class ConvNeXtBlock(nn.Module):
|
||||
dim: int,
|
||||
intermediate_dim: int,
|
||||
layer_scale_init_value: Optional[float] = None,
|
||||
adanorm_num_embeddings: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dwconv = nn.Conv1d(
|
||||
dim, dim, kernel_size=7, padding=3, groups=dim
|
||||
) # depthwise conv
|
||||
self.adanorm = adanorm_num_embeddings is not None
|
||||
if adanorm_num_embeddings:
|
||||
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
|
||||
else:
|
||||
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
||||
self.pwconv1 = nn.Linear(
|
||||
dim, intermediate_dim
|
||||
) # pointwise/1x1 convs, implemented with linear layers
|
||||
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
||||
# pointwise/1x1 convs, implemented with linear layers
|
||||
self.pwconv1 = nn.Linear(dim, intermediate_dim)
|
||||
self.act = nn.GELU()
|
||||
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
||||
self.gamma = (
|
||||
@ -159,16 +183,13 @@ class ConvNeXtBlock(nn.Module):
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
residual = x
|
||||
x = self.dwconv(x)
|
||||
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
||||
if self.adanorm:
|
||||
assert cond_embedding_id is not None
|
||||
x = self.norm(x, cond_embedding_id)
|
||||
else:
|
||||
x = self.norm(x)
|
||||
x = self.norm(x)
|
||||
x = self.pwconv1(x)
|
||||
x = self.act(x)
|
||||
x = self.pwconv2(x)
|
||||
@ -189,28 +210,22 @@ class Generator(torch.nn.Module):
|
||||
hop_length: int = 256,
|
||||
intermediate_dim: int = 1536,
|
||||
num_layers: int = 8,
|
||||
padding: str = "same",
|
||||
layer_scale_init_value: Optional[float] = None,
|
||||
adanorm_num_embeddings: Optional[int] = None,
|
||||
padding: str = "none",
|
||||
max_samples: int = 1440000, # 1440000 / 24000 = 60s
|
||||
):
|
||||
super(Generator, self).__init__()
|
||||
self.feature_dim = feature_dim
|
||||
self.embed = nn.Conv1d(feature_dim, dim, kernel_size=7, padding=3)
|
||||
|
||||
self.adanorm = adanorm_num_embeddings is not None
|
||||
if adanorm_num_embeddings:
|
||||
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
|
||||
else:
|
||||
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
||||
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
||||
|
||||
layer_scale_init_value = layer_scale_init_value or 1 / num_layers
|
||||
layer_scale_init_value = 1 / num_layers
|
||||
self.convnext = nn.ModuleList(
|
||||
[
|
||||
ConvNeXtBlock(
|
||||
dim=dim,
|
||||
intermediate_dim=intermediate_dim,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
adanorm_num_embeddings=adanorm_num_embeddings,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
@ -221,7 +236,11 @@ class Generator(torch.nn.Module):
|
||||
|
||||
self.out_proj = torch.nn.Linear(dim, n_fft + 2)
|
||||
self.istft = ISTFT(
|
||||
n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
|
||||
filter_length=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=n_fft,
|
||||
padding=padding,
|
||||
max_samples=max_samples,
|
||||
)
|
||||
|
||||
def _init_weights(self, m):
|
||||
@ -229,29 +248,17 @@ class Generator(torch.nn.Module):
|
||||
nn.init.trunc_normal_(m.weight, std=0.02)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
bandwidth_id = kwargs.get("bandwidth_id", None)
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.embed(x)
|
||||
if self.adanorm:
|
||||
assert bandwidth_id is not None
|
||||
x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
|
||||
else:
|
||||
x = self.norm(x.transpose(1, 2))
|
||||
|
||||
x = self.norm(x.transpose(1, 2))
|
||||
x = x.transpose(1, 2)
|
||||
for conv_block in self.convnext:
|
||||
x = conv_block(x, cond_embedding_id=bandwidth_id)
|
||||
|
||||
x = conv_block(x)
|
||||
x = self.final_layer_norm(x.transpose(1, 2))
|
||||
|
||||
x = self.out_proj(x).transpose(1, 2)
|
||||
mag, p = x.chunk(2, dim=1)
|
||||
mag, phase = x.chunk(2, dim=1)
|
||||
mag = torch.exp(mag)
|
||||
mag = torch.clip(
|
||||
mag, max=1e2
|
||||
) # safeguard to prevent excessively large magnitudes
|
||||
x = torch.cos(p)
|
||||
y = torch.sin(p)
|
||||
S = mag * (x + 1j * y)
|
||||
audio = self.istft(S)
|
||||
# safeguard to prevent excessively large magnitudes
|
||||
mag = torch.clip(mag, max=1e2)
|
||||
audio = self.istft(mag, phase)
|
||||
return audio
|
||||
|
40
egs/libritts/TTS/vocos/infer.py
Normal file → Executable file
40
egs/libritts/TTS/vocos/infer.py
Normal file → Executable file
@ -20,6 +20,7 @@ import argparse
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
@ -29,7 +30,7 @@ import torch.nn as nn
|
||||
from lhotse.utils import fix_random_seed
|
||||
from scipy.io.wavfile import write
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
from tts_datamodule import LJSpeechTtsDataModule
|
||||
from tts_datamodule import LibriTTSDataModule
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
@ -89,7 +90,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="flow_match/exp",
|
||||
default="vocos/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
@ -128,22 +129,31 @@ def decode_one_batch(
|
||||
|
||||
cut_ids = [cut.id for cut in batch["cut"]]
|
||||
|
||||
infer_time = 0
|
||||
audio_time = 0
|
||||
|
||||
features = batch["features"] # (B, T, F)
|
||||
utt_durations = batch["features_lens"]
|
||||
|
||||
x = features.permute(0, 2, 1) # (B, F, T)
|
||||
|
||||
audio_time += torch.sum(utt_durations)
|
||||
|
||||
start = time.time()
|
||||
|
||||
audios = model(x.to(device)) # (B, T)
|
||||
|
||||
infer_time += time.time() - start
|
||||
|
||||
wav_dir = f"{params.res_dir}/{params.suffix}"
|
||||
os.makedirs(wav_dir, exist_ok=True)
|
||||
|
||||
for i in range(audios.shape[0]):
|
||||
audio = audios[i][
|
||||
: int(utt_durations[i] * params.frame_shift_ms / 1000 * 22050)
|
||||
]
|
||||
audio = audios[i][: int(utt_durations[i] * 256)]
|
||||
audio = audio.cpu().squeeze().numpy()
|
||||
write(f"{wav_dir}/{cut_ids[i]}.wav", 22050, audio)
|
||||
write(f"{wav_dir}/{cut_ids[i]}.wav", 24000, audio)
|
||||
|
||||
print(f"RTF : {infer_time / (audio_time * (256/24000))}")
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
@ -173,7 +183,7 @@ def decode_dataset(
|
||||
|
||||
with open(f"{params.res_dir}/{test_set}.scp", "w", encoding="utf8") as f:
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["text"]
|
||||
# texts = batch["text"]
|
||||
cut_ids = [cut.id for cut in batch["cut"]]
|
||||
|
||||
decode_one_batch(
|
||||
@ -182,12 +192,12 @@ def decode_dataset(
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
assert len(texts) == len(cut_ids), (len(texts), len(cut_ids))
|
||||
# assert len(texts) == len(cut_ids), (len(texts), len(cut_ids))
|
||||
|
||||
for i in range(len(texts)):
|
||||
f.write(f"{cut_ids[i]}\t{texts[i]}\n")
|
||||
# for i in range(len(texts)):
|
||||
# f.write(f"{cut_ids[i]}\t{texts[i]}\n")
|
||||
|
||||
num_cuts += len(texts)
|
||||
# num_cuts += len(texts)
|
||||
|
||||
if batch_idx % 50 == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
@ -200,7 +210,7 @@ def decode_dataset(
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LJSpeechTtsDataModule.add_arguments(parser)
|
||||
LibriTTSDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
@ -318,11 +328,11 @@ def main():
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
ljspeech = LJSpeechTtsDataModule(args)
|
||||
libritts = LibriTTSDataModule(args)
|
||||
|
||||
test_cuts = ljspeech.test_cuts()
|
||||
test_cuts = libritts.test_clean_cuts()
|
||||
|
||||
test_dl = ljspeech.test_dataloaders(test_cuts)
|
||||
test_dl = libritts.test_dataloaders(test_cuts)
|
||||
|
||||
test_sets = ["test"]
|
||||
test_dls = [test_dl]
|
||||
|
@ -19,8 +19,9 @@ class Vocos(torch.nn.Module):
|
||||
hop_length: int = 256,
|
||||
intermediate_dim: int = 1536,
|
||||
num_layers: int = 8,
|
||||
padding: str = "same",
|
||||
padding: str = "none",
|
||||
sample_rate: int = 24000,
|
||||
max_seconds: int = 60,
|
||||
):
|
||||
super(Vocos, self).__init__()
|
||||
self.generator = Generator(
|
||||
@ -31,6 +32,7 @@ class Vocos(torch.nn.Module):
|
||||
num_layers=num_layers,
|
||||
intermediate_dim=intermediate_dim,
|
||||
padding=padding,
|
||||
max_samples=int(sample_rate * max_seconds),
|
||||
)
|
||||
|
||||
self.mpd = MultiPeriodDiscriminator()
|
||||
|
268
egs/libritts/TTS/vocos/onnx_pretrained.py
Executable file
268
egs/libritts/TTS/vocos/onnx_pretrained.py
Executable file
@ -0,0 +1,268 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads ONNX models and uses them to decode waves.
|
||||
You can use the following command to get the exported models:
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
2. Export the model to ONNX
|
||||
|
||||
./zipformer/export-onnx.py \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--exp-dir $repo/exp \
|
||||
--causal False
|
||||
|
||||
It will generate the following 3 files inside $repo/exp:
|
||||
|
||||
- encoder-epoch-99-avg-1.onnx
|
||||
- decoder-epoch-99-avg-1.onnx
|
||||
- joiner-epoch-99-avg-1.onnx
|
||||
|
||||
3. Run this file
|
||||
|
||||
./zipformer/onnx_pretrained.py \
|
||||
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from lhotse import Fbank, FbankConfig
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the encoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sampling-rate",
|
||||
type=int,
|
||||
default=24000,
|
||||
help="The sampleing rate of libritts dataset",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--frame-shift",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Frame shift.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--frame-length",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Frame shift.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-fft-mag",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to use magnitude of fbank, false to use power energy.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="generated_audios",
|
||||
help="The generated will be written to.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(
|
||||
self,
|
||||
model_filename: str,
|
||||
):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 4
|
||||
|
||||
self.session_opts = session_opts
|
||||
|
||||
self.init_model(model_filename)
|
||||
|
||||
def init_model(self, model_filename: str):
|
||||
self.model = ort.InferenceSession(
|
||||
model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def run_model(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
x_lens:
|
||||
A 2-D tensor of shape (N,). Its dtype is torch.int64
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- encoder_out, its shape is (N, T', joiner_dim)
|
||||
- encoder_out_lens, its shape is (N,)
|
||||
"""
|
||||
out = self.model.run(
|
||||
[
|
||||
self.model.get_outputs()[0].name,
|
||||
],
|
||||
{
|
||||
self.model.get_inputs()[0].name: x.numpy(),
|
||||
},
|
||||
)
|
||||
return torch.from_numpy(out[0])
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
output_dir = Path(args.model_filename).parent / args.output_dir
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
args.output_dir = output_dir
|
||||
logging.info(vars(args))
|
||||
|
||||
model = OnnxModel(model_filename=args.model_filename)
|
||||
|
||||
config = FbankConfig(
|
||||
sampling_rate=args.sampling_rate,
|
||||
frame_length=args.frame_length / args.sampling_rate, # (in second),
|
||||
frame_shift=args.frame_shift / args.sampling_rate, # (in second)
|
||||
use_fft_mag=args.use_fft_mag,
|
||||
)
|
||||
fbank = Fbank(config)
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_files}")
|
||||
|
||||
waves = read_sound_files(
|
||||
filenames=args.sound_files, expected_sample_rate=args.sampling_rate
|
||||
)
|
||||
wave_lengths = [w.size(0) for w in waves]
|
||||
waves = pad_sequence(waves, batch_first=True, padding_value=0)
|
||||
|
||||
logging.info(f"waves : {waves.shape}")
|
||||
|
||||
features = fbank.extract_batch(waves, sampling_rate=args.sampling_rate)
|
||||
|
||||
if features.dim() == 2:
|
||||
features = features.unsqueeze(0)
|
||||
|
||||
features = features.permute(0, 2, 1)
|
||||
|
||||
logging.info(f"features : {features.shape}")
|
||||
|
||||
logging.info("Generating started")
|
||||
|
||||
# model forward
|
||||
audios = model.run_model(features)
|
||||
|
||||
for i, filename in enumerate(args.sound_files):
|
||||
audio = audios[i : i + 1, 0 : wave_lengths[i]]
|
||||
ofilename = args.output_dir / filename.split("/")[-1]
|
||||
logging.info(f"Writting audio : {ofilename}")
|
||||
torchaudio.save(str(ofilename), audio.cpu(), args.sampling_rate)
|
||||
|
||||
logging.info("Generating Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
196
egs/libritts/TTS/vocos/pretrained.py
Executable file
196
egs/libritts/TTS/vocos/pretrained.py
Executable file
@ -0,0 +1,196 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads a checkpoint and uses it to decode waves.
|
||||
You can generate the checkpoint with the following command:
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
from lhotse import Fbank, FbankConfig
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the checkpoint. "
|
||||
"The checkpoint is assumed to be saved by "
|
||||
"icefall.checkpoint.save_checkpoint().",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sampling-rate",
|
||||
type=int,
|
||||
default=24000,
|
||||
help="The sampleing rate of libritts dataset",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--frame-shift",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Frame shift.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--frame-length",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Frame shift.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-fft-mag",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to use magnitude of fbank, false to use power energy.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="generated_audios",
|
||||
help="The generated will be written to.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0].contiguous())
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
params.device = device
|
||||
|
||||
output_dir = Path(params.checkpoint).parent / params.output_dir
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
params.output_dir = output_dir
|
||||
|
||||
logging.info(f"{params}")
|
||||
|
||||
logging.info("Creating model")
|
||||
model = get_model(params)
|
||||
|
||||
model = model.generator
|
||||
|
||||
checkpoint = torch.load(params.checkpoint, map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
|
||||
config = FbankConfig(
|
||||
sampling_rate=params.sampling_rate,
|
||||
frame_length=params.frame_length / params.sampling_rate, # (in second),
|
||||
frame_shift=params.frame_shift / params.sampling_rate, # (in second)
|
||||
use_fft_mag=params.use_fft_mag,
|
||||
)
|
||||
fbank = Fbank(config)
|
||||
|
||||
logging.info(f"Reading sound files: {params.sound_files}")
|
||||
|
||||
waves = read_sound_files(
|
||||
filenames=params.sound_files, expected_sample_rate=params.sampling_rate
|
||||
)
|
||||
wave_lengths = [w.size(0) for w in waves]
|
||||
waves = pad_sequence(waves, batch_first=True, padding_value=0)
|
||||
|
||||
features = (
|
||||
fbank.extract_batch(waves, sampling_rate=params.sampling_rate)
|
||||
.permute(0, 2, 1)
|
||||
.to(device)
|
||||
)
|
||||
|
||||
logging.info("Generating started")
|
||||
|
||||
# model forward
|
||||
audios = model(features)
|
||||
|
||||
for i, filename in enumerate(params.sound_files):
|
||||
audio = audios[i : i + 1, 0 : wave_lengths[i]]
|
||||
ofilename = params.output_dir / filename.split("/")[-1]
|
||||
logging.info(f"Writting audio : {ofilename}")
|
||||
torchaudio.save(str(ofilename), audio.cpu(), params.sampling_rate)
|
||||
|
||||
logging.info("Generating Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -52,9 +52,11 @@ from utils import (
|
||||
save_checkpoint,
|
||||
plot_spectrogram,
|
||||
get_cosine_schedule_with_warmup,
|
||||
save_checkpoint_with_global_batch_idx,
|
||||
)
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.checkpoint import remove_checkpoints, update_averaged_model
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
@ -91,6 +93,20 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
help="Intermediate dim of ConvNeXt module.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-seconds",
|
||||
type=int,
|
||||
default=60,
|
||||
help="""
|
||||
The length of the precomputed normalization window sum square
|
||||
(required by istft). This argument is only for onnx export, it determines
|
||||
the max length of the audio that be properly normalized.
|
||||
Note, you can generate audios longer than this value with the exported onnx model,
|
||||
the part longer than this value will not be normalized yet.
|
||||
The larger this value is the bigger the exported onnx model will be.
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -203,6 +219,16 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--keep-last-epoch-k",
|
||||
type=int,
|
||||
default=50,
|
||||
help="""Only keep this number of checkpoints on disk.
|
||||
For instance, if it is 3, there are only 3 checkpoints
|
||||
in the exp-dir with filenames `epoch-xxx.pt`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--average-period",
|
||||
type=int,
|
||||
@ -290,8 +316,8 @@ def get_params() -> AttributeDict:
|
||||
"valid_interval": 500,
|
||||
"feature_dim": 80,
|
||||
"segment_size": 16384,
|
||||
"adam_b1": 0.8,
|
||||
"adam_b2": 0.9,
|
||||
"adam_b1": 0.9,
|
||||
"adam_b2": 0.99,
|
||||
"warmup_steps": 0,
|
||||
"max_steps": 2000000,
|
||||
"env_info": get_env_info(),
|
||||
@ -311,6 +337,7 @@ def get_model(params: AttributeDict) -> nn.Module:
|
||||
intermediate_dim=params.intermediate_dim,
|
||||
num_layers=params.num_layers,
|
||||
sample_rate=params.sampling_rate,
|
||||
max_seconds=params.max_seconds,
|
||||
).to(device)
|
||||
|
||||
num_param_gen = sum([p.numel() for p in model.generator.parameters()])
|
||||
@ -479,11 +506,6 @@ def compute_discriminator_loss(
|
||||
info["loss_disc_mrd"] = loss_mrd.detach().cpu().item()
|
||||
info["loss_disc_mpd"] = loss_mpd.detach().cpu().item()
|
||||
|
||||
for i in range(len(loss_mpd_real)):
|
||||
info[f"loss_disc_mpd_period_{i+1}"] = loss_mpd_real[i] + loss_mpd_gen[i]
|
||||
for i in range(len(loss_mrd_real)):
|
||||
info[f"loss_disc_mrd_resolution_{i+1}"] = loss_mrd_real[i] + loss_mrd_gen[i]
|
||||
|
||||
return loss_disc_all, info
|
||||
|
||||
|
||||
@ -497,6 +519,7 @@ def train_one_epoch(
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
@ -542,6 +565,7 @@ def train_one_epoch(
|
||||
save_checkpoint(
|
||||
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
params=params,
|
||||
optimizer_g=optimizer_g,
|
||||
optimizer_d=optimizer_d,
|
||||
@ -588,6 +612,7 @@ def train_one_epoch(
|
||||
|
||||
loss_disc.backward()
|
||||
optimizer_d.step()
|
||||
scheduler_d.step()
|
||||
|
||||
optimizer_g.zero_grad()
|
||||
loss_gen, loss_gen_info = compute_generator_loss(
|
||||
@ -599,6 +624,7 @@ def train_one_epoch(
|
||||
|
||||
loss_gen.backward()
|
||||
optimizer_g.step()
|
||||
scheduler_g.step()
|
||||
|
||||
loss_info = loss_gen_info + loss_disc_info
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_gen_info
|
||||
@ -611,6 +637,39 @@ def train_one_epoch(
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
return
|
||||
|
||||
if (
|
||||
rank == 0
|
||||
and params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.average_period == 0
|
||||
):
|
||||
update_averaged_model(
|
||||
params=params,
|
||||
model_cur=model,
|
||||
model_avg=model_avg,
|
||||
)
|
||||
|
||||
if (
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sampler=train_dl.sampler,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if params.batch_idx_train % 100 == 0 and params.use_fp16:
|
||||
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
||||
# of the grad scaler is configurable, but we can't configure it to have different
|
||||
@ -641,8 +700,8 @@ def train_one_epoch(
|
||||
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
||||
f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, "
|
||||
f"loss[{loss_info}], tot_loss[{tot_loss}], "
|
||||
f"cur_lr_g: {cur_lr_g:.2e}, "
|
||||
f"cur_lr_d: {cur_lr_d:.2e}, "
|
||||
f"cur_lr_g: {cur_lr_g:.4e}, "
|
||||
f"cur_lr_d: {cur_lr_d:.4e}, "
|
||||
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
||||
)
|
||||
|
||||
@ -685,8 +744,6 @@ def train_one_epoch(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
|
||||
scheduler_g.step()
|
||||
scheduler_d.step()
|
||||
loss_value = tot_loss["loss_gen"]
|
||||
params.train_loss = loss_value
|
||||
if params.train_loss < params.best_train_loss:
|
||||
@ -766,7 +823,7 @@ def compute_validation_loss(
|
||||
params.sampling_rate,
|
||||
)
|
||||
|
||||
logging.info(f"RTF : {infer_time / (audio_time * 10 / 1000)}")
|
||||
logging.info(f"Validation RTF : {infer_time / (audio_time * 10 / 1000)}")
|
||||
|
||||
if world_size > 1:
|
||||
tot_loss.reduce(device)
|
||||
@ -811,15 +868,22 @@ def run(rank, world_size, args):
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", rank)
|
||||
logging.info(f"Device: {device}")
|
||||
params.device = device
|
||||
logging.info(params)
|
||||
logging.info("About to create model")
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
assert params.save_every_n >= params.average_period
|
||||
model_avg: Optional[nn.Module] = None
|
||||
if rank == 0:
|
||||
# model_avg is only used with rank 0
|
||||
model_avg = copy.deepcopy(model).to(torch.float64)
|
||||
|
||||
assert params.start_epoch > 0, params.start_epoch
|
||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||
checkpoints = load_checkpoint_if_available(
|
||||
params=params, model=model, model_avg=model_avg
|
||||
)
|
||||
|
||||
model = model.to(device)
|
||||
generator = model.generator
|
||||
@ -915,6 +979,7 @@ def run(rank, world_size, args):
|
||||
train_one_epoch(
|
||||
params=params,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
optimizer_g=optimizer_g,
|
||||
optimizer_d=optimizer_d,
|
||||
scheduler_g=scheduler_g,
|
||||
@ -936,6 +1001,7 @@ def run(rank, world_size, args):
|
||||
filename=filename,
|
||||
params=params,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
optimizer_g=optimizer_g,
|
||||
optimizer_d=optimizer_d,
|
||||
scheduler_g=scheduler_g,
|
||||
@ -945,28 +1011,20 @@ def run(rank, world_size, args):
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if params.batch_idx_train % params.save_every_n == 0:
|
||||
filename = params.exp_dir / f"checkpoint-{params.batch_idx_train}.pt"
|
||||
save_checkpoint(
|
||||
filename=filename,
|
||||
params=params,
|
||||
model=model,
|
||||
optimizer_g=optimizer_g,
|
||||
optimizer_d=optimizer_d,
|
||||
scheduler_g=scheduler_g,
|
||||
scheduler_d=scheduler_d,
|
||||
sampler=train_dl.sampler,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
if rank == 0:
|
||||
if params.best_train_epoch == params.cur_epoch:
|
||||
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
||||
copyfile(src=filename, dst=best_train_filename)
|
||||
if params.best_train_epoch == params.cur_epoch:
|
||||
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
||||
copyfile(src=filename, dst=best_train_filename)
|
||||
|
||||
if params.best_valid_epoch == params.cur_epoch:
|
||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
if params.best_valid_epoch == params.cur_epoch:
|
||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_epoch_k,
|
||||
prefix="epoch",
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
@ -34,6 +34,69 @@ def plot_spectrogram(spectrogram):
|
||||
return fig
|
||||
|
||||
|
||||
def save_checkpoint_with_global_batch_idx(
|
||||
out_dir: Path,
|
||||
global_batch_idx: int,
|
||||
model: Union[nn.Module, DDP],
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
optimizer_g: Optional[Optimizer] = None,
|
||||
optimizer_d: Optional[Optimizer] = None,
|
||||
scheduler_g: Optional[LRScheduler] = None,
|
||||
scheduler_d: Optional[LRScheduler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
rank: int = 0,
|
||||
):
|
||||
"""Save training info after processing given number of batches.
|
||||
|
||||
Args:
|
||||
out_dir:
|
||||
The directory to save the checkpoint.
|
||||
global_batch_idx:
|
||||
The number of batches processed so far from the very start of the
|
||||
training. The saved checkpoint will have the following filename:
|
||||
|
||||
f'out_dir / checkpoint-{global_batch_idx}.pt'
|
||||
model:
|
||||
The neural network model whose `state_dict` will be saved in the
|
||||
checkpoint.
|
||||
model_avg:
|
||||
The stored model averaged from the start of training.
|
||||
params:
|
||||
A dict of training configurations to be saved.
|
||||
optimizer:
|
||||
The optimizer used in the training. Its `state_dict` will be saved.
|
||||
scheduler:
|
||||
The learning rate scheduler used in the training. Its `state_dict` will
|
||||
be saved.
|
||||
scaler:
|
||||
The scaler used for mix precision training. Its `state_dict` will
|
||||
be saved.
|
||||
sampler:
|
||||
The sampler used in the training dataset.
|
||||
rank:
|
||||
The rank ID used in DDP training of the current node. Set it to 0
|
||||
if DDP is not used.
|
||||
"""
|
||||
out_dir = Path(out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
filename = out_dir / f"checkpoint-{global_batch_idx}.pt"
|
||||
save_checkpoint(
|
||||
filename=filename,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
params=params,
|
||||
optimizer_g=optimizer_g,
|
||||
scheduler_g=scheduler_g,
|
||||
optimizer_d=optimizer_d,
|
||||
scheduler_d=scheduler_d,
|
||||
scaler=scaler,
|
||||
sampler=sampler,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
filename: Path,
|
||||
model: nn.Module,
|
||||
|
287
egs/ljspeech/TTS/local/evaluate_fsd.py
Normal file
287
egs/ljspeech/TTS/local/evaluate_fsd.py
Normal file
@ -0,0 +1,287 @@
|
||||
"""
|
||||
Calculate Frechet Speech Distance betweeen two speech directories.
|
||||
Adapted from: https://github.com/gudgud96/frechet-audio-distance/blob/main/frechet_audio_distance/fad.py
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from multiprocessing.dummy import Pool as ThreadPool
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from scipy import linalg
|
||||
from tqdm import tqdm
|
||||
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--real-path", type=str, help="path of the real speech directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval-path", type=str, help="path of the evaluated speech directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
type=str,
|
||||
default="model/huggingface/wav2vec2_base",
|
||||
help="path of the wav2vec 2.0 model directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--real-embds-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path of the real embedding directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval-embds-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path of the evaluated embedding directory",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
class FrechetSpeechDistance:
|
||||
def __init__(
|
||||
self,
|
||||
model_path="resources/wav2vec2_base",
|
||||
pca_dim=128,
|
||||
speech_load_worker=8,
|
||||
):
|
||||
"""
|
||||
Initialize FSD
|
||||
"""
|
||||
self.sample_rate = 16000
|
||||
self.channels = 1
|
||||
self.device = (
|
||||
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
)
|
||||
logging.info("[Frechet Speech Distance] Using device: {}".format(self.device))
|
||||
self.speech_load_worker = speech_load_worker
|
||||
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_path)
|
||||
self.model = Wav2Vec2Model.from_pretrained(model_path)
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
self.pca_dim = pca_dim
|
||||
|
||||
def load_speech_files(self, dir, dtype="float32"):
|
||||
def _load_speech_task(fname, sample_rate, channels, dtype="float32"):
|
||||
if dtype not in ["float64", "float32", "int32", "int16"]:
|
||||
raise ValueError(f"dtype not supported: {dtype}")
|
||||
|
||||
wav_data, sr = sf.read(fname, dtype=dtype)
|
||||
# For integer type PCM input, convert to [-1.0, +1.0]
|
||||
if dtype == "int16":
|
||||
wav_data = wav_data / 32768.0
|
||||
elif dtype == "int32":
|
||||
wav_data = wav_data / float(2**31)
|
||||
|
||||
# Convert to mono
|
||||
assert channels in [1, 2], "channels must be 1 or 2"
|
||||
if len(wav_data.shape) > channels:
|
||||
wav_data = np.mean(wav_data, axis=1)
|
||||
|
||||
if sr != sample_rate:
|
||||
wav_data = (
|
||||
librosa.resample(wav_data, orig_sr=sr, target_sr=sample_rate),
|
||||
)
|
||||
|
||||
return wav_data
|
||||
|
||||
task_results = []
|
||||
|
||||
pool = ThreadPool(self.speech_load_worker)
|
||||
|
||||
logging.info("[Frechet Speech Distance] Loading speech from {}...".format(dir))
|
||||
for fname in os.listdir(dir):
|
||||
res = pool.apply_async(
|
||||
_load_speech_task,
|
||||
args=(os.path.join(dir, fname), self.sample_rate, self.channels, dtype),
|
||||
)
|
||||
task_results.append(res)
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
return [k.get() for k in task_results]
|
||||
|
||||
def get_embeddings(self, x):
|
||||
"""
|
||||
Get embeddings
|
||||
Params:
|
||||
-- x : a list of np.ndarray speech samples
|
||||
-- sr : sampling rate.
|
||||
"""
|
||||
embd_lst = []
|
||||
try:
|
||||
for speech in tqdm(x):
|
||||
input_features = self.feature_extractor(
|
||||
speech, sampling_rate=self.sample_rate, return_tensors="pt"
|
||||
).input_values.to(self.device)
|
||||
with torch.no_grad():
|
||||
embd = self.model(input_features).last_hidden_state.mean(1)
|
||||
|
||||
if embd.device != torch.device("cpu"):
|
||||
embd = embd.cpu()
|
||||
|
||||
if torch.is_tensor(embd):
|
||||
embd = embd.detach().numpy()
|
||||
|
||||
embd_lst.append(embd)
|
||||
except Exception as e:
|
||||
print(
|
||||
"[Frechet Speech Distance] get_embeddings throw an exception: {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
|
||||
return np.concatenate(embd_lst, axis=0)
|
||||
|
||||
def calculate_embd_statistics(self, embd_lst):
|
||||
if isinstance(embd_lst, list):
|
||||
embd_lst = np.array(embd_lst)
|
||||
mu = np.mean(embd_lst, axis=0)
|
||||
sigma = np.cov(embd_lst, rowvar=False)
|
||||
return mu, sigma
|
||||
|
||||
def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6):
|
||||
"""
|
||||
Adapted from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py
|
||||
|
||||
Numpy implementation of the Frechet Distance.
|
||||
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
||||
and X_2 ~ N(mu_2, C_2) is
|
||||
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
||||
Stable version by Dougal J. Sutherland.
|
||||
Params:
|
||||
-- mu1 : Numpy array containing the activations of a layer of the
|
||||
inception net (like returned by the function 'get_predictions')
|
||||
for generated samples.
|
||||
-- mu2 : The sample mean over activations, precalculated on an
|
||||
representative data set.
|
||||
-- sigma1: The covariance matrix over activations for generated samples.
|
||||
-- sigma2: The covariance matrix over activations, precalculated on an
|
||||
representative data set.
|
||||
Returns:
|
||||
-- : The Frechet Distance.
|
||||
"""
|
||||
|
||||
mu1 = np.atleast_1d(mu1)
|
||||
mu2 = np.atleast_1d(mu2)
|
||||
|
||||
sigma1 = np.atleast_2d(sigma1)
|
||||
sigma2 = np.atleast_2d(sigma2)
|
||||
|
||||
assert (
|
||||
mu1.shape == mu2.shape
|
||||
), "Training and test mean vectors have different lengths"
|
||||
assert (
|
||||
sigma1.shape == sigma2.shape
|
||||
), "Training and test covariances have different dimensions"
|
||||
|
||||
diff = mu1 - mu2
|
||||
|
||||
# Product might be almost singular
|
||||
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2).astype(complex), disp=False)
|
||||
if not np.isfinite(covmean).all():
|
||||
msg = (
|
||||
"fid calculation produces singular product; "
|
||||
"adding %s to diagonal of cov estimates"
|
||||
) % eps
|
||||
logging.info(msg)
|
||||
offset = np.eye(sigma1.shape[0]) * eps
|
||||
covmean = linalg.sqrtm(
|
||||
(sigma1 + offset).dot(sigma2 + offset).astype(complex)
|
||||
)
|
||||
|
||||
# Numerical error might give slight imaginary component
|
||||
if np.iscomplexobj(covmean):
|
||||
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
||||
m = np.max(np.abs(covmean.imag))
|
||||
raise ValueError("Imaginary component {}".format(m))
|
||||
covmean = covmean.real
|
||||
|
||||
tr_covmean = np.trace(covmean)
|
||||
|
||||
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
|
||||
|
||||
def score(
|
||||
self,
|
||||
real_path,
|
||||
eval_path,
|
||||
real_embds_path=None,
|
||||
eval_embds_path=None,
|
||||
dtype="float32",
|
||||
):
|
||||
"""
|
||||
Computes the Frechet Speech Distance (FSD) between two directories of speech files.
|
||||
|
||||
Parameters:
|
||||
- real_path (str): Path to the directory containing real speech files.
|
||||
- eval_path (str): Path to the directory containing evaluation speech files.
|
||||
- real_embds_path (str, optional): Path to save/load real speech embeddings (e.g., /folder/bkg_embs.npy). If None, embeddings won't be saved.
|
||||
- eval_embds_path (str, optional): Path to save/load evaluation speech embeddings (e.g., /folder/test_embs.npy). If None, embeddings won't be saved.
|
||||
- dtype (str, optional): Data type for loading speech. Default is "float32".
|
||||
|
||||
Returns:
|
||||
- float: The Frechet Speech Distance (FSD) score between the two directories of speech files.
|
||||
"""
|
||||
# Load or compute real embeddings
|
||||
if real_embds_path is not None and os.path.exists(real_embds_path):
|
||||
logging.info(
|
||||
f"[Frechet Speech Distance] Loading embeddings from {real_embds_path}..."
|
||||
)
|
||||
embds_real = np.load(real_embds_path)
|
||||
else:
|
||||
speech_real = self.load_speech_files(real_path, dtype=dtype)
|
||||
embds_real = self.get_embeddings(speech_real)
|
||||
if real_embds_path:
|
||||
os.makedirs(os.path.dirname(real_embds_path), exist_ok=True)
|
||||
np.save(real_embds_path, embds_real)
|
||||
|
||||
# Load or compute eval embeddings
|
||||
if eval_embds_path is not None and os.path.exists(eval_embds_path):
|
||||
logging.info(
|
||||
f"[Frechet Speech Distance] Loading embeddings from {eval_embds_path}..."
|
||||
)
|
||||
embds_eval = np.load(eval_embds_path)
|
||||
else:
|
||||
speech_eval = self.load_speech_files(eval_path, dtype=dtype)
|
||||
embds_eval = self.get_embeddings(speech_eval)
|
||||
if eval_embds_path:
|
||||
os.makedirs(os.path.dirname(eval_embds_path), exist_ok=True)
|
||||
np.save(eval_embds_path, embds_eval)
|
||||
|
||||
# Check if embeddings are empty
|
||||
if len(embds_real) == 0:
|
||||
logging.info("[Frechet Speech Distance] real set dir is empty, exiting...")
|
||||
return -10.46
|
||||
if len(embds_eval) == 0:
|
||||
logging.info("[Frechet Speech Distance] eval set dir is empty, exiting...")
|
||||
return -1
|
||||
|
||||
# Compute statistics and FSD score
|
||||
mu_real, sigma_real = self.calculate_embd_statistics(embds_real)
|
||||
mu_eval, sigma_eval = self.calculate_embd_statistics(embds_eval)
|
||||
|
||||
fsd_score = self.calculate_frechet_distance(
|
||||
mu_real, sigma_real, mu_eval, sigma_eval
|
||||
)
|
||||
|
||||
return fsd_score
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
FSD = FrechetSpeechDistance(model_path=args.model_path)
|
||||
score = FSD.score(
|
||||
args.real_path, args.eval_path, args.real_embds_path, args.eval_embds_path
|
||||
)
|
||||
logging.info(f"FSD score: {score:.2f}")
|
139
egs/ljspeech/TTS/local/evaluate_wer_whisper.py
Normal file
139
egs/ljspeech/TTS/local/evaluate_wer_whisper.py
Normal file
@ -0,0 +1,139 @@
|
||||
"""
|
||||
Calculate WER with Whisper model
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import librosa
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from num2words import num2words
|
||||
from tqdm import tqdm
|
||||
from transformers import pipeline
|
||||
|
||||
from icefall.utils import store_transcripts, write_error_stats
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--wav-path", type=str, help="path of the speech directory")
|
||||
parser.add_argument("--decode-path", type=str, help="path of the speech directory")
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
type=str,
|
||||
default="model/huggingface/whisper_medium",
|
||||
help="path of the huggingface whisper model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--transcript-path",
|
||||
type=str,
|
||||
default="data/transcript/test.tsv",
|
||||
help="path of the transcript tsv file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size", type=int, default=64, help="decoding batch size"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="cuda:0", help="decoding device, cuda:0 or cpu"
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def post_process(text: str):
|
||||
def convert_numbers(match):
|
||||
return num2words(match.group())
|
||||
|
||||
text = re.sub(r"\b\d{1,2}\b", convert_numbers, text)
|
||||
text = re.sub(r"[^a-zA-Z0-9']", " ", text.lower())
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
return text
|
||||
|
||||
|
||||
def save_results(
|
||||
res_dir: str,
|
||||
results: List[Tuple[str, List[str], List[str]]],
|
||||
):
|
||||
if not os.path.exists(res_dir):
|
||||
os.makedirs(res_dir)
|
||||
recog_path = os.path.join(res_dir, "recogs.txt")
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
errs_filename = os.path.join(res_dir, "errs.txt")
|
||||
with open(errs_filename, "w") as f:
|
||||
_ = write_error_stats(f, "test", results, enable_log=True)
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
|
||||
class SpeechEvalDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, wav_path: str, transcript_path: str):
|
||||
super().__init__()
|
||||
self.audio_name = []
|
||||
self.audio_paths = []
|
||||
self.transcripts = []
|
||||
with Path(transcript_path).open("r", encoding="utf8") as f:
|
||||
meta = [item.split("\t") for item in f.read().rstrip().split("\n")]
|
||||
for item in meta:
|
||||
self.audio_name.append(item[0])
|
||||
self.audio_paths.append(Path(wav_path, item[0] + ".wav"))
|
||||
self.transcripts.append(item[1])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.audio_paths)
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
audio, sampling_rate = sf.read(self.audio_paths[index])
|
||||
item = {
|
||||
"array": librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000),
|
||||
"sampling_rate": 16000,
|
||||
"reference": self.transcripts[index],
|
||||
"audio_name": self.audio_name[index],
|
||||
}
|
||||
return item
|
||||
|
||||
|
||||
def main(args):
|
||||
|
||||
batch_size = args.batch_size
|
||||
|
||||
pipe = pipeline(
|
||||
"automatic-speech-recognition",
|
||||
model=args.model_path,
|
||||
device=args.device,
|
||||
tokenizer=args.model_path,
|
||||
)
|
||||
|
||||
dataset = SpeechEvalDataset(args.wav_path, args.transcript_path)
|
||||
|
||||
results = []
|
||||
bar = tqdm(
|
||||
pipe(
|
||||
dataset,
|
||||
generate_kwargs={"language": "english", "task": "transcribe"},
|
||||
batch_size=batch_size,
|
||||
),
|
||||
total=len(dataset),
|
||||
)
|
||||
for out in bar:
|
||||
results.append(
|
||||
(
|
||||
out["audio_name"][0],
|
||||
post_process(out["reference"][0].strip()).split(),
|
||||
post_process(out["text"].strip()).split(),
|
||||
)
|
||||
)
|
||||
save_results(args.decode_path, results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
main(args)
|
1
egs/ljspeech/TTS/vocos/export-onnx.py
Symbolic link
1
egs/ljspeech/TTS/vocos/export-onnx.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libritts/TTS/vocos/export-onnx.py
|
1
egs/ljspeech/TTS/vocos/export.py
Symbolic link
1
egs/ljspeech/TTS/vocos/export.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libritts/TTS/vocos/export.py
|
340
egs/ljspeech/TTS/vocos/infer.py
Executable file
340
egs/ljspeech/TTS/vocos/infer.py
Executable file
@ -0,0 +1,340 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Wei Kang
|
||||
# Han Zhu)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from lhotse.utils import fix_random_seed
|
||||
from scipy.io.wavfile import write
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
from tts_datamodule import LJSpeechTtsDataModule
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=100,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="flow_match/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--generate-dir",
|
||||
type=str,
|
||||
default="generated_wavs",
|
||||
help="Path name of the generated wavs",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
batch: dict,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The text-to-feature neural model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
|
||||
cut_ids = [cut.id for cut in batch["cut"]]
|
||||
|
||||
features = batch["features"] # (B, T, F)
|
||||
utt_durations = batch["features_lens"]
|
||||
|
||||
x = features.permute(0, 2, 1) # (B, F, T)
|
||||
|
||||
audios = model(x.to(device)) # (B, T)
|
||||
|
||||
wav_dir = f"{params.res_dir}/{params.suffix}"
|
||||
os.makedirs(wav_dir, exist_ok=True)
|
||||
|
||||
for i in range(audios.shape[0]):
|
||||
audio = audios[i][: (utt_durations[i] - 1) * 256 + 1024]
|
||||
audio = audio.cpu().squeeze().numpy()
|
||||
write(f"{wav_dir}/{cut_ids[i]}.wav", 22050, audio)
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
test_set: str,
|
||||
):
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The text-to-feature neural model.
|
||||
test_set:
|
||||
The name of the test_set
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
with open(f"{params.res_dir}/{test_set}.scp", "w", encoding="utf8") as f:
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["text"]
|
||||
cut_ids = [cut.id for cut in batch["cut"]]
|
||||
|
||||
decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
assert len(texts) == len(cut_ids), (len(texts), len(cut_ids))
|
||||
|
||||
for i in range(len(texts)):
|
||||
f.write(f"{cut_ids[i]}\t{texts[i]}\n")
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % 50 == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LJSpeechTtsDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
params.res_dir = params.exp_dir / params.generate_dir
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
params.device = device
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
logging.info(params)
|
||||
fix_random_seed(666)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
ljspeech = LJSpeechTtsDataModule(args)
|
||||
|
||||
test_cuts = ljspeech.test_cuts()
|
||||
|
||||
test_dl = ljspeech.test_dataloaders(test_cuts)
|
||||
|
||||
test_sets = ["test"]
|
||||
test_dls = [test_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dls):
|
||||
decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
test_set=test_set,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
egs/ljspeech/TTS/vocos/onnx_pretrained.py
Symbolic link
1
egs/ljspeech/TTS/vocos/onnx_pretrained.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libritts/TTS/vocos/onnx_pretrained.py
|
1
egs/ljspeech/TTS/vocos/pretrained.py
Symbolic link
1
egs/ljspeech/TTS/vocos/pretrained.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libritts/TTS/vocos/pretrained.py
|
@ -52,9 +52,11 @@ from utils import (
|
||||
save_checkpoint,
|
||||
plot_spectrogram,
|
||||
get_cosine_schedule_with_warmup,
|
||||
save_checkpoint_with_global_batch_idx,
|
||||
)
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.checkpoint import remove_checkpoints, update_averaged_model
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
@ -65,7 +67,7 @@ from icefall.utils import (
|
||||
str2bool,
|
||||
get_parameter_groups_with_lrs,
|
||||
)
|
||||
from models import Vocos
|
||||
from model import Vocos
|
||||
from lhotse import Fbank, FbankConfig
|
||||
|
||||
|
||||
@ -91,6 +93,20 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
help="Intermediate dim of ConvNeXt module.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-seconds",
|
||||
type=int,
|
||||
default=60,
|
||||
help="""
|
||||
The length of the precomputed normalization window sum square
|
||||
(required by istft). This argument is only for onnx export, it determines
|
||||
the max length of the audio that be properly normalized.
|
||||
Note, you can generate audios longer than this value with the exported onnx model,
|
||||
the part longer than this value will not be normalized yet.
|
||||
The larger this value is the bigger the exported onnx model will be.
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -203,6 +219,16 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--keep-last-epoch-k",
|
||||
type=int,
|
||||
default=50,
|
||||
help="""Only keep this number of checkpoints on disk.
|
||||
For instance, if it is 3, there are only 3 checkpoints
|
||||
in the exp-dir with filenames `epoch-xxx.pt`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--average-period",
|
||||
type=int,
|
||||
@ -290,8 +316,8 @@ def get_params() -> AttributeDict:
|
||||
"valid_interval": 500,
|
||||
"feature_dim": 80,
|
||||
"segment_size": 16384,
|
||||
"adam_b1": 0.8,
|
||||
"adam_b2": 0.9,
|
||||
"adam_b1": 0.9,
|
||||
"adam_b2": 0.99,
|
||||
"warmup_steps": 0,
|
||||
"max_steps": 2000000,
|
||||
"env_info": get_env_info(),
|
||||
@ -311,18 +337,17 @@ def get_model(params: AttributeDict) -> nn.Module:
|
||||
intermediate_dim=params.intermediate_dim,
|
||||
num_layers=params.num_layers,
|
||||
sample_rate=params.sampling_rate,
|
||||
max_seconds=params.max_seconds,
|
||||
).to(device)
|
||||
|
||||
num_param_head = sum([p.numel() for p in model.head.parameters()])
|
||||
logging.info(f"Number of Head parameters : {num_param_head}")
|
||||
num_param_bone = sum([p.numel() for p in model.backbone.parameters()])
|
||||
logging.info(f"Number of Generator parameters : {num_param_bone}")
|
||||
num_param_gen = sum([p.numel() for p in model.generator.parameters()])
|
||||
logging.info(f"Number of Generator parameters : {num_param_gen}")
|
||||
num_param_mpd = sum([p.numel() for p in model.mpd.parameters()])
|
||||
logging.info(f"Number of MultiPeriodDiscriminator parameters : {num_param_mpd}")
|
||||
num_param_mrd = sum([p.numel() for p in model.mrd.parameters()])
|
||||
logging.info(f"Number of MultiResolutionDiscriminator parameters : {num_param_mrd}")
|
||||
logging.info(
|
||||
f"Number of model parameters : {num_param_head + num_param_bone + num_param_mpd + num_param_mrd}"
|
||||
f"Number of model parameters : {num_param_gen + num_param_mpd + num_param_mrd}"
|
||||
)
|
||||
return model
|
||||
|
||||
@ -481,11 +506,6 @@ def compute_discriminator_loss(
|
||||
info["loss_disc_mrd"] = loss_mrd.detach().cpu().item()
|
||||
info["loss_disc_mpd"] = loss_mpd.detach().cpu().item()
|
||||
|
||||
for i in range(len(loss_mpd_real)):
|
||||
info[f"loss_disc_mpd_period_{i+1}"] = loss_mpd_real[i] + loss_mpd_gen[i]
|
||||
for i in range(len(loss_mrd_real)):
|
||||
info[f"loss_disc_mrd_resolution_{i+1}"] = loss_mrd_real[i] + loss_mrd_gen[i]
|
||||
|
||||
return loss_disc_all, info
|
||||
|
||||
|
||||
@ -499,6 +519,7 @@ def train_one_epoch(
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
@ -544,6 +565,7 @@ def train_one_epoch(
|
||||
save_checkpoint(
|
||||
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
params=params,
|
||||
optimizer_g=optimizer_g,
|
||||
optimizer_d=optimizer_d,
|
||||
@ -566,10 +588,6 @@ def train_one_epoch(
|
||||
params.segment_size - params.frame_length
|
||||
) // params.frame_shift + 1
|
||||
|
||||
# segment_frames = (
|
||||
# params.segment_size + params.frame_shift // 2
|
||||
# ) // params.frame_shift
|
||||
|
||||
start_p = random.randint(0, features_lens.min() - (segment_frames + 1))
|
||||
|
||||
features = features[:, start_p : start_p + segment_frames, :].permute(
|
||||
@ -594,6 +612,7 @@ def train_one_epoch(
|
||||
|
||||
loss_disc.backward()
|
||||
optimizer_d.step()
|
||||
scheduler_d.step()
|
||||
|
||||
optimizer_g.zero_grad()
|
||||
loss_gen, loss_gen_info = compute_generator_loss(
|
||||
@ -605,6 +624,7 @@ def train_one_epoch(
|
||||
|
||||
loss_gen.backward()
|
||||
optimizer_g.step()
|
||||
scheduler_g.step()
|
||||
|
||||
loss_info = loss_gen_info + loss_disc_info
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_gen_info
|
||||
@ -617,6 +637,39 @@ def train_one_epoch(
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
return
|
||||
|
||||
if (
|
||||
rank == 0
|
||||
and params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.average_period == 0
|
||||
):
|
||||
update_averaged_model(
|
||||
params=params,
|
||||
model_cur=model,
|
||||
model_avg=model_avg,
|
||||
)
|
||||
|
||||
if (
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sampler=train_dl.sampler,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if params.batch_idx_train % 100 == 0 and params.use_fp16:
|
||||
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
||||
# of the grad scaler is configurable, but we can't configure it to have different
|
||||
@ -647,8 +700,8 @@ def train_one_epoch(
|
||||
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
||||
f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, "
|
||||
f"loss[{loss_info}], tot_loss[{tot_loss}], "
|
||||
f"cur_lr_g: {cur_lr_g:.2e}, "
|
||||
f"cur_lr_d: {cur_lr_d:.2e}, "
|
||||
f"cur_lr_g: {cur_lr_g:.4e}, "
|
||||
f"cur_lr_d: {cur_lr_d:.4e}, "
|
||||
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
||||
)
|
||||
|
||||
@ -668,11 +721,10 @@ def train_one_epoch(
|
||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||
)
|
||||
|
||||
# if (
|
||||
# params.batch_idx_train % params.valid_interval == 0
|
||||
# and not params.print_diagnostics
|
||||
# ):
|
||||
if True:
|
||||
if (
|
||||
params.batch_idx_train % params.valid_interval == 0
|
||||
and not params.print_diagnostics
|
||||
):
|
||||
logging.info("Computing validation loss")
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
@ -692,8 +744,6 @@ def train_one_epoch(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
|
||||
scheduler_g.step()
|
||||
scheduler_d.step()
|
||||
loss_value = tot_loss["loss_gen"]
|
||||
params.train_loss = loss_value
|
||||
if params.train_loss < params.best_train_loss:
|
||||
@ -773,7 +823,7 @@ def compute_validation_loss(
|
||||
params.sampling_rate,
|
||||
)
|
||||
|
||||
logging.info(f"RTF : {infer_time / (audio_time * 10 / 1000)}")
|
||||
logging.info(f"Validation RTF : {infer_time / (audio_time * 10 / 1000)}")
|
||||
|
||||
if world_size > 1:
|
||||
tot_loss.reduce(device)
|
||||
@ -818,19 +868,25 @@ def run(rank, world_size, args):
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", rank)
|
||||
logging.info(f"Device: {device}")
|
||||
params.device = device
|
||||
logging.info(params)
|
||||
logging.info("About to create model")
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
assert params.save_every_n >= params.average_period
|
||||
model_avg: Optional[nn.Module] = None
|
||||
if rank == 0:
|
||||
# model_avg is only used with rank 0
|
||||
model_avg = copy.deepcopy(model).to(torch.float64)
|
||||
|
||||
assert params.start_epoch > 0, params.start_epoch
|
||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||
checkpoints = load_checkpoint_if_available(
|
||||
params=params, model=model, model_avg=model_avg
|
||||
)
|
||||
|
||||
model = model.to(device)
|
||||
head = model.head
|
||||
backbone = model.backbone
|
||||
generator = model.generator
|
||||
mrd = model.mrd
|
||||
mpd = model.mpd
|
||||
if world_size > 1:
|
||||
@ -838,7 +894,7 @@ def run(rank, world_size, args):
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
|
||||
optimizer_g = torch.optim.AdamW(
|
||||
itertools.chain(head.parameters(), backbone.parameters()),
|
||||
generator.parameters(),
|
||||
params.learning_rate,
|
||||
betas=[params.adam_b1, params.adam_b2],
|
||||
)
|
||||
@ -923,6 +979,7 @@ def run(rank, world_size, args):
|
||||
train_one_epoch(
|
||||
params=params,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
optimizer_g=optimizer_g,
|
||||
optimizer_d=optimizer_d,
|
||||
scheduler_g=scheduler_g,
|
||||
@ -944,6 +1001,7 @@ def run(rank, world_size, args):
|
||||
filename=filename,
|
||||
params=params,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
optimizer_g=optimizer_g,
|
||||
optimizer_d=optimizer_d,
|
||||
scheduler_g=scheduler_g,
|
||||
@ -953,28 +1011,20 @@ def run(rank, world_size, args):
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if params.batch_idx_train % params.save_every_n == 0:
|
||||
filename = params.exp_dir / f"checkpoint-{params.batch_idx_train}.pt"
|
||||
save_checkpoint(
|
||||
filename=filename,
|
||||
params=params,
|
||||
model=model,
|
||||
optimizer_g=optimizer_g,
|
||||
optimizer_d=optimizer_d,
|
||||
scheduler_g=scheduler_g,
|
||||
scheduler_d=scheduler_d,
|
||||
sampler=train_dl.sampler,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
if rank == 0:
|
||||
if params.best_train_epoch == params.cur_epoch:
|
||||
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
||||
copyfile(src=filename, dst=best_train_filename)
|
||||
if params.best_train_epoch == params.cur_epoch:
|
||||
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
||||
copyfile(src=filename, dst=best_train_filename)
|
||||
|
||||
if params.best_valid_epoch == params.cur_epoch:
|
||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
if params.best_valid_epoch == params.cur_epoch:
|
||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_epoch_k,
|
||||
prefix="epoch",
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
@ -997,7 +1047,8 @@ def main():
|
||||
run(rank=0, world_size=1, args=args)
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
main()
|
||||
|
@ -250,18 +250,22 @@ def save_checkpoint_with_global_batch_idx(
|
||||
)
|
||||
|
||||
|
||||
def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
|
||||
def find_checkpoints(
|
||||
out_dir: Path,
|
||||
iteration: int = 0,
|
||||
prefix: str = "checkpoint",
|
||||
) -> List[str]:
|
||||
"""Find all available checkpoints in a directory.
|
||||
|
||||
The checkpoint filenames have the form: `checkpoint-xxx.pt`
|
||||
The checkpoint filenames have the form: `{prefix}-xxx.pt`
|
||||
where xxx is a numerical value.
|
||||
|
||||
Assume you have the following checkpoints in the folder `foo`:
|
||||
|
||||
- checkpoint-1.pt
|
||||
- checkpoint-20.pt
|
||||
- checkpoint-300.pt
|
||||
- checkpoint-4000.pt
|
||||
- {prefix}-1.pt
|
||||
- {prefix}-20.pt
|
||||
- {prefix}-300.pt
|
||||
- {prefix}-4000.pt
|
||||
|
||||
Case 1 (Return all checkpoints)::
|
||||
|
||||
@ -290,8 +294,8 @@ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
|
||||
Return a list of checkpoint filenames, sorted in descending
|
||||
order by the numerical value in the filename.
|
||||
"""
|
||||
checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
|
||||
pattern = re.compile(r"checkpoint-([0-9]+).pt")
|
||||
checkpoints = list(glob.glob(f"{out_dir}/{prefix}-[0-9]*.pt"))
|
||||
pattern = re.compile(rf"{prefix}-([0-9]+).pt")
|
||||
iter_checkpoints = []
|
||||
for c in checkpoints:
|
||||
result = pattern.search(c)
|
||||
@ -316,12 +320,13 @@ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
|
||||
def remove_checkpoints(
|
||||
out_dir: Path,
|
||||
topk: int,
|
||||
prefix: str = "checkpoint",
|
||||
rank: int = 0,
|
||||
):
|
||||
"""Remove checkpoints from the given directory.
|
||||
|
||||
We assume that checkpoint filename has the form `checkpoint-xxx.pt`
|
||||
where xxx is a number, representing the number of processed batches
|
||||
We assume that checkpoint filename has the form `{prefix}-xxx.pt`
|
||||
where xxx is a number, representing the number of processed batches/epochs
|
||||
when saving that checkpoint. We sort checkpoints by filename and keep
|
||||
only the `topk` checkpoints with the highest `xxx`.
|
||||
|
||||
@ -330,6 +335,8 @@ def remove_checkpoints(
|
||||
The directory containing checkpoints to be removed.
|
||||
topk:
|
||||
Number of checkpoints to keep.
|
||||
prefix:
|
||||
The prefix of the checkpoint filename, normally `epoch`, `checkpoint`.
|
||||
rank:
|
||||
If using DDP for training, it is the rank of the current node.
|
||||
Use 0 if no DDP is used for training.
|
||||
@ -337,7 +344,7 @@ def remove_checkpoints(
|
||||
assert topk >= 1, topk
|
||||
if rank != 0:
|
||||
return
|
||||
checkpoints = find_checkpoints(out_dir)
|
||||
checkpoints = find_checkpoints(out_dir, prefix=prefix)
|
||||
|
||||
if len(checkpoints) == 0:
|
||||
logging.warn(f"No checkpoints found in {out_dir}")
|
||||
|
Loading…
x
Reference in New Issue
Block a user