mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
295 lines
8.9 KiB
Python
295 lines
8.9 KiB
Python
"""
|
|
Calculate UTMOS score with automatic Mean Opinion Score (MOS) prediction system
|
|
adapted from https://huggingface.co/spaces/sarulab-speech/UTMOS-demo
|
|
|
|
# Download model checkpoints
|
|
wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/epoch%3D3-step%3D7459.ckpt -P model/huggingface/utmos/utmos.pt
|
|
wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/wav2vec_small.pt -P model/huggingface/utmos/wav2vec_small.pt
|
|
"""
|
|
|
|
import argparse
|
|
import logging
|
|
import os
|
|
|
|
import fairseq
|
|
import librosa
|
|
import numpy as np
|
|
import pytorch_lightning as pl
|
|
import soundfile as sf
|
|
import torch
|
|
import torch.nn as nn
|
|
from tqdm import tqdm
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
|
def get_parser():
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument(
|
|
"--wav-path", type=str, help="path of the evaluated speech directory"
|
|
)
|
|
parser.add_argument(
|
|
"--utmos-model-path",
|
|
type=str,
|
|
default="model/huggingface/utmos/utmos.pt",
|
|
help="path of the UTMOS model",
|
|
)
|
|
parser.add_argument(
|
|
"--ssl-model-path",
|
|
type=str,
|
|
default="model/huggingface/utmos/wav2vec_small.pt",
|
|
help="path of the wav2vec SSL model",
|
|
)
|
|
return parser
|
|
|
|
|
|
class UTMOSScore:
|
|
"""Predicting score for each audio clip."""
|
|
|
|
def __init__(self, utmos_model_path, ssl_model_path):
|
|
self.sample_rate = 16000
|
|
self.device = (
|
|
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
)
|
|
self.model = (
|
|
BaselineLightningModule.load_from_checkpoint(
|
|
utmos_model_path, ssl_model_path=ssl_model_path
|
|
)
|
|
.eval()
|
|
.to(self.device)
|
|
)
|
|
|
|
def score(self, wavs: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
wavs: waveforms to be evaluated. When len(wavs) == 1 or 2,
|
|
the model processes the input as a single audio clip. The model
|
|
performs batch processing when len(wavs) == 3.
|
|
"""
|
|
if len(wavs.shape) == 1:
|
|
out_wavs = wavs.unsqueeze(0).unsqueeze(0)
|
|
elif len(wavs.shape) == 2:
|
|
out_wavs = wavs.unsqueeze(0)
|
|
elif len(wavs.shape) == 3:
|
|
out_wavs = wavs
|
|
else:
|
|
raise ValueError("Dimension of input tensor needs to be <= 3.")
|
|
bs = out_wavs.shape[0]
|
|
batch = {
|
|
"wav": out_wavs,
|
|
"domains": torch.zeros(bs, dtype=torch.int).to(self.device),
|
|
"judge_id": torch.ones(bs, dtype=torch.int).to(self.device) * 288,
|
|
}
|
|
with torch.no_grad():
|
|
output = self.model(batch)
|
|
|
|
return output.mean(dim=1).squeeze(1).cpu().detach() * 2 + 3
|
|
|
|
def score_dir(self, dir, dtype="float32"):
|
|
def _load_speech_task(fname, sample_rate):
|
|
|
|
wav_data, sr = sf.read(fname, dtype=dtype)
|
|
if sr != sample_rate:
|
|
wav_data = librosa.resample(
|
|
wav_data, orig_sr=sr, target_sr=self.sample_rate
|
|
)
|
|
wav_data = torch.from_numpy(wav_data)
|
|
|
|
return wav_data
|
|
|
|
score_lst = []
|
|
for fname in tqdm(os.listdir(dir)):
|
|
speech = _load_speech_task(os.path.join(dir, fname), self.sample_rate)
|
|
speech = speech.to(self.device)
|
|
with torch.no_grad():
|
|
score = self.score(speech)
|
|
score_lst.append(score.item())
|
|
return np.mean(score_lst)
|
|
|
|
|
|
def load_ssl_model(ckpt_path="wav2vec_small.pt"):
|
|
SSL_OUT_DIM = 768
|
|
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
|
[ckpt_path]
|
|
)
|
|
ssl_model = model[0]
|
|
ssl_model.remove_pretraining_modules()
|
|
return SSL_model(ssl_model, SSL_OUT_DIM)
|
|
|
|
|
|
class BaselineLightningModule(pl.LightningModule):
|
|
def __init__(self, ssl_model_path):
|
|
super().__init__()
|
|
self.construct_model(ssl_model_path)
|
|
self.save_hyperparameters()
|
|
|
|
def construct_model(self, ssl_model_path):
|
|
self.feature_extractors = nn.ModuleList(
|
|
[
|
|
load_ssl_model(ckpt_path=ssl_model_path),
|
|
DomainEmbedding(3, 128),
|
|
]
|
|
)
|
|
output_dim = sum(
|
|
[
|
|
feature_extractor.get_output_dim()
|
|
for feature_extractor in self.feature_extractors
|
|
]
|
|
)
|
|
output_layers = [
|
|
LDConditioner(judge_dim=128, num_judges=3000, input_dim=output_dim)
|
|
]
|
|
output_dim = output_layers[-1].get_output_dim()
|
|
output_layers.append(
|
|
Projection(
|
|
hidden_dim=2048,
|
|
activation=torch.nn.ReLU(),
|
|
range_clipping=False,
|
|
input_dim=output_dim,
|
|
)
|
|
)
|
|
|
|
self.output_layers = nn.ModuleList(output_layers)
|
|
|
|
def forward(self, inputs):
|
|
outputs = {}
|
|
for feature_extractor in self.feature_extractors:
|
|
outputs.update(feature_extractor(inputs))
|
|
x = outputs
|
|
for output_layer in self.output_layers:
|
|
x = output_layer(x, inputs)
|
|
return x
|
|
|
|
|
|
class SSL_model(nn.Module):
|
|
def __init__(self, ssl_model, ssl_out_dim) -> None:
|
|
super(SSL_model, self).__init__()
|
|
self.ssl_model, self.ssl_out_dim = ssl_model, ssl_out_dim
|
|
|
|
def forward(self, batch):
|
|
wav = batch["wav"]
|
|
wav = wav.squeeze(1) # [batches, wav_len]
|
|
res = self.ssl_model(wav, mask=False, features_only=True)
|
|
x = res["x"]
|
|
return {"ssl-feature": x}
|
|
|
|
def get_output_dim(self):
|
|
return self.ssl_out_dim
|
|
|
|
|
|
class DomainEmbedding(nn.Module):
|
|
def __init__(self, n_domains, domain_dim) -> None:
|
|
super().__init__()
|
|
self.embedding = nn.Embedding(n_domains, domain_dim)
|
|
self.output_dim = domain_dim
|
|
|
|
def forward(self, batch):
|
|
return {"domain-feature": self.embedding(batch["domains"])}
|
|
|
|
def get_output_dim(self):
|
|
return self.output_dim
|
|
|
|
|
|
class LDConditioner(nn.Module):
|
|
"""
|
|
Conditions ssl output by listener embedding
|
|
"""
|
|
|
|
def __init__(self, input_dim, judge_dim, num_judges=None):
|
|
super().__init__()
|
|
self.input_dim = input_dim
|
|
self.judge_dim = judge_dim
|
|
self.num_judges = num_judges
|
|
assert num_judges != None
|
|
self.judge_embedding = nn.Embedding(num_judges, self.judge_dim)
|
|
# concat [self.output_layer, phoneme features]
|
|
|
|
self.decoder_rnn = nn.LSTM(
|
|
input_size=self.input_dim + self.judge_dim,
|
|
hidden_size=512,
|
|
num_layers=1,
|
|
batch_first=True,
|
|
bidirectional=True,
|
|
) # linear?
|
|
self.out_dim = self.decoder_rnn.hidden_size * 2
|
|
|
|
def get_output_dim(self):
|
|
return self.out_dim
|
|
|
|
def forward(self, x, batch):
|
|
judge_ids = batch["judge_id"]
|
|
if "phoneme-feature" in x.keys():
|
|
concatenated_feature = torch.cat(
|
|
(
|
|
x["ssl-feature"],
|
|
x["phoneme-feature"]
|
|
.unsqueeze(1)
|
|
.expand(-1, x["ssl-feature"].size(1), -1),
|
|
),
|
|
dim=2,
|
|
)
|
|
else:
|
|
concatenated_feature = x["ssl-feature"]
|
|
if "domain-feature" in x.keys():
|
|
concatenated_feature = torch.cat(
|
|
(
|
|
concatenated_feature,
|
|
x["domain-feature"]
|
|
.unsqueeze(1)
|
|
.expand(-1, concatenated_feature.size(1), -1),
|
|
),
|
|
dim=2,
|
|
)
|
|
if judge_ids != None:
|
|
concatenated_feature = torch.cat(
|
|
(
|
|
concatenated_feature,
|
|
self.judge_embedding(judge_ids)
|
|
.unsqueeze(1)
|
|
.expand(-1, concatenated_feature.size(1), -1),
|
|
),
|
|
dim=2,
|
|
)
|
|
decoder_output, (h, c) = self.decoder_rnn(concatenated_feature)
|
|
return decoder_output
|
|
|
|
|
|
class Projection(nn.Module):
|
|
def __init__(self, input_dim, hidden_dim, activation, range_clipping=False):
|
|
super(Projection, self).__init__()
|
|
self.range_clipping = range_clipping
|
|
output_dim = 1
|
|
if range_clipping:
|
|
self.proj = nn.Tanh()
|
|
|
|
self.net = nn.Sequential(
|
|
nn.Linear(input_dim, hidden_dim),
|
|
activation,
|
|
nn.Dropout(0.3),
|
|
nn.Linear(hidden_dim, output_dim),
|
|
)
|
|
self.output_dim = output_dim
|
|
|
|
def forward(self, x, batch):
|
|
output = self.net(x)
|
|
|
|
# range clipping
|
|
if self.range_clipping:
|
|
return self.proj(output) * 2.0 + 3
|
|
else:
|
|
return output
|
|
|
|
def get_output_dim(self):
|
|
return self.output_dim
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = get_parser()
|
|
args = parser.parse_args()
|
|
UTMOS = UTMOSScore(
|
|
utmos_model_path=args.utmos_model_path, ssl_model_path=args.ssl_model_path
|
|
)
|
|
score = UTMOS.score_dir(args.wav_path)
|
|
logging.info(f"UTMOS score: {score:.2f}")
|