diff --git a/icefall/decode.py b/icefall/decode.py index dfac5700e..29b76d973 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -222,7 +222,7 @@ def nbest_decoding( word_seq = k2.ragged.index(lattice.aux_labels, path) else: word_seq = lattice.aux_labels.index(path) - word_seq = word_seq.remove_axis(1) + word_seq = word_seq.remove_axis(word_seq.num_axes - 2) # Remove 0 (epsilon) and -1 from word_seq word_seq = word_seq.remove_values_leq(0) @@ -432,7 +432,7 @@ def rescore_with_n_best_list( word_seq = k2.ragged.index(lattice.aux_labels, path) else: word_seq = lattice.aux_labels.index(path) - word_seq = word_seq.remove_axis(1) + word_seq = word_seq.remove_axis(word_seq.num_axes - 2) # Remove epsilons and -1 from word_seq word_seq = word_seq.remove_values_leq(0) @@ -669,7 +669,7 @@ def nbest_oracle( word_seq = k2.ragged.index(lattice.aux_labels, path) else: word_seq = lattice.aux_labels.index(path) - word_seq = word_seq.remove_axis(1) + word_seq = word_seq.remove_axis(word_seq.num_axes - 2) word_seq = word_seq.remove_values_leq(0) unique_word_seq, _, _ = word_seq.unique( @@ -761,7 +761,7 @@ def rescore_with_attention_decoder( word_seq = k2.ragged.index(lattice.aux_labels, path) else: word_seq = lattice.aux_labels.index(path) - word_seq = word_seq.remove_axis(1) + word_seq = word_seq.remove_axis(word_seq.num_axes - 2) # Remove epsilons and -1 from word_seq word_seq = word_seq.remove_values_leq(0) @@ -815,7 +815,7 @@ def rescore_with_attention_decoder( token_seq = k2.ragged.index(lattice.tokens, path) else: token_seq = lattice.tokens.index(path) - token_seq = token_seq.remove_axis(1) + token_seq = token_seq.remove_axis(token_seq.num_axes - 2) # Remove epsilons and -1 from token_seq token_seq = token_seq.remove_values_leq(0)