Shallow fusion & LODR documentation (#1142)

* add shallow fusion documentation

* add documentation for LODR

* upload docs for LM rescoring
This commit is contained in:
marcoyang1998 2023-07-06 19:11:01 +08:00 committed by GitHub
parent 6fd674312c
commit 11523c5b89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 652 additions and 22 deletions

View File

@ -86,6 +86,7 @@ rst_epilog = """
.. _git-lfs: https://git-lfs.com/ .. _git-lfs: https://git-lfs.com/
.. _ncnn: https://github.com/tencent/ncnn .. _ncnn: https://github.com/tencent/ncnn
.. _LibriSpeech: https://www.openslr.org/12 .. _LibriSpeech: https://www.openslr.org/12
.. _Gigaspeech: https://github.com/SpeechColab/GigaSpeech
.. _musan: http://www.openslr.org/17/ .. _musan: http://www.openslr.org/17/
.. _ONNX: https://github.com/onnx/onnx .. _ONNX: https://github.com/onnx/onnx
.. _onnxruntime: https://github.com/microsoft/onnxruntime .. _onnxruntime: https://github.com/microsoft/onnxruntime

View File

@ -0,0 +1,184 @@
.. _LODR:
LODR for RNN Transducer
=======================
As a type of E2E model, neural transducers are usually considered as having an internal
language model, which learns the language level information on the training corpus.
In real-life scenario, there is often a mismatch between the training corpus and the target corpus space.
This mismatch can be a problem when decoding for neural transducer models with language models as its internal
language can act "against" the external LM. In this tutorial, we show how to use
`Low-order Density Ratio <https://arxiv.org/abs/2203.16776>`_ to alleviate this effect to further improve the performance
of langugae model integration.
.. note::
This tutorial is based on the recipe
`pruned_transducer_stateless7_streaming <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming>`_,
which is a streaming transducer model trained on `LibriSpeech`_.
However, you can easily apply LODR to other recipes.
If you encounter any problems, please open an issue here `icefall <https://github.com/k2-fsa/icefall/issues>`__.
.. note::
For simplicity, the training and testing corpus in this tutorial are the same (`LibriSpeech`_). However,
you can change the testing set to any other domains (e.g `GigaSpeech`_) and prepare the language models
using that corpus.
First, let's have a look at some background information. As the predecessor of LODR, Density Ratio (DR) is first proposed `here <https://arxiv.org/abs/2002.11268>`_
to address the language information mismatch between the training
corpus (source domain) and the testing corpus (target domain). Assuming that the source domain and the test domain
are acoustically similar, DR derives the following formular for decoding with Bayes' theorem:
.. math::
\text{score}\left(y_u|\mathit{x},y\right) =
\log p\left(y_u|\mathit{x},y_{1:u-1}\right) +
\lambda_1 \log p_{\text{Target LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) -
\lambda_2 \log p_{\text{Source LM}}\left(y_u|\mathit{x},y_{1:u-1}\right)
where :math:`\lambda_1` and :math:`\lambda_2` are the weights of LM scores for target domain and source domain respectively.
Here, the source domain LM is trained on the training corpus. The only difference in the above formular compared to
shallow fusion is the subtraction of the source domain LM.
Some works treat the predictor and the joiner of the neural transducer as its internal LM. However, the LM is
considered to be weak and can only capture low-level language information. Therefore, `LODR <https://arxiv.org/abs/2203.16776>`__ proposed to use
a low-order n-gram LM as an approximation of the ILM of the neural transducer. This leads to the following formula
during decoding for transducer model:
.. math::
\text{score}\left(y_u|\mathit{x},y\right) =
\log p_{rnnt}\left(y_u|\mathit{x},y_{1:u-1}\right) +
\lambda_1 \log p_{\text{Target LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) -
\lambda_2 \log p_{\text{bi-gram}}\left(y_u|\mathit{x},y_{1:u-1}\right)
In LODR, an additional bi-gram LM estimated on the source domain (e.g training corpus) is required. Comared to DR,
the only difference lies in the choice of source domain LM. According to the original `paper <https://arxiv.org/abs/2203.16776>`_,
LODR achieves similar performance compared DR in both intra-domain and cross-domain settings.
As a bi-gram is much faster to evaluate, LODR is usually much faster.
Now, we will show you how to use LODR in ``icefall``.
For illustration purpose, we will use a pre-trained ASR model from this `link <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29>`_.
If you want to train your model from scratch, please have a look at :ref:`non_streaming_librispeech_pruned_transducer_stateless`.
The testing scenario here is intra-domain (we decode the model trained on `LibriSpeech`_ on `LibriSpeech`_ testing sets).
As the initial step, let's download the pre-trained model.
.. code-block:: bash
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
$ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
$ git lfs pull --include "pretrained.pt"
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
To test the model, let's have a look at the decoding results **without** using LM. This can be done via the following command:
.. code-block:: bash
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/
$ ./pruned_transducer_stateless7_streaming/decode.py \
--epoch 99 \
--avg 1 \
--use-averaged-model False \
--exp-dir $exp_dir \
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method modified_beam_search
The following WERs are achieved on test-clean and test-other:
.. code-block:: text
$ For test-clean, WER of different settings are:
$ beam_size_4 3.11 best for test-clean
$ For test-other, WER of different settings are:
$ beam_size_4 7.93 best for test-other
Then, we download the external language model and bi-gram LM that are necessary for LODR.
Note that the bi-gram is estimated on the LibriSpeech 960 hours' text.
.. code-block:: bash
$ # download the external LM
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
$ # create a symbolic link so that the checkpoint can be loaded
$ pushd icefall-librispeech-rnn-lm/exp
$ git lfs pull --include "pretrained.pt"
$ ln -s pretrained.pt epoch-99.pt
$ popd
$
$ # download the bi-gram
$ git lfs install
$ git clone https://huggingface.co/marcoyang/librispeech_bigram
$ pushd data/lang_bpe_500
$ ln -s ../../librispeech_bigram/2gram.fst.txt .
$ popd
Then, we perform LODR decoding by setting ``--decoding-method`` to ``modified_beam_search_lm_LODR``:
.. code-block:: bash
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
$ lm_dir=./icefall-librispeech-rnn-lm/exp
$ lm_scale=0.42
$ LODR_scale=-0.24
$ ./pruned_transducer_stateless7_streaming/decode.py \
--epoch 99 \
--avg 1 \
--use-averaged-model False \
--beam-size 4 \
--exp-dir $exp_dir \
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method modified_beam_search_lm_LODR \
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model
--use-shallow-fusion 1 \
--lm-type rnn \
--lm-exp-dir $lm_dir \
--lm-epoch 99 \
--lm-scale $lm_scale \
--lm-avg 1 \
--rnn-lm-embedding-dim 2048 \
--rnn-lm-hidden-dim 2048 \
--rnn-lm-num-layers 3 \
--lm-vocab-size 500 \
--tokens-ngram 2 \
--ngram-lm-scale $LODR_scale
There are two extra arguments that need to be given when doing LODR. ``--tokens-ngram`` specifies the order of n-gram. As we
are using a bi-gram, we set it to 2. ``--ngram-lm-scale`` is the scale of the bi-gram, it should be a negative number
as we are subtracting the bi-gram's score during decoding.
The decoding results obtained with the above command are shown below:
.. code-block:: text
$ For test-clean, WER of different settings are:
$ beam_size_4 2.61 best for test-clean
$ For test-other, WER of different settings are:
$ beam_size_4 6.74 best for test-other
Recall that the lowest WER we obtained in :ref:`shallow_fusion` with beam size of 4 is ``2.77/7.08``, LODR
indeed **further improves** the WER. We can do even better if we increase ``--beam-size``:
.. list-table:: WER of LODR with different beam sizes
:widths: 25 25 50
:header-rows: 1
* - Beam size
- test-clean
- test-other
* - 4
- 2.61
- 6.74
* - 8
- 2.45
- 6.38
* - 12
- 2.4
- 6.23

View File

@ -0,0 +1,12 @@
Decoding with language models
=============================
This section describes how to use external langugage models
during decoding to improve the WER of transducer models.
.. toctree::
:maxdepth: 2
shallow-fusion
LODR
rescoring

View File

@ -0,0 +1,252 @@
.. _rescoring:
LM rescoring for Transducer
=================================
LM rescoring is a commonly used approach to incorporate external LM information. Unlike shallow-fusion-based
methods (see :ref:`shallow-fusion`, :ref:`LODR`), rescoring is usually performed to re-rank the n-best hypotheses after beam search.
Rescoring is usually more efficient than shallow fusion since less computation is performed on the external LM.
In this tutorial, we will show you how to use external LM to rescore the n-best hypotheses decoded from neural transducer models in
`icefall <https://github.com/k2-fsa/icefall>`__.
.. note::
This tutorial is based on the recipe
`pruned_transducer_stateless7_streaming <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming>`_,
which is a streaming transducer model trained on `LibriSpeech`_.
However, you can easily apply shallow fusion to other recipes.
If you encounter any problems, please open an issue `here <https://github.com/k2-fsa/icefall/issues>`_.
.. note::
For simplicity, the training and testing corpus in this tutorial is the same (`LibriSpeech`_). However, you can change the testing set
to any other domains (e.g `GigaSpeech`_) and use an external LM trained on that domain.
.. HINT::
We recommend you to use a GPU for decoding.
For illustration purpose, we will use a pre-trained ASR model from this `link <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29>`__.
If you want to train your model from scratch, please have a look at :ref:`non_streaming_librispeech_pruned_transducer_stateless`.
As the initial step, let's download the pre-trained model.
.. code-block:: bash
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
$ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
$ git lfs pull --include "pretrained.pt"
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
As usual, we first test the model's performance without external LM. This can be done via the following command:
.. code-block:: bash
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/
$ ./pruned_transducer_stateless7_streaming/decode.py \
--epoch 99 \
--avg 1 \
--use-averaged-model False \
--exp-dir $exp_dir \
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method modified_beam_search
The following WERs are achieved on test-clean and test-other:
.. code-block:: text
$ For test-clean, WER of different settings are:
$ beam_size_4 3.11 best for test-clean
$ For test-other, WER of different settings are:
$ beam_size_4 7.93 best for test-other
Now, we will try to improve the above WER numbers via external LM rescoring. We will download
a pre-trained LM from this `link <https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm>`__.
.. note::
This is an RNN LM trained on the LibriSpeech text corpus. So it might not be ideal for other corpus.
You may also train a RNN LM from scratch. Please refer to this `script <https://github.com/k2-fsa/icefall/blob/master/icefall/rnn_lm/train.py>`__
for training a RNN LM and this `script <https://github.com/k2-fsa/icefall/blob/master/icefall/transformer_lm/train.py>`__ to train a transformer LM.
.. code-block:: bash
$ # download the external LM
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
$ # create a symbolic link so that the checkpoint can be loaded
$ pushd icefall-librispeech-rnn-lm/exp
$ git lfs pull --include "pretrained.pt"
$ ln -s pretrained.pt epoch-99.pt
$ popd
With the RNNLM available, we can rescore the n-best hypotheses generated from `modified_beam_search`. Here,
`n` should be the number of beams, i.e ``--beam-size``. The command for LM rescoring is
as follows. Note that the ``--decoding-method`` is set to `modified_beam_search_lm_rescore` and ``--use-shallow-fusion``
is set to `False`.
.. code-block:: bash
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
$ lm_dir=./icefall-librispeech-rnn-lm/exp
$ lm_scale=0.43
$ ./pruned_transducer_stateless7_streaming/decode.py \
--epoch 99 \
--avg 1 \
--use-averaged-model False \
--beam-size 4 \
--exp-dir $exp_dir \
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method modified_beam_search_lm_rescore \
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model
--use-shallow-fusion 0 \
--lm-type rnn \
--lm-exp-dir $lm_dir \
--lm-epoch 99 \
--lm-scale $lm_scale \
--lm-avg 1 \
--rnn-lm-embedding-dim 2048 \
--rnn-lm-hidden-dim 2048 \
--rnn-lm-num-layers 3 \
--lm-vocab-size 500
.. code-block:: text
$ For test-clean, WER of different settings are:
$ beam_size_4 2.93 best for test-clean
$ For test-other, WER of different settings are:
$ beam_size_4 7.6 best for test-other
Great! We made some improvements! Increasing the size of the n-best hypotheses will further boost the performance,
see the following table:
.. list-table:: WERs of LM rescoring with different beam sizes
:widths: 25 25 25
:header-rows: 1
* - Beam size
- test-clean
- test-other
* - 4
- 2.93
- 7.6
* - 8
- 2.67
- 7.11
* - 12
- 2.59
- 6.86
In fact, we can also apply LODR (see :ref:`LODR`) when doing LM rescoring. To do so, we need to
download the bi-gram required by LODR:
.. code-block:: bash
$ # download the bi-gram
$ git lfs install
$ git clone https://huggingface.co/marcoyang/librispeech_bigram
$ pushd data/lang_bpe_500
$ ln -s ../../librispeech_bigram/2gram.arpa .
$ popd
Then we can performn LM rescoring + LODR by changing the decoding method to `modified_beam_search_lm_rescore_LODR`.
.. note::
This decoding method requires the dependency of `kenlm <https://github.com/kpu/kenlm>`_. You can install it
via this command: `pip install https://github.com/kpu/kenlm/archive/master.zip`.
.. code-block:: bash
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
$ lm_dir=./icefall-librispeech-rnn-lm/exp
$ lm_scale=0.43
$ ./pruned_transducer_stateless7_streaming/decode.py \
--epoch 99 \
--avg 1 \
--use-averaged-model False \
--beam-size 4 \
--exp-dir $exp_dir \
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method modified_beam_search_lm_rescore_LODR \
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model
--use-shallow-fusion 0 \
--lm-type rnn \
--lm-exp-dir $lm_dir \
--lm-epoch 99 \
--lm-scale $lm_scale \
--lm-avg 1 \
--rnn-lm-embedding-dim 2048 \
--rnn-lm-hidden-dim 2048 \
--rnn-lm-num-layers 3 \
--lm-vocab-size 500
You should see the following WERs after executing the commands above:
.. code-block:: text
$ For test-clean, WER of different settings are:
$ beam_size_4 2.9 best for test-clean
$ For test-other, WER of different settings are:
$ beam_size_4 7.57 best for test-other
It's slightly better than LM rescoring. If we further increase the beam size, we will see
further improvements from LM rescoring + LODR:
.. list-table:: WERs of LM rescoring + LODR with different beam sizes
:widths: 25 25 25
:header-rows: 1
* - Beam size
- test-clean
- test-other
* - 4
- 2.9
- 7.57
* - 8
- 2.63
- 7.04
* - 12
- 2.52
- 6.73
As mentioned earlier, LM rescoring is usually faster than shallow-fusion based methods.
Here, we benchmark the WERs and decoding speed of them:
.. list-table:: LM-rescoring-based methods vs shallow-fusion-based methods (The numbers in each field is WER on test-clean, WER on test-other and decoding time on test-clean)
:widths: 25 25 25 25
:header-rows: 1
* - Decoding method
- beam=4
- beam=8
- beam=12
* - `modified_beam_search`
- 3.11/7.93; 132s
- 3.1/7.95; 177s
- 3.1/7.96; 210s
* - `modified_beam_search_lm_shallow_fusion`
- 2.77/7.08; 262s
- 2.62/6.65; 352s
- 2.58/6.65; 488s
* - LODR
- 2.61/6.74; 400s
- 2.45/6.38; 610s
- 2.4/6.23; 870s
* - `modified_beam_search_lm_rescore`
- 2.93/7.6; 156s
- 2.67/7.11; 203s
- 2.59/6.86; 255s
* - `modified_beam_search_lm_rescore_LODR`
- 2.9/7.57; 160s
- 2.63/7.04; 203s
- 2.52/6.73; 263s
.. note::
Decoding is performed with a single 32G V100, we set ``--max-duration`` to 600.
Decoding time here is only for reference and it may vary.

View File

@ -0,0 +1,176 @@
.. _shallow_fusion:
Shallow fusion for Transducer
=================================
External language models (LM) are commonly used to improve WERs for E2E ASR models.
This tutorial shows you how to perform ``shallow fusion`` with an external LM
to improve the word-error-rate of a transducer model.
.. note::
This tutorial is based on the recipe
`pruned_transducer_stateless7_streaming <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming>`_,
which is a streaming transducer model trained on `LibriSpeech`_.
However, you can easily apply shallow fusion to other recipes.
If you encounter any problems, please open an issue here `icefall <https://github.com/k2-fsa/icefall/issues>`_.
.. note::
For simplicity, the training and testing corpus in this tutorial is the same (`LibriSpeech`_). However, you can change the testing set
to any other domains (e.g `GigaSpeech`_) and use an external LM trained on that domain.
.. HINT::
We recommend you to use a GPU for decoding.
For illustration purpose, we will use a pre-trained ASR model from this `link <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29>`__.
If you want to train your model from scratch, please have a look at :ref:`non_streaming_librispeech_pruned_transducer_stateless`.
As the initial step, let's download the pre-trained model.
.. code-block:: bash
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
$ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
$ git lfs pull --include "pretrained.pt"
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
To test the model, let's have a look at the decoding results without using LM. This can be done via the following command:
.. code-block:: bash
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/
$ ./pruned_transducer_stateless7_streaming/decode.py \
--epoch 99 \
--avg 1 \
--use-averaged-model False \
--exp-dir $exp_dir \
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method modified_beam_search
The following WERs are achieved on test-clean and test-other:
.. code-block:: text
$ For test-clean, WER of different settings are:
$ beam_size_4 3.11 best for test-clean
$ For test-other, WER of different settings are:
$ beam_size_4 7.93 best for test-other
These are already good numbers! But we can further improve it by using shallow fusion with external LM.
Training a language model usually takes a long time, we can download a pre-trained LM from this `link <https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm>`__.
.. code-block:: bash
$ # download the external LM
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
$ # create a symbolic link so that the checkpoint can be loaded
$ pushd icefall-librispeech-rnn-lm/exp
$ git lfs pull --include "pretrained.pt"
$ ln -s pretrained.pt epoch-99.pt
$ popd
.. note::
This is an RNN LM trained on the LibriSpeech text corpus. So it might not be ideal for other corpus.
You may also train a RNN LM from scratch. Please refer to this `script <https://github.com/k2-fsa/icefall/blob/master/icefall/rnn_lm/train.py>`__
for training a RNN LM and this `script <https://github.com/k2-fsa/icefall/blob/master/icefall/transformer_lm/train.py>`__ to train a transformer LM.
To use shallow fusion for decoding, we can execute the following command:
.. code-block:: bash
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
$ lm_dir=./icefall-librispeech-rnn-lm/exp
$ lm_scale=0.29
$ ./pruned_transducer_stateless7_streaming/decode.py \
--epoch 99 \
--avg 1 \
--use-averaged-model False \
--beam-size 4 \
--exp-dir $exp_dir \
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method modified_beam_search_lm_shallow_fusion \
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model
--use-shallow-fusion 1 \
--lm-type rnn \
--lm-exp-dir $lm_dir \
--lm-epoch 99 \
--lm-scale $lm_scale \
--lm-avg 1 \
--rnn-lm-embedding-dim 2048 \
--rnn-lm-hidden-dim 2048 \
--rnn-lm-num-layers 3 \
--lm-vocab-size 500
Note that we set ``--decoding-method modified_beam_search_lm_shallow_fusion`` and ``--use-shallow-fusion True``
to use shallow fusion. ``--lm-type`` specifies the type of neural LM we are going to use, you can either choose
between ``rnn`` or ``transformer``. The following three arguments are associated with the rnn:
- ``--rnn-lm-embedding-dim``
The embedding dimension of the RNN LM
- ``--rnn-lm-hidden-dim``
The hidden dimension of the RNN LM
- ``--rnn-lm-num-layers``
The number of RNN layers in the RNN LM.
The decoding result obtained with the above command are shown below.
.. code-block:: text
$ For test-clean, WER of different settings are:
$ beam_size_4 2.77 best for test-clean
$ For test-other, WER of different settings are:
$ beam_size_4 7.08 best for test-other
The improvement of shallow fusion is very obvious! The relative WER reduction on test-other is around 10.5%.
A few parameters can be tuned to further boost the performance of shallow fusion:
- ``--lm-scale``
Controls the scale of the LM. If too small, the external language model may not be fully utilized; if too large,
the LM score may dominant during decoding, leading to bad WER. A typical value of this is around 0.3.
- ``--beam-size``
The number of active paths in the search beam. It controls the trade-off between decoding efficiency and accuracy.
Here, we also show how `--beam-size` effect the WER and decoding time:
.. list-table:: WERs and decoding time (on test-clean) of shallow fusion with different beam sizes
:widths: 25 25 25 25
:header-rows: 1
* - Beam size
- test-clean
- test-other
- Decoding time on test-clean (s)
* - 4
- 2.77
- 7.08
- 262
* - 8
- 2.62
- 6.65
- 352
* - 12
- 2.58
- 6.65
- 488
As we see, a larger beam size during shallow fusion improves the WER, but is also slower.

View File

@ -34,3 +34,8 @@ speech recognition recipes using `k2 <https://github.com/k2-fsa/k2>`_.
contributing/index contributing/index
huggingface/index huggingface/index
.. toctree::
:maxdepth: 2
decoding-with-langugage-models/index

View File

@ -1,7 +1,7 @@
Distillation with HuBERT Distillation with HuBERT
======================== ========================
This tutorial shows you how to perform knowledge distillation in `icefall`_ This tutorial shows you how to perform knowledge distillation in `icefall <https://github.com/k2-fsa/icefall>`_
with the `LibriSpeech`_ dataset. The distillation method with the `LibriSpeech`_ dataset. The distillation method
used here is called "Multi Vector Quantization Knowledge Distillation" (MVQ-KD). used here is called "Multi Vector Quantization Knowledge Distillation" (MVQ-KD).
Please have a look at our paper `Predicting Multi-Codebook Vector Quantization Indexes for Knowledge Distillation <https://arxiv.org/abs/2211.00508>`_ Please have a look at our paper `Predicting Multi-Codebook Vector Quantization Indexes for Knowledge Distillation <https://arxiv.org/abs/2211.00508>`_
@ -13,7 +13,7 @@ for more details about MVQ-KD.
`pruned_transducer_stateless4 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless4>`_. `pruned_transducer_stateless4 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless4>`_.
Currently, we only implement MVQ-KD in this recipe. However, MVQ-KD is theoretically applicable to all recipes Currently, we only implement MVQ-KD in this recipe. However, MVQ-KD is theoretically applicable to all recipes
with only minor changes needed. Feel free to try out MVQ-KD in different recipes. If you with only minor changes needed. Feel free to try out MVQ-KD in different recipes. If you
encounter any problems, please open an issue here `icefall <https://github.com/k2-fsa/icefall/issues>`_. encounter any problems, please open an issue here `icefall <https://github.com/k2-fsa/icefall/issues>`__.
.. note:: .. note::
@ -217,7 +217,7 @@ the following command.
--exp-dir $exp_dir \ --exp-dir $exp_dir \
--enable-distillation True --enable-distillation True
You should get similar results as `here <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS-100hours.md#distillation-with-hubert>`_. You should get similar results as `here <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS-100hours.md#distillation-with-hubert>`__.
That's all! Feel free to experiment with your own setups and report your results. That's all! Feel free to experiment with your own setups and report your results.
If you encounter any problems during training, please open up an issue `here <https://github.com/k2-fsa/icefall/issues>`_. If you encounter any problems during training, please open up an issue `here <https://github.com/k2-fsa/icefall/issues>`__.

View File

@ -8,10 +8,10 @@ with the `LibriSpeech <https://www.openslr.org/12>`_ dataset.
.. Note:: .. Note::
The tutorial is suitable for `pruned_transducer_stateless <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless>`_, The tutorial is suitable for `pruned_transducer_stateless <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless>`__,
`pruned_transducer_stateless2 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless2>`_, `pruned_transducer_stateless2 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless2>`__,
`pruned_transducer_stateless4 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless4>`_, `pruned_transducer_stateless4 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless4>`__,
`pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless5>`_, `pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless5>`__,
We will take pruned_transducer_stateless4 as an example in this tutorial. We will take pruned_transducer_stateless4 as an example in this tutorial.
.. HINT:: .. HINT::
@ -237,7 +237,7 @@ them, please modify ``./pruned_transducer_stateless4/train.py`` directly.
.. NOTE:: .. NOTE::
The options for `pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless5/train.py>`_ are a little different from The options for `pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless5/train.py>`__ are a little different from
other recipes. It allows you to configure ``--num-encoder-layers``, ``--dim-feedforward``, ``--nhead``, ``--encoder-dim``, ``--decoder-dim``, ``--joiner-dim`` from commandline, so that you can train models with different size with pruned_transducer_stateless5. other recipes. It allows you to configure ``--num-encoder-layers``, ``--dim-feedforward``, ``--nhead``, ``--encoder-dim``, ``--decoder-dim``, ``--joiner-dim`` from commandline, so that you can train models with different size with pruned_transducer_stateless5.
@ -529,13 +529,13 @@ Download pretrained models
If you don't want to train from scratch, you can download the pretrained models If you don't want to train from scratch, you can download the pretrained models
by visiting the following links: by visiting the following links:
- `pruned_transducer_stateless <https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12>`_ - `pruned_transducer_stateless <https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12>`__
- `pruned_transducer_stateless2 <https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless2-2022-04-29>`_ - `pruned_transducer_stateless2 <https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless2-2022-04-29>`__
- `pruned_transducer_stateless4 <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless4-2022-06-03>`_ - `pruned_transducer_stateless4 <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless4-2022-06-03>`__
- `pruned_transducer_stateless5 <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless5-2022-07-07>`_ - `pruned_transducer_stateless5 <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless5-2022-07-07>`__
See `<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md>`_ See `<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md>`_
for the details of the above pretrained models for the details of the above pretrained models

View File

@ -45,9 +45,9 @@ the input features.
We have three variants of Emformer models in ``icefall``. We have three variants of Emformer models in ``icefall``.
- ``pruned_stateless_emformer_rnnt2`` using Emformer from torchaudio, see `LibriSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2>`_. - ``pruned_stateless_emformer_rnnt2`` using Emformer from torchaudio, see `LibriSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2>`__.
- ``conv_emformer_transducer_stateless`` using ConvEmformer implemented by ourself. Different from the Emformer in torchaudio, - ``conv_emformer_transducer_stateless`` using ConvEmformer implemented by ourself. Different from the Emformer in torchaudio,
ConvEmformer has a convolution in each layer and uses the mechanisms in our reworked conformer model. ConvEmformer has a convolution in each layer and uses the mechanisms in our reworked conformer model.
See `LibriSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/conv_emformer_transducer_stateless>`_. See `LibriSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/conv_emformer_transducer_stateless>`__.
- ``conv_emformer_transducer_stateless2`` using ConvEmformer implemented by ourself. The only difference from the above one is that - ``conv_emformer_transducer_stateless2`` using ConvEmformer implemented by ourself. The only difference from the above one is that
it uses a simplified memory bank. See `LibriSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/conv_emformer_transducer_stateless2>`_. it uses a simplified memory bank. See `LibriSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/conv_emformer_transducer_stateless2>`_.

View File

@ -6,10 +6,10 @@ with the `LibriSpeech <https://www.openslr.org/12>`_ dataset.
.. Note:: .. Note::
The tutorial is suitable for `pruned_transducer_stateless <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless>`_, The tutorial is suitable for `pruned_transducer_stateless <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless>`__,
`pruned_transducer_stateless2 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless2>`_, `pruned_transducer_stateless2 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless2>`__,
`pruned_transducer_stateless4 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless4>`_, `pruned_transducer_stateless4 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless4>`__,
`pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless5>`_, `pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless5>`__,
We will take pruned_transducer_stateless4 as an example in this tutorial. We will take pruned_transducer_stateless4 as an example in this tutorial.
.. HINT:: .. HINT::
@ -264,7 +264,7 @@ them, please modify ``./pruned_transducer_stateless4/train.py`` directly.
.. NOTE:: .. NOTE::
The options for `pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless5/train.py>`_ are a little different from The options for `pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless5/train.py>`__ are a little different from
other recipes. It allows you to configure ``--num-encoder-layers``, ``--dim-feedforward``, ``--nhead``, ``--encoder-dim``, ``--decoder-dim``, ``--joiner-dim`` from commandline, so that you can train models with different size with pruned_transducer_stateless5. other recipes. It allows you to configure ``--num-encoder-layers``, ``--dim-feedforward``, ``--nhead``, ``--encoder-dim``, ``--decoder-dim``, ``--joiner-dim`` from commandline, so that you can train models with different size with pruned_transducer_stateless5.

View File

@ -6,7 +6,7 @@ with the `LibriSpeech <https://www.openslr.org/12>`_ dataset.
.. Note:: .. Note::
The tutorial is suitable for `pruned_transducer_stateless7_streaming <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming>`_, The tutorial is suitable for `pruned_transducer_stateless7_streaming <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming>`__,
.. HINT:: .. HINT::
@ -642,7 +642,7 @@ Download pretrained models
If you don't want to train from scratch, you can download the pretrained models If you don't want to train from scratch, you can download the pretrained models
by visiting the following links: by visiting the following links:
- `pruned_transducer_stateless7_streaming <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29>`_ - `pruned_transducer_stateless7_streaming <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29>`__
See `<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md>`_ See `<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md>`_
for the details of the above pretrained models for the details of the above pretrained models