diff --git a/egs/libritts/CODEC/encodec/infer.py b/egs/libritts/CODEC/encodec/infer.py index dccff984d..c407b4a59 100755 --- a/egs/libritts/CODEC/encodec/infer.py +++ b/egs/libritts/CODEC/encodec/infer.py @@ -30,12 +30,16 @@ import argparse import logging from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import Dict, List +from statistics import mean +from typing import List, Tuple +import numpy as np import torch -import torch.nn.functional as F import torchaudio from codec_datamodule import LibriTTSCodecDataModule +from pesq import pesq +from pystoi import stoi +from scipy import signal from torch import nn from train import get_model, get_params @@ -105,12 +109,25 @@ def remove_encodec_weight_norm(model) -> None: remove_weight_norm(decoder._modules[key].conv.conv) +def compute_pesq(ref_wav: np.ndarray, gen_wav: np.ndarray) -> float: + """Compute PESQ score between reference and generated audio.""" + DEFAULT_SAMPLING_RATE = 16000 + ref = signal.resample(ref_wav, DEFAULT_SAMPLING_RATE) + deg = signal.resample(gen_wav, DEFAULT_SAMPLING_RATE) + return pesq(fs=DEFAULT_SAMPLING_RATE, ref=ref, deg=deg, mode="wb") + + +def compute_stoi(ref_wav: np.ndarray, gen_wav: np.ndarray, sampling_rate: int) -> float: + """Compute STOI score between reference and generated audio.""" + return stoi(x=ref_wav, y=gen_wav, fs_sig=sampling_rate, extended=False) + + def infer_dataset( dl: torch.utils.data.DataLoader, subset: str, params: AttributeDict, model: nn.Module, -) -> None: +) -> Tuple[float, float]: """Decode dataset. The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`. @@ -123,6 +140,9 @@ def infer_dataset( It is returned by :func:`get_params`. model: The neural model. + + Returns: + The average PESQ and STOI scores. """ # Background worker save audios to disk. @@ -150,6 +170,9 @@ def infer_dataset( num_cuts = 0 log_interval = 5 + pesq_wb_scores = [] + stoi_scores = [] + try: num_batches = len(dl) except TypeError: @@ -169,6 +192,25 @@ def infer_dataset( ) audio_hats = audio_hats.squeeze(1).cpu() + for cut_id, audio, audio_hat, audio_len in zip( + cut_ids, audios, audio_hats, audio_lens + ): + try: + pesq_wb = compute_pesq( + ref_wav=audio[:audio_len].numpy(), + gen_wav=audio_hat[:audio_len].numpy(), + ) + pesq_wb_scores.append(pesq_wb) + except Exception as e: + logging.error(f"Error while computing PESQ for cut {cut_id}: {e}") + + stoi_score = compute_stoi( + ref_wav=audio[:audio_len].numpy(), + gen_wav=audio_hat[:audio_len].numpy(), + sampling_rate=params.sampling_rate, + ) + stoi_scores.append(stoi_score) + futures.append( executor.submit( _save_worker, @@ -192,6 +234,7 @@ def infer_dataset( # return results for f in futures: f.result() + return mean(pesq_wb_scores), mean(stoi_scores) @torch.no_grad() @@ -285,12 +328,13 @@ def main(): logging.info(f"Processing {subset} set, saving to {save_wav_dir}") - infer_dataset( + pesq_wb, stoi = infer_dataset( dl=dl, subset=subset, params=params, model=model, ) + logging.info(f"{subset}: PESQ-WB: {pesq_wb:.4f}, STOI: {stoi:.4f}") logging.info(f"Wav files are saved to {params.save_wav_dir}") logging.info("Done!")