mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-16 20:52:18 +00:00
minor updates
This commit is contained in:
parent
ca3b495c4f
commit
3c3db1ae69
@ -152,7 +152,7 @@ def infer_dataset(
|
|||||||
audio_lens = batch["audio_lens"].tolist()
|
audio_lens = batch["audio_lens"].tolist()
|
||||||
cut_ids = [cut.id for cut in batch["cut"]]
|
cut_ids = [cut.id for cut in batch["cut"]]
|
||||||
sids = ["_".join(cut_id.split("_")[:2]) for cut_id in cut_ids]
|
sids = ["_".join(cut_id.split("_")[:2]) for cut_id in cut_ids]
|
||||||
speakers = (
|
spembs = (
|
||||||
torch.Tensor(np.array([speaker_map.read(sid) for sid in sids]))
|
torch.Tensor(np.array([speaker_map.read(sid) for sid in sids]))
|
||||||
.squeeze(1)
|
.squeeze(1)
|
||||||
.to(device)
|
.to(device)
|
||||||
@ -161,7 +161,7 @@ def infer_dataset(
|
|||||||
audio_pred, _, durations = model.inference_batch(
|
audio_pred, _, durations = model.inference_batch(
|
||||||
text=tokens,
|
text=tokens,
|
||||||
text_lengths=tokens_lens,
|
text_lengths=tokens_lens,
|
||||||
spembs=speakers,
|
spembs=spembs,
|
||||||
)
|
)
|
||||||
audio_pred = audio_pred.detach().cpu()
|
audio_pred = audio_pred.detach().cpu()
|
||||||
# convert to samples
|
# convert to samples
|
||||||
|
@ -344,7 +344,7 @@ def prepare_input(
|
|||||||
audio_lens = batch["audio_lens"].to(device)
|
audio_lens = batch["audio_lens"].to(device)
|
||||||
features_lens = batch["features_lens"].to(device)
|
features_lens = batch["features_lens"].to(device)
|
||||||
tokens = batch["tokens"]
|
tokens = batch["tokens"]
|
||||||
speakers = (
|
spembs = (
|
||||||
torch.Tensor(np.array([speaker_map.read(sid) for sid in parse_sids(batch)]))
|
torch.Tensor(np.array([speaker_map.read(sid) for sid in parse_sids(batch)]))
|
||||||
.squeeze(1)
|
.squeeze(1)
|
||||||
.to(device)
|
.to(device)
|
||||||
@ -361,7 +361,7 @@ def prepare_input(
|
|||||||
# a tensor of shape (B, T)
|
# a tensor of shape (B, T)
|
||||||
tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
|
tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
|
||||||
|
|
||||||
return audio, audio_lens, features, features_lens, tokens, tokens_lens, speakers
|
return audio, audio_lens, features, features_lens, tokens, tokens_lens, spembs
|
||||||
|
|
||||||
|
|
||||||
def train_one_epoch(
|
def train_one_epoch(
|
||||||
@ -449,7 +449,7 @@ def train_one_epoch(
|
|||||||
features_lens,
|
features_lens,
|
||||||
tokens,
|
tokens,
|
||||||
tokens_lens,
|
tokens_lens,
|
||||||
speakers,
|
spembs,
|
||||||
) = prepare_input(batch, tokenizer, device, train_speaker_map)
|
) = prepare_input(batch, tokenizer, device, train_speaker_map)
|
||||||
|
|
||||||
loss_info = MetricsTracker()
|
loss_info = MetricsTracker()
|
||||||
@ -465,7 +465,7 @@ def train_one_epoch(
|
|||||||
feats_lengths=features_lens,
|
feats_lengths=features_lens,
|
||||||
speech=audio,
|
speech=audio,
|
||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
spembs=speakers,
|
spembs=spembs,
|
||||||
forward_generator=False,
|
forward_generator=False,
|
||||||
)
|
)
|
||||||
for k, v in stats_d.items():
|
for k, v in stats_d.items():
|
||||||
@ -484,7 +484,7 @@ def train_one_epoch(
|
|||||||
feats_lengths=features_lens,
|
feats_lengths=features_lens,
|
||||||
speech=audio,
|
speech=audio,
|
||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
spembs=speakers,
|
spembs=spembs,
|
||||||
forward_generator=True,
|
forward_generator=True,
|
||||||
return_sample=params.batch_idx_train % params.log_interval == 0,
|
return_sample=params.batch_idx_train % params.log_interval == 0,
|
||||||
)
|
)
|
||||||
@ -651,7 +651,7 @@ def compute_validation_loss(
|
|||||||
features_lens,
|
features_lens,
|
||||||
tokens,
|
tokens,
|
||||||
tokens_lens,
|
tokens_lens,
|
||||||
speakers,
|
spembs,
|
||||||
) = prepare_input(batch, tokenizer, device, dev_speaker_map)
|
) = prepare_input(batch, tokenizer, device, dev_speaker_map)
|
||||||
|
|
||||||
loss_info = MetricsTracker()
|
loss_info = MetricsTracker()
|
||||||
@ -665,7 +665,7 @@ def compute_validation_loss(
|
|||||||
feats_lengths=features_lens,
|
feats_lengths=features_lens,
|
||||||
speech=audio,
|
speech=audio,
|
||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
spembs=speakers,
|
spembs=spembs,
|
||||||
forward_generator=False,
|
forward_generator=False,
|
||||||
)
|
)
|
||||||
assert loss_d.requires_grad is False
|
assert loss_d.requires_grad is False
|
||||||
@ -680,7 +680,7 @@ def compute_validation_loss(
|
|||||||
feats_lengths=features_lens,
|
feats_lengths=features_lens,
|
||||||
speech=audio,
|
speech=audio,
|
||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
spembs=speakers,
|
spembs=spembs,
|
||||||
forward_generator=True,
|
forward_generator=True,
|
||||||
)
|
)
|
||||||
assert loss_g.requires_grad is False
|
assert loss_g.requires_grad is False
|
||||||
@ -695,7 +695,7 @@ def compute_validation_loss(
|
|||||||
inner_model = model.module if isinstance(model, DDP) else model
|
inner_model = model.module if isinstance(model, DDP) else model
|
||||||
audio_pred, _, duration = inner_model.inference(
|
audio_pred, _, duration = inner_model.inference(
|
||||||
text=tokens[0, : tokens_lens[0].item()],
|
text=tokens[0, : tokens_lens[0].item()],
|
||||||
spembs=speakers[0],
|
spembs=spembs[0],
|
||||||
)
|
)
|
||||||
audio_pred = audio_pred.data.cpu().numpy()
|
audio_pred = audio_pred.data.cpu().numpy()
|
||||||
audio_len_pred = (
|
audio_len_pred = (
|
||||||
@ -744,7 +744,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
features_lens,
|
features_lens,
|
||||||
tokens,
|
tokens,
|
||||||
tokens_lens,
|
tokens_lens,
|
||||||
speakers,
|
spembs,
|
||||||
) = prepare_input(batch, tokenizer, device, train_speaker_map)
|
) = prepare_input(batch, tokenizer, device, train_speaker_map)
|
||||||
try:
|
try:
|
||||||
# for discriminator
|
# for discriminator
|
||||||
@ -756,7 +756,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
feats_lengths=features_lens,
|
feats_lengths=features_lens,
|
||||||
speech=audio,
|
speech=audio,
|
||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
spembs=speakers,
|
spembs=spembs,
|
||||||
forward_generator=False,
|
forward_generator=False,
|
||||||
)
|
)
|
||||||
optimizer_d.zero_grad()
|
optimizer_d.zero_grad()
|
||||||
@ -770,7 +770,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
feats_lengths=features_lens,
|
feats_lengths=features_lens,
|
||||||
speech=audio,
|
speech=audio,
|
||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
spembs=speakers,
|
spembs=spembs,
|
||||||
forward_generator=True,
|
forward_generator=True,
|
||||||
)
|
)
|
||||||
optimizer_g.zero_grad()
|
optimizer_g.zero_grad()
|
||||||
|
@ -409,7 +409,12 @@ class VITSGenerator(torch.nn.Module):
|
|||||||
g = self.global_emb(sids.view(-1)).unsqueeze(-1)
|
g = self.global_emb(sids.view(-1)).unsqueeze(-1)
|
||||||
if self.spk_embed_dim is not None:
|
if self.spk_embed_dim is not None:
|
||||||
# (B, global_channels, 1)
|
# (B, global_channels, 1)
|
||||||
g_ = self.spemb_proj(F.normalize(spembs.unsqueeze(0))).unsqueeze(-1)
|
if spembs.ndim == 2:
|
||||||
|
g_ = self.spemb_proj(F.normalize(spembs)).unsqueeze(-1)
|
||||||
|
elif spembs.ndim == 1:
|
||||||
|
g_ = self.spemb_proj(F.normalize(spembs.unsqueeze(0))).unsqueeze(-1)
|
||||||
|
else:
|
||||||
|
raise ValueError("spembs should be 1D or 2D (batch mode) tensor.")
|
||||||
if g is None:
|
if g is None:
|
||||||
g = g_
|
g = g_
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user