mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
99 lines
3.3 KiB
Python
Executable File
99 lines
3.3 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
|
|
|
|
import torch
|
|
|
|
|
|
from icefall.utils import AttributeDict
|
|
from matcha.models.matcha_tts import MatchaTTS
|
|
|
|
|
|
def get_model(params):
|
|
m = MatchaTTS(**params.model)
|
|
return m
|
|
|
|
|
|
def main():
|
|
n_feats = 80
|
|
filter_channels_dp = 256
|
|
encoder_params_p_dropout = 0.1
|
|
params = AttributeDict(
|
|
{
|
|
"model": AttributeDict(
|
|
{
|
|
"n_vocab": 178,
|
|
"n_spks": 1, # for ljspeech.
|
|
"spk_emb_dim": 64,
|
|
"n_feats": n_feats,
|
|
"out_size": None, # or use 172
|
|
"prior_loss": True,
|
|
"use_precomputed_durations": False,
|
|
"encoder": AttributeDict(
|
|
{
|
|
"encoder_type": "RoPE Encoder", # not used
|
|
"encoder_params": AttributeDict(
|
|
{
|
|
"n_feats": n_feats,
|
|
"n_channels": 192,
|
|
"filter_channels": 768,
|
|
"filter_channels_dp": filter_channels_dp,
|
|
"n_heads": 2,
|
|
"n_layers": 6,
|
|
"kernel_size": 3,
|
|
"p_dropout": encoder_params_p_dropout,
|
|
"spk_emb_dim": 64,
|
|
"n_spks": 1,
|
|
"prenet": True,
|
|
}
|
|
),
|
|
"duration_predictor_params": AttributeDict(
|
|
{
|
|
"filter_channels_dp": filter_channels_dp,
|
|
"kernel_size": 3,
|
|
"p_dropout": encoder_params_p_dropout,
|
|
}
|
|
),
|
|
}
|
|
),
|
|
"decoder": AttributeDict(
|
|
{
|
|
"channels": [256, 256],
|
|
"dropout": 0.05,
|
|
"attention_head_dim": 64,
|
|
"n_blocks": 1,
|
|
"num_mid_blocks": 2,
|
|
"num_heads": 2,
|
|
"act_fn": "snakebeta",
|
|
}
|
|
),
|
|
"cfm": AttributeDict(
|
|
{
|
|
"name": "CFM",
|
|
"solver": "euler",
|
|
"sigma_min": 1e-4,
|
|
}
|
|
),
|
|
"optimizer": AttributeDict(
|
|
{
|
|
"lr": 1e-4,
|
|
"weight_decay": 0.0,
|
|
}
|
|
),
|
|
}
|
|
)
|
|
}
|
|
)
|
|
m = get_model(params)
|
|
print(m)
|
|
|
|
num_param = sum([p.numel() for p in m.parameters()])
|
|
print(f"Number of parameters: {num_param}")
|
|
|
|
|
|
torch.set_num_threads(1)
|
|
torch.set_num_interop_threads(1)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|