diff --git a/icefall/decode.py b/icefall/decode.py index 573c6bf78..e678e4622 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -128,43 +128,6 @@ def get_lattice( return lattice -# TODO(fangjun): Use Kangwei's C++ implementation that also -# supports List[List[int]] -def levenshtein_graph(symbol_ids: List[int]) -> k2.Fsa: - """Construct a graph to compute Levenshtein distance. - - An example graph for `levenshtein_graph([1, 2, 3]` can be found - at https://git.io/Ju7eW - (Assuming the symbol table is 1<->a, 2<->b, 3<->c, blk<->0) - - Args: - symbol_ids: - A list of symbol IDs (excluding 0 and -1) - Returns: - Return an Fsa (with 2 axes [state][arc]). - """ - assert 0 not in symbol_ids - assert -1 not in symbol_ids - final_state = len(symbol_ids) + 1 - arcs = [] - for i in range(final_state - 1): - arcs.append([i, i, 0, 0, -0.5]) - arcs.append([i, i + 1, 0, symbol_ids[i], -0.5]) - arcs.append([i, i + 1, symbol_ids[i], symbol_ids[i], 0]) - arcs.append([final_state - 1, final_state - 1, 0, 0, -0.5]) - arcs.append([final_state - 1, final_state, -1, -1, 0]) - arcs.append([final_state]) - - arcs = sorted(arcs, key=lambda arc: arc[0]) - arcs = [[str(i) for i in arc] for arc in arcs] - arcs = [" ".join(arc) for arc in arcs] - arcs = "\n".join(arcs) - - fsa = k2.Fsa.from_str(arcs, acceptor=False) - fsa = k2.arc_sort(fsa) - return fsa - - class Nbest(object): """ An Nbest object contains two fields: @@ -456,9 +419,8 @@ class Nbest(object): def build_levenshtein_graphs(self) -> k2.Fsa: """Return an FsaVec with axes [utt][state][arc].""" - word_ids = get_texts(self.fsa) - word_levenshtein_graphs = [levenshtein_graph(ids) for ids in word_ids] - return k2.Fsa.from_fsas(word_levenshtein_graphs) + word_ids = get_texts(self.fsa, return_ragged=True) + return k2.levenshtein_graph(word_ids) def one_best_decoding( @@ -590,7 +552,7 @@ def nbest_oracle( lattice_score_scale=lattice_score_scale, ) - hyps = nbest.build_levenshtein_graphs().to(device) + hyps = nbest.build_levenshtein_graphs() oov_id = word_table[oov] word_ids_list = [] @@ -603,24 +565,16 @@ def nbest_oracle( word_ids.append(oov_id) word_ids_list.append(word_ids) - levenshtein_graphs = [levenshtein_graph(ids) for ids in word_ids_list] - refs = k2.Fsa.from_fsas(levenshtein_graphs).to(device) + refs = k2.levenshtein_graph(word_ids_list, device=device) - # Now compute the edit distance between hyps and refs - hyps.rename_tensor_attribute_("aux_labels", "aux_labels2") - edit_dist_lattice = k2.intersect_device( - refs, - hyps, - b_to_a_map=nbest.shape.row_ids(1), - sorted_match_a=True, + levenshtein_alignment = k2.levenshtein_alignment( + refs=refs, + hyps=hyps, + hyp_to_ref_map=nbest.shape.row_ids(1), + sorted_match_ref=True, ) - edit_dist_lattice = k2.remove_epsilon_self_loops(edit_dist_lattice) - edit_dist_best_path = k2.shortest_path( - edit_dist_lattice, use_double_scores=True - ).invert_() - edit_dist_best_path.rename_tensor_attribute_("aux_labels2", "aux_labels") - tot_scores = edit_dist_best_path.get_tot_scores( + tot_scores = levenshtein_alignment.get_tot_scores( use_double_scores=False, log_semiring=False ) ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) diff --git a/icefall/utils.py b/icefall/utils.py index cc658ae32..2324201c3 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -186,7 +186,9 @@ def encode_supervisions( return supervision_segments, texts -def get_texts(best_paths: k2.Fsa) -> List[List[int]]: +def get_texts( + best_paths: k2.Fsa, return_ragged: bool = False +) -> Union[List[List[int]], k2.RaggedTensor]: """Extract the texts (as word IDs) from the best-path FSAs. Args: best_paths: @@ -194,6 +196,9 @@ def get_texts(best_paths: k2.Fsa) -> List[List[int]]: containing multiple FSAs, which is expected to be the result of k2.shortest_path (otherwise the returned values won't be meaningful). + return_ragged: + True to return a ragged tensor with two axes [utt][word_id]. + False to return a list-of-list word IDs. Returns: Returns a list of lists of int, containing the label sequences we decoded. @@ -216,7 +221,10 @@ def get_texts(best_paths: k2.Fsa) -> List[List[int]]: aux_labels = aux_labels.remove_values_leq(0) assert aux_labels.num_axes == 2 - return aux_labels.tolist() + if return_ragged: + return aux_labels + else: + return aux_labels.tolist() def store_transcripts(