mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
Reformat code
This commit is contained in:
parent
a67d4b9a80
commit
7994684bf4
@ -32,7 +32,10 @@ class BaseLightningClass(LightningModule, ABC):
|
||||
if self.hparams.scheduler not in (None, {}):
|
||||
scheduler_args = {}
|
||||
# Manage last epoch for exponential schedulers
|
||||
if "last_epoch" in inspect.signature(self.hparams.scheduler.scheduler).parameters:
|
||||
if (
|
||||
"last_epoch"
|
||||
in inspect.signature(self.hparams.scheduler.scheduler).parameters
|
||||
):
|
||||
if hasattr(self, "ckpt_loaded_epoch"):
|
||||
current_epoch = self.ckpt_loaded_epoch - 1
|
||||
else:
|
||||
@ -74,7 +77,9 @@ class BaseLightningClass(LightningModule, ABC):
|
||||
}
|
||||
|
||||
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||
self.ckpt_loaded_epoch = checkpoint["epoch"] # pylint: disable=attribute-defined-outside-init
|
||||
self.ckpt_loaded_epoch = checkpoint[
|
||||
"epoch"
|
||||
] # pylint: disable=attribute-defined-outside-init
|
||||
|
||||
def training_step(self, batch: Any, batch_idx: int):
|
||||
loss_dict = self.get_losses(batch)
|
||||
@ -183,8 +188,14 @@ class BaseLightningClass(LightningModule, ABC):
|
||||
for i in range(2):
|
||||
x = one_batch["x"][i].unsqueeze(0).to(self.device)
|
||||
x_lengths = one_batch["x_lengths"][i].unsqueeze(0).to(self.device)
|
||||
spks = one_batch["spks"][i].unsqueeze(0).to(self.device) if one_batch["spks"] is not None else None
|
||||
output = self.synthesise(x[:, :x_lengths], x_lengths, n_timesteps=10, spks=spks)
|
||||
spks = (
|
||||
one_batch["spks"][i].unsqueeze(0).to(self.device)
|
||||
if one_batch["spks"] is not None
|
||||
else None
|
||||
)
|
||||
output = self.synthesise(
|
||||
x[:, :x_lengths], x_lengths, n_timesteps=10, spks=spks
|
||||
)
|
||||
y_enc, y_dec = output["encoder_outputs"], output["decoder_outputs"]
|
||||
attn = output["attn"]
|
||||
self.logger.experiment.add_image(
|
||||
@ -207,4 +218,6 @@ class BaseLightningClass(LightningModule, ABC):
|
||||
)
|
||||
|
||||
def on_before_optimizer_step(self, optimizer):
|
||||
self.log_dict({f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()})
|
||||
self.log_dict(
|
||||
{f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()}
|
||||
)
|
||||
|
@ -46,7 +46,9 @@ class Block1D(torch.nn.Module):
|
||||
class ResnetBlock1D(torch.nn.Module):
|
||||
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
|
||||
super().__init__()
|
||||
self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out))
|
||||
self.mlp = torch.nn.Sequential(
|
||||
nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out)
|
||||
)
|
||||
|
||||
self.block1 = Block1D(dim, dim_out, groups=groups)
|
||||
self.block2 = Block1D(dim_out, dim_out, groups=groups)
|
||||
@ -131,7 +133,14 @@ class Upsample1D(nn.Module):
|
||||
number of output channels. Defaults to `channels`.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"):
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
use_conv=False,
|
||||
use_conv_transpose=True,
|
||||
out_channels=None,
|
||||
name="conv",
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
@ -235,7 +244,9 @@ class Decoder(nn.Module):
|
||||
input_channel = output_channel
|
||||
output_channel = channels[i]
|
||||
is_last = i == len(channels) - 1
|
||||
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
||||
resnet = ResnetBlock1D(
|
||||
dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim
|
||||
)
|
||||
transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
self.get_block(
|
||||
@ -250,16 +261,22 @@ class Decoder(nn.Module):
|
||||
]
|
||||
)
|
||||
downsample = (
|
||||
Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
||||
Downsample1D(output_channel)
|
||||
if not is_last
|
||||
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
||||
)
|
||||
|
||||
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
||||
self.down_blocks.append(
|
||||
nn.ModuleList([resnet, transformer_blocks, downsample])
|
||||
)
|
||||
|
||||
for i in range(num_mid_blocks):
|
||||
input_channel = channels[-1]
|
||||
out_channels = channels[-1]
|
||||
|
||||
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
||||
resnet = ResnetBlock1D(
|
||||
dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim
|
||||
)
|
||||
|
||||
transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
|
@ -4,6 +4,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from matcha.models.components.decoder import Decoder
|
||||
|
||||
# from matcha.utils.pylogger import get_pylogger
|
||||
|
||||
# log = get_pylogger(__name__)
|
||||
@ -50,7 +51,9 @@ class BASECFM(torch.nn.Module, ABC):
|
||||
"""
|
||||
z = torch.randn_like(mu) * temperature
|
||||
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
||||
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
|
||||
return self.solve_euler(
|
||||
z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond
|
||||
)
|
||||
|
||||
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
||||
"""
|
||||
@ -112,14 +115,22 @@ class BASECFM(torch.nn.Module, ABC):
|
||||
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
||||
u = x1 - (1 - self.sigma_min) * z
|
||||
|
||||
loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / (
|
||||
torch.sum(mask) * u.shape[1]
|
||||
)
|
||||
loss = F.mse_loss(
|
||||
self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum"
|
||||
) / (torch.sum(mask) * u.shape[1])
|
||||
return loss, y
|
||||
|
||||
|
||||
class CFM(BASECFM):
|
||||
def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channel,
|
||||
cfm_params,
|
||||
decoder_params,
|
||||
n_spks=1,
|
||||
spk_emb_dim=64,
|
||||
):
|
||||
super().__init__(
|
||||
n_feats=in_channels,
|
||||
cfm_params=cfm_params,
|
||||
@ -129,4 +140,6 @@ class CFM(BASECFM):
|
||||
|
||||
in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0)
|
||||
# Just change the architecture of the estimator here
|
||||
self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params)
|
||||
self.estimator = Decoder(
|
||||
in_channels=in_channels, out_channels=out_channel, **decoder_params
|
||||
)
|
||||
|
@ -34,7 +34,15 @@ class LayerNorm(nn.Module):
|
||||
|
||||
|
||||
class ConvReluNorm(nn.Module):
|
||||
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
hidden_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
n_layers,
|
||||
p_dropout,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
@ -45,12 +53,23 @@ class ConvReluNorm(nn.Module):
|
||||
|
||||
self.conv_layers = torch.nn.ModuleList()
|
||||
self.norm_layers = torch.nn.ModuleList()
|
||||
self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
||||
self.conv_layers.append(
|
||||
torch.nn.Conv1d(
|
||||
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
)
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout))
|
||||
self.relu_drop = torch.nn.Sequential(
|
||||
torch.nn.ReLU(), torch.nn.Dropout(p_dropout)
|
||||
)
|
||||
for _ in range(n_layers - 1):
|
||||
self.conv_layers.append(
|
||||
torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
|
||||
torch.nn.Conv1d(
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
)
|
||||
)
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
@ -75,9 +94,13 @@ class DurationPredictor(nn.Module):
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.drop = torch.nn.Dropout(p_dropout)
|
||||
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.conv_1 = torch.nn.Conv1d(
|
||||
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
self.norm_1 = LayerNorm(filter_channels)
|
||||
self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.conv_2 = torch.nn.Conv1d(
|
||||
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
self.norm_2 = LayerNorm(filter_channels)
|
||||
self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
|
||||
|
||||
@ -128,7 +151,9 @@ class RotaryPositionalEmbeddings(nn.Module):
|
||||
seq_len = x.shape[0]
|
||||
|
||||
# $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
|
||||
theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
|
||||
theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(
|
||||
x.device
|
||||
)
|
||||
|
||||
# Create position indexes `[0, 1, ..., seq_len - 1]`
|
||||
seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
|
||||
@ -167,7 +192,9 @@ class RotaryPositionalEmbeddings(nn.Module):
|
||||
# $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
|
||||
neg_half_x = self._neg_half(x_rope)
|
||||
|
||||
x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]])
|
||||
x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (
|
||||
neg_half_x * self.sin_cached[: x.shape[0]]
|
||||
)
|
||||
|
||||
return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d")
|
||||
|
||||
@ -236,7 +263,9 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
if self.proximal_bias:
|
||||
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
||||
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
|
||||
scores = scores + self._attention_bias_proximal(t_s).to(
|
||||
device=scores.device, dtype=scores.dtype
|
||||
)
|
||||
if mask is not None:
|
||||
scores = scores.masked_fill(mask == 0, -1e4)
|
||||
p_attn = torch.nn.functional.softmax(scores, dim=-1)
|
||||
@ -253,7 +282,9 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
|
||||
class FFN(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0):
|
||||
def __init__(
|
||||
self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
@ -261,8 +292,12 @@ class FFN(nn.Module):
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.conv_1 = torch.nn.Conv1d(
|
||||
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
self.conv_2 = torch.nn.Conv1d(
|
||||
filter_channels, out_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
self.drop = torch.nn.Dropout(p_dropout)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
@ -298,7 +333,11 @@ class Encoder(nn.Module):
|
||||
self.ffn_layers = torch.nn.ModuleList()
|
||||
self.norm_layers_2 = torch.nn.ModuleList()
|
||||
for _ in range(self.n_layers):
|
||||
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
|
||||
self.attn_layers.append(
|
||||
MultiHeadAttention(
|
||||
hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
|
||||
)
|
||||
)
|
||||
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
||||
self.ffn_layers.append(
|
||||
FFN(
|
||||
@ -367,7 +406,9 @@ class TextEncoder(nn.Module):
|
||||
encoder_params.p_dropout,
|
||||
)
|
||||
|
||||
self.proj_m = torch.nn.Conv1d(self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1)
|
||||
self.proj_m = torch.nn.Conv1d(
|
||||
self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1
|
||||
)
|
||||
self.proj_w = DurationPredictor(
|
||||
self.n_channels + (spk_emb_dim if n_spks > 1 else 0),
|
||||
duration_predictor_params.filter_channels_dp,
|
||||
|
@ -32,7 +32,14 @@ class SnakeBeta(nn.Module):
|
||||
>>> x = a1(x)
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
out_features,
|
||||
alpha=1.0,
|
||||
alpha_trainable=True,
|
||||
alpha_logscale=True,
|
||||
):
|
||||
"""
|
||||
Initialization.
|
||||
INPUT:
|
||||
@ -44,7 +51,9 @@ class SnakeBeta(nn.Module):
|
||||
alpha will be trained along with the rest of your model.
|
||||
"""
|
||||
super().__init__()
|
||||
self.in_features = out_features if isinstance(out_features, list) else [out_features]
|
||||
self.in_features = (
|
||||
out_features if isinstance(out_features, list) else [out_features]
|
||||
)
|
||||
self.proj = LoRACompatibleLinear(in_features, out_features)
|
||||
|
||||
# initialize alpha
|
||||
@ -75,7 +84,9 @@ class SnakeBeta(nn.Module):
|
||||
alpha = self.alpha
|
||||
beta = self.beta
|
||||
|
||||
x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
|
||||
x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(
|
||||
torch.sin(x * alpha), 2
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
@ -176,8 +187,12 @@ class BasicTransformerBlock(nn.Module):
|
||||
super().__init__()
|
||||
self.only_cross_attention = only_cross_attention
|
||||
|
||||
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
||||
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
||||
self.use_ada_layer_norm_zero = (
|
||||
num_embeds_ada_norm is not None
|
||||
) and norm_type == "ada_norm_zero"
|
||||
self.use_ada_layer_norm = (
|
||||
num_embeds_ada_norm is not None
|
||||
) and norm_type == "ada_norm"
|
||||
|
||||
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
||||
raise ValueError(
|
||||
@ -215,7 +230,9 @@ class BasicTransformerBlock(nn.Module):
|
||||
)
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
||||
cross_attention_dim=cross_attention_dim
|
||||
if not double_self_attention
|
||||
else None,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
@ -229,7 +246,12 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
# 3. Feed-forward
|
||||
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
||||
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
final_dropout=final_dropout,
|
||||
)
|
||||
|
||||
# let chunk size default to None
|
||||
self._chunk_size = None
|
||||
@ -261,12 +283,18 @@ class BasicTransformerBlock(nn.Module):
|
||||
else:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
cross_attention_kwargs = (
|
||||
cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
)
|
||||
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
||||
attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states
|
||||
if self.only_cross_attention
|
||||
else None,
|
||||
attention_mask=encoder_attention_mask
|
||||
if self.only_cross_attention
|
||||
else attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
if self.use_ada_layer_norm_zero:
|
||||
@ -276,7 +304,9 @@ class BasicTransformerBlock(nn.Module):
|
||||
# 2. Cross-Attention
|
||||
if self.attn2 is not None:
|
||||
norm_hidden_states = (
|
||||
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
||||
self.norm2(hidden_states, timestep)
|
||||
if self.use_ada_layer_norm
|
||||
else self.norm2(hidden_states)
|
||||
)
|
||||
|
||||
attn_output = self.attn2(
|
||||
@ -291,7 +321,9 @@ class BasicTransformerBlock(nn.Module):
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
|
||||
if self.use_ada_layer_norm_zero:
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
norm_hidden_states = (
|
||||
norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
)
|
||||
|
||||
if self._chunk_size is not None:
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
@ -302,7 +334,12 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
||||
ff_output = torch.cat(
|
||||
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
|
||||
[
|
||||
self.ff(hid_slice)
|
||||
for hid_slice in norm_hidden_states.chunk(
|
||||
num_chunks, dim=self._chunk_dim
|
||||
)
|
||||
],
|
||||
dim=self._chunk_dim,
|
||||
)
|
||||
else:
|
||||
|
@ -4,7 +4,9 @@ from matcha.text.symbols import symbols
|
||||
|
||||
# Mappings from symbol to numeric ID and vice versa:
|
||||
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
||||
_id_to_symbol = {i: s for i, s in enumerate(symbols)} # pylint: disable=unnecessary-comprehension
|
||||
_id_to_symbol = {
|
||||
i: s for i, s in enumerate(symbols)
|
||||
} # pylint: disable=unnecessary-comprehension
|
||||
|
||||
|
||||
def text_to_sequence(text, cleaner_names):
|
||||
@ -20,13 +22,15 @@ def text_to_sequence(text, cleaner_names):
|
||||
clean_text = _clean_text(text, cleaner_names)
|
||||
for symbol in clean_text:
|
||||
try:
|
||||
if symbol in '_()[]# ̃':
|
||||
if symbol in "_()[]# ̃":
|
||||
continue
|
||||
symbol_id = _symbol_to_id[symbol]
|
||||
except Exception as ex:
|
||||
print(text)
|
||||
print(clean_text)
|
||||
raise RuntimeError(f'text: {text}, clean_text: {clean_text}, ex: {ex}, symbol: {symbol}')
|
||||
raise RuntimeError(
|
||||
f"text: {text}, clean_text: {clean_text}, ex: {ex}, symbol: {symbol}"
|
||||
)
|
||||
sequence += [symbol_id]
|
||||
return sequence, clean_text
|
||||
|
||||
|
@ -75,6 +75,7 @@ def lowercase(text):
|
||||
def collapse_whitespace(text):
|
||||
return re.sub(_whitespace_re, " ", text)
|
||||
|
||||
|
||||
def remove_parentheses(text):
|
||||
text = text.replace("(", "")
|
||||
text = text.replace(")", "")
|
||||
|
@ -56,7 +56,9 @@ def _expand_number(m):
|
||||
elif num % 100 == 0:
|
||||
return _inflect.number_to_words(num // 100) + " hundred"
|
||||
else:
|
||||
return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
|
||||
return _inflect.number_to_words(
|
||||
num, andword="", zero="oh", group=2
|
||||
).replace(", ", " ")
|
||||
else:
|
||||
return _inflect.number_to_words(num, andword="")
|
||||
|
||||
|
@ -5,9 +5,7 @@ Defines the set of symbols used in text input to the model.
|
||||
_pad = "_"
|
||||
_punctuation = ';:,.!?¡¿—…"«»“” '
|
||||
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||
_letters_ipa = (
|
||||
"ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
||||
)
|
||||
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
||||
|
||||
|
||||
# Export all symbols:
|
||||
|
@ -42,7 +42,9 @@ mel_basis = {}
|
||||
hann_window = {}
|
||||
|
||||
|
||||
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
||||
def mel_spectrogram(
|
||||
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
|
||||
):
|
||||
if torch.min(y) < -1.0:
|
||||
print("min value is ", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
@ -50,12 +52,18 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin,
|
||||
|
||||
global mel_basis, hann_window # pylint: disable=global-statement
|
||||
if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
|
||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
||||
mel = librosa_mel_fn(
|
||||
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
||||
)
|
||||
mel_basis[str(fmax) + "_" + str(y.device)] = (
|
||||
torch.from_numpy(mel).float().to(y.device)
|
||||
)
|
||||
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
||||
y.unsqueeze(1),
|
||||
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
||||
mode="reflect",
|
||||
)
|
||||
y = y.squeeze(1)
|
||||
|
||||
|
@ -22,7 +22,9 @@ from matcha.utils.logging_utils import pylogger
|
||||
log = pylogger.get_pylogger(__name__)
|
||||
|
||||
|
||||
def compute_data_statistics(data_loader: torch.utils.data.DataLoader, out_channels: int):
|
||||
def compute_data_statistics(
|
||||
data_loader: torch.utils.data.DataLoader, out_channels: int
|
||||
):
|
||||
"""Generate data mean and standard deviation helpful in data normalisation
|
||||
|
||||
Args:
|
||||
@ -42,7 +44,9 @@ def compute_data_statistics(data_loader: torch.utils.data.DataLoader, out_channe
|
||||
total_mel_sq_sum += torch.sum(torch.pow(mels, 2))
|
||||
|
||||
data_mean = total_mel_sum / (total_mel_len * out_channels)
|
||||
data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2))
|
||||
data_std = torch.sqrt(
|
||||
(total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2)
|
||||
)
|
||||
|
||||
return {"mel_mean": data_mean.item(), "mel_std": data_std.item()}
|
||||
|
||||
@ -82,7 +86,9 @@ def main():
|
||||
sys.exit(1)
|
||||
|
||||
with initialize(version_base="1.3", config_path="../../configs/data"):
|
||||
cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[])
|
||||
cfg = compose(
|
||||
config_name=args.input_config, return_hydra_config=True, overrides=[]
|
||||
)
|
||||
|
||||
root_path = rootutils.find_root(search_from=__file__, indicator=".project-root")
|
||||
|
||||
@ -93,8 +99,12 @@ def main():
|
||||
cfg["data_statistics"] = None
|
||||
cfg["seed"] = 1234
|
||||
cfg["batch_size"] = args.batch_size
|
||||
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
|
||||
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
|
||||
cfg["train_filelist_path"] = str(
|
||||
os.path.join(root_path, cfg["train_filelist_path"])
|
||||
)
|
||||
cfg["valid_filelist_path"] = str(
|
||||
os.path.join(root_path, cfg["valid_filelist_path"])
|
||||
)
|
||||
cfg["load_durations"] = False
|
||||
|
||||
text_mel_datamodule = TextMelDataModule(**cfg)
|
||||
|
@ -29,7 +29,12 @@ log = pylogger.get_pylogger(__name__)
|
||||
|
||||
|
||||
def save_durations_to_folder(
|
||||
attn: torch.Tensor, x_length: int, y_length: int, filepath: str, output_folder: Path, text: str
|
||||
attn: torch.Tensor,
|
||||
x_length: int,
|
||||
y_length: int,
|
||||
filepath: str,
|
||||
output_folder: Path,
|
||||
text: str,
|
||||
):
|
||||
durations = attn.squeeze().sum(1)[:x_length].numpy()
|
||||
durations_json = get_phoneme_durations(durations, text)
|
||||
@ -41,7 +46,12 @@ def save_durations_to_folder(
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def compute_durations(data_loader: torch.utils.data.DataLoader, model: nn.Module, device: torch.device, output_folder):
|
||||
def compute_durations(
|
||||
data_loader: torch.utils.data.DataLoader,
|
||||
model: nn.Module,
|
||||
device: torch.device,
|
||||
output_folder,
|
||||
):
|
||||
"""Generate durations from the model for each datapoint and save it in a folder
|
||||
|
||||
Args:
|
||||
@ -123,13 +133,17 @@ def main():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cpu", action="store_true", help="Use CPU for inference, not recommended (default: use GPU if available)"
|
||||
"--cpu",
|
||||
action="store_true",
|
||||
help="Use CPU for inference, not recommended (default: use GPU if available)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
with initialize(version_base="1.3", config_path="../../configs/data"):
|
||||
cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[])
|
||||
cfg = compose(
|
||||
config_name=args.input_config, return_hydra_config=True, overrides=[]
|
||||
)
|
||||
|
||||
root_path = rootutils.find_root(search_from=__file__, indicator=".project-root")
|
||||
|
||||
@ -138,8 +152,12 @@ def main():
|
||||
del cfg["_target_"]
|
||||
cfg["seed"] = 1234
|
||||
cfg["batch_size"] = args.batch_size
|
||||
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
|
||||
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
|
||||
cfg["train_filelist_path"] = str(
|
||||
os.path.join(root_path, cfg["train_filelist_path"])
|
||||
)
|
||||
cfg["valid_filelist_path"] = str(
|
||||
os.path.join(root_path, cfg["valid_filelist_path"])
|
||||
)
|
||||
cfg["load_durations"] = False
|
||||
|
||||
if args.output_folder is not None:
|
||||
@ -155,7 +173,9 @@ def main():
|
||||
|
||||
output_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Preprocessing: {cfg['name']} from training filelist: {cfg['train_filelist_path']}")
|
||||
print(
|
||||
f"Preprocessing: {cfg['name']} from training filelist: {cfg['train_filelist_path']}"
|
||||
)
|
||||
print("Loading model...")
|
||||
device = get_device(args)
|
||||
model = MatchaTTS.load_from_checkpoint(args.checkpoint_path, map_location=device)
|
||||
|
@ -27,7 +27,9 @@ def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
|
||||
|
||||
for _, cb_conf in callbacks_cfg.items():
|
||||
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
|
||||
log.info(f"Instantiating callback <{cb_conf._target_}>") # pylint: disable=protected-access
|
||||
log.info(
|
||||
f"Instantiating callback <{cb_conf._target_}>"
|
||||
) # pylint: disable=protected-access
|
||||
callbacks.append(hydra.utils.instantiate(cb_conf))
|
||||
|
||||
return callbacks
|
||||
@ -50,7 +52,9 @@ def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
|
||||
|
||||
for _, lg_conf in logger_cfg.items():
|
||||
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
|
||||
log.info(f"Instantiating logger <{lg_conf._target_}>") # pylint: disable=protected-access
|
||||
log.info(
|
||||
f"Instantiating logger <{lg_conf._target_}>"
|
||||
) # pylint: disable=protected-access
|
||||
logger.append(hydra.utils.instantiate(lg_conf))
|
||||
|
||||
return logger
|
||||
|
@ -34,8 +34,12 @@ def log_hyperparameters(object_dict: Dict[str, Any]) -> None:
|
||||
|
||||
# save number of model parameters
|
||||
hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
|
||||
hparams["model/params/trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
hparams["model/params/non_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad)
|
||||
hparams["model/params/trainable"] = sum(
|
||||
p.numel() for p in model.parameters() if p.requires_grad
|
||||
)
|
||||
hparams["model/params/non_trainable"] = sum(
|
||||
p.numel() for p in model.parameters() if not p.requires_grad
|
||||
)
|
||||
|
||||
hparams["data"] = cfg["data"]
|
||||
hparams["trainer"] = cfg["trainer"]
|
||||
|
@ -36,7 +36,12 @@ def generate_path(duration, mask):
|
||||
cum_duration_flat = cum_duration.view(b * t_x)
|
||||
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
||||
path = path.view(b, t_x, t_y)
|
||||
path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
||||
path = (
|
||||
path
|
||||
- torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[
|
||||
:, :-1
|
||||
]
|
||||
)
|
||||
path = path * mask
|
||||
return path
|
||||
|
||||
|
@ -14,7 +14,15 @@ def get_pylogger(name: str = __name__) -> logging.Logger:
|
||||
|
||||
# this ensures all logging levels get marked with the rank zero decorator
|
||||
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
|
||||
logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical")
|
||||
logging_levels = (
|
||||
"debug",
|
||||
"info",
|
||||
"warning",
|
||||
"error",
|
||||
"exception",
|
||||
"fatal",
|
||||
"critical",
|
||||
)
|
||||
for level in logging_levels:
|
||||
setattr(logger, level, rank_zero_only(getattr(logger, level)))
|
||||
|
||||
|
@ -47,7 +47,9 @@ def print_config_tree(
|
||||
_ = (
|
||||
queue.append(field)
|
||||
if field in cfg
|
||||
else log.warning(f"Field '{field}' not found in config. Skipping '{field}' config printing...")
|
||||
else log.warning(
|
||||
f"Field '{field}' not found in config. Skipping '{field}' config printing..."
|
||||
)
|
||||
)
|
||||
|
||||
# add all the other fields to queue (not specified in `print_order`)
|
||||
|
@ -9,6 +9,7 @@ from typing import Any, Callable, Dict, Tuple
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# from omegaconf import DictConfig
|
||||
|
||||
# from matcha.utils import pylogger, rich_utils
|
||||
@ -16,7 +17,7 @@ import torch
|
||||
# log = pylogger.get_pylogger(__name__)
|
||||
|
||||
|
||||
def extras(cfg: 'DictConfig') -> None:
|
||||
def extras(cfg: "DictConfig") -> None:
|
||||
"""Applies optional utilities before the task is started.
|
||||
|
||||
Utilities:
|
||||
@ -207,6 +208,7 @@ def get_user_data_dir(appname="matcha_tts"):
|
||||
def assert_model_downloaded(checkpoint_path, url, use_wget=True):
|
||||
import gdown
|
||||
import wget
|
||||
|
||||
if Path(checkpoint_path).exists():
|
||||
log.debug(f"[+] Model already present at {checkpoint_path}!")
|
||||
print(f"[+] Model already present at {checkpoint_path}!")
|
||||
|
Loading…
x
Reference in New Issue
Block a user