mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
Update train.py
This commit is contained in:
parent
0e7f0a4ee9
commit
2eb06451ee
@ -274,10 +274,48 @@ def get_model(params: AttributeDict) -> nn.Module:
|
|||||||
"frame_length": params.frame_length,
|
"frame_length": params.frame_length,
|
||||||
"frame_shift": params.frame_shift,
|
"frame_shift": params.frame_shift,
|
||||||
}
|
}
|
||||||
|
generator_params = {
|
||||||
|
"hidden_channels": 192,
|
||||||
|
"spks": params.num_spks,
|
||||||
|
"langs": None,
|
||||||
|
"spk_embed_dim": None,
|
||||||
|
"global_channels": -1,
|
||||||
|
"segment_size": 32,
|
||||||
|
"text_encoder_attention_heads": 2,
|
||||||
|
"text_encoder_ffn_expand": 4,
|
||||||
|
"text_encoder_cnn_module_kernel": 5,
|
||||||
|
"text_encoder_blocks": 6,
|
||||||
|
"text_encoder_dropout_rate": 0.1,
|
||||||
|
"decoder_kernel_size": 7,
|
||||||
|
"decoder_channels": 512,
|
||||||
|
"decoder_upsample_scales": [8, 8, 2, 2],
|
||||||
|
"decoder_upsample_kernel_sizes": [16, 16, 4, 4],
|
||||||
|
"decoder_resblock_kernel_sizes": [3, 7, 11],
|
||||||
|
"decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||||
|
"use_weight_norm_in_decoder": True,
|
||||||
|
"posterior_encoder_kernel_size": 5,
|
||||||
|
"posterior_encoder_layers": 16,
|
||||||
|
"posterior_encoder_stacks": 1,
|
||||||
|
"posterior_encoder_base_dilation": 1,
|
||||||
|
"posterior_encoder_dropout_rate": 0.0,
|
||||||
|
"use_weight_norm_in_posterior_encoder": True,
|
||||||
|
"flow_flows": 4,
|
||||||
|
"flow_kernel_size": 5,
|
||||||
|
"flow_base_dilation": 1,
|
||||||
|
"flow_layers": 4,
|
||||||
|
"flow_dropout_rate": 0.0,
|
||||||
|
"use_weight_norm_in_flow": True,
|
||||||
|
"use_only_mean_in_flow": True,
|
||||||
|
"stochastic_duration_predictor_kernel_size": 3,
|
||||||
|
"stochastic_duration_predictor_dropout_rate": 0.5,
|
||||||
|
"stochastic_duration_predictor_flows": 4,
|
||||||
|
"stochastic_duration_predictor_dds_conv_layers": 3,
|
||||||
|
}
|
||||||
model = VITS(
|
model = VITS(
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
feature_dim=params.feature_dim,
|
feature_dim=params.feature_dim,
|
||||||
sampling_rate=params.sampling_rate,
|
sampling_rate=params.sampling_rate,
|
||||||
|
generator_params=generator_params,
|
||||||
mel_loss_params=mel_loss_params,
|
mel_loss_params=mel_loss_params,
|
||||||
lambda_adv=params.lambda_adv,
|
lambda_adv=params.lambda_adv,
|
||||||
lambda_mel=params.lambda_mel,
|
lambda_mel=params.lambda_mel,
|
||||||
@ -775,6 +813,12 @@ def run(rank, world_size, args):
|
|||||||
params.oov_id = tokenizer.oov_id
|
params.oov_id = tokenizer.oov_id
|
||||||
params.vocab_size = tokenizer.vocab_size
|
params.vocab_size = tokenizer.vocab_size
|
||||||
|
|
||||||
|
vctk = VctkTtsDataModule(args)
|
||||||
|
|
||||||
|
train_cuts = vctk.train_cuts()
|
||||||
|
speaker_map = vctk.speakers()
|
||||||
|
params.num_spks = len(speaker_map)
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
@ -832,11 +876,6 @@ def run(rank, world_size, args):
|
|||||||
if params.inf_check:
|
if params.inf_check:
|
||||||
register_inf_check_hooks(model)
|
register_inf_check_hooks(model)
|
||||||
|
|
||||||
vctk = VctkTtsDataModule(args)
|
|
||||||
|
|
||||||
train_cuts = vctk.train_cuts()
|
|
||||||
speaker_map = vctk.speakers()
|
|
||||||
|
|
||||||
def remove_short_and_long_utt(c: Cut):
|
def remove_short_and_long_utt(c: Cut):
|
||||||
# Keep only utterances with duration between 1 second and 20 seconds
|
# Keep only utterances with duration between 1 second and 20 seconds
|
||||||
# You should use ../local/display_manifest_statistics.py to get
|
# You should use ../local/display_manifest_statistics.py to get
|
||||||
|
Loading…
x
Reference in New Issue
Block a user