zr_jin 1c4dd464a0
Performed end to end testing on the matcha recipe (#1797)
* minor fixes to the `ljspeech/matcha` recipe
2024-12-08 03:18:15 +08:00

296 lines
11 KiB
Python

import datetime as dt
import math
import random
import monotonic_align as monotonic_align
import torch
from model import (
denormalize,
duration_loss,
fix_len_compatibility,
generate_path,
sequence_mask,
)
from models.components.flow_matching import CFM
from models.components.text_encoder import TextEncoder
class MatchaTTS(torch.nn.Module): # 🍵
def __init__(
self,
n_vocab,
n_spks,
spk_emb_dim,
n_feats,
encoder,
decoder,
cfm,
data_statistics,
out_size,
optimizer=None,
scheduler=None,
prior_loss=True,
use_precomputed_durations=False,
):
super().__init__()
# self.save_hyperparameters(logger=False)
self.n_vocab = n_vocab
self.n_spks = n_spks
self.spk_emb_dim = spk_emb_dim
self.n_feats = n_feats
self.out_size = out_size
self.prior_loss = prior_loss
self.use_precomputed_durations = use_precomputed_durations
if n_spks > 1:
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
self.encoder = TextEncoder(
encoder.encoder_type,
encoder.encoder_params,
encoder.duration_predictor_params,
n_vocab,
n_spks,
spk_emb_dim,
)
self.decoder = CFM(
in_channels=2 * encoder.encoder_params.n_feats,
out_channel=encoder.encoder_params.n_feats,
cfm_params=cfm,
decoder_params=decoder,
n_spks=n_spks,
spk_emb_dim=spk_emb_dim,
)
if data_statistics is not None:
self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"]))
self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"]))
else:
self.register_buffer("mel_mean", torch.tensor(0.0))
self.register_buffer("mel_std", torch.tensor(1.0))
@torch.inference_mode()
def synthesise(
self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, length_scale=1.0
):
"""
Generates mel-spectrogram from text. Returns:
1. encoder outputs
2. decoder outputs
3. generated alignment
Args:
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
shape: (batch_size, max_text_length)
x_lengths (torch.Tensor): lengths of texts in batch.
shape: (batch_size,)
n_timesteps (int): number of steps to use for reverse diffusion in decoder.
temperature (float, optional): controls variance of terminal distribution.
spks (bool, optional): speaker ids.
shape: (batch_size,)
length_scale (float, optional): controls speech pace.
Increase value to slow down generated speech and vice versa.
Returns:
dict: {
"encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
# Average mel spectrogram generated by the encoder
"decoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
# Refined mel spectrogram improved by the CFM
"attn": torch.Tensor, shape: (batch_size, max_text_length, max_mel_length),
# Alignment map between text and mel spectrogram
"mel": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
# Denormalized mel spectrogram
"mel_lengths": torch.Tensor, shape: (batch_size,),
# Lengths of mel spectrograms
"rtf": float,
# Real-time factor
"""
# For RTF computation
t = dt.datetime.now()
if self.n_spks > 1:
# Get speaker embedding
spks = self.spk_emb(spks.long())
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
mu_x, logw, x_mask = self.encoder(x, x_lengths, spks)
w = torch.exp(logw) * x_mask
w_ceil = torch.ceil(w) * length_scale
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_max_length = y_lengths.max()
y_max_length_ = fix_len_compatibility(y_max_length)
# Using obtained durations `w` construct alignment map `attn`
y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype)
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
# Align encoded text and get mu_y
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
mu_y = mu_y.transpose(1, 2)
encoder_outputs = mu_y[:, :, :y_max_length]
# Generate sample tracing the probability flow
decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, spks)
decoder_outputs = decoder_outputs[:, :, :y_max_length]
t = (dt.datetime.now() - t).total_seconds()
rtf = t * 22050 / (decoder_outputs.shape[-1] * 256)
return {
"encoder_outputs": encoder_outputs,
"decoder_outputs": decoder_outputs,
"attn": attn[:, :, :y_max_length],
"mel": denormalize(decoder_outputs, self.mel_mean, self.mel_std),
"mel_lengths": y_lengths,
"rtf": rtf,
}
def forward(
self,
x,
x_lengths,
y,
y_lengths,
spks=None,
out_size=None,
cond=None,
durations=None,
):
"""
Computes 3 losses:
1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
2. prior loss: loss between mel-spectrogram and encoder outputs.
3. flow matching loss: loss between mel-spectrogram and decoder outputs.
Args:
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
shape: (batch_size, max_text_length)
x_lengths (torch.Tensor): lengths of texts in batch.
shape: (batch_size,)
y (torch.Tensor): batch of corresponding mel-spectrograms.
shape: (batch_size, n_feats, max_mel_length)
y_lengths (torch.Tensor): lengths of mel-spectrograms in batch.
shape: (batch_size,)
out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained.
Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size.
spks (torch.Tensor, optional): speaker ids.
shape: (batch_size,)
"""
if self.n_spks > 1:
# Get speaker embedding
spks = self.spk_emb(spks)
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
mu_x, logw, x_mask = self.encoder(x, x_lengths, spks)
y_max_length = y.shape[-1]
y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask)
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
if self.use_precomputed_durations:
attn = generate_path(durations.squeeze(1), attn_mask.squeeze(1))
else:
# Use MAS to find most likely alignment `attn` between text and mel-spectrogram
with torch.no_grad():
const = -0.5 * math.log(2 * math.pi) * self.n_feats
factor = -0.5 * torch.ones(
mu_x.shape, dtype=mu_x.dtype, device=mu_x.device
)
y_square = torch.matmul(factor.transpose(1, 2), y**2)
y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1)
log_prior = y_square - y_mu_double + mu_square + const
attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1))
attn = attn.detach() # b, t_text, T_mel
# Compute loss between predicted log-scaled durations and those obtained from MAS
# refered to as prior loss in the paper
logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask
dur_loss = duration_loss(logw, logw_, x_lengths)
# Cut a small segment of mel-spectrogram in order to increase batch size
# - "Hack" taken from Grad-TTS, in case of Grad-TTS, we cannot train batch size 32 on a 24GB GPU without it
# - Do not need this hack for Matcha-TTS, but it works with it as well
if not isinstance(out_size, type(None)):
max_offset = (y_lengths - out_size).clamp(0)
offset_ranges = list(
zip([0] * max_offset.shape[0], max_offset.cpu().numpy())
)
out_offset = torch.LongTensor(
[
torch.tensor(random.choice(range(start, end)) if end > start else 0)
for start, end in offset_ranges
]
).to(y_lengths)
attn_cut = torch.zeros(
attn.shape[0],
attn.shape[1],
out_size,
dtype=attn.dtype,
device=attn.device,
)
y_cut = torch.zeros(
y.shape[0], self.n_feats, out_size, dtype=y.dtype, device=y.device
)
y_cut_lengths = []
for i, (y_, out_offset_) in enumerate(zip(y, out_offset)):
y_cut_length = out_size + (y_lengths[i] - out_size).clamp(None, 0)
y_cut_lengths.append(y_cut_length)
cut_lower, cut_upper = out_offset_, out_offset_ + y_cut_length
y_cut[i, :, :y_cut_length] = y_[:, cut_lower:cut_upper]
attn_cut[i, :, :y_cut_length] = attn[i, :, cut_lower:cut_upper]
y_cut_lengths = torch.LongTensor(y_cut_lengths)
y_cut_mask = sequence_mask(y_cut_lengths).unsqueeze(1).to(y_mask)
attn = attn_cut
y = y_cut
y_mask = y_cut_mask
# Align encoded text with mel-spectrogram and get mu_y segment
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
mu_y = mu_y.transpose(1, 2)
# Compute loss of the decoder
diff_loss, _ = self.decoder.compute_loss(
x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond
)
if self.prior_loss:
prior_loss = torch.sum(
0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask
)
prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats)
else:
prior_loss = 0
return dur_loss, prior_loss, diff_loss, attn
def get_losses(self, batch):
x, x_lengths = batch["x"], batch["x_lengths"]
y, y_lengths = batch["y"], batch["y_lengths"]
spks = batch["spks"]
dur_loss, prior_loss, diff_loss, *_ = self(
x=x,
x_lengths=x_lengths,
y=y,
y_lengths=y_lengths,
spks=spks,
out_size=self.out_size,
durations=batch["durations"],
)
return {
"dur_loss": dur_loss,
"prior_loss": prior_loss,
"diff_loss": diff_loss,
}