mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-14 20:42:22 +00:00
Fix decode.py to remove the correct axis.
This commit is contained in:
parent
9a6e0489c8
commit
e753b4509a
@ -222,7 +222,7 @@ def nbest_decoding(
|
|||||||
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
||||||
else:
|
else:
|
||||||
word_seq = lattice.aux_labels.index(path)
|
word_seq = lattice.aux_labels.index(path)
|
||||||
word_seq = word_seq.remove_axis(1)
|
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
|
||||||
|
|
||||||
# Remove 0 (epsilon) and -1 from word_seq
|
# Remove 0 (epsilon) and -1 from word_seq
|
||||||
word_seq = word_seq.remove_values_leq(0)
|
word_seq = word_seq.remove_values_leq(0)
|
||||||
@ -432,7 +432,7 @@ def rescore_with_n_best_list(
|
|||||||
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
||||||
else:
|
else:
|
||||||
word_seq = lattice.aux_labels.index(path)
|
word_seq = lattice.aux_labels.index(path)
|
||||||
word_seq = word_seq.remove_axis(1)
|
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
|
||||||
|
|
||||||
# Remove epsilons and -1 from word_seq
|
# Remove epsilons and -1 from word_seq
|
||||||
word_seq = word_seq.remove_values_leq(0)
|
word_seq = word_seq.remove_values_leq(0)
|
||||||
@ -669,7 +669,7 @@ def nbest_oracle(
|
|||||||
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
||||||
else:
|
else:
|
||||||
word_seq = lattice.aux_labels.index(path)
|
word_seq = lattice.aux_labels.index(path)
|
||||||
word_seq = word_seq.remove_axis(1)
|
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
|
||||||
|
|
||||||
word_seq = word_seq.remove_values_leq(0)
|
word_seq = word_seq.remove_values_leq(0)
|
||||||
unique_word_seq, _, _ = word_seq.unique(
|
unique_word_seq, _, _ = word_seq.unique(
|
||||||
@ -761,7 +761,7 @@ def rescore_with_attention_decoder(
|
|||||||
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
||||||
else:
|
else:
|
||||||
word_seq = lattice.aux_labels.index(path)
|
word_seq = lattice.aux_labels.index(path)
|
||||||
word_seq = word_seq.remove_axis(1)
|
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
|
||||||
|
|
||||||
# Remove epsilons and -1 from word_seq
|
# Remove epsilons and -1 from word_seq
|
||||||
word_seq = word_seq.remove_values_leq(0)
|
word_seq = word_seq.remove_values_leq(0)
|
||||||
@ -815,7 +815,7 @@ def rescore_with_attention_decoder(
|
|||||||
token_seq = k2.ragged.index(lattice.tokens, path)
|
token_seq = k2.ragged.index(lattice.tokens, path)
|
||||||
else:
|
else:
|
||||||
token_seq = lattice.tokens.index(path)
|
token_seq = lattice.tokens.index(path)
|
||||||
token_seq = token_seq.remove_axis(1)
|
token_seq = token_seq.remove_axis(token_seq.num_axes - 2)
|
||||||
|
|
||||||
# Remove epsilons and -1 from token_seq
|
# Remove epsilons and -1 from token_seq
|
||||||
token_seq = token_seq.remove_values_leq(0)
|
token_seq = token_seq.remove_values_leq(0)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user