From 8ac632cec8c129af1c106d36d1abd895305cfb7d Mon Sep 17 00:00:00 2001 From: pkufool Date: Tue, 14 Sep 2021 10:54:07 +0800 Subject: [PATCH] Fix code style --- icefall/decode.py | 25 ++++++++++++++++--------- icefall/utils.py | 2 +- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/icefall/decode.py b/icefall/decode.py index 3f6e5fc84..dfac5700e 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -221,7 +221,8 @@ def nbest_decoding( if isinstance(lattice.aux_labels, torch.Tensor): word_seq = k2.ragged.index(lattice.aux_labels, path) else: - word_seq = lattice.aux_labels.index(path, remove_axis=True) + word_seq = lattice.aux_labels.index(path) + word_seq = word_seq.remove_axis(1) # Remove 0 (epsilon) and -1 from word_seq word_seq = word_seq.remove_values_leq(0) @@ -300,7 +301,7 @@ def nbest_decoding( # lattice.aux_labels is a k2.RaggedTensor with 2 axes, so # aux_labels is also a k2.RaggedTensor with 2 axes aux_labels, _ = lattice.aux_labels.index( - indexes=best_path.data, axis=0, need_value_indexes=False + indexes=best_path.values, axis=0, need_value_indexes=False ) best_path_fsa = k2.linear_fsa(labels) @@ -430,7 +431,8 @@ def rescore_with_n_best_list( if isinstance(lattice.aux_labels, torch.Tensor): word_seq = k2.ragged.index(lattice.aux_labels, path) else: - word_seq = lattice.aux_labels.index(path, remove_axis=True) + word_seq = lattice.aux_labels.index(path) + word_seq = word_seq.remove_axis(1) # Remove epsilons and -1 from word_seq word_seq = word_seq.remove_values_leq(0) @@ -520,7 +522,7 @@ def rescore_with_n_best_list( # aux_labels is also a k2.RaggedTensor with 2 axes aux_labels, _ = lattice.aux_labels.index( - indexes=best_path.data, axis=0, need_value_indexes=False + indexes=best_path.values, axis=0, need_value_indexes=False ) best_path_fsa = k2.linear_fsa(labels) @@ -666,7 +668,8 @@ def nbest_oracle( if isinstance(lattice.aux_labels, torch.Tensor): word_seq = k2.ragged.index(lattice.aux_labels, path) else: - word_seq = lattice.aux_labels.index(path, remove_axis=True) + word_seq = lattice.aux_labels.index(path) + word_seq = word_seq.remove_axis(1) word_seq = word_seq.remove_values_leq(0) unique_word_seq, _, _ = word_seq.unique( @@ -757,7 +760,8 @@ def rescore_with_attention_decoder( if isinstance(lattice.aux_labels, torch.Tensor): word_seq = k2.ragged.index(lattice.aux_labels, path) else: - word_seq = lattice.aux_labels.index(path, remove_axis=True) + word_seq = lattice.aux_labels.index(path) + word_seq = word_seq.remove_axis(1) # Remove epsilons and -1 from word_seq word_seq = word_seq.remove_values_leq(0) @@ -810,7 +814,8 @@ def rescore_with_attention_decoder( if isinstance(lattice.tokens, torch.Tensor): token_seq = k2.ragged.index(lattice.tokens, path) else: - token_seq = lattice.tokens.index(path, remove_axis=True) + token_seq = lattice.tokens.index(path) + token_seq = token_seq.remove_axis(1) # Remove epsilons and -1 from token_seq token_seq = token_seq.remove_values_leq(0) @@ -890,10 +895,12 @@ def rescore_with_attention_decoder( labels = labels.remove_values_eq(-1) if isinstance(lattice.aux_labels, torch.Tensor): - aux_labels = k2.index_select(lattice.aux_labels, best_path.data) + aux_labels = k2.index_select( + lattice.aux_labels, best_path.values + ) else: aux_labels, _ = lattice.aux_labels.index( - indexes=best_path.data, axis=0, need_value_indexes=False + indexes=best_path.values, axis=0, need_value_indexes=False ) best_path_fsa = k2.linear_fsa(labels) diff --git a/icefall/utils.py b/icefall/utils.py index 1130d8947..cc658ae32 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -207,7 +207,7 @@ def get_texts(best_paths: k2.Fsa) -> List[List[int]]: # remove the states and arcs axes. aux_shape = aux_shape.remove_axis(1) aux_shape = aux_shape.remove_axis(1) - aux_labels = k2.RaggedTensor(aux_shape, aux_labels.data) + aux_labels = k2.RaggedTensor(aux_shape, aux_labels.values) else: # remove axis corresponding to states. aux_shape = best_paths.arcs.shape().remove_axis(1)