diff --git a/.gitignore b/.gitignore index 839a1c34a..e6c84ca5e 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,4 @@ path.sh exp exp*/ *.pt -download/ +download diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index ff6374d73..cfdcff756 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -45,6 +45,7 @@ from icefall.utils import ( get_texts, setup_logger, store_transcripts, + str2bool, write_error_stats, ) @@ -116,6 +117,17 @@ def get_parser(): """, ) + parser.add_argument( + "--export", + type=str2bool, + default=False, + help="""When enabled, the averaged model is saved to + conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved. + pretrained.pt contains a dict {"model": model.state_dict()}, + which can be loaded by `icefall.checkpoint.load_checkpoint()`. + """, + ) + return parser @@ -541,6 +553,13 @@ def main(): logging.info(f"averaging {filenames}") model.load_state_dict(average_checkpoints(filenames)) + if params.export: + logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") + torch.save( + {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" + ) + return + model.to(device) model.eval() num_param = sum([p.numel() for p in model.parameters()]) diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index 19a1ddd23..407fb7d88 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -102,14 +102,14 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: LG.labels[LG.labels >= first_token_disambig_id] = 0 - assert isinstance(LG.aux_labels, k2.RaggedInt) - LG.aux_labels.values()[LG.aux_labels.values() >= first_word_disambig_id] = 0 + assert isinstance(LG.aux_labels, k2.RaggedTensor) + LG.aux_labels.data[LG.aux_labels.data >= first_word_disambig_id] = 0 LG = k2.remove_epsilon(LG) logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") LG = k2.connect(LG) - LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0) + LG.aux_labels = LG.aux_labels.remove_values_eq(0) logging.info("Arc sorting LG") LG = k2.arc_sort(LG) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index afdebd12b..87e9cddb4 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -99,8 +99,10 @@ def get_params() -> AttributeDict: # - nbest-rescoring # - whole-lattice-rescoring "method": "whole-lattice-rescoring", + # "method": "1best", + # "method": "nbest", # num_paths is used when method is "nbest" and "nbest-rescoring" - "num_paths": 30, + "num_paths": 100, } ) return params @@ -424,6 +426,7 @@ def main(): torch.save( {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" ) + return model.to(device) model.eval() diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py old mode 100644 new mode 100755 diff --git a/egs/yesno/ASR/local/compile_hlg.py b/egs/yesno/ASR/local/compile_hlg.py index f2fafd013..41a927455 100755 --- a/egs/yesno/ASR/local/compile_hlg.py +++ b/egs/yesno/ASR/local/compile_hlg.py @@ -80,14 +80,14 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: LG.labels[LG.labels >= first_token_disambig_id] = 0 - assert isinstance(LG.aux_labels, k2.RaggedInt) - LG.aux_labels.values()[LG.aux_labels.values() >= first_word_disambig_id] = 0 + assert isinstance(LG.aux_labels, k2.RaggedTensor) + LG.aux_labels.data[LG.aux_labels.data >= first_word_disambig_id] = 0 LG = k2.remove_epsilon(LG) logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") LG = k2.connect(LG) - LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0) + LG.aux_labels = LG.aux_labels.remove_values_eq(0) logging.info("Arc sorting LG") LG = k2.arc_sort(LG) diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py index aa7b07b98..54fdbb3cc 100755 --- a/egs/yesno/ASR/tdnn/decode.py +++ b/egs/yesno/ASR/tdnn/decode.py @@ -296,6 +296,7 @@ def main(): torch.save( {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" ) + return model.to(device) model.eval() diff --git a/icefall/decode.py b/icefall/decode.py index de3219401..3f6e5fc84 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -84,8 +84,8 @@ def _intersect_device( for start, end in splits: indexes = torch.arange(start, end).to(b_to_a_map) - fsas = k2.index(b_fsas, indexes) - b_to_a = k2.index(b_to_a_map, indexes) + fsas = k2.index_fsa(b_fsas, indexes) + b_to_a = k2.index_select(b_to_a_map, indexes) path_lattice = k2.intersect_device( a_fsas, fsas, b_to_a_map=b_to_a, sorted_match_a=sorted_match_a ) @@ -215,18 +215,16 @@ def nbest_decoding( scale=scale, ) - # word_seq is a k2.RaggedInt sharing the same shape as `path` + # word_seq is a k2.RaggedTensor sharing the same shape as `path` # but it contains word IDs. Note that it also contains 0s and -1s. # The last entry in each sublist is -1. - word_seq = k2.index(lattice.aux_labels, path) - # Note: the above operation supports also the case when - # lattice.aux_labels is a ragged tensor. In that case, - # `remove_axis=True` is used inside the pybind11 binding code, - # so the resulting `word_seq` still has 3 axes, like `path`. - # The 3 axes are [seq][path][word_id] + if isinstance(lattice.aux_labels, torch.Tensor): + word_seq = k2.ragged.index(lattice.aux_labels, path) + else: + word_seq = lattice.aux_labels.index(path, remove_axis=True) # Remove 0 (epsilon) and -1 from word_seq - word_seq = k2.ragged.remove_values_leq(word_seq, 0) + word_seq = word_seq.remove_values_leq(0) # Remove sequences with identical word sequences. # @@ -234,12 +232,12 @@ def nbest_decoding( # `new2old` is a 1-D torch.Tensor mapping from the output path index # to the input path index. # new2old.numel() == unique_word_seqs.tot_size(1) - unique_word_seq, _, new2old = k2.ragged.unique_sequences( - word_seq, need_num_repeats=False, need_new2old_indexes=True + unique_word_seq, _, new2old = word_seq.unique( + need_num_repeats=False, need_new2old_indexes=True ) # Note: unique_word_seq still has the same axes as word_seq - seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0) + seq_to_path_shape = unique_word_seq.shape.get_layer(0) # path_to_seq_map is a 1-D torch.Tensor. # path_to_seq_map[i] is the seq to which the i-th path belongs @@ -247,7 +245,7 @@ def nbest_decoding( # Remove the seq axis. # Now unique_word_seq has only two axes [path][word] - unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0) + unique_word_seq = unique_word_seq.remove_axis(0) # word_fsa is an FsaVec with axes [path][state][arc] word_fsa = k2.linear_fsa(unique_word_seq) @@ -275,35 +273,35 @@ def nbest_decoding( use_double_scores=use_double_scores, log_semiring=False ) - # RaggedFloat currently supports float32 only. - # If Ragged is wrapped, we can use k2.RaggedDouble here - ragged_tot_scores = k2.RaggedFloat( - seq_to_path_shape, tot_scores.to(torch.float32) - ) + ragged_tot_scores = k2.RaggedTensor(seq_to_path_shape, tot_scores) - argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores) + argmax_indexes = ragged_tot_scores.argmax() # Since we invoked `k2.ragged.unique_sequences`, which reorders # the index from `path`, we use `new2old` here to convert argmax_indexes # to the indexes into `path`. # # Use k2.index here since argmax_indexes' dtype is torch.int32 - best_path_indexes = k2.index(new2old, argmax_indexes) + best_path_indexes = k2.index_select(new2old, argmax_indexes) - path_2axes = k2.ragged.remove_axis(path, 0) + path_2axes = path.remove_axis(0) - # best_path is a k2.RaggedInt with 2 axes [path][arc_pos] - best_path = k2.index(path_2axes, best_path_indexes) + # best_path is a k2.RaggedTensor with 2 axes [path][arc_pos] + best_path, _ = path_2axes.index( + indexes=best_path_indexes, axis=0, need_value_indexes=False + ) - # labels is a k2.RaggedInt with 2 axes [path][token_id] + # labels is a k2.RaggedTensor with 2 axes [path][token_id] # Note that it contains -1s. - labels = k2.index(lattice.labels.contiguous(), best_path) + labels = k2.ragged.index(lattice.labels.contiguous(), best_path) - labels = k2.ragged.remove_values_eq(labels, -1) + labels = labels.remove_values_eq(-1) - # lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so - # aux_labels is also a k2.RaggedInt with 2 axes - aux_labels = k2.index(lattice.aux_labels, best_path.values()) + # lattice.aux_labels is a k2.RaggedTensor with 2 axes, so + # aux_labels is also a k2.RaggedTensor with 2 axes + aux_labels, _ = lattice.aux_labels.index( + indexes=best_path.data, axis=0, need_value_indexes=False + ) best_path_fsa = k2.linear_fsa(labels) best_path_fsa.aux_labels = aux_labels @@ -426,33 +424,36 @@ def rescore_with_n_best_list( scale=scale, ) - # word_seq is a k2.RaggedInt sharing the same shape as `path` + # word_seq is a k2.RaggedTensor sharing the same shape as `path` # but it contains word IDs. Note that it also contains 0s and -1s. # The last entry in each sublist is -1. - word_seq = k2.index(lattice.aux_labels, path) + if isinstance(lattice.aux_labels, torch.Tensor): + word_seq = k2.ragged.index(lattice.aux_labels, path) + else: + word_seq = lattice.aux_labels.index(path, remove_axis=True) # Remove epsilons and -1 from word_seq - word_seq = k2.ragged.remove_values_leq(word_seq, 0) + word_seq = word_seq.remove_values_leq(0) # Remove paths that has identical word sequences. # - # unique_word_seq is still a k2.RaggedInt with 3 axes [seq][path][word] + # unique_word_seq is still a k2.RaggedTensor with 3 axes [seq][path][word] # except that there are no repeated paths with the same word_seq # within a sequence. # - # num_repeats is also a k2.RaggedInt with 2 axes containing the + # num_repeats is also a k2.RaggedTensor with 2 axes containing the # multiplicities of each path. - # num_repeats.num_elements() == unique_word_seqs.tot_size(1) + # num_repeats.numel() == unique_word_seqs.tot_size(1) # # Since k2.ragged.unique_sequences will reorder paths within a seq, # `new2old` is a 1-D torch.Tensor mapping from the output path index # to the input path index. # new2old.numel() == unique_word_seqs.tot_size(1) - unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences( - word_seq, need_num_repeats=True, need_new2old_indexes=True + unique_word_seq, num_repeats, new2old = word_seq.unique( + need_num_repeats=True, need_new2old_indexes=True ) - seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0) + seq_to_path_shape = unique_word_seq.shape.get_layer(0) # path_to_seq_map is a 1-D torch.Tensor. # path_to_seq_map[i] is the seq to which the i-th path @@ -461,7 +462,7 @@ def rescore_with_n_best_list( # Remove the seq axis. # Now unique_word_seq has only two axes [path][word] - unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0) + unique_word_seq = unique_word_seq.remove_axis(0) # word_fsa is an FsaVec with axes [path][state][arc] word_fsa = k2.linear_fsa(unique_word_seq) @@ -485,39 +486,42 @@ def rescore_with_n_best_list( use_double_scores=True, log_semiring=False ) - path_2axes = k2.ragged.remove_axis(path, 0) + path_2axes = path.remove_axis(0) ans = dict() for lm_scale in lm_scale_list: tot_scores = am_scores / lm_scale + lm_scores - # Remember that we used `k2.ragged.unique_sequences` to remove repeated + # Remember that we used `k2.RaggedTensor.unique` to remove repeated # paths to avoid redundant computation in `k2.intersect_device`. # Now we use `num_repeats` to correct the scores for each path. # # NOTE(fangjun): It is commented out as it leads to a worse WER # tot_scores = tot_scores * num_repeats.values() - ragged_tot_scores = k2.RaggedFloat( - seq_to_path_shape, tot_scores.to(torch.float32) - ) - argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores) + ragged_tot_scores = k2.RaggedTensor(seq_to_path_shape, tot_scores) + argmax_indexes = ragged_tot_scores.argmax() # Use k2.index here since argmax_indexes' dtype is torch.int32 - best_path_indexes = k2.index(new2old, argmax_indexes) + best_path_indexes = k2.index_select(new2old, argmax_indexes) # best_path is a k2.RaggedInt with 2 axes [path][arc_pos] - best_path = k2.index(path_2axes, best_path_indexes) + best_path, _ = path_2axes.index( + indexes=best_path_indexes, axis=0, need_value_indexes=False + ) - # labels is a k2.RaggedInt with 2 axes [path][phone_id] + # labels is a k2.RaggedTensor with 2 axes [path][phone_id] # Note that it contains -1s. - labels = k2.index(lattice.labels.contiguous(), best_path) + labels = k2.ragged.index(lattice.labels.contiguous(), best_path) - labels = k2.ragged.remove_values_eq(labels, -1) + labels = labels.remove_values_eq(-1) - # lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so - # aux_labels is also a k2.RaggedInt with 2 axes - aux_labels = k2.index(lattice.aux_labels, best_path.values()) + # lattice.aux_labels is a k2.RaggedTensor tensor with 2 axes, so + # aux_labels is also a k2.RaggedTensor with 2 axes + + aux_labels, _ = lattice.aux_labels.index( + indexes=best_path.data, axis=0, need_value_indexes=False + ) best_path_fsa = k2.linear_fsa(labels) best_path_fsa.aux_labels = aux_labels @@ -659,12 +663,16 @@ def nbest_oracle( scale=scale, ) - word_seq = k2.index(lattice.aux_labels, path) - word_seq = k2.ragged.remove_values_leq(word_seq, 0) - unique_word_seq, _, _ = k2.ragged.unique_sequences( - word_seq, need_num_repeats=False, need_new2old_indexes=False + if isinstance(lattice.aux_labels, torch.Tensor): + word_seq = k2.ragged.index(lattice.aux_labels, path) + else: + word_seq = lattice.aux_labels.index(path, remove_axis=True) + + word_seq = word_seq.remove_values_leq(0) + unique_word_seq, _, _ = word_seq.unique( + need_num_repeats=False, need_new2old_indexes=False ) - unique_word_ids = k2.ragged.to_list(unique_word_seq) + unique_word_ids = unique_word_seq.tolist() assert len(unique_word_ids) == len(ref_texts) # unique_word_ids[i] contains all hypotheses of the i-th utterance @@ -743,33 +751,36 @@ def rescore_with_attention_decoder( scale=scale, ) - # word_seq is a k2.RaggedInt sharing the same shape as `path` + # word_seq is a k2.RaggedTensor sharing the same shape as `path` # but it contains word IDs. Note that it also contains 0s and -1s. # The last entry in each sublist is -1. - word_seq = k2.index(lattice.aux_labels, path) + if isinstance(lattice.aux_labels, torch.Tensor): + word_seq = k2.ragged.index(lattice.aux_labels, path) + else: + word_seq = lattice.aux_labels.index(path, remove_axis=True) # Remove epsilons and -1 from word_seq - word_seq = k2.ragged.remove_values_leq(word_seq, 0) + word_seq = word_seq.remove_values_leq(0) # Remove paths that has identical word sequences. # - # unique_word_seq is still a k2.RaggedInt with 3 axes [seq][path][word] + # unique_word_seq is still a k2.RaggedTensor with 3 axes [seq][path][word] # except that there are no repeated paths with the same word_seq # within a sequence. # - # num_repeats is also a k2.RaggedInt with 2 axes containing the + # num_repeats is also a k2.RaggedTensor with 2 axes containing the # multiplicities of each path. - # num_repeats.num_elements() == unique_word_seqs.tot_size(1) + # num_repeats.numel() == unique_word_seqs.tot_size(1) # # Since k2.ragged.unique_sequences will reorder paths within a seq, # `new2old` is a 1-D torch.Tensor mapping from the output path index # to the input path index. # new2old.numel() == unique_word_seq.tot_size(1) - unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences( - word_seq, need_num_repeats=True, need_new2old_indexes=True + unique_word_seq, num_repeats, new2old = word_seq.unique( + need_num_repeats=True, need_new2old_indexes=True ) - seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0) + seq_to_path_shape = unique_word_seq.shape.get_layer(0) # path_to_seq_map is a 1-D torch.Tensor. # path_to_seq_map[i] is the seq to which the i-th path @@ -778,7 +789,7 @@ def rescore_with_attention_decoder( # Remove the seq axis. # Now unique_word_seq has only two axes [path][word] - unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0) + unique_word_seq = unique_word_seq.remove_axis(0) # word_fsa is an FsaVec with axes [path][state][arc] word_fsa = k2.linear_fsa(unique_word_seq) @@ -796,20 +807,23 @@ def rescore_with_attention_decoder( # CAUTION: The "tokens" attribute is set in the file # local/compile_hlg.py - token_seq = k2.index(lattice.tokens, path) + if isinstance(lattice.tokens, torch.Tensor): + token_seq = k2.ragged.index(lattice.tokens, path) + else: + token_seq = lattice.tokens.index(path, remove_axis=True) # Remove epsilons and -1 from token_seq - token_seq = k2.ragged.remove_values_leq(token_seq, 0) + token_seq = token_seq.remove_values_leq(0) # Remove the seq axis. - token_seq = k2.ragged.remove_axis(token_seq, 0) + token_seq = token_seq.remove_axis(0) - token_seq, _ = k2.ragged.index( - token_seq, indexes=new2old, axis=0, need_value_indexes=False + token_seq, _ = token_seq.index( + indexes=new2old, axis=0, need_value_indexes=False ) # Now word in unique_word_seq has its corresponding token IDs. - token_ids = k2.ragged.to_list(token_seq) + token_ids = token_seq.tolist() num_word_seqs = new2old.numel() @@ -849,7 +863,7 @@ def rescore_with_attention_decoder( else: attention_scale_list = [attention_scale] - path_2axes = k2.ragged.remove_axis(path, 0) + path_2axes = path.remove_axis(0) ans = dict() for n_scale in ngram_lm_scale_list: @@ -859,23 +873,28 @@ def rescore_with_attention_decoder( + n_scale * ngram_lm_scores + a_scale * attention_scores ) - ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape, tot_scores) - argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores) + ragged_tot_scores = k2.RaggedTensor(seq_to_path_shape, tot_scores) + argmax_indexes = ragged_tot_scores.argmax() - best_path_indexes = k2.index(new2old, argmax_indexes) + best_path_indexes = k2.index_select(new2old, argmax_indexes) # best_path is a k2.RaggedInt with 2 axes [path][arc_pos] - best_path = k2.index(path_2axes, best_path_indexes) + best_path, _ = path_2axes.index( + indexes=best_path_indexes, axis=0, need_value_indexes=False + ) - # labels is a k2.RaggedInt with 2 axes [path][token_id] + # labels is a k2.RaggedTensor with 2 axes [path][token_id] # Note that it contains -1s. - labels = k2.index(lattice.labels.contiguous(), best_path) + labels = k2.ragged.index(lattice.labels.contiguous(), best_path) - labels = k2.ragged.remove_values_eq(labels, -1) + labels = labels.remove_values_eq(-1) - # lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so - # aux_labels is also a k2.RaggedInt with 2 axes - aux_labels = k2.index(lattice.aux_labels, best_path.values()) + if isinstance(lattice.aux_labels, torch.Tensor): + aux_labels = k2.index_select(lattice.aux_labels, best_path.data) + else: + aux_labels, _ = lattice.aux_labels.index( + indexes=best_path.data, axis=0, need_value_indexes=False + ) best_path_fsa = k2.linear_fsa(labels) best_path_fsa.aux_labels = aux_labels diff --git a/icefall/lexicon.py b/icefall/lexicon.py index f1127c7cf..6730bac49 100644 --- a/icefall/lexicon.py +++ b/icefall/lexicon.py @@ -157,7 +157,7 @@ class BpeLexicon(Lexicon): lang_dir / "lexicon.txt" ) - def convert_lexicon_to_ragged(self, filename: str) -> k2.RaggedInt: + def convert_lexicon_to_ragged(self, filename: str) -> k2.RaggedTensor: """Read a BPE lexicon from file and convert it to a k2 ragged tensor. @@ -200,19 +200,18 @@ class BpeLexicon(Lexicon): ) values = torch.tensor(token_ids, dtype=torch.int32) - return k2.RaggedInt(shape, values) + return k2.RaggedTensor(shape, values) - def words_to_piece_ids(self, words: List[str]) -> k2.RaggedInt: + def words_to_piece_ids(self, words: List[str]) -> k2.RaggedTensor: """Convert a list of words to a ragged tensor contained word piece IDs. """ word_ids = [self.word_table[w] for w in words] word_ids = torch.tensor(word_ids, dtype=torch.int32) - ragged, _ = k2.ragged.index( - self.ragged_lexicon, + ragged, _ = self.ragged_lexicon.index( indexes=word_ids, - need_value_indexes=False, axis=0, + need_value_indexes=False, ) return ragged diff --git a/icefall/utils.py b/icefall/utils.py index 2994c2d47..b78dfe8ad 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -199,26 +199,25 @@ def get_texts(best_paths: k2.Fsa) -> List[List[int]]: Returns a list of lists of int, containing the label sequences we decoded. """ - if isinstance(best_paths.aux_labels, k2.RaggedInt): + if isinstance(best_paths.aux_labels, k2.RaggedTensor): # remove 0's and -1's. - aux_labels = k2r.remove_values_leq(best_paths.aux_labels, 0) - aux_shape = k2r.compose_ragged_shapes( - best_paths.arcs.shape(), aux_labels.shape() - ) + aux_labels = best_paths.aux_labels.remove_values_leq(0) + # TODO: change arcs.shape() to arcs.shape + aux_shape = best_paths.arcs.shape().compose(aux_labels.shape) # remove the states and arcs axes. - aux_shape = k2r.remove_axis(aux_shape, 1) - aux_shape = k2r.remove_axis(aux_shape, 1) - aux_labels = k2.RaggedInt(aux_shape, aux_labels.values()) + aux_shape = aux_shape.remove_axis(1) + aux_shape = aux_shape.remove_axis(1) + aux_labels = k2.RaggedTensor(aux_shape, aux_labels.data) else: # remove axis corresponding to states. - aux_shape = k2r.remove_axis(best_paths.arcs.shape(), 1) - aux_labels = k2.RaggedInt(aux_shape, best_paths.aux_labels) + aux_shape = best_paths.arcs.shape().remove_axis(1) + aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels) # remove 0's and -1's. - aux_labels = k2r.remove_values_leq(aux_labels, 0) + aux_labels = aux_labels.remove_values_leq(0) - assert aux_labels.num_axes() == 2 - return k2r.to_list(aux_labels) + assert aux_labels.num_axes == 2 + return aux_labels.tolist() def store_transcripts( diff --git a/test/test_utils.py b/test/test_utils.py index 2dd79689f..b4c9358fd 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -60,7 +60,7 @@ def test_get_texts_ragged(): 4 """ ) - fsa1.aux_labels = k2.RaggedInt("[ [1 3 0 2] [] [4 0 1] [-1]]") + fsa1.aux_labels = k2.RaggedTensor("[ [1 3 0 2] [] [4 0 1] [-1]]") fsa2 = k2.Fsa.from_str( """ @@ -70,7 +70,7 @@ def test_get_texts_ragged(): 3 """ ) - fsa2.aux_labels = k2.RaggedInt("[[3 0 5 0 8] [0 9 7 0] [-1]]") + fsa2.aux_labels = k2.RaggedTensor("[[3 0 5 0 8] [0 9 7 0] [-1]]") fsas = k2.Fsa.from_fsas([fsa1, fsa2]) texts = get_texts(fsas) assert texts == [[1, 3, 2, 4, 1], [3, 5, 8, 9, 7]]