Update train.py

This commit is contained in:
jinzr 2023-11-30 22:32:07 +08:00
parent 0e7f0a4ee9
commit 2eb06451ee

View File

@ -274,10 +274,48 @@ def get_model(params: AttributeDict) -> nn.Module:
"frame_length": params.frame_length,
"frame_shift": params.frame_shift,
}
generator_params = {
"hidden_channels": 192,
"spks": params.num_spks,
"langs": None,
"spk_embed_dim": None,
"global_channels": -1,
"segment_size": 32,
"text_encoder_attention_heads": 2,
"text_encoder_ffn_expand": 4,
"text_encoder_cnn_module_kernel": 5,
"text_encoder_blocks": 6,
"text_encoder_dropout_rate": 0.1,
"decoder_kernel_size": 7,
"decoder_channels": 512,
"decoder_upsample_scales": [8, 8, 2, 2],
"decoder_upsample_kernel_sizes": [16, 16, 4, 4],
"decoder_resblock_kernel_sizes": [3, 7, 11],
"decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"use_weight_norm_in_decoder": True,
"posterior_encoder_kernel_size": 5,
"posterior_encoder_layers": 16,
"posterior_encoder_stacks": 1,
"posterior_encoder_base_dilation": 1,
"posterior_encoder_dropout_rate": 0.0,
"use_weight_norm_in_posterior_encoder": True,
"flow_flows": 4,
"flow_kernel_size": 5,
"flow_base_dilation": 1,
"flow_layers": 4,
"flow_dropout_rate": 0.0,
"use_weight_norm_in_flow": True,
"use_only_mean_in_flow": True,
"stochastic_duration_predictor_kernel_size": 3,
"stochastic_duration_predictor_dropout_rate": 0.5,
"stochastic_duration_predictor_flows": 4,
"stochastic_duration_predictor_dds_conv_layers": 3,
}
model = VITS(
vocab_size=params.vocab_size,
feature_dim=params.feature_dim,
sampling_rate=params.sampling_rate,
generator_params=generator_params,
mel_loss_params=mel_loss_params,
lambda_adv=params.lambda_adv,
lambda_mel=params.lambda_mel,
@ -775,6 +813,12 @@ def run(rank, world_size, args):
params.oov_id = tokenizer.oov_id
params.vocab_size = tokenizer.vocab_size
vctk = VctkTtsDataModule(args)
train_cuts = vctk.train_cuts()
speaker_map = vctk.speakers()
params.num_spks = len(speaker_map)
logging.info(params)
logging.info("About to create model")
@ -832,11 +876,6 @@ def run(rank, world_size, args):
if params.inf_check:
register_inf_check_hooks(model)
vctk = VctkTtsDataModule(args)
train_cuts = vctk.train_cuts()
speaker_map = vctk.speakers()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
# You should use ../local/display_manifest_statistics.py to get