mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-16 04:32:19 +00:00
minor updates
This commit is contained in:
parent
ca3b495c4f
commit
3c3db1ae69
@ -152,7 +152,7 @@ def infer_dataset(
|
||||
audio_lens = batch["audio_lens"].tolist()
|
||||
cut_ids = [cut.id for cut in batch["cut"]]
|
||||
sids = ["_".join(cut_id.split("_")[:2]) for cut_id in cut_ids]
|
||||
speakers = (
|
||||
spembs = (
|
||||
torch.Tensor(np.array([speaker_map.read(sid) for sid in sids]))
|
||||
.squeeze(1)
|
||||
.to(device)
|
||||
@ -161,7 +161,7 @@ def infer_dataset(
|
||||
audio_pred, _, durations = model.inference_batch(
|
||||
text=tokens,
|
||||
text_lengths=tokens_lens,
|
||||
spembs=speakers,
|
||||
spembs=spembs,
|
||||
)
|
||||
audio_pred = audio_pred.detach().cpu()
|
||||
# convert to samples
|
||||
|
@ -344,7 +344,7 @@ def prepare_input(
|
||||
audio_lens = batch["audio_lens"].to(device)
|
||||
features_lens = batch["features_lens"].to(device)
|
||||
tokens = batch["tokens"]
|
||||
speakers = (
|
||||
spembs = (
|
||||
torch.Tensor(np.array([speaker_map.read(sid) for sid in parse_sids(batch)]))
|
||||
.squeeze(1)
|
||||
.to(device)
|
||||
@ -361,7 +361,7 @@ def prepare_input(
|
||||
# a tensor of shape (B, T)
|
||||
tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
|
||||
|
||||
return audio, audio_lens, features, features_lens, tokens, tokens_lens, speakers
|
||||
return audio, audio_lens, features, features_lens, tokens, tokens_lens, spembs
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
@ -449,7 +449,7 @@ def train_one_epoch(
|
||||
features_lens,
|
||||
tokens,
|
||||
tokens_lens,
|
||||
speakers,
|
||||
spembs,
|
||||
) = prepare_input(batch, tokenizer, device, train_speaker_map)
|
||||
|
||||
loss_info = MetricsTracker()
|
||||
@ -465,7 +465,7 @@ def train_one_epoch(
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
spembs=speakers,
|
||||
spembs=spembs,
|
||||
forward_generator=False,
|
||||
)
|
||||
for k, v in stats_d.items():
|
||||
@ -484,7 +484,7 @@ def train_one_epoch(
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
spembs=speakers,
|
||||
spembs=spembs,
|
||||
forward_generator=True,
|
||||
return_sample=params.batch_idx_train % params.log_interval == 0,
|
||||
)
|
||||
@ -651,7 +651,7 @@ def compute_validation_loss(
|
||||
features_lens,
|
||||
tokens,
|
||||
tokens_lens,
|
||||
speakers,
|
||||
spembs,
|
||||
) = prepare_input(batch, tokenizer, device, dev_speaker_map)
|
||||
|
||||
loss_info = MetricsTracker()
|
||||
@ -665,7 +665,7 @@ def compute_validation_loss(
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
spembs=speakers,
|
||||
spembs=spembs,
|
||||
forward_generator=False,
|
||||
)
|
||||
assert loss_d.requires_grad is False
|
||||
@ -680,7 +680,7 @@ def compute_validation_loss(
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
spembs=speakers,
|
||||
spembs=spembs,
|
||||
forward_generator=True,
|
||||
)
|
||||
assert loss_g.requires_grad is False
|
||||
@ -695,7 +695,7 @@ def compute_validation_loss(
|
||||
inner_model = model.module if isinstance(model, DDP) else model
|
||||
audio_pred, _, duration = inner_model.inference(
|
||||
text=tokens[0, : tokens_lens[0].item()],
|
||||
spembs=speakers[0],
|
||||
spembs=spembs[0],
|
||||
)
|
||||
audio_pred = audio_pred.data.cpu().numpy()
|
||||
audio_len_pred = (
|
||||
@ -744,7 +744,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
features_lens,
|
||||
tokens,
|
||||
tokens_lens,
|
||||
speakers,
|
||||
spembs,
|
||||
) = prepare_input(batch, tokenizer, device, train_speaker_map)
|
||||
try:
|
||||
# for discriminator
|
||||
@ -756,7 +756,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
spembs=speakers,
|
||||
spembs=spembs,
|
||||
forward_generator=False,
|
||||
)
|
||||
optimizer_d.zero_grad()
|
||||
@ -770,7 +770,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
spembs=speakers,
|
||||
spembs=spembs,
|
||||
forward_generator=True,
|
||||
)
|
||||
optimizer_g.zero_grad()
|
||||
|
@ -409,7 +409,12 @@ class VITSGenerator(torch.nn.Module):
|
||||
g = self.global_emb(sids.view(-1)).unsqueeze(-1)
|
||||
if self.spk_embed_dim is not None:
|
||||
# (B, global_channels, 1)
|
||||
g_ = self.spemb_proj(F.normalize(spembs.unsqueeze(0))).unsqueeze(-1)
|
||||
if spembs.ndim == 2:
|
||||
g_ = self.spemb_proj(F.normalize(spembs)).unsqueeze(-1)
|
||||
elif spembs.ndim == 1:
|
||||
g_ = self.spemb_proj(F.normalize(spembs.unsqueeze(0))).unsqueeze(-1)
|
||||
else:
|
||||
raise ValueError("spembs should be 1D or 2D (batch mode) tensor.")
|
||||
if g is None:
|
||||
g = g_
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user