diff --git a/egs/libritts/CODEC/encodec/encodec.py b/egs/libritts/CODEC/encodec/encodec.py index 385551d06..4f45be9c2 100644 --- a/egs/libritts/CODEC/encodec/encodec.py +++ b/egs/libritts/CODEC/encodec/encodec.py @@ -267,13 +267,13 @@ class Encodec(nn.Module): def decode(self, codes): quantized = self.quantizer.decode(codes) - o = self.decoder(quantized) - return o + x_hat = self.decoder(quantized) + return x_hat def inference(self, x, target_bw=None, st=None): # setup x = x.unsqueeze(1) codes = self.encode(x, target_bw, st) - o = self.decode(codes) - return o + x_hat = self.decode(codes) + return codes, x_hat diff --git a/egs/libritts/CODEC/encodec/infer.py b/egs/libritts/CODEC/encodec/infer.py new file mode 100755 index 000000000..dccff984d --- /dev/null +++ b/egs/libritts/CODEC/encodec/infer.py @@ -0,0 +1,300 @@ +#!/usr/bin/env python3 +# +# Copyright 2024 The Chinese University of HK (Author: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script performs model inference on test set. + +Usage: +./vits/infer.py \ + --epoch 1000 \ + --exp-dir ./vits/exp \ + --max-duration 500 +""" + + +import argparse +import logging +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import Dict, List + +import torch +import torch.nn.functional as F +import torchaudio +from codec_datamodule import LibriTTSCodecDataModule +from torch import nn +from train import get_model, get_params + +from icefall.checkpoint import load_checkpoint +from icefall.utils import AttributeDict, setup_logger + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=1000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="encodec/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--target-bw", + type=float, + default=7.5, + help="The target bandwidth for the generator", + ) + + return parser + + +# implementation from https://github.com/yangdongchao/AcademiCodec/blob/master/academicodec/models/encodec/test.py +def remove_encodec_weight_norm(model) -> None: + from modules import SConv1d + from modules.seanet import SConvTranspose1d, SEANetResnetBlock + from torch.nn.utils import remove_weight_norm + + encoder = model.encoder.model + for key in encoder._modules: + if isinstance(encoder._modules[key], SEANetResnetBlock): + remove_weight_norm(encoder._modules[key].shortcut.conv.conv) + block_modules = encoder._modules[key].block._modules + for skey in block_modules: + if isinstance(block_modules[skey], SConv1d): + remove_weight_norm(block_modules[skey].conv.conv) + elif isinstance(encoder._modules[key], SConv1d): + remove_weight_norm(encoder._modules[key].conv.conv) + + decoder = model.decoder.model + for key in decoder._modules: + if isinstance(decoder._modules[key], SEANetResnetBlock): + remove_weight_norm(decoder._modules[key].shortcut.conv.conv) + block_modules = decoder._modules[key].block._modules + for skey in block_modules: + if isinstance(block_modules[skey], SConv1d): + remove_weight_norm(block_modules[skey].conv.conv) + elif isinstance(decoder._modules[key], SConvTranspose1d): + remove_weight_norm(decoder._modules[key].convtr.convtr) + elif isinstance(decoder._modules[key], SConv1d): + remove_weight_norm(decoder._modules[key].conv.conv) + + +def infer_dataset( + dl: torch.utils.data.DataLoader, + subset: str, + params: AttributeDict, + model: nn.Module, +) -> None: + """Decode dataset. + The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + subset: + The name of the subset. + params: + It is returned by :func:`get_params`. + model: + The neural model. + """ + + # Background worker save audios to disk. + def _save_worker( + subset: str, + batch_size: int, + cut_ids: List[str], + audio: torch.Tensor, + audio_pred: torch.Tensor, + audio_lens: List[int], + ): + for i in range(batch_size): + torchaudio.save( + str(params.save_wav_dir / subset / f"{cut_ids[i]}_gt.wav"), + audio[i : i + 1, : audio_lens[i]], + sample_rate=params.sampling_rate, + ) + torchaudio.save( + str(params.save_wav_dir / subset / f"{cut_ids[i]}_recon.wav"), + audio_pred[i : i + 1, : audio_lens[i]], + sample_rate=params.sampling_rate, + ) + + device = next(model.parameters()).device + num_cuts = 0 + log_interval = 5 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + futures = [] + with ThreadPoolExecutor(max_workers=1) as executor: + for batch_idx, batch in enumerate(dl): + batch_size = len(batch["audio"]) + + audios = batch["audio"] + audio_lens = batch["audio_lens"].tolist() + cut_ids = [cut.id for cut in batch["cut"]] + + codes, audio_hats = model.inference( + audios.to(device), target_bw=params.target_bw + ) + audio_hats = audio_hats.squeeze(1).cpu() + + futures.append( + executor.submit( + _save_worker, + subset, + batch_size, + cut_ids, + audios, + audio_hats, + audio_lens, + ) + ) + + num_cuts += batch_size + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + # return results + for f in futures: + f.result() + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriTTSCodecDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.suffix = f"epoch-{params.epoch}" + + params.res_dir = params.exp_dir / "infer" / params.suffix + params.save_wav_dir = params.res_dir / "wav" + params.save_wav_dir.mkdir(parents=True, exist_ok=True) + + setup_logger(f"{params.res_dir}/log-infer-{params.suffix}") + logging.info("Infer started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + # we need cut ids to display results of both constructed and ground-truth audio + args.return_cuts = True + libritts = LibriTTSCodecDataModule(args) + + logging.info(f"Device: {device}") + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + remove_encodec_weight_norm(model) + + model.to(device) + model.eval() + + encoder = model.encoder + decoder = model.decoder + quantizer = model.quantizer + multi_scale_discriminator = model.multi_scale_discriminator + multi_period_discriminator = model.multi_period_discriminator + multi_scale_stft_discriminator = model.multi_scale_stft_discriminator + + num_param_e = sum([p.numel() for p in encoder.parameters()]) + logging.info(f"Number of parameters in encoder: {num_param_e}") + num_param_d = sum([p.numel() for p in decoder.parameters()]) + logging.info(f"Number of parameters in decoder: {num_param_d}") + num_param_q = sum([p.numel() for p in quantizer.parameters()]) + logging.info(f"Number of parameters in quantizer: {num_param_q}") + num_param_ds = sum([p.numel() for p in multi_scale_discriminator.parameters()]) + logging.info(f"Number of parameters in multi_scale_discriminator: {num_param_ds}") + num_param_dp = sum([p.numel() for p in multi_period_discriminator.parameters()]) + logging.info(f"Number of parameters in multi_period_discriminator: {num_param_dp}") + num_param_dstft = sum( + [p.numel() for p in multi_scale_stft_discriminator.parameters()] + ) + logging.info( + f"Number of parameters in multi_scale_stft_discriminator: {num_param_dstft}" + ) + logging.info( + f"Total number of parameters: {num_param_e + num_param_d + num_param_q + num_param_ds + num_param_dp + num_param_dstft}" + ) + + test_clean_cuts = libritts.test_clean_cuts() + test_clean = libritts.test_dataloaders(test_clean_cuts) + + test_other_cuts = libritts.test_other_cuts() + test_other = libritts.test_dataloaders(test_other_cuts) + + dev_clean_cuts = libritts.dev_clean_cuts() + dev_clean = libritts.valid_dataloaders(dev_clean_cuts) + + dev_other_cuts = libritts.dev_other_cuts() + dev_other = libritts.valid_dataloaders(dev_other_cuts) + + infer_sets = { + "test-clean": test_clean, + "test-other": test_other, + "dev-clean": dev_clean, + "dev-other": dev_other, + } + + for subset, dl in infer_sets.items(): + save_wav_dir = params.res_dir / "wav" / subset + save_wav_dir.mkdir(parents=True, exist_ok=True) + + logging.info(f"Processing {subset} set, saving to {save_wav_dir}") + + infer_dataset( + dl=dl, + subset=subset, + params=params, + model=model, + ) + + logging.info(f"Wav files are saved to {params.save_wav_dir}") + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/libritts/CODEC/encodec/quantization/core_vq.py b/egs/libritts/CODEC/encodec/quantization/core_vq.py index 66d3dcf5d..4719e20f7 100644 --- a/egs/libritts/CODEC/encodec/quantization/core_vq.py +++ b/egs/libritts/CODEC/encodec/quantization/core_vq.py @@ -360,7 +360,7 @@ class ResidualVectorQuantization(nn.Module): all_indices = [] n_q = n_q or len(self.layers) st = st or 0 - for layer in self.layers[st:n_q]: # 设置解码的起止layer + for layer in self.layers[st:n_q]: indices = layer.encode(residual) quantized = layer.decode(indices) residual = residual - quantized diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index 842689155..65aec1383 100644 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -136,12 +136,6 @@ def get_parser(): default=False, help="Whether to use half precision training.", ) - parser.add_argument( - "--chunk-size", - type=int, - default=1, - help="The chunk size for the dataset (in second).", - ) return parser @@ -191,6 +185,7 @@ def get_params() -> AttributeDict: "valid_interval": 200, "env_info": get_env_info(), "sampling_rate": 24000, + "chunk_size": 1.0, # in seconds "lambda_adv": 1.0, # loss scaling coefficient for adversarial loss "lambda_wav": 100.0, # loss scaling coefficient for waveform loss "lambda_feat": 1.0, # loss scaling coefficient for feat loss @@ -570,7 +565,7 @@ def train_one_epoch( valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train ) - for index in range(params.num_samples): # 3 + for index in range(params.num_samples): # 3 speech_hat_i = speech_hat[index] speech_i = speech[index] if speech_hat_i.dim() > 1: @@ -655,8 +650,10 @@ def compute_validation_loss( # infer for first batch: if batch_idx == 0 and rank == 0: inner_model = model.module if isinstance(model, DDP) else model - audio_pred = inner_model.inference(x=audio, target_bw=params.target_bw) - returned_sample = (audio_pred, audio) + _, audio_hat = inner_model.inference( + x=audio, target_bw=params.target_bw + ) + returned_sample = (audio_hat, audio) if world_size > 1: tot_loss.reduce(device)