mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
179 lines
6.3 KiB
Python
179 lines
6.3 KiB
Python
from typing import Optional
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
|
|
|
|
from spectral_ops import IMDCT, ISTFT
|
|
from modules import symexp
|
|
|
|
|
|
class FourierHead(nn.Module):
|
|
"""Base class for inverse fourier modules."""
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
|
L is the sequence length, and H denotes the model dimension.
|
|
|
|
Returns:
|
|
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
|
"""
|
|
raise NotImplementedError("Subclasses must implement the forward method.")
|
|
|
|
|
|
class ISTFTHead(FourierHead):
|
|
"""
|
|
ISTFT Head module for predicting STFT complex coefficients.
|
|
|
|
Args:
|
|
dim (int): Hidden dimension of the model.
|
|
n_fft (int): Size of Fourier transform.
|
|
hop_length (int): The distance between neighboring sliding window frames, which should align with
|
|
the resolution of the input features.
|
|
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
|
"""
|
|
|
|
def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
|
|
super().__init__()
|
|
out_dim = n_fft + 2
|
|
self.out = torch.nn.Linear(dim, out_dim)
|
|
self.istft = ISTFT(
|
|
n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Forward pass of the ISTFTHead module.
|
|
|
|
Args:
|
|
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
|
L is the sequence length, and H denotes the model dimension.
|
|
|
|
Returns:
|
|
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
|
"""
|
|
x = self.out(x).transpose(1, 2)
|
|
mag, p = x.chunk(2, dim=1)
|
|
mag = torch.exp(mag)
|
|
mag = torch.clip(
|
|
mag, max=1e2
|
|
) # safeguard to prevent excessively large magnitudes
|
|
# wrapping happens here. These two lines produce real and imaginary value
|
|
x = torch.cos(p)
|
|
y = torch.sin(p)
|
|
# recalculating phase here does not produce anything new
|
|
# only costs time
|
|
# phase = torch.atan2(y, x)
|
|
# S = mag * torch.exp(phase * 1j)
|
|
# better directly produce the complex value
|
|
S = mag * (x + 1j * y)
|
|
audio = self.istft(S)
|
|
return audio
|
|
|
|
|
|
class IMDCTSymExpHead(FourierHead):
|
|
"""
|
|
IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
|
|
|
|
Args:
|
|
dim (int): Hidden dimension of the model.
|
|
mdct_frame_len (int): Length of the MDCT frame.
|
|
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
|
sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
|
|
based on perceptual scaling. Defaults to None.
|
|
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
mdct_frame_len: int,
|
|
padding: str = "same",
|
|
sample_rate: Optional[int] = None,
|
|
clip_audio: bool = False,
|
|
):
|
|
super().__init__()
|
|
out_dim = mdct_frame_len // 2
|
|
self.out = nn.Linear(dim, out_dim)
|
|
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
|
|
self.clip_audio = clip_audio
|
|
|
|
if sample_rate is not None:
|
|
# optionally init the last layer following mel-scale
|
|
m_max = _hz_to_mel(sample_rate // 2)
|
|
m_pts = torch.linspace(0, m_max, out_dim)
|
|
f_pts = _mel_to_hz(m_pts)
|
|
scale = 1 - (f_pts / f_pts.max())
|
|
|
|
with torch.no_grad():
|
|
self.out.weight.mul_(scale.view(-1, 1))
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Forward pass of the IMDCTSymExpHead module.
|
|
|
|
Args:
|
|
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
|
L is the sequence length, and H denotes the model dimension.
|
|
|
|
Returns:
|
|
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
|
"""
|
|
x = self.out(x)
|
|
x = symexp(x)
|
|
x = torch.clip(
|
|
x, min=-1e2, max=1e2
|
|
) # safeguard to prevent excessively large magnitudes
|
|
audio = self.imdct(x)
|
|
if self.clip_audio:
|
|
audio = torch.clip(x, min=-1.0, max=1.0)
|
|
|
|
return audio
|
|
|
|
|
|
class IMDCTCosHead(FourierHead):
|
|
"""
|
|
IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
|
|
|
|
Args:
|
|
dim (int): Hidden dimension of the model.
|
|
mdct_frame_len (int): Length of the MDCT frame.
|
|
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
|
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
mdct_frame_len: int,
|
|
padding: str = "same",
|
|
clip_audio: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.clip_audio = clip_audio
|
|
self.out = nn.Linear(dim, mdct_frame_len)
|
|
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Forward pass of the IMDCTCosHead module.
|
|
|
|
Args:
|
|
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
|
L is the sequence length, and H denotes the model dimension.
|
|
|
|
Returns:
|
|
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
|
"""
|
|
x = self.out(x)
|
|
m, p = x.chunk(2, dim=2)
|
|
m = torch.exp(m).clip(
|
|
max=1e2
|
|
) # safeguard to prevent excessively large magnitudes
|
|
audio = self.imdct(m * torch.cos(p))
|
|
if self.clip_audio:
|
|
audio = torch.clip(x, min=-1.0, max=1.0)
|
|
return audio
|