This commit is contained in:
jinzr 2023-11-30 22:11:44 +08:00
parent 100e622c18
commit 0e7f0a4ee9
2 changed files with 62 additions and 23 deletions

View File

@ -288,13 +288,19 @@ def get_model(params: AttributeDict) -> nn.Module:
return model return model
def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device): def prepare_input(
batch: dict,
tokenizer: Tokenizer,
device: torch.device,
speaker_map: Dict[str, int],
):
"""Parse batch data""" """Parse batch data"""
audio = batch["audio"].to(device) audio = batch["audio"].to(device)
features = batch["features"].to(device) features = batch["features"].to(device)
audio_lens = batch["audio_lens"].to(device) audio_lens = batch["audio_lens"].to(device)
features_lens = batch["features_lens"].to(device) features_lens = batch["features_lens"].to(device)
tokens = batch["tokens"] tokens = batch["tokens"]
speakers = torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]).to(device)
tokens = tokenizer.tokens_to_token_ids(tokens) tokens = tokenizer.tokens_to_token_ids(tokens)
tokens = k2.RaggedTensor(tokens) tokens = k2.RaggedTensor(tokens)
@ -305,7 +311,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
# a tensor of shape (B, T) # a tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
return audio, audio_lens, features, features_lens, tokens, tokens_lens return audio, audio_lens, features, features_lens, tokens, tokens_lens, speakers
def train_one_epoch( def train_one_epoch(
@ -318,6 +324,7 @@ def train_one_epoch(
scheduler_d: LRSchedulerType, scheduler_d: LRSchedulerType,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
speaker_map: Dict[str, int],
scaler: GradScaler, scaler: GradScaler,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
@ -384,9 +391,15 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["tokens"]) batch_size = len(batch["tokens"])
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input( (
batch, tokenizer, device audio,
) audio_lens,
features,
features_lens,
tokens,
tokens_lens,
speakers,
) = prepare_input(batch, tokenizer, device, speaker_map)
loss_info = MetricsTracker() loss_info = MetricsTracker()
loss_info["samples"] = batch_size loss_info["samples"] = batch_size
@ -401,6 +414,7 @@ def train_one_epoch(
feats_lengths=features_lens, feats_lengths=features_lens,
speech=audio, speech=audio,
speech_lengths=audio_lens, speech_lengths=audio_lens,
sid=speakers,
forward_generator=False, forward_generator=False,
) )
for k, v in stats_d.items(): for k, v in stats_d.items():
@ -419,6 +433,7 @@ def train_one_epoch(
feats_lengths=features_lens, feats_lengths=features_lens,
speech=audio, speech=audio,
speech_lengths=audio_lens, speech_lengths=audio_lens,
sid=speakers,
forward_generator=True, forward_generator=True,
return_sample=params.batch_idx_train % params.log_interval == 0, return_sample=params.batch_idx_train % params.log_interval == 0,
) )
@ -526,6 +541,7 @@ def train_one_epoch(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
valid_dl=valid_dl, valid_dl=valid_dl,
speaker_map=speaker_map,
world_size=world_size, world_size=world_size,
) )
model.train() model.train()
@ -562,6 +578,7 @@ def compute_validation_loss(
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
tokenizer: Tokenizer, tokenizer: Tokenizer,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
speaker_map: Dict[str, int],
world_size: int = 1, world_size: int = 1,
rank: int = 0, rank: int = 0,
) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]: ) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]:
@ -583,7 +600,8 @@ def compute_validation_loss(
features_lens, features_lens,
tokens, tokens,
tokens_lens, tokens_lens,
) = prepare_input(batch, tokenizer, device) speakers,
) = prepare_input(batch, tokenizer, device, speaker_map)
loss_info = MetricsTracker() loss_info = MetricsTracker()
loss_info["samples"] = batch_size loss_info["samples"] = batch_size
@ -596,6 +614,7 @@ def compute_validation_loss(
feats_lengths=features_lens, feats_lengths=features_lens,
speech=audio, speech=audio,
speech_lengths=audio_lens, speech_lengths=audio_lens,
sid=speakers,
forward_generator=False, forward_generator=False,
) )
assert loss_d.requires_grad is False assert loss_d.requires_grad is False
@ -610,6 +629,7 @@ def compute_validation_loss(
feats_lengths=features_lens, feats_lengths=features_lens,
speech=audio, speech=audio,
speech_lengths=audio_lens, speech_lengths=audio_lens,
sid=speakers,
forward_generator=True, forward_generator=True,
) )
assert loss_g.requires_grad is False assert loss_g.requires_grad is False
@ -653,6 +673,7 @@ def scan_pessimistic_batches_for_oom(
tokenizer: Tokenizer, tokenizer: Tokenizer,
optimizer_g: torch.optim.Optimizer, optimizer_g: torch.optim.Optimizer,
optimizer_d: torch.optim.Optimizer, optimizer_d: torch.optim.Optimizer,
speaker_map: Dict[str, int],
params: AttributeDict, params: AttributeDict,
): ):
from lhotse.dataset import find_pessimistic_batches from lhotse.dataset import find_pessimistic_batches
@ -664,9 +685,15 @@ def scan_pessimistic_batches_for_oom(
batches, crit_values = find_pessimistic_batches(train_dl.sampler) batches, crit_values = find_pessimistic_batches(train_dl.sampler)
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input( (
batch, tokenizer, device audio,
) audio_lens,
features,
features_lens,
tokens,
tokens_lens,
speakers,
) = prepare_input(batch, tokenizer, device, speaker_map)
try: try:
# for discriminator # for discriminator
with autocast(enabled=params.use_fp16): with autocast(enabled=params.use_fp16):
@ -677,6 +704,7 @@ def scan_pessimistic_batches_for_oom(
feats_lengths=features_lens, feats_lengths=features_lens,
speech=audio, speech=audio,
speech_lengths=audio_lens, speech_lengths=audio_lens,
sid=speakers,
forward_generator=False, forward_generator=False,
) )
optimizer_d.zero_grad() optimizer_d.zero_grad()
@ -690,6 +718,7 @@ def scan_pessimistic_batches_for_oom(
feats_lengths=features_lens, feats_lengths=features_lens,
speech=audio, speech=audio,
speech_lengths=audio_lens, speech_lengths=audio_lens,
sid=speakers,
forward_generator=True, forward_generator=True,
) )
optimizer_g.zero_grad() optimizer_g.zero_grad()
@ -803,9 +832,10 @@ def run(rank, world_size, args):
if params.inf_check: if params.inf_check:
register_inf_check_hooks(model) register_inf_check_hooks(model)
ljspeech = VctkTtsDataModule(args) vctk = VctkTtsDataModule(args)
train_cuts = ljspeech.train_cuts() train_cuts = vctk.train_cuts()
speaker_map = vctk.speakers()
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds # Keep only utterances with duration between 1 second and 20 seconds
@ -820,10 +850,10 @@ def run(rank, world_size, args):
return True return True
train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_utt)
train_dl = ljspeech.train_dataloaders(train_cuts) train_dl = vctk.train_dataloaders(train_cuts)
valid_cuts = ljspeech.valid_cuts() valid_cuts = vctk.valid_cuts()
valid_dl = ljspeech.valid_dataloaders(valid_cuts) valid_dl = vctk.valid_dataloaders(valid_cuts)
if not params.print_diagnostics: if not params.print_diagnostics:
scan_pessimistic_batches_for_oom( scan_pessimistic_batches_for_oom(
@ -832,6 +862,7 @@ def run(rank, world_size, args):
tokenizer=tokenizer, tokenizer=tokenizer,
optimizer_g=optimizer_g, optimizer_g=optimizer_g,
optimizer_d=optimizer_d, optimizer_d=optimizer_d,
speaker_map=speaker_map,
params=params, params=params,
) )
@ -861,6 +892,7 @@ def run(rank, world_size, args):
scheduler_d=scheduler_d, scheduler_d=scheduler_d,
train_dl=train_dl, train_dl=train_dl,
valid_dl=valid_dl, valid_dl=valid_dl,
speaker_map=speaker_map,
scaler=scaler, scaler=scaler,
tb_writer=tb_writer, tb_writer=tb_writer,
world_size=world_size, world_size=world_size,

View File

@ -88,6 +88,12 @@ class VctkTtsDataModule:
default=Path("data/spectrogram"), default=Path("data/spectrogram"),
help="Path to directory with train/valid/test cuts.", help="Path to directory with train/valid/test cuts.",
) )
group.add_argument(
"--speakers",
type=Path,
default=Path("data/speakers.txt"),
help="Path to speakers.txt file.",
)
group.add_argument( group.add_argument(
"--max-duration", "--max-duration",
type=int, type=int,
@ -306,20 +312,21 @@ class VctkTtsDataModule:
@lru_cache() @lru_cache()
def train_cuts(self) -> CutSet: def train_cuts(self) -> CutSet:
logging.info("About to get train cuts") logging.info("About to get train cuts")
return load_manifest_lazy( return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_train.jsonl.gz")
self.args.manifest_dir / "ljspeech_cuts_train.jsonl.gz"
)
@lru_cache() @lru_cache()
def valid_cuts(self) -> CutSet: def valid_cuts(self) -> CutSet:
logging.info("About to get validation cuts") logging.info("About to get validation cuts")
return load_manifest_lazy( return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_valid.jsonl.gz")
self.args.manifest_dir / "ljspeech_cuts_valid.jsonl.gz"
)
@lru_cache() @lru_cache()
def test_cuts(self) -> CutSet: def test_cuts(self) -> CutSet:
logging.info("About to get test cuts") logging.info("About to get test cuts")
return load_manifest_lazy( return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_test.jsonl.gz")
self.args.manifest_dir / "ljspeech_cuts_test.jsonl.gz"
) @lru_cache()
def speakers(self) -> Dict[str, int]:
logging.info("About to get speakers")
with open(self.args.speakers) as f:
speakers = {line.strip(): i for i, line in enumerate(f)}
return speakers