diff --git a/egs/vctk/TTS/vits/train.py b/egs/vctk/TTS/vits/train.py index 367f7c108..8cc597a8e 100755 --- a/egs/vctk/TTS/vits/train.py +++ b/egs/vctk/TTS/vits/train.py @@ -274,10 +274,48 @@ def get_model(params: AttributeDict) -> nn.Module: "frame_length": params.frame_length, "frame_shift": params.frame_shift, } + generator_params = { + "hidden_channels": 192, + "spks": params.num_spks, + "langs": None, + "spk_embed_dim": None, + "global_channels": -1, + "segment_size": 32, + "text_encoder_attention_heads": 2, + "text_encoder_ffn_expand": 4, + "text_encoder_cnn_module_kernel": 5, + "text_encoder_blocks": 6, + "text_encoder_dropout_rate": 0.1, + "decoder_kernel_size": 7, + "decoder_channels": 512, + "decoder_upsample_scales": [8, 8, 2, 2], + "decoder_upsample_kernel_sizes": [16, 16, 4, 4], + "decoder_resblock_kernel_sizes": [3, 7, 11], + "decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "use_weight_norm_in_decoder": True, + "posterior_encoder_kernel_size": 5, + "posterior_encoder_layers": 16, + "posterior_encoder_stacks": 1, + "posterior_encoder_base_dilation": 1, + "posterior_encoder_dropout_rate": 0.0, + "use_weight_norm_in_posterior_encoder": True, + "flow_flows": 4, + "flow_kernel_size": 5, + "flow_base_dilation": 1, + "flow_layers": 4, + "flow_dropout_rate": 0.0, + "use_weight_norm_in_flow": True, + "use_only_mean_in_flow": True, + "stochastic_duration_predictor_kernel_size": 3, + "stochastic_duration_predictor_dropout_rate": 0.5, + "stochastic_duration_predictor_flows": 4, + "stochastic_duration_predictor_dds_conv_layers": 3, + } model = VITS( vocab_size=params.vocab_size, feature_dim=params.feature_dim, sampling_rate=params.sampling_rate, + generator_params=generator_params, mel_loss_params=mel_loss_params, lambda_adv=params.lambda_adv, lambda_mel=params.lambda_mel, @@ -775,6 +813,12 @@ def run(rank, world_size, args): params.oov_id = tokenizer.oov_id params.vocab_size = tokenizer.vocab_size + vctk = VctkTtsDataModule(args) + + train_cuts = vctk.train_cuts() + speaker_map = vctk.speakers() + params.num_spks = len(speaker_map) + logging.info(params) logging.info("About to create model") @@ -832,11 +876,6 @@ def run(rank, world_size, args): if params.inf_check: register_inf_check_hooks(model) - vctk = VctkTtsDataModule(args) - - train_cuts = vctk.train_cuts() - speaker_map = vctk.speakers() - def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds # You should use ../local/display_manifest_statistics.py to get