minor updates

This commit is contained in:
zr_jin 2024-10-22 12:53:44 +08:00
parent ca3b495c4f
commit 3c3db1ae69
3 changed files with 20 additions and 15 deletions

View File

@ -152,7 +152,7 @@ def infer_dataset(
audio_lens = batch["audio_lens"].tolist()
cut_ids = [cut.id for cut in batch["cut"]]
sids = ["_".join(cut_id.split("_")[:2]) for cut_id in cut_ids]
speakers = (
spembs = (
torch.Tensor(np.array([speaker_map.read(sid) for sid in sids]))
.squeeze(1)
.to(device)
@ -161,7 +161,7 @@ def infer_dataset(
audio_pred, _, durations = model.inference_batch(
text=tokens,
text_lengths=tokens_lens,
spembs=speakers,
spembs=spembs,
)
audio_pred = audio_pred.detach().cpu()
# convert to samples

View File

@ -344,7 +344,7 @@ def prepare_input(
audio_lens = batch["audio_lens"].to(device)
features_lens = batch["features_lens"].to(device)
tokens = batch["tokens"]
speakers = (
spembs = (
torch.Tensor(np.array([speaker_map.read(sid) for sid in parse_sids(batch)]))
.squeeze(1)
.to(device)
@ -361,7 +361,7 @@ def prepare_input(
# a tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
return audio, audio_lens, features, features_lens, tokens, tokens_lens, speakers
return audio, audio_lens, features, features_lens, tokens, tokens_lens, spembs
def train_one_epoch(
@ -449,7 +449,7 @@ def train_one_epoch(
features_lens,
tokens,
tokens_lens,
speakers,
spembs,
) = prepare_input(batch, tokenizer, device, train_speaker_map)
loss_info = MetricsTracker()
@ -465,7 +465,7 @@ def train_one_epoch(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
spembs=speakers,
spembs=spembs,
forward_generator=False,
)
for k, v in stats_d.items():
@ -484,7 +484,7 @@ def train_one_epoch(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
spembs=speakers,
spembs=spembs,
forward_generator=True,
return_sample=params.batch_idx_train % params.log_interval == 0,
)
@ -651,7 +651,7 @@ def compute_validation_loss(
features_lens,
tokens,
tokens_lens,
speakers,
spembs,
) = prepare_input(batch, tokenizer, device, dev_speaker_map)
loss_info = MetricsTracker()
@ -665,7 +665,7 @@ def compute_validation_loss(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
spembs=speakers,
spembs=spembs,
forward_generator=False,
)
assert loss_d.requires_grad is False
@ -680,7 +680,7 @@ def compute_validation_loss(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
spembs=speakers,
spembs=spembs,
forward_generator=True,
)
assert loss_g.requires_grad is False
@ -695,7 +695,7 @@ def compute_validation_loss(
inner_model = model.module if isinstance(model, DDP) else model
audio_pred, _, duration = inner_model.inference(
text=tokens[0, : tokens_lens[0].item()],
spembs=speakers[0],
spembs=spembs[0],
)
audio_pred = audio_pred.data.cpu().numpy()
audio_len_pred = (
@ -744,7 +744,7 @@ def scan_pessimistic_batches_for_oom(
features_lens,
tokens,
tokens_lens,
speakers,
spembs,
) = prepare_input(batch, tokenizer, device, train_speaker_map)
try:
# for discriminator
@ -756,7 +756,7 @@ def scan_pessimistic_batches_for_oom(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
spembs=speakers,
spembs=spembs,
forward_generator=False,
)
optimizer_d.zero_grad()
@ -770,7 +770,7 @@ def scan_pessimistic_batches_for_oom(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
spembs=speakers,
spembs=spembs,
forward_generator=True,
)
optimizer_g.zero_grad()

View File

@ -409,7 +409,12 @@ class VITSGenerator(torch.nn.Module):
g = self.global_emb(sids.view(-1)).unsqueeze(-1)
if self.spk_embed_dim is not None:
# (B, global_channels, 1)
g_ = self.spemb_proj(F.normalize(spembs.unsqueeze(0))).unsqueeze(-1)
if spembs.ndim == 2:
g_ = self.spemb_proj(F.normalize(spembs)).unsqueeze(-1)
elif spembs.ndim == 1:
g_ = self.spemb_proj(F.normalize(spembs.unsqueeze(0))).unsqueeze(-1)
else:
raise ValueError("spembs should be 1D or 2D (batch mode) tensor.")
if g is None:
g = g_
else: