added pesq and stoi for reconstruction performance evaluation

This commit is contained in:
JinZr 2024-09-08 15:37:06 +08:00
parent c43977ea05
commit 1e65a976d0

View File

@ -30,12 +30,16 @@ import argparse
import logging import logging
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from pathlib import Path from pathlib import Path
from typing import Dict, List from statistics import mean
from typing import List, Tuple
import numpy as np
import torch import torch
import torch.nn.functional as F
import torchaudio import torchaudio
from codec_datamodule import LibriTTSCodecDataModule from codec_datamodule import LibriTTSCodecDataModule
from pesq import pesq
from pystoi import stoi
from scipy import signal
from torch import nn from torch import nn
from train import get_model, get_params from train import get_model, get_params
@ -105,12 +109,25 @@ def remove_encodec_weight_norm(model) -> None:
remove_weight_norm(decoder._modules[key].conv.conv) remove_weight_norm(decoder._modules[key].conv.conv)
def compute_pesq(ref_wav: np.ndarray, gen_wav: np.ndarray) -> float:
"""Compute PESQ score between reference and generated audio."""
DEFAULT_SAMPLING_RATE = 16000
ref = signal.resample(ref_wav, DEFAULT_SAMPLING_RATE)
deg = signal.resample(gen_wav, DEFAULT_SAMPLING_RATE)
return pesq(fs=DEFAULT_SAMPLING_RATE, ref=ref, deg=deg, mode="wb")
def compute_stoi(ref_wav: np.ndarray, gen_wav: np.ndarray, sampling_rate: int) -> float:
"""Compute STOI score between reference and generated audio."""
return stoi(x=ref_wav, y=gen_wav, fs_sig=sampling_rate, extended=False)
def infer_dataset( def infer_dataset(
dl: torch.utils.data.DataLoader, dl: torch.utils.data.DataLoader,
subset: str, subset: str,
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
) -> None: ) -> Tuple[float, float]:
"""Decode dataset. """Decode dataset.
The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`. The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`.
@ -123,6 +140,9 @@ def infer_dataset(
It is returned by :func:`get_params`. It is returned by :func:`get_params`.
model: model:
The neural model. The neural model.
Returns:
The average PESQ and STOI scores.
""" """
# Background worker save audios to disk. # Background worker save audios to disk.
@ -150,6 +170,9 @@ def infer_dataset(
num_cuts = 0 num_cuts = 0
log_interval = 5 log_interval = 5
pesq_wb_scores = []
stoi_scores = []
try: try:
num_batches = len(dl) num_batches = len(dl)
except TypeError: except TypeError:
@ -169,6 +192,25 @@ def infer_dataset(
) )
audio_hats = audio_hats.squeeze(1).cpu() audio_hats = audio_hats.squeeze(1).cpu()
for cut_id, audio, audio_hat, audio_len in zip(
cut_ids, audios, audio_hats, audio_lens
):
try:
pesq_wb = compute_pesq(
ref_wav=audio[:audio_len].numpy(),
gen_wav=audio_hat[:audio_len].numpy(),
)
pesq_wb_scores.append(pesq_wb)
except Exception as e:
logging.error(f"Error while computing PESQ for cut {cut_id}: {e}")
stoi_score = compute_stoi(
ref_wav=audio[:audio_len].numpy(),
gen_wav=audio_hat[:audio_len].numpy(),
sampling_rate=params.sampling_rate,
)
stoi_scores.append(stoi_score)
futures.append( futures.append(
executor.submit( executor.submit(
_save_worker, _save_worker,
@ -192,6 +234,7 @@ def infer_dataset(
# return results # return results
for f in futures: for f in futures:
f.result() f.result()
return mean(pesq_wb_scores), mean(stoi_scores)
@torch.no_grad() @torch.no_grad()
@ -285,12 +328,13 @@ def main():
logging.info(f"Processing {subset} set, saving to {save_wav_dir}") logging.info(f"Processing {subset} set, saving to {save_wav_dir}")
infer_dataset( pesq_wb, stoi = infer_dataset(
dl=dl, dl=dl,
subset=subset, subset=subset,
params=params, params=params,
model=model, model=model,
) )
logging.info(f"{subset}: PESQ-WB: {pesq_wb:.4f}, STOI: {stoi:.4f}")
logging.info(f"Wav files are saved to {params.save_wav_dir}") logging.info(f"Wav files are saved to {params.save_wav_dir}")
logging.info("Done!") logging.info("Done!")