From b9fc46f432f39c61f9376635c332ace6bae3053c Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 17 Sep 2021 15:46:23 +0800 Subject: [PATCH] Refactor nbest-oracle. --- egs/librispeech/ASR/conformer_ctc/decode.py | 37 +++++++++++++-------- icefall/decode.py | 2 +- icefall/decode2.py | 16 +++++---- test/test_decode.py | 5 ++- 4 files changed, 38 insertions(+), 22 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index fc758f439..37fb58bbb 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -30,14 +30,15 @@ from conformer import Conformer from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.decode import get_lattice from icefall.decode import ( - get_lattice, one_best_decoding, rescore_with_attention_decoder, rescore_with_n_best_list, rescore_with_whole_lattice, + nbest_oracle, ) -from icefall.decode2 import nbest_decoding, nbest_oracle +from icefall.decode2 import nbest_decoding, nbest_oracle as nbest_oracle2 from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -244,17 +245,27 @@ def decode_one_batch( # We choose the HLG decoded lattice for speed reasons # as HLG decoding is faster and the oracle WER # is slightly worse than that of rescored lattices. - best_path = nbest_oracle( - lattice=lattice, - num_paths=params.num_paths, - ref_texts=supervisions["text"], - word_table=word_table, - lattice_score_scale=params.lattice_score_scale, - oov="", - ) - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - return {"nbest-orcale": hyps} + if True: + # TODO: delete the `else` branch + best_path = nbest_oracle2( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + lattice_score_scale=params.lattice_score_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + return {"nbest-orcale": hyps} + else: + return nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + scale=params.lattice_score_scale, + ) if params.method in ["1best", "nbest"]: if params.method == "1best": diff --git a/icefall/decode.py b/icefall/decode.py index dfac5700e..61900683f 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -669,7 +669,7 @@ def nbest_oracle( word_seq = k2.ragged.index(lattice.aux_labels, path) else: word_seq = lattice.aux_labels.index(path) - word_seq = word_seq.remove_axis(1) + word_seq = word_seq.remove_axis(word_seq.num_axes - 2) word_seq = word_seq.remove_values_leq(0) unique_word_seq, _, _ = word_seq.unique( diff --git a/icefall/decode2.py b/icefall/decode2.py index 2f21925bb..1c393a920 100644 --- a/icefall/decode2.py +++ b/icefall/decode2.py @@ -25,7 +25,8 @@ import torch from icefall.utils import get_texts -# TODO(fangjun): Use Kangwei's C++ implementation that also List[List[int]] +# TODO(fangjun): Use Kangwei's C++ implementation that also +# supports List[List[int]] def levenshtein_graph(symbol_ids: List[int]) -> k2.Fsa: """Construct a graph to compute Levenshtein distance. @@ -42,10 +43,10 @@ def levenshtein_graph(symbol_ids: List[int]) -> k2.Fsa: final_state = len(symbol_ids) + 1 arcs = [] for i in range(final_state - 1): - arcs.append([i, i, 0, 0, -1.01]) + arcs.append([i, i, 0, 0, -0.5]) + arcs.append([i, i + 1, 0, symbol_ids[i], -0.5]) arcs.append([i, i + 1, symbol_ids[i], symbol_ids[i], 0]) - arcs.append([i, i + 1, 0, symbol_ids[i], -1]) - arcs.append([final_state - 1, final_state - 1, 0, 0, -1.01]) + arcs.append([final_state - 1, final_state - 1, 0, 0, -0.5]) arcs.append([final_state - 1, final_state, -1, -1, 0]) arcs.append([final_state]) @@ -134,7 +135,8 @@ class Nbest(object): if isinstance(lattice.aux_labels, torch.Tensor): word_seq = k2.ragged.index(lattice.aux_labels, path) else: - word_seq = lattice.aux_labels.index(path, remove_axis=True) + word_seq = lattice.aux_labels.index(path) + word_seq = word_seq.remove_axis(word_seq.num_axes - 2) # Each utterance has `num_paths` paths but some of them transduces # to the same word sequence, so we need to remove repeated word @@ -170,11 +172,11 @@ class Nbest(object): # It has 2 axes [arc][word], so aux_labels is also a ragged tensor # with 2 axes [arc][word] aux_labels, _ = lattice.aux_labels.index( - indexes=kept_path.data, axis=0, need_value_indexes=False + indexes=kept_path.values, axis=0, need_value_indexes=False ) else: assert isinstance(lattice.aux_labels, torch.Tensor) - aux_labels = k2.index_select(lattice.aux_labels, kept_path.data) + aux_labels = k2.index_select(lattice.aux_labels, kept_path.values) # aux_labels is a 1-D torch.Tensor. It also contains -1 and 0. fsa = k2.linear_fsa(labels) diff --git a/test/test_decode.py b/test/test_decode.py index 93bddf23f..ca89c6e02 100644 --- a/test/test_decode.py +++ b/test/test_decode.py @@ -40,7 +40,10 @@ def test_nbest_from_lattice(): lattice = k2.Fsa.from_fsas([lattice, lattice]) nbest = Nbest.from_lattice( - lattice=lattice, num_paths=10, use_double_scores=True, scale=0.5 + lattice=lattice, + num_paths=10, + use_double_scores=True, + lattice_score_scale=0.5, ) # each lattice has only 4 distinct paths that have different word sequences: # 10->30