mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-13 19:14:20 +00:00
Merge branch 'k2-fsa:master' into master
This commit is contained in:
commit
f35eed4240
151
README.md
151
README.md
@ -28,14 +28,15 @@ We provide the following recipes:
|
||||
|
||||
- [yesno][yesno]
|
||||
- [LibriSpeech][librispeech]
|
||||
- [GigaSpeech][gigaspeech]
|
||||
- [Aishell][aishell]
|
||||
- [Aishell2][aishell2]
|
||||
- [Aishell4][aishell4]
|
||||
- [TIMIT][timit]
|
||||
- [TED-LIUM3][tedlium3]
|
||||
- [GigaSpeech][gigaspeech]
|
||||
- [Aidatatang_200zh][aidatatang_200zh]
|
||||
- [WenetSpeech][wenetspeech]
|
||||
- [Alimeeting][alimeeting]
|
||||
- [Aishell4][aishell4]
|
||||
- [TAL_CSASR][tal_csasr]
|
||||
|
||||
### yesno
|
||||
@ -46,9 +47,7 @@ Training takes less than 30 seconds and gives you the following WER:
|
||||
```
|
||||
[test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ]
|
||||
```
|
||||
We do provide a Colab notebook for this recipe.
|
||||
|
||||
[](https://colab.research.google.com/drive/1tIjjzaJc3IvGyKiMCDWO-TSnBgkcuN3B?usp=sharing)
|
||||
We provide a Colab notebook for this recipe: [](https://colab.research.google.com/drive/1tIjjzaJc3IvGyKiMCDWO-TSnBgkcuN3B?usp=sharing)
|
||||
|
||||
|
||||
### LibriSpeech
|
||||
@ -82,7 +81,7 @@ The WER for this model is:
|
||||
|-----|------------|------------|
|
||||
| WER | 6.59 | 17.69 |
|
||||
|
||||
We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [](https://colab.research.google.com/drive/1-iSfQMp2So-We_Uu49N4AAcMInB72u9z?usp=sharing)
|
||||
We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [](https://colab.research.google.com/drive/1-iSfQMp2So-We_Uu49N4AAcMInB72u9z?usp=sharing)
|
||||
|
||||
|
||||
#### Transducer: Conformer encoder + LSTM decoder
|
||||
@ -118,19 +117,54 @@ We provide a Colab notebook to run a pre-trained transducer conformer + stateles
|
||||
|
||||
| | test-clean | test-other |
|
||||
|-----|------------|------------|
|
||||
| WER | 2.57 | 5.95 |
|
||||
| WER | 2.15 | 5.20 |
|
||||
|
||||
Note: No auxiliary losses are used in the training and no LMs are used
|
||||
in the decoding.
|
||||
|
||||
#### k2 pruned RNN-T + GigaSpeech
|
||||
|
||||
| | test-clean | test-other |
|
||||
|-----|------------|------------|
|
||||
| WER | 2.00 | 4.63 |
|
||||
| WER | 1.78 | 4.08 |
|
||||
|
||||
Note: No auxiliary losses are used in the training and no LMs are used
|
||||
in the decoding.
|
||||
|
||||
#### k2 pruned RNN-T + GigaSpeech + CommonVoice
|
||||
|
||||
| | test-clean | test-other |
|
||||
|-----|------------|------------|
|
||||
| WER | 1.90 | 3.98 |
|
||||
|
||||
Note: No auxiliary losses are used in the training and no LMs are used
|
||||
in the decoding.
|
||||
|
||||
|
||||
### GigaSpeech
|
||||
|
||||
We provide two models for this recipe: [Conformer CTC model][GigaSpeech_conformer_ctc]
|
||||
and [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][GigaSpeech_pruned_transducer_stateless2].
|
||||
|
||||
#### Conformer CTC
|
||||
|
||||
| | Dev | Test |
|
||||
|-----|-------|-------|
|
||||
| WER | 10.47 | 10.58 |
|
||||
|
||||
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
|
||||
|
||||
| | Dev | Test |
|
||||
|----------------------|-------|-------|
|
||||
| greedy search | 10.51 | 10.73 |
|
||||
| fast beam search | 10.50 | 10.69 |
|
||||
| modified beam search | 10.40 | 10.51 |
|
||||
|
||||
|
||||
### Aishell
|
||||
|
||||
We provide two models for this recipe: [conformer CTC model][Aishell_conformer_ctc]
|
||||
and [TDNN LSTM CTC model][Aishell_tdnn_lstm_ctc].
|
||||
We provide three models for this recipe: [conformer CTC model][Aishell_conformer_ctc],
|
||||
[TDNN LSTM CTC model][Aishell_tdnn_lstm_ctc], and [Transducer Stateless Model][Aishell_pruned_transducer_stateless7],
|
||||
|
||||
#### Conformer CTC Model
|
||||
|
||||
@ -140,20 +174,6 @@ The best CER we currently have is:
|
||||
|-----|------|
|
||||
| CER | 4.26 |
|
||||
|
||||
|
||||
We provide a Colab notebook to run a pre-trained conformer CTC model: [
|
||||
|
||||
#### Transducer Stateless Model
|
||||
|
||||
The best CER we currently have is:
|
||||
|
||||
| | test |
|
||||
|-----|------|
|
||||
| CER | 4.68 |
|
||||
|
||||
|
||||
We provide a Colab notebook to run a pre-trained TransducerStateless model: [](https://colab.research.google.com/drive/14XaT2MhnBkK-3_RqqWq3K90Xlbin-GZC?usp=sharing)
|
||||
|
||||
#### TDNN LSTM CTC Model
|
||||
|
||||
The CER for this model is:
|
||||
@ -164,6 +184,46 @@ The CER for this model is:
|
||||
|
||||
We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [](https://colab.research.google.com/drive/1jbyzYq3ytm6j2nlEt-diQm-6QVWyDDEa?usp=sharing)
|
||||
|
||||
#### Transducer Stateless Model
|
||||
|
||||
The best CER we currently have is:
|
||||
|
||||
| | test |
|
||||
|-----|------|
|
||||
| CER | 4.38 |
|
||||
|
||||
We provide a Colab notebook to run a pre-trained TransducerStateless model: [](https://colab.research.google.com/drive/14XaT2MhnBkK-3_RqqWq3K90Xlbin-GZC?usp=sharing)
|
||||
|
||||
|
||||
### Aishell2
|
||||
|
||||
We provide one model for this recipe: [Transducer Stateless Model][Aishell2_pruned_transducer_stateless5].
|
||||
|
||||
#### Transducer Stateless Model
|
||||
|
||||
The best WER we currently have is:
|
||||
|
||||
| | dev-ios | test-ios |
|
||||
|-----|------------|------------|
|
||||
| WER | 5.32 | 5.56 |
|
||||
|
||||
|
||||
### Aishell4
|
||||
|
||||
We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][Aishell4_pruned_transducer_stateless5].
|
||||
|
||||
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with all subsets)
|
||||
|
||||
The best CER we currently have is:
|
||||
|
||||
| | test |
|
||||
|-----|------------|
|
||||
| CER | 29.08 |
|
||||
|
||||
|
||||
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing)
|
||||
|
||||
|
||||
### TIMIT
|
||||
|
||||
We provide two models for this recipe: [TDNN LSTM CTC model][TIMIT_tdnn_lstm_ctc]
|
||||
@ -187,7 +247,8 @@ The PER for this model is:
|
||||
|--|--|
|
||||
|PER| 17.66% |
|
||||
|
||||
We provide a Colab notebook to run a pre-trained TDNN LiGRU CTC model: [](https://colab.research.google.com/drive/11IT-k4HQIgQngXz1uvWsEYktjqQt7Tmb?usp=sharing)
|
||||
We provide a Colab notebook to run a pre-trained TDNN LiGRU CTC model: [](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing)
|
||||
|
||||
|
||||
### TED-LIUM3
|
||||
|
||||
@ -215,24 +276,6 @@ The best WER using modified beam search with beam size 4 is:
|
||||
|
||||
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1je_1zGrOkGVVd4WLzgkXRHxl-I27yWtz?usp=sharing)
|
||||
|
||||
### GigaSpeech
|
||||
|
||||
We provide two models for this recipe: [Conformer CTC model][GigaSpeech_conformer_ctc]
|
||||
and [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][GigaSpeech_pruned_transducer_stateless2].
|
||||
|
||||
#### Conformer CTC
|
||||
|
||||
| | Dev | Test |
|
||||
|-----|-------|-------|
|
||||
| WER | 10.47 | 10.58 |
|
||||
|
||||
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
|
||||
|
||||
| | Dev | Test |
|
||||
|----------------------|-------|-------|
|
||||
| greedy search | 10.51 | 10.73 |
|
||||
| fast beam search | 10.50 | 10.69 |
|
||||
| modified beam search | 10.40 | 10.51 |
|
||||
|
||||
### Aidatatang_200zh
|
||||
|
||||
@ -248,6 +291,7 @@ We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder
|
||||
|
||||
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1wNSnSj3T5oOctbh5IGCa393gKOoQw2GH?usp=sharing)
|
||||
|
||||
|
||||
### WenetSpeech
|
||||
|
||||
We provide some models for this recipe: [Pruned stateless RNN-T_2: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][WenetSpeech_pruned_transducer_stateless2] and [Pruned stateless RNN-T_5: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][WenetSpeech_pruned_transducer_stateless5].
|
||||
@ -284,20 +328,6 @@ We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder
|
||||
|
||||
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1tKr3f0mL17uO_ljdHGKtR7HOmthYHwJG?usp=sharing)
|
||||
|
||||
### Aishell4
|
||||
|
||||
We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][Aishell4_pruned_transducer_stateless5].
|
||||
|
||||
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with all subsets)
|
||||
|
||||
The best CER(%) results:
|
||||
| | test |
|
||||
|----------------------|--------|
|
||||
| greedy search | 29.89 |
|
||||
| fast beam search | 28.91 |
|
||||
| modified beam search | 29.08 |
|
||||
|
||||
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing)
|
||||
|
||||
### TAL_CSASR
|
||||
|
||||
@ -333,6 +363,9 @@ Please see: [ Install PyTorch and torchaudio
|
||||
- (1) Install k2
|
||||
- (2) Install lhotse
|
||||
- (0) Install CUDA toolkit and cuDNN
|
||||
- (1) Install PyTorch and torchaudio
|
||||
- (2) Install k2
|
||||
- (3) Install lhotse
|
||||
|
||||
.. caution::
|
||||
|
||||
99% users who have issues about the installation are using conda.
|
||||
|
||||
.. caution::
|
||||
|
||||
99% users who have issues about the installation are using conda.
|
||||
|
||||
.. caution::
|
||||
|
||||
99% users who have issues about the installation are using conda.
|
||||
|
||||
.. hint::
|
||||
|
||||
We suggest that you use ``pip install`` to install PyTorch.
|
||||
|
||||
You can use the following command to create a virutal environment in Python:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python3 -m venv ./my_env
|
||||
source ./my_env/bin/activate
|
||||
|
||||
.. caution::
|
||||
|
||||
Installation order matters.
|
||||
|
||||
(0) Install PyTorch and torchaudio
|
||||
(0) Install CUDA toolkit and cuDNN
|
||||
----------------------------------
|
||||
|
||||
Please refer to
|
||||
`<https://k2-fsa.github.io/k2/installation/cuda-cudnn.html>`_
|
||||
to install CUDA and cuDNN.
|
||||
|
||||
|
||||
(1) Install PyTorch and torchaudio
|
||||
----------------------------------
|
||||
|
||||
Please refer `<https://pytorch.org/>`_ to install PyTorch
|
||||
and torchaudio.
|
||||
|
||||
.. hint::
|
||||
|
||||
(1) Install k2
|
||||
You can also go to `<https://download.pytorch.org/whl/torch_stable.html>`_
|
||||
to download pre-compiled wheels and install them.
|
||||
|
||||
.. caution::
|
||||
|
||||
Please install torch and torchaudio at the same time.
|
||||
|
||||
|
||||
(2) Install k2
|
||||
--------------
|
||||
|
||||
Please refer to `<https://k2-fsa.github.io/k2/installation/index.html>`_
|
||||
to install ``k2``.
|
||||
|
||||
.. CAUTION::
|
||||
.. caution::
|
||||
|
||||
You need to install ``k2`` with a version at least **v1.9**.
|
||||
Please don't change your installed PyTorch after you have installed k2.
|
||||
|
||||
.. HINT::
|
||||
.. note::
|
||||
|
||||
If you have already installed PyTorch and don't want to replace it,
|
||||
please install a version of ``k2`` that is compiled against the version
|
||||
of PyTorch you are using.
|
||||
We suggest that you install k2 from source by following
|
||||
`<https://k2-fsa.github.io/k2/installation/from_source.html>`_
|
||||
or
|
||||
`<https://k2-fsa.github.io/k2/installation/for_developers.html>`_.
|
||||
|
||||
(2) Install lhotse
|
||||
.. hint::
|
||||
|
||||
Please always install the latest version of k2.
|
||||
|
||||
(3) Install lhotse
|
||||
------------------
|
||||
|
||||
Please refer to `<https://lhotse.readthedocs.io/en/latest/getting-started.html#installation>`_
|
||||
@ -75,8 +102,7 @@ to install ``lhotse``.
|
||||
|
||||
to install the latest version of lhotse.
|
||||
|
||||
|
||||
(3) Download icefall
|
||||
(4) Download icefall
|
||||
--------------------
|
||||
|
||||
``icefall`` is a collection of Python scripts; what you need is to download it
|
||||
@ -338,44 +364,42 @@ The log of running ``./prepare.sh`` is:
|
||||
|
||||
.. code-block::
|
||||
|
||||
2021-08-23 19:27:26 (prepare.sh:24:main) dl_dir: /tmp/icefall/egs/yesno/ASR/download
|
||||
2021-08-23 19:27:26 (prepare.sh:27:main) stage 0: Download data
|
||||
Downloading waves_yesno.tar.gz: 4.49MB [00:03, 1.39MB/s]
|
||||
2021-08-23 19:27:30 (prepare.sh:36:main) Stage 1: Prepare yesno manifest
|
||||
2021-08-23 19:27:31 (prepare.sh:42:main) Stage 2: Compute fbank for yesno
|
||||
2021-08-23 19:27:32,803 INFO [compute_fbank_yesno.py:52] Processing train
|
||||
Extracting and storing features: 100%|_______________________________________________________________| 90/90 [00:01<00:00, 80.57it/s]
|
||||
2021-08-23 19:27:34,085 INFO [compute_fbank_yesno.py:52] Processing test
|
||||
Extracting and storing features: 100%|______________________________________________________________| 30/30 [00:00<00:00, 248.21it/s]
|
||||
2021-08-23 19:27:34 (prepare.sh:48:main) Stage 3: Prepare lang
|
||||
2021-08-23 19:27:35 (prepare.sh:63:main) Stage 4: Prepare G
|
||||
/tmp/pip-install-fcordre9/kaldilm_6899d26f2d684ad48f21025950cd2866/kaldilm/csrc/arpa_file_parser.cc:void kaldilm::ArpaFileParser::Rea
|
||||
d(std::istream&):79
|
||||
[I] Reading \data\ section.
|
||||
/tmp/pip-install-fcordre9/kaldilm_6899d26f2d684ad48f21025950cd2866/kaldilm/csrc/arpa_file_parser.cc:void kaldilm::ArpaFileParser::Rea
|
||||
d(std::istream&):140
|
||||
[I] Reading \1-grams: section.
|
||||
2021-08-23 19:27:35 (prepare.sh:89:main) Stage 5: Compile HLG
|
||||
2021-08-23 19:27:35,928 INFO [compile_hlg.py:120] Processing data/lang_phone
|
||||
2021-08-23 19:27:35,929 INFO [lexicon.py:116] Converting L.pt to Linv.pt
|
||||
2021-08-23 19:27:35,931 INFO [compile_hlg.py:48] Building ctc_topo. max_token_id: 3
|
||||
2021-08-23 19:27:35,932 INFO [compile_hlg.py:52] Loading G.fst.txt
|
||||
2021-08-23 19:27:35,932 INFO [compile_hlg.py:62] Intersecting L and G
|
||||
2021-08-23 19:27:35,933 INFO [compile_hlg.py:64] LG shape: (4, None)
|
||||
2021-08-23 19:27:35,933 INFO [compile_hlg.py:66] Connecting LG
|
||||
2021-08-23 19:27:35,933 INFO [compile_hlg.py:68] LG shape after k2.connect: (4, None)
|
||||
2021-08-23 19:27:35,933 INFO [compile_hlg.py:70] <class 'torch.Tensor'>
|
||||
2021-08-23 19:27:35,933 INFO [compile_hlg.py:71] Determinizing LG
|
||||
2021-08-23 19:27:35,934 INFO [compile_hlg.py:74] <class '_k2.RaggedInt'>
|
||||
2021-08-23 19:27:35,934 INFO [compile_hlg.py:76] Connecting LG after k2.determinize
|
||||
2021-08-23 19:27:35,934 INFO [compile_hlg.py:79] Removing disambiguation symbols on LG
|
||||
2021-08-23 19:27:35,934 INFO [compile_hlg.py:87] LG shape after k2.remove_epsilon: (6, None)
|
||||
2021-08-23 19:27:35,935 INFO [compile_hlg.py:92] Arc sorting LG
|
||||
2021-08-23 19:27:35,935 INFO [compile_hlg.py:95] Composing H and LG
|
||||
2021-08-23 19:27:35,935 INFO [compile_hlg.py:102] Connecting LG
|
||||
2021-08-23 19:27:35,935 INFO [compile_hlg.py:105] Arc sorting LG
|
||||
2021-08-23 19:27:35,936 INFO [compile_hlg.py:107] HLG.shape: (8, None)
|
||||
2021-08-23 19:27:35,936 INFO [compile_hlg.py:123] Saving HLG.pt to data/lang_phone
|
||||
2023-05-12 17:55:21 (prepare.sh:27:main) dl_dir: /tmp/icefall/egs/yesno/ASR/download
|
||||
2023-05-12 17:55:21 (prepare.sh:30:main) Stage 0: Download data
|
||||
/tmp/icefall/egs/yesno/ASR/download/waves_yesno.tar.gz: 100%|_______________________________________________________________| 4.70M/4.70M [06:54<00:00, 11.4kB/s]
|
||||
2023-05-12 18:02:19 (prepare.sh:39:main) Stage 1: Prepare yesno manifest
|
||||
2023-05-12 18:02:21 (prepare.sh:45:main) Stage 2: Compute fbank for yesno
|
||||
2023-05-12 18:02:23,199 INFO [compute_fbank_yesno.py:65] Processing train
|
||||
Extracting and storing features: 100%|_______________________________________________________________| 90/90 [00:00<00:00, 212.60it/s]
|
||||
2023-05-12 18:02:23,640 INFO [compute_fbank_yesno.py:65] Processing test
|
||||
Extracting and storing features: 100%|_______________________________________________________________| 30/30 [00:00<00:00, 304.53it/s]
|
||||
2023-05-12 18:02:24 (prepare.sh:51:main) Stage 3: Prepare lang
|
||||
2023-05-12 18:02:26 (prepare.sh:66:main) Stage 4: Prepare G
|
||||
/project/kaldilm/csrc/arpa_file_parser.cc:void kaldilm::ArpaFileParser::Read(std::istream&):79
|
||||
[I] Reading \data\ section.
|
||||
/project/kaldilm/csrc/arpa_file_parser.cc:void kaldilm::ArpaFileParser::Read(std::istream&):140
|
||||
[I] Reading \1-grams: section.
|
||||
2023-05-12 18:02:26 (prepare.sh:92:main) Stage 5: Compile HLG
|
||||
2023-05-12 18:02:28,581 INFO [compile_hlg.py:124] Processing data/lang_phone
|
||||
2023-05-12 18:02:28,582 INFO [lexicon.py:171] Converting L.pt to Linv.pt
|
||||
2023-05-12 18:02:28,609 INFO [compile_hlg.py:48] Building ctc_topo. max_token_id: 3
|
||||
2023-05-12 18:02:28,610 INFO [compile_hlg.py:52] Loading G.fst.txt
|
||||
2023-05-12 18:02:28,611 INFO [compile_hlg.py:62] Intersecting L and G
|
||||
2023-05-12 18:02:28,613 INFO [compile_hlg.py:64] LG shape: (4, None)
|
||||
2023-05-12 18:02:28,613 INFO [compile_hlg.py:66] Connecting LG
|
||||
2023-05-12 18:02:28,614 INFO [compile_hlg.py:68] LG shape after k2.connect: (4, None)
|
||||
2023-05-12 18:02:28,614 INFO [compile_hlg.py:70] <class 'torch.Tensor'>
|
||||
2023-05-12 18:02:28,614 INFO [compile_hlg.py:71] Determinizing LG
|
||||
2023-05-12 18:02:28,615 INFO [compile_hlg.py:74] <class '_k2.ragged.RaggedTensor'>
|
||||
2023-05-12 18:02:28,615 INFO [compile_hlg.py:76] Connecting LG after k2.determinize
|
||||
2023-05-12 18:02:28,615 INFO [compile_hlg.py:79] Removing disambiguation symbols on LG
|
||||
2023-05-12 18:02:28,616 INFO [compile_hlg.py:91] LG shape after k2.remove_epsilon: (6, None)
|
||||
2023-05-12 18:02:28,617 INFO [compile_hlg.py:96] Arc sorting LG
|
||||
2023-05-12 18:02:28,617 INFO [compile_hlg.py:99] Composing H and LG
|
||||
2023-05-12 18:02:28,619 INFO [compile_hlg.py:106] Connecting LG
|
||||
2023-05-12 18:02:28,619 INFO [compile_hlg.py:109] Arc sorting LG
|
||||
2023-05-12 18:02:28,619 INFO [compile_hlg.py:111] HLG.shape: (8, None)
|
||||
2023-05-12 18:02:28,619 INFO [compile_hlg.py:127] Saving HLG.pt to data/lang_phone
|
||||
|
||||
|
||||
Training
|
||||
@ -408,49 +432,53 @@ The training log is given below:
|
||||
|
||||
.. code-block::
|
||||
|
||||
2021-08-23 19:30:31,072 INFO [train.py:465] Training started
|
||||
2021-08-23 19:30:31,072 INFO [train.py:466] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lr': 0.01,
|
||||
'feature_dim': 23, 'weight_decay': 1e-06, 'start_epoch': 0, 'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, '
|
||||
best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 'valid_interval': 10, 'beam_size': 10, 'reduction': 'sum', 'use_doub
|
||||
le_scores': True, 'world_size': 1, 'master_port': 12354, 'tensorboard': True, 'num_epochs': 15, 'feature_dir': PosixPath('data/fbank'
|
||||
), 'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0
|
||||
, 'on_the_fly_feats': False, 'shuffle': True, 'return_cuts': True, 'num_workers': 2}
|
||||
2021-08-23 19:30:31,074 INFO [lexicon.py:113] Loading pre-compiled data/lang_phone/Linv.pt
|
||||
2021-08-23 19:30:31,098 INFO [asr_datamodule.py:146] About to get train cuts
|
||||
2021-08-23 19:30:31,098 INFO [asr_datamodule.py:240] About to get train cuts
|
||||
2021-08-23 19:30:31,102 INFO [asr_datamodule.py:149] About to create train dataset
|
||||
2021-08-23 19:30:31,102 INFO [asr_datamodule.py:200] Using SingleCutSampler.
|
||||
2021-08-23 19:30:31,102 INFO [asr_datamodule.py:206] About to create train dataloader
|
||||
2021-08-23 19:30:31,102 INFO [asr_datamodule.py:219] About to get test cuts
|
||||
2021-08-23 19:30:31,102 INFO [asr_datamodule.py:246] About to get test cuts
|
||||
2021-08-23 19:30:31,357 INFO [train.py:416] Epoch 0, batch 0, batch avg loss 1.0789, total avg loss: 1.0789, batch size: 4
|
||||
2021-08-23 19:30:31,848 INFO [train.py:416] Epoch 0, batch 10, batch avg loss 0.5356, total avg loss: 0.7556, batch size: 4
|
||||
2021-08-23 19:30:32,301 INFO [train.py:432] Epoch 0, valid loss 0.9972, best valid loss: 0.9972 best valid epoch: 0
|
||||
2021-08-23 19:30:32,805 INFO [train.py:416] Epoch 0, batch 20, batch avg loss 0.2436, total avg loss: 0.5717, batch size: 3
|
||||
2021-08-23 19:30:33,109 INFO [train.py:432] Epoch 0, valid loss 0.4167, best valid loss: 0.4167 best valid epoch: 0
|
||||
2021-08-23 19:30:33,121 INFO [checkpoint.py:62] Saving checkpoint to tdnn/exp/epoch-0.pt
|
||||
2021-08-23 19:30:33,325 INFO [train.py:416] Epoch 1, batch 0, batch avg loss 0.2214, total avg loss: 0.2214, batch size: 5
|
||||
2021-08-23 19:30:33,798 INFO [train.py:416] Epoch 1, batch 10, batch avg loss 0.0781, total avg loss: 0.1343, batch size: 5
|
||||
2021-08-23 19:30:34,065 INFO [train.py:432] Epoch 1, valid loss 0.0859, best valid loss: 0.0859 best valid epoch: 1
|
||||
2021-08-23 19:30:34,556 INFO [train.py:416] Epoch 1, batch 20, batch avg loss 0.0421, total avg loss: 0.0975, batch size: 3
|
||||
2021-08-23 19:30:34,810 INFO [train.py:432] Epoch 1, valid loss 0.0431, best valid loss: 0.0431 best valid epoch: 1
|
||||
2021-08-23 19:30:34,824 INFO [checkpoint.py:62] Saving checkpoint to tdnn/exp/epoch-1.pt
|
||||
2023-05-12 18:04:59,759 INFO [train.py:481] Training started
|
||||
2023-05-12 18:04:59,759 INFO [train.py:482] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lr': 0.01, 'feature_dim': 23, 'weight_decay': 1e-06, 'start_epoch': 0,
|
||||
'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 'reset_interval': 20, 'valid_interval': 10, 'beam_size': 10,
|
||||
'reduction': 'sum', 'use_double_scores': True, 'world_size': 1, 'master_port': 12354, 'tensorboard': True, 'num_epochs': 15, 'seed': 42, 'feature_dir': PosixPath('data/fbank'), 'max_duration': 30.0,
|
||||
'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': False, 'return_cuts': True, 'num_workers': 2,
|
||||
'env_info': {'k2-version': '1.24.3', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '3b7f09fa35e72589914f67089c0da9f196a92ca4', 'k2-git-date': 'Mon May 8 22:58:45 2023',
|
||||
'lhotse-version': '1.15.0.dev+git.6fcfced.clean', 'torch-version': '2.0.0+cu118', 'torch-cuda-available': False, 'torch-cuda-version': '11.8', 'python-version': '3.1', 'icefall-git-branch': 'master',
|
||||
'icefall-git-sha1': '30bde4b-clean', 'icefall-git-date': 'Thu May 11 17:37:47 2023', 'icefall-path': '/tmp/icefall',
|
||||
'k2-path': 'tmp/lib/python3.10/site-packages/k2-1.24.3.dev20230512+cuda11.8.torch2.0.0-py3.10-linux-x86_64.egg/k2/__init__.py',
|
||||
'lhotse-path': 'tmp/lib/python3.10/site-packages/lhotse/__init__.py', 'hostname': 'host', 'IP address': '0.0.0.0'}}
|
||||
2023-05-12 18:04:59,761 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt
|
||||
2023-05-12 18:04:59,764 INFO [train.py:495] device: cpu
|
||||
2023-05-12 18:04:59,791 INFO [asr_datamodule.py:146] About to get train cuts
|
||||
2023-05-12 18:04:59,791 INFO [asr_datamodule.py:244] About to get train cuts
|
||||
2023-05-12 18:04:59,852 INFO [asr_datamodule.py:149] About to create train dataset
|
||||
2023-05-12 18:04:59,852 INFO [asr_datamodule.py:199] Using SingleCutSampler.
|
||||
2023-05-12 18:04:59,852 INFO [asr_datamodule.py:205] About to create train dataloader
|
||||
2023-05-12 18:04:59,853 INFO [asr_datamodule.py:218] About to get test cuts
|
||||
2023-05-12 18:04:59,853 INFO [asr_datamodule.py:252] About to get test cuts
|
||||
2023-05-12 18:04:59,986 INFO [train.py:422] Epoch 0, batch 0, loss[loss=1.065, over 2436.00 frames. ], tot_loss[loss=1.065, over 2436.00 frames. ], batch size: 4
|
||||
2023-05-12 18:05:00,352 INFO [train.py:422] Epoch 0, batch 10, loss[loss=0.4561, over 2828.00 frames. ], tot_loss[loss=0.7076, over 22192.90 frames. ], batch size: 4
|
||||
2023-05-12 18:05:00,691 INFO [train.py:444] Epoch 0, validation loss=0.9002, over 18067.00 frames.
|
||||
2023-05-12 18:05:00,996 INFO [train.py:422] Epoch 0, batch 20, loss[loss=0.2555, over 2695.00 frames. ], tot_loss[loss=0.484, over 34971.47 frames. ], batch size: 5
|
||||
2023-05-12 18:05:01,217 INFO [train.py:444] Epoch 0, validation loss=0.4688, over 18067.00 frames.
|
||||
2023-05-12 18:05:01,251 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-0.pt
|
||||
2023-05-12 18:05:01,389 INFO [train.py:422] Epoch 1, batch 0, loss[loss=0.2532, over 2436.00 frames. ], tot_loss[loss=0.2532, over 2436.00 frames. ], batch size: 4
|
||||
2023-05-12 18:05:01,637 INFO [train.py:422] Epoch 1, batch 10, loss[loss=0.1139, over 2828.00 frames. ], tot_loss[loss=0.1592, over 22192.90 frames. ], batch size: 4
|
||||
2023-05-12 18:05:01,859 INFO [train.py:444] Epoch 1, validation loss=0.1629, over 18067.00 frames.
|
||||
2023-05-12 18:05:02,094 INFO [train.py:422] Epoch 1, batch 20, loss[loss=0.0767, over 2695.00 frames. ], tot_loss[loss=0.118, over 34971.47 frames. ], batch size: 5
|
||||
2023-05-12 18:05:02,350 INFO [train.py:444] Epoch 1, validation loss=0.06778, over 18067.00 frames.
|
||||
2023-05-12 18:05:02,395 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-1.pt
|
||||
|
||||
... ...
|
||||
|
||||
2021-08-23 19:30:49,657 INFO [train.py:416] Epoch 13, batch 0, batch avg loss 0.0109, total avg loss: 0.0109, batch size: 5
|
||||
2021-08-23 19:30:49,984 INFO [train.py:416] Epoch 13, batch 10, batch avg loss 0.0093, total avg loss: 0.0096, batch size: 4
|
||||
2021-08-23 19:30:50,239 INFO [train.py:432] Epoch 13, valid loss 0.0104, best valid loss: 0.0101 best valid epoch: 12
|
||||
2021-08-23 19:30:50,569 INFO [train.py:416] Epoch 13, batch 20, batch avg loss 0.0092, total avg loss: 0.0096, batch size: 2
|
||||
2021-08-23 19:30:50,819 INFO [train.py:432] Epoch 13, valid loss 0.0101, best valid loss: 0.0101 best valid epoch: 13
|
||||
2021-08-23 19:30:50,835 INFO [checkpoint.py:62] Saving checkpoint to tdnn/exp/epoch-13.pt
|
||||
2021-08-23 19:30:51,024 INFO [train.py:416] Epoch 14, batch 0, batch avg loss 0.0105, total avg loss: 0.0105, batch size: 5
|
||||
2021-08-23 19:30:51,317 INFO [train.py:416] Epoch 14, batch 10, batch avg loss 0.0099, total avg loss: 0.0097, batch size: 4
|
||||
2021-08-23 19:30:51,552 INFO [train.py:432] Epoch 14, valid loss 0.0108, best valid loss: 0.0101 best valid epoch: 13
|
||||
2021-08-23 19:30:51,869 INFO [train.py:416] Epoch 14, batch 20, batch avg loss 0.0096, total avg loss: 0.0097, batch size: 5
|
||||
2021-08-23 19:30:52,107 INFO [train.py:432] Epoch 14, valid loss 0.0102, best valid loss: 0.0101 best valid epoch: 13
|
||||
2021-08-23 19:30:52,126 INFO [checkpoint.py:62] Saving checkpoint to tdnn/exp/epoch-14.pt
|
||||
2021-08-23 19:30:52,128 INFO [train.py:537] Done!
|
||||
2023-05-12 18:05:14,789 INFO [train.py:422] Epoch 13, batch 0, loss[loss=0.01056, over 2436.00 frames. ], tot_loss[loss=0.01056, over 2436.00 frames. ], batch size: 4
|
||||
2023-05-12 18:05:15,016 INFO [train.py:422] Epoch 13, batch 10, loss[loss=0.009022, over 2828.00 frames. ], tot_loss[loss=0.009985, over 22192.90 frames. ], batch size: 4
|
||||
2023-05-12 18:05:15,271 INFO [train.py:444] Epoch 13, validation loss=0.01088, over 18067.00 frames.
|
||||
2023-05-12 18:05:15,497 INFO [train.py:422] Epoch 13, batch 20, loss[loss=0.01174, over 2695.00 frames. ], tot_loss[loss=0.01077, over 34971.47 frames. ], batch size: 5
|
||||
2023-05-12 18:05:15,747 INFO [train.py:444] Epoch 13, validation loss=0.01087, over 18067.00 frames.
|
||||
2023-05-12 18:05:15,783 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-13.pt
|
||||
2023-05-12 18:05:15,921 INFO [train.py:422] Epoch 14, batch 0, loss[loss=0.01045, over 2436.00 frames. ], tot_loss[loss=0.01045, over 2436.00 frames. ], batch size: 4
|
||||
2023-05-12 18:05:16,146 INFO [train.py:422] Epoch 14, batch 10, loss[loss=0.008957, over 2828.00 frames. ], tot_loss[loss=0.009903, over 22192.90 frames. ], batch size: 4
|
||||
2023-05-12 18:05:16,374 INFO [train.py:444] Epoch 14, validation loss=0.01092, over 18067.00 frames.
|
||||
2023-05-12 18:05:16,598 INFO [train.py:422] Epoch 14, batch 20, loss[loss=0.01169, over 2695.00 frames. ], tot_loss[loss=0.01065, over 34971.47 frames. ], batch size: 5
|
||||
2023-05-12 18:05:16,824 INFO [train.py:444] Epoch 14, validation loss=0.01077, over 18067.00 frames.
|
||||
2023-05-12 18:05:16,862 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-14.pt
|
||||
2023-05-12 18:05:16,865 INFO [train.py:555] Done!
|
||||
|
||||
Decoding
|
||||
~~~~~~~~
|
||||
@ -465,22 +493,25 @@ The decoding log is:
|
||||
|
||||
.. code-block::
|
||||
|
||||
2021-08-23 19:35:30,192 INFO [decode.py:249] Decoding started
|
||||
2021-08-23 19:35:30,192 INFO [decode.py:250] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lm_dir': PosixPath('data/lm'), 'feature_dim': 23, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'epoch': 14, 'avg': 2, 'feature_dir': PosixPath('data/fbank'), 'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': True, 'return_cuts': True, 'num_workers': 2}
|
||||
2021-08-23 19:35:30,193 INFO [lexicon.py:113] Loading pre-compiled data/lang_phone/Linv.pt
|
||||
2021-08-23 19:35:30,213 INFO [decode.py:259] device: cpu
|
||||
2021-08-23 19:35:30,217 INFO [decode.py:279] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt']
|
||||
/tmp/icefall/icefall/checkpoint.py:146: UserWarning: floor_divide is deprecated, and will be removed in a future version of pytorch.
|
||||
It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values.
|
||||
To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at /pytorch/aten/src/ATen/native/BinaryOps.cpp:450.)
|
||||
avg[k] //= n
|
||||
2021-08-23 19:35:30,220 INFO [asr_datamodule.py:219] About to get test cuts
|
||||
2021-08-23 19:35:30,220 INFO [asr_datamodule.py:246] About to get test cuts
|
||||
2021-08-23 19:35:30,409 INFO [decode.py:190] batch 0/8, cuts processed until now is 4
|
||||
2021-08-23 19:35:30,571 INFO [decode.py:228] The transcripts are stored in tdnn/exp/recogs-test_set.txt
|
||||
2021-08-23 19:35:30,572 INFO [utils.py:317] [test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ]
|
||||
2021-08-23 19:35:30,573 INFO [decode.py:236] Wrote detailed error stats to tdnn/exp/errs-test_set.txt
|
||||
2021-08-23 19:35:30,573 INFO [decode.py:299] Done!
|
||||
2023-05-12 18:08:30,482 INFO [decode.py:263] Decoding started
|
||||
2023-05-12 18:08:30,483 INFO [decode.py:264] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lm_dir': PosixPath('data/lm'), 'feature_dim': 23,
|
||||
'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'epoch': 14, 'avg': 2, 'export': False, 'feature_dir': PosixPath('data/fbank'),
|
||||
'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': False, 'return_cuts': True,
|
||||
'num_workers': 2, 'env_info': {'k2-version': '1.24.3', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '3b7f09fa35e72589914f67089c0da9f196a92ca4', 'k2-git-date': 'Mon May 8 22:58:45 2023',
|
||||
'lhotse-version': '1.15.0.dev+git.6fcfced.clean', 'torch-version': '2.0.0+cu118', 'torch-cuda-available': False, 'torch-cuda-version': '11.8', 'python-version': '3.1', 'icefall-git-branch': 'master',
|
||||
'icefall-git-sha1': '30bde4b-clean', 'icefall-git-date': 'Thu May 11 17:37:47 2023', 'icefall-path': '/tmp/icefall',
|
||||
'k2-path': '/tmp/lib/python3.10/site-packages/k2-1.24.3.dev20230512+cuda11.8.torch2.0.0-py3.10-linux-x86_64.egg/k2/__init__.py',
|
||||
'lhotse-path': '/tmp/lib/python3.10/site-packages/lhotse/__init__.py', 'hostname': 'host', 'IP address': '0.0.0.0'}}
|
||||
2023-05-12 18:08:30,483 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt
|
||||
2023-05-12 18:08:30,487 INFO [decode.py:273] device: cpu
|
||||
2023-05-12 18:08:30,513 INFO [decode.py:291] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt']
|
||||
2023-05-12 18:08:30,521 INFO [asr_datamodule.py:218] About to get test cuts
|
||||
2023-05-12 18:08:30,521 INFO [asr_datamodule.py:252] About to get test cuts
|
||||
2023-05-12 18:08:30,675 INFO [decode.py:204] batch 0/?, cuts processed until now is 4
|
||||
2023-05-12 18:08:30,923 INFO [decode.py:241] The transcripts are stored in tdnn/exp/recogs-test_set.txt
|
||||
2023-05-12 18:08:30,924 INFO [utils.py:558] [test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ]
|
||||
2023-05-12 18:08:30,925 INFO [decode.py:249] Wrote detailed error stats to tdnn/exp/errs-test_set.txt
|
||||
2023-05-12 18:08:30,925 INFO [decode.py:316] Done!
|
||||
|
||||
**Congratulations!** You have successfully setup the environment and have run the first recipe in ``icefall``.
|
||||
|
||||
|
@ -3,6 +3,15 @@ Export to ONNX
|
||||
|
||||
In this section, we describe how to export models to `ONNX`_.
|
||||
|
||||
.. hint::
|
||||
|
||||
Before you continue, please run:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install onnx
|
||||
|
||||
|
||||
In each recipe, there is a file called ``export-onnx.py``, which is used
|
||||
to export trained models to `ONNX`_.
|
||||
|
||||
|
@ -2,6 +2,57 @@
|
||||
|
||||
### Aishell training result(Stateless Transducer)
|
||||
|
||||
#### Pruned transducer stateless 7 (zipformer)
|
||||
|
||||
See <https://github.com/k2-fsa/icefall/pull/986>
|
||||
|
||||
[./pruned_transducer_stateless7_bbpe](./pruned_transducer_stateless7_bbpe)
|
||||
|
||||
**Note**: The modeling units are byte level BPEs
|
||||
|
||||
The best results I have gotten are:
|
||||
|
||||
Vocab size | Greedy search(dev & test) | Modified beam search(dev & test) | Fast beam search (dev & test) | Fast beam search LG (dev & test) | comments
|
||||
-- | -- | -- | -- | -- | --
|
||||
500 | 4.31 & 4.59 | 4.25 & 4.54 | 4.27 & 4.55 | 4.07 & 4.38 | --epoch 48 --avg 29
|
||||
|
||||
The training command:
|
||||
|
||||
```
|
||||
export CUDA_VISIBLE_DEVICES="4,5,6,7"
|
||||
|
||||
./pruned_transducer_stateless7_bbpe/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 50 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--max-duration 800 \
|
||||
--bpe-model data/lang_bbpe_500/bbpe.model \
|
||||
--exp-dir pruned_transducer_stateless7_bbpe/exp \
|
||||
--lr-epochs 6 \
|
||||
--master-port 12535
|
||||
```
|
||||
|
||||
The decoding command:
|
||||
|
||||
```
|
||||
for m in greedy_search modified_beam_search fast_beam_search fast_beam_search_LG; do
|
||||
./pruned_transducer_stateless7_bbpe/decode.py \
|
||||
--epoch 48 \
|
||||
--avg 29 \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--max-sym-per-frame 1 \
|
||||
--ngram-lm-scale 0.25 \
|
||||
--ilme-scale 0.2 \
|
||||
--bpe-model data/lang_bbpe_500/bbpe.model \
|
||||
--max-duration 2000 \
|
||||
--decoding-method $m
|
||||
done
|
||||
```
|
||||
|
||||
The pretrained model is available at: https://huggingface.co/pkufool/icefall_asr_aishell_pruned_transducer_stateless7_bbpe
|
||||
|
||||
|
||||
#### Pruned transducer stateless 3
|
||||
|
||||
See <https://github.com/k2-fsa/icefall/pull/436>
|
||||
@ -75,7 +126,7 @@ for epoch in 29; do
|
||||
done
|
||||
```
|
||||
|
||||
We provide the option of shallow fusion with a RNN language model. The pre-trained language model is
|
||||
We provide the option of shallow fusion with a RNN language model. The pre-trained language model is
|
||||
available at <https://huggingface.co/marcoyang/icefall-aishell-rnn-lm>. To decode with the language model,
|
||||
please use the following command:
|
||||
|
||||
|
1
egs/aishell/ASR/local/compile_lg.py
Symbolic link
1
egs/aishell/ASR/local/compile_lg.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/compile_lg.py
|
@ -33,6 +33,7 @@ and generates the following files in the directory `lang_dir`:
|
||||
- tokens.txt
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
@ -189,8 +190,22 @@ def generate_tokens(text_file: str) -> Dict[str, int]:
|
||||
return tokens
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
It should contain the bpe.model and words.txt
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
lang_dir = Path("data/lang_char")
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
text_file = lang_dir / "text"
|
||||
|
||||
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||
|
267
egs/aishell/ASR/local/prepare_lang_bbpe.py
Executable file
267
egs/aishell/ASR/local/prepare_lang_bbpe.py
Executable file
@ -0,0 +1,267 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
||||
# Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
This script takes as input `lang_dir`, which should contain::
|
||||
|
||||
- lang_dir/bbpe.model,
|
||||
- lang_dir/words.txt
|
||||
|
||||
and generates the following files in the directory `lang_dir`:
|
||||
|
||||
- lexicon.txt
|
||||
- lexicon_disambig.txt
|
||||
- L.pt
|
||||
- L_disambig.pt
|
||||
- tokens.txt
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from prepare_lang import (
|
||||
Lexicon,
|
||||
add_disambig_symbols,
|
||||
add_self_loops,
|
||||
write_lexicon,
|
||||
write_mapping,
|
||||
)
|
||||
|
||||
from icefall.byte_utils import byte_encode
|
||||
from icefall.utils import str2bool, tokenize_by_CJK_char
|
||||
|
||||
|
||||
def lexicon_to_fst_no_sil(
|
||||
lexicon: Lexicon,
|
||||
token2id: Dict[str, int],
|
||||
word2id: Dict[str, int],
|
||||
need_self_loops: bool = False,
|
||||
) -> k2.Fsa:
|
||||
"""Convert a lexicon to an FST (in k2 format).
|
||||
|
||||
Args:
|
||||
lexicon:
|
||||
The input lexicon. See also :func:`read_lexicon`
|
||||
token2id:
|
||||
A dict mapping tokens to IDs.
|
||||
word2id:
|
||||
A dict mapping words to IDs.
|
||||
need_self_loops:
|
||||
If True, add self-loop to states with non-epsilon output symbols
|
||||
on at least one arc out of the state. The input label for this
|
||||
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
|
||||
Returns:
|
||||
Return an instance of `k2.Fsa` representing the given lexicon.
|
||||
"""
|
||||
loop_state = 0 # words enter and leave from here
|
||||
next_state = 1 # the next un-allocated state, will be incremented as we go
|
||||
|
||||
arcs = []
|
||||
|
||||
# The blank symbol <blk> is defined in local/train_bpe_model.py
|
||||
assert token2id["<blk>"] == 0
|
||||
assert word2id["<eps>"] == 0
|
||||
|
||||
eps = 0
|
||||
|
||||
for word, pieces in lexicon:
|
||||
assert len(pieces) > 0, f"{word} has no pronunciations"
|
||||
cur_state = loop_state
|
||||
|
||||
word = word2id[word]
|
||||
pieces = [token2id[i] for i in pieces]
|
||||
|
||||
for i in range(len(pieces) - 1):
|
||||
w = word if i == 0 else eps
|
||||
arcs.append([cur_state, next_state, pieces[i], w, 0])
|
||||
|
||||
cur_state = next_state
|
||||
next_state += 1
|
||||
|
||||
# now for the last piece of this word
|
||||
i = len(pieces) - 1
|
||||
w = word if i == 0 else eps
|
||||
arcs.append([cur_state, loop_state, pieces[i], w, 0])
|
||||
|
||||
if need_self_loops:
|
||||
disambig_token = token2id["#0"]
|
||||
disambig_word = word2id["#0"]
|
||||
arcs = add_self_loops(
|
||||
arcs,
|
||||
disambig_token=disambig_token,
|
||||
disambig_word=disambig_word,
|
||||
)
|
||||
|
||||
final_state = next_state
|
||||
arcs.append([loop_state, final_state, -1, -1, 0])
|
||||
arcs.append([final_state])
|
||||
|
||||
arcs = sorted(arcs, key=lambda arc: arc[0])
|
||||
arcs = [[str(i) for i in arc] for arc in arcs]
|
||||
arcs = [" ".join(arc) for arc in arcs]
|
||||
arcs = "\n".join(arcs)
|
||||
|
||||
fsa = k2.Fsa.from_str(arcs, acceptor=False)
|
||||
return fsa
|
||||
|
||||
|
||||
def generate_lexicon(
|
||||
model_file: str, words: List[str], oov: str
|
||||
) -> Tuple[Lexicon, Dict[str, int]]:
|
||||
"""Generate a lexicon from a BPE model.
|
||||
|
||||
Args:
|
||||
model_file:
|
||||
Path to a sentencepiece model.
|
||||
words:
|
||||
A list of strings representing words.
|
||||
oov:
|
||||
The out of vocabulary word in lexicon.
|
||||
Returns:
|
||||
Return a tuple with two elements:
|
||||
- A dict whose keys are words and values are the corresponding
|
||||
word pieces.
|
||||
- A dict representing the token symbol, mapping from tokens to IDs.
|
||||
"""
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(str(model_file))
|
||||
|
||||
# Convert word to word piece IDs instead of word piece strings
|
||||
# to avoid OOV tokens.
|
||||
encode_words = [byte_encode(tokenize_by_CJK_char(w)) for w in words]
|
||||
words_pieces_ids: List[List[int]] = sp.encode(encode_words, out_type=int)
|
||||
|
||||
# Now convert word piece IDs back to word piece strings.
|
||||
words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids]
|
||||
|
||||
lexicon = []
|
||||
for word, pieces in zip(words, words_pieces):
|
||||
lexicon.append((word, pieces))
|
||||
|
||||
lexicon.append((oov, ["▁", sp.id_to_piece(sp.unk_id())]))
|
||||
|
||||
token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())}
|
||||
|
||||
return lexicon, token2id
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
It should contain the bpe.model and words.txt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--oov",
|
||||
type=str,
|
||||
default="<UNK>",
|
||||
help="The out of vocabulary word in lexicon.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True for debugging, which will generate
|
||||
a visualization of the lexicon FST.
|
||||
|
||||
Caution: If your lexicon contains hundreds of thousands
|
||||
of lines, please set it to False!
|
||||
|
||||
See "test/test_bpe_lexicon.py" for usage.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
model_file = lang_dir / "bbpe.model"
|
||||
|
||||
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||
|
||||
words = word_sym_table.symbols
|
||||
|
||||
excluded = ["<eps>", "!SIL", "<SPOKEN_NOISE>", args.oov, "#0", "<s>", "</s>"]
|
||||
|
||||
for w in excluded:
|
||||
if w in words:
|
||||
words.remove(w)
|
||||
|
||||
lexicon, token_sym_table = generate_lexicon(model_file, words, args.oov)
|
||||
|
||||
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
|
||||
|
||||
next_token_id = max(token_sym_table.values()) + 1
|
||||
for i in range(max_disambig + 1):
|
||||
disambig = f"#{i}"
|
||||
assert disambig not in token_sym_table
|
||||
token_sym_table[disambig] = next_token_id
|
||||
next_token_id += 1
|
||||
|
||||
word_sym_table.add("#0")
|
||||
word_sym_table.add("<s>")
|
||||
word_sym_table.add("</s>")
|
||||
|
||||
write_mapping(lang_dir / "tokens.txt", token_sym_table)
|
||||
|
||||
write_lexicon(lang_dir / "lexicon.txt", lexicon)
|
||||
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
|
||||
|
||||
L = lexicon_to_fst_no_sil(
|
||||
lexicon,
|
||||
token2id=token_sym_table,
|
||||
word2id=word_sym_table,
|
||||
)
|
||||
|
||||
L_disambig = lexicon_to_fst_no_sil(
|
||||
lexicon_disambig,
|
||||
token2id=token_sym_table,
|
||||
word2id=word_sym_table,
|
||||
need_self_loops=True,
|
||||
)
|
||||
torch.save(L.as_dict(), lang_dir / "L.pt")
|
||||
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
|
||||
|
||||
if args.debug:
|
||||
labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
|
||||
aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||
|
||||
L.labels_sym = labels_sym
|
||||
L.aux_labels_sym = aux_labels_sym
|
||||
L.draw(f"{lang_dir / 'L.svg'}", title="L.pt")
|
||||
|
||||
L_disambig.labels_sym = labels_sym
|
||||
L_disambig.aux_labels_sym = aux_labels_sym
|
||||
L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
113
egs/aishell/ASR/local/train_bbpe_model.py
Executable file
113
egs/aishell/ASR/local/train_bbpe_model.py
Executable file
@ -0,0 +1,113 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# You can install sentencepiece via:
|
||||
#
|
||||
# pip install sentencepiece
|
||||
#
|
||||
# Due to an issue reported in
|
||||
# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030
|
||||
#
|
||||
# Please install a version >=0.1.96
|
||||
|
||||
import argparse
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
from icefall import byte_encode, tokenize_by_CJK_char
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
The generated bpe.model is saved to this directory.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--transcript",
|
||||
type=str,
|
||||
help="Training transcript.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vocab-size",
|
||||
type=int,
|
||||
help="Vocabulary size for BPE training",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _convert_to_bchar(in_path: str, out_path: str):
|
||||
with open(out_path, "w") as f:
|
||||
for line in open(in_path, "r").readlines():
|
||||
f.write(byte_encode(tokenize_by_CJK_char(line)) + "\n")
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
vocab_size = args.vocab_size
|
||||
lang_dir = Path(args.lang_dir)
|
||||
|
||||
model_type = "unigram"
|
||||
|
||||
model_prefix = f"{lang_dir}/{model_type}_{vocab_size}"
|
||||
character_coverage = 1.0
|
||||
input_sentence_size = 100000000
|
||||
|
||||
user_defined_symbols = ["<blk>", "<sos/eos>"]
|
||||
unk_id = len(user_defined_symbols)
|
||||
# Note: unk_id is fixed to 2.
|
||||
# If you change it, you should also change other
|
||||
# places that are using it.
|
||||
|
||||
temp = tempfile.NamedTemporaryFile()
|
||||
train_text = temp.name
|
||||
|
||||
_convert_to_bchar(args.transcript, train_text)
|
||||
|
||||
model_file = Path(model_prefix + ".model")
|
||||
if not model_file.is_file():
|
||||
spm.SentencePieceTrainer.train(
|
||||
input=train_text,
|
||||
vocab_size=vocab_size,
|
||||
model_type=model_type,
|
||||
model_prefix=model_prefix,
|
||||
input_sentence_size=input_sentence_size,
|
||||
character_coverage=character_coverage,
|
||||
user_defined_symbols=user_defined_symbols,
|
||||
unk_id=unk_id,
|
||||
bos_id=-1,
|
||||
eos_id=-1,
|
||||
)
|
||||
else:
|
||||
print(f"{model_file} exists - skipping")
|
||||
return
|
||||
|
||||
shutil.copyfile(model_file, f"{lang_dir}/bbpe.model")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -35,6 +35,15 @@ dl_dir=$PWD/download
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
# vocab size for sentence piece models.
|
||||
# It will generate data/lang_bbpe_xxx,
|
||||
# data/lang_bbpe_yyy if the array contains xxx, yyy
|
||||
vocab_sizes=(
|
||||
# 2000
|
||||
# 1000
|
||||
500
|
||||
)
|
||||
|
||||
# All files generated by this script are saved in "data".
|
||||
# You can safely remove "data" and rerun this script to regenerate it.
|
||||
mkdir -p data
|
||||
@ -47,20 +56,6 @@ log() {
|
||||
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||
log "stage -1: Download LM"
|
||||
# We assume that you have installed the git-lfs, if not, you could install it
|
||||
# using: `sudo apt-get install git-lfs && git-lfs install`
|
||||
git lfs 1>/dev/null 2>&1 || (echo "please install git-lfs, consider using: sudo apt-get install git-lfs && git-lfs install" && exit 1)
|
||||
|
||||
if [ ! -f $dl_dir/lm/3-gram.unpruned.arpa ]; then
|
||||
git clone https://huggingface.co/pkufool/aishell_lm $dl_dir/lm
|
||||
pushd $dl_dir/lm
|
||||
git lfs pull --include "3-gram.unpruned.arpa"
|
||||
popd
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "stage 0: Download data"
|
||||
|
||||
@ -134,7 +129,6 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
fi
|
||||
|
||||
lang_phone_dir=data/lang_phone
|
||||
lang_char_dir=data/lang_char
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Prepare phone based lang"
|
||||
mkdir -p $lang_phone_dir
|
||||
@ -183,45 +177,107 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
lang_char_dir=data/lang_char
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Prepare char based lang"
|
||||
mkdir -p $lang_char_dir
|
||||
# We reuse words.txt from phone based lexicon
|
||||
# so that the two can share G.pt later.
|
||||
cp $lang_phone_dir/words.txt $lang_char_dir
|
||||
|
||||
# The transcripts in training set, generated in stage 5
|
||||
cp $lang_phone_dir/transcript_words.txt $lang_char_dir/transcript_words.txt
|
||||
|
||||
cat $dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt |
|
||||
cut -d " " -f 2- | sed -e 's/[ \t\r\n]*//g' > $lang_char_dir/text
|
||||
cut -d " " -f 2- > $lang_char_dir/text
|
||||
|
||||
(echo '<eps> 0'; echo '!SIL 1'; echo '<SPOKEN_NOISE> 2'; echo '<UNK> 3';) \
|
||||
> $lang_char_dir/words.txt
|
||||
|
||||
cat $lang_char_dir/text | sed 's/ /\n/g' | sort -u | sed '/^$/d' \
|
||||
| awk '{print $1" "NR+3}' >> $lang_char_dir/words.txt
|
||||
|
||||
num_lines=$(< $lang_char_dir/words.txt wc -l)
|
||||
(echo "#0 $num_lines"; echo "<s> $(($num_lines + 1))"; echo "</s> $(($num_lines + 2))";) \
|
||||
>> $lang_char_dir/words.txt
|
||||
|
||||
if [ ! -f $lang_char_dir/L_disambig.pt ]; then
|
||||
./local/prepare_char.py
|
||||
./local/prepare_char.py --lang-dir $lang_char_dir
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
log "Stage 7: Prepare G"
|
||||
# We assume you have install kaldilm, if not, please install
|
||||
# it using: pip install kaldilm
|
||||
log "Stage 7: Prepare Byte BPE based lang"
|
||||
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/lang_bbpe_${vocab_size}
|
||||
mkdir -p $lang_dir
|
||||
|
||||
cp $lang_char_dir/words.txt $lang_dir
|
||||
cp $lang_char_dir/text $lang_dir
|
||||
|
||||
if [ ! -f $lang_dir/bbpe.model ]; then
|
||||
./local/train_bbpe_model.py \
|
||||
--lang-dir $lang_dir \
|
||||
--vocab-size $vocab_size \
|
||||
--transcript $lang_dir/text
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
||||
./local/prepare_lang_bbpe.py --lang-dir $lang_dir
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||
log "Stage 8: Prepare G"
|
||||
|
||||
mkdir -p data/lm
|
||||
if [ ! -f data/lm/G_3_gram.fst.txt ]; then
|
||||
|
||||
# Train LM on transcripts
|
||||
if [ ! -f data/lm/3-gram.unpruned.arpa ]; then
|
||||
python3 ./shared/make_kn_lm.py \
|
||||
-ngram-order 3 \
|
||||
-text $lang_char_dir/transcript_words.txt \
|
||||
-lm data/lm/3-gram.unpruned.arpa
|
||||
fi
|
||||
|
||||
# We assume you have install kaldilm, if not, please install
|
||||
# it using: pip install kaldilm
|
||||
if [ ! -f data/lm/G_3_gram_char.fst.txt ]; then
|
||||
# It is used in building HLG
|
||||
python3 -m kaldilm \
|
||||
--read-symbol-table="$lang_phone_dir/words.txt" \
|
||||
--disambig-symbol='#0' \
|
||||
--max-order=3 \
|
||||
$dl_dir/lm/3-gram.unpruned.arpa > data/lm/G_3_gram.fst.txt
|
||||
data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram_phone.fst.txt
|
||||
|
||||
python3 -m kaldilm \
|
||||
--read-symbol-table="$lang_char_dir/words.txt" \
|
||||
--disambig-symbol='#0' \
|
||||
--max-order=3 \
|
||||
data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram_char.fst.txt
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||
log "Stage 8: Compile HLG"
|
||||
./local/compile_hlg.py --lang-dir $lang_phone_dir
|
||||
./local/compile_hlg.py --lang-dir $lang_char_dir
|
||||
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
log "Stage 9: Compile LG & HLG"
|
||||
./local/compile_hlg.py --lang-dir $lang_phone_dir --lm G_3_gram_phone
|
||||
./local/compile_hlg.py --lang-dir $lang_char_dir --lm G_3_gram_char
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/lang_bbpe_${vocab_size}
|
||||
./local/compile_hlg.py --lang-dir $lang_dir --lm G_3_gram_char
|
||||
done
|
||||
|
||||
./local/compile_lg.py --lang-dir $lang_phone_dir --lm G_3_gram_phone
|
||||
./local/compile_lg.py --lang-dir $lang_char_dir --lm G_3_gram_char
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/lang_bbpe_${vocab_size}
|
||||
./local/compile_lg.py --lang-dir $lang_dir --lm G_3_gram_char
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
log "Stage 9: Generate LM training data"
|
||||
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
||||
log "Stage 10: Generate LM training data"
|
||||
|
||||
log "Processing char based data"
|
||||
out_dir=data/lm_training_char
|
||||
@ -267,8 +323,8 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
fi
|
||||
|
||||
|
||||
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
||||
log "Stage 10: Sort LM training data"
|
||||
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
||||
log "Stage 11: Sort LM training data"
|
||||
# Sort LM training data by sentence length in descending order
|
||||
# for ease of training.
|
||||
#
|
||||
@ -295,7 +351,7 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
||||
--out-statistics $out_dir/statistics-test.txt
|
||||
fi
|
||||
|
||||
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
||||
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
|
||||
log "Stage 11: Train RNN LM model"
|
||||
python ../../../icefall/rnn_lm/train.py \
|
||||
--start-epoch 0 \
|
||||
|
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/asr_datamodule.py
|
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/beam_search.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/beam_search.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/beam_search.py
|
819
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py
Executable file
819
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py
Executable file
@ -0,0 +1,819 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao,
|
||||
# Xiaoyu Yang,
|
||||
# Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless7_bbpe/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search (not recommended)
|
||||
./pruned_transducer_stateless7_bbpe/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless7_bbpe/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(4) fast beam search (one best)
|
||||
./pruned_transducer_stateless7_bbpe/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
|
||||
(5) fast beam search (nbest)
|
||||
./pruned_transducer_stateless7_bbpe/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
|
||||
(6) fast beam search (nbest oracle WER)
|
||||
./pruned_transducer_stateless7_bbpe/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_oracle \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
|
||||
(7) fast beam search (with LG)
|
||||
./pruned_transducer_stateless7_bbpe/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_LG \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import AishellAsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall import (
|
||||
LmScorer,
|
||||
NgramLm,
|
||||
byte_encode,
|
||||
smart_byte_decode,
|
||||
tokenize_by_CJK_char,
|
||||
)
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bbpe_500/bbpe.model",
|
||||
help="Path to the byte BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default="data/lang_bbpe_500",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
- fast_beam_search_LG
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
If you use fast_beam_search_LG, you have to specify
|
||||
`--lang-dir`, which should contain `LG.pt`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=20.0,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search,
|
||||
fast_beam_search_nbest, fast_beam_search_LG,
|
||||
and fast_beam_search_nbest_oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=0.25,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ilme-scale",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_LG.
|
||||
It specifies the scale for the internal language model estimation.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=64,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=200,
|
||||
help="""Number of paths for nbest decoding.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""Scale applied to lattice scores when computing nbest paths.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(smart_byte_decode(hyp).split())
|
||||
elif params.decoding_method == "fast_beam_search_LG":
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
subtract_ilme=True,
|
||||
ilme_scale=params.ilme_scale,
|
||||
)
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([word_table[i] for i in hyp])
|
||||
elif params.decoding_method == "fast_beam_search_nbest":
|
||||
hyp_tokens = fast_beam_search_nbest(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(smart_byte_decode(hyp).split())
|
||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||
ref_texts = []
|
||||
for tx in supervisions["text"]:
|
||||
ref_texts.append(byte_encode(tokenize_by_CJK_char(tx)))
|
||||
|
||||
hyp_tokens = fast_beam_search_nbest_oracle(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=sp.encode(ref_texts),
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(smart_byte_decode(hyp).split())
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(smart_byte_decode(hyp).split())
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(smart_byte_decode(hyp).split())
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(smart_byte_decode(sp.decode(hyp)).split())
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key = f"beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
key += f"max_states_{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
key += f"_num_paths_{params.num_paths}_"
|
||||
key += f"nbest_scale_{params.nbest_scale}"
|
||||
if "LG" in params.decoding_method:
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
key += f"_ilme_scale_{params.ilme_scale}"
|
||||
|
||||
return {key: hyps}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
LM:
|
||||
A neural network LM, used during shallow fusion
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 50
|
||||
else:
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
word_table=word_table,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
AishellAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"fast_beam_search_LG",
|
||||
"fast_beam_search_nbest",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
"modified_beam_search",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||
params.suffix += f"-num-paths-{params.num_paths}"
|
||||
if "LG" in params.decoding_method:
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
params.suffix += f"-ilme-scale-{params.ilme_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> are defined in local/train_bbpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
if params.decoding_method == "fast_beam_search_LG":
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
word_table = lexicon.word_table
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
word_table = None
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
else:
|
||||
decoding_graph = None
|
||||
word_table = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
aishell = AishellAsrDataModule(args)
|
||||
|
||||
test_cuts = aishell.test_cuts()
|
||||
dev_cuts = aishell.valid_cuts()
|
||||
|
||||
test_dl = aishell.test_dataloaders(test_cuts)
|
||||
dev_dl = aishell.test_dataloaders(dev_cuts)
|
||||
|
||||
test_sets = ["test", "dev"]
|
||||
test_dls = [test_dl, dev_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dls):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decoder.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decoder.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py
|
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/encoder_interface.py
|
320
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/export.py
Executable file
320
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/export.py
Executable file
@ -0,0 +1,320 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This script converts several saved checkpoints
|
||||
# to a single one using model averaging.
|
||||
"""
|
||||
|
||||
Usage:
|
||||
|
||||
(1) Export to torchscript model using torch.jit.script()
|
||||
|
||||
./pruned_transducer_stateless7_bbpe/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--bpe-model data/lang_bbpe_500/bbpe.model \
|
||||
--epoch 30 \
|
||||
--avg 9 \
|
||||
--jit 1
|
||||
|
||||
It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
|
||||
load it by `torch.jit.load("cpu_jit.pt")`.
|
||||
|
||||
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
|
||||
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
|
||||
|
||||
Check
|
||||
https://github.com/k2-fsa/sherpa
|
||||
for how to use the exported models outside of icefall.
|
||||
|
||||
(2) Export `model.state_dict()`
|
||||
|
||||
./pruned_transducer_stateless7_bbpe/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--bpe-model data/lang_bbpe_500/bbpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10
|
||||
|
||||
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
|
||||
load it by `icefall.checkpoint.load_checkpoint()`.
|
||||
|
||||
To use the generated file with `pruned_transducer_stateless7_bbpe/decode.py`,
|
||||
you can do:
|
||||
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
|
||||
cd /path/to/egs/aishell/ASR
|
||||
./pruned_transducer_stateless7_bbpe/decode.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search \
|
||||
--bpe-model data/lang_bbpe_500/bbpe.model
|
||||
|
||||
Check ./pretrained.py for its usage.
|
||||
|
||||
Note: If you don't want to train a model from scratch, we have
|
||||
provided one for you. You can get it at
|
||||
|
||||
https://huggingface.co/pkufool/icefall_asr_aishell_pruned_transducer_stateless7_bbpe
|
||||
|
||||
with the following commands:
|
||||
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
git clone https://huggingface.co/pkufool/icefall_asr_aishell_pruned_transducer_stateless7_bbpe
|
||||
# You will find the pre-trained model in icefall_asr_aishell_pruned_transducer_stateless7_bbpe/exp
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7_bbpe/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bbpe_500/bbpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to save a model after applying torch.jit.script.
|
||||
It will generate a file named cpu_jit.pt
|
||||
|
||||
Check ./jit_pretrained.py for how to use it.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bbpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
model.to(device)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
if params.jit is True:
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
# torch scriptabe.
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
filename = params.exp_dir / "cpu_jit.pt"
|
||||
model.save(str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
else:
|
||||
logging.info("Not using torchscript. Export model.state_dict()")
|
||||
# Save it using a format so that it can be loaded
|
||||
# by :func:`load_checkpoint`
|
||||
filename = params.exp_dir / "pretrained.pt"
|
||||
torch.save({"model": model.state_dict()}, str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
274
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py
Executable file
274
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py
Executable file
@ -0,0 +1,274 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads torchscript models, exported by `torch.jit.script()`
|
||||
and uses them to decode waves.
|
||||
You can use the following command to get the exported models:
|
||||
|
||||
./pruned_transducer_stateless7_bbpe/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--bpe-model data/lang_bbpe_500/bbpe.model \
|
||||
--epoch 49 \
|
||||
--avg 28 \
|
||||
--jit 1
|
||||
|
||||
Usage of this script:
|
||||
|
||||
./pruned_transducer_stateless7_bbpe/jit_pretrained.py \
|
||||
--nn-model-filename ./pruned_transducer_stateless7_bbpe/exp/cpu_jit.pt \
|
||||
--bpe-model ./data/lang_bbpe_500/bbpe.model \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall import smart_byte_decode
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nn-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the torchscript model cpu_jit.pt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float = 16000
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: torch.jit.ScriptModule,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
) -> List[List[int]]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
model:
|
||||
The transducer model.
|
||||
encoder_out:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
encoder_out_lens:
|
||||
A 1-D tensor of shape (N,).
|
||||
Returns:
|
||||
Return the decoded results for each utterance.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
|
||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
input=encoder_out,
|
||||
lengths=encoder_out_lens.cpu(),
|
||||
batch_first=True,
|
||||
enforce_sorted=False,
|
||||
)
|
||||
|
||||
device = encoder_out.device
|
||||
blank_id = 0 # hard-code to 0
|
||||
|
||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||
N = encoder_out.size(0)
|
||||
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
assert N == batch_size_list[0], (N, batch_size_list)
|
||||
|
||||
context_size = model.decoder.context_size
|
||||
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
hyps,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
) # (N, context_size)
|
||||
|
||||
decoder_out = model.decoder(
|
||||
decoder_input,
|
||||
need_pad=torch.tensor([False]),
|
||||
).squeeze(1)
|
||||
|
||||
offset = 0
|
||||
for batch_size in batch_size_list:
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = packed_encoder_out.data[start:end]
|
||||
current_encoder_out = current_encoder_out
|
||||
# current_encoder_out's shape: (batch_size, encoder_out_dim)
|
||||
offset = end
|
||||
|
||||
decoder_out = decoder_out[:batch_size]
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
)
|
||||
# logits'shape (batch_size, vocab_size)
|
||||
|
||||
assert logits.ndim == 2, logits.shape
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
for i, v in enumerate(y):
|
||||
if v != blank_id:
|
||||
hyps[i].append(v)
|
||||
emitted = True
|
||||
if emitted:
|
||||
# update decoder output
|
||||
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||
decoder_input = torch.tensor(
|
||||
decoder_input,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
decoder_out = model.decoder(
|
||||
decoder_input,
|
||||
need_pad=torch.tensor([False]),
|
||||
)
|
||||
decoder_out = decoder_out.squeeze(1)
|
||||
|
||||
sorted_ans = [h[context_size:] for h in hyps]
|
||||
ans = []
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
model = torch.jit.load(args.nn_model_filename)
|
||||
|
||||
model.eval()
|
||||
|
||||
model.to(device)
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(args.bpe_model)
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = 16000
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=args.sound_files,
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(
|
||||
features,
|
||||
batch_first=True,
|
||||
padding_value=math.log(1e-10),
|
||||
)
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=features,
|
||||
x_lens=feature_lengths,
|
||||
)
|
||||
|
||||
hyps = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
s = "\n"
|
||||
for filename, hyp in zip(args.sound_files, hyps):
|
||||
words = smart_byte_decode(sp.decode(hyp))
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/joiner.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/joiner.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py
|
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/model.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/model.py
|
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/optim.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/optim.py
|
345
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py
Executable file
345
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py
Executable file
@ -0,0 +1,345 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads a checkpoint and uses it to decode waves.
|
||||
You can generate the checkpoint with the following command:
|
||||
|
||||
./pruned_transducer_stateless7_bbpe/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--bpe-model data/lang_bbpe_500/bbpe.model \
|
||||
--epoch 48 \
|
||||
--avg 29
|
||||
|
||||
Usage of this script:
|
||||
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless7_bbpe/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bbpe_500/bbpe.model \
|
||||
--method greedy_search \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(2) modified beam search
|
||||
./pruned_transducer_stateless7_bbpe/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bbpe_500/bbpe.model \
|
||||
--method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(3) fast beam search
|
||||
./pruned_transducer_stateless7_bbpe/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bbpe_500/bbpe.model \
|
||||
--method fast_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
Note: ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt is generated by
|
||||
./pruned_transducer_stateless7_bbpe/export.py
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall import smart_byte_decode
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the checkpoint. "
|
||||
"The checkpoint is assumed to be saved by "
|
||||
"icefall.checkpoint.save_checkpoint().",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=4,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame. Used only when
|
||||
--method is greedy_search.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
logging.info("Creating model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = params.sample_rate
|
||||
opts.mel_opts.num_bins = params.feature_dim
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {params.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
|
||||
|
||||
num_waves = encoder_out.size(0)
|
||||
hyps = []
|
||||
msg = f"Using {params.method}"
|
||||
if params.method == "beam_search":
|
||||
msg += f" with beam size {params.beam_size}"
|
||||
logging.info(msg)
|
||||
|
||||
if params.method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(smart_byte_decode(hyp).split())
|
||||
elif params.method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(smart_byte_decode(hyp).split())
|
||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(smart_byte_decode(hyp).split())
|
||||
else:
|
||||
for i in range(num_waves):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported method: {params.method}")
|
||||
|
||||
hyps.append(smart_byte_decode(sp.decode(hyp)).split())
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
words = " ".join(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py
|
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/test_model.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/test_model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/test_model.py
|
1261
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py
Executable file
1261
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py
Executable file
File diff suppressed because it is too large
Load Diff
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/zipformer.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/zipformer.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py
|
@ -21,7 +21,7 @@ import inspect
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||
from lhotse.dataset import (
|
||||
@ -181,7 +181,16 @@ class AishellAsrDataModule:
|
||||
"with training dataset. ",
|
||||
)
|
||||
|
||||
def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
|
||||
def train_dataloaders(
|
||||
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
|
||||
@ -277,6 +286,10 @@ class AishellAsrDataModule:
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
@ -325,7 +338,7 @@ class AishellAsrDataModule:
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
logging.info("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
|
@ -1,5 +1,76 @@
|
||||
## Results
|
||||
|
||||
### pruned_transducer_stateless7 (zipformer + multidataset(LibriSpeech + GigaSpeech + CommonVoice 13.0))
|
||||
|
||||
See <https://github.com/k2-fsa/icefall/pull/1010> for more details.
|
||||
|
||||
[pruned_transducer_stateless7](./pruned_transducer_stateless7)
|
||||
|
||||
The tensorboard log can be found at
|
||||
<https://tensorboard.dev/experiment/SwdJoHgZSZWn8ph9aJLb8g/>
|
||||
|
||||
You can find a pretrained model, training logs, decoding logs, and decoding
|
||||
results at:
|
||||
<https://huggingface.co/yfyeung/icefall-asr-multidataset-pruned_transducer_stateless7-2023-05-04>
|
||||
|
||||
You can use <https://github.com/k2-fsa/sherpa> to deploy it.
|
||||
|
||||
Number of model parameters: 70369391, i.e., 70.37 M
|
||||
|
||||
| decoding method | test-clean | test-other | comment |
|
||||
|----------------------|------------|------------|--------------------|
|
||||
| greedy_search | 1.91 | 4.06 | --epoch 30 --avg 7 |
|
||||
| modified_beam_search | 1.90 | 3.99 | --epoch 30 --avg 7 |
|
||||
| fast_beam_search | 1.90 | 3.98 | --epoch 30 --avg 7 |
|
||||
|
||||
|
||||
The training commands are:
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||
|
||||
./pruned_transducer_stateless7/train.py \
|
||||
--world-size 8 \
|
||||
--num-epochs 30 \
|
||||
--use-multidataset 1 \
|
||||
--use-fp16 1 \
|
||||
--max-duration 750 \
|
||||
--exp-dir pruned_transducer_stateless7/exp
|
||||
```
|
||||
|
||||
The decoding commands are:
|
||||
```bash
|
||||
# greedy_search
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 30 \
|
||||
--avg 7 \
|
||||
--use-averaged-model 1 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
# modified_beam_search
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 30 \
|
||||
--avg 7 \
|
||||
--use-averaged-model 1 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
# fast_beam_search
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 30 \
|
||||
--avg 7 \
|
||||
--use-averaged-model 1 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
```
|
||||
|
||||
### Streaming Zipformer-Transducer (Pruned Stateless Transducer + Streaming Zipformer + Multi-Dataset)
|
||||
|
||||
#### [pruned_transducer_stateless7_streaming_multi](./pruned_transducer_stateless7_streaming_multi)
|
||||
@ -215,11 +286,12 @@ done
|
||||
We also support decoding with neural network LMs. After combining with language models, the WERs are
|
||||
| decoding method | chunk size | test-clean | test-other | comment | decoding mode |
|
||||
|----------------------|------------|------------|------------|---------------------|----------------------|
|
||||
| modified beam search | 320ms | 3.11 | 7.93 | --epoch 30 --avg 9 | simulated streaming |
|
||||
| modified beam search + RNNLM shallow fusion | 320ms | 2.58 | 6.65 | --epoch 30 --avg 9 | simulated streaming |
|
||||
| modified beam search + RNNLM nbest rescore | 320ms | 2.59 | 6.86 | --epoch 30 --avg 9 | simulated streaming |
|
||||
| `modified_beam_search` | 320ms | 3.11 | 7.93 | --epoch 30 --avg 9 | simulated streaming |
|
||||
| `modified_beam_search_lm_shallow_fusion` | 320ms | 2.58 | 6.65 | --epoch 30 --avg 9 | simulated streaming |
|
||||
| `modified_beam_search_lm_rescore` | 320ms | 2.59 | 6.86 | --epoch 30 --avg 9 | simulated streaming |
|
||||
| `modified_beam_search_lm_rescore_LODR` | 320ms | 2.52 | 6.73 | --epoch 30 --avg 9 | simulated streaming |
|
||||
|
||||
Please use the following command for RNNLM shallow fusion:
|
||||
Please use the following command for `modified_beam_search_lm_shallow_fusion`:
|
||||
```bash
|
||||
for lm_scale in $(seq 0.15 0.01 0.38); do
|
||||
for beam_size in 4 8 12; do
|
||||
@ -246,7 +318,7 @@ for lm_scale in $(seq 0.15 0.01 0.38); do
|
||||
done
|
||||
```
|
||||
|
||||
Please use the following command for RNNLM rescore:
|
||||
Please use the following command for `modified_beam_search_lm_rescore`:
|
||||
```bash
|
||||
./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 30 \
|
||||
@ -268,7 +340,32 @@ Please use the following command for RNNLM rescore:
|
||||
--lm-vocab-size 500
|
||||
```
|
||||
|
||||
A well-trained RNNLM can be found here: <https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main>.
|
||||
Please use the following command for `modified_beam_search_lm_rescore_LODR`:
|
||||
```bash
|
||||
./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 30 \
|
||||
--avg 9 \
|
||||
--use-averaged-model True \
|
||||
--beam-size 8 \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method modified_beam_search_lm_rescore_LODR \
|
||||
--use-shallow-fusion 0 \
|
||||
--lm-type rnn \
|
||||
--lm-exp-dir rnn_lm/exp \
|
||||
--lm-epoch 99 \
|
||||
--lm-avg 1 \
|
||||
--rnn-lm-embedding-dim 2048 \
|
||||
--rnn-lm-hidden-dim 2048 \
|
||||
--rnn-lm-num-layers 3 \
|
||||
--lm-vocab-size 500 \
|
||||
--tokens-ngram 2 \
|
||||
--backoff-id 500
|
||||
```
|
||||
|
||||
A well-trained RNNLM can be found here: <https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main>. The bi-gram used in LODR decoding
|
||||
can be found here: <https://huggingface.co/marcoyang/librispeech_bigram>.
|
||||
|
||||
|
||||
#### Smaller model
|
||||
|
@ -109,10 +109,11 @@ def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
|
||||
|
||||
logging.info("Removing disambiguation symbols on LG")
|
||||
|
||||
LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||
# See https://github.com/k2-fsa/k2/issues/874
|
||||
# for why we need to set LG.properties to None
|
||||
LG.__dict__["_properties"] = None
|
||||
# LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||
# see https://github.com/k2-fsa/k2/pull/1140
|
||||
labels = LG.labels
|
||||
labels[labels >= first_token_disambig_id] = 0
|
||||
LG.labels = labels
|
||||
|
||||
assert isinstance(LG.aux_labels, k2.RaggedTensor)
|
||||
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
|
||||
|
@ -45,11 +45,18 @@ def get_args():
|
||||
help="""Input and output directory.
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lm",
|
||||
type=str,
|
||||
default="G_3_gram",
|
||||
help="""Stem name for LM used in HLG compiling.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compile_LG(lang_dir: str) -> k2.Fsa:
|
||||
def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
|
||||
"""
|
||||
Args:
|
||||
lang_dir:
|
||||
@ -61,15 +68,15 @@ def compile_LG(lang_dir: str) -> k2.Fsa:
|
||||
lexicon = Lexicon(lang_dir)
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||
|
||||
if Path("data/lm/G_3_gram.pt").is_file():
|
||||
logging.info("Loading pre-compiled G_3_gram")
|
||||
d = torch.load("data/lm/G_3_gram.pt")
|
||||
if Path(f"data/lm/{lm}.pt").is_file():
|
||||
logging.info(f"Loading pre-compiled {lm}")
|
||||
d = torch.load(f"data/lm/{lm}.pt")
|
||||
G = k2.Fsa.from_dict(d)
|
||||
else:
|
||||
logging.info("Loading G_3_gram.fst.txt")
|
||||
with open("data/lm/G_3_gram.fst.txt") as f:
|
||||
logging.info(f"Loading {lm}.fst.txt")
|
||||
with open(f"data/lm/{lm}.fst.txt") as f:
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
torch.save(G.as_dict(), "data/lm/G_3_gram.pt")
|
||||
torch.save(G.as_dict(), f"data/lm/{lm}.pt")
|
||||
|
||||
first_token_disambig_id = lexicon.token_table["#0"]
|
||||
first_word_disambig_id = lexicon.word_table["#0"]
|
||||
@ -96,10 +103,11 @@ def compile_LG(lang_dir: str) -> k2.Fsa:
|
||||
|
||||
logging.info("Removing disambiguation symbols on LG")
|
||||
|
||||
LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||
# See https://github.com/k2-fsa/k2/issues/874
|
||||
# for why we need to set LG.properties to None
|
||||
LG.__dict__["_properties"] = None
|
||||
# LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||
# see https://github.com/k2-fsa/k2/pull/1140
|
||||
labels = LG.labels
|
||||
labels[labels >= first_token_disambig_id] = 0
|
||||
LG.labels = labels
|
||||
|
||||
assert isinstance(LG.aux_labels, k2.RaggedTensor)
|
||||
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
|
||||
@ -126,7 +134,7 @@ def main():
|
||||
|
||||
logging.info(f"Processing {lang_dir}")
|
||||
|
||||
LG = compile_LG(lang_dir)
|
||||
LG = compile_LG(lang_dir, args.lm)
|
||||
logging.info(f"Saving LG.pt to {lang_dir}")
|
||||
torch.save(LG.as_dict(), f"{lang_dir}/LG.pt")
|
||||
|
||||
|
@ -41,9 +41,57 @@ import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from lhotse.utils import urlretrieve_progress
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
# This function is copied from lhotse
|
||||
def tqdm_urlretrieve_hook(t):
|
||||
"""Wraps tqdm instance.
|
||||
Don't forget to close() or __exit__()
|
||||
the tqdm instance once you're done with it (easiest using `with` syntax).
|
||||
Example
|
||||
-------
|
||||
>>> from urllib.request import urlretrieve
|
||||
>>> with tqdm(...) as t:
|
||||
... reporthook = tqdm_urlretrieve_hook(t)
|
||||
... urlretrieve(..., reporthook=reporthook)
|
||||
|
||||
Source: https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py
|
||||
"""
|
||||
last_b = [0]
|
||||
|
||||
def update_to(b=1, bsize=1, tsize=None):
|
||||
"""
|
||||
b : int, optional
|
||||
Number of blocks transferred so far [default: 1].
|
||||
bsize : int, optional
|
||||
Size of each block (in tqdm units) [default: 1].
|
||||
tsize : int, optional
|
||||
Total size (in tqdm units). If [default: None] or -1,
|
||||
remains unchanged.
|
||||
"""
|
||||
if tsize not in (None, -1):
|
||||
t.total = tsize
|
||||
displayed = t.update((b - last_b[0]) * bsize)
|
||||
last_b[0] = b
|
||||
return displayed
|
||||
|
||||
return update_to
|
||||
|
||||
|
||||
# This function is copied from lhotse
|
||||
def urlretrieve_progress(url, filename=None, data=None, desc=None):
|
||||
"""
|
||||
Works exactly like urllib.request.urlretrieve, but attaches a tqdm hook to
|
||||
display a progress bar of the download.
|
||||
Use "desc" argument to display a user-readable string that informs what is
|
||||
being downloaded.
|
||||
"""
|
||||
from urllib.request import urlretrieve
|
||||
|
||||
with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=desc) as t:
|
||||
reporthook = tqdm_urlretrieve_hook(t)
|
||||
return urlretrieve(url=url, filename=filename, reporthook=reporthook, data=data)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
@ -184,7 +184,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
./shared/convert-k2-to-openfst.py \
|
||||
--olabels aux_labels \
|
||||
$lang_dir/L_disambig.pt \
|
||||
$lang_dir/disambig_L.fst
|
||||
$lang_dir/L_disambig.fst
|
||||
fi
|
||||
fi
|
||||
|
||||
|
@ -198,7 +198,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
./shared/convert-k2-to-openfst.py \
|
||||
--olabels aux_labels \
|
||||
$lang_dir/L_disambig.pt \
|
||||
$lang_dir/disambig_L.fst
|
||||
$lang_dir/L_disambig.fst
|
||||
fi
|
||||
fi
|
||||
|
||||
@ -328,46 +328,3 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
||||
./prepare_common_voice.sh
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
||||
log "Stage 11: Create multidataset"
|
||||
split_dir=data/fbank/multidataset_split_${num_splits}
|
||||
if [ ! -f data/fbank/multidataset_split/.multidataset.done ]; then
|
||||
mkdir -p $split_dir/multidataset
|
||||
log "Split LibriSpeech"
|
||||
if [ ! -f $split_dir/.librispeech_split.done ]; then
|
||||
lhotse split $num_splits ./data/fbank/librispeech_cuts_train-all-shuf.jsonl.gz $split_dir
|
||||
touch $split_dir/.librispeech_split.done
|
||||
fi
|
||||
|
||||
if [[ "${multidataset[@]}" =~ "gigaspeech" ]]; then
|
||||
log "Split GigaSpeech XL"
|
||||
if [ ! -f $split_dir/.gigaspeech_XL_split.done ]; then
|
||||
cd $split_dir
|
||||
ln -sv ../gigaspeech_XL_split_2000/gigaspeech_cuts_XL.*.jsonl.gz .
|
||||
cd ../../..
|
||||
touch $split_dir/.gigaspeech_XL_split.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ "${multidataset[@]}" =~ "commonvoice" ]]; then
|
||||
log "Split CommonVoice"
|
||||
if [ ! -f $split_dir/.cv-en_train_split.done ]; then
|
||||
lhotse split $num_splits ./data/en/fbank/cv-en_cuts_train.jsonl.gz $split_dir
|
||||
touch $split_dir/.cv-en_train_split.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ ! -f $split_dir/.multidataset_mix.done ]; then
|
||||
log "Mix multidataset"
|
||||
for ((seq=1; seq<=$num_splits; seq++)); do
|
||||
fseq=$(printf "%04d" $seq)
|
||||
gunzip -c $split_dir/*.*${fseq}.jsonl.gz | \
|
||||
shuf | gzip -c > $split_dir/multidataset/multidataset_cuts_train.${fseq}.jsonl.gz
|
||||
done
|
||||
touch $split_dir/.multidataset_mix.done
|
||||
fi
|
||||
|
||||
touch data/fbank/multidataset_split/.multidataset.done
|
||||
fi
|
||||
fi
|
||||
|
@ -47,6 +47,8 @@ def fast_beam_search_one_best(
|
||||
max_states: int,
|
||||
max_contexts: int,
|
||||
temperature: float = 1.0,
|
||||
subtract_ilme: bool = False,
|
||||
ilme_scale: float = 0.1,
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[List[int]], DecodingResults]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
@ -88,6 +90,8 @@ def fast_beam_search_one_best(
|
||||
max_states=max_states,
|
||||
max_contexts=max_contexts,
|
||||
temperature=temperature,
|
||||
subtract_ilme=subtract_ilme,
|
||||
ilme_scale=ilme_scale,
|
||||
)
|
||||
|
||||
best_path = one_best_decoding(lattice)
|
||||
@ -428,6 +432,8 @@ def fast_beam_search(
|
||||
max_states: int,
|
||||
max_contexts: int,
|
||||
temperature: float = 1.0,
|
||||
subtract_ilme: bool = False,
|
||||
ilme_scale: float = 0.1,
|
||||
) -> k2.Fsa:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
@ -498,6 +504,17 @@ def fast_beam_search(
|
||||
)
|
||||
logits = logits.squeeze(1).squeeze(1)
|
||||
log_probs = (logits / temperature).log_softmax(dim=-1)
|
||||
if subtract_ilme:
|
||||
ilme_logits = model.joiner(
|
||||
torch.zeros_like(
|
||||
current_encoder_out, device=current_encoder_out.device
|
||||
).unsqueeze(2),
|
||||
decoder_out.unsqueeze(1),
|
||||
project_input=False,
|
||||
)
|
||||
ilme_logits = ilme_logits.squeeze(1).squeeze(1)
|
||||
ilme_log_probs = (ilme_logits / temperature).log_softmax(dim=-1)
|
||||
log_probs -= ilme_scale * ilme_log_probs
|
||||
decoding_streams.advance(log_probs)
|
||||
decoding_streams.terminate_and_flush_to_streams()
|
||||
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
|
||||
@ -1244,7 +1261,7 @@ def modified_beam_search_lm_rescore(
|
||||
|
||||
# get the best hyp with different lm_scale
|
||||
for lm_scale in lm_scale_list:
|
||||
key = f"nnlm_scale_{lm_scale}"
|
||||
key = f"nnlm_scale_{lm_scale:.2f}"
|
||||
tot_scores = am_scores.values + lm_scores * lm_scale
|
||||
ragged_tot_scores = k2.RaggedTensor(shape=am_scores.shape, value=tot_scores)
|
||||
max_indexes = ragged_tot_scores.argmax().tolist()
|
||||
@ -1257,6 +1274,222 @@ def modified_beam_search_lm_rescore(
|
||||
return ans
|
||||
|
||||
|
||||
def modified_beam_search_lm_rescore_LODR(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
LM: LmScorer,
|
||||
LODR_lm: NgramLm,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
lm_scale_list: List[int],
|
||||
beam: int = 4,
|
||||
temperature: float = 1.0,
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[List[int]], DecodingResults]:
|
||||
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
||||
Rescore the final results with RNNLM and return the one with the highest score
|
||||
|
||||
Args:
|
||||
model:
|
||||
The transducer model.
|
||||
encoder_out:
|
||||
Output from the encoder. Its shape is (N, T, C).
|
||||
encoder_out_lens:
|
||||
A 1-D tensor of shape (N,), containing number of valid frames in
|
||||
encoder_out before padding.
|
||||
beam:
|
||||
Number of active paths during the beam search.
|
||||
temperature:
|
||||
Softmax temperature.
|
||||
LM:
|
||||
A neural network language model
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
Returns:
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
"""
|
||||
assert encoder_out.ndim == 3, encoder_out.shape
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
|
||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
input=encoder_out,
|
||||
lengths=encoder_out_lens.cpu(),
|
||||
batch_first=True,
|
||||
enforce_sorted=False,
|
||||
)
|
||||
|
||||
blank_id = model.decoder.blank_id
|
||||
unk_id = getattr(model, "unk_id", blank_id)
|
||||
context_size = model.decoder.context_size
|
||||
device = next(model.parameters()).device
|
||||
|
||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||
N = encoder_out.size(0)
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
assert N == batch_size_list[0], (N, batch_size_list)
|
||||
|
||||
B = [HypothesisList() for _ in range(N)]
|
||||
for i in range(N):
|
||||
B[i].add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
timestamp=[],
|
||||
)
|
||||
)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||
|
||||
offset = 0
|
||||
finalized_B = []
|
||||
for (t, batch_size) in enumerate(batch_size_list):
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = encoder_out.data[start:end]
|
||||
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
|
||||
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
|
||||
offset = end
|
||||
|
||||
finalized_B = B[batch_size:] + finalized_B
|
||||
B = B[:batch_size]
|
||||
|
||||
hyps_shape = get_hyps_shape(B).to(device)
|
||||
|
||||
A = [list(b) for b in B]
|
||||
B = [HypothesisList() for _ in range(batch_size)]
|
||||
|
||||
ys_log_probs = torch.cat(
|
||||
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
|
||||
) # (num_hyps, 1)
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
) # (num_hyps, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
|
||||
|
||||
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
|
||||
# as index, so we use `to(torch.int64)` below.
|
||||
current_encoder_out = torch.index_select(
|
||||
current_encoder_out,
|
||||
dim=0,
|
||||
index=hyps_shape.row_ids(1).to(torch.int64),
|
||||
) # (num_hyps, 1, 1, encoder_out_dim)
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
project_input=False,
|
||||
) # (num_hyps, 1, 1, vocab_size)
|
||||
|
||||
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
|
||||
|
||||
log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size)
|
||||
|
||||
log_probs.add_(ys_log_probs)
|
||||
|
||||
vocab_size = log_probs.size(-1)
|
||||
|
||||
log_probs = log_probs.reshape(-1)
|
||||
|
||||
row_splits = hyps_shape.row_splits(1) * vocab_size
|
||||
log_probs_shape = k2.ragged.create_ragged_shape2(
|
||||
row_splits=row_splits, cached_tot_size=log_probs.numel()
|
||||
)
|
||||
ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
|
||||
|
||||
for i in range(batch_size):
|
||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||
|
||||
for k in range(len(topk_hyp_indexes)):
|
||||
hyp_idx = topk_hyp_indexes[k]
|
||||
hyp = A[i][hyp_idx]
|
||||
|
||||
new_ys = hyp.ys[:]
|
||||
new_token = topk_token_indexes[k]
|
||||
new_timestamp = hyp.timestamp[:]
|
||||
if new_token not in (blank_id, unk_id):
|
||||
new_ys.append(new_token)
|
||||
new_timestamp.append(t)
|
||||
|
||||
new_log_prob = topk_log_probs[k]
|
||||
new_hyp = Hypothesis(
|
||||
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
|
||||
)
|
||||
B[i].add(new_hyp)
|
||||
|
||||
B = B + finalized_B
|
||||
|
||||
# get the am_scores for n-best list
|
||||
hyps_shape = get_hyps_shape(B)
|
||||
am_scores = torch.tensor([hyp.log_prob.item() for b in B for hyp in b])
|
||||
am_scores = k2.RaggedTensor(value=am_scores, shape=hyps_shape).to(device)
|
||||
|
||||
# now LM rescore
|
||||
# prepare input data to LM
|
||||
candidate_seqs = [hyp.ys[context_size:] for b in B for hyp in b]
|
||||
possible_seqs = k2.RaggedTensor(candidate_seqs)
|
||||
row_splits = possible_seqs.shape.row_splits(1)
|
||||
sentence_token_lengths = row_splits[1:] - row_splits[:-1]
|
||||
possible_seqs_with_sos = add_sos(possible_seqs, sos_id=1)
|
||||
possible_seqs_with_eos = add_eos(possible_seqs, eos_id=1)
|
||||
sentence_token_lengths += 1
|
||||
|
||||
x = possible_seqs_with_sos.pad(mode="constant", padding_value=blank_id)
|
||||
y = possible_seqs_with_eos.pad(mode="constant", padding_value=blank_id)
|
||||
x = x.to(device).to(torch.int64)
|
||||
y = y.to(device).to(torch.int64)
|
||||
sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64)
|
||||
|
||||
lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths)
|
||||
assert lm_scores.ndim == 2
|
||||
lm_scores = -1 * lm_scores.sum(dim=1)
|
||||
|
||||
# now LODR scores
|
||||
import math
|
||||
|
||||
LODR_scores = []
|
||||
for seq in candidate_seqs:
|
||||
tokens = " ".join(sp.id_to_piece(seq))
|
||||
LODR_scores.append(LODR_lm.score(tokens))
|
||||
LODR_scores = torch.tensor(LODR_scores).to(device) * math.log(
|
||||
10
|
||||
) # arpa scores are 10-based
|
||||
assert lm_scores.shape == LODR_scores.shape
|
||||
|
||||
ans = {}
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
|
||||
LODR_scale_list = [0.05 * i for i in range(1, 20)]
|
||||
# get the best hyp with different lm_scale and lodr_scale
|
||||
for lm_scale in lm_scale_list:
|
||||
for lodr_scale in LODR_scale_list:
|
||||
key = f"nnlm_scale_{lm_scale:.2f}_lodr_scale_{lodr_scale:.2f}"
|
||||
tot_scores = (
|
||||
am_scores.values / lm_scale + lm_scores - LODR_scores * lodr_scale
|
||||
)
|
||||
ragged_tot_scores = k2.RaggedTensor(shape=am_scores.shape, value=tot_scores)
|
||||
max_indexes = ragged_tot_scores.argmax().tolist()
|
||||
unsorted_hyps = [candidate_seqs[idx] for idx in max_indexes]
|
||||
hyps = []
|
||||
for idx in unsorted_indices:
|
||||
hyps.append(unsorted_hyps[idx])
|
||||
|
||||
ans[key] = hyps
|
||||
return ans
|
||||
|
||||
|
||||
def _deprecated_modified_beam_search(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
|
681
egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py
Executable file
681
egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py
Executable file
@ -0,0 +1,681 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
|
||||
"""
|
||||
This script exports a transducer model from PyTorch to ONNX.
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/pkufool/icefall_librispeech_streaming_pruned_transducer_stateless5_20220729
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/pkufool/icefall_librispeech_streaming_pruned_transducer_stateless5_20220729
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "exp/pretrained-epoch-25-avg-5.pt"
|
||||
|
||||
cd exp
|
||||
ln -s pretrained-epoch-25-avg-5.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
2. Export the model to ONNX
|
||||
|
||||
./pruned_transducer_stateless5/export-onnx-streaming.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--use-averaged-model 0 \
|
||||
--exp-dir $repo/exp \
|
||||
--num-encoder-layers 18 \
|
||||
--dim-feedforward 2048 \
|
||||
--nhead 8 \
|
||||
--encoder-dim 512 \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512
|
||||
|
||||
It will generate the following 3 files inside $repo/exp:
|
||||
|
||||
- encoder-epoch-99-avg-1.onnx
|
||||
- decoder-epoch-99-avg-1.onnx
|
||||
- joiner-epoch-99-avg-1.onnx
|
||||
|
||||
See ./onnx_pretrained.py and ./onnx_check.py for how to
|
||||
use the exported ONNX models.
|
||||
|
||||
You can find the exported models in
|
||||
https://huggingface.co/csukuangfj/sherpa-onnx-streaming-conformer-en-2023-05-09
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import onnx
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from conformer import Conformer
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
from decoder import Decoder
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import setup_logger, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless5/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||
"""Add meta data to an ONNX model. It is changed in-place.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename of the ONNX model to be changed.
|
||||
meta_data:
|
||||
Key-value pairs.
|
||||
"""
|
||||
model = onnx.load(filename)
|
||||
for key, value in meta_data.items():
|
||||
meta = model.metadata_props.add()
|
||||
meta.key = key
|
||||
meta.value = value
|
||||
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
class OnnxEncoder(nn.Module):
|
||||
"""A wrapper for Conformer and the encoder_proj from the joiner"""
|
||||
|
||||
def __init__(self, encoder: Conformer, encoder_proj: nn.Linear):
|
||||
"""
|
||||
Args:
|
||||
encoder:
|
||||
A Conformer encoder.
|
||||
encoder_proj:
|
||||
The projection layer for encoder from the joiner.
|
||||
"""
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.encoder_proj = encoder_proj
|
||||
|
||||
self.num_encoder_layers = encoder.encoder_layers
|
||||
self.encoder_dim = encoder.d_model
|
||||
self.cnn_module_kernel = encoder.cnn_module_kernel
|
||||
|
||||
# Note you can tune these values
|
||||
self.left_context = 64 # after subsampling
|
||||
self.chunk_size = 16 # after subsampling
|
||||
self.right_context = 0 # after subsampling
|
||||
|
||||
subsampling_factor = 4
|
||||
self.pad_length = (self.right_context + 2) * subsampling_factor + 3
|
||||
|
||||
self.T = (self.chunk_size * subsampling_factor) + self.pad_length
|
||||
self.decode_chunk_len = self.chunk_size * subsampling_factor
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cached_attn: torch.Tensor,
|
||||
cached_conv: torch.Tensor,
|
||||
processed_lens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Please see the help information of Conformer.forward
|
||||
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, self.T, C)
|
||||
cached_attn:
|
||||
A 3-D tensor of shape
|
||||
(num_encoder_layers, self.left_context, N, self.encoder_dim)
|
||||
cached_conv:
|
||||
A 3-D tensor of shape
|
||||
(num_encoder_layers, self.cnn_module_kernel-1, N, self.encoder_dim)
|
||||
processed_lens:
|
||||
A 1-D tensor of shape (N,). It contains number of processed frames
|
||||
after subsampling. Its dtype is torch.int64.
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- encoder_out, A 3-D tensor of shape (N, self.chunk_size, joiner_dim)
|
||||
- new_cached_attn, it has the same shape as cached_attn
|
||||
- new_cached_conv, it has the same shape as cached_conv
|
||||
"""
|
||||
assert x.size(1) == self.T, (x.shape, self.T)
|
||||
N = x.size(0)
|
||||
x_lens = torch.full((N,), fill_value=self.T, device=x.device, dtype=torch.int64)
|
||||
|
||||
(
|
||||
encoder_out,
|
||||
_,
|
||||
[new_cached_attn, new_cached_conv],
|
||||
) = self.encoder.streaming_forward(
|
||||
x,
|
||||
x_lens,
|
||||
states=[cached_attn, cached_conv],
|
||||
processed_lens=processed_lens,
|
||||
left_context=self.left_context,
|
||||
right_context=self.right_context,
|
||||
chunk_size=self.chunk_size,
|
||||
)
|
||||
|
||||
encoder_out = self.encoder_proj(encoder_out)
|
||||
# Now encoder_out is of shape (N, T, joiner_dim)
|
||||
|
||||
return encoder_out, new_cached_attn, new_cached_conv
|
||||
|
||||
|
||||
class OnnxDecoder(nn.Module):
|
||||
"""A wrapper for Decoder and the decoder_proj from the joiner"""
|
||||
|
||||
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
|
||||
super().__init__()
|
||||
self.decoder = decoder
|
||||
self.decoder_proj = decoder_proj
|
||||
|
||||
def forward(self, y: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
A 2-D tensor of shape (N, context_size).
|
||||
Returns
|
||||
Return a 2-D tensor of shape (N, joiner_dim)
|
||||
"""
|
||||
need_pad = False
|
||||
decoder_output = self.decoder(y, need_pad=need_pad)
|
||||
decoder_output = decoder_output.squeeze(1)
|
||||
output = self.decoder_proj(decoder_output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class OnnxJoiner(nn.Module):
|
||||
"""A wrapper for the joiner"""
|
||||
|
||||
def __init__(self, output_linear: nn.Linear):
|
||||
super().__init__()
|
||||
self.output_linear = output_linear
|
||||
|
||||
def forward(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
decoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
Returns:
|
||||
Return a 2-D tensor of shape (N, vocab_size)
|
||||
"""
|
||||
logit = encoder_out + decoder_out
|
||||
logit = self.output_linear(torch.tanh(logit))
|
||||
return logit
|
||||
|
||||
|
||||
def export_encoder_model_onnx(
|
||||
encoder_model: OnnxEncoder,
|
||||
encoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the given encoder model to ONNX format.
|
||||
The exported model has two inputs:
|
||||
|
||||
- x, a tensor of shape (N, T, C); dtype is torch.float32
|
||||
- x_lens, a tensor of shape (N,); dtype is torch.int64
|
||||
|
||||
and it has two outputs:
|
||||
|
||||
- encoder_out, a tensor of shape (N, T', joiner_dim)
|
||||
- encoder_out_lens, a tensor of shape (N,)
|
||||
|
||||
Args:
|
||||
encoder_model:
|
||||
The input encoder model
|
||||
encoder_filename:
|
||||
The filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
N = 1
|
||||
x = torch.zeros(N, encoder_model.T, 80, dtype=torch.float32)
|
||||
cached_attn = torch.zeros(
|
||||
encoder_model.num_encoder_layers,
|
||||
encoder_model.left_context,
|
||||
N,
|
||||
encoder_model.encoder_dim,
|
||||
)
|
||||
cached_conv = torch.zeros(
|
||||
encoder_model.num_encoder_layers,
|
||||
encoder_model.cnn_module_kernel - 1,
|
||||
N,
|
||||
encoder_model.encoder_dim,
|
||||
)
|
||||
processed_lens = torch.zeros((N,), dtype=torch.int64)
|
||||
|
||||
torch.onnx.export(
|
||||
encoder_model,
|
||||
(x, cached_attn, cached_conv, processed_lens),
|
||||
encoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["x", "cached_attn", "cached_conv", "processed_lens"],
|
||||
output_names=["encoder_out", "new_cached_attn", "new_cached_conv"],
|
||||
dynamic_axes={
|
||||
"x": {0: "N"},
|
||||
"cached_attn": {2: "N"},
|
||||
"cached_conv": {2: "N"},
|
||||
"processed_lens": {0: "N"},
|
||||
"encoder_out": {0: "N"},
|
||||
"new_cached_attn": {2: "N"},
|
||||
"new_cached_conv": {2: "N"},
|
||||
},
|
||||
)
|
||||
|
||||
meta_data = {
|
||||
"model_type": "conformer",
|
||||
"version": "1",
|
||||
"model_author": "k2-fsa",
|
||||
"comment": "stateless5",
|
||||
"pad_length": str(encoder_model.pad_length),
|
||||
"decode_chunk_len": str(encoder_model.decode_chunk_len),
|
||||
"encoder_dim": str(encoder_model.encoder_dim),
|
||||
"num_encoder_layers": str(encoder_model.num_encoder_layers),
|
||||
"cnn_module_kernel": str(encoder_model.cnn_module_kernel),
|
||||
"left_context": str(encoder_model.left_context),
|
||||
"right_context": str(encoder_model.right_context),
|
||||
"chunk_size": str(encoder_model.chunk_size),
|
||||
"T": str(encoder_model.T),
|
||||
}
|
||||
logging.info(f"meta_data: {meta_data}")
|
||||
|
||||
add_meta_data(filename=encoder_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
def export_decoder_model_onnx(
|
||||
decoder_model: OnnxDecoder,
|
||||
decoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the decoder model to ONNX format.
|
||||
|
||||
The exported model has one input:
|
||||
|
||||
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||
|
||||
and has one output:
|
||||
|
||||
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The decoder model to be exported.
|
||||
decoder_filename:
|
||||
Filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
context_size = decoder_model.decoder.context_size
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
decoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["y"],
|
||||
output_names=["decoder_out"],
|
||||
dynamic_axes={
|
||||
"y": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
|
||||
meta_data = {
|
||||
"context_size": str(context_size),
|
||||
"vocab_size": str(vocab_size),
|
||||
}
|
||||
add_meta_data(filename=decoder_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
def export_joiner_model_onnx(
|
||||
joiner_model: nn.Module,
|
||||
joiner_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the joiner model to ONNX format.
|
||||
The exported joiner model has two inputs:
|
||||
|
||||
- encoder_out: a tensor of shape (N, joiner_dim)
|
||||
- decoder_out: a tensor of shape (N, joiner_dim)
|
||||
|
||||
and produces one output:
|
||||
|
||||
- logit: a tensor of shape (N, vocab_size)
|
||||
"""
|
||||
joiner_dim = joiner_model.output_linear.weight.shape[1]
|
||||
logging.info(f"joiner dim: {joiner_dim}")
|
||||
|
||||
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||
|
||||
torch.onnx.export(
|
||||
joiner_model,
|
||||
(projected_encoder_out, projected_decoder_out),
|
||||
joiner_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=[
|
||||
"encoder_out",
|
||||
"decoder_out",
|
||||
],
|
||||
output_names=["logit"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
"logit": {0: "N"},
|
||||
},
|
||||
)
|
||||
meta_data = {
|
||||
"joiner_dim": str(joiner_dim),
|
||||
}
|
||||
add_meta_data(filename=joiner_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
if not params.causal_convolution:
|
||||
logging.info("Seting causal_convolution to True for exporting streaming models")
|
||||
params.causal_convolution = True
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log-export/log-export-onnx")
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
num_param = sum(p.numel() for p in model.parameters())
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
model.to(device)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
|
||||
encoder = OnnxEncoder(
|
||||
encoder=model.encoder,
|
||||
encoder_proj=model.joiner.encoder_proj,
|
||||
)
|
||||
|
||||
decoder = OnnxDecoder(
|
||||
decoder=model.decoder,
|
||||
decoder_proj=model.joiner.decoder_proj,
|
||||
)
|
||||
|
||||
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
|
||||
|
||||
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
|
||||
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
|
||||
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
|
||||
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
|
||||
logging.info(f"encoder parameters: {encoder_num_param}")
|
||||
logging.info(f"decoder parameters: {decoder_num_param}")
|
||||
logging.info(f"joiner parameters: {joiner_num_param}")
|
||||
logging.info(f"total parameters: {total_num_param}")
|
||||
|
||||
if params.iter > 0:
|
||||
suffix = f"iter-{params.iter}"
|
||||
else:
|
||||
suffix = f"epoch-{params.epoch}"
|
||||
|
||||
suffix += f"-avg-{params.avg}"
|
||||
|
||||
opset_version = 13
|
||||
|
||||
logging.info("Exporting encoder")
|
||||
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
|
||||
export_encoder_model_onnx(
|
||||
encoder,
|
||||
encoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported encoder to {encoder_filename}")
|
||||
|
||||
logging.info("Exporting decoder")
|
||||
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
|
||||
export_decoder_model_onnx(
|
||||
decoder,
|
||||
decoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported decoder to {decoder_filename}")
|
||||
|
||||
logging.info("Exporting joiner")
|
||||
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
|
||||
export_joiner_model_onnx(
|
||||
joiner,
|
||||
joiner_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported joiner to {joiner_filename}")
|
||||
|
||||
# Generate int8 quantization models
|
||||
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||
|
||||
logging.info("Generate int8 quantization models")
|
||||
|
||||
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=encoder_filename,
|
||||
model_output=encoder_filename_int8,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=decoder_filename,
|
||||
model_output=decoder_filename_int8,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=joiner_filename,
|
||||
model_output=joiner_filename_int8,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
main()
|
456
egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py
Executable file
456
egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py
Executable file
@ -0,0 +1,456 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
"""
|
||||
This script loads ONNX models exported by ./export-onnx.py
|
||||
and uses them to decode waves.
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/pkufool/icefall_librispeech_streaming_pruned_transducer_stateless5_20220729
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/pkufool/icefall_librispeech_streaming_pruned_transducer_stateless5_20220729
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "exp/pretrained-epoch-25-avg-5.pt"
|
||||
cd exp
|
||||
ln -s pretrained-epoch-25-avg-5.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
2. Export the model to ONNX
|
||||
|
||||
./pruned_transducer_stateless5/export-onnx-streaming.py \
|
||||
--bpe-model ./icefall_librispeech_streaming_pruned_transducer_stateless5_20220729/data/lang_bpe_500/bpe.model \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--use-averaged-model 0 \
|
||||
--exp-dir ./icefall_librispeech_streaming_pruned_transducer_stateless5_20220729/exp \
|
||||
--num-encoder-layers 18 \
|
||||
--dim-feedforward 2048 \
|
||||
--nhead 8 \
|
||||
--encoder-dim 512 \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512
|
||||
|
||||
It will generate the following 3 files in $repo/exp
|
||||
|
||||
- encoder-epoch-99-avg-1.onnx
|
||||
- decoder-epoch-99-avg-1.onnx
|
||||
- joiner-epoch-99-avg-1.onnx
|
||||
|
||||
3. Run this file with the exported ONNX models
|
||||
|
||||
./pruned_transducer_stateless5/onnx_pretrained-streaming.py \
|
||||
--encoder-model-filename ./icefall_librispeech_streaming_pruned_transducer_stateless5_20220729/exp/encoder-epoch-99-avg-1.onnx \
|
||||
--decoder-model-filename ./icefall_librispeech_streaming_pruned_transducer_stateless5_20220729/exp/decoder-epoch-99-avg-1.onnx \
|
||||
--joiner-model-filename ./icefall_librispeech_streaming_pruned_transducer_stateless5_20220729/exp/joiner-epoch-99-avg-1.onnx \
|
||||
--tokens=./icefall_librispeech_streaming_pruned_transducer_stateless5_20220729/data/lang_bpe_500/tokens.txt \
|
||||
./icefall_librispeech_streaming_pruned_transducer_stateless5_20220729/test_waves/1221-135766-0001.wav
|
||||
|
||||
Note: Even though this script only supports decoding a single file,
|
||||
the exported ONNX models do support batch processing.
|
||||
|
||||
You can find the exported models in
|
||||
https://huggingface.co/csukuangfj/sherpa-onnx-streaming-conformer-en-2023-05-09
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torchaudio
|
||||
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the encoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the decoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the joiner onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
help="""Path to tokens.txt.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_file",
|
||||
type=str,
|
||||
help="The input sound file to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(
|
||||
self,
|
||||
encoder_model_filename: str,
|
||||
decoder_model_filename: str,
|
||||
joiner_model_filename: str,
|
||||
):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.session_opts = session_opts
|
||||
|
||||
self.init_encoder(encoder_model_filename)
|
||||
self.init_decoder(decoder_model_filename)
|
||||
self.init_joiner(joiner_model_filename)
|
||||
|
||||
def init_encoder(self, encoder_model_filename: str):
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
)
|
||||
self.init_encoder_states()
|
||||
|
||||
def init_encoder_states(self, batch_size: int = 1):
|
||||
encoder_meta = self.encoder.get_modelmeta().custom_metadata_map
|
||||
print(encoder_meta)
|
||||
|
||||
model_type = encoder_meta["model_type"]
|
||||
assert model_type == "conformer", model_type
|
||||
|
||||
decode_chunk_len = int(encoder_meta["decode_chunk_len"])
|
||||
T = int(encoder_meta["T"])
|
||||
pad_length = int(encoder_meta["pad_length"])
|
||||
|
||||
encoder_dim = int(encoder_meta["encoder_dim"])
|
||||
cnn_module_kernel = int(encoder_meta["cnn_module_kernel"])
|
||||
left_context = int(encoder_meta["left_context"])
|
||||
num_encoder_layers = int(encoder_meta["num_encoder_layers"])
|
||||
|
||||
self.cached_attn = torch.zeros(
|
||||
num_encoder_layers,
|
||||
left_context,
|
||||
batch_size,
|
||||
encoder_dim,
|
||||
).numpy()
|
||||
self.cached_conv = torch.zeros(
|
||||
num_encoder_layers,
|
||||
cnn_module_kernel - 1,
|
||||
batch_size,
|
||||
encoder_dim,
|
||||
).numpy()
|
||||
|
||||
logging.info(f"decode_chunk_len: {decode_chunk_len}")
|
||||
logging.info(f"T: {T}")
|
||||
logging.info(f"pad_length: {pad_length}")
|
||||
logging.info(f"encoder_dim: {encoder_dim}")
|
||||
logging.info(f"cnn_module_kernel: {cnn_module_kernel}")
|
||||
logging.info(f"left_context: {left_context}")
|
||||
logging.info(f"num_encoder_layers: {num_encoder_layers}")
|
||||
|
||||
self.segment = T
|
||||
self.offset = decode_chunk_len
|
||||
|
||||
def init_decoder(self, decoder_model_filename: str):
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
self.context_size = int(decoder_meta["context_size"])
|
||||
self.vocab_size = int(decoder_meta["vocab_size"])
|
||||
|
||||
logging.info(f"context_size: {self.context_size}")
|
||||
logging.info(f"vocab_size: {self.vocab_size}")
|
||||
|
||||
def init_joiner(self, joiner_model_filename: str):
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
self.joiner_dim = int(joiner_meta["joiner_dim"])
|
||||
|
||||
logging.info(f"joiner_dim: {self.joiner_dim}")
|
||||
|
||||
def _build_encoder_input_output(
|
||||
self, x: torch.Tensor, processed_lens: int
|
||||
) -> Tuple[Dict[str, np.ndarray], List[str]]:
|
||||
assert x.size(0) == 1
|
||||
encoder_input = {
|
||||
"x": x.numpy(),
|
||||
"cached_attn": self.cached_attn,
|
||||
"cached_conv": self.cached_conv,
|
||||
"processed_lens": torch.full(
|
||||
(1,), fill_value=processed_lens, dtype=torch.int64
|
||||
).numpy(),
|
||||
}
|
||||
encoder_output = ["encoder_out", "new_cached_attn", "new_cached_conv"]
|
||||
|
||||
return encoder_input, encoder_output
|
||||
|
||||
def _update_states(self, states: List[np.ndarray]):
|
||||
self.cached_attn = states[0]
|
||||
self.cached_conv = states[1]
|
||||
|
||||
def run_encoder(self, x: torch.Tensor, num_processed_frames: int) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, self.T, C). It only implements N == 1
|
||||
num_processed_frames:
|
||||
Number of processed frames before subsampling.
|
||||
Returns:
|
||||
Return a 3-D tensor of shape (N, chunk_size, joiner_dim)
|
||||
"""
|
||||
# assume subsampling_factor is 4
|
||||
num_processed_frames = num_processed_frames // 4
|
||||
encoder_input, encoder_output_names = self._build_encoder_input_output(
|
||||
x, num_processed_frames
|
||||
)
|
||||
out = self.encoder.run(encoder_output_names, encoder_input)
|
||||
|
||||
self._update_states(out[1:])
|
||||
|
||||
return torch.from_numpy(out[0])
|
||||
|
||||
def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
decoder_input:
|
||||
A 2-D tensor of shape (N, context_size)
|
||||
Returns:
|
||||
Return a 2-D tensor of shape (N, joiner_dim)
|
||||
"""
|
||||
out = self.decoder.run(
|
||||
[self.decoder.get_outputs()[0].name],
|
||||
{self.decoder.get_inputs()[0].name: decoder_input.numpy()},
|
||||
)[0]
|
||||
|
||||
return torch.from_numpy(out)
|
||||
|
||||
def run_joiner(
|
||||
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
decoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
Returns:
|
||||
Return a 2-D tensor of shape (N, vocab_size)
|
||||
"""
|
||||
out = self.joiner.run(
|
||||
[self.joiner.get_outputs()[0].name],
|
||||
{
|
||||
self.joiner.get_inputs()[0].name: encoder_out.numpy(),
|
||||
self.joiner.get_inputs()[1].name: decoder_out.numpy(),
|
||||
},
|
||||
)[0]
|
||||
|
||||
return torch.from_numpy(out)
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0].contiguous())
|
||||
return ans
|
||||
|
||||
|
||||
def create_streaming_feature_extractor() -> OnlineFeature:
|
||||
"""Create a CPU streaming feature extractor.
|
||||
|
||||
At present, we assume it returns a fbank feature extractor with
|
||||
fixed options. In the future, we will support passing in the options
|
||||
from outside.
|
||||
|
||||
Returns:
|
||||
Return a CPU streaming feature extractor.
|
||||
"""
|
||||
opts = FbankOptions()
|
||||
opts.device = "cpu"
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = 16000
|
||||
opts.mel_opts.num_bins = 80
|
||||
return OnlineFbank(opts)
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: OnnxModel,
|
||||
encoder_out: torch.Tensor,
|
||||
context_size: int,
|
||||
decoder_out: Optional[torch.Tensor] = None,
|
||||
hyp: Optional[List[int]] = None,
|
||||
) -> List[int]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
model:
|
||||
The transducer model.
|
||||
encoder_out:
|
||||
A 3-D tensor of shape (1, T, joiner_dim)
|
||||
context_size:
|
||||
The context size of the decoder model.
|
||||
decoder_out:
|
||||
Optional. Decoder output of the previous chunk.
|
||||
hyp:
|
||||
Decoding results for previous chunks.
|
||||
Returns:
|
||||
Return the decoded results so far.
|
||||
"""
|
||||
|
||||
blank_id = 0
|
||||
|
||||
if decoder_out is None:
|
||||
assert hyp is None, hyp
|
||||
hyp = [blank_id] * context_size
|
||||
decoder_input = torch.tensor([hyp], dtype=torch.int64)
|
||||
decoder_out = model.run_decoder(decoder_input)
|
||||
else:
|
||||
assert hyp is not None, hyp
|
||||
|
||||
encoder_out = encoder_out.squeeze(0)
|
||||
T = encoder_out.size(0)
|
||||
for t in range(T):
|
||||
cur_encoder_out = encoder_out[t : t + 1]
|
||||
joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0)
|
||||
y = joiner_out.argmax(dim=0).item()
|
||||
if y != blank_id:
|
||||
hyp.append(y)
|
||||
decoder_input = hyp[-context_size:]
|
||||
decoder_input = torch.tensor([decoder_input], dtype=torch.int64)
|
||||
decoder_out = model.run_decoder(decoder_input)
|
||||
|
||||
return hyp, decoder_out
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
model = OnnxModel(
|
||||
encoder_model_filename=args.encoder_model_filename,
|
||||
decoder_model_filename=args.decoder_model_filename,
|
||||
joiner_model_filename=args.joiner_model_filename,
|
||||
)
|
||||
|
||||
sample_rate = 16000
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
online_fbank = create_streaming_feature_extractor()
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_file}")
|
||||
waves = read_sound_files(
|
||||
filenames=[args.sound_file],
|
||||
expected_sample_rate=sample_rate,
|
||||
)[0]
|
||||
|
||||
tail_padding = torch.zeros(int(1.0 * sample_rate), dtype=torch.float32)
|
||||
wave_samples = torch.cat([waves, tail_padding])
|
||||
|
||||
num_processed_frames = 0
|
||||
segment = model.segment
|
||||
offset = model.offset
|
||||
|
||||
context_size = model.context_size
|
||||
hyp = None
|
||||
decoder_out = None
|
||||
|
||||
chunk = int(1 * sample_rate) # 1 second
|
||||
start = 0
|
||||
while start < wave_samples.numel():
|
||||
end = min(start + chunk, wave_samples.numel())
|
||||
samples = wave_samples[start:end]
|
||||
start += chunk
|
||||
|
||||
online_fbank.accept_waveform(
|
||||
sampling_rate=sample_rate,
|
||||
waveform=samples,
|
||||
)
|
||||
|
||||
while online_fbank.num_frames_ready - num_processed_frames >= segment:
|
||||
frames = []
|
||||
for i in range(segment):
|
||||
frames.append(online_fbank.get_frame(num_processed_frames + i))
|
||||
num_processed_frames += offset
|
||||
frames = torch.cat(frames, dim=0)
|
||||
frames = frames.unsqueeze(0)
|
||||
encoder_out = model.run_encoder(frames, num_processed_frames)
|
||||
hyp, decoder_out = greedy_search(
|
||||
model,
|
||||
encoder_out,
|
||||
context_size,
|
||||
decoder_out,
|
||||
hyp,
|
||||
)
|
||||
|
||||
symbol_table = k2.SymbolTable.from_file(args.tokens)
|
||||
|
||||
text = ""
|
||||
for i in hyp[context_size:]:
|
||||
text += symbol_table[i]
|
||||
text = text.replace("▁", " ").strip()
|
||||
|
||||
logging.info(args.sound_file)
|
||||
logging.info(text)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -25,29 +25,53 @@ from lhotse import CutSet, load_manifest_lazy
|
||||
|
||||
|
||||
class MultiDataset:
|
||||
def __init__(self, manifest_dir: str):
|
||||
def __init__(self, manifest_dir: str, cv_manifest_dir: str):
|
||||
"""
|
||||
Args:
|
||||
manifest_dir:
|
||||
It is expected to contain the following files:
|
||||
|
||||
- multidataset_split_1998/multidataset/multidataset_cuts_train.*.jsonl.gz
|
||||
- librispeech_cuts_train-all-shuf.jsonl.gz
|
||||
- gigaspeech_XL_split_2000/gigaspeech_cuts_XL.*.jsonl.gz
|
||||
|
||||
cv_manifest_dir:
|
||||
It is expected to contain the following files:
|
||||
|
||||
- cv-en_cuts_train.jsonl.gz
|
||||
"""
|
||||
self.manifest_dir = Path(manifest_dir)
|
||||
self.cv_manifest_dir = Path(cv_manifest_dir)
|
||||
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get multidataset train cuts")
|
||||
|
||||
filenames = glob.glob(
|
||||
f"{self.manifest_dir}/multidataset_split_1998/multidataset/multidataset_cuts_train.*.jsonl.gz"
|
||||
# LibriSpeech
|
||||
logging.info(f"Loading LibriSpeech in lazy mode")
|
||||
librispeech_cuts = load_manifest_lazy(
|
||||
self.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
|
||||
)
|
||||
|
||||
pattern = re.compile(r"multidataset_cuts_train.([0-9]+).jsonl.gz")
|
||||
# GigaSpeech
|
||||
filenames = glob.glob(
|
||||
f"{self.manifest_dir}/gigaspeech_XL_split_2000/gigaspeech_cuts_XL.*.jsonl.gz"
|
||||
)
|
||||
|
||||
pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz")
|
||||
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
|
||||
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
|
||||
|
||||
sorted_filenames = [f[1] for f in idx_filenames]
|
||||
|
||||
logging.info(f"Loading {len(sorted_filenames)} splits")
|
||||
logging.info(f"Loading GigaSpeech {len(sorted_filenames)} splits in lazy mode")
|
||||
|
||||
return lhotse.combine(lhotse.load_manifest_lazy(p) for p in sorted_filenames)
|
||||
gigaspeech_cuts = lhotse.combine(
|
||||
lhotse.load_manifest_lazy(p) for p in sorted_filenames
|
||||
)
|
||||
|
||||
# CommonVoice
|
||||
logging.info(f"Loading CommonVoice in lazy mode")
|
||||
commonvoice_cuts = load_manifest_lazy(
|
||||
self.cv_manifest_dir / f"cv-en_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
return CutSet.mux(librispeech_cuts, gigaspeech_cuts, commonvoice_cuts)
|
||||
|
@ -60,13 +60,13 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from multidataset import MultiDataset
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from multidataset import MultiDataset
|
||||
from optim import Eden, ScaledAdam
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
@ -1053,10 +1053,12 @@ def run(rank, world_size, args):
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
if params.use_multidataset:
|
||||
multidataset = MultiDataset(params.manifest_dir)
|
||||
multidataset = MultiDataset(params.manifest_dir, params.cv_manifest_dir)
|
||||
train_cuts = multidataset.train_cuts()
|
||||
else:
|
||||
if params.full_libri:
|
||||
if params.mini_libri:
|
||||
train_cuts = librispeech.train_clean_5_cuts()
|
||||
elif params.full_libri:
|
||||
train_cuts = librispeech.train_all_shuf_cuts()
|
||||
else:
|
||||
train_cuts = librispeech.train_clean_100_cuts()
|
||||
@ -1108,8 +1110,11 @@ def run(rank, world_size, args):
|
||||
train_cuts, sampler_state_dict=sampler_state_dict
|
||||
)
|
||||
|
||||
valid_cuts = librispeech.dev_clean_cuts()
|
||||
valid_cuts += librispeech.dev_other_cuts()
|
||||
if params.mini_libri:
|
||||
valid_cuts = librispeech.dev_clean_2_cuts()
|
||||
else:
|
||||
valid_cuts = librispeech.dev_clean_cuts()
|
||||
valid_cuts += librispeech.dev_other_cuts()
|
||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||
|
||||
if not params.use_multidataset and not params.print_diagnostics:
|
||||
|
@ -123,10 +123,13 @@ from beam_search import (
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
modified_beam_search_lm_rescore,
|
||||
modified_beam_search_lm_rescore_LODR,
|
||||
modified_beam_search_lm_shallow_fusion,
|
||||
modified_beam_search_LODR,
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall import LmScorer, NgramLm
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
@ -134,7 +137,6 @@ from icefall.checkpoint import (
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.lm_wrapper import LmScorer
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
@ -336,6 +338,21 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens-ngram",
|
||||
type=int,
|
||||
default=2,
|
||||
help="""The order of the ngram lm.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--backoff-id",
|
||||
type=int,
|
||||
default=500,
|
||||
help="ID of the backoff symbol in the ngram LM",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -349,6 +366,8 @@ def decode_one_batch(
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
LM: Optional[LmScorer] = None,
|
||||
ngram_lm=None,
|
||||
ngram_lm_scale: float = 0.0,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
@ -483,6 +502,18 @@ def decode_one_batch(
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "modified_beam_search_LODR":
|
||||
hyp_tokens = modified_beam_search_LODR(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
LODR_lm=ngram_lm,
|
||||
LODR_lm_scale=ngram_lm_scale,
|
||||
LM=LM,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "modified_beam_search_lm_rescore":
|
||||
lm_scale_list = [0.01 * i for i in range(10, 50)]
|
||||
ans_dict = modified_beam_search_lm_rescore(
|
||||
@ -493,6 +524,18 @@ def decode_one_batch(
|
||||
LM=LM,
|
||||
lm_scale_list=lm_scale_list,
|
||||
)
|
||||
elif params.decoding_method == "modified_beam_search_lm_rescore_LODR":
|
||||
lm_scale_list = [0.02 * i for i in range(2, 30)]
|
||||
ans_dict = modified_beam_search_lm_rescore_LODR(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
LM=LM,
|
||||
LODR_lm=ngram_lm,
|
||||
sp=sp,
|
||||
lm_scale_list=lm_scale_list,
|
||||
)
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
@ -531,7 +574,10 @@ def decode_one_batch(
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
|
||||
return {key: hyps}
|
||||
elif params.decoding_method == "modified_beam_search_lm_rescore":
|
||||
elif params.decoding_method in (
|
||||
"modified_beam_search_lm_rescore",
|
||||
"modified_beam_search_lm_rescore_LODR",
|
||||
):
|
||||
ans = dict()
|
||||
assert ans_dict is not None
|
||||
for key, hyps in ans_dict.items():
|
||||
@ -550,6 +596,8 @@ def decode_dataset(
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
LM: Optional[LmScorer] = None,
|
||||
ngram_lm=None,
|
||||
ngram_lm_scale: float = 0.0,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
@ -568,6 +616,8 @@ def decode_dataset(
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
ngram_lm:
|
||||
A n-gram LM to be used for LODR.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
@ -600,6 +650,8 @@ def decode_dataset(
|
||||
word_table=word_table,
|
||||
batch=batch,
|
||||
LM=LM,
|
||||
ngram_lm=ngram_lm,
|
||||
ngram_lm_scale=ngram_lm_scale,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
@ -677,8 +729,10 @@ def main():
|
||||
"fast_beam_search_nbest_LG",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
"modified_beam_search",
|
||||
"modified_beam_search_LODR",
|
||||
"modified_beam_search_lm_shallow_fusion",
|
||||
"modified_beam_search_lm_rescore",
|
||||
"modified_beam_search_lm_rescore_LODR",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
@ -822,7 +876,12 @@ def main():
|
||||
model.eval()
|
||||
|
||||
# only load the neural network LM if required
|
||||
if params.use_shallow_fusion or "lm" in params.decoding_method:
|
||||
if params.use_shallow_fusion or params.decoding_method in (
|
||||
"modified_beam_search_lm_rescore",
|
||||
"modified_beam_search_lm_rescore_LODR",
|
||||
"modified_beam_search_lm_shallow_fusion",
|
||||
"modified_beam_search_LODR",
|
||||
):
|
||||
LM = LmScorer(
|
||||
lm_type=params.lm_type,
|
||||
params=params,
|
||||
@ -834,6 +893,35 @@ def main():
|
||||
else:
|
||||
LM = None
|
||||
|
||||
# only load N-gram LM when needed
|
||||
if params.decoding_method == "modified_beam_search_lm_rescore_LODR":
|
||||
try:
|
||||
import kenlm
|
||||
except ImportError:
|
||||
print("Please install kenlm first. You can use")
|
||||
print(" pip install https://github.com/kpu/kenlm/archive/master.zip")
|
||||
print("to install it")
|
||||
import sys
|
||||
|
||||
sys.exit(-1)
|
||||
ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa")
|
||||
logging.info(f"lm filename: {ngram_file_name}")
|
||||
ngram_lm = kenlm.Model(ngram_file_name)
|
||||
|
||||
elif params.decoding_method == "modified_beam_search_LODR":
|
||||
lm_filename = f"{params.tokens_ngram}gram.fst.txt"
|
||||
logging.info(f"Loading token level lm: {lm_filename}")
|
||||
ngram_lm = NgramLm(
|
||||
str(params.lang_dir / lm_filename),
|
||||
backoff_id=params.backoff_id,
|
||||
is_binary=False,
|
||||
)
|
||||
logging.info(f"num states: {ngram_lm.lm.num_states}")
|
||||
ngram_lm_scale = params.ngram_lm_scale
|
||||
else:
|
||||
ngram_lm = None
|
||||
ngram_lm_scale = None
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
@ -866,8 +954,10 @@ def main():
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
import time
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
start = time.time()
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
@ -876,7 +966,10 @@ def main():
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
LM=LM,
|
||||
ngram_lm=ngram_lm,
|
||||
ngram_lm_scale=ngram_lm_scale,
|
||||
)
|
||||
logging.info(f"Elasped time for {test_set}: {time.time() - start}")
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
|
@ -1049,10 +1049,13 @@ def run(rank, world_size, args):
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
train_cuts = librispeech.train_clean_100_cuts()
|
||||
if params.full_libri:
|
||||
train_cuts += librispeech.train_clean_360_cuts()
|
||||
train_cuts += librispeech.train_other_500_cuts()
|
||||
if params.mini_libri:
|
||||
train_cuts = librispeech.train_clean_5_cuts()
|
||||
else:
|
||||
train_cuts = librispeech.train_clean_100_cuts()
|
||||
if params.full_libri:
|
||||
train_cuts += librispeech.train_clean_360_cuts()
|
||||
train_cuts += librispeech.train_other_500_cuts()
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
# Keep only utterances with duration between 1 second and 20 seconds
|
||||
@ -1104,8 +1107,11 @@ def run(rank, world_size, args):
|
||||
train_cuts, sampler_state_dict=sampler_state_dict
|
||||
)
|
||||
|
||||
valid_cuts = librispeech.dev_clean_cuts()
|
||||
valid_cuts += librispeech.dev_other_cuts()
|
||||
if params.mini_libri:
|
||||
valid_cuts = librispeech.dev_clean_2_cuts()
|
||||
else:
|
||||
valid_cuts = librispeech.dev_clean_cuts()
|
||||
valid_cuts += librispeech.dev_other_cuts()
|
||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||
|
||||
# if not params.print_diagnostics:
|
||||
|
@ -86,8 +86,16 @@ class LibriSpeechAsrDataModule:
|
||||
"--full-libri",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
|
||||
help="""Used only when --mini-libri is False.When enabled,
|
||||
use 960h LibriSpeech. Otherwise, use 100h subset.""",
|
||||
)
|
||||
group.add_argument(
|
||||
"--mini-libri",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="True for mini librispeech",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
@ -393,6 +401,13 @@ class LibriSpeechAsrDataModule:
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_5_cuts(self) -> CutSet:
|
||||
logging.info("mini_librispeech: About to get train-clean-5 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_100_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-clean-100 cuts")
|
||||
@ -424,6 +439,13 @@ class LibriSpeechAsrDataModule:
|
||||
self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_clean_2_cuts(self) -> CutSet:
|
||||
logging.info("mini_librispeech: About to get dev-clean-2 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-clean cuts")
|
||||
|
6
egs/timit/ASR/local/compile_hlg.py
Normal file → Executable file
6
egs/timit/ASR/local/compile_hlg.py
Normal file → Executable file
@ -100,7 +100,11 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
|
||||
|
||||
logging.info("Removing disambiguation symbols on LG")
|
||||
|
||||
LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||
# LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||
# see https://github.com/k2-fsa/k2/pull/1140
|
||||
labels = LG.labels
|
||||
labels[labels >= first_token_disambig_id] = 0
|
||||
LG.labels = labels
|
||||
|
||||
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
|
||||
|
||||
|
0
egs/timit/ASR/local/compute_fbank_timit.py
Normal file → Executable file
0
egs/timit/ASR/local/compute_fbank_timit.py
Normal file → Executable file
0
egs/timit/ASR/local/prepare_lang.py
Normal file → Executable file
0
egs/timit/ASR/local/prepare_lang.py
Normal file → Executable file
0
egs/timit/ASR/local/prepare_lexicon.py
Normal file → Executable file
0
egs/timit/ASR/local/prepare_lexicon.py
Normal file → Executable file
@ -59,7 +59,9 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||
# using: `sudo apt-get install git-lfs && git-lfs install`
|
||||
[ ! -e $dl_dir/lm ] && mkdir -p $dl_dir/lm
|
||||
git clone https://huggingface.co/luomingshuang/timit_lm $dl_dir/lm
|
||||
cd $dl_dir/lm && git lfs pull
|
||||
pushd $dl_dir/lm
|
||||
git lfs pull
|
||||
popd
|
||||
fi
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
|
@ -78,10 +78,11 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
|
||||
|
||||
logging.info("Removing disambiguation symbols on LG")
|
||||
|
||||
LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||
# See https://github.com/k2-fsa/k2/issues/874
|
||||
# for why we need to set LG.properties to None
|
||||
LG.__dict__["_properties"] = None
|
||||
# LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||
# see https://github.com/k2-fsa/k2/pull/1140
|
||||
labels = LG.labels
|
||||
labels[labels >= first_token_disambig_id] = 0
|
||||
LG.labels = labels
|
||||
|
||||
assert isinstance(LG.aux_labels, k2.RaggedTensor)
|
||||
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
|
||||
|
@ -8,6 +8,12 @@ from . import (
|
||||
utils
|
||||
)
|
||||
|
||||
from .byte_utils import (
|
||||
byte_decode,
|
||||
byte_encode,
|
||||
smart_byte_decode,
|
||||
)
|
||||
|
||||
from .checkpoint import (
|
||||
average_checkpoints,
|
||||
find_checkpoints,
|
||||
@ -49,6 +55,7 @@ from .utils import (
|
||||
get_alignments,
|
||||
get_executor,
|
||||
get_texts,
|
||||
is_cjk,
|
||||
is_jit_tracing,
|
||||
is_module_available,
|
||||
l1_norm,
|
||||
@ -64,6 +71,7 @@ from .utils import (
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
subsequent_chunk_mask,
|
||||
tokenize_by_CJK_char,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
|
311
icefall/byte_utils.py
Normal file
311
icefall/byte_utils.py
Normal file
@ -0,0 +1,311 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
# This file was copied and modified from https://github.com/facebookresearch/fairseq/blob/main/fairseq/data/encoders/byte_utils.py
|
||||
|
||||
import re
|
||||
import unicodedata
|
||||
|
||||
|
||||
WHITESPACE_NORMALIZER = re.compile(r"\s+")
|
||||
SPACE = chr(32)
|
||||
SPACE_ESCAPE = chr(9601)
|
||||
|
||||
PRINTABLE_BASE_CHARS = [
|
||||
256,
|
||||
257,
|
||||
258,
|
||||
259,
|
||||
260,
|
||||
261,
|
||||
262,
|
||||
263,
|
||||
264,
|
||||
265,
|
||||
266,
|
||||
267,
|
||||
268,
|
||||
269,
|
||||
270,
|
||||
271,
|
||||
272,
|
||||
273,
|
||||
274,
|
||||
275,
|
||||
276,
|
||||
277,
|
||||
278,
|
||||
279,
|
||||
280,
|
||||
281,
|
||||
282,
|
||||
283,
|
||||
284,
|
||||
285,
|
||||
286,
|
||||
287,
|
||||
32,
|
||||
33,
|
||||
34,
|
||||
35,
|
||||
36,
|
||||
37,
|
||||
38,
|
||||
39,
|
||||
40,
|
||||
41,
|
||||
42,
|
||||
43,
|
||||
44,
|
||||
45,
|
||||
46,
|
||||
47,
|
||||
48,
|
||||
49,
|
||||
50,
|
||||
51,
|
||||
52,
|
||||
53,
|
||||
54,
|
||||
55,
|
||||
56,
|
||||
57,
|
||||
58,
|
||||
59,
|
||||
60,
|
||||
61,
|
||||
62,
|
||||
63,
|
||||
64,
|
||||
65,
|
||||
66,
|
||||
67,
|
||||
68,
|
||||
69,
|
||||
70,
|
||||
71,
|
||||
72,
|
||||
73,
|
||||
74,
|
||||
75,
|
||||
76,
|
||||
77,
|
||||
78,
|
||||
79,
|
||||
80,
|
||||
81,
|
||||
82,
|
||||
83,
|
||||
84,
|
||||
85,
|
||||
86,
|
||||
87,
|
||||
88,
|
||||
89,
|
||||
90,
|
||||
91,
|
||||
92,
|
||||
93,
|
||||
94,
|
||||
95,
|
||||
96,
|
||||
97,
|
||||
98,
|
||||
99,
|
||||
100,
|
||||
101,
|
||||
102,
|
||||
103,
|
||||
104,
|
||||
105,
|
||||
106,
|
||||
107,
|
||||
108,
|
||||
109,
|
||||
110,
|
||||
111,
|
||||
112,
|
||||
113,
|
||||
114,
|
||||
115,
|
||||
116,
|
||||
117,
|
||||
118,
|
||||
119,
|
||||
120,
|
||||
121,
|
||||
122,
|
||||
123,
|
||||
124,
|
||||
125,
|
||||
126,
|
||||
288,
|
||||
289,
|
||||
290,
|
||||
291,
|
||||
292,
|
||||
293,
|
||||
294,
|
||||
295,
|
||||
296,
|
||||
297,
|
||||
298,
|
||||
299,
|
||||
300,
|
||||
301,
|
||||
302,
|
||||
303,
|
||||
304,
|
||||
305,
|
||||
308,
|
||||
309,
|
||||
310,
|
||||
311,
|
||||
312,
|
||||
313,
|
||||
314,
|
||||
315,
|
||||
316,
|
||||
317,
|
||||
318,
|
||||
321,
|
||||
322,
|
||||
323,
|
||||
324,
|
||||
325,
|
||||
326,
|
||||
327,
|
||||
328,
|
||||
330,
|
||||
331,
|
||||
332,
|
||||
333,
|
||||
334,
|
||||
335,
|
||||
336,
|
||||
337,
|
||||
338,
|
||||
339,
|
||||
340,
|
||||
341,
|
||||
342,
|
||||
343,
|
||||
344,
|
||||
345,
|
||||
346,
|
||||
347,
|
||||
348,
|
||||
349,
|
||||
350,
|
||||
351,
|
||||
352,
|
||||
353,
|
||||
354,
|
||||
355,
|
||||
356,
|
||||
357,
|
||||
358,
|
||||
359,
|
||||
360,
|
||||
361,
|
||||
362,
|
||||
363,
|
||||
364,
|
||||
365,
|
||||
366,
|
||||
367,
|
||||
368,
|
||||
369,
|
||||
370,
|
||||
371,
|
||||
372,
|
||||
373,
|
||||
374,
|
||||
375,
|
||||
376,
|
||||
377,
|
||||
378,
|
||||
379,
|
||||
380,
|
||||
381,
|
||||
382,
|
||||
384,
|
||||
385,
|
||||
386,
|
||||
387,
|
||||
388,
|
||||
389,
|
||||
390,
|
||||
391,
|
||||
392,
|
||||
393,
|
||||
394,
|
||||
395,
|
||||
396,
|
||||
397,
|
||||
398,
|
||||
399,
|
||||
400,
|
||||
401,
|
||||
402,
|
||||
403,
|
||||
404,
|
||||
405,
|
||||
406,
|
||||
407,
|
||||
408,
|
||||
409,
|
||||
410,
|
||||
411,
|
||||
412,
|
||||
413,
|
||||
414,
|
||||
415,
|
||||
416,
|
||||
417,
|
||||
418,
|
||||
419,
|
||||
420,
|
||||
421,
|
||||
422,
|
||||
]
|
||||
|
||||
for c in PRINTABLE_BASE_CHARS:
|
||||
assert unicodedata.normalize("NFKC", chr(c)) == chr(c), c
|
||||
|
||||
BYTE_TO_BCHAR = {b: chr(PRINTABLE_BASE_CHARS[b]) for b in range(256)}
|
||||
BCHAR_TO_BYTE = {bc: b for b, bc in BYTE_TO_BCHAR.items()}
|
||||
|
||||
|
||||
def byte_encode(x: str) -> str:
|
||||
normalized = WHITESPACE_NORMALIZER.sub(SPACE, x)
|
||||
return "".join([BYTE_TO_BCHAR[b] for b in normalized.encode("utf-8")])
|
||||
|
||||
|
||||
def byte_decode(x: str) -> str:
|
||||
try:
|
||||
return bytes([BCHAR_TO_BYTE[bc] for bc in x]).decode("utf-8")
|
||||
except ValueError:
|
||||
return ""
|
||||
|
||||
|
||||
def smart_byte_decode(x: str) -> str:
|
||||
output = byte_decode(x)
|
||||
if output == "":
|
||||
# DP the best recovery (max valid chars) if it's broken
|
||||
n_bytes = len(x)
|
||||
f = [0 for _ in range(n_bytes + 1)]
|
||||
pt = [0 for _ in range(n_bytes + 1)]
|
||||
for i in range(1, n_bytes + 1):
|
||||
f[i], pt[i] = f[i - 1], i - 1
|
||||
for j in range(1, min(4, i) + 1):
|
||||
if f[i - j] + 1 > f[i] and len(byte_decode(x[i - j : i])) > 0:
|
||||
f[i], pt[i] = f[i - j] + 1, i - j
|
||||
cur_pt = n_bytes
|
||||
while cur_pt > 0:
|
||||
if f[cur_pt] == f[pt[cur_pt]] + 1:
|
||||
output = byte_decode(x[pt[cur_pt] : cur_pt]) + output
|
||||
cur_pt = pt[cur_pt]
|
||||
return output
|
1
icefall/rnn_lm/.gitignore
vendored
Normal file
1
icefall/rnn_lm/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
icefall-librispeech-rnn-lm
|
129
icefall/rnn_lm/check-onnx-streaming.py
Executable file
129
icefall/rnn_lm/check-onnx-streaming.py
Executable file
@ -0,0 +1,129 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2023 Xiaomi Corporation
|
||||
|
||||
"""
|
||||
Usage:
|
||||
|
||||
./check-onnx-streaming.py \
|
||||
--jit ./icefall-librispeech-rnn-lm/exp/cpu_jit.pt \
|
||||
--onnx ./icefall-librispeech-rnn-lm/exp/with-state-epoch-99-avg-1.onnx
|
||||
|
||||
Note: You can download pre-trained models from
|
||||
https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the torchscript model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the onnx model",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(self, filename: str):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.model = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=session_opts,
|
||||
)
|
||||
|
||||
meta_data = self.model.get_modelmeta().custom_metadata_map
|
||||
self.sos_id = int(meta_data["sos_id"])
|
||||
self.eos_id = int(meta_data["eos_id"])
|
||||
self.vocab_size = int(meta_data["vocab_size"])
|
||||
self.num_layers = int(meta_data["num_layers"])
|
||||
self.hidden_size = int(meta_data["hidden_size"])
|
||||
print(meta_data)
|
||||
|
||||
def __call__(
|
||||
self, x: torch.Tensor, y: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
out = self.model.run(
|
||||
[
|
||||
self.model.get_outputs()[0].name,
|
||||
self.model.get_outputs()[1].name,
|
||||
self.model.get_outputs()[2].name,
|
||||
],
|
||||
{
|
||||
self.model.get_inputs()[0].name: x.numpy(),
|
||||
self.model.get_inputs()[1].name: y.numpy(),
|
||||
self.model.get_inputs()[2].name: h0.numpy(),
|
||||
self.model.get_inputs()[3].name: c0.numpy(),
|
||||
},
|
||||
)
|
||||
return (
|
||||
torch.from_numpy(out[0]),
|
||||
torch.from_numpy(out[1]),
|
||||
torch.from_numpy(out[2]),
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
torch_model = torch.jit.load(args.jit).cpu()
|
||||
onnx_model = OnnxModel(args.onnx)
|
||||
N = torch.arange(1, 5).tolist()
|
||||
|
||||
num_layers = onnx_model.num_layers
|
||||
hidden_size = onnx_model.hidden_size
|
||||
|
||||
for n in N:
|
||||
L = torch.randint(low=1, high=100, size=(1,)).item()
|
||||
x = torch.randint(
|
||||
low=1, high=onnx_model.vocab_size, size=(n, L), dtype=torch.int64
|
||||
)
|
||||
h0 = torch.rand(num_layers, n, hidden_size)
|
||||
c0 = torch.rand(num_layers, n, hidden_size)
|
||||
|
||||
torch_log_prob, torch_h0, torch_c0 = torch_model.score_token_onnx(x, h0, c0)
|
||||
onnx_log_prob, onnx_h0, onnx_c0 = onnx_model(x, h0, c0)
|
||||
|
||||
for torch_v, onnx_v in zip(
|
||||
(torch_log_prob, torch_h0, torch_c0), (onnx_log_prob, onnx_h0, onnx_c0)
|
||||
):
|
||||
|
||||
assert torch.allclose(torch_v, onnx_v, atol=1e-5), (
|
||||
torch_v.shape,
|
||||
onnx_v.shape,
|
||||
(torch_v - onnx_v).abs().max(),
|
||||
)
|
||||
print(n, L, torch_v.sum(), onnx_v.sum())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(20230423)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
119
icefall/rnn_lm/check-onnx.py
Executable file
119
icefall/rnn_lm/check-onnx.py
Executable file
@ -0,0 +1,119 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2023 Xiaomi Corporation
|
||||
|
||||
"""
|
||||
Usage:
|
||||
|
||||
./check-onnx.py \
|
||||
--jit ./icefall-librispeech-rnn-lm/exp/cpu_jit.pt \
|
||||
--onnx ./icefall-librispeech-rnn-lm/exp/no-state-epoch-99-avg-1.onnx
|
||||
|
||||
Note: You can download pre-trained models from
|
||||
https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the torchscript model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the onnx model",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(self, filename: str):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.model = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=session_opts,
|
||||
)
|
||||
|
||||
meta_data = self.model.get_modelmeta().custom_metadata_map
|
||||
self.sos_id = int(meta_data["sos_id"])
|
||||
self.eos_id = int(meta_data["eos_id"])
|
||||
self.vocab_size = int(meta_data["vocab_size"])
|
||||
print(meta_data)
|
||||
|
||||
def __call__(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
|
||||
out = self.model.run(
|
||||
[
|
||||
self.model.get_outputs()[0].name,
|
||||
],
|
||||
{
|
||||
self.model.get_inputs()[0].name: x.numpy(),
|
||||
self.model.get_inputs()[1].name: x_lens.numpy(),
|
||||
},
|
||||
)
|
||||
return torch.from_numpy(out[0])
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
torch_model = torch.jit.load(args.jit).cpu()
|
||||
onnx_model = OnnxModel(args.onnx)
|
||||
N = torch.arange(1, 5).tolist()
|
||||
|
||||
for n in N:
|
||||
L = torch.randint(low=1, high=100, size=(1,)).item()
|
||||
x = torch.randint(
|
||||
low=1, high=onnx_model.vocab_size, size=(n, L), dtype=torch.int64
|
||||
)
|
||||
x_lens = torch.full((n,), fill_value=L, dtype=torch.int64)
|
||||
if n > 1:
|
||||
x_lens[0] = L // 2 + 1
|
||||
|
||||
sos = torch.full((1,), fill_value=onnx_model.sos_id).expand(n, 1)
|
||||
sos_x = torch.cat([sos, x], dim=1)
|
||||
|
||||
pad_col = torch.zeros((1,), dtype=x.dtype).expand(n, 1)
|
||||
x_eos = torch.cat([x, pad_col], dim=1)
|
||||
|
||||
row_index = torch.arange(0, n, dtype=x.dtype)
|
||||
x_eos[row_index, x_lens] = onnx_model.eos_id
|
||||
|
||||
torch_nll = torch_model(sos_x, x_eos, x_lens + 1).sum(dim=-1)
|
||||
onnx_nll = onnx_model(x, x_lens)
|
||||
# Note: For int8 models, the differences may be quite large,
|
||||
# e.g., within 0.9
|
||||
assert torch.allclose(torch_nll, onnx_nll), (
|
||||
torch_nll,
|
||||
onnx_nll,
|
||||
)
|
||||
print(n, L, torch_nll, onnx_nll)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(20230420)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
395
icefall/rnn_lm/export-onnx.py
Executable file
395
icefall/rnn_lm/export-onnx.py
Executable file
@ -0,0 +1,395 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2023 Xiaomi Corporation
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import onnx
|
||||
import torch
|
||||
from model import RnnLmModel
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
||||
from icefall.utils import AttributeDict, str2bool
|
||||
from typing import Dict
|
||||
from train import get_params
|
||||
|
||||
|
||||
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||
"""Add meta data to an ONNX model. It is changed in-place.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename of the ONNX model to be changed.
|
||||
meta_data:
|
||||
Key-value pairs.
|
||||
"""
|
||||
model = onnx.load(filename)
|
||||
for key, value in meta_data.items():
|
||||
meta = model.metadata_props.add()
|
||||
meta.key = key
|
||||
meta.value = value
|
||||
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
# A wrapper for RnnLm model to simpily the C++ calling code
|
||||
# when exporting the model to ONNX.
|
||||
#
|
||||
# TODO(fangjun): The current wrapper works only for non-streaming ASR
|
||||
# since we don't expose the LM state and it is used to score
|
||||
# a complete sentence at once.
|
||||
class RnnLmModelWrapper(torch.nn.Module):
|
||||
def __init__(self, model: RnnLmModel, sos_id: int, eos_id: int):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.sos_id = sos_id
|
||||
self.eos_id = eos_id
|
||||
|
||||
def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 2-D tensor of shape (N, L) with dtype torch.int64.
|
||||
It does not contain SOS or EOS. We will add SOS and EOS inside
|
||||
this function.
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,) with dtype torch.int64. It contains
|
||||
number of valid tokens in ``x`` before padding.
|
||||
Returns:
|
||||
Return a 1-D tensor of shape (N,) containing negative loglikelihood.
|
||||
Its dtype is torch.float32
|
||||
"""
|
||||
N = x.size(0)
|
||||
|
||||
sos_tensor = torch.full((1,), fill_value=self.sos_id, dtype=x.dtype).expand(
|
||||
N, 1
|
||||
)
|
||||
sos_x = torch.cat([sos_tensor, x], dim=1)
|
||||
|
||||
pad_col = torch.zeros((1,), dtype=x.dtype).expand(N, 1)
|
||||
x_eos = torch.cat([x, pad_col], dim=1)
|
||||
|
||||
row_index = torch.arange(0, N, dtype=x.dtype)
|
||||
x_eos[row_index, x_lens] = self.eos_id
|
||||
|
||||
# use x_lens + 1 here since we prepended x with sos
|
||||
return (
|
||||
self.model(x=sos_x, y=x_eos, lengths=x_lens + 1)
|
||||
.to(torch.float32)
|
||||
.sum(dim=1)
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=29,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vocab-size",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Vocabulary size of the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--embedding-dim",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Embedding dim of the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hidden-dim",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Hidden dim of the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-layers",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of RNN layers the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tie-weights",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="""True to share the weights between the input embedding layer and the
|
||||
last output linear layer
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="rnn_lm/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def export_without_state(
|
||||
model: RnnLmModel,
|
||||
filename: str,
|
||||
params: AttributeDict,
|
||||
opset_version: int,
|
||||
):
|
||||
model_wrapper = RnnLmModelWrapper(
|
||||
model,
|
||||
sos_id=params.sos_id,
|
||||
eos_id=params.eos_id,
|
||||
)
|
||||
|
||||
N = 1
|
||||
L = 20
|
||||
x = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64)
|
||||
x_lens = torch.full((N,), fill_value=L, dtype=torch.int64)
|
||||
|
||||
# Note(fangjun): The following warnings can be ignored.
|
||||
# We can use ./check-onnx.py to validate the exported model with batch_size > 1
|
||||
"""
|
||||
torch/onnx/symbolic_opset9.py:2119: UserWarning: Exporting a model to ONNX
|
||||
with a batch_size other than 1, with a variable length with LSTM can cause
|
||||
an error when running the ONNX model with a different batch size. Make sure
|
||||
to save the model with a batch size of 1, or define the initial states
|
||||
(h0/c0) as inputs of the model. warnings.warn("Exporting a model to ONNX
|
||||
with a batch_size other than 1, " +
|
||||
"""
|
||||
|
||||
torch.onnx.export(
|
||||
model_wrapper,
|
||||
(x, x_lens),
|
||||
filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["x", "x_lens"],
|
||||
output_names=["nll"],
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "L"},
|
||||
"x_lens": {0: "N"},
|
||||
"nll": {0: "N"},
|
||||
},
|
||||
)
|
||||
|
||||
meta_data = {
|
||||
"model_type": "rnnlm",
|
||||
"version": "1",
|
||||
"model_author": "k2-fsa",
|
||||
"comment": "rnnlm without state",
|
||||
"sos_id": str(params.sos_id),
|
||||
"eos_id": str(params.eos_id),
|
||||
"vocab_size": str(params.vocab_size),
|
||||
"url": "https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm",
|
||||
}
|
||||
logging.info(f"meta_data: {meta_data}")
|
||||
|
||||
add_meta_data(filename=filename, meta_data=meta_data)
|
||||
|
||||
|
||||
def export_with_state(
|
||||
model: RnnLmModel,
|
||||
filename: str,
|
||||
params: AttributeDict,
|
||||
opset_version: int,
|
||||
):
|
||||
N = 1
|
||||
L = 20
|
||||
num_layers = model.rnn.num_layers
|
||||
hidden_size = model.rnn.hidden_size
|
||||
embedding_dim = model.embedding_dim
|
||||
|
||||
x = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64)
|
||||
h0 = torch.zeros(num_layers, N, hidden_size)
|
||||
c0 = torch.zeros(num_layers, N, hidden_size)
|
||||
|
||||
# Note(fangjun): The following warnings can be ignored.
|
||||
# We can use ./check-onnx.py to validate the exported model with batch_size > 1
|
||||
"""
|
||||
torch/onnx/symbolic_opset9.py:2119: UserWarning: Exporting a model to ONNX
|
||||
with a batch_size other than 1, with a variable length with LSTM can cause
|
||||
an error when running the ONNX model with a different batch size. Make sure
|
||||
to save the model with a batch size of 1, or define the initial states
|
||||
(h0/c0) as inputs of the model. warnings.warn("Exporting a model to ONNX
|
||||
with a batch_size other than 1, " +
|
||||
"""
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(x, h0, c0),
|
||||
filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["x", "h0", "c0"],
|
||||
output_names=["log_softmax", "next_h0", "next_c0"],
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "L"},
|
||||
"h0": {1: "N"},
|
||||
"c0": {1: "N"},
|
||||
"log_softmax": {0: "N"},
|
||||
"next_h0": {1: "N"},
|
||||
"next_c0": {1: "N"},
|
||||
},
|
||||
)
|
||||
|
||||
meta_data = {
|
||||
"model_type": "rnnlm",
|
||||
"version": "1",
|
||||
"model_author": "k2-fsa",
|
||||
"comment": "rnnlm state",
|
||||
"sos_id": str(params.sos_id),
|
||||
"eos_id": str(params.eos_id),
|
||||
"vocab_size": str(params.vocab_size),
|
||||
"num_layers": str(num_layers),
|
||||
"hidden_size": str(hidden_size),
|
||||
"embedding_dim": str(embedding_dim),
|
||||
"url": "https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm",
|
||||
}
|
||||
logging.info(f"meta_data: {meta_data}")
|
||||
|
||||
add_meta_data(filename=filename, meta_data=meta_data)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
logging.info(params)
|
||||
|
||||
device = torch.device("cpu")
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
model = RnnLmModel(
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.embedding_dim,
|
||||
hidden_dim=params.hidden_dim,
|
||||
num_layers=params.num_layers,
|
||||
tie_weights=params.tie_weights,
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device), strict=False
|
||||
)
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 0:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device), strict=False
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
if params.iter > 0:
|
||||
suffix = f"iter-{params.iter}"
|
||||
else:
|
||||
suffix = f"epoch-{params.epoch}"
|
||||
|
||||
suffix += f"-avg-{params.avg}"
|
||||
|
||||
opset_version = 13
|
||||
|
||||
logging.info("Exporting model without state")
|
||||
filename = params.exp_dir / f"no-state-{suffix}.onnx"
|
||||
export_without_state(
|
||||
model=model,
|
||||
filename=filename,
|
||||
params=params,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
filename_int8 = params.exp_dir / f"no-state-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=filename,
|
||||
model_output=filename_int8,
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
# now for streaming export
|
||||
saved_forward = model.__class__.forward
|
||||
model.__class__.forward = model.__class__.score_token_onnx
|
||||
streaming_filename = params.exp_dir / f"with-state-{suffix}.onnx"
|
||||
export_with_state(
|
||||
model=model,
|
||||
filename=streaming_filename,
|
||||
params=params,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
model.__class__.forward = saved_forward
|
||||
|
||||
streaming_filename_int8 = params.exp_dir / f"with-state-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=streaming_filename,
|
||||
model_output=streaming_filename_int8,
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
26
icefall/rnn_lm/export-onnx.sh
Executable file
26
icefall/rnn_lm/export-onnx.sh
Executable file
@ -0,0 +1,26 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# We use the model from
|
||||
# https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main
|
||||
# as an example
|
||||
|
||||
export CUDA_VISIBLE_DEVICES=
|
||||
|
||||
if [ ! -f ./icefall-librispeech-rnn-lm/exp/pretrained.pt ]; then
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
|
||||
pushd icefall-librispeech-rnn-lm/exp
|
||||
git lfs pull --include "pretrained.pt"
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
fi
|
||||
|
||||
python3 ./export-onnx.py \
|
||||
--exp-dir ./icefall-librispeech-rnn-lm/exp \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--vocab-size 500 \
|
||||
--embedding-dim 2048 \
|
||||
--hidden-dim 2048 \
|
||||
--num-layers 3 \
|
||||
--tie-weights 1
|
||||
|
@ -26,7 +26,7 @@ import torch
|
||||
from model import RnnLmModel
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
||||
from icefall.utils import AttributeDict, load_averaged_model, str2bool
|
||||
from icefall.utils import AttributeDict, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -118,6 +118,7 @@ def get_parser():
|
||||
return parser
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
@ -180,6 +181,10 @@ def main():
|
||||
|
||||
if params.jit:
|
||||
logging.info("Using torch.jit.script")
|
||||
|
||||
model.__class__.score_token_onnx = torch.jit.export(
|
||||
model.__class__.score_token_onnx
|
||||
)
|
||||
model = torch.jit.script(model)
|
||||
filename = params.exp_dir / "cpu_jit.pt"
|
||||
model.save(str(filename))
|
||||
|
27
icefall/rnn_lm/export.sh
Executable file
27
icefall/rnn_lm/export.sh
Executable file
@ -0,0 +1,27 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# We use the model from
|
||||
# https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main
|
||||
# as an example
|
||||
|
||||
export CUDA_VISIBLE_DEVICES=
|
||||
|
||||
if [ ! -f ./icefall-librispeech-rnn-lm/exp/pretrained.pt ]; then
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
|
||||
pushd icefall-librispeech-rnn-lm/exp
|
||||
git lfs pull --include "pretrained.pt"
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
fi
|
||||
|
||||
python3 ./export.py \
|
||||
--exp-dir ./icefall-librispeech-rnn-lm/exp \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--vocab-size 500 \
|
||||
--embedding-dim 2048 \
|
||||
--hidden-dim 2048 \
|
||||
--num-layers 3 \
|
||||
--tie-weights 1 \
|
||||
--jit 1
|
||||
|
@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -47,6 +48,11 @@ class RnnLmModel(torch.nn.Module):
|
||||
and https://arxiv.org/abs/1611.01462
|
||||
"""
|
||||
super().__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.embedding_dim = embedding_dim
|
||||
self.hidden_dim = hidden_dim
|
||||
self.num_layers = num_layers
|
||||
self.tie_weights = tie_weights
|
||||
|
||||
self.input_embedding = torch.nn.Embedding(
|
||||
num_embeddings=vocab_size,
|
||||
@ -74,6 +80,46 @@ class RnnLmModel(torch.nn.Module):
|
||||
|
||||
self.cache = {}
|
||||
|
||||
def streaming_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
h0: torch.Tensor,
|
||||
c0: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 2-D tensor of shape (N, L). We won't prepend it with SOS.
|
||||
y:
|
||||
A 2-D tensor of shape (N, L). We won't append it with EOS.
|
||||
h0:
|
||||
A 3-D tensor of shape (num_layers, N, hidden_size).
|
||||
(If proj_size > 0, then it is (num_layers, N, proj_size))
|
||||
c0:
|
||||
A 3-D tensor of shape (num_layers, N, hidden_size).
|
||||
Returns:
|
||||
Return a tuple containing 3 tensors:
|
||||
- negative loglike (nll), a 1-D tensor of shape (N,)
|
||||
- next_h0, a 3-D tensor with the same shape as h0
|
||||
- next_c0, a 3-D tensor with the same shape as c0
|
||||
"""
|
||||
assert x.ndim == y.ndim == 2, (x.ndim, y.ndim)
|
||||
assert x.shape == y.shape, (x.shape, y.shape)
|
||||
|
||||
# embedding is of shape (N, L, embedding_dim)
|
||||
embedding = self.input_embedding(x)
|
||||
# Note: We use batch_first==True
|
||||
rnn_out, (next_h0, next_c0) = self.rnn(embedding, (h0, c0))
|
||||
logits = self.output_linear(rnn_out)
|
||||
nll_loss = F.cross_entropy(
|
||||
logits.reshape(-1, self.vocab_size), y.reshape(-1), reduction="none"
|
||||
)
|
||||
|
||||
batch_size = x.size(0)
|
||||
nll_loss = nll_loss.reshape(batch_size, -1).sum(dim=1)
|
||||
return nll_loss, next_h0, next_c0
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, y: torch.Tensor, lengths: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
@ -188,6 +234,33 @@ class RnnLmModel(torch.nn.Module):
|
||||
|
||||
return logits[:, 0].log_softmax(-1), states
|
||||
|
||||
def score_token_onnx(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
state_h: torch.Tensor,
|
||||
state_c: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Score a batch of tokens, i.e each sample in the batch should be a
|
||||
single token. For example, x = torch.tensor([[5],[10],[20]])
|
||||
|
||||
|
||||
Args:
|
||||
x (torch.Tensor):
|
||||
A batch of tokens
|
||||
state_h:
|
||||
state h of RNN has the shape of (num_layers, bs, hidden_dim)
|
||||
state_c:
|
||||
state c of RNN has the shape of (num_layers, bs, hidden_dim)
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
embedding = self.input_embedding(x)
|
||||
rnn_out, (next_h0, next_c0) = self.rnn(embedding, (state_h, state_c))
|
||||
logits = self.output_linear(rnn_out)
|
||||
|
||||
return logits[:, 0].log_softmax(-1), next_h0, next_c0
|
||||
|
||||
def forward_with_state(
|
||||
self, tokens, token_lens, sos_id, eos_id, blank_id, state=None
|
||||
):
|
||||
|
@ -22,6 +22,7 @@ Usage:
|
||||
--world-size 2 \
|
||||
--num-epochs 1 \
|
||||
--use-fp16 0 \
|
||||
--tie-weights 0 \
|
||||
--embedding-dim 800 \
|
||||
--hidden-dim 200 \
|
||||
--num-layers 2 \
|
||||
|
@ -1306,6 +1306,31 @@ def tokenize_by_bpe_model(
|
||||
return txt_with_bpe
|
||||
|
||||
|
||||
def tokenize_by_CJK_char(line: str) -> str:
|
||||
"""
|
||||
Tokenize a line of text with CJK char.
|
||||
|
||||
Note: All return charaters will be upper case.
|
||||
|
||||
Example:
|
||||
input = "你好世界是 hello world 的中文"
|
||||
output = "你 好 世 界 是 HELLO WORLD 的 中 文"
|
||||
|
||||
Args:
|
||||
line:
|
||||
The input text.
|
||||
|
||||
Return:
|
||||
A new string tokenize by CJK char.
|
||||
"""
|
||||
# The CJK ranges is from https://github.com/alvations/nltk/blob/79eed6ddea0d0a2c212c1060b477fc268fec4d4b/nltk/tokenize/util.py
|
||||
pattern = re.compile(
|
||||
r"([\u1100-\u11ff\u2e80-\ua4cf\ua840-\uD7AF\uF900-\uFAFF\uFE30-\uFE4F\uFF65-\uFFDC\U00020000-\U0002FFFF])"
|
||||
)
|
||||
chars = pattern.split(line.strip().upper())
|
||||
return " ".join([w.strip() for w in chars if w.strip()])
|
||||
|
||||
|
||||
def display_and_save_batch(
|
||||
batch: dict,
|
||||
params: AttributeDict,
|
||||
@ -1764,3 +1789,34 @@ def parse_fsa_timestamps_and_texts(
|
||||
utt_time_pairs.append(list(zip(start, end)))
|
||||
|
||||
return utt_time_pairs, utt_words
|
||||
|
||||
|
||||
# Copied from https://github.com/alvations/nltk/blob/79eed6ddea0d0a2c212c1060b477fc268fec4d4b/nltk/tokenize/util.py
|
||||
def is_cjk(character):
|
||||
"""
|
||||
Python port of Moses' code to check for CJK character.
|
||||
|
||||
>>> is_cjk(u'\u33fe')
|
||||
True
|
||||
>>> is_cjk(u'\uFE5F')
|
||||
False
|
||||
|
||||
:param character: The character that needs to be checked.
|
||||
:type character: char
|
||||
:return: bool
|
||||
"""
|
||||
return any(
|
||||
[
|
||||
start <= ord(character) <= end
|
||||
for start, end in [
|
||||
(4352, 4607),
|
||||
(11904, 42191),
|
||||
(43072, 43135),
|
||||
(44032, 55215),
|
||||
(63744, 64255),
|
||||
(65072, 65103),
|
||||
(65381, 65500),
|
||||
(131072, 196607),
|
||||
]
|
||||
]
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user