mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 20:12:24 +00:00
Fix code style
This commit is contained in:
parent
7f8e3a673a
commit
8ac632cec8
@ -221,7 +221,8 @@ def nbest_decoding(
|
||||
if isinstance(lattice.aux_labels, torch.Tensor):
|
||||
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
||||
else:
|
||||
word_seq = lattice.aux_labels.index(path, remove_axis=True)
|
||||
word_seq = lattice.aux_labels.index(path)
|
||||
word_seq = word_seq.remove_axis(1)
|
||||
|
||||
# Remove 0 (epsilon) and -1 from word_seq
|
||||
word_seq = word_seq.remove_values_leq(0)
|
||||
@ -300,7 +301,7 @@ def nbest_decoding(
|
||||
# lattice.aux_labels is a k2.RaggedTensor with 2 axes, so
|
||||
# aux_labels is also a k2.RaggedTensor with 2 axes
|
||||
aux_labels, _ = lattice.aux_labels.index(
|
||||
indexes=best_path.data, axis=0, need_value_indexes=False
|
||||
indexes=best_path.values, axis=0, need_value_indexes=False
|
||||
)
|
||||
|
||||
best_path_fsa = k2.linear_fsa(labels)
|
||||
@ -430,7 +431,8 @@ def rescore_with_n_best_list(
|
||||
if isinstance(lattice.aux_labels, torch.Tensor):
|
||||
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
||||
else:
|
||||
word_seq = lattice.aux_labels.index(path, remove_axis=True)
|
||||
word_seq = lattice.aux_labels.index(path)
|
||||
word_seq = word_seq.remove_axis(1)
|
||||
|
||||
# Remove epsilons and -1 from word_seq
|
||||
word_seq = word_seq.remove_values_leq(0)
|
||||
@ -520,7 +522,7 @@ def rescore_with_n_best_list(
|
||||
# aux_labels is also a k2.RaggedTensor with 2 axes
|
||||
|
||||
aux_labels, _ = lattice.aux_labels.index(
|
||||
indexes=best_path.data, axis=0, need_value_indexes=False
|
||||
indexes=best_path.values, axis=0, need_value_indexes=False
|
||||
)
|
||||
|
||||
best_path_fsa = k2.linear_fsa(labels)
|
||||
@ -666,7 +668,8 @@ def nbest_oracle(
|
||||
if isinstance(lattice.aux_labels, torch.Tensor):
|
||||
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
||||
else:
|
||||
word_seq = lattice.aux_labels.index(path, remove_axis=True)
|
||||
word_seq = lattice.aux_labels.index(path)
|
||||
word_seq = word_seq.remove_axis(1)
|
||||
|
||||
word_seq = word_seq.remove_values_leq(0)
|
||||
unique_word_seq, _, _ = word_seq.unique(
|
||||
@ -757,7 +760,8 @@ def rescore_with_attention_decoder(
|
||||
if isinstance(lattice.aux_labels, torch.Tensor):
|
||||
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
||||
else:
|
||||
word_seq = lattice.aux_labels.index(path, remove_axis=True)
|
||||
word_seq = lattice.aux_labels.index(path)
|
||||
word_seq = word_seq.remove_axis(1)
|
||||
|
||||
# Remove epsilons and -1 from word_seq
|
||||
word_seq = word_seq.remove_values_leq(0)
|
||||
@ -810,7 +814,8 @@ def rescore_with_attention_decoder(
|
||||
if isinstance(lattice.tokens, torch.Tensor):
|
||||
token_seq = k2.ragged.index(lattice.tokens, path)
|
||||
else:
|
||||
token_seq = lattice.tokens.index(path, remove_axis=True)
|
||||
token_seq = lattice.tokens.index(path)
|
||||
token_seq = token_seq.remove_axis(1)
|
||||
|
||||
# Remove epsilons and -1 from token_seq
|
||||
token_seq = token_seq.remove_values_leq(0)
|
||||
@ -890,10 +895,12 @@ def rescore_with_attention_decoder(
|
||||
labels = labels.remove_values_eq(-1)
|
||||
|
||||
if isinstance(lattice.aux_labels, torch.Tensor):
|
||||
aux_labels = k2.index_select(lattice.aux_labels, best_path.data)
|
||||
aux_labels = k2.index_select(
|
||||
lattice.aux_labels, best_path.values
|
||||
)
|
||||
else:
|
||||
aux_labels, _ = lattice.aux_labels.index(
|
||||
indexes=best_path.data, axis=0, need_value_indexes=False
|
||||
indexes=best_path.values, axis=0, need_value_indexes=False
|
||||
)
|
||||
|
||||
best_path_fsa = k2.linear_fsa(labels)
|
||||
|
@ -207,7 +207,7 @@ def get_texts(best_paths: k2.Fsa) -> List[List[int]]:
|
||||
# remove the states and arcs axes.
|
||||
aux_shape = aux_shape.remove_axis(1)
|
||||
aux_shape = aux_shape.remove_axis(1)
|
||||
aux_labels = k2.RaggedTensor(aux_shape, aux_labels.data)
|
||||
aux_labels = k2.RaggedTensor(aux_shape, aux_labels.values)
|
||||
else:
|
||||
# remove axis corresponding to states.
|
||||
aux_shape = best_paths.arcs.shape().remove_axis(1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user