create model from parameters

This commit is contained in:
Fangjun Kuang 2024-10-15 17:57:10 +08:00
parent f95ac12d70
commit 6fac3a3143
6 changed files with 117 additions and 16 deletions

View File

@ -4,9 +4,9 @@ import torch
import torch.nn.functional as F
from matcha.models.components.decoder import Decoder
from matcha.utils.pylogger import get_pylogger
# from matcha.utils.pylogger import get_pylogger
log = get_pylogger(__name__)
# log = get_pylogger(__name__)
class BASECFM(torch.nn.Module, ABC):

View File

@ -6,10 +6,10 @@ import torch
import torch.nn as nn
from einops import rearrange
import matcha.utils as utils
# import matcha.utils as utils
from matcha.utils.model import sequence_mask
log = utils.get_pylogger(__name__)
# log = utils.get_pylogger(__name__)
class LayerNorm(nn.Module):

View File

@ -5,8 +5,8 @@ import random
import torch
import matcha.utils.monotonic_align as monotonic_align
from matcha import utils
from matcha.models.baselightningmodule import BaseLightningClass
# from matcha import utils
# from matcha.models.baselightningmodule import BaseLightningClass
from matcha.models.components.flow_matching import CFM
from matcha.models.components.text_encoder import TextEncoder
from matcha.utils.model import (
@ -17,10 +17,10 @@ from matcha.utils.model import (
sequence_mask,
)
log = utils.get_pylogger(__name__)
# log = utils.get_pylogger(__name__)
class MatchaTTS(BaseLightningClass): # 🍵
class MatchaTTS(torch.nn.Module): # 🍵
def __init__(
self,
n_vocab,
@ -30,7 +30,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
encoder,
decoder,
cfm,
data_statistics,
# data_statistics,
out_size,
optimizer=None,
scheduler=None,
@ -39,7 +39,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
):
super().__init__()
self.save_hyperparameters(logger=False)
# self.save_hyperparameters(logger=False)
self.n_vocab = n_vocab
self.n_spks = n_spks
@ -70,7 +70,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
spk_emb_dim=spk_emb_dim,
)
self.update_data_statistics(data_statistics)
# self.update_data_statistics(data_statistics)
@torch.inference_mode()
def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, length_scale=1.0):

View File

@ -0,0 +1,98 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
import torch
from icefall.utils import AttributeDict
from matcha.models.matcha_tts import MatchaTTS
def get_model(params):
m = MatchaTTS(**params.model)
return m
def main():
n_feats = 80
filter_channels_dp = 256
encoder_params_p_dropout = 0.1
params = AttributeDict(
{
"model": AttributeDict(
{
"n_vocab": 178,
"n_spks": 1, # for ljspeech.
"spk_emb_dim": 64,
"n_feats": n_feats,
"out_size": None, # or use 172
"prior_loss": True,
"use_precomputed_durations": False,
"encoder": AttributeDict(
{
"encoder_type": "RoPE Encoder", # not used
"encoder_params": AttributeDict(
{
"n_feats": n_feats,
"n_channels": 192,
"filter_channels": 768,
"filter_channels_dp": filter_channels_dp,
"n_heads": 2,
"n_layers": 6,
"kernel_size": 3,
"p_dropout": encoder_params_p_dropout,
"spk_emb_dim": 64,
"n_spks": 1,
"prenet": True,
}
),
"duration_predictor_params": AttributeDict(
{
"filter_channels_dp": filter_channels_dp,
"kernel_size": 3,
"p_dropout": encoder_params_p_dropout,
}
),
}
),
"decoder": AttributeDict(
{
"channels": [256, 256],
"dropout": 0.05,
"attention_head_dim": 64,
"n_blocks": 1,
"num_mid_blocks": 2,
"num_heads": 2,
"act_fn": "snakebeta",
}
),
"cfm": AttributeDict(
{
"name": "CFM",
"solver": "euler",
"sigma_min": 1e-4,
}
),
"optimizer": AttributeDict(
{
"lr": 1e-4,
"weight_decay": 0.0,
}
),
}
)
}
)
m = get_model(params)
print(m)
num_param = sum([p.numel() for p in m.parameters()])
print(f"Number of parameters: {num_param}")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -1,5 +1,5 @@
from matcha.utils.instantiators import instantiate_callbacks, instantiate_loggers
from matcha.utils.logging_utils import log_hyperparameters
from matcha.utils.pylogger import get_pylogger
from matcha.utils.rich_utils import enforce_tags, print_config_tree
from matcha.utils.utils import extras, get_metric_value, task_wrapper
# from matcha.utils.instantiators import instantiate_callbacks, instantiate_loggers
# from matcha.utils.logging_utils import log_hyperparameters
# from matcha.utils.pylogger import get_pylogger
# from matcha.utils.rich_utils import enforce_tags, print_config_tree
# from matcha.utils.utils import extras, get_metric_value, task_wrapper

View File

@ -0,0 +1,3 @@
build
core.c
*.so