diff --git a/egs/librispeech/ASR/conformer_ctc/ali.py b/egs/librispeech/ASR/conformer_ctc/ali.py index 25a8fdfc5..7bd97dfe7 100755 --- a/egs/librispeech/ASR/conformer_ctc/ali.py +++ b/egs/librispeech/ASR/conformer_ctc/ali.py @@ -24,13 +24,12 @@ Usage: --avg 10 \ --max-duration 300 \ --dataset train-clean-100 \ - --out-dir data/token-ali + --out-dir data/ali """ import argparse import logging from pathlib import Path -from typing import List, Tuple import k2 import numpy as np @@ -49,7 +48,6 @@ from icefall.utils import ( encode_supervisions, get_alignments, get_env_info, - save_alignments, setup_logger, ) @@ -94,16 +92,21 @@ def get_parser(): type=str, required=True, help="""Output directory. - It contains the following generated files: + It contains 3 generated files: - - xxx.h5 + - labels_xxx.h5 + - aux_labels_xxx.h5 - cuts_xxx.json.gz where xxx is the value of `--dataset`. For instance, if - `--dataset` is `train-clean-100`, it will contain two files: + `--dataset` is `train-clean-100`, it will contain 3 files: - - `train-clean-100.h5` + - `labels_train-clean-100.h5` + - `aux_labels_train-clean-100.h5` - `cuts_train-clean-100.json.gz` + + Note: Both labels_xxx.h5 and aux_labels_xxx.h5 contain framewise + alignment. The difference is that labels_xxx.h5 contains repeats. """, ) @@ -149,7 +152,8 @@ def get_params() -> AttributeDict: def compute_alignments( model: torch.nn.Module, dl: torch.utils.data.DataLoader, - writer: FeaturesWriter, + labels_writer: FeaturesWriter, + aux_labels_writer: FeaturesWriter, params: AttributeDict, graph_compiler: BpeCtcTrainingGraphCompiler, ) -> CutSet: @@ -165,8 +169,10 @@ def compute_alignments( graph_compiler: It converts token IDs to decoding graphs. Returns: - Return a CutSet. Each cut has a custom field `token_alignment` - of type `lhotse.array.TemporalArray`. + Return a CutSet. Each cut has two custom fields: labels_alignment + and aux_labels_alignment, containing framewise alignments information. + Both are of type `lhotse.array.TemporalArray`. The difference between + the two alignments is that `labels_alignment` contain repeats. """ try: num_batches = len(dl) @@ -204,7 +210,6 @@ def compute_alignments( token_ids = graph_compiler.texts_to_ids(texts) decoding_graph = graph_compiler.compile(token_ids) - decoding_graph.tokens = decoding_graph.aux_labels.clone() dense_fsa_vec = k2.DenseFsaVec( nnet_output, @@ -223,21 +228,32 @@ def compute_alignments( use_double_scores=params.use_double_scores, ) - ali_ids = get_alignments(best_path) - assert len(ali_ids) == len(cut_list) - for cut, ali in zip(cut_list, ali_ids): - - cut.token_alignment = writer.store_array( + labels_ali = get_alignments(best_path, kind="labels") + aux_labels_ali = get_alignments(best_path, kind="aux_labels") + assert len(labels_ali) == len(aux_labels_ali) == len(cut_list) + for cut, labels, aux_labels in zip( + cut_list, labels_ali, aux_labels_ali + ): + cut.labels_alignment = labels_writer.store_array( key=cut.id, - value=np.asarray(ali, dtype=np.int32), - frame_shift=0.04, # frame shift is 0.01s, subsampling_factor is 4 + value=np.asarray(labels, dtype=np.int32), + # frame shift is 0.01s, subsampling_factor is 4 + frame_shift=0.04, + temporal_dim=0, + start=0, + ) + cut.aux_labels_alignment = aux_labels_writer.store_array( + key=cut.id, + value=np.asarray(aux_labels, dtype=np.int32), + # frame shift is 0.01s, subsampling_factor is 4 + frame_shift=0.04, temporal_dim=0, start=0, ) cuts += cut_list - num_cuts += len(ali_ids) + num_cuts += len(cut_list) if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" @@ -271,15 +287,18 @@ def main(): out_dir = Path(params.out_dir) out_dir.mkdir(exist_ok=True) - out_ali_filename = out_dir / f"{params.dataset}.h5" - if out_ali_filename.exists(): - logging.info(f"{out_ali_filename} exists - skipping") - return - + out_labels_ali_filename = out_dir / f"labels_{params.dataset}.h5" + out_aux_labels_ali_filename = out_dir / f"aux_labels_{params.dataset}.h5" out_manifest_filename = out_dir / f"cuts_{params.dataset}.json.gz" - if out_manifest_filename.exists(): - logging.info(f"{out_manifest_filename} exists - skipping") - return + + for f in ( + out_labels_ali_filename, + out_aux_labels_ali_filename, + out_manifest_filename, + ): + if f.exists(): + logging.info(f"{f} exists - skipping") + return lexicon = Lexicon(params.lang_dir) max_token_id = max(lexicon.tokens) @@ -352,21 +371,24 @@ def main(): dl = librispeech.valid_dataloaders(dev_other_cuts) logging.info(f"Processing {params.dataset}") - with NumpyHdf5Writer(out_ali_filename) as writer: - cut_set = compute_alignments( - model=model, - dl=dl, - writer=writer, - params=params, - graph_compiler=graph_compiler, - ) + with NumpyHdf5Writer(out_labels_ali_filename) as labels_writer: + with NumpyHdf5Writer(out_aux_labels_ali_filename) as aux_labels_writer: + cut_set = compute_alignments( + model=model, + dl=dl, + labels_writer=labels_writer, + aux_labels_writer=aux_labels_writer, + params=params, + graph_compiler=graph_compiler, + ) - cut_set.to_json(out_manifest_filename) + cut_set.to_file(out_manifest_filename) logging.info( - f"For dataset {params.dataset}, its alignments are " - f"saved to {out_ali_filename} and the cut manifest file " - f"is {out_manifest_filename}. Number of cuts: {len(cut_set)}" + f"For dataset {params.dataset}, its alignments with repeats are " + f"saved to {out_labels_ali_filename}, the alignments without repeats " + f"are saved to {out_aux_labels_ali_filename}, and the cut manifest " + f"file is {out_manifest_filename}. Number of cuts: {len(cut_set)}" ) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index c11726c39..e075a2d03 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -19,7 +19,6 @@ import argparse import logging from functools import lru_cache from pathlib import Path -from typing import List, Union from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse.dataset import ( diff --git a/icefall/utils.py b/icefall/utils.py index 7f092fa32..da8c3b920 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -305,11 +305,8 @@ def get_texts( return aux_labels.tolist() -def get_alignments(best_paths: k2.Fsa) -> List[List[int]]: - """Extract the token IDs (from best_paths.tokens) from the best-path FSAs. - - Caution: - There are no repeats in `best_paths.tokens`. +def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]: + """Extract labels or aux_labels from the best-path FSAs. Args: best_paths: @@ -317,15 +314,21 @@ def get_alignments(best_paths: k2.Fsa) -> List[List[int]]: containing multiple FSAs, which is expected to be the result of k2.shortest_path (otherwise the returned values won't be meaningful). + kind: + Possible values are: "labels" and "aux_labels". Caution: When it is + "labels", the resulting alignments contain repeats. Returns: Returns a list of lists of int, containing the token sequences we decoded. For `ans[i]`, its length equals to the number of frames after subsampling of the i-th utterance in the batch. """ + assert kind in ("labels", "aux_labels") # arc.shape() has axes [fsa][state][arc], we remove "state"-axis here token_shape = best_paths.arcs.shape().remove_axis(1) # token_shape has axes [fsa][arc] - tokens = k2.RaggedTensor(token_shape, best_paths.tokens) + tokens = k2.RaggedTensor( + token_shape, getattr(best_paths, kind).contiguous() + ) tokens = tokens.remove_values_eq(-1) return tokens.tolist() diff --git a/test/test_ali.py b/test/test_ali.py index 6de701d86..69dd82ab2 100755 --- a/test/test_ali.py +++ b/test/test_ali.py @@ -25,26 +25,15 @@ from pathlib import Path -import k2 -import torch -from lhotse import load_manifest +from lhotse import CutSet, load_manifest from lhotse.dataset import K2SpeechRecognitionDataset, SingleCutSampler -from torch.nn.utils.rnn import pad_sequence +from lhotse.dataset.collation import collate_custom_field from torch.utils.data import DataLoader -from icefall.ali import ( - convert_alignments_to_tensor, - load_alignments, - lookup_alignments, -) -from icefall.decode import get_lattice, one_best_decoding -from icefall.lexicon import Lexicon -from icefall.utils import get_texts - ICEFALL_DIR = Path(__file__).resolve().parent.parent egs_dir = ICEFALL_DIR / "egs/librispeech/ASR" lang_dir = egs_dir / "data/lang_bpe_500" -cuts_json = egs_dir / "data/token_ali/cuts_test-clean.json.gz" +cuts_json = egs_dir / "data/ali/cuts_dev-clean.json.gz" def data_exists(): @@ -53,10 +42,11 @@ def data_exists(): def get_dataloader(): cuts = load_manifest(cuts_json) + print(cuts[0]) cuts = cuts.with_features_path_prefix(egs_dir) sampler = SingleCutSampler( cuts, - max_duration=40, + max_duration=10, shuffle=False, ) @@ -75,14 +65,24 @@ def get_dataloader(): def test(): if not data_exists(): return - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) dl = get_dataloader() for batch in dl: supervisions = batch["supervisions"] cuts = supervisions["cut"] - print(cuts) + labels_alignment, labels_alignment_length = collate_custom_field( + CutSet.from_cuts(cuts), "labels_alignment" + ) + + ( + aux_labels_alignment, + aux_labels_alignment_length, + ) = collate_custom_field(CutSet.from_cuts(cuts), "aux_labels_alignment") + + print(labels_alignment) + print(aux_labels_alignment) + print(labels_alignment_length) + print(aux_labels_alignment_length) + # print(cuts) break