#!/usr/bin/env python3 # # Copyright 2024 The Chinese University of HK (Author: Zengrui Jin) # # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ This script performs model inference on test set. Usage: ./vits/infer.py \ --epoch 1000 \ --exp-dir ./vits/exp \ --max-duration 500 """ import argparse import logging from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Dict, List import torch import torch.nn.functional as F import torchaudio from codec_datamodule import LibriTTSCodecDataModule from torch import nn from train import get_model, get_params from icefall.checkpoint import load_checkpoint from icefall.utils import AttributeDict, setup_logger def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument( "--epoch", type=int, default=1000, help="""It specifies the checkpoint to use for decoding. Note: Epoch counts from 1. """, ) parser.add_argument( "--exp-dir", type=str, default="encodec/exp", help="The experiment dir", ) parser.add_argument( "--target-bw", type=float, default=7.5, help="The target bandwidth for the generator", ) return parser # implementation from https://github.com/yangdongchao/AcademiCodec/blob/master/academicodec/models/encodec/test.py def remove_encodec_weight_norm(model) -> None: from modules import SConv1d from modules.seanet import SConvTranspose1d, SEANetResnetBlock from torch.nn.utils import remove_weight_norm encoder = model.encoder.model for key in encoder._modules: if isinstance(encoder._modules[key], SEANetResnetBlock): remove_weight_norm(encoder._modules[key].shortcut.conv.conv) block_modules = encoder._modules[key].block._modules for skey in block_modules: if isinstance(block_modules[skey], SConv1d): remove_weight_norm(block_modules[skey].conv.conv) elif isinstance(encoder._modules[key], SConv1d): remove_weight_norm(encoder._modules[key].conv.conv) decoder = model.decoder.model for key in decoder._modules: if isinstance(decoder._modules[key], SEANetResnetBlock): remove_weight_norm(decoder._modules[key].shortcut.conv.conv) block_modules = decoder._modules[key].block._modules for skey in block_modules: if isinstance(block_modules[skey], SConv1d): remove_weight_norm(block_modules[skey].conv.conv) elif isinstance(decoder._modules[key], SConvTranspose1d): remove_weight_norm(decoder._modules[key].convtr.convtr) elif isinstance(decoder._modules[key], SConv1d): remove_weight_norm(decoder._modules[key].conv.conv) def infer_dataset( dl: torch.utils.data.DataLoader, subset: str, params: AttributeDict, model: nn.Module, ) -> None: """Decode dataset. The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`. Args: dl: PyTorch's dataloader containing the dataset to decode. subset: The name of the subset. params: It is returned by :func:`get_params`. model: The neural model. """ # Background worker save audios to disk. def _save_worker( subset: str, batch_size: int, cut_ids: List[str], audio: torch.Tensor, audio_pred: torch.Tensor, audio_lens: List[int], ): for i in range(batch_size): torchaudio.save( str(params.save_wav_dir / subset / f"{cut_ids[i]}_gt.wav"), audio[i : i + 1, : audio_lens[i]], sample_rate=params.sampling_rate, ) torchaudio.save( str(params.save_wav_dir / subset / f"{cut_ids[i]}_recon.wav"), audio_pred[i : i + 1, : audio_lens[i]], sample_rate=params.sampling_rate, ) device = next(model.parameters()).device num_cuts = 0 log_interval = 5 try: num_batches = len(dl) except TypeError: num_batches = "?" futures = [] with ThreadPoolExecutor(max_workers=1) as executor: for batch_idx, batch in enumerate(dl): batch_size = len(batch["audio"]) audios = batch["audio"] audio_lens = batch["audio_lens"].tolist() cut_ids = [cut.id for cut in batch["cut"]] codes, audio_hats = model.inference( audios.to(device), target_bw=params.target_bw ) audio_hats = audio_hats.squeeze(1).cpu() futures.append( executor.submit( _save_worker, subset, batch_size, cut_ids, audios, audio_hats, audio_lens, ) ) num_cuts += batch_size if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" logging.info( f"batch {batch_str}, cuts processed until now is {num_cuts}" ) # return results for f in futures: f.result() @torch.no_grad() def main(): parser = get_parser() LibriTTSCodecDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) params = get_params() params.update(vars(args)) params.suffix = f"epoch-{params.epoch}" params.res_dir = params.exp_dir / "infer" / params.suffix params.save_wav_dir = params.res_dir / "wav" params.save_wav_dir.mkdir(parents=True, exist_ok=True) setup_logger(f"{params.res_dir}/log-infer-{params.suffix}") logging.info("Infer started") device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", 0) # we need cut ids to display results of both constructed and ground-truth audio args.return_cuts = True libritts = LibriTTSCodecDataModule(args) logging.info(f"Device: {device}") logging.info(params) logging.info("About to create model") model = get_model(params) load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) remove_encodec_weight_norm(model) model.to(device) model.eval() encoder = model.encoder decoder = model.decoder quantizer = model.quantizer multi_scale_discriminator = model.multi_scale_discriminator multi_period_discriminator = model.multi_period_discriminator multi_scale_stft_discriminator = model.multi_scale_stft_discriminator num_param_e = sum([p.numel() for p in encoder.parameters()]) logging.info(f"Number of parameters in encoder: {num_param_e}") num_param_d = sum([p.numel() for p in decoder.parameters()]) logging.info(f"Number of parameters in decoder: {num_param_d}") num_param_q = sum([p.numel() for p in quantizer.parameters()]) logging.info(f"Number of parameters in quantizer: {num_param_q}") num_param_ds = sum([p.numel() for p in multi_scale_discriminator.parameters()]) logging.info(f"Number of parameters in multi_scale_discriminator: {num_param_ds}") num_param_dp = sum([p.numel() for p in multi_period_discriminator.parameters()]) logging.info(f"Number of parameters in multi_period_discriminator: {num_param_dp}") num_param_dstft = sum( [p.numel() for p in multi_scale_stft_discriminator.parameters()] ) logging.info( f"Number of parameters in multi_scale_stft_discriminator: {num_param_dstft}" ) logging.info( f"Total number of parameters: {num_param_e + num_param_d + num_param_q + num_param_ds + num_param_dp + num_param_dstft}" ) test_clean_cuts = libritts.test_clean_cuts() test_clean = libritts.test_dataloaders(test_clean_cuts) test_other_cuts = libritts.test_other_cuts() test_other = libritts.test_dataloaders(test_other_cuts) dev_clean_cuts = libritts.dev_clean_cuts() dev_clean = libritts.valid_dataloaders(dev_clean_cuts) dev_other_cuts = libritts.dev_other_cuts() dev_other = libritts.valid_dataloaders(dev_other_cuts) infer_sets = { "test-clean": test_clean, "test-other": test_other, "dev-clean": dev_clean, "dev-other": dev_other, } for subset, dl in infer_sets.items(): save_wav_dir = params.res_dir / "wav" / subset save_wav_dir.mkdir(parents=True, exist_ok=True) logging.info(f"Processing {subset} set, saving to {save_wav_dir}") infer_dataset( dl=dl, subset=subset, params=params, model=model, ) logging.info(f"Wav files are saved to {params.save_wav_dir}") logging.info("Done!") if __name__ == "__main__": main()