mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
Reformat code
This commit is contained in:
parent
a67d4b9a80
commit
7994684bf4
@ -32,7 +32,10 @@ class BaseLightningClass(LightningModule, ABC):
|
|||||||
if self.hparams.scheduler not in (None, {}):
|
if self.hparams.scheduler not in (None, {}):
|
||||||
scheduler_args = {}
|
scheduler_args = {}
|
||||||
# Manage last epoch for exponential schedulers
|
# Manage last epoch for exponential schedulers
|
||||||
if "last_epoch" in inspect.signature(self.hparams.scheduler.scheduler).parameters:
|
if (
|
||||||
|
"last_epoch"
|
||||||
|
in inspect.signature(self.hparams.scheduler.scheduler).parameters
|
||||||
|
):
|
||||||
if hasattr(self, "ckpt_loaded_epoch"):
|
if hasattr(self, "ckpt_loaded_epoch"):
|
||||||
current_epoch = self.ckpt_loaded_epoch - 1
|
current_epoch = self.ckpt_loaded_epoch - 1
|
||||||
else:
|
else:
|
||||||
@ -74,7 +77,9 @@ class BaseLightningClass(LightningModule, ABC):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||||
self.ckpt_loaded_epoch = checkpoint["epoch"] # pylint: disable=attribute-defined-outside-init
|
self.ckpt_loaded_epoch = checkpoint[
|
||||||
|
"epoch"
|
||||||
|
] # pylint: disable=attribute-defined-outside-init
|
||||||
|
|
||||||
def training_step(self, batch: Any, batch_idx: int):
|
def training_step(self, batch: Any, batch_idx: int):
|
||||||
loss_dict = self.get_losses(batch)
|
loss_dict = self.get_losses(batch)
|
||||||
@ -183,8 +188,14 @@ class BaseLightningClass(LightningModule, ABC):
|
|||||||
for i in range(2):
|
for i in range(2):
|
||||||
x = one_batch["x"][i].unsqueeze(0).to(self.device)
|
x = one_batch["x"][i].unsqueeze(0).to(self.device)
|
||||||
x_lengths = one_batch["x_lengths"][i].unsqueeze(0).to(self.device)
|
x_lengths = one_batch["x_lengths"][i].unsqueeze(0).to(self.device)
|
||||||
spks = one_batch["spks"][i].unsqueeze(0).to(self.device) if one_batch["spks"] is not None else None
|
spks = (
|
||||||
output = self.synthesise(x[:, :x_lengths], x_lengths, n_timesteps=10, spks=spks)
|
one_batch["spks"][i].unsqueeze(0).to(self.device)
|
||||||
|
if one_batch["spks"] is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
output = self.synthesise(
|
||||||
|
x[:, :x_lengths], x_lengths, n_timesteps=10, spks=spks
|
||||||
|
)
|
||||||
y_enc, y_dec = output["encoder_outputs"], output["decoder_outputs"]
|
y_enc, y_dec = output["encoder_outputs"], output["decoder_outputs"]
|
||||||
attn = output["attn"]
|
attn = output["attn"]
|
||||||
self.logger.experiment.add_image(
|
self.logger.experiment.add_image(
|
||||||
@ -207,4 +218,6 @@ class BaseLightningClass(LightningModule, ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def on_before_optimizer_step(self, optimizer):
|
def on_before_optimizer_step(self, optimizer):
|
||||||
self.log_dict({f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()})
|
self.log_dict(
|
||||||
|
{f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()}
|
||||||
|
)
|
||||||
|
@ -46,7 +46,9 @@ class Block1D(torch.nn.Module):
|
|||||||
class ResnetBlock1D(torch.nn.Module):
|
class ResnetBlock1D(torch.nn.Module):
|
||||||
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
|
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out))
|
self.mlp = torch.nn.Sequential(
|
||||||
|
nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out)
|
||||||
|
)
|
||||||
|
|
||||||
self.block1 = Block1D(dim, dim_out, groups=groups)
|
self.block1 = Block1D(dim, dim_out, groups=groups)
|
||||||
self.block2 = Block1D(dim_out, dim_out, groups=groups)
|
self.block2 = Block1D(dim_out, dim_out, groups=groups)
|
||||||
@ -131,7 +133,14 @@ class Upsample1D(nn.Module):
|
|||||||
number of output channels. Defaults to `channels`.
|
number of output channels. Defaults to `channels`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels,
|
||||||
|
use_conv=False,
|
||||||
|
use_conv_transpose=True,
|
||||||
|
out_channels=None,
|
||||||
|
name="conv",
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.out_channels = out_channels or channels
|
self.out_channels = out_channels or channels
|
||||||
@ -235,7 +244,9 @@ class Decoder(nn.Module):
|
|||||||
input_channel = output_channel
|
input_channel = output_channel
|
||||||
output_channel = channels[i]
|
output_channel = channels[i]
|
||||||
is_last = i == len(channels) - 1
|
is_last = i == len(channels) - 1
|
||||||
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
resnet = ResnetBlock1D(
|
||||||
|
dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim
|
||||||
|
)
|
||||||
transformer_blocks = nn.ModuleList(
|
transformer_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
self.get_block(
|
self.get_block(
|
||||||
@ -250,16 +261,22 @@ class Decoder(nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
downsample = (
|
downsample = (
|
||||||
Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
Downsample1D(output_channel)
|
||||||
|
if not is_last
|
||||||
|
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
self.down_blocks.append(
|
||||||
|
nn.ModuleList([resnet, transformer_blocks, downsample])
|
||||||
|
)
|
||||||
|
|
||||||
for i in range(num_mid_blocks):
|
for i in range(num_mid_blocks):
|
||||||
input_channel = channels[-1]
|
input_channel = channels[-1]
|
||||||
out_channels = channels[-1]
|
out_channels = channels[-1]
|
||||||
|
|
||||||
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
resnet = ResnetBlock1D(
|
||||||
|
dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim
|
||||||
|
)
|
||||||
|
|
||||||
transformer_blocks = nn.ModuleList(
|
transformer_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
|
@ -4,6 +4,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from matcha.models.components.decoder import Decoder
|
from matcha.models.components.decoder import Decoder
|
||||||
|
|
||||||
# from matcha.utils.pylogger import get_pylogger
|
# from matcha.utils.pylogger import get_pylogger
|
||||||
|
|
||||||
# log = get_pylogger(__name__)
|
# log = get_pylogger(__name__)
|
||||||
@ -50,7 +51,9 @@ class BASECFM(torch.nn.Module, ABC):
|
|||||||
"""
|
"""
|
||||||
z = torch.randn_like(mu) * temperature
|
z = torch.randn_like(mu) * temperature
|
||||||
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
||||||
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
|
return self.solve_euler(
|
||||||
|
z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond
|
||||||
|
)
|
||||||
|
|
||||||
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
||||||
"""
|
"""
|
||||||
@ -112,14 +115,22 @@ class BASECFM(torch.nn.Module, ABC):
|
|||||||
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
||||||
u = x1 - (1 - self.sigma_min) * z
|
u = x1 - (1 - self.sigma_min) * z
|
||||||
|
|
||||||
loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / (
|
loss = F.mse_loss(
|
||||||
torch.sum(mask) * u.shape[1]
|
self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum"
|
||||||
)
|
) / (torch.sum(mask) * u.shape[1])
|
||||||
return loss, y
|
return loss, y
|
||||||
|
|
||||||
|
|
||||||
class CFM(BASECFM):
|
class CFM(BASECFM):
|
||||||
def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64):
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channel,
|
||||||
|
cfm_params,
|
||||||
|
decoder_params,
|
||||||
|
n_spks=1,
|
||||||
|
spk_emb_dim=64,
|
||||||
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
n_feats=in_channels,
|
n_feats=in_channels,
|
||||||
cfm_params=cfm_params,
|
cfm_params=cfm_params,
|
||||||
@ -129,4 +140,6 @@ class CFM(BASECFM):
|
|||||||
|
|
||||||
in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0)
|
in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0)
|
||||||
# Just change the architecture of the estimator here
|
# Just change the architecture of the estimator here
|
||||||
self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params)
|
self.estimator = Decoder(
|
||||||
|
in_channels=in_channels, out_channels=out_channel, **decoder_params
|
||||||
|
)
|
||||||
|
@ -34,7 +34,15 @@ class LayerNorm(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class ConvReluNorm(nn.Module):
|
class ConvReluNorm(nn.Module):
|
||||||
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
hidden_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
n_layers,
|
||||||
|
p_dropout,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.hidden_channels = hidden_channels
|
self.hidden_channels = hidden_channels
|
||||||
@ -45,12 +53,23 @@ class ConvReluNorm(nn.Module):
|
|||||||
|
|
||||||
self.conv_layers = torch.nn.ModuleList()
|
self.conv_layers = torch.nn.ModuleList()
|
||||||
self.norm_layers = torch.nn.ModuleList()
|
self.norm_layers = torch.nn.ModuleList()
|
||||||
self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
self.conv_layers.append(
|
||||||
|
torch.nn.Conv1d(
|
||||||
|
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
||||||
|
)
|
||||||
|
)
|
||||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||||
self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout))
|
self.relu_drop = torch.nn.Sequential(
|
||||||
|
torch.nn.ReLU(), torch.nn.Dropout(p_dropout)
|
||||||
|
)
|
||||||
for _ in range(n_layers - 1):
|
for _ in range(n_layers - 1):
|
||||||
self.conv_layers.append(
|
self.conv_layers.append(
|
||||||
torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
|
torch.nn.Conv1d(
|
||||||
|
hidden_channels,
|
||||||
|
hidden_channels,
|
||||||
|
kernel_size,
|
||||||
|
padding=kernel_size // 2,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||||
self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
|
self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
|
||||||
@ -75,9 +94,13 @@ class DurationPredictor(nn.Module):
|
|||||||
self.p_dropout = p_dropout
|
self.p_dropout = p_dropout
|
||||||
|
|
||||||
self.drop = torch.nn.Dropout(p_dropout)
|
self.drop = torch.nn.Dropout(p_dropout)
|
||||||
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
self.conv_1 = torch.nn.Conv1d(
|
||||||
|
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
||||||
|
)
|
||||||
self.norm_1 = LayerNorm(filter_channels)
|
self.norm_1 = LayerNorm(filter_channels)
|
||||||
self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
self.conv_2 = torch.nn.Conv1d(
|
||||||
|
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
||||||
|
)
|
||||||
self.norm_2 = LayerNorm(filter_channels)
|
self.norm_2 = LayerNorm(filter_channels)
|
||||||
self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
|
self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
|
||||||
|
|
||||||
@ -128,7 +151,9 @@ class RotaryPositionalEmbeddings(nn.Module):
|
|||||||
seq_len = x.shape[0]
|
seq_len = x.shape[0]
|
||||||
|
|
||||||
# $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
|
# $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
|
||||||
theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
|
theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(
|
||||||
|
x.device
|
||||||
|
)
|
||||||
|
|
||||||
# Create position indexes `[0, 1, ..., seq_len - 1]`
|
# Create position indexes `[0, 1, ..., seq_len - 1]`
|
||||||
seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
|
seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
|
||||||
@ -167,7 +192,9 @@ class RotaryPositionalEmbeddings(nn.Module):
|
|||||||
# $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
|
# $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
|
||||||
neg_half_x = self._neg_half(x_rope)
|
neg_half_x = self._neg_half(x_rope)
|
||||||
|
|
||||||
x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]])
|
x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (
|
||||||
|
neg_half_x * self.sin_cached[: x.shape[0]]
|
||||||
|
)
|
||||||
|
|
||||||
return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d")
|
return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d")
|
||||||
|
|
||||||
@ -236,7 +263,9 @@ class MultiHeadAttention(nn.Module):
|
|||||||
|
|
||||||
if self.proximal_bias:
|
if self.proximal_bias:
|
||||||
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
||||||
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
|
scores = scores + self._attention_bias_proximal(t_s).to(
|
||||||
|
device=scores.device, dtype=scores.dtype
|
||||||
|
)
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
scores = scores.masked_fill(mask == 0, -1e4)
|
scores = scores.masked_fill(mask == 0, -1e4)
|
||||||
p_attn = torch.nn.functional.softmax(scores, dim=-1)
|
p_attn = torch.nn.functional.softmax(scores, dim=-1)
|
||||||
@ -253,7 +282,9 @@ class MultiHeadAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FFN(nn.Module):
|
class FFN(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0):
|
def __init__(
|
||||||
|
self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
@ -261,8 +292,12 @@ class FFN(nn.Module):
|
|||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
self.p_dropout = p_dropout
|
self.p_dropout = p_dropout
|
||||||
|
|
||||||
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
self.conv_1 = torch.nn.Conv1d(
|
||||||
self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2)
|
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
||||||
|
)
|
||||||
|
self.conv_2 = torch.nn.Conv1d(
|
||||||
|
filter_channels, out_channels, kernel_size, padding=kernel_size // 2
|
||||||
|
)
|
||||||
self.drop = torch.nn.Dropout(p_dropout)
|
self.drop = torch.nn.Dropout(p_dropout)
|
||||||
|
|
||||||
def forward(self, x, x_mask):
|
def forward(self, x, x_mask):
|
||||||
@ -298,7 +333,11 @@ class Encoder(nn.Module):
|
|||||||
self.ffn_layers = torch.nn.ModuleList()
|
self.ffn_layers = torch.nn.ModuleList()
|
||||||
self.norm_layers_2 = torch.nn.ModuleList()
|
self.norm_layers_2 = torch.nn.ModuleList()
|
||||||
for _ in range(self.n_layers):
|
for _ in range(self.n_layers):
|
||||||
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
|
self.attn_layers.append(
|
||||||
|
MultiHeadAttention(
|
||||||
|
hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
|
||||||
|
)
|
||||||
|
)
|
||||||
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
||||||
self.ffn_layers.append(
|
self.ffn_layers.append(
|
||||||
FFN(
|
FFN(
|
||||||
@ -367,7 +406,9 @@ class TextEncoder(nn.Module):
|
|||||||
encoder_params.p_dropout,
|
encoder_params.p_dropout,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.proj_m = torch.nn.Conv1d(self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1)
|
self.proj_m = torch.nn.Conv1d(
|
||||||
|
self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1
|
||||||
|
)
|
||||||
self.proj_w = DurationPredictor(
|
self.proj_w = DurationPredictor(
|
||||||
self.n_channels + (spk_emb_dim if n_spks > 1 else 0),
|
self.n_channels + (spk_emb_dim if n_spks > 1 else 0),
|
||||||
duration_predictor_params.filter_channels_dp,
|
duration_predictor_params.filter_channels_dp,
|
||||||
|
@ -32,7 +32,14 @@ class SnakeBeta(nn.Module):
|
|||||||
>>> x = a1(x)
|
>>> x = a1(x)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features,
|
||||||
|
out_features,
|
||||||
|
alpha=1.0,
|
||||||
|
alpha_trainable=True,
|
||||||
|
alpha_logscale=True,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initialization.
|
Initialization.
|
||||||
INPUT:
|
INPUT:
|
||||||
@ -44,7 +51,9 @@ class SnakeBeta(nn.Module):
|
|||||||
alpha will be trained along with the rest of your model.
|
alpha will be trained along with the rest of your model.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_features = out_features if isinstance(out_features, list) else [out_features]
|
self.in_features = (
|
||||||
|
out_features if isinstance(out_features, list) else [out_features]
|
||||||
|
)
|
||||||
self.proj = LoRACompatibleLinear(in_features, out_features)
|
self.proj = LoRACompatibleLinear(in_features, out_features)
|
||||||
|
|
||||||
# initialize alpha
|
# initialize alpha
|
||||||
@ -75,7 +84,9 @@ class SnakeBeta(nn.Module):
|
|||||||
alpha = self.alpha
|
alpha = self.alpha
|
||||||
beta = self.beta
|
beta = self.beta
|
||||||
|
|
||||||
x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
|
x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(
|
||||||
|
torch.sin(x * alpha), 2
|
||||||
|
)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -176,8 +187,12 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.only_cross_attention = only_cross_attention
|
self.only_cross_attention = only_cross_attention
|
||||||
|
|
||||||
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
self.use_ada_layer_norm_zero = (
|
||||||
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
num_embeds_ada_norm is not None
|
||||||
|
) and norm_type == "ada_norm_zero"
|
||||||
|
self.use_ada_layer_norm = (
|
||||||
|
num_embeds_ada_norm is not None
|
||||||
|
) and norm_type == "ada_norm"
|
||||||
|
|
||||||
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -215,7 +230,9 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
self.attn2 = Attention(
|
self.attn2 = Attention(
|
||||||
query_dim=dim,
|
query_dim=dim,
|
||||||
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
cross_attention_dim=cross_attention_dim
|
||||||
|
if not double_self_attention
|
||||||
|
else None,
|
||||||
heads=num_attention_heads,
|
heads=num_attention_heads,
|
||||||
dim_head=attention_head_dim,
|
dim_head=attention_head_dim,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
@ -229,7 +246,12 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
|
|
||||||
# 3. Feed-forward
|
# 3. Feed-forward
|
||||||
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
||||||
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
self.ff = FeedForward(
|
||||||
|
dim,
|
||||||
|
dropout=dropout,
|
||||||
|
activation_fn=activation_fn,
|
||||||
|
final_dropout=final_dropout,
|
||||||
|
)
|
||||||
|
|
||||||
# let chunk size default to None
|
# let chunk size default to None
|
||||||
self._chunk_size = None
|
self._chunk_size = None
|
||||||
@ -261,12 +283,18 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
else:
|
else:
|
||||||
norm_hidden_states = self.norm1(hidden_states)
|
norm_hidden_states = self.norm1(hidden_states)
|
||||||
|
|
||||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
cross_attention_kwargs = (
|
||||||
|
cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||||
|
)
|
||||||
|
|
||||||
attn_output = self.attn1(
|
attn_output = self.attn1(
|
||||||
norm_hidden_states,
|
norm_hidden_states,
|
||||||
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
encoder_hidden_states=encoder_hidden_states
|
||||||
attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
|
if self.only_cross_attention
|
||||||
|
else None,
|
||||||
|
attention_mask=encoder_attention_mask
|
||||||
|
if self.only_cross_attention
|
||||||
|
else attention_mask,
|
||||||
**cross_attention_kwargs,
|
**cross_attention_kwargs,
|
||||||
)
|
)
|
||||||
if self.use_ada_layer_norm_zero:
|
if self.use_ada_layer_norm_zero:
|
||||||
@ -276,7 +304,9 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
# 2. Cross-Attention
|
# 2. Cross-Attention
|
||||||
if self.attn2 is not None:
|
if self.attn2 is not None:
|
||||||
norm_hidden_states = (
|
norm_hidden_states = (
|
||||||
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
self.norm2(hidden_states, timestep)
|
||||||
|
if self.use_ada_layer_norm
|
||||||
|
else self.norm2(hidden_states)
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = self.attn2(
|
attn_output = self.attn2(
|
||||||
@ -291,7 +321,9 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
norm_hidden_states = self.norm3(hidden_states)
|
norm_hidden_states = self.norm3(hidden_states)
|
||||||
|
|
||||||
if self.use_ada_layer_norm_zero:
|
if self.use_ada_layer_norm_zero:
|
||||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
norm_hidden_states = (
|
||||||
|
norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||||
|
)
|
||||||
|
|
||||||
if self._chunk_size is not None:
|
if self._chunk_size is not None:
|
||||||
# "feed_forward_chunk_size" can be used to save memory
|
# "feed_forward_chunk_size" can be used to save memory
|
||||||
@ -302,7 +334,12 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
|
|
||||||
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
||||||
ff_output = torch.cat(
|
ff_output = torch.cat(
|
||||||
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
|
[
|
||||||
|
self.ff(hid_slice)
|
||||||
|
for hid_slice in norm_hidden_states.chunk(
|
||||||
|
num_chunks, dim=self._chunk_dim
|
||||||
|
)
|
||||||
|
],
|
||||||
dim=self._chunk_dim,
|
dim=self._chunk_dim,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -4,7 +4,9 @@ from matcha.text.symbols import symbols
|
|||||||
|
|
||||||
# Mappings from symbol to numeric ID and vice versa:
|
# Mappings from symbol to numeric ID and vice versa:
|
||||||
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
||||||
_id_to_symbol = {i: s for i, s in enumerate(symbols)} # pylint: disable=unnecessary-comprehension
|
_id_to_symbol = {
|
||||||
|
i: s for i, s in enumerate(symbols)
|
||||||
|
} # pylint: disable=unnecessary-comprehension
|
||||||
|
|
||||||
|
|
||||||
def text_to_sequence(text, cleaner_names):
|
def text_to_sequence(text, cleaner_names):
|
||||||
@ -20,13 +22,15 @@ def text_to_sequence(text, cleaner_names):
|
|||||||
clean_text = _clean_text(text, cleaner_names)
|
clean_text = _clean_text(text, cleaner_names)
|
||||||
for symbol in clean_text:
|
for symbol in clean_text:
|
||||||
try:
|
try:
|
||||||
if symbol in '_()[]# ̃':
|
if symbol in "_()[]# ̃":
|
||||||
continue
|
continue
|
||||||
symbol_id = _symbol_to_id[symbol]
|
symbol_id = _symbol_to_id[symbol]
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
print(text)
|
print(text)
|
||||||
print(clean_text)
|
print(clean_text)
|
||||||
raise RuntimeError(f'text: {text}, clean_text: {clean_text}, ex: {ex}, symbol: {symbol}')
|
raise RuntimeError(
|
||||||
|
f"text: {text}, clean_text: {clean_text}, ex: {ex}, symbol: {symbol}"
|
||||||
|
)
|
||||||
sequence += [symbol_id]
|
sequence += [symbol_id]
|
||||||
return sequence, clean_text
|
return sequence, clean_text
|
||||||
|
|
||||||
|
@ -75,6 +75,7 @@ def lowercase(text):
|
|||||||
def collapse_whitespace(text):
|
def collapse_whitespace(text):
|
||||||
return re.sub(_whitespace_re, " ", text)
|
return re.sub(_whitespace_re, " ", text)
|
||||||
|
|
||||||
|
|
||||||
def remove_parentheses(text):
|
def remove_parentheses(text):
|
||||||
text = text.replace("(", "")
|
text = text.replace("(", "")
|
||||||
text = text.replace(")", "")
|
text = text.replace(")", "")
|
||||||
|
@ -56,7 +56,9 @@ def _expand_number(m):
|
|||||||
elif num % 100 == 0:
|
elif num % 100 == 0:
|
||||||
return _inflect.number_to_words(num // 100) + " hundred"
|
return _inflect.number_to_words(num // 100) + " hundred"
|
||||||
else:
|
else:
|
||||||
return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
|
return _inflect.number_to_words(
|
||||||
|
num, andword="", zero="oh", group=2
|
||||||
|
).replace(", ", " ")
|
||||||
else:
|
else:
|
||||||
return _inflect.number_to_words(num, andword="")
|
return _inflect.number_to_words(num, andword="")
|
||||||
|
|
||||||
|
@ -5,9 +5,7 @@ Defines the set of symbols used in text input to the model.
|
|||||||
_pad = "_"
|
_pad = "_"
|
||||||
_punctuation = ';:,.!?¡¿—…"«»“” '
|
_punctuation = ';:,.!?¡¿—…"«»“” '
|
||||||
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||||
_letters_ipa = (
|
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
||||||
"ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Export all symbols:
|
# Export all symbols:
|
||||||
|
@ -42,7 +42,9 @@ mel_basis = {}
|
|||||||
hann_window = {}
|
hann_window = {}
|
||||||
|
|
||||||
|
|
||||||
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
def mel_spectrogram(
|
||||||
|
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
|
||||||
|
):
|
||||||
if torch.min(y) < -1.0:
|
if torch.min(y) < -1.0:
|
||||||
print("min value is ", torch.min(y))
|
print("min value is ", torch.min(y))
|
||||||
if torch.max(y) > 1.0:
|
if torch.max(y) > 1.0:
|
||||||
@ -50,12 +52,18 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin,
|
|||||||
|
|
||||||
global mel_basis, hann_window # pylint: disable=global-statement
|
global mel_basis, hann_window # pylint: disable=global-statement
|
||||||
if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
|
if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
|
||||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
mel = librosa_mel_fn(
|
||||||
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
||||||
|
)
|
||||||
|
mel_basis[str(fmax) + "_" + str(y.device)] = (
|
||||||
|
torch.from_numpy(mel).float().to(y.device)
|
||||||
|
)
|
||||||
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
||||||
|
|
||||||
y = torch.nn.functional.pad(
|
y = torch.nn.functional.pad(
|
||||||
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
y.unsqueeze(1),
|
||||||
|
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
||||||
|
mode="reflect",
|
||||||
)
|
)
|
||||||
y = y.squeeze(1)
|
y = y.squeeze(1)
|
||||||
|
|
||||||
|
@ -22,7 +22,9 @@ from matcha.utils.logging_utils import pylogger
|
|||||||
log = pylogger.get_pylogger(__name__)
|
log = pylogger.get_pylogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def compute_data_statistics(data_loader: torch.utils.data.DataLoader, out_channels: int):
|
def compute_data_statistics(
|
||||||
|
data_loader: torch.utils.data.DataLoader, out_channels: int
|
||||||
|
):
|
||||||
"""Generate data mean and standard deviation helpful in data normalisation
|
"""Generate data mean and standard deviation helpful in data normalisation
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -42,7 +44,9 @@ def compute_data_statistics(data_loader: torch.utils.data.DataLoader, out_channe
|
|||||||
total_mel_sq_sum += torch.sum(torch.pow(mels, 2))
|
total_mel_sq_sum += torch.sum(torch.pow(mels, 2))
|
||||||
|
|
||||||
data_mean = total_mel_sum / (total_mel_len * out_channels)
|
data_mean = total_mel_sum / (total_mel_len * out_channels)
|
||||||
data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2))
|
data_std = torch.sqrt(
|
||||||
|
(total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2)
|
||||||
|
)
|
||||||
|
|
||||||
return {"mel_mean": data_mean.item(), "mel_std": data_std.item()}
|
return {"mel_mean": data_mean.item(), "mel_std": data_std.item()}
|
||||||
|
|
||||||
@ -82,7 +86,9 @@ def main():
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
with initialize(version_base="1.3", config_path="../../configs/data"):
|
with initialize(version_base="1.3", config_path="../../configs/data"):
|
||||||
cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[])
|
cfg = compose(
|
||||||
|
config_name=args.input_config, return_hydra_config=True, overrides=[]
|
||||||
|
)
|
||||||
|
|
||||||
root_path = rootutils.find_root(search_from=__file__, indicator=".project-root")
|
root_path = rootutils.find_root(search_from=__file__, indicator=".project-root")
|
||||||
|
|
||||||
@ -93,8 +99,12 @@ def main():
|
|||||||
cfg["data_statistics"] = None
|
cfg["data_statistics"] = None
|
||||||
cfg["seed"] = 1234
|
cfg["seed"] = 1234
|
||||||
cfg["batch_size"] = args.batch_size
|
cfg["batch_size"] = args.batch_size
|
||||||
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
|
cfg["train_filelist_path"] = str(
|
||||||
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
|
os.path.join(root_path, cfg["train_filelist_path"])
|
||||||
|
)
|
||||||
|
cfg["valid_filelist_path"] = str(
|
||||||
|
os.path.join(root_path, cfg["valid_filelist_path"])
|
||||||
|
)
|
||||||
cfg["load_durations"] = False
|
cfg["load_durations"] = False
|
||||||
|
|
||||||
text_mel_datamodule = TextMelDataModule(**cfg)
|
text_mel_datamodule = TextMelDataModule(**cfg)
|
||||||
|
@ -29,7 +29,12 @@ log = pylogger.get_pylogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def save_durations_to_folder(
|
def save_durations_to_folder(
|
||||||
attn: torch.Tensor, x_length: int, y_length: int, filepath: str, output_folder: Path, text: str
|
attn: torch.Tensor,
|
||||||
|
x_length: int,
|
||||||
|
y_length: int,
|
||||||
|
filepath: str,
|
||||||
|
output_folder: Path,
|
||||||
|
text: str,
|
||||||
):
|
):
|
||||||
durations = attn.squeeze().sum(1)[:x_length].numpy()
|
durations = attn.squeeze().sum(1)[:x_length].numpy()
|
||||||
durations_json = get_phoneme_durations(durations, text)
|
durations_json = get_phoneme_durations(durations, text)
|
||||||
@ -41,7 +46,12 @@ def save_durations_to_folder(
|
|||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def compute_durations(data_loader: torch.utils.data.DataLoader, model: nn.Module, device: torch.device, output_folder):
|
def compute_durations(
|
||||||
|
data_loader: torch.utils.data.DataLoader,
|
||||||
|
model: nn.Module,
|
||||||
|
device: torch.device,
|
||||||
|
output_folder,
|
||||||
|
):
|
||||||
"""Generate durations from the model for each datapoint and save it in a folder
|
"""Generate durations from the model for each datapoint and save it in a folder
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -123,13 +133,17 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cpu", action="store_true", help="Use CPU for inference, not recommended (default: use GPU if available)"
|
"--cpu",
|
||||||
|
action="store_true",
|
||||||
|
help="Use CPU for inference, not recommended (default: use GPU if available)",
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
with initialize(version_base="1.3", config_path="../../configs/data"):
|
with initialize(version_base="1.3", config_path="../../configs/data"):
|
||||||
cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[])
|
cfg = compose(
|
||||||
|
config_name=args.input_config, return_hydra_config=True, overrides=[]
|
||||||
|
)
|
||||||
|
|
||||||
root_path = rootutils.find_root(search_from=__file__, indicator=".project-root")
|
root_path = rootutils.find_root(search_from=__file__, indicator=".project-root")
|
||||||
|
|
||||||
@ -138,8 +152,12 @@ def main():
|
|||||||
del cfg["_target_"]
|
del cfg["_target_"]
|
||||||
cfg["seed"] = 1234
|
cfg["seed"] = 1234
|
||||||
cfg["batch_size"] = args.batch_size
|
cfg["batch_size"] = args.batch_size
|
||||||
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
|
cfg["train_filelist_path"] = str(
|
||||||
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
|
os.path.join(root_path, cfg["train_filelist_path"])
|
||||||
|
)
|
||||||
|
cfg["valid_filelist_path"] = str(
|
||||||
|
os.path.join(root_path, cfg["valid_filelist_path"])
|
||||||
|
)
|
||||||
cfg["load_durations"] = False
|
cfg["load_durations"] = False
|
||||||
|
|
||||||
if args.output_folder is not None:
|
if args.output_folder is not None:
|
||||||
@ -155,7 +173,9 @@ def main():
|
|||||||
|
|
||||||
output_folder.mkdir(parents=True, exist_ok=True)
|
output_folder.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
print(f"Preprocessing: {cfg['name']} from training filelist: {cfg['train_filelist_path']}")
|
print(
|
||||||
|
f"Preprocessing: {cfg['name']} from training filelist: {cfg['train_filelist_path']}"
|
||||||
|
)
|
||||||
print("Loading model...")
|
print("Loading model...")
|
||||||
device = get_device(args)
|
device = get_device(args)
|
||||||
model = MatchaTTS.load_from_checkpoint(args.checkpoint_path, map_location=device)
|
model = MatchaTTS.load_from_checkpoint(args.checkpoint_path, map_location=device)
|
||||||
|
@ -27,7 +27,9 @@ def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
|
|||||||
|
|
||||||
for _, cb_conf in callbacks_cfg.items():
|
for _, cb_conf in callbacks_cfg.items():
|
||||||
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
|
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
|
||||||
log.info(f"Instantiating callback <{cb_conf._target_}>") # pylint: disable=protected-access
|
log.info(
|
||||||
|
f"Instantiating callback <{cb_conf._target_}>"
|
||||||
|
) # pylint: disable=protected-access
|
||||||
callbacks.append(hydra.utils.instantiate(cb_conf))
|
callbacks.append(hydra.utils.instantiate(cb_conf))
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
@ -50,7 +52,9 @@ def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
|
|||||||
|
|
||||||
for _, lg_conf in logger_cfg.items():
|
for _, lg_conf in logger_cfg.items():
|
||||||
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
|
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
|
||||||
log.info(f"Instantiating logger <{lg_conf._target_}>") # pylint: disable=protected-access
|
log.info(
|
||||||
|
f"Instantiating logger <{lg_conf._target_}>"
|
||||||
|
) # pylint: disable=protected-access
|
||||||
logger.append(hydra.utils.instantiate(lg_conf))
|
logger.append(hydra.utils.instantiate(lg_conf))
|
||||||
|
|
||||||
return logger
|
return logger
|
||||||
|
@ -34,8 +34,12 @@ def log_hyperparameters(object_dict: Dict[str, Any]) -> None:
|
|||||||
|
|
||||||
# save number of model parameters
|
# save number of model parameters
|
||||||
hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
|
hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
|
||||||
hparams["model/params/trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
hparams["model/params/trainable"] = sum(
|
||||||
hparams["model/params/non_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad)
|
p.numel() for p in model.parameters() if p.requires_grad
|
||||||
|
)
|
||||||
|
hparams["model/params/non_trainable"] = sum(
|
||||||
|
p.numel() for p in model.parameters() if not p.requires_grad
|
||||||
|
)
|
||||||
|
|
||||||
hparams["data"] = cfg["data"]
|
hparams["data"] = cfg["data"]
|
||||||
hparams["trainer"] = cfg["trainer"]
|
hparams["trainer"] = cfg["trainer"]
|
||||||
|
@ -36,7 +36,12 @@ def generate_path(duration, mask):
|
|||||||
cum_duration_flat = cum_duration.view(b * t_x)
|
cum_duration_flat = cum_duration.view(b * t_x)
|
||||||
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
||||||
path = path.view(b, t_x, t_y)
|
path = path.view(b, t_x, t_y)
|
||||||
path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
path = (
|
||||||
|
path
|
||||||
|
- torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[
|
||||||
|
:, :-1
|
||||||
|
]
|
||||||
|
)
|
||||||
path = path * mask
|
path = path * mask
|
||||||
return path
|
return path
|
||||||
|
|
||||||
|
@ -14,7 +14,15 @@ def get_pylogger(name: str = __name__) -> logging.Logger:
|
|||||||
|
|
||||||
# this ensures all logging levels get marked with the rank zero decorator
|
# this ensures all logging levels get marked with the rank zero decorator
|
||||||
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
|
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
|
||||||
logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical")
|
logging_levels = (
|
||||||
|
"debug",
|
||||||
|
"info",
|
||||||
|
"warning",
|
||||||
|
"error",
|
||||||
|
"exception",
|
||||||
|
"fatal",
|
||||||
|
"critical",
|
||||||
|
)
|
||||||
for level in logging_levels:
|
for level in logging_levels:
|
||||||
setattr(logger, level, rank_zero_only(getattr(logger, level)))
|
setattr(logger, level, rank_zero_only(getattr(logger, level)))
|
||||||
|
|
||||||
|
@ -47,7 +47,9 @@ def print_config_tree(
|
|||||||
_ = (
|
_ = (
|
||||||
queue.append(field)
|
queue.append(field)
|
||||||
if field in cfg
|
if field in cfg
|
||||||
else log.warning(f"Field '{field}' not found in config. Skipping '{field}' config printing...")
|
else log.warning(
|
||||||
|
f"Field '{field}' not found in config. Skipping '{field}' config printing..."
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# add all the other fields to queue (not specified in `print_order`)
|
# add all the other fields to queue (not specified in `print_order`)
|
||||||
|
@ -9,6 +9,7 @@ from typing import Any, Callable, Dict, Tuple
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# from omegaconf import DictConfig
|
# from omegaconf import DictConfig
|
||||||
|
|
||||||
# from matcha.utils import pylogger, rich_utils
|
# from matcha.utils import pylogger, rich_utils
|
||||||
@ -16,7 +17,7 @@ import torch
|
|||||||
# log = pylogger.get_pylogger(__name__)
|
# log = pylogger.get_pylogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def extras(cfg: 'DictConfig') -> None:
|
def extras(cfg: "DictConfig") -> None:
|
||||||
"""Applies optional utilities before the task is started.
|
"""Applies optional utilities before the task is started.
|
||||||
|
|
||||||
Utilities:
|
Utilities:
|
||||||
@ -207,6 +208,7 @@ def get_user_data_dir(appname="matcha_tts"):
|
|||||||
def assert_model_downloaded(checkpoint_path, url, use_wget=True):
|
def assert_model_downloaded(checkpoint_path, url, use_wget=True):
|
||||||
import gdown
|
import gdown
|
||||||
import wget
|
import wget
|
||||||
|
|
||||||
if Path(checkpoint_path).exists():
|
if Path(checkpoint_path).exists():
|
||||||
log.debug(f"[+] Model already present at {checkpoint_path}!")
|
log.debug(f"[+] Model already present at {checkpoint_path}!")
|
||||||
print(f"[+] Model already present at {checkpoint_path}!")
|
print(f"[+] Model already present at {checkpoint_path}!")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user