zr_jin e8b6b920c0
A LibriTTS recipe on both ASR & Neural Codec Tasks (#1746)
* added ASR & CODEC recipes for LibriTTS corpus
2024-10-21 11:30:14 +08:00

353 lines
11 KiB
Python
Executable File

#!/usr/bin/env python3
#
# Copyright 2024 The Chinese University of HK (Author: Zengrui Jin)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script performs model inference on test set.
Usage:
./codec/infer.py \
--epoch 300 \
--exp-dir ./codec/exp \
--max-duration 500
"""
import argparse
import logging
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from statistics import mean
from typing import List, Tuple
import numpy as np
import torch
import torchaudio
from codec_datamodule import LibriTTSCodecDataModule
from pesq import pesq
from pystoi import stoi
from scipy import signal
from torch import nn
from train import get_model, get_params
from icefall.checkpoint import load_checkpoint
from icefall.utils import AttributeDict, setup_logger
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=1000,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="encodec/exp",
help="The experiment dir",
)
parser.add_argument(
"--target-bw",
type=float,
default=24,
help="The target bandwidth for the generator",
)
return parser
# implementation from https://github.com/yangdongchao/AcademiCodec/blob/master/academicodec/models/encodec/test.py
def remove_encodec_weight_norm(model) -> None:
from modules import SConv1d
from modules.seanet import SConvTranspose1d, SEANetResnetBlock
from torch.nn.utils import remove_weight_norm
encoder = model.encoder.model
for key in encoder._modules:
if isinstance(encoder._modules[key], SEANetResnetBlock):
remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
block_modules = encoder._modules[key].block._modules
for skey in block_modules:
if isinstance(block_modules[skey], SConv1d):
remove_weight_norm(block_modules[skey].conv.conv)
elif isinstance(encoder._modules[key], SConv1d):
remove_weight_norm(encoder._modules[key].conv.conv)
decoder = model.decoder.model
for key in decoder._modules:
if isinstance(decoder._modules[key], SEANetResnetBlock):
remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
block_modules = decoder._modules[key].block._modules
for skey in block_modules:
if isinstance(block_modules[skey], SConv1d):
remove_weight_norm(block_modules[skey].conv.conv)
elif isinstance(decoder._modules[key], SConvTranspose1d):
remove_weight_norm(decoder._modules[key].convtr.convtr)
elif isinstance(decoder._modules[key], SConv1d):
remove_weight_norm(decoder._modules[key].conv.conv)
def compute_pesq(ref_wav: np.ndarray, gen_wav: np.ndarray) -> float:
"""Compute PESQ score between reference and generated audio."""
DEFAULT_SAMPLING_RATE = 16000
ref = signal.resample(ref_wav, DEFAULT_SAMPLING_RATE)
deg = signal.resample(gen_wav, DEFAULT_SAMPLING_RATE)
return pesq(fs=DEFAULT_SAMPLING_RATE, ref=ref, deg=deg, mode="wb")
def compute_stoi(ref_wav: np.ndarray, gen_wav: np.ndarray, sampling_rate: int) -> float:
"""Compute STOI score between reference and generated audio."""
return stoi(x=ref_wav, y=gen_wav, fs_sig=sampling_rate, extended=False)
def infer_dataset(
dl: torch.utils.data.DataLoader,
subset: str,
params: AttributeDict,
model: nn.Module,
) -> Tuple[float, float]:
"""Decode dataset.
The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
subset:
The name of the subset.
params:
It is returned by :func:`get_params`.
model:
The neural model.
Returns:
The average PESQ and STOI scores.
"""
# Background worker save audios to disk.
def _save_worker(
subset: str,
batch_size: int,
cut_ids: List[str],
audio: torch.Tensor,
audio_pred: torch.Tensor,
audio_lens: List[int],
):
for i in range(batch_size):
torchaudio.save(
str(params.save_wav_dir / subset / f"{cut_ids[i]}_gt.wav"),
audio[i : i + 1, : audio_lens[i]],
sample_rate=params.sampling_rate,
)
torchaudio.save(
str(params.save_wav_dir / subset / f"{cut_ids[i]}_recon.wav"),
audio_pred[i : i + 1, : audio_lens[i]],
sample_rate=params.sampling_rate,
)
device = next(model.parameters()).device
num_cuts = 0
log_interval = 5
pesq_wb_scores = []
stoi_scores = []
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
futures = []
with ThreadPoolExecutor(max_workers=1) as executor:
for batch_idx, batch in enumerate(dl):
batch_size = len(batch["audio"])
audios = batch["audio"]
audio_lens = batch["audio_lens"].tolist()
cut_ids = [cut.id for cut in batch["cut"]]
codes, audio_hats = model.inference(
audios.to(device), target_bw=params.target_bw
)
audio_hats = audio_hats.squeeze(1).cpu()
for cut_id, audio, audio_hat, audio_len in zip(
cut_ids, audios, audio_hats, audio_lens
):
try:
pesq_wb = compute_pesq(
ref_wav=audio[:audio_len].numpy(),
gen_wav=audio_hat[:audio_len].numpy(),
)
pesq_wb_scores.append(pesq_wb)
except Exception as e:
logging.error(f"Error while computing PESQ for cut {cut_id}: {e}")
stoi_score = compute_stoi(
ref_wav=audio[:audio_len].numpy(),
gen_wav=audio_hat[:audio_len].numpy(),
sampling_rate=params.sampling_rate,
)
stoi_scores.append(stoi_score)
futures.append(
executor.submit(
_save_worker,
subset,
batch_size,
cut_ids,
audios,
audio_hats,
audio_lens,
)
)
num_cuts += batch_size
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
# return results
for f in futures:
f.result()
return mean(pesq_wb_scores), mean(stoi_scores)
@torch.no_grad()
def main():
parser = get_parser()
LibriTTSCodecDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.suffix = f"epoch-{params.epoch}"
params.res_dir = params.exp_dir / "infer" / params.suffix
params.save_wav_dir = params.res_dir / "wav"
params.save_wav_dir.mkdir(parents=True, exist_ok=True)
setup_logger(f"{params.res_dir}/log-infer-{params.suffix}")
logging.info("Infer started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
# we need cut ids to display results of both constructed and ground-truth audio
args.return_cuts = True
libritts = LibriTTSCodecDataModule(args)
logging.info(f"Device: {device}")
logging.info(params)
logging.info("About to create model")
model = get_model(params)
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
remove_encodec_weight_norm(model)
model.to(device)
model.eval()
encoder = model.encoder
decoder = model.decoder
quantizer = model.quantizer
multi_scale_discriminator = model.multi_scale_discriminator
multi_period_discriminator = model.multi_period_discriminator
multi_scale_stft_discriminator = model.multi_scale_stft_discriminator
num_param_e = sum([p.numel() for p in encoder.parameters()])
logging.info(f"Number of parameters in encoder: {num_param_e}")
num_param_d = sum([p.numel() for p in decoder.parameters()])
logging.info(f"Number of parameters in decoder: {num_param_d}")
num_param_q = sum([p.numel() for p in quantizer.parameters()])
logging.info(f"Number of parameters in quantizer: {num_param_q}")
num_param_ds = (
sum([p.numel() for p in multi_scale_discriminator.parameters()])
if multi_scale_discriminator is not None
else 0
)
logging.info(f"Number of parameters in multi_scale_discriminator: {num_param_ds}")
num_param_dp = (
sum([p.numel() for p in multi_period_discriminator.parameters()])
if multi_period_discriminator is not None
else 0
)
logging.info(f"Number of parameters in multi_period_discriminator: {num_param_dp}")
num_param_dstft = sum(
[p.numel() for p in multi_scale_stft_discriminator.parameters()]
)
logging.info(
f"Number of parameters in multi_scale_stft_discriminator: {num_param_dstft}"
)
logging.info(
f"Total number of parameters: {num_param_e + num_param_d + num_param_q + num_param_ds + num_param_dp + num_param_dstft}"
)
test_clean_cuts = libritts.test_clean_cuts()
test_clean = libritts.test_dataloaders(test_clean_cuts)
test_other_cuts = libritts.test_other_cuts()
test_other = libritts.test_dataloaders(test_other_cuts)
dev_clean_cuts = libritts.dev_clean_cuts()
dev_clean = libritts.valid_dataloaders(dev_clean_cuts)
dev_other_cuts = libritts.dev_other_cuts()
dev_other = libritts.valid_dataloaders(dev_other_cuts)
infer_sets = {
"test-clean": test_clean,
"test-other": test_other,
"dev-clean": dev_clean,
"dev-other": dev_other,
}
for subset, dl in infer_sets.items():
save_wav_dir = params.res_dir / "wav" / subset
save_wav_dir.mkdir(parents=True, exist_ok=True)
logging.info(f"Processing {subset} set, saving to {save_wav_dir}")
pesq_wb, stoi = infer_dataset(
dl=dl,
subset=subset,
params=params,
model=model,
)
logging.info(f"{subset}: PESQ-WB: {pesq_wb:.4f}, STOI: {stoi:.4f}")
logging.info(f"Wav files are saved to {params.save_wav_dir}")
logging.info("Done!")
if __name__ == "__main__":
main()