mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
265 lines
8.8 KiB
Python
265 lines
8.8 KiB
Python
import logging
|
|
from typing import Optional
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
from torch.autograd import Variable
|
|
from torch.nn import functional as F
|
|
|
|
|
|
def window_sumsquare(
|
|
window: torch.Tensor,
|
|
n_samples: int,
|
|
hop_length: int = 256,
|
|
win_length: int = 1024,
|
|
):
|
|
"""
|
|
Compute the sum-square envelope of a window function at a given hop length.
|
|
This is used to estimate modulation effects induced by windowing
|
|
observations in short-time fourier transforms.
|
|
Parameters
|
|
----------
|
|
window : string, tuple, number, callable, or list-like
|
|
Window specification, as in `get_window`
|
|
n_samples : int > 0
|
|
The number of expected samples.
|
|
hop_length : int > 0
|
|
The number of samples to advance between frames
|
|
win_length :
|
|
The length of the window function.
|
|
Returns
|
|
-------
|
|
wss : torch.Tensor, The sum-squared envelope of the window function.
|
|
"""
|
|
|
|
n_frames = (n_samples - win_length) // hop_length + 1
|
|
output_size = (n_frames - 1) * hop_length + win_length
|
|
device = window.device
|
|
|
|
# Window envelope
|
|
window_sq = window.square().expand(1, n_frames, -1).transpose(1, 2)
|
|
window_envelope = torch.nn.functional.fold(
|
|
window_sq,
|
|
output_size=(1, output_size),
|
|
kernel_size=(1, win_length),
|
|
stride=(1, hop_length),
|
|
).squeeze()
|
|
window_envelope = torch.nn.functional.pad(
|
|
window_envelope, (0, n_samples - output_size)
|
|
)
|
|
return window_envelope
|
|
|
|
|
|
class ISTFT(torch.nn.Module):
|
|
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
|
|
|
def __init__(
|
|
self,
|
|
filter_length: int = 1024,
|
|
hop_length: int = 256,
|
|
win_length: int = 1024,
|
|
padding: str = "none",
|
|
window_type: str = "povey",
|
|
max_samples: int = 1440000, # 1440000 / 24000 = 60s
|
|
):
|
|
super(ISTFT, self).__init__()
|
|
self.filter_length = filter_length
|
|
self.hop_length = hop_length
|
|
self.win_length = win_length
|
|
self.padding = padding
|
|
scale = self.filter_length / self.hop_length
|
|
fourier_basis = np.fft.fft(np.eye(self.filter_length))
|
|
|
|
cutoff = int((self.filter_length / 2 + 1))
|
|
fourier_basis = np.vstack(
|
|
[np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
|
|
)
|
|
inverse_basis = torch.FloatTensor(
|
|
np.linalg.pinv(scale * fourier_basis).T[:, None, :]
|
|
)
|
|
|
|
assert filter_length >= win_length
|
|
# Consistence with lhotse, search "create_frame_window" in https://github.com/lhotse-speech/lhotse
|
|
assert window_type in [
|
|
"hanning",
|
|
"povey",
|
|
], f"Only 'hanning' and 'povey' windows are supported, given {window_type}."
|
|
fft_window = torch.hann_window(win_length, periodic=False)
|
|
if window_type == "povey":
|
|
fft_window = fft_window.pow(0.85)
|
|
|
|
if filter_length > win_length:
|
|
pad_size = (filter_length - win_length) // 2
|
|
fft_window = torch.nn.functional.pad(fft_window, (pad_size, pad_size))
|
|
|
|
window_sum = window_sumsquare(
|
|
window=fft_window,
|
|
n_samples=max_samples,
|
|
hop_length=hop_length,
|
|
win_length=filter_length,
|
|
)
|
|
|
|
inverse_basis *= fft_window
|
|
|
|
self.register_buffer("inverse_basis", inverse_basis.float())
|
|
self.register_buffer("fft_window", fft_window)
|
|
self.register_buffer("window_sum", window_sum)
|
|
self.tiny = torch.finfo(torch.float16).tiny
|
|
|
|
def forward(self, magnitude, phase):
|
|
magnitude_phase = torch.cat(
|
|
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
|
|
)
|
|
inverse_transform = F.conv_transpose1d(
|
|
magnitude_phase,
|
|
Variable(self.inverse_basis, requires_grad=False),
|
|
stride=self.hop_length,
|
|
padding=0,
|
|
)
|
|
inverse_transform = inverse_transform.squeeze(1)
|
|
|
|
window_sum = self.window_sum
|
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
|
if self.window_sum.size(-1) < inverse_transform.size(-1):
|
|
logging.warning(
|
|
f"The precomputed `window_sumsquare` is too small, recomputing, "
|
|
f"from {self.window_sum.size(-1)} to {inverse_transform.size(-1)}"
|
|
)
|
|
window_sum = window_sumsquare(
|
|
window=self.fft_window,
|
|
n_samples=inverse_transform.size(-1),
|
|
win_length=self.filter_length,
|
|
hop_length=self.hop_length,
|
|
)
|
|
window_sum = window_sum[: inverse_transform.size(-1)]
|
|
approx_nonzero_indices = (window_sum > self.tiny).nonzero().squeeze()
|
|
|
|
inverse_transform[:, approx_nonzero_indices] /= window_sum[
|
|
approx_nonzero_indices
|
|
]
|
|
|
|
# scale by hop ratio
|
|
inverse_transform *= float(self.filter_length) / self.hop_length
|
|
assert self.padding in ["none", "same", "center"]
|
|
if self.padding == "center":
|
|
pad_len = self.filter_length // 2
|
|
elif self.padding == "same":
|
|
pad_len = (self.filter_length - self.hop_length) // 2
|
|
else:
|
|
return inverse_transform
|
|
return inverse_transform[:, pad_len:-pad_len]
|
|
|
|
|
|
class ConvNeXtBlock(nn.Module):
|
|
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
|
|
|
|
Args:
|
|
dim (int): Number of input channels.
|
|
intermediate_dim (int): Dimensionality of the intermediate layer.
|
|
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
|
Defaults to None.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
intermediate_dim: int,
|
|
layer_scale_init_value: Optional[float] = None,
|
|
):
|
|
super().__init__()
|
|
self.dwconv = nn.Conv1d(
|
|
dim, dim, kernel_size=7, padding=3, groups=dim
|
|
) # depthwise conv
|
|
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
|
# pointwise/1x1 convs, implemented with linear layers
|
|
self.pwconv1 = nn.Linear(dim, intermediate_dim)
|
|
self.act = nn.GELU()
|
|
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
|
self.gamma = (
|
|
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
|
if layer_scale_init_value > 0
|
|
else None
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
residual = x
|
|
x = self.dwconv(x)
|
|
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
|
x = self.norm(x)
|
|
x = self.pwconv1(x)
|
|
x = self.act(x)
|
|
x = self.pwconv2(x)
|
|
if self.gamma is not None:
|
|
x = self.gamma * x
|
|
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
|
|
|
|
x = residual + x
|
|
return x
|
|
|
|
|
|
class Generator(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
feature_dim: int = 80,
|
|
dim: int = 512,
|
|
n_fft: int = 1024,
|
|
hop_length: int = 256,
|
|
intermediate_dim: int = 1536,
|
|
num_layers: int = 8,
|
|
padding: str = "none",
|
|
max_samples: int = 1440000, # 1440000 / 24000 = 60s
|
|
):
|
|
super(Generator, self).__init__()
|
|
self.feature_dim = feature_dim
|
|
self.embed = nn.Conv1d(feature_dim, dim, kernel_size=7, padding=3)
|
|
|
|
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
|
|
|
layer_scale_init_value = 1 / num_layers
|
|
self.convnext = nn.ModuleList(
|
|
[
|
|
ConvNeXtBlock(
|
|
dim=dim,
|
|
intermediate_dim=intermediate_dim,
|
|
layer_scale_init_value=layer_scale_init_value,
|
|
)
|
|
for _ in range(num_layers)
|
|
]
|
|
)
|
|
|
|
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
|
|
self.apply(self._init_weights)
|
|
|
|
self.out_proj = torch.nn.Linear(dim, n_fft + 2)
|
|
self.istft = ISTFT(
|
|
filter_length=n_fft,
|
|
hop_length=hop_length,
|
|
win_length=n_fft,
|
|
padding=padding,
|
|
max_samples=max_samples,
|
|
)
|
|
|
|
def _init_weights(self, m):
|
|
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
|
nn.init.trunc_normal_(m.weight, std=0.02)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.embed(x)
|
|
x = self.norm(x.transpose(1, 2))
|
|
x = x.transpose(1, 2)
|
|
for conv_block in self.convnext:
|
|
x = conv_block(x)
|
|
x = self.final_layer_norm(x.transpose(1, 2))
|
|
x = self.out_proj(x).transpose(1, 2)
|
|
mag, phase = x.chunk(2, dim=1)
|
|
mag = torch.exp(mag)
|
|
# safeguard to prevent excessively large magnitudes
|
|
mag = torch.clip(mag, max=1e2)
|
|
audio = self.istft(mag, phase)
|
|
return audio
|