This commit is contained in:
yuekaiz 2024-12-22 15:22:26 +08:00
parent ad966fb81d
commit 604ab6f6b3
10 changed files with 3002 additions and 5 deletions

View File

@ -17,6 +17,7 @@ class MatchaFbankConfig:
win_length: int
f_min: float
f_max: float
device: str = "cuda"
@register_extractor
@ -46,7 +47,7 @@ class MatchaFbank(FeatureExtractor):
f"Mismatched sampling rate: extractor expects {expected_sr}, "
f"got {sampling_rate}"
)
samples = torch.from_numpy(samples)
samples = torch.from_numpy(samples).to(self.device)
assert samples.ndim == 2, samples.shape
assert samples.shape[0] == 1, samples.shape
@ -81,7 +82,7 @@ class MatchaFbank(FeatureExtractor):
mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
).squeeze(0)
return mel.numpy()
return mel.cpu().numpy()
@property
def frame_shift(self) -> Seconds:

View File

@ -0,0 +1,7 @@
# from f5_tts.model.backbones.dit import DiT
# from f5_tts.model.backbones.mmdit import MMDiT
# from f5_tts.model.backbones.unett import UNetT
# from f5_tts.model.cfm import CFM
# from f5_tts.model.trainer import Trainer
# __all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]

View File

@ -0,0 +1,326 @@
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
from __future__ import annotations
from random import random
from typing import Callable
import torch
import torch.nn.functional as F
from model.modules import MelSpec
from model.utils import (
default,
exists,
lens_to_mask,
list_str_to_idx,
list_str_to_tensor,
mask_from_frac_lengths,
)
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torchdiffeq import odeint
class CFM(nn.Module):
def __init__(
self,
transformer: nn.Module,
sigma=0.0,
odeint_kwargs: dict = dict(
# atol = 1e-5,
# rtol = 1e-5,
method="euler" # 'midpoint'
),
audio_drop_prob=0.3,
cond_drop_prob=0.2,
num_channels=None,
mel_spec_module: nn.Module | None = None,
mel_spec_kwargs: dict = dict(),
frac_lengths_mask: tuple[float, float] = (0.7, 1.0),
vocab_char_map: dict[str:int] | None = None,
):
super().__init__()
self.frac_lengths_mask = frac_lengths_mask
# mel spec
self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs))
num_channels = default(num_channels, self.mel_spec.n_mel_channels)
self.num_channels = num_channels
# classifier-free guidance
self.audio_drop_prob = audio_drop_prob
self.cond_drop_prob = cond_drop_prob
# transformer
self.transformer = transformer
dim = transformer.dim
self.dim = dim
# conditional flow related
self.sigma = sigma
# sampling related
self.odeint_kwargs = odeint_kwargs
# vocab map for tokenization
self.vocab_char_map = vocab_char_map
@property
def device(self):
return next(self.parameters()).device
@torch.no_grad()
def sample(
self,
cond: float["b n d"] | float["b nw"], # noqa: F722
text: int["b nt"] | list[str], # noqa: F722
duration: int | int["b"], # noqa: F821
*,
lens: int["b"] | None = None, # noqa: F821
steps=32,
cfg_strength=1.0,
sway_sampling_coef=None,
seed: int | None = None,
max_duration=4096,
vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
no_ref_audio=False,
duplicate_test=False,
t_inter=0.1,
edit_mask=None,
):
self.eval()
# raw wave
if cond.ndim == 2:
cond = self.mel_spec(cond)
cond = cond.permute(0, 2, 1)
assert cond.shape[-1] == self.num_channels
cond = cond.to(next(self.parameters()).dtype)
batch, cond_seq_len, device = *cond.shape[:2], cond.device
if not exists(lens):
lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
# text
if isinstance(text, list):
if exists(self.vocab_char_map):
text = list_str_to_idx(text, self.vocab_char_map).to(device)
else:
text = list_str_to_tensor(text).to(device)
assert text.shape[0] == batch
if exists(text):
text_lens = (text != -1).sum(dim=-1)
lens = torch.maximum(
text_lens, lens
) # make sure lengths are at least those of the text characters
# duration
cond_mask = lens_to_mask(lens)
if edit_mask is not None:
cond_mask = cond_mask & edit_mask
if isinstance(duration, int):
duration = torch.full((batch,), duration, device=device, dtype=torch.long)
duration = torch.maximum(
lens + 1, duration
) # just add one token so something is generated
duration = duration.clamp(max=max_duration)
max_duration = duration.amax()
# duplicate test corner for inner time step oberservation
if duplicate_test:
test_cond = F.pad(
cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0
)
cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
cond_mask = F.pad(
cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False
)
cond_mask = cond_mask.unsqueeze(-1)
step_cond = torch.where(
cond_mask, cond, torch.zeros_like(cond)
) # allow direct control (cut cond audio) with lens passed in
if batch > 1:
mask = lens_to_mask(duration)
else: # save memory and speed up, as single inference need no mask currently
mask = None
# test for no ref audio
if no_ref_audio:
cond = torch.zeros_like(cond)
# neural ode
def fn(t, x):
# at each step, conditioning is fixed
# step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
# predict flow
pred = self.transformer(
x=x,
cond=step_cond,
text=text,
time=t,
mask=mask,
drop_audio_cond=False,
drop_text=False,
)
if cfg_strength < 1e-5:
return pred
null_pred = self.transformer(
x=x,
cond=step_cond,
text=text,
time=t,
mask=mask,
drop_audio_cond=True,
drop_text=True,
)
return pred + (pred - null_pred) * cfg_strength
# noise input
# to make sure batch inference result is same with different batch size, and for sure single inference
# still some difference maybe due to convolutional layers
y0 = []
for dur in duration:
if exists(seed):
torch.manual_seed(seed)
y0.append(
torch.randn(
dur, self.num_channels, device=self.device, dtype=step_cond.dtype
)
)
y0 = pad_sequence(y0, padding_value=0, batch_first=True)
t_start = 0
# duplicate test corner for inner time step oberservation
if duplicate_test:
t_start = t_inter
y0 = (1 - t_start) * y0 + t_start * test_cond
steps = int(steps * (1 - t_start))
t = torch.linspace(
t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype
)
if sway_sampling_coef is not None:
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
sampled = trajectory[-1]
out = sampled
out = torch.where(cond_mask, cond, out)
if exists(vocoder):
out = out.permute(0, 2, 1)
out = vocoder(out)
return out, trajectory
def forward(
self,
inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722
text: int["b nt"] | list[str], # noqa: F722
*,
lens: int["b"] | None = None, # noqa: F821
noise_scheduler: str | None = None,
):
# handle raw wave
if inp.ndim == 2:
inp = self.mel_spec(inp)
inp = inp.permute(0, 2, 1)
assert inp.shape[-1] == self.num_channels
batch, seq_len, dtype, device, _σ1 = (
*inp.shape[:2],
inp.dtype,
self.device,
self.sigma,
)
# handle text as string
if isinstance(text, list):
if exists(self.vocab_char_map):
text = list_str_to_idx(text, self.vocab_char_map).to(device)
else:
text = list_str_to_tensor(text).to(device)
assert text.shape[0] == batch
# lens and mask
if not exists(lens):
lens = torch.full((batch,), seq_len, device=device)
mask = lens_to_mask(
lens, length=seq_len
) # useless here, as collate_fn will pad to max length in batch
# get a random span to mask out for training conditionally
frac_lengths = (
torch.zeros((batch,), device=self.device)
.float()
.uniform_(*self.frac_lengths_mask)
)
rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
if exists(mask):
rand_span_mask &= mask
# mel is x1
x1 = inp
# x0 is gaussian noise
x0 = torch.randn_like(x1)
# time step
time = torch.rand((batch,), dtype=dtype, device=self.device)
# TODO. noise_scheduler
# sample xt (φ_t(x) in the paper)
t = time.unsqueeze(-1).unsqueeze(-1)
φ = (1 - t) * x0 + t * x1
flow = x1 - x0
# only predict what is within the random mask span for infilling
cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1)
# transformer and cfg training with a drop rate
drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
if random() < self.cond_drop_prob: # p_uncond in voicebox paper
drop_audio_cond = True
drop_text = True
else:
drop_text = False
# if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
# adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
pred = self.transformer(
x=φ,
cond=cond,
text=text,
time=time,
drop_audio_cond=drop_audio_cond,
drop_text=drop_text,
)
# flow matching loss
loss = F.mse_loss(pred, flow, reduction="none")
loss = loss[rand_span_mask]
return loss.mean(), cond, pred

View File

@ -0,0 +1,728 @@
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
from __future__ import annotations
import math
from typing import Optional
import torch
import torch.nn.functional as F
import torchaudio
from librosa.filters import mel as librosa_mel_fn
from torch import nn
from x_transformers.x_transformers import apply_rotary_pos_emb
# raw wav to mel spec
mel_basis_cache = {}
hann_window_cache = {}
def get_bigvgan_mel_spectrogram(
waveform,
n_fft=1024,
n_mel_channels=100,
target_sample_rate=24000,
hop_length=256,
win_length=1024,
fmin=0,
fmax=None,
center=False,
): # Copy from https://github.com/NVIDIA/BigVGAN/tree/main
device = waveform.device
key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}"
if key not in mel_basis_cache:
mel = librosa_mel_fn(
sr=target_sample_rate,
n_fft=n_fft,
n_mels=n_mel_channels,
fmin=fmin,
fmax=fmax,
)
mel_basis_cache[key] = (
torch.from_numpy(mel).float().to(device)
) # TODO: why they need .float()?
hann_window_cache[key] = torch.hann_window(win_length).to(device)
mel_basis = mel_basis_cache[key]
hann_window = hann_window_cache[key]
padding = (n_fft - hop_length) // 2
waveform = torch.nn.functional.pad(
waveform.unsqueeze(1), (padding, padding), mode="reflect"
).squeeze(1)
spec = torch.stft(
waveform,
n_fft,
hop_length=hop_length,
win_length=win_length,
window=hann_window,
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
mel_spec = torch.matmul(mel_basis, spec)
mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
return mel_spec
def get_vocos_mel_spectrogram(
waveform,
n_fft=1024,
n_mel_channels=100,
target_sample_rate=24000,
hop_length=256,
win_length=1024,
):
mel_stft = torchaudio.transforms.MelSpectrogram(
sample_rate=target_sample_rate,
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
n_mels=n_mel_channels,
power=1,
center=True,
normalized=False,
norm=None,
).to(waveform.device)
if len(waveform.shape) == 3:
waveform = waveform.squeeze(1) # 'b 1 nw -> b nw'
assert len(waveform.shape) == 2
mel = mel_stft(waveform)
mel = mel.clamp(min=1e-5).log()
return mel
class MelSpec(nn.Module):
def __init__(
self,
n_fft=1024,
hop_length=256,
win_length=1024,
n_mel_channels=100,
target_sample_rate=24_000,
mel_spec_type="vocos",
):
super().__init__()
assert mel_spec_type in ["vocos", "bigvgan"], print(
"We only support two extract mel backend: vocos or bigvgan"
)
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.n_mel_channels = n_mel_channels
self.target_sample_rate = target_sample_rate
if mel_spec_type == "vocos":
self.extractor = get_vocos_mel_spectrogram
elif mel_spec_type == "bigvgan":
self.extractor = get_bigvgan_mel_spectrogram
self.register_buffer("dummy", torch.tensor(0), persistent=False)
def forward(self, wav):
if self.dummy.device != wav.device:
self.to(wav.device)
mel = self.extractor(
waveform=wav,
n_fft=self.n_fft,
n_mel_channels=self.n_mel_channels,
target_sample_rate=self.target_sample_rate,
hop_length=self.hop_length,
win_length=self.win_length,
)
return mel
# sinusoidal position embedding
class SinusPositionEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x, scale=1000):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
# convolutional position embedding
class ConvPositionEmbedding(nn.Module):
def __init__(self, dim, kernel_size=31, groups=16):
super().__init__()
assert kernel_size % 2 != 0
self.conv1d = nn.Sequential(
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
nn.Mish(),
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
nn.Mish(),
)
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
if mask is not None:
mask = mask[..., None]
x = x.masked_fill(~mask, 0.0)
x = x.permute(0, 2, 1)
x = self.conv1d(x)
out = x.permute(0, 2, 1)
if mask is not None:
out = out.masked_fill(~mask, 0.0)
return out
# rotary positional embedding related
def precompute_freqs_cis(
dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0
):
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
# https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
theta *= theta_rescale_factor ** (dim / (dim - 2))
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cos = torch.cos(freqs) # real part
freqs_sin = torch.sin(freqs) # imaginary part
return torch.cat([freqs_cos, freqs_sin], dim=-1)
def get_pos_embed_indices(start, length, max_pos, scale=1.0):
# length = length if isinstance(length, int) else length.max()
scale = scale * torch.ones_like(
start, dtype=torch.float32
) # in case scale is a scalar
pos = (
start.unsqueeze(1)
+ (
torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0)
* scale.unsqueeze(1)
).long()
)
# avoid extra long error.
pos = torch.where(pos < max_pos, pos, max_pos - 1)
return pos
# Global Response Normalization layer (Instance Normalization ?)
class GRN(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=1, keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * Nx) + self.beta + x
# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
class ConvNeXtV2Block(nn.Module):
def __init__(
self,
dim: int,
intermediate_dim: int,
dilation: int = 1,
):
super().__init__()
padding = (dilation * (7 - 1)) // 2
self.dwconv = nn.Conv1d(
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
) # depthwise conv
self.norm = nn.LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(
dim, intermediate_dim
) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.grn = GRN(intermediate_dim)
self.pwconv2 = nn.Linear(intermediate_dim, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = x.transpose(1, 2) # b n d -> b d n
x = self.dwconv(x)
x = x.transpose(1, 2) # b d n -> b n d
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.grn(x)
x = self.pwconv2(x)
return residual + x
# AdaLayerNormZero
# return with modulated x for attn input, and params for later mlp modulation
class AdaLayerNormZero(nn.Module):
def __init__(self, dim):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(dim, dim * 6)
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb=None):
emb = self.linear(self.silu(emb))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(
emb, 6, dim=1
)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
# AdaLayerNormZero for final layer
# return only with modulated x for attn input, cuz no more mlp modulation
class AdaLayerNormZero_Final(nn.Module):
def __init__(self, dim):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(dim, dim * 2)
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb):
emb = self.linear(self.silu(emb))
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x
# FeedForward
class FeedForward(nn.Module):
def __init__(
self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"
):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
activation = nn.GELU(approximate=approximate)
project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
self.ff = nn.Sequential(
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.ff(x)
# Attention with possible joint part
# modified from diffusers/src/diffusers/models/attention_processor.py
class Attention(nn.Module):
def __init__(
self,
processor: JointAttnProcessor | AttnProcessor,
dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
context_dim: Optional[int] = None, # if not None -> joint attention
context_pre_only=None,
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
self.processor = processor
self.dim = dim
self.heads = heads
self.inner_dim = dim_head * heads
self.dropout = dropout
self.context_dim = context_dim
self.context_pre_only = context_pre_only
self.to_q = nn.Linear(dim, self.inner_dim)
self.to_k = nn.Linear(dim, self.inner_dim)
self.to_v = nn.Linear(dim, self.inner_dim)
if self.context_dim is not None:
self.to_k_c = nn.Linear(context_dim, self.inner_dim)
self.to_v_c = nn.Linear(context_dim, self.inner_dim)
if self.context_pre_only is not None:
self.to_q_c = nn.Linear(context_dim, self.inner_dim)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, dim))
self.to_out.append(nn.Dropout(dropout))
if self.context_pre_only is not None and not self.context_pre_only:
self.to_out_c = nn.Linear(self.inner_dim, dim)
def forward(
self,
x: float["b n d"], # noised input x # noqa: F722
c: float["b n d"] = None, # context c # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
rope=None, # rotary position embedding for x
c_rope=None, # rotary position embedding for c
) -> torch.Tensor:
if c is not None:
return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
else:
return self.processor(self, x, mask=mask, rope=rope)
# Attention processor
class AttnProcessor:
def __init__(self):
pass
def __call__(
self,
attn: Attention,
x: float["b n d"], # noised input x # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
rope=None, # rotary position embedding
) -> torch.FloatTensor:
batch_size = x.shape[0]
# `sample` projections.
query = attn.to_q(x)
key = attn.to_k(x)
value = attn.to_v(x)
# apply rotary position embedding
if rope is not None:
freqs, xpos_scale = rope
q_xpos_scale, k_xpos_scale = (
(xpos_scale, xpos_scale**-1.0)
if xpos_scale is not None
else (1.0, 1.0)
)
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
# attention
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# mask. e.g. inference got a batch with different target durations, mask out the padding
if mask is not None:
attn_mask = mask
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
attn_mask = attn_mask.expand(
batch_size, attn.heads, query.shape[-2], key.shape[-2]
)
else:
attn_mask = None
x = F.scaled_dot_product_attention(
query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
)
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
x = x.to(query.dtype)
# linear proj
x = attn.to_out[0](x)
# dropout
x = attn.to_out[1](x)
if mask is not None:
mask = mask.unsqueeze(-1)
x = x.masked_fill(~mask, 0.0)
return x
# Joint Attention processor for MM-DiT
# modified from diffusers/src/diffusers/models/attention_processor.py
class JointAttnProcessor:
def __init__(self):
pass
def __call__(
self,
attn: Attention,
x: float["b n d"], # noised input x # noqa: F722
c: float["b nt d"] = None, # context c, here text # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
rope=None, # rotary position embedding for x
c_rope=None, # rotary position embedding for c
) -> torch.FloatTensor:
residual = x
batch_size = c.shape[0]
# `sample` projections.
query = attn.to_q(x)
key = attn.to_k(x)
value = attn.to_v(x)
# `context` projections.
c_query = attn.to_q_c(c)
c_key = attn.to_k_c(c)
c_value = attn.to_v_c(c)
# apply rope for context and noised input independently
if rope is not None:
freqs, xpos_scale = rope
q_xpos_scale, k_xpos_scale = (
(xpos_scale, xpos_scale**-1.0)
if xpos_scale is not None
else (1.0, 1.0)
)
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
if c_rope is not None:
freqs, xpos_scale = c_rope
q_xpos_scale, k_xpos_scale = (
(xpos_scale, xpos_scale**-1.0)
if xpos_scale is not None
else (1.0, 1.0)
)
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
# attention
query = torch.cat([query, c_query], dim=1)
key = torch.cat([key, c_key], dim=1)
value = torch.cat([value, c_value], dim=1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# mask. e.g. inference got a batch with different target durations, mask out the padding
if mask is not None:
attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
attn_mask = attn_mask.expand(
batch_size, attn.heads, query.shape[-2], key.shape[-2]
)
else:
attn_mask = None
x = F.scaled_dot_product_attention(
query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
)
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
x = x.to(query.dtype)
# Split the attention outputs.
x, c = (
x[:, : residual.shape[1]],
x[:, residual.shape[1] :],
)
# linear proj
x = attn.to_out[0](x)
# dropout
x = attn.to_out[1](x)
if not attn.context_pre_only:
c = attn.to_out_c(c)
if mask is not None:
mask = mask.unsqueeze(-1)
x = x.masked_fill(~mask, 0.0)
# c = c.masked_fill(~mask, 0.) # no mask for c (text)
return x, c
# DiT Block
class DiTBlock(nn.Module):
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
super().__init__()
self.attn_norm = AdaLayerNormZero(dim)
self.attn = Attention(
processor=AttnProcessor(),
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
)
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(
dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh"
)
def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
# pre-norm & modulation for attention input
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
# attention
attn_output = self.attn(x=norm, mask=mask, rope=rope)
# process attention output for input x
x = x + gate_msa.unsqueeze(1) * attn_output
norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.ff(norm)
x = x + gate_mlp.unsqueeze(1) * ff_output
return x
# MMDiT Block https://arxiv.org/abs/2403.03206
class MMDiTBlock(nn.Module):
r"""
modified from diffusers/src/diffusers/models/attention.py
notes.
_c: context related. text, cond, etc. (left part in sd3 fig2.b)
_x: noised input related. (right part)
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
"""
def __init__(
self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False
):
super().__init__()
self.context_pre_only = context_pre_only
self.attn_norm_c = (
AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
)
self.attn_norm_x = AdaLayerNormZero(dim)
self.attn = Attention(
processor=JointAttnProcessor(),
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
context_dim=dim,
context_pre_only=context_pre_only,
)
if not context_pre_only:
self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_c = FeedForward(
dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh"
)
else:
self.ff_norm_c = None
self.ff_c = None
self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_x = FeedForward(
dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh"
)
def forward(
self, x, c, t, mask=None, rope=None, c_rope=None
): # x: noised input, c: context, t: time embedding
# pre-norm & modulation for attention input
if self.context_pre_only:
norm_c = self.attn_norm_c(c, t)
else:
norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(
c, emb=t
)
norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(
x, emb=t
)
# attention
x_attn_output, c_attn_output = self.attn(
x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope
)
# process attention output for context c
if self.context_pre_only:
c = None
else: # if not last layer
c = c + c_gate_msa.unsqueeze(1) * c_attn_output
norm_c = (
self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
)
c_ff_output = self.ff_c(norm_c)
c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
# process attention output for input x
x = x + x_gate_msa.unsqueeze(1) * x_attn_output
norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
x_ff_output = self.ff_x(norm_x)
x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
return c, x
# time step conditioning embedding
class TimestepEmbedding(nn.Module):
def __init__(self, dim, freq_embed_dim=256):
super().__init__()
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
self.time_mlp = nn.Sequential(
nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)
)
def forward(self, timestep: float["b"]): # noqa: F821
time_hidden = self.time_embed(timestep)
time_hidden = time_hidden.to(timestep.dtype)
time = self.time_mlp(time_hidden) # b d
return time

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,346 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from fbank import MatchaFbank, MatchaFbankConfig
from lhotse import CutSet, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
PrecomputedFeatures,
SimpleCutSampler,
SpeechSynthesisDataset,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class TtsDataModule:
"""
DataModule for tts experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="TTS data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=False,
help="When enabled, each batch will have the "
"field: batch['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
parser.add_argument(
"--prefix",
type=str,
default="wenetspeech4tts",
help="prefix of the manifest file",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
logging.info("About to create train dataset")
train = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = MatchaFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=sampling_rate,
hop_length=256,
win_length=1024,
f_min=0,
f_max=8000,
)
train = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=True,
pin_memory=True,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = MatchaFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=sampling_rate,
hop_length=256,
win_length=1024,
f_min=0,
f_max=8000,
)
validate = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
return_cuts=self.args.return_cuts,
)
else:
validate = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
num_buckets=self.args.num_buckets,
shuffle=False,
)
logging.info("About to create valid dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=True,
pin_memory=True,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.info("About to create test dataset")
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = MatchaFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=sampling_rate,
hop_length=256,
win_length=1024,
f_min=0,
f_max=8000,
)
test = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
return_cuts=self.args.return_cuts,
)
else:
test = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
test_sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
num_buckets=self.args.num_buckets,
shuffle=False,
)
logging.info("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=test_sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
return load_manifest_lazy(
self.args.manifest_dir / f"{self.args.prefix}_cuts_train.jsonl.gz"
)
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get validation cuts")
return load_manifest_lazy(
self.args.manifest_dir / f"{self.args.prefix}_cuts_valid.jsonl.gz"
)
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest_lazy(
self.args.manifest_dir / f"{self.args.prefix}_cuts_test.jsonl.gz"
)

View File

@ -0,0 +1,122 @@
# Copyright (c) 2024 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.
import math
import os
import pathlib
import random
from typing import List, Optional, Tuple
import librosa
import numpy as np
import torch
import torch.utils.data
from librosa.filters import mel as librosa_mel_fn
from tqdm import tqdm
# from env import AttrDict
MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases)
def dynamic_range_compression(x, C=1, clip_val=1e-5):
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
def dynamic_range_decompression(x, C=1):
return np.exp(x) / C
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression_torch(x, C=1):
return torch.exp(x) / C
def spectral_normalize_torch(magnitudes):
return dynamic_range_compression_torch(magnitudes)
def spectral_de_normalize_torch(magnitudes):
return dynamic_range_decompression_torch(magnitudes)
mel_basis_cache = {}
hann_window_cache = {}
def mel_spectrogram(
y: torch.Tensor,
n_fft: int = 1024,
num_mels: int = 100,
sampling_rate: int = 24_000,
hop_size: int = 256,
win_size: int = 1024,
fmin: int = 0,
fmax: int = None,
center: bool = False,
) -> torch.Tensor:
"""
Calculate the mel spectrogram of an input signal.
This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft).
Args:
y (torch.Tensor): Input signal.
n_fft (int): FFT size.
num_mels (int): Number of mel bins.
sampling_rate (int): Sampling rate of the input signal.
hop_size (int): Hop size for STFT.
win_size (int): Window size for STFT.
fmin (int): Minimum frequency for mel filterbank.
fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn
center (bool): Whether to pad the input to center the frames. Default is False.
Returns:
torch.Tensor: Mel spectrogram.
"""
if torch.min(y) < -1.0:
print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}")
if torch.max(y) > 1.0:
print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}")
device = y.device
key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
if key not in mel_basis_cache:
mel = librosa_mel_fn(
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
)
mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
hann_window_cache[key] = torch.hann_window(win_size).to(device)
mel_basis = mel_basis_cache[key]
hann_window = hann_window_cache[key]
padding = (n_fft - hop_size) // 2
y = torch.nn.functional.pad(
y.unsqueeze(1), (padding, padding), mode="reflect"
).squeeze(1)
spec = torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window,
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
mel_spec = torch.matmul(mel_basis, spec)
mel_spec = spectral_normalize_torch(mel_spec)
return mel_spec

View File

@ -0,0 +1,218 @@
#!/usr/bin/env python3
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the LJSpeech dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import argparse
import logging
import os
from pathlib import Path
import torch
from fbank import MatchaFbank, MatchaFbankConfig
from lhotse import CutSet, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--num-jobs",
type=int,
default=1,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
""",
)
parser.add_argument(
"--src-dir",
type=Path,
default=Path("data/manifests"),
help="Path to the manifest files",
)
parser.add_argument(
"--output-dir",
type=Path,
default=Path("data/fbank"),
help="Path to the tokenized files",
)
parser.add_argument(
"--dataset-parts",
type=str,
default="Basic",
help="Space separated dataset parts",
)
parser.add_argument(
"--prefix",
type=str,
default="wenetspeech4tts",
help="prefix of the manifest file",
)
parser.add_argument(
"--suffix",
type=str,
default="jsonl.gz",
help="suffix of the manifest file",
)
parser.add_argument(
"--split",
type=int,
default=100,
help="Split the cut_set into multiple parts",
)
parser.add_argument(
"--resample-to-24kHz",
default=True,
help="Resample the audio to 24kHz",
)
parser.add_argument(
"--extractor",
type=str,
choices=["bigvgan", "hifigan"],
default="bigvgan",
help="The type of extractor to use",
)
return parser
def compute_fbank(args):
src_dir = Path(args.src_dir)
output_dir = Path(args.output_dir)
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
num_jobs = min(args.num_jobs, os.cpu_count())
dataset_parts = args.dataset_parts.replace("--dataset-parts", "").strip().split(" ")
logging.info(f"num_jobs: {num_jobs}")
logging.info(f"src_dir: {src_dir}")
logging.info(f"output_dir: {output_dir}")
logging.info(f"dataset_parts: {dataset_parts}")
if args.extractor == "bigvgan":
config = MatchaFbankConfig(
n_fft=1024,
n_mels=100,
sampling_rate=24_000,
hop_length=256,
win_length=1024,
f_min=0,
f_max=None,
)
elif args.extractor == "hifigan":
config = MatchaFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=22050,
hop_length=256,
win_length=1024,
f_min=0,
f_max=8000,
)
else:
raise NotImplementedError(f"Extractor {args.extractor} is not implemented")
extractor = MatchaFbank(config)
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=args.src_dir,
prefix=args.prefix,
suffix=args.suffix,
types=["recordings", "supervisions", "cuts"],
)
with get_executor() as ex:
for partition, m in manifests.items():
logging.info(
f"Processing partition: {partition} CUDA: {torch.cuda.is_available()}"
)
try:
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
except Exception:
cut_set = m["cuts"]
if args.split > 1:
cut_sets = cut_set.split(args.split)
else:
cut_sets = [cut_set]
for idx, part in enumerate(cut_sets):
if args.split > 1:
storage_path = f"{args.output_dir}/{args.prefix}_{args.extractor}_{partition}_{idx}"
else:
storage_path = (
f"{args.output_dir}/{args.prefix}_{args.extractor}_{partition}"
)
if args.resample_to_24kHz:
part = part.resample(24000)
with torch.no_grad():
part = part.compute_and_store_features(
extractor=extractor,
storage_path=storage_path,
num_jobs=num_jobs if ex is None else 64,
executor=ex,
storage_type=LilcomChunkyWriter,
)
if args.split > 1:
cuts_filename = (
f"{args.prefix}_cuts_{partition}.{idx}.{args.suffix}"
)
else:
cuts_filename = f"{args.prefix}_cuts_{partition}.{args.suffix}"
part.to_file(f"{args.output_dir}/{cuts_filename}")
logging.info(f"Saved {cuts_filename}")
if __name__ == "__main__":
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_parser().parse_args()
compute_fbank(args)

View File

@ -0,0 +1 @@
../../../ljspeech/TTS/matcha/fbank.py

View File

@ -1,12 +1,14 @@
#!/usr/bin/env bash
export PYTHONPATH=$PYTHONPATH:/home/yuekaiz/icefall_matcha
set -eou pipefail
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
stage=1
stop_stage=4
stage=7
stop_stage=7
dl_dir=$PWD/download
@ -98,3 +100,57 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
fi
python3 ./local/display_manifest_statistics.py --manifest-dir ${audio_feats_dir}
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: build monotonic_align lib (used by matcha recipes)"
for recipe in matcha; do
if [ ! -d $recipe/monotonic_align/build ]; then
cd $recipe/monotonic_align
python3 setup.py build_ext --inplace
cd ../../
else
log "monotonic_align lib for $recipe already built"
fi
done
fi
subset="Basic"
prefix="wenetspeech4tts"
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Generate fbank (used by ./matcha)"
mkdir -p data/fbank
if [ ! -e data/fbank/.${prefix}.done ]; then
./local/compute_mel_feat.py --dataset-parts $subset --split 100
touch data/fbank/.${prefix}.done
fi
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Split the ${prefix} cuts into train, valid and test sets (used by ./matcha)"
if [ ! -f data/fbank/${prefix}_cuts_${subset}.jsonl.gz ]; then
echo "Combining ${prefix} cuts"
pieces=$(find data/fbank/ -name "${prefix}_cuts_${subset}.*.jsonl.gz")
# lhotse combine $pieces data/fbank/${prefix}_cuts_${subset}.jsonl.gz
fi
if [ ! -e data/fbank/.${prefix}_split.done ]; then
echo "Splitting ${prefix} cuts into train, valid and test sets"
# lhotse subset --last 800 \
# data/fbank/${prefix}_cuts_${subset}.jsonl.gz \
# data/fbank/${prefix}_cuts_validtest.jsonl.gz
# lhotse subset --first 400 \
# data/fbank/${prefix}_cuts_validtest.jsonl.gz \
# data/fbank/${prefix}_cuts_valid.jsonl.gz
# lhotse subset --last 400 \
# data/fbank/${prefix}_cuts_validtest.jsonl.gz \
# data/fbank/${prefix}_cuts_test.jsonl.gz
# rm data/fbank/${prefix}_cuts_validtest.jsonl.gz
n=$(( $(gunzip -c data/fbank/${prefix}_cuts_${subset}.jsonl.gz | wc -l) - 800 ))
lhotse subset --first $n \
data/fbank/${prefix}_cuts_${subset}.jsonl.gz \
data/fbank/${prefix}_cuts_train.jsonl.gz
touch data/fbank/.${prefix}_split.done
fi
fi