# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) # # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, Optional, Tuple import librosa import numpy as np import torch from torch import nn from icefall.utils import make_pad_mask # From https://github.com/espnet/espnet/blob/master/espnet2/layers/stft.py class Stft(nn.Module): def __init__( self, n_fft: int = 512, win_length: int = None, hop_length: int = 128, window: Optional[str] = "hann", center: bool = True, normalized: bool = False, onesided: bool = True, ): super().__init__() self.n_fft = n_fft if win_length is None: self.win_length = n_fft else: self.win_length = win_length self.hop_length = hop_length self.center = center self.normalized = normalized self.onesided = onesided if window is not None and not hasattr(torch, f"{window}_window"): raise ValueError(f"{window} window is not implemented") self.window = window def extra_repr(self): return ( f"n_fft={self.n_fft}, " f"win_length={self.win_length}, " f"hop_length={self.hop_length}, " f"center={self.center}, " f"normalized={self.normalized}, " f"onesided={self.onesided}" ) def forward( self, input: torch.Tensor, ilens: torch.Tensor = None ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """STFT forward function. Args: input: (Batch, Nsamples) or (Batch, Nsample, Channels) ilens: (Batch) Returns: output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2) """ bs = input.size(0) if input.dim() == 3: multi_channel = True # input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample) input = input.transpose(1, 2).reshape(-1, input.size(1)) else: multi_channel = False # NOTE(kamo): # The default behaviour of torch.stft is compatible with librosa.stft # about padding and scaling. # Note that it's different from scipy.signal.stft # output: (Batch, Freq, Frames, 2=real_imag) # or (Batch, Channel, Freq, Frames, 2=real_imag) if self.window is not None: window_func = getattr(torch, f"{self.window}_window") window = window_func( self.win_length, dtype=input.dtype, device=input.device ) else: window = None # For the compatibility of ARM devices, which do not support # torch.stft() due to the lack of MKL (on older pytorch versions), # there is an alternative replacement implementation with librosa. # Note: pytorch >= 1.10.0 now has native support for FFT and STFT # on all cpu targets including ARM. if input.is_cuda or torch.backends.mkl.is_available(): stft_kwargs = dict( n_fft=self.n_fft, win_length=self.win_length, hop_length=self.hop_length, center=self.center, window=window, normalized=self.normalized, onesided=self.onesided, ) stft_kwargs["return_complex"] = True output = torch.stft(input, **stft_kwargs) output = torch.view_as_real(output) else: if self.training: raise NotImplementedError( "stft is implemented with librosa on this device, which does not " "support the training mode." ) # use stft_kwargs to flexibly control different PyTorch versions' kwargs # note: librosa does not support a win_length that is < n_ftt # but the window can be manually padded (see below). stft_kwargs = dict( n_fft=self.n_fft, win_length=self.n_fft, hop_length=self.hop_length, center=self.center, window=window, pad_mode="reflect", ) if window is not None: # pad the given window to n_fft n_pad_left = (self.n_fft - window.shape[0]) // 2 n_pad_right = self.n_fft - window.shape[0] - n_pad_left stft_kwargs["window"] = torch.cat( [torch.zeros(n_pad_left), window, torch.zeros(n_pad_right)], 0 ).numpy() else: win_length = ( self.win_length if self.win_length is not None else self.n_fft ) stft_kwargs["window"] = torch.ones(win_length) output = [] # iterate over istances in a batch for i, instance in enumerate(input): stft = librosa.stft(input[i].numpy(), **stft_kwargs) output.append(torch.tensor(np.stack([stft.real, stft.imag], -1))) output = torch.stack(output, 0) if not self.onesided: len_conj = self.n_fft - output.shape[1] conj = output[:, 1 : 1 + len_conj].flip(1) conj[:, :, :, -1].data *= -1 output = torch.cat([output, conj], 1) if self.normalized: output = output * (stft_kwargs["window"].shape[0] ** (-0.5)) # output: (Batch, Freq, Frames, 2=real_imag) # -> (Batch, Frames, Freq, 2=real_imag) output = output.transpose(1, 2) if multi_channel: # output: (Batch * Channel, Frames, Freq, 2=real_imag) # -> (Batch, Frame, Channel, Freq, 2=real_imag) output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose( 1, 2 ) if ilens is not None: if self.center: pad = self.n_fft // 2 ilens = ilens + 2 * pad olens = ( torch.div(ilens - self.n_fft, self.hop_length, rounding_mode="trunc") + 1 ) output.masked_fill_(make_pad_mask(olens), 0.0) else: olens = None return output, olens # From https://github.com/espnet/espnet/blob/master/espnet2/tts/feats_extract/linear_spectrogram.py class LinearSpectrogram(nn.Module): """Linear amplitude spectrogram. Stft -> amplitude-spec """ def __init__( self, n_fft: int = 1024, win_length: int = None, hop_length: int = 256, window: Optional[str] = "hann", center: bool = True, normalized: bool = False, onesided: bool = True, ): super().__init__() self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length self.window = window self.stft = Stft( n_fft=n_fft, win_length=win_length, hop_length=hop_length, window=window, center=center, normalized=normalized, onesided=onesided, ) self.n_fft = n_fft def output_size(self) -> int: return self.n_fft // 2 + 1 def get_parameters(self) -> Dict[str, Any]: """Return the parameters required by Vocoder.""" return dict( n_fft=self.n_fft, n_shift=self.hop_length, win_length=self.win_length, window=self.window, ) def forward( self, input: torch.Tensor, input_lengths: torch.Tensor = None ) -> Tuple[torch.Tensor, torch.Tensor]: # 1. Stft: time -> time-freq input_stft, feats_lens = self.stft(input, input_lengths) assert input_stft.dim() >= 4, input_stft.shape # "2" refers to the real/imag parts of Complex assert input_stft.shape[-1] == 2, input_stft.shape # STFT -> Power spectrum -> Amp spectrum # input_stft: (..., F, 2) -> (..., F) input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2 input_amp = torch.sqrt(torch.clamp(input_power, min=1.0e-10)) return input_amp, feats_lens # From https://github.com/espnet/espnet/blob/master/espnet2/layers/log_mel.py class LogMel(nn.Module): """Convert STFT to fbank feats The arguments is same as librosa.filters.mel Args: fs: number > 0 [scalar] sampling rate of the incoming signal n_fft: int > 0 [scalar] number of FFT components n_mels: int > 0 [scalar] number of Mel bands to generate fmin: float >= 0 [scalar] lowest frequency (in Hz) fmax: float >= 0 [scalar] highest frequency (in Hz). If `None`, use `fmax = fs / 2.0` htk: use HTK formula instead of Slaney """ def __init__( self, fs: int = 16000, n_fft: int = 512, n_mels: int = 80, fmin: float = None, fmax: float = None, htk: bool = False, log_base: float = None, ): super().__init__() fmin = 0 if fmin is None else fmin fmax = fs / 2 if fmax is None else fmax _mel_options = dict( sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, ) self.mel_options = _mel_options self.log_base = log_base # Note(kamo): The mel matrix of librosa is different from kaldi. melmat = librosa.filters.mel(**_mel_options) # melmat: (D2, D1) -> (D1, D2) self.register_buffer("melmat", torch.from_numpy(melmat.T).float()) def extra_repr(self): return ", ".join(f"{k}={v}" for k, v in self.mel_options.items()) def forward( self, feat: torch.Tensor, ilens: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2) mel_feat = torch.matmul(feat, self.melmat) mel_feat = torch.clamp(mel_feat, min=1e-10) if self.log_base is None: logmel_feat = mel_feat.log() elif self.log_base == 2.0: logmel_feat = mel_feat.log2() elif self.log_base == 10.0: logmel_feat = mel_feat.log10() else: logmel_feat = mel_feat.log() / torch.log(self.log_base) # Zero padding if ilens is not None: logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens), 0.0) else: ilens = feat.new_full( [feat.size(0)], fill_value=feat.size(1), dtype=torch.long ) return logmel_feat, ilens # From https://github.com/espnet/espnet/blob/master/espnet2/tts/feats_extract/log_mel_fbank.py class LogMelFbank(nn.Module): """Conventional frontend structure for TTS. Stft -> amplitude-spec -> Log-Mel-Fbank """ def __init__( self, fs: int = 16000, n_fft: int = 1024, win_length: int = None, hop_length: int = 256, window: Optional[str] = "hann", center: bool = True, normalized: bool = False, onesided: bool = True, n_mels: int = 80, fmin: Optional[int] = 80, fmax: Optional[int] = 7600, htk: bool = False, log_base: Optional[float] = 10.0, ): super().__init__() self.fs = fs self.n_mels = n_mels self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length self.window = window self.fmin = fmin self.fmax = fmax self.stft = Stft( n_fft=n_fft, win_length=win_length, hop_length=hop_length, window=window, center=center, normalized=normalized, onesided=onesided, ) self.logmel = LogMel( fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, log_base=log_base, ) def output_size(self) -> int: return self.n_mels def get_parameters(self) -> Dict[str, Any]: """Return the parameters required by Vocoder""" return dict( fs=self.fs, n_fft=self.n_fft, n_shift=self.hop_length, window=self.window, n_mels=self.n_mels, win_length=self.win_length, fmin=self.fmin, fmax=self.fmax, ) def forward( self, input: torch.Tensor, input_lengths: torch.Tensor = None ) -> Tuple[torch.Tensor, torch.Tensor]: # 1. Domain-conversion: e.g. Stft: time -> time-freq input_stft, feats_lens = self.stft(input, input_lengths) assert input_stft.dim() >= 4, input_stft.shape # "2" refers to the real/imag parts of Complex assert input_stft.shape[-1] == 2, input_stft.shape # NOTE(kamo): We use different definition for log-spec between TTS and ASR # TTS: log_10(abs(stft)) # ASR: log_e(power(stft)) # input_stft: (..., F, 2) -> (..., F) input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2 input_amp = torch.sqrt(torch.clamp(input_power, min=1.0e-10)) input_feats, _ = self.logmel(input_amp, feats_lens) return input_feats, feats_lens