mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
added pesq and stoi for reconstruction performance evaluation
This commit is contained in:
parent
c43977ea05
commit
1e65a976d0
@ -30,12 +30,16 @@ import argparse
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
from statistics import mean
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from codec_datamodule import LibriTTSCodecDataModule
|
||||
from pesq import pesq
|
||||
from pystoi import stoi
|
||||
from scipy import signal
|
||||
from torch import nn
|
||||
from train import get_model, get_params
|
||||
|
||||
@ -105,12 +109,25 @@ def remove_encodec_weight_norm(model) -> None:
|
||||
remove_weight_norm(decoder._modules[key].conv.conv)
|
||||
|
||||
|
||||
def compute_pesq(ref_wav: np.ndarray, gen_wav: np.ndarray) -> float:
|
||||
"""Compute PESQ score between reference and generated audio."""
|
||||
DEFAULT_SAMPLING_RATE = 16000
|
||||
ref = signal.resample(ref_wav, DEFAULT_SAMPLING_RATE)
|
||||
deg = signal.resample(gen_wav, DEFAULT_SAMPLING_RATE)
|
||||
return pesq(fs=DEFAULT_SAMPLING_RATE, ref=ref, deg=deg, mode="wb")
|
||||
|
||||
|
||||
def compute_stoi(ref_wav: np.ndarray, gen_wav: np.ndarray, sampling_rate: int) -> float:
|
||||
"""Compute STOI score between reference and generated audio."""
|
||||
return stoi(x=ref_wav, y=gen_wav, fs_sig=sampling_rate, extended=False)
|
||||
|
||||
|
||||
def infer_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
subset: str,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
) -> None:
|
||||
) -> Tuple[float, float]:
|
||||
"""Decode dataset.
|
||||
The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`.
|
||||
|
||||
@ -123,6 +140,9 @@ def infer_dataset(
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
|
||||
Returns:
|
||||
The average PESQ and STOI scores.
|
||||
"""
|
||||
|
||||
# Background worker save audios to disk.
|
||||
@ -150,6 +170,9 @@ def infer_dataset(
|
||||
num_cuts = 0
|
||||
log_interval = 5
|
||||
|
||||
pesq_wb_scores = []
|
||||
stoi_scores = []
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
@ -169,6 +192,25 @@ def infer_dataset(
|
||||
)
|
||||
audio_hats = audio_hats.squeeze(1).cpu()
|
||||
|
||||
for cut_id, audio, audio_hat, audio_len in zip(
|
||||
cut_ids, audios, audio_hats, audio_lens
|
||||
):
|
||||
try:
|
||||
pesq_wb = compute_pesq(
|
||||
ref_wav=audio[:audio_len].numpy(),
|
||||
gen_wav=audio_hat[:audio_len].numpy(),
|
||||
)
|
||||
pesq_wb_scores.append(pesq_wb)
|
||||
except Exception as e:
|
||||
logging.error(f"Error while computing PESQ for cut {cut_id}: {e}")
|
||||
|
||||
stoi_score = compute_stoi(
|
||||
ref_wav=audio[:audio_len].numpy(),
|
||||
gen_wav=audio_hat[:audio_len].numpy(),
|
||||
sampling_rate=params.sampling_rate,
|
||||
)
|
||||
stoi_scores.append(stoi_score)
|
||||
|
||||
futures.append(
|
||||
executor.submit(
|
||||
_save_worker,
|
||||
@ -192,6 +234,7 @@ def infer_dataset(
|
||||
# return results
|
||||
for f in futures:
|
||||
f.result()
|
||||
return mean(pesq_wb_scores), mean(stoi_scores)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@ -285,12 +328,13 @@ def main():
|
||||
|
||||
logging.info(f"Processing {subset} set, saving to {save_wav_dir}")
|
||||
|
||||
infer_dataset(
|
||||
pesq_wb, stoi = infer_dataset(
|
||||
dl=dl,
|
||||
subset=subset,
|
||||
params=params,
|
||||
model=model,
|
||||
)
|
||||
logging.info(f"{subset}: PESQ-WB: {pesq_wb:.4f}, STOI: {stoi:.4f}")
|
||||
|
||||
logging.info(f"Wav files are saved to {params.save_wav_dir}")
|
||||
logging.info("Done!")
|
||||
|
Loading…
x
Reference in New Issue
Block a user