Merge branch 'k2-fsa:master' into master

This commit is contained in:
lishaojie 2023-05-15 21:49:25 +08:00 committed by GitHub
commit f35eed4240
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
59 changed files with 6730 additions and 322 deletions

151
README.md
View File

@ -28,14 +28,15 @@ We provide the following recipes:
- [yesno][yesno]
- [LibriSpeech][librispeech]
- [GigaSpeech][gigaspeech]
- [Aishell][aishell]
- [Aishell2][aishell2]
- [Aishell4][aishell4]
- [TIMIT][timit]
- [TED-LIUM3][tedlium3]
- [GigaSpeech][gigaspeech]
- [Aidatatang_200zh][aidatatang_200zh]
- [WenetSpeech][wenetspeech]
- [Alimeeting][alimeeting]
- [Aishell4][aishell4]
- [TAL_CSASR][tal_csasr]
### yesno
@ -46,9 +47,7 @@ Training takes less than 30 seconds and gives you the following WER:
```
[test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ]
```
We do provide a Colab notebook for this recipe.
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1tIjjzaJc3IvGyKiMCDWO-TSnBgkcuN3B?usp=sharing)
We provide a Colab notebook for this recipe: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1tIjjzaJc3IvGyKiMCDWO-TSnBgkcuN3B?usp=sharing)
### LibriSpeech
@ -82,7 +81,7 @@ The WER for this model is:
|-----|------------|------------|
| WER | 6.59 | 17.69 |
We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-iSfQMp2So-We_Uu49N4AAcMInB72u9z?usp=sharing)
We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-iSfQMp2So-We_Uu49N4AAcMInB72u9z?usp=sharing)
#### Transducer: Conformer encoder + LSTM decoder
@ -118,19 +117,54 @@ We provide a Colab notebook to run a pre-trained transducer conformer + stateles
| | test-clean | test-other |
|-----|------------|------------|
| WER | 2.57 | 5.95 |
| WER | 2.15 | 5.20 |
Note: No auxiliary losses are used in the training and no LMs are used
in the decoding.
#### k2 pruned RNN-T + GigaSpeech
| | test-clean | test-other |
|-----|------------|------------|
| WER | 2.00 | 4.63 |
| WER | 1.78 | 4.08 |
Note: No auxiliary losses are used in the training and no LMs are used
in the decoding.
#### k2 pruned RNN-T + GigaSpeech + CommonVoice
| | test-clean | test-other |
|-----|------------|------------|
| WER | 1.90 | 3.98 |
Note: No auxiliary losses are used in the training and no LMs are used
in the decoding.
### GigaSpeech
We provide two models for this recipe: [Conformer CTC model][GigaSpeech_conformer_ctc]
and [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][GigaSpeech_pruned_transducer_stateless2].
#### Conformer CTC
| | Dev | Test |
|-----|-------|-------|
| WER | 10.47 | 10.58 |
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
| | Dev | Test |
|----------------------|-------|-------|
| greedy search | 10.51 | 10.73 |
| fast beam search | 10.50 | 10.69 |
| modified beam search | 10.40 | 10.51 |
### Aishell
We provide two models for this recipe: [conformer CTC model][Aishell_conformer_ctc]
and [TDNN LSTM CTC model][Aishell_tdnn_lstm_ctc].
We provide three models for this recipe: [conformer CTC model][Aishell_conformer_ctc],
[TDNN LSTM CTC model][Aishell_tdnn_lstm_ctc], and [Transducer Stateless Model][Aishell_pruned_transducer_stateless7],
#### Conformer CTC Model
@ -140,20 +174,6 @@ The best CER we currently have is:
|-----|------|
| CER | 4.26 |
We provide a Colab notebook to run a pre-trained conformer CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg](https://colab.research.google.com/drive/1WnG17io5HEZ0Gn_cnh_VzK5QYOoiiklC?usp=sharing)
#### Transducer Stateless Model
The best CER we currently have is:
| | test |
|-----|------|
| CER | 4.68 |
We provide a Colab notebook to run a pre-trained TransducerStateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14XaT2MhnBkK-3_RqqWq3K90Xlbin-GZC?usp=sharing)
#### TDNN LSTM CTC Model
The CER for this model is:
@ -164,6 +184,46 @@ The CER for this model is:
We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1jbyzYq3ytm6j2nlEt-diQm-6QVWyDDEa?usp=sharing)
#### Transducer Stateless Model
The best CER we currently have is:
| | test |
|-----|------|
| CER | 4.38 |
We provide a Colab notebook to run a pre-trained TransducerStateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14XaT2MhnBkK-3_RqqWq3K90Xlbin-GZC?usp=sharing)
### Aishell2
We provide one model for this recipe: [Transducer Stateless Model][Aishell2_pruned_transducer_stateless5].
#### Transducer Stateless Model
The best WER we currently have is:
| | dev-ios | test-ios |
|-----|------------|------------|
| WER | 5.32 | 5.56 |
### Aishell4
We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][Aishell4_pruned_transducer_stateless5].
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with all subsets)
The best CER we currently have is:
| | test |
|-----|------------|
| CER | 29.08 |
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing)
### TIMIT
We provide two models for this recipe: [TDNN LSTM CTC model][TIMIT_tdnn_lstm_ctc]
@ -187,7 +247,8 @@ The PER for this model is:
|--|--|
|PER| 17.66% |
We provide a Colab notebook to run a pre-trained TDNN LiGRU CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/11IT-k4HQIgQngXz1uvWsEYktjqQt7Tmb?usp=sharing)
We provide a Colab notebook to run a pre-trained TDNN LiGRU CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing)
### TED-LIUM3
@ -215,24 +276,6 @@ The best WER using modified beam search with beam size 4 is:
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1je_1zGrOkGVVd4WLzgkXRHxl-I27yWtz?usp=sharing)
### GigaSpeech
We provide two models for this recipe: [Conformer CTC model][GigaSpeech_conformer_ctc]
and [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][GigaSpeech_pruned_transducer_stateless2].
#### Conformer CTC
| | Dev | Test |
|-----|-------|-------|
| WER | 10.47 | 10.58 |
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
| | Dev | Test |
|----------------------|-------|-------|
| greedy search | 10.51 | 10.73 |
| fast beam search | 10.50 | 10.69 |
| modified beam search | 10.40 | 10.51 |
### Aidatatang_200zh
@ -248,6 +291,7 @@ We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1wNSnSj3T5oOctbh5IGCa393gKOoQw2GH?usp=sharing)
### WenetSpeech
We provide some models for this recipe: [Pruned stateless RNN-T_2: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][WenetSpeech_pruned_transducer_stateless2] and [Pruned stateless RNN-T_5: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][WenetSpeech_pruned_transducer_stateless5].
@ -284,20 +328,6 @@ We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1tKr3f0mL17uO_ljdHGKtR7HOmthYHwJG?usp=sharing)
### Aishell4
We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][Aishell4_pruned_transducer_stateless5].
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with all subsets)
The best CER(%) results:
| | test |
|----------------------|--------|
| greedy search | 29.89 |
| fast beam search | 28.91 |
| modified beam search | 29.08 |
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing)
### TAL_CSASR
@ -333,6 +363,9 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad
[LibriSpeech_transducer_stateless]: egs/librispeech/ASR/transducer_stateless
[Aishell_tdnn_lstm_ctc]: egs/aishell/ASR/tdnn_lstm_ctc
[Aishell_conformer_ctc]: egs/aishell/ASR/conformer_ctc
[Aishell_pruned_transducer_stateless7]: egs/aishell/ASR/pruned_transducer_stateless7_bbpe
[Aishell2_pruned_transducer_stateless5]: egs/aishell2/ASR/pruned_transducer_stateless5
[Aishell4_pruned_transducer_stateless5]: egs/aishell4/ASR/pruned_transducer_stateless5
[TIMIT_tdnn_lstm_ctc]: egs/timit/ASR/tdnn_lstm_ctc
[TIMIT_tdnn_ligru_ctc]: egs/timit/ASR/tdnn_ligru_ctc
[TED-LIUM3_transducer_stateless]: egs/tedlium3/ASR/transducer_stateless
@ -343,17 +376,17 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad
[WenetSpeech_pruned_transducer_stateless2]: egs/wenetspeech/ASR/pruned_transducer_stateless2
[WenetSpeech_pruned_transducer_stateless5]: egs/wenetspeech/ASR/pruned_transducer_stateless5
[Alimeeting_pruned_transducer_stateless2]: egs/alimeeting/ASR/pruned_transducer_stateless2
[Aishell4_pruned_transducer_stateless5]: egs/aishell4/ASR/pruned_transducer_stateless5
[TAL_CSASR_pruned_transducer_stateless5]: egs/tal_csasr/ASR/pruned_transducer_stateless5
[yesno]: egs/yesno/ASR
[librispeech]: egs/librispeech/ASR
[aishell]: egs/aishell/ASR
[aishell2]: egs/aishell2/ASR
[aishell4]: egs/aishell4/ASR
[timit]: egs/timit/ASR
[tedlium3]: egs/tedlium3/ASR
[gigaspeech]: egs/gigaspeech/ASR
[aidatatang_200zh]: egs/aidatatang_200zh/ASR
[wenetspeech]: egs/wenetspeech/ASR
[alimeeting]: egs/alimeeting/ASR
[aishell4]: egs/aishell4/ASR
[tal_csasr]: egs/tal_csasr/ASR
[k2]: https://github.com/k2-fsa/k2

View File

@ -3,64 +3,91 @@
Installation
============
- |os|
- |device|
- |python_versions|
- |torch_versions|
- |k2_versions|
.. |os| image:: ./images/os-Linux_macOS-ff69b4.svg
:alt: Supported operating systems
.. |device| image:: ./images/device-CPU_CUDA-orange.svg
:alt: Supported devices
.. |python_versions| image:: ./images/python-gt-v3.6-blue.svg
:alt: Supported python versions
.. |torch_versions| image:: ./images/torch-gt-v1.6.0-green.svg
:alt: Supported PyTorch versions
.. |k2_versions| image:: ./images/k2-gt-v1.9-blueviolet.svg
:alt: Supported k2 versions
``icefall`` depends on `k2 <https://github.com/k2-fsa/k2>`_ and
`lhotse <https://github.com/lhotse-speech/lhotse>`_.
We recommend you to use the following steps to install the dependencies.
We recommend that you use the following steps to install the dependencies.
- (0) Install PyTorch and torchaudio
- (1) Install k2
- (2) Install lhotse
- (0) Install CUDA toolkit and cuDNN
- (1) Install PyTorch and torchaudio
- (2) Install k2
- (3) Install lhotse
.. caution::
99% users who have issues about the installation are using conda.
.. caution::
99% users who have issues about the installation are using conda.
.. caution::
99% users who have issues about the installation are using conda.
.. hint::
We suggest that you use ``pip install`` to install PyTorch.
You can use the following command to create a virutal environment in Python:
.. code-block:: bash
python3 -m venv ./my_env
source ./my_env/bin/activate
.. caution::
Installation order matters.
(0) Install PyTorch and torchaudio
(0) Install CUDA toolkit and cuDNN
----------------------------------
Please refer to
`<https://k2-fsa.github.io/k2/installation/cuda-cudnn.html>`_
to install CUDA and cuDNN.
(1) Install PyTorch and torchaudio
----------------------------------
Please refer `<https://pytorch.org/>`_ to install PyTorch
and torchaudio.
.. hint::
(1) Install k2
You can also go to `<https://download.pytorch.org/whl/torch_stable.html>`_
to download pre-compiled wheels and install them.
.. caution::
Please install torch and torchaudio at the same time.
(2) Install k2
--------------
Please refer to `<https://k2-fsa.github.io/k2/installation/index.html>`_
to install ``k2``.
.. CAUTION::
.. caution::
You need to install ``k2`` with a version at least **v1.9**.
Please don't change your installed PyTorch after you have installed k2.
.. HINT::
.. note::
If you have already installed PyTorch and don't want to replace it,
please install a version of ``k2`` that is compiled against the version
of PyTorch you are using.
We suggest that you install k2 from source by following
`<https://k2-fsa.github.io/k2/installation/from_source.html>`_
or
`<https://k2-fsa.github.io/k2/installation/for_developers.html>`_.
(2) Install lhotse
.. hint::
Please always install the latest version of k2.
(3) Install lhotse
------------------
Please refer to `<https://lhotse.readthedocs.io/en/latest/getting-started.html#installation>`_
@ -75,8 +102,7 @@ to install ``lhotse``.
to install the latest version of lhotse.
(3) Download icefall
(4) Download icefall
--------------------
``icefall`` is a collection of Python scripts; what you need is to download it
@ -338,44 +364,42 @@ The log of running ``./prepare.sh`` is:
.. code-block::
2021-08-23 19:27:26 (prepare.sh:24:main) dl_dir: /tmp/icefall/egs/yesno/ASR/download
2021-08-23 19:27:26 (prepare.sh:27:main) stage 0: Download data
Downloading waves_yesno.tar.gz: 4.49MB [00:03, 1.39MB/s]
2021-08-23 19:27:30 (prepare.sh:36:main) Stage 1: Prepare yesno manifest
2021-08-23 19:27:31 (prepare.sh:42:main) Stage 2: Compute fbank for yesno
2021-08-23 19:27:32,803 INFO [compute_fbank_yesno.py:52] Processing train
Extracting and storing features: 100%|_______________________________________________________________| 90/90 [00:01<00:00, 80.57it/s]
2021-08-23 19:27:34,085 INFO [compute_fbank_yesno.py:52] Processing test
Extracting and storing features: 100%|______________________________________________________________| 30/30 [00:00<00:00, 248.21it/s]
2021-08-23 19:27:34 (prepare.sh:48:main) Stage 3: Prepare lang
2021-08-23 19:27:35 (prepare.sh:63:main) Stage 4: Prepare G
/tmp/pip-install-fcordre9/kaldilm_6899d26f2d684ad48f21025950cd2866/kaldilm/csrc/arpa_file_parser.cc:void kaldilm::ArpaFileParser::Rea
d(std::istream&):79
[I] Reading \data\ section.
/tmp/pip-install-fcordre9/kaldilm_6899d26f2d684ad48f21025950cd2866/kaldilm/csrc/arpa_file_parser.cc:void kaldilm::ArpaFileParser::Rea
d(std::istream&):140
[I] Reading \1-grams: section.
2021-08-23 19:27:35 (prepare.sh:89:main) Stage 5: Compile HLG
2021-08-23 19:27:35,928 INFO [compile_hlg.py:120] Processing data/lang_phone
2021-08-23 19:27:35,929 INFO [lexicon.py:116] Converting L.pt to Linv.pt
2021-08-23 19:27:35,931 INFO [compile_hlg.py:48] Building ctc_topo. max_token_id: 3
2021-08-23 19:27:35,932 INFO [compile_hlg.py:52] Loading G.fst.txt
2021-08-23 19:27:35,932 INFO [compile_hlg.py:62] Intersecting L and G
2021-08-23 19:27:35,933 INFO [compile_hlg.py:64] LG shape: (4, None)
2021-08-23 19:27:35,933 INFO [compile_hlg.py:66] Connecting LG
2021-08-23 19:27:35,933 INFO [compile_hlg.py:68] LG shape after k2.connect: (4, None)
2021-08-23 19:27:35,933 INFO [compile_hlg.py:70] <class 'torch.Tensor'>
2021-08-23 19:27:35,933 INFO [compile_hlg.py:71] Determinizing LG
2021-08-23 19:27:35,934 INFO [compile_hlg.py:74] <class '_k2.RaggedInt'>
2021-08-23 19:27:35,934 INFO [compile_hlg.py:76] Connecting LG after k2.determinize
2021-08-23 19:27:35,934 INFO [compile_hlg.py:79] Removing disambiguation symbols on LG
2021-08-23 19:27:35,934 INFO [compile_hlg.py:87] LG shape after k2.remove_epsilon: (6, None)
2021-08-23 19:27:35,935 INFO [compile_hlg.py:92] Arc sorting LG
2021-08-23 19:27:35,935 INFO [compile_hlg.py:95] Composing H and LG
2021-08-23 19:27:35,935 INFO [compile_hlg.py:102] Connecting LG
2021-08-23 19:27:35,935 INFO [compile_hlg.py:105] Arc sorting LG
2021-08-23 19:27:35,936 INFO [compile_hlg.py:107] HLG.shape: (8, None)
2021-08-23 19:27:35,936 INFO [compile_hlg.py:123] Saving HLG.pt to data/lang_phone
2023-05-12 17:55:21 (prepare.sh:27:main) dl_dir: /tmp/icefall/egs/yesno/ASR/download
2023-05-12 17:55:21 (prepare.sh:30:main) Stage 0: Download data
/tmp/icefall/egs/yesno/ASR/download/waves_yesno.tar.gz: 100%|_______________________________________________________________| 4.70M/4.70M [06:54<00:00, 11.4kB/s]
2023-05-12 18:02:19 (prepare.sh:39:main) Stage 1: Prepare yesno manifest
2023-05-12 18:02:21 (prepare.sh:45:main) Stage 2: Compute fbank for yesno
2023-05-12 18:02:23,199 INFO [compute_fbank_yesno.py:65] Processing train
Extracting and storing features: 100%|_______________________________________________________________| 90/90 [00:00<00:00, 212.60it/s]
2023-05-12 18:02:23,640 INFO [compute_fbank_yesno.py:65] Processing test
Extracting and storing features: 100%|_______________________________________________________________| 30/30 [00:00<00:00, 304.53it/s]
2023-05-12 18:02:24 (prepare.sh:51:main) Stage 3: Prepare lang
2023-05-12 18:02:26 (prepare.sh:66:main) Stage 4: Prepare G
/project/kaldilm/csrc/arpa_file_parser.cc:void kaldilm::ArpaFileParser::Read(std::istream&):79
[I] Reading \data\ section.
/project/kaldilm/csrc/arpa_file_parser.cc:void kaldilm::ArpaFileParser::Read(std::istream&):140
[I] Reading \1-grams: section.
2023-05-12 18:02:26 (prepare.sh:92:main) Stage 5: Compile HLG
2023-05-12 18:02:28,581 INFO [compile_hlg.py:124] Processing data/lang_phone
2023-05-12 18:02:28,582 INFO [lexicon.py:171] Converting L.pt to Linv.pt
2023-05-12 18:02:28,609 INFO [compile_hlg.py:48] Building ctc_topo. max_token_id: 3
2023-05-12 18:02:28,610 INFO [compile_hlg.py:52] Loading G.fst.txt
2023-05-12 18:02:28,611 INFO [compile_hlg.py:62] Intersecting L and G
2023-05-12 18:02:28,613 INFO [compile_hlg.py:64] LG shape: (4, None)
2023-05-12 18:02:28,613 INFO [compile_hlg.py:66] Connecting LG
2023-05-12 18:02:28,614 INFO [compile_hlg.py:68] LG shape after k2.connect: (4, None)
2023-05-12 18:02:28,614 INFO [compile_hlg.py:70] <class 'torch.Tensor'>
2023-05-12 18:02:28,614 INFO [compile_hlg.py:71] Determinizing LG
2023-05-12 18:02:28,615 INFO [compile_hlg.py:74] <class '_k2.ragged.RaggedTensor'>
2023-05-12 18:02:28,615 INFO [compile_hlg.py:76] Connecting LG after k2.determinize
2023-05-12 18:02:28,615 INFO [compile_hlg.py:79] Removing disambiguation symbols on LG
2023-05-12 18:02:28,616 INFO [compile_hlg.py:91] LG shape after k2.remove_epsilon: (6, None)
2023-05-12 18:02:28,617 INFO [compile_hlg.py:96] Arc sorting LG
2023-05-12 18:02:28,617 INFO [compile_hlg.py:99] Composing H and LG
2023-05-12 18:02:28,619 INFO [compile_hlg.py:106] Connecting LG
2023-05-12 18:02:28,619 INFO [compile_hlg.py:109] Arc sorting LG
2023-05-12 18:02:28,619 INFO [compile_hlg.py:111] HLG.shape: (8, None)
2023-05-12 18:02:28,619 INFO [compile_hlg.py:127] Saving HLG.pt to data/lang_phone
Training
@ -408,49 +432,53 @@ The training log is given below:
.. code-block::
2021-08-23 19:30:31,072 INFO [train.py:465] Training started
2021-08-23 19:30:31,072 INFO [train.py:466] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lr': 0.01,
'feature_dim': 23, 'weight_decay': 1e-06, 'start_epoch': 0, 'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, '
best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 'valid_interval': 10, 'beam_size': 10, 'reduction': 'sum', 'use_doub
le_scores': True, 'world_size': 1, 'master_port': 12354, 'tensorboard': True, 'num_epochs': 15, 'feature_dir': PosixPath('data/fbank'
), 'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0
, 'on_the_fly_feats': False, 'shuffle': True, 'return_cuts': True, 'num_workers': 2}
2021-08-23 19:30:31,074 INFO [lexicon.py:113] Loading pre-compiled data/lang_phone/Linv.pt
2021-08-23 19:30:31,098 INFO [asr_datamodule.py:146] About to get train cuts
2021-08-23 19:30:31,098 INFO [asr_datamodule.py:240] About to get train cuts
2021-08-23 19:30:31,102 INFO [asr_datamodule.py:149] About to create train dataset
2021-08-23 19:30:31,102 INFO [asr_datamodule.py:200] Using SingleCutSampler.
2021-08-23 19:30:31,102 INFO [asr_datamodule.py:206] About to create train dataloader
2021-08-23 19:30:31,102 INFO [asr_datamodule.py:219] About to get test cuts
2021-08-23 19:30:31,102 INFO [asr_datamodule.py:246] About to get test cuts
2021-08-23 19:30:31,357 INFO [train.py:416] Epoch 0, batch 0, batch avg loss 1.0789, total avg loss: 1.0789, batch size: 4
2021-08-23 19:30:31,848 INFO [train.py:416] Epoch 0, batch 10, batch avg loss 0.5356, total avg loss: 0.7556, batch size: 4
2021-08-23 19:30:32,301 INFO [train.py:432] Epoch 0, valid loss 0.9972, best valid loss: 0.9972 best valid epoch: 0
2021-08-23 19:30:32,805 INFO [train.py:416] Epoch 0, batch 20, batch avg loss 0.2436, total avg loss: 0.5717, batch size: 3
2021-08-23 19:30:33,109 INFO [train.py:432] Epoch 0, valid loss 0.4167, best valid loss: 0.4167 best valid epoch: 0
2021-08-23 19:30:33,121 INFO [checkpoint.py:62] Saving checkpoint to tdnn/exp/epoch-0.pt
2021-08-23 19:30:33,325 INFO [train.py:416] Epoch 1, batch 0, batch avg loss 0.2214, total avg loss: 0.2214, batch size: 5
2021-08-23 19:30:33,798 INFO [train.py:416] Epoch 1, batch 10, batch avg loss 0.0781, total avg loss: 0.1343, batch size: 5
2021-08-23 19:30:34,065 INFO [train.py:432] Epoch 1, valid loss 0.0859, best valid loss: 0.0859 best valid epoch: 1
2021-08-23 19:30:34,556 INFO [train.py:416] Epoch 1, batch 20, batch avg loss 0.0421, total avg loss: 0.0975, batch size: 3
2021-08-23 19:30:34,810 INFO [train.py:432] Epoch 1, valid loss 0.0431, best valid loss: 0.0431 best valid epoch: 1
2021-08-23 19:30:34,824 INFO [checkpoint.py:62] Saving checkpoint to tdnn/exp/epoch-1.pt
2023-05-12 18:04:59,759 INFO [train.py:481] Training started
2023-05-12 18:04:59,759 INFO [train.py:482] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lr': 0.01, 'feature_dim': 23, 'weight_decay': 1e-06, 'start_epoch': 0,
'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 'reset_interval': 20, 'valid_interval': 10, 'beam_size': 10,
'reduction': 'sum', 'use_double_scores': True, 'world_size': 1, 'master_port': 12354, 'tensorboard': True, 'num_epochs': 15, 'seed': 42, 'feature_dir': PosixPath('data/fbank'), 'max_duration': 30.0,
'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': False, 'return_cuts': True, 'num_workers': 2,
'env_info': {'k2-version': '1.24.3', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '3b7f09fa35e72589914f67089c0da9f196a92ca4', 'k2-git-date': 'Mon May 8 22:58:45 2023',
'lhotse-version': '1.15.0.dev+git.6fcfced.clean', 'torch-version': '2.0.0+cu118', 'torch-cuda-available': False, 'torch-cuda-version': '11.8', 'python-version': '3.1', 'icefall-git-branch': 'master',
'icefall-git-sha1': '30bde4b-clean', 'icefall-git-date': 'Thu May 11 17:37:47 2023', 'icefall-path': '/tmp/icefall',
'k2-path': 'tmp/lib/python3.10/site-packages/k2-1.24.3.dev20230512+cuda11.8.torch2.0.0-py3.10-linux-x86_64.egg/k2/__init__.py',
'lhotse-path': 'tmp/lib/python3.10/site-packages/lhotse/__init__.py', 'hostname': 'host', 'IP address': '0.0.0.0'}}
2023-05-12 18:04:59,761 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt
2023-05-12 18:04:59,764 INFO [train.py:495] device: cpu
2023-05-12 18:04:59,791 INFO [asr_datamodule.py:146] About to get train cuts
2023-05-12 18:04:59,791 INFO [asr_datamodule.py:244] About to get train cuts
2023-05-12 18:04:59,852 INFO [asr_datamodule.py:149] About to create train dataset
2023-05-12 18:04:59,852 INFO [asr_datamodule.py:199] Using SingleCutSampler.
2023-05-12 18:04:59,852 INFO [asr_datamodule.py:205] About to create train dataloader
2023-05-12 18:04:59,853 INFO [asr_datamodule.py:218] About to get test cuts
2023-05-12 18:04:59,853 INFO [asr_datamodule.py:252] About to get test cuts
2023-05-12 18:04:59,986 INFO [train.py:422] Epoch 0, batch 0, loss[loss=1.065, over 2436.00 frames. ], tot_loss[loss=1.065, over 2436.00 frames. ], batch size: 4
2023-05-12 18:05:00,352 INFO [train.py:422] Epoch 0, batch 10, loss[loss=0.4561, over 2828.00 frames. ], tot_loss[loss=0.7076, over 22192.90 frames. ], batch size: 4
2023-05-12 18:05:00,691 INFO [train.py:444] Epoch 0, validation loss=0.9002, over 18067.00 frames.
2023-05-12 18:05:00,996 INFO [train.py:422] Epoch 0, batch 20, loss[loss=0.2555, over 2695.00 frames. ], tot_loss[loss=0.484, over 34971.47 frames. ], batch size: 5
2023-05-12 18:05:01,217 INFO [train.py:444] Epoch 0, validation loss=0.4688, over 18067.00 frames.
2023-05-12 18:05:01,251 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-0.pt
2023-05-12 18:05:01,389 INFO [train.py:422] Epoch 1, batch 0, loss[loss=0.2532, over 2436.00 frames. ], tot_loss[loss=0.2532, over 2436.00 frames. ], batch size: 4
2023-05-12 18:05:01,637 INFO [train.py:422] Epoch 1, batch 10, loss[loss=0.1139, over 2828.00 frames. ], tot_loss[loss=0.1592, over 22192.90 frames. ], batch size: 4
2023-05-12 18:05:01,859 INFO [train.py:444] Epoch 1, validation loss=0.1629, over 18067.00 frames.
2023-05-12 18:05:02,094 INFO [train.py:422] Epoch 1, batch 20, loss[loss=0.0767, over 2695.00 frames. ], tot_loss[loss=0.118, over 34971.47 frames. ], batch size: 5
2023-05-12 18:05:02,350 INFO [train.py:444] Epoch 1, validation loss=0.06778, over 18067.00 frames.
2023-05-12 18:05:02,395 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-1.pt
... ...
2021-08-23 19:30:49,657 INFO [train.py:416] Epoch 13, batch 0, batch avg loss 0.0109, total avg loss: 0.0109, batch size: 5
2021-08-23 19:30:49,984 INFO [train.py:416] Epoch 13, batch 10, batch avg loss 0.0093, total avg loss: 0.0096, batch size: 4
2021-08-23 19:30:50,239 INFO [train.py:432] Epoch 13, valid loss 0.0104, best valid loss: 0.0101 best valid epoch: 12
2021-08-23 19:30:50,569 INFO [train.py:416] Epoch 13, batch 20, batch avg loss 0.0092, total avg loss: 0.0096, batch size: 2
2021-08-23 19:30:50,819 INFO [train.py:432] Epoch 13, valid loss 0.0101, best valid loss: 0.0101 best valid epoch: 13
2021-08-23 19:30:50,835 INFO [checkpoint.py:62] Saving checkpoint to tdnn/exp/epoch-13.pt
2021-08-23 19:30:51,024 INFO [train.py:416] Epoch 14, batch 0, batch avg loss 0.0105, total avg loss: 0.0105, batch size: 5
2021-08-23 19:30:51,317 INFO [train.py:416] Epoch 14, batch 10, batch avg loss 0.0099, total avg loss: 0.0097, batch size: 4
2021-08-23 19:30:51,552 INFO [train.py:432] Epoch 14, valid loss 0.0108, best valid loss: 0.0101 best valid epoch: 13
2021-08-23 19:30:51,869 INFO [train.py:416] Epoch 14, batch 20, batch avg loss 0.0096, total avg loss: 0.0097, batch size: 5
2021-08-23 19:30:52,107 INFO [train.py:432] Epoch 14, valid loss 0.0102, best valid loss: 0.0101 best valid epoch: 13
2021-08-23 19:30:52,126 INFO [checkpoint.py:62] Saving checkpoint to tdnn/exp/epoch-14.pt
2021-08-23 19:30:52,128 INFO [train.py:537] Done!
2023-05-12 18:05:14,789 INFO [train.py:422] Epoch 13, batch 0, loss[loss=0.01056, over 2436.00 frames. ], tot_loss[loss=0.01056, over 2436.00 frames. ], batch size: 4
2023-05-12 18:05:15,016 INFO [train.py:422] Epoch 13, batch 10, loss[loss=0.009022, over 2828.00 frames. ], tot_loss[loss=0.009985, over 22192.90 frames. ], batch size: 4
2023-05-12 18:05:15,271 INFO [train.py:444] Epoch 13, validation loss=0.01088, over 18067.00 frames.
2023-05-12 18:05:15,497 INFO [train.py:422] Epoch 13, batch 20, loss[loss=0.01174, over 2695.00 frames. ], tot_loss[loss=0.01077, over 34971.47 frames. ], batch size: 5
2023-05-12 18:05:15,747 INFO [train.py:444] Epoch 13, validation loss=0.01087, over 18067.00 frames.
2023-05-12 18:05:15,783 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-13.pt
2023-05-12 18:05:15,921 INFO [train.py:422] Epoch 14, batch 0, loss[loss=0.01045, over 2436.00 frames. ], tot_loss[loss=0.01045, over 2436.00 frames. ], batch size: 4
2023-05-12 18:05:16,146 INFO [train.py:422] Epoch 14, batch 10, loss[loss=0.008957, over 2828.00 frames. ], tot_loss[loss=0.009903, over 22192.90 frames. ], batch size: 4
2023-05-12 18:05:16,374 INFO [train.py:444] Epoch 14, validation loss=0.01092, over 18067.00 frames.
2023-05-12 18:05:16,598 INFO [train.py:422] Epoch 14, batch 20, loss[loss=0.01169, over 2695.00 frames. ], tot_loss[loss=0.01065, over 34971.47 frames. ], batch size: 5
2023-05-12 18:05:16,824 INFO [train.py:444] Epoch 14, validation loss=0.01077, over 18067.00 frames.
2023-05-12 18:05:16,862 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-14.pt
2023-05-12 18:05:16,865 INFO [train.py:555] Done!
Decoding
~~~~~~~~
@ -465,22 +493,25 @@ The decoding log is:
.. code-block::
2021-08-23 19:35:30,192 INFO [decode.py:249] Decoding started
2021-08-23 19:35:30,192 INFO [decode.py:250] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lm_dir': PosixPath('data/lm'), 'feature_dim': 23, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'epoch': 14, 'avg': 2, 'feature_dir': PosixPath('data/fbank'), 'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': True, 'return_cuts': True, 'num_workers': 2}
2021-08-23 19:35:30,193 INFO [lexicon.py:113] Loading pre-compiled data/lang_phone/Linv.pt
2021-08-23 19:35:30,213 INFO [decode.py:259] device: cpu
2021-08-23 19:35:30,217 INFO [decode.py:279] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt']
/tmp/icefall/icefall/checkpoint.py:146: UserWarning: floor_divide is deprecated, and will be removed in a future version of pytorch.
It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values.
To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at /pytorch/aten/src/ATen/native/BinaryOps.cpp:450.)
avg[k] //= n
2021-08-23 19:35:30,220 INFO [asr_datamodule.py:219] About to get test cuts
2021-08-23 19:35:30,220 INFO [asr_datamodule.py:246] About to get test cuts
2021-08-23 19:35:30,409 INFO [decode.py:190] batch 0/8, cuts processed until now is 4
2021-08-23 19:35:30,571 INFO [decode.py:228] The transcripts are stored in tdnn/exp/recogs-test_set.txt
2021-08-23 19:35:30,572 INFO [utils.py:317] [test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ]
2021-08-23 19:35:30,573 INFO [decode.py:236] Wrote detailed error stats to tdnn/exp/errs-test_set.txt
2021-08-23 19:35:30,573 INFO [decode.py:299] Done!
2023-05-12 18:08:30,482 INFO [decode.py:263] Decoding started
2023-05-12 18:08:30,483 INFO [decode.py:264] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lm_dir': PosixPath('data/lm'), 'feature_dim': 23,
'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'epoch': 14, 'avg': 2, 'export': False, 'feature_dir': PosixPath('data/fbank'),
'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': False, 'return_cuts': True,
'num_workers': 2, 'env_info': {'k2-version': '1.24.3', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '3b7f09fa35e72589914f67089c0da9f196a92ca4', 'k2-git-date': 'Mon May 8 22:58:45 2023',
'lhotse-version': '1.15.0.dev+git.6fcfced.clean', 'torch-version': '2.0.0+cu118', 'torch-cuda-available': False, 'torch-cuda-version': '11.8', 'python-version': '3.1', 'icefall-git-branch': 'master',
'icefall-git-sha1': '30bde4b-clean', 'icefall-git-date': 'Thu May 11 17:37:47 2023', 'icefall-path': '/tmp/icefall',
'k2-path': '/tmp/lib/python3.10/site-packages/k2-1.24.3.dev20230512+cuda11.8.torch2.0.0-py3.10-linux-x86_64.egg/k2/__init__.py',
'lhotse-path': '/tmp/lib/python3.10/site-packages/lhotse/__init__.py', 'hostname': 'host', 'IP address': '0.0.0.0'}}
2023-05-12 18:08:30,483 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt
2023-05-12 18:08:30,487 INFO [decode.py:273] device: cpu
2023-05-12 18:08:30,513 INFO [decode.py:291] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt']
2023-05-12 18:08:30,521 INFO [asr_datamodule.py:218] About to get test cuts
2023-05-12 18:08:30,521 INFO [asr_datamodule.py:252] About to get test cuts
2023-05-12 18:08:30,675 INFO [decode.py:204] batch 0/?, cuts processed until now is 4
2023-05-12 18:08:30,923 INFO [decode.py:241] The transcripts are stored in tdnn/exp/recogs-test_set.txt
2023-05-12 18:08:30,924 INFO [utils.py:558] [test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ]
2023-05-12 18:08:30,925 INFO [decode.py:249] Wrote detailed error stats to tdnn/exp/errs-test_set.txt
2023-05-12 18:08:30,925 INFO [decode.py:316] Done!
**Congratulations!** You have successfully setup the environment and have run the first recipe in ``icefall``.

View File

@ -3,6 +3,15 @@ Export to ONNX
In this section, we describe how to export models to `ONNX`_.
.. hint::
Before you continue, please run:
.. code-block:: bash
pip install onnx
In each recipe, there is a file called ``export-onnx.py``, which is used
to export trained models to `ONNX`_.

View File

@ -2,6 +2,57 @@
### Aishell training result(Stateless Transducer)
#### Pruned transducer stateless 7 (zipformer)
See <https://github.com/k2-fsa/icefall/pull/986>
[./pruned_transducer_stateless7_bbpe](./pruned_transducer_stateless7_bbpe)
**Note**: The modeling units are byte level BPEs
The best results I have gotten are:
Vocab size | Greedy search(dev & test) | Modified beam search(dev & test) | Fast beam search (dev & test) | Fast beam search LG (dev & test) | comments
-- | -- | -- | -- | -- | --
500 | 4.31 & 4.59 | 4.25 & 4.54 | 4.27 & 4.55 | 4.07 & 4.38 | --epoch 48 --avg 29
The training command:
```
export CUDA_VISIBLE_DEVICES="4,5,6,7"
./pruned_transducer_stateless7_bbpe/train.py \
--world-size 4 \
--num-epochs 50 \
--start-epoch 1 \
--use-fp16 1 \
--max-duration 800 \
--bpe-model data/lang_bbpe_500/bbpe.model \
--exp-dir pruned_transducer_stateless7_bbpe/exp \
--lr-epochs 6 \
--master-port 12535
```
The decoding command:
```
for m in greedy_search modified_beam_search fast_beam_search fast_beam_search_LG; do
./pruned_transducer_stateless7_bbpe/decode.py \
--epoch 48 \
--avg 29 \
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
--max-sym-per-frame 1 \
--ngram-lm-scale 0.25 \
--ilme-scale 0.2 \
--bpe-model data/lang_bbpe_500/bbpe.model \
--max-duration 2000 \
--decoding-method $m
done
```
The pretrained model is available at: https://huggingface.co/pkufool/icefall_asr_aishell_pruned_transducer_stateless7_bbpe
#### Pruned transducer stateless 3
See <https://github.com/k2-fsa/icefall/pull/436>
@ -75,7 +126,7 @@ for epoch in 29; do
done
```
We provide the option of shallow fusion with a RNN language model. The pre-trained language model is
We provide the option of shallow fusion with a RNN language model. The pre-trained language model is
available at <https://huggingface.co/marcoyang/icefall-aishell-rnn-lm>. To decode with the language model,
please use the following command:

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/compile_lg.py

View File

@ -33,6 +33,7 @@ and generates the following files in the directory `lang_dir`:
- tokens.txt
"""
import argparse
import re
from pathlib import Path
from typing import Dict, List
@ -189,8 +190,22 @@ def generate_tokens(text_file: str) -> Dict[str, int]:
return tokens
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
It should contain the bpe.model and words.txt
""",
)
return parser.parse_args()
def main():
lang_dir = Path("data/lang_char")
args = get_args()
lang_dir = Path(args.lang_dir)
text_file = lang_dir / "text"
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")

View File

@ -0,0 +1,267 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input `lang_dir`, which should contain::
- lang_dir/bbpe.model,
- lang_dir/words.txt
and generates the following files in the directory `lang_dir`:
- lexicon.txt
- lexicon_disambig.txt
- L.pt
- L_disambig.pt
- tokens.txt
"""
import argparse
from pathlib import Path
from typing import Dict, List, Tuple
import k2
import sentencepiece as spm
import torch
from prepare_lang import (
Lexicon,
add_disambig_symbols,
add_self_loops,
write_lexicon,
write_mapping,
)
from icefall.byte_utils import byte_encode
from icefall.utils import str2bool, tokenize_by_CJK_char
def lexicon_to_fst_no_sil(
lexicon: Lexicon,
token2id: Dict[str, int],
word2id: Dict[str, int],
need_self_loops: bool = False,
) -> k2.Fsa:
"""Convert a lexicon to an FST (in k2 format).
Args:
lexicon:
The input lexicon. See also :func:`read_lexicon`
token2id:
A dict mapping tokens to IDs.
word2id:
A dict mapping words to IDs.
need_self_loops:
If True, add self-loop to states with non-epsilon output symbols
on at least one arc out of the state. The input label for this
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
Returns:
Return an instance of `k2.Fsa` representing the given lexicon.
"""
loop_state = 0 # words enter and leave from here
next_state = 1 # the next un-allocated state, will be incremented as we go
arcs = []
# The blank symbol <blk> is defined in local/train_bpe_model.py
assert token2id["<blk>"] == 0
assert word2id["<eps>"] == 0
eps = 0
for word, pieces in lexicon:
assert len(pieces) > 0, f"{word} has no pronunciations"
cur_state = loop_state
word = word2id[word]
pieces = [token2id[i] for i in pieces]
for i in range(len(pieces) - 1):
w = word if i == 0 else eps
arcs.append([cur_state, next_state, pieces[i], w, 0])
cur_state = next_state
next_state += 1
# now for the last piece of this word
i = len(pieces) - 1
w = word if i == 0 else eps
arcs.append([cur_state, loop_state, pieces[i], w, 0])
if need_self_loops:
disambig_token = token2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
)
final_state = next_state
arcs.append([loop_state, final_state, -1, -1, 0])
arcs.append([final_state])
arcs = sorted(arcs, key=lambda arc: arc[0])
arcs = [[str(i) for i in arc] for arc in arcs]
arcs = [" ".join(arc) for arc in arcs]
arcs = "\n".join(arcs)
fsa = k2.Fsa.from_str(arcs, acceptor=False)
return fsa
def generate_lexicon(
model_file: str, words: List[str], oov: str
) -> Tuple[Lexicon, Dict[str, int]]:
"""Generate a lexicon from a BPE model.
Args:
model_file:
Path to a sentencepiece model.
words:
A list of strings representing words.
oov:
The out of vocabulary word in lexicon.
Returns:
Return a tuple with two elements:
- A dict whose keys are words and values are the corresponding
word pieces.
- A dict representing the token symbol, mapping from tokens to IDs.
"""
sp = spm.SentencePieceProcessor()
sp.load(str(model_file))
# Convert word to word piece IDs instead of word piece strings
# to avoid OOV tokens.
encode_words = [byte_encode(tokenize_by_CJK_char(w)) for w in words]
words_pieces_ids: List[List[int]] = sp.encode(encode_words, out_type=int)
# Now convert word piece IDs back to word piece strings.
words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids]
lexicon = []
for word, pieces in zip(words, words_pieces):
lexicon.append((word, pieces))
lexicon.append((oov, ["", sp.id_to_piece(sp.unk_id())]))
token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())}
return lexicon, token2id
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
It should contain the bpe.model and words.txt
""",
)
parser.add_argument(
"--oov",
type=str,
default="<UNK>",
help="The out of vocabulary word in lexicon.",
)
parser.add_argument(
"--debug",
type=str2bool,
default=False,
help="""True for debugging, which will generate
a visualization of the lexicon FST.
Caution: If your lexicon contains hundreds of thousands
of lines, please set it to False!
See "test/test_bpe_lexicon.py" for usage.
""",
)
return parser.parse_args()
def main():
args = get_args()
lang_dir = Path(args.lang_dir)
model_file = lang_dir / "bbpe.model"
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
words = word_sym_table.symbols
excluded = ["<eps>", "!SIL", "<SPOKEN_NOISE>", args.oov, "#0", "<s>", "</s>"]
for w in excluded:
if w in words:
words.remove(w)
lexicon, token_sym_table = generate_lexicon(model_file, words, args.oov)
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
next_token_id = max(token_sym_table.values()) + 1
for i in range(max_disambig + 1):
disambig = f"#{i}"
assert disambig not in token_sym_table
token_sym_table[disambig] = next_token_id
next_token_id += 1
word_sym_table.add("#0")
word_sym_table.add("<s>")
word_sym_table.add("</s>")
write_mapping(lang_dir / "tokens.txt", token_sym_table)
write_lexicon(lang_dir / "lexicon.txt", lexicon)
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst_no_sil(
lexicon,
token2id=token_sym_table,
word2id=word_sym_table,
)
L_disambig = lexicon_to_fst_no_sil(
lexicon_disambig,
token2id=token_sym_table,
word2id=word_sym_table,
need_self_loops=True,
)
torch.save(L.as_dict(), lang_dir / "L.pt")
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
if args.debug:
labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
L.labels_sym = labels_sym
L.aux_labels_sym = aux_labels_sym
L.draw(f"{lang_dir / 'L.svg'}", title="L.pt")
L_disambig.labels_sym = labels_sym
L_disambig.aux_labels_sym = aux_labels_sym
L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,113 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# You can install sentencepiece via:
#
# pip install sentencepiece
#
# Due to an issue reported in
# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030
#
# Please install a version >=0.1.96
import argparse
import re
import shutil
import tempfile
from pathlib import Path
import sentencepiece as spm
from icefall import byte_encode, tokenize_by_CJK_char
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
The generated bpe.model is saved to this directory.
""",
)
parser.add_argument(
"--transcript",
type=str,
help="Training transcript.",
)
parser.add_argument(
"--vocab-size",
type=int,
help="Vocabulary size for BPE training",
)
return parser.parse_args()
def _convert_to_bchar(in_path: str, out_path: str):
with open(out_path, "w") as f:
for line in open(in_path, "r").readlines():
f.write(byte_encode(tokenize_by_CJK_char(line)) + "\n")
def main():
args = get_args()
vocab_size = args.vocab_size
lang_dir = Path(args.lang_dir)
model_type = "unigram"
model_prefix = f"{lang_dir}/{model_type}_{vocab_size}"
character_coverage = 1.0
input_sentence_size = 100000000
user_defined_symbols = ["<blk>", "<sos/eos>"]
unk_id = len(user_defined_symbols)
# Note: unk_id is fixed to 2.
# If you change it, you should also change other
# places that are using it.
temp = tempfile.NamedTemporaryFile()
train_text = temp.name
_convert_to_bchar(args.transcript, train_text)
model_file = Path(model_prefix + ".model")
if not model_file.is_file():
spm.SentencePieceTrainer.train(
input=train_text,
vocab_size=vocab_size,
model_type=model_type,
model_prefix=model_prefix,
input_sentence_size=input_sentence_size,
character_coverage=character_coverage,
user_defined_symbols=user_defined_symbols,
unk_id=unk_id,
bos_id=-1,
eos_id=-1,
)
else:
print(f"{model_file} exists - skipping")
return
shutil.copyfile(model_file, f"{lang_dir}/bbpe.model")
if __name__ == "__main__":
main()

View File

@ -35,6 +35,15 @@ dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
# vocab size for sentence piece models.
# It will generate data/lang_bbpe_xxx,
# data/lang_bbpe_yyy if the array contains xxx, yyy
vocab_sizes=(
# 2000
# 1000
500
)
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
@ -47,20 +56,6 @@ log() {
log "dl_dir: $dl_dir"
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "stage -1: Download LM"
# We assume that you have installed the git-lfs, if not, you could install it
# using: `sudo apt-get install git-lfs && git-lfs install`
git lfs 1>/dev/null 2>&1 || (echo "please install git-lfs, consider using: sudo apt-get install git-lfs && git-lfs install" && exit 1)
if [ ! -f $dl_dir/lm/3-gram.unpruned.arpa ]; then
git clone https://huggingface.co/pkufool/aishell_lm $dl_dir/lm
pushd $dl_dir/lm
git lfs pull --include "3-gram.unpruned.arpa"
popd
fi
fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "stage 0: Download data"
@ -134,7 +129,6 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
fi
lang_phone_dir=data/lang_phone
lang_char_dir=data/lang_char
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Prepare phone based lang"
mkdir -p $lang_phone_dir
@ -183,45 +177,107 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
fi
fi
lang_char_dir=data/lang_char
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Prepare char based lang"
mkdir -p $lang_char_dir
# We reuse words.txt from phone based lexicon
# so that the two can share G.pt later.
cp $lang_phone_dir/words.txt $lang_char_dir
# The transcripts in training set, generated in stage 5
cp $lang_phone_dir/transcript_words.txt $lang_char_dir/transcript_words.txt
cat $dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt |
cut -d " " -f 2- | sed -e 's/[ \t\r\n]*//g' > $lang_char_dir/text
cut -d " " -f 2- > $lang_char_dir/text
(echo '<eps> 0'; echo '!SIL 1'; echo '<SPOKEN_NOISE> 2'; echo '<UNK> 3';) \
> $lang_char_dir/words.txt
cat $lang_char_dir/text | sed 's/ /\n/g' | sort -u | sed '/^$/d' \
| awk '{print $1" "NR+3}' >> $lang_char_dir/words.txt
num_lines=$(< $lang_char_dir/words.txt wc -l)
(echo "#0 $num_lines"; echo "<s> $(($num_lines + 1))"; echo "</s> $(($num_lines + 2))";) \
>> $lang_char_dir/words.txt
if [ ! -f $lang_char_dir/L_disambig.pt ]; then
./local/prepare_char.py
./local/prepare_char.py --lang-dir $lang_char_dir
fi
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Prepare G"
# We assume you have install kaldilm, if not, please install
# it using: pip install kaldilm
log "Stage 7: Prepare Byte BPE based lang"
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bbpe_${vocab_size}
mkdir -p $lang_dir
cp $lang_char_dir/words.txt $lang_dir
cp $lang_char_dir/text $lang_dir
if [ ! -f $lang_dir/bbpe.model ]; then
./local/train_bbpe_model.py \
--lang-dir $lang_dir \
--vocab-size $vocab_size \
--transcript $lang_dir/text
fi
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang_bbpe.py --lang-dir $lang_dir
fi
done
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
log "Stage 8: Prepare G"
mkdir -p data/lm
if [ ! -f data/lm/G_3_gram.fst.txt ]; then
# Train LM on transcripts
if [ ! -f data/lm/3-gram.unpruned.arpa ]; then
python3 ./shared/make_kn_lm.py \
-ngram-order 3 \
-text $lang_char_dir/transcript_words.txt \
-lm data/lm/3-gram.unpruned.arpa
fi
# We assume you have install kaldilm, if not, please install
# it using: pip install kaldilm
if [ ! -f data/lm/G_3_gram_char.fst.txt ]; then
# It is used in building HLG
python3 -m kaldilm \
--read-symbol-table="$lang_phone_dir/words.txt" \
--disambig-symbol='#0' \
--max-order=3 \
$dl_dir/lm/3-gram.unpruned.arpa > data/lm/G_3_gram.fst.txt
data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram_phone.fst.txt
python3 -m kaldilm \
--read-symbol-table="$lang_char_dir/words.txt" \
--disambig-symbol='#0' \
--max-order=3 \
data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram_char.fst.txt
fi
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
log "Stage 8: Compile HLG"
./local/compile_hlg.py --lang-dir $lang_phone_dir
./local/compile_hlg.py --lang-dir $lang_char_dir
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
log "Stage 9: Compile LG & HLG"
./local/compile_hlg.py --lang-dir $lang_phone_dir --lm G_3_gram_phone
./local/compile_hlg.py --lang-dir $lang_char_dir --lm G_3_gram_char
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bbpe_${vocab_size}
./local/compile_hlg.py --lang-dir $lang_dir --lm G_3_gram_char
done
./local/compile_lg.py --lang-dir $lang_phone_dir --lm G_3_gram_phone
./local/compile_lg.py --lang-dir $lang_char_dir --lm G_3_gram_char
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bbpe_${vocab_size}
./local/compile_lg.py --lang-dir $lang_dir --lm G_3_gram_char
done
fi
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
log "Stage 9: Generate LM training data"
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
log "Stage 10: Generate LM training data"
log "Processing char based data"
out_dir=data/lm_training_char
@ -267,8 +323,8 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
fi
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
log "Stage 10: Sort LM training data"
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
log "Stage 11: Sort LM training data"
# Sort LM training data by sentence length in descending order
# for ease of training.
#
@ -295,7 +351,7 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
--out-statistics $out_dir/statistics-test.txt
fi
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
log "Stage 11: Train RNN LM model"
python ../../../icefall/rnn_lm/train.py \
--start-epoch 0 \

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/asr_datamodule.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/beam_search.py

View File

@ -0,0 +1,819 @@
#!/usr/bin/env python3
#
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao,
# Xiaoyu Yang,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./pruned_transducer_stateless7_bbpe/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search (not recommended)
./pruned_transducer_stateless7_bbpe/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./pruned_transducer_stateless7_bbpe/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search (one best)
./pruned_transducer_stateless7_bbpe/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(5) fast beam search (nbest)
./pruned_transducer_stateless7_bbpe/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless7_bbpe/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(7) fast beam search (with LG)
./pruned_transducer_stateless7_bbpe/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
--max-duration 600 \
--decoding-method fast_beam_search_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
"""
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import AishellAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall import (
LmScorer,
NgramLm,
byte_encode,
smart_byte_decode,
tokenize_by_CJK_char,
)
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bbpe_500/bbpe.model",
help="Path to the byte BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bbpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_LG
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
If you use fast_beam_search_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.25,
help="""
Used only when --decoding_method is fast_beam_search_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--ilme-scale",
type=float,
default=0.2,
help="""
Used only when --decoding_method is fast_beam_search_LG.
It specifies the scale for the internal language model estimation.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(smart_byte_decode(hyp).split())
elif params.decoding_method == "fast_beam_search_LG":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
subtract_ilme=True,
ilme_scale=params.ilme_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(smart_byte_decode(hyp).split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
ref_texts = []
for tx in supervisions["text"]:
ref_texts.append(byte_encode(tokenize_by_CJK_char(tx)))
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(ref_texts),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(smart_byte_decode(hyp).split())
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(smart_byte_decode(hyp).split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(smart_byte_decode(hyp).split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(smart_byte_decode(sp.decode(hyp)).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
key += f"_ilme_scale_{params.ilme_scale}"
return {key: hyps}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
LM:
A neural network LM, used during shallow fusion
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
results_char = []
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
AishellAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_LG",
"fast_beam_search_nbest",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
params.suffix += f"-ilme-scale-{params.ilme_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bbpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_LG":
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
aishell = AishellAsrDataModule(args)
test_cuts = aishell.test_cuts()
dev_cuts = aishell.valid_cuts()
test_dl = aishell.test_dataloaders(test_cuts)
dev_dl = aishell.test_dataloaders(dev_cuts)
test_sets = ["test", "dev"]
test_dls = [test_dl, dev_dl]
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/encoder_interface.py

View File

@ -0,0 +1,320 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
(1) Export to torchscript model using torch.jit.script()
./pruned_transducer_stateless7_bbpe/export.py \
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
--bpe-model data/lang_bbpe_500/bbpe.model \
--epoch 30 \
--avg 9 \
--jit 1
It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
load it by `torch.jit.load("cpu_jit.pt")`.
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
Check
https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall.
(2) Export `model.state_dict()`
./pruned_transducer_stateless7_bbpe/export.py \
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
--bpe-model data/lang_bbpe_500/bbpe.model \
--epoch 20 \
--avg 10
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
load it by `icefall.checkpoint.load_checkpoint()`.
To use the generated file with `pruned_transducer_stateless7_bbpe/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/aishell/ASR
./pruned_transducer_stateless7_bbpe/decode.py \
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--decoding-method greedy_search \
--bpe-model data/lang_bbpe_500/bbpe.model
Check ./pretrained.py for its usage.
Note: If you don't want to train a model from scratch, we have
provided one for you. You can get it at
https://huggingface.co/pkufool/icefall_asr_aishell_pruned_transducer_stateless7_bbpe
with the following commands:
sudo apt-get install git-lfs
git lfs install
git clone https://huggingface.co/pkufool/icefall_asr_aishell_pruned_transducer_stateless7_bbpe
# You will find the pre-trained model in icefall_asr_aishell_pruned_transducer_stateless7_bbpe/exp
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7_bbpe/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bbpe_500/bbpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
It will generate a file named cpu_jit.pt
Check ./jit_pretrained.py for how to use it.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bbpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
if params.jit is True:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torchscript. Export model.state_dict()")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,274 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads torchscript models, exported by `torch.jit.script()`
and uses them to decode waves.
You can use the following command to get the exported models:
./pruned_transducer_stateless7_bbpe/export.py \
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
--bpe-model data/lang_bbpe_500/bbpe.model \
--epoch 49 \
--avg 28 \
--jit 1
Usage of this script:
./pruned_transducer_stateless7_bbpe/jit_pretrained.py \
--nn-model-filename ./pruned_transducer_stateless7_bbpe/exp/cpu_jit.pt \
--bpe-model ./data/lang_bbpe_500/bbpe.model \
/path/to/foo.wav \
/path/to/bar.wav
"""
import argparse
import logging
import math
from typing import List
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
from icefall import smart_byte_decode
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--nn-model-filename",
type=str,
required=True,
help="Path to the torchscript model cpu_jit.pt",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float = 16000
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
def greedy_search(
model: torch.jit.ScriptModule,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
A 3-D tensor of shape (N, T, C)
encoder_out_lens:
A 1-D tensor of shape (N,).
Returns:
Return the decoded results for each utterance.
"""
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
device = encoder_out.device
blank_id = 0 # hard-code to 0
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
context_size = model.decoder.context_size
hyps = [[blank_id] * context_size for _ in range(N)]
decoder_input = torch.tensor(
hyps,
device=device,
dtype=torch.int64,
) # (N, context_size)
decoder_out = model.decoder(
decoder_input,
need_pad=torch.tensor([False]),
).squeeze(1)
offset = 0
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = packed_encoder_out.data[start:end]
current_encoder_out = current_encoder_out
# current_encoder_out's shape: (batch_size, encoder_out_dim)
offset = end
decoder_out = decoder_out[:batch_size]
logits = model.joiner(
current_encoder_out,
decoder_out,
)
# logits'shape (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
hyps[i].append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
decoder_input = torch.tensor(
decoder_input,
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=torch.tensor([False]),
)
decoder_out = decoder_out.squeeze(1)
sorted_ans = [h[context_size:] for h in hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
model = torch.jit.load(args.nn_model_filename)
model.eval()
model.to(device)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files,
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features,
batch_first=True,
padding_value=math.log(1e-10),
)
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(
x=features,
x_lens=feature_lengths,
)
hyps = greedy_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
s = "\n"
for filename, hyp in zip(args.sound_files, hyps):
words = smart_byte_decode(sp.decode(hyp))
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/model.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/optim.py

View File

@ -0,0 +1,345 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads a checkpoint and uses it to decode waves.
You can generate the checkpoint with the following command:
./pruned_transducer_stateless7_bbpe/export.py \
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
--bpe-model data/lang_bbpe_500/bbpe.model \
--epoch 48 \
--avg 29
Usage of this script:
(1) greedy search
./pruned_transducer_stateless7_bbpe/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt \
--bpe-model ./data/lang_bbpe_500/bbpe.model \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
(2) modified beam search
./pruned_transducer_stateless7_bbpe/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt \
--bpe-model ./data/lang_bbpe_500/bbpe.model \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(3) fast beam search
./pruned_transducer_stateless7_bbpe/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt \
--bpe-model ./data/lang_bbpe_500/bbpe.model \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
Note: ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt is generated by
./pruned_transducer_stateless7_bbpe/export.py
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall import smart_byte_decode
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.""",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
add_model_arguments(parser)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
num_waves = encoder_out.size(0)
hyps = []
msg = f"Using {params.method}"
if params.method == "beam_search":
msg += f" with beam size {params.beam_size}"
logging.info(msg)
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(smart_byte_decode(hyp).split())
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(smart_byte_decode(hyp).split())
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(smart_byte_decode(hyp).split())
else:
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(smart_byte_decode(sp.decode(hyp)).split())
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/test_model.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py

View File

@ -21,7 +21,7 @@ import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import List
from typing import Any, Dict, List, Optional
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import (
@ -181,7 +181,16 @@ class AishellAsrDataModule:
"with training dataset. ",
)
def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
def train_dataloaders(
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
@ -277,6 +286,10 @@ class AishellAsrDataModule:
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
train_dl = DataLoader(
train,
sampler=train_sampler,
@ -325,7 +338,7 @@ class AishellAsrDataModule:
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
logging.info("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats

View File

@ -1,5 +1,76 @@
## Results
### pruned_transducer_stateless7 (zipformer + multidataset(LibriSpeech + GigaSpeech + CommonVoice 13.0))
See <https://github.com/k2-fsa/icefall/pull/1010> for more details.
[pruned_transducer_stateless7](./pruned_transducer_stateless7)
The tensorboard log can be found at
<https://tensorboard.dev/experiment/SwdJoHgZSZWn8ph9aJLb8g/>
You can find a pretrained model, training logs, decoding logs, and decoding
results at:
<https://huggingface.co/yfyeung/icefall-asr-multidataset-pruned_transducer_stateless7-2023-05-04>
You can use <https://github.com/k2-fsa/sherpa> to deploy it.
Number of model parameters: 70369391, i.e., 70.37 M
| decoding method | test-clean | test-other | comment |
|----------------------|------------|------------|--------------------|
| greedy_search | 1.91 | 4.06 | --epoch 30 --avg 7 |
| modified_beam_search | 1.90 | 3.99 | --epoch 30 --avg 7 |
| fast_beam_search | 1.90 | 3.98 | --epoch 30 --avg 7 |
The training commands are:
```bash
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
./pruned_transducer_stateless7/train.py \
--world-size 8 \
--num-epochs 30 \
--use-multidataset 1 \
--use-fp16 1 \
--max-duration 750 \
--exp-dir pruned_transducer_stateless7/exp
```
The decoding commands are:
```bash
# greedy_search
./pruned_transducer_stateless7/decode.py \
--epoch 30 \
--avg 7 \
--use-averaged-model 1 \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 600 \
--decoding-method greedy_search
# modified_beam_search
./pruned_transducer_stateless7/decode.py \
--epoch 30 \
--avg 7 \
--use-averaged-model 1 \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
# fast_beam_search
./pruned_transducer_stateless7/decode.py \
--epoch 30 \
--avg 7 \
--use-averaged-model 1 \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
```
### Streaming Zipformer-Transducer (Pruned Stateless Transducer + Streaming Zipformer + Multi-Dataset)
#### [pruned_transducer_stateless7_streaming_multi](./pruned_transducer_stateless7_streaming_multi)
@ -215,11 +286,12 @@ done
We also support decoding with neural network LMs. After combining with language models, the WERs are
| decoding method | chunk size | test-clean | test-other | comment | decoding mode |
|----------------------|------------|------------|------------|---------------------|----------------------|
| modified beam search | 320ms | 3.11 | 7.93 | --epoch 30 --avg 9 | simulated streaming |
| modified beam search + RNNLM shallow fusion | 320ms | 2.58 | 6.65 | --epoch 30 --avg 9 | simulated streaming |
| modified beam search + RNNLM nbest rescore | 320ms | 2.59 | 6.86 | --epoch 30 --avg 9 | simulated streaming |
| `modified_beam_search` | 320ms | 3.11 | 7.93 | --epoch 30 --avg 9 | simulated streaming |
| `modified_beam_search_lm_shallow_fusion` | 320ms | 2.58 | 6.65 | --epoch 30 --avg 9 | simulated streaming |
| `modified_beam_search_lm_rescore` | 320ms | 2.59 | 6.86 | --epoch 30 --avg 9 | simulated streaming |
| `modified_beam_search_lm_rescore_LODR` | 320ms | 2.52 | 6.73 | --epoch 30 --avg 9 | simulated streaming |
Please use the following command for RNNLM shallow fusion:
Please use the following command for `modified_beam_search_lm_shallow_fusion`:
```bash
for lm_scale in $(seq 0.15 0.01 0.38); do
for beam_size in 4 8 12; do
@ -246,7 +318,7 @@ for lm_scale in $(seq 0.15 0.01 0.38); do
done
```
Please use the following command for RNNLM rescore:
Please use the following command for `modified_beam_search_lm_rescore`:
```bash
./pruned_transducer_stateless7_streaming/decode.py \
--epoch 30 \
@ -268,7 +340,32 @@ Please use the following command for RNNLM rescore:
--lm-vocab-size 500
```
A well-trained RNNLM can be found here: <https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main>.
Please use the following command for `modified_beam_search_lm_rescore_LODR`:
```bash
./pruned_transducer_stateless7_streaming/decode.py \
--epoch 30 \
--avg 9 \
--use-averaged-model True \
--beam-size 8 \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method modified_beam_search_lm_rescore_LODR \
--use-shallow-fusion 0 \
--lm-type rnn \
--lm-exp-dir rnn_lm/exp \
--lm-epoch 99 \
--lm-avg 1 \
--rnn-lm-embedding-dim 2048 \
--rnn-lm-hidden-dim 2048 \
--rnn-lm-num-layers 3 \
--lm-vocab-size 500 \
--tokens-ngram 2 \
--backoff-id 500
```
A well-trained RNNLM can be found here: <https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main>. The bi-gram used in LODR decoding
can be found here: <https://huggingface.co/marcoyang/librispeech_bigram>.
#### Smaller model

View File

@ -109,10 +109,11 @@ def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
logging.info("Removing disambiguation symbols on LG")
LG.labels[LG.labels >= first_token_disambig_id] = 0
# See https://github.com/k2-fsa/k2/issues/874
# for why we need to set LG.properties to None
LG.__dict__["_properties"] = None
# LG.labels[LG.labels >= first_token_disambig_id] = 0
# see https://github.com/k2-fsa/k2/pull/1140
labels = LG.labels
labels[labels >= first_token_disambig_id] = 0
LG.labels = labels
assert isinstance(LG.aux_labels, k2.RaggedTensor)
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0

View File

@ -45,11 +45,18 @@ def get_args():
help="""Input and output directory.
""",
)
parser.add_argument(
"--lm",
type=str,
default="G_3_gram",
help="""Stem name for LM used in HLG compiling.
""",
)
return parser.parse_args()
def compile_LG(lang_dir: str) -> k2.Fsa:
def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
"""
Args:
lang_dir:
@ -61,15 +68,15 @@ def compile_LG(lang_dir: str) -> k2.Fsa:
lexicon = Lexicon(lang_dir)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
if Path("data/lm/G_3_gram.pt").is_file():
logging.info("Loading pre-compiled G_3_gram")
d = torch.load("data/lm/G_3_gram.pt")
if Path(f"data/lm/{lm}.pt").is_file():
logging.info(f"Loading pre-compiled {lm}")
d = torch.load(f"data/lm/{lm}.pt")
G = k2.Fsa.from_dict(d)
else:
logging.info("Loading G_3_gram.fst.txt")
with open("data/lm/G_3_gram.fst.txt") as f:
logging.info(f"Loading {lm}.fst.txt")
with open(f"data/lm/{lm}.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
torch.save(G.as_dict(), "data/lm/G_3_gram.pt")
torch.save(G.as_dict(), f"data/lm/{lm}.pt")
first_token_disambig_id = lexicon.token_table["#0"]
first_word_disambig_id = lexicon.word_table["#0"]
@ -96,10 +103,11 @@ def compile_LG(lang_dir: str) -> k2.Fsa:
logging.info("Removing disambiguation symbols on LG")
LG.labels[LG.labels >= first_token_disambig_id] = 0
# See https://github.com/k2-fsa/k2/issues/874
# for why we need to set LG.properties to None
LG.__dict__["_properties"] = None
# LG.labels[LG.labels >= first_token_disambig_id] = 0
# see https://github.com/k2-fsa/k2/pull/1140
labels = LG.labels
labels[labels >= first_token_disambig_id] = 0
LG.labels = labels
assert isinstance(LG.aux_labels, k2.RaggedTensor)
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
@ -126,7 +134,7 @@ def main():
logging.info(f"Processing {lang_dir}")
LG = compile_LG(lang_dir)
LG = compile_LG(lang_dir, args.lm)
logging.info(f"Saving LG.pt to {lang_dir}")
torch.save(LG.as_dict(), f"{lang_dir}/LG.pt")

View File

@ -41,9 +41,57 @@ import os
import shutil
from pathlib import Path
from lhotse.utils import urlretrieve_progress
from tqdm.auto import tqdm
# This function is copied from lhotse
def tqdm_urlretrieve_hook(t):
"""Wraps tqdm instance.
Don't forget to close() or __exit__()
the tqdm instance once you're done with it (easiest using `with` syntax).
Example
-------
>>> from urllib.request import urlretrieve
>>> with tqdm(...) as t:
... reporthook = tqdm_urlretrieve_hook(t)
... urlretrieve(..., reporthook=reporthook)
Source: https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py
"""
last_b = [0]
def update_to(b=1, bsize=1, tsize=None):
"""
b : int, optional
Number of blocks transferred so far [default: 1].
bsize : int, optional
Size of each block (in tqdm units) [default: 1].
tsize : int, optional
Total size (in tqdm units). If [default: None] or -1,
remains unchanged.
"""
if tsize not in (None, -1):
t.total = tsize
displayed = t.update((b - last_b[0]) * bsize)
last_b[0] = b
return displayed
return update_to
# This function is copied from lhotse
def urlretrieve_progress(url, filename=None, data=None, desc=None):
"""
Works exactly like urllib.request.urlretrieve, but attaches a tqdm hook to
display a progress bar of the download.
Use "desc" argument to display a user-readable string that informs what is
being downloaded.
"""
from urllib.request import urlretrieve
with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=desc) as t:
reporthook = tqdm_urlretrieve_hook(t)
return urlretrieve(url=url, filename=filename, reporthook=reporthook, data=data)
def get_args():
parser = argparse.ArgumentParser()

View File

@ -184,7 +184,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
./shared/convert-k2-to-openfst.py \
--olabels aux_labels \
$lang_dir/L_disambig.pt \
$lang_dir/disambig_L.fst
$lang_dir/L_disambig.fst
fi
fi

View File

@ -198,7 +198,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
./shared/convert-k2-to-openfst.py \
--olabels aux_labels \
$lang_dir/L_disambig.pt \
$lang_dir/disambig_L.fst
$lang_dir/L_disambig.fst
fi
fi
@ -328,46 +328,3 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
./prepare_common_voice.sh
fi
fi
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
log "Stage 11: Create multidataset"
split_dir=data/fbank/multidataset_split_${num_splits}
if [ ! -f data/fbank/multidataset_split/.multidataset.done ]; then
mkdir -p $split_dir/multidataset
log "Split LibriSpeech"
if [ ! -f $split_dir/.librispeech_split.done ]; then
lhotse split $num_splits ./data/fbank/librispeech_cuts_train-all-shuf.jsonl.gz $split_dir
touch $split_dir/.librispeech_split.done
fi
if [[ "${multidataset[@]}" =~ "gigaspeech" ]]; then
log "Split GigaSpeech XL"
if [ ! -f $split_dir/.gigaspeech_XL_split.done ]; then
cd $split_dir
ln -sv ../gigaspeech_XL_split_2000/gigaspeech_cuts_XL.*.jsonl.gz .
cd ../../..
touch $split_dir/.gigaspeech_XL_split.done
fi
fi
if [[ "${multidataset[@]}" =~ "commonvoice" ]]; then
log "Split CommonVoice"
if [ ! -f $split_dir/.cv-en_train_split.done ]; then
lhotse split $num_splits ./data/en/fbank/cv-en_cuts_train.jsonl.gz $split_dir
touch $split_dir/.cv-en_train_split.done
fi
fi
if [ ! -f $split_dir/.multidataset_mix.done ]; then
log "Mix multidataset"
for ((seq=1; seq<=$num_splits; seq++)); do
fseq=$(printf "%04d" $seq)
gunzip -c $split_dir/*.*${fseq}.jsonl.gz | \
shuf | gzip -c > $split_dir/multidataset/multidataset_cuts_train.${fseq}.jsonl.gz
done
touch $split_dir/.multidataset_mix.done
fi
touch data/fbank/multidataset_split/.multidataset.done
fi
fi

View File

@ -47,6 +47,8 @@ def fast_beam_search_one_best(
max_states: int,
max_contexts: int,
temperature: float = 1.0,
subtract_ilme: bool = False,
ilme_scale: float = 0.1,
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1.
@ -88,6 +90,8 @@ def fast_beam_search_one_best(
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
subtract_ilme=subtract_ilme,
ilme_scale=ilme_scale,
)
best_path = one_best_decoding(lattice)
@ -428,6 +432,8 @@ def fast_beam_search(
max_states: int,
max_contexts: int,
temperature: float = 1.0,
subtract_ilme: bool = False,
ilme_scale: float = 0.1,
) -> k2.Fsa:
"""It limits the maximum number of symbols per frame to 1.
@ -498,6 +504,17 @@ def fast_beam_search(
)
logits = logits.squeeze(1).squeeze(1)
log_probs = (logits / temperature).log_softmax(dim=-1)
if subtract_ilme:
ilme_logits = model.joiner(
torch.zeros_like(
current_encoder_out, device=current_encoder_out.device
).unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
ilme_logits = ilme_logits.squeeze(1).squeeze(1)
ilme_log_probs = (ilme_logits / temperature).log_softmax(dim=-1)
log_probs -= ilme_scale * ilme_log_probs
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
@ -1244,7 +1261,7 @@ def modified_beam_search_lm_rescore(
# get the best hyp with different lm_scale
for lm_scale in lm_scale_list:
key = f"nnlm_scale_{lm_scale}"
key = f"nnlm_scale_{lm_scale:.2f}"
tot_scores = am_scores.values + lm_scores * lm_scale
ragged_tot_scores = k2.RaggedTensor(shape=am_scores.shape, value=tot_scores)
max_indexes = ragged_tot_scores.argmax().tolist()
@ -1257,6 +1274,222 @@ def modified_beam_search_lm_rescore(
return ans
def modified_beam_search_lm_rescore_LODR(
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
LM: LmScorer,
LODR_lm: NgramLm,
sp: spm.SentencePieceProcessor,
lm_scale_list: List[int],
beam: int = 4,
temperature: float = 1.0,
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Rescore the final results with RNNLM and return the one with the highest score
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C).
encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding.
beam:
Number of active paths during the beam search.
temperature:
Softmax temperature.
LM:
A neural network language model
return_timestamps:
Whether to return timestamps.
Returns:
If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0)
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = next(model.parameters()).device
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
B = [HypothesisList() for _ in range(N)]
for i in range(N):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
timestamp=[],
)
)
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0
finalized_B = []
for (t, batch_size) in enumerate(batch_size_list):
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
offset = end
finalized_B = B[batch_size:] + finalized_B
B = B[:batch_size]
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat(
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out)
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, 1, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out,
decoder_out,
project_input=False,
) # (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
new_timestamp = hyp.timestamp[:]
if new_token not in (blank_id, unk_id):
new_ys.append(new_token)
new_timestamp.append(t)
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
)
B[i].add(new_hyp)
B = B + finalized_B
# get the am_scores for n-best list
hyps_shape = get_hyps_shape(B)
am_scores = torch.tensor([hyp.log_prob.item() for b in B for hyp in b])
am_scores = k2.RaggedTensor(value=am_scores, shape=hyps_shape).to(device)
# now LM rescore
# prepare input data to LM
candidate_seqs = [hyp.ys[context_size:] for b in B for hyp in b]
possible_seqs = k2.RaggedTensor(candidate_seqs)
row_splits = possible_seqs.shape.row_splits(1)
sentence_token_lengths = row_splits[1:] - row_splits[:-1]
possible_seqs_with_sos = add_sos(possible_seqs, sos_id=1)
possible_seqs_with_eos = add_eos(possible_seqs, eos_id=1)
sentence_token_lengths += 1
x = possible_seqs_with_sos.pad(mode="constant", padding_value=blank_id)
y = possible_seqs_with_eos.pad(mode="constant", padding_value=blank_id)
x = x.to(device).to(torch.int64)
y = y.to(device).to(torch.int64)
sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64)
lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths)
assert lm_scores.ndim == 2
lm_scores = -1 * lm_scores.sum(dim=1)
# now LODR scores
import math
LODR_scores = []
for seq in candidate_seqs:
tokens = " ".join(sp.id_to_piece(seq))
LODR_scores.append(LODR_lm.score(tokens))
LODR_scores = torch.tensor(LODR_scores).to(device) * math.log(
10
) # arpa scores are 10-based
assert lm_scores.shape == LODR_scores.shape
ans = {}
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
LODR_scale_list = [0.05 * i for i in range(1, 20)]
# get the best hyp with different lm_scale and lodr_scale
for lm_scale in lm_scale_list:
for lodr_scale in LODR_scale_list:
key = f"nnlm_scale_{lm_scale:.2f}_lodr_scale_{lodr_scale:.2f}"
tot_scores = (
am_scores.values / lm_scale + lm_scores - LODR_scores * lodr_scale
)
ragged_tot_scores = k2.RaggedTensor(shape=am_scores.shape, value=tot_scores)
max_indexes = ragged_tot_scores.argmax().tolist()
unsorted_hyps = [candidate_seqs[idx] for idx in max_indexes]
hyps = []
for idx in unsorted_indices:
hyps.append(unsorted_hyps[idx])
ans[key] = hyps
return ans
def _deprecated_modified_beam_search(
model: Transducer,
encoder_out: torch.Tensor,

View File

@ -0,0 +1,681 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
"""
This script exports a transducer model from PyTorch to ONNX.
We use the pre-trained model from
https://huggingface.co/pkufool/icefall_librispeech_streaming_pruned_transducer_stateless5_20220729
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/pkufool/icefall_librispeech_streaming_pruned_transducer_stateless5_20220729
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained-epoch-25-avg-5.pt"
cd exp
ln -s pretrained-epoch-25-avg-5.pt epoch-99.pt
popd
2. Export the model to ONNX
./pruned_transducer_stateless5/export-onnx-streaming.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--use-averaged-model 0 \
--exp-dir $repo/exp \
--num-encoder-layers 18 \
--dim-feedforward 2048 \
--nhead 8 \
--encoder-dim 512 \
--decoder-dim 512 \
--joiner-dim 512
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
See ./onnx_pretrained.py and ./onnx_check.py for how to
use the exported ONNX models.
You can find the exported models in
https://huggingface.co/csukuangfj/sherpa-onnx-streaming-conformer-en-2023-05-09
"""
import argparse
import logging
from pathlib import Path
from typing import Dict, Tuple
import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
from conformer import Conformer
from onnxruntime.quantization import QuantType, quantize_dynamic
from decoder import Decoder
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless5/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = value
onnx.save(model, filename)
class OnnxEncoder(nn.Module):
"""A wrapper for Conformer and the encoder_proj from the joiner"""
def __init__(self, encoder: Conformer, encoder_proj: nn.Linear):
"""
Args:
encoder:
A Conformer encoder.
encoder_proj:
The projection layer for encoder from the joiner.
"""
super().__init__()
self.encoder = encoder
self.encoder_proj = encoder_proj
self.num_encoder_layers = encoder.encoder_layers
self.encoder_dim = encoder.d_model
self.cnn_module_kernel = encoder.cnn_module_kernel
# Note you can tune these values
self.left_context = 64 # after subsampling
self.chunk_size = 16 # after subsampling
self.right_context = 0 # after subsampling
subsampling_factor = 4
self.pad_length = (self.right_context + 2) * subsampling_factor + 3
self.T = (self.chunk_size * subsampling_factor) + self.pad_length
self.decode_chunk_len = self.chunk_size * subsampling_factor
def forward(
self,
x: torch.Tensor,
cached_attn: torch.Tensor,
cached_conv: torch.Tensor,
processed_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Please see the help information of Conformer.forward
Args:
x:
A 3-D tensor of shape (N, self.T, C)
cached_attn:
A 3-D tensor of shape
(num_encoder_layers, self.left_context, N, self.encoder_dim)
cached_conv:
A 3-D tensor of shape
(num_encoder_layers, self.cnn_module_kernel-1, N, self.encoder_dim)
processed_lens:
A 1-D tensor of shape (N,). It contains number of processed frames
after subsampling. Its dtype is torch.int64.
Returns:
Return a tuple containing:
- encoder_out, A 3-D tensor of shape (N, self.chunk_size, joiner_dim)
- new_cached_attn, it has the same shape as cached_attn
- new_cached_conv, it has the same shape as cached_conv
"""
assert x.size(1) == self.T, (x.shape, self.T)
N = x.size(0)
x_lens = torch.full((N,), fill_value=self.T, device=x.device, dtype=torch.int64)
(
encoder_out,
_,
[new_cached_attn, new_cached_conv],
) = self.encoder.streaming_forward(
x,
x_lens,
states=[cached_attn, cached_conv],
processed_lens=processed_lens,
left_context=self.left_context,
right_context=self.right_context,
chunk_size=self.chunk_size,
)
encoder_out = self.encoder_proj(encoder_out)
# Now encoder_out is of shape (N, T, joiner_dim)
return encoder_out, new_cached_attn, new_cached_conv
class OnnxDecoder(nn.Module):
"""A wrapper for Decoder and the decoder_proj from the joiner"""
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
super().__init__()
self.decoder = decoder
self.decoder_proj = decoder_proj
def forward(self, y: torch.Tensor) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, context_size).
Returns
Return a 2-D tensor of shape (N, joiner_dim)
"""
need_pad = False
decoder_output = self.decoder(y, need_pad=need_pad)
decoder_output = decoder_output.squeeze(1)
output = self.decoder_proj(decoder_output)
return output
class OnnxJoiner(nn.Module):
"""A wrapper for the joiner"""
def __init__(self, output_linear: nn.Linear):
super().__init__()
self.output_linear = output_linear
def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
) -> torch.Tensor:
"""
Args:
encoder_out:
A 2-D tensor of shape (N, joiner_dim)
decoder_out:
A 2-D tensor of shape (N, joiner_dim)
Returns:
Return a 2-D tensor of shape (N, vocab_size)
"""
logit = encoder_out + decoder_out
logit = self.output_linear(torch.tanh(logit))
return logit
def export_encoder_model_onnx(
encoder_model: OnnxEncoder,
encoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the given encoder model to ONNX format.
The exported model has two inputs:
- x, a tensor of shape (N, T, C); dtype is torch.float32
- x_lens, a tensor of shape (N,); dtype is torch.int64
and it has two outputs:
- encoder_out, a tensor of shape (N, T', joiner_dim)
- encoder_out_lens, a tensor of shape (N,)
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
N = 1
x = torch.zeros(N, encoder_model.T, 80, dtype=torch.float32)
cached_attn = torch.zeros(
encoder_model.num_encoder_layers,
encoder_model.left_context,
N,
encoder_model.encoder_dim,
)
cached_conv = torch.zeros(
encoder_model.num_encoder_layers,
encoder_model.cnn_module_kernel - 1,
N,
encoder_model.encoder_dim,
)
processed_lens = torch.zeros((N,), dtype=torch.int64)
torch.onnx.export(
encoder_model,
(x, cached_attn, cached_conv, processed_lens),
encoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "cached_attn", "cached_conv", "processed_lens"],
output_names=["encoder_out", "new_cached_attn", "new_cached_conv"],
dynamic_axes={
"x": {0: "N"},
"cached_attn": {2: "N"},
"cached_conv": {2: "N"},
"processed_lens": {0: "N"},
"encoder_out": {0: "N"},
"new_cached_attn": {2: "N"},
"new_cached_conv": {2: "N"},
},
)
meta_data = {
"model_type": "conformer",
"version": "1",
"model_author": "k2-fsa",
"comment": "stateless5",
"pad_length": str(encoder_model.pad_length),
"decode_chunk_len": str(encoder_model.decode_chunk_len),
"encoder_dim": str(encoder_model.encoder_dim),
"num_encoder_layers": str(encoder_model.num_encoder_layers),
"cnn_module_kernel": str(encoder_model.cnn_module_kernel),
"left_context": str(encoder_model.left_context),
"right_context": str(encoder_model.right_context),
"chunk_size": str(encoder_model.chunk_size),
"T": str(encoder_model.T),
}
logging.info(f"meta_data: {meta_data}")
add_meta_data(filename=encoder_filename, meta_data=meta_data)
def export_decoder_model_onnx(
decoder_model: OnnxDecoder,
decoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the decoder model to ONNX format.
The exported model has one input:
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
and has one output:
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
Args:
decoder_model:
The decoder model to be exported.
decoder_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
context_size = decoder_model.decoder.context_size
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
torch.onnx.export(
decoder_model,
y,
decoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["y"],
output_names=["decoder_out"],
dynamic_axes={
"y": {0: "N"},
"decoder_out": {0: "N"},
},
)
meta_data = {
"context_size": str(context_size),
"vocab_size": str(vocab_size),
}
add_meta_data(filename=decoder_filename, meta_data=meta_data)
def export_joiner_model_onnx(
joiner_model: nn.Module,
joiner_filename: str,
opset_version: int = 11,
) -> None:
"""Export the joiner model to ONNX format.
The exported joiner model has two inputs:
- encoder_out: a tensor of shape (N, joiner_dim)
- decoder_out: a tensor of shape (N, joiner_dim)
and produces one output:
- logit: a tensor of shape (N, vocab_size)
"""
joiner_dim = joiner_model.output_linear.weight.shape[1]
logging.info(f"joiner dim: {joiner_dim}")
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model,
(projected_encoder_out, projected_decoder_out),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"encoder_out",
"decoder_out",
],
output_names=["logit"],
dynamic_axes={
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
"logit": {0: "N"},
},
)
meta_data = {
"joiner_dim": str(joiner_dim),
}
add_meta_data(filename=joiner_filename, meta_data=meta_data)
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
if not params.causal_convolution:
logging.info("Seting causal_convolution to True for exporting streaming models")
params.causal_convolution = True
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
setup_logger(f"{params.exp_dir}/log-export/log-export-onnx")
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
num_param = sum(p.numel() for p in model.parameters())
logging.info(f"Number of model parameters: {num_param}")
model.to(device)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
convert_scaled_to_non_scaled(model, inplace=True)
encoder = OnnxEncoder(
encoder=model.encoder,
encoder_proj=model.joiner.encoder_proj,
)
decoder = OnnxDecoder(
decoder=model.decoder,
decoder_proj=model.joiner.decoder_proj,
)
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
logging.info(f"encoder parameters: {encoder_num_param}")
logging.info(f"decoder parameters: {decoder_num_param}")
logging.info(f"joiner parameters: {joiner_num_param}")
logging.info(f"total parameters: {total_num_param}")
if params.iter > 0:
suffix = f"iter-{params.iter}"
else:
suffix = f"epoch-{params.epoch}"
suffix += f"-avg-{params.avg}"
opset_version = 13
logging.info("Exporting encoder")
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
export_encoder_model_onnx(
encoder,
encoder_filename,
opset_version=opset_version,
)
logging.info(f"Exported encoder to {encoder_filename}")
logging.info("Exporting decoder")
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
export_decoder_model_onnx(
decoder,
decoder_filename,
opset_version=opset_version,
)
logging.info(f"Exported decoder to {decoder_filename}")
logging.info("Exporting joiner")
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
export_joiner_model_onnx(
joiner,
joiner_filename,
opset_version=opset_version,
)
logging.info(f"Exported joiner to {joiner_filename}")
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
logging.info("Generate int8 quantization models")
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=encoder_filename,
model_output=encoder_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=decoder_filename,
model_output=decoder_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
quantize_dynamic(
model_input=joiner_filename,
model_output=joiner_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
main()

View File

@ -0,0 +1,456 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
"""
This script loads ONNX models exported by ./export-onnx.py
and uses them to decode waves.
We use the pre-trained model from
https://huggingface.co/pkufool/icefall_librispeech_streaming_pruned_transducer_stateless5_20220729
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/pkufool/icefall_librispeech_streaming_pruned_transducer_stateless5_20220729
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained-epoch-25-avg-5.pt"
cd exp
ln -s pretrained-epoch-25-avg-5.pt epoch-99.pt
popd
2. Export the model to ONNX
./pruned_transducer_stateless5/export-onnx-streaming.py \
--bpe-model ./icefall_librispeech_streaming_pruned_transducer_stateless5_20220729/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--use-averaged-model 0 \
--exp-dir ./icefall_librispeech_streaming_pruned_transducer_stateless5_20220729/exp \
--num-encoder-layers 18 \
--dim-feedforward 2048 \
--nhead 8 \
--encoder-dim 512 \
--decoder-dim 512 \
--joiner-dim 512
It will generate the following 3 files in $repo/exp
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
3. Run this file with the exported ONNX models
./pruned_transducer_stateless5/onnx_pretrained-streaming.py \
--encoder-model-filename ./icefall_librispeech_streaming_pruned_transducer_stateless5_20220729/exp/encoder-epoch-99-avg-1.onnx \
--decoder-model-filename ./icefall_librispeech_streaming_pruned_transducer_stateless5_20220729/exp/decoder-epoch-99-avg-1.onnx \
--joiner-model-filename ./icefall_librispeech_streaming_pruned_transducer_stateless5_20220729/exp/joiner-epoch-99-avg-1.onnx \
--tokens=./icefall_librispeech_streaming_pruned_transducer_stateless5_20220729/data/lang_bpe_500/tokens.txt \
./icefall_librispeech_streaming_pruned_transducer_stateless5_20220729/test_waves/1221-135766-0001.wav
Note: Even though this script only supports decoding a single file,
the exported ONNX models do support batch processing.
You can find the exported models in
https://huggingface.co/csukuangfj/sherpa-onnx-streaming-conformer-en-2023-05-09
"""
import argparse
import logging
from typing import Dict, List, Optional, Tuple
import k2
import numpy as np
import onnxruntime as ort
import torch
import torchaudio
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--encoder-model-filename",
type=str,
required=True,
help="Path to the encoder onnx model. ",
)
parser.add_argument(
"--decoder-model-filename",
type=str,
required=True,
help="Path to the decoder onnx model. ",
)
parser.add_argument(
"--joiner-model-filename",
type=str,
required=True,
help="Path to the joiner onnx model. ",
)
parser.add_argument(
"--tokens",
type=str,
help="""Path to tokens.txt.""",
)
parser.add_argument(
"sound_file",
type=str,
help="The input sound file to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
return parser
class OnnxModel:
def __init__(
self,
encoder_model_filename: str,
decoder_model_filename: str,
joiner_model_filename: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.session_opts = session_opts
self.init_encoder(encoder_model_filename)
self.init_decoder(decoder_model_filename)
self.init_joiner(joiner_model_filename)
def init_encoder(self, encoder_model_filename: str):
self.encoder = ort.InferenceSession(
encoder_model_filename,
sess_options=self.session_opts,
)
self.init_encoder_states()
def init_encoder_states(self, batch_size: int = 1):
encoder_meta = self.encoder.get_modelmeta().custom_metadata_map
print(encoder_meta)
model_type = encoder_meta["model_type"]
assert model_type == "conformer", model_type
decode_chunk_len = int(encoder_meta["decode_chunk_len"])
T = int(encoder_meta["T"])
pad_length = int(encoder_meta["pad_length"])
encoder_dim = int(encoder_meta["encoder_dim"])
cnn_module_kernel = int(encoder_meta["cnn_module_kernel"])
left_context = int(encoder_meta["left_context"])
num_encoder_layers = int(encoder_meta["num_encoder_layers"])
self.cached_attn = torch.zeros(
num_encoder_layers,
left_context,
batch_size,
encoder_dim,
).numpy()
self.cached_conv = torch.zeros(
num_encoder_layers,
cnn_module_kernel - 1,
batch_size,
encoder_dim,
).numpy()
logging.info(f"decode_chunk_len: {decode_chunk_len}")
logging.info(f"T: {T}")
logging.info(f"pad_length: {pad_length}")
logging.info(f"encoder_dim: {encoder_dim}")
logging.info(f"cnn_module_kernel: {cnn_module_kernel}")
logging.info(f"left_context: {left_context}")
logging.info(f"num_encoder_layers: {num_encoder_layers}")
self.segment = T
self.offset = decode_chunk_len
def init_decoder(self, decoder_model_filename: str):
self.decoder = ort.InferenceSession(
decoder_model_filename,
sess_options=self.session_opts,
)
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
self.context_size = int(decoder_meta["context_size"])
self.vocab_size = int(decoder_meta["vocab_size"])
logging.info(f"context_size: {self.context_size}")
logging.info(f"vocab_size: {self.vocab_size}")
def init_joiner(self, joiner_model_filename: str):
self.joiner = ort.InferenceSession(
joiner_model_filename,
sess_options=self.session_opts,
)
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
self.joiner_dim = int(joiner_meta["joiner_dim"])
logging.info(f"joiner_dim: {self.joiner_dim}")
def _build_encoder_input_output(
self, x: torch.Tensor, processed_lens: int
) -> Tuple[Dict[str, np.ndarray], List[str]]:
assert x.size(0) == 1
encoder_input = {
"x": x.numpy(),
"cached_attn": self.cached_attn,
"cached_conv": self.cached_conv,
"processed_lens": torch.full(
(1,), fill_value=processed_lens, dtype=torch.int64
).numpy(),
}
encoder_output = ["encoder_out", "new_cached_attn", "new_cached_conv"]
return encoder_input, encoder_output
def _update_states(self, states: List[np.ndarray]):
self.cached_attn = states[0]
self.cached_conv = states[1]
def run_encoder(self, x: torch.Tensor, num_processed_frames: int) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, self.T, C). It only implements N == 1
num_processed_frames:
Number of processed frames before subsampling.
Returns:
Return a 3-D tensor of shape (N, chunk_size, joiner_dim)
"""
# assume subsampling_factor is 4
num_processed_frames = num_processed_frames // 4
encoder_input, encoder_output_names = self._build_encoder_input_output(
x, num_processed_frames
)
out = self.encoder.run(encoder_output_names, encoder_input)
self._update_states(out[1:])
return torch.from_numpy(out[0])
def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor:
"""
Args:
decoder_input:
A 2-D tensor of shape (N, context_size)
Returns:
Return a 2-D tensor of shape (N, joiner_dim)
"""
out = self.decoder.run(
[self.decoder.get_outputs()[0].name],
{self.decoder.get_inputs()[0].name: decoder_input.numpy()},
)[0]
return torch.from_numpy(out)
def run_joiner(
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
) -> torch.Tensor:
"""
Args:
encoder_out:
A 2-D tensor of shape (N, joiner_dim)
decoder_out:
A 2-D tensor of shape (N, joiner_dim)
Returns:
Return a 2-D tensor of shape (N, vocab_size)
"""
out = self.joiner.run(
[self.joiner.get_outputs()[0].name],
{
self.joiner.get_inputs()[0].name: encoder_out.numpy(),
self.joiner.get_inputs()[1].name: decoder_out.numpy(),
},
)[0]
return torch.from_numpy(out)
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0].contiguous())
return ans
def create_streaming_feature_extractor() -> OnlineFeature:
"""Create a CPU streaming feature extractor.
At present, we assume it returns a fbank feature extractor with
fixed options. In the future, we will support passing in the options
from outside.
Returns:
Return a CPU streaming feature extractor.
"""
opts = FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
return OnlineFbank(opts)
def greedy_search(
model: OnnxModel,
encoder_out: torch.Tensor,
context_size: int,
decoder_out: Optional[torch.Tensor] = None,
hyp: Optional[List[int]] = None,
) -> List[int]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
A 3-D tensor of shape (1, T, joiner_dim)
context_size:
The context size of the decoder model.
decoder_out:
Optional. Decoder output of the previous chunk.
hyp:
Decoding results for previous chunks.
Returns:
Return the decoded results so far.
"""
blank_id = 0
if decoder_out is None:
assert hyp is None, hyp
hyp = [blank_id] * context_size
decoder_input = torch.tensor([hyp], dtype=torch.int64)
decoder_out = model.run_decoder(decoder_input)
else:
assert hyp is not None, hyp
encoder_out = encoder_out.squeeze(0)
T = encoder_out.size(0)
for t in range(T):
cur_encoder_out = encoder_out[t : t + 1]
joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0)
y = joiner_out.argmax(dim=0).item()
if y != blank_id:
hyp.append(y)
decoder_input = hyp[-context_size:]
decoder_input = torch.tensor([decoder_input], dtype=torch.int64)
decoder_out = model.run_decoder(decoder_input)
return hyp, decoder_out
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
model = OnnxModel(
encoder_model_filename=args.encoder_model_filename,
decoder_model_filename=args.decoder_model_filename,
joiner_model_filename=args.joiner_model_filename,
)
sample_rate = 16000
logging.info("Constructing Fbank computer")
online_fbank = create_streaming_feature_extractor()
logging.info(f"Reading sound files: {args.sound_file}")
waves = read_sound_files(
filenames=[args.sound_file],
expected_sample_rate=sample_rate,
)[0]
tail_padding = torch.zeros(int(1.0 * sample_rate), dtype=torch.float32)
wave_samples = torch.cat([waves, tail_padding])
num_processed_frames = 0
segment = model.segment
offset = model.offset
context_size = model.context_size
hyp = None
decoder_out = None
chunk = int(1 * sample_rate) # 1 second
start = 0
while start < wave_samples.numel():
end = min(start + chunk, wave_samples.numel())
samples = wave_samples[start:end]
start += chunk
online_fbank.accept_waveform(
sampling_rate=sample_rate,
waveform=samples,
)
while online_fbank.num_frames_ready - num_processed_frames >= segment:
frames = []
for i in range(segment):
frames.append(online_fbank.get_frame(num_processed_frames + i))
num_processed_frames += offset
frames = torch.cat(frames, dim=0)
frames = frames.unsqueeze(0)
encoder_out = model.run_encoder(frames, num_processed_frames)
hyp, decoder_out = greedy_search(
model,
encoder_out,
context_size,
decoder_out,
hyp,
)
symbol_table = k2.SymbolTable.from_file(args.tokens)
text = ""
for i in hyp[context_size:]:
text += symbol_table[i]
text = text.replace("", " ").strip()
logging.info(args.sound_file)
logging.info(text)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -25,29 +25,53 @@ from lhotse import CutSet, load_manifest_lazy
class MultiDataset:
def __init__(self, manifest_dir: str):
def __init__(self, manifest_dir: str, cv_manifest_dir: str):
"""
Args:
manifest_dir:
It is expected to contain the following files:
- multidataset_split_1998/multidataset/multidataset_cuts_train.*.jsonl.gz
- librispeech_cuts_train-all-shuf.jsonl.gz
- gigaspeech_XL_split_2000/gigaspeech_cuts_XL.*.jsonl.gz
cv_manifest_dir:
It is expected to contain the following files:
- cv-en_cuts_train.jsonl.gz
"""
self.manifest_dir = Path(manifest_dir)
self.cv_manifest_dir = Path(cv_manifest_dir)
def train_cuts(self) -> CutSet:
logging.info("About to get multidataset train cuts")
filenames = glob.glob(
f"{self.manifest_dir}/multidataset_split_1998/multidataset/multidataset_cuts_train.*.jsonl.gz"
# LibriSpeech
logging.info(f"Loading LibriSpeech in lazy mode")
librispeech_cuts = load_manifest_lazy(
self.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
)
pattern = re.compile(r"multidataset_cuts_train.([0-9]+).jsonl.gz")
# GigaSpeech
filenames = glob.glob(
f"{self.manifest_dir}/gigaspeech_XL_split_2000/gigaspeech_cuts_XL.*.jsonl.gz"
)
pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz")
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
sorted_filenames = [f[1] for f in idx_filenames]
logging.info(f"Loading {len(sorted_filenames)} splits")
logging.info(f"Loading GigaSpeech {len(sorted_filenames)} splits in lazy mode")
return lhotse.combine(lhotse.load_manifest_lazy(p) for p in sorted_filenames)
gigaspeech_cuts = lhotse.combine(
lhotse.load_manifest_lazy(p) for p in sorted_filenames
)
# CommonVoice
logging.info(f"Loading CommonVoice in lazy mode")
commonvoice_cuts = load_manifest_lazy(
self.cv_manifest_dir / f"cv-en_cuts_train.jsonl.gz"
)
return CutSet.mux(librispeech_cuts, gigaspeech_cuts, commonvoice_cuts)

View File

@ -60,13 +60,13 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from multidataset import MultiDataset
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import Transducer
from multidataset import MultiDataset
from optim import Eden, ScaledAdam
from torch import Tensor
from torch.cuda.amp import GradScaler
@ -1053,10 +1053,12 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args)
if params.use_multidataset:
multidataset = MultiDataset(params.manifest_dir)
multidataset = MultiDataset(params.manifest_dir, params.cv_manifest_dir)
train_cuts = multidataset.train_cuts()
else:
if params.full_libri:
if params.mini_libri:
train_cuts = librispeech.train_clean_5_cuts()
elif params.full_libri:
train_cuts = librispeech.train_all_shuf_cuts()
else:
train_cuts = librispeech.train_clean_100_cuts()
@ -1108,8 +1110,11 @@ def run(rank, world_size, args):
train_cuts, sampler_state_dict=sampler_state_dict
)
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
if params.mini_libri:
valid_cuts = librispeech.dev_clean_2_cuts()
else:
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
if not params.use_multidataset and not params.print_diagnostics:

View File

@ -123,10 +123,13 @@ from beam_search import (
greedy_search_batch,
modified_beam_search,
modified_beam_search_lm_rescore,
modified_beam_search_lm_rescore_LODR,
modified_beam_search_lm_shallow_fusion,
modified_beam_search_LODR,
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall import LmScorer, NgramLm
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
@ -134,7 +137,6 @@ from icefall.checkpoint import (
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.lm_wrapper import LmScorer
from icefall.utils import (
AttributeDict,
setup_logger,
@ -336,6 +338,21 @@ def get_parser():
""",
)
parser.add_argument(
"--tokens-ngram",
type=int,
default=2,
help="""The order of the ngram lm.
""",
)
parser.add_argument(
"--backoff-id",
type=int,
default=500,
help="ID of the backoff symbol in the ngram LM",
)
add_model_arguments(parser)
return parser
@ -349,6 +366,8 @@ def decode_one_batch(
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
LM: Optional[LmScorer] = None,
ngram_lm=None,
ngram_lm_scale: float = 0.0,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -483,6 +502,18 @@ def decode_one_batch(
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_LODR":
hyp_tokens = modified_beam_search_LODR(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
LODR_lm=ngram_lm,
LODR_lm_scale=ngram_lm_scale,
LM=LM,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_lm_rescore":
lm_scale_list = [0.01 * i for i in range(10, 50)]
ans_dict = modified_beam_search_lm_rescore(
@ -493,6 +524,18 @@ def decode_one_batch(
LM=LM,
lm_scale_list=lm_scale_list,
)
elif params.decoding_method == "modified_beam_search_lm_rescore_LODR":
lm_scale_list = [0.02 * i for i in range(2, 30)]
ans_dict = modified_beam_search_lm_rescore_LODR(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
LM=LM,
LODR_lm=ngram_lm,
sp=sp,
lm_scale_list=lm_scale_list,
)
else:
batch_size = encoder_out.size(0)
@ -531,7 +574,10 @@ def decode_one_batch(
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
elif params.decoding_method == "modified_beam_search_lm_rescore":
elif params.decoding_method in (
"modified_beam_search_lm_rescore",
"modified_beam_search_lm_rescore_LODR",
):
ans = dict()
assert ans_dict is not None
for key, hyps in ans_dict.items():
@ -550,6 +596,8 @@ def decode_dataset(
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
LM: Optional[LmScorer] = None,
ngram_lm=None,
ngram_lm_scale: float = 0.0,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
@ -568,6 +616,8 @@ def decode_dataset(
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
ngram_lm:
A n-gram LM to be used for LODR.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
@ -600,6 +650,8 @@ def decode_dataset(
word_table=word_table,
batch=batch,
LM=LM,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
)
for name, hyps in hyps_dict.items():
@ -677,8 +729,10 @@ def main():
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
"modified_beam_search_LODR",
"modified_beam_search_lm_shallow_fusion",
"modified_beam_search_lm_rescore",
"modified_beam_search_lm_rescore_LODR",
)
params.res_dir = params.exp_dir / params.decoding_method
@ -822,7 +876,12 @@ def main():
model.eval()
# only load the neural network LM if required
if params.use_shallow_fusion or "lm" in params.decoding_method:
if params.use_shallow_fusion or params.decoding_method in (
"modified_beam_search_lm_rescore",
"modified_beam_search_lm_rescore_LODR",
"modified_beam_search_lm_shallow_fusion",
"modified_beam_search_LODR",
):
LM = LmScorer(
lm_type=params.lm_type,
params=params,
@ -834,6 +893,35 @@ def main():
else:
LM = None
# only load N-gram LM when needed
if params.decoding_method == "modified_beam_search_lm_rescore_LODR":
try:
import kenlm
except ImportError:
print("Please install kenlm first. You can use")
print(" pip install https://github.com/kpu/kenlm/archive/master.zip")
print("to install it")
import sys
sys.exit(-1)
ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa")
logging.info(f"lm filename: {ngram_file_name}")
ngram_lm = kenlm.Model(ngram_file_name)
elif params.decoding_method == "modified_beam_search_LODR":
lm_filename = f"{params.tokens_ngram}gram.fst.txt"
logging.info(f"Loading token level lm: {lm_filename}")
ngram_lm = NgramLm(
str(params.lang_dir / lm_filename),
backoff_id=params.backoff_id,
is_binary=False,
)
logging.info(f"num states: {ngram_lm.lm.num_states}")
ngram_lm_scale = params.ngram_lm_scale
else:
ngram_lm = None
ngram_lm_scale = None
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
@ -866,8 +954,10 @@ def main():
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
import time
for test_set, test_dl in zip(test_sets, test_dl):
start = time.time()
results_dict = decode_dataset(
dl=test_dl,
params=params,
@ -876,7 +966,10 @@ def main():
word_table=word_table,
decoding_graph=decoding_graph,
LM=LM,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
)
logging.info(f"Elasped time for {test_set}: {time.time() - start}")
save_results(
params=params,

View File

@ -1049,10 +1049,13 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
if params.mini_libri:
train_cuts = librispeech.train_clean_5_cuts()
else:
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
@ -1104,8 +1107,11 @@ def run(rank, world_size, args):
train_cuts, sampler_state_dict=sampler_state_dict
)
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
if params.mini_libri:
valid_cuts = librispeech.dev_clean_2_cuts()
else:
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
# if not params.print_diagnostics:

View File

@ -86,8 +86,16 @@ class LibriSpeechAsrDataModule:
"--full-libri",
type=str2bool,
default=True,
help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
help="""Used only when --mini-libri is False.When enabled,
use 960h LibriSpeech. Otherwise, use 100h subset.""",
)
group.add_argument(
"--mini-libri",
type=str2bool,
default=False,
help="True for mini librispeech",
)
group.add_argument(
"--manifest-dir",
type=Path,
@ -393,6 +401,13 @@ class LibriSpeechAsrDataModule:
)
return test_dl
@lru_cache()
def train_clean_5_cuts(self) -> CutSet:
logging.info("mini_librispeech: About to get train-clean-5 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz"
)
@lru_cache()
def train_clean_100_cuts(self) -> CutSet:
logging.info("About to get train-clean-100 cuts")
@ -424,6 +439,13 @@ class LibriSpeechAsrDataModule:
self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
)
@lru_cache()
def dev_clean_2_cuts(self) -> CutSet:
logging.info("mini_librispeech: About to get dev-clean-2 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz"
)
@lru_cache()
def dev_clean_cuts(self) -> CutSet:
logging.info("About to get dev-clean cuts")

6
egs/timit/ASR/local/compile_hlg.py Normal file → Executable file
View File

@ -100,7 +100,11 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
logging.info("Removing disambiguation symbols on LG")
LG.labels[LG.labels >= first_token_disambig_id] = 0
# LG.labels[LG.labels >= first_token_disambig_id] = 0
# see https://github.com/k2-fsa/k2/pull/1140
labels = LG.labels
labels[labels >= first_token_disambig_id] = 0
LG.labels = labels
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0

0
egs/timit/ASR/local/compute_fbank_timit.py Normal file → Executable file
View File

0
egs/timit/ASR/local/prepare_lang.py Normal file → Executable file
View File

0
egs/timit/ASR/local/prepare_lexicon.py Normal file → Executable file
View File

View File

@ -59,7 +59,9 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
# using: `sudo apt-get install git-lfs && git-lfs install`
[ ! -e $dl_dir/lm ] && mkdir -p $dl_dir/lm
git clone https://huggingface.co/luomingshuang/timit_lm $dl_dir/lm
cd $dl_dir/lm && git lfs pull
pushd $dl_dir/lm
git lfs pull
popd
fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then

View File

@ -78,10 +78,11 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
logging.info("Removing disambiguation symbols on LG")
LG.labels[LG.labels >= first_token_disambig_id] = 0
# See https://github.com/k2-fsa/k2/issues/874
# for why we need to set LG.properties to None
LG.__dict__["_properties"] = None
# LG.labels[LG.labels >= first_token_disambig_id] = 0
# see https://github.com/k2-fsa/k2/pull/1140
labels = LG.labels
labels[labels >= first_token_disambig_id] = 0
LG.labels = labels
assert isinstance(LG.aux_labels, k2.RaggedTensor)
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0

View File

@ -8,6 +8,12 @@ from . import (
utils
)
from .byte_utils import (
byte_decode,
byte_encode,
smart_byte_decode,
)
from .checkpoint import (
average_checkpoints,
find_checkpoints,
@ -49,6 +55,7 @@ from .utils import (
get_alignments,
get_executor,
get_texts,
is_cjk,
is_jit_tracing,
is_module_available,
l1_norm,
@ -64,6 +71,7 @@ from .utils import (
store_transcripts,
str2bool,
subsequent_chunk_mask,
tokenize_by_CJK_char,
write_error_stats,
)

311
icefall/byte_utils.py Normal file
View File

@ -0,0 +1,311 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# This file was copied and modified from https://github.com/facebookresearch/fairseq/blob/main/fairseq/data/encoders/byte_utils.py
import re
import unicodedata
WHITESPACE_NORMALIZER = re.compile(r"\s+")
SPACE = chr(32)
SPACE_ESCAPE = chr(9601)
PRINTABLE_BASE_CHARS = [
256,
257,
258,
259,
260,
261,
262,
263,
264,
265,
266,
267,
268,
269,
270,
271,
272,
273,
274,
275,
276,
277,
278,
279,
280,
281,
282,
283,
284,
285,
286,
287,
32,
33,
34,
35,
36,
37,
38,
39,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
52,
53,
54,
55,
56,
57,
58,
59,
60,
61,
62,
63,
64,
65,
66,
67,
68,
69,
70,
71,
72,
73,
74,
75,
76,
77,
78,
79,
80,
81,
82,
83,
84,
85,
86,
87,
88,
89,
90,
91,
92,
93,
94,
95,
96,
97,
98,
99,
100,
101,
102,
103,
104,
105,
106,
107,
108,
109,
110,
111,
112,
113,
114,
115,
116,
117,
118,
119,
120,
121,
122,
123,
124,
125,
126,
288,
289,
290,
291,
292,
293,
294,
295,
296,
297,
298,
299,
300,
301,
302,
303,
304,
305,
308,
309,
310,
311,
312,
313,
314,
315,
316,
317,
318,
321,
322,
323,
324,
325,
326,
327,
328,
330,
331,
332,
333,
334,
335,
336,
337,
338,
339,
340,
341,
342,
343,
344,
345,
346,
347,
348,
349,
350,
351,
352,
353,
354,
355,
356,
357,
358,
359,
360,
361,
362,
363,
364,
365,
366,
367,
368,
369,
370,
371,
372,
373,
374,
375,
376,
377,
378,
379,
380,
381,
382,
384,
385,
386,
387,
388,
389,
390,
391,
392,
393,
394,
395,
396,
397,
398,
399,
400,
401,
402,
403,
404,
405,
406,
407,
408,
409,
410,
411,
412,
413,
414,
415,
416,
417,
418,
419,
420,
421,
422,
]
for c in PRINTABLE_BASE_CHARS:
assert unicodedata.normalize("NFKC", chr(c)) == chr(c), c
BYTE_TO_BCHAR = {b: chr(PRINTABLE_BASE_CHARS[b]) for b in range(256)}
BCHAR_TO_BYTE = {bc: b for b, bc in BYTE_TO_BCHAR.items()}
def byte_encode(x: str) -> str:
normalized = WHITESPACE_NORMALIZER.sub(SPACE, x)
return "".join([BYTE_TO_BCHAR[b] for b in normalized.encode("utf-8")])
def byte_decode(x: str) -> str:
try:
return bytes([BCHAR_TO_BYTE[bc] for bc in x]).decode("utf-8")
except ValueError:
return ""
def smart_byte_decode(x: str) -> str:
output = byte_decode(x)
if output == "":
# DP the best recovery (max valid chars) if it's broken
n_bytes = len(x)
f = [0 for _ in range(n_bytes + 1)]
pt = [0 for _ in range(n_bytes + 1)]
for i in range(1, n_bytes + 1):
f[i], pt[i] = f[i - 1], i - 1
for j in range(1, min(4, i) + 1):
if f[i - j] + 1 > f[i] and len(byte_decode(x[i - j : i])) > 0:
f[i], pt[i] = f[i - j] + 1, i - j
cur_pt = n_bytes
while cur_pt > 0:
if f[cur_pt] == f[pt[cur_pt]] + 1:
output = byte_decode(x[pt[cur_pt] : cur_pt]) + output
cur_pt = pt[cur_pt]
return output

1
icefall/rnn_lm/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
icefall-librispeech-rnn-lm

View File

@ -0,0 +1,129 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation
"""
Usage:
./check-onnx-streaming.py \
--jit ./icefall-librispeech-rnn-lm/exp/cpu_jit.pt \
--onnx ./icefall-librispeech-rnn-lm/exp/with-state-epoch-99-avg-1.onnx
Note: You can download pre-trained models from
https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
"""
import argparse
import logging
from typing import Tuple
import onnxruntime as ort
import torch
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--jit",
required=True,
type=str,
help="Path to the torchscript model",
)
parser.add_argument(
"--onnx",
required=True,
type=str,
help="Path to the onnx model",
)
return parser
class OnnxModel:
def __init__(self, filename: str):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.model = ort.InferenceSession(
filename,
sess_options=session_opts,
)
meta_data = self.model.get_modelmeta().custom_metadata_map
self.sos_id = int(meta_data["sos_id"])
self.eos_id = int(meta_data["eos_id"])
self.vocab_size = int(meta_data["vocab_size"])
self.num_layers = int(meta_data["num_layers"])
self.hidden_size = int(meta_data["hidden_size"])
print(meta_data)
def __call__(
self, x: torch.Tensor, y: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
out = self.model.run(
[
self.model.get_outputs()[0].name,
self.model.get_outputs()[1].name,
self.model.get_outputs()[2].name,
],
{
self.model.get_inputs()[0].name: x.numpy(),
self.model.get_inputs()[1].name: y.numpy(),
self.model.get_inputs()[2].name: h0.numpy(),
self.model.get_inputs()[3].name: c0.numpy(),
},
)
return (
torch.from_numpy(out[0]),
torch.from_numpy(out[1]),
torch.from_numpy(out[2]),
)
@torch.no_grad()
def main():
args = get_parser().parse_args()
logging.info(vars(args))
torch_model = torch.jit.load(args.jit).cpu()
onnx_model = OnnxModel(args.onnx)
N = torch.arange(1, 5).tolist()
num_layers = onnx_model.num_layers
hidden_size = onnx_model.hidden_size
for n in N:
L = torch.randint(low=1, high=100, size=(1,)).item()
x = torch.randint(
low=1, high=onnx_model.vocab_size, size=(n, L), dtype=torch.int64
)
h0 = torch.rand(num_layers, n, hidden_size)
c0 = torch.rand(num_layers, n, hidden_size)
torch_log_prob, torch_h0, torch_c0 = torch_model.score_token_onnx(x, h0, c0)
onnx_log_prob, onnx_h0, onnx_c0 = onnx_model(x, h0, c0)
for torch_v, onnx_v in zip(
(torch_log_prob, torch_h0, torch_c0), (onnx_log_prob, onnx_h0, onnx_c0)
):
assert torch.allclose(torch_v, onnx_v, atol=1e-5), (
torch_v.shape,
onnx_v.shape,
(torch_v - onnx_v).abs().max(),
)
print(n, L, torch_v.sum(), onnx_v.sum())
if __name__ == "__main__":
torch.manual_seed(20230423)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

119
icefall/rnn_lm/check-onnx.py Executable file
View File

@ -0,0 +1,119 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation
"""
Usage:
./check-onnx.py \
--jit ./icefall-librispeech-rnn-lm/exp/cpu_jit.pt \
--onnx ./icefall-librispeech-rnn-lm/exp/no-state-epoch-99-avg-1.onnx
Note: You can download pre-trained models from
https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
"""
import argparse
import logging
import onnxruntime as ort
import torch
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--jit",
required=True,
type=str,
help="Path to the torchscript model",
)
parser.add_argument(
"--onnx",
required=True,
type=str,
help="Path to the onnx model",
)
return parser
class OnnxModel:
def __init__(self, filename: str):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.model = ort.InferenceSession(
filename,
sess_options=session_opts,
)
meta_data = self.model.get_modelmeta().custom_metadata_map
self.sos_id = int(meta_data["sos_id"])
self.eos_id = int(meta_data["eos_id"])
self.vocab_size = int(meta_data["vocab_size"])
print(meta_data)
def __call__(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
out = self.model.run(
[
self.model.get_outputs()[0].name,
],
{
self.model.get_inputs()[0].name: x.numpy(),
self.model.get_inputs()[1].name: x_lens.numpy(),
},
)
return torch.from_numpy(out[0])
@torch.no_grad()
def main():
args = get_parser().parse_args()
logging.info(vars(args))
torch_model = torch.jit.load(args.jit).cpu()
onnx_model = OnnxModel(args.onnx)
N = torch.arange(1, 5).tolist()
for n in N:
L = torch.randint(low=1, high=100, size=(1,)).item()
x = torch.randint(
low=1, high=onnx_model.vocab_size, size=(n, L), dtype=torch.int64
)
x_lens = torch.full((n,), fill_value=L, dtype=torch.int64)
if n > 1:
x_lens[0] = L // 2 + 1
sos = torch.full((1,), fill_value=onnx_model.sos_id).expand(n, 1)
sos_x = torch.cat([sos, x], dim=1)
pad_col = torch.zeros((1,), dtype=x.dtype).expand(n, 1)
x_eos = torch.cat([x, pad_col], dim=1)
row_index = torch.arange(0, n, dtype=x.dtype)
x_eos[row_index, x_lens] = onnx_model.eos_id
torch_nll = torch_model(sos_x, x_eos, x_lens + 1).sum(dim=-1)
onnx_nll = onnx_model(x, x_lens)
# Note: For int8 models, the differences may be quite large,
# e.g., within 0.9
assert torch.allclose(torch_nll, onnx_nll), (
torch_nll,
onnx_nll,
)
print(n, L, torch_nll, onnx_nll)
if __name__ == "__main__":
torch.manual_seed(20230420)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

395
icefall/rnn_lm/export-onnx.py Executable file
View File

@ -0,0 +1,395 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation
import argparse
import logging
from pathlib import Path
import onnx
import torch
from model import RnnLmModel
from onnxruntime.quantization import QuantType, quantize_dynamic
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.utils import AttributeDict, str2bool
from typing import Dict
from train import get_params
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = value
onnx.save(model, filename)
# A wrapper for RnnLm model to simpily the C++ calling code
# when exporting the model to ONNX.
#
# TODO(fangjun): The current wrapper works only for non-streaming ASR
# since we don't expose the LM state and it is used to score
# a complete sentence at once.
class RnnLmModelWrapper(torch.nn.Module):
def __init__(self, model: RnnLmModel, sos_id: int, eos_id: int):
super().__init__()
self.model = model
self.sos_id = sos_id
self.eos_id = eos_id
def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
A 2-D tensor of shape (N, L) with dtype torch.int64.
It does not contain SOS or EOS. We will add SOS and EOS inside
this function.
x_lens:
A 1-D tensor of shape (N,) with dtype torch.int64. It contains
number of valid tokens in ``x`` before padding.
Returns:
Return a 1-D tensor of shape (N,) containing negative loglikelihood.
Its dtype is torch.float32
"""
N = x.size(0)
sos_tensor = torch.full((1,), fill_value=self.sos_id, dtype=x.dtype).expand(
N, 1
)
sos_x = torch.cat([sos_tensor, x], dim=1)
pad_col = torch.zeros((1,), dtype=x.dtype).expand(N, 1)
x_eos = torch.cat([x, pad_col], dim=1)
row_index = torch.arange(0, N, dtype=x.dtype)
x_eos[row_index, x_lens] = self.eos_id
# use x_lens + 1 here since we prepended x with sos
return (
self.model(x=sos_x, y=x_eos, lengths=x_lens + 1)
.to(torch.float32)
.sum(dim=1)
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=29,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=5,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--vocab-size",
type=int,
default=500,
help="Vocabulary size of the model",
)
parser.add_argument(
"--embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--num-layers",
type=int,
default=3,
help="Number of RNN layers the model",
)
parser.add_argument(
"--tie-weights",
type=str2bool,
default=True,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="rnn_lm/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
return parser
def export_without_state(
model: RnnLmModel,
filename: str,
params: AttributeDict,
opset_version: int,
):
model_wrapper = RnnLmModelWrapper(
model,
sos_id=params.sos_id,
eos_id=params.eos_id,
)
N = 1
L = 20
x = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64)
x_lens = torch.full((N,), fill_value=L, dtype=torch.int64)
# Note(fangjun): The following warnings can be ignored.
# We can use ./check-onnx.py to validate the exported model with batch_size > 1
"""
torch/onnx/symbolic_opset9.py:2119: UserWarning: Exporting a model to ONNX
with a batch_size other than 1, with a variable length with LSTM can cause
an error when running the ONNX model with a different batch size. Make sure
to save the model with a batch size of 1, or define the initial states
(h0/c0) as inputs of the model. warnings.warn("Exporting a model to ONNX
with a batch_size other than 1, " +
"""
torch.onnx.export(
model_wrapper,
(x, x_lens),
filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "x_lens"],
output_names=["nll"],
dynamic_axes={
"x": {0: "N", 1: "L"},
"x_lens": {0: "N"},
"nll": {0: "N"},
},
)
meta_data = {
"model_type": "rnnlm",
"version": "1",
"model_author": "k2-fsa",
"comment": "rnnlm without state",
"sos_id": str(params.sos_id),
"eos_id": str(params.eos_id),
"vocab_size": str(params.vocab_size),
"url": "https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm",
}
logging.info(f"meta_data: {meta_data}")
add_meta_data(filename=filename, meta_data=meta_data)
def export_with_state(
model: RnnLmModel,
filename: str,
params: AttributeDict,
opset_version: int,
):
N = 1
L = 20
num_layers = model.rnn.num_layers
hidden_size = model.rnn.hidden_size
embedding_dim = model.embedding_dim
x = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64)
h0 = torch.zeros(num_layers, N, hidden_size)
c0 = torch.zeros(num_layers, N, hidden_size)
# Note(fangjun): The following warnings can be ignored.
# We can use ./check-onnx.py to validate the exported model with batch_size > 1
"""
torch/onnx/symbolic_opset9.py:2119: UserWarning: Exporting a model to ONNX
with a batch_size other than 1, with a variable length with LSTM can cause
an error when running the ONNX model with a different batch size. Make sure
to save the model with a batch size of 1, or define the initial states
(h0/c0) as inputs of the model. warnings.warn("Exporting a model to ONNX
with a batch_size other than 1, " +
"""
torch.onnx.export(
model,
(x, h0, c0),
filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "h0", "c0"],
output_names=["log_softmax", "next_h0", "next_c0"],
dynamic_axes={
"x": {0: "N", 1: "L"},
"h0": {1: "N"},
"c0": {1: "N"},
"log_softmax": {0: "N"},
"next_h0": {1: "N"},
"next_c0": {1: "N"},
},
)
meta_data = {
"model_type": "rnnlm",
"version": "1",
"model_author": "k2-fsa",
"comment": "rnnlm state",
"sos_id": str(params.sos_id),
"eos_id": str(params.eos_id),
"vocab_size": str(params.vocab_size),
"num_layers": str(num_layers),
"hidden_size": str(hidden_size),
"embedding_dim": str(embedding_dim),
"url": "https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm",
}
logging.info(f"meta_data: {meta_data}")
add_meta_data(filename=filename, meta_data=meta_data)
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
logging.info(params)
device = torch.device("cpu")
logging.info(f"device: {device}")
model = RnnLmModel(
vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim,
hidden_dim=params.hidden_dim,
num_layers=params.num_layers,
tie_weights=params.tie_weights,
)
model.to(device)
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
model.to("cpu")
model.eval()
if params.iter > 0:
suffix = f"iter-{params.iter}"
else:
suffix = f"epoch-{params.epoch}"
suffix += f"-avg-{params.avg}"
opset_version = 13
logging.info("Exporting model without state")
filename = params.exp_dir / f"no-state-{suffix}.onnx"
export_without_state(
model=model,
filename=filename,
params=params,
opset_version=opset_version,
)
filename_int8 = params.exp_dir / f"no-state-{suffix}.int8.onnx"
quantize_dynamic(
model_input=filename,
model_output=filename_int8,
weight_type=QuantType.QInt8,
)
# now for streaming export
saved_forward = model.__class__.forward
model.__class__.forward = model.__class__.score_token_onnx
streaming_filename = params.exp_dir / f"with-state-{suffix}.onnx"
export_with_state(
model=model,
filename=streaming_filename,
params=params,
opset_version=opset_version,
)
model.__class__.forward = saved_forward
streaming_filename_int8 = params.exp_dir / f"with-state-{suffix}.int8.onnx"
quantize_dynamic(
model_input=streaming_filename,
model_output=streaming_filename_int8,
weight_type=QuantType.QInt8,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

26
icefall/rnn_lm/export-onnx.sh Executable file
View File

@ -0,0 +1,26 @@
#!/usr/bin/env bash
# We use the model from
# https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main
# as an example
export CUDA_VISIBLE_DEVICES=
if [ ! -f ./icefall-librispeech-rnn-lm/exp/pretrained.pt ]; then
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
pushd icefall-librispeech-rnn-lm/exp
git lfs pull --include "pretrained.pt"
ln -s pretrained.pt epoch-99.pt
popd
fi
python3 ./export-onnx.py \
--exp-dir ./icefall-librispeech-rnn-lm/exp \
--epoch 99 \
--avg 1 \
--vocab-size 500 \
--embedding-dim 2048 \
--hidden-dim 2048 \
--num-layers 3 \
--tie-weights 1

View File

@ -26,7 +26,7 @@ import torch
from model import RnnLmModel
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.utils import AttributeDict, load_averaged_model, str2bool
from icefall.utils import AttributeDict, str2bool
def get_parser():
@ -118,6 +118,7 @@ def get_parser():
return parser
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
@ -180,6 +181,10 @@ def main():
if params.jit:
logging.info("Using torch.jit.script")
model.__class__.score_token_onnx = torch.jit.export(
model.__class__.score_token_onnx
)
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))

27
icefall/rnn_lm/export.sh Executable file
View File

@ -0,0 +1,27 @@
#!/usr/bin/env bash
# We use the model from
# https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main
# as an example
export CUDA_VISIBLE_DEVICES=
if [ ! -f ./icefall-librispeech-rnn-lm/exp/pretrained.pt ]; then
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
pushd icefall-librispeech-rnn-lm/exp
git lfs pull --include "pretrained.pt"
ln -s pretrained.pt epoch-99.pt
popd
fi
python3 ./export.py \
--exp-dir ./icefall-librispeech-rnn-lm/exp \
--epoch 99 \
--avg 1 \
--vocab-size 500 \
--embedding-dim 2048 \
--hidden-dim 2048 \
--num-layers 3 \
--tie-weights 1 \
--jit 1

View File

@ -15,6 +15,7 @@
# limitations under the License.
import logging
from typing import Tuple
import torch
import torch.nn.functional as F
@ -47,6 +48,11 @@ class RnnLmModel(torch.nn.Module):
and https://arxiv.org/abs/1611.01462
"""
super().__init__()
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.tie_weights = tie_weights
self.input_embedding = torch.nn.Embedding(
num_embeddings=vocab_size,
@ -74,6 +80,46 @@ class RnnLmModel(torch.nn.Module):
self.cache = {}
def streaming_forward(
self,
x: torch.Tensor,
y: torch.Tensor,
h0: torch.Tensor,
c0: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
A 2-D tensor of shape (N, L). We won't prepend it with SOS.
y:
A 2-D tensor of shape (N, L). We won't append it with EOS.
h0:
A 3-D tensor of shape (num_layers, N, hidden_size).
(If proj_size > 0, then it is (num_layers, N, proj_size))
c0:
A 3-D tensor of shape (num_layers, N, hidden_size).
Returns:
Return a tuple containing 3 tensors:
- negative loglike (nll), a 1-D tensor of shape (N,)
- next_h0, a 3-D tensor with the same shape as h0
- next_c0, a 3-D tensor with the same shape as c0
"""
assert x.ndim == y.ndim == 2, (x.ndim, y.ndim)
assert x.shape == y.shape, (x.shape, y.shape)
# embedding is of shape (N, L, embedding_dim)
embedding = self.input_embedding(x)
# Note: We use batch_first==True
rnn_out, (next_h0, next_c0) = self.rnn(embedding, (h0, c0))
logits = self.output_linear(rnn_out)
nll_loss = F.cross_entropy(
logits.reshape(-1, self.vocab_size), y.reshape(-1), reduction="none"
)
batch_size = x.size(0)
nll_loss = nll_loss.reshape(batch_size, -1).sum(dim=1)
return nll_loss, next_h0, next_c0
def forward(
self, x: torch.Tensor, y: torch.Tensor, lengths: torch.Tensor
) -> torch.Tensor:
@ -188,6 +234,33 @@ class RnnLmModel(torch.nn.Module):
return logits[:, 0].log_softmax(-1), states
def score_token_onnx(
self,
x: torch.Tensor,
state_h: torch.Tensor,
state_c: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Score a batch of tokens, i.e each sample in the batch should be a
single token. For example, x = torch.tensor([[5],[10],[20]])
Args:
x (torch.Tensor):
A batch of tokens
state_h:
state h of RNN has the shape of (num_layers, bs, hidden_dim)
state_c:
state c of RNN has the shape of (num_layers, bs, hidden_dim)
Returns:
_type_: _description_
"""
embedding = self.input_embedding(x)
rnn_out, (next_h0, next_c0) = self.rnn(embedding, (state_h, state_c))
logits = self.output_linear(rnn_out)
return logits[:, 0].log_softmax(-1), next_h0, next_c0
def forward_with_state(
self, tokens, token_lens, sos_id, eos_id, blank_id, state=None
):

View File

@ -22,6 +22,7 @@ Usage:
--world-size 2 \
--num-epochs 1 \
--use-fp16 0 \
--tie-weights 0 \
--embedding-dim 800 \
--hidden-dim 200 \
--num-layers 2 \

View File

@ -1306,6 +1306,31 @@ def tokenize_by_bpe_model(
return txt_with_bpe
def tokenize_by_CJK_char(line: str) -> str:
"""
Tokenize a line of text with CJK char.
Note: All return charaters will be upper case.
Example:
input = "你好世界是 hello world 的中文"
output = "你 好 世 界 是 HELLO WORLD 的 中 文"
Args:
line:
The input text.
Return:
A new string tokenize by CJK char.
"""
# The CJK ranges is from https://github.com/alvations/nltk/blob/79eed6ddea0d0a2c212c1060b477fc268fec4d4b/nltk/tokenize/util.py
pattern = re.compile(
r"([\u1100-\u11ff\u2e80-\ua4cf\ua840-\uD7AF\uF900-\uFAFF\uFE30-\uFE4F\uFF65-\uFFDC\U00020000-\U0002FFFF])"
)
chars = pattern.split(line.strip().upper())
return " ".join([w.strip() for w in chars if w.strip()])
def display_and_save_batch(
batch: dict,
params: AttributeDict,
@ -1764,3 +1789,34 @@ def parse_fsa_timestamps_and_texts(
utt_time_pairs.append(list(zip(start, end)))
return utt_time_pairs, utt_words
# Copied from https://github.com/alvations/nltk/blob/79eed6ddea0d0a2c212c1060b477fc268fec4d4b/nltk/tokenize/util.py
def is_cjk(character):
"""
Python port of Moses' code to check for CJK character.
>>> is_cjk(u'\u33fe')
True
>>> is_cjk(u'\uFE5F')
False
:param character: The character that needs to be checked.
:type character: char
:return: bool
"""
return any(
[
start <= ord(character) <= end
for start, end in [
(4352, 4607),
(11904, 42191),
(43072, 43135),
(44032, 55215),
(63744, 64255),
(65072, 65103),
(65381, 65500),
(131072, 196607),
]
]
)