icefall/egs/libritts/TTS/vocos/generator.py
2024-12-13 19:39:55 +08:00

265 lines
8.8 KiB
Python

import logging
from typing import Optional
import numpy as np
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F
def window_sumsquare(
window: torch.Tensor,
n_samples: int,
hop_length: int = 256,
win_length: int = 1024,
):
"""
Compute the sum-square envelope of a window function at a given hop length.
This is used to estimate modulation effects induced by windowing
observations in short-time fourier transforms.
Parameters
----------
window : string, tuple, number, callable, or list-like
Window specification, as in `get_window`
n_samples : int > 0
The number of expected samples.
hop_length : int > 0
The number of samples to advance between frames
win_length :
The length of the window function.
Returns
-------
wss : torch.Tensor, The sum-squared envelope of the window function.
"""
n_frames = (n_samples - win_length) // hop_length + 1
output_size = (n_frames - 1) * hop_length + win_length
device = window.device
# Window envelope
window_sq = window.square().expand(1, n_frames, -1).transpose(1, 2)
window_envelope = torch.nn.functional.fold(
window_sq,
output_size=(1, output_size),
kernel_size=(1, win_length),
stride=(1, hop_length),
).squeeze()
window_envelope = torch.nn.functional.pad(
window_envelope, (0, n_samples - output_size)
)
return window_envelope
class ISTFT(torch.nn.Module):
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
def __init__(
self,
filter_length: int = 1024,
hop_length: int = 256,
win_length: int = 1024,
padding: str = "none",
window_type: str = "povey",
max_samples: int = 1440000, # 1440000 / 24000 = 60s
):
super(ISTFT, self).__init__()
self.filter_length = filter_length
self.hop_length = hop_length
self.win_length = win_length
self.padding = padding
scale = self.filter_length / self.hop_length
fourier_basis = np.fft.fft(np.eye(self.filter_length))
cutoff = int((self.filter_length / 2 + 1))
fourier_basis = np.vstack(
[np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
)
inverse_basis = torch.FloatTensor(
np.linalg.pinv(scale * fourier_basis).T[:, None, :]
)
assert filter_length >= win_length
# Consistence with lhotse, search "create_frame_window" in https://github.com/lhotse-speech/lhotse
assert window_type in [
"hanning",
"povey",
], f"Only 'hanning' and 'povey' windows are supported, given {window_type}."
fft_window = torch.hann_window(win_length, periodic=False)
if window_type == "povey":
fft_window = fft_window.pow(0.85)
if filter_length > win_length:
pad_size = (filter_length - win_length) // 2
fft_window = torch.nn.functional.pad(fft_window, (pad_size, pad_size))
window_sum = window_sumsquare(
window=fft_window,
n_samples=max_samples,
hop_length=hop_length,
win_length=filter_length,
)
inverse_basis *= fft_window
self.register_buffer("inverse_basis", inverse_basis.float())
self.register_buffer("fft_window", fft_window)
self.register_buffer("window_sum", window_sum)
self.tiny = torch.finfo(torch.float16).tiny
def forward(self, magnitude, phase):
magnitude_phase = torch.cat(
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
)
inverse_transform = F.conv_transpose1d(
magnitude_phase,
Variable(self.inverse_basis, requires_grad=False),
stride=self.hop_length,
padding=0,
)
inverse_transform = inverse_transform.squeeze(1)
window_sum = self.window_sum
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
if self.window_sum.size(-1) < inverse_transform.size(-1):
logging.warning(
f"The precomputed `window_sumsquare` is too small, recomputing, "
f"from {self.window_sum.size(-1)} to {inverse_transform.size(-1)}"
)
window_sum = window_sumsquare(
window=self.fft_window,
n_samples=inverse_transform.size(-1),
win_length=self.filter_length,
hop_length=self.hop_length,
)
window_sum = window_sum[: inverse_transform.size(-1)]
approx_nonzero_indices = (window_sum > self.tiny).nonzero().squeeze()
inverse_transform[:, approx_nonzero_indices] /= window_sum[
approx_nonzero_indices
]
# scale by hop ratio
inverse_transform *= float(self.filter_length) / self.hop_length
assert self.padding in ["none", "same", "center"]
if self.padding == "center":
pad_len = self.filter_length // 2
elif self.padding == "same":
pad_len = (self.filter_length - self.hop_length) // 2
else:
return inverse_transform
return inverse_transform[:, pad_len:-pad_len]
class ConvNeXtBlock(nn.Module):
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
Args:
dim (int): Number of input channels.
intermediate_dim (int): Dimensionality of the intermediate layer.
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
Defaults to None.
"""
def __init__(
self,
dim: int,
intermediate_dim: int,
layer_scale_init_value: Optional[float] = None,
):
super().__init__()
self.dwconv = nn.Conv1d(
dim, dim, kernel_size=7, padding=3, groups=dim
) # depthwise conv
self.norm = nn.LayerNorm(dim, eps=1e-6)
# pointwise/1x1 convs, implemented with linear layers
self.pwconv1 = nn.Linear(dim, intermediate_dim)
self.act = nn.GELU()
self.pwconv2 = nn.Linear(intermediate_dim, dim)
self.gamma = (
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
if layer_scale_init_value > 0
else None
)
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
residual = x
x = self.dwconv(x)
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
x = residual + x
return x
class Generator(torch.nn.Module):
def __init__(
self,
feature_dim: int = 80,
dim: int = 512,
n_fft: int = 1024,
hop_length: int = 256,
intermediate_dim: int = 1536,
num_layers: int = 8,
padding: str = "none",
max_samples: int = 1440000, # 1440000 / 24000 = 60s
):
super(Generator, self).__init__()
self.feature_dim = feature_dim
self.embed = nn.Conv1d(feature_dim, dim, kernel_size=7, padding=3)
self.norm = nn.LayerNorm(dim, eps=1e-6)
layer_scale_init_value = 1 / num_layers
self.convnext = nn.ModuleList(
[
ConvNeXtBlock(
dim=dim,
intermediate_dim=intermediate_dim,
layer_scale_init_value=layer_scale_init_value,
)
for _ in range(num_layers)
]
)
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
self.apply(self._init_weights)
self.out_proj = torch.nn.Linear(dim, n_fft + 2)
self.istft = ISTFT(
filter_length=n_fft,
hop_length=hop_length,
win_length=n_fft,
padding=padding,
max_samples=max_samples,
)
def _init_weights(self, m):
if isinstance(m, (nn.Conv1d, nn.Linear)):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.embed(x)
x = self.norm(x.transpose(1, 2))
x = x.transpose(1, 2)
for conv_block in self.convnext:
x = conv_block(x)
x = self.final_layer_norm(x.transpose(1, 2))
x = self.out_proj(x).transpose(1, 2)
mag, phase = x.chunk(2, dim=1)
mag = torch.exp(mag)
# safeguard to prevent excessively large magnitudes
mag = torch.clip(mag, max=1e2)
audio = self.istft(mag, phase)
return audio