From 6f9fe5b9061371be306e10192679c6a6a4f85d4b Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 24 Jul 2021 22:23:50 +0800 Subject: [PATCH] Refactor decoding code. --- .github/workflows/test.yml | 1 + egs/librispeech/ASR/tdnn_lstm_ctc/decode.py | 28 ++- icefall/decode.py | 245 ++++++++++++++++++++ icefall/utils.py | 2 +- 4 files changed, 267 insertions(+), 9 deletions(-) create mode 100644 icefall/decode.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 7da954790..5af8a9ee6 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -55,6 +55,7 @@ jobs: git clone --depth 1 https://github.com/lhotse-speech/lhotse cd lhotse sed -i.bak "/torch/d" requirements.txt + pip install -r ./requirements.txt - name: Run tests diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 9d7d2597b..885ebb1fd 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -13,6 +13,7 @@ from model import TdnnLstm from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.dataset.librispeech import LibriSpeechAsrDataModule +from icefall.decode import get_lattice, nbest_decoding, one_best_decoding from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -48,7 +49,7 @@ def get_parser(): def get_params() -> AttributeDict: params = AttributeDict( { - "exp_dir": Path("tdnn_lstm_ctc/exp3/"), + "exp_dir": Path("tdnn_lstm_ctc/exp/"), "lang_dir": Path("data/lang"), "feature_dim": 80, "subsampling_factor": 3, @@ -56,6 +57,9 @@ def get_params() -> AttributeDict: "output_beam": 8, "min_active_states": 30, "max_active_states": 10000, + "use_double_scores": True, + "method": "1best", # [1best, nbest] + "num_paths": 30, # used when method is nbest } ) return params @@ -100,20 +104,28 @@ def decode_one_batch( 1, ).to(torch.int32) - dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) - - lattices = k2.intersect_dense_pruned( - HLG, - dense_fsa_vec, + lattice = get_lattice( + nnet_output=nnet_output, + HLG=HLG, + supervision_segments=supervision_segments, search_beam=params.search_beam, output_beam=params.output_beam, min_active_states=params.min_active_states, max_active_states=params.max_active_states, ) - best_paths = k2.shortest_path(lattices, use_double_scores=True) + if params.method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + ) - hyps = get_texts(best_paths) + hyps = get_texts(best_path) hyps = [[lexicon.words[i] for i in ids] for ids in hyps] texts = supervisions["text"] diff --git a/icefall/decode.py b/icefall/decode.py new file mode 100644 index 000000000..ed663bce8 --- /dev/null +++ b/icefall/decode.py @@ -0,0 +1,245 @@ +import k2 +import torch + + +def _intersect_device( + a_fsas: k2.Fsa, + b_fsas: k2.Fsa, + b_to_a_map: torch.Tensor, + sorted_match_a: bool, + batch_size: int = 50, +): + """This is a wrapper of k2.intersect_device and its purpose is to split + b_fsas into several batches and process each batch separately to avoid + CUDA OOM error. + + The arguments and return value of this function are the same as + k2.intersect_device. + """ + num_fsas = b_fsas.shape[0] + if num_fsas <= batch_size: + return k2.intersect_device( + a_fsas, b_fsas, b_to_a_map=b_to_a_map, sorted_match_a=sorted_match_a + ) + + num_batches = (num_fsas + batch_size - 1) // batch_size + splits = [] + for i in range(num_batches): + start = i * batch_size + end = min(start + batch_size, num_fsas) + splits.append((start, end)) + + ans = [] + for start, end in splits: + indexes = torch.arange(start, end).to(b_to_a_map) + + fsas = k2.index(b_fsas, indexes) + b_to_a = k2.index(b_to_a_map, indexes) + path_lattice = k2.intersect_device( + a_fsas, fsas, b_to_a_map=b_to_a, sorted_match_a=sorted_match_a + ) + ans.append(path_lattice) + + return k2.cat(ans) + + +def get_lattice( + nnet_output: torch.Tensor, + HLG: k2.Fsa, + supervision_segments: torch.Tensor, + search_beam: float, + output_beam: float, + min_active_states: int, + max_active_states: int, +): + """Get the decoding lattice from a decoding graph and neural + network output. + + Args: + nnet_output: + It is the output of a neural model of shape `[N, T, C]`. + HLG: + An Fsa, the decoding graph. See also `compile_HLG.py`. + supervision_segments: + A 2-D **CPU** tensor of dtype `torch.int32` with 3 columns. + Each row contains information for a supervision segment. Column 0 + is the `sequence_index` indicating which sequence this segment + comes from; column 1 specifies the `start_frame` of this segment + within the sequence; column 2 contains the `duration` of this + segment. + search_beam: + Decoding beam, e.g. 20. Smaller is faster, larger is more exact + (less pruning). This is the default value; it may be modified by + `min_active_states` and `max_active_states`. + output_beam: + Beam to prune output, similar to lattice-beam in Kaldi. Relative + to best path of output. + min_active_states: + Minimum number of FSA states that are allowed to be active on any given + frame for any given intersection/composition task. This is advisory, + in that it will try not to have fewer than this number active. + Set it to zero if there is no constraint. + max_active_states: + Maximum number of FSA states that are allowed to be active on any given + frame for any given intersection/composition task. This is advisory, + in that it will try not to exceed that but may not always succeed. + You can use a very large number if no constraint is needed. + Returns: + A lattice containing the decoding result. + """ + dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) + + lattice = k2.intersect_dense_pruned( + HLG, + dense_fsa_vec, + search_beam=search_beam, + output_beam=output_beam, + min_active_states=min_active_states, + max_active_states=max_active_states, + ) + + return lattice + + +def one_best_decoding( + lattice: k2.Fsa, use_double_scores: bool = True +) -> k2.Fsa: + """Get the best path from a lattice. + + Args: + lattice: + The decoding lattice returned by :func:`get_lattice`. + use_double_scores: + True to use double precision floating point in the computation. + False to use single precision. + Return: + An FsaVec containing linear paths. + """ + best_path = k2.shortest_path(lattice, use_double_scores=use_double_scores) + return best_path + + +def nbest_decoding( + lattice: k2.Fsa, num_paths: int, use_double_scores: bool = True +): + """It implements something like CTC prefix beam search using n-best lists. + + The basic idea is to first extra n-best paths from the given lattice, + build a word seqs from these paths, and compute the total scores + of these sequences in the log-semiring. The one with the max score + is used as the decoding output. + + Caution: + Don't be confused by `best` in the name `n-best`. Paths are selected + randomly, not by ranking their scores. + + Args: + lattice: + The decoding lattice, returned by :func:`get_lattice`. + num_paths: + It specifies the size `n` in n-best. Note: Paths are selected randomly + and those containing identical word sequences are remove dand only one + of them is kept. + use_double_scores: + True to use double precision floating point in the computation. + False to use single precision. + Returns: + An FsaVec containing linear FSAs. + """ + # First, extract `num_paths` paths for each sequence. + # path is a k2.RaggedInt with axes [seq][path][arc_pos] + path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True) + + # word_seq is a k2.RaggedInt sharing the same shape as `path` + # but it contains word IDs. Note that it also contains 0s and -1s. + # The last entry in each sublist is -1. + word_seq = k2.index(lattice.aux_labels, path) + # Note: the above operation supports also the case when + # lattice.aux_labels is a ragged tensor. In that case, + # `remove_axis=True` is used inside the pybind11 binding code, + # so the resulting `word_seq` still has 3 axes, like `path`. + # The 3 axes are [seq][path][word_id] + + # Remove 0 (epsilon) and -1 from word_seq + word_seq = k2.ragged.remove_values_leq(word_seq, 0) + + # Remove sequences with identical word sequences. + # + # k2.ragged.unique_sequences will reorder paths within a seq. + # `new2old` is a 1-D torch.Tensor mapping from the output path index + # to the input path index. + # new2old.numel() == unique_word_seqs.tot_size(1) + unique_word_seq, _, new2old = k2.ragged.unique_sequences( + word_seq, need_num_repeats=False, need_new2old_indexes=True + ) + # Note: unique_word_seq still has the same axes as word_seq + + seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0) + + # path_to_seq_map is a 1-D torch.Tensor. + # path_to_seq_map[i] is the seq to which the i-th path belongs + path_to_seq_map = seq_to_path_shape.row_ids(1) + + # Remove the seq axis. + # Now unique_word_seq has only two axes [path][word] + unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0) + + # word_fsa is an FsaVec with axes [path][state][arc] + word_fsa = k2.linear_fsa(unique_word_seq) + + # add epsilon self loops since we will use + # k2.intersect_device, which treats epsilon as a normal symbol + word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa) + + # lattice has token IDs as labels and word IDs as aux_labels. + # inv_lattice has word IDs as labels and token IDs as aux_labels + inv_lattice = k2.invert(lattice) + inv_lattice = k2.arc_sort(inv_lattice) + + path_lattice = _intersect_device( + inv_lattice, + word_fsa_with_epsilon_loops, + b_to_a_map=path_to_seq_map, + sorted_match_a=True, + ) + # path_lat has word IDs as labels and token IDs as aux_labels + + path_lattice = k2.top_sort(k2.connect(path_lattice)) + + tot_scores = path_lattice.get_tot_scores( + use_double_scores=use_double_scores, log_semiring=False + ) + + # RaggedFloat currently supports float32 only. + # If Ragged is wrapped, we can use k2.RaggedDouble here + ragged_tot_scores = k2.RaggedFloat( + seq_to_path_shape, tot_scores.to(torch.float32) + ) + + argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores) + + # Since we invoked `k2.ragged.unique_sequences`, which reorders + # the index from `path`, we use `new2old` here to convert argmax_indexes + # to the indexes into `path`. + # + # Use k2.index here since argmax_indexes' dtype is torch.int32 + best_path_indexes = k2.index(new2old, argmax_indexes) + + path_2axes = k2.ragged.remove_axis(path, 0) + + # best_path is a k2.RaggedInt with 2 axes [path][arc_pos] + best_path = k2.index(path_2axes, best_path_indexes) + + # labels is a k2.RaggedInt with 2 axes [path][token_id] + # Note that it contains -1s. + labels = k2.index(lattice.labels.contiguous(), best_path) + + labels = k2.ragged.remove_values_eq(labels, -1) + + # lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so + # aux_labels is also a k2.RaggedInt with 2 axes + aux_labels = k2.index(lattice.aux_labels, best_path.values()) + + best_path_fsa = k2.linear_fsa(labels) + best_path_fsa.aux_labels = aux_labels + return best_path_fsa diff --git a/icefall/utils.py b/icefall/utils.py index 813246132..4d1ca6cff 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -139,7 +139,7 @@ class AttributeDict(dict): def encode_supervisions( - supervisions: Dict[str, torch.Tensor], subsampling_factor: int + supervisions: dict, subsampling_factor: int ) -> Tuple[torch.Tensor, List[str]]: """ Encodes Lhotse's ``batch["supervisions"]`` dict into a pair of torch Tensor,