Reformat code

This commit is contained in:
Fangjun Kuang 2024-10-28 19:06:44 +08:00
parent a67d4b9a80
commit 7994684bf4
18 changed files with 268 additions and 79 deletions

View File

@ -32,7 +32,10 @@ class BaseLightningClass(LightningModule, ABC):
if self.hparams.scheduler not in (None, {}): if self.hparams.scheduler not in (None, {}):
scheduler_args = {} scheduler_args = {}
# Manage last epoch for exponential schedulers # Manage last epoch for exponential schedulers
if "last_epoch" in inspect.signature(self.hparams.scheduler.scheduler).parameters: if (
"last_epoch"
in inspect.signature(self.hparams.scheduler.scheduler).parameters
):
if hasattr(self, "ckpt_loaded_epoch"): if hasattr(self, "ckpt_loaded_epoch"):
current_epoch = self.ckpt_loaded_epoch - 1 current_epoch = self.ckpt_loaded_epoch - 1
else: else:
@ -74,7 +77,9 @@ class BaseLightningClass(LightningModule, ABC):
} }
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
self.ckpt_loaded_epoch = checkpoint["epoch"] # pylint: disable=attribute-defined-outside-init self.ckpt_loaded_epoch = checkpoint[
"epoch"
] # pylint: disable=attribute-defined-outside-init
def training_step(self, batch: Any, batch_idx: int): def training_step(self, batch: Any, batch_idx: int):
loss_dict = self.get_losses(batch) loss_dict = self.get_losses(batch)
@ -183,8 +188,14 @@ class BaseLightningClass(LightningModule, ABC):
for i in range(2): for i in range(2):
x = one_batch["x"][i].unsqueeze(0).to(self.device) x = one_batch["x"][i].unsqueeze(0).to(self.device)
x_lengths = one_batch["x_lengths"][i].unsqueeze(0).to(self.device) x_lengths = one_batch["x_lengths"][i].unsqueeze(0).to(self.device)
spks = one_batch["spks"][i].unsqueeze(0).to(self.device) if one_batch["spks"] is not None else None spks = (
output = self.synthesise(x[:, :x_lengths], x_lengths, n_timesteps=10, spks=spks) one_batch["spks"][i].unsqueeze(0).to(self.device)
if one_batch["spks"] is not None
else None
)
output = self.synthesise(
x[:, :x_lengths], x_lengths, n_timesteps=10, spks=spks
)
y_enc, y_dec = output["encoder_outputs"], output["decoder_outputs"] y_enc, y_dec = output["encoder_outputs"], output["decoder_outputs"]
attn = output["attn"] attn = output["attn"]
self.logger.experiment.add_image( self.logger.experiment.add_image(
@ -207,4 +218,6 @@ class BaseLightningClass(LightningModule, ABC):
) )
def on_before_optimizer_step(self, optimizer): def on_before_optimizer_step(self, optimizer):
self.log_dict({f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()}) self.log_dict(
{f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()}
)

View File

@ -46,7 +46,9 @@ class Block1D(torch.nn.Module):
class ResnetBlock1D(torch.nn.Module): class ResnetBlock1D(torch.nn.Module):
def __init__(self, dim, dim_out, time_emb_dim, groups=8): def __init__(self, dim, dim_out, time_emb_dim, groups=8):
super().__init__() super().__init__()
self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out)) self.mlp = torch.nn.Sequential(
nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out)
)
self.block1 = Block1D(dim, dim_out, groups=groups) self.block1 = Block1D(dim, dim_out, groups=groups)
self.block2 = Block1D(dim_out, dim_out, groups=groups) self.block2 = Block1D(dim_out, dim_out, groups=groups)
@ -131,7 +133,14 @@ class Upsample1D(nn.Module):
number of output channels. Defaults to `channels`. number of output channels. Defaults to `channels`.
""" """
def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"): def __init__(
self,
channels,
use_conv=False,
use_conv_transpose=True,
out_channels=None,
name="conv",
):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
@ -235,7 +244,9 @@ class Decoder(nn.Module):
input_channel = output_channel input_channel = output_channel
output_channel = channels[i] output_channel = channels[i]
is_last = i == len(channels) - 1 is_last = i == len(channels) - 1
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) resnet = ResnetBlock1D(
dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim
)
transformer_blocks = nn.ModuleList( transformer_blocks = nn.ModuleList(
[ [
self.get_block( self.get_block(
@ -250,16 +261,22 @@ class Decoder(nn.Module):
] ]
) )
downsample = ( downsample = (
Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1) Downsample1D(output_channel)
if not is_last
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
) )
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample])) self.down_blocks.append(
nn.ModuleList([resnet, transformer_blocks, downsample])
)
for i in range(num_mid_blocks): for i in range(num_mid_blocks):
input_channel = channels[-1] input_channel = channels[-1]
out_channels = channels[-1] out_channels = channels[-1]
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) resnet = ResnetBlock1D(
dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim
)
transformer_blocks = nn.ModuleList( transformer_blocks = nn.ModuleList(
[ [

View File

@ -4,6 +4,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from matcha.models.components.decoder import Decoder from matcha.models.components.decoder import Decoder
# from matcha.utils.pylogger import get_pylogger # from matcha.utils.pylogger import get_pylogger
# log = get_pylogger(__name__) # log = get_pylogger(__name__)
@ -50,7 +51,9 @@ class BASECFM(torch.nn.Module, ABC):
""" """
z = torch.randn_like(mu) * temperature z = torch.randn_like(mu) * temperature
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond) return self.solve_euler(
z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond
)
def solve_euler(self, x, t_span, mu, mask, spks, cond): def solve_euler(self, x, t_span, mu, mask, spks, cond):
""" """
@ -112,14 +115,22 @@ class BASECFM(torch.nn.Module, ABC):
y = (1 - (1 - self.sigma_min) * t) * z + t * x1 y = (1 - (1 - self.sigma_min) * t) * z + t * x1
u = x1 - (1 - self.sigma_min) * z u = x1 - (1 - self.sigma_min) * z
loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / ( loss = F.mse_loss(
torch.sum(mask) * u.shape[1] self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum"
) ) / (torch.sum(mask) * u.shape[1])
return loss, y return loss, y
class CFM(BASECFM): class CFM(BASECFM):
def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64): def __init__(
self,
in_channels,
out_channel,
cfm_params,
decoder_params,
n_spks=1,
spk_emb_dim=64,
):
super().__init__( super().__init__(
n_feats=in_channels, n_feats=in_channels,
cfm_params=cfm_params, cfm_params=cfm_params,
@ -129,4 +140,6 @@ class CFM(BASECFM):
in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0) in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0)
# Just change the architecture of the estimator here # Just change the architecture of the estimator here
self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params) self.estimator = Decoder(
in_channels=in_channels, out_channels=out_channel, **decoder_params
)

View File

@ -34,7 +34,15 @@ class LayerNorm(nn.Module):
class ConvReluNorm(nn.Module): class ConvReluNorm(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): def __init__(
self,
in_channels,
hidden_channels,
out_channels,
kernel_size,
n_layers,
p_dropout,
):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
@ -45,12 +53,23 @@ class ConvReluNorm(nn.Module):
self.conv_layers = torch.nn.ModuleList() self.conv_layers = torch.nn.ModuleList()
self.norm_layers = torch.nn.ModuleList() self.norm_layers = torch.nn.ModuleList()
self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) self.conv_layers.append(
torch.nn.Conv1d(
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
)
self.norm_layers.append(LayerNorm(hidden_channels)) self.norm_layers.append(LayerNorm(hidden_channels))
self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout)) self.relu_drop = torch.nn.Sequential(
torch.nn.ReLU(), torch.nn.Dropout(p_dropout)
)
for _ in range(n_layers - 1): for _ in range(n_layers - 1):
self.conv_layers.append( self.conv_layers.append(
torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2) torch.nn.Conv1d(
hidden_channels,
hidden_channels,
kernel_size,
padding=kernel_size // 2,
)
) )
self.norm_layers.append(LayerNorm(hidden_channels)) self.norm_layers.append(LayerNorm(hidden_channels))
self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1) self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
@ -75,9 +94,13 @@ class DurationPredictor(nn.Module):
self.p_dropout = p_dropout self.p_dropout = p_dropout
self.drop = torch.nn.Dropout(p_dropout) self.drop = torch.nn.Dropout(p_dropout)
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) self.conv_1 = torch.nn.Conv1d(
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.norm_1 = LayerNorm(filter_channels) self.norm_1 = LayerNorm(filter_channels)
self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) self.conv_2 = torch.nn.Conv1d(
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.norm_2 = LayerNorm(filter_channels) self.norm_2 = LayerNorm(filter_channels)
self.proj = torch.nn.Conv1d(filter_channels, 1, 1) self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
@ -128,7 +151,9 @@ class RotaryPositionalEmbeddings(nn.Module):
seq_len = x.shape[0] seq_len = x.shape[0]
# $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ # $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device) theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(
x.device
)
# Create position indexes `[0, 1, ..., seq_len - 1]` # Create position indexes `[0, 1, ..., seq_len - 1]`
seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device) seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
@ -167,7 +192,9 @@ class RotaryPositionalEmbeddings(nn.Module):
# $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
neg_half_x = self._neg_half(x_rope) neg_half_x = self._neg_half(x_rope)
x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]]) x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (
neg_half_x * self.sin_cached[: x.shape[0]]
)
return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d") return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d")
@ -236,7 +263,9 @@ class MultiHeadAttention(nn.Module):
if self.proximal_bias: if self.proximal_bias:
assert t_s == t_t, "Proximal bias is only available for self-attention." assert t_s == t_t, "Proximal bias is only available for self-attention."
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) scores = scores + self._attention_bias_proximal(t_s).to(
device=scores.device, dtype=scores.dtype
)
if mask is not None: if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4) scores = scores.masked_fill(mask == 0, -1e4)
p_attn = torch.nn.functional.softmax(scores, dim=-1) p_attn = torch.nn.functional.softmax(scores, dim=-1)
@ -253,7 +282,9 @@ class MultiHeadAttention(nn.Module):
class FFN(nn.Module): class FFN(nn.Module):
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0): def __init__(
self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0
):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
@ -261,8 +292,12 @@ class FFN(nn.Module):
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.p_dropout = p_dropout self.p_dropout = p_dropout
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) self.conv_1 = torch.nn.Conv1d(
self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2) in_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.conv_2 = torch.nn.Conv1d(
filter_channels, out_channels, kernel_size, padding=kernel_size // 2
)
self.drop = torch.nn.Dropout(p_dropout) self.drop = torch.nn.Dropout(p_dropout)
def forward(self, x, x_mask): def forward(self, x, x_mask):
@ -298,7 +333,11 @@ class Encoder(nn.Module):
self.ffn_layers = torch.nn.ModuleList() self.ffn_layers = torch.nn.ModuleList()
self.norm_layers_2 = torch.nn.ModuleList() self.norm_layers_2 = torch.nn.ModuleList()
for _ in range(self.n_layers): for _ in range(self.n_layers):
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)) self.attn_layers.append(
MultiHeadAttention(
hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
)
)
self.norm_layers_1.append(LayerNorm(hidden_channels)) self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append( self.ffn_layers.append(
FFN( FFN(
@ -367,7 +406,9 @@ class TextEncoder(nn.Module):
encoder_params.p_dropout, encoder_params.p_dropout,
) )
self.proj_m = torch.nn.Conv1d(self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1) self.proj_m = torch.nn.Conv1d(
self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1
)
self.proj_w = DurationPredictor( self.proj_w = DurationPredictor(
self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_channels + (spk_emb_dim if n_spks > 1 else 0),
duration_predictor_params.filter_channels_dp, duration_predictor_params.filter_channels_dp,

View File

@ -32,7 +32,14 @@ class SnakeBeta(nn.Module):
>>> x = a1(x) >>> x = a1(x)
""" """
def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): def __init__(
self,
in_features,
out_features,
alpha=1.0,
alpha_trainable=True,
alpha_logscale=True,
):
""" """
Initialization. Initialization.
INPUT: INPUT:
@ -44,7 +51,9 @@ class SnakeBeta(nn.Module):
alpha will be trained along with the rest of your model. alpha will be trained along with the rest of your model.
""" """
super().__init__() super().__init__()
self.in_features = out_features if isinstance(out_features, list) else [out_features] self.in_features = (
out_features if isinstance(out_features, list) else [out_features]
)
self.proj = LoRACompatibleLinear(in_features, out_features) self.proj = LoRACompatibleLinear(in_features, out_features)
# initialize alpha # initialize alpha
@ -75,7 +84,9 @@ class SnakeBeta(nn.Module):
alpha = self.alpha alpha = self.alpha
beta = self.beta beta = self.beta
x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2) x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(
torch.sin(x * alpha), 2
)
return x return x
@ -176,8 +187,12 @@ class BasicTransformerBlock(nn.Module):
super().__init__() super().__init__()
self.only_cross_attention = only_cross_attention self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" self.use_ada_layer_norm_zero = (
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" num_embeds_ada_norm is not None
) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (
num_embeds_ada_norm is not None
) and norm_type == "ada_norm"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError( raise ValueError(
@ -215,7 +230,9 @@ class BasicTransformerBlock(nn.Module):
) )
self.attn2 = Attention( self.attn2 = Attention(
query_dim=dim, query_dim=dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None, cross_attention_dim=cross_attention_dim
if not double_self_attention
else None,
heads=num_attention_heads, heads=num_attention_heads,
dim_head=attention_head_dim, dim_head=attention_head_dim,
dropout=dropout, dropout=dropout,
@ -229,7 +246,12 @@ class BasicTransformerBlock(nn.Module):
# 3. Feed-forward # 3. Feed-forward
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
)
# let chunk size default to None # let chunk size default to None
self._chunk_size = None self._chunk_size = None
@ -261,12 +283,18 @@ class BasicTransformerBlock(nn.Module):
else: else:
norm_hidden_states = self.norm1(hidden_states) norm_hidden_states = self.norm1(hidden_states)
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} cross_attention_kwargs = (
cross_attention_kwargs if cross_attention_kwargs is not None else {}
)
attn_output = self.attn1( attn_output = self.attn1(
norm_hidden_states, norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, encoder_hidden_states=encoder_hidden_states
attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask, if self.only_cross_attention
else None,
attention_mask=encoder_attention_mask
if self.only_cross_attention
else attention_mask,
**cross_attention_kwargs, **cross_attention_kwargs,
) )
if self.use_ada_layer_norm_zero: if self.use_ada_layer_norm_zero:
@ -276,7 +304,9 @@ class BasicTransformerBlock(nn.Module):
# 2. Cross-Attention # 2. Cross-Attention
if self.attn2 is not None: if self.attn2 is not None:
norm_hidden_states = ( norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) self.norm2(hidden_states, timestep)
if self.use_ada_layer_norm
else self.norm2(hidden_states)
) )
attn_output = self.attn2( attn_output = self.attn2(
@ -291,7 +321,9 @@ class BasicTransformerBlock(nn.Module):
norm_hidden_states = self.norm3(hidden_states) norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero: if self.use_ada_layer_norm_zero:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] norm_hidden_states = (
norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
)
if self._chunk_size is not None: if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory # "feed_forward_chunk_size" can be used to save memory
@ -302,7 +334,12 @@ class BasicTransformerBlock(nn.Module):
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
ff_output = torch.cat( ff_output = torch.cat(
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], [
self.ff(hid_slice)
for hid_slice in norm_hidden_states.chunk(
num_chunks, dim=self._chunk_dim
)
],
dim=self._chunk_dim, dim=self._chunk_dim,
) )
else: else:

View File

@ -4,7 +4,9 @@ from matcha.text.symbols import symbols
# Mappings from symbol to numeric ID and vice versa: # Mappings from symbol to numeric ID and vice versa:
_symbol_to_id = {s: i for i, s in enumerate(symbols)} _symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)} # pylint: disable=unnecessary-comprehension _id_to_symbol = {
i: s for i, s in enumerate(symbols)
} # pylint: disable=unnecessary-comprehension
def text_to_sequence(text, cleaner_names): def text_to_sequence(text, cleaner_names):
@ -20,13 +22,15 @@ def text_to_sequence(text, cleaner_names):
clean_text = _clean_text(text, cleaner_names) clean_text = _clean_text(text, cleaner_names)
for symbol in clean_text: for symbol in clean_text:
try: try:
if symbol in '_()[]# ̃': if symbol in "_()[]# ̃":
continue continue
symbol_id = _symbol_to_id[symbol] symbol_id = _symbol_to_id[symbol]
except Exception as ex: except Exception as ex:
print(text) print(text)
print(clean_text) print(clean_text)
raise RuntimeError(f'text: {text}, clean_text: {clean_text}, ex: {ex}, symbol: {symbol}') raise RuntimeError(
f"text: {text}, clean_text: {clean_text}, ex: {ex}, symbol: {symbol}"
)
sequence += [symbol_id] sequence += [symbol_id]
return sequence, clean_text return sequence, clean_text

View File

@ -75,11 +75,12 @@ def lowercase(text):
def collapse_whitespace(text): def collapse_whitespace(text):
return re.sub(_whitespace_re, " ", text) return re.sub(_whitespace_re, " ", text)
def remove_parentheses(text): def remove_parentheses(text):
text = text.replace("(", "") text = text.replace("(", "")
text = text.replace(")", "") text = text.replace(")", "")
text = text.replace("[", "") text = text.replace("[", "")
text = text.replace("]", "") text = text.replace("]", "")
return text return text

View File

@ -56,7 +56,9 @@ def _expand_number(m):
elif num % 100 == 0: elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + " hundred" return _inflect.number_to_words(num // 100) + " hundred"
else: else:
return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ") return _inflect.number_to_words(
num, andword="", zero="oh", group=2
).replace(", ", " ")
else: else:
return _inflect.number_to_words(num, andword="") return _inflect.number_to_words(num, andword="")

View File

@ -5,9 +5,7 @@ Defines the set of symbols used in text input to the model.
_pad = "_" _pad = "_"
_punctuation = ';:,.!?¡¿—…"«»“” ' _punctuation = ';:,.!?¡¿—…"«»“” '
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
_letters_ipa = ( _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'"
"ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'"
)
# Export all symbols: # Export all symbols:

View File

@ -42,7 +42,9 @@ mel_basis = {}
hann_window = {} hann_window = {}
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): def mel_spectrogram(
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
):
if torch.min(y) < -1.0: if torch.min(y) < -1.0:
print("min value is ", torch.min(y)) print("min value is ", torch.min(y))
if torch.max(y) > 1.0: if torch.max(y) > 1.0:
@ -50,12 +52,18 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin,
global mel_basis, hann_window # pylint: disable=global-statement global mel_basis, hann_window # pylint: disable=global-statement
if f"{str(fmax)}_{str(y.device)}" not in mel_basis: if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) mel = librosa_mel_fn(
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
)
mel_basis[str(fmax) + "_" + str(y.device)] = (
torch.from_numpy(mel).float().to(y.device)
)
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
y = torch.nn.functional.pad( y = torch.nn.functional.pad(
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
mode="reflect",
) )
y = y.squeeze(1) y = y.squeeze(1)

View File

@ -22,7 +22,9 @@ from matcha.utils.logging_utils import pylogger
log = pylogger.get_pylogger(__name__) log = pylogger.get_pylogger(__name__)
def compute_data_statistics(data_loader: torch.utils.data.DataLoader, out_channels: int): def compute_data_statistics(
data_loader: torch.utils.data.DataLoader, out_channels: int
):
"""Generate data mean and standard deviation helpful in data normalisation """Generate data mean and standard deviation helpful in data normalisation
Args: Args:
@ -42,7 +44,9 @@ def compute_data_statistics(data_loader: torch.utils.data.DataLoader, out_channe
total_mel_sq_sum += torch.sum(torch.pow(mels, 2)) total_mel_sq_sum += torch.sum(torch.pow(mels, 2))
data_mean = total_mel_sum / (total_mel_len * out_channels) data_mean = total_mel_sum / (total_mel_len * out_channels)
data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2)) data_std = torch.sqrt(
(total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2)
)
return {"mel_mean": data_mean.item(), "mel_std": data_std.item()} return {"mel_mean": data_mean.item(), "mel_std": data_std.item()}
@ -82,7 +86,9 @@ def main():
sys.exit(1) sys.exit(1)
with initialize(version_base="1.3", config_path="../../configs/data"): with initialize(version_base="1.3", config_path="../../configs/data"):
cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[]) cfg = compose(
config_name=args.input_config, return_hydra_config=True, overrides=[]
)
root_path = rootutils.find_root(search_from=__file__, indicator=".project-root") root_path = rootutils.find_root(search_from=__file__, indicator=".project-root")
@ -93,8 +99,12 @@ def main():
cfg["data_statistics"] = None cfg["data_statistics"] = None
cfg["seed"] = 1234 cfg["seed"] = 1234
cfg["batch_size"] = args.batch_size cfg["batch_size"] = args.batch_size
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) cfg["train_filelist_path"] = str(
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) os.path.join(root_path, cfg["train_filelist_path"])
)
cfg["valid_filelist_path"] = str(
os.path.join(root_path, cfg["valid_filelist_path"])
)
cfg["load_durations"] = False cfg["load_durations"] = False
text_mel_datamodule = TextMelDataModule(**cfg) text_mel_datamodule = TextMelDataModule(**cfg)

View File

@ -29,7 +29,12 @@ log = pylogger.get_pylogger(__name__)
def save_durations_to_folder( def save_durations_to_folder(
attn: torch.Tensor, x_length: int, y_length: int, filepath: str, output_folder: Path, text: str attn: torch.Tensor,
x_length: int,
y_length: int,
filepath: str,
output_folder: Path,
text: str,
): ):
durations = attn.squeeze().sum(1)[:x_length].numpy() durations = attn.squeeze().sum(1)[:x_length].numpy()
durations_json = get_phoneme_durations(durations, text) durations_json = get_phoneme_durations(durations, text)
@ -41,7 +46,12 @@ def save_durations_to_folder(
@torch.inference_mode() @torch.inference_mode()
def compute_durations(data_loader: torch.utils.data.DataLoader, model: nn.Module, device: torch.device, output_folder): def compute_durations(
data_loader: torch.utils.data.DataLoader,
model: nn.Module,
device: torch.device,
output_folder,
):
"""Generate durations from the model for each datapoint and save it in a folder """Generate durations from the model for each datapoint and save it in a folder
Args: Args:
@ -123,13 +133,17 @@ def main():
) )
parser.add_argument( parser.add_argument(
"--cpu", action="store_true", help="Use CPU for inference, not recommended (default: use GPU if available)" "--cpu",
action="store_true",
help="Use CPU for inference, not recommended (default: use GPU if available)",
) )
args = parser.parse_args() args = parser.parse_args()
with initialize(version_base="1.3", config_path="../../configs/data"): with initialize(version_base="1.3", config_path="../../configs/data"):
cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[]) cfg = compose(
config_name=args.input_config, return_hydra_config=True, overrides=[]
)
root_path = rootutils.find_root(search_from=__file__, indicator=".project-root") root_path = rootutils.find_root(search_from=__file__, indicator=".project-root")
@ -138,8 +152,12 @@ def main():
del cfg["_target_"] del cfg["_target_"]
cfg["seed"] = 1234 cfg["seed"] = 1234
cfg["batch_size"] = args.batch_size cfg["batch_size"] = args.batch_size
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) cfg["train_filelist_path"] = str(
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) os.path.join(root_path, cfg["train_filelist_path"])
)
cfg["valid_filelist_path"] = str(
os.path.join(root_path, cfg["valid_filelist_path"])
)
cfg["load_durations"] = False cfg["load_durations"] = False
if args.output_folder is not None: if args.output_folder is not None:
@ -155,7 +173,9 @@ def main():
output_folder.mkdir(parents=True, exist_ok=True) output_folder.mkdir(parents=True, exist_ok=True)
print(f"Preprocessing: {cfg['name']} from training filelist: {cfg['train_filelist_path']}") print(
f"Preprocessing: {cfg['name']} from training filelist: {cfg['train_filelist_path']}"
)
print("Loading model...") print("Loading model...")
device = get_device(args) device = get_device(args)
model = MatchaTTS.load_from_checkpoint(args.checkpoint_path, map_location=device) model = MatchaTTS.load_from_checkpoint(args.checkpoint_path, map_location=device)

View File

@ -27,7 +27,9 @@ def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
for _, cb_conf in callbacks_cfg.items(): for _, cb_conf in callbacks_cfg.items():
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
log.info(f"Instantiating callback <{cb_conf._target_}>") # pylint: disable=protected-access log.info(
f"Instantiating callback <{cb_conf._target_}>"
) # pylint: disable=protected-access
callbacks.append(hydra.utils.instantiate(cb_conf)) callbacks.append(hydra.utils.instantiate(cb_conf))
return callbacks return callbacks
@ -50,7 +52,9 @@ def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
for _, lg_conf in logger_cfg.items(): for _, lg_conf in logger_cfg.items():
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
log.info(f"Instantiating logger <{lg_conf._target_}>") # pylint: disable=protected-access log.info(
f"Instantiating logger <{lg_conf._target_}>"
) # pylint: disable=protected-access
logger.append(hydra.utils.instantiate(lg_conf)) logger.append(hydra.utils.instantiate(lg_conf))
return logger return logger

View File

@ -34,8 +34,12 @@ def log_hyperparameters(object_dict: Dict[str, Any]) -> None:
# save number of model parameters # save number of model parameters
hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
hparams["model/params/trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad) hparams["model/params/trainable"] = sum(
hparams["model/params/non_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad) p.numel() for p in model.parameters() if p.requires_grad
)
hparams["model/params/non_trainable"] = sum(
p.numel() for p in model.parameters() if not p.requires_grad
)
hparams["data"] = cfg["data"] hparams["data"] = cfg["data"]
hparams["trainer"] = cfg["trainer"] hparams["trainer"] = cfg["trainer"]

View File

@ -36,7 +36,12 @@ def generate_path(duration, mask):
cum_duration_flat = cum_duration.view(b * t_x) cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y) path = path.view(b, t_x, t_y)
path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] path = (
path
- torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[
:, :-1
]
)
path = path * mask path = path * mask
return path return path

View File

@ -14,7 +14,15 @@ def get_pylogger(name: str = __name__) -> logging.Logger:
# this ensures all logging levels get marked with the rank zero decorator # this ensures all logging levels get marked with the rank zero decorator
# otherwise logs would get multiplied for each GPU process in multi-GPU setup # otherwise logs would get multiplied for each GPU process in multi-GPU setup
logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") logging_levels = (
"debug",
"info",
"warning",
"error",
"exception",
"fatal",
"critical",
)
for level in logging_levels: for level in logging_levels:
setattr(logger, level, rank_zero_only(getattr(logger, level))) setattr(logger, level, rank_zero_only(getattr(logger, level)))

View File

@ -47,7 +47,9 @@ def print_config_tree(
_ = ( _ = (
queue.append(field) queue.append(field)
if field in cfg if field in cfg
else log.warning(f"Field '{field}' not found in config. Skipping '{field}' config printing...") else log.warning(
f"Field '{field}' not found in config. Skipping '{field}' config printing..."
)
) )
# add all the other fields to queue (not specified in `print_order`) # add all the other fields to queue (not specified in `print_order`)

View File

@ -9,6 +9,7 @@ from typing import Any, Callable, Dict, Tuple
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
# from omegaconf import DictConfig # from omegaconf import DictConfig
# from matcha.utils import pylogger, rich_utils # from matcha.utils import pylogger, rich_utils
@ -16,7 +17,7 @@ import torch
# log = pylogger.get_pylogger(__name__) # log = pylogger.get_pylogger(__name__)
def extras(cfg: 'DictConfig') -> None: def extras(cfg: "DictConfig") -> None:
"""Applies optional utilities before the task is started. """Applies optional utilities before the task is started.
Utilities: Utilities:
@ -207,6 +208,7 @@ def get_user_data_dir(appname="matcha_tts"):
def assert_model_downloaded(checkpoint_path, url, use_wget=True): def assert_model_downloaded(checkpoint_path, url, use_wget=True):
import gdown import gdown
import wget import wget
if Path(checkpoint_path).exists(): if Path(checkpoint_path).exists():
log.debug(f"[+] Model already present at {checkpoint_path}!") log.debug(f"[+] Model already present at {checkpoint_path}!")
print(f"[+] Model already present at {checkpoint_path}!") print(f"[+] Model already present at {checkpoint_path}!")