From e753b4509af7c9cb4cbdff86b08fe3cd24cc1ba7 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 17 Sep 2021 16:13:24 +0800 Subject: [PATCH] Fix decode.py to remove the correct axis. --- icefall/decode.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/icefall/decode.py b/icefall/decode.py index dfac5700e..29b76d973 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -222,7 +222,7 @@ def nbest_decoding( word_seq = k2.ragged.index(lattice.aux_labels, path) else: word_seq = lattice.aux_labels.index(path) - word_seq = word_seq.remove_axis(1) + word_seq = word_seq.remove_axis(word_seq.num_axes - 2) # Remove 0 (epsilon) and -1 from word_seq word_seq = word_seq.remove_values_leq(0) @@ -432,7 +432,7 @@ def rescore_with_n_best_list( word_seq = k2.ragged.index(lattice.aux_labels, path) else: word_seq = lattice.aux_labels.index(path) - word_seq = word_seq.remove_axis(1) + word_seq = word_seq.remove_axis(word_seq.num_axes - 2) # Remove epsilons and -1 from word_seq word_seq = word_seq.remove_values_leq(0) @@ -669,7 +669,7 @@ def nbest_oracle( word_seq = k2.ragged.index(lattice.aux_labels, path) else: word_seq = lattice.aux_labels.index(path) - word_seq = word_seq.remove_axis(1) + word_seq = word_seq.remove_axis(word_seq.num_axes - 2) word_seq = word_seq.remove_values_leq(0) unique_word_seq, _, _ = word_seq.unique( @@ -761,7 +761,7 @@ def rescore_with_attention_decoder( word_seq = k2.ragged.index(lattice.aux_labels, path) else: word_seq = lattice.aux_labels.index(path) - word_seq = word_seq.remove_axis(1) + word_seq = word_seq.remove_axis(word_seq.num_axes - 2) # Remove epsilons and -1 from word_seq word_seq = word_seq.remove_values_leq(0) @@ -815,7 +815,7 @@ def rescore_with_attention_decoder( token_seq = k2.ragged.index(lattice.tokens, path) else: token_seq = lattice.tokens.index(path) - token_seq = token_seq.remove_axis(1) + token_seq = token_seq.remove_axis(token_seq.num_axes - 2) # Remove epsilons and -1 from token_seq token_seq = token_seq.remove_values_leq(0)