diff --git a/icefall/decode.py b/icefall/decode.py deleted file mode 100644 index 52a134dfc..000000000 --- a/icefall/decode.py +++ /dev/null @@ -1,913 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import Dict, List, Optional, Tuple, Union - -import k2 -import kaldialign -import torch -import torch.nn as nn - - -def _get_random_paths( - lattice: k2.Fsa, - num_paths: int, - use_double_scores: bool = True, - scale: float = 1.0, -): - """ - Args: - lattice: - The decoding lattice, returned by :func:`get_lattice`. - num_paths: - It specifies the size `n` in n-best. Note: Paths are selected randomly - and those containing identical word sequences are remove dand only one - of them is kept. - use_double_scores: - True to use double precision floating point in the computation. - False to use single precision. - scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - Returns: - Return a k2.RaggedInt with 3 axes [seq][path][arc_pos] - """ - saved_scores = lattice.scores.clone() - lattice.scores *= scale - path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True) - lattice.scores = saved_scores - return path - - -def _intersect_device( - a_fsas: k2.Fsa, - b_fsas: k2.Fsa, - b_to_a_map: torch.Tensor, - sorted_match_a: bool, - batch_size: int = 50, -) -> k2.Fsa: - """This is a wrapper of k2.intersect_device and its purpose is to split - b_fsas into several batches and process each batch separately to avoid - CUDA OOM error. - - The arguments and return value of this function are the same as - k2.intersect_device. - """ - num_fsas = b_fsas.shape[0] - if num_fsas <= batch_size: - return k2.intersect_device( - a_fsas, b_fsas, b_to_a_map=b_to_a_map, sorted_match_a=sorted_match_a - ) - - num_batches = (num_fsas + batch_size - 1) // batch_size - splits = [] - for i in range(num_batches): - start = i * batch_size - end = min(start + batch_size, num_fsas) - splits.append((start, end)) - - ans = [] - for start, end in splits: - indexes = torch.arange(start, end).to(b_to_a_map) - - fsas = k2.index_fsa(b_fsas, indexes) - b_to_a = k2.index_select(b_to_a_map, indexes) - path_lattice = k2.intersect_device( - a_fsas, fsas, b_to_a_map=b_to_a, sorted_match_a=sorted_match_a - ) - ans.append(path_lattice) - - return k2.cat(ans) - - -def get_lattice( - nnet_output: torch.Tensor, - HLG: k2.Fsa, - supervision_segments: torch.Tensor, - search_beam: float, - output_beam: float, - min_active_states: int, - max_active_states: int, - subsampling_factor: int = 1, -) -> k2.Fsa: - """Get the decoding lattice from a decoding graph and neural - network output. - - Args: - nnet_output: - It is the output of a neural model of shape `[N, T, C]`. - HLG: - An Fsa, the decoding graph. See also `compile_HLG.py`. - supervision_segments: - A 2-D **CPU** tensor of dtype `torch.int32` with 3 columns. - Each row contains information for a supervision segment. Column 0 - is the `sequence_index` indicating which sequence this segment - comes from; column 1 specifies the `start_frame` of this segment - within the sequence; column 2 contains the `duration` of this - segment. - search_beam: - Decoding beam, e.g. 20. Smaller is faster, larger is more exact - (less pruning). This is the default value; it may be modified by - `min_active_states` and `max_active_states`. - output_beam: - Beam to prune output, similar to lattice-beam in Kaldi. Relative - to best path of output. - min_active_states: - Minimum number of FSA states that are allowed to be active on any given - frame for any given intersection/composition task. This is advisory, - in that it will try not to have fewer than this number active. - Set it to zero if there is no constraint. - max_active_states: - Maximum number of FSA states that are allowed to be active on any given - frame for any given intersection/composition task. This is advisory, - in that it will try not to exceed that but may not always succeed. - You can use a very large number if no constraint is needed. - subsampling_factor: - The subsampling factor of the model. - Returns: - A lattice containing the decoding result. - """ - dense_fsa_vec = k2.DenseFsaVec( - nnet_output, supervision_segments, allow_truncate=subsampling_factor - 1 - ) - - lattice = k2.intersect_dense_pruned( - HLG, - dense_fsa_vec, - search_beam=search_beam, - output_beam=output_beam, - min_active_states=min_active_states, - max_active_states=max_active_states, - ) - - return lattice - - -def one_best_decoding( - lattice: k2.Fsa, use_double_scores: bool = True -) -> k2.Fsa: - """Get the best path from a lattice. - - Args: - lattice: - The decoding lattice returned by :func:`get_lattice`. - use_double_scores: - True to use double precision floating point in the computation. - False to use single precision. - Return: - An FsaVec containing linear paths. - """ - best_path = k2.shortest_path(lattice, use_double_scores=use_double_scores) - return best_path - - -def nbest_decoding( - lattice: k2.Fsa, - num_paths: int, - use_double_scores: bool = True, - scale: float = 1.0, -) -> k2.Fsa: - """It implements something like CTC prefix beam search using n-best lists. - - The basic idea is to first extra n-best paths from the given lattice, - build a word seqs from these paths, and compute the total scores - of these sequences in the log-semiring. The one with the max score - is used as the decoding output. - - Caution: - Don't be confused by `best` in the name `n-best`. Paths are selected - randomly, not by ranking their scores. - - Args: - lattice: - The decoding lattice, returned by :func:`get_lattice`. - num_paths: - It specifies the size `n` in n-best. Note: Paths are selected randomly - and those containing identical word sequences are remove dand only one - of them is kept. - use_double_scores: - True to use double precision floating point in the computation. - False to use single precision. - scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - Returns: - An FsaVec containing linear FSAs. - """ - path = _get_random_paths( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - scale=scale, - ) - - # word_seq is a k2.RaggedTensor sharing the same shape as `path` - # but it contains word IDs. Note that it also contains 0s and -1s. - # The last entry in each sublist is -1. - if isinstance(lattice.aux_labels, torch.Tensor): - word_seq = k2.ragged.index(lattice.aux_labels, path) - else: - word_seq = lattice.aux_labels.index(path) - word_seq = word_seq.remove_axis(word_seq.num_axes - 2) - - # Remove 0 (epsilon) and -1 from word_seq - word_seq = word_seq.remove_values_leq(0) - - # Remove sequences with identical word sequences. - # - # k2.ragged.unique_sequences will reorder paths within a seq. - # `new2old` is a 1-D torch.Tensor mapping from the output path index - # to the input path index. - # new2old.numel() == unique_word_seqs.tot_size(1) - unique_word_seq, _, new2old = word_seq.unique( - need_num_repeats=False, need_new2old_indexes=True - ) - # Note: unique_word_seq still has the same axes as word_seq - - seq_to_path_shape = unique_word_seq.shape.get_layer(0) - - # path_to_seq_map is a 1-D torch.Tensor. - # path_to_seq_map[i] is the seq to which the i-th path belongs - path_to_seq_map = seq_to_path_shape.row_ids(1) - - # Remove the seq axis. - # Now unique_word_seq has only two axes [path][word] - unique_word_seq = unique_word_seq.remove_axis(0) - - # word_fsa is an FsaVec with axes [path][state][arc] - word_fsa = k2.linear_fsa(unique_word_seq) - - # add epsilon self loops since we will use - # k2.intersect_device, which treats epsilon as a normal symbol - word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa) - - # lattice has token IDs as labels and word IDs as aux_labels. - # inv_lattice has word IDs as labels and token IDs as aux_labels - inv_lattice = k2.invert(lattice) - inv_lattice = k2.arc_sort(inv_lattice) - - path_lattice = _intersect_device( - inv_lattice, - word_fsa_with_epsilon_loops, - b_to_a_map=path_to_seq_map, - sorted_match_a=True, - ) - # path_lat has word IDs as labels and token IDs as aux_labels - - path_lattice = k2.top_sort(k2.connect(path_lattice)) - - tot_scores = path_lattice.get_tot_scores( - use_double_scores=use_double_scores, log_semiring=False - ) - - ragged_tot_scores = k2.RaggedTensor(seq_to_path_shape, tot_scores) - - argmax_indexes = ragged_tot_scores.argmax() - - # Since we invoked `k2.ragged.unique_sequences`, which reorders - # the index from `path`, we use `new2old` here to convert argmax_indexes - # to the indexes into `path`. - # - # Use k2.index here since argmax_indexes' dtype is torch.int32 - best_path_indexes = k2.index_select(new2old, argmax_indexes) - - path_2axes = path.remove_axis(0) - - # best_path is a k2.RaggedTensor with 2 axes [path][arc_pos] - best_path, _ = path_2axes.index( - indexes=best_path_indexes, axis=0, need_value_indexes=False - ) - - # labels is a k2.RaggedTensor with 2 axes [path][token_id] - # Note that it contains -1s. - labels = k2.ragged.index(lattice.labels.contiguous(), best_path) - - labels = labels.remove_values_eq(-1) - - # lattice.aux_labels is a k2.RaggedTensor with 2 axes, so - # aux_labels is also a k2.RaggedTensor with 2 axes - aux_labels, _ = lattice.aux_labels.index( - indexes=best_path.values, axis=0, need_value_indexes=False - ) - - best_path_fsa = k2.linear_fsa(labels) - best_path_fsa.aux_labels = aux_labels - return best_path_fsa - - -def compute_am_and_lm_scores( - lattice: k2.Fsa, - word_fsa_with_epsilon_loops: k2.Fsa, - path_to_seq_map: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Compute AM scores of n-best lists (represented as word_fsas). - - Args: - lattice: - An FsaVec, e.g., the return value of :func:`get_lattice` - It must have the attribute `lm_scores`. - word_fsa_with_epsilon_loops: - An FsaVec representing an n-best list. Note that it has been processed - by `k2.add_epsilon_self_loops`. - path_to_seq_map: - A 1-D torch.Tensor with dtype torch.int32. path_to_seq_map[i] indicates - which sequence the i-th Fsa in word_fsa_with_epsilon_loops belongs to. - path_to_seq_map.numel() == word_fsas_with_epsilon_loops.arcs.dim0(). - Returns: - Return a tuple containing two 1-D torch.Tensors: (am_scores, lm_scores). - Each tensor's `numel()' equals to `word_fsas_with_epsilon_loops.shape[0]` - """ - assert len(lattice.shape) == 3 - assert hasattr(lattice, "lm_scores") - - # k2.compose() currently does not support b_to_a_map. To void - # replicating `lats`, we use k2.intersect_device here. - # - # lattice has token IDs as `labels` and word IDs as aux_labels, so we - # need to invert it here. - inv_lattice = k2.invert(lattice) - - # Now the `labels` of inv_lattice are word IDs (a 1-D torch.Tensor) - # and its `aux_labels` are token IDs ( a k2.RaggedInt with 2 axes) - - # Remove its `aux_labels` since it is not needed in the - # following computation - del inv_lattice.aux_labels - inv_lattice = k2.arc_sort(inv_lattice) - - path_lattice = _intersect_device( - inv_lattice, - word_fsa_with_epsilon_loops, - b_to_a_map=path_to_seq_map, - sorted_match_a=True, - ) - - path_lattice = k2.top_sort(k2.connect(path_lattice)) - - # The `scores` of every arc consists of `am_scores` and `lm_scores` - path_lattice.scores = path_lattice.scores - path_lattice.lm_scores - - am_scores = path_lattice.get_tot_scores( - use_double_scores=True, log_semiring=False - ) - - path_lattice.scores = path_lattice.lm_scores - - lm_scores = path_lattice.get_tot_scores( - use_double_scores=True, log_semiring=False - ) - - return am_scores.to(torch.float32), lm_scores.to(torch.float32) - - -def rescore_with_n_best_list( - lattice: k2.Fsa, - G: k2.Fsa, - num_paths: int, - lm_scale_list: List[float], - scale: float = 1.0, -) -> Dict[str, k2.Fsa]: - """Decode using n-best list with LM rescoring. - - `lattice` is a decoding lattice with 3 axes. This function first - extracts `num_paths` paths from `lattice` for each sequence using - `k2.random_paths`. The `am_scores` of these paths are computed. - For each path, its `lm_scores` is computed using `G` (which is an LM). - The final `tot_scores` is the sum of `am_scores` and `lm_scores`. - The path with the largest `tot_scores` within a sequence is used - as the decoding output. - - Args: - lattice: - An FsaVec. It can be the return value of :func:`get_lattice`. - G: - An FsaVec representing the language model (LM). Note that it - is an FsaVec, but it contains only one Fsa. - num_paths: - It is the size `n` in `n-best` list. - lm_scale_list: - A list containing lm_scale values. - scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - Returns: - A dict of FsaVec, whose key is an lm_scale and the value is the - best decoding path for each sequence in the lattice. - """ - device = lattice.device - - assert len(lattice.shape) == 3 - assert hasattr(lattice, "aux_labels") - assert hasattr(lattice, "lm_scores") - - assert G.shape == (1, None, None) - assert G.device == device - assert hasattr(G, "aux_labels") is False - - path = _get_random_paths( - lattice=lattice, - num_paths=num_paths, - use_double_scores=True, - scale=scale, - ) - - # word_seq is a k2.RaggedTensor sharing the same shape as `path` - # but it contains word IDs. Note that it also contains 0s and -1s. - # The last entry in each sublist is -1. - if isinstance(lattice.aux_labels, torch.Tensor): - word_seq = k2.ragged.index(lattice.aux_labels, path) - else: - word_seq = lattice.aux_labels.index(path) - word_seq = word_seq.remove_axis(word_seq.num_axes - 2) - - # Remove epsilons and -1 from word_seq - word_seq = word_seq.remove_values_leq(0) - - # Remove paths that has identical word sequences. - # - # unique_word_seq is still a k2.RaggedTensor with 3 axes [seq][path][word] - # except that there are no repeated paths with the same word_seq - # within a sequence. - # - # num_repeats is also a k2.RaggedTensor with 2 axes containing the - # multiplicities of each path. - # num_repeats.numel() == unique_word_seqs.tot_size(1) - # - # Since k2.ragged.unique_sequences will reorder paths within a seq, - # `new2old` is a 1-D torch.Tensor mapping from the output path index - # to the input path index. - # new2old.numel() == unique_word_seqs.tot_size(1) - unique_word_seq, num_repeats, new2old = word_seq.unique( - need_num_repeats=True, need_new2old_indexes=True - ) - - seq_to_path_shape = unique_word_seq.shape.get_layer(0) - - # path_to_seq_map is a 1-D torch.Tensor. - # path_to_seq_map[i] is the seq to which the i-th path - # belongs. - path_to_seq_map = seq_to_path_shape.row_ids(1) - - # Remove the seq axis. - # Now unique_word_seq has only two axes [path][word] - unique_word_seq = unique_word_seq.remove_axis(0) - - # word_fsa is an FsaVec with axes [path][state][arc] - word_fsa = k2.linear_fsa(unique_word_seq) - - word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa) - - am_scores, _ = compute_am_and_lm_scores( - lattice, word_fsa_with_epsilon_loops, path_to_seq_map - ) - - # Now compute lm_scores - b_to_a_map = torch.zeros_like(path_to_seq_map) - lm_path_lattice = _intersect_device( - G, - word_fsa_with_epsilon_loops, - b_to_a_map=b_to_a_map, - sorted_match_a=True, - ) - lm_path_lattice = k2.top_sort(k2.connect(lm_path_lattice)) - lm_scores = lm_path_lattice.get_tot_scores( - use_double_scores=True, log_semiring=False - ) - - path_2axes = path.remove_axis(0) - - ans = dict() - for lm_scale in lm_scale_list: - tot_scores = am_scores / lm_scale + lm_scores - - # Remember that we used `k2.RaggedTensor.unique` to remove repeated - # paths to avoid redundant computation in `k2.intersect_device`. - # Now we use `num_repeats` to correct the scores for each path. - # - # NOTE(fangjun): It is commented out as it leads to a worse WER - # tot_scores = tot_scores * num_repeats.values() - - ragged_tot_scores = k2.RaggedTensor(seq_to_path_shape, tot_scores) - argmax_indexes = ragged_tot_scores.argmax() - - # Use k2.index here since argmax_indexes' dtype is torch.int32 - best_path_indexes = k2.index_select(new2old, argmax_indexes) - - # best_path is a k2.RaggedInt with 2 axes [path][arc_pos] - best_path, _ = path_2axes.index( - indexes=best_path_indexes, axis=0, need_value_indexes=False - ) - - # labels is a k2.RaggedTensor with 2 axes [path][phone_id] - # Note that it contains -1s. - labels = k2.ragged.index(lattice.labels.contiguous(), best_path) - - labels = labels.remove_values_eq(-1) - - # lattice.aux_labels is a k2.RaggedTensor tensor with 2 axes, so - # aux_labels is also a k2.RaggedTensor with 2 axes - - aux_labels, _ = lattice.aux_labels.index( - indexes=best_path.values, axis=0, need_value_indexes=False - ) - - best_path_fsa = k2.linear_fsa(labels) - best_path_fsa.aux_labels = aux_labels - - key = f"lm_scale_{lm_scale}" - ans[key] = best_path_fsa - - return ans - - -def rescore_with_whole_lattice( - lattice: k2.Fsa, - G_with_epsilon_loops: k2.Fsa, - lm_scale_list: Optional[List[float]] = None, -) -> Union[k2.Fsa, Dict[str, k2.Fsa]]: - """Use whole lattice to rescore. - - Args: - lattice: - An FsaVec It can be the return value of :func:`get_lattice`. - G_with_epsilon_loops: - An FsaVec representing the language model (LM). Note that it - is an FsaVec, but it contains only one Fsa. - lm_scale_list: - A list containing lm_scale values or None. - Returns: - If lm_scale_list is not None, return a dict of FsaVec, whose key - is a lm_scale and the value represents the best decoding path for - each sequence in the lattice. - If lm_scale_list is not None, return a lattice that is rescored - with the given LM. - """ - assert len(lattice.shape) == 3 - assert hasattr(lattice, "lm_scores") - assert G_with_epsilon_loops.shape == (1, None, None) - - device = lattice.device - lattice.scores = lattice.scores - lattice.lm_scores - # We will use lm_scores from G, so remove lats.lm_scores here - del lattice.lm_scores - assert hasattr(lattice, "lm_scores") is False - - assert hasattr(G_with_epsilon_loops, "lm_scores") - - # Now, lattice.scores contains only am_scores - - # inv_lattice has word IDs as labels. - # Its aux_labels are token IDs, which is a ragged tensor k2.RaggedInt - inv_lattice = k2.invert(lattice) - num_seqs = lattice.shape[0] - - b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32) - while True: - try: - rescoring_lattice = k2.intersect_device( - G_with_epsilon_loops, - inv_lattice, - b_to_a_map, - sorted_match_a=True, - ) - rescoring_lattice = k2.top_sort(k2.connect(rescoring_lattice)) - break - except RuntimeError as e: - logging.info(f"Caught exception:\n{e}\n") - logging.info( - f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}" - ) - - # NOTE(fangjun): The choice of the threshold 1e-7 is arbitrary here - # to avoid OOM. We may need to fine tune it. - inv_lattice = k2.prune_on_arc_post(inv_lattice, 1e-7, True) - logging.info( - f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}" - ) - - # lat has token IDs as labels - # and word IDs as aux_labels. - lat = k2.invert(rescoring_lattice) - - if lm_scale_list is None: - return lat - - ans = dict() - # - # The following implements - # scores = (scores - lm_scores)/lm_scale + lm_scores - # = scores/lm_scale + lm_scores*(1 - 1/lm_scale) - # - saved_am_scores = lat.scores - lat.lm_scores - for lm_scale in lm_scale_list: - am_scores = saved_am_scores / lm_scale - lat.scores = am_scores + lat.lm_scores - - best_path = k2.shortest_path(lat, use_double_scores=True) - key = f"lm_scale_{lm_scale}" - ans[key] = best_path - return ans - - -def nbest_oracle( - lattice: k2.Fsa, - num_paths: int, - ref_texts: List[str], - word_table: k2.SymbolTable, - scale: float = 1.0, -) -> Dict[str, List[List[int]]]: - """Select the best hypothesis given a lattice and a reference transcript. - - The basic idea is to extract n paths from the given lattice, unique them, - and select the one that has the minimum edit distance with the corresponding - reference transcript as the decoding output. - - The decoding result returned from this function is the best result that - we can obtain using n-best decoding with all kinds of rescoring techniques. - - Args: - lattice: - An FsaVec. It can be the return value of :func:`get_lattice`. - Note: We assume its aux_labels contain word IDs. - num_paths: - The size of `n` in n-best. - ref_texts: - A list of reference transcript. Each entry contains space(s) - separated words - word_table: - It is the word symbol table. - scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - Return: - Return a dict. Its key contains the information about the parameters - when calling this function, while its value contains the decoding output. - `len(ans_dict) == len(ref_texts)` - """ - path = _get_random_paths( - lattice=lattice, - num_paths=num_paths, - use_double_scores=True, - scale=scale, - ) - - if isinstance(lattice.aux_labels, torch.Tensor): - word_seq = k2.ragged.index(lattice.aux_labels, path) - else: - word_seq = lattice.aux_labels.index(path) - word_seq = word_seq.remove_axis(word_seq.num_axes - 2) - - word_seq = word_seq.remove_values_leq(0) - unique_word_seq, _, _ = word_seq.unique( - need_num_repeats=False, need_new2old_indexes=False - ) - unique_word_ids = unique_word_seq.tolist() - assert len(unique_word_ids) == len(ref_texts) - # unique_word_ids[i] contains all hypotheses of the i-th utterance - - results = [] - for hyps, ref in zip(unique_word_ids, ref_texts): - # Note hyps is a list-of-list ints - # Each sublist contains a hypothesis - ref_words = ref.strip().split() - # CAUTION: We don't convert ref_words to ref_words_ids - # since there may exist OOV words in ref_words - best_hyp_words = None - min_error = float("inf") - for hyp_words in hyps: - hyp_words = [word_table[i] for i in hyp_words] - this_error = kaldialign.edit_distance(ref_words, hyp_words)["total"] - if this_error < min_error: - min_error = this_error - best_hyp_words = hyp_words - results.append(best_hyp_words) - - return {f"nbest_{num_paths}_scale_{scale}_oracle": results} - - -def rescore_with_attention_decoder( - lattice: k2.Fsa, - num_paths: int, - model: nn.Module, - memory: torch.Tensor, - memory_key_padding_mask: Optional[torch.Tensor], - sos_id: int, - eos_id: int, - scale: float = 1.0, - ngram_lm_scale: Optional[float] = None, - attention_scale: Optional[float] = None, -) -> Dict[str, k2.Fsa]: - """This function extracts n paths from the given lattice and uses - an attention decoder to rescore them. The path with the highest - score is used as the decoding output. - - Args: - lattice: - An FsaVec. It can be the return value of :func:`get_lattice`. - num_paths: - Number of paths to extract from the given lattice for rescoring. - model: - A transformer model. See the class "Transformer" in - conformer_ctc/transformer.py for its interface. - memory: - The encoder memory of the given model. It is the output of - the last torch.nn.TransformerEncoder layer in the given model. - Its shape is `[T, N, C]`. - memory_key_padding_mask: - The padding mask for memory with shape [N, T]. - sos_id: - The token ID for SOS. - eos_id: - The token ID for EOS. - scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - ngram_lm_scale: - Optional. It specifies the scale for n-gram LM scores. - attention_scale: - Optional. It specifies the scale for attention decoder scores. - Returns: - A dict of FsaVec, whose key contains a string - ngram_lm_scale_attention_scale and the value is the - best decoding path for each sequence in the lattice. - """ - # First, extract `num_paths` paths for each sequence. - # path is a k2.RaggedInt with axes [seq][path][arc_pos] - path = _get_random_paths( - lattice=lattice, - num_paths=num_paths, - use_double_scores=True, - scale=scale, - ) - - # word_seq is a k2.RaggedTensor sharing the same shape as `path` - # but it contains word IDs. Note that it also contains 0s and -1s. - # The last entry in each sublist is -1. - if isinstance(lattice.aux_labels, torch.Tensor): - word_seq = k2.ragged.index(lattice.aux_labels, path) - else: - word_seq = lattice.aux_labels.index(path) - word_seq = word_seq.remove_axis(word_seq.num_axes - 2) - - # Remove epsilons and -1 from word_seq - word_seq = word_seq.remove_values_leq(0) - - # Remove paths that has identical word sequences. - # - # unique_word_seq is still a k2.RaggedTensor with 3 axes [seq][path][word] - # except that there are no repeated paths with the same word_seq - # within a sequence. - # - # num_repeats is also a k2.RaggedTensor with 2 axes containing the - # multiplicities of each path. - # num_repeats.numel() == unique_word_seqs.tot_size(1) - # - # Since k2.ragged.unique_sequences will reorder paths within a seq, - # `new2old` is a 1-D torch.Tensor mapping from the output path index - # to the input path index. - # new2old.numel() == unique_word_seq.tot_size(1) - unique_word_seq, num_repeats, new2old = word_seq.unique( - need_num_repeats=True, need_new2old_indexes=True - ) - - seq_to_path_shape = unique_word_seq.shape.get_layer(0) - - # path_to_seq_map is a 1-D torch.Tensor. - # path_to_seq_map[i] is the seq to which the i-th path - # belongs. - path_to_seq_map = seq_to_path_shape.row_ids(1) - - # Remove the seq axis. - # Now unique_word_seq has only two axes [path][word] - unique_word_seq = unique_word_seq.remove_axis(0) - - # word_fsa is an FsaVec with axes [path][state][arc] - word_fsa = k2.linear_fsa(unique_word_seq) - - word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa) - - am_scores, ngram_lm_scores = compute_am_and_lm_scores( - lattice, word_fsa_with_epsilon_loops, path_to_seq_map - ) - # Now we use the attention decoder to compute another - # score: attention_scores. - # - # To do that, we have to get the input and output for the attention - # decoder. - - # CAUTION: The "tokens" attribute is set in the file - # local/compile_hlg.py - if isinstance(lattice.tokens, torch.Tensor): - token_seq = k2.ragged.index(lattice.tokens, path) - else: - token_seq = lattice.tokens.index(path) - token_seq = token_seq.remove_axis(token_seq.num_axes - 2) - - # Remove epsilons and -1 from token_seq - token_seq = token_seq.remove_values_leq(0) - - # Remove the seq axis. - token_seq = token_seq.remove_axis(0) - - token_seq, _ = token_seq.index( - indexes=new2old, axis=0, need_value_indexes=False - ) - - # Now word in unique_word_seq has its corresponding token IDs. - token_ids = token_seq.tolist() - - num_word_seqs = new2old.numel() - - path_to_seq_map_long = path_to_seq_map.to(torch.long) - expanded_memory = memory.index_select(1, path_to_seq_map_long) - - if memory_key_padding_mask is not None: - expanded_memory_key_padding_mask = memory_key_padding_mask.index_select( - 0, path_to_seq_map_long - ) - else: - expanded_memory_key_padding_mask = None - - nll = model.decoder_nll( - memory=expanded_memory, - memory_key_padding_mask=expanded_memory_key_padding_mask, - token_ids=token_ids, - sos_id=sos_id, - eos_id=eos_id, - ) - assert nll.ndim == 2 - assert nll.shape[0] == num_word_seqs - - attention_scores = -nll.sum(dim=1) - assert attention_scores.ndim == 1 - assert attention_scores.numel() == num_word_seqs - - if ngram_lm_scale is None: - ngram_lm_scale_list = [0.01, 0.05, 0.08] - ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] - ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] - else: - ngram_lm_scale_list = [ngram_lm_scale] - - if attention_scale is None: - attention_scale_list = [0.01, 0.05, 0.08] - attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] - attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] - else: - attention_scale_list = [attention_scale] - - path_2axes = path.remove_axis(0) - - ans = dict() - for n_scale in ngram_lm_scale_list: - for a_scale in attention_scale_list: - tot_scores = ( - am_scores - + n_scale * ngram_lm_scores - + a_scale * attention_scores - ) - ragged_tot_scores = k2.RaggedTensor(seq_to_path_shape, tot_scores) - argmax_indexes = ragged_tot_scores.argmax() - - best_path_indexes = k2.index_select(new2old, argmax_indexes) - - # best_path is a k2.RaggedInt with 2 axes [path][arc_pos] - best_path, _ = path_2axes.index( - indexes=best_path_indexes, axis=0, need_value_indexes=False - ) - - # labels is a k2.RaggedTensor with 2 axes [path][token_id] - # Note that it contains -1s. - labels = k2.ragged.index(lattice.labels.contiguous(), best_path) - - labels = labels.remove_values_eq(-1) - - if isinstance(lattice.aux_labels, torch.Tensor): - aux_labels = k2.index_select( - lattice.aux_labels, best_path.values - ) - else: - aux_labels, _ = lattice.aux_labels.index( - indexes=best_path.values, axis=0, need_value_indexes=False - ) - - best_path_fsa = k2.linear_fsa(labels) - best_path_fsa.aux_labels = aux_labels - - key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}" - ans[key] = best_path_fsa - return ans diff --git a/icefall/decode2.py b/icefall/decode2.py index 16e4dcd25..c5c127cbb 100644 --- a/icefall/decode2.py +++ b/icefall/decode2.py @@ -14,9 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# NOTE: This file is a refactor of decode.py -# We will delete decode.py and rename this file to decode.py - import logging from typing import Dict, List, Optional, Union @@ -38,7 +35,7 @@ def _intersect_device( CUDA OOM error. The arguments and return value of this function are the same as - k2.intersect_device. + :func:`k2.intersect_device`. """ num_fsas = b_fsas.shape[0] if num_fsas <= batch_size: @@ -107,7 +104,9 @@ class Nbest(object): An Nbest object contains two fields: (1) fsa. It is an FsaVec containing a vector of **linear** FSAs. + Its axes are [path][state][arc] (2) shape. Its type is :class:`k2.RaggedShape`. + Its axes are [utt][path] The field `shape` has two axes [utt][path]. `shape.dim0` contains the number of utterances, which is also the number of rows in the @@ -116,11 +115,19 @@ class Nbest(object): Caution: Don't be confused by the name `Nbest`. The best in the name `Nbest` - has nothing to do with the `best scores`. The important part is + has nothing to do with `best scores`. The important part is `N` in `Nbest`, not `best`. """ def __init__(self, fsa: k2.Fsa, shape: k2.RaggedShape) -> None: + """ + Args: + fsa: + An FsaVec with axes [path][state][arc]. It is expected to contain + a list of **linear** FSAs. + shape: + A ragged shape with two axes [utt][path]. + """ assert len(fsa.shape) == 3, f"fsa.shape: {fsa.shape}" assert shape.num_axes == 2, f"num_axes: {shape.num_axes}" @@ -135,8 +142,8 @@ class Nbest(object): def __str__(self): s = "Nbest(" - s += f"num_seqs:{self.shape.dim0}, " - s += f"num_fsas:{self.fsa.shape[0]})" + s += f"Number of utterances:{self.shape.dim0}, " + s += f"Number of Paths:{self.fsa.shape[0]})" return s @staticmethod @@ -148,6 +155,11 @@ class Nbest(object): ) -> "Nbest": """Construct an Nbest object by **sampling** `num_paths` from a lattice. + Each sampled path is a linear FSA. + + We assume `lattice.labels` contains token IDs and `lattice.aux_labels` + contains word IDs. + Args: lattice: An FsaVec with axes [utt][state][arc]. @@ -159,13 +171,15 @@ class Nbest(object): False to use single precision. scale: Scale `lattice.score` before passing it to :func:`k2.random_paths`. - A smaller value leads to more unique paths with the risk being not + A smaller value leads to more unique paths at the risk of being not to sample the path with the best score. + Returns: + Return an Nbest instance. """ saved_scores = lattice.scores.clone() lattice.scores *= lattice_score_scale # path is a ragged tensor with dtype torch.int32. - # It has three axes [utt][path][arc_pos + # It has three axes [utt][path][arc_pos] path = k2.random_paths( lattice, num_paths=num_paths, use_double_scores=use_double_scores ) @@ -174,6 +188,7 @@ class Nbest(object): # word_seq is a k2.RaggedTensor sharing the same shape as `path` # but it contains word IDs. Note that it also contains 0s and -1s. # The last entry in each sublist is -1. + # It axes is [utt][path][word_id] if isinstance(lattice.aux_labels, torch.Tensor): word_seq = k2.ragged.index(lattice.aux_labels, path) else: @@ -224,28 +239,28 @@ class Nbest(object): fsa = k2.linear_fsa(labels) fsa.aux_labels = aux_labels # Caution: fsa.scores are all 0s. + # `fsa` has only one extra attribute: aux_labels. return Nbest(fsa=fsa, shape=utt_to_path_shape) def intersect(self, lattice: k2.Fsa, use_double_scores=True) -> "Nbest": - """Intersect this Nbest object with a lattice and get 1-best - path from the resulting FsaVec. + """Intersect this Nbest object with a lattice, get 1-best + path from the resulting FsaVec, and return a new Nbest object. The purpose of this function is to attach scores to an Nbest. - Args: lattice: An FsaVec with axes [utt][state][arc]. If it has `aux_labels`, then we assume its `labels` are token IDs and `aux_labels` are word IDs. - If it has only `labels`, we assume it `labels` are word IDs. - + If it has only `labels`, we assume its `labels` are word IDs. use_double_scores: True to use double precision when computing shortest path. False to use single precision. Returns: Return a new Nbest. This new Nbest shares the same shape with `self`, while its `fsa` is the 1-best path from intersecting `self.fsa` and - `lattice`. + `lattice`. Also, its `fsa` has non-zero scores and inherits attributes + for `lattice`. """ # Note: We view each linear FSA as a word sequence # and we use the passed lattice to give each word sequence a score. @@ -308,7 +323,8 @@ class Nbest(object): an utterance). Hint: - `self.fsa.scores` contains two parts: am scores and lm scores. + `self.fsa.scores` contains two parts: acoustic scores (AM scores) + and n-gram language model scores (LM scores). Caution: We require that ``self.fsa`` has an attribute ``lm_scores``. @@ -334,7 +350,8 @@ class Nbest(object): an utterance). Hint: - `self.fsa.scores` contains two parts: am scores and lm scores. + `self.fsa.scores` contains two parts: acoustic scores (AM scores) + and n-gram language model scores (LM scores). Caution: We require that ``self.fsa`` has an attribute ``lm_scores``. @@ -348,9 +365,6 @@ class Nbest(object): # The `scores` of every arc consists of `am_scores` and `lm_scores` self.fsa.scores = self.fsa.lm_scores - # Caution: self.fsa.lm_scores is per arc - # while lm_scores in the following is per path - # lm_scores = self.fsa.get_tot_scores( use_double_scores=True, log_semiring=False ) @@ -359,17 +373,16 @@ class Nbest(object): return k2.RaggedTensor(self.shape, lm_scores) def tot_scores(self) -> k2.RaggedTensor: - """Get total scores of the FSAs in this Nbest. + """Get total scores of FSAs in this Nbest. Note: - Since FSAs in Nbest are just linear FSAs, log-semirng and tropical - semiring produce the same total scores. + Since FSAs in Nbest are just linear FSAs, log-semiring + and tropical semiring produce the same total scores. Returns: Return a ragged tensor with two axes [utt][path_scores]. Its dtype is torch.float64. """ - # Use single precision since there are only additions. scores = self.fsa.get_tot_scores( use_double_scores=True, log_semiring=False ) @@ -382,6 +395,25 @@ class Nbest(object): return k2.Fsa.from_fsas(word_levenshtein_graphs) +def one_best_decoding( + lattice: k2.Fsa, + use_double_scores: bool = True, +) -> k2.Fsa: + """Get the best path from a lattice. + + Args: + lattice: + The decoding lattice returned by :func:`get_lattice`. + use_double_scores: + True to use double precision floating point in the computation. + False to use single precision. + Return: + An FsaVec containing linear paths. + """ + best_path = k2.shortest_path(lattice, use_double_scores=use_double_scores) + return best_path + + def nbest_decoding( lattice: k2.Fsa, num_paths: int, @@ -390,7 +422,7 @@ def nbest_decoding( ) -> k2.Fsa: """It implements something like CTC prefix beam search using n-best lists. - The basic idea is to first extra `num_paths` paths from the given lattice, + The basic idea is to first extract `num_paths` paths from the given lattice, build a word sequence from these paths, and compute the total scores of the word sequence in the tropical semiring. The one with the max score is used as the decoding output. @@ -399,6 +431,10 @@ def nbest_decoding( Don't be confused by `best` in the name `n-best`. Paths are selected **randomly**, not by ranking their scores. + Hint: + This decoding method is for demonstration only and it does + not produce a lower WER than :func:`one_best_decoding`. + Args: lattice: The decoding lattice, e.g., can be the return value of @@ -411,10 +447,10 @@ def nbest_decoding( True to use double precision floating point in the computation. False to use single precision. lattice_score_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. + It's the scale applied to the `lattice.scores`. A smaller value + leads to more unique paths at the risk of missing the correct path. Returns: - An FsaVec containing linear FSAs. It axes are [utt][state][arc]. + An FsaVec containing **linear** FSAs. It axes are [utt][state][arc]. """ nbest = Nbest.from_lattice( lattice=lattice, @@ -446,9 +482,9 @@ def nbest_oracle( ) -> Dict[str, List[List[int]]]: """Select the best hypothesis given a lattice and a reference transcript. - The basic idea is to extract n paths from the given lattice, unique them, - and select the one that has the minimum edit distance with the corresponding - reference transcript as the decoding output. + The basic idea is to extract `num_paths` paths from the given lattice, + unique them, and select the one that has the minimum edit distance with + the corresponding reference transcript as the decoding output. The decoding result returned from this function is the best result that we can obtain using n-best decoding with all kinds of rescoring techniques. @@ -458,7 +494,7 @@ def nbest_oracle( Args: lattice: An FsaVec with axes [utt][state][arc]. - Note: We assume its aux_labels contain word IDs. + Note: We assume its `aux_labels` contains word IDs. num_paths: The size of `n` in n-best. ref_texts: @@ -500,6 +536,7 @@ def nbest_oracle( else: word_ids.append(oov_id) word_ids_list.append(word_ids) + levenshtein_graphs = [levenshtein_graph(ids) for ids in word_ids_list] refs = k2.Fsa.from_fsas(levenshtein_graphs).to(device) @@ -536,8 +573,8 @@ def rescore_with_n_best_list( lattice_score_scale: float = 1.0, use_double_scores: bool = True, ) -> Dict[str, k2.Fsa]: - """Rescore a nbest list with an n-gram LM. - The path with a maximum score is used as the decoding output. + """Rescore an n-best list with an n-gram LM. + The path with the maximum score is used as the decoding output. Args: lattice: @@ -605,7 +642,38 @@ def rescore_with_whole_lattice( lm_scale_list: Optional[List[float]] = None, use_double_scores: bool = True, ) -> Union[k2.Fsa, Dict[str, k2.Fsa]]: - # This is not an Nbest based coding method + """Intersect the lattice with an n-gram LM and use shortest path + to decode. + + The input lattice is obtained by intersecting `HLG` with + a DenseFsaVec, where the `G` in `HLG` is in general a 3-gram LM. + The input `G_with_epsilon_loops` is usually a 4-gram LM. You can consider + this function as a second pass decoding. In the first pass decoding, we + use a small G, while we use a larger G in the second pass decoding. + + Args: + lattice: + An FsaVec with axes [utt][state][arc]. Its `aux_lables` are word IDs. + It must have an attribute `lm_scores`. + G_with_epsilon_loops: + An FsaVec containing only a single FSA. It contains epsilon self-loops. + It is an acceptor and its labels are word IDs. + lm_scale_list: + Optional. If none, return the intersection of `lattice` and + `G_with_epsilon_loops`. + If not None, it contains a list of values to scale LM scores. + For each scale, there is a corresponding decoding result contained in + the resulting dict. + use_double_scores: + True to use double precision in the computation. + False to use single precision. + Returns: + If `lm_scale_list` is None, return a new lattice which is the intersection + result of `lattice` and `G_with_epsilon_loops`. + Otherwise, return a dict whose key is an entry in `lm_scale_list` and the + value is the decoding result (i.e., an FsaVec containing linear FSAs). + """ + # Nbest is not used in this function assert hasattr(lattice, "lm_scores") assert G_with_epsilon_loops.shape == (1, None, None) @@ -619,7 +687,7 @@ def rescore_with_whole_lattice( # Now, lattice.scores contains only am_scores # inv_lattice has word IDs as labels. - # Its aux_labels are token IDs + # Its `aux_labels` is token IDs inv_lattice = k2.invert(lattice) num_seqs = lattice.shape[0] @@ -668,7 +736,7 @@ def rescore_with_whole_lattice( lat.scores = am_scores + lat.lm_scores best_path = k2.shortest_path(lat, use_double_scores=use_double_scores) - key = f"lm_scale_{lm_scale}_yy" + key = f"lm_scale_{lm_scale}" ans[key] = best_path return ans @@ -686,6 +754,40 @@ def rescore_with_attention_decoder( attention_scale: Optional[float] = None, use_double_scores: bool = True, ) -> Dict[str, k2.Fsa]: + """This function extracts `num_paths` paths from the given lattice and uses + an attention decoder to rescore them. The path with the highest score is + the decoding output. + + Args: + lattice: + An FsaVec with axes [utt][state][arc]. + num_paths: + Number of paths to extract from the given lattice for rescoring. + model: + A transformer model. See the class "Transformer" in + conformer_ctc/transformer.py for its interface. + memory: + The encoder memory of the given model. It is the output of + the last torch.nn.TransformerEncoder layer in the given model. + Its shape is `(T, N, C)`. + memory_key_padding_mask: + The padding mask for memory with shape `(N, T)`. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + lattice_score_scale: + It's the scale applied to `lattice.scores`. A smaller value + leads to more unique paths at the risk of missing the correct path. + ngram_lm_scale: + Optional. It specifies the scale for n-gram LM scores. + attention_scale: + Optional. It specifies the scale for attention decoder scores. + Returns: + A dict of FsaVec, whose key contains a string + ngram_lm_scale_attention_scale and the value is the + best decoding path for each utterance in the lattice. + """ nbest = Nbest.from_lattice( lattice=lattice, num_paths=num_paths,