Fix code style

This commit is contained in:
pkufool 2021-09-14 10:54:07 +08:00
parent 7f8e3a673a
commit 8ac632cec8
2 changed files with 17 additions and 10 deletions

View File

@ -221,7 +221,8 @@ def nbest_decoding(
if isinstance(lattice.aux_labels, torch.Tensor): if isinstance(lattice.aux_labels, torch.Tensor):
word_seq = k2.ragged.index(lattice.aux_labels, path) word_seq = k2.ragged.index(lattice.aux_labels, path)
else: else:
word_seq = lattice.aux_labels.index(path, remove_axis=True) word_seq = lattice.aux_labels.index(path)
word_seq = word_seq.remove_axis(1)
# Remove 0 (epsilon) and -1 from word_seq # Remove 0 (epsilon) and -1 from word_seq
word_seq = word_seq.remove_values_leq(0) word_seq = word_seq.remove_values_leq(0)
@ -300,7 +301,7 @@ def nbest_decoding(
# lattice.aux_labels is a k2.RaggedTensor with 2 axes, so # lattice.aux_labels is a k2.RaggedTensor with 2 axes, so
# aux_labels is also a k2.RaggedTensor with 2 axes # aux_labels is also a k2.RaggedTensor with 2 axes
aux_labels, _ = lattice.aux_labels.index( aux_labels, _ = lattice.aux_labels.index(
indexes=best_path.data, axis=0, need_value_indexes=False indexes=best_path.values, axis=0, need_value_indexes=False
) )
best_path_fsa = k2.linear_fsa(labels) best_path_fsa = k2.linear_fsa(labels)
@ -430,7 +431,8 @@ def rescore_with_n_best_list(
if isinstance(lattice.aux_labels, torch.Tensor): if isinstance(lattice.aux_labels, torch.Tensor):
word_seq = k2.ragged.index(lattice.aux_labels, path) word_seq = k2.ragged.index(lattice.aux_labels, path)
else: else:
word_seq = lattice.aux_labels.index(path, remove_axis=True) word_seq = lattice.aux_labels.index(path)
word_seq = word_seq.remove_axis(1)
# Remove epsilons and -1 from word_seq # Remove epsilons and -1 from word_seq
word_seq = word_seq.remove_values_leq(0) word_seq = word_seq.remove_values_leq(0)
@ -520,7 +522,7 @@ def rescore_with_n_best_list(
# aux_labels is also a k2.RaggedTensor with 2 axes # aux_labels is also a k2.RaggedTensor with 2 axes
aux_labels, _ = lattice.aux_labels.index( aux_labels, _ = lattice.aux_labels.index(
indexes=best_path.data, axis=0, need_value_indexes=False indexes=best_path.values, axis=0, need_value_indexes=False
) )
best_path_fsa = k2.linear_fsa(labels) best_path_fsa = k2.linear_fsa(labels)
@ -666,7 +668,8 @@ def nbest_oracle(
if isinstance(lattice.aux_labels, torch.Tensor): if isinstance(lattice.aux_labels, torch.Tensor):
word_seq = k2.ragged.index(lattice.aux_labels, path) word_seq = k2.ragged.index(lattice.aux_labels, path)
else: else:
word_seq = lattice.aux_labels.index(path, remove_axis=True) word_seq = lattice.aux_labels.index(path)
word_seq = word_seq.remove_axis(1)
word_seq = word_seq.remove_values_leq(0) word_seq = word_seq.remove_values_leq(0)
unique_word_seq, _, _ = word_seq.unique( unique_word_seq, _, _ = word_seq.unique(
@ -757,7 +760,8 @@ def rescore_with_attention_decoder(
if isinstance(lattice.aux_labels, torch.Tensor): if isinstance(lattice.aux_labels, torch.Tensor):
word_seq = k2.ragged.index(lattice.aux_labels, path) word_seq = k2.ragged.index(lattice.aux_labels, path)
else: else:
word_seq = lattice.aux_labels.index(path, remove_axis=True) word_seq = lattice.aux_labels.index(path)
word_seq = word_seq.remove_axis(1)
# Remove epsilons and -1 from word_seq # Remove epsilons and -1 from word_seq
word_seq = word_seq.remove_values_leq(0) word_seq = word_seq.remove_values_leq(0)
@ -810,7 +814,8 @@ def rescore_with_attention_decoder(
if isinstance(lattice.tokens, torch.Tensor): if isinstance(lattice.tokens, torch.Tensor):
token_seq = k2.ragged.index(lattice.tokens, path) token_seq = k2.ragged.index(lattice.tokens, path)
else: else:
token_seq = lattice.tokens.index(path, remove_axis=True) token_seq = lattice.tokens.index(path)
token_seq = token_seq.remove_axis(1)
# Remove epsilons and -1 from token_seq # Remove epsilons and -1 from token_seq
token_seq = token_seq.remove_values_leq(0) token_seq = token_seq.remove_values_leq(0)
@ -890,10 +895,12 @@ def rescore_with_attention_decoder(
labels = labels.remove_values_eq(-1) labels = labels.remove_values_eq(-1)
if isinstance(lattice.aux_labels, torch.Tensor): if isinstance(lattice.aux_labels, torch.Tensor):
aux_labels = k2.index_select(lattice.aux_labels, best_path.data) aux_labels = k2.index_select(
lattice.aux_labels, best_path.values
)
else: else:
aux_labels, _ = lattice.aux_labels.index( aux_labels, _ = lattice.aux_labels.index(
indexes=best_path.data, axis=0, need_value_indexes=False indexes=best_path.values, axis=0, need_value_indexes=False
) )
best_path_fsa = k2.linear_fsa(labels) best_path_fsa = k2.linear_fsa(labels)

View File

@ -207,7 +207,7 @@ def get_texts(best_paths: k2.Fsa) -> List[List[int]]:
# remove the states and arcs axes. # remove the states and arcs axes.
aux_shape = aux_shape.remove_axis(1) aux_shape = aux_shape.remove_axis(1)
aux_shape = aux_shape.remove_axis(1) aux_shape = aux_shape.remove_axis(1)
aux_labels = k2.RaggedTensor(aux_shape, aux_labels.data) aux_labels = k2.RaggedTensor(aux_shape, aux_labels.values)
else: else:
# remove axis corresponding to states. # remove axis corresponding to states.
aux_shape = best_paths.arcs.shape().remove_axis(1) aux_shape = best_paths.arcs.shape().remove_axis(1)