Merge remote-tracking branch 'dan/master' into mmi

This commit is contained in:
Fangjun Kuang 2021-10-18 14:38:50 +08:00
commit d7023c3c4b
28 changed files with 1624 additions and 479 deletions

106
.github/workflows/run-pretrained.yml vendored Normal file
View File

@ -0,0 +1,106 @@
# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: run-pre-trained-conformer-ctc
on:
push:
branches:
- master
pull_request:
types: [labeled]
jobs:
run_pre_trained_conformer_ctc:
if: github.event.label.name == 'ready' || github.event_name == 'push'
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
python-version: [3.6, 3.7, 3.8, 3.9]
torch: ["1.8.1"]
k2-version: ["1.9.dev20210919"]
fail-fast: false
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Install Python dependencies
run: |
python3 -m pip install --upgrade pip pytest
pip install torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
python3 -m pip install git+https://github.com/lhotse-speech/lhotse
python3 -m pip install kaldifeat
# We are in ./icefall and there is a file: requirements.txt in it
pip install -r requirements.txt
- name: Install graphviz
shell: bash
run: |
python3 -m pip install -qq graphviz
sudo apt-get -qq install graphviz
- name: Download pre-trained model
shell: bash
run: |
sudo apt-get -qq install git-lfs tree sox
cd egs/librispeech/ASR
mkdir tmp
cd tmp
git lfs install
git clone https://github.com/csukuangfj/icefall-asr-conformer-ctc-bpe-500
cd ..
tree tmp
soxi tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/*.flac
ls -lh tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/*.flac
- name: Run CTC decoding
shell: bash
run: |
export PYTHONPATH=$PWD:PYTHONPATH
cd egs/librispeech/ASR
./conformer_ctc/pretrained.py \
--num-classes 500 \
--checkpoint ./tmp/icefall-asr-conformer-ctc-bpe-500/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-conformer-ctc-bpe-500/data/lang_bpe_500/bpe.model \
--method ctc-decoding \
./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1089-134686-0001.flac \
./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1221-135766-0001.flac \
./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1221-135766-0002.flac
- name: Run HLG decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
cd egs/librispeech/ASR
./conformer_ctc/pretrained.py \
--num-classes 500 \
--checkpoint ./tmp/icefall-asr-conformer-ctc-bpe-500/exp/pretrained.pt \
--words-file ./tmp/icefall-asr-conformer-ctc-bpe-500/data/lang_bpe_500/words.txt \
--HLG ./tmp/icefall-asr-conformer-ctc-bpe-500/data/lang_bpe_500/HLG.pt \
./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1089-134686-0001.flac \
./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1221-135766-0001.flac \
./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1221-135766-0002.flac

View File

@ -46,10 +46,18 @@ jobs:
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- name: Install libnsdfile and libsox
if: startsWith(matrix.os, 'ubuntu')
run: |
sudo apt update
sudo apt install -q -y libsndfile1-dev libsndfile1 ffmpeg
sudo apt install -q -y --fix-missing sox libsox-dev libsox-fmt-all
- name: Install Python dependencies - name: Install Python dependencies
run: | run: |
python3 -m pip install --upgrade pip pytest python3 -m pip install --upgrade pip pytest
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/ pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
pip install git+https://github.com/lhotse-speech/lhotse
# icefall requirements # icefall requirements
pip install -r requirements.txt pip install -r requirements.txt
@ -84,3 +92,7 @@ jobs:
echo "lib_path: $lib_path" echo "lib_path: $lib_path"
export DYLD_LIBRARY_PATH=$lib_path:$DYLD_LIBRARY_PATH export DYLD_LIBRARY_PATH=$lib_path:$DYLD_LIBRARY_PATH
pytest ./test pytest ./test
# runt tests for conformer ctc
cd egs/librispeech/ASR/conformer_ctc
pytest

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
icefall.egg-info/
data data
__pycache__ __pycache__
path.sh path.sh

View File

@ -55,7 +55,22 @@ The WER for this model is:
We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd?usp=sharing) We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd?usp=sharing)
## Deployment with C++
Once you have trained a model in icefall, you may want to deploy it with C++,
without Python dependencies.
Please refer to the documentation
<https://icefall.readthedocs.io/en/latest/recipes/librispeech/conformer_ctc.html#deployment-with-c>
for how to do this.
We also provide a Colab notebook, showing you how to run a torch scripted model in [k2][k2] with C++.
Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1BIGLWzS36isskMXHKcqC9ysN6pspYXs_?usp=sharing)
[LibriSpeech_tdnn_lstm_ctc]: egs/librispeech/ASR/tdnn_lstm_ctc [LibriSpeech_tdnn_lstm_ctc]: egs/librispeech/ASR/tdnn_lstm_ctc
[LibriSpeech_conformer_ctc]: egs/librispeech/ASR/conformer_ctc [LibriSpeech_conformer_ctc]: egs/librispeech/ASR/conformer_ctc
[yesno]: egs/yesno/ASR [yesno]: egs/yesno/ASR
[librispeech]: egs/librispeech/ASR [librispeech]: egs/librispeech/ASR
[k2]: https://github.com/k2-fsa/k2

View File

@ -1,4 +1,4 @@
Confromer CTC Conformer CTC
============= =============
This tutorial shows you how to run a conformer ctc model This tutorial shows you how to run a conformer ctc model
@ -20,6 +20,7 @@ In this tutorial, you will learn:
- (2) How to start the training, either with a single GPU or multiple GPUs - (2) How to start the training, either with a single GPU or multiple GPUs
- (3) How to do decoding after training, with n-gram LM rescoring and attention decoder rescoring - (3) How to do decoding after training, with n-gram LM rescoring and attention decoder rescoring
- (4) How to use a pre-trained model, provided by us - (4) How to use a pre-trained model, provided by us
- (5) How to deploy your trained model in C++, without Python dependencies
Data preparation Data preparation
---------------- ----------------
@ -292,16 +293,25 @@ The commonly used options are:
- ``--method`` - ``--method``
This specifies the decoding method. This specifies the decoding method. This script supports 7 decoding methods.
As for ctc decoding, it uses a sentence piece model to convert word pieces to words.
And it needs neither a lexicon nor an n-gram LM.
The following command uses attention decoder for rescoring: For example, the following command uses CTC topology for decoding:
.. code-block:: .. code-block::
$ cd egs/librispeech/ASR $ cd egs/librispeech/ASR
$ ./conformer_ctc/decode.py --method attention-decoder --max-duration 30 --lattice-score-scale 0.5 $ ./conformer_ctc/decode.py --method ctc-decoding --max-duration 300
- ``--lattice-score-scale`` And the following command uses attention decoder for rescoring:
.. code-block::
$ cd egs/librispeech/ASR
$ ./conformer_ctc/decode.py --method attention-decoder --max-duration 30 --nbest-scale 0.5
- ``--nbest-scale``
It is used to scale down lattice scores so that there are more unique It is used to scale down lattice scores so that there are more unique
paths for rescoring. paths for rescoring.
@ -311,6 +321,61 @@ The commonly used options are:
It has the same meaning as the one during training. A larger It has the same meaning as the one during training. A larger
value may cause OOM. value may cause OOM.
Here are some results for CTC decoding with a vocab size of 500:
Usage:
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./conformer_ctc/decode.py \
--epoch 25 \
--avg 1 \
--max-duration 300 \
--exp-dir conformer_ctc/exp \
--lang-dir data/lang_bpe_500 \
--method ctc-decoding
The output is given below:
.. code-block:: bash
2021-09-26 12:44:31,033 INFO [decode.py:537] Decoding started
2021-09-26 12:44:31,033 INFO [decode.py:538]
{'lm_dir': PosixPath('data/lm'), 'subsampling_factor': 4, 'vgg_frontend': False, 'use_feat_batchnorm': True,
'feature_dim': 80, 'nhead': 8, 'attention_dim': 512, 'num_decoder_layers': 6, 'search_beam': 20, 'output_beam': 8,
'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True,
'epoch': 25, 'avg': 1, 'method': 'ctc-decoding', 'num_paths': 100, 'nbest_scale': 0.5,
'export': False, 'exp_dir': PosixPath('conformer_ctc/exp'), 'lang_dir': PosixPath('data/lang_bpe_500'), 'full_libri': False,
'feature_dir': PosixPath('data/fbank'), 'max_duration': 100, 'bucketing_sampler': False, 'num_buckets': 30,
'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False,
'shuffle': True, 'return_cuts': True, 'num_workers': 2}
2021-09-26 12:44:31,406 INFO [lexicon.py:113] Loading pre-compiled data/lang_bpe_500/Linv.pt
2021-09-26 12:44:31,464 INFO [decode.py:548] device: cuda:0
2021-09-26 12:44:36,171 INFO [checkpoint.py:92] Loading checkpoint from conformer_ctc/exp/epoch-25.pt
2021-09-26 12:44:36,776 INFO [decode.py:652] Number of model parameters: 109226120
2021-09-26 12:44:37,714 INFO [decode.py:473] batch 0/206, cuts processed until now is 12
2021-09-26 12:45:15,944 INFO [decode.py:473] batch 100/206, cuts processed until now is 1328
2021-09-26 12:45:54,443 INFO [decode.py:473] batch 200/206, cuts processed until now is 2563
2021-09-26 12:45:56,411 INFO [decode.py:494] The transcripts are stored in conformer_ctc/exp/recogs-test-clean-ctc-decoding.txt
2021-09-26 12:45:56,592 INFO [utils.py:331] [test-clean-ctc-decoding] %WER 3.26% [1715 / 52576, 163 ins, 128 del, 1424 sub ]
2021-09-26 12:45:56,807 INFO [decode.py:506] Wrote detailed error stats to conformer_ctc/exp/errs-test-clean-ctc-decoding.txt
2021-09-26 12:45:56,808 INFO [decode.py:522]
For test-clean, WER of different settings are:
ctc-decoding 3.26 best for test-clean
2021-09-26 12:45:57,362 INFO [decode.py:473] batch 0/203, cuts processed until now is 15
2021-09-26 12:46:35,565 INFO [decode.py:473] batch 100/203, cuts processed until now is 1477
2021-09-26 12:47:15,106 INFO [decode.py:473] batch 200/203, cuts processed until now is 2922
2021-09-26 12:47:16,131 INFO [decode.py:494] The transcripts are stored in conformer_ctc/exp/recogs-test-other-ctc-decoding.txt
2021-09-26 12:47:16,208 INFO [utils.py:331] [test-other-ctc-decoding] %WER 8.21% [4295 / 52343, 396 ins, 315 del, 3584 sub ]
2021-09-26 12:47:16,432 INFO [decode.py:506] Wrote detailed error stats to conformer_ctc/exp/errs-test-other-ctc-decoding.txt
2021-09-26 12:47:16,432 INFO [decode.py:522]
For test-other, WER of different settings are:
ctc-decoding 8.21 best for test-other
2021-09-26 12:47:16,433 INFO [decode.py:680] Done!
Pre-trained Model Pre-trained Model
----------------- -----------------
@ -381,7 +446,6 @@ After downloading, you will have the following files:
6 directories, 11 files 6 directories, 11 files
**File descriptions**: **File descriptions**:
- ``data/lang_bpe/HLG.pt`` - ``data/lang_bpe/HLG.pt``
It is the decoding graph. It is the decoding graph.
@ -462,12 +526,58 @@ Usage
displays the help information. displays the help information.
It supports three decoding methods: It supports 4 decoding methods:
- CTC decoding
- HLG decoding - HLG decoding
- HLG + n-gram LM rescoring - HLG + n-gram LM rescoring
- HLG + n-gram LM rescoring + attention decoder rescoring - HLG + n-gram LM rescoring + attention decoder rescoring
CTC decoding
^^^^^^^^^^^^
CTC decoding uses the best path of the decoding lattice as the decoding result
without any LM or lexicon.
The command to run CTC decoding is:
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./conformer_ctc/pretrained.py \
--checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \
--bpe-model ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/bpe.model \
--method ctc-decoding \
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac \
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac \
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac
The output is given below:
.. code-block::
2021-10-13 11:21:50,896 INFO [pretrained.py:236] device: cuda:0
2021-10-13 11:21:50,896 INFO [pretrained.py:238] Creating model
2021-10-13 11:21:56,669 INFO [pretrained.py:255] Constructing Fbank computer
2021-10-13 11:21:56,670 INFO [pretrained.py:265] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac']
2021-10-13 11:21:56,683 INFO [pretrained.py:271] Decoding started
2021-10-13 11:21:57,341 INFO [pretrained.py:290] Building CTC topology
2021-10-13 11:21:57,625 INFO [lexicon.py:113] Loading pre-compiled tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/Linv.pt
2021-10-13 11:21:57,679 INFO [pretrained.py:299] Loading BPE model
2021-10-13 11:22:00,076 INFO [pretrained.py:314] Use CTC decoding
2021-10-13 11:22:00,087 INFO [pretrained.py:400]
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac:
GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONOURED
BOSOM TO CONNECT HER PARENT FOR EVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
2021-10-13 11:22:00,087 INFO [pretrained.py:402] Decoding Done
HLG decoding HLG decoding
^^^^^^^^^^^^ ^^^^^^^^^^^^
@ -490,14 +600,14 @@ The output is given below:
.. code-block:: .. code-block::
2021-08-20 11:03:05,712 INFO [pretrained.py:217] device: cuda:0 2021-10-13 11:25:19,458 INFO [pretrained.py:236] device: cuda:0
2021-08-20 11:03:05,712 INFO [pretrained.py:219] Creating model 2021-10-13 11:25:19,458 INFO [pretrained.py:238] Creating model
2021-08-20 11:03:11,345 INFO [pretrained.py:238] Loading HLG from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt 2021-10-13 11:25:25,342 INFO [pretrained.py:255] Constructing Fbank computer
2021-08-20 11:03:18,442 INFO [pretrained.py:255] Constructing Fbank computer 2021-10-13 11:25:25,343 INFO [pretrained.py:265] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac']
2021-08-20 11:03:18,444 INFO [pretrained.py:265] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac'] 2021-10-13 11:25:25,356 INFO [pretrained.py:271] Decoding started
2021-08-20 11:03:18,507 INFO [pretrained.py:271] Decoding started 2021-10-13 11:25:26,026 INFO [pretrained.py:327] Loading HLG from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt
2021-08-20 11:03:18,795 INFO [pretrained.py:300] Use HLG decoding 2021-10-13 11:25:33,735 INFO [pretrained.py:359] Use HLG decoding
2021-08-20 11:03:19,149 INFO [pretrained.py:339] 2021-10-13 11:25:34,013 INFO [pretrained.py:400]
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac: ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
@ -508,7 +618,7 @@ The output is given below:
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac: ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
2021-08-20 11:03:19,149 INFO [pretrained.py:341] Decoding Done 2021-10-13 11:25:34,014 INFO [pretrained.py:402] Decoding Done
HLG decoding + LM rescoring HLG decoding + LM rescoring
^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^
@ -536,15 +646,15 @@ Its output is:
.. code-block:: .. code-block::
2021-08-20 11:12:17,565 INFO [pretrained.py:217] device: cuda:0 2021-10-13 11:28:19,129 INFO [pretrained.py:236] device: cuda:0
2021-08-20 11:12:17,565 INFO [pretrained.py:219] Creating model 2021-10-13 11:28:19,129 INFO [pretrained.py:238] Creating model
2021-08-20 11:12:23,728 INFO [pretrained.py:238] Loading HLG from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt 2021-10-13 11:28:23,531 INFO [pretrained.py:255] Constructing Fbank computer
2021-08-20 11:12:30,035 INFO [pretrained.py:246] Loading G from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt 2021-10-13 11:28:23,532 INFO [pretrained.py:265] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac']
2021-08-20 11:13:10,779 INFO [pretrained.py:255] Constructing Fbank computer 2021-10-13 11:28:23,544 INFO [pretrained.py:271] Decoding started
2021-08-20 11:13:10,787 INFO [pretrained.py:265] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac'] 2021-10-13 11:28:24,141 INFO [pretrained.py:327] Loading HLG from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt
2021-08-20 11:13:10,798 INFO [pretrained.py:271] Decoding started 2021-10-13 11:28:30,752 INFO [pretrained.py:338] Loading G from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt
2021-08-20 11:13:11,085 INFO [pretrained.py:305] Use HLG decoding + LM rescoring 2021-10-13 11:28:48,308 INFO [pretrained.py:364] Use HLG decoding + LM rescoring
2021-08-20 11:13:11,736 INFO [pretrained.py:339] 2021-10-13 11:28:48,815 INFO [pretrained.py:400]
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac: ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
@ -555,7 +665,7 @@ Its output is:
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac: ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
2021-08-20 11:13:11,737 INFO [pretrained.py:341] Decoding Done 2021-10-13 11:28:48,815 INFO [pretrained.py:402] Decoding Done
HLG decoding + LM rescoring + attention decoder rescoring HLG decoding + LM rescoring + attention decoder rescoring
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@ -577,7 +687,7 @@ The command to run HLG decoding + LM rescoring + attention decoder rescoring is:
--G ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt \ --G ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt \
--ngram-lm-scale 1.3 \ --ngram-lm-scale 1.3 \
--attention-decoder-scale 1.2 \ --attention-decoder-scale 1.2 \
--lattice-score-scale 0.5 \ --nbest-scale 0.5 \
--num-paths 100 \ --num-paths 100 \
--sos-id 1 \ --sos-id 1 \
--eos-id 1 \ --eos-id 1 \
@ -589,15 +699,15 @@ The output is below:
.. code-block:: .. code-block::
2021-08-20 11:19:11,397 INFO [pretrained.py:217] device: cuda:0 2021-10-13 11:29:50,106 INFO [pretrained.py:236] device: cuda:0
2021-08-20 11:19:11,397 INFO [pretrained.py:219] Creating model 2021-10-13 11:29:50,106 INFO [pretrained.py:238] Creating model
2021-08-20 11:19:17,354 INFO [pretrained.py:238] Loading HLG from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt 2021-10-13 11:29:56,063 INFO [pretrained.py:255] Constructing Fbank computer
2021-08-20 11:19:24,615 INFO [pretrained.py:246] Loading G from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt 2021-10-13 11:29:56,063 INFO [pretrained.py:265] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac']
2021-08-20 11:20:04,576 INFO [pretrained.py:255] Constructing Fbank computer 2021-10-13 11:29:56,077 INFO [pretrained.py:271] Decoding started
2021-08-20 11:20:04,584 INFO [pretrained.py:265] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac'] 2021-10-13 11:29:56,770 INFO [pretrained.py:327] Loading HLG from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt
2021-08-20 11:20:04,595 INFO [pretrained.py:271] Decoding started 2021-10-13 11:30:04,023 INFO [pretrained.py:338] Loading G from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt
2021-08-20 11:20:04,854 INFO [pretrained.py:313] Use HLG + LM rescoring + attention decoder rescoring 2021-10-13 11:30:18,163 INFO [pretrained.py:372] Use HLG + LM rescoring + attention decoder rescoring
2021-08-20 11:20:05,805 INFO [pretrained.py:339] 2021-10-13 11:30:19,367 INFO [pretrained.py:400]
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac: ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
@ -608,7 +718,7 @@ The output is below:
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac: ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
2021-08-20 11:20:05,805 INFO [pretrained.py:341] Decoding Done 2021-10-13 11:30:19,367 INFO [pretrained.py:402] Decoding Done
Colab notebook Colab notebook
-------------- --------------
@ -629,3 +739,119 @@ We do provide a colab notebook for this recipe showing how to use a pre-trained
**Congratulations!** You have finished the librispeech ASR recipe with **Congratulations!** You have finished the librispeech ASR recipe with
conformer CTC models in ``icefall``. conformer CTC models in ``icefall``.
If you want to deploy your trained model in C++, please read the following section.
Deployment with C++
-------------------
This section describes how to deploy your trained model in C++, without
Python dependencies.
We assume you have run ``./prepare.sh`` and have the following directories available:
.. code-block:: bash
data
|-- lang_bpe
Also, we assume your checkpoints are saved in ``conformer_ctc/exp``.
If you know that averaging 20 checkpoints starting from ``epoch-30.pt`` yields the
lowest WER, you can run the following commands
.. code-block::
$ cd egs/librispeech/ASR
$ ./conformer_ctc/export.py \
--epoch 30 \
--avg 20 \
--jit 1 \
--lang-dir data/lang_bpe \
--exp-dir conformer_ctc/exp
to get a torch scripted model saved in ``conformer_ctc/exp/cpu_jit.pt``.
Now you have all needed files ready. Let us compile k2 from source:
.. code-block:: bash
$ cd $HOME
$ git clone https://github.com/k2-fsa/k2
$ cd k2
$ git checkout v2.0-pre
.. CAUTION::
You have to switch to the branch ``v2.0-pre``!
.. code-block:: bash
$ mkdir build-release
$ cd build-release
$ cmake -DCMAKE_BUILD_TYPE=Release ..
$ make -j decode
# You will find an executable: `./bin/decode`
Now you are ready to go!
To view the usage of ``./bin/decode``, run:
.. code-block::
$ ./bin/decode
It will show you the following message:
.. code-block::
Please provide --jit_pt
(1) CTC decoding
./bin/decode \
--use_ctc_decoding true \
--jit_pt <path to exported torch script pt file> \
--bpe_model <path to pretrained BPE model> \
/path/to/foo.wav \
/path/to/bar.wav \
<more wave files if any>
(2) HLG decoding
./bin/decode \
--use_ctc_decoding false \
--jit_pt <path to exported torch script pt file> \
--hlg <path to HLG.pt> \
--word-table <path to words.txt> \
/path/to/foo.wav \
/path/to/bar.wav \
<more wave files if any>
--use_gpu false to use CPU
--use_gpu true to use GPU
``./bin/decode`` supports two types of decoding at present: CTC decoding and HLG decoding.
CTC decoding
^^^^^^^^^^^^
You need to provide:
- ``--jit_pt``, this is the file generated by ``conformer_ctc/export.py``. You can find it
in ``conformer_ctc/exp/cpu_jit.pt``.
- ``--bpe_model``, this is a sentence piece model generated by ``prepare.sh``. You can find
it in ``data/lang_bpe/bpe.model``.
HLG decoding
^^^^^^^^^^^^
You need to provide:
- ``--jit_pt``, this is the same file as in CTC decoding.
- ``--hlg``, this file is generated by ``prepare.sh``. You can find it in ``data/lang_bpe/HLG.pt``.
- ``--word-table``, this file is generated by ``prepare.sh``. You can find it in ``data/lang_bpe/words.txt``.
We do provide a Colab notebook, showing you how to run a torch scripted model in C++.
Please see |librispeech asr conformer ctc torch script colab notebook|
.. |librispeech asr conformer ctc torch script colab notebook| image:: https://colab.research.google.com/assets/colab-badge.svg
:target: https://colab.research.google.com/drive/1BIGLWzS36isskMXHKcqC9ysN6pspYXs_?usp=sharing

View File

@ -38,14 +38,16 @@ python conformer_ctc/train.py --bucketing-sampler True \
--concatenate-cuts False \ --concatenate-cuts False \
--max-duration 200 \ --max-duration 200 \
--full-libri True \ --full-libri True \
--world-size 4 --world-size 4 \
--lang-dir data/lang_bpe_5000
python conformer_ctc/decode.py --lattice-score-scale 0.5 \ python conformer_ctc/decode.py --nbest-scale 0.5 \
--epoch 34 \ --epoch 34 \
--avg 20 \ --avg 20 \
--method attention-decoder \ --method attention-decoder \
--max-duration 20 \ --max-duration 20 \
--num-paths 100 --num-paths 100 \
--lang-dir data/lang_bpe_5000
``` ```
### LibriSpeech training results (Tdnn-Lstm) ### LibriSpeech training results (Tdnn-Lstm)

View File

@ -1,3 +1,53 @@
## Introduction
Please visit Please visit
<https://icefall.readthedocs.io/en/latest/recipes/librispeech/conformer_ctc.html> <https://icefall.readthedocs.io/en/latest/recipes/librispeech/conformer_ctc.html>
for how to run this recipe. for how to run this recipe.
## How to compute framewise alignment information
### Step 1: Train a model
Please use `conformer_ctc/train.py` to train a model.
See <https://icefall.readthedocs.io/en/latest/recipes/librispeech/conformer_ctc.html>
for how to do it.
### Step 2: Compute framewise alignment
Run
```
# Choose a checkpoint and determine the number of checkpoints to average
epoch=30
avg=15
./conformer_ctc/ali.py \
--epoch $epoch \
--avg $avg \
--max-duration 500 \
--bucketing-sampler 0 \
--full-libri 1 \
--exp-dir conformer_ctc/exp \
--lang-dir data/lang_bpe_5000 \
--ali-dir data/ali_5000
```
and you will get four files inside the folder `data/ali_5000`:
```
$ ls -lh data/ali_500
total 546M
-rw-r--r-- 1 kuangfangjun root 1.1M Sep 28 08:06 test_clean.pt
-rw-r--r-- 1 kuangfangjun root 1.1M Sep 28 08:07 test_other.pt
-rw-r--r-- 1 kuangfangjun root 542M Sep 28 11:36 train-960.pt
-rw-r--r-- 1 kuangfangjun root 2.1M Sep 28 11:38 valid.pt
```
**Note**: It can take more than 3 hours to compute the alignment
for the training dataset, which contains 960 * 3 = 2880 hours of data.
**Caution**: The model parameters in `conformer_ctc/ali.py` have to match those
in `conformer_ctc/train.py`.
**Caution**: You have to set the parameter `preserve_id` to `True` for `CutMix`.
Search `./conformer_ctc/asr_datamodule.py` for `preserve_id`.
**TODO:** Add doc about how to use the extracted alignment in the other pull-request.

View File

@ -0,0 +1,314 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
from typing import List, Tuple
import k2
import torch
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.decode import one_best_decoding
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
encode_supervisions,
get_alignments,
get_env_info,
save_alignments,
setup_logger,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=34,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=20,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe_5000",
help="The lang dir",
)
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_ctc/exp",
help="The experiment dir",
)
parser.add_argument(
"--ali-dir",
type=str,
default="data/ali_500",
help="The experiment dir",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"lm_dir": Path("data/lm"),
"feature_dim": 80,
"nhead": 8,
"attention_dim": 512,
"subsampling_factor": 4,
"num_decoder_layers": 6,
"vgg_frontend": False,
"use_feat_batchnorm": True,
"output_beam": 10,
"use_double_scores": True,
"env_info": get_env_info(),
}
)
return params
def compute_alignments(
model: torch.nn.Module,
dl: torch.utils.data.DataLoader,
params: AttributeDict,
graph_compiler: BpeCtcTrainingGraphCompiler,
) -> List[Tuple[str, List[int]]]:
"""Compute the framewise alignments of a dataset.
Args:
model:
The neural network model.
dl:
Dataloader containing the dataset.
params:
Parameters for computing alignments.
graph_compiler:
It converts token IDs to decoding graphs.
Returns:
Return a list of tuples. Each tuple contains two entries:
- Utterance ID
- Framewise alignments (token IDs) after subsampling
"""
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
num_cuts = 0
device = graph_compiler.device
ans = []
for batch_idx, batch in enumerate(dl):
feature = batch["inputs"]
# at entry, feature is [N, T, C]
assert feature.ndim == 3
feature = feature.to(device)
supervisions = batch["supervisions"]
cut_ids = []
for cut in supervisions["cut"]:
assert len(cut.supervisions) == 1
cut_ids.append(cut.id)
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
# nnet_output is [N, T, C]
supervision_segments, texts = encode_supervisions(
supervisions, subsampling_factor=params.subsampling_factor
)
# we need also to sort cut_ids as encode_supervisions()
# reorders "texts".
# In general, new2old is an identity map since lhotse sorts the returned
# cuts by duration in descending order
new2old = supervision_segments[:, 0].tolist()
cut_ids = [cut_ids[i] for i in new2old]
token_ids = graph_compiler.texts_to_ids(texts)
decoding_graph = graph_compiler.compile(token_ids)
dense_fsa_vec = k2.DenseFsaVec(
nnet_output,
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
)
lattice = k2.intersect_dense(
decoding_graph,
dense_fsa_vec,
params.output_beam,
)
best_path = one_best_decoding(
lattice=lattice,
use_double_scores=params.use_double_scores,
)
ali_ids = get_alignments(best_path)
assert len(ali_ids) == len(cut_ids)
ans += list(zip(cut_ids, ali_ids))
num_cuts += len(ali_ids)
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return ans
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
assert args.return_cuts is True
assert args.concatenate_cuts is False
if args.full_libri is False:
print("Changing --full-libri to True")
args.full_libri = True
params = get_params()
params.update(vars(args))
setup_logger(f"{params.exp_dir}/log/ali")
logging.info("Computing alignment - started")
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir,
device=device,
sos_token="<sos/eos>",
eos_token="<sos/eos>",
)
logging.info("About to create model")
model = Conformer(
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=num_classes,
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
model.to(device)
model.eval()
librispeech = LibriSpeechAsrDataModule(args)
train_dl = librispeech.train_dataloaders()
valid_dl = librispeech.valid_dataloaders()
test_dl = librispeech.test_dataloaders() # a list
ali_dir = Path(params.ali_dir)
ali_dir.mkdir(exist_ok=True)
enabled_datasets = {
"test_clean": test_dl[0],
"test_other": test_dl[1],
"train-960": train_dl,
"valid": valid_dl,
}
# For train-960, it takes about 3 hours 40 minutes, i.e., 3.67 hours to
# compute the alignments if you use --max-duration=500
#
# There are 960 * 3 = 2880 hours data and it takes only
# 3 hours 40 minutes to get the alignment.
# The RTF is roughly: 3.67 / 2880 = 0.0012743
#
# At the end, you would see
# 2021-09-28 11:32:46,690 INFO [ali.py:188] batch 21000/?, cuts processed until now is 836270 # noqa
# 2021-09-28 11:33:45,084 INFO [ali.py:188] batch 21100/?, cuts processed until now is 840268 # noqa
for name, dl in enabled_datasets.items():
logging.info(f"Processing {name}")
if name == "train-960":
logging.info(
f"It will take about 3 hours 40 minutes for {name}, "
"which contains 960 * 3 = 2880 hours of data"
)
alignments = compute_alignments(
model=model,
dl=dl,
params=params,
graph_compiler=graph_compiler,
)
num_utt = len(alignments)
alignments = dict(alignments)
assert num_utt == len(alignments)
filename = ali_dir / f"{name}.pt"
save_alignments(
alignments=alignments,
subsampling_factor=params.subsampling_factor,
filename=filename,
)
logging.info(
f"For dataset {name}, its alignments are saved to {filename}"
)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -23,6 +23,7 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import k2 import k2
import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
@ -42,6 +43,7 @@ from icefall.decode import (
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_env_info,
get_texts, get_texts,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
@ -77,6 +79,9 @@ def get_parser():
default="attention-decoder", default="attention-decoder",
help="""Decoding method. help="""Decoding method.
Supported values are: Supported values are:
- (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
It needs neither a lexicon nor an n-gram LM.
- (1) 1best. Extract the best path from the decoding lattice as the - (1) 1best. Extract the best path from the decoding lattice as the
decoding result. decoding result.
- (2) nbest. Extract n paths from the decoding lattice; the path - (2) nbest. Extract n paths from the decoding lattice; the path
@ -106,7 +111,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--lattice-score-scale", "--nbest-scale",
type=float, type=float,
default=0.5, default=0.5,
help="""The scale to be applied to `lattice.scores`. help="""The scale to be applied to `lattice.scores`.
@ -128,14 +133,26 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_ctc/exp",
help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe_5000",
help="The lang dir",
)
return parser return parser
def get_params() -> AttributeDict: def get_params() -> AttributeDict:
params = AttributeDict( params = AttributeDict(
{ {
"exp_dir": Path("conformer_ctc/exp"),
"lang_dir": Path("data/lang_bpe"),
"lm_dir": Path("data/lm"), "lm_dir": Path("data/lm"),
# parameters for conformer # parameters for conformer
"subsampling_factor": 4, "subsampling_factor": 4,
@ -151,6 +168,7 @@ def get_params() -> AttributeDict:
"min_active_states": 30, "min_active_states": 30,
"max_active_states": 10000, "max_active_states": 10000,
"use_double_scores": True, "use_double_scores": True,
"env_info": get_env_info(),
} }
) )
return params return params
@ -159,13 +177,15 @@ def get_params() -> AttributeDict:
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
HLG: k2.Fsa, HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
batch: dict, batch: dict,
word_table: k2.SymbolTable, word_table: k2.SymbolTable,
sos_id: int, sos_id: int,
eos_id: int, eos_id: int,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -190,7 +210,11 @@ def decode_one_batch(
model: model:
The neural model. The neural model.
HLG: HLG:
The decoding graph. The decoding graph. Used only when params.method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.method is ctc-decoding.
bpe_model:
The BPE model. Used only when params.method is ctc-decoding.
batch: batch:
It is the return value from iterating It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
@ -209,7 +233,10 @@ def decode_one_batch(
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
""" """
device = HLG.device if HLG is not None:
device = HLG.device
else:
device = H.device
feature = batch["inputs"] feature = batch["inputs"]
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device) feature = feature.to(device)
@ -229,9 +256,17 @@ def decode_one_batch(
1, 1,
).to(torch.int32) ).to(torch.int32)
if H is None:
assert HLG is not None
decoding_graph = HLG
else:
assert HLG is None
assert bpe_model is not None
decoding_graph = H
lattice = get_lattice( lattice = get_lattice(
nnet_output=nnet_output, nnet_output=nnet_output,
HLG=HLG, decoding_graph=decoding_graph,
supervision_segments=supervision_segments, supervision_segments=supervision_segments,
search_beam=params.search_beam, search_beam=params.search_beam,
output_beam=params.output_beam, output_beam=params.output_beam,
@ -240,6 +275,24 @@ def decode_one_batch(
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
) )
if params.method == "ctc-decoding":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
# Note: `best_path.aux_labels` contains token IDs, not word IDs
# since we are using H, not HLG here.
#
# token_ids is a lit-of-list of IDs
token_ids = get_texts(best_path)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
key = "ctc-decoding"
return {key: hyps}
if params.method == "nbest-oracle": if params.method == "nbest-oracle":
# Note: You can also pass rescored lattices to it. # Note: You can also pass rescored lattices to it.
# We choose the HLG decoded lattice for speed reasons # We choose the HLG decoded lattice for speed reasons
@ -250,12 +303,12 @@ def decode_one_batch(
num_paths=params.num_paths, num_paths=params.num_paths,
ref_texts=supervisions["text"], ref_texts=supervisions["text"],
word_table=word_table, word_table=word_table,
lattice_score_scale=params.lattice_score_scale, nbest_scale=params.nbest_scale,
oov="<UNK>", oov="<UNK>",
) )
hyps = get_texts(best_path) hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps] hyps = [[word_table[i] for i in ids] for ids in hyps]
key = f"oracle_{params.num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
return {key: hyps} return {key: hyps}
if params.method in ["1best", "nbest"]: if params.method in ["1best", "nbest"]:
@ -269,9 +322,9 @@ def decode_one_batch(
lattice=lattice, lattice=lattice,
num_paths=params.num_paths, num_paths=params.num_paths,
use_double_scores=params.use_double_scores, use_double_scores=params.use_double_scores,
lattice_score_scale=params.lattice_score_scale, nbest_scale=params.nbest_scale,
) )
key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
hyps = get_texts(best_path) hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps] hyps = [[word_table[i] for i in ids] for ids in hyps]
@ -293,7 +346,7 @@ def decode_one_batch(
G=G, G=G,
num_paths=params.num_paths, num_paths=params.num_paths,
lm_scale_list=lm_scale_list, lm_scale_list=lm_scale_list,
lattice_score_scale=params.lattice_score_scale, nbest_scale=params.nbest_scale,
) )
elif params.method == "whole-lattice-rescoring": elif params.method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice( best_path_dict = rescore_with_whole_lattice(
@ -319,7 +372,7 @@ def decode_one_batch(
memory_key_padding_mask=memory_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id, sos_id=sos_id,
eos_id=eos_id, eos_id=eos_id,
lattice_score_scale=params.lattice_score_scale, nbest_scale=params.nbest_scale,
) )
else: else:
assert False, f"Unsupported decoding method: {params.method}" assert False, f"Unsupported decoding method: {params.method}"
@ -340,12 +393,14 @@ def decode_dataset(
dl: torch.utils.data.DataLoader, dl: torch.utils.data.DataLoader,
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
HLG: k2.Fsa, HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
word_table: k2.SymbolTable, word_table: k2.SymbolTable,
sos_id: int, sos_id: int,
eos_id: int, eos_id: int,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[int], List[int]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
Args: Args:
@ -356,7 +411,11 @@ def decode_dataset(
model: model:
The neural model. The neural model.
HLG: HLG:
The decoding graph. The decoding graph. Used only when params.method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.method is ctc-decoding.
bpe_model:
The BPE model. Used only when params.method is ctc-decoding.
word_table: word_table:
It is the word symbol table. It is the word symbol table.
sos_id: sos_id:
@ -391,6 +450,8 @@ def decode_dataset(
params=params, params=params,
model=model, model=model,
HLG=HLG, HLG=HLG,
H=H,
bpe_model=bpe_model,
batch=batch, batch=batch,
word_table=word_table, word_table=word_table,
G=G, G=G,
@ -469,6 +530,8 @@ def main():
parser = get_parser() parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser) LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
@ -496,14 +559,26 @@ def main():
sos_id = graph_compiler.sos_id sos_id = graph_compiler.sos_id
eos_id = graph_compiler.eos_id eos_id = graph_compiler.eos_id
HLG = k2.Fsa.from_dict( if params.method == "ctc-decoding":
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") HLG = None
) H = k2.ctc_topo(
HLG = HLG.to(device) max_token=max_token_id,
assert HLG.requires_grad is False modified=False,
device=device,
)
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(str(params.lang_dir / "bpe.model"))
else:
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
)
HLG = HLG.to(device)
assert HLG.requires_grad is False
if not hasattr(HLG, "lm_scores"): if not hasattr(HLG, "lm_scores"):
HLG.lm_scores = HLG.scores.clone() HLG.lm_scores = HLG.scores.clone()
if params.method in ( if params.method in (
"nbest-rescoring", "nbest-rescoring",
@ -593,6 +668,8 @@ def main():
params=params, params=params,
model=model, model=model,
HLG=HLG, HLG=HLG,
H=H,
bpe_model=bpe_model,
word_table=lexicon.word_table, word_table=lexicon.word_table,
G=G, G=G,
sos_id=sos_id, sos_id=sos_id,

View File

@ -0,0 +1,165 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
import argparse
import logging
from pathlib import Path
import torch
from conformer import Conformer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=34,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=20,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_ctc/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe_5000",
help="""It contains language related input files such as "lexicon.txt"
""",
)
parser.add_argument(
"--jit",
type=str2bool,
default=True,
help="""True to save a model after applying torch.jit.script.
""",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"feature_dim": 80,
"subsampling_factor": 4,
"use_feat_batchnorm": True,
"attention_dim": 512,
"nhead": 8,
"num_decoder_layers": 6,
}
)
return params
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
params = get_params()
params.update(vars(args))
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
model = Conformer(
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=num_classes,
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=False,
use_feat_batchnorm=params.use_feat_batchnorm,
)
model.to(device)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
model.to("cpu")
model.eval()
if params.jit:
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Mingshuang Luo)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -23,6 +24,7 @@ from typing import List
import k2 import k2
import kaldifeat import kaldifeat
import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from conformer import Conformer from conformer import Conformer
@ -34,7 +36,7 @@ from icefall.decode import (
rescore_with_attention_decoder, rescore_with_attention_decoder,
rescore_with_whole_lattice, rescore_with_whole_lattice,
) )
from icefall.utils import AttributeDict, get_texts from icefall.utils import AttributeDict, get_env_info, get_texts
def get_parser(): def get_parser():
@ -54,12 +56,25 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--words-file", "--words-file",
type=str, type=str,
required=True, help="""Path to words.txt.
help="Path to words.txt", Used only when method is not ctc-decoding.
""",
) )
parser.add_argument( parser.add_argument(
"--HLG", type=str, required=True, help="Path to HLG.pt." "--HLG",
type=str,
help="""Path to HLG.pt.
Used only when method is not ctc-decoding.
""",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.
Used only when method is ctc-decoding.
""",
) )
parser.add_argument( parser.add_argument(
@ -68,6 +83,10 @@ def get_parser():
default="1best", default="1best",
help="""Decoding method. help="""Decoding method.
Possible values are: Possible values are:
(0) ctc-decoding - Use CTC decoding. It uses a sentence
piece model, i.e., lang_dir/bpe.model, to convert
word pieces to words. It needs neither a lexicon
nor an n-gram LM.
(1) 1best - Use the best path as decoding output. Only (1) 1best - Use the best path as decoding output. Only
the transformer encoder output is used for decoding. the transformer encoder output is used for decoding.
We call it HLG decoding. We call it HLG decoding.
@ -125,7 +144,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--lattice-score-scale", "--nbest-scale",
type=float, type=float,
default=0.5, default=0.5,
help=""" help="""
@ -139,7 +158,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--sos-id", "--sos-id",
type=float, type=int,
default=1, default=1,
help=""" help="""
Used only when method is attention-decoder. Used only when method is attention-decoder.
@ -147,9 +166,18 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--num-classes",
type=int,
default=5000,
help="""
Vocab size in the BPE model.
""",
)
parser.add_argument( parser.add_argument(
"--eos-id", "--eos-id",
type=float, type=int,
default=1, default=1,
help=""" help="""
Used only when method is attention-decoder. Used only when method is attention-decoder.
@ -180,7 +208,6 @@ def get_params() -> AttributeDict:
"use_feat_batchnorm": True, "use_feat_batchnorm": True,
"feature_dim": 80, "feature_dim": 80,
"nhead": 8, "nhead": 8,
"num_classes": 5000,
"attention_dim": 512, "attention_dim": 512,
"num_decoder_layers": 6, "num_decoder_layers": 6,
# parameters for decoding # parameters for decoding
@ -223,7 +250,13 @@ def main():
args = parser.parse_args() args = parser.parse_args()
params = get_params() params = get_params()
if args.method != "attention-decoder":
# to save memory as the attention decoder
# will not be used
params.num_decoder_layers = 0
params.update(vars(args)) params.update(vars(args))
params["env_info"] = get_env_info()
logging.info(f"{params}") logging.info(f"{params}")
device = torch.device("cpu") device = torch.device("cpu")
@ -245,27 +278,10 @@ def main():
) )
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"]) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()
logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
HLG.lm_scores = HLG.scores.clone()
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
logging.info(f"Loading G from {params.G}")
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = G.to(device)
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G.lm_scores = G.scores.clone()
logging.info("Constructing Fbank computer") logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions() opts = kaldifeat.FbankOptions()
opts.device = device opts.device = device
@ -299,52 +315,108 @@ def main():
dtype=torch.int32, dtype=torch.int32,
) )
lattice = get_lattice( if params.method == "ctc-decoding":
nnet_output=nnet_output, logging.info("Use CTC decoding")
HLG=HLG, bpe_model = spm.SentencePieceProcessor()
supervision_segments=supervision_segments, bpe_model.load(params.bpe_model)
search_beam=params.search_beam, max_token_id = params.num_classes - 1
output_beam=params.output_beam,
min_active_states=params.min_active_states, H = k2.ctc_topo(
max_active_states=params.max_active_states, max_token=max_token_id,
subsampling_factor=params.subsampling_factor, modified=False,
) device=device,
)
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=H,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
if params.method == "1best":
logging.info("Use HLG decoding")
best_path = one_best_decoding( best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores lattice=lattice, use_double_scores=params.use_double_scores
) )
elif params.method == "whole-lattice-rescoring": token_ids = get_texts(best_path)
logging.info("Use HLG decoding + LM rescoring") hyps = bpe_model.decode(token_ids)
best_path_dict = rescore_with_whole_lattice( hyps = [s.split() for s in hyps]
lattice=lattice, elif params.method in [
G_with_epsilon_loops=G, "1best",
lm_scale_list=[params.ngram_lm_scale], "whole-lattice-rescoring",
) "attention-decoder",
best_path = next(iter(best_path_dict.values())) ]:
elif params.method == "attention-decoder": logging.info(f"Loading HLG from {params.HLG}")
logging.info("Use HLG + LM rescoring + attention decoder rescoring") HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
rescored_lattice = rescore_with_whole_lattice( HLG = HLG.to(device)
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None if not hasattr(HLG, "lm_scores"):
) # For whole-lattice-rescoring and attention-decoder
best_path_dict = rescore_with_attention_decoder( HLG.lm_scores = HLG.scores.clone()
lattice=rescored_lattice,
num_paths=params.num_paths,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=params.sos_id,
eos_id=params.eos_id,
lattice_score_scale=params.lattice_score_scale,
ngram_lm_scale=params.ngram_lm_scale,
attention_scale=params.attention_decoder_scale,
)
best_path = next(iter(best_path_dict.values()))
hyps = get_texts(best_path) if params.method in [
word_sym_table = k2.SymbolTable.from_file(params.words_file) "whole-lattice-rescoring",
hyps = [[word_sym_table[i] for i in ids] for ids in hyps] "attention-decoder",
]:
logging.info(f"Loading G from {params.G}")
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = G.to(device)
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G.lm_scores = G.scores.clone()
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
if params.method == "1best":
logging.info("Use HLG decoding")
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
elif params.method == "whole-lattice-rescoring":
logging.info("Use HLG decoding + LM rescoring")
best_path_dict = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=[params.ngram_lm_scale],
)
best_path = next(iter(best_path_dict.values()))
elif params.method == "attention-decoder":
logging.info("Use HLG + LM rescoring + attention decoder rescoring")
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
)
best_path_dict = rescore_with_attention_decoder(
lattice=rescored_lattice,
num_paths=params.num_paths,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=params.sos_id,
eos_id=params.eos_id,
nbest_scale=params.nbest_scale,
ngram_lm_scale=params.ngram_lm_scale,
attention_scale=params.attention_decoder_scale,
)
best_path = next(iter(best_path_dict.values()))
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file)
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
else:
raise ValueError(f"Unsupported decoding method: {params.method}")
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang) # Wei Kang
# Mingshuang Luo)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -21,16 +22,16 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Optional from typing import Optional, Tuple
import k2 import k2
import torch import torch
import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer from conformer import Conformer
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -43,7 +44,9 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker,
encode_supervisions, encode_supervisions,
get_env_info,
setup_logger, setup_logger,
str2bool, str2bool,
) )
@ -75,6 +78,13 @@ def get_parser():
help="Should various information be logged in tensorboard.", help="Should various information be logged in tensorboard.",
) )
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe_5000",
help="lang directory",
)
parser.add_argument( parser.add_argument(
"--num-epochs", "--num-epochs",
type=int, type=int,
@ -92,6 +102,26 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_ctc/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe_5000",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
return parser return parser
@ -106,12 +136,6 @@ def get_params() -> AttributeDict:
Explanation of options saved in `params`: Explanation of options saved in `params`:
- exp_dir: It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
- lang_dir: It contains language related input files such as
"lexicon.txt"
- best_train_loss: Best training loss so far. It is used to select - best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is the model that has the lowest training loss. It is
updated during the training. updated during the training.
@ -162,14 +186,12 @@ def get_params() -> AttributeDict:
""" """
params = AttributeDict( params = AttributeDict(
{ {
"exp_dir": Path("conformer_ctc/exp"),
"lang_dir": Path("data/lang_bpe"),
"best_train_loss": float("inf"), "best_train_loss": float("inf"),
"best_valid_loss": float("inf"), "best_valid_loss": float("inf"),
"best_train_epoch": -1, "best_train_epoch": -1,
"best_valid_epoch": -1, "best_valid_epoch": -1,
"batch_idx_train": 0, "batch_idx_train": 0,
"log_interval": 10, "log_interval": 50,
"reset_interval": 200, "reset_interval": 200,
"valid_interval": 3000, "valid_interval": 3000,
# parameters for conformer # parameters for conformer
@ -188,6 +210,7 @@ def get_params() -> AttributeDict:
"weight_decay": 1e-6, "weight_decay": 1e-6,
"lr_factor": 5.0, "lr_factor": 5.0,
"warm_step": 80000, "warm_step": 80000,
"env_info": get_env_info(),
} }
) )
@ -287,7 +310,7 @@ def compute_loss(
batch: dict, batch: dict,
graph_compiler: BpeCtcTrainingGraphCompiler, graph_compiler: BpeCtcTrainingGraphCompiler,
is_training: bool, is_training: bool,
): ) -> Tuple[Tensor, MetricsTracker]:
""" """
Compute CTC loss given the model and its inputs. Compute CTC loss given the model and its inputs.
@ -367,15 +390,17 @@ def compute_loss(
loss = ctc_loss loss = ctc_loss
att_loss = torch.tensor([0]) att_loss = torch.tensor([0])
# train_frames and valid_frames are used for printing.
if is_training:
params.train_frames = supervision_segments[:, 2].sum().item()
else:
params.valid_frames = supervision_segments[:, 2].sum().item()
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
return loss, ctc_loss.detach(), att_loss.detach() info = MetricsTracker()
info["frames"] = supervision_segments[:, 2].sum().item()
info["ctc_loss"] = ctc_loss.detach().cpu().item()
if params.att_rate != 0.0:
info["att_loss"] = att_loss.detach().cpu().item()
info["loss"] = loss.detach().cpu().item()
return loss, info
def compute_validation_loss( def compute_validation_loss(
@ -384,18 +409,14 @@ def compute_validation_loss(
graph_compiler: BpeCtcTrainingGraphCompiler, graph_compiler: BpeCtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
world_size: int = 1, world_size: int = 1,
) -> None: ) -> MetricsTracker:
"""Run the validation process. The validation loss """Run the validation process."""
is saved in `params.valid_loss`.
"""
model.eval() model.eval()
tot_loss = 0.0 tot_loss = MetricsTracker()
tot_ctc_loss = 0.0
tot_att_loss = 0.0
tot_frames = 0.0
for batch_idx, batch in enumerate(valid_dl): for batch_idx, batch in enumerate(valid_dl):
loss, ctc_loss, att_loss = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
batch=batch, batch=batch,
@ -403,36 +424,17 @@ def compute_validation_loss(
is_training=False, is_training=False,
) )
assert loss.requires_grad is False assert loss.requires_grad is False
assert ctc_loss.requires_grad is False tot_loss = tot_loss + loss_info
assert att_loss.requires_grad is False
loss_cpu = loss.detach().cpu().item()
tot_loss += loss_cpu
tot_ctc_loss += ctc_loss.detach().cpu().item()
tot_att_loss += att_loss.detach().cpu().item()
tot_frames += params.valid_frames
if world_size > 1: if world_size > 1:
s = torch.tensor( tot_loss.reduce(loss.device)
[tot_loss, tot_ctc_loss, tot_att_loss, tot_frames],
device=loss.device,
)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
s = s.cpu().tolist()
tot_loss = s[0]
tot_ctc_loss = s[1]
tot_att_loss = s[2]
tot_frames = s[3]
params.valid_loss = tot_loss / tot_frames loss_value = tot_loss["loss"] / tot_loss["frames"]
params.valid_ctc_loss = tot_ctc_loss / tot_frames if loss_value < params.best_valid_loss:
params.valid_att_loss = tot_att_loss / tot_frames
if params.valid_loss < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = params.valid_loss params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch( def train_one_epoch(
@ -471,24 +473,21 @@ def train_one_epoch(
""" """
model.train() model.train()
tot_loss = 0.0 # sum of losses over all batches tot_loss = MetricsTracker()
tot_ctc_loss = 0.0
tot_att_loss = 0.0
tot_frames = 0.0 # sum of frames over all batches
params.tot_loss = 0.0
params.tot_frames = 0.0
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
loss, ctc_loss, att_loss = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
batch=batch, batch=batch,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
is_training=True, is_training=True,
) )
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances # NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far. # in the batch and there is no normalization to it so far.
@ -498,75 +497,26 @@ def train_one_epoch(
clip_grad_norm_(model.parameters(), 5.0, 2.0) clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step() optimizer.step()
loss_cpu = loss.detach().cpu().item()
ctc_loss_cpu = ctc_loss.detach().cpu().item()
att_loss_cpu = att_loss.detach().cpu().item()
tot_frames += params.train_frames
tot_loss += loss_cpu
tot_ctc_loss += ctc_loss_cpu
tot_att_loss += att_loss_cpu
params.tot_frames += params.train_frames
params.tot_loss += loss_cpu
tot_avg_loss = tot_loss / tot_frames
tot_avg_ctc_loss = tot_ctc_loss / tot_frames
tot_avg_att_loss = tot_att_loss / tot_frames
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
logging.info( logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, " f"Epoch {params.cur_epoch}, "
f"batch avg ctc loss {ctc_loss_cpu/params.train_frames:.4f}, " f"batch {batch_idx}, loss[{loss_info}], "
f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, " f"tot_loss[{tot_loss}], batch size: {batch_size}"
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, "
f"total avg att loss: {tot_avg_att_loss:.4f}, "
f"total avg loss: {tot_avg_loss:.4f}, "
f"batch size: {batch_size}"
) )
if batch_idx % params.log_interval == 0:
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar( loss_info.write_summary(
"train/current_ctc_loss", tb_writer, "train/current_", params.batch_idx_train
ctc_loss_cpu / params.train_frames,
params.batch_idx_train,
) )
tb_writer.add_scalar( tot_loss.write_summary(
"train/current_att_loss", tb_writer, "train/tot_", params.batch_idx_train
att_loss_cpu / params.train_frames,
params.batch_idx_train,
) )
tb_writer.add_scalar(
"train/current_loss",
loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_ctc_loss",
tot_avg_ctc_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_att_loss",
tot_avg_att_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_loss",
tot_avg_loss,
params.batch_idx_train,
)
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
tot_loss = 0.0 # sum of losses over all batches
tot_ctc_loss = 0.0
tot_att_loss = 0.0
tot_frames = 0.0 # sum of frames over all batches
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
compute_validation_loss( logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params, params=params,
model=model, model=model,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
@ -574,33 +524,14 @@ def train_one_epoch(
world_size=world_size, world_size=world_size,
) )
model.train() model.train()
logging.info( logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
f"Epoch {params.cur_epoch}, "
f"valid ctc loss {params.valid_ctc_loss:.4f},"
f"valid att loss {params.valid_att_loss:.4f},"
f"valid loss {params.valid_loss:.4f},"
f" best valid loss: {params.best_valid_loss:.4f} "
f"best valid epoch: {params.best_valid_epoch}"
)
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar( valid_info.write_summary(
"train/valid_ctc_loss", tb_writer, "train/valid_", params.batch_idx_train
params.valid_ctc_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/valid_att_loss",
params.valid_att_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/valid_loss",
params.valid_loss,
params.batch_idx_train,
) )
params.train_loss = params.tot_loss / params.tot_frames loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss: if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss params.best_train_loss = params.train_loss
@ -726,6 +657,8 @@ def main():
parser = get_parser() parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser) LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
world_size = args.world_size world_size = args.world_size
assert world_size >= 1 assert world_size >= 1

View File

@ -236,6 +236,7 @@ class Transformer(nn.Module):
x = nn.functional.log_softmax(x, dim=-1) # (N, T, C) x = nn.functional.log_softmax(x, dim=-1) # (N, T, C)
return x return x
@torch.jit.export
def decoder_forward( def decoder_forward(
self, self,
memory: torch.Tensor, memory: torch.Tensor,
@ -264,11 +265,15 @@ class Transformer(nn.Module):
""" """
ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in] ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) ys_in_pad = pad_sequence(
ys_in, batch_first=True, padding_value=float(eos_id)
)
ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out] ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1) ys_out_pad = pad_sequence(
ys_out, batch_first=True, padding_value=float(-1)
)
device = memory.device device = memory.device
ys_in_pad = ys_in_pad.to(device) ys_in_pad = ys_in_pad.to(device)
@ -301,6 +306,7 @@ class Transformer(nn.Module):
return decoder_loss return decoder_loss
@torch.jit.export
def decoder_nll( def decoder_nll(
self, self,
memory: torch.Tensor, memory: torch.Tensor,
@ -331,11 +337,15 @@ class Transformer(nn.Module):
ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in] ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) ys_in_pad = pad_sequence(
ys_in, batch_first=True, padding_value=float(eos_id)
)
ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out] ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1) ys_out_pad = pad_sequence(
ys_out, batch_first=True, padding_value=float(-1)
)
device = memory.device device = memory.device
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
@ -649,7 +659,8 @@ class PositionalEncoding(nn.Module):
self.d_model = d_model self.d_model = d_model
self.xscale = math.sqrt(self.d_model) self.xscale = math.sqrt(self.d_model)
self.dropout = nn.Dropout(p=dropout) self.dropout = nn.Dropout(p=dropout)
self.pe = None # not doing: self.pe = None because of errors thrown by torchscript
self.pe = torch.zeros(0, 0, dtype=torch.float32)
def extend_pe(self, x: torch.Tensor) -> None: def extend_pe(self, x: torch.Tensor) -> None:
"""Extend the time t in the positional encoding if required. """Extend the time t in the positional encoding if required.
@ -666,8 +677,7 @@ class PositionalEncoding(nn.Module):
""" """
if self.pe is not None: if self.pe is not None:
if self.pe.size(1) >= x.size(1): if self.pe.size(1) >= x.size(1):
if self.pe.dtype != x.dtype or self.pe.device != x.device: self.pe = self.pe.to(dtype=x.dtype, device=x.device)
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return return
pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
@ -972,10 +982,7 @@ def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]:
Return a new list-of-list, where each sublist starts Return a new list-of-list, where each sublist starts
with SOS ID. with SOS ID.
""" """
ans = [] return [[sos_id] + utt for utt in token_ids]
for utt in token_ids:
ans.append([sos_id] + utt)
return ans
def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
@ -992,7 +999,4 @@ def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
Return a new list-of-list, where each sublist ends Return a new list-of-list, where each sublist ends
with EOS ID. with EOS ID.
""" """
ans = [] return [utt + [eos_id] for utt in token_ids]
for utt in token_ids:
ans.append(utt + [eos_id])
return ans

View File

@ -40,9 +40,9 @@ dl_dir=$PWD/download
# It will generate data/lang_bpe_xxx, # It will generate data/lang_bpe_xxx,
# data/lang_bpe_yyy if the array contains xxx, yyy # data/lang_bpe_yyy if the array contains xxx, yyy
vocab_sizes=( vocab_sizes=(
# 5000 5000
# 2000 2000
# 1000 1000
500 500
) )
@ -59,13 +59,13 @@ log() {
log "dl_dir: $dl_dir" log "dl_dir: $dl_dir"
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "stage -1: Download LM" log "Stage -1: Download LM"
[ ! -e $dl_dir/lm ] && mkdir -p $dl_dir/lm [ ! -e $dl_dir/lm ] && mkdir -p $dl_dir/lm
./local/download_lm.py --out-dir=$dl_dir/lm ./local/download_lm.py --out-dir=$dl_dir/lm
fi fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "stage 0: Download data" log "Stage 0: Download data"
# If you have pre-downloaded it to /path/to/LibriSpeech, # If you have pre-downloaded it to /path/to/LibriSpeech,
# you can create a symlink # you can create a symlink
@ -130,7 +130,7 @@ fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "State 6: Prepare BPE based lang" log "Stage 6: Prepare BPE based lang"
for vocab_size in ${vocab_sizes[@]}; do for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size} lang_dir=data/lang_bpe_${vocab_size}

View File

@ -162,7 +162,9 @@ class LibriSpeechAsrDataModule(DataModule):
cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz") cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz")
logging.info("About to create train dataset") logging.info("About to create train dataset")
transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))] transforms = [
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
]
if self.args.concatenate_cuts: if self.args.concatenate_cuts:
logging.info( logging.info(
f"Using cut concatenation with duration factor " f"Using cut concatenation with duration factor "
@ -267,7 +269,7 @@ class LibriSpeechAsrDataModule(DataModule):
cut_transforms=transforms, cut_transforms=transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
valid_sampler = SingleCutSampler( valid_sampler = BucketingSampler(
cuts_valid, cuts_valid,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=False, shuffle=False,
@ -300,12 +302,15 @@ class LibriSpeechAsrDataModule(DataModule):
else PrecomputedFeatures(), else PrecomputedFeatures(),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
sampler = SingleCutSampler( sampler = BucketingSampler(
cuts_test, max_duration=self.args.max_duration cuts_test, max_duration=self.args.max_duration, shuffle=False
) )
logging.debug("About to create test dataloader") logging.debug("About to create test dataloader")
test_dl = DataLoader( test_dl = DataLoader(
test, batch_size=None, sampler=sampler, num_workers=1 test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
) )
test_loaders.append(test_dl) test_loaders.append(test_dl)

View File

@ -39,6 +39,7 @@ from icefall.decode import (
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_env_info,
get_texts, get_texts,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
@ -97,7 +98,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--lattice-score-scale", "--nbest-scale",
type=float, type=float,
default=0.5, default=0.5,
help="""The scale to be applied to `lattice.scores`. help="""The scale to be applied to `lattice.scores`.
@ -134,6 +135,7 @@ def get_params() -> AttributeDict:
"min_active_states": 30, "min_active_states": 30,
"max_active_states": 10000, "max_active_states": 10000,
"use_double_scores": True, "use_double_scores": True,
"env_info": get_env_info(),
} }
) )
return params return params
@ -146,7 +148,7 @@ def decode_one_batch(
batch: dict, batch: dict,
lexicon: Lexicon, lexicon: Lexicon,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -210,7 +212,7 @@ def decode_one_batch(
lattice = get_lattice( lattice = get_lattice(
nnet_output=nnet_output, nnet_output=nnet_output,
HLG=HLG, decoding_graph=HLG,
supervision_segments=supervision_segments, supervision_segments=supervision_segments,
search_beam=params.search_beam, search_beam=params.search_beam,
output_beam=params.output_beam, output_beam=params.output_beam,
@ -229,7 +231,7 @@ def decode_one_batch(
lattice=lattice, lattice=lattice,
num_paths=params.num_paths, num_paths=params.num_paths,
use_double_scores=params.use_double_scores, use_double_scores=params.use_double_scores,
lattice_score_scale=params.lattice_score_scale, nbest_scale=params.nbest_scale,
) )
key = f"no_rescore-{params.num_paths}" key = f"no_rescore-{params.num_paths}"
hyps = get_texts(best_path) hyps = get_texts(best_path)
@ -248,7 +250,7 @@ def decode_one_batch(
G=G, G=G,
num_paths=params.num_paths, num_paths=params.num_paths,
lm_scale_list=lm_scale_list, lm_scale_list=lm_scale_list,
lattice_score_scale=params.lattice_score_scale, nbest_scale=params.nbest_scale,
) )
else: else:
best_path_dict = rescore_with_whole_lattice( best_path_dict = rescore_with_whole_lattice(
@ -272,7 +274,7 @@ def decode_dataset(
HLG: k2.Fsa, HLG: k2.Fsa,
lexicon: Lexicon, lexicon: Lexicon,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[int], List[int]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
Args: Args:

View File

@ -34,7 +34,7 @@ from icefall.decode import (
one_best_decoding, one_best_decoding,
rescore_with_whole_lattice, rescore_with_whole_lattice,
) )
from icefall.utils import AttributeDict, get_texts from icefall.utils import AttributeDict, get_env_info, get_texts
def get_parser(): def get_parser():
@ -159,6 +159,7 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
params["env_info"] = get_env_info()
logging.info(f"{params}") logging.info(f"{params}")
device = torch.device("cpu") device = torch.device("cpu")
@ -232,7 +233,7 @@ def main():
lattice = get_lattice( lattice = get_lattice(
nnet_output=nnet_output, nnet_output=nnet_output,
HLG=HLG, decoding_graph=HLG,
supervision_segments=supervision_segments, supervision_segments=supervision_segments,
search_beam=params.search_beam, search_beam=params.search_beam,
output_beam=params.output_beam, output_beam=params.output_beam,

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -20,17 +21,17 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Optional from typing import Optional, Tuple
import k2 import k2
import torch import torch
import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import TdnnLstm from model import TdnnLstm
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import StepLR from torch.optim.lr_scheduler import StepLR
@ -43,7 +44,9 @@ from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker,
encode_supervisions, encode_supervisions,
get_env_info,
setup_logger, setup_logger,
str2bool, str2bool,
) )
@ -168,6 +171,7 @@ def get_params() -> AttributeDict:
"beam_size": 10, "beam_size": 10,
"reduction": "sum", "reduction": "sum",
"use_double_scores": True, "use_double_scores": True,
"env_info": get_env_info(),
} }
) )
@ -267,7 +271,7 @@ def compute_loss(
batch: dict, batch: dict,
graph_compiler: CtcTrainingGraphCompiler, graph_compiler: CtcTrainingGraphCompiler,
is_training: bool, is_training: bool,
): ) -> Tuple[Tensor, MetricsTracker]:
""" """
Compute CTC loss given the model and its inputs. Compute CTC loss given the model and its inputs.
@ -324,13 +328,11 @@ def compute_loss(
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
# train_frames and valid_frames are used for printing. info = MetricsTracker()
if is_training: info["frames"] = supervision_segments[:, 2].sum().item()
params.train_frames = supervision_segments[:, 2].sum().item() info["loss"] = loss.detach().cpu().item()
else:
params.valid_frames = supervision_segments[:, 2].sum().item()
return loss return loss, info
def compute_validation_loss( def compute_validation_loss(
@ -339,16 +341,16 @@ def compute_validation_loss(
graph_compiler: CtcTrainingGraphCompiler, graph_compiler: CtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
world_size: int = 1, world_size: int = 1,
) -> None: ) -> MetricsTracker:
"""Run the validation process. The validation loss """Run the validation process. The validation loss
is saved in `params.valid_loss`. is saved in `params.valid_loss`.
""" """
model.eval() model.eval()
tot_loss = 0.0 tot_loss = MetricsTracker()
tot_frames = 0.0
for batch_idx, batch in enumerate(valid_dl): for batch_idx, batch in enumerate(valid_dl):
loss = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
batch=batch, batch=batch,
@ -357,22 +359,18 @@ def compute_validation_loss(
) )
assert loss.requires_grad is False assert loss.requires_grad is False
loss_cpu = loss.detach().cpu().item() tot_loss = tot_loss + loss_info
tot_loss += loss_cpu
tot_frames += params.valid_frames
if world_size > 1: if world_size > 1:
s = torch.tensor([tot_loss, tot_frames], device=loss.device) tot_loss.reduce(loss.device)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
s = s.cpu().tolist()
tot_loss = s[0]
tot_frames = s[1]
params.valid_loss = tot_loss / tot_frames loss_value = tot_loss["loss"] / tot_loss["frames"]
if params.valid_loss < params.best_valid_loss: if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = params.valid_loss params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch( def train_one_epoch(
@ -411,67 +409,45 @@ def train_one_epoch(
""" """
model.train() model.train()
tot_loss = 0.0 # reset after params.reset_interval of batches tot_loss = MetricsTracker()
tot_frames = 0.0 # reset after params.reset_interval of batches
params.tot_loss = 0.0
params.tot_frames = 0.0
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
loss = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
batch=batch, batch=batch,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
is_training=True, is_training=True,
) )
# summary stats.
# NOTE: We use reduction==sum and loss is computed over utterances tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# in the batch and there is no normalization to it so far.
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0) clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step() optimizer.step()
loss_cpu = loss.detach().cpu().item()
tot_frames += params.train_frames
tot_loss += loss_cpu
tot_avg_loss = tot_loss / tot_frames
params.tot_frames += params.train_frames
params.tot_loss += loss_cpu
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
logging.info( logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, " f"Epoch {params.cur_epoch}, "
f"batch avg loss {loss_cpu/params.train_frames:.4f}, " f"batch {batch_idx}, loss[{loss_info}], "
f"total avg loss: {tot_avg_loss:.4f}, " f"tot_loss[{tot_loss}], batch size: {batch_size}"
f"batch size: {batch_size}"
) )
if batch_idx % params.log_interval == 0:
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar( loss_info.write_summary(
"train/current_loss", tb_writer, "train/current_", params.batch_idx_train
loss_cpu / params.train_frames,
params.batch_idx_train,
) )
tot_loss.write_summary(
tb_writer.add_scalar( tb_writer, "train/tot_", params.batch_idx_train
"train/tot_avg_loss",
tot_avg_loss,
params.batch_idx_train,
) )
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
tot_loss = 0
tot_frames = 0
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
compute_validation_loss( valid_info = compute_validation_loss(
params=params, params=params,
model=model, model=model,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
@ -479,13 +455,16 @@ def train_one_epoch(
world_size=world_size, world_size=world_size,
) )
model.train() model.train()
logging.info( logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}")
f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f}," if tb_writer is not None:
f" best valid loss: {params.best_valid_loss:.4f} " valid_info.write_summary(
f"best valid epoch: {params.best_valid_epoch}" tb_writer,
) "train/valid_",
params.batch_idx_train,
)
params.train_loss = params.tot_loss / params.tot_frames loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss: if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch params.best_train_epoch = params.cur_epoch

View File

@ -24,7 +24,7 @@ log() {
log "dl_dir: $dl_dir" log "dl_dir: $dl_dir"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "stage 0: Download data" log "Stage 0: Download data"
mkdir -p $dl_dir mkdir -p $dl_dir
if [ ! -f $dl_dir/waves_yesno/.completed ]; then if [ ! -f $dl_dir/waves_yesno/.completed ]; then

View File

@ -20,19 +20,18 @@ from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import List from typing import List
from torch.utils.data import DataLoader
from icefall.dataset.datamodule import DataModule
from icefall.utils import str2bool
from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import ( from lhotse.dataset import (
BucketingSampler, BucketingSampler,
CutConcatenate, CutConcatenate,
K2SpeechRecognitionDataset, K2SpeechRecognitionDataset,
PrecomputedFeatures, PrecomputedFeatures,
SingleCutSampler,
) )
from lhotse.dataset.input_strategies import OnTheFlyFeatures from lhotse.dataset.input_strategies import OnTheFlyFeatures
from torch.utils.data import DataLoader
from icefall.dataset.datamodule import DataModule
from icefall.utils import str2bool
class YesNoAsrDataModule(DataModule): class YesNoAsrDataModule(DataModule):
@ -198,7 +197,7 @@ class YesNoAsrDataModule(DataModule):
) )
else: else:
logging.info("Using SingleCutSampler.") logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler( train_sampler = BucketingSampler(
cuts_train, cuts_train,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=self.args.shuffle, shuffle=self.args.shuffle,
@ -226,12 +225,15 @@ class YesNoAsrDataModule(DataModule):
else PrecomputedFeatures(), else PrecomputedFeatures(),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
sampler = SingleCutSampler( sampler = BucketingSampler(
cuts_test, max_duration=self.args.max_duration cuts_test, max_duration=self.args.max_duration, shuffle=False
) )
logging.debug("About to create test dataloader") logging.debug("About to create test dataloader")
test_dl = DataLoader( test_dl = DataLoader(
test, batch_size=None, sampler=sampler, num_workers=1 test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
) )
return test_dl return test_dl

View File

@ -17,6 +17,7 @@ from icefall.decode import get_lattice, one_best_decoding
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_env_info,
get_texts, get_texts,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
@ -124,7 +125,7 @@ def decode_one_batch(
lattice = get_lattice( lattice = get_lattice(
nnet_output=nnet_output, nnet_output=nnet_output,
HLG=HLG, decoding_graph=HLG,
supervision_segments=supervision_segments, supervision_segments=supervision_segments,
search_beam=params.search_beam, search_beam=params.search_beam,
output_beam=params.output_beam, output_beam=params.output_beam,
@ -256,6 +257,7 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
params["env_info"] = get_env_info()
setup_logger(f"{params.exp_dir}/log/log-decode") setup_logger(f"{params.exp_dir}/log/log-decode")
logging.info("Decoding started") logging.info("Decoding started")

View File

@ -29,7 +29,7 @@ from model import Tdnn
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from icefall.decode import get_lattice, one_best_decoding from icefall.decode import get_lattice, one_best_decoding
from icefall.utils import AttributeDict, get_texts from icefall.utils import AttributeDict, get_env_info, get_texts
def get_parser(): def get_parser():
@ -116,6 +116,7 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
params["env_info"] = get_env_info()
logging.info(f"{params}") logging.info(f"{params}")
device = torch.device("cpu") device = torch.device("cpu")
@ -175,7 +176,7 @@ def main():
lattice = get_lattice( lattice = get_lattice(
nnet_output=nnet_output, nnet_output=nnet_output,
HLG=HLG, decoding_graph=HLG,
supervision_segments=supervision_segments, supervision_segments=supervision_segments,
search_beam=params.search_beam, search_beam=params.search_beam,
output_beam=params.output_beam, output_beam=params.output_beam,

View File

@ -4,17 +4,17 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Optional from typing import Optional, Tuple
import k2 import k2
import torch import torch
import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from asr_datamodule import YesNoAsrDataModule from asr_datamodule import YesNoAsrDataModule
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import Tdnn from model import Tdnn
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -24,7 +24,13 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
get_env_info,
setup_logger,
str2bool,
)
def get_parser(): def get_parser():
@ -122,6 +128,8 @@ def get_params() -> AttributeDict:
- valid_interval: Run validation if batch_idx % valid_interval` is 0 - valid_interval: Run validation if batch_idx % valid_interval` is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
- beam_size: It is used in k2.ctc_loss - beam_size: It is used in k2.ctc_loss
- reduction: It is used in k2.ctc_loss - reduction: It is used in k2.ctc_loss
@ -142,6 +150,7 @@ def get_params() -> AttributeDict:
"best_valid_epoch": -1, "best_valid_epoch": -1,
"batch_idx_train": 0, "batch_idx_train": 0,
"log_interval": 10, "log_interval": 10,
"reset_interval": 20,
"valid_interval": 10, "valid_interval": 10,
"beam_size": 10, "beam_size": 10,
"reduction": "sum", "reduction": "sum",
@ -245,7 +254,7 @@ def compute_loss(
batch: dict, batch: dict,
graph_compiler: CtcTrainingGraphCompiler, graph_compiler: CtcTrainingGraphCompiler,
is_training: bool, is_training: bool,
): ) -> Tuple[Tensor, MetricsTracker]:
""" """
Compute CTC loss given the model and its inputs. Compute CTC loss given the model and its inputs.
@ -305,13 +314,11 @@ def compute_loss(
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
# train_frames and valid_frames are used for printing. info = MetricsTracker()
if is_training: info["frames"] = supervision_segments[:, 2].sum().item()
params.train_frames = supervision_segments[:, 2].sum().item() info["loss"] = loss.detach().cpu().item()
else:
params.valid_frames = supervision_segments[:, 2].sum().item()
return loss return loss, info
def compute_validation_loss( def compute_validation_loss(
@ -320,16 +327,16 @@ def compute_validation_loss(
graph_compiler: CtcTrainingGraphCompiler, graph_compiler: CtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
world_size: int = 1, world_size: int = 1,
) -> None: ) -> MetricsTracker:
"""Run the validation process. The validation loss """Run the validation process. The validation loss
is saved in `params.valid_loss`. is saved in `params.valid_loss`.
""" """
model.eval() model.eval()
tot_loss = 0.0 tot_loss = MetricsTracker()
tot_frames = 0.0
for batch_idx, batch in enumerate(valid_dl): for batch_idx, batch in enumerate(valid_dl):
loss = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
batch=batch, batch=batch,
@ -338,22 +345,18 @@ def compute_validation_loss(
) )
assert loss.requires_grad is False assert loss.requires_grad is False
loss_cpu = loss.detach().cpu().item() tot_loss = tot_loss + loss_info
tot_loss += loss_cpu
tot_frames += params.valid_frames
if world_size > 1: if world_size > 1:
s = torch.tensor([tot_loss, tot_frames], device=loss.device) tot_loss.reduce(loss.device)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
s = s.cpu().tolist()
tot_loss = s[0]
tot_frames = s[1]
params.valid_loss = tot_loss / tot_frames loss_value = tot_loss["loss"] / tot_loss["frames"]
if params.valid_loss < params.best_valid_loss: if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = params.valid_loss params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch( def train_one_epoch(
@ -392,57 +395,45 @@ def train_one_epoch(
""" """
model.train() model.train()
tot_loss = 0.0 # sum of losses over all batches tot_loss = MetricsTracker()
tot_frames = 0.0 # sum of frames over all batches
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
loss = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
batch=batch, batch=batch,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
is_training=True, is_training=True,
) )
# summary stats.
# NOTE: We use reduction==sum and loss is computed over utterances tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# in the batch and there is no normalization to it so far.
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0) clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step() optimizer.step()
loss_cpu = loss.detach().cpu().item()
tot_frames += params.train_frames
tot_loss += loss_cpu
tot_avg_loss = tot_loss / tot_frames
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
logging.info( logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, " f"Epoch {params.cur_epoch}, "
f"batch avg loss {loss_cpu/params.train_frames:.4f}, " f"batch {batch_idx}, loss[{loss_info}], "
f"total avg loss: {tot_avg_loss:.4f}, " f"tot_loss[{tot_loss}], batch size: {batch_size}"
f"batch size: {batch_size}"
) )
if batch_idx % params.log_interval == 0:
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar( loss_info.write_summary(
"train/current_loss", tb_writer, "train/current_", params.batch_idx_train
loss_cpu / params.train_frames,
params.batch_idx_train,
) )
tot_loss.write_summary(
tb_writer.add_scalar( tb_writer, "train/tot_", params.batch_idx_train
"train/tot_avg_loss",
tot_avg_loss,
params.batch_idx_train,
) )
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
compute_validation_loss( valid_info = compute_validation_loss(
params=params, params=params,
model=model, model=model,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
@ -450,19 +441,16 @@ def train_one_epoch(
world_size=world_size, world_size=world_size,
) )
model.train() model.train()
logging.info( logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}")
f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f},"
f" best valid loss: {params.best_valid_loss:.4f} "
f"best valid epoch: {params.best_valid_epoch}"
)
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar( valid_info.write_summary(
"train/valid_loss", tb_writer,
params.valid_loss, "train/valid_",
params.batch_idx_train, params.batch_idx_train,
) )
params.train_loss = tot_loss / tot_frames loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss: if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch params.best_train_epoch = params.cur_epoch
@ -483,6 +471,7 @@ def run(rank, world_size, args):
""" """
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
params["env_info"] = get_env_info()
fix_random_seed(42) fix_random_seed(42)
if world_size > 1: if world_size > 1:

View File

@ -66,7 +66,7 @@ def _intersect_device(
def get_lattice( def get_lattice(
nnet_output: torch.Tensor, nnet_output: torch.Tensor,
HLG: k2.Fsa, decoding_graph: k2.Fsa,
supervision_segments: torch.Tensor, supervision_segments: torch.Tensor,
search_beam: float, search_beam: float,
output_beam: float, output_beam: float,
@ -79,8 +79,9 @@ def get_lattice(
Args: Args:
nnet_output: nnet_output:
It is the output of a neural model of shape `(N, T, C)`. It is the output of a neural model of shape `(N, T, C)`.
HLG: decoding_graph:
An Fsa, the decoding graph. See also `compile_HLG.py`. An Fsa, the decoding graph. It can be either an HLG
(see `compile_HLG.py`) or an H (see `k2.ctc_topo`).
supervision_segments: supervision_segments:
A 2-D **CPU** tensor of dtype `torch.int32` with 3 columns. A 2-D **CPU** tensor of dtype `torch.int32` with 3 columns.
Each row contains information for a supervision segment. Column 0 Each row contains information for a supervision segment. Column 0
@ -117,7 +118,7 @@ def get_lattice(
) )
lattice = k2.intersect_dense_pruned( lattice = k2.intersect_dense_pruned(
HLG, decoding_graph,
dense_fsa_vec, dense_fsa_vec,
search_beam=search_beam, search_beam=search_beam,
output_beam=output_beam, output_beam=output_beam,
@ -180,7 +181,7 @@ class Nbest(object):
lattice: k2.Fsa, lattice: k2.Fsa,
num_paths: int, num_paths: int,
use_double_scores: bool = True, use_double_scores: bool = True,
lattice_score_scale: float = 0.5, nbest_scale: float = 0.5,
) -> "Nbest": ) -> "Nbest":
"""Construct an Nbest object by **sampling** `num_paths` from a lattice. """Construct an Nbest object by **sampling** `num_paths` from a lattice.
@ -206,7 +207,7 @@ class Nbest(object):
Return an Nbest instance. Return an Nbest instance.
""" """
saved_scores = lattice.scores.clone() saved_scores = lattice.scores.clone()
lattice.scores *= lattice_score_scale lattice.scores *= nbest_scale
# path is a ragged tensor with dtype torch.int32. # path is a ragged tensor with dtype torch.int32.
# It has three axes [utt][path][arc_pos] # It has three axes [utt][path][arc_pos]
path = k2.random_paths( path = k2.random_paths(
@ -446,7 +447,7 @@ def nbest_decoding(
lattice: k2.Fsa, lattice: k2.Fsa,
num_paths: int, num_paths: int,
use_double_scores: bool = True, use_double_scores: bool = True,
lattice_score_scale: float = 1.0, nbest_scale: float = 1.0,
) -> k2.Fsa: ) -> k2.Fsa:
"""It implements something like CTC prefix beam search using n-best lists. """It implements something like CTC prefix beam search using n-best lists.
@ -474,7 +475,7 @@ def nbest_decoding(
use_double_scores: use_double_scores:
True to use double precision floating point in the computation. True to use double precision floating point in the computation.
False to use single precision. False to use single precision.
lattice_score_scale: nbest_scale:
It's the scale applied to the `lattice.scores`. A smaller value It's the scale applied to the `lattice.scores`. A smaller value
leads to more unique paths at the risk of missing the correct path. leads to more unique paths at the risk of missing the correct path.
Returns: Returns:
@ -484,7 +485,7 @@ def nbest_decoding(
lattice=lattice, lattice=lattice,
num_paths=num_paths, num_paths=num_paths,
use_double_scores=use_double_scores, use_double_scores=use_double_scores,
lattice_score_scale=lattice_score_scale, nbest_scale=nbest_scale,
) )
# nbest.fsa.scores contains 0s # nbest.fsa.scores contains 0s
@ -505,7 +506,7 @@ def nbest_oracle(
ref_texts: List[str], ref_texts: List[str],
word_table: k2.SymbolTable, word_table: k2.SymbolTable,
use_double_scores: bool = True, use_double_scores: bool = True,
lattice_score_scale: float = 0.5, nbest_scale: float = 0.5,
oov: str = "<UNK>", oov: str = "<UNK>",
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[int]]]:
"""Select the best hypothesis given a lattice and a reference transcript. """Select the best hypothesis given a lattice and a reference transcript.
@ -517,7 +518,7 @@ def nbest_oracle(
The decoding result returned from this function is the best result that The decoding result returned from this function is the best result that
we can obtain using n-best decoding with all kinds of rescoring techniques. we can obtain using n-best decoding with all kinds of rescoring techniques.
This function is useful to tune the value of `lattice_score_scale`. This function is useful to tune the value of `nbest_scale`.
Args: Args:
lattice: lattice:
@ -533,7 +534,7 @@ def nbest_oracle(
use_double_scores: use_double_scores:
True to use double precision for computation. False to use True to use double precision for computation. False to use
single precision. single precision.
lattice_score_scale: nbest_scale:
It's the scale applied to the lattice.scores. A smaller value It's the scale applied to the lattice.scores. A smaller value
yields more unique paths. yields more unique paths.
oov: oov:
@ -549,7 +550,7 @@ def nbest_oracle(
lattice=lattice, lattice=lattice,
num_paths=num_paths, num_paths=num_paths,
use_double_scores=use_double_scores, use_double_scores=use_double_scores,
lattice_score_scale=lattice_score_scale, nbest_scale=nbest_scale,
) )
hyps = nbest.build_levenshtein_graphs() hyps = nbest.build_levenshtein_graphs()
@ -590,7 +591,7 @@ def rescore_with_n_best_list(
G: k2.Fsa, G: k2.Fsa,
num_paths: int, num_paths: int,
lm_scale_list: List[float], lm_scale_list: List[float],
lattice_score_scale: float = 1.0, nbest_scale: float = 1.0,
use_double_scores: bool = True, use_double_scores: bool = True,
) -> Dict[str, k2.Fsa]: ) -> Dict[str, k2.Fsa]:
"""Rescore an n-best list with an n-gram LM. """Rescore an n-best list with an n-gram LM.
@ -607,7 +608,7 @@ def rescore_with_n_best_list(
Size of nbest list. Size of nbest list.
lm_scale_list: lm_scale_list:
A list of float representing LM score scales. A list of float representing LM score scales.
lattice_score_scale: nbest_scale:
Scale to be applied to ``lattice.score`` when sampling paths Scale to be applied to ``lattice.score`` when sampling paths
using ``k2.random_paths``. using ``k2.random_paths``.
use_double_scores: use_double_scores:
@ -631,7 +632,7 @@ def rescore_with_n_best_list(
lattice=lattice, lattice=lattice,
num_paths=num_paths, num_paths=num_paths,
use_double_scores=use_double_scores, use_double_scores=use_double_scores,
lattice_score_scale=lattice_score_scale, nbest_scale=nbest_scale,
) )
# nbest.fsa.scores are all 0s at this point # nbest.fsa.scores are all 0s at this point
@ -769,7 +770,7 @@ def rescore_with_attention_decoder(
memory_key_padding_mask: Optional[torch.Tensor], memory_key_padding_mask: Optional[torch.Tensor],
sos_id: int, sos_id: int,
eos_id: int, eos_id: int,
lattice_score_scale: float = 1.0, nbest_scale: float = 1.0,
ngram_lm_scale: Optional[float] = None, ngram_lm_scale: Optional[float] = None,
attention_scale: Optional[float] = None, attention_scale: Optional[float] = None,
use_double_scores: bool = True, use_double_scores: bool = True,
@ -796,7 +797,7 @@ def rescore_with_attention_decoder(
The token ID for SOS. The token ID for SOS.
eos_id: eos_id:
The token ID for EOS. The token ID for EOS.
lattice_score_scale: nbest_scale:
It's the scale applied to `lattice.scores`. A smaller value It's the scale applied to `lattice.scores`. A smaller value
leads to more unique paths at the risk of missing the correct path. leads to more unique paths at the risk of missing the correct path.
ngram_lm_scale: ngram_lm_scale:
@ -812,7 +813,7 @@ def rescore_with_attention_decoder(
lattice=lattice, lattice=lattice,
num_paths=num_paths, num_paths=num_paths,
use_double_scores=use_double_scores, use_double_scores=use_double_scores,
lattice_score_scale=lattice_score_scale, nbest_scale=nbest_scale,
) )
# nbest.fsa.scores are all 0s at this point # nbest.fsa.scores are all 0s at this point

View File

@ -1,4 +1,5 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
# #
# See ../../LICENSE for clarification regarding multiple authors # See ../../LICENSE for clarification regarding multiple authors
# #
@ -16,6 +17,7 @@
import argparse import argparse
import collections
import logging import logging
import os import os
import subprocess import subprocess
@ -27,10 +29,12 @@ from pathlib import Path
from typing import Any, Dict, Iterable, List, TextIO, Tuple, Union from typing import Any, Dict, Iterable, List, TextIO, Tuple, Union
import k2 import k2
import k2.version
import kaldialign import kaldialign
import lhotse import lhotse
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter
Pathlike = Union[str, Path] Pathlike = Union[str, Path]
@ -233,8 +237,8 @@ def encode_supervisions(
supervisions: dict, subsampling_factor: int supervisions: dict, subsampling_factor: int
) -> Tuple[torch.Tensor, List[str]]: ) -> Tuple[torch.Tensor, List[str]]:
""" """
Encodes Lhotse's ``batch["supervisions"]`` dict into a pair of torch Tensor, Encodes Lhotse's ``batch["supervisions"]`` dict into
and a list of transcription strings. a pair of torch Tensor, and a list of transcription strings.
The supervision tensor has shape ``(batch_size, 3)``. The supervision tensor has shape ``(batch_size, 3)``.
Its second dimension contains information about sequence index [0], Its second dimension contains information about sequence index [0],
@ -302,6 +306,73 @@ def get_texts(
return aux_labels.tolist() return aux_labels.tolist()
def get_alignments(best_paths: k2.Fsa) -> List[List[int]]:
"""Extract the token IDs (from best_paths.labels) from the best-path FSAs.
Args:
best_paths:
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
containing multiple FSAs, which is expected to be the result
of k2.shortest_path (otherwise the returned values won't
be meaningful).
Returns:
Returns a list of lists of int, containing the token sequences we
decoded. For `ans[i]`, its length equals to the number of frames
after subsampling of the i-th utterance in the batch.
"""
# arc.shape() has axes [fsa][state][arc], we remove "state"-axis here
label_shape = best_paths.arcs.shape().remove_axis(1)
# label_shape has axes [fsa][arc]
labels = k2.RaggedTensor(label_shape, best_paths.labels.contiguous())
labels = labels.remove_values_eq(-1)
return labels.tolist()
def save_alignments(
alignments: Dict[str, List[int]],
subsampling_factor: int,
filename: str,
) -> None:
"""Save alignments to a file.
Args:
alignments:
A dict containing alignments. Keys of the dict are utterances and
values are the corresponding framewise alignments after subsampling.
subsampling_factor:
The subsampling factor of the model.
filename:
Path to save the alignments.
Returns:
Return None.
"""
ali_dict = {
"subsampling_factor": subsampling_factor,
"alignments": alignments,
}
torch.save(ali_dict, filename)
def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]:
"""Load alignments from a file.
Args:
filename:
Path to the file containing alignment information.
The file should be saved by :func:`save_alignments`.
Returns:
Return a tuple containing:
- subsampling_factor: The subsampling_factor used to compute
the alignments.
- alignments: A dict containing utterances and their corresponding
framewise alignment, after subsampling.
"""
ali_dict = torch.load(filename)
subsampling_factor = ali_dict["subsampling_factor"]
alignments = ali_dict["alignments"]
return subsampling_factor, alignments
def store_transcripts( def store_transcripts(
filename: Pathlike, texts: Iterable[Tuple[str, str]] filename: Pathlike, texts: Iterable[Tuple[str, str]]
) -> None: ) -> None:
@ -339,13 +410,13 @@ def write_error_stats(
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606 Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
reference words (2337 correct) reference words (2337 correct)
- The difference between the reference transcript and predicted results. - The difference between the reference transcript and predicted result.
An instance is given below:: An instance is given below::
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
The above example shows that the reference word is `EDISON`, but it is The above example shows that the reference word is `EDISON`,
predicted to `ADDISON` (a substitution error). but it is predicted to `ADDISON` (a substitution error).
Another example is:: Another example is::
@ -486,3 +557,76 @@ def write_error_stats(
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
return float(tot_err_rate) return float(tot_err_rate)
class MetricsTracker(collections.defaultdict):
def __init__(self):
# Passing the type 'int' to the base-class constructor
# makes undefined items default to int() which is zero.
# This class will play a role as metrics tracker.
# It can record many metrics, including but not limited to loss.
super(MetricsTracker, self).__init__(int)
def __add__(self, other: "MetricsTracker") -> "MetricsTracker":
ans = MetricsTracker()
for k, v in self.items():
ans[k] = v
for k, v in other.items():
ans[k] = ans[k] + v
return ans
def __mul__(self, alpha: float) -> "MetricsTracker":
ans = MetricsTracker()
for k, v in self.items():
ans[k] = v * alpha
return ans
def __str__(self) -> str:
ans = ""
for k, v in self.norm_items():
norm_value = "%.4g" % v
ans += str(k) + "=" + str(norm_value) + ", "
frames = str(self["frames"])
ans += "over " + frames + " frames."
return ans
def norm_items(self) -> List[Tuple[str, float]]:
"""
Returns a list of pairs, like:
[('ctc_loss', 0.1), ('att_loss', 0.07)]
"""
num_frames = self["frames"] if "frames" in self else 1
ans = []
for k, v in self.items():
if k != "frames":
norm_value = float(v) / num_frames
ans.append((k, norm_value))
return ans
def reduce(self, device):
"""
Reduce using torch.distributed, which I believe ensures that
all processes get the total.
"""
keys = sorted(self.keys())
s = torch.tensor([float(self[k]) for k in keys], device=device)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
for k, v in zip(keys, s.cpu().tolist()):
self[k] = v
def write_summary(
self,
tb_writer: SummaryWriter,
prefix: str,
batch_idx: int,
) -> None:
"""Add logging information to a TensorBoard writer.
Args:
tb_writer: a TensorBoard writer
prefix: a prefix for the name of the loss, e.g. "train/valid_",
or "train/current_"
batch_idx: The current batch index, used as the x-axis of the plot.
"""
for k, v in self.norm_items():
tb_writer.add_scalar(prefix + k, v, batch_idx)

31
setup.py Normal file
View File

@ -0,0 +1,31 @@
#!/usr/bin/env python3
from setuptools import find_packages, setup
from pathlib import Path
icefall_dir = Path(__file__).parent
install_requires = (icefall_dir / "requirements.txt").read_text().splitlines()
setup(
name="icefall",
version="1.0",
python_requires=">=3.6.0",
description="Speech processing recipes using k2 and Lhotse.",
author="The k2 and Lhotse Development Team",
license="Apache-2.0 License",
packages=find_packages(),
install_requires=install_requires,
classifiers=[
"Development Status :: 3 - Alpha",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Intended Audience :: Science/Research",
"Operating System :: POSIX :: Linux",
"License :: OSI Approved :: Apache Software License",
"Topic :: Multimedia :: Sound/Audio :: Speech",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries :: Python Modules",
"Typing :: Typed",
],
)

View File

@ -43,7 +43,7 @@ def test_nbest_from_lattice():
lattice=lattice, lattice=lattice,
num_paths=10, num_paths=10,
use_double_scores=True, use_double_scores=True,
lattice_score_scale=0.5, nbest_scale=0.5,
) )
# each lattice has only 4 distinct paths that have different word sequences: # each lattice has only 4 distinct paths that have different word sequences:
# 10->30 # 10->30

View File

@ -20,7 +20,12 @@ import k2
import pytest import pytest
import torch import torch
from icefall.utils import AttributeDict, encode_supervisions, get_texts from icefall.utils import (
AttributeDict,
encode_supervisions,
get_env_info,
get_texts,
)
@pytest.fixture @pytest.fixture
@ -108,6 +113,7 @@ def test_attribute_dict():
assert s["b"] == 20 assert s["b"] == 20
s.c = 100 s.c = 100
assert s["c"] == 100 assert s["c"] == 100
assert hasattr(s, "a") assert hasattr(s, "a")
assert hasattr(s, "b") assert hasattr(s, "b")
assert getattr(s, "a") == 10 assert getattr(s, "a") == 10
@ -119,3 +125,8 @@ def test_attribute_dict():
del s.a del s.a
except AttributeError as ex: except AttributeError as ex:
print(f"Caught exception: {ex}") print(f"Caught exception: {ex}")
def test_get_env_info():
s = get_env_info()
print(s)