mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
create model from parameters
This commit is contained in:
parent
f95ac12d70
commit
6fac3a3143
@ -4,9 +4,9 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from matcha.models.components.decoder import Decoder
|
from matcha.models.components.decoder import Decoder
|
||||||
from matcha.utils.pylogger import get_pylogger
|
# from matcha.utils.pylogger import get_pylogger
|
||||||
|
|
||||||
log = get_pylogger(__name__)
|
# log = get_pylogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BASECFM(torch.nn.Module, ABC):
|
class BASECFM(torch.nn.Module, ABC):
|
||||||
|
@ -6,10 +6,10 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
import matcha.utils as utils
|
# import matcha.utils as utils
|
||||||
from matcha.utils.model import sequence_mask
|
from matcha.utils.model import sequence_mask
|
||||||
|
|
||||||
log = utils.get_pylogger(__name__)
|
# log = utils.get_pylogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(nn.Module):
|
class LayerNorm(nn.Module):
|
||||||
|
@ -5,8 +5,8 @@ import random
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import matcha.utils.monotonic_align as monotonic_align
|
import matcha.utils.monotonic_align as monotonic_align
|
||||||
from matcha import utils
|
# from matcha import utils
|
||||||
from matcha.models.baselightningmodule import BaseLightningClass
|
# from matcha.models.baselightningmodule import BaseLightningClass
|
||||||
from matcha.models.components.flow_matching import CFM
|
from matcha.models.components.flow_matching import CFM
|
||||||
from matcha.models.components.text_encoder import TextEncoder
|
from matcha.models.components.text_encoder import TextEncoder
|
||||||
from matcha.utils.model import (
|
from matcha.utils.model import (
|
||||||
@ -17,10 +17,10 @@ from matcha.utils.model import (
|
|||||||
sequence_mask,
|
sequence_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
log = utils.get_pylogger(__name__)
|
# log = utils.get_pylogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MatchaTTS(BaseLightningClass): # 🍵
|
class MatchaTTS(torch.nn.Module): # 🍵
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
n_vocab,
|
n_vocab,
|
||||||
@ -30,7 +30,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
|||||||
encoder,
|
encoder,
|
||||||
decoder,
|
decoder,
|
||||||
cfm,
|
cfm,
|
||||||
data_statistics,
|
# data_statistics,
|
||||||
out_size,
|
out_size,
|
||||||
optimizer=None,
|
optimizer=None,
|
||||||
scheduler=None,
|
scheduler=None,
|
||||||
@ -39,7 +39,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.save_hyperparameters(logger=False)
|
# self.save_hyperparameters(logger=False)
|
||||||
|
|
||||||
self.n_vocab = n_vocab
|
self.n_vocab = n_vocab
|
||||||
self.n_spks = n_spks
|
self.n_spks = n_spks
|
||||||
@ -70,7 +70,7 @@ class MatchaTTS(BaseLightningClass): # 🍵
|
|||||||
spk_emb_dim=spk_emb_dim,
|
spk_emb_dim=spk_emb_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.update_data_statistics(data_statistics)
|
# self.update_data_statistics(data_statistics)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, length_scale=1.0):
|
def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, length_scale=1.0):
|
||||||
|
98
egs/ljspeech/TTS/matcha/train.py
Executable file
98
egs/ljspeech/TTS/matcha/train.py
Executable file
@ -0,0 +1,98 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
from icefall.utils import AttributeDict
|
||||||
|
from matcha.models.matcha_tts import MatchaTTS
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(params):
|
||||||
|
m = MatchaTTS(**params.model)
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
n_feats = 80
|
||||||
|
filter_channels_dp = 256
|
||||||
|
encoder_params_p_dropout = 0.1
|
||||||
|
params = AttributeDict(
|
||||||
|
{
|
||||||
|
"model": AttributeDict(
|
||||||
|
{
|
||||||
|
"n_vocab": 178,
|
||||||
|
"n_spks": 1, # for ljspeech.
|
||||||
|
"spk_emb_dim": 64,
|
||||||
|
"n_feats": n_feats,
|
||||||
|
"out_size": None, # or use 172
|
||||||
|
"prior_loss": True,
|
||||||
|
"use_precomputed_durations": False,
|
||||||
|
"encoder": AttributeDict(
|
||||||
|
{
|
||||||
|
"encoder_type": "RoPE Encoder", # not used
|
||||||
|
"encoder_params": AttributeDict(
|
||||||
|
{
|
||||||
|
"n_feats": n_feats,
|
||||||
|
"n_channels": 192,
|
||||||
|
"filter_channels": 768,
|
||||||
|
"filter_channels_dp": filter_channels_dp,
|
||||||
|
"n_heads": 2,
|
||||||
|
"n_layers": 6,
|
||||||
|
"kernel_size": 3,
|
||||||
|
"p_dropout": encoder_params_p_dropout,
|
||||||
|
"spk_emb_dim": 64,
|
||||||
|
"n_spks": 1,
|
||||||
|
"prenet": True,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
"duration_predictor_params": AttributeDict(
|
||||||
|
{
|
||||||
|
"filter_channels_dp": filter_channels_dp,
|
||||||
|
"kernel_size": 3,
|
||||||
|
"p_dropout": encoder_params_p_dropout,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
"decoder": AttributeDict(
|
||||||
|
{
|
||||||
|
"channels": [256, 256],
|
||||||
|
"dropout": 0.05,
|
||||||
|
"attention_head_dim": 64,
|
||||||
|
"n_blocks": 1,
|
||||||
|
"num_mid_blocks": 2,
|
||||||
|
"num_heads": 2,
|
||||||
|
"act_fn": "snakebeta",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
"cfm": AttributeDict(
|
||||||
|
{
|
||||||
|
"name": "CFM",
|
||||||
|
"solver": "euler",
|
||||||
|
"sigma_min": 1e-4,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
"optimizer": AttributeDict(
|
||||||
|
{
|
||||||
|
"lr": 1e-4,
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
m = get_model(params)
|
||||||
|
print(m)
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in m.parameters()])
|
||||||
|
print(f"Number of parameters: {num_param}")
|
||||||
|
|
||||||
|
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -1,5 +1,5 @@
|
|||||||
from matcha.utils.instantiators import instantiate_callbacks, instantiate_loggers
|
# from matcha.utils.instantiators import instantiate_callbacks, instantiate_loggers
|
||||||
from matcha.utils.logging_utils import log_hyperparameters
|
# from matcha.utils.logging_utils import log_hyperparameters
|
||||||
from matcha.utils.pylogger import get_pylogger
|
# from matcha.utils.pylogger import get_pylogger
|
||||||
from matcha.utils.rich_utils import enforce_tags, print_config_tree
|
# from matcha.utils.rich_utils import enforce_tags, print_config_tree
|
||||||
from matcha.utils.utils import extras, get_metric_value, task_wrapper
|
# from matcha.utils.utils import extras, get_metric_value, task_wrapper
|
||||||
|
3
egs/ljspeech/TTS/matcha/utils/monotonic_align/.gitignore
vendored
Normal file
3
egs/ljspeech/TTS/matcha/utils/monotonic_align/.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
build
|
||||||
|
core.c
|
||||||
|
*.so
|
Loading…
x
Reference in New Issue
Block a user