icefall/egs/ljspeech/tts/vits/features.py
2023-10-22 23:14:00 +08:00

417 lines
14 KiB
Python

# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple
import librosa
import numpy as np
import torch
from torch import nn
from icefall.utils import make_pad_mask
# From https://github.com/espnet/espnet/blob/master/espnet2/layers/stft.py
class Stft(nn.Module):
def __init__(
self,
n_fft: int = 512,
win_length: int = None,
hop_length: int = 128,
window: Optional[str] = "hann",
center: bool = True,
normalized: bool = False,
onesided: bool = True,
):
super().__init__()
self.n_fft = n_fft
if win_length is None:
self.win_length = n_fft
else:
self.win_length = win_length
self.hop_length = hop_length
self.center = center
self.normalized = normalized
self.onesided = onesided
if window is not None and not hasattr(torch, f"{window}_window"):
raise ValueError(f"{window} window is not implemented")
self.window = window
def extra_repr(self):
return (
f"n_fft={self.n_fft}, "
f"win_length={self.win_length}, "
f"hop_length={self.hop_length}, "
f"center={self.center}, "
f"normalized={self.normalized}, "
f"onesided={self.onesided}"
)
def forward(
self, input: torch.Tensor, ilens: torch.Tensor = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""STFT forward function.
Args:
input: (Batch, Nsamples) or (Batch, Nsample, Channels)
ilens: (Batch)
Returns:
output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2)
"""
bs = input.size(0)
if input.dim() == 3:
multi_channel = True
# input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample)
input = input.transpose(1, 2).reshape(-1, input.size(1))
else:
multi_channel = False
# NOTE(kamo):
# The default behaviour of torch.stft is compatible with librosa.stft
# about padding and scaling.
# Note that it's different from scipy.signal.stft
# output: (Batch, Freq, Frames, 2=real_imag)
# or (Batch, Channel, Freq, Frames, 2=real_imag)
if self.window is not None:
window_func = getattr(torch, f"{self.window}_window")
window = window_func(
self.win_length, dtype=input.dtype, device=input.device
)
else:
window = None
# For the compatibility of ARM devices, which do not support
# torch.stft() due to the lack of MKL (on older pytorch versions),
# there is an alternative replacement implementation with librosa.
# Note: pytorch >= 1.10.0 now has native support for FFT and STFT
# on all cpu targets including ARM.
if input.is_cuda or torch.backends.mkl.is_available():
stft_kwargs = dict(
n_fft=self.n_fft,
win_length=self.win_length,
hop_length=self.hop_length,
center=self.center,
window=window,
normalized=self.normalized,
onesided=self.onesided,
)
stft_kwargs["return_complex"] = True
output = torch.stft(input, **stft_kwargs)
output = torch.view_as_real(output)
else:
if self.training:
raise NotImplementedError(
"stft is implemented with librosa on this device, which does not "
"support the training mode."
)
# use stft_kwargs to flexibly control different PyTorch versions' kwargs
# note: librosa does not support a win_length that is < n_ftt
# but the window can be manually padded (see below).
stft_kwargs = dict(
n_fft=self.n_fft,
win_length=self.n_fft,
hop_length=self.hop_length,
center=self.center,
window=window,
pad_mode="reflect",
)
if window is not None:
# pad the given window to n_fft
n_pad_left = (self.n_fft - window.shape[0]) // 2
n_pad_right = self.n_fft - window.shape[0] - n_pad_left
stft_kwargs["window"] = torch.cat(
[torch.zeros(n_pad_left), window, torch.zeros(n_pad_right)], 0
).numpy()
else:
win_length = (
self.win_length if self.win_length is not None else self.n_fft
)
stft_kwargs["window"] = torch.ones(win_length)
output = []
# iterate over istances in a batch
for i, instance in enumerate(input):
stft = librosa.stft(input[i].numpy(), **stft_kwargs)
output.append(torch.tensor(np.stack([stft.real, stft.imag], -1)))
output = torch.stack(output, 0)
if not self.onesided:
len_conj = self.n_fft - output.shape[1]
conj = output[:, 1 : 1 + len_conj].flip(1)
conj[:, :, :, -1].data *= -1
output = torch.cat([output, conj], 1)
if self.normalized:
output = output * (stft_kwargs["window"].shape[0] ** (-0.5))
# output: (Batch, Freq, Frames, 2=real_imag)
# -> (Batch, Frames, Freq, 2=real_imag)
output = output.transpose(1, 2)
if multi_channel:
# output: (Batch * Channel, Frames, Freq, 2=real_imag)
# -> (Batch, Frame, Channel, Freq, 2=real_imag)
output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose(
1, 2
)
if ilens is not None:
if self.center:
pad = self.n_fft // 2
ilens = ilens + 2 * pad
olens = (
torch.div(ilens - self.n_fft, self.hop_length, rounding_mode="trunc")
+ 1
)
output.masked_fill_(make_pad_mask(olens), 0.0)
else:
olens = None
return output, olens
# From https://github.com/espnet/espnet/blob/master/espnet2/tts/feats_extract/linear_spectrogram.py
class LinearSpectrogram(nn.Module):
"""Linear amplitude spectrogram.
Stft -> amplitude-spec
"""
def __init__(
self,
n_fft: int = 1024,
win_length: int = None,
hop_length: int = 256,
window: Optional[str] = "hann",
center: bool = True,
normalized: bool = False,
onesided: bool = True,
):
super().__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.window = window
self.stft = Stft(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
window=window,
center=center,
normalized=normalized,
onesided=onesided,
)
self.n_fft = n_fft
def output_size(self) -> int:
return self.n_fft // 2 + 1
def get_parameters(self) -> Dict[str, Any]:
"""Return the parameters required by Vocoder."""
return dict(
n_fft=self.n_fft,
n_shift=self.hop_length,
win_length=self.win_length,
window=self.window,
)
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Stft: time -> time-freq
input_stft, feats_lens = self.stft(input, input_lengths)
assert input_stft.dim() >= 4, input_stft.shape
# "2" refers to the real/imag parts of Complex
assert input_stft.shape[-1] == 2, input_stft.shape
# STFT -> Power spectrum -> Amp spectrum
# input_stft: (..., F, 2) -> (..., F)
input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2
input_amp = torch.sqrt(torch.clamp(input_power, min=1.0e-10))
return input_amp, feats_lens
# From https://github.com/espnet/espnet/blob/master/espnet2/layers/log_mel.py
class LogMel(nn.Module):
"""Convert STFT to fbank feats
The arguments is same as librosa.filters.mel
Args:
fs: number > 0 [scalar] sampling rate of the incoming signal
n_fft: int > 0 [scalar] number of FFT components
n_mels: int > 0 [scalar] number of Mel bands to generate
fmin: float >= 0 [scalar] lowest frequency (in Hz)
fmax: float >= 0 [scalar] highest frequency (in Hz).
If `None`, use `fmax = fs / 2.0`
htk: use HTK formula instead of Slaney
"""
def __init__(
self,
fs: int = 16000,
n_fft: int = 512,
n_mels: int = 80,
fmin: float = None,
fmax: float = None,
htk: bool = False,
log_base: float = None,
):
super().__init__()
fmin = 0 if fmin is None else fmin
fmax = fs / 2 if fmax is None else fmax
_mel_options = dict(
sr=fs,
n_fft=n_fft,
n_mels=n_mels,
fmin=fmin,
fmax=fmax,
htk=htk,
)
self.mel_options = _mel_options
self.log_base = log_base
# Note(kamo): The mel matrix of librosa is different from kaldi.
melmat = librosa.filters.mel(**_mel_options)
# melmat: (D2, D1) -> (D1, D2)
self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
def extra_repr(self):
return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
def forward(
self,
feat: torch.Tensor,
ilens: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
mel_feat = torch.matmul(feat, self.melmat)
mel_feat = torch.clamp(mel_feat, min=1e-10)
if self.log_base is None:
logmel_feat = mel_feat.log()
elif self.log_base == 2.0:
logmel_feat = mel_feat.log2()
elif self.log_base == 10.0:
logmel_feat = mel_feat.log10()
else:
logmel_feat = mel_feat.log() / torch.log(self.log_base)
# Zero padding
if ilens is not None:
logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens), 0.0)
else:
ilens = feat.new_full(
[feat.size(0)], fill_value=feat.size(1), dtype=torch.long
)
return logmel_feat, ilens
# From https://github.com/espnet/espnet/blob/master/espnet2/tts/feats_extract/log_mel_fbank.py
class LogMelFbank(nn.Module):
"""Conventional frontend structure for TTS.
Stft -> amplitude-spec -> Log-Mel-Fbank
"""
def __init__(
self,
fs: int = 16000,
n_fft: int = 1024,
win_length: int = None,
hop_length: int = 256,
window: Optional[str] = "hann",
center: bool = True,
normalized: bool = False,
onesided: bool = True,
n_mels: int = 80,
fmin: Optional[int] = 80,
fmax: Optional[int] = 7600,
htk: bool = False,
log_base: Optional[float] = 10.0,
):
super().__init__()
self.fs = fs
self.n_mels = n_mels
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.window = window
self.fmin = fmin
self.fmax = fmax
self.stft = Stft(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
window=window,
center=center,
normalized=normalized,
onesided=onesided,
)
self.logmel = LogMel(
fs=fs,
n_fft=n_fft,
n_mels=n_mels,
fmin=fmin,
fmax=fmax,
htk=htk,
log_base=log_base,
)
def output_size(self) -> int:
return self.n_mels
def get_parameters(self) -> Dict[str, Any]:
"""Return the parameters required by Vocoder"""
return dict(
fs=self.fs,
n_fft=self.n_fft,
n_shift=self.hop_length,
window=self.window,
n_mels=self.n_mels,
win_length=self.win_length,
fmin=self.fmin,
fmax=self.fmax,
)
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Domain-conversion: e.g. Stft: time -> time-freq
input_stft, feats_lens = self.stft(input, input_lengths)
assert input_stft.dim() >= 4, input_stft.shape
# "2" refers to the real/imag parts of Complex
assert input_stft.shape[-1] == 2, input_stft.shape
# NOTE(kamo): We use different definition for log-spec between TTS and ASR
# TTS: log_10(abs(stft))
# ASR: log_e(power(stft))
# input_stft: (..., F, 2) -> (..., F)
input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2
input_amp = torch.sqrt(torch.clamp(input_power, min=1.0e-10))
input_feats, _ = self.logmel(input_amp, feats_lens)
return input_feats, feats_lens