icefall/egs/zipvoice/local/evaluate_utmos.py
2025-06-16 09:45:34 +08:00

295 lines
8.9 KiB
Python

"""
Calculate UTMOS score with automatic Mean Opinion Score (MOS) prediction system
adapted from https://huggingface.co/spaces/sarulab-speech/UTMOS-demo
# Download model checkpoints
wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/epoch%3D3-step%3D7459.ckpt -P model/huggingface/utmos/utmos.pt
wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/wav2vec_small.pt -P model/huggingface/utmos/wav2vec_small.pt
"""
import argparse
import logging
import os
import fairseq
import librosa
import numpy as np
import pytorch_lightning as pl
import soundfile as sf
import torch
import torch.nn as nn
from tqdm import tqdm
logging.basicConfig(level=logging.INFO)
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--wav-path", type=str, help="path of the evaluated speech directory"
)
parser.add_argument(
"--utmos-model-path",
type=str,
default="model/huggingface/utmos/utmos.pt",
help="path of the UTMOS model",
)
parser.add_argument(
"--ssl-model-path",
type=str,
default="model/huggingface/utmos/wav2vec_small.pt",
help="path of the wav2vec SSL model",
)
return parser
class UTMOSScore:
"""Predicting score for each audio clip."""
def __init__(self, utmos_model_path, ssl_model_path):
self.sample_rate = 16000
self.device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
self.model = (
BaselineLightningModule.load_from_checkpoint(
utmos_model_path, ssl_model_path=ssl_model_path
)
.eval()
.to(self.device)
)
def score(self, wavs: torch.Tensor) -> torch.Tensor:
"""
Args:
wavs: waveforms to be evaluated. When len(wavs) == 1 or 2,
the model processes the input as a single audio clip. The model
performs batch processing when len(wavs) == 3.
"""
if len(wavs.shape) == 1:
out_wavs = wavs.unsqueeze(0).unsqueeze(0)
elif len(wavs.shape) == 2:
out_wavs = wavs.unsqueeze(0)
elif len(wavs.shape) == 3:
out_wavs = wavs
else:
raise ValueError("Dimension of input tensor needs to be <= 3.")
bs = out_wavs.shape[0]
batch = {
"wav": out_wavs,
"domains": torch.zeros(bs, dtype=torch.int).to(self.device),
"judge_id": torch.ones(bs, dtype=torch.int).to(self.device) * 288,
}
with torch.no_grad():
output = self.model(batch)
return output.mean(dim=1).squeeze(1).cpu().detach() * 2 + 3
def score_dir(self, dir, dtype="float32"):
def _load_speech_task(fname, sample_rate):
wav_data, sr = sf.read(fname, dtype=dtype)
if sr != sample_rate:
wav_data = librosa.resample(
wav_data, orig_sr=sr, target_sr=self.sample_rate
)
wav_data = torch.from_numpy(wav_data)
return wav_data
score_lst = []
for fname in tqdm(os.listdir(dir)):
speech = _load_speech_task(os.path.join(dir, fname), self.sample_rate)
speech = speech.to(self.device)
with torch.no_grad():
score = self.score(speech)
score_lst.append(score.item())
return np.mean(score_lst)
def load_ssl_model(ckpt_path="wav2vec_small.pt"):
SSL_OUT_DIM = 768
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[ckpt_path]
)
ssl_model = model[0]
ssl_model.remove_pretraining_modules()
return SSL_model(ssl_model, SSL_OUT_DIM)
class BaselineLightningModule(pl.LightningModule):
def __init__(self, ssl_model_path):
super().__init__()
self.construct_model(ssl_model_path)
self.save_hyperparameters()
def construct_model(self, ssl_model_path):
self.feature_extractors = nn.ModuleList(
[
load_ssl_model(ckpt_path=ssl_model_path),
DomainEmbedding(3, 128),
]
)
output_dim = sum(
[
feature_extractor.get_output_dim()
for feature_extractor in self.feature_extractors
]
)
output_layers = [
LDConditioner(judge_dim=128, num_judges=3000, input_dim=output_dim)
]
output_dim = output_layers[-1].get_output_dim()
output_layers.append(
Projection(
hidden_dim=2048,
activation=torch.nn.ReLU(),
range_clipping=False,
input_dim=output_dim,
)
)
self.output_layers = nn.ModuleList(output_layers)
def forward(self, inputs):
outputs = {}
for feature_extractor in self.feature_extractors:
outputs.update(feature_extractor(inputs))
x = outputs
for output_layer in self.output_layers:
x = output_layer(x, inputs)
return x
class SSL_model(nn.Module):
def __init__(self, ssl_model, ssl_out_dim) -> None:
super(SSL_model, self).__init__()
self.ssl_model, self.ssl_out_dim = ssl_model, ssl_out_dim
def forward(self, batch):
wav = batch["wav"]
wav = wav.squeeze(1) # [batches, wav_len]
res = self.ssl_model(wav, mask=False, features_only=True)
x = res["x"]
return {"ssl-feature": x}
def get_output_dim(self):
return self.ssl_out_dim
class DomainEmbedding(nn.Module):
def __init__(self, n_domains, domain_dim) -> None:
super().__init__()
self.embedding = nn.Embedding(n_domains, domain_dim)
self.output_dim = domain_dim
def forward(self, batch):
return {"domain-feature": self.embedding(batch["domains"])}
def get_output_dim(self):
return self.output_dim
class LDConditioner(nn.Module):
"""
Conditions ssl output by listener embedding
"""
def __init__(self, input_dim, judge_dim, num_judges=None):
super().__init__()
self.input_dim = input_dim
self.judge_dim = judge_dim
self.num_judges = num_judges
assert num_judges != None
self.judge_embedding = nn.Embedding(num_judges, self.judge_dim)
# concat [self.output_layer, phoneme features]
self.decoder_rnn = nn.LSTM(
input_size=self.input_dim + self.judge_dim,
hidden_size=512,
num_layers=1,
batch_first=True,
bidirectional=True,
) # linear?
self.out_dim = self.decoder_rnn.hidden_size * 2
def get_output_dim(self):
return self.out_dim
def forward(self, x, batch):
judge_ids = batch["judge_id"]
if "phoneme-feature" in x.keys():
concatenated_feature = torch.cat(
(
x["ssl-feature"],
x["phoneme-feature"]
.unsqueeze(1)
.expand(-1, x["ssl-feature"].size(1), -1),
),
dim=2,
)
else:
concatenated_feature = x["ssl-feature"]
if "domain-feature" in x.keys():
concatenated_feature = torch.cat(
(
concatenated_feature,
x["domain-feature"]
.unsqueeze(1)
.expand(-1, concatenated_feature.size(1), -1),
),
dim=2,
)
if judge_ids != None:
concatenated_feature = torch.cat(
(
concatenated_feature,
self.judge_embedding(judge_ids)
.unsqueeze(1)
.expand(-1, concatenated_feature.size(1), -1),
),
dim=2,
)
decoder_output, (h, c) = self.decoder_rnn(concatenated_feature)
return decoder_output
class Projection(nn.Module):
def __init__(self, input_dim, hidden_dim, activation, range_clipping=False):
super(Projection, self).__init__()
self.range_clipping = range_clipping
output_dim = 1
if range_clipping:
self.proj = nn.Tanh()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
activation,
nn.Dropout(0.3),
nn.Linear(hidden_dim, output_dim),
)
self.output_dim = output_dim
def forward(self, x, batch):
output = self.net(x)
# range clipping
if self.range_clipping:
return self.proj(output) * 2.0 + 3
else:
return output
def get_output_dim(self):
return self.output_dim
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
UTMOS = UTMOSScore(
utmos_model_path=args.utmos_model_path, ssl_model_path=args.ssl_model_path
)
score = UTMOS.score_dir(args.wav_path)
logging.info(f"UTMOS score: {score:.2f}")