mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Refactor, add libritts
This commit is contained in:
parent
6e07cb91e3
commit
d8a0a40955
@ -3,7 +3,7 @@ from typing import List, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import Conv2d
|
from torch.nn import Conv2d
|
||||||
from torch.nn.utils import weight_norm
|
from torch.nn.utils.parametrizations import weight_norm
|
||||||
from torchaudio.transforms import Spectrogram
|
from torchaudio.transforms import Spectrogram
|
||||||
|
|
||||||
|
|
||||||
|
371
egs/libritts/TTS/vocos/export-onnx.py
Executable file
371
egs/libritts/TTS/vocos/export-onnx.py
Executable file
@ -0,0 +1,371 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang)
|
||||||
|
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script exports a transducer model from PyTorch to ONNX.
|
||||||
|
|
||||||
|
We use the pre-trained model from
|
||||||
|
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||||
|
as an example to show how to use this file.
|
||||||
|
|
||||||
|
1. Download the pre-trained model
|
||||||
|
|
||||||
|
cd egs/librispeech/ASR
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
|
||||||
|
cd exp
|
||||||
|
ln -s pretrained.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
2. Export the model to ONNX
|
||||||
|
|
||||||
|
./zipformer/export-onnx.py \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--num-encoder-layers "2,2,3,4,3,2" \
|
||||||
|
--downsampling-factor "1,2,4,8,4,2" \
|
||||||
|
--feedforward-dim "512,768,1024,1536,1024,768" \
|
||||||
|
--num-heads "4,4,4,8,4,4" \
|
||||||
|
--encoder-dim "192,256,384,512,384,256" \
|
||||||
|
--query-head-dim 32 \
|
||||||
|
--value-head-dim 12 \
|
||||||
|
--pos-head-dim 4 \
|
||||||
|
--pos-dim 48 \
|
||||||
|
--encoder-unmasked-dim "192,192,256,256,256,192" \
|
||||||
|
--cnn-module-kernel "31,31,15,15,15,31" \
|
||||||
|
--decoder-dim 512 \
|
||||||
|
--joiner-dim 512 \
|
||||||
|
--causal False \
|
||||||
|
--chunk-size "16,32,64,-1" \
|
||||||
|
--left-context-frames "64,128,256,-1" \
|
||||||
|
--fp16 True
|
||||||
|
It will generate the following 3 files inside $repo/exp:
|
||||||
|
|
||||||
|
- encoder-epoch-99-avg-1.onnx
|
||||||
|
- decoder-epoch-99-avg-1.onnx
|
||||||
|
- joiner-epoch-99-avg-1.onnx
|
||||||
|
|
||||||
|
See ./onnx_pretrained.py and ./onnx_check.py for how to
|
||||||
|
use the exported ONNX models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
|
import onnx
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from onnxconverter_common import float16
|
||||||
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.utils import make_pad_mask, num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sampling-rate",
|
||||||
|
type=int,
|
||||||
|
default=24000,
|
||||||
|
help="The sampleing rate of libritts dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--frame-shift",
|
||||||
|
type=int,
|
||||||
|
default=256,
|
||||||
|
help="Frame shift.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--frame-length",
|
||||||
|
type=int,
|
||||||
|
default=1024,
|
||||||
|
help="Frame shift.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=28,
|
||||||
|
help="""It specifies the checkpoint to use for averaging.
|
||||||
|
Note: Epoch counts from 0.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="zipformer/exp",
|
||||||
|
help="""It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--fp16",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to export models in fp16",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||||
|
"""Add meta data to an ONNX model. It is changed in-place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
Filename of the ONNX model to be changed.
|
||||||
|
meta_data:
|
||||||
|
Key-value pairs.
|
||||||
|
"""
|
||||||
|
model = onnx.load(filename)
|
||||||
|
for key, value in meta_data.items():
|
||||||
|
meta = model.metadata_props.add()
|
||||||
|
meta.key = key
|
||||||
|
meta.value = value
|
||||||
|
|
||||||
|
onnx.save(model, filename)
|
||||||
|
|
||||||
|
|
||||||
|
def export_model_onnx(
|
||||||
|
model: nn.Module,
|
||||||
|
model_filename: str,
|
||||||
|
opset_version: int = 13,
|
||||||
|
) -> None:
|
||||||
|
"""Export the joiner model to ONNX format.
|
||||||
|
The exported joiner model has two inputs:
|
||||||
|
|
||||||
|
- encoder_out: a tensor of shape (N, joiner_dim)
|
||||||
|
- decoder_out: a tensor of shape (N, joiner_dim)
|
||||||
|
|
||||||
|
and produces one output:
|
||||||
|
|
||||||
|
- logit: a tensor of shape (N, vocab_size)
|
||||||
|
"""
|
||||||
|
input_tensor = torch.rand((2, 80, 100), dtype=torch.float32)
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
model,
|
||||||
|
(input_tensor,),
|
||||||
|
model_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=[
|
||||||
|
"features",
|
||||||
|
],
|
||||||
|
output_names=["audio"],
|
||||||
|
dynamic_axes={
|
||||||
|
"features": {0: "N", 2: "F"},
|
||||||
|
"audio": {0: "N", 1: "T"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
meta_data = {
|
||||||
|
"model_type": "Vocos",
|
||||||
|
"version": "1",
|
||||||
|
"model_author": "k2-fsa",
|
||||||
|
"comment": "ConvNext Vocos",
|
||||||
|
}
|
||||||
|
logging.info(f"meta_data: {meta_data}")
|
||||||
|
|
||||||
|
add_meta_data(filename=model_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
params.device = device
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_model(params)
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
vocos = model.generator
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
suffix = f"iter-{params.iter}"
|
||||||
|
else:
|
||||||
|
suffix = f"epoch-{params.epoch}"
|
||||||
|
|
||||||
|
suffix += f"-avg-{params.avg}"
|
||||||
|
|
||||||
|
opset_version = 13
|
||||||
|
|
||||||
|
logging.info("Exporting model")
|
||||||
|
model_filename = params.exp_dir / f"vocos-{suffix}.onnx"
|
||||||
|
export_model_onnx(
|
||||||
|
vocos,
|
||||||
|
model_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
logging.info(f"Exported vocos generator to {model_filename}")
|
||||||
|
|
||||||
|
if params.fp16:
|
||||||
|
logging.info("Generate fp16 models")
|
||||||
|
|
||||||
|
model = onnx.load(model_filename)
|
||||||
|
model_fp16 = float16.convert_float_to_float16(model, keep_io_types=True)
|
||||||
|
model_filename_fp16 = params.exp_dir / f"vocos-{suffix}.fp16.onnx"
|
||||||
|
onnx.save(model_fp16, model_filename_fp16)
|
||||||
|
|
||||||
|
# Generate int8 quantization models
|
||||||
|
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||||
|
|
||||||
|
logging.info("Generate int8 quantization models")
|
||||||
|
|
||||||
|
model_filename_int8 = params.exp_dir / f"vocos-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=model_filename,
|
||||||
|
model_output=model_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
407
egs/libritts/TTS/vocos/export.py
Executable file
407
egs/libritts/TTS/vocos/export.py
Executable file
@ -0,0 +1,407 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2024 Xiaomi Corporation (Author: Wei Kang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# This script converts several saved checkpoints
|
||||||
|
# to a single one using model averaging.
|
||||||
|
"""
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
Note: This is a example for libritts dataset, if you are using different
|
||||||
|
dataset, you should change the argument values according to your dataset.
|
||||||
|
|
||||||
|
(1) Export to torchscript model using torch.jit.script()
|
||||||
|
|
||||||
|
|
||||||
|
./vocos/export.py \
|
||||||
|
--exp-dir ./vocos/exp \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 9 \
|
||||||
|
--jit 1
|
||||||
|
|
||||||
|
It will generate a file `jit_script.pt` in the given `exp_dir`. You can later
|
||||||
|
load it by `torch.jit.load("jit_script.pt")`.
|
||||||
|
|
||||||
|
Check ./jit_pretrained.py for its usage.
|
||||||
|
|
||||||
|
Check https://github.com/k2-fsa/sherpa
|
||||||
|
for how to use the exported models outside of icefall.
|
||||||
|
|
||||||
|
- For streaming model:
|
||||||
|
|
||||||
|
./zipformer/export.py \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--causal 1 \
|
||||||
|
--chunk-size 16 \
|
||||||
|
--left-context-frames 128 \
|
||||||
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 9 \
|
||||||
|
--jit 1
|
||||||
|
|
||||||
|
It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`.
|
||||||
|
You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`.
|
||||||
|
|
||||||
|
Check ./jit_pretrained_streaming.py for its usage.
|
||||||
|
|
||||||
|
Check https://github.com/k2-fsa/sherpa
|
||||||
|
for how to use the exported models outside of icefall.
|
||||||
|
|
||||||
|
(2) Export `model.state_dict()`
|
||||||
|
|
||||||
|
- For non-streaming model:
|
||||||
|
|
||||||
|
./zipformer/export.py \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 9
|
||||||
|
|
||||||
|
- For streaming model:
|
||||||
|
|
||||||
|
./zipformer/export.py \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--causal 1 \
|
||||||
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 9
|
||||||
|
|
||||||
|
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
|
||||||
|
load it by `icefall.checkpoint.load_checkpoint()`.
|
||||||
|
|
||||||
|
- For non-streaming model:
|
||||||
|
|
||||||
|
To use the generated file with `zipformer/decode.py`,
|
||||||
|
you can do:
|
||||||
|
|
||||||
|
cd /path/to/exp_dir
|
||||||
|
ln -s pretrained.pt epoch-9999.pt
|
||||||
|
|
||||||
|
cd /path/to/egs/librispeech/ASR
|
||||||
|
./zipformer/decode.py \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--epoch 9999 \
|
||||||
|
--avg 1 \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method greedy_search \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model
|
||||||
|
|
||||||
|
- For streaming model:
|
||||||
|
|
||||||
|
To use the generated file with `zipformer/decode.py` and `zipformer/streaming_decode.py`, you can do:
|
||||||
|
|
||||||
|
cd /path/to/exp_dir
|
||||||
|
ln -s pretrained.pt epoch-9999.pt
|
||||||
|
|
||||||
|
cd /path/to/egs/librispeech/ASR
|
||||||
|
|
||||||
|
# simulated streaming decoding
|
||||||
|
./zipformer/decode.py \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--epoch 9999 \
|
||||||
|
--avg 1 \
|
||||||
|
--max-duration 600 \
|
||||||
|
--causal 1 \
|
||||||
|
--chunk-size 16 \
|
||||||
|
--left-context-frames 128 \
|
||||||
|
--decoding-method greedy_search \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model
|
||||||
|
|
||||||
|
# chunk-wise streaming decoding
|
||||||
|
./zipformer/streaming_decode.py \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--epoch 9999 \
|
||||||
|
--avg 1 \
|
||||||
|
--max-duration 600 \
|
||||||
|
--causal 1 \
|
||||||
|
--chunk-size 16 \
|
||||||
|
--left-context-frames 128 \
|
||||||
|
--decoding-method greedy_search \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model
|
||||||
|
|
||||||
|
Check ./pretrained.py for its usage.
|
||||||
|
|
||||||
|
Note: If you don't want to train a model from scratch, we have
|
||||||
|
provided one for you. You can get it at
|
||||||
|
|
||||||
|
- non-streaming model:
|
||||||
|
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||||
|
|
||||||
|
- streaming model:
|
||||||
|
https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
|
||||||
|
|
||||||
|
with the following commands:
|
||||||
|
|
||||||
|
sudo apt-get install git-lfs
|
||||||
|
git lfs install
|
||||||
|
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||||
|
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
|
||||||
|
# You will find the pre-trained models in exp dir
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
)
|
||||||
|
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
from utils import load_checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sampling-rate",
|
||||||
|
type=int,
|
||||||
|
default=24000,
|
||||||
|
help="The sampleing rate of libritts dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--frame-shift",
|
||||||
|
type=int,
|
||||||
|
default=256,
|
||||||
|
help="Frame shift.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--frame-length",
|
||||||
|
type=int,
|
||||||
|
default=1024,
|
||||||
|
help="Frame shift.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
|
Note: Epoch counts from 1.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=9,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="vocos/exp",
|
||||||
|
help="""It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--jit",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""True to save a model after applying torch.jit.script.
|
||||||
|
It will generate a file named jit_script.pt.
|
||||||
|
Check ./jit_pretrained.py for how to use it.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderModel(nn.Module):
|
||||||
|
"""A wrapper for encoder and encoder_embed"""
|
||||||
|
|
||||||
|
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = encoder
|
||||||
|
self.encoder_embed = encoder_embed
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, features: Tensor, feature_lengths: Tensor
|
||||||
|
) -> Tuple[Tensor, Tensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
features: (N, T, C)
|
||||||
|
feature_lengths: (N,)
|
||||||
|
"""
|
||||||
|
x, x_lens = self.encoder_embed(features, feature_lengths)
|
||||||
|
|
||||||
|
src_key_padding_mask = make_pad_mask(x_lens)
|
||||||
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
|
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
||||||
|
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
|
||||||
|
return encoder_out, encoder_out_lens
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
params.device = device
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_model(params)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
model = model.generator
|
||||||
|
|
||||||
|
if params.jit is True:
|
||||||
|
model.encoder = EncoderModel(model.encoder, model.encoder_embed)
|
||||||
|
filename = "jit_script.pt"
|
||||||
|
|
||||||
|
logging.info("Using torch.jit.script")
|
||||||
|
model = torch.jit.script(model)
|
||||||
|
model.save(str(params.exp_dir / filename))
|
||||||
|
logging.info(f"Saved to {filename}")
|
||||||
|
else:
|
||||||
|
logging.info("Not using torchscript. Export model.state_dict()")
|
||||||
|
# Save it using a format so that it can be loaded
|
||||||
|
# by :func:`load_checkpoint`
|
||||||
|
filename = params.exp_dir / "generator.pt"
|
||||||
|
torch.save({"model": model.state_dict()}, str(filename))
|
||||||
|
logging.info(f"Saved to {filename}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
@ -1,122 +1,154 @@
|
|||||||
import torch
|
import logging
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.autograd import Variable
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
class AdaLayerNorm(nn.Module):
|
|
||||||
|
def window_sumsquare(
|
||||||
|
window: torch.Tensor,
|
||||||
|
n_samples: int,
|
||||||
|
hop_length: int = 256,
|
||||||
|
win_length: int = 1024,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
|
Compute the sum-square envelope of a window function at a given hop length.
|
||||||
|
This is used to estimate modulation effects induced by windowing
|
||||||
Args:
|
observations in short-time fourier transforms.
|
||||||
num_embeddings (int): Number of embeddings.
|
Parameters
|
||||||
embedding_dim (int): Dimension of the embeddings.
|
----------
|
||||||
|
window : string, tuple, number, callable, or list-like
|
||||||
|
Window specification, as in `get_window`
|
||||||
|
n_samples : int > 0
|
||||||
|
The number of expected samples.
|
||||||
|
hop_length : int > 0
|
||||||
|
The number of samples to advance between frames
|
||||||
|
win_length :
|
||||||
|
The length of the window function.
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
wss : torch.Tensor, The sum-squared envelope of the window function.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
|
n_frames = (n_samples - win_length) // hop_length + 1
|
||||||
super().__init__()
|
output_size = (n_frames - 1) * hop_length + win_length
|
||||||
self.eps = eps
|
device = window.device
|
||||||
self.dim = embedding_dim
|
|
||||||
self.scale = nn.Embedding(
|
|
||||||
num_embeddings=num_embeddings, embedding_dim=embedding_dim
|
|
||||||
)
|
|
||||||
self.shift = nn.Embedding(
|
|
||||||
num_embeddings=num_embeddings, embedding_dim=embedding_dim
|
|
||||||
)
|
|
||||||
torch.nn.init.ones_(self.scale.weight)
|
|
||||||
torch.nn.init.zeros_(self.shift.weight)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
|
|
||||||
scale = self.scale(cond_embedding_id)
|
|
||||||
shift = self.shift(cond_embedding_id)
|
|
||||||
x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
|
|
||||||
x = x * scale + shift
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class ISTFT(nn.Module):
|
|
||||||
"""
|
|
||||||
Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
|
|
||||||
windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
|
|
||||||
See issue: https://github.com/pytorch/pytorch/issues/62323
|
|
||||||
Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
n_fft (int): Size of Fourier transform.
|
|
||||||
hop_length (int): The distance between neighboring sliding window frames.
|
|
||||||
win_length (int): The size of window frame and STFT filter.
|
|
||||||
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
if padding not in ["center", "same"]:
|
|
||||||
raise ValueError("Padding must be 'center' or 'same'.")
|
|
||||||
self.padding = padding
|
|
||||||
self.n_fft = n_fft
|
|
||||||
self.hop_length = hop_length
|
|
||||||
self.win_length = win_length
|
|
||||||
window = torch.hann_window(win_length)
|
|
||||||
self.register_buffer("window", window)
|
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
|
|
||||||
N is the number of frequency bins, and T is the number of time frames.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
|
|
||||||
"""
|
|
||||||
if self.padding == "center":
|
|
||||||
# Fallback to pytorch native implementation
|
|
||||||
return torch.istft(
|
|
||||||
spec,
|
|
||||||
self.n_fft,
|
|
||||||
self.hop_length,
|
|
||||||
self.win_length,
|
|
||||||
self.window,
|
|
||||||
center=True,
|
|
||||||
)
|
|
||||||
elif self.padding == "same":
|
|
||||||
pad = (self.win_length - self.hop_length) // 2
|
|
||||||
else:
|
|
||||||
raise ValueError("Padding must be 'center' or 'same'.")
|
|
||||||
|
|
||||||
assert spec.dim() == 3, "Expected a 3D tensor as input"
|
|
||||||
B, N, T = spec.shape
|
|
||||||
|
|
||||||
# Inverse FFT
|
|
||||||
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
|
|
||||||
ifft = ifft * self.window[None, :, None]
|
|
||||||
|
|
||||||
# Overlap and Add
|
|
||||||
output_size = (T - 1) * self.hop_length + self.win_length
|
|
||||||
y = torch.nn.functional.fold(
|
|
||||||
ifft,
|
|
||||||
output_size=(1, output_size),
|
|
||||||
kernel_size=(1, self.win_length),
|
|
||||||
stride=(1, self.hop_length),
|
|
||||||
)[:, 0, 0, :]
|
|
||||||
|
|
||||||
# Window envelope
|
# Window envelope
|
||||||
window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
|
window_sq = window.square().expand(1, n_frames, -1).transpose(1, 2)
|
||||||
window_envelope = torch.nn.functional.fold(
|
window_envelope = torch.nn.functional.fold(
|
||||||
window_sq,
|
window_sq,
|
||||||
output_size=(1, output_size),
|
output_size=(1, output_size),
|
||||||
kernel_size=(1, self.win_length),
|
kernel_size=(1, win_length),
|
||||||
stride=(1, self.hop_length),
|
stride=(1, hop_length),
|
||||||
).squeeze()
|
).squeeze()
|
||||||
|
window_envelope = torch.nn.functional.pad(
|
||||||
|
window_envelope, (0, n_samples - output_size)
|
||||||
|
)
|
||||||
|
return window_envelope
|
||||||
|
|
||||||
# Normalize
|
|
||||||
norm_indexes = window_envelope > 1e-11
|
|
||||||
y[:, norm_indexes] = y[:, norm_indexes] / window_envelope[norm_indexes]
|
|
||||||
|
|
||||||
return y
|
class ISTFT(torch.nn.Module):
|
||||||
|
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
filter_length: int = 1024,
|
||||||
|
hop_length: int = 256,
|
||||||
|
win_length: int = 1024,
|
||||||
|
padding: str = "none",
|
||||||
|
window_type: str = "povey",
|
||||||
|
max_samples: int = 1440000, # 1440000 / 24000 = 60s
|
||||||
|
):
|
||||||
|
super(ISTFT, self).__init__()
|
||||||
|
self.filter_length = filter_length
|
||||||
|
self.hop_length = hop_length
|
||||||
|
self.win_length = win_length
|
||||||
|
self.padding = padding
|
||||||
|
scale = self.filter_length / self.hop_length
|
||||||
|
fourier_basis = np.fft.fft(np.eye(self.filter_length))
|
||||||
|
|
||||||
|
cutoff = int((self.filter_length / 2 + 1))
|
||||||
|
fourier_basis = np.vstack(
|
||||||
|
[np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
|
||||||
|
)
|
||||||
|
inverse_basis = torch.FloatTensor(
|
||||||
|
np.linalg.pinv(scale * fourier_basis).T[:, None, :]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert filter_length >= win_length
|
||||||
|
# Consistence with lhotse, search "create_frame_window" in https://github.com/lhotse-speech/lhotse
|
||||||
|
assert window_type in [
|
||||||
|
"hanning",
|
||||||
|
"povey",
|
||||||
|
], f"Only 'hanning' and 'povey' windows are supported, given {window_type}."
|
||||||
|
fft_window = torch.hann_window(win_length, periodic=False)
|
||||||
|
if window_type == "povey":
|
||||||
|
fft_window = fft_window.pow(0.85)
|
||||||
|
|
||||||
|
if filter_length > win_length:
|
||||||
|
pad_size = (filter_length - win_length) // 2
|
||||||
|
fft_window = torch.nn.functional.pad(fft_window, (pad_size, pad_size))
|
||||||
|
|
||||||
|
window_sum = window_sumsquare(
|
||||||
|
window=fft_window,
|
||||||
|
n_samples=max_samples,
|
||||||
|
hop_length=hop_length,
|
||||||
|
win_length=filter_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
inverse_basis *= fft_window
|
||||||
|
|
||||||
|
self.register_buffer("inverse_basis", inverse_basis.float())
|
||||||
|
self.register_buffer("fft_window", fft_window)
|
||||||
|
self.register_buffer("window_sum", window_sum)
|
||||||
|
self.tiny = torch.finfo(torch.float16).tiny
|
||||||
|
|
||||||
|
def forward(self, magnitude, phase):
|
||||||
|
magnitude_phase = torch.cat(
|
||||||
|
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
|
||||||
|
)
|
||||||
|
inverse_transform = F.conv_transpose1d(
|
||||||
|
magnitude_phase,
|
||||||
|
Variable(self.inverse_basis, requires_grad=False),
|
||||||
|
stride=self.hop_length,
|
||||||
|
padding=0,
|
||||||
|
)
|
||||||
|
inverse_transform = inverse_transform.squeeze(1)
|
||||||
|
|
||||||
|
window_sum = self.window_sum
|
||||||
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||||
|
if self.window_sum.size(-1) < inverse_transform.size(-1):
|
||||||
|
logging.warning(
|
||||||
|
f"The precomputed `window_sumsquare` is too small, recomputing, "
|
||||||
|
f"from {self.window_sum.size(-1)} to {inverse_transform.size(-1)}"
|
||||||
|
)
|
||||||
|
window_sum = window_sumsquare(
|
||||||
|
window=self.fft_window,
|
||||||
|
n_samples=inverse_transform.size(-1),
|
||||||
|
win_length=self.filter_length,
|
||||||
|
hop_length=self.hop_length,
|
||||||
|
)
|
||||||
|
window_sum = window_sum[: inverse_transform.size(-1)]
|
||||||
|
approx_nonzero_indices = (window_sum > self.tiny).nonzero().squeeze()
|
||||||
|
|
||||||
|
inverse_transform[:, approx_nonzero_indices] /= window_sum[
|
||||||
|
approx_nonzero_indices
|
||||||
|
]
|
||||||
|
|
||||||
|
# scale by hop ratio
|
||||||
|
inverse_transform *= float(self.filter_length) / self.hop_length
|
||||||
|
assert self.padding in ["none", "same", "center"]
|
||||||
|
if self.padding == "center":
|
||||||
|
pad_len = self.filter_length // 2
|
||||||
|
elif self.padding == "same":
|
||||||
|
pad_len = (self.filter_length - self.hop_length) // 2
|
||||||
|
else:
|
||||||
|
return inverse_transform
|
||||||
|
return inverse_transform[:, pad_len:-pad_len]
|
||||||
|
|
||||||
|
|
||||||
class ConvNeXtBlock(nn.Module):
|
class ConvNeXtBlock(nn.Module):
|
||||||
@ -127,8 +159,6 @@ class ConvNeXtBlock(nn.Module):
|
|||||||
intermediate_dim (int): Dimensionality of the intermediate layer.
|
intermediate_dim (int): Dimensionality of the intermediate layer.
|
||||||
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
|
||||||
None means non-conditional LayerNorm. Defaults to None.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -136,20 +166,14 @@ class ConvNeXtBlock(nn.Module):
|
|||||||
dim: int,
|
dim: int,
|
||||||
intermediate_dim: int,
|
intermediate_dim: int,
|
||||||
layer_scale_init_value: Optional[float] = None,
|
layer_scale_init_value: Optional[float] = None,
|
||||||
adanorm_num_embeddings: Optional[int] = None,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dwconv = nn.Conv1d(
|
self.dwconv = nn.Conv1d(
|
||||||
dim, dim, kernel_size=7, padding=3, groups=dim
|
dim, dim, kernel_size=7, padding=3, groups=dim
|
||||||
) # depthwise conv
|
) # depthwise conv
|
||||||
self.adanorm = adanorm_num_embeddings is not None
|
|
||||||
if adanorm_num_embeddings:
|
|
||||||
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
|
|
||||||
else:
|
|
||||||
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
||||||
self.pwconv1 = nn.Linear(
|
# pointwise/1x1 convs, implemented with linear layers
|
||||||
dim, intermediate_dim
|
self.pwconv1 = nn.Linear(dim, intermediate_dim)
|
||||||
) # pointwise/1x1 convs, implemented with linear layers
|
|
||||||
self.act = nn.GELU()
|
self.act = nn.GELU()
|
||||||
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
||||||
self.gamma = (
|
self.gamma = (
|
||||||
@ -159,15 +183,12 @@ class ConvNeXtBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
residual = x
|
residual = x
|
||||||
x = self.dwconv(x)
|
x = self.dwconv(x)
|
||||||
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
||||||
if self.adanorm:
|
|
||||||
assert cond_embedding_id is not None
|
|
||||||
x = self.norm(x, cond_embedding_id)
|
|
||||||
else:
|
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
x = self.pwconv1(x)
|
x = self.pwconv1(x)
|
||||||
x = self.act(x)
|
x = self.act(x)
|
||||||
@ -189,28 +210,22 @@ class Generator(torch.nn.Module):
|
|||||||
hop_length: int = 256,
|
hop_length: int = 256,
|
||||||
intermediate_dim: int = 1536,
|
intermediate_dim: int = 1536,
|
||||||
num_layers: int = 8,
|
num_layers: int = 8,
|
||||||
padding: str = "same",
|
padding: str = "none",
|
||||||
layer_scale_init_value: Optional[float] = None,
|
max_samples: int = 1440000, # 1440000 / 24000 = 60s
|
||||||
adanorm_num_embeddings: Optional[int] = None,
|
|
||||||
):
|
):
|
||||||
super(Generator, self).__init__()
|
super(Generator, self).__init__()
|
||||||
self.feature_dim = feature_dim
|
self.feature_dim = feature_dim
|
||||||
self.embed = nn.Conv1d(feature_dim, dim, kernel_size=7, padding=3)
|
self.embed = nn.Conv1d(feature_dim, dim, kernel_size=7, padding=3)
|
||||||
|
|
||||||
self.adanorm = adanorm_num_embeddings is not None
|
|
||||||
if adanorm_num_embeddings:
|
|
||||||
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
|
|
||||||
else:
|
|
||||||
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
||||||
|
|
||||||
layer_scale_init_value = layer_scale_init_value or 1 / num_layers
|
layer_scale_init_value = 1 / num_layers
|
||||||
self.convnext = nn.ModuleList(
|
self.convnext = nn.ModuleList(
|
||||||
[
|
[
|
||||||
ConvNeXtBlock(
|
ConvNeXtBlock(
|
||||||
dim=dim,
|
dim=dim,
|
||||||
intermediate_dim=intermediate_dim,
|
intermediate_dim=intermediate_dim,
|
||||||
layer_scale_init_value=layer_scale_init_value,
|
layer_scale_init_value=layer_scale_init_value,
|
||||||
adanorm_num_embeddings=adanorm_num_embeddings,
|
|
||||||
)
|
)
|
||||||
for _ in range(num_layers)
|
for _ in range(num_layers)
|
||||||
]
|
]
|
||||||
@ -221,7 +236,11 @@ class Generator(torch.nn.Module):
|
|||||||
|
|
||||||
self.out_proj = torch.nn.Linear(dim, n_fft + 2)
|
self.out_proj = torch.nn.Linear(dim, n_fft + 2)
|
||||||
self.istft = ISTFT(
|
self.istft = ISTFT(
|
||||||
n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
|
filter_length=n_fft,
|
||||||
|
hop_length=hop_length,
|
||||||
|
win_length=n_fft,
|
||||||
|
padding=padding,
|
||||||
|
max_samples=max_samples,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init_weights(self, m):
|
def _init_weights(self, m):
|
||||||
@ -229,29 +248,17 @@ class Generator(torch.nn.Module):
|
|||||||
nn.init.trunc_normal_(m.weight, std=0.02)
|
nn.init.trunc_normal_(m.weight, std=0.02)
|
||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
bandwidth_id = kwargs.get("bandwidth_id", None)
|
|
||||||
x = self.embed(x)
|
x = self.embed(x)
|
||||||
if self.adanorm:
|
|
||||||
assert bandwidth_id is not None
|
|
||||||
x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
|
|
||||||
else:
|
|
||||||
x = self.norm(x.transpose(1, 2))
|
x = self.norm(x.transpose(1, 2))
|
||||||
|
|
||||||
x = x.transpose(1, 2)
|
x = x.transpose(1, 2)
|
||||||
for conv_block in self.convnext:
|
for conv_block in self.convnext:
|
||||||
x = conv_block(x, cond_embedding_id=bandwidth_id)
|
x = conv_block(x)
|
||||||
|
|
||||||
x = self.final_layer_norm(x.transpose(1, 2))
|
x = self.final_layer_norm(x.transpose(1, 2))
|
||||||
|
|
||||||
x = self.out_proj(x).transpose(1, 2)
|
x = self.out_proj(x).transpose(1, 2)
|
||||||
mag, p = x.chunk(2, dim=1)
|
mag, phase = x.chunk(2, dim=1)
|
||||||
mag = torch.exp(mag)
|
mag = torch.exp(mag)
|
||||||
mag = torch.clip(
|
# safeguard to prevent excessively large magnitudes
|
||||||
mag, max=1e2
|
mag = torch.clip(mag, max=1e2)
|
||||||
) # safeguard to prevent excessively large magnitudes
|
audio = self.istft(mag, phase)
|
||||||
x = torch.cos(p)
|
|
||||||
y = torch.sin(p)
|
|
||||||
S = mag * (x + 1j * y)
|
|
||||||
audio = self.istft(S)
|
|
||||||
return audio
|
return audio
|
||||||
|
40
egs/libritts/TTS/vocos/infer.py
Normal file → Executable file
40
egs/libritts/TTS/vocos/infer.py
Normal file → Executable file
@ -20,6 +20,7 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import time
|
||||||
import os
|
import os
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -29,7 +30,7 @@ import torch.nn as nn
|
|||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from scipy.io.wavfile import write
|
from scipy.io.wavfile import write
|
||||||
from train import add_model_arguments, get_model, get_params
|
from train import add_model_arguments, get_model, get_params
|
||||||
from tts_datamodule import LJSpeechTtsDataModule
|
from tts_datamodule import LibriTTSDataModule
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
@ -89,7 +90,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="flow_match/exp",
|
default="vocos/exp",
|
||||||
help="The experiment dir",
|
help="The experiment dir",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -128,22 +129,31 @@ def decode_one_batch(
|
|||||||
|
|
||||||
cut_ids = [cut.id for cut in batch["cut"]]
|
cut_ids = [cut.id for cut in batch["cut"]]
|
||||||
|
|
||||||
|
infer_time = 0
|
||||||
|
audio_time = 0
|
||||||
|
|
||||||
features = batch["features"] # (B, T, F)
|
features = batch["features"] # (B, T, F)
|
||||||
utt_durations = batch["features_lens"]
|
utt_durations = batch["features_lens"]
|
||||||
|
|
||||||
x = features.permute(0, 2, 1) # (B, F, T)
|
x = features.permute(0, 2, 1) # (B, F, T)
|
||||||
|
|
||||||
|
audio_time += torch.sum(utt_durations)
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
|
||||||
audios = model(x.to(device)) # (B, T)
|
audios = model(x.to(device)) # (B, T)
|
||||||
|
|
||||||
|
infer_time += time.time() - start
|
||||||
|
|
||||||
wav_dir = f"{params.res_dir}/{params.suffix}"
|
wav_dir = f"{params.res_dir}/{params.suffix}"
|
||||||
os.makedirs(wav_dir, exist_ok=True)
|
os.makedirs(wav_dir, exist_ok=True)
|
||||||
|
|
||||||
for i in range(audios.shape[0]):
|
for i in range(audios.shape[0]):
|
||||||
audio = audios[i][
|
audio = audios[i][: int(utt_durations[i] * 256)]
|
||||||
: int(utt_durations[i] * params.frame_shift_ms / 1000 * 22050)
|
|
||||||
]
|
|
||||||
audio = audio.cpu().squeeze().numpy()
|
audio = audio.cpu().squeeze().numpy()
|
||||||
write(f"{wav_dir}/{cut_ids[i]}.wav", 22050, audio)
|
write(f"{wav_dir}/{cut_ids[i]}.wav", 24000, audio)
|
||||||
|
|
||||||
|
print(f"RTF : {infer_time / (audio_time * (256/24000))}")
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
def decode_dataset(
|
||||||
@ -173,7 +183,7 @@ def decode_dataset(
|
|||||||
|
|
||||||
with open(f"{params.res_dir}/{test_set}.scp", "w", encoding="utf8") as f:
|
with open(f"{params.res_dir}/{test_set}.scp", "w", encoding="utf8") as f:
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["text"]
|
# texts = batch["text"]
|
||||||
cut_ids = [cut.id for cut in batch["cut"]]
|
cut_ids = [cut.id for cut in batch["cut"]]
|
||||||
|
|
||||||
decode_one_batch(
|
decode_one_batch(
|
||||||
@ -182,12 +192,12 @@ def decode_dataset(
|
|||||||
batch=batch,
|
batch=batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(texts) == len(cut_ids), (len(texts), len(cut_ids))
|
# assert len(texts) == len(cut_ids), (len(texts), len(cut_ids))
|
||||||
|
|
||||||
for i in range(len(texts)):
|
# for i in range(len(texts)):
|
||||||
f.write(f"{cut_ids[i]}\t{texts[i]}\n")
|
# f.write(f"{cut_ids[i]}\t{texts[i]}\n")
|
||||||
|
|
||||||
num_cuts += len(texts)
|
# num_cuts += len(texts)
|
||||||
|
|
||||||
if batch_idx % 50 == 0:
|
if batch_idx % 50 == 0:
|
||||||
batch_str = f"{batch_idx}/{num_batches}"
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
@ -200,7 +210,7 @@ def decode_dataset(
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
LJSpeechTtsDataModule.add_arguments(parser)
|
LibriTTSDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
@ -318,11 +328,11 @@ def main():
|
|||||||
|
|
||||||
# we need cut ids to display recognition results.
|
# we need cut ids to display recognition results.
|
||||||
args.return_cuts = True
|
args.return_cuts = True
|
||||||
ljspeech = LJSpeechTtsDataModule(args)
|
libritts = LibriTTSDataModule(args)
|
||||||
|
|
||||||
test_cuts = ljspeech.test_cuts()
|
test_cuts = libritts.test_clean_cuts()
|
||||||
|
|
||||||
test_dl = ljspeech.test_dataloaders(test_cuts)
|
test_dl = libritts.test_dataloaders(test_cuts)
|
||||||
|
|
||||||
test_sets = ["test"]
|
test_sets = ["test"]
|
||||||
test_dls = [test_dl]
|
test_dls = [test_dl]
|
||||||
|
@ -19,8 +19,9 @@ class Vocos(torch.nn.Module):
|
|||||||
hop_length: int = 256,
|
hop_length: int = 256,
|
||||||
intermediate_dim: int = 1536,
|
intermediate_dim: int = 1536,
|
||||||
num_layers: int = 8,
|
num_layers: int = 8,
|
||||||
padding: str = "same",
|
padding: str = "none",
|
||||||
sample_rate: int = 24000,
|
sample_rate: int = 24000,
|
||||||
|
max_seconds: int = 60,
|
||||||
):
|
):
|
||||||
super(Vocos, self).__init__()
|
super(Vocos, self).__init__()
|
||||||
self.generator = Generator(
|
self.generator = Generator(
|
||||||
@ -31,6 +32,7 @@ class Vocos(torch.nn.Module):
|
|||||||
num_layers=num_layers,
|
num_layers=num_layers,
|
||||||
intermediate_dim=intermediate_dim,
|
intermediate_dim=intermediate_dim,
|
||||||
padding=padding,
|
padding=padding,
|
||||||
|
max_samples=int(sample_rate * max_seconds),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mpd = MultiPeriodDiscriminator()
|
self.mpd = MultiPeriodDiscriminator()
|
||||||
|
268
egs/libritts/TTS/vocos/onnx_pretrained.py
Executable file
268
egs/libritts/TTS/vocos/onnx_pretrained.py
Executable file
@ -0,0 +1,268 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This script loads ONNX models and uses them to decode waves.
|
||||||
|
You can use the following command to get the exported models:
|
||||||
|
|
||||||
|
We use the pre-trained model from
|
||||||
|
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||||
|
as an example to show how to use this file.
|
||||||
|
|
||||||
|
1. Download the pre-trained model
|
||||||
|
|
||||||
|
cd egs/librispeech/ASR
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
|
||||||
|
cd exp
|
||||||
|
ln -s pretrained.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
2. Export the model to ONNX
|
||||||
|
|
||||||
|
./zipformer/export-onnx.py \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--causal False
|
||||||
|
|
||||||
|
It will generate the following 3 files inside $repo/exp:
|
||||||
|
|
||||||
|
- encoder-epoch-99-avg-1.onnx
|
||||||
|
- decoder-epoch-99-avg-1.onnx
|
||||||
|
- joiner-epoch-99-avg-1.onnx
|
||||||
|
|
||||||
|
3. Run this file
|
||||||
|
|
||||||
|
./zipformer/onnx_pretrained.py \
|
||||||
|
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||||
|
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||||
|
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import onnxruntime as ort
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
from lhotse import Fbank, FbankConfig
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the encoder onnx model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sampling-rate",
|
||||||
|
type=int,
|
||||||
|
default=24000,
|
||||||
|
help="The sampleing rate of libritts dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--frame-shift",
|
||||||
|
type=int,
|
||||||
|
default=256,
|
||||||
|
help="Frame shift.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--frame-length",
|
||||||
|
type=int,
|
||||||
|
default=1024,
|
||||||
|
help="Frame shift.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-fft-mag",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to use magnitude of fbank, false to use power energy.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir",
|
||||||
|
type=str,
|
||||||
|
default="generated_audios",
|
||||||
|
help="The generated will be written to.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"sound_files",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="The input sound file(s) to transcribe. "
|
||||||
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
|
"For example, wav and flac are supported. "
|
||||||
|
"The sample rate has to be 16kHz.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxModel:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_filename: str,
|
||||||
|
):
|
||||||
|
session_opts = ort.SessionOptions()
|
||||||
|
session_opts.inter_op_num_threads = 1
|
||||||
|
session_opts.intra_op_num_threads = 4
|
||||||
|
|
||||||
|
self.session_opts = session_opts
|
||||||
|
|
||||||
|
self.init_model(model_filename)
|
||||||
|
|
||||||
|
def init_model(self, model_filename: str):
|
||||||
|
self.model = ort.InferenceSession(
|
||||||
|
model_filename,
|
||||||
|
sess_options=self.session_opts,
|
||||||
|
providers=["CPUExecutionProvider"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_model(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
A 3-D tensor of shape (N, T, C)
|
||||||
|
x_lens:
|
||||||
|
A 2-D tensor of shape (N,). Its dtype is torch.int64
|
||||||
|
Returns:
|
||||||
|
Return a tuple containing:
|
||||||
|
- encoder_out, its shape is (N, T', joiner_dim)
|
||||||
|
- encoder_out_lens, its shape is (N,)
|
||||||
|
"""
|
||||||
|
out = self.model.run(
|
||||||
|
[
|
||||||
|
self.model.get_outputs()[0].name,
|
||||||
|
],
|
||||||
|
{
|
||||||
|
self.model.get_inputs()[0].name: x.numpy(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return torch.from_numpy(out[0])
|
||||||
|
|
||||||
|
|
||||||
|
def read_sound_files(
|
||||||
|
filenames: List[str], expected_sample_rate: float
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||||
|
Args:
|
||||||
|
filenames:
|
||||||
|
A list of sound filenames.
|
||||||
|
expected_sample_rate:
|
||||||
|
The expected sample rate of the sound files.
|
||||||
|
Returns:
|
||||||
|
Return a list of 1-D float32 torch tensors.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
for f in filenames:
|
||||||
|
wave, sample_rate = torchaudio.load(f)
|
||||||
|
assert (
|
||||||
|
sample_rate == expected_sample_rate
|
||||||
|
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||||
|
# We use only the first channel
|
||||||
|
ans.append(wave[0])
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
output_dir = Path(args.model_filename).parent / args.output_dir
|
||||||
|
output_dir.mkdir(exist_ok=True)
|
||||||
|
args.output_dir = output_dir
|
||||||
|
logging.info(vars(args))
|
||||||
|
|
||||||
|
model = OnnxModel(model_filename=args.model_filename)
|
||||||
|
|
||||||
|
config = FbankConfig(
|
||||||
|
sampling_rate=args.sampling_rate,
|
||||||
|
frame_length=args.frame_length / args.sampling_rate, # (in second),
|
||||||
|
frame_shift=args.frame_shift / args.sampling_rate, # (in second)
|
||||||
|
use_fft_mag=args.use_fft_mag,
|
||||||
|
)
|
||||||
|
fbank = Fbank(config)
|
||||||
|
|
||||||
|
logging.info(f"Reading sound files: {args.sound_files}")
|
||||||
|
|
||||||
|
waves = read_sound_files(
|
||||||
|
filenames=args.sound_files, expected_sample_rate=args.sampling_rate
|
||||||
|
)
|
||||||
|
wave_lengths = [w.size(0) for w in waves]
|
||||||
|
waves = pad_sequence(waves, batch_first=True, padding_value=0)
|
||||||
|
|
||||||
|
logging.info(f"waves : {waves.shape}")
|
||||||
|
|
||||||
|
features = fbank.extract_batch(waves, sampling_rate=args.sampling_rate)
|
||||||
|
|
||||||
|
if features.dim() == 2:
|
||||||
|
features = features.unsqueeze(0)
|
||||||
|
|
||||||
|
features = features.permute(0, 2, 1)
|
||||||
|
|
||||||
|
logging.info(f"features : {features.shape}")
|
||||||
|
|
||||||
|
logging.info("Generating started")
|
||||||
|
|
||||||
|
# model forward
|
||||||
|
audios = model.run_model(features)
|
||||||
|
|
||||||
|
for i, filename in enumerate(args.sound_files):
|
||||||
|
audio = audios[i : i + 1, 0 : wave_lengths[i]]
|
||||||
|
ofilename = args.output_dir / filename.split("/")[-1]
|
||||||
|
logging.info(f"Writting audio : {ofilename}")
|
||||||
|
torchaudio.save(str(ofilename), audio.cpu(), args.sampling_rate)
|
||||||
|
|
||||||
|
logging.info("Generating Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
196
egs/libritts/TTS/vocos/pretrained.py
Executable file
196
egs/libritts/TTS/vocos/pretrained.py
Executable file
@ -0,0 +1,196 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2024 Xiaomi Corp. (authors: Wei Kang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This script loads a checkpoint and uses it to decode waves.
|
||||||
|
You can generate the checkpoint with the following command:
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
from lhotse import Fbank, FbankConfig
|
||||||
|
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the checkpoint. "
|
||||||
|
"The checkpoint is assumed to be saved by "
|
||||||
|
"icefall.checkpoint.save_checkpoint().",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sampling-rate",
|
||||||
|
type=int,
|
||||||
|
default=24000,
|
||||||
|
help="The sampleing rate of libritts dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--frame-shift",
|
||||||
|
type=int,
|
||||||
|
default=256,
|
||||||
|
help="Frame shift.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--frame-length",
|
||||||
|
type=int,
|
||||||
|
default=1024,
|
||||||
|
help="Frame shift.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-fft-mag",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to use magnitude of fbank, false to use power energy.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir",
|
||||||
|
type=str,
|
||||||
|
default="generated_audios",
|
||||||
|
help="The generated will be written to.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"sound_files",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="The input sound file(s) to transcribe. "
|
||||||
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
|
"For example, wav and flac are supported. "
|
||||||
|
"The sample rate has to be 16kHz.",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def read_sound_files(
|
||||||
|
filenames: List[str], expected_sample_rate: float
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||||
|
Args:
|
||||||
|
filenames:
|
||||||
|
A list of sound filenames.
|
||||||
|
expected_sample_rate:
|
||||||
|
The expected sample rate of the sound files.
|
||||||
|
Returns:
|
||||||
|
Return a list of 1-D float32 torch tensors.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
for f in filenames:
|
||||||
|
wave, sample_rate = torchaudio.load(f)
|
||||||
|
assert (
|
||||||
|
sample_rate == expected_sample_rate
|
||||||
|
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||||
|
# We use only the first channel
|
||||||
|
ans.append(wave[0].contiguous())
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
params.device = device
|
||||||
|
|
||||||
|
output_dir = Path(params.checkpoint).parent / params.output_dir
|
||||||
|
output_dir.mkdir(exist_ok=True)
|
||||||
|
params.output_dir = output_dir
|
||||||
|
|
||||||
|
logging.info(f"{params}")
|
||||||
|
|
||||||
|
logging.info("Creating model")
|
||||||
|
model = get_model(params)
|
||||||
|
|
||||||
|
model = model.generator
|
||||||
|
|
||||||
|
checkpoint = torch.load(params.checkpoint, map_location="cpu")
|
||||||
|
model.load_state_dict(checkpoint["model"], strict=False)
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
logging.info("Constructing Fbank computer")
|
||||||
|
|
||||||
|
config = FbankConfig(
|
||||||
|
sampling_rate=params.sampling_rate,
|
||||||
|
frame_length=params.frame_length / params.sampling_rate, # (in second),
|
||||||
|
frame_shift=params.frame_shift / params.sampling_rate, # (in second)
|
||||||
|
use_fft_mag=params.use_fft_mag,
|
||||||
|
)
|
||||||
|
fbank = Fbank(config)
|
||||||
|
|
||||||
|
logging.info(f"Reading sound files: {params.sound_files}")
|
||||||
|
|
||||||
|
waves = read_sound_files(
|
||||||
|
filenames=params.sound_files, expected_sample_rate=params.sampling_rate
|
||||||
|
)
|
||||||
|
wave_lengths = [w.size(0) for w in waves]
|
||||||
|
waves = pad_sequence(waves, batch_first=True, padding_value=0)
|
||||||
|
|
||||||
|
features = (
|
||||||
|
fbank.extract_batch(waves, sampling_rate=params.sampling_rate)
|
||||||
|
.permute(0, 2, 1)
|
||||||
|
.to(device)
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Generating started")
|
||||||
|
|
||||||
|
# model forward
|
||||||
|
audios = model(features)
|
||||||
|
|
||||||
|
for i, filename in enumerate(params.sound_files):
|
||||||
|
audio = audios[i : i + 1, 0 : wave_lengths[i]]
|
||||||
|
ofilename = params.output_dir / filename.split("/")[-1]
|
||||||
|
logging.info(f"Writting audio : {ofilename}")
|
||||||
|
torchaudio.save(str(ofilename), audio.cpu(), params.sampling_rate)
|
||||||
|
|
||||||
|
logging.info("Generating Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
@ -52,9 +52,11 @@ from utils import (
|
|||||||
save_checkpoint,
|
save_checkpoint,
|
||||||
plot_spectrogram,
|
plot_spectrogram,
|
||||||
get_cosine_schedule_with_warmup,
|
get_cosine_schedule_with_warmup,
|
||||||
|
save_checkpoint_with_global_batch_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
from icefall import diagnostics
|
from icefall import diagnostics
|
||||||
|
from icefall.checkpoint import remove_checkpoints, update_averaged_model
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
@ -91,6 +93,20 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
help="Intermediate dim of ConvNeXt module.",
|
help="Intermediate dim of ConvNeXt module.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-seconds",
|
||||||
|
type=int,
|
||||||
|
default=60,
|
||||||
|
help="""
|
||||||
|
The length of the precomputed normalization window sum square
|
||||||
|
(required by istft). This argument is only for onnx export, it determines
|
||||||
|
the max length of the audio that be properly normalized.
|
||||||
|
Note, you can generate audios longer than this value with the exported onnx model,
|
||||||
|
the part longer than this value will not be normalized yet.
|
||||||
|
The larger this value is the bigger the exported onnx model will be.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -203,6 +219,16 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--keep-last-epoch-k",
|
||||||
|
type=int,
|
||||||
|
default=50,
|
||||||
|
help="""Only keep this number of checkpoints on disk.
|
||||||
|
For instance, if it is 3, there are only 3 checkpoints
|
||||||
|
in the exp-dir with filenames `epoch-xxx.pt`.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--average-period",
|
"--average-period",
|
||||||
type=int,
|
type=int,
|
||||||
@ -290,8 +316,8 @@ def get_params() -> AttributeDict:
|
|||||||
"valid_interval": 500,
|
"valid_interval": 500,
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"segment_size": 16384,
|
"segment_size": 16384,
|
||||||
"adam_b1": 0.8,
|
"adam_b1": 0.9,
|
||||||
"adam_b2": 0.9,
|
"adam_b2": 0.99,
|
||||||
"warmup_steps": 0,
|
"warmup_steps": 0,
|
||||||
"max_steps": 2000000,
|
"max_steps": 2000000,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
@ -311,6 +337,7 @@ def get_model(params: AttributeDict) -> nn.Module:
|
|||||||
intermediate_dim=params.intermediate_dim,
|
intermediate_dim=params.intermediate_dim,
|
||||||
num_layers=params.num_layers,
|
num_layers=params.num_layers,
|
||||||
sample_rate=params.sampling_rate,
|
sample_rate=params.sampling_rate,
|
||||||
|
max_seconds=params.max_seconds,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
num_param_gen = sum([p.numel() for p in model.generator.parameters()])
|
num_param_gen = sum([p.numel() for p in model.generator.parameters()])
|
||||||
@ -479,11 +506,6 @@ def compute_discriminator_loss(
|
|||||||
info["loss_disc_mrd"] = loss_mrd.detach().cpu().item()
|
info["loss_disc_mrd"] = loss_mrd.detach().cpu().item()
|
||||||
info["loss_disc_mpd"] = loss_mpd.detach().cpu().item()
|
info["loss_disc_mpd"] = loss_mpd.detach().cpu().item()
|
||||||
|
|
||||||
for i in range(len(loss_mpd_real)):
|
|
||||||
info[f"loss_disc_mpd_period_{i+1}"] = loss_mpd_real[i] + loss_mpd_gen[i]
|
|
||||||
for i in range(len(loss_mrd_real)):
|
|
||||||
info[f"loss_disc_mrd_resolution_{i+1}"] = loss_mrd_real[i] + loss_mrd_gen[i]
|
|
||||||
|
|
||||||
return loss_disc_all, info
|
return loss_disc_all, info
|
||||||
|
|
||||||
|
|
||||||
@ -497,6 +519,7 @@ def train_one_epoch(
|
|||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
scaler: GradScaler,
|
scaler: GradScaler,
|
||||||
|
model_avg: Optional[nn.Module] = None,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
@ -542,6 +565,7 @@ def train_one_epoch(
|
|||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
|
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
|
||||||
model=model,
|
model=model,
|
||||||
|
model_avg=model_avg,
|
||||||
params=params,
|
params=params,
|
||||||
optimizer_g=optimizer_g,
|
optimizer_g=optimizer_g,
|
||||||
optimizer_d=optimizer_d,
|
optimizer_d=optimizer_d,
|
||||||
@ -588,6 +612,7 @@ def train_one_epoch(
|
|||||||
|
|
||||||
loss_disc.backward()
|
loss_disc.backward()
|
||||||
optimizer_d.step()
|
optimizer_d.step()
|
||||||
|
scheduler_d.step()
|
||||||
|
|
||||||
optimizer_g.zero_grad()
|
optimizer_g.zero_grad()
|
||||||
loss_gen, loss_gen_info = compute_generator_loss(
|
loss_gen, loss_gen_info = compute_generator_loss(
|
||||||
@ -599,6 +624,7 @@ def train_one_epoch(
|
|||||||
|
|
||||||
loss_gen.backward()
|
loss_gen.backward()
|
||||||
optimizer_g.step()
|
optimizer_g.step()
|
||||||
|
scheduler_g.step()
|
||||||
|
|
||||||
loss_info = loss_gen_info + loss_disc_info
|
loss_info = loss_gen_info + loss_disc_info
|
||||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_gen_info
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_gen_info
|
||||||
@ -611,6 +637,39 @@ def train_one_epoch(
|
|||||||
if params.print_diagnostics and batch_idx == 5:
|
if params.print_diagnostics and batch_idx == 5:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if (
|
||||||
|
rank == 0
|
||||||
|
and params.batch_idx_train > 0
|
||||||
|
and params.batch_idx_train % params.average_period == 0
|
||||||
|
):
|
||||||
|
update_averaged_model(
|
||||||
|
params=params,
|
||||||
|
model_cur=model,
|
||||||
|
model_avg=model_avg,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
params.batch_idx_train > 0
|
||||||
|
and params.batch_idx_train % params.save_every_n == 0
|
||||||
|
):
|
||||||
|
save_checkpoint_with_global_batch_idx(
|
||||||
|
out_dir=params.exp_dir,
|
||||||
|
global_batch_idx=params.batch_idx_train,
|
||||||
|
model=model,
|
||||||
|
model_avg=model_avg,
|
||||||
|
params=params,
|
||||||
|
optimizer=optimizer,
|
||||||
|
scheduler=scheduler,
|
||||||
|
sampler=train_dl.sampler,
|
||||||
|
scaler=scaler,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
remove_checkpoints(
|
||||||
|
out_dir=params.exp_dir,
|
||||||
|
topk=params.keep_last_k,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
|
||||||
if params.batch_idx_train % 100 == 0 and params.use_fp16:
|
if params.batch_idx_train % 100 == 0 and params.use_fp16:
|
||||||
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
||||||
# of the grad scaler is configurable, but we can't configure it to have different
|
# of the grad scaler is configurable, but we can't configure it to have different
|
||||||
@ -641,8 +700,8 @@ def train_one_epoch(
|
|||||||
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
||||||
f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, "
|
f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, "
|
||||||
f"loss[{loss_info}], tot_loss[{tot_loss}], "
|
f"loss[{loss_info}], tot_loss[{tot_loss}], "
|
||||||
f"cur_lr_g: {cur_lr_g:.2e}, "
|
f"cur_lr_g: {cur_lr_g:.4e}, "
|
||||||
f"cur_lr_d: {cur_lr_d:.2e}, "
|
f"cur_lr_d: {cur_lr_d:.4e}, "
|
||||||
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -685,8 +744,6 @@ def train_one_epoch(
|
|||||||
tb_writer, "train/valid_", params.batch_idx_train
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
|
|
||||||
scheduler_g.step()
|
|
||||||
scheduler_d.step()
|
|
||||||
loss_value = tot_loss["loss_gen"]
|
loss_value = tot_loss["loss_gen"]
|
||||||
params.train_loss = loss_value
|
params.train_loss = loss_value
|
||||||
if params.train_loss < params.best_train_loss:
|
if params.train_loss < params.best_train_loss:
|
||||||
@ -766,7 +823,7 @@ def compute_validation_loss(
|
|||||||
params.sampling_rate,
|
params.sampling_rate,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info(f"RTF : {infer_time / (audio_time * 10 / 1000)}")
|
logging.info(f"Validation RTF : {infer_time / (audio_time * 10 / 1000)}")
|
||||||
|
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
tot_loss.reduce(device)
|
tot_loss.reduce(device)
|
||||||
@ -811,15 +868,22 @@ def run(rank, world_size, args):
|
|||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda", rank)
|
device = torch.device("cuda", rank)
|
||||||
logging.info(f"Device: {device}")
|
|
||||||
params.device = device
|
params.device = device
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
logging.info("About to create model")
|
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
model = get_model(params)
|
model = get_model(params)
|
||||||
|
|
||||||
|
assert params.save_every_n >= params.average_period
|
||||||
|
model_avg: Optional[nn.Module] = None
|
||||||
|
if rank == 0:
|
||||||
|
# model_avg is only used with rank 0
|
||||||
|
model_avg = copy.deepcopy(model).to(torch.float64)
|
||||||
|
|
||||||
assert params.start_epoch > 0, params.start_epoch
|
assert params.start_epoch > 0, params.start_epoch
|
||||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
checkpoints = load_checkpoint_if_available(
|
||||||
|
params=params, model=model, model_avg=model_avg
|
||||||
|
)
|
||||||
|
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
generator = model.generator
|
generator = model.generator
|
||||||
@ -915,6 +979,7 @@ def run(rank, world_size, args):
|
|||||||
train_one_epoch(
|
train_one_epoch(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
model_avg=model_avg,
|
||||||
optimizer_g=optimizer_g,
|
optimizer_g=optimizer_g,
|
||||||
optimizer_d=optimizer_d,
|
optimizer_d=optimizer_d,
|
||||||
scheduler_g=scheduler_g,
|
scheduler_g=scheduler_g,
|
||||||
@ -936,6 +1001,7 @@ def run(rank, world_size, args):
|
|||||||
filename=filename,
|
filename=filename,
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
model_avg=model_avg,
|
||||||
optimizer_g=optimizer_g,
|
optimizer_g=optimizer_g,
|
||||||
optimizer_d=optimizer_d,
|
optimizer_d=optimizer_d,
|
||||||
scheduler_g=scheduler_g,
|
scheduler_g=scheduler_g,
|
||||||
@ -945,21 +1011,6 @@ def run(rank, world_size, args):
|
|||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.batch_idx_train % params.save_every_n == 0:
|
|
||||||
filename = params.exp_dir / f"checkpoint-{params.batch_idx_train}.pt"
|
|
||||||
save_checkpoint(
|
|
||||||
filename=filename,
|
|
||||||
params=params,
|
|
||||||
model=model,
|
|
||||||
optimizer_g=optimizer_g,
|
|
||||||
optimizer_d=optimizer_d,
|
|
||||||
scheduler_g=scheduler_g,
|
|
||||||
scheduler_d=scheduler_d,
|
|
||||||
sampler=train_dl.sampler,
|
|
||||||
scaler=scaler,
|
|
||||||
rank=rank,
|
|
||||||
)
|
|
||||||
if rank == 0:
|
|
||||||
if params.best_train_epoch == params.cur_epoch:
|
if params.best_train_epoch == params.cur_epoch:
|
||||||
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
||||||
copyfile(src=filename, dst=best_train_filename)
|
copyfile(src=filename, dst=best_train_filename)
|
||||||
@ -968,6 +1019,13 @@ def run(rank, world_size, args):
|
|||||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||||
copyfile(src=filename, dst=best_valid_filename)
|
copyfile(src=filename, dst=best_valid_filename)
|
||||||
|
|
||||||
|
remove_checkpoints(
|
||||||
|
out_dir=params.exp_dir,
|
||||||
|
topk=params.keep_last_epoch_k,
|
||||||
|
prefix="epoch",
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
|
@ -34,6 +34,69 @@ def plot_spectrogram(spectrogram):
|
|||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def save_checkpoint_with_global_batch_idx(
|
||||||
|
out_dir: Path,
|
||||||
|
global_batch_idx: int,
|
||||||
|
model: Union[nn.Module, DDP],
|
||||||
|
model_avg: Optional[nn.Module] = None,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
optimizer_g: Optional[Optimizer] = None,
|
||||||
|
optimizer_d: Optional[Optimizer] = None,
|
||||||
|
scheduler_g: Optional[LRScheduler] = None,
|
||||||
|
scheduler_d: Optional[LRScheduler] = None,
|
||||||
|
scaler: Optional[GradScaler] = None,
|
||||||
|
sampler: Optional[CutSampler] = None,
|
||||||
|
rank: int = 0,
|
||||||
|
):
|
||||||
|
"""Save training info after processing given number of batches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
out_dir:
|
||||||
|
The directory to save the checkpoint.
|
||||||
|
global_batch_idx:
|
||||||
|
The number of batches processed so far from the very start of the
|
||||||
|
training. The saved checkpoint will have the following filename:
|
||||||
|
|
||||||
|
f'out_dir / checkpoint-{global_batch_idx}.pt'
|
||||||
|
model:
|
||||||
|
The neural network model whose `state_dict` will be saved in the
|
||||||
|
checkpoint.
|
||||||
|
model_avg:
|
||||||
|
The stored model averaged from the start of training.
|
||||||
|
params:
|
||||||
|
A dict of training configurations to be saved.
|
||||||
|
optimizer:
|
||||||
|
The optimizer used in the training. Its `state_dict` will be saved.
|
||||||
|
scheduler:
|
||||||
|
The learning rate scheduler used in the training. Its `state_dict` will
|
||||||
|
be saved.
|
||||||
|
scaler:
|
||||||
|
The scaler used for mix precision training. Its `state_dict` will
|
||||||
|
be saved.
|
||||||
|
sampler:
|
||||||
|
The sampler used in the training dataset.
|
||||||
|
rank:
|
||||||
|
The rank ID used in DDP training of the current node. Set it to 0
|
||||||
|
if DDP is not used.
|
||||||
|
"""
|
||||||
|
out_dir = Path(out_dir)
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
filename = out_dir / f"checkpoint-{global_batch_idx}.pt"
|
||||||
|
save_checkpoint(
|
||||||
|
filename=filename,
|
||||||
|
model=model,
|
||||||
|
model_avg=model_avg,
|
||||||
|
params=params,
|
||||||
|
optimizer_g=optimizer_g,
|
||||||
|
scheduler_g=scheduler_g,
|
||||||
|
optimizer_d=optimizer_d,
|
||||||
|
scheduler_d=scheduler_d,
|
||||||
|
scaler=scaler,
|
||||||
|
sampler=sampler,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
filename: Path,
|
filename: Path,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
|
287
egs/ljspeech/TTS/local/evaluate_fsd.py
Normal file
287
egs/ljspeech/TTS/local/evaluate_fsd.py
Normal file
@ -0,0 +1,287 @@
|
|||||||
|
"""
|
||||||
|
Calculate Frechet Speech Distance betweeen two speech directories.
|
||||||
|
Adapted from: https://github.com/gudgud96/frechet-audio-distance/blob/main/frechet_audio_distance/fad.py
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from multiprocessing.dummy import Pool as ThreadPool
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
import torch
|
||||||
|
from scipy import linalg
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--real-path", type=str, help="path of the real speech directory"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--eval-path", type=str, help="path of the evaluated speech directory"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-path",
|
||||||
|
type=str,
|
||||||
|
default="model/huggingface/wav2vec2_base",
|
||||||
|
help="path of the wav2vec 2.0 model directory",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--real-embds-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="path of the real embedding directory",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--eval-embds-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="path of the evaluated embedding directory",
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
class FrechetSpeechDistance:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_path="resources/wav2vec2_base",
|
||||||
|
pca_dim=128,
|
||||||
|
speech_load_worker=8,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize FSD
|
||||||
|
"""
|
||||||
|
self.sample_rate = 16000
|
||||||
|
self.channels = 1
|
||||||
|
self.device = (
|
||||||
|
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
|
)
|
||||||
|
logging.info("[Frechet Speech Distance] Using device: {}".format(self.device))
|
||||||
|
self.speech_load_worker = speech_load_worker
|
||||||
|
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_path)
|
||||||
|
self.model = Wav2Vec2Model.from_pretrained(model_path)
|
||||||
|
self.model.to(self.device)
|
||||||
|
self.model.eval()
|
||||||
|
self.pca_dim = pca_dim
|
||||||
|
|
||||||
|
def load_speech_files(self, dir, dtype="float32"):
|
||||||
|
def _load_speech_task(fname, sample_rate, channels, dtype="float32"):
|
||||||
|
if dtype not in ["float64", "float32", "int32", "int16"]:
|
||||||
|
raise ValueError(f"dtype not supported: {dtype}")
|
||||||
|
|
||||||
|
wav_data, sr = sf.read(fname, dtype=dtype)
|
||||||
|
# For integer type PCM input, convert to [-1.0, +1.0]
|
||||||
|
if dtype == "int16":
|
||||||
|
wav_data = wav_data / 32768.0
|
||||||
|
elif dtype == "int32":
|
||||||
|
wav_data = wav_data / float(2**31)
|
||||||
|
|
||||||
|
# Convert to mono
|
||||||
|
assert channels in [1, 2], "channels must be 1 or 2"
|
||||||
|
if len(wav_data.shape) > channels:
|
||||||
|
wav_data = np.mean(wav_data, axis=1)
|
||||||
|
|
||||||
|
if sr != sample_rate:
|
||||||
|
wav_data = (
|
||||||
|
librosa.resample(wav_data, orig_sr=sr, target_sr=sample_rate),
|
||||||
|
)
|
||||||
|
|
||||||
|
return wav_data
|
||||||
|
|
||||||
|
task_results = []
|
||||||
|
|
||||||
|
pool = ThreadPool(self.speech_load_worker)
|
||||||
|
|
||||||
|
logging.info("[Frechet Speech Distance] Loading speech from {}...".format(dir))
|
||||||
|
for fname in os.listdir(dir):
|
||||||
|
res = pool.apply_async(
|
||||||
|
_load_speech_task,
|
||||||
|
args=(os.path.join(dir, fname), self.sample_rate, self.channels, dtype),
|
||||||
|
)
|
||||||
|
task_results.append(res)
|
||||||
|
pool.close()
|
||||||
|
pool.join()
|
||||||
|
|
||||||
|
return [k.get() for k in task_results]
|
||||||
|
|
||||||
|
def get_embeddings(self, x):
|
||||||
|
"""
|
||||||
|
Get embeddings
|
||||||
|
Params:
|
||||||
|
-- x : a list of np.ndarray speech samples
|
||||||
|
-- sr : sampling rate.
|
||||||
|
"""
|
||||||
|
embd_lst = []
|
||||||
|
try:
|
||||||
|
for speech in tqdm(x):
|
||||||
|
input_features = self.feature_extractor(
|
||||||
|
speech, sampling_rate=self.sample_rate, return_tensors="pt"
|
||||||
|
).input_values.to(self.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
embd = self.model(input_features).last_hidden_state.mean(1)
|
||||||
|
|
||||||
|
if embd.device != torch.device("cpu"):
|
||||||
|
embd = embd.cpu()
|
||||||
|
|
||||||
|
if torch.is_tensor(embd):
|
||||||
|
embd = embd.detach().numpy()
|
||||||
|
|
||||||
|
embd_lst.append(embd)
|
||||||
|
except Exception as e:
|
||||||
|
print(
|
||||||
|
"[Frechet Speech Distance] get_embeddings throw an exception: {}".format(
|
||||||
|
str(e)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return np.concatenate(embd_lst, axis=0)
|
||||||
|
|
||||||
|
def calculate_embd_statistics(self, embd_lst):
|
||||||
|
if isinstance(embd_lst, list):
|
||||||
|
embd_lst = np.array(embd_lst)
|
||||||
|
mu = np.mean(embd_lst, axis=0)
|
||||||
|
sigma = np.cov(embd_lst, rowvar=False)
|
||||||
|
return mu, sigma
|
||||||
|
|
||||||
|
def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6):
|
||||||
|
"""
|
||||||
|
Adapted from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py
|
||||||
|
|
||||||
|
Numpy implementation of the Frechet Distance.
|
||||||
|
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
||||||
|
and X_2 ~ N(mu_2, C_2) is
|
||||||
|
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
||||||
|
Stable version by Dougal J. Sutherland.
|
||||||
|
Params:
|
||||||
|
-- mu1 : Numpy array containing the activations of a layer of the
|
||||||
|
inception net (like returned by the function 'get_predictions')
|
||||||
|
for generated samples.
|
||||||
|
-- mu2 : The sample mean over activations, precalculated on an
|
||||||
|
representative data set.
|
||||||
|
-- sigma1: The covariance matrix over activations for generated samples.
|
||||||
|
-- sigma2: The covariance matrix over activations, precalculated on an
|
||||||
|
representative data set.
|
||||||
|
Returns:
|
||||||
|
-- : The Frechet Distance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
mu1 = np.atleast_1d(mu1)
|
||||||
|
mu2 = np.atleast_1d(mu2)
|
||||||
|
|
||||||
|
sigma1 = np.atleast_2d(sigma1)
|
||||||
|
sigma2 = np.atleast_2d(sigma2)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
mu1.shape == mu2.shape
|
||||||
|
), "Training and test mean vectors have different lengths"
|
||||||
|
assert (
|
||||||
|
sigma1.shape == sigma2.shape
|
||||||
|
), "Training and test covariances have different dimensions"
|
||||||
|
|
||||||
|
diff = mu1 - mu2
|
||||||
|
|
||||||
|
# Product might be almost singular
|
||||||
|
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2).astype(complex), disp=False)
|
||||||
|
if not np.isfinite(covmean).all():
|
||||||
|
msg = (
|
||||||
|
"fid calculation produces singular product; "
|
||||||
|
"adding %s to diagonal of cov estimates"
|
||||||
|
) % eps
|
||||||
|
logging.info(msg)
|
||||||
|
offset = np.eye(sigma1.shape[0]) * eps
|
||||||
|
covmean = linalg.sqrtm(
|
||||||
|
(sigma1 + offset).dot(sigma2 + offset).astype(complex)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Numerical error might give slight imaginary component
|
||||||
|
if np.iscomplexobj(covmean):
|
||||||
|
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
||||||
|
m = np.max(np.abs(covmean.imag))
|
||||||
|
raise ValueError("Imaginary component {}".format(m))
|
||||||
|
covmean = covmean.real
|
||||||
|
|
||||||
|
tr_covmean = np.trace(covmean)
|
||||||
|
|
||||||
|
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
|
||||||
|
|
||||||
|
def score(
|
||||||
|
self,
|
||||||
|
real_path,
|
||||||
|
eval_path,
|
||||||
|
real_embds_path=None,
|
||||||
|
eval_embds_path=None,
|
||||||
|
dtype="float32",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Computes the Frechet Speech Distance (FSD) between two directories of speech files.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- real_path (str): Path to the directory containing real speech files.
|
||||||
|
- eval_path (str): Path to the directory containing evaluation speech files.
|
||||||
|
- real_embds_path (str, optional): Path to save/load real speech embeddings (e.g., /folder/bkg_embs.npy). If None, embeddings won't be saved.
|
||||||
|
- eval_embds_path (str, optional): Path to save/load evaluation speech embeddings (e.g., /folder/test_embs.npy). If None, embeddings won't be saved.
|
||||||
|
- dtype (str, optional): Data type for loading speech. Default is "float32".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- float: The Frechet Speech Distance (FSD) score between the two directories of speech files.
|
||||||
|
"""
|
||||||
|
# Load or compute real embeddings
|
||||||
|
if real_embds_path is not None and os.path.exists(real_embds_path):
|
||||||
|
logging.info(
|
||||||
|
f"[Frechet Speech Distance] Loading embeddings from {real_embds_path}..."
|
||||||
|
)
|
||||||
|
embds_real = np.load(real_embds_path)
|
||||||
|
else:
|
||||||
|
speech_real = self.load_speech_files(real_path, dtype=dtype)
|
||||||
|
embds_real = self.get_embeddings(speech_real)
|
||||||
|
if real_embds_path:
|
||||||
|
os.makedirs(os.path.dirname(real_embds_path), exist_ok=True)
|
||||||
|
np.save(real_embds_path, embds_real)
|
||||||
|
|
||||||
|
# Load or compute eval embeddings
|
||||||
|
if eval_embds_path is not None and os.path.exists(eval_embds_path):
|
||||||
|
logging.info(
|
||||||
|
f"[Frechet Speech Distance] Loading embeddings from {eval_embds_path}..."
|
||||||
|
)
|
||||||
|
embds_eval = np.load(eval_embds_path)
|
||||||
|
else:
|
||||||
|
speech_eval = self.load_speech_files(eval_path, dtype=dtype)
|
||||||
|
embds_eval = self.get_embeddings(speech_eval)
|
||||||
|
if eval_embds_path:
|
||||||
|
os.makedirs(os.path.dirname(eval_embds_path), exist_ok=True)
|
||||||
|
np.save(eval_embds_path, embds_eval)
|
||||||
|
|
||||||
|
# Check if embeddings are empty
|
||||||
|
if len(embds_real) == 0:
|
||||||
|
logging.info("[Frechet Speech Distance] real set dir is empty, exiting...")
|
||||||
|
return -10.46
|
||||||
|
if len(embds_eval) == 0:
|
||||||
|
logging.info("[Frechet Speech Distance] eval set dir is empty, exiting...")
|
||||||
|
return -1
|
||||||
|
|
||||||
|
# Compute statistics and FSD score
|
||||||
|
mu_real, sigma_real = self.calculate_embd_statistics(embds_real)
|
||||||
|
mu_eval, sigma_eval = self.calculate_embd_statistics(embds_eval)
|
||||||
|
|
||||||
|
fsd_score = self.calculate_frechet_distance(
|
||||||
|
mu_real, sigma_real, mu_eval, sigma_eval
|
||||||
|
)
|
||||||
|
|
||||||
|
return fsd_score
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
FSD = FrechetSpeechDistance(model_path=args.model_path)
|
||||||
|
score = FSD.score(
|
||||||
|
args.real_path, args.eval_path, args.real_embds_path, args.eval_embds_path
|
||||||
|
)
|
||||||
|
logging.info(f"FSD score: {score:.2f}")
|
139
egs/ljspeech/TTS/local/evaluate_wer_whisper.py
Normal file
139
egs/ljspeech/TTS/local/evaluate_wer_whisper.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
"""
|
||||||
|
Calculate WER with Whisper model
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import soundfile as sf
|
||||||
|
import torch
|
||||||
|
from num2words import num2words
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import pipeline
|
||||||
|
|
||||||
|
from icefall.utils import store_transcripts, write_error_stats
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument("--wav-path", type=str, help="path of the speech directory")
|
||||||
|
parser.add_argument("--decode-path", type=str, help="path of the speech directory")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-path",
|
||||||
|
type=str,
|
||||||
|
default="model/huggingface/whisper_medium",
|
||||||
|
help="path of the huggingface whisper model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--transcript-path",
|
||||||
|
type=str,
|
||||||
|
default="data/transcript/test.tsv",
|
||||||
|
help="path of the transcript tsv file",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size", type=int, default=64, help="decoding batch size"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device", type=str, default="cuda:0", help="decoding device, cuda:0 or cpu"
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def post_process(text: str):
|
||||||
|
def convert_numbers(match):
|
||||||
|
return num2words(match.group())
|
||||||
|
|
||||||
|
text = re.sub(r"\b\d{1,2}\b", convert_numbers, text)
|
||||||
|
text = re.sub(r"[^a-zA-Z0-9']", " ", text.lower())
|
||||||
|
text = re.sub(r"\s+", " ", text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def save_results(
|
||||||
|
res_dir: str,
|
||||||
|
results: List[Tuple[str, List[str], List[str]]],
|
||||||
|
):
|
||||||
|
if not os.path.exists(res_dir):
|
||||||
|
os.makedirs(res_dir)
|
||||||
|
recog_path = os.path.join(res_dir, "recogs.txt")
|
||||||
|
results = sorted(results)
|
||||||
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
|
errs_filename = os.path.join(res_dir, "errs.txt")
|
||||||
|
with open(errs_filename, "w") as f:
|
||||||
|
_ = write_error_stats(f, "test", results, enable_log=True)
|
||||||
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
|
|
||||||
|
class SpeechEvalDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, wav_path: str, transcript_path: str):
|
||||||
|
super().__init__()
|
||||||
|
self.audio_name = []
|
||||||
|
self.audio_paths = []
|
||||||
|
self.transcripts = []
|
||||||
|
with Path(transcript_path).open("r", encoding="utf8") as f:
|
||||||
|
meta = [item.split("\t") for item in f.read().rstrip().split("\n")]
|
||||||
|
for item in meta:
|
||||||
|
self.audio_name.append(item[0])
|
||||||
|
self.audio_paths.append(Path(wav_path, item[0] + ".wav"))
|
||||||
|
self.transcripts.append(item[1])
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.audio_paths)
|
||||||
|
|
||||||
|
def __getitem__(self, index: int):
|
||||||
|
audio, sampling_rate = sf.read(self.audio_paths[index])
|
||||||
|
item = {
|
||||||
|
"array": librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000),
|
||||||
|
"sampling_rate": 16000,
|
||||||
|
"reference": self.transcripts[index],
|
||||||
|
"audio_name": self.audio_name[index],
|
||||||
|
}
|
||||||
|
return item
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
|
||||||
|
batch_size = args.batch_size
|
||||||
|
|
||||||
|
pipe = pipeline(
|
||||||
|
"automatic-speech-recognition",
|
||||||
|
model=args.model_path,
|
||||||
|
device=args.device,
|
||||||
|
tokenizer=args.model_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = SpeechEvalDataset(args.wav_path, args.transcript_path)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
bar = tqdm(
|
||||||
|
pipe(
|
||||||
|
dataset,
|
||||||
|
generate_kwargs={"language": "english", "task": "transcribe"},
|
||||||
|
batch_size=batch_size,
|
||||||
|
),
|
||||||
|
total=len(dataset),
|
||||||
|
)
|
||||||
|
for out in bar:
|
||||||
|
results.append(
|
||||||
|
(
|
||||||
|
out["audio_name"][0],
|
||||||
|
post_process(out["reference"][0].strip()).split(),
|
||||||
|
post_process(out["text"].strip()).split(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
save_results(args.decode_path, results)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
1
egs/ljspeech/TTS/vocos/export-onnx.py
Symbolic link
1
egs/ljspeech/TTS/vocos/export-onnx.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../libritts/TTS/vocos/export-onnx.py
|
1
egs/ljspeech/TTS/vocos/export.py
Symbolic link
1
egs/ljspeech/TTS/vocos/export.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../libritts/TTS/vocos/export.py
|
340
egs/ljspeech/TTS/vocos/infer.py
Executable file
340
egs/ljspeech/TTS/vocos/infer.py
Executable file
@ -0,0 +1,340 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2024 Xiaomi Corp. (authors: Wei Kang
|
||||||
|
# Han Zhu)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from lhotse.utils import fix_random_seed
|
||||||
|
from scipy.io.wavfile import write
|
||||||
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
from tts_datamodule import LJSpeechTtsDataModule
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||||
|
|
||||||
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
|
Note: Epoch counts from 1.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=10,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="flow_match/exp",
|
||||||
|
help="The experiment dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--generate-dir",
|
||||||
|
type=str,
|
||||||
|
default="generated_wavs",
|
||||||
|
help="Path name of the generated wavs",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def decode_one_batch(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
batch: dict,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
It's the return value of :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The text-to-feature neural model.
|
||||||
|
batch:
|
||||||
|
It is the return value from iterating
|
||||||
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
|
for the format of the `batch`.
|
||||||
|
Returns:
|
||||||
|
Return the decoding result. See above description for the format of
|
||||||
|
the returned dict.
|
||||||
|
"""
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
|
cut_ids = [cut.id for cut in batch["cut"]]
|
||||||
|
|
||||||
|
features = batch["features"] # (B, T, F)
|
||||||
|
utt_durations = batch["features_lens"]
|
||||||
|
|
||||||
|
x = features.permute(0, 2, 1) # (B, F, T)
|
||||||
|
|
||||||
|
audios = model(x.to(device)) # (B, T)
|
||||||
|
|
||||||
|
wav_dir = f"{params.res_dir}/{params.suffix}"
|
||||||
|
os.makedirs(wav_dir, exist_ok=True)
|
||||||
|
|
||||||
|
for i in range(audios.shape[0]):
|
||||||
|
audio = audios[i][: (utt_durations[i] - 1) * 256 + 1024]
|
||||||
|
audio = audio.cpu().squeeze().numpy()
|
||||||
|
write(f"{wav_dir}/{cut_ids[i]}.wav", 22050, audio)
|
||||||
|
|
||||||
|
|
||||||
|
def decode_dataset(
|
||||||
|
dl: torch.utils.data.DataLoader,
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
test_set: str,
|
||||||
|
):
|
||||||
|
"""Decode dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dl:
|
||||||
|
PyTorch's dataloader containing the dataset to decode.
|
||||||
|
params:
|
||||||
|
It is returned by :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The text-to-feature neural model.
|
||||||
|
test_set:
|
||||||
|
The name of the test_set
|
||||||
|
"""
|
||||||
|
num_cuts = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
num_batches = len(dl)
|
||||||
|
except TypeError:
|
||||||
|
num_batches = "?"
|
||||||
|
|
||||||
|
with open(f"{params.res_dir}/{test_set}.scp", "w", encoding="utf8") as f:
|
||||||
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
texts = batch["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["cut"]]
|
||||||
|
|
||||||
|
decode_one_batch(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
batch=batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(texts) == len(cut_ids), (len(texts), len(cut_ids))
|
||||||
|
|
||||||
|
for i in range(len(texts)):
|
||||||
|
f.write(f"{cut_ids[i]}\t{texts[i]}\n")
|
||||||
|
|
||||||
|
num_cuts += len(texts)
|
||||||
|
|
||||||
|
if batch_idx % 50 == 0:
|
||||||
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
LJSpeechTtsDataModule.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
params.res_dir = params.exp_dir / params.generate_dir
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||||
|
else:
|
||||||
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
|
if params.use_averaged_model:
|
||||||
|
params.suffix += "-use-averaged-model"
|
||||||
|
|
||||||
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
|
logging.info("Decoding started")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
params.device = device
|
||||||
|
|
||||||
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
fix_random_seed(666)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_model(params)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model = model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
|
ljspeech = LJSpeechTtsDataModule(args)
|
||||||
|
|
||||||
|
test_cuts = ljspeech.test_cuts()
|
||||||
|
|
||||||
|
test_dl = ljspeech.test_dataloaders(test_cuts)
|
||||||
|
|
||||||
|
test_sets = ["test"]
|
||||||
|
test_dls = [test_dl]
|
||||||
|
|
||||||
|
for test_set, test_dl in zip(test_sets, test_dls):
|
||||||
|
decode_dataset(
|
||||||
|
dl=test_dl,
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
test_set=test_set,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
1
egs/ljspeech/TTS/vocos/onnx_pretrained.py
Symbolic link
1
egs/ljspeech/TTS/vocos/onnx_pretrained.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../libritts/TTS/vocos/onnx_pretrained.py
|
1
egs/ljspeech/TTS/vocos/pretrained.py
Symbolic link
1
egs/ljspeech/TTS/vocos/pretrained.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../libritts/TTS/vocos/pretrained.py
|
@ -52,9 +52,11 @@ from utils import (
|
|||||||
save_checkpoint,
|
save_checkpoint,
|
||||||
plot_spectrogram,
|
plot_spectrogram,
|
||||||
get_cosine_schedule_with_warmup,
|
get_cosine_schedule_with_warmup,
|
||||||
|
save_checkpoint_with_global_batch_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
from icefall import diagnostics
|
from icefall import diagnostics
|
||||||
|
from icefall.checkpoint import remove_checkpoints, update_averaged_model
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
@ -65,7 +67,7 @@ from icefall.utils import (
|
|||||||
str2bool,
|
str2bool,
|
||||||
get_parameter_groups_with_lrs,
|
get_parameter_groups_with_lrs,
|
||||||
)
|
)
|
||||||
from models import Vocos
|
from model import Vocos
|
||||||
from lhotse import Fbank, FbankConfig
|
from lhotse import Fbank, FbankConfig
|
||||||
|
|
||||||
|
|
||||||
@ -91,6 +93,20 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
help="Intermediate dim of ConvNeXt module.",
|
help="Intermediate dim of ConvNeXt module.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-seconds",
|
||||||
|
type=int,
|
||||||
|
default=60,
|
||||||
|
help="""
|
||||||
|
The length of the precomputed normalization window sum square
|
||||||
|
(required by istft). This argument is only for onnx export, it determines
|
||||||
|
the max length of the audio that be properly normalized.
|
||||||
|
Note, you can generate audios longer than this value with the exported onnx model,
|
||||||
|
the part longer than this value will not be normalized yet.
|
||||||
|
The larger this value is the bigger the exported onnx model will be.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -203,6 +219,16 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--keep-last-epoch-k",
|
||||||
|
type=int,
|
||||||
|
default=50,
|
||||||
|
help="""Only keep this number of checkpoints on disk.
|
||||||
|
For instance, if it is 3, there are only 3 checkpoints
|
||||||
|
in the exp-dir with filenames `epoch-xxx.pt`.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--average-period",
|
"--average-period",
|
||||||
type=int,
|
type=int,
|
||||||
@ -290,8 +316,8 @@ def get_params() -> AttributeDict:
|
|||||||
"valid_interval": 500,
|
"valid_interval": 500,
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"segment_size": 16384,
|
"segment_size": 16384,
|
||||||
"adam_b1": 0.8,
|
"adam_b1": 0.9,
|
||||||
"adam_b2": 0.9,
|
"adam_b2": 0.99,
|
||||||
"warmup_steps": 0,
|
"warmup_steps": 0,
|
||||||
"max_steps": 2000000,
|
"max_steps": 2000000,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
@ -311,18 +337,17 @@ def get_model(params: AttributeDict) -> nn.Module:
|
|||||||
intermediate_dim=params.intermediate_dim,
|
intermediate_dim=params.intermediate_dim,
|
||||||
num_layers=params.num_layers,
|
num_layers=params.num_layers,
|
||||||
sample_rate=params.sampling_rate,
|
sample_rate=params.sampling_rate,
|
||||||
|
max_seconds=params.max_seconds,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
num_param_head = sum([p.numel() for p in model.head.parameters()])
|
num_param_gen = sum([p.numel() for p in model.generator.parameters()])
|
||||||
logging.info(f"Number of Head parameters : {num_param_head}")
|
logging.info(f"Number of Generator parameters : {num_param_gen}")
|
||||||
num_param_bone = sum([p.numel() for p in model.backbone.parameters()])
|
|
||||||
logging.info(f"Number of Generator parameters : {num_param_bone}")
|
|
||||||
num_param_mpd = sum([p.numel() for p in model.mpd.parameters()])
|
num_param_mpd = sum([p.numel() for p in model.mpd.parameters()])
|
||||||
logging.info(f"Number of MultiPeriodDiscriminator parameters : {num_param_mpd}")
|
logging.info(f"Number of MultiPeriodDiscriminator parameters : {num_param_mpd}")
|
||||||
num_param_mrd = sum([p.numel() for p in model.mrd.parameters()])
|
num_param_mrd = sum([p.numel() for p in model.mrd.parameters()])
|
||||||
logging.info(f"Number of MultiResolutionDiscriminator parameters : {num_param_mrd}")
|
logging.info(f"Number of MultiResolutionDiscriminator parameters : {num_param_mrd}")
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Number of model parameters : {num_param_head + num_param_bone + num_param_mpd + num_param_mrd}"
|
f"Number of model parameters : {num_param_gen + num_param_mpd + num_param_mrd}"
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@ -481,11 +506,6 @@ def compute_discriminator_loss(
|
|||||||
info["loss_disc_mrd"] = loss_mrd.detach().cpu().item()
|
info["loss_disc_mrd"] = loss_mrd.detach().cpu().item()
|
||||||
info["loss_disc_mpd"] = loss_mpd.detach().cpu().item()
|
info["loss_disc_mpd"] = loss_mpd.detach().cpu().item()
|
||||||
|
|
||||||
for i in range(len(loss_mpd_real)):
|
|
||||||
info[f"loss_disc_mpd_period_{i+1}"] = loss_mpd_real[i] + loss_mpd_gen[i]
|
|
||||||
for i in range(len(loss_mrd_real)):
|
|
||||||
info[f"loss_disc_mrd_resolution_{i+1}"] = loss_mrd_real[i] + loss_mrd_gen[i]
|
|
||||||
|
|
||||||
return loss_disc_all, info
|
return loss_disc_all, info
|
||||||
|
|
||||||
|
|
||||||
@ -499,6 +519,7 @@ def train_one_epoch(
|
|||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
scaler: GradScaler,
|
scaler: GradScaler,
|
||||||
|
model_avg: Optional[nn.Module] = None,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
@ -544,6 +565,7 @@ def train_one_epoch(
|
|||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
|
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
|
||||||
model=model,
|
model=model,
|
||||||
|
model_avg=model_avg,
|
||||||
params=params,
|
params=params,
|
||||||
optimizer_g=optimizer_g,
|
optimizer_g=optimizer_g,
|
||||||
optimizer_d=optimizer_d,
|
optimizer_d=optimizer_d,
|
||||||
@ -566,10 +588,6 @@ def train_one_epoch(
|
|||||||
params.segment_size - params.frame_length
|
params.segment_size - params.frame_length
|
||||||
) // params.frame_shift + 1
|
) // params.frame_shift + 1
|
||||||
|
|
||||||
# segment_frames = (
|
|
||||||
# params.segment_size + params.frame_shift // 2
|
|
||||||
# ) // params.frame_shift
|
|
||||||
|
|
||||||
start_p = random.randint(0, features_lens.min() - (segment_frames + 1))
|
start_p = random.randint(0, features_lens.min() - (segment_frames + 1))
|
||||||
|
|
||||||
features = features[:, start_p : start_p + segment_frames, :].permute(
|
features = features[:, start_p : start_p + segment_frames, :].permute(
|
||||||
@ -594,6 +612,7 @@ def train_one_epoch(
|
|||||||
|
|
||||||
loss_disc.backward()
|
loss_disc.backward()
|
||||||
optimizer_d.step()
|
optimizer_d.step()
|
||||||
|
scheduler_d.step()
|
||||||
|
|
||||||
optimizer_g.zero_grad()
|
optimizer_g.zero_grad()
|
||||||
loss_gen, loss_gen_info = compute_generator_loss(
|
loss_gen, loss_gen_info = compute_generator_loss(
|
||||||
@ -605,6 +624,7 @@ def train_one_epoch(
|
|||||||
|
|
||||||
loss_gen.backward()
|
loss_gen.backward()
|
||||||
optimizer_g.step()
|
optimizer_g.step()
|
||||||
|
scheduler_g.step()
|
||||||
|
|
||||||
loss_info = loss_gen_info + loss_disc_info
|
loss_info = loss_gen_info + loss_disc_info
|
||||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_gen_info
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_gen_info
|
||||||
@ -617,6 +637,39 @@ def train_one_epoch(
|
|||||||
if params.print_diagnostics and batch_idx == 5:
|
if params.print_diagnostics and batch_idx == 5:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if (
|
||||||
|
rank == 0
|
||||||
|
and params.batch_idx_train > 0
|
||||||
|
and params.batch_idx_train % params.average_period == 0
|
||||||
|
):
|
||||||
|
update_averaged_model(
|
||||||
|
params=params,
|
||||||
|
model_cur=model,
|
||||||
|
model_avg=model_avg,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
params.batch_idx_train > 0
|
||||||
|
and params.batch_idx_train % params.save_every_n == 0
|
||||||
|
):
|
||||||
|
save_checkpoint_with_global_batch_idx(
|
||||||
|
out_dir=params.exp_dir,
|
||||||
|
global_batch_idx=params.batch_idx_train,
|
||||||
|
model=model,
|
||||||
|
model_avg=model_avg,
|
||||||
|
params=params,
|
||||||
|
optimizer=optimizer,
|
||||||
|
scheduler=scheduler,
|
||||||
|
sampler=train_dl.sampler,
|
||||||
|
scaler=scaler,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
remove_checkpoints(
|
||||||
|
out_dir=params.exp_dir,
|
||||||
|
topk=params.keep_last_k,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
|
||||||
if params.batch_idx_train % 100 == 0 and params.use_fp16:
|
if params.batch_idx_train % 100 == 0 and params.use_fp16:
|
||||||
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
||||||
# of the grad scaler is configurable, but we can't configure it to have different
|
# of the grad scaler is configurable, but we can't configure it to have different
|
||||||
@ -647,8 +700,8 @@ def train_one_epoch(
|
|||||||
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
||||||
f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, "
|
f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, "
|
||||||
f"loss[{loss_info}], tot_loss[{tot_loss}], "
|
f"loss[{loss_info}], tot_loss[{tot_loss}], "
|
||||||
f"cur_lr_g: {cur_lr_g:.2e}, "
|
f"cur_lr_g: {cur_lr_g:.4e}, "
|
||||||
f"cur_lr_d: {cur_lr_d:.2e}, "
|
f"cur_lr_d: {cur_lr_d:.4e}, "
|
||||||
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -668,11 +721,10 @@ def train_one_epoch(
|
|||||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||||
)
|
)
|
||||||
|
|
||||||
# if (
|
if (
|
||||||
# params.batch_idx_train % params.valid_interval == 0
|
params.batch_idx_train % params.valid_interval == 0
|
||||||
# and not params.print_diagnostics
|
and not params.print_diagnostics
|
||||||
# ):
|
):
|
||||||
if True:
|
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
valid_info = compute_validation_loss(
|
valid_info = compute_validation_loss(
|
||||||
params=params,
|
params=params,
|
||||||
@ -692,8 +744,6 @@ def train_one_epoch(
|
|||||||
tb_writer, "train/valid_", params.batch_idx_train
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
|
|
||||||
scheduler_g.step()
|
|
||||||
scheduler_d.step()
|
|
||||||
loss_value = tot_loss["loss_gen"]
|
loss_value = tot_loss["loss_gen"]
|
||||||
params.train_loss = loss_value
|
params.train_loss = loss_value
|
||||||
if params.train_loss < params.best_train_loss:
|
if params.train_loss < params.best_train_loss:
|
||||||
@ -773,7 +823,7 @@ def compute_validation_loss(
|
|||||||
params.sampling_rate,
|
params.sampling_rate,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info(f"RTF : {infer_time / (audio_time * 10 / 1000)}")
|
logging.info(f"Validation RTF : {infer_time / (audio_time * 10 / 1000)}")
|
||||||
|
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
tot_loss.reduce(device)
|
tot_loss.reduce(device)
|
||||||
@ -818,19 +868,25 @@ def run(rank, world_size, args):
|
|||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda", rank)
|
device = torch.device("cuda", rank)
|
||||||
logging.info(f"Device: {device}")
|
|
||||||
params.device = device
|
params.device = device
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
logging.info("About to create model")
|
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
model = get_model(params)
|
model = get_model(params)
|
||||||
|
|
||||||
|
assert params.save_every_n >= params.average_period
|
||||||
|
model_avg: Optional[nn.Module] = None
|
||||||
|
if rank == 0:
|
||||||
|
# model_avg is only used with rank 0
|
||||||
|
model_avg = copy.deepcopy(model).to(torch.float64)
|
||||||
|
|
||||||
assert params.start_epoch > 0, params.start_epoch
|
assert params.start_epoch > 0, params.start_epoch
|
||||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
checkpoints = load_checkpoint_if_available(
|
||||||
|
params=params, model=model, model_avg=model_avg
|
||||||
|
)
|
||||||
|
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
head = model.head
|
generator = model.generator
|
||||||
backbone = model.backbone
|
|
||||||
mrd = model.mrd
|
mrd = model.mrd
|
||||||
mpd = model.mpd
|
mpd = model.mpd
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
@ -838,7 +894,7 @@ def run(rank, world_size, args):
|
|||||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||||
|
|
||||||
optimizer_g = torch.optim.AdamW(
|
optimizer_g = torch.optim.AdamW(
|
||||||
itertools.chain(head.parameters(), backbone.parameters()),
|
generator.parameters(),
|
||||||
params.learning_rate,
|
params.learning_rate,
|
||||||
betas=[params.adam_b1, params.adam_b2],
|
betas=[params.adam_b1, params.adam_b2],
|
||||||
)
|
)
|
||||||
@ -923,6 +979,7 @@ def run(rank, world_size, args):
|
|||||||
train_one_epoch(
|
train_one_epoch(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
model_avg=model_avg,
|
||||||
optimizer_g=optimizer_g,
|
optimizer_g=optimizer_g,
|
||||||
optimizer_d=optimizer_d,
|
optimizer_d=optimizer_d,
|
||||||
scheduler_g=scheduler_g,
|
scheduler_g=scheduler_g,
|
||||||
@ -944,6 +1001,7 @@ def run(rank, world_size, args):
|
|||||||
filename=filename,
|
filename=filename,
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
model_avg=model_avg,
|
||||||
optimizer_g=optimizer_g,
|
optimizer_g=optimizer_g,
|
||||||
optimizer_d=optimizer_d,
|
optimizer_d=optimizer_d,
|
||||||
scheduler_g=scheduler_g,
|
scheduler_g=scheduler_g,
|
||||||
@ -953,21 +1011,6 @@ def run(rank, world_size, args):
|
|||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.batch_idx_train % params.save_every_n == 0:
|
|
||||||
filename = params.exp_dir / f"checkpoint-{params.batch_idx_train}.pt"
|
|
||||||
save_checkpoint(
|
|
||||||
filename=filename,
|
|
||||||
params=params,
|
|
||||||
model=model,
|
|
||||||
optimizer_g=optimizer_g,
|
|
||||||
optimizer_d=optimizer_d,
|
|
||||||
scheduler_g=scheduler_g,
|
|
||||||
scheduler_d=scheduler_d,
|
|
||||||
sampler=train_dl.sampler,
|
|
||||||
scaler=scaler,
|
|
||||||
rank=rank,
|
|
||||||
)
|
|
||||||
if rank == 0:
|
|
||||||
if params.best_train_epoch == params.cur_epoch:
|
if params.best_train_epoch == params.cur_epoch:
|
||||||
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
||||||
copyfile(src=filename, dst=best_train_filename)
|
copyfile(src=filename, dst=best_train_filename)
|
||||||
@ -976,6 +1019,13 @@ def run(rank, world_size, args):
|
|||||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||||
copyfile(src=filename, dst=best_valid_filename)
|
copyfile(src=filename, dst=best_valid_filename)
|
||||||
|
|
||||||
|
remove_checkpoints(
|
||||||
|
out_dir=params.exp_dir,
|
||||||
|
topk=params.keep_last_epoch_k,
|
||||||
|
prefix="epoch",
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
@ -997,7 +1047,8 @@ def main():
|
|||||||
run(rank=0, world_size=1, args=args)
|
run(rank=0, world_size=1, args=args)
|
||||||
|
|
||||||
|
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
torch.set_num_threads(1)
|
|
||||||
torch.set_num_interop_threads(1)
|
|
||||||
main()
|
main()
|
||||||
|
@ -250,18 +250,22 @@ def save_checkpoint_with_global_batch_idx(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
|
def find_checkpoints(
|
||||||
|
out_dir: Path,
|
||||||
|
iteration: int = 0,
|
||||||
|
prefix: str = "checkpoint",
|
||||||
|
) -> List[str]:
|
||||||
"""Find all available checkpoints in a directory.
|
"""Find all available checkpoints in a directory.
|
||||||
|
|
||||||
The checkpoint filenames have the form: `checkpoint-xxx.pt`
|
The checkpoint filenames have the form: `{prefix}-xxx.pt`
|
||||||
where xxx is a numerical value.
|
where xxx is a numerical value.
|
||||||
|
|
||||||
Assume you have the following checkpoints in the folder `foo`:
|
Assume you have the following checkpoints in the folder `foo`:
|
||||||
|
|
||||||
- checkpoint-1.pt
|
- {prefix}-1.pt
|
||||||
- checkpoint-20.pt
|
- {prefix}-20.pt
|
||||||
- checkpoint-300.pt
|
- {prefix}-300.pt
|
||||||
- checkpoint-4000.pt
|
- {prefix}-4000.pt
|
||||||
|
|
||||||
Case 1 (Return all checkpoints)::
|
Case 1 (Return all checkpoints)::
|
||||||
|
|
||||||
@ -290,8 +294,8 @@ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
|
|||||||
Return a list of checkpoint filenames, sorted in descending
|
Return a list of checkpoint filenames, sorted in descending
|
||||||
order by the numerical value in the filename.
|
order by the numerical value in the filename.
|
||||||
"""
|
"""
|
||||||
checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
|
checkpoints = list(glob.glob(f"{out_dir}/{prefix}-[0-9]*.pt"))
|
||||||
pattern = re.compile(r"checkpoint-([0-9]+).pt")
|
pattern = re.compile(rf"{prefix}-([0-9]+).pt")
|
||||||
iter_checkpoints = []
|
iter_checkpoints = []
|
||||||
for c in checkpoints:
|
for c in checkpoints:
|
||||||
result = pattern.search(c)
|
result = pattern.search(c)
|
||||||
@ -316,12 +320,13 @@ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
|
|||||||
def remove_checkpoints(
|
def remove_checkpoints(
|
||||||
out_dir: Path,
|
out_dir: Path,
|
||||||
topk: int,
|
topk: int,
|
||||||
|
prefix: str = "checkpoint",
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
):
|
):
|
||||||
"""Remove checkpoints from the given directory.
|
"""Remove checkpoints from the given directory.
|
||||||
|
|
||||||
We assume that checkpoint filename has the form `checkpoint-xxx.pt`
|
We assume that checkpoint filename has the form `{prefix}-xxx.pt`
|
||||||
where xxx is a number, representing the number of processed batches
|
where xxx is a number, representing the number of processed batches/epochs
|
||||||
when saving that checkpoint. We sort checkpoints by filename and keep
|
when saving that checkpoint. We sort checkpoints by filename and keep
|
||||||
only the `topk` checkpoints with the highest `xxx`.
|
only the `topk` checkpoints with the highest `xxx`.
|
||||||
|
|
||||||
@ -330,6 +335,8 @@ def remove_checkpoints(
|
|||||||
The directory containing checkpoints to be removed.
|
The directory containing checkpoints to be removed.
|
||||||
topk:
|
topk:
|
||||||
Number of checkpoints to keep.
|
Number of checkpoints to keep.
|
||||||
|
prefix:
|
||||||
|
The prefix of the checkpoint filename, normally `epoch`, `checkpoint`.
|
||||||
rank:
|
rank:
|
||||||
If using DDP for training, it is the rank of the current node.
|
If using DDP for training, it is the rank of the current node.
|
||||||
Use 0 if no DDP is used for training.
|
Use 0 if no DDP is used for training.
|
||||||
@ -337,7 +344,7 @@ def remove_checkpoints(
|
|||||||
assert topk >= 1, topk
|
assert topk >= 1, topk
|
||||||
if rank != 0:
|
if rank != 0:
|
||||||
return
|
return
|
||||||
checkpoints = find_checkpoints(out_dir)
|
checkpoints = find_checkpoints(out_dir, prefix=prefix)
|
||||||
|
|
||||||
if len(checkpoints) == 0:
|
if len(checkpoints) == 0:
|
||||||
logging.warn(f"No checkpoints found in {out_dir}")
|
logging.warn(f"No checkpoints found in {out_dir}")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user