mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 20:12:24 +00:00
Fix decode.py to remove the correct axis.
This commit is contained in:
parent
9a6e0489c8
commit
e753b4509a
@ -222,7 +222,7 @@ def nbest_decoding(
|
||||
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
||||
else:
|
||||
word_seq = lattice.aux_labels.index(path)
|
||||
word_seq = word_seq.remove_axis(1)
|
||||
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
|
||||
|
||||
# Remove 0 (epsilon) and -1 from word_seq
|
||||
word_seq = word_seq.remove_values_leq(0)
|
||||
@ -432,7 +432,7 @@ def rescore_with_n_best_list(
|
||||
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
||||
else:
|
||||
word_seq = lattice.aux_labels.index(path)
|
||||
word_seq = word_seq.remove_axis(1)
|
||||
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
|
||||
|
||||
# Remove epsilons and -1 from word_seq
|
||||
word_seq = word_seq.remove_values_leq(0)
|
||||
@ -669,7 +669,7 @@ def nbest_oracle(
|
||||
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
||||
else:
|
||||
word_seq = lattice.aux_labels.index(path)
|
||||
word_seq = word_seq.remove_axis(1)
|
||||
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
|
||||
|
||||
word_seq = word_seq.remove_values_leq(0)
|
||||
unique_word_seq, _, _ = word_seq.unique(
|
||||
@ -761,7 +761,7 @@ def rescore_with_attention_decoder(
|
||||
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
||||
else:
|
||||
word_seq = lattice.aux_labels.index(path)
|
||||
word_seq = word_seq.remove_axis(1)
|
||||
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
|
||||
|
||||
# Remove epsilons and -1 from word_seq
|
||||
word_seq = word_seq.remove_values_leq(0)
|
||||
@ -815,7 +815,7 @@ def rescore_with_attention_decoder(
|
||||
token_seq = k2.ragged.index(lattice.tokens, path)
|
||||
else:
|
||||
token_seq = lattice.tokens.index(path)
|
||||
token_seq = token_seq.remove_axis(1)
|
||||
token_seq = token_seq.remove_axis(token_seq.num_axes - 2)
|
||||
|
||||
# Remove epsilons and -1 from token_seq
|
||||
token_seq = token_seq.remove_values_leq(0)
|
||||
|
Loading…
x
Reference in New Issue
Block a user