mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
added pesq and stoi for reconstruction performance evaluation
This commit is contained in:
parent
c43977ea05
commit
1e65a976d0
@ -30,12 +30,16 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List
|
from statistics import mean
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from codec_datamodule import LibriTTSCodecDataModule
|
from codec_datamodule import LibriTTSCodecDataModule
|
||||||
|
from pesq import pesq
|
||||||
|
from pystoi import stoi
|
||||||
|
from scipy import signal
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from train import get_model, get_params
|
from train import get_model, get_params
|
||||||
|
|
||||||
@ -105,12 +109,25 @@ def remove_encodec_weight_norm(model) -> None:
|
|||||||
remove_weight_norm(decoder._modules[key].conv.conv)
|
remove_weight_norm(decoder._modules[key].conv.conv)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_pesq(ref_wav: np.ndarray, gen_wav: np.ndarray) -> float:
|
||||||
|
"""Compute PESQ score between reference and generated audio."""
|
||||||
|
DEFAULT_SAMPLING_RATE = 16000
|
||||||
|
ref = signal.resample(ref_wav, DEFAULT_SAMPLING_RATE)
|
||||||
|
deg = signal.resample(gen_wav, DEFAULT_SAMPLING_RATE)
|
||||||
|
return pesq(fs=DEFAULT_SAMPLING_RATE, ref=ref, deg=deg, mode="wb")
|
||||||
|
|
||||||
|
|
||||||
|
def compute_stoi(ref_wav: np.ndarray, gen_wav: np.ndarray, sampling_rate: int) -> float:
|
||||||
|
"""Compute STOI score between reference and generated audio."""
|
||||||
|
return stoi(x=ref_wav, y=gen_wav, fs_sig=sampling_rate, extended=False)
|
||||||
|
|
||||||
|
|
||||||
def infer_dataset(
|
def infer_dataset(
|
||||||
dl: torch.utils.data.DataLoader,
|
dl: torch.utils.data.DataLoader,
|
||||||
subset: str,
|
subset: str,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
) -> None:
|
) -> Tuple[float, float]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`.
|
The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`.
|
||||||
|
|
||||||
@ -123,6 +140,9 @@ def infer_dataset(
|
|||||||
It is returned by :func:`get_params`.
|
It is returned by :func:`get_params`.
|
||||||
model:
|
model:
|
||||||
The neural model.
|
The neural model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The average PESQ and STOI scores.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Background worker save audios to disk.
|
# Background worker save audios to disk.
|
||||||
@ -150,6 +170,9 @@ def infer_dataset(
|
|||||||
num_cuts = 0
|
num_cuts = 0
|
||||||
log_interval = 5
|
log_interval = 5
|
||||||
|
|
||||||
|
pesq_wb_scores = []
|
||||||
|
stoi_scores = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
num_batches = len(dl)
|
num_batches = len(dl)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
@ -169,6 +192,25 @@ def infer_dataset(
|
|||||||
)
|
)
|
||||||
audio_hats = audio_hats.squeeze(1).cpu()
|
audio_hats = audio_hats.squeeze(1).cpu()
|
||||||
|
|
||||||
|
for cut_id, audio, audio_hat, audio_len in zip(
|
||||||
|
cut_ids, audios, audio_hats, audio_lens
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
pesq_wb = compute_pesq(
|
||||||
|
ref_wav=audio[:audio_len].numpy(),
|
||||||
|
gen_wav=audio_hat[:audio_len].numpy(),
|
||||||
|
)
|
||||||
|
pesq_wb_scores.append(pesq_wb)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error while computing PESQ for cut {cut_id}: {e}")
|
||||||
|
|
||||||
|
stoi_score = compute_stoi(
|
||||||
|
ref_wav=audio[:audio_len].numpy(),
|
||||||
|
gen_wav=audio_hat[:audio_len].numpy(),
|
||||||
|
sampling_rate=params.sampling_rate,
|
||||||
|
)
|
||||||
|
stoi_scores.append(stoi_score)
|
||||||
|
|
||||||
futures.append(
|
futures.append(
|
||||||
executor.submit(
|
executor.submit(
|
||||||
_save_worker,
|
_save_worker,
|
||||||
@ -192,6 +234,7 @@ def infer_dataset(
|
|||||||
# return results
|
# return results
|
||||||
for f in futures:
|
for f in futures:
|
||||||
f.result()
|
f.result()
|
||||||
|
return mean(pesq_wb_scores), mean(stoi_scores)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -285,12 +328,13 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"Processing {subset} set, saving to {save_wav_dir}")
|
logging.info(f"Processing {subset} set, saving to {save_wav_dir}")
|
||||||
|
|
||||||
infer_dataset(
|
pesq_wb, stoi = infer_dataset(
|
||||||
dl=dl,
|
dl=dl,
|
||||||
subset=subset,
|
subset=subset,
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
|
logging.info(f"{subset}: PESQ-WB: {pesq_wb:.4f}, STOI: {stoi:.4f}")
|
||||||
|
|
||||||
logging.info(f"Wav files are saved to {params.save_wav_dir}")
|
logging.info(f"Wav files are saved to {params.save_wav_dir}")
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user