diff --git a/egs/ljspeech/TTS/matcha/models/components/flow_matching.py b/egs/ljspeech/TTS/matcha/models/components/flow_matching.py index 5cad7431e..552c4b383 100644 --- a/egs/ljspeech/TTS/matcha/models/components/flow_matching.py +++ b/egs/ljspeech/TTS/matcha/models/components/flow_matching.py @@ -4,9 +4,9 @@ import torch import torch.nn.functional as F from matcha.models.components.decoder import Decoder -from matcha.utils.pylogger import get_pylogger +# from matcha.utils.pylogger import get_pylogger -log = get_pylogger(__name__) +# log = get_pylogger(__name__) class BASECFM(torch.nn.Module, ABC): diff --git a/egs/ljspeech/TTS/matcha/models/components/text_encoder.py b/egs/ljspeech/TTS/matcha/models/components/text_encoder.py index a388d05d6..efd225356 100644 --- a/egs/ljspeech/TTS/matcha/models/components/text_encoder.py +++ b/egs/ljspeech/TTS/matcha/models/components/text_encoder.py @@ -6,10 +6,10 @@ import torch import torch.nn as nn from einops import rearrange -import matcha.utils as utils +# import matcha.utils as utils from matcha.utils.model import sequence_mask -log = utils.get_pylogger(__name__) +# log = utils.get_pylogger(__name__) class LayerNorm(nn.Module): diff --git a/egs/ljspeech/TTS/matcha/models/matcha_tts.py b/egs/ljspeech/TTS/matcha/models/matcha_tts.py index 07f95ad2e..d4b1c57ab 100644 --- a/egs/ljspeech/TTS/matcha/models/matcha_tts.py +++ b/egs/ljspeech/TTS/matcha/models/matcha_tts.py @@ -5,8 +5,8 @@ import random import torch import matcha.utils.monotonic_align as monotonic_align -from matcha import utils -from matcha.models.baselightningmodule import BaseLightningClass +# from matcha import utils +# from matcha.models.baselightningmodule import BaseLightningClass from matcha.models.components.flow_matching import CFM from matcha.models.components.text_encoder import TextEncoder from matcha.utils.model import ( @@ -17,10 +17,10 @@ from matcha.utils.model import ( sequence_mask, ) -log = utils.get_pylogger(__name__) +# log = utils.get_pylogger(__name__) -class MatchaTTS(BaseLightningClass): # 🍵 +class MatchaTTS(torch.nn.Module): # 🍵 def __init__( self, n_vocab, @@ -30,7 +30,7 @@ class MatchaTTS(BaseLightningClass): # 🍵 encoder, decoder, cfm, - data_statistics, + # data_statistics, out_size, optimizer=None, scheduler=None, @@ -39,7 +39,7 @@ class MatchaTTS(BaseLightningClass): # 🍵 ): super().__init__() - self.save_hyperparameters(logger=False) + # self.save_hyperparameters(logger=False) self.n_vocab = n_vocab self.n_spks = n_spks @@ -70,7 +70,7 @@ class MatchaTTS(BaseLightningClass): # 🍵 spk_emb_dim=spk_emb_dim, ) - self.update_data_statistics(data_statistics) + # self.update_data_statistics(data_statistics) @torch.inference_mode() def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, length_scale=1.0): diff --git a/egs/ljspeech/TTS/matcha/train.py b/egs/ljspeech/TTS/matcha/train.py new file mode 100755 index 000000000..1c5084204 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/train.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + + +import torch + + +from icefall.utils import AttributeDict +from matcha.models.matcha_tts import MatchaTTS + + +def get_model(params): + m = MatchaTTS(**params.model) + return m + + +def main(): + n_feats = 80 + filter_channels_dp = 256 + encoder_params_p_dropout = 0.1 + params = AttributeDict( + { + "model": AttributeDict( + { + "n_vocab": 178, + "n_spks": 1, # for ljspeech. + "spk_emb_dim": 64, + "n_feats": n_feats, + "out_size": None, # or use 172 + "prior_loss": True, + "use_precomputed_durations": False, + "encoder": AttributeDict( + { + "encoder_type": "RoPE Encoder", # not used + "encoder_params": AttributeDict( + { + "n_feats": n_feats, + "n_channels": 192, + "filter_channels": 768, + "filter_channels_dp": filter_channels_dp, + "n_heads": 2, + "n_layers": 6, + "kernel_size": 3, + "p_dropout": encoder_params_p_dropout, + "spk_emb_dim": 64, + "n_spks": 1, + "prenet": True, + } + ), + "duration_predictor_params": AttributeDict( + { + "filter_channels_dp": filter_channels_dp, + "kernel_size": 3, + "p_dropout": encoder_params_p_dropout, + } + ), + } + ), + "decoder": AttributeDict( + { + "channels": [256, 256], + "dropout": 0.05, + "attention_head_dim": 64, + "n_blocks": 1, + "num_mid_blocks": 2, + "num_heads": 2, + "act_fn": "snakebeta", + } + ), + "cfm": AttributeDict( + { + "name": "CFM", + "solver": "euler", + "sigma_min": 1e-4, + } + ), + "optimizer": AttributeDict( + { + "lr": 1e-4, + "weight_decay": 0.0, + } + ), + } + ) + } + ) + m = get_model(params) + print(m) + + num_param = sum([p.numel() for p in m.parameters()]) + print(f"Number of parameters: {num_param}") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/ljspeech/TTS/matcha/utils/__init__.py b/egs/ljspeech/TTS/matcha/utils/__init__.py index 074db6461..2b74b40f5 100644 --- a/egs/ljspeech/TTS/matcha/utils/__init__.py +++ b/egs/ljspeech/TTS/matcha/utils/__init__.py @@ -1,5 +1,5 @@ -from matcha.utils.instantiators import instantiate_callbacks, instantiate_loggers -from matcha.utils.logging_utils import log_hyperparameters -from matcha.utils.pylogger import get_pylogger -from matcha.utils.rich_utils import enforce_tags, print_config_tree -from matcha.utils.utils import extras, get_metric_value, task_wrapper +# from matcha.utils.instantiators import instantiate_callbacks, instantiate_loggers +# from matcha.utils.logging_utils import log_hyperparameters +# from matcha.utils.pylogger import get_pylogger +# from matcha.utils.rich_utils import enforce_tags, print_config_tree +# from matcha.utils.utils import extras, get_metric_value, task_wrapper diff --git a/egs/ljspeech/TTS/matcha/utils/monotonic_align/.gitignore b/egs/ljspeech/TTS/matcha/utils/monotonic_align/.gitignore new file mode 100644 index 000000000..28bdad6b8 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/utils/monotonic_align/.gitignore @@ -0,0 +1,3 @@ +build +core.c +*.so