Decode zipformer with external LMs (#1193)

* update some documentation

* support decoding with LMs in zipformer recipe

* update RESULTS.md
This commit is contained in:
marcoyang1998 2023-08-03 15:50:35 +08:00 committed by GitHub
parent bcabaf896c
commit 1ee251c8b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 238 additions and 56 deletions

View File

@ -4,59 +4,59 @@ LODR for RNN Transducer
======================= =======================
As a type of E2E model, neural transducers are usually considered as having an internal As a type of E2E model, neural transducers are usually considered as having an internal
language model, which learns the language level information on the training corpus. language model, which learns the language level information on the training corpus.
In real-life scenario, there is often a mismatch between the training corpus and the target corpus space. In real-life scenario, there is often a mismatch between the training corpus and the target corpus space.
This mismatch can be a problem when decoding for neural transducer models with language models as its internal This mismatch can be a problem when decoding for neural transducer models with language models as its internal
language can act "against" the external LM. In this tutorial, we show how to use language can act "against" the external LM. In this tutorial, we show how to use
`Low-order Density Ratio <https://arxiv.org/abs/2203.16776>`_ to alleviate this effect to further improve the performance `Low-order Density Ratio <https://arxiv.org/abs/2203.16776>`_ to alleviate this effect to further improve the performance
of langugae model integration. of langugae model integration.
.. note:: .. note::
This tutorial is based on the recipe This tutorial is based on the recipe
`pruned_transducer_stateless7_streaming <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming>`_, `pruned_transducer_stateless7_streaming <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming>`_,
which is a streaming transducer model trained on `LibriSpeech`_. which is a streaming transducer model trained on `LibriSpeech`_.
However, you can easily apply LODR to other recipes. However, you can easily apply LODR to other recipes.
If you encounter any problems, please open an issue here `icefall <https://github.com/k2-fsa/icefall/issues>`__. If you encounter any problems, please open an issue here `icefall <https://github.com/k2-fsa/icefall/issues>`__.
.. note:: .. note::
For simplicity, the training and testing corpus in this tutorial are the same (`LibriSpeech`_). However, For simplicity, the training and testing corpus in this tutorial are the same (`LibriSpeech`_). However,
you can change the testing set to any other domains (e.g `GigaSpeech`_) and prepare the language models you can change the testing set to any other domains (e.g `GigaSpeech`_) and prepare the language models
using that corpus. using that corpus.
First, let's have a look at some background information. As the predecessor of LODR, Density Ratio (DR) is first proposed `here <https://arxiv.org/abs/2002.11268>`_ First, let's have a look at some background information. As the predecessor of LODR, Density Ratio (DR) is first proposed `here <https://arxiv.org/abs/2002.11268>`_
to address the language information mismatch between the training to address the language information mismatch between the training
corpus (source domain) and the testing corpus (target domain). Assuming that the source domain and the test domain corpus (source domain) and the testing corpus (target domain). Assuming that the source domain and the test domain
are acoustically similar, DR derives the following formular for decoding with Bayes' theorem: are acoustically similar, DR derives the following formular for decoding with Bayes' theorem:
.. math:: .. math::
\text{score}\left(y_u|\mathit{x},y\right) = \text{score}\left(y_u|\mathit{x},y\right) =
\log p\left(y_u|\mathit{x},y_{1:u-1}\right) + \log p\left(y_u|\mathit{x},y_{1:u-1}\right) +
\lambda_1 \log p_{\text{Target LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) - \lambda_1 \log p_{\text{Target LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) -
\lambda_2 \log p_{\text{Source LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) \lambda_2 \log p_{\text{Source LM}}\left(y_u|\mathit{x},y_{1:u-1}\right)
where :math:`\lambda_1` and :math:`\lambda_2` are the weights of LM scores for target domain and source domain respectively. where :math:`\lambda_1` and :math:`\lambda_2` are the weights of LM scores for target domain and source domain respectively.
Here, the source domain LM is trained on the training corpus. The only difference in the above formular compared to Here, the source domain LM is trained on the training corpus. The only difference in the above formular compared to
shallow fusion is the subtraction of the source domain LM. shallow fusion is the subtraction of the source domain LM.
Some works treat the predictor and the joiner of the neural transducer as its internal LM. However, the LM is Some works treat the predictor and the joiner of the neural transducer as its internal LM. However, the LM is
considered to be weak and can only capture low-level language information. Therefore, `LODR <https://arxiv.org/abs/2203.16776>`__ proposed to use considered to be weak and can only capture low-level language information. Therefore, `LODR <https://arxiv.org/abs/2203.16776>`__ proposed to use
a low-order n-gram LM as an approximation of the ILM of the neural transducer. This leads to the following formula a low-order n-gram LM as an approximation of the ILM of the neural transducer. This leads to the following formula
during decoding for transducer model: during decoding for transducer model:
.. math:: .. math::
\text{score}\left(y_u|\mathit{x},y\right) = \text{score}\left(y_u|\mathit{x},y\right) =
\log p_{rnnt}\left(y_u|\mathit{x},y_{1:u-1}\right) + \log p_{rnnt}\left(y_u|\mathit{x},y_{1:u-1}\right) +
\lambda_1 \log p_{\text{Target LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) - \lambda_1 \log p_{\text{Target LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) -
\lambda_2 \log p_{\text{bi-gram}}\left(y_u|\mathit{x},y_{1:u-1}\right) \lambda_2 \log p_{\text{bi-gram}}\left(y_u|\mathit{x},y_{1:u-1}\right)
In LODR, an additional bi-gram LM estimated on the source domain (e.g training corpus) is required. Comared to DR, In LODR, an additional bi-gram LM estimated on the source domain (e.g training corpus) is required. Comared to DR,
the only difference lies in the choice of source domain LM. According to the original `paper <https://arxiv.org/abs/2203.16776>`_, the only difference lies in the choice of source domain LM. According to the original `paper <https://arxiv.org/abs/2203.16776>`_,
LODR achieves similar performance compared DR in both intra-domain and cross-domain settings. LODR achieves similar performance compared DR in both intra-domain and cross-domain settings.
As a bi-gram is much faster to evaluate, LODR is usually much faster. As a bi-gram is much faster to evaluate, LODR is usually much faster.
@ -85,7 +85,7 @@ To test the model, let's have a look at the decoding results **without** using L
--avg 1 \ --avg 1 \
--use-averaged-model False \ --use-averaged-model False \
--exp-dir $exp_dir \ --exp-dir $exp_dir \
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \
--max-duration 600 \ --max-duration 600 \
--decode-chunk-len 32 \ --decode-chunk-len 32 \
--decoding-method modified_beam_search --decoding-method modified_beam_search
@ -99,17 +99,17 @@ The following WERs are achieved on test-clean and test-other:
$ For test-other, WER of different settings are: $ For test-other, WER of different settings are:
$ beam_size_4 7.93 best for test-other $ beam_size_4 7.93 best for test-other
Then, we download the external language model and bi-gram LM that are necessary for LODR. Then, we download the external language model and bi-gram LM that are necessary for LODR.
Note that the bi-gram is estimated on the LibriSpeech 960 hours' text. Note that the bi-gram is estimated on the LibriSpeech 960 hours' text.
.. code-block:: bash .. code-block:: bash
$ # download the external LM $ # download the external LM
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
$ # create a symbolic link so that the checkpoint can be loaded $ # create a symbolic link so that the checkpoint can be loaded
$ pushd icefall-librispeech-rnn-lm/exp $ pushd icefall-librispeech-rnn-lm/exp
$ git lfs pull --include "pretrained.pt" $ git lfs pull --include "pretrained.pt"
$ ln -s pretrained.pt epoch-99.pt $ ln -s pretrained.pt epoch-99.pt
$ popd $ popd
$ $
$ # download the bi-gram $ # download the bi-gram
@ -122,7 +122,7 @@ Note that the bi-gram is estimated on the LibriSpeech 960 hours' text.
Then, we perform LODR decoding by setting ``--decoding-method`` to ``modified_beam_search_lm_LODR``: Then, we perform LODR decoding by setting ``--decoding-method`` to ``modified_beam_search_lm_LODR``:
.. code-block:: bash .. code-block:: bash
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp $ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
$ lm_dir=./icefall-librispeech-rnn-lm/exp $ lm_dir=./icefall-librispeech-rnn-lm/exp
$ lm_scale=0.42 $ lm_scale=0.42
@ -135,8 +135,8 @@ Then, we perform LODR decoding by setting ``--decoding-method`` to ``modified_be
--exp-dir $exp_dir \ --exp-dir $exp_dir \
--max-duration 600 \ --max-duration 600 \
--decode-chunk-len 32 \ --decode-chunk-len 32 \
--decoding-method modified_beam_search_lm_LODR \ --decoding-method modified_beam_search_LODR \
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \
--use-shallow-fusion 1 \ --use-shallow-fusion 1 \
--lm-type rnn \ --lm-type rnn \
--lm-exp-dir $lm_dir \ --lm-exp-dir $lm_dir \
@ -181,4 +181,4 @@ indeed **further improves** the WER. We can do even better if we increase ``--be
- 6.38 - 6.38
* - 12 * - 12
- 2.4 - 2.4
- 6.23 - 6.23

View File

@ -48,7 +48,7 @@ As usual, we first test the model's performance without external LM. This can be
--avg 1 \ --avg 1 \
--use-averaged-model False \ --use-averaged-model False \
--exp-dir $exp_dir \ --exp-dir $exp_dir \
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \
--max-duration 600 \ --max-duration 600 \
--decode-chunk-len 32 \ --decode-chunk-len 32 \
--decoding-method modified_beam_search --decoding-method modified_beam_search
@ -101,7 +101,7 @@ is set to `False`.
--max-duration 600 \ --max-duration 600 \
--decode-chunk-len 32 \ --decode-chunk-len 32 \
--decoding-method modified_beam_search_lm_rescore \ --decoding-method modified_beam_search_lm_rescore \
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \
--use-shallow-fusion 0 \ --use-shallow-fusion 0 \
--lm-type rnn \ --lm-type rnn \
--lm-exp-dir $lm_dir \ --lm-exp-dir $lm_dir \
@ -173,7 +173,7 @@ Then we can performn LM rescoring + LODR by changing the decoding method to `mod
--max-duration 600 \ --max-duration 600 \
--decode-chunk-len 32 \ --decode-chunk-len 32 \
--decoding-method modified_beam_search_lm_rescore_LODR \ --decoding-method modified_beam_search_lm_rescore_LODR \
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \
--use-shallow-fusion 0 \ --use-shallow-fusion 0 \
--lm-type rnn \ --lm-type rnn \
--lm-exp-dir $lm_dir \ --lm-exp-dir $lm_dir \

View File

@ -46,7 +46,7 @@ To test the model, let's have a look at the decoding results without using LM. T
--avg 1 \ --avg 1 \
--use-averaged-model False \ --use-averaged-model False \
--exp-dir $exp_dir \ --exp-dir $exp_dir \
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \
--max-duration 600 \ --max-duration 600 \
--decode-chunk-len 32 \ --decode-chunk-len 32 \
--decoding-method modified_beam_search --decoding-method modified_beam_search
@ -95,7 +95,7 @@ To use shallow fusion for decoding, we can execute the following command:
--max-duration 600 \ --max-duration 600 \
--decode-chunk-len 32 \ --decode-chunk-len 32 \
--decoding-method modified_beam_search_lm_shallow_fusion \ --decoding-method modified_beam_search_lm_shallow_fusion \
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \
--use-shallow-fusion 1 \ --use-shallow-fusion 1 \
--lm-type rnn \ --lm-type rnn \
--lm-exp-dir $lm_dir \ --lm-exp-dir $lm_dir \

View File

@ -90,6 +90,11 @@ You can use <https://github.com/k2-fsa/sherpa> to deploy it.
| greedy_search | 2.23 | 4.96 | --epoch 40 --avg 16 | | greedy_search | 2.23 | 4.96 | --epoch 40 --avg 16 |
| modified_beam_search | 2.21 | 4.91 | --epoch 40 --avg 16 | | modified_beam_search | 2.21 | 4.91 | --epoch 40 --avg 16 |
| fast_beam_search | 2.24 | 4.93 | --epoch 40 --avg 16 | | fast_beam_search | 2.24 | 4.93 | --epoch 40 --avg 16 |
| modified_beam_search_shallow_fusion | 2.01 | 4.37 | --epoch 40 --avg 16 --beam-size 12 --lm-scale 0.3 |
| modified_beam_search_LODR | 1.94 | 4.17 | --epoch 40 --avg 16 --beam-size 12 --lm-scale 0.52 --LODR-scale -0.26 |
| modified_beam_search_rescore | 2.04 | 4.39 | --epoch 40 --avg 16 --beam-size 12 |
| modified_beam_search_rescore_LODR | 2.01 | 4.33 | --epoch 40 --avg 16 --beam-size 12 |
The training command is: The training command is:
```bash ```bash
@ -119,6 +124,8 @@ for m in greedy_search modified_beam_search fast_beam_search; do
done done
``` ```
To decode with external language models, please refer to the documentation [here](https://k2-fsa.github.io/icefall/decoding-with-langugage-models/index.html).
##### small-scaled model, number of model parameters: 23285615, i.e., 23.3 M ##### small-scaled model, number of model parameters: 23285615, i.e., 23.3 M
The tensorboard log can be found at The tensorboard log can be found at

View File

@ -396,6 +396,12 @@ def decode_one_batch(
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest, only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
LM:
A neural network language model.
ngram_lm:
A ngram language model
ngram_lm_scale:
The scale for the ngram language model.
Returns: Returns:
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
@ -907,6 +913,7 @@ def main():
ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa")
logging.info(f"lm filename: {ngram_file_name}") logging.info(f"lm filename: {ngram_file_name}")
ngram_lm = kenlm.Model(ngram_file_name) ngram_lm = kenlm.Model(ngram_file_name)
ngram_lm_scale = None # use a list to search
elif params.decoding_method == "modified_beam_search_LODR": elif params.decoding_method == "modified_beam_search_LODR":
lm_filename = f"{params.tokens_ngram}gram.fst.txt" lm_filename = f"{params.tokens_ngram}gram.fst.txt"

View File

@ -115,9 +115,14 @@ from beam_search import (
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
modified_beam_search_lm_rescore,
modified_beam_search_lm_rescore_LODR,
modified_beam_search_lm_shallow_fusion,
modified_beam_search_LODR,
) )
from train import add_model_arguments, get_params, get_model from train import add_model_arguments, get_model, get_params
from icefall import LmScorer, NgramLm
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
average_checkpoints_with_averaged_model, average_checkpoints_with_averaged_model,
@ -273,8 +278,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -302,6 +306,47 @@ def get_parser():
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
) )
parser.add_argument(
"--use-shallow-fusion",
type=str2bool,
default=False,
help="""Use neural network LM for shallow fusion.
If you want to use LODR, you will also need to set this to true
""",
)
parser.add_argument(
"--lm-type",
type=str,
default="rnn",
help="Type of NN lm",
choices=["rnn", "transformer"],
)
parser.add_argument(
"--lm-scale",
type=float,
default=0.3,
help="""The scale of the neural network LM
Used only when `--use-shallow-fusion` is set to True.
""",
)
parser.add_argument(
"--tokens-ngram",
type=int,
default=2,
help="""The order of the ngram lm.
""",
)
parser.add_argument(
"--backoff-id",
type=int,
default=500,
help="ID of the backoff symbol in the ngram LM",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -314,6 +359,9 @@ def decode_one_batch(
batch: dict, batch: dict,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
LM: Optional[LmScorer] = None,
ngram_lm=None,
ngram_lm_scale: float = 0.0,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -342,6 +390,12 @@ def decode_one_batch(
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest, only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
LM:
A neural network language model.
ngram_lm:
A ngram language model
ngram_lm_scale:
The scale for the ngram language model.
Returns: Returns:
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
@ -425,10 +479,7 @@ def decode_one_batch(
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
elif ( elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -445,6 +496,50 @@ def decode_one_batch(
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
hyp_tokens = modified_beam_search_lm_shallow_fusion(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
LM=LM,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_LODR":
hyp_tokens = modified_beam_search_LODR(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
LODR_lm=ngram_lm,
LODR_lm_scale=ngram_lm_scale,
LM=LM,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_lm_rescore":
lm_scale_list = [0.01 * i for i in range(10, 50)]
ans_dict = modified_beam_search_lm_rescore(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
LM=LM,
lm_scale_list=lm_scale_list,
)
elif params.decoding_method == "modified_beam_search_lm_rescore_LODR":
lm_scale_list = [0.02 * i for i in range(2, 30)]
ans_dict = modified_beam_search_lm_rescore_LODR(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
LM=LM,
LODR_lm=ngram_lm,
sp=sp,
lm_scale_list=lm_scale_list,
)
else: else:
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
@ -483,6 +578,16 @@ def decode_one_batch(
key += f"_ngram_lm_scale_{params.ngram_lm_scale}" key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps} return {key: hyps}
elif params.decoding_method in (
"modified_beam_search_lm_rescore",
"modified_beam_search_lm_rescore_LODR",
):
ans = dict()
assert ans_dict is not None
for key, hyps in ans_dict.items():
hyps = [sp.decode(hyp).split() for hyp in hyps]
ans[f"beam_size_{params.beam_size}_{key}"] = hyps
return ans
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
@ -494,6 +599,9 @@ def decode_dataset(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
LM: Optional[LmScorer] = None,
ngram_lm=None,
ngram_lm_scale: float = 0.0,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -543,6 +651,9 @@ def decode_dataset(
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
word_table=word_table, word_table=word_table,
batch=batch, batch=batch,
LM=LM,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
) )
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
@ -559,9 +670,7 @@ def decode_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info( logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -594,8 +703,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tWER", file=f) print("settings\tWER", file=f)
@ -614,6 +722,7 @@ def save_results(
def main(): def main():
parser = get_parser() parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser) LibriSpeechAsrDataModule.add_arguments(parser)
LmScorer.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
@ -628,6 +737,10 @@ def main():
"fast_beam_search_nbest_LG", "fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle", "fast_beam_search_nbest_oracle",
"modified_beam_search", "modified_beam_search",
"modified_beam_search_LODR",
"modified_beam_search_lm_shallow_fusion",
"modified_beam_search_lm_rescore",
"modified_beam_search_lm_rescore_LODR",
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
@ -656,13 +769,19 @@ def main():
if "LG" in params.decoding_method: if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += ( params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_shallow_fusion:
params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}"
if "LODR" in params.decoding_method:
params.suffix += (
f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
)
if params.use_averaged_model: if params.use_averaged_model:
params.suffix += "-use-averaged-model" params.suffix += "-use-averaged-model"
@ -690,9 +809,9 @@ def main():
if not params.use_averaged_model: if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg
)[: params.avg] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for"
@ -719,9 +838,9 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))
else: else:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg + 1
)[: params.avg + 1] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for"
@ -768,6 +887,54 @@ def main():
model.to(device) model.to(device)
model.eval() model.eval()
# only load the neural network LM if required
if params.use_shallow_fusion or params.decoding_method in (
"modified_beam_search_lm_rescore",
"modified_beam_search_lm_rescore_LODR",
"modified_beam_search_lm_shallow_fusion",
"modified_beam_search_LODR",
):
LM = LmScorer(
lm_type=params.lm_type,
params=params,
device=device,
lm_scale=params.lm_scale,
)
LM.to(device)
LM.eval()
else:
LM = None
# only load N-gram LM when needed
if params.decoding_method == "modified_beam_search_lm_rescore_LODR":
try:
import kenlm
except ImportError:
print("Please install kenlm first. You can use")
print(" pip install https://github.com/kpu/kenlm/archive/master.zip")
print("to install it")
import sys
sys.exit(-1)
ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa")
logging.info(f"lm filename: {ngram_file_name}")
ngram_lm = kenlm.Model(ngram_file_name)
ngram_lm_scale = None # use a list to search
elif params.decoding_method == "modified_beam_search_LODR":
lm_filename = f"{params.tokens_ngram}gram.fst.txt"
logging.info(f"Loading token level lm: {lm_filename}")
ngram_lm = NgramLm(
str(params.lang_dir / lm_filename),
backoff_id=params.backoff_id,
is_binary=False,
)
logging.info(f"num states: {ngram_lm.lm.num_states}")
ngram_lm_scale = params.ngram_lm_scale
else:
ngram_lm = None
ngram_lm_scale = None
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG": if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir) lexicon = Lexicon(params.lang_dir)
@ -780,9 +947,7 @@ def main():
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:
word_table = None word_table = None
decoding_graph = k2.trivial_graph( decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
params.vocab_size - 1, device=device
)
else: else:
decoding_graph = None decoding_graph = None
word_table = None word_table = None
@ -811,6 +976,9 @@ def main():
sp=sp, sp=sp,
word_table=word_table, word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
LM=LM,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
) )
save_results( save_results(