mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 14:44:18 +00:00
updated
This commit is contained in:
parent
100e622c18
commit
0e7f0a4ee9
@ -288,13 +288,19 @@ def get_model(params: AttributeDict) -> nn.Module:
|
||||
return model
|
||||
|
||||
|
||||
def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
|
||||
def prepare_input(
|
||||
batch: dict,
|
||||
tokenizer: Tokenizer,
|
||||
device: torch.device,
|
||||
speaker_map: Dict[str, int],
|
||||
):
|
||||
"""Parse batch data"""
|
||||
audio = batch["audio"].to(device)
|
||||
features = batch["features"].to(device)
|
||||
audio_lens = batch["audio_lens"].to(device)
|
||||
features_lens = batch["features_lens"].to(device)
|
||||
tokens = batch["tokens"]
|
||||
speakers = torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]).to(device)
|
||||
|
||||
tokens = tokenizer.tokens_to_token_ids(tokens)
|
||||
tokens = k2.RaggedTensor(tokens)
|
||||
@ -305,7 +311,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
|
||||
# a tensor of shape (B, T)
|
||||
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
|
||||
|
||||
return audio, audio_lens, features, features_lens, tokens, tokens_lens
|
||||
return audio, audio_lens, features, features_lens, tokens, tokens_lens, speakers
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
@ -318,6 +324,7 @@ def train_one_epoch(
|
||||
scheduler_d: LRSchedulerType,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
speaker_map: Dict[str, int],
|
||||
scaler: GradScaler,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -384,9 +391,15 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
|
||||
batch_size = len(batch["tokens"])
|
||||
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
|
||||
batch, tokenizer, device
|
||||
)
|
||||
(
|
||||
audio,
|
||||
audio_lens,
|
||||
features,
|
||||
features_lens,
|
||||
tokens,
|
||||
tokens_lens,
|
||||
speakers,
|
||||
) = prepare_input(batch, tokenizer, device, speaker_map)
|
||||
|
||||
loss_info = MetricsTracker()
|
||||
loss_info["samples"] = batch_size
|
||||
@ -401,6 +414,7 @@ def train_one_epoch(
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
sid=speakers,
|
||||
forward_generator=False,
|
||||
)
|
||||
for k, v in stats_d.items():
|
||||
@ -419,6 +433,7 @@ def train_one_epoch(
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
sid=speakers,
|
||||
forward_generator=True,
|
||||
return_sample=params.batch_idx_train % params.log_interval == 0,
|
||||
)
|
||||
@ -526,6 +541,7 @@ def train_one_epoch(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
valid_dl=valid_dl,
|
||||
speaker_map=speaker_map,
|
||||
world_size=world_size,
|
||||
)
|
||||
model.train()
|
||||
@ -562,6 +578,7 @@ def compute_validation_loss(
|
||||
model: Union[nn.Module, DDP],
|
||||
tokenizer: Tokenizer,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
speaker_map: Dict[str, int],
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]:
|
||||
@ -583,7 +600,8 @@ def compute_validation_loss(
|
||||
features_lens,
|
||||
tokens,
|
||||
tokens_lens,
|
||||
) = prepare_input(batch, tokenizer, device)
|
||||
speakers,
|
||||
) = prepare_input(batch, tokenizer, device, speaker_map)
|
||||
|
||||
loss_info = MetricsTracker()
|
||||
loss_info["samples"] = batch_size
|
||||
@ -596,6 +614,7 @@ def compute_validation_loss(
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
sid=speakers,
|
||||
forward_generator=False,
|
||||
)
|
||||
assert loss_d.requires_grad is False
|
||||
@ -610,6 +629,7 @@ def compute_validation_loss(
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
sid=speakers,
|
||||
forward_generator=True,
|
||||
)
|
||||
assert loss_g.requires_grad is False
|
||||
@ -653,6 +673,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
tokenizer: Tokenizer,
|
||||
optimizer_g: torch.optim.Optimizer,
|
||||
optimizer_d: torch.optim.Optimizer,
|
||||
speaker_map: Dict[str, int],
|
||||
params: AttributeDict,
|
||||
):
|
||||
from lhotse.dataset import find_pessimistic_batches
|
||||
@ -664,9 +685,15 @@ def scan_pessimistic_batches_for_oom(
|
||||
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
|
||||
batch, tokenizer, device
|
||||
)
|
||||
(
|
||||
audio,
|
||||
audio_lens,
|
||||
features,
|
||||
features_lens,
|
||||
tokens,
|
||||
tokens_lens,
|
||||
speakers,
|
||||
) = prepare_input(batch, tokenizer, device, speaker_map)
|
||||
try:
|
||||
# for discriminator
|
||||
with autocast(enabled=params.use_fp16):
|
||||
@ -677,6 +704,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
sid=speakers,
|
||||
forward_generator=False,
|
||||
)
|
||||
optimizer_d.zero_grad()
|
||||
@ -690,6 +718,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
sid=speakers,
|
||||
forward_generator=True,
|
||||
)
|
||||
optimizer_g.zero_grad()
|
||||
@ -803,9 +832,10 @@ def run(rank, world_size, args):
|
||||
if params.inf_check:
|
||||
register_inf_check_hooks(model)
|
||||
|
||||
ljspeech = VctkTtsDataModule(args)
|
||||
vctk = VctkTtsDataModule(args)
|
||||
|
||||
train_cuts = ljspeech.train_cuts()
|
||||
train_cuts = vctk.train_cuts()
|
||||
speaker_map = vctk.speakers()
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
# Keep only utterances with duration between 1 second and 20 seconds
|
||||
@ -820,10 +850,10 @@ def run(rank, world_size, args):
|
||||
return True
|
||||
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
train_dl = ljspeech.train_dataloaders(train_cuts)
|
||||
train_dl = vctk.train_dataloaders(train_cuts)
|
||||
|
||||
valid_cuts = ljspeech.valid_cuts()
|
||||
valid_dl = ljspeech.valid_dataloaders(valid_cuts)
|
||||
valid_cuts = vctk.valid_cuts()
|
||||
valid_dl = vctk.valid_dataloaders(valid_cuts)
|
||||
|
||||
if not params.print_diagnostics:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
@ -832,6 +862,7 @@ def run(rank, world_size, args):
|
||||
tokenizer=tokenizer,
|
||||
optimizer_g=optimizer_g,
|
||||
optimizer_d=optimizer_d,
|
||||
speaker_map=speaker_map,
|
||||
params=params,
|
||||
)
|
||||
|
||||
@ -861,6 +892,7 @@ def run(rank, world_size, args):
|
||||
scheduler_d=scheduler_d,
|
||||
train_dl=train_dl,
|
||||
valid_dl=valid_dl,
|
||||
speaker_map=speaker_map,
|
||||
scaler=scaler,
|
||||
tb_writer=tb_writer,
|
||||
world_size=world_size,
|
||||
|
@ -88,6 +88,12 @@ class VctkTtsDataModule:
|
||||
default=Path("data/spectrogram"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--speakers",
|
||||
type=Path,
|
||||
default=Path("data/speakers.txt"),
|
||||
help="Path to speakers.txt file.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
@ -306,20 +312,21 @@ class VctkTtsDataModule:
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "ljspeech_cuts_train.jsonl.gz"
|
||||
)
|
||||
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_train.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def valid_cuts(self) -> CutSet:
|
||||
logging.info("About to get validation cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "ljspeech_cuts_valid.jsonl.gz"
|
||||
)
|
||||
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_valid.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts(self) -> CutSet:
|
||||
logging.info("About to get test cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "ljspeech_cuts_test.jsonl.gz"
|
||||
)
|
||||
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_test.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def speakers(self) -> Dict[str, int]:
|
||||
logging.info("About to get speakers")
|
||||
with open(self.args.speakers) as f:
|
||||
speakers = {line.strip(): i for i, line in enumerate(f)}
|
||||
return speakers
|
||||
|
Loading…
x
Reference in New Issue
Block a user