This commit is contained in:
jinzr 2023-11-30 22:11:44 +08:00
parent 100e622c18
commit 0e7f0a4ee9
2 changed files with 62 additions and 23 deletions

View File

@ -288,13 +288,19 @@ def get_model(params: AttributeDict) -> nn.Module:
return model
def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
def prepare_input(
batch: dict,
tokenizer: Tokenizer,
device: torch.device,
speaker_map: Dict[str, int],
):
"""Parse batch data"""
audio = batch["audio"].to(device)
features = batch["features"].to(device)
audio_lens = batch["audio_lens"].to(device)
features_lens = batch["features_lens"].to(device)
tokens = batch["tokens"]
speakers = torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]).to(device)
tokens = tokenizer.tokens_to_token_ids(tokens)
tokens = k2.RaggedTensor(tokens)
@ -305,7 +311,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
# a tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
return audio, audio_lens, features, features_lens, tokens, tokens_lens
return audio, audio_lens, features, features_lens, tokens, tokens_lens, speakers
def train_one_epoch(
@ -318,6 +324,7 @@ def train_one_epoch(
scheduler_d: LRSchedulerType,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
speaker_map: Dict[str, int],
scaler: GradScaler,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@ -384,9 +391,15 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["tokens"])
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
batch, tokenizer, device
)
(
audio,
audio_lens,
features,
features_lens,
tokens,
tokens_lens,
speakers,
) = prepare_input(batch, tokenizer, device, speaker_map)
loss_info = MetricsTracker()
loss_info["samples"] = batch_size
@ -401,6 +414,7 @@ def train_one_epoch(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
sid=speakers,
forward_generator=False,
)
for k, v in stats_d.items():
@ -419,6 +433,7 @@ def train_one_epoch(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
sid=speakers,
forward_generator=True,
return_sample=params.batch_idx_train % params.log_interval == 0,
)
@ -526,6 +541,7 @@ def train_one_epoch(
model=model,
tokenizer=tokenizer,
valid_dl=valid_dl,
speaker_map=speaker_map,
world_size=world_size,
)
model.train()
@ -562,6 +578,7 @@ def compute_validation_loss(
model: Union[nn.Module, DDP],
tokenizer: Tokenizer,
valid_dl: torch.utils.data.DataLoader,
speaker_map: Dict[str, int],
world_size: int = 1,
rank: int = 0,
) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]:
@ -583,7 +600,8 @@ def compute_validation_loss(
features_lens,
tokens,
tokens_lens,
) = prepare_input(batch, tokenizer, device)
speakers,
) = prepare_input(batch, tokenizer, device, speaker_map)
loss_info = MetricsTracker()
loss_info["samples"] = batch_size
@ -596,6 +614,7 @@ def compute_validation_loss(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
sid=speakers,
forward_generator=False,
)
assert loss_d.requires_grad is False
@ -610,6 +629,7 @@ def compute_validation_loss(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
sid=speakers,
forward_generator=True,
)
assert loss_g.requires_grad is False
@ -653,6 +673,7 @@ def scan_pessimistic_batches_for_oom(
tokenizer: Tokenizer,
optimizer_g: torch.optim.Optimizer,
optimizer_d: torch.optim.Optimizer,
speaker_map: Dict[str, int],
params: AttributeDict,
):
from lhotse.dataset import find_pessimistic_batches
@ -664,9 +685,15 @@ def scan_pessimistic_batches_for_oom(
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
batch, tokenizer, device
)
(
audio,
audio_lens,
features,
features_lens,
tokens,
tokens_lens,
speakers,
) = prepare_input(batch, tokenizer, device, speaker_map)
try:
# for discriminator
with autocast(enabled=params.use_fp16):
@ -677,6 +704,7 @@ def scan_pessimistic_batches_for_oom(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
sid=speakers,
forward_generator=False,
)
optimizer_d.zero_grad()
@ -690,6 +718,7 @@ def scan_pessimistic_batches_for_oom(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
sid=speakers,
forward_generator=True,
)
optimizer_g.zero_grad()
@ -803,9 +832,10 @@ def run(rank, world_size, args):
if params.inf_check:
register_inf_check_hooks(model)
ljspeech = VctkTtsDataModule(args)
vctk = VctkTtsDataModule(args)
train_cuts = ljspeech.train_cuts()
train_cuts = vctk.train_cuts()
speaker_map = vctk.speakers()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
@ -820,10 +850,10 @@ def run(rank, world_size, args):
return True
train_cuts = train_cuts.filter(remove_short_and_long_utt)
train_dl = ljspeech.train_dataloaders(train_cuts)
train_dl = vctk.train_dataloaders(train_cuts)
valid_cuts = ljspeech.valid_cuts()
valid_dl = ljspeech.valid_dataloaders(valid_cuts)
valid_cuts = vctk.valid_cuts()
valid_dl = vctk.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
@ -832,6 +862,7 @@ def run(rank, world_size, args):
tokenizer=tokenizer,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
speaker_map=speaker_map,
params=params,
)
@ -861,6 +892,7 @@ def run(rank, world_size, args):
scheduler_d=scheduler_d,
train_dl=train_dl,
valid_dl=valid_dl,
speaker_map=speaker_map,
scaler=scaler,
tb_writer=tb_writer,
world_size=world_size,

View File

@ -88,6 +88,12 @@ class VctkTtsDataModule:
default=Path("data/spectrogram"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--speakers",
type=Path,
default=Path("data/speakers.txt"),
help="Path to speakers.txt file.",
)
group.add_argument(
"--max-duration",
type=int,
@ -306,20 +312,21 @@ class VctkTtsDataModule:
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "ljspeech_cuts_train.jsonl.gz"
)
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_train.jsonl.gz")
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get validation cuts")
return load_manifest_lazy(
self.args.manifest_dir / "ljspeech_cuts_valid.jsonl.gz"
)
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_valid.jsonl.gz")
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "ljspeech_cuts_test.jsonl.gz"
)
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_test.jsonl.gz")
@lru_cache()
def speakers(self) -> Dict[str, int]:
logging.info("About to get speakers")
with open(self.args.speakers) as f:
speakers = {line.strip(): i for i, line in enumerate(f)}
return speakers