Fix decode.py to remove the correct axis.

This commit is contained in:
Fangjun Kuang 2021-09-17 16:13:24 +08:00
parent 9a6e0489c8
commit e753b4509a

View File

@ -222,7 +222,7 @@ def nbest_decoding(
word_seq = k2.ragged.index(lattice.aux_labels, path)
else:
word_seq = lattice.aux_labels.index(path)
word_seq = word_seq.remove_axis(1)
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
# Remove 0 (epsilon) and -1 from word_seq
word_seq = word_seq.remove_values_leq(0)
@ -432,7 +432,7 @@ def rescore_with_n_best_list(
word_seq = k2.ragged.index(lattice.aux_labels, path)
else:
word_seq = lattice.aux_labels.index(path)
word_seq = word_seq.remove_axis(1)
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
# Remove epsilons and -1 from word_seq
word_seq = word_seq.remove_values_leq(0)
@ -669,7 +669,7 @@ def nbest_oracle(
word_seq = k2.ragged.index(lattice.aux_labels, path)
else:
word_seq = lattice.aux_labels.index(path)
word_seq = word_seq.remove_axis(1)
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
word_seq = word_seq.remove_values_leq(0)
unique_word_seq, _, _ = word_seq.unique(
@ -761,7 +761,7 @@ def rescore_with_attention_decoder(
word_seq = k2.ragged.index(lattice.aux_labels, path)
else:
word_seq = lattice.aux_labels.index(path)
word_seq = word_seq.remove_axis(1)
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
# Remove epsilons and -1 from word_seq
word_seq = word_seq.remove_values_leq(0)
@ -815,7 +815,7 @@ def rescore_with_attention_decoder(
token_seq = k2.ragged.index(lattice.tokens, path)
else:
token_seq = lattice.tokens.index(path)
token_seq = token_seq.remove_axis(1)
token_seq = token_seq.remove_axis(token_seq.num_axes - 2)
# Remove epsilons and -1 from token_seq
token_seq = token_seq.remove_values_leq(0)