diff --git a/egs/ljspeech/TTS/matcha/models/baselightningmodule.py b/egs/ljspeech/TTS/matcha/models/baselightningmodule.py index f8abe7b44..e80d2a5c9 100644 --- a/egs/ljspeech/TTS/matcha/models/baselightningmodule.py +++ b/egs/ljspeech/TTS/matcha/models/baselightningmodule.py @@ -32,7 +32,10 @@ class BaseLightningClass(LightningModule, ABC): if self.hparams.scheduler not in (None, {}): scheduler_args = {} # Manage last epoch for exponential schedulers - if "last_epoch" in inspect.signature(self.hparams.scheduler.scheduler).parameters: + if ( + "last_epoch" + in inspect.signature(self.hparams.scheduler.scheduler).parameters + ): if hasattr(self, "ckpt_loaded_epoch"): current_epoch = self.ckpt_loaded_epoch - 1 else: @@ -74,7 +77,9 @@ class BaseLightningClass(LightningModule, ABC): } def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - self.ckpt_loaded_epoch = checkpoint["epoch"] # pylint: disable=attribute-defined-outside-init + self.ckpt_loaded_epoch = checkpoint[ + "epoch" + ] # pylint: disable=attribute-defined-outside-init def training_step(self, batch: Any, batch_idx: int): loss_dict = self.get_losses(batch) @@ -183,8 +188,14 @@ class BaseLightningClass(LightningModule, ABC): for i in range(2): x = one_batch["x"][i].unsqueeze(0).to(self.device) x_lengths = one_batch["x_lengths"][i].unsqueeze(0).to(self.device) - spks = one_batch["spks"][i].unsqueeze(0).to(self.device) if one_batch["spks"] is not None else None - output = self.synthesise(x[:, :x_lengths], x_lengths, n_timesteps=10, spks=spks) + spks = ( + one_batch["spks"][i].unsqueeze(0).to(self.device) + if one_batch["spks"] is not None + else None + ) + output = self.synthesise( + x[:, :x_lengths], x_lengths, n_timesteps=10, spks=spks + ) y_enc, y_dec = output["encoder_outputs"], output["decoder_outputs"] attn = output["attn"] self.logger.experiment.add_image( @@ -207,4 +218,6 @@ class BaseLightningClass(LightningModule, ABC): ) def on_before_optimizer_step(self, optimizer): - self.log_dict({f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()}) + self.log_dict( + {f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()} + ) diff --git a/egs/ljspeech/TTS/matcha/models/components/decoder.py b/egs/ljspeech/TTS/matcha/models/components/decoder.py index 1137cd700..5850f2639 100644 --- a/egs/ljspeech/TTS/matcha/models/components/decoder.py +++ b/egs/ljspeech/TTS/matcha/models/components/decoder.py @@ -46,7 +46,9 @@ class Block1D(torch.nn.Module): class ResnetBlock1D(torch.nn.Module): def __init__(self, dim, dim_out, time_emb_dim, groups=8): super().__init__() - self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out)) + self.mlp = torch.nn.Sequential( + nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out) + ) self.block1 = Block1D(dim, dim_out, groups=groups) self.block2 = Block1D(dim_out, dim_out, groups=groups) @@ -131,7 +133,14 @@ class Upsample1D(nn.Module): number of output channels. Defaults to `channels`. """ - def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"): + def __init__( + self, + channels, + use_conv=False, + use_conv_transpose=True, + out_channels=None, + name="conv", + ): super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -235,7 +244,9 @@ class Decoder(nn.Module): input_channel = output_channel output_channel = channels[i] is_last = i == len(channels) - 1 - resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + resnet = ResnetBlock1D( + dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim + ) transformer_blocks = nn.ModuleList( [ self.get_block( @@ -250,16 +261,22 @@ class Decoder(nn.Module): ] ) downsample = ( - Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1) + Downsample1D(output_channel) + if not is_last + else nn.Conv1d(output_channel, output_channel, 3, padding=1) ) - self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample])) + self.down_blocks.append( + nn.ModuleList([resnet, transformer_blocks, downsample]) + ) for i in range(num_mid_blocks): input_channel = channels[-1] out_channels = channels[-1] - resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + resnet = ResnetBlock1D( + dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim + ) transformer_blocks = nn.ModuleList( [ diff --git a/egs/ljspeech/TTS/matcha/models/components/flow_matching.py b/egs/ljspeech/TTS/matcha/models/components/flow_matching.py index 552c4b383..5a7226b4f 100644 --- a/egs/ljspeech/TTS/matcha/models/components/flow_matching.py +++ b/egs/ljspeech/TTS/matcha/models/components/flow_matching.py @@ -4,6 +4,7 @@ import torch import torch.nn.functional as F from matcha.models.components.decoder import Decoder + # from matcha.utils.pylogger import get_pylogger # log = get_pylogger(__name__) @@ -50,7 +51,9 @@ class BASECFM(torch.nn.Module, ABC): """ z = torch.randn_like(mu) * temperature t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) - return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond) + return self.solve_euler( + z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond + ) def solve_euler(self, x, t_span, mu, mask, spks, cond): """ @@ -112,14 +115,22 @@ class BASECFM(torch.nn.Module, ABC): y = (1 - (1 - self.sigma_min) * t) * z + t * x1 u = x1 - (1 - self.sigma_min) * z - loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / ( - torch.sum(mask) * u.shape[1] - ) + loss = F.mse_loss( + self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum" + ) / (torch.sum(mask) * u.shape[1]) return loss, y class CFM(BASECFM): - def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64): + def __init__( + self, + in_channels, + out_channel, + cfm_params, + decoder_params, + n_spks=1, + spk_emb_dim=64, + ): super().__init__( n_feats=in_channels, cfm_params=cfm_params, @@ -129,4 +140,6 @@ class CFM(BASECFM): in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0) # Just change the architecture of the estimator here - self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params) + self.estimator = Decoder( + in_channels=in_channels, out_channels=out_channel, **decoder_params + ) diff --git a/egs/ljspeech/TTS/matcha/models/components/text_encoder.py b/egs/ljspeech/TTS/matcha/models/components/text_encoder.py index efd225356..68f8ad864 100644 --- a/egs/ljspeech/TTS/matcha/models/components/text_encoder.py +++ b/egs/ljspeech/TTS/matcha/models/components/text_encoder.py @@ -34,7 +34,15 @@ class LayerNorm(nn.Module): class ConvReluNorm(nn.Module): - def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): + def __init__( + self, + in_channels, + hidden_channels, + out_channels, + kernel_size, + n_layers, + p_dropout, + ): super().__init__() self.in_channels = in_channels self.hidden_channels = hidden_channels @@ -45,12 +53,23 @@ class ConvReluNorm(nn.Module): self.conv_layers = torch.nn.ModuleList() self.norm_layers = torch.nn.ModuleList() - self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.conv_layers.append( + torch.nn.Conv1d( + in_channels, hidden_channels, kernel_size, padding=kernel_size // 2 + ) + ) self.norm_layers.append(LayerNorm(hidden_channels)) - self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout)) + self.relu_drop = torch.nn.Sequential( + torch.nn.ReLU(), torch.nn.Dropout(p_dropout) + ) for _ in range(n_layers - 1): self.conv_layers.append( - torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2) + torch.nn.Conv1d( + hidden_channels, + hidden_channels, + kernel_size, + padding=kernel_size // 2, + ) ) self.norm_layers.append(LayerNorm(hidden_channels)) self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1) @@ -75,9 +94,13 @@ class DurationPredictor(nn.Module): self.p_dropout = p_dropout self.drop = torch.nn.Dropout(p_dropout) - self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.conv_1 = torch.nn.Conv1d( + in_channels, filter_channels, kernel_size, padding=kernel_size // 2 + ) self.norm_1 = LayerNorm(filter_channels) - self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.conv_2 = torch.nn.Conv1d( + filter_channels, filter_channels, kernel_size, padding=kernel_size // 2 + ) self.norm_2 = LayerNorm(filter_channels) self.proj = torch.nn.Conv1d(filter_channels, 1, 1) @@ -128,7 +151,9 @@ class RotaryPositionalEmbeddings(nn.Module): seq_len = x.shape[0] # $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ - theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device) + theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to( + x.device + ) # Create position indexes `[0, 1, ..., seq_len - 1]` seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device) @@ -167,7 +192,9 @@ class RotaryPositionalEmbeddings(nn.Module): # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ neg_half_x = self._neg_half(x_rope) - x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]]) + x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + ( + neg_half_x * self.sin_cached[: x.shape[0]] + ) return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d") @@ -236,7 +263,9 @@ class MultiHeadAttention(nn.Module): if self.proximal_bias: assert t_s == t_t, "Proximal bias is only available for self-attention." - scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) + scores = scores + self._attention_bias_proximal(t_s).to( + device=scores.device, dtype=scores.dtype + ) if mask is not None: scores = scores.masked_fill(mask == 0, -1e4) p_attn = torch.nn.functional.softmax(scores, dim=-1) @@ -253,7 +282,9 @@ class MultiHeadAttention(nn.Module): class FFN(nn.Module): - def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0): + def __init__( + self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0 + ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels @@ -261,8 +292,12 @@ class FFN(nn.Module): self.kernel_size = kernel_size self.p_dropout = p_dropout - self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) - self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2) + self.conv_1 = torch.nn.Conv1d( + in_channels, filter_channels, kernel_size, padding=kernel_size // 2 + ) + self.conv_2 = torch.nn.Conv1d( + filter_channels, out_channels, kernel_size, padding=kernel_size // 2 + ) self.drop = torch.nn.Dropout(p_dropout) def forward(self, x, x_mask): @@ -298,7 +333,11 @@ class Encoder(nn.Module): self.ffn_layers = torch.nn.ModuleList() self.norm_layers_2 = torch.nn.ModuleList() for _ in range(self.n_layers): - self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)) + self.attn_layers.append( + MultiHeadAttention( + hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout + ) + ) self.norm_layers_1.append(LayerNorm(hidden_channels)) self.ffn_layers.append( FFN( @@ -367,7 +406,9 @@ class TextEncoder(nn.Module): encoder_params.p_dropout, ) - self.proj_m = torch.nn.Conv1d(self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1) + self.proj_m = torch.nn.Conv1d( + self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1 + ) self.proj_w = DurationPredictor( self.n_channels + (spk_emb_dim if n_spks > 1 else 0), duration_predictor_params.filter_channels_dp, diff --git a/egs/ljspeech/TTS/matcha/models/components/transformer.py b/egs/ljspeech/TTS/matcha/models/components/transformer.py index dd1afa3af..a82e560bc 100644 --- a/egs/ljspeech/TTS/matcha/models/components/transformer.py +++ b/egs/ljspeech/TTS/matcha/models/components/transformer.py @@ -32,7 +32,14 @@ class SnakeBeta(nn.Module): >>> x = a1(x) """ - def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): + def __init__( + self, + in_features, + out_features, + alpha=1.0, + alpha_trainable=True, + alpha_logscale=True, + ): """ Initialization. INPUT: @@ -44,7 +51,9 @@ class SnakeBeta(nn.Module): alpha will be trained along with the rest of your model. """ super().__init__() - self.in_features = out_features if isinstance(out_features, list) else [out_features] + self.in_features = ( + out_features if isinstance(out_features, list) else [out_features] + ) self.proj = LoRACompatibleLinear(in_features, out_features) # initialize alpha @@ -75,7 +84,9 @@ class SnakeBeta(nn.Module): alpha = self.alpha beta = self.beta - x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2) + x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow( + torch.sin(x * alpha), 2 + ) return x @@ -176,8 +187,12 @@ class BasicTransformerBlock(nn.Module): super().__init__() self.only_cross_attention = only_cross_attention - self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" - self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_zero = ( + num_embeds_ada_norm is not None + ) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = ( + num_embeds_ada_norm is not None + ) and norm_type == "ada_norm" if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: raise ValueError( @@ -215,7 +230,9 @@ class BasicTransformerBlock(nn.Module): ) self.attn2 = Attention( query_dim=dim, - cross_attention_dim=cross_attention_dim if not double_self_attention else None, + cross_attention_dim=cross_attention_dim + if not double_self_attention + else None, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, @@ -229,7 +246,12 @@ class BasicTransformerBlock(nn.Module): # 3. Feed-forward self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + ) # let chunk size default to None self._chunk_size = None @@ -261,12 +283,18 @@ class BasicTransformerBlock(nn.Module): else: norm_hidden_states = self.norm1(hidden_states) - cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + cross_attention_kwargs = ( + cross_attention_kwargs if cross_attention_kwargs is not None else {} + ) attn_output = self.attn1( norm_hidden_states, - encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask, + encoder_hidden_states=encoder_hidden_states + if self.only_cross_attention + else None, + attention_mask=encoder_attention_mask + if self.only_cross_attention + else attention_mask, **cross_attention_kwargs, ) if self.use_ada_layer_norm_zero: @@ -276,7 +304,9 @@ class BasicTransformerBlock(nn.Module): # 2. Cross-Attention if self.attn2 is not None: norm_hidden_states = ( - self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + self.norm2(hidden_states, timestep) + if self.use_ada_layer_norm + else self.norm2(hidden_states) ) attn_output = self.attn2( @@ -291,7 +321,9 @@ class BasicTransformerBlock(nn.Module): norm_hidden_states = self.norm3(hidden_states) if self.use_ada_layer_norm_zero: - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + norm_hidden_states = ( + norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ) if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory @@ -302,7 +334,12 @@ class BasicTransformerBlock(nn.Module): num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size ff_output = torch.cat( - [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], + [ + self.ff(hid_slice) + for hid_slice in norm_hidden_states.chunk( + num_chunks, dim=self._chunk_dim + ) + ], dim=self._chunk_dim, ) else: diff --git a/egs/ljspeech/TTS/matcha/text/__init__.py b/egs/ljspeech/TTS/matcha/text/__init__.py index dc3427f0b..78c8b1f18 100644 --- a/egs/ljspeech/TTS/matcha/text/__init__.py +++ b/egs/ljspeech/TTS/matcha/text/__init__.py @@ -4,7 +4,9 @@ from matcha.text.symbols import symbols # Mappings from symbol to numeric ID and vice versa: _symbol_to_id = {s: i for i, s in enumerate(symbols)} -_id_to_symbol = {i: s for i, s in enumerate(symbols)} # pylint: disable=unnecessary-comprehension +_id_to_symbol = { + i: s for i, s in enumerate(symbols) +} # pylint: disable=unnecessary-comprehension def text_to_sequence(text, cleaner_names): @@ -20,13 +22,15 @@ def text_to_sequence(text, cleaner_names): clean_text = _clean_text(text, cleaner_names) for symbol in clean_text: try: - if symbol in '_()[]# ̃': + if symbol in "_()[]# ̃": continue symbol_id = _symbol_to_id[symbol] except Exception as ex: print(text) print(clean_text) - raise RuntimeError(f'text: {text}, clean_text: {clean_text}, ex: {ex}, symbol: {symbol}') + raise RuntimeError( + f"text: {text}, clean_text: {clean_text}, ex: {ex}, symbol: {symbol}" + ) sequence += [symbol_id] return sequence, clean_text diff --git a/egs/ljspeech/TTS/matcha/text/cleaners.py b/egs/ljspeech/TTS/matcha/text/cleaners.py index 33cdc9fc6..0a1979afe 100644 --- a/egs/ljspeech/TTS/matcha/text/cleaners.py +++ b/egs/ljspeech/TTS/matcha/text/cleaners.py @@ -75,11 +75,12 @@ def lowercase(text): def collapse_whitespace(text): return re.sub(_whitespace_re, " ", text) + def remove_parentheses(text): - text = text.replace("(", "") - text = text.replace(")", "") - text = text.replace("[", "") - text = text.replace("]", "") + text = text.replace("(", "") + text = text.replace(")", "") + text = text.replace("[", "") + text = text.replace("]", "") return text diff --git a/egs/ljspeech/TTS/matcha/text/numbers.py b/egs/ljspeech/TTS/matcha/text/numbers.py index f99a8686d..49c21d4e9 100644 --- a/egs/ljspeech/TTS/matcha/text/numbers.py +++ b/egs/ljspeech/TTS/matcha/text/numbers.py @@ -56,7 +56,9 @@ def _expand_number(m): elif num % 100 == 0: return _inflect.number_to_words(num // 100) + " hundred" else: - return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ") + return _inflect.number_to_words( + num, andword="", zero="oh", group=2 + ).replace(", ", " ") else: return _inflect.number_to_words(num, andword="") diff --git a/egs/ljspeech/TTS/matcha/text/symbols.py b/egs/ljspeech/TTS/matcha/text/symbols.py index 7018df549..b32c12430 100644 --- a/egs/ljspeech/TTS/matcha/text/symbols.py +++ b/egs/ljspeech/TTS/matcha/text/symbols.py @@ -5,9 +5,7 @@ Defines the set of symbols used in text input to the model. _pad = "_" _punctuation = ';:,.!?¡¿—…"«»“” ' _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" -_letters_ipa = ( - "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" -) +_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" # Export all symbols: diff --git a/egs/ljspeech/TTS/matcha/utils/audio.py b/egs/ljspeech/TTS/matcha/utils/audio.py index 0bcd74df4..0a9b8db2a 100644 --- a/egs/ljspeech/TTS/matcha/utils/audio.py +++ b/egs/ljspeech/TTS/matcha/utils/audio.py @@ -42,7 +42,9 @@ mel_basis = {} hann_window = {} -def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): +def mel_spectrogram( + y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False +): if torch.min(y) < -1.0: print("min value is ", torch.min(y)) if torch.max(y) > 1.0: @@ -50,12 +52,18 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, global mel_basis, hann_window # pylint: disable=global-statement if f"{str(fmax)}_{str(y.device)}" not in mel_basis: - mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) - mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) + mel = librosa_mel_fn( + sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax + ) + mel_basis[str(fmax) + "_" + str(y.device)] = ( + torch.from_numpy(mel).float().to(y.device) + ) hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) y = torch.nn.functional.pad( - y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", ) y = y.squeeze(1) diff --git a/egs/ljspeech/TTS/matcha/utils/generate_data_statistics.py b/egs/ljspeech/TTS/matcha/utils/generate_data_statistics.py index 3b8cd67c9..3028e7695 100644 --- a/egs/ljspeech/TTS/matcha/utils/generate_data_statistics.py +++ b/egs/ljspeech/TTS/matcha/utils/generate_data_statistics.py @@ -22,7 +22,9 @@ from matcha.utils.logging_utils import pylogger log = pylogger.get_pylogger(__name__) -def compute_data_statistics(data_loader: torch.utils.data.DataLoader, out_channels: int): +def compute_data_statistics( + data_loader: torch.utils.data.DataLoader, out_channels: int +): """Generate data mean and standard deviation helpful in data normalisation Args: @@ -42,7 +44,9 @@ def compute_data_statistics(data_loader: torch.utils.data.DataLoader, out_channe total_mel_sq_sum += torch.sum(torch.pow(mels, 2)) data_mean = total_mel_sum / (total_mel_len * out_channels) - data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2)) + data_std = torch.sqrt( + (total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2) + ) return {"mel_mean": data_mean.item(), "mel_std": data_std.item()} @@ -82,7 +86,9 @@ def main(): sys.exit(1) with initialize(version_base="1.3", config_path="../../configs/data"): - cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[]) + cfg = compose( + config_name=args.input_config, return_hydra_config=True, overrides=[] + ) root_path = rootutils.find_root(search_from=__file__, indicator=".project-root") @@ -93,8 +99,12 @@ def main(): cfg["data_statistics"] = None cfg["seed"] = 1234 cfg["batch_size"] = args.batch_size - cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) - cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) + cfg["train_filelist_path"] = str( + os.path.join(root_path, cfg["train_filelist_path"]) + ) + cfg["valid_filelist_path"] = str( + os.path.join(root_path, cfg["valid_filelist_path"]) + ) cfg["load_durations"] = False text_mel_datamodule = TextMelDataModule(**cfg) diff --git a/egs/ljspeech/TTS/matcha/utils/get_durations_from_trained_model.py b/egs/ljspeech/TTS/matcha/utils/get_durations_from_trained_model.py index 0fe2f35c4..acc7eabd9 100644 --- a/egs/ljspeech/TTS/matcha/utils/get_durations_from_trained_model.py +++ b/egs/ljspeech/TTS/matcha/utils/get_durations_from_trained_model.py @@ -29,7 +29,12 @@ log = pylogger.get_pylogger(__name__) def save_durations_to_folder( - attn: torch.Tensor, x_length: int, y_length: int, filepath: str, output_folder: Path, text: str + attn: torch.Tensor, + x_length: int, + y_length: int, + filepath: str, + output_folder: Path, + text: str, ): durations = attn.squeeze().sum(1)[:x_length].numpy() durations_json = get_phoneme_durations(durations, text) @@ -41,7 +46,12 @@ def save_durations_to_folder( @torch.inference_mode() -def compute_durations(data_loader: torch.utils.data.DataLoader, model: nn.Module, device: torch.device, output_folder): +def compute_durations( + data_loader: torch.utils.data.DataLoader, + model: nn.Module, + device: torch.device, + output_folder, +): """Generate durations from the model for each datapoint and save it in a folder Args: @@ -123,13 +133,17 @@ def main(): ) parser.add_argument( - "--cpu", action="store_true", help="Use CPU for inference, not recommended (default: use GPU if available)" + "--cpu", + action="store_true", + help="Use CPU for inference, not recommended (default: use GPU if available)", ) args = parser.parse_args() with initialize(version_base="1.3", config_path="../../configs/data"): - cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[]) + cfg = compose( + config_name=args.input_config, return_hydra_config=True, overrides=[] + ) root_path = rootutils.find_root(search_from=__file__, indicator=".project-root") @@ -138,8 +152,12 @@ def main(): del cfg["_target_"] cfg["seed"] = 1234 cfg["batch_size"] = args.batch_size - cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) - cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) + cfg["train_filelist_path"] = str( + os.path.join(root_path, cfg["train_filelist_path"]) + ) + cfg["valid_filelist_path"] = str( + os.path.join(root_path, cfg["valid_filelist_path"]) + ) cfg["load_durations"] = False if args.output_folder is not None: @@ -155,7 +173,9 @@ def main(): output_folder.mkdir(parents=True, exist_ok=True) - print(f"Preprocessing: {cfg['name']} from training filelist: {cfg['train_filelist_path']}") + print( + f"Preprocessing: {cfg['name']} from training filelist: {cfg['train_filelist_path']}" + ) print("Loading model...") device = get_device(args) model = MatchaTTS.load_from_checkpoint(args.checkpoint_path, map_location=device) diff --git a/egs/ljspeech/TTS/matcha/utils/instantiators.py b/egs/ljspeech/TTS/matcha/utils/instantiators.py index 5547b4ed6..bde0c0d75 100644 --- a/egs/ljspeech/TTS/matcha/utils/instantiators.py +++ b/egs/ljspeech/TTS/matcha/utils/instantiators.py @@ -27,7 +27,9 @@ def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: for _, cb_conf in callbacks_cfg.items(): if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: - log.info(f"Instantiating callback <{cb_conf._target_}>") # pylint: disable=protected-access + log.info( + f"Instantiating callback <{cb_conf._target_}>" + ) # pylint: disable=protected-access callbacks.append(hydra.utils.instantiate(cb_conf)) return callbacks @@ -50,7 +52,9 @@ def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: for _, lg_conf in logger_cfg.items(): if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: - log.info(f"Instantiating logger <{lg_conf._target_}>") # pylint: disable=protected-access + log.info( + f"Instantiating logger <{lg_conf._target_}>" + ) # pylint: disable=protected-access logger.append(hydra.utils.instantiate(lg_conf)) return logger diff --git a/egs/ljspeech/TTS/matcha/utils/logging_utils.py b/egs/ljspeech/TTS/matcha/utils/logging_utils.py index 1a12d1dda..2d2377eb2 100644 --- a/egs/ljspeech/TTS/matcha/utils/logging_utils.py +++ b/egs/ljspeech/TTS/matcha/utils/logging_utils.py @@ -34,8 +34,12 @@ def log_hyperparameters(object_dict: Dict[str, Any]) -> None: # save number of model parameters hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) - hparams["model/params/trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad) - hparams["model/params/non_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad) + hparams["model/params/trainable"] = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + hparams["model/params/non_trainable"] = sum( + p.numel() for p in model.parameters() if not p.requires_grad + ) hparams["data"] = cfg["data"] hparams["trainer"] = cfg["trainer"] diff --git a/egs/ljspeech/TTS/matcha/utils/model.py b/egs/ljspeech/TTS/matcha/utils/model.py index 869cc6092..a488ab4e8 100644 --- a/egs/ljspeech/TTS/matcha/utils/model.py +++ b/egs/ljspeech/TTS/matcha/utils/model.py @@ -36,7 +36,12 @@ def generate_path(duration, mask): cum_duration_flat = cum_duration.view(b * t_x) path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) path = path.view(b, t_x, t_y) - path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = ( + path + - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[ + :, :-1 + ] + ) path = path * mask return path diff --git a/egs/ljspeech/TTS/matcha/utils/pylogger.py b/egs/ljspeech/TTS/matcha/utils/pylogger.py index 616006780..a7ed7a961 100644 --- a/egs/ljspeech/TTS/matcha/utils/pylogger.py +++ b/egs/ljspeech/TTS/matcha/utils/pylogger.py @@ -14,7 +14,15 @@ def get_pylogger(name: str = __name__) -> logging.Logger: # this ensures all logging levels get marked with the rank zero decorator # otherwise logs would get multiplied for each GPU process in multi-GPU setup - logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") + logging_levels = ( + "debug", + "info", + "warning", + "error", + "exception", + "fatal", + "critical", + ) for level in logging_levels: setattr(logger, level, rank_zero_only(getattr(logger, level))) diff --git a/egs/ljspeech/TTS/matcha/utils/rich_utils.py b/egs/ljspeech/TTS/matcha/utils/rich_utils.py index f602f6e93..d7fcd1aae 100644 --- a/egs/ljspeech/TTS/matcha/utils/rich_utils.py +++ b/egs/ljspeech/TTS/matcha/utils/rich_utils.py @@ -47,7 +47,9 @@ def print_config_tree( _ = ( queue.append(field) if field in cfg - else log.warning(f"Field '{field}' not found in config. Skipping '{field}' config printing...") + else log.warning( + f"Field '{field}' not found in config. Skipping '{field}' config printing..." + ) ) # add all the other fields to queue (not specified in `print_order`) diff --git a/egs/ljspeech/TTS/matcha/utils/utils.py b/egs/ljspeech/TTS/matcha/utils/utils.py index bc81c316e..a54554263 100644 --- a/egs/ljspeech/TTS/matcha/utils/utils.py +++ b/egs/ljspeech/TTS/matcha/utils/utils.py @@ -9,6 +9,7 @@ from typing import Any, Callable, Dict, Tuple import matplotlib.pyplot as plt import numpy as np import torch + # from omegaconf import DictConfig # from matcha.utils import pylogger, rich_utils @@ -16,7 +17,7 @@ import torch # log = pylogger.get_pylogger(__name__) -def extras(cfg: 'DictConfig') -> None: +def extras(cfg: "DictConfig") -> None: """Applies optional utilities before the task is started. Utilities: @@ -207,6 +208,7 @@ def get_user_data_dir(appname="matcha_tts"): def assert_model_downloaded(checkpoint_path, url, use_wget=True): import gdown import wget + if Path(checkpoint_path).exists(): log.debug(f"[+] Model already present at {checkpoint_path}!") print(f"[+] Model already present at {checkpoint_path}!")