diff --git a/egs/libritts/TTS/vits/infer.py b/egs/libritts/TTS/vits/infer.py index fda07b791..675678606 100755 --- a/egs/libritts/TTS/vits/infer.py +++ b/egs/libritts/TTS/vits/infer.py @@ -152,7 +152,7 @@ def infer_dataset( audio_lens = batch["audio_lens"].tolist() cut_ids = [cut.id for cut in batch["cut"]] sids = ["_".join(cut_id.split("_")[:2]) for cut_id in cut_ids] - speakers = ( + spembs = ( torch.Tensor(np.array([speaker_map.read(sid) for sid in sids])) .squeeze(1) .to(device) @@ -161,7 +161,7 @@ def infer_dataset( audio_pred, _, durations = model.inference_batch( text=tokens, text_lengths=tokens_lens, - spembs=speakers, + spembs=spembs, ) audio_pred = audio_pred.detach().cpu() # convert to samples diff --git a/egs/libritts/TTS/vits/train.py b/egs/libritts/TTS/vits/train.py index d89de9608..67864bdce 100755 --- a/egs/libritts/TTS/vits/train.py +++ b/egs/libritts/TTS/vits/train.py @@ -344,7 +344,7 @@ def prepare_input( audio_lens = batch["audio_lens"].to(device) features_lens = batch["features_lens"].to(device) tokens = batch["tokens"] - speakers = ( + spembs = ( torch.Tensor(np.array([speaker_map.read(sid) for sid in parse_sids(batch)])) .squeeze(1) .to(device) @@ -361,7 +361,7 @@ def prepare_input( # a tensor of shape (B, T) tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id) - return audio, audio_lens, features, features_lens, tokens, tokens_lens, speakers + return audio, audio_lens, features, features_lens, tokens, tokens_lens, spembs def train_one_epoch( @@ -449,7 +449,7 @@ def train_one_epoch( features_lens, tokens, tokens_lens, - speakers, + spembs, ) = prepare_input(batch, tokenizer, device, train_speaker_map) loss_info = MetricsTracker() @@ -465,7 +465,7 @@ def train_one_epoch( feats_lengths=features_lens, speech=audio, speech_lengths=audio_lens, - spembs=speakers, + spembs=spembs, forward_generator=False, ) for k, v in stats_d.items(): @@ -484,7 +484,7 @@ def train_one_epoch( feats_lengths=features_lens, speech=audio, speech_lengths=audio_lens, - spembs=speakers, + spembs=spembs, forward_generator=True, return_sample=params.batch_idx_train % params.log_interval == 0, ) @@ -651,7 +651,7 @@ def compute_validation_loss( features_lens, tokens, tokens_lens, - speakers, + spembs, ) = prepare_input(batch, tokenizer, device, dev_speaker_map) loss_info = MetricsTracker() @@ -665,7 +665,7 @@ def compute_validation_loss( feats_lengths=features_lens, speech=audio, speech_lengths=audio_lens, - spembs=speakers, + spembs=spembs, forward_generator=False, ) assert loss_d.requires_grad is False @@ -680,7 +680,7 @@ def compute_validation_loss( feats_lengths=features_lens, speech=audio, speech_lengths=audio_lens, - spembs=speakers, + spembs=spembs, forward_generator=True, ) assert loss_g.requires_grad is False @@ -695,7 +695,7 @@ def compute_validation_loss( inner_model = model.module if isinstance(model, DDP) else model audio_pred, _, duration = inner_model.inference( text=tokens[0, : tokens_lens[0].item()], - spembs=speakers[0], + spembs=spembs[0], ) audio_pred = audio_pred.data.cpu().numpy() audio_len_pred = ( @@ -744,7 +744,7 @@ def scan_pessimistic_batches_for_oom( features_lens, tokens, tokens_lens, - speakers, + spembs, ) = prepare_input(batch, tokenizer, device, train_speaker_map) try: # for discriminator @@ -756,7 +756,7 @@ def scan_pessimistic_batches_for_oom( feats_lengths=features_lens, speech=audio, speech_lengths=audio_lens, - spembs=speakers, + spembs=spembs, forward_generator=False, ) optimizer_d.zero_grad() @@ -770,7 +770,7 @@ def scan_pessimistic_batches_for_oom( feats_lengths=features_lens, speech=audio, speech_lengths=audio_lens, - spembs=speakers, + spembs=spembs, forward_generator=True, ) optimizer_g.zero_grad() diff --git a/egs/ljspeech/TTS/vits/generator.py b/egs/ljspeech/TTS/vits/generator.py index b9add9e82..521b0121f 100644 --- a/egs/ljspeech/TTS/vits/generator.py +++ b/egs/ljspeech/TTS/vits/generator.py @@ -409,7 +409,12 @@ class VITSGenerator(torch.nn.Module): g = self.global_emb(sids.view(-1)).unsqueeze(-1) if self.spk_embed_dim is not None: # (B, global_channels, 1) - g_ = self.spemb_proj(F.normalize(spembs.unsqueeze(0))).unsqueeze(-1) + if spembs.ndim == 2: + g_ = self.spemb_proj(F.normalize(spembs)).unsqueeze(-1) + elif spembs.ndim == 1: + g_ = self.spemb_proj(F.normalize(spembs.unsqueeze(0))).unsqueeze(-1) + else: + raise ValueError("spembs should be 1D or 2D (batch mode) tensor.") if g is None: g = g_ else: