diff --git a/.github/workflows/run-yesno-recipe.yml b/.github/workflows/run-yesno-recipe.yml index 39a6a0e80..b4e266672 100644 --- a/.github/workflows/run-yesno-recipe.yml +++ b/.github/workflows/run-yesno-recipe.yml @@ -56,7 +56,7 @@ jobs: run: | python3 -m pip install --upgrade pip black flake8 python3 -m pip install -U pip - python3 -m pip install k2==1.4.dev20210822+cpu.torch1.7.1 -f https://k2-fsa.org/nightly/ + python3 -m pip install k2==1.7.dev20210908+cpu.torch1.7.1 -f https://k2-fsa.org/nightly/ python3 -m pip install torchaudio==0.7.2 python3 -m pip install git+https://github.com/lhotse-speech/lhotse diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9110e7db4..c853e3de1 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -32,7 +32,8 @@ jobs: os: [ubuntu-18.04, macos-10.15] python-version: [3.6, 3.7, 3.8, 3.9] torch: ["1.8.1"] - k2-version: ["1.4.dev20210822"] + k2-version: ["1.7.dev20210908"] + fail-fast: false steps: diff --git a/.gitignore b/.gitignore index 839a1c34a..e6c84ca5e 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,4 @@ path.sh exp exp*/ *.pt -download/ +download diff --git a/docs/source/conf.py b/docs/source/conf.py index f97f72d54..599df8b3e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -16,7 +16,6 @@ import sphinx_rtd_theme - # -- Project information ----------------------------------------------------- project = "icefall" diff --git a/docs/source/installation/images/device-CPU_CUDA-orange.svg b/docs/source/installation/images/device-CPU_CUDA-orange.svg index b760102e3..a023a1283 100644 --- a/docs/source/installation/images/device-CPU_CUDA-orange.svg +++ b/docs/source/installation/images/device-CPU_CUDA-orange.svg @@ -1 +1 @@ -device: CPU | CUDAdeviceCPU | CUDA \ No newline at end of file +device: CPU | CUDAdeviceCPU | CUDA diff --git a/docs/source/installation/images/k2-v-1.7.svg b/docs/source/installation/images/k2-v-1.7.svg new file mode 100644 index 000000000..8a74d0b55 --- /dev/null +++ b/docs/source/installation/images/k2-v-1.7.svg @@ -0,0 +1 @@ +k2: >= v1.7k2>= v1.7 diff --git a/docs/source/installation/images/os-Linux_macOS-ff69b4.svg b/docs/source/installation/images/os-Linux_macOS-ff69b4.svg index 44c112747..178813ed4 100644 --- a/docs/source/installation/images/os-Linux_macOS-ff69b4.svg +++ b/docs/source/installation/images/os-Linux_macOS-ff69b4.svg @@ -1 +1 @@ -os: Linux | macOSosLinux | macOS \ No newline at end of file +os: Linux | macOSosLinux | macOS diff --git a/docs/source/installation/images/python-3.6_3.7_3.8_3.9-blue.svg b/docs/source/installation/images/python-3.6_3.7_3.8_3.9-blue.svg index 676feba2c..befc1e19e 100644 --- a/docs/source/installation/images/python-3.6_3.7_3.8_3.9-blue.svg +++ b/docs/source/installation/images/python-3.6_3.7_3.8_3.9-blue.svg @@ -1 +1 @@ -python: 3.6 | 3.7 | 3.8 | 3.9python3.6 | 3.7 | 3.8 | 3.9 \ No newline at end of file +python: 3.6 | 3.7 | 3.8 | 3.9python3.6 | 3.7 | 3.8 | 3.9 diff --git a/docs/source/installation/images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg b/docs/source/installation/images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg index d9b0efe1a..496e5a9ef 100644 --- a/docs/source/installation/images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg +++ b/docs/source/installation/images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg @@ -1 +1 @@ -torch: 1.6.0 | 1.7.0 | 1.7.1 | 1.8.0 | 1.8.1 | 1.9.0torch1.6.0 | 1.7.0 | 1.7.1 | 1.8.0 | 1.8.1 | 1.9.0 \ No newline at end of file +torch: 1.6.0 | 1.7.0 | 1.7.1 | 1.8.0 | 1.8.1 | 1.9.0torch1.6.0 | 1.7.0 | 1.7.1 | 1.8.0 | 1.8.1 | 1.9.0 diff --git a/docs/source/installation/index.rst b/docs/source/installation/index.rst index bcef669c8..c11cbd1be 100644 --- a/docs/source/installation/index.rst +++ b/docs/source/installation/index.rst @@ -7,6 +7,7 @@ Installation - |device| - |python_versions| - |torch_versions| +- |k2_versions| .. |os| image:: ./images/os-Linux_macOS-ff69b4.svg :alt: Supported operating systems @@ -20,7 +21,10 @@ Installation .. |torch_versions| image:: ./images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg :alt: Supported PyTorch versions -icefall depends on `k2 `_ and +.. |k2_versions| image:: ./images/k2-v-1.7.svg + :alt: Supported k2 versions + +``icefall`` depends on `k2 `_ and `lhotse `_. We recommend you to install ``k2`` first, as ``k2`` is bound to @@ -32,12 +36,16 @@ installs its dependency PyTorch, which can be reused by ``lhotse``. -------------- Please refer to ``_ -to install `k2`. +to install ``k2``. + +.. CAUTION:: + + You need to install ``k2`` with a version at least **v1.7**. .. HINT:: If you have already installed PyTorch and don't want to replace it, - please install a version of k2 that is compiled against the version + please install a version of ``k2`` that is compiled against the version of PyTorch you are using. (2) Install lhotse @@ -50,10 +58,15 @@ to install ``lhotse``. Install ``lhotse`` also installs its dependency `torchaudio `_. +.. CAUTION:: + + If you have installed ``torchaudio``, please consider uninstalling it before + installing ``lhotse``. Otherwise, it may update your already installed PyTorch. + (3) Download icefall -------------------- -icefall is a collection of Python scripts, so you don't need to install it +``icefall`` is a collection of Python scripts, so you don't need to install it and we don't provide a ``setup.py`` to install it. What you need is to download it and set the environment variable ``PYTHONPATH`` @@ -367,7 +380,7 @@ Now let us run the training part: .. CAUTION:: - We use ``export CUDA_VISIBLE_DEVICES=""`` so that icefall uses CPU + We use ``export CUDA_VISIBLE_DEVICES=""`` so that ``icefall`` uses CPU even if there are GPUs available. The training log is given below: diff --git a/docs/source/recipes/index.rst b/docs/source/recipes/index.rst index db34fdca5..36f8dfc39 100644 --- a/docs/source/recipes/index.rst +++ b/docs/source/recipes/index.rst @@ -15,4 +15,3 @@ We may add recipes for other tasks as well in the future. yesno librispeech - diff --git a/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst b/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst index a59f34db7..64f0a6a08 100644 --- a/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst +++ b/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst @@ -209,7 +209,7 @@ After downloading, you will have the following files: |-- 1221-135766-0001.flac |-- 1221-135766-0002.flac `-- trans.txt - + 6 directories, 10 files @@ -256,14 +256,14 @@ The output is: 2021-08-24 16:57:28,098 INFO [pretrained.py:266] ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac: AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS - + ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac: GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONORED BOSOM TO CONNECT HER PARENT FOREVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN - + ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac: YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION - - + + 2021-08-24 16:57:28,099 INFO [pretrained.py:268] Decoding Done @@ -297,14 +297,14 @@ The decoding output is: 2021-08-24 16:39:54,010 INFO [pretrained.py:266] ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac: AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS - + ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac: GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONORED BOSOM TO CONNECT HER PARENT FOREVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN - + ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac: YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION - - + + 2021-08-24 16:39:54,010 INFO [pretrained.py:268] Decoding Done diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index dfc412672..d4acf9206 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -43,4 +43,3 @@ We searched the lm_score_scale for best results, the scales that produced the WE |--|--| |test-clean|0.8| |test-other|0.9| - diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index ff6374d73..cfdcff756 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -45,6 +45,7 @@ from icefall.utils import ( get_texts, setup_logger, store_transcripts, + str2bool, write_error_stats, ) @@ -116,6 +117,17 @@ def get_parser(): """, ) + parser.add_argument( + "--export", + type=str2bool, + default=False, + help="""When enabled, the averaged model is saved to + conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved. + pretrained.pt contains a dict {"model": model.state_dict()}, + which can be loaded by `icefall.checkpoint.load_checkpoint()`. + """, + ) + return parser @@ -541,6 +553,13 @@ def main(): logging.info(f"averaging {filenames}") model.load_state_dict(average_checkpoints(filenames)) + if params.export: + logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") + torch.save( + {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" + ) + return + model.to(device) model.eval() num_param = sum([p.numel() for p in model.parameters()]) diff --git a/egs/librispeech/ASR/conformer_ctc/test_subsampling.py b/egs/librispeech/ASR/conformer_ctc/test_subsampling.py index e3361d0c9..81fa234dd 100755 --- a/egs/librispeech/ASR/conformer_ctc/test_subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/test_subsampling.py @@ -16,9 +16,8 @@ # limitations under the License. -from subsampling import Conv2dSubsampling -from subsampling import VggSubsampling import torch +from subsampling import Conv2dSubsampling, VggSubsampling def test_conv2d_subsampling(): diff --git a/egs/librispeech/ASR/conformer_ctc/test_transformer.py b/egs/librispeech/ASR/conformer_ctc/test_transformer.py index b90215274..667057c51 100644 --- a/egs/librispeech/ASR/conformer_ctc/test_transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/test_transformer.py @@ -17,17 +17,16 @@ import torch +from torch.nn.utils.rnn import pad_sequence from transformer import ( Transformer, + add_eos, + add_sos, + decoder_padding_mask, encoder_padding_mask, generate_square_subsequent_mask, - decoder_padding_mask, - add_sos, - add_eos, ) -from torch.nn.utils.rnn import pad_sequence - def test_encoder_padding_mask(): supervisions = { diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index 19a1ddd23..407fb7d88 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -102,14 +102,14 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: LG.labels[LG.labels >= first_token_disambig_id] = 0 - assert isinstance(LG.aux_labels, k2.RaggedInt) - LG.aux_labels.values()[LG.aux_labels.values() >= first_word_disambig_id] = 0 + assert isinstance(LG.aux_labels, k2.RaggedTensor) + LG.aux_labels.data[LG.aux_labels.data >= first_word_disambig_id] = 0 LG = k2.remove_epsilon(LG) logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") LG = k2.connect(LG) - LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0) + LG.aux_labels = LG.aux_labels.remove_values_eq(0) logging.info("Arc sorting LG") LG = k2.arc_sort(LG) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index afdebd12b..87e9cddb4 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -99,8 +99,10 @@ def get_params() -> AttributeDict: # - nbest-rescoring # - whole-lattice-rescoring "method": "whole-lattice-rescoring", + # "method": "1best", + # "method": "nbest", # num_paths is used when method is "nbest" and "nbest-rescoring" - "num_paths": 30, + "num_paths": 100, } ) return params @@ -424,6 +426,7 @@ def main(): torch.save( {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" ) + return model.to(device) model.eval() diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py old mode 100644 new mode 100755 diff --git a/egs/yesno/ASR/local/compile_hlg.py b/egs/yesno/ASR/local/compile_hlg.py index f2fafd013..41a927455 100755 --- a/egs/yesno/ASR/local/compile_hlg.py +++ b/egs/yesno/ASR/local/compile_hlg.py @@ -80,14 +80,14 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: LG.labels[LG.labels >= first_token_disambig_id] = 0 - assert isinstance(LG.aux_labels, k2.RaggedInt) - LG.aux_labels.values()[LG.aux_labels.values() >= first_word_disambig_id] = 0 + assert isinstance(LG.aux_labels, k2.RaggedTensor) + LG.aux_labels.data[LG.aux_labels.data >= first_word_disambig_id] = 0 LG = k2.remove_epsilon(LG) logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") LG = k2.connect(LG) - LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0) + LG.aux_labels = LG.aux_labels.remove_values_eq(0) logging.info("Arc sorting LG") LG = k2.arc_sort(LG) diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py index aa7b07b98..54fdbb3cc 100755 --- a/egs/yesno/ASR/tdnn/decode.py +++ b/egs/yesno/ASR/tdnn/decode.py @@ -296,6 +296,7 @@ def main(): torch.save( {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" ) + return model.to(device) model.eval() diff --git a/icefall/decode.py b/icefall/decode.py index de3219401..3f6e5fc84 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -84,8 +84,8 @@ def _intersect_device( for start, end in splits: indexes = torch.arange(start, end).to(b_to_a_map) - fsas = k2.index(b_fsas, indexes) - b_to_a = k2.index(b_to_a_map, indexes) + fsas = k2.index_fsa(b_fsas, indexes) + b_to_a = k2.index_select(b_to_a_map, indexes) path_lattice = k2.intersect_device( a_fsas, fsas, b_to_a_map=b_to_a, sorted_match_a=sorted_match_a ) @@ -215,18 +215,16 @@ def nbest_decoding( scale=scale, ) - # word_seq is a k2.RaggedInt sharing the same shape as `path` + # word_seq is a k2.RaggedTensor sharing the same shape as `path` # but it contains word IDs. Note that it also contains 0s and -1s. # The last entry in each sublist is -1. - word_seq = k2.index(lattice.aux_labels, path) - # Note: the above operation supports also the case when - # lattice.aux_labels is a ragged tensor. In that case, - # `remove_axis=True` is used inside the pybind11 binding code, - # so the resulting `word_seq` still has 3 axes, like `path`. - # The 3 axes are [seq][path][word_id] + if isinstance(lattice.aux_labels, torch.Tensor): + word_seq = k2.ragged.index(lattice.aux_labels, path) + else: + word_seq = lattice.aux_labels.index(path, remove_axis=True) # Remove 0 (epsilon) and -1 from word_seq - word_seq = k2.ragged.remove_values_leq(word_seq, 0) + word_seq = word_seq.remove_values_leq(0) # Remove sequences with identical word sequences. # @@ -234,12 +232,12 @@ def nbest_decoding( # `new2old` is a 1-D torch.Tensor mapping from the output path index # to the input path index. # new2old.numel() == unique_word_seqs.tot_size(1) - unique_word_seq, _, new2old = k2.ragged.unique_sequences( - word_seq, need_num_repeats=False, need_new2old_indexes=True + unique_word_seq, _, new2old = word_seq.unique( + need_num_repeats=False, need_new2old_indexes=True ) # Note: unique_word_seq still has the same axes as word_seq - seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0) + seq_to_path_shape = unique_word_seq.shape.get_layer(0) # path_to_seq_map is a 1-D torch.Tensor. # path_to_seq_map[i] is the seq to which the i-th path belongs @@ -247,7 +245,7 @@ def nbest_decoding( # Remove the seq axis. # Now unique_word_seq has only two axes [path][word] - unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0) + unique_word_seq = unique_word_seq.remove_axis(0) # word_fsa is an FsaVec with axes [path][state][arc] word_fsa = k2.linear_fsa(unique_word_seq) @@ -275,35 +273,35 @@ def nbest_decoding( use_double_scores=use_double_scores, log_semiring=False ) - # RaggedFloat currently supports float32 only. - # If Ragged is wrapped, we can use k2.RaggedDouble here - ragged_tot_scores = k2.RaggedFloat( - seq_to_path_shape, tot_scores.to(torch.float32) - ) + ragged_tot_scores = k2.RaggedTensor(seq_to_path_shape, tot_scores) - argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores) + argmax_indexes = ragged_tot_scores.argmax() # Since we invoked `k2.ragged.unique_sequences`, which reorders # the index from `path`, we use `new2old` here to convert argmax_indexes # to the indexes into `path`. # # Use k2.index here since argmax_indexes' dtype is torch.int32 - best_path_indexes = k2.index(new2old, argmax_indexes) + best_path_indexes = k2.index_select(new2old, argmax_indexes) - path_2axes = k2.ragged.remove_axis(path, 0) + path_2axes = path.remove_axis(0) - # best_path is a k2.RaggedInt with 2 axes [path][arc_pos] - best_path = k2.index(path_2axes, best_path_indexes) + # best_path is a k2.RaggedTensor with 2 axes [path][arc_pos] + best_path, _ = path_2axes.index( + indexes=best_path_indexes, axis=0, need_value_indexes=False + ) - # labels is a k2.RaggedInt with 2 axes [path][token_id] + # labels is a k2.RaggedTensor with 2 axes [path][token_id] # Note that it contains -1s. - labels = k2.index(lattice.labels.contiguous(), best_path) + labels = k2.ragged.index(lattice.labels.contiguous(), best_path) - labels = k2.ragged.remove_values_eq(labels, -1) + labels = labels.remove_values_eq(-1) - # lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so - # aux_labels is also a k2.RaggedInt with 2 axes - aux_labels = k2.index(lattice.aux_labels, best_path.values()) + # lattice.aux_labels is a k2.RaggedTensor with 2 axes, so + # aux_labels is also a k2.RaggedTensor with 2 axes + aux_labels, _ = lattice.aux_labels.index( + indexes=best_path.data, axis=0, need_value_indexes=False + ) best_path_fsa = k2.linear_fsa(labels) best_path_fsa.aux_labels = aux_labels @@ -426,33 +424,36 @@ def rescore_with_n_best_list( scale=scale, ) - # word_seq is a k2.RaggedInt sharing the same shape as `path` + # word_seq is a k2.RaggedTensor sharing the same shape as `path` # but it contains word IDs. Note that it also contains 0s and -1s. # The last entry in each sublist is -1. - word_seq = k2.index(lattice.aux_labels, path) + if isinstance(lattice.aux_labels, torch.Tensor): + word_seq = k2.ragged.index(lattice.aux_labels, path) + else: + word_seq = lattice.aux_labels.index(path, remove_axis=True) # Remove epsilons and -1 from word_seq - word_seq = k2.ragged.remove_values_leq(word_seq, 0) + word_seq = word_seq.remove_values_leq(0) # Remove paths that has identical word sequences. # - # unique_word_seq is still a k2.RaggedInt with 3 axes [seq][path][word] + # unique_word_seq is still a k2.RaggedTensor with 3 axes [seq][path][word] # except that there are no repeated paths with the same word_seq # within a sequence. # - # num_repeats is also a k2.RaggedInt with 2 axes containing the + # num_repeats is also a k2.RaggedTensor with 2 axes containing the # multiplicities of each path. - # num_repeats.num_elements() == unique_word_seqs.tot_size(1) + # num_repeats.numel() == unique_word_seqs.tot_size(1) # # Since k2.ragged.unique_sequences will reorder paths within a seq, # `new2old` is a 1-D torch.Tensor mapping from the output path index # to the input path index. # new2old.numel() == unique_word_seqs.tot_size(1) - unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences( - word_seq, need_num_repeats=True, need_new2old_indexes=True + unique_word_seq, num_repeats, new2old = word_seq.unique( + need_num_repeats=True, need_new2old_indexes=True ) - seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0) + seq_to_path_shape = unique_word_seq.shape.get_layer(0) # path_to_seq_map is a 1-D torch.Tensor. # path_to_seq_map[i] is the seq to which the i-th path @@ -461,7 +462,7 @@ def rescore_with_n_best_list( # Remove the seq axis. # Now unique_word_seq has only two axes [path][word] - unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0) + unique_word_seq = unique_word_seq.remove_axis(0) # word_fsa is an FsaVec with axes [path][state][arc] word_fsa = k2.linear_fsa(unique_word_seq) @@ -485,39 +486,42 @@ def rescore_with_n_best_list( use_double_scores=True, log_semiring=False ) - path_2axes = k2.ragged.remove_axis(path, 0) + path_2axes = path.remove_axis(0) ans = dict() for lm_scale in lm_scale_list: tot_scores = am_scores / lm_scale + lm_scores - # Remember that we used `k2.ragged.unique_sequences` to remove repeated + # Remember that we used `k2.RaggedTensor.unique` to remove repeated # paths to avoid redundant computation in `k2.intersect_device`. # Now we use `num_repeats` to correct the scores for each path. # # NOTE(fangjun): It is commented out as it leads to a worse WER # tot_scores = tot_scores * num_repeats.values() - ragged_tot_scores = k2.RaggedFloat( - seq_to_path_shape, tot_scores.to(torch.float32) - ) - argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores) + ragged_tot_scores = k2.RaggedTensor(seq_to_path_shape, tot_scores) + argmax_indexes = ragged_tot_scores.argmax() # Use k2.index here since argmax_indexes' dtype is torch.int32 - best_path_indexes = k2.index(new2old, argmax_indexes) + best_path_indexes = k2.index_select(new2old, argmax_indexes) # best_path is a k2.RaggedInt with 2 axes [path][arc_pos] - best_path = k2.index(path_2axes, best_path_indexes) + best_path, _ = path_2axes.index( + indexes=best_path_indexes, axis=0, need_value_indexes=False + ) - # labels is a k2.RaggedInt with 2 axes [path][phone_id] + # labels is a k2.RaggedTensor with 2 axes [path][phone_id] # Note that it contains -1s. - labels = k2.index(lattice.labels.contiguous(), best_path) + labels = k2.ragged.index(lattice.labels.contiguous(), best_path) - labels = k2.ragged.remove_values_eq(labels, -1) + labels = labels.remove_values_eq(-1) - # lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so - # aux_labels is also a k2.RaggedInt with 2 axes - aux_labels = k2.index(lattice.aux_labels, best_path.values()) + # lattice.aux_labels is a k2.RaggedTensor tensor with 2 axes, so + # aux_labels is also a k2.RaggedTensor with 2 axes + + aux_labels, _ = lattice.aux_labels.index( + indexes=best_path.data, axis=0, need_value_indexes=False + ) best_path_fsa = k2.linear_fsa(labels) best_path_fsa.aux_labels = aux_labels @@ -659,12 +663,16 @@ def nbest_oracle( scale=scale, ) - word_seq = k2.index(lattice.aux_labels, path) - word_seq = k2.ragged.remove_values_leq(word_seq, 0) - unique_word_seq, _, _ = k2.ragged.unique_sequences( - word_seq, need_num_repeats=False, need_new2old_indexes=False + if isinstance(lattice.aux_labels, torch.Tensor): + word_seq = k2.ragged.index(lattice.aux_labels, path) + else: + word_seq = lattice.aux_labels.index(path, remove_axis=True) + + word_seq = word_seq.remove_values_leq(0) + unique_word_seq, _, _ = word_seq.unique( + need_num_repeats=False, need_new2old_indexes=False ) - unique_word_ids = k2.ragged.to_list(unique_word_seq) + unique_word_ids = unique_word_seq.tolist() assert len(unique_word_ids) == len(ref_texts) # unique_word_ids[i] contains all hypotheses of the i-th utterance @@ -743,33 +751,36 @@ def rescore_with_attention_decoder( scale=scale, ) - # word_seq is a k2.RaggedInt sharing the same shape as `path` + # word_seq is a k2.RaggedTensor sharing the same shape as `path` # but it contains word IDs. Note that it also contains 0s and -1s. # The last entry in each sublist is -1. - word_seq = k2.index(lattice.aux_labels, path) + if isinstance(lattice.aux_labels, torch.Tensor): + word_seq = k2.ragged.index(lattice.aux_labels, path) + else: + word_seq = lattice.aux_labels.index(path, remove_axis=True) # Remove epsilons and -1 from word_seq - word_seq = k2.ragged.remove_values_leq(word_seq, 0) + word_seq = word_seq.remove_values_leq(0) # Remove paths that has identical word sequences. # - # unique_word_seq is still a k2.RaggedInt with 3 axes [seq][path][word] + # unique_word_seq is still a k2.RaggedTensor with 3 axes [seq][path][word] # except that there are no repeated paths with the same word_seq # within a sequence. # - # num_repeats is also a k2.RaggedInt with 2 axes containing the + # num_repeats is also a k2.RaggedTensor with 2 axes containing the # multiplicities of each path. - # num_repeats.num_elements() == unique_word_seqs.tot_size(1) + # num_repeats.numel() == unique_word_seqs.tot_size(1) # # Since k2.ragged.unique_sequences will reorder paths within a seq, # `new2old` is a 1-D torch.Tensor mapping from the output path index # to the input path index. # new2old.numel() == unique_word_seq.tot_size(1) - unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences( - word_seq, need_num_repeats=True, need_new2old_indexes=True + unique_word_seq, num_repeats, new2old = word_seq.unique( + need_num_repeats=True, need_new2old_indexes=True ) - seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0) + seq_to_path_shape = unique_word_seq.shape.get_layer(0) # path_to_seq_map is a 1-D torch.Tensor. # path_to_seq_map[i] is the seq to which the i-th path @@ -778,7 +789,7 @@ def rescore_with_attention_decoder( # Remove the seq axis. # Now unique_word_seq has only two axes [path][word] - unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0) + unique_word_seq = unique_word_seq.remove_axis(0) # word_fsa is an FsaVec with axes [path][state][arc] word_fsa = k2.linear_fsa(unique_word_seq) @@ -796,20 +807,23 @@ def rescore_with_attention_decoder( # CAUTION: The "tokens" attribute is set in the file # local/compile_hlg.py - token_seq = k2.index(lattice.tokens, path) + if isinstance(lattice.tokens, torch.Tensor): + token_seq = k2.ragged.index(lattice.tokens, path) + else: + token_seq = lattice.tokens.index(path, remove_axis=True) # Remove epsilons and -1 from token_seq - token_seq = k2.ragged.remove_values_leq(token_seq, 0) + token_seq = token_seq.remove_values_leq(0) # Remove the seq axis. - token_seq = k2.ragged.remove_axis(token_seq, 0) + token_seq = token_seq.remove_axis(0) - token_seq, _ = k2.ragged.index( - token_seq, indexes=new2old, axis=0, need_value_indexes=False + token_seq, _ = token_seq.index( + indexes=new2old, axis=0, need_value_indexes=False ) # Now word in unique_word_seq has its corresponding token IDs. - token_ids = k2.ragged.to_list(token_seq) + token_ids = token_seq.tolist() num_word_seqs = new2old.numel() @@ -849,7 +863,7 @@ def rescore_with_attention_decoder( else: attention_scale_list = [attention_scale] - path_2axes = k2.ragged.remove_axis(path, 0) + path_2axes = path.remove_axis(0) ans = dict() for n_scale in ngram_lm_scale_list: @@ -859,23 +873,28 @@ def rescore_with_attention_decoder( + n_scale * ngram_lm_scores + a_scale * attention_scores ) - ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape, tot_scores) - argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores) + ragged_tot_scores = k2.RaggedTensor(seq_to_path_shape, tot_scores) + argmax_indexes = ragged_tot_scores.argmax() - best_path_indexes = k2.index(new2old, argmax_indexes) + best_path_indexes = k2.index_select(new2old, argmax_indexes) # best_path is a k2.RaggedInt with 2 axes [path][arc_pos] - best_path = k2.index(path_2axes, best_path_indexes) + best_path, _ = path_2axes.index( + indexes=best_path_indexes, axis=0, need_value_indexes=False + ) - # labels is a k2.RaggedInt with 2 axes [path][token_id] + # labels is a k2.RaggedTensor with 2 axes [path][token_id] # Note that it contains -1s. - labels = k2.index(lattice.labels.contiguous(), best_path) + labels = k2.ragged.index(lattice.labels.contiguous(), best_path) - labels = k2.ragged.remove_values_eq(labels, -1) + labels = labels.remove_values_eq(-1) - # lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so - # aux_labels is also a k2.RaggedInt with 2 axes - aux_labels = k2.index(lattice.aux_labels, best_path.values()) + if isinstance(lattice.aux_labels, torch.Tensor): + aux_labels = k2.index_select(lattice.aux_labels, best_path.data) + else: + aux_labels, _ = lattice.aux_labels.index( + indexes=best_path.data, axis=0, need_value_indexes=False + ) best_path_fsa = k2.linear_fsa(labels) best_path_fsa.aux_labels = aux_labels diff --git a/icefall/lexicon.py b/icefall/lexicon.py index f1127c7cf..6730bac49 100644 --- a/icefall/lexicon.py +++ b/icefall/lexicon.py @@ -157,7 +157,7 @@ class BpeLexicon(Lexicon): lang_dir / "lexicon.txt" ) - def convert_lexicon_to_ragged(self, filename: str) -> k2.RaggedInt: + def convert_lexicon_to_ragged(self, filename: str) -> k2.RaggedTensor: """Read a BPE lexicon from file and convert it to a k2 ragged tensor. @@ -200,19 +200,18 @@ class BpeLexicon(Lexicon): ) values = torch.tensor(token_ids, dtype=torch.int32) - return k2.RaggedInt(shape, values) + return k2.RaggedTensor(shape, values) - def words_to_piece_ids(self, words: List[str]) -> k2.RaggedInt: + def words_to_piece_ids(self, words: List[str]) -> k2.RaggedTensor: """Convert a list of words to a ragged tensor contained word piece IDs. """ word_ids = [self.word_table[w] for w in words] word_ids = torch.tensor(word_ids, dtype=torch.int32) - ragged, _ = k2.ragged.index( - self.ragged_lexicon, + ragged, _ = self.ragged_lexicon.index( indexes=word_ids, - need_value_indexes=False, axis=0, + need_value_indexes=False, ) return ragged diff --git a/icefall/utils.py b/icefall/utils.py index 2994c2d47..1130d8947 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -26,7 +26,6 @@ from pathlib import Path from typing import Dict, Iterable, List, TextIO, Tuple, Union import k2 -import k2.ragged as k2r import kaldialign import torch import torch.distributed as dist @@ -199,26 +198,25 @@ def get_texts(best_paths: k2.Fsa) -> List[List[int]]: Returns a list of lists of int, containing the label sequences we decoded. """ - if isinstance(best_paths.aux_labels, k2.RaggedInt): + if isinstance(best_paths.aux_labels, k2.RaggedTensor): # remove 0's and -1's. - aux_labels = k2r.remove_values_leq(best_paths.aux_labels, 0) - aux_shape = k2r.compose_ragged_shapes( - best_paths.arcs.shape(), aux_labels.shape() - ) + aux_labels = best_paths.aux_labels.remove_values_leq(0) + # TODO: change arcs.shape() to arcs.shape + aux_shape = best_paths.arcs.shape().compose(aux_labels.shape) # remove the states and arcs axes. - aux_shape = k2r.remove_axis(aux_shape, 1) - aux_shape = k2r.remove_axis(aux_shape, 1) - aux_labels = k2.RaggedInt(aux_shape, aux_labels.values()) + aux_shape = aux_shape.remove_axis(1) + aux_shape = aux_shape.remove_axis(1) + aux_labels = k2.RaggedTensor(aux_shape, aux_labels.data) else: # remove axis corresponding to states. - aux_shape = k2r.remove_axis(best_paths.arcs.shape(), 1) - aux_labels = k2.RaggedInt(aux_shape, best_paths.aux_labels) + aux_shape = best_paths.arcs.shape().remove_axis(1) + aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels) # remove 0's and -1's. - aux_labels = k2r.remove_values_leq(aux_labels, 0) + aux_labels = aux_labels.remove_values_leq(0) - assert aux_labels.num_axes() == 2 - return k2r.to_list(aux_labels) + assert aux_labels.num_axes == 2 + return aux_labels.tolist() def store_transcripts( diff --git a/test/test_bpe_graph_compiler.py b/test/test_bpe_graph_compiler.py index 67d300b7d..e58c4f1c6 100755 --- a/test/test_bpe_graph_compiler.py +++ b/test/test_bpe_graph_compiler.py @@ -16,9 +16,10 @@ # limitations under the License. +from pathlib import Path + from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.lexicon import BpeLexicon -from pathlib import Path def test(): diff --git a/test/test_utils.py b/test/test_utils.py index 2dd79689f..b4c9358fd 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -60,7 +60,7 @@ def test_get_texts_ragged(): 4 """ ) - fsa1.aux_labels = k2.RaggedInt("[ [1 3 0 2] [] [4 0 1] [-1]]") + fsa1.aux_labels = k2.RaggedTensor("[ [1 3 0 2] [] [4 0 1] [-1]]") fsa2 = k2.Fsa.from_str( """ @@ -70,7 +70,7 @@ def test_get_texts_ragged(): 3 """ ) - fsa2.aux_labels = k2.RaggedInt("[[3 0 5 0 8] [0 9 7 0] [-1]]") + fsa2.aux_labels = k2.RaggedTensor("[[3 0 5 0 8] [0 9 7 0] [-1]]") fsas = k2.Fsa.from_fsas([fsa1, fsa2]) texts = get_texts(fsas) assert texts == [[1, 3, 2, 4, 1], [3, 5, 8, 9, 7]]