Merge remote-tracking branch 'dan/master' into ctc-ali

This commit is contained in:
Fangjun Kuang 2021-10-18 14:07:20 +08:00
commit 1c603c3bce
25 changed files with 1133 additions and 477 deletions

106
.github/workflows/run-pretrained.yml vendored Normal file
View File

@ -0,0 +1,106 @@
# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: run-pre-trained-conformer-ctc
on:
push:
branches:
- master
pull_request:
types: [labeled]
jobs:
run_pre_trained_conformer_ctc:
if: github.event.label.name == 'ready' || github.event_name == 'push'
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
python-version: [3.6, 3.7, 3.8, 3.9]
torch: ["1.8.1"]
k2-version: ["1.9.dev20210919"]
fail-fast: false
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Install Python dependencies
run: |
python3 -m pip install --upgrade pip pytest
pip install torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
python3 -m pip install git+https://github.com/lhotse-speech/lhotse
python3 -m pip install kaldifeat
# We are in ./icefall and there is a file: requirements.txt in it
pip install -r requirements.txt
- name: Install graphviz
shell: bash
run: |
python3 -m pip install -qq graphviz
sudo apt-get -qq install graphviz
- name: Download pre-trained model
shell: bash
run: |
sudo apt-get -qq install git-lfs tree sox
cd egs/librispeech/ASR
mkdir tmp
cd tmp
git lfs install
git clone https://github.com/csukuangfj/icefall-asr-conformer-ctc-bpe-500
cd ..
tree tmp
soxi tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/*.flac
ls -lh tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/*.flac
- name: Run CTC decoding
shell: bash
run: |
export PYTHONPATH=$PWD:PYTHONPATH
cd egs/librispeech/ASR
./conformer_ctc/pretrained.py \
--num-classes 500 \
--checkpoint ./tmp/icefall-asr-conformer-ctc-bpe-500/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-conformer-ctc-bpe-500/data/lang_bpe_500/bpe.model \
--method ctc-decoding \
./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1089-134686-0001.flac \
./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1221-135766-0001.flac \
./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1221-135766-0002.flac
- name: Run HLG decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
cd egs/librispeech/ASR
./conformer_ctc/pretrained.py \
--num-classes 500 \
--checkpoint ./tmp/icefall-asr-conformer-ctc-bpe-500/exp/pretrained.pt \
--words-file ./tmp/icefall-asr-conformer-ctc-bpe-500/data/lang_bpe_500/words.txt \
--HLG ./tmp/icefall-asr-conformer-ctc-bpe-500/data/lang_bpe_500/HLG.pt \
./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1089-134686-0001.flac \
./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1221-135766-0001.flac \
./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1221-135766-0002.flac

View File

@ -92,3 +92,7 @@ jobs:
echo "lib_path: $lib_path"
export DYLD_LIBRARY_PATH=$lib_path:$DYLD_LIBRARY_PATH
pytest ./test
# runt tests for conformer ctc
cd egs/librispeech/ASR/conformer_ctc
pytest

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
icefall.egg-info/
data
__pycache__
path.sh

View File

@ -55,7 +55,22 @@ The WER for this model is:
We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd?usp=sharing)
## Deployment with C++
Once you have trained a model in icefall, you may want to deploy it with C++,
without Python dependencies.
Please refer to the documentation
<https://icefall.readthedocs.io/en/latest/recipes/librispeech/conformer_ctc.html#deployment-with-c>
for how to do this.
We also provide a Colab notebook, showing you how to run a torch scripted model in [k2][k2] with C++.
Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1BIGLWzS36isskMXHKcqC9ysN6pspYXs_?usp=sharing)
[LibriSpeech_tdnn_lstm_ctc]: egs/librispeech/ASR/tdnn_lstm_ctc
[LibriSpeech_conformer_ctc]: egs/librispeech/ASR/conformer_ctc
[yesno]: egs/yesno/ASR
[librispeech]: egs/librispeech/ASR
[k2]: https://github.com/k2-fsa/k2

View File

@ -1,4 +1,4 @@
Confromer CTC
Conformer CTC
=============
This tutorial shows you how to run a conformer ctc model
@ -20,6 +20,7 @@ In this tutorial, you will learn:
- (2) How to start the training, either with a single GPU or multiple GPUs
- (3) How to do decoding after training, with n-gram LM rescoring and attention decoder rescoring
- (4) How to use a pre-trained model, provided by us
- (5) How to deploy your trained model in C++, without Python dependencies
Data preparation
----------------
@ -292,16 +293,25 @@ The commonly used options are:
- ``--method``
This specifies the decoding method.
This specifies the decoding method. This script supports 7 decoding methods.
As for ctc decoding, it uses a sentence piece model to convert word pieces to words.
And it needs neither a lexicon nor an n-gram LM.
The following command uses attention decoder for rescoring:
For example, the following command uses CTC topology for decoding:
.. code-block::
$ cd egs/librispeech/ASR
$ ./conformer_ctc/decode.py --method attention-decoder --max-duration 30 --lattice-score-scale 0.5
$ ./conformer_ctc/decode.py --method ctc-decoding --max-duration 300
- ``--lattice-score-scale``
And the following command uses attention decoder for rescoring:
.. code-block::
$ cd egs/librispeech/ASR
$ ./conformer_ctc/decode.py --method attention-decoder --max-duration 30 --nbest-scale 0.5
- ``--nbest-scale``
It is used to scale down lattice scores so that there are more unique
paths for rescoring.
@ -311,6 +321,61 @@ The commonly used options are:
It has the same meaning as the one during training. A larger
value may cause OOM.
Here are some results for CTC decoding with a vocab size of 500:
Usage:
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./conformer_ctc/decode.py \
--epoch 25 \
--avg 1 \
--max-duration 300 \
--exp-dir conformer_ctc/exp \
--lang-dir data/lang_bpe_500 \
--method ctc-decoding
The output is given below:
.. code-block:: bash
2021-09-26 12:44:31,033 INFO [decode.py:537] Decoding started
2021-09-26 12:44:31,033 INFO [decode.py:538]
{'lm_dir': PosixPath('data/lm'), 'subsampling_factor': 4, 'vgg_frontend': False, 'use_feat_batchnorm': True,
'feature_dim': 80, 'nhead': 8, 'attention_dim': 512, 'num_decoder_layers': 6, 'search_beam': 20, 'output_beam': 8,
'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True,
'epoch': 25, 'avg': 1, 'method': 'ctc-decoding', 'num_paths': 100, 'nbest_scale': 0.5,
'export': False, 'exp_dir': PosixPath('conformer_ctc/exp'), 'lang_dir': PosixPath('data/lang_bpe_500'), 'full_libri': False,
'feature_dir': PosixPath('data/fbank'), 'max_duration': 100, 'bucketing_sampler': False, 'num_buckets': 30,
'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False,
'shuffle': True, 'return_cuts': True, 'num_workers': 2}
2021-09-26 12:44:31,406 INFO [lexicon.py:113] Loading pre-compiled data/lang_bpe_500/Linv.pt
2021-09-26 12:44:31,464 INFO [decode.py:548] device: cuda:0
2021-09-26 12:44:36,171 INFO [checkpoint.py:92] Loading checkpoint from conformer_ctc/exp/epoch-25.pt
2021-09-26 12:44:36,776 INFO [decode.py:652] Number of model parameters: 109226120
2021-09-26 12:44:37,714 INFO [decode.py:473] batch 0/206, cuts processed until now is 12
2021-09-26 12:45:15,944 INFO [decode.py:473] batch 100/206, cuts processed until now is 1328
2021-09-26 12:45:54,443 INFO [decode.py:473] batch 200/206, cuts processed until now is 2563
2021-09-26 12:45:56,411 INFO [decode.py:494] The transcripts are stored in conformer_ctc/exp/recogs-test-clean-ctc-decoding.txt
2021-09-26 12:45:56,592 INFO [utils.py:331] [test-clean-ctc-decoding] %WER 3.26% [1715 / 52576, 163 ins, 128 del, 1424 sub ]
2021-09-26 12:45:56,807 INFO [decode.py:506] Wrote detailed error stats to conformer_ctc/exp/errs-test-clean-ctc-decoding.txt
2021-09-26 12:45:56,808 INFO [decode.py:522]
For test-clean, WER of different settings are:
ctc-decoding 3.26 best for test-clean
2021-09-26 12:45:57,362 INFO [decode.py:473] batch 0/203, cuts processed until now is 15
2021-09-26 12:46:35,565 INFO [decode.py:473] batch 100/203, cuts processed until now is 1477
2021-09-26 12:47:15,106 INFO [decode.py:473] batch 200/203, cuts processed until now is 2922
2021-09-26 12:47:16,131 INFO [decode.py:494] The transcripts are stored in conformer_ctc/exp/recogs-test-other-ctc-decoding.txt
2021-09-26 12:47:16,208 INFO [utils.py:331] [test-other-ctc-decoding] %WER 8.21% [4295 / 52343, 396 ins, 315 del, 3584 sub ]
2021-09-26 12:47:16,432 INFO [decode.py:506] Wrote detailed error stats to conformer_ctc/exp/errs-test-other-ctc-decoding.txt
2021-09-26 12:47:16,432 INFO [decode.py:522]
For test-other, WER of different settings are:
ctc-decoding 8.21 best for test-other
2021-09-26 12:47:16,433 INFO [decode.py:680] Done!
Pre-trained Model
-----------------
@ -381,7 +446,6 @@ After downloading, you will have the following files:
6 directories, 11 files
**File descriptions**:
- ``data/lang_bpe/HLG.pt``
It is the decoding graph.
@ -462,12 +526,58 @@ Usage
displays the help information.
It supports three decoding methods:
It supports 4 decoding methods:
- CTC decoding
- HLG decoding
- HLG + n-gram LM rescoring
- HLG + n-gram LM rescoring + attention decoder rescoring
CTC decoding
^^^^^^^^^^^^
CTC decoding uses the best path of the decoding lattice as the decoding result
without any LM or lexicon.
The command to run CTC decoding is:
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./conformer_ctc/pretrained.py \
--checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \
--bpe-model ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/bpe.model \
--method ctc-decoding \
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac \
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac \
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac
The output is given below:
.. code-block::
2021-10-13 11:21:50,896 INFO [pretrained.py:236] device: cuda:0
2021-10-13 11:21:50,896 INFO [pretrained.py:238] Creating model
2021-10-13 11:21:56,669 INFO [pretrained.py:255] Constructing Fbank computer
2021-10-13 11:21:56,670 INFO [pretrained.py:265] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac']
2021-10-13 11:21:56,683 INFO [pretrained.py:271] Decoding started
2021-10-13 11:21:57,341 INFO [pretrained.py:290] Building CTC topology
2021-10-13 11:21:57,625 INFO [lexicon.py:113] Loading pre-compiled tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/Linv.pt
2021-10-13 11:21:57,679 INFO [pretrained.py:299] Loading BPE model
2021-10-13 11:22:00,076 INFO [pretrained.py:314] Use CTC decoding
2021-10-13 11:22:00,087 INFO [pretrained.py:400]
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac:
GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONOURED
BOSOM TO CONNECT HER PARENT FOR EVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
2021-10-13 11:22:00,087 INFO [pretrained.py:402] Decoding Done
HLG decoding
^^^^^^^^^^^^
@ -490,14 +600,14 @@ The output is given below:
.. code-block::
2021-08-20 11:03:05,712 INFO [pretrained.py:217] device: cuda:0
2021-08-20 11:03:05,712 INFO [pretrained.py:219] Creating model
2021-08-20 11:03:11,345 INFO [pretrained.py:238] Loading HLG from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt
2021-08-20 11:03:18,442 INFO [pretrained.py:255] Constructing Fbank computer
2021-08-20 11:03:18,444 INFO [pretrained.py:265] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac']
2021-08-20 11:03:18,507 INFO [pretrained.py:271] Decoding started
2021-08-20 11:03:18,795 INFO [pretrained.py:300] Use HLG decoding
2021-08-20 11:03:19,149 INFO [pretrained.py:339]
2021-10-13 11:25:19,458 INFO [pretrained.py:236] device: cuda:0
2021-10-13 11:25:19,458 INFO [pretrained.py:238] Creating model
2021-10-13 11:25:25,342 INFO [pretrained.py:255] Constructing Fbank computer
2021-10-13 11:25:25,343 INFO [pretrained.py:265] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac']
2021-10-13 11:25:25,356 INFO [pretrained.py:271] Decoding started
2021-10-13 11:25:26,026 INFO [pretrained.py:327] Loading HLG from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt
2021-10-13 11:25:33,735 INFO [pretrained.py:359] Use HLG decoding
2021-10-13 11:25:34,013 INFO [pretrained.py:400]
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
@ -508,7 +618,7 @@ The output is given below:
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
2021-08-20 11:03:19,149 INFO [pretrained.py:341] Decoding Done
2021-10-13 11:25:34,014 INFO [pretrained.py:402] Decoding Done
HLG decoding + LM rescoring
^^^^^^^^^^^^^^^^^^^^^^^^^^^
@ -536,15 +646,15 @@ Its output is:
.. code-block::
2021-08-20 11:12:17,565 INFO [pretrained.py:217] device: cuda:0
2021-08-20 11:12:17,565 INFO [pretrained.py:219] Creating model
2021-08-20 11:12:23,728 INFO [pretrained.py:238] Loading HLG from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt
2021-08-20 11:12:30,035 INFO [pretrained.py:246] Loading G from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt
2021-08-20 11:13:10,779 INFO [pretrained.py:255] Constructing Fbank computer
2021-08-20 11:13:10,787 INFO [pretrained.py:265] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac']
2021-08-20 11:13:10,798 INFO [pretrained.py:271] Decoding started
2021-08-20 11:13:11,085 INFO [pretrained.py:305] Use HLG decoding + LM rescoring
2021-08-20 11:13:11,736 INFO [pretrained.py:339]
2021-10-13 11:28:19,129 INFO [pretrained.py:236] device: cuda:0
2021-10-13 11:28:19,129 INFO [pretrained.py:238] Creating model
2021-10-13 11:28:23,531 INFO [pretrained.py:255] Constructing Fbank computer
2021-10-13 11:28:23,532 INFO [pretrained.py:265] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac']
2021-10-13 11:28:23,544 INFO [pretrained.py:271] Decoding started
2021-10-13 11:28:24,141 INFO [pretrained.py:327] Loading HLG from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt
2021-10-13 11:28:30,752 INFO [pretrained.py:338] Loading G from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt
2021-10-13 11:28:48,308 INFO [pretrained.py:364] Use HLG decoding + LM rescoring
2021-10-13 11:28:48,815 INFO [pretrained.py:400]
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
@ -555,7 +665,7 @@ Its output is:
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
2021-08-20 11:13:11,737 INFO [pretrained.py:341] Decoding Done
2021-10-13 11:28:48,815 INFO [pretrained.py:402] Decoding Done
HLG decoding + LM rescoring + attention decoder rescoring
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@ -577,7 +687,7 @@ The command to run HLG decoding + LM rescoring + attention decoder rescoring is:
--G ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt \
--ngram-lm-scale 1.3 \
--attention-decoder-scale 1.2 \
--lattice-score-scale 0.5 \
--nbest-scale 0.5 \
--num-paths 100 \
--sos-id 1 \
--eos-id 1 \
@ -589,15 +699,15 @@ The output is below:
.. code-block::
2021-08-20 11:19:11,397 INFO [pretrained.py:217] device: cuda:0
2021-08-20 11:19:11,397 INFO [pretrained.py:219] Creating model
2021-08-20 11:19:17,354 INFO [pretrained.py:238] Loading HLG from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt
2021-08-20 11:19:24,615 INFO [pretrained.py:246] Loading G from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt
2021-08-20 11:20:04,576 INFO [pretrained.py:255] Constructing Fbank computer
2021-08-20 11:20:04,584 INFO [pretrained.py:265] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac']
2021-08-20 11:20:04,595 INFO [pretrained.py:271] Decoding started
2021-08-20 11:20:04,854 INFO [pretrained.py:313] Use HLG + LM rescoring + attention decoder rescoring
2021-08-20 11:20:05,805 INFO [pretrained.py:339]
2021-10-13 11:29:50,106 INFO [pretrained.py:236] device: cuda:0
2021-10-13 11:29:50,106 INFO [pretrained.py:238] Creating model
2021-10-13 11:29:56,063 INFO [pretrained.py:255] Constructing Fbank computer
2021-10-13 11:29:56,063 INFO [pretrained.py:265] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac']
2021-10-13 11:29:56,077 INFO [pretrained.py:271] Decoding started
2021-10-13 11:29:56,770 INFO [pretrained.py:327] Loading HLG from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt
2021-10-13 11:30:04,023 INFO [pretrained.py:338] Loading G from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt
2021-10-13 11:30:18,163 INFO [pretrained.py:372] Use HLG + LM rescoring + attention decoder rescoring
2021-10-13 11:30:19,367 INFO [pretrained.py:400]
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
@ -608,7 +718,7 @@ The output is below:
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
2021-08-20 11:20:05,805 INFO [pretrained.py:341] Decoding Done
2021-10-13 11:30:19,367 INFO [pretrained.py:402] Decoding Done
Colab notebook
--------------
@ -629,3 +739,119 @@ We do provide a colab notebook for this recipe showing how to use a pre-trained
**Congratulations!** You have finished the librispeech ASR recipe with
conformer CTC models in ``icefall``.
If you want to deploy your trained model in C++, please read the following section.
Deployment with C++
-------------------
This section describes how to deploy your trained model in C++, without
Python dependencies.
We assume you have run ``./prepare.sh`` and have the following directories available:
.. code-block:: bash
data
|-- lang_bpe
Also, we assume your checkpoints are saved in ``conformer_ctc/exp``.
If you know that averaging 20 checkpoints starting from ``epoch-30.pt`` yields the
lowest WER, you can run the following commands
.. code-block::
$ cd egs/librispeech/ASR
$ ./conformer_ctc/export.py \
--epoch 30 \
--avg 20 \
--jit 1 \
--lang-dir data/lang_bpe \
--exp-dir conformer_ctc/exp
to get a torch scripted model saved in ``conformer_ctc/exp/cpu_jit.pt``.
Now you have all needed files ready. Let us compile k2 from source:
.. code-block:: bash
$ cd $HOME
$ git clone https://github.com/k2-fsa/k2
$ cd k2
$ git checkout v2.0-pre
.. CAUTION::
You have to switch to the branch ``v2.0-pre``!
.. code-block:: bash
$ mkdir build-release
$ cd build-release
$ cmake -DCMAKE_BUILD_TYPE=Release ..
$ make -j decode
# You will find an executable: `./bin/decode`
Now you are ready to go!
To view the usage of ``./bin/decode``, run:
.. code-block::
$ ./bin/decode
It will show you the following message:
.. code-block::
Please provide --jit_pt
(1) CTC decoding
./bin/decode \
--use_ctc_decoding true \
--jit_pt <path to exported torch script pt file> \
--bpe_model <path to pretrained BPE model> \
/path/to/foo.wav \
/path/to/bar.wav \
<more wave files if any>
(2) HLG decoding
./bin/decode \
--use_ctc_decoding false \
--jit_pt <path to exported torch script pt file> \
--hlg <path to HLG.pt> \
--word-table <path to words.txt> \
/path/to/foo.wav \
/path/to/bar.wav \
<more wave files if any>
--use_gpu false to use CPU
--use_gpu true to use GPU
``./bin/decode`` supports two types of decoding at present: CTC decoding and HLG decoding.
CTC decoding
^^^^^^^^^^^^
You need to provide:
- ``--jit_pt``, this is the file generated by ``conformer_ctc/export.py``. You can find it
in ``conformer_ctc/exp/cpu_jit.pt``.
- ``--bpe_model``, this is a sentence piece model generated by ``prepare.sh``. You can find
it in ``data/lang_bpe/bpe.model``.
HLG decoding
^^^^^^^^^^^^
You need to provide:
- ``--jit_pt``, this is the same file as in CTC decoding.
- ``--hlg``, this file is generated by ``prepare.sh``. You can find it in ``data/lang_bpe/HLG.pt``.
- ``--word-table``, this file is generated by ``prepare.sh``. You can find it in ``data/lang_bpe/words.txt``.
We do provide a Colab notebook, showing you how to run a torch scripted model in C++.
Please see |librispeech asr conformer ctc torch script colab notebook|
.. |librispeech asr conformer ctc torch script colab notebook| image:: https://colab.research.google.com/assets/colab-badge.svg
:target: https://colab.research.google.com/drive/1BIGLWzS36isskMXHKcqC9ysN6pspYXs_?usp=sharing

View File

@ -41,7 +41,7 @@ python conformer_ctc/train.py --bucketing-sampler True \
--world-size 4 \
--lang-dir data/lang_bpe_5000
python conformer_ctc/decode.py --lattice-score-scale 0.5 \
python conformer_ctc/decode.py --nbest-scale 0.5 \
--epoch 34 \
--avg 20 \
--method attention-decoder \

View File

@ -23,6 +23,7 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
@ -78,6 +79,9 @@ def get_parser():
default="attention-decoder",
help="""Decoding method.
Supported values are:
- (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
It needs neither a lexicon nor an n-gram LM.
- (1) 1best. Extract the best path from the decoding lattice as the
decoding result.
- (2) nbest. Extract n paths from the decoding lattice; the path
@ -107,7 +111,7 @@ def get_parser():
)
parser.add_argument(
"--lattice-score-scale",
"--nbest-scale",
type=float,
default=0.5,
help="""The scale to be applied to `lattice.scores`.
@ -129,11 +133,18 @@ def get_parser():
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_ctc/exp",
help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe_5000",
help="lang directory",
default="data/lang_bpe",
help="The lang dir",
)
return parser
@ -142,7 +153,6 @@ def get_parser():
def get_params() -> AttributeDict:
params = AttributeDict(
{
"exp_dir": Path("conformer_ctc/exp"),
"lm_dir": Path("data/lm"),
# parameters for conformer
"subsampling_factor": 4,
@ -167,13 +177,15 @@ def get_params() -> AttributeDict:
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
HLG: k2.Fsa,
HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
batch: dict,
word_table: k2.SymbolTable,
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[int]]]:
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -198,7 +210,11 @@ def decode_one_batch(
model:
The neural model.
HLG:
The decoding graph.
The decoding graph. Used only when params.method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.method is ctc-decoding.
bpe_model:
The BPE model. Used only when params.method is ctc-decoding.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
@ -217,7 +233,10 @@ def decode_one_batch(
Return the decoding result. See above description for the format of
the returned dict.
"""
if HLG is not None:
device = HLG.device
else:
device = H.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
@ -237,9 +256,17 @@ def decode_one_batch(
1,
).to(torch.int32)
if H is None:
assert HLG is not None
decoding_graph = HLG
else:
assert HLG is None
assert bpe_model is not None
decoding_graph = H
lattice = get_lattice(
nnet_output=nnet_output,
HLG=HLG,
decoding_graph=decoding_graph,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
@ -248,6 +275,24 @@ def decode_one_batch(
subsampling_factor=params.subsampling_factor,
)
if params.method == "ctc-decoding":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
# Note: `best_path.aux_labels` contains token IDs, not word IDs
# since we are using H, not HLG here.
#
# token_ids is a lit-of-list of IDs
token_ids = get_texts(best_path)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
key = "ctc-decoding"
return {key: hyps}
if params.method == "nbest-oracle":
# Note: You can also pass rescored lattices to it.
# We choose the HLG decoded lattice for speed reasons
@ -258,12 +303,12 @@ def decode_one_batch(
num_paths=params.num_paths,
ref_texts=supervisions["text"],
word_table=word_table,
lattice_score_scale=params.lattice_score_scale,
nbest_scale=params.nbest_scale,
oov="<UNK>",
)
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
key = f"oracle_{params.num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
return {key: hyps}
if params.method in ["1best", "nbest"]:
@ -277,9 +322,9 @@ def decode_one_batch(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
lattice_score_scale=params.lattice_score_scale,
nbest_scale=params.nbest_scale,
)
key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
@ -301,7 +346,7 @@ def decode_one_batch(
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
lattice_score_scale=params.lattice_score_scale,
nbest_scale=params.nbest_scale,
)
elif params.method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice(
@ -327,7 +372,7 @@ def decode_one_batch(
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
lattice_score_scale=params.lattice_score_scale,
nbest_scale=params.nbest_scale,
)
else:
assert False, f"Unsupported decoding method: {params.method}"
@ -348,12 +393,14 @@ def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
HLG: k2.Fsa,
HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
word_table: k2.SymbolTable,
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[int], List[int]]]]:
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
@ -364,7 +411,11 @@ def decode_dataset(
model:
The neural model.
HLG:
The decoding graph.
The decoding graph. Used only when params.method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.method is ctc-decoding.
bpe_model:
The BPE model. Used only when params.method is ctc-decoding.
word_table:
It is the word symbol table.
sos_id:
@ -399,6 +450,8 @@ def decode_dataset(
params=params,
model=model,
HLG=HLG,
H=H,
bpe_model=bpe_model,
batch=batch,
word_table=word_table,
G=G,
@ -477,6 +530,8 @@ def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
params = get_params()
params.update(vars(args))
@ -504,6 +559,18 @@ def main():
sos_id = graph_compiler.sos_id
eos_id = graph_compiler.eos_id
if params.method == "ctc-decoding":
HLG = None
H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
device=device,
)
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(str(params.lang_dir / "bpe.model"))
else:
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
)
@ -601,6 +668,8 @@ def main():
params=params,
model=model,
HLG=HLG,
H=H,
bpe_model=bpe_model,
word_table=lexicon.word_table,
G=G,
sos_id=sos_id,

View File

@ -0,0 +1,165 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
import argparse
import logging
from pathlib import Path
import torch
from conformer import Conformer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=34,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=20,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_ctc/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe",
help="""It contains language related input files such as "lexicon.txt"
""",
)
parser.add_argument(
"--jit",
type=str2bool,
default=True,
help="""True to save a model after applying torch.jit.script.
""",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"feature_dim": 80,
"subsampling_factor": 4,
"use_feat_batchnorm": True,
"attention_dim": 512,
"nhead": 8,
"num_decoder_layers": 6,
}
)
return params
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
params = get_params()
params.update(vars(args))
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
model = Conformer(
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=num_classes,
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=False,
use_feat_batchnorm=params.use_feat_batchnorm,
)
model.to(device)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
model.to("cpu")
model.eval()
if params.jit:
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -23,6 +24,7 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from conformer import Conformer
@ -54,12 +56,25 @@ def get_parser():
parser.add_argument(
"--words-file",
type=str,
required=True,
help="Path to words.txt",
help="""Path to words.txt.
Used only when method is not ctc-decoding.
""",
)
parser.add_argument(
"--HLG", type=str, required=True, help="Path to HLG.pt."
"--HLG",
type=str,
help="""Path to HLG.pt.
Used only when method is not ctc-decoding.
""",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.
Used only when method is ctc-decoding.
""",
)
parser.add_argument(
@ -68,6 +83,10 @@ def get_parser():
default="1best",
help="""Decoding method.
Possible values are:
(0) ctc-decoding - Use CTC decoding. It uses a sentence
piece model, i.e., lang_dir/bpe.model, to convert
word pieces to words. It needs neither a lexicon
nor an n-gram LM.
(1) 1best - Use the best path as decoding output. Only
the transformer encoder output is used for decoding.
We call it HLG decoding.
@ -125,7 +144,7 @@ def get_parser():
)
parser.add_argument(
"--lattice-score-scale",
"--nbest-scale",
type=float,
default=0.5,
help="""
@ -139,7 +158,7 @@ def get_parser():
parser.add_argument(
"--sos-id",
type=float,
type=int,
default=1,
help="""
Used only when method is attention-decoder.
@ -147,9 +166,18 @@ def get_parser():
""",
)
parser.add_argument(
"--num-classes",
type=int,
default=5000,
help="""
Vocab size in the BPE model.
""",
)
parser.add_argument(
"--eos-id",
type=float,
type=int,
default=1,
help="""
Used only when method is attention-decoder.
@ -180,7 +208,6 @@ def get_params() -> AttributeDict:
"use_feat_batchnorm": True,
"feature_dim": 80,
"nhead": 8,
"num_classes": 5000,
"attention_dim": 512,
"num_decoder_layers": 6,
# parameters for decoding
@ -223,6 +250,11 @@ def main():
args = parser.parse_args()
params = get_params()
if args.method != "attention-decoder":
# to save memory as the attention decoder
# will not be used
params.num_decoder_layers = 0
params.update(vars(args))
params["env_info"] = get_env_info()
logging.info(f"{params}")
@ -246,27 +278,10 @@ def main():
)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"])
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
HLG.lm_scores = HLG.scores.clone()
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
logging.info(f"Loading G from {params.G}")
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = G.to(device)
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G.lm_scores = G.scores.clone()
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
@ -300,9 +315,63 @@ def main():
dtype=torch.int32,
)
if params.method == "ctc-decoding":
logging.info("Use CTC decoding")
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(params.bpe_model)
max_token_id = params.num_classes - 1
H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
device=device,
)
lattice = get_lattice(
nnet_output=nnet_output,
HLG=HLG,
decoding_graph=H,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
token_ids = get_texts(best_path)
hyps = bpe_model.decode(token_ids)
hyps = [s.split() for s in hyps]
elif params.method in [
"1best",
"whole-lattice-rescoring",
"attention-decoder",
]:
logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
HLG.lm_scores = HLG.scores.clone()
if params.method in [
"whole-lattice-rescoring",
"attention-decoder",
]:
logging.info(f"Loading G from {params.G}")
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = G.to(device)
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G.lm_scores = G.scores.clone()
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
@ -337,7 +406,7 @@ def main():
memory_key_padding_mask=memory_key_padding_mask,
sos_id=params.sos_id,
eos_id=params.eos_id,
lattice_score_scale=params.lattice_score_scale,
nbest_scale=params.nbest_scale,
ngram_lm_scale=params.ngram_lm_scale,
attention_scale=params.attention_decoder_scale,
)
@ -346,6 +415,8 @@ def main():
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file)
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
else:
raise ValueError(f"Unsupported decoding method: {params.method}")
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang)
# Wei Kang
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -21,16 +22,16 @@ import argparse
import logging
from pathlib import Path
from shutil import copyfile
from typing import Optional
from typing import Optional, Tuple
import k2
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from lhotse.utils import fix_random_seed
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
@ -43,6 +44,7 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
MetricsTracker,
encode_supervisions,
get_env_info,
setup_logger,
@ -100,6 +102,26 @@ def get_parser():
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_ctc/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
return parser
@ -114,18 +136,6 @@ def get_params() -> AttributeDict:
Explanation of options saved in `params`:
- exp_dir: It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
- lr: It specifies the initial learning rate
- feature_dim: The model input dim. It has to match the one used
in computing features.
- weight_decay: The weight_decay for the optimizer.
- subsampling_factor: The subsampling factor for the model.
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
@ -176,10 +186,6 @@ def get_params() -> AttributeDict:
"""
params = AttributeDict(
{
"exp_dir": Path("conformer_ctc/exp"),
"feature_dim": 80,
"weight_decay": 1e-6,
"subsampling_factor": 4,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
@ -304,7 +310,7 @@ def compute_loss(
batch: dict,
graph_compiler: BpeCtcTrainingGraphCompiler,
is_training: bool,
):
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute CTC loss given the model and its inputs.
@ -384,15 +390,17 @@ def compute_loss(
loss = ctc_loss
att_loss = torch.tensor([0])
# train_frames and valid_frames are used for printing.
if is_training:
params.train_frames = supervision_segments[:, 2].sum().item()
else:
params.valid_frames = supervision_segments[:, 2].sum().item()
assert loss.requires_grad == is_training
return loss, ctc_loss.detach(), att_loss.detach()
info = MetricsTracker()
info["frames"] = supervision_segments[:, 2].sum().item()
info["ctc_loss"] = ctc_loss.detach().cpu().item()
if params.att_rate != 0.0:
info["att_loss"] = att_loss.detach().cpu().item()
info["loss"] = loss.detach().cpu().item()
return loss, info
def compute_validation_loss(
@ -401,18 +409,14 @@ def compute_validation_loss(
graph_compiler: BpeCtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> None:
"""Run the validation process. The validation loss
is saved in `params.valid_loss`.
"""
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
tot_loss = 0.0
tot_ctc_loss = 0.0
tot_att_loss = 0.0
tot_frames = 0.0
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
loss, ctc_loss, att_loss = compute_loss(
loss, loss_info = compute_loss(
params=params,
model=model,
batch=batch,
@ -420,36 +424,17 @@ def compute_validation_loss(
is_training=False,
)
assert loss.requires_grad is False
assert ctc_loss.requires_grad is False
assert att_loss.requires_grad is False
loss_cpu = loss.detach().cpu().item()
tot_loss += loss_cpu
tot_ctc_loss += ctc_loss.detach().cpu().item()
tot_att_loss += att_loss.detach().cpu().item()
tot_frames += params.valid_frames
tot_loss = tot_loss + loss_info
if world_size > 1:
s = torch.tensor(
[tot_loss, tot_ctc_loss, tot_att_loss, tot_frames],
device=loss.device,
)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
s = s.cpu().tolist()
tot_loss = s[0]
tot_ctc_loss = s[1]
tot_att_loss = s[2]
tot_frames = s[3]
tot_loss.reduce(loss.device)
params.valid_loss = tot_loss / tot_frames
params.valid_ctc_loss = tot_ctc_loss / tot_frames
params.valid_att_loss = tot_att_loss / tot_frames
if params.valid_loss < params.best_valid_loss:
loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = params.valid_loss
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
@ -488,24 +473,21 @@ def train_one_epoch(
"""
model.train()
tot_loss = 0.0 # sum of losses over all batches
tot_ctc_loss = 0.0
tot_att_loss = 0.0
tot_loss = MetricsTracker()
tot_frames = 0.0 # sum of frames over all batches
params.tot_loss = 0.0
params.tot_frames = 0.0
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
loss, ctc_loss, att_loss = compute_loss(
loss, loss_info = compute_loss(
params=params,
model=model,
batch=batch,
graph_compiler=graph_compiler,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
@ -515,75 +497,26 @@ def train_one_epoch(
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
loss_cpu = loss.detach().cpu().item()
ctc_loss_cpu = ctc_loss.detach().cpu().item()
att_loss_cpu = att_loss.detach().cpu().item()
tot_frames += params.train_frames
tot_loss += loss_cpu
tot_ctc_loss += ctc_loss_cpu
tot_att_loss += att_loss_cpu
params.tot_frames += params.train_frames
params.tot_loss += loss_cpu
tot_avg_loss = tot_loss / tot_frames
tot_avg_ctc_loss = tot_ctc_loss / tot_frames
tot_avg_att_loss = tot_att_loss / tot_frames
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
f"batch avg ctc loss {ctc_loss_cpu/params.train_frames:.4f}, "
f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, "
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, "
f"total avg att loss: {tot_avg_att_loss:.4f}, "
f"total avg loss: {tot_avg_loss:.4f}, "
f"batch size: {batch_size}"
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}"
)
if batch_idx % params.log_interval == 0:
if tb_writer is not None:
tb_writer.add_scalar(
"train/current_ctc_loss",
ctc_loss_cpu / params.train_frames,
params.batch_idx_train,
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tb_writer.add_scalar(
"train/current_att_loss",
att_loss_cpu / params.train_frames,
params.batch_idx_train,
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
tb_writer.add_scalar(
"train/current_loss",
loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_ctc_loss",
tot_avg_ctc_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_att_loss",
tot_avg_att_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_loss",
tot_avg_loss,
params.batch_idx_train,
)
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
tot_loss = 0.0 # sum of losses over all batches
tot_ctc_loss = 0.0
tot_att_loss = 0.0
tot_frames = 0.0 # sum of frames over all batches
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
compute_validation_loss(
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
@ -591,33 +524,14 @@ def train_one_epoch(
world_size=world_size,
)
model.train()
logging.info(
f"Epoch {params.cur_epoch}, "
f"valid ctc loss {params.valid_ctc_loss:.4f},"
f"valid att loss {params.valid_att_loss:.4f},"
f"valid loss {params.valid_loss:.4f},"
f" best valid loss: {params.best_valid_loss:.4f} "
f"best valid epoch: {params.best_valid_epoch}"
)
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
if tb_writer is not None:
tb_writer.add_scalar(
"train/valid_ctc_loss",
params.valid_ctc_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/valid_att_loss",
params.valid_att_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/valid_loss",
params.valid_loss,
params.batch_idx_train,
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
params.train_loss = params.tot_loss / params.tot_frames
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
@ -743,6 +657,8 @@ def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
world_size = args.world_size
assert world_size >= 1

View File

@ -236,6 +236,7 @@ class Transformer(nn.Module):
x = nn.functional.log_softmax(x, dim=-1) # (N, T, C)
return x
@torch.jit.export
def decoder_forward(
self,
memory: torch.Tensor,
@ -264,11 +265,15 @@ class Transformer(nn.Module):
"""
ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
ys_in_pad = pad_sequence(
ys_in, batch_first=True, padding_value=float(eos_id)
)
ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1)
ys_out_pad = pad_sequence(
ys_out, batch_first=True, padding_value=float(-1)
)
device = memory.device
ys_in_pad = ys_in_pad.to(device)
@ -301,6 +306,7 @@ class Transformer(nn.Module):
return decoder_loss
@torch.jit.export
def decoder_nll(
self,
memory: torch.Tensor,
@ -331,11 +337,15 @@ class Transformer(nn.Module):
ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
ys_in_pad = pad_sequence(
ys_in, batch_first=True, padding_value=float(eos_id)
)
ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1)
ys_out_pad = pad_sequence(
ys_out, batch_first=True, padding_value=float(-1)
)
device = memory.device
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
@ -649,7 +659,8 @@ class PositionalEncoding(nn.Module):
self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = nn.Dropout(p=dropout)
self.pe = None
# not doing: self.pe = None because of errors thrown by torchscript
self.pe = torch.zeros(0, 0, dtype=torch.float32)
def extend_pe(self, x: torch.Tensor) -> None:
"""Extend the time t in the positional encoding if required.
@ -666,7 +677,6 @@ class PositionalEncoding(nn.Module):
"""
if self.pe is not None:
if self.pe.size(1) >= x.size(1):
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32)
@ -972,10 +982,7 @@ def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]:
Return a new list-of-list, where each sublist starts
with SOS ID.
"""
ans = []
for utt in token_ids:
ans.append([sos_id] + utt)
return ans
return [[sos_id] + utt for utt in token_ids]
def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
@ -992,7 +999,4 @@ def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
Return a new list-of-list, where each sublist ends
with EOS ID.
"""
ans = []
for utt in token_ids:
ans.append(utt + [eos_id])
return ans
return [utt + [eos_id] for utt in token_ids]

View File

@ -43,6 +43,7 @@ vocab_sizes=(
5000
2000
1000
500
)
# All files generated by this script are saved in "data".
@ -58,13 +59,13 @@ log() {
log "dl_dir: $dl_dir"
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "stage -1: Download LM"
log "Stage -1: Download LM"
[ ! -e $dl_dir/lm ] && mkdir -p $dl_dir/lm
./local/download_lm.py --out-dir=$dl_dir/lm
fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "stage 0: Download data"
log "Stage 0: Download data"
# If you have pre-downloaded it to /path/to/LibriSpeech,
# you can create a symlink
@ -127,7 +128,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "State 6: Prepare BPE based lang"
log "Stage 6: Prepare BPE based lang"
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}

View File

@ -269,7 +269,7 @@ class LibriSpeechAsrDataModule(DataModule):
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = SingleCutSampler(
valid_sampler = BucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
@ -302,12 +302,15 @@ class LibriSpeechAsrDataModule(DataModule):
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = SingleCutSampler(
cuts_test, max_duration=self.args.max_duration
sampler = BucketingSampler(
cuts_test, max_duration=self.args.max_duration, shuffle=False
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test, batch_size=None, sampler=sampler, num_workers=1
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
test_loaders.append(test_dl)

View File

@ -98,7 +98,7 @@ def get_parser():
)
parser.add_argument(
"--lattice-score-scale",
"--nbest-scale",
type=float,
default=0.5,
help="""The scale to be applied to `lattice.scores`.
@ -148,7 +148,7 @@ def decode_one_batch(
batch: dict,
lexicon: Lexicon,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[int]]]:
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -212,7 +212,7 @@ def decode_one_batch(
lattice = get_lattice(
nnet_output=nnet_output,
HLG=HLG,
decoding_graph=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
@ -231,7 +231,7 @@ def decode_one_batch(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
lattice_score_scale=params.lattice_score_scale,
nbest_scale=params.nbest_scale,
)
key = f"no_rescore-{params.num_paths}"
hyps = get_texts(best_path)
@ -250,7 +250,7 @@ def decode_one_batch(
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
lattice_score_scale=params.lattice_score_scale,
nbest_scale=params.nbest_scale,
)
else:
best_path_dict = rescore_with_whole_lattice(
@ -274,7 +274,7 @@ def decode_dataset(
HLG: k2.Fsa,
lexicon: Lexicon,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[int], List[int]]]]:
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:

View File

@ -233,7 +233,7 @@ def main():
lattice = get_lattice(
nnet_output=nnet_output,
HLG=HLG,
decoding_graph=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -20,17 +21,17 @@ import argparse
import logging
from pathlib import Path
from shutil import copyfile
from typing import Optional
from typing import Optional, Tuple
import k2
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from asr_datamodule import LibriSpeechAsrDataModule
from lhotse.utils import fix_random_seed
from model import TdnnLstm
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import StepLR
@ -43,6 +44,7 @@ from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
MetricsTracker,
encode_supervisions,
get_env_info,
setup_logger,
@ -269,7 +271,7 @@ def compute_loss(
batch: dict,
graph_compiler: CtcTrainingGraphCompiler,
is_training: bool,
):
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute CTC loss given the model and its inputs.
@ -326,13 +328,11 @@ def compute_loss(
assert loss.requires_grad == is_training
# train_frames and valid_frames are used for printing.
if is_training:
params.train_frames = supervision_segments[:, 2].sum().item()
else:
params.valid_frames = supervision_segments[:, 2].sum().item()
info = MetricsTracker()
info["frames"] = supervision_segments[:, 2].sum().item()
info["loss"] = loss.detach().cpu().item()
return loss
return loss, info
def compute_validation_loss(
@ -341,16 +341,16 @@ def compute_validation_loss(
graph_compiler: CtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> None:
) -> MetricsTracker:
"""Run the validation process. The validation loss
is saved in `params.valid_loss`.
"""
model.eval()
tot_loss = 0.0
tot_frames = 0.0
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
loss = compute_loss(
loss, loss_info = compute_loss(
params=params,
model=model,
batch=batch,
@ -359,22 +359,18 @@ def compute_validation_loss(
)
assert loss.requires_grad is False
loss_cpu = loss.detach().cpu().item()
tot_loss += loss_cpu
tot_frames += params.valid_frames
tot_loss = tot_loss + loss_info
if world_size > 1:
s = torch.tensor([tot_loss, tot_frames], device=loss.device)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
s = s.cpu().tolist()
tot_loss = s[0]
tot_frames = s[1]
tot_loss.reduce(loss.device)
params.valid_loss = tot_loss / tot_frames
loss_value = tot_loss["loss"] / tot_loss["frames"]
if params.valid_loss < params.best_valid_loss:
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = params.valid_loss
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
@ -413,67 +409,45 @@ def train_one_epoch(
"""
model.train()
tot_loss = 0.0 # reset after params.reset_interval of batches
tot_frames = 0.0 # reset after params.reset_interval of batches
params.tot_loss = 0.0
params.tot_frames = 0.0
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
loss = compute_loss(
loss, loss_info = compute_loss(
params=params,
model=model,
batch=batch,
graph_compiler=graph_compiler,
is_training=True,
)
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
# summary stats.
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
loss_cpu = loss.detach().cpu().item()
tot_frames += params.train_frames
tot_loss += loss_cpu
tot_avg_loss = tot_loss / tot_frames
params.tot_frames += params.train_frames
params.tot_loss += loss_cpu
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
f"total avg loss: {tot_avg_loss:.4f}, "
f"batch size: {batch_size}"
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}"
)
if batch_idx % params.log_interval == 0:
if tb_writer is not None:
tb_writer.add_scalar(
"train/current_loss",
loss_cpu / params.train_frames,
params.batch_idx_train,
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tb_writer.add_scalar(
"train/tot_avg_loss",
tot_avg_loss,
params.batch_idx_train,
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
tot_loss = 0
tot_frames = 0
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
compute_validation_loss(
valid_info = compute_validation_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
@ -481,13 +455,16 @@ def train_one_epoch(
world_size=world_size,
)
model.train()
logging.info(
f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f},"
f" best valid loss: {params.best_valid_loss:.4f} "
f"best valid epoch: {params.best_valid_epoch}"
logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}")
if tb_writer is not None:
valid_info.write_summary(
tb_writer,
"train/valid_",
params.batch_idx_train,
)
params.train_loss = params.tot_loss / params.tot_frames
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch

View File

@ -24,7 +24,7 @@ log() {
log "dl_dir: $dl_dir"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "stage 0: Download data"
log "Stage 0: Download data"
mkdir -p $dl_dir
if [ ! -f $dl_dir/waves_yesno/.completed ]; then

View File

@ -20,19 +20,18 @@ from functools import lru_cache
from pathlib import Path
from typing import List
from torch.utils.data import DataLoader
from icefall.dataset.datamodule import DataModule
from icefall.utils import str2bool
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import (
BucketingSampler,
CutConcatenate,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from torch.utils.data import DataLoader
from icefall.dataset.datamodule import DataModule
from icefall.utils import str2bool
class YesNoAsrDataModule(DataModule):
@ -198,7 +197,7 @@ class YesNoAsrDataModule(DataModule):
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler(
train_sampler = BucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
@ -226,12 +225,15 @@ class YesNoAsrDataModule(DataModule):
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = SingleCutSampler(
cuts_test, max_duration=self.args.max_duration
sampler = BucketingSampler(
cuts_test, max_duration=self.args.max_duration, shuffle=False
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test, batch_size=None, sampler=sampler, num_workers=1
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl

View File

@ -125,7 +125,7 @@ def decode_one_batch(
lattice = get_lattice(
nnet_output=nnet_output,
HLG=HLG,
decoding_graph=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,

View File

@ -176,7 +176,7 @@ def main():
lattice = get_lattice(
nnet_output=nnet_output,
HLG=HLG,
decoding_graph=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,

View File

@ -4,17 +4,17 @@ import argparse
import logging
from pathlib import Path
from shutil import copyfile
from typing import Optional
from typing import Optional, Tuple
import k2
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from asr_datamodule import YesNoAsrDataModule
from lhotse.utils import fix_random_seed
from model import Tdnn
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
@ -24,7 +24,13 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, get_env_info, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
get_env_info,
setup_logger,
str2bool,
)
def get_parser():
@ -122,6 +128,8 @@ def get_params() -> AttributeDict:
- valid_interval: Run validation if batch_idx % valid_interval` is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
- beam_size: It is used in k2.ctc_loss
- reduction: It is used in k2.ctc_loss
@ -142,6 +150,7 @@ def get_params() -> AttributeDict:
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 10,
"reset_interval": 20,
"valid_interval": 10,
"beam_size": 10,
"reduction": "sum",
@ -245,7 +254,7 @@ def compute_loss(
batch: dict,
graph_compiler: CtcTrainingGraphCompiler,
is_training: bool,
):
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute CTC loss given the model and its inputs.
@ -305,13 +314,11 @@ def compute_loss(
assert loss.requires_grad == is_training
# train_frames and valid_frames are used for printing.
if is_training:
params.train_frames = supervision_segments[:, 2].sum().item()
else:
params.valid_frames = supervision_segments[:, 2].sum().item()
info = MetricsTracker()
info["frames"] = supervision_segments[:, 2].sum().item()
info["loss"] = loss.detach().cpu().item()
return loss
return loss, info
def compute_validation_loss(
@ -320,16 +327,16 @@ def compute_validation_loss(
graph_compiler: CtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> None:
) -> MetricsTracker:
"""Run the validation process. The validation loss
is saved in `params.valid_loss`.
"""
model.eval()
tot_loss = 0.0
tot_frames = 0.0
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
loss = compute_loss(
loss, loss_info = compute_loss(
params=params,
model=model,
batch=batch,
@ -338,22 +345,18 @@ def compute_validation_loss(
)
assert loss.requires_grad is False
loss_cpu = loss.detach().cpu().item()
tot_loss += loss_cpu
tot_frames += params.valid_frames
tot_loss = tot_loss + loss_info
if world_size > 1:
s = torch.tensor([tot_loss, tot_frames], device=loss.device)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
s = s.cpu().tolist()
tot_loss = s[0]
tot_frames = s[1]
tot_loss.reduce(loss.device)
params.valid_loss = tot_loss / tot_frames
loss_value = tot_loss["loss"] / tot_loss["frames"]
if params.valid_loss < params.best_valid_loss:
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = params.valid_loss
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
@ -392,57 +395,45 @@ def train_one_epoch(
"""
model.train()
tot_loss = 0.0 # sum of losses over all batches
tot_frames = 0.0 # sum of frames over all batches
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
loss = compute_loss(
loss, loss_info = compute_loss(
params=params,
model=model,
batch=batch,
graph_compiler=graph_compiler,
is_training=True,
)
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
# summary stats.
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
loss_cpu = loss.detach().cpu().item()
tot_frames += params.train_frames
tot_loss += loss_cpu
tot_avg_loss = tot_loss / tot_frames
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
f"total avg loss: {tot_avg_loss:.4f}, "
f"batch size: {batch_size}"
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}"
)
if batch_idx % params.log_interval == 0:
if tb_writer is not None:
tb_writer.add_scalar(
"train/current_loss",
loss_cpu / params.train_frames,
params.batch_idx_train,
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tb_writer.add_scalar(
"train/tot_avg_loss",
tot_avg_loss,
params.batch_idx_train,
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
compute_validation_loss(
valid_info = compute_validation_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
@ -450,19 +441,16 @@ def train_one_epoch(
world_size=world_size,
)
model.train()
logging.info(
f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f},"
f" best valid loss: {params.best_valid_loss:.4f} "
f"best valid epoch: {params.best_valid_epoch}"
)
logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}")
if tb_writer is not None:
tb_writer.add_scalar(
"train/valid_loss",
params.valid_loss,
valid_info.write_summary(
tb_writer,
"train/valid_",
params.batch_idx_train,
)
params.train_loss = tot_loss / tot_frames
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch

View File

@ -66,7 +66,7 @@ def _intersect_device(
def get_lattice(
nnet_output: torch.Tensor,
HLG: k2.Fsa,
decoding_graph: k2.Fsa,
supervision_segments: torch.Tensor,
search_beam: float,
output_beam: float,
@ -79,8 +79,9 @@ def get_lattice(
Args:
nnet_output:
It is the output of a neural model of shape `(N, T, C)`.
HLG:
An Fsa, the decoding graph. See also `compile_HLG.py`.
decoding_graph:
An Fsa, the decoding graph. It can be either an HLG
(see `compile_HLG.py`) or an H (see `k2.ctc_topo`).
supervision_segments:
A 2-D **CPU** tensor of dtype `torch.int32` with 3 columns.
Each row contains information for a supervision segment. Column 0
@ -117,7 +118,7 @@ def get_lattice(
)
lattice = k2.intersect_dense_pruned(
HLG,
decoding_graph,
dense_fsa_vec,
search_beam=search_beam,
output_beam=output_beam,
@ -180,7 +181,7 @@ class Nbest(object):
lattice: k2.Fsa,
num_paths: int,
use_double_scores: bool = True,
lattice_score_scale: float = 0.5,
nbest_scale: float = 0.5,
) -> "Nbest":
"""Construct an Nbest object by **sampling** `num_paths` from a lattice.
@ -206,7 +207,7 @@ class Nbest(object):
Return an Nbest instance.
"""
saved_scores = lattice.scores.clone()
lattice.scores *= lattice_score_scale
lattice.scores *= nbest_scale
# path is a ragged tensor with dtype torch.int32.
# It has three axes [utt][path][arc_pos]
path = k2.random_paths(
@ -446,7 +447,7 @@ def nbest_decoding(
lattice: k2.Fsa,
num_paths: int,
use_double_scores: bool = True,
lattice_score_scale: float = 1.0,
nbest_scale: float = 1.0,
) -> k2.Fsa:
"""It implements something like CTC prefix beam search using n-best lists.
@ -474,7 +475,7 @@ def nbest_decoding(
use_double_scores:
True to use double precision floating point in the computation.
False to use single precision.
lattice_score_scale:
nbest_scale:
It's the scale applied to the `lattice.scores`. A smaller value
leads to more unique paths at the risk of missing the correct path.
Returns:
@ -484,7 +485,7 @@ def nbest_decoding(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
lattice_score_scale=lattice_score_scale,
nbest_scale=nbest_scale,
)
# nbest.fsa.scores contains 0s
@ -505,7 +506,7 @@ def nbest_oracle(
ref_texts: List[str],
word_table: k2.SymbolTable,
use_double_scores: bool = True,
lattice_score_scale: float = 0.5,
nbest_scale: float = 0.5,
oov: str = "<UNK>",
) -> Dict[str, List[List[int]]]:
"""Select the best hypothesis given a lattice and a reference transcript.
@ -517,7 +518,7 @@ def nbest_oracle(
The decoding result returned from this function is the best result that
we can obtain using n-best decoding with all kinds of rescoring techniques.
This function is useful to tune the value of `lattice_score_scale`.
This function is useful to tune the value of `nbest_scale`.
Args:
lattice:
@ -533,7 +534,7 @@ def nbest_oracle(
use_double_scores:
True to use double precision for computation. False to use
single precision.
lattice_score_scale:
nbest_scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
oov:
@ -549,7 +550,7 @@ def nbest_oracle(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
lattice_score_scale=lattice_score_scale,
nbest_scale=nbest_scale,
)
hyps = nbest.build_levenshtein_graphs()
@ -590,7 +591,7 @@ def rescore_with_n_best_list(
G: k2.Fsa,
num_paths: int,
lm_scale_list: List[float],
lattice_score_scale: float = 1.0,
nbest_scale: float = 1.0,
use_double_scores: bool = True,
) -> Dict[str, k2.Fsa]:
"""Rescore an n-best list with an n-gram LM.
@ -607,7 +608,7 @@ def rescore_with_n_best_list(
Size of nbest list.
lm_scale_list:
A list of float representing LM score scales.
lattice_score_scale:
nbest_scale:
Scale to be applied to ``lattice.score`` when sampling paths
using ``k2.random_paths``.
use_double_scores:
@ -631,7 +632,7 @@ def rescore_with_n_best_list(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
lattice_score_scale=lattice_score_scale,
nbest_scale=nbest_scale,
)
# nbest.fsa.scores are all 0s at this point
@ -769,7 +770,7 @@ def rescore_with_attention_decoder(
memory_key_padding_mask: Optional[torch.Tensor],
sos_id: int,
eos_id: int,
lattice_score_scale: float = 1.0,
nbest_scale: float = 1.0,
ngram_lm_scale: Optional[float] = None,
attention_scale: Optional[float] = None,
use_double_scores: bool = True,
@ -796,7 +797,7 @@ def rescore_with_attention_decoder(
The token ID for SOS.
eos_id:
The token ID for EOS.
lattice_score_scale:
nbest_scale:
It's the scale applied to `lattice.scores`. A smaller value
leads to more unique paths at the risk of missing the correct path.
ngram_lm_scale:
@ -812,7 +813,7 @@ def rescore_with_attention_decoder(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
lattice_score_scale=lattice_score_scale,
nbest_scale=nbest_scale,
)
# nbest.fsa.scores are all 0s at this point

View File

@ -1,4 +1,5 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
#
# See ../../LICENSE for clarification regarding multiple authors
#
@ -16,6 +17,7 @@
import argparse
import collections
import logging
import os
import subprocess
@ -32,6 +34,7 @@ import kaldialign
import lhotse
import torch
import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter
Pathlike = Union[str, Path]
@ -234,8 +237,8 @@ def encode_supervisions(
supervisions: dict, subsampling_factor: int
) -> Tuple[torch.Tensor, List[str]]:
"""
Encodes Lhotse's ``batch["supervisions"]`` dict into a pair of torch Tensor,
and a list of transcription strings.
Encodes Lhotse's ``batch["supervisions"]`` dict into
a pair of torch Tensor, and a list of transcription strings.
The supervision tensor has shape ``(batch_size, 3)``.
Its second dimension contains information about sequence index [0],
@ -407,13 +410,13 @@ def write_error_stats(
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
reference words (2337 correct)
- The difference between the reference transcript and predicted results.
- The difference between the reference transcript and predicted result.
An instance is given below::
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
The above example shows that the reference word is `EDISON`, but it is
predicted to `ADDISON` (a substitution error).
The above example shows that the reference word is `EDISON`,
but it is predicted to `ADDISON` (a substitution error).
Another example is::
@ -554,3 +557,76 @@ def write_error_stats(
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
return float(tot_err_rate)
class MetricsTracker(collections.defaultdict):
def __init__(self):
# Passing the type 'int' to the base-class constructor
# makes undefined items default to int() which is zero.
# This class will play a role as metrics tracker.
# It can record many metrics, including but not limited to loss.
super(MetricsTracker, self).__init__(int)
def __add__(self, other: "MetricsTracker") -> "MetricsTracker":
ans = MetricsTracker()
for k, v in self.items():
ans[k] = v
for k, v in other.items():
ans[k] = ans[k] + v
return ans
def __mul__(self, alpha: float) -> "MetricsTracker":
ans = MetricsTracker()
for k, v in self.items():
ans[k] = v * alpha
return ans
def __str__(self) -> str:
ans = ""
for k, v in self.norm_items():
norm_value = "%.4g" % v
ans += str(k) + "=" + str(norm_value) + ", "
frames = str(self["frames"])
ans += "over " + frames + " frames."
return ans
def norm_items(self) -> List[Tuple[str, float]]:
"""
Returns a list of pairs, like:
[('ctc_loss', 0.1), ('att_loss', 0.07)]
"""
num_frames = self["frames"] if "frames" in self else 1
ans = []
for k, v in self.items():
if k != "frames":
norm_value = float(v) / num_frames
ans.append((k, norm_value))
return ans
def reduce(self, device):
"""
Reduce using torch.distributed, which I believe ensures that
all processes get the total.
"""
keys = sorted(self.keys())
s = torch.tensor([float(self[k]) for k in keys], device=device)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
for k, v in zip(keys, s.cpu().tolist()):
self[k] = v
def write_summary(
self,
tb_writer: SummaryWriter,
prefix: str,
batch_idx: int,
) -> None:
"""Add logging information to a TensorBoard writer.
Args:
tb_writer: a TensorBoard writer
prefix: a prefix for the name of the loss, e.g. "train/valid_",
or "train/current_"
batch_idx: The current batch index, used as the x-axis of the plot.
"""
for k, v in self.norm_items():
tb_writer.add_scalar(prefix + k, v, batch_idx)

31
setup.py Normal file
View File

@ -0,0 +1,31 @@
#!/usr/bin/env python3
from setuptools import find_packages, setup
from pathlib import Path
icefall_dir = Path(__file__).parent
install_requires = (icefall_dir / "requirements.txt").read_text().splitlines()
setup(
name="icefall",
version="1.0",
python_requires=">=3.6.0",
description="Speech processing recipes using k2 and Lhotse.",
author="The k2 and Lhotse Development Team",
license="Apache-2.0 License",
packages=find_packages(),
install_requires=install_requires,
classifiers=[
"Development Status :: 3 - Alpha",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Intended Audience :: Science/Research",
"Operating System :: POSIX :: Linux",
"License :: OSI Approved :: Apache Software License",
"Topic :: Multimedia :: Sound/Audio :: Speech",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries :: Python Modules",
"Typing :: Typed",
],
)

View File

@ -43,7 +43,7 @@ def test_nbest_from_lattice():
lattice=lattice,
num_paths=10,
use_double_scores=True,
lattice_score_scale=0.5,
nbest_scale=0.5,
)
# each lattice has only 4 distinct paths that have different word sequences:
# 10->30