mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 16:14:17 +00:00
updated
This commit is contained in:
parent
100e622c18
commit
0e7f0a4ee9
@ -288,13 +288,19 @@ def get_model(params: AttributeDict) -> nn.Module:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
|
def prepare_input(
|
||||||
|
batch: dict,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
device: torch.device,
|
||||||
|
speaker_map: Dict[str, int],
|
||||||
|
):
|
||||||
"""Parse batch data"""
|
"""Parse batch data"""
|
||||||
audio = batch["audio"].to(device)
|
audio = batch["audio"].to(device)
|
||||||
features = batch["features"].to(device)
|
features = batch["features"].to(device)
|
||||||
audio_lens = batch["audio_lens"].to(device)
|
audio_lens = batch["audio_lens"].to(device)
|
||||||
features_lens = batch["features_lens"].to(device)
|
features_lens = batch["features_lens"].to(device)
|
||||||
tokens = batch["tokens"]
|
tokens = batch["tokens"]
|
||||||
|
speakers = torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]).to(device)
|
||||||
|
|
||||||
tokens = tokenizer.tokens_to_token_ids(tokens)
|
tokens = tokenizer.tokens_to_token_ids(tokens)
|
||||||
tokens = k2.RaggedTensor(tokens)
|
tokens = k2.RaggedTensor(tokens)
|
||||||
@ -305,7 +311,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
|
|||||||
# a tensor of shape (B, T)
|
# a tensor of shape (B, T)
|
||||||
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
|
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
|
||||||
|
|
||||||
return audio, audio_lens, features, features_lens, tokens, tokens_lens
|
return audio, audio_lens, features, features_lens, tokens, tokens_lens, speakers
|
||||||
|
|
||||||
|
|
||||||
def train_one_epoch(
|
def train_one_epoch(
|
||||||
@ -318,6 +324,7 @@ def train_one_epoch(
|
|||||||
scheduler_d: LRSchedulerType,
|
scheduler_d: LRSchedulerType,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
|
speaker_map: Dict[str, int],
|
||||||
scaler: GradScaler,
|
scaler: GradScaler,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
@ -384,9 +391,15 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
|
|
||||||
batch_size = len(batch["tokens"])
|
batch_size = len(batch["tokens"])
|
||||||
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
|
(
|
||||||
batch, tokenizer, device
|
audio,
|
||||||
)
|
audio_lens,
|
||||||
|
features,
|
||||||
|
features_lens,
|
||||||
|
tokens,
|
||||||
|
tokens_lens,
|
||||||
|
speakers,
|
||||||
|
) = prepare_input(batch, tokenizer, device, speaker_map)
|
||||||
|
|
||||||
loss_info = MetricsTracker()
|
loss_info = MetricsTracker()
|
||||||
loss_info["samples"] = batch_size
|
loss_info["samples"] = batch_size
|
||||||
@ -401,6 +414,7 @@ def train_one_epoch(
|
|||||||
feats_lengths=features_lens,
|
feats_lengths=features_lens,
|
||||||
speech=audio,
|
speech=audio,
|
||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
|
sid=speakers,
|
||||||
forward_generator=False,
|
forward_generator=False,
|
||||||
)
|
)
|
||||||
for k, v in stats_d.items():
|
for k, v in stats_d.items():
|
||||||
@ -419,6 +433,7 @@ def train_one_epoch(
|
|||||||
feats_lengths=features_lens,
|
feats_lengths=features_lens,
|
||||||
speech=audio,
|
speech=audio,
|
||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
|
sid=speakers,
|
||||||
forward_generator=True,
|
forward_generator=True,
|
||||||
return_sample=params.batch_idx_train % params.log_interval == 0,
|
return_sample=params.batch_idx_train % params.log_interval == 0,
|
||||||
)
|
)
|
||||||
@ -526,6 +541,7 @@ def train_one_epoch(
|
|||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
|
speaker_map=speaker_map,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
model.train()
|
model.train()
|
||||||
@ -562,6 +578,7 @@ def compute_validation_loss(
|
|||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
|
speaker_map: Dict[str, int],
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]:
|
) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]:
|
||||||
@ -583,7 +600,8 @@ def compute_validation_loss(
|
|||||||
features_lens,
|
features_lens,
|
||||||
tokens,
|
tokens,
|
||||||
tokens_lens,
|
tokens_lens,
|
||||||
) = prepare_input(batch, tokenizer, device)
|
speakers,
|
||||||
|
) = prepare_input(batch, tokenizer, device, speaker_map)
|
||||||
|
|
||||||
loss_info = MetricsTracker()
|
loss_info = MetricsTracker()
|
||||||
loss_info["samples"] = batch_size
|
loss_info["samples"] = batch_size
|
||||||
@ -596,6 +614,7 @@ def compute_validation_loss(
|
|||||||
feats_lengths=features_lens,
|
feats_lengths=features_lens,
|
||||||
speech=audio,
|
speech=audio,
|
||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
|
sid=speakers,
|
||||||
forward_generator=False,
|
forward_generator=False,
|
||||||
)
|
)
|
||||||
assert loss_d.requires_grad is False
|
assert loss_d.requires_grad is False
|
||||||
@ -610,6 +629,7 @@ def compute_validation_loss(
|
|||||||
feats_lengths=features_lens,
|
feats_lengths=features_lens,
|
||||||
speech=audio,
|
speech=audio,
|
||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
|
sid=speakers,
|
||||||
forward_generator=True,
|
forward_generator=True,
|
||||||
)
|
)
|
||||||
assert loss_g.requires_grad is False
|
assert loss_g.requires_grad is False
|
||||||
@ -653,6 +673,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
optimizer_g: torch.optim.Optimizer,
|
optimizer_g: torch.optim.Optimizer,
|
||||||
optimizer_d: torch.optim.Optimizer,
|
optimizer_d: torch.optim.Optimizer,
|
||||||
|
speaker_map: Dict[str, int],
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
):
|
):
|
||||||
from lhotse.dataset import find_pessimistic_batches
|
from lhotse.dataset import find_pessimistic_batches
|
||||||
@ -664,9 +685,15 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
|
(
|
||||||
batch, tokenizer, device
|
audio,
|
||||||
)
|
audio_lens,
|
||||||
|
features,
|
||||||
|
features_lens,
|
||||||
|
tokens,
|
||||||
|
tokens_lens,
|
||||||
|
speakers,
|
||||||
|
) = prepare_input(batch, tokenizer, device, speaker_map)
|
||||||
try:
|
try:
|
||||||
# for discriminator
|
# for discriminator
|
||||||
with autocast(enabled=params.use_fp16):
|
with autocast(enabled=params.use_fp16):
|
||||||
@ -677,6 +704,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
feats_lengths=features_lens,
|
feats_lengths=features_lens,
|
||||||
speech=audio,
|
speech=audio,
|
||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
|
sid=speakers,
|
||||||
forward_generator=False,
|
forward_generator=False,
|
||||||
)
|
)
|
||||||
optimizer_d.zero_grad()
|
optimizer_d.zero_grad()
|
||||||
@ -690,6 +718,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
feats_lengths=features_lens,
|
feats_lengths=features_lens,
|
||||||
speech=audio,
|
speech=audio,
|
||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
|
sid=speakers,
|
||||||
forward_generator=True,
|
forward_generator=True,
|
||||||
)
|
)
|
||||||
optimizer_g.zero_grad()
|
optimizer_g.zero_grad()
|
||||||
@ -803,9 +832,10 @@ def run(rank, world_size, args):
|
|||||||
if params.inf_check:
|
if params.inf_check:
|
||||||
register_inf_check_hooks(model)
|
register_inf_check_hooks(model)
|
||||||
|
|
||||||
ljspeech = VctkTtsDataModule(args)
|
vctk = VctkTtsDataModule(args)
|
||||||
|
|
||||||
train_cuts = ljspeech.train_cuts()
|
train_cuts = vctk.train_cuts()
|
||||||
|
speaker_map = vctk.speakers()
|
||||||
|
|
||||||
def remove_short_and_long_utt(c: Cut):
|
def remove_short_and_long_utt(c: Cut):
|
||||||
# Keep only utterances with duration between 1 second and 20 seconds
|
# Keep only utterances with duration between 1 second and 20 seconds
|
||||||
@ -820,10 +850,10 @@ def run(rank, world_size, args):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
train_dl = ljspeech.train_dataloaders(train_cuts)
|
train_dl = vctk.train_dataloaders(train_cuts)
|
||||||
|
|
||||||
valid_cuts = ljspeech.valid_cuts()
|
valid_cuts = vctk.valid_cuts()
|
||||||
valid_dl = ljspeech.valid_dataloaders(valid_cuts)
|
valid_dl = vctk.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if not params.print_diagnostics:
|
if not params.print_diagnostics:
|
||||||
scan_pessimistic_batches_for_oom(
|
scan_pessimistic_batches_for_oom(
|
||||||
@ -832,6 +862,7 @@ def run(rank, world_size, args):
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
optimizer_g=optimizer_g,
|
optimizer_g=optimizer_g,
|
||||||
optimizer_d=optimizer_d,
|
optimizer_d=optimizer_d,
|
||||||
|
speaker_map=speaker_map,
|
||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -861,6 +892,7 @@ def run(rank, world_size, args):
|
|||||||
scheduler_d=scheduler_d,
|
scheduler_d=scheduler_d,
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
|
speaker_map=speaker_map,
|
||||||
scaler=scaler,
|
scaler=scaler,
|
||||||
tb_writer=tb_writer,
|
tb_writer=tb_writer,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
|
@ -88,6 +88,12 @@ class VctkTtsDataModule:
|
|||||||
default=Path("data/spectrogram"),
|
default=Path("data/spectrogram"),
|
||||||
help="Path to directory with train/valid/test cuts.",
|
help="Path to directory with train/valid/test cuts.",
|
||||||
)
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--speakers",
|
||||||
|
type=Path,
|
||||||
|
default=Path("data/speakers.txt"),
|
||||||
|
help="Path to speakers.txt file.",
|
||||||
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--max-duration",
|
"--max-duration",
|
||||||
type=int,
|
type=int,
|
||||||
@ -306,20 +312,21 @@ class VctkTtsDataModule:
|
|||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_cuts(self) -> CutSet:
|
def train_cuts(self) -> CutSet:
|
||||||
logging.info("About to get train cuts")
|
logging.info("About to get train cuts")
|
||||||
return load_manifest_lazy(
|
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_train.jsonl.gz")
|
||||||
self.args.manifest_dir / "ljspeech_cuts_train.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def valid_cuts(self) -> CutSet:
|
def valid_cuts(self) -> CutSet:
|
||||||
logging.info("About to get validation cuts")
|
logging.info("About to get validation cuts")
|
||||||
return load_manifest_lazy(
|
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_valid.jsonl.gz")
|
||||||
self.args.manifest_dir / "ljspeech_cuts_valid.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def test_cuts(self) -> CutSet:
|
def test_cuts(self) -> CutSet:
|
||||||
logging.info("About to get test cuts")
|
logging.info("About to get test cuts")
|
||||||
return load_manifest_lazy(
|
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_test.jsonl.gz")
|
||||||
self.args.manifest_dir / "ljspeech_cuts_test.jsonl.gz"
|
|
||||||
)
|
@lru_cache()
|
||||||
|
def speakers(self) -> Dict[str, int]:
|
||||||
|
logging.info("About to get speakers")
|
||||||
|
with open(self.args.speakers) as f:
|
||||||
|
speakers = {line.strip(): i for i, line in enumerate(f)}
|
||||||
|
return speakers
|
||||||
|
Loading…
x
Reference in New Issue
Block a user