From 0e7f0a4ee9550f5743a8486fa93c9ee82aed1959 Mon Sep 17 00:00:00 2001 From: jinzr Date: Thu, 30 Nov 2023 22:11:44 +0800 Subject: [PATCH] updated --- egs/vctk/TTS/vits/train.py | 60 ++++++++++++++++++++++------- egs/vctk/TTS/vits/tts_datamodule.py | 25 +++++++----- 2 files changed, 62 insertions(+), 23 deletions(-) diff --git a/egs/vctk/TTS/vits/train.py b/egs/vctk/TTS/vits/train.py index 1dfe92685..367f7c108 100755 --- a/egs/vctk/TTS/vits/train.py +++ b/egs/vctk/TTS/vits/train.py @@ -288,13 +288,19 @@ def get_model(params: AttributeDict) -> nn.Module: return model -def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device): +def prepare_input( + batch: dict, + tokenizer: Tokenizer, + device: torch.device, + speaker_map: Dict[str, int], +): """Parse batch data""" audio = batch["audio"].to(device) features = batch["features"].to(device) audio_lens = batch["audio_lens"].to(device) features_lens = batch["features_lens"].to(device) tokens = batch["tokens"] + speakers = torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]).to(device) tokens = tokenizer.tokens_to_token_ids(tokens) tokens = k2.RaggedTensor(tokens) @@ -305,7 +311,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device): # a tensor of shape (B, T) tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) - return audio, audio_lens, features, features_lens, tokens, tokens_lens + return audio, audio_lens, features, features_lens, tokens, tokens_lens, speakers def train_one_epoch( @@ -318,6 +324,7 @@ def train_one_epoch( scheduler_d: LRSchedulerType, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, + speaker_map: Dict[str, int], scaler: GradScaler, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -384,9 +391,15 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["tokens"]) - audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input( - batch, tokenizer, device - ) + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + speakers, + ) = prepare_input(batch, tokenizer, device, speaker_map) loss_info = MetricsTracker() loss_info["samples"] = batch_size @@ -401,6 +414,7 @@ def train_one_epoch( feats_lengths=features_lens, speech=audio, speech_lengths=audio_lens, + sid=speakers, forward_generator=False, ) for k, v in stats_d.items(): @@ -419,6 +433,7 @@ def train_one_epoch( feats_lengths=features_lens, speech=audio, speech_lengths=audio_lens, + sid=speakers, forward_generator=True, return_sample=params.batch_idx_train % params.log_interval == 0, ) @@ -526,6 +541,7 @@ def train_one_epoch( model=model, tokenizer=tokenizer, valid_dl=valid_dl, + speaker_map=speaker_map, world_size=world_size, ) model.train() @@ -562,6 +578,7 @@ def compute_validation_loss( model: Union[nn.Module, DDP], tokenizer: Tokenizer, valid_dl: torch.utils.data.DataLoader, + speaker_map: Dict[str, int], world_size: int = 1, rank: int = 0, ) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]: @@ -583,7 +600,8 @@ def compute_validation_loss( features_lens, tokens, tokens_lens, - ) = prepare_input(batch, tokenizer, device) + speakers, + ) = prepare_input(batch, tokenizer, device, speaker_map) loss_info = MetricsTracker() loss_info["samples"] = batch_size @@ -596,6 +614,7 @@ def compute_validation_loss( feats_lengths=features_lens, speech=audio, speech_lengths=audio_lens, + sid=speakers, forward_generator=False, ) assert loss_d.requires_grad is False @@ -610,6 +629,7 @@ def compute_validation_loss( feats_lengths=features_lens, speech=audio, speech_lengths=audio_lens, + sid=speakers, forward_generator=True, ) assert loss_g.requires_grad is False @@ -653,6 +673,7 @@ def scan_pessimistic_batches_for_oom( tokenizer: Tokenizer, optimizer_g: torch.optim.Optimizer, optimizer_d: torch.optim.Optimizer, + speaker_map: Dict[str, int], params: AttributeDict, ): from lhotse.dataset import find_pessimistic_batches @@ -664,9 +685,15 @@ def scan_pessimistic_batches_for_oom( batches, crit_values = find_pessimistic_batches(train_dl.sampler) for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] - audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input( - batch, tokenizer, device - ) + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + speakers, + ) = prepare_input(batch, tokenizer, device, speaker_map) try: # for discriminator with autocast(enabled=params.use_fp16): @@ -677,6 +704,7 @@ def scan_pessimistic_batches_for_oom( feats_lengths=features_lens, speech=audio, speech_lengths=audio_lens, + sid=speakers, forward_generator=False, ) optimizer_d.zero_grad() @@ -690,6 +718,7 @@ def scan_pessimistic_batches_for_oom( feats_lengths=features_lens, speech=audio, speech_lengths=audio_lens, + sid=speakers, forward_generator=True, ) optimizer_g.zero_grad() @@ -803,9 +832,10 @@ def run(rank, world_size, args): if params.inf_check: register_inf_check_hooks(model) - ljspeech = VctkTtsDataModule(args) + vctk = VctkTtsDataModule(args) - train_cuts = ljspeech.train_cuts() + train_cuts = vctk.train_cuts() + speaker_map = vctk.speakers() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds @@ -820,10 +850,10 @@ def run(rank, world_size, args): return True train_cuts = train_cuts.filter(remove_short_and_long_utt) - train_dl = ljspeech.train_dataloaders(train_cuts) + train_dl = vctk.train_dataloaders(train_cuts) - valid_cuts = ljspeech.valid_cuts() - valid_dl = ljspeech.valid_dataloaders(valid_cuts) + valid_cuts = vctk.valid_cuts() + valid_dl = vctk.valid_dataloaders(valid_cuts) if not params.print_diagnostics: scan_pessimistic_batches_for_oom( @@ -832,6 +862,7 @@ def run(rank, world_size, args): tokenizer=tokenizer, optimizer_g=optimizer_g, optimizer_d=optimizer_d, + speaker_map=speaker_map, params=params, ) @@ -861,6 +892,7 @@ def run(rank, world_size, args): scheduler_d=scheduler_d, train_dl=train_dl, valid_dl=valid_dl, + speaker_map=speaker_map, scaler=scaler, tb_writer=tb_writer, world_size=world_size, diff --git a/egs/vctk/TTS/vits/tts_datamodule.py b/egs/vctk/TTS/vits/tts_datamodule.py index 93f39e329..f7772d6d6 100644 --- a/egs/vctk/TTS/vits/tts_datamodule.py +++ b/egs/vctk/TTS/vits/tts_datamodule.py @@ -88,6 +88,12 @@ class VctkTtsDataModule: default=Path("data/spectrogram"), help="Path to directory with train/valid/test cuts.", ) + group.add_argument( + "--speakers", + type=Path, + default=Path("data/speakers.txt"), + help="Path to speakers.txt file.", + ) group.add_argument( "--max-duration", type=int, @@ -306,20 +312,21 @@ class VctkTtsDataModule: @lru_cache() def train_cuts(self) -> CutSet: logging.info("About to get train cuts") - return load_manifest_lazy( - self.args.manifest_dir / "ljspeech_cuts_train.jsonl.gz" - ) + return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_train.jsonl.gz") @lru_cache() def valid_cuts(self) -> CutSet: logging.info("About to get validation cuts") - return load_manifest_lazy( - self.args.manifest_dir / "ljspeech_cuts_valid.jsonl.gz" - ) + return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_valid.jsonl.gz") @lru_cache() def test_cuts(self) -> CutSet: logging.info("About to get test cuts") - return load_manifest_lazy( - self.args.manifest_dir / "ljspeech_cuts_test.jsonl.gz" - ) + return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_test.jsonl.gz") + + @lru_cache() + def speakers(self) -> Dict[str, int]: + logging.info("About to get speakers") + with open(self.args.speakers) as f: + speakers = {line.strip(): i for i, line in enumerate(f)} + return speakers