diff --git a/egs/librispeech/ASR/conformer_ctc/ali.py b/egs/librispeech/ASR/conformer_ctc/ali.py index 8d779e850..07390f7e7 100755 --- a/egs/librispeech/ASR/conformer_ctc/ali.py +++ b/egs/librispeech/ASR/conformer_ctc/ali.py @@ -18,6 +18,7 @@ import argparse import logging from pathlib import Path +from typing import List, Tuple import k2 import torch @@ -32,6 +33,7 @@ from icefall.utils import ( AttributeDict, encode_supervisions, get_alignments, + save_alignments, setup_logger, ) @@ -56,14 +58,33 @@ def get_parser(): "consecutive checkpoints before the checkpoint specified by " "'--epoch'. ", ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe", + help="The lang dir", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--ali-dir", + type=str, + default="data/ali", + help="The experiment dir", + ) return parser def get_params() -> AttributeDict: params = AttributeDict( { - "exp_dir": Path("conformer_ctc/exp"), - "lang_dir": Path("data/lang_bpe"), "lm_dir": Path("data/lm"), "feature_dim": 80, "nhead": 8, @@ -71,8 +92,6 @@ def get_params() -> AttributeDict: "subsampling_factor": 4, "num_decoder_layers": 6, "vgg_frontend": False, - "is_espnet_structure": True, - "mmi_loss": False, "use_feat_batchnorm": True, "output_beam": 10, "use_double_scores": True, @@ -86,9 +105,31 @@ def compute_alignments( dl: torch.utils.data.DataLoader, params: AttributeDict, graph_compiler: BpeCtcTrainingGraphCompiler, - token_table: k2.SymbolTable, -): +) -> List[Tuple[str, List[int]]]: + """Compute the framewise alignments of a dataset. + + Args: + model: + The neural network model. + dl: + Dataloader containing the dataset. + params: + Parameters for computing alignments. + graph_compiler: + It converts token IDs to decoding graphs. + Returns: + Return a list of tuples. Each tuple contains two entries: + - Utterance ID + - Framewise alignments (token IDs) after subsampling + """ + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + num_cuts = 0 + device = graph_compiler.device + ans = [] for batch_idx, batch in enumerate(dl): feature = batch["inputs"] @@ -97,11 +138,23 @@ def compute_alignments( feature = feature.to(device) supervisions = batch["supervisions"] + + cut_ids = [] + for cut in supervisions["cut"]: + assert len(cut.supervisions) == 1 + cut_ids.append(cut.id) + nnet_output, encoder_memory, memory_mask = model(feature, supervisions) # nnet_output is [N, T, C] supervision_segments, texts = encode_supervisions( supervisions, subsampling_factor=params.subsampling_factor ) + # we need also to sort cut_ids as encode_supervisions() + # reorders "texts". + # In general, new2old is an identity map since lhotse sorts the returned + # cuts by duration in descending order + new2old = supervision_segments[:, 0].tolist() + cut_ids = [cut_ids[i] for i in new2old] token_ids = graph_compiler.texts_to_ids(texts) decoding_graph = graph_compiler.compile(token_ids) @@ -113,22 +166,30 @@ def compute_alignments( ) lattice = k2.intersect_dense( - decoding_graph, dense_fsa_vec, params.output_beam + decoding_graph, + dense_fsa_vec, + params.output_beam, ) best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores + lattice=lattice, + use_double_scores=params.use_double_scores, ) ali_ids = get_alignments(best_path) - ali_tokens = [[token_table[i] for i in ids] for ids in ali_ids] + assert len(ali_ids) == len(cut_ids) + ans += list(zip(cut_ids, ali_ids)) - frame_shift = 0.01 # 10ms, i.e., 0.01 seconds - for i, ali in enumerate(ali_tokens[0]): - print(i * params.subsampling_factor * frame_shift, ali) - import sys + num_cuts += len(ali_ids) - sys.exit(0) + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + + return ans @torch.no_grad() @@ -138,6 +199,7 @@ def main(): args = parser.parse_args() assert args.return_cuts is True + assert args.concatenate_cuts is False params = get_params() params.update(vars(args)) @@ -169,9 +231,7 @@ def main(): num_classes=num_classes, subsampling_factor=params.subsampling_factor, num_decoder_layers=params.num_decoder_layers, - vgg_frontend=False, - is_espnet_structure=params.is_espnet_structure, - mmi_loss=params.mmi_loss, + vgg_frontend=params.vgg_frontend, use_feat_batchnorm=params.use_feat_batchnorm, ) @@ -190,20 +250,40 @@ def main(): model.eval() librispeech = LibriSpeechAsrDataModule(args) + + train_dl = librispeech.train_dataloaders() + valid_dl = librispeech.valid_dataloaders() test_dl = librispeech.test_dataloaders() # a list + ali_dir = Path(params.ali_dir) + ali_dir.mkdir(exist_ok=True) + enabled_datasets = { "test_clean": test_dl[0], "test_other": test_dl[1], + "train-960": train_dl, + "valid": valid_dl, } - - compute_alignments( - model=model, - dl=enabled_datasets["test_clean"], - params=params, - graph_compiler=graph_compiler, - token_table=lexicon.token_table, - ) + for name, dl in enabled_datasets.items(): + logging.info(f"Processing {name}") + alignments = compute_alignments( + model=model, + dl=dl, + params=params, + graph_compiler=graph_compiler, + ) + num_utt = len(alignments) + alignments = dict(alignments) + assert num_utt == len(alignments) + filename = ali_dir / f"{name}.pt" + save_alignments( + alignments=alignments, + subsampling_factor=params.subsampling_factor, + filename=filename, + ) + logging.info( + f"For dataset {name}, its alignments are saved to {filename}" + ) torch.set_num_threads(1) diff --git a/icefall/utils.py b/icefall/utils.py index 36312f9af..cb6cc17c5 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -325,6 +325,51 @@ def get_alignments(best_paths: k2.Fsa) -> List[List[int]]: return labels.tolist() +def save_alignments( + alignments: Dict[str, List[int]], + subsampling_factor: int, + filename: str, +) -> None: + """Save alignments to a file. + + Args: + alignments: + A dict containing alignments. Keys of the dict are utterances and + values are the corresponding framewise alignments after subsampling. + subsampling_factor: + The subsampling factor of the model. + filename: + Path to save the alignments. + Returns: + Return None. + """ + ali_dict = { + "subsampling_factor": subsampling_factor, + "alignments": alignments, + } + torch.save(ali_dict, filename) + + +def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]: + """Load alignments from a file. + + Args: + filename: + Path to the file containing alignment information. + The file should be saved by :func:`save_alignments`. + Returns: + Return a tuple containing: + - subsampling_factor: The subsampling_factor used to compute + the alignments. + - alignments: A dict containing utterances and their corresponding + framewise alignment, after subsampling. + """ + ali_dict = torch.load(filename) + subsampling_factor = ali_dict["subsampling_factor"] + alignments = ali_dict["alignments"] + return subsampling_factor, alignments + + def store_transcripts( filename: Pathlike, texts: Iterable[Tuple[str, str]] ) -> None: