From 664b87a57b476d0279e0c7efe751438b3846545e Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 14 Sep 2021 21:07:54 +0800 Subject: [PATCH] Compute edit distance with k2. --- egs/librispeech/ASR/conformer_ctc/decode.py | 13 +- icefall/decode2.py | 174 +++++++++++++++++++- icefall/graph_compiler.py | 2 +- 3 files changed, 178 insertions(+), 11 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 97d0b63fd..fc758f439 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -32,13 +32,12 @@ from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.decode import ( get_lattice, - nbest_oracle, one_best_decoding, rescore_with_attention_decoder, rescore_with_n_best_list, rescore_with_whole_lattice, ) -from icefall.decode2 import nbest_decoding +from icefall.decode2 import nbest_decoding, nbest_oracle from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -245,13 +244,17 @@ def decode_one_batch( # We choose the HLG decoded lattice for speed reasons # as HLG decoding is faster and the oracle WER # is slightly worse than that of rescored lattices. - return nbest_oracle( + best_path = nbest_oracle( lattice=lattice, num_paths=params.num_paths, ref_texts=supervisions["text"], word_table=word_table, - scale=params.lattice_score_scale, + lattice_score_scale=params.lattice_score_scale, + oov="", ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + return {"nbest-orcale": hyps} if params.method in ["1best", "nbest"]: if params.method == "1best": @@ -264,7 +267,7 @@ def decode_one_batch( lattice=lattice, num_paths=params.num_paths, use_double_scores=params.use_double_scores, - scale=params.lattice_score_scale, + lattice_score_scale=params.lattice_score_scale, ) key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa diff --git a/icefall/decode2.py b/icefall/decode2.py index ea76fb63b..2f21925bb 100644 --- a/icefall/decode2.py +++ b/icefall/decode2.py @@ -17,9 +17,47 @@ # NOTE: This file is a refactor of decode.py # We will delete decode.py and rename this file to decode.py +from typing import Dict, List + import k2 import torch +from icefall.utils import get_texts + + +# TODO(fangjun): Use Kangwei's C++ implementation that also List[List[int]] +def levenshtein_graph(symbol_ids: List[int]) -> k2.Fsa: + """Construct a graph to compute Levenshtein distance. + + An example graph for `levenshtein_graph([1, 2, 3]` can be found + at https://git.io/Ju7eW + (Assuming the symbol table is 1<->a, 2<->b, 3<->c, blk<->0) + + Args: + symbol_ids: + A list of symbol IDs (excluding 0 and -1) + """ + assert 0 not in symbol_ids + assert -1 not in symbol_ids + final_state = len(symbol_ids) + 1 + arcs = [] + for i in range(final_state - 1): + arcs.append([i, i, 0, 0, -1.01]) + arcs.append([i, i + 1, symbol_ids[i], symbol_ids[i], 0]) + arcs.append([i, i + 1, 0, symbol_ids[i], -1]) + arcs.append([final_state - 1, final_state - 1, 0, 0, -1.01]) + arcs.append([final_state - 1, final_state, -1, -1, 0]) + arcs.append([final_state]) + + arcs = sorted(arcs, key=lambda arc: arc[0]) + arcs = [[str(i) for i in arc] for arc in arcs] + arcs = [" ".join(arc) for arc in arcs] + arcs = "\n".join(arcs) + + fsa = k2.Fsa.from_str(arcs, acceptor=False) + fsa = k2.arc_sort(fsa) + return fsa + class Nbest(object): """ @@ -63,7 +101,7 @@ class Nbest(object): lattice: k2.Fsa, num_paths: int, use_double_scores: bool = True, - scale: float = 0.5, + lattice_score_scale: float = 0.5, ) -> "Nbest": """Construct an Nbest object by **sampling** `num_paths` from a lattice. @@ -82,7 +120,7 @@ class Nbest(object): to sample the path with the best score. """ saved_scores = lattice.scores.clone() - lattice.scores *= scale + lattice.scores *= lattice_score_scale # path is a ragged tensor with dtype torch.int32. # It has three axes [utt][path][arc_pos path = k2.random_paths( @@ -230,19 +268,52 @@ class Nbest(object): ) return k2.RaggedTensor(self.shape, scores) + def build_levenshtein_graphs(self) -> k2.Fsa: + """Return an FsaVec with axes [utt][state][arc].""" + word_ids = get_texts(self.fsa) + word_levenshtein_graphs = [levenshtein_graph(ids) for ids in word_ids] + return k2.Fsa.from_fsas(word_levenshtein_graphs) + def nbest_decoding( lattice: k2.Fsa, num_paths: int, use_double_scores: bool = True, - scale: float = 1.0, + lattice_score_scale: float = 1.0, ) -> k2.Fsa: - """ """ + """It implements something like CTC prefix beam search using n-best lists. + + The basic idea is to first extra `num_paths` paths from the given lattice, + build a word sequence from these paths, and compute the total scores + of the word sequence in the tropical semiring. The one with the max score + is used as the decoding output. + + Caution: + Don't be confused by `best` in the name `n-best`. Paths are selected + **randomly**, not by ranking their scores. + + Args: + lattice: + The decoding lattice, e.g., can be the return value of + :func:`get_lattice`. It has 3 axes [utt][state][arc]. + num_paths: + It specifies the size `n` in n-best. Note: Paths are selected randomly + and those containing identical word sequences are removed and only one + of them is kept. + use_double_scores: + True to use double precision floating point in the computation. + False to use single precision. + lattice_score_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + Returns: + An FsaVec containing linear FSAs. It axes are [utt][state][arc]. + """ nbest = Nbest.from_lattice( lattice=lattice, num_paths=num_paths, use_double_scores=use_double_scores, - scale=scale, + lattice_score_scale=lattice_score_scale, ) nbest = nbest.intersect(lattice) @@ -253,3 +324,96 @@ def nbest_decoding( best_path = k2.index_fsa(nbest.fsa, max_indexes) return best_path + + +def nbest_oracle( + lattice: k2.Fsa, + num_paths: int, + ref_texts: List[str], + word_table: k2.SymbolTable, + use_double_scores: bool = True, + lattice_score_scale: float = 0.5, + oov: str = "", +) -> Dict[str, List[List[int]]]: + """Select the best hypothesis given a lattice and a reference transcript. + + The basic idea is to extract n paths from the given lattice, unique them, + and select the one that has the minimum edit distance with the corresponding + reference transcript as the decoding output. + + The decoding result returned from this function is the best result that + we can obtain using n-best decoding with all kinds of rescoring techniques. + + This function is useful to tune the value of `lattice_score_scale`. + + Args: + lattice: + An FsaVec with axes [utt][state][arc]. + Note: We assume its aux_labels contain word IDs. + num_paths: + The size of `n` in n-best. + ref_texts: + A list of reference transcript. Each entry contains space(s) + separated words + word_table: + It is the word symbol table. + use_double_scores: + True to use double precision for computation. False to use + single precision. + lattice_score_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + oov: + The out of vocabulary word. + Return: + Return a dict. Its key contains the information about the parameters + when calling this function, while its value contains the decoding output. + `len(ans_dict) == len(ref_texts)` + """ + device = lattice.device + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + lattice_score_scale=lattice_score_scale, + ) + + hyps = nbest.build_levenshtein_graphs().to(device) + + oov_id = word_table[oov] + word_ids_list = [] + for text in ref_texts: + word_ids = [] + for word in text.split(): + if word in word_table: + word_ids.append(word_table[word]) + else: + word_ids.append(oov_id) + word_ids_list.append(word_ids) + levenshtein_graphs = [levenshtein_graph(ids) for ids in word_ids_list] + refs = k2.Fsa.from_fsas(levenshtein_graphs).to(device) + + # Now compute the edit distance between hyps and refs + hyps.rename_tensor_attribute_("aux_labels", "aux_labels2") + edit_dist_lattice = k2.intersect_device( + refs, + hyps, + b_to_a_map=nbest.shape.row_ids(1), + sorted_match_a=True, + ) + edit_dist_lattice = k2.remove_epsilon_self_loops(edit_dist_lattice) + edit_dist_best_path = k2.shortest_path( + edit_dist_lattice, use_double_scores=True + ).invert_() + edit_dist_best_path.rename_tensor_attribute_("aux_labels2", "aux_labels") + + tot_scores = edit_dist_best_path.get_tot_scores( + use_double_scores=False, log_semiring=False + ) + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + + max_indexes = ragged_tot_scores.argmax() + + best_path = k2.index_fsa(nbest.fsa, max_indexes) + return best_path diff --git a/icefall/graph_compiler.py b/icefall/graph_compiler.py index 23ac247e8..b4c87d964 100644 --- a/icefall/graph_compiler.py +++ b/icefall/graph_compiler.py @@ -106,7 +106,7 @@ class CtcTrainingGraphCompiler(object): word_ids_list = [] for text in texts: word_ids = [] - for word in text.split(" "): + for word in text.split(): if word in self.word_table: word_ids.append(self.word_table[word]) else: