mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
add f5
This commit is contained in:
parent
ad966fb81d
commit
604ab6f6b3
@ -17,6 +17,7 @@ class MatchaFbankConfig:
|
||||
win_length: int
|
||||
f_min: float
|
||||
f_max: float
|
||||
device: str = "cuda"
|
||||
|
||||
|
||||
@register_extractor
|
||||
@ -46,7 +47,7 @@ class MatchaFbank(FeatureExtractor):
|
||||
f"Mismatched sampling rate: extractor expects {expected_sr}, "
|
||||
f"got {sampling_rate}"
|
||||
)
|
||||
samples = torch.from_numpy(samples)
|
||||
samples = torch.from_numpy(samples).to(self.device)
|
||||
assert samples.ndim == 2, samples.shape
|
||||
assert samples.shape[0] == 1, samples.shape
|
||||
|
||||
@ -81,7 +82,7 @@ class MatchaFbank(FeatureExtractor):
|
||||
mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
|
||||
).squeeze(0)
|
||||
|
||||
return mel.numpy()
|
||||
return mel.cpu().numpy()
|
||||
|
||||
@property
|
||||
def frame_shift(self) -> Seconds:
|
||||
|
7
egs/wenetspeech4tts/TTS/f5-tts/model/__init__.py
Normal file
7
egs/wenetspeech4tts/TTS/f5-tts/model/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
# from f5_tts.model.backbones.dit import DiT
|
||||
# from f5_tts.model.backbones.mmdit import MMDiT
|
||||
# from f5_tts.model.backbones.unett import UNetT
|
||||
# from f5_tts.model.cfm import CFM
|
||||
# from f5_tts.model.trainer import Trainer
|
||||
|
||||
# __all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]
|
326
egs/wenetspeech4tts/TTS/f5-tts/model/cfm.py
Normal file
326
egs/wenetspeech4tts/TTS/f5-tts/model/cfm.py
Normal file
@ -0,0 +1,326 @@
|
||||
"""
|
||||
ein notation:
|
||||
b - batch
|
||||
n - sequence
|
||||
nt - text sequence
|
||||
nw - raw wave length
|
||||
d - dimension
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from random import random
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from model.modules import MelSpec
|
||||
from model.utils import (
|
||||
default,
|
||||
exists,
|
||||
lens_to_mask,
|
||||
list_str_to_idx,
|
||||
list_str_to_tensor,
|
||||
mask_from_frac_lengths,
|
||||
)
|
||||
from torch import nn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from torchdiffeq import odeint
|
||||
|
||||
|
||||
class CFM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
transformer: nn.Module,
|
||||
sigma=0.0,
|
||||
odeint_kwargs: dict = dict(
|
||||
# atol = 1e-5,
|
||||
# rtol = 1e-5,
|
||||
method="euler" # 'midpoint'
|
||||
),
|
||||
audio_drop_prob=0.3,
|
||||
cond_drop_prob=0.2,
|
||||
num_channels=None,
|
||||
mel_spec_module: nn.Module | None = None,
|
||||
mel_spec_kwargs: dict = dict(),
|
||||
frac_lengths_mask: tuple[float, float] = (0.7, 1.0),
|
||||
vocab_char_map: dict[str:int] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.frac_lengths_mask = frac_lengths_mask
|
||||
|
||||
# mel spec
|
||||
self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs))
|
||||
num_channels = default(num_channels, self.mel_spec.n_mel_channels)
|
||||
self.num_channels = num_channels
|
||||
|
||||
# classifier-free guidance
|
||||
self.audio_drop_prob = audio_drop_prob
|
||||
self.cond_drop_prob = cond_drop_prob
|
||||
|
||||
# transformer
|
||||
self.transformer = transformer
|
||||
dim = transformer.dim
|
||||
self.dim = dim
|
||||
|
||||
# conditional flow related
|
||||
self.sigma = sigma
|
||||
|
||||
# sampling related
|
||||
self.odeint_kwargs = odeint_kwargs
|
||||
|
||||
# vocab map for tokenization
|
||||
self.vocab_char_map = vocab_char_map
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(
|
||||
self,
|
||||
cond: float["b n d"] | float["b nw"], # noqa: F722
|
||||
text: int["b nt"] | list[str], # noqa: F722
|
||||
duration: int | int["b"], # noqa: F821
|
||||
*,
|
||||
lens: int["b"] | None = None, # noqa: F821
|
||||
steps=32,
|
||||
cfg_strength=1.0,
|
||||
sway_sampling_coef=None,
|
||||
seed: int | None = None,
|
||||
max_duration=4096,
|
||||
vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
|
||||
no_ref_audio=False,
|
||||
duplicate_test=False,
|
||||
t_inter=0.1,
|
||||
edit_mask=None,
|
||||
):
|
||||
self.eval()
|
||||
# raw wave
|
||||
|
||||
if cond.ndim == 2:
|
||||
cond = self.mel_spec(cond)
|
||||
cond = cond.permute(0, 2, 1)
|
||||
assert cond.shape[-1] == self.num_channels
|
||||
|
||||
cond = cond.to(next(self.parameters()).dtype)
|
||||
|
||||
batch, cond_seq_len, device = *cond.shape[:2], cond.device
|
||||
if not exists(lens):
|
||||
lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
|
||||
|
||||
# text
|
||||
|
||||
if isinstance(text, list):
|
||||
if exists(self.vocab_char_map):
|
||||
text = list_str_to_idx(text, self.vocab_char_map).to(device)
|
||||
else:
|
||||
text = list_str_to_tensor(text).to(device)
|
||||
assert text.shape[0] == batch
|
||||
|
||||
if exists(text):
|
||||
text_lens = (text != -1).sum(dim=-1)
|
||||
lens = torch.maximum(
|
||||
text_lens, lens
|
||||
) # make sure lengths are at least those of the text characters
|
||||
|
||||
# duration
|
||||
|
||||
cond_mask = lens_to_mask(lens)
|
||||
if edit_mask is not None:
|
||||
cond_mask = cond_mask & edit_mask
|
||||
|
||||
if isinstance(duration, int):
|
||||
duration = torch.full((batch,), duration, device=device, dtype=torch.long)
|
||||
|
||||
duration = torch.maximum(
|
||||
lens + 1, duration
|
||||
) # just add one token so something is generated
|
||||
duration = duration.clamp(max=max_duration)
|
||||
max_duration = duration.amax()
|
||||
|
||||
# duplicate test corner for inner time step oberservation
|
||||
if duplicate_test:
|
||||
test_cond = F.pad(
|
||||
cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0
|
||||
)
|
||||
|
||||
cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
|
||||
cond_mask = F.pad(
|
||||
cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False
|
||||
)
|
||||
cond_mask = cond_mask.unsqueeze(-1)
|
||||
step_cond = torch.where(
|
||||
cond_mask, cond, torch.zeros_like(cond)
|
||||
) # allow direct control (cut cond audio) with lens passed in
|
||||
|
||||
if batch > 1:
|
||||
mask = lens_to_mask(duration)
|
||||
else: # save memory and speed up, as single inference need no mask currently
|
||||
mask = None
|
||||
|
||||
# test for no ref audio
|
||||
if no_ref_audio:
|
||||
cond = torch.zeros_like(cond)
|
||||
|
||||
# neural ode
|
||||
|
||||
def fn(t, x):
|
||||
# at each step, conditioning is fixed
|
||||
# step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
|
||||
|
||||
# predict flow
|
||||
pred = self.transformer(
|
||||
x=x,
|
||||
cond=step_cond,
|
||||
text=text,
|
||||
time=t,
|
||||
mask=mask,
|
||||
drop_audio_cond=False,
|
||||
drop_text=False,
|
||||
)
|
||||
if cfg_strength < 1e-5:
|
||||
return pred
|
||||
|
||||
null_pred = self.transformer(
|
||||
x=x,
|
||||
cond=step_cond,
|
||||
text=text,
|
||||
time=t,
|
||||
mask=mask,
|
||||
drop_audio_cond=True,
|
||||
drop_text=True,
|
||||
)
|
||||
return pred + (pred - null_pred) * cfg_strength
|
||||
|
||||
# noise input
|
||||
# to make sure batch inference result is same with different batch size, and for sure single inference
|
||||
# still some difference maybe due to convolutional layers
|
||||
y0 = []
|
||||
for dur in duration:
|
||||
if exists(seed):
|
||||
torch.manual_seed(seed)
|
||||
y0.append(
|
||||
torch.randn(
|
||||
dur, self.num_channels, device=self.device, dtype=step_cond.dtype
|
||||
)
|
||||
)
|
||||
y0 = pad_sequence(y0, padding_value=0, batch_first=True)
|
||||
|
||||
t_start = 0
|
||||
|
||||
# duplicate test corner for inner time step oberservation
|
||||
if duplicate_test:
|
||||
t_start = t_inter
|
||||
y0 = (1 - t_start) * y0 + t_start * test_cond
|
||||
steps = int(steps * (1 - t_start))
|
||||
|
||||
t = torch.linspace(
|
||||
t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype
|
||||
)
|
||||
if sway_sampling_coef is not None:
|
||||
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
|
||||
|
||||
trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
|
||||
|
||||
sampled = trajectory[-1]
|
||||
out = sampled
|
||||
out = torch.where(cond_mask, cond, out)
|
||||
|
||||
if exists(vocoder):
|
||||
out = out.permute(0, 2, 1)
|
||||
out = vocoder(out)
|
||||
|
||||
return out, trajectory
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722
|
||||
text: int["b nt"] | list[str], # noqa: F722
|
||||
*,
|
||||
lens: int["b"] | None = None, # noqa: F821
|
||||
noise_scheduler: str | None = None,
|
||||
):
|
||||
# handle raw wave
|
||||
if inp.ndim == 2:
|
||||
inp = self.mel_spec(inp)
|
||||
inp = inp.permute(0, 2, 1)
|
||||
assert inp.shape[-1] == self.num_channels
|
||||
|
||||
batch, seq_len, dtype, device, _σ1 = (
|
||||
*inp.shape[:2],
|
||||
inp.dtype,
|
||||
self.device,
|
||||
self.sigma,
|
||||
)
|
||||
|
||||
# handle text as string
|
||||
if isinstance(text, list):
|
||||
if exists(self.vocab_char_map):
|
||||
text = list_str_to_idx(text, self.vocab_char_map).to(device)
|
||||
else:
|
||||
text = list_str_to_tensor(text).to(device)
|
||||
assert text.shape[0] == batch
|
||||
|
||||
# lens and mask
|
||||
if not exists(lens):
|
||||
lens = torch.full((batch,), seq_len, device=device)
|
||||
|
||||
mask = lens_to_mask(
|
||||
lens, length=seq_len
|
||||
) # useless here, as collate_fn will pad to max length in batch
|
||||
|
||||
# get a random span to mask out for training conditionally
|
||||
frac_lengths = (
|
||||
torch.zeros((batch,), device=self.device)
|
||||
.float()
|
||||
.uniform_(*self.frac_lengths_mask)
|
||||
)
|
||||
rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
|
||||
|
||||
if exists(mask):
|
||||
rand_span_mask &= mask
|
||||
|
||||
# mel is x1
|
||||
x1 = inp
|
||||
|
||||
# x0 is gaussian noise
|
||||
x0 = torch.randn_like(x1)
|
||||
|
||||
# time step
|
||||
time = torch.rand((batch,), dtype=dtype, device=self.device)
|
||||
# TODO. noise_scheduler
|
||||
|
||||
# sample xt (φ_t(x) in the paper)
|
||||
t = time.unsqueeze(-1).unsqueeze(-1)
|
||||
φ = (1 - t) * x0 + t * x1
|
||||
flow = x1 - x0
|
||||
|
||||
# only predict what is within the random mask span for infilling
|
||||
cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1)
|
||||
|
||||
# transformer and cfg training with a drop rate
|
||||
drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
|
||||
if random() < self.cond_drop_prob: # p_uncond in voicebox paper
|
||||
drop_audio_cond = True
|
||||
drop_text = True
|
||||
else:
|
||||
drop_text = False
|
||||
|
||||
# if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
|
||||
# adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
|
||||
pred = self.transformer(
|
||||
x=φ,
|
||||
cond=cond,
|
||||
text=text,
|
||||
time=time,
|
||||
drop_audio_cond=drop_audio_cond,
|
||||
drop_text=drop_text,
|
||||
)
|
||||
|
||||
# flow matching loss
|
||||
loss = F.mse_loss(pred, flow, reduction="none")
|
||||
loss = loss[rand_span_mask]
|
||||
|
||||
return loss.mean(), cond, pred
|
728
egs/wenetspeech4tts/TTS/f5-tts/model/modules.py
Normal file
728
egs/wenetspeech4tts/TTS/f5-tts/model/modules.py
Normal file
@ -0,0 +1,728 @@
|
||||
"""
|
||||
ein notation:
|
||||
b - batch
|
||||
n - sequence
|
||||
nt - text sequence
|
||||
nw - raw wave length
|
||||
d - dimension
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
from torch import nn
|
||||
from x_transformers.x_transformers import apply_rotary_pos_emb
|
||||
|
||||
# raw wav to mel spec
|
||||
|
||||
|
||||
mel_basis_cache = {}
|
||||
hann_window_cache = {}
|
||||
|
||||
|
||||
def get_bigvgan_mel_spectrogram(
|
||||
waveform,
|
||||
n_fft=1024,
|
||||
n_mel_channels=100,
|
||||
target_sample_rate=24000,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
fmin=0,
|
||||
fmax=None,
|
||||
center=False,
|
||||
): # Copy from https://github.com/NVIDIA/BigVGAN/tree/main
|
||||
device = waveform.device
|
||||
key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}"
|
||||
|
||||
if key not in mel_basis_cache:
|
||||
mel = librosa_mel_fn(
|
||||
sr=target_sample_rate,
|
||||
n_fft=n_fft,
|
||||
n_mels=n_mel_channels,
|
||||
fmin=fmin,
|
||||
fmax=fmax,
|
||||
)
|
||||
mel_basis_cache[key] = (
|
||||
torch.from_numpy(mel).float().to(device)
|
||||
) # TODO: why they need .float()?
|
||||
hann_window_cache[key] = torch.hann_window(win_length).to(device)
|
||||
|
||||
mel_basis = mel_basis_cache[key]
|
||||
hann_window = hann_window_cache[key]
|
||||
|
||||
padding = (n_fft - hop_length) // 2
|
||||
waveform = torch.nn.functional.pad(
|
||||
waveform.unsqueeze(1), (padding, padding), mode="reflect"
|
||||
).squeeze(1)
|
||||
|
||||
spec = torch.stft(
|
||||
waveform,
|
||||
n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
window=hann_window,
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=True,
|
||||
)
|
||||
spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
|
||||
|
||||
mel_spec = torch.matmul(mel_basis, spec)
|
||||
mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
|
||||
|
||||
return mel_spec
|
||||
|
||||
|
||||
def get_vocos_mel_spectrogram(
|
||||
waveform,
|
||||
n_fft=1024,
|
||||
n_mel_channels=100,
|
||||
target_sample_rate=24000,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
):
|
||||
mel_stft = torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=target_sample_rate,
|
||||
n_fft=n_fft,
|
||||
win_length=win_length,
|
||||
hop_length=hop_length,
|
||||
n_mels=n_mel_channels,
|
||||
power=1,
|
||||
center=True,
|
||||
normalized=False,
|
||||
norm=None,
|
||||
).to(waveform.device)
|
||||
if len(waveform.shape) == 3:
|
||||
waveform = waveform.squeeze(1) # 'b 1 nw -> b nw'
|
||||
|
||||
assert len(waveform.shape) == 2
|
||||
|
||||
mel = mel_stft(waveform)
|
||||
mel = mel.clamp(min=1e-5).log()
|
||||
return mel
|
||||
|
||||
|
||||
class MelSpec(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_fft=1024,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
n_mel_channels=100,
|
||||
target_sample_rate=24_000,
|
||||
mel_spec_type="vocos",
|
||||
):
|
||||
super().__init__()
|
||||
assert mel_spec_type in ["vocos", "bigvgan"], print(
|
||||
"We only support two extract mel backend: vocos or bigvgan"
|
||||
)
|
||||
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.n_mel_channels = n_mel_channels
|
||||
self.target_sample_rate = target_sample_rate
|
||||
|
||||
if mel_spec_type == "vocos":
|
||||
self.extractor = get_vocos_mel_spectrogram
|
||||
elif mel_spec_type == "bigvgan":
|
||||
self.extractor = get_bigvgan_mel_spectrogram
|
||||
|
||||
self.register_buffer("dummy", torch.tensor(0), persistent=False)
|
||||
|
||||
def forward(self, wav):
|
||||
if self.dummy.device != wav.device:
|
||||
self.to(wav.device)
|
||||
|
||||
mel = self.extractor(
|
||||
waveform=wav,
|
||||
n_fft=self.n_fft,
|
||||
n_mel_channels=self.n_mel_channels,
|
||||
target_sample_rate=self.target_sample_rate,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
)
|
||||
|
||||
return mel
|
||||
|
||||
|
||||
# sinusoidal position embedding
|
||||
|
||||
|
||||
class SinusPositionEmbedding(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x, scale=1000):
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
||||
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
# convolutional position embedding
|
||||
|
||||
|
||||
class ConvPositionEmbedding(nn.Module):
|
||||
def __init__(self, dim, kernel_size=31, groups=16):
|
||||
super().__init__()
|
||||
assert kernel_size % 2 != 0
|
||||
self.conv1d = nn.Sequential(
|
||||
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
|
||||
nn.Mish(),
|
||||
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
|
||||
nn.Mish(),
|
||||
)
|
||||
|
||||
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
|
||||
if mask is not None:
|
||||
mask = mask[..., None]
|
||||
x = x.masked_fill(~mask, 0.0)
|
||||
|
||||
x = x.permute(0, 2, 1)
|
||||
x = self.conv1d(x)
|
||||
out = x.permute(0, 2, 1)
|
||||
|
||||
if mask is not None:
|
||||
out = out.masked_fill(~mask, 0.0)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
# rotary positional embedding related
|
||||
|
||||
|
||||
def precompute_freqs_cis(
|
||||
dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0
|
||||
):
|
||||
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
||||
# has some connection to NTK literature
|
||||
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||
# https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
|
||||
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
t = torch.arange(end, device=freqs.device) # type: ignore
|
||||
freqs = torch.outer(t, freqs).float() # type: ignore
|
||||
freqs_cos = torch.cos(freqs) # real part
|
||||
freqs_sin = torch.sin(freqs) # imaginary part
|
||||
return torch.cat([freqs_cos, freqs_sin], dim=-1)
|
||||
|
||||
|
||||
def get_pos_embed_indices(start, length, max_pos, scale=1.0):
|
||||
# length = length if isinstance(length, int) else length.max()
|
||||
scale = scale * torch.ones_like(
|
||||
start, dtype=torch.float32
|
||||
) # in case scale is a scalar
|
||||
pos = (
|
||||
start.unsqueeze(1)
|
||||
+ (
|
||||
torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0)
|
||||
* scale.unsqueeze(1)
|
||||
).long()
|
||||
)
|
||||
# avoid extra long error.
|
||||
pos = torch.where(pos < max_pos, pos, max_pos - 1)
|
||||
return pos
|
||||
|
||||
|
||||
# Global Response Normalization layer (Instance Normalization ?)
|
||||
|
||||
|
||||
class GRN(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
|
||||
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
|
||||
|
||||
def forward(self, x):
|
||||
Gx = torch.norm(x, p=2, dim=1, keepdim=True)
|
||||
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
||||
return self.gamma * (x * Nx) + self.beta + x
|
||||
|
||||
|
||||
# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
|
||||
# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
|
||||
|
||||
|
||||
class ConvNeXtV2Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
intermediate_dim: int,
|
||||
dilation: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
padding = (dilation * (7 - 1)) // 2
|
||||
self.dwconv = nn.Conv1d(
|
||||
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
|
||||
) # depthwise conv
|
||||
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
||||
self.pwconv1 = nn.Linear(
|
||||
dim, intermediate_dim
|
||||
) # pointwise/1x1 convs, implemented with linear layers
|
||||
self.act = nn.GELU()
|
||||
self.grn = GRN(intermediate_dim)
|
||||
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
residual = x
|
||||
x = x.transpose(1, 2) # b n d -> b d n
|
||||
x = self.dwconv(x)
|
||||
x = x.transpose(1, 2) # b d n -> b n d
|
||||
x = self.norm(x)
|
||||
x = self.pwconv1(x)
|
||||
x = self.act(x)
|
||||
x = self.grn(x)
|
||||
x = self.pwconv2(x)
|
||||
return residual + x
|
||||
|
||||
|
||||
# AdaLayerNormZero
|
||||
# return with modulated x for attn input, and params for later mlp modulation
|
||||
|
||||
|
||||
class AdaLayerNormZero(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(dim, dim * 6)
|
||||
|
||||
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
def forward(self, x, emb=None):
|
||||
emb = self.linear(self.silu(emb))
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(
|
||||
emb, 6, dim=1
|
||||
)
|
||||
|
||||
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
||||
|
||||
|
||||
# AdaLayerNormZero for final layer
|
||||
# return only with modulated x for attn input, cuz no more mlp modulation
|
||||
|
||||
|
||||
class AdaLayerNormZero_Final(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(dim, dim * 2)
|
||||
|
||||
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
def forward(self, x, emb):
|
||||
emb = self.linear(self.silu(emb))
|
||||
scale, shift = torch.chunk(emb, 2, dim=1)
|
||||
|
||||
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||
return x
|
||||
|
||||
|
||||
# FeedForward
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
|
||||
activation = nn.GELU(approximate=approximate)
|
||||
project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
|
||||
self.ff = nn.Sequential(
|
||||
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.ff(x)
|
||||
|
||||
|
||||
# Attention with possible joint part
|
||||
# modified from diffusers/src/diffusers/models/attention_processor.py
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
processor: JointAttnProcessor | AttnProcessor,
|
||||
dim: int,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
context_dim: Optional[int] = None, # if not None -> joint attention
|
||||
context_pre_only=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
self.processor = processor
|
||||
|
||||
self.dim = dim
|
||||
self.heads = heads
|
||||
self.inner_dim = dim_head * heads
|
||||
self.dropout = dropout
|
||||
|
||||
self.context_dim = context_dim
|
||||
self.context_pre_only = context_pre_only
|
||||
|
||||
self.to_q = nn.Linear(dim, self.inner_dim)
|
||||
self.to_k = nn.Linear(dim, self.inner_dim)
|
||||
self.to_v = nn.Linear(dim, self.inner_dim)
|
||||
|
||||
if self.context_dim is not None:
|
||||
self.to_k_c = nn.Linear(context_dim, self.inner_dim)
|
||||
self.to_v_c = nn.Linear(context_dim, self.inner_dim)
|
||||
if self.context_pre_only is not None:
|
||||
self.to_q_c = nn.Linear(context_dim, self.inner_dim)
|
||||
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(nn.Linear(self.inner_dim, dim))
|
||||
self.to_out.append(nn.Dropout(dropout))
|
||||
|
||||
if self.context_pre_only is not None and not self.context_pre_only:
|
||||
self.to_out_c = nn.Linear(self.inner_dim, dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: float["b n d"], # noised input x # noqa: F722
|
||||
c: float["b n d"] = None, # context c # noqa: F722
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
rope=None, # rotary position embedding for x
|
||||
c_rope=None, # rotary position embedding for c
|
||||
) -> torch.Tensor:
|
||||
if c is not None:
|
||||
return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
|
||||
else:
|
||||
return self.processor(self, x, mask=mask, rope=rope)
|
||||
|
||||
|
||||
# Attention processor
|
||||
|
||||
|
||||
class AttnProcessor:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
x: float["b n d"], # noised input x # noqa: F722
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
rope=None, # rotary position embedding
|
||||
) -> torch.FloatTensor:
|
||||
batch_size = x.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
query = attn.to_q(x)
|
||||
key = attn.to_k(x)
|
||||
value = attn.to_v(x)
|
||||
|
||||
# apply rotary position embedding
|
||||
if rope is not None:
|
||||
freqs, xpos_scale = rope
|
||||
q_xpos_scale, k_xpos_scale = (
|
||||
(xpos_scale, xpos_scale**-1.0)
|
||||
if xpos_scale is not None
|
||||
else (1.0, 1.0)
|
||||
)
|
||||
|
||||
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
||||
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
||||
|
||||
# attention
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
||||
if mask is not None:
|
||||
attn_mask = mask
|
||||
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
||||
attn_mask = attn_mask.expand(
|
||||
batch_size, attn.heads, query.shape[-2], key.shape[-2]
|
||||
)
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
x = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
x = x.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
x = attn.to_out[0](x)
|
||||
# dropout
|
||||
x = attn.to_out[1](x)
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(-1)
|
||||
x = x.masked_fill(~mask, 0.0)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# Joint Attention processor for MM-DiT
|
||||
# modified from diffusers/src/diffusers/models/attention_processor.py
|
||||
|
||||
|
||||
class JointAttnProcessor:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
x: float["b n d"], # noised input x # noqa: F722
|
||||
c: float["b nt d"] = None, # context c, here text # noqa: F722
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
rope=None, # rotary position embedding for x
|
||||
c_rope=None, # rotary position embedding for c
|
||||
) -> torch.FloatTensor:
|
||||
residual = x
|
||||
|
||||
batch_size = c.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
query = attn.to_q(x)
|
||||
key = attn.to_k(x)
|
||||
value = attn.to_v(x)
|
||||
|
||||
# `context` projections.
|
||||
c_query = attn.to_q_c(c)
|
||||
c_key = attn.to_k_c(c)
|
||||
c_value = attn.to_v_c(c)
|
||||
|
||||
# apply rope for context and noised input independently
|
||||
if rope is not None:
|
||||
freqs, xpos_scale = rope
|
||||
q_xpos_scale, k_xpos_scale = (
|
||||
(xpos_scale, xpos_scale**-1.0)
|
||||
if xpos_scale is not None
|
||||
else (1.0, 1.0)
|
||||
)
|
||||
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
||||
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
||||
if c_rope is not None:
|
||||
freqs, xpos_scale = c_rope
|
||||
q_xpos_scale, k_xpos_scale = (
|
||||
(xpos_scale, xpos_scale**-1.0)
|
||||
if xpos_scale is not None
|
||||
else (1.0, 1.0)
|
||||
)
|
||||
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
|
||||
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
|
||||
|
||||
# attention
|
||||
query = torch.cat([query, c_query], dim=1)
|
||||
key = torch.cat([key, c_key], dim=1)
|
||||
value = torch.cat([value, c_value], dim=1)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
||||
if mask is not None:
|
||||
attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
|
||||
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
||||
attn_mask = attn_mask.expand(
|
||||
batch_size, attn.heads, query.shape[-2], key.shape[-2]
|
||||
)
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
x = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
x = x.to(query.dtype)
|
||||
|
||||
# Split the attention outputs.
|
||||
x, c = (
|
||||
x[:, : residual.shape[1]],
|
||||
x[:, residual.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
x = attn.to_out[0](x)
|
||||
# dropout
|
||||
x = attn.to_out[1](x)
|
||||
if not attn.context_pre_only:
|
||||
c = attn.to_out_c(c)
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(-1)
|
||||
x = x.masked_fill(~mask, 0.0)
|
||||
# c = c.masked_fill(~mask, 0.) # no mask for c (text)
|
||||
|
||||
return x, c
|
||||
|
||||
|
||||
# DiT Block
|
||||
|
||||
|
||||
class DiTBlock(nn.Module):
|
||||
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
|
||||
super().__init__()
|
||||
|
||||
self.attn_norm = AdaLayerNormZero(dim)
|
||||
self.attn = Attention(
|
||||
processor=AttnProcessor(),
|
||||
dim=dim,
|
||||
heads=heads,
|
||||
dim_head=dim_head,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff = FeedForward(
|
||||
dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh"
|
||||
)
|
||||
|
||||
def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
|
||||
# pre-norm & modulation for attention input
|
||||
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
|
||||
|
||||
# attention
|
||||
attn_output = self.attn(x=norm, mask=mask, rope=rope)
|
||||
|
||||
# process attention output for input x
|
||||
x = x + gate_msa.unsqueeze(1) * attn_output
|
||||
|
||||
norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
ff_output = self.ff(norm)
|
||||
x = x + gate_mlp.unsqueeze(1) * ff_output
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# MMDiT Block https://arxiv.org/abs/2403.03206
|
||||
|
||||
|
||||
class MMDiTBlock(nn.Module):
|
||||
r"""
|
||||
modified from diffusers/src/diffusers/models/attention.py
|
||||
|
||||
notes.
|
||||
_c: context related. text, cond, etc. (left part in sd3 fig2.b)
|
||||
_x: noised input related. (right part)
|
||||
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.context_pre_only = context_pre_only
|
||||
|
||||
self.attn_norm_c = (
|
||||
AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
|
||||
)
|
||||
self.attn_norm_x = AdaLayerNormZero(dim)
|
||||
self.attn = Attention(
|
||||
processor=JointAttnProcessor(),
|
||||
dim=dim,
|
||||
heads=heads,
|
||||
dim_head=dim_head,
|
||||
dropout=dropout,
|
||||
context_dim=dim,
|
||||
context_pre_only=context_pre_only,
|
||||
)
|
||||
|
||||
if not context_pre_only:
|
||||
self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_c = FeedForward(
|
||||
dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh"
|
||||
)
|
||||
else:
|
||||
self.ff_norm_c = None
|
||||
self.ff_c = None
|
||||
self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_x = FeedForward(
|
||||
dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh"
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, x, c, t, mask=None, rope=None, c_rope=None
|
||||
): # x: noised input, c: context, t: time embedding
|
||||
# pre-norm & modulation for attention input
|
||||
if self.context_pre_only:
|
||||
norm_c = self.attn_norm_c(c, t)
|
||||
else:
|
||||
norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(
|
||||
c, emb=t
|
||||
)
|
||||
norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(
|
||||
x, emb=t
|
||||
)
|
||||
|
||||
# attention
|
||||
x_attn_output, c_attn_output = self.attn(
|
||||
x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope
|
||||
)
|
||||
|
||||
# process attention output for context c
|
||||
if self.context_pre_only:
|
||||
c = None
|
||||
else: # if not last layer
|
||||
c = c + c_gate_msa.unsqueeze(1) * c_attn_output
|
||||
|
||||
norm_c = (
|
||||
self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
||||
)
|
||||
c_ff_output = self.ff_c(norm_c)
|
||||
c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
|
||||
|
||||
# process attention output for input x
|
||||
x = x + x_gate_msa.unsqueeze(1) * x_attn_output
|
||||
|
||||
norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
|
||||
x_ff_output = self.ff_x(norm_x)
|
||||
x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
|
||||
|
||||
return c, x
|
||||
|
||||
|
||||
# time step conditioning embedding
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(self, dim, freq_embed_dim=256):
|
||||
super().__init__()
|
||||
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
|
||||
self.time_mlp = nn.Sequential(
|
||||
nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)
|
||||
)
|
||||
|
||||
def forward(self, timestep: float["b"]): # noqa: F821
|
||||
time_hidden = self.time_embed(timestep)
|
||||
time_hidden = time_hidden.to(timestep.dtype)
|
||||
time = self.time_mlp(time_hidden) # b d
|
||||
return time
|
1192
egs/wenetspeech4tts/TTS/f5-tts/train.py
Executable file
1192
egs/wenetspeech4tts/TTS/f5-tts/train.py
Executable file
File diff suppressed because it is too large
Load Diff
346
egs/wenetspeech4tts/TTS/f5-tts/tts_datamodule.py
Normal file
346
egs/wenetspeech4tts/TTS/f5-tts/tts_datamodule.py
Normal file
@ -0,0 +1,346 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from fbank import MatchaFbank, MatchaFbankConfig
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
PrecomputedFeatures,
|
||||
SimpleCutSampler,
|
||||
SpeechSynthesisDataset,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
AudioSamples,
|
||||
OnTheFlyFeatures,
|
||||
)
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class TtsDataModule:
|
||||
"""
|
||||
DataModule for tts experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- cut concatenation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="TTS data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/fbank"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--input-strategy",
|
||||
type=str,
|
||||
default="PrecomputedFeatures",
|
||||
help="AudioSamples or PrecomputedFeatures",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefix",
|
||||
type=str,
|
||||
default="wenetspeech4tts",
|
||||
help="prefix of the manifest file",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
logging.info("About to create train dataset")
|
||||
train = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = 22050
|
||||
config = MatchaFbankConfig(
|
||||
n_fft=1024,
|
||||
n_mels=80,
|
||||
sampling_rate=sampling_rate,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
)
|
||||
train = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
buffer_size=self.args.num_buckets * 2000,
|
||||
shuffle_buffer_size=self.args.num_buckets * 5000,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=True,
|
||||
pin_memory=True,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = 22050
|
||||
config = MatchaFbankConfig(
|
||||
n_fft=1024,
|
||||
n_mels=80,
|
||||
sampling_rate=sampling_rate,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
)
|
||||
validate = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
validate = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
num_buckets=self.args.num_buckets,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create valid dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=True,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.info("About to create test dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = 22050
|
||||
config = MatchaFbankConfig(
|
||||
n_fft=1024,
|
||||
n_mels=80,
|
||||
sampling_rate=sampling_rate,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
)
|
||||
test = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
test = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
test_sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration,
|
||||
num_buckets=self.args.num_buckets,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=test_sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / f"{self.args.prefix}_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def valid_cuts(self) -> CutSet:
|
||||
logging.info("About to get validation cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / f"{self.args.prefix}_cuts_valid.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts(self) -> CutSet:
|
||||
logging.info("About to get test cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / f"{self.args.prefix}_cuts_test.jsonl.gz"
|
||||
)
|
122
egs/wenetspeech4tts/TTS/local/audio.py
Normal file
122
egs/wenetspeech4tts/TTS/local/audio.py
Normal file
@ -0,0 +1,122 @@
|
||||
# Copyright (c) 2024 NVIDIA CORPORATION.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import math
|
||||
import os
|
||||
import pathlib
|
||||
import random
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
from tqdm import tqdm
|
||||
|
||||
# from env import AttrDict
|
||||
|
||||
MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases)
|
||||
|
||||
|
||||
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
||||
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression(x, C=1):
|
||||
return np.exp(x) / C
|
||||
|
||||
|
||||
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression_torch(x, C=1):
|
||||
return torch.exp(x) / C
|
||||
|
||||
|
||||
def spectral_normalize_torch(magnitudes):
|
||||
return dynamic_range_compression_torch(magnitudes)
|
||||
|
||||
|
||||
def spectral_de_normalize_torch(magnitudes):
|
||||
return dynamic_range_decompression_torch(magnitudes)
|
||||
|
||||
|
||||
mel_basis_cache = {}
|
||||
hann_window_cache = {}
|
||||
|
||||
|
||||
def mel_spectrogram(
|
||||
y: torch.Tensor,
|
||||
n_fft: int = 1024,
|
||||
num_mels: int = 100,
|
||||
sampling_rate: int = 24_000,
|
||||
hop_size: int = 256,
|
||||
win_size: int = 1024,
|
||||
fmin: int = 0,
|
||||
fmax: int = None,
|
||||
center: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the mel spectrogram of an input signal.
|
||||
This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft).
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Input signal.
|
||||
n_fft (int): FFT size.
|
||||
num_mels (int): Number of mel bins.
|
||||
sampling_rate (int): Sampling rate of the input signal.
|
||||
hop_size (int): Hop size for STFT.
|
||||
win_size (int): Window size for STFT.
|
||||
fmin (int): Minimum frequency for mel filterbank.
|
||||
fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn
|
||||
center (bool): Whether to pad the input to center the frames. Default is False.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Mel spectrogram.
|
||||
"""
|
||||
if torch.min(y) < -1.0:
|
||||
print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}")
|
||||
if torch.max(y) > 1.0:
|
||||
print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}")
|
||||
|
||||
device = y.device
|
||||
key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
|
||||
|
||||
if key not in mel_basis_cache:
|
||||
mel = librosa_mel_fn(
|
||||
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
||||
)
|
||||
mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
|
||||
hann_window_cache[key] = torch.hann_window(win_size).to(device)
|
||||
|
||||
mel_basis = mel_basis_cache[key]
|
||||
hann_window = hann_window_cache[key]
|
||||
|
||||
padding = (n_fft - hop_size) // 2
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1), (padding, padding), mode="reflect"
|
||||
).squeeze(1)
|
||||
|
||||
spec = torch.stft(
|
||||
y,
|
||||
n_fft,
|
||||
hop_length=hop_size,
|
||||
win_length=win_size,
|
||||
window=hann_window,
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=True,
|
||||
)
|
||||
spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
|
||||
|
||||
mel_spec = torch.matmul(mel_basis, spec)
|
||||
mel_spec = spectral_normalize_torch(mel_spec)
|
||||
|
||||
return mel_spec
|
218
egs/wenetspeech4tts/TTS/local/compute_mel_feat.py
Executable file
218
egs/wenetspeech4tts/TTS/local/compute_mel_feat.py
Executable file
@ -0,0 +1,218 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the LJSpeech dataset.
|
||||
It looks for manifests in the directory data/manifests.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from fbank import MatchaFbank, MatchaFbankConfig
|
||||
from lhotse import CutSet, LilcomChunkyWriter
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-jobs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--src-dir",
|
||||
type=Path,
|
||||
default=Path("data/manifests"),
|
||||
help="Path to the manifest files",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=Path("data/fbank"),
|
||||
help="Path to the tokenized files",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset-parts",
|
||||
type=str,
|
||||
default="Basic",
|
||||
help="Space separated dataset parts",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--prefix",
|
||||
type=str,
|
||||
default="wenetspeech4tts",
|
||||
help="prefix of the manifest file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--suffix",
|
||||
type=str,
|
||||
default="jsonl.gz",
|
||||
help="suffix of the manifest file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--split",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Split the cut_set into multiple parts",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--resample-to-24kHz",
|
||||
default=True,
|
||||
help="Resample the audio to 24kHz",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--extractor",
|
||||
type=str,
|
||||
choices=["bigvgan", "hifigan"],
|
||||
default="bigvgan",
|
||||
help="The type of extractor to use",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def compute_fbank(args):
|
||||
src_dir = Path(args.src_dir)
|
||||
output_dir = Path(args.output_dir)
|
||||
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
num_jobs = min(args.num_jobs, os.cpu_count())
|
||||
dataset_parts = args.dataset_parts.replace("--dataset-parts", "").strip().split(" ")
|
||||
|
||||
logging.info(f"num_jobs: {num_jobs}")
|
||||
logging.info(f"src_dir: {src_dir}")
|
||||
logging.info(f"output_dir: {output_dir}")
|
||||
logging.info(f"dataset_parts: {dataset_parts}")
|
||||
if args.extractor == "bigvgan":
|
||||
config = MatchaFbankConfig(
|
||||
n_fft=1024,
|
||||
n_mels=100,
|
||||
sampling_rate=24_000,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
f_min=0,
|
||||
f_max=None,
|
||||
)
|
||||
elif args.extractor == "hifigan":
|
||||
config = MatchaFbankConfig(
|
||||
n_fft=1024,
|
||||
n_mels=80,
|
||||
sampling_rate=22050,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Extractor {args.extractor} is not implemented")
|
||||
|
||||
extractor = MatchaFbank(config)
|
||||
|
||||
manifests = read_manifests_if_cached(
|
||||
dataset_parts=dataset_parts,
|
||||
output_dir=args.src_dir,
|
||||
prefix=args.prefix,
|
||||
suffix=args.suffix,
|
||||
types=["recordings", "supervisions", "cuts"],
|
||||
)
|
||||
|
||||
with get_executor() as ex:
|
||||
for partition, m in manifests.items():
|
||||
logging.info(
|
||||
f"Processing partition: {partition} CUDA: {torch.cuda.is_available()}"
|
||||
)
|
||||
try:
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=m["recordings"],
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
except Exception:
|
||||
cut_set = m["cuts"]
|
||||
|
||||
if args.split > 1:
|
||||
cut_sets = cut_set.split(args.split)
|
||||
else:
|
||||
cut_sets = [cut_set]
|
||||
|
||||
for idx, part in enumerate(cut_sets):
|
||||
if args.split > 1:
|
||||
storage_path = f"{args.output_dir}/{args.prefix}_{args.extractor}_{partition}_{idx}"
|
||||
else:
|
||||
storage_path = (
|
||||
f"{args.output_dir}/{args.prefix}_{args.extractor}_{partition}"
|
||||
)
|
||||
|
||||
if args.resample_to_24kHz:
|
||||
part = part.resample(24000)
|
||||
|
||||
with torch.no_grad():
|
||||
part = part.compute_and_store_features(
|
||||
extractor=extractor,
|
||||
storage_path=storage_path,
|
||||
num_jobs=num_jobs if ex is None else 64,
|
||||
executor=ex,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
|
||||
if args.split > 1:
|
||||
cuts_filename = (
|
||||
f"{args.prefix}_cuts_{partition}.{idx}.{args.suffix}"
|
||||
)
|
||||
else:
|
||||
cuts_filename = f"{args.prefix}_cuts_{partition}.{args.suffix}"
|
||||
|
||||
part.to_file(f"{args.output_dir}/{cuts_filename}")
|
||||
logging.info(f"Saved {cuts_filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
args = get_parser().parse_args()
|
||||
compute_fbank(args)
|
1
egs/wenetspeech4tts/TTS/local/fbank.py
Symbolic link
1
egs/wenetspeech4tts/TTS/local/fbank.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../ljspeech/TTS/matcha/fbank.py
|
@ -1,12 +1,14 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
export PYTHONPATH=$PYTHONPATH:/home/yuekaiz/icefall_matcha
|
||||
set -eou pipefail
|
||||
|
||||
|
||||
|
||||
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
stage=1
|
||||
stop_stage=4
|
||||
stage=7
|
||||
stop_stage=7
|
||||
|
||||
dl_dir=$PWD/download
|
||||
|
||||
@ -98,3 +100,57 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
fi
|
||||
python3 ./local/display_manifest_statistics.py --manifest-dir ${audio_feats_dir}
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: build monotonic_align lib (used by matcha recipes)"
|
||||
for recipe in matcha; do
|
||||
if [ ! -d $recipe/monotonic_align/build ]; then
|
||||
cd $recipe/monotonic_align
|
||||
python3 setup.py build_ext --inplace
|
||||
cd ../../
|
||||
else
|
||||
log "monotonic_align lib for $recipe already built"
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
subset="Basic"
|
||||
prefix="wenetspeech4tts"
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Generate fbank (used by ./matcha)"
|
||||
mkdir -p data/fbank
|
||||
if [ ! -e data/fbank/.${prefix}.done ]; then
|
||||
./local/compute_mel_feat.py --dataset-parts $subset --split 100
|
||||
touch data/fbank/.${prefix}.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
log "Stage 7: Split the ${prefix} cuts into train, valid and test sets (used by ./matcha)"
|
||||
if [ ! -f data/fbank/${prefix}_cuts_${subset}.jsonl.gz ]; then
|
||||
echo "Combining ${prefix} cuts"
|
||||
pieces=$(find data/fbank/ -name "${prefix}_cuts_${subset}.*.jsonl.gz")
|
||||
# lhotse combine $pieces data/fbank/${prefix}_cuts_${subset}.jsonl.gz
|
||||
fi
|
||||
if [ ! -e data/fbank/.${prefix}_split.done ]; then
|
||||
echo "Splitting ${prefix} cuts into train, valid and test sets"
|
||||
|
||||
# lhotse subset --last 800 \
|
||||
# data/fbank/${prefix}_cuts_${subset}.jsonl.gz \
|
||||
# data/fbank/${prefix}_cuts_validtest.jsonl.gz
|
||||
# lhotse subset --first 400 \
|
||||
# data/fbank/${prefix}_cuts_validtest.jsonl.gz \
|
||||
# data/fbank/${prefix}_cuts_valid.jsonl.gz
|
||||
# lhotse subset --last 400 \
|
||||
# data/fbank/${prefix}_cuts_validtest.jsonl.gz \
|
||||
# data/fbank/${prefix}_cuts_test.jsonl.gz
|
||||
|
||||
# rm data/fbank/${prefix}_cuts_validtest.jsonl.gz
|
||||
|
||||
n=$(( $(gunzip -c data/fbank/${prefix}_cuts_${subset}.jsonl.gz | wc -l) - 800 ))
|
||||
lhotse subset --first $n \
|
||||
data/fbank/${prefix}_cuts_${subset}.jsonl.gz \
|
||||
data/fbank/${prefix}_cuts_train.jsonl.gz
|
||||
touch data/fbank/.${prefix}_split.done
|
||||
fi
|
||||
fi
|
||||
|
Loading…
x
Reference in New Issue
Block a user