mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
create model from parameters
This commit is contained in:
parent
f95ac12d70
commit
6fac3a3143
@ -4,9 +4,9 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from matcha.models.components.decoder import Decoder
|
||||
from matcha.utils.pylogger import get_pylogger
|
||||
# from matcha.utils.pylogger import get_pylogger
|
||||
|
||||
log = get_pylogger(__name__)
|
||||
# log = get_pylogger(__name__)
|
||||
|
||||
|
||||
class BASECFM(torch.nn.Module, ABC):
|
||||
|
@ -6,10 +6,10 @@ import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
|
||||
import matcha.utils as utils
|
||||
# import matcha.utils as utils
|
||||
from matcha.utils.model import sequence_mask
|
||||
|
||||
log = utils.get_pylogger(__name__)
|
||||
# log = utils.get_pylogger(__name__)
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
|
@ -5,8 +5,8 @@ import random
|
||||
import torch
|
||||
|
||||
import matcha.utils.monotonic_align as monotonic_align
|
||||
from matcha import utils
|
||||
from matcha.models.baselightningmodule import BaseLightningClass
|
||||
# from matcha import utils
|
||||
# from matcha.models.baselightningmodule import BaseLightningClass
|
||||
from matcha.models.components.flow_matching import CFM
|
||||
from matcha.models.components.text_encoder import TextEncoder
|
||||
from matcha.utils.model import (
|
||||
@ -17,10 +17,10 @@ from matcha.utils.model import (
|
||||
sequence_mask,
|
||||
)
|
||||
|
||||
log = utils.get_pylogger(__name__)
|
||||
# log = utils.get_pylogger(__name__)
|
||||
|
||||
|
||||
class MatchaTTS(BaseLightningClass): # 🍵
|
||||
class MatchaTTS(torch.nn.Module): # 🍵
|
||||
def __init__(
|
||||
self,
|
||||
n_vocab,
|
||||
@ -30,7 +30,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
||||
encoder,
|
||||
decoder,
|
||||
cfm,
|
||||
data_statistics,
|
||||
# data_statistics,
|
||||
out_size,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
@ -39,7 +39,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.save_hyperparameters(logger=False)
|
||||
# self.save_hyperparameters(logger=False)
|
||||
|
||||
self.n_vocab = n_vocab
|
||||
self.n_spks = n_spks
|
||||
@ -70,7 +70,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
||||
spk_emb_dim=spk_emb_dim,
|
||||
)
|
||||
|
||||
self.update_data_statistics(data_statistics)
|
||||
# self.update_data_statistics(data_statistics)
|
||||
|
||||
@torch.inference_mode()
|
||||
def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, length_scale=1.0):
|
||||
|
98
egs/ljspeech/TTS/matcha/train.py
Executable file
98
egs/ljspeech/TTS/matcha/train.py
Executable file
@ -0,0 +1,98 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
from icefall.utils import AttributeDict
|
||||
from matcha.models.matcha_tts import MatchaTTS
|
||||
|
||||
|
||||
def get_model(params):
|
||||
m = MatchaTTS(**params.model)
|
||||
return m
|
||||
|
||||
|
||||
def main():
|
||||
n_feats = 80
|
||||
filter_channels_dp = 256
|
||||
encoder_params_p_dropout = 0.1
|
||||
params = AttributeDict(
|
||||
{
|
||||
"model": AttributeDict(
|
||||
{
|
||||
"n_vocab": 178,
|
||||
"n_spks": 1, # for ljspeech.
|
||||
"spk_emb_dim": 64,
|
||||
"n_feats": n_feats,
|
||||
"out_size": None, # or use 172
|
||||
"prior_loss": True,
|
||||
"use_precomputed_durations": False,
|
||||
"encoder": AttributeDict(
|
||||
{
|
||||
"encoder_type": "RoPE Encoder", # not used
|
||||
"encoder_params": AttributeDict(
|
||||
{
|
||||
"n_feats": n_feats,
|
||||
"n_channels": 192,
|
||||
"filter_channels": 768,
|
||||
"filter_channels_dp": filter_channels_dp,
|
||||
"n_heads": 2,
|
||||
"n_layers": 6,
|
||||
"kernel_size": 3,
|
||||
"p_dropout": encoder_params_p_dropout,
|
||||
"spk_emb_dim": 64,
|
||||
"n_spks": 1,
|
||||
"prenet": True,
|
||||
}
|
||||
),
|
||||
"duration_predictor_params": AttributeDict(
|
||||
{
|
||||
"filter_channels_dp": filter_channels_dp,
|
||||
"kernel_size": 3,
|
||||
"p_dropout": encoder_params_p_dropout,
|
||||
}
|
||||
),
|
||||
}
|
||||
),
|
||||
"decoder": AttributeDict(
|
||||
{
|
||||
"channels": [256, 256],
|
||||
"dropout": 0.05,
|
||||
"attention_head_dim": 64,
|
||||
"n_blocks": 1,
|
||||
"num_mid_blocks": 2,
|
||||
"num_heads": 2,
|
||||
"act_fn": "snakebeta",
|
||||
}
|
||||
),
|
||||
"cfm": AttributeDict(
|
||||
{
|
||||
"name": "CFM",
|
||||
"solver": "euler",
|
||||
"sigma_min": 1e-4,
|
||||
}
|
||||
),
|
||||
"optimizer": AttributeDict(
|
||||
{
|
||||
"lr": 1e-4,
|
||||
"weight_decay": 0.0,
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
}
|
||||
)
|
||||
m = get_model(params)
|
||||
print(m)
|
||||
|
||||
num_param = sum([p.numel() for p in m.parameters()])
|
||||
print(f"Number of parameters: {num_param}")
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,5 +1,5 @@
|
||||
from matcha.utils.instantiators import instantiate_callbacks, instantiate_loggers
|
||||
from matcha.utils.logging_utils import log_hyperparameters
|
||||
from matcha.utils.pylogger import get_pylogger
|
||||
from matcha.utils.rich_utils import enforce_tags, print_config_tree
|
||||
from matcha.utils.utils import extras, get_metric_value, task_wrapper
|
||||
# from matcha.utils.instantiators import instantiate_callbacks, instantiate_loggers
|
||||
# from matcha.utils.logging_utils import log_hyperparameters
|
||||
# from matcha.utils.pylogger import get_pylogger
|
||||
# from matcha.utils.rich_utils import enforce_tags, print_config_tree
|
||||
# from matcha.utils.utils import extras, get_metric_value, task_wrapper
|
||||
|
3
egs/ljspeech/TTS/matcha/utils/monotonic_align/.gitignore
vendored
Normal file
3
egs/ljspeech/TTS/matcha/utils/monotonic_align/.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
build
|
||||
core.c
|
||||
*.so
|
Loading…
x
Reference in New Issue
Block a user