From 2cb438c3f081b9fa01e012998a8601cc78c56c7a Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 8 Sep 2021 19:32:14 +0800 Subject: [PATCH] Extract framewise alignment information using CTC decoding. --- egs/librispeech/ASR/conformer_ctc/ali.py | 213 +++++++++++++++++++++++ icefall/decode.py | 2 +- icefall/utils.py | 22 +++ 3 files changed, 236 insertions(+), 1 deletion(-) create mode 100755 egs/librispeech/ASR/conformer_ctc/ali.py diff --git a/egs/librispeech/ASR/conformer_ctc/ali.py b/egs/librispeech/ASR/conformer_ctc/ali.py new file mode 100755 index 000000000..8d779e850 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/ali.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +from pathlib import Path + +import k2 +import torch +from asr_datamodule import LibriSpeechAsrDataModule +from conformer import Conformer + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.decode import one_best_decoding +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + encode_supervisions, + get_alignments, + setup_logger, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=34, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=20, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "exp_dir": Path("conformer_ctc/exp"), + "lang_dir": Path("data/lang_bpe"), + "lm_dir": Path("data/lm"), + "feature_dim": 80, + "nhead": 8, + "attention_dim": 512, + "subsampling_factor": 4, + "num_decoder_layers": 6, + "vgg_frontend": False, + "is_espnet_structure": True, + "mmi_loss": False, + "use_feat_batchnorm": True, + "output_beam": 10, + "use_double_scores": True, + } + ) + return params + + +def compute_alignments( + model: torch.nn.Module, + dl: torch.utils.data.DataLoader, + params: AttributeDict, + graph_compiler: BpeCtcTrainingGraphCompiler, + token_table: k2.SymbolTable, +): + device = graph_compiler.device + for batch_idx, batch in enumerate(dl): + feature = batch["inputs"] + + # at entry, feature is [N, T, C] + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + nnet_output, encoder_memory, memory_mask = model(feature, supervisions) + # nnet_output is [N, T, C] + supervision_segments, texts = encode_supervisions( + supervisions, subsampling_factor=params.subsampling_factor + ) + + token_ids = graph_compiler.texts_to_ids(texts) + decoding_graph = graph_compiler.compile(token_ids) + + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + lattice = k2.intersect_dense( + decoding_graph, dense_fsa_vec, params.output_beam + ) + + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + + ali_ids = get_alignments(best_path) + ali_tokens = [[token_table[i] for i in ids] for ids in ali_ids] + + frame_shift = 0.01 # 10ms, i.e., 0.01 seconds + for i, ali in enumerate(ali_tokens[0]): + print(i * params.subsampling_factor * frame_shift, ali) + import sys + + sys.exit(0) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + assert args.return_cuts is True + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log/ali") + logging.info("Computing alignment - started") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + + logging.info("About to create model") + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=False, + is_espnet_structure=params.is_espnet_structure, + mmi_loss=params.mmi_loss, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames)) + + model.to(device) + model.eval() + + librispeech = LibriSpeechAsrDataModule(args) + test_dl = librispeech.test_dataloaders() # a list + + enabled_datasets = { + "test_clean": test_dl[0], + "test_other": test_dl[1], + } + + compute_alignments( + model=model, + dl=enabled_datasets["test_clean"], + params=params, + graph_compiler=graph_compiler, + token_table=lexicon.token_table, + ) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/icefall/decode.py b/icefall/decode.py index 3f6e5fc84..e3a875b16 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -878,7 +878,7 @@ def rescore_with_attention_decoder( best_path_indexes = k2.index_select(new2old, argmax_indexes) - # best_path is a k2.RaggedInt with 2 axes [path][arc_pos] + # best_path is a k2.RaggedTensor with 2 axes [path][arc_pos] best_path, _ = path_2axes.index( indexes=best_path_indexes, axis=0, need_value_indexes=False ) diff --git a/icefall/utils.py b/icefall/utils.py index 1130d8947..1016bcd35 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -219,6 +219,28 @@ def get_texts(best_paths: k2.Fsa) -> List[List[int]]: return aux_labels.tolist() +def get_alignments(best_paths: k2.Fsa) -> List[List[int]]: + """Extract the token IDs (from best_paths.labels) from the best-path FSAs. + + Args: + best_paths: + A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. + containing multiple FSAs, which is expected to be the result + of k2.shortest_path (otherwise the returned values won't + be meaningful). + Returns: + Returns a list of lists of int, containing the token sequences we + decoded. For `ans[i]`, its length equals to the number of frames + after subsampling of the i-th utterance in the batch. + """ + # arc.shape() has axes [fsa][state][arc], we remove "state"-axis here + label_shape = best_paths.arcs.shape().remove_axis(1) + # label_shape has axes [fsa][arc] + labels = k2.RaggedTensor(label_shape, best_paths.labels.contiguous()) + labels = labels.remove_values_eq(-1) + return labels.tolist() + + def store_transcripts( filename: Pathlike, texts: Iterable[Tuple[str, str]] ) -> None: