mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
Add low-order density ratio in RNNLM shallow fusion (#678)
* Support LODR in RNNLM shallow fusion * fix style * fix code style * update workflow and CI * update results * propagate changes to stateless3 * add decoding results for stateless3+giga * fix CI
This commit is contained in:
parent
1d5c03f85a
commit
4b5bc480e8
@ -16,6 +16,7 @@ log "Downloading pre-trained model from $repo_url"
|
|||||||
git lfs install
|
git lfs install
|
||||||
git clone $repo_url
|
git clone $repo_url
|
||||||
repo=$(basename $repo_url)
|
repo=$(basename $repo_url)
|
||||||
|
abs_repo=$(realpath $repo)
|
||||||
|
|
||||||
log "Display test files"
|
log "Display test files"
|
||||||
tree $repo/
|
tree $repo/
|
||||||
@ -178,21 +179,27 @@ echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
|
|||||||
if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"shallow-fusion" ]]; then
|
if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"shallow-fusion" ]]; then
|
||||||
lm_repo_url=https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
|
lm_repo_url=https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
|
||||||
log "Download pre-trained RNN-LM model from ${lm_repo_url}"
|
log "Download pre-trained RNN-LM model from ${lm_repo_url}"
|
||||||
git clone $lm_repo_url
|
GIT_LFS_SKIP_SMUDGE=1 git clone $lm_repo_url
|
||||||
lm_repo=$(basename $lm_repo_url)
|
lm_repo=$(basename $lm_repo_url)
|
||||||
pushd $lm_repo
|
pushd $lm_repo
|
||||||
git lfs pull --include "exp/pretrained.pt"
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
cd exp
|
mv exp/pretrained.pt exp/epoch-88.pt
|
||||||
ln -s pretrained.pt epoch-88.pt
|
|
||||||
popd
|
popd
|
||||||
|
|
||||||
|
mkdir -p lstm_transducer_stateless2/exp
|
||||||
|
ln -sf $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt
|
||||||
|
ln -s $PWD/$repo/data/lang_bpe_500 data/
|
||||||
|
|
||||||
|
ls -lh data
|
||||||
|
ls -lh lstm_transducer_stateless2/exp
|
||||||
|
|
||||||
|
log "Decoding test-clean and test-other"
|
||||||
|
|
||||||
./lstm_transducer_stateless2/decode.py \
|
./lstm_transducer_stateless2/decode.py \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 999 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir lstm_transducer_stateless2/exp \
|
||||||
--lang-dir $repo/data/lang_bpe_500 \
|
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method modified_beam_search_rnnlm_shallow_fusion \
|
--decoding-method modified_beam_search_rnnlm_shallow_fusion \
|
||||||
--beam 4 \
|
--beam 4 \
|
||||||
@ -204,6 +211,52 @@ if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"shallow-fusion" ]]; then
|
|||||||
--rnn-lm-tie-weights 1
|
--rnn-lm-tie-weights 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"LODR" ]]; then
|
||||||
|
bigram_repo_url=https://huggingface.co/marcoyang/librispeech_bigram
|
||||||
|
log "Download bi-gram LM from ${bigram_repo_url}"
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $bigram_repo_url
|
||||||
|
bigramlm_repo=$(basename $bigram_repo_url)
|
||||||
|
pushd $bigramlm_repo
|
||||||
|
git lfs pull --include "2gram.fst.txt"
|
||||||
|
cp 2gram.fst.txt $abs_repo/data/lang_bpe_500/.
|
||||||
|
popd
|
||||||
|
|
||||||
|
lm_repo_url=https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
|
||||||
|
log "Download pre-trained RNN-LM model from ${lm_repo_url}"
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $lm_repo_url
|
||||||
|
lm_repo=$(basename $lm_repo_url)
|
||||||
|
pushd $lm_repo
|
||||||
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
mv exp/pretrained.pt exp/epoch-88.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
mkdir -p lstm_transducer_stateless2/exp
|
||||||
|
ln -sf $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt
|
||||||
|
ln -s $PWD/$repo/data/lang_bpe_500 data/
|
||||||
|
|
||||||
|
ls -lh data
|
||||||
|
ls -lh lstm_transducer_stateless2/exp
|
||||||
|
|
||||||
|
log "Decoding test-clean and test-other"
|
||||||
|
|
||||||
|
./lstm_transducer_stateless2/decode.py \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--epoch 999 \
|
||||||
|
--avg 1 \
|
||||||
|
--exp-dir lstm_transducer_stateless2/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method modified_beam_search_rnnlm_LODR \
|
||||||
|
--beam 4 \
|
||||||
|
--rnn-lm-scale 0.3 \
|
||||||
|
--rnn-lm-exp-dir $lm_repo/exp \
|
||||||
|
--rnn-lm-epoch 88 \
|
||||||
|
--rnn-lm-avg 1 \
|
||||||
|
--rnn-lm-num-layers 3 \
|
||||||
|
--rnn-lm-tie-weights 1 \
|
||||||
|
--tokens-ngram 2 \
|
||||||
|
--ngram-lm-scale -0.16
|
||||||
|
fi
|
||||||
|
|
||||||
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then
|
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then
|
||||||
mkdir -p lstm_transducer_stateless2/exp
|
mkdir -p lstm_transducer_stateless2/exp
|
||||||
ln -s $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt
|
ln -s $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt
|
||||||
|
@ -18,7 +18,7 @@ on:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
run_librispeech_lstm_transducer_stateless2_2022_09_03:
|
run_librispeech_lstm_transducer_stateless2_2022_09_03:
|
||||||
if: github.event.label.name == 'ready' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'ncnn' || github.event.label.name == 'onnx' || github.event_name == 'push' || github.event_name == 'schedule'
|
if: github.event.label.name == 'ready' || github.event.label.name == 'LODR' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'ncnn' || github.event.label.name == 'onnx' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
@ -139,9 +139,20 @@ jobs:
|
|||||||
find modified_beam_search_rnnlm_shallow_fusion -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
find modified_beam_search_rnnlm_shallow_fusion -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||||
find modified_beam_search_rnnlm_shallow_fusion -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
find modified_beam_search_rnnlm_shallow_fusion -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||||
|
|
||||||
|
- name: Display decoding results for lstm_transducer_stateless2
|
||||||
|
if: github.event.label.name == 'LODR'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
cd egs/librispeech/ASR
|
||||||
|
tree lstm_transducer_stateless2/exp
|
||||||
|
cd lstm_transducer_stateless2/exp
|
||||||
|
echo "===modified_beam_search_rnnlm_LODR==="
|
||||||
|
find modified_beam_search_rnnlm_LODR -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||||
|
find modified_beam_search_rnnlm_LODR -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||||
|
|
||||||
- name: Upload decoding results for lstm_transducer_stateless2
|
- name: Upload decoding results for lstm_transducer_stateless2
|
||||||
uses: actions/upload-artifact@v2
|
uses: actions/upload-artifact@v2
|
||||||
if: github.event_name == 'schedule' || github.event.label.name == 'shallow-fusion'
|
if: github.event_name == 'schedule' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'LODR'
|
||||||
with:
|
with:
|
||||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-lstm_transducer_stateless2-2022-09-03
|
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-lstm_transducer_stateless2-2022-09-03
|
||||||
path: egs/librispeech/ASR/lstm_transducer_stateless2/exp/
|
path: egs/librispeech/ASR/lstm_transducer_stateless2/exp/
|
||||||
|
@ -318,6 +318,7 @@ The WERs are:
|
|||||||
| greedy search (max sym per frame 1) | 2.78 | 7.36 | --iter 468000 --avg 16 |
|
| greedy search (max sym per frame 1) | 2.78 | 7.36 | --iter 468000 --avg 16 |
|
||||||
| modified_beam_search | 2.73 | 7.15 | --iter 468000 --avg 16 |
|
| modified_beam_search | 2.73 | 7.15 | --iter 468000 --avg 16 |
|
||||||
| modified_beam_search + RNNLM shallow fusion | 2.42 | 6.46 | --iter 468000 --avg 16 |
|
| modified_beam_search + RNNLM shallow fusion | 2.42 | 6.46 | --iter 468000 --avg 16 |
|
||||||
|
| modified_beam_search + RNNLM shallow fusion | 2.28 | 5.94 | --iter 468000 --avg 16 |
|
||||||
| fast_beam_search | 2.76 | 7.31 | --iter 468000 --avg 16 |
|
| fast_beam_search | 2.76 | 7.31 | --iter 468000 --avg 16 |
|
||||||
| greedy search (max sym per frame 1) | 2.77 | 7.35 | --iter 472000 --avg 18 |
|
| greedy search (max sym per frame 1) | 2.77 | 7.35 | --iter 472000 --avg 18 |
|
||||||
| modified_beam_search | 2.75 | 7.08 | --iter 472000 --avg 18 |
|
| modified_beam_search | 2.75 | 7.08 | --iter 472000 --avg 18 |
|
||||||
@ -393,6 +394,32 @@ for iter in 472000; do
|
|||||||
done
|
done
|
||||||
done
|
done
|
||||||
|
|
||||||
|
You may also decode using LODR + RNNLM shallow fusion. This decoding method is proposed in <https://arxiv.org/pdf/2203.16776.pdf>.
|
||||||
|
It subtracts the internal language model score during shallow fusion, which is approximated by a bi-gram model. The bi-gram can be
|
||||||
|
generated by `generate-lm.sh`, or you may download it from <https://huggingface.co/marcoyang/librispeech_bigram>.
|
||||||
|
|
||||||
|
The decoding command is as follows:
|
||||||
|
|
||||||
|
for iter in 472000; do
|
||||||
|
for avg in 8 10 12 14 16 18; do
|
||||||
|
./lstm_transducer_stateless2/decode.py \
|
||||||
|
--iter $iter \
|
||||||
|
--avg $avg \
|
||||||
|
--exp-dir ./lstm_transducer_stateless2/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method modified_beam_search_rnnlm_LODR \
|
||||||
|
--beam 4 \
|
||||||
|
--rnn-lm-scale 0.4 \
|
||||||
|
--rnn-lm-exp-dir /path/to/RNNLM \
|
||||||
|
--rnn-lm-epoch 99 \
|
||||||
|
--rnn-lm-avg 1 \
|
||||||
|
--rnn-lm-num-layers 3 \
|
||||||
|
--rnn-lm-tie-weights 1 \
|
||||||
|
--token-ngram 2 \
|
||||||
|
--ngram-lm-scale -0.16
|
||||||
|
done
|
||||||
|
done
|
||||||
|
|
||||||
Pretrained models, training logs, decoding logs, and decoding results
|
Pretrained models, training logs, decoding logs, and decoding results
|
||||||
are available at
|
are available at
|
||||||
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03>
|
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03>
|
||||||
@ -1912,6 +1939,8 @@ subset so that the gigaspeech dataloader never exhausts.
|
|||||||
|-------------------------------------|------------|------------|---------------------------------------------|
|
|-------------------------------------|------------|------------|---------------------------------------------|
|
||||||
| greedy search (max sym per frame 1) | 2.03 | 4.70 | --iter 1224000 --avg 14 --max-duration 600 |
|
| greedy search (max sym per frame 1) | 2.03 | 4.70 | --iter 1224000 --avg 14 --max-duration 600 |
|
||||||
| modified beam search | 2.00 | 4.63 | --iter 1224000 --avg 14 --max-duration 600 |
|
| modified beam search | 2.00 | 4.63 | --iter 1224000 --avg 14 --max-duration 600 |
|
||||||
|
| modified beam search + rnnlm shallow fusion | 1.94 | 4.2 | --iter 1224000 --avg 14 --max-duration 600 |
|
||||||
|
| modified beam search + LODR | 1.83 | 4.03 | --iter 1224000 --avg 14 --max-duration 600 |
|
||||||
| fast beam search | 2.10 | 4.68 | --iter 1224000 --avg 14 --max-duration 600 |
|
| fast beam search | 2.10 | 4.68 | --iter 1224000 --avg 14 --max-duration 600 |
|
||||||
|
|
||||||
The training commands are:
|
The training commands are:
|
||||||
@ -1957,6 +1986,64 @@ for iter in 1224000; do
|
|||||||
done
|
done
|
||||||
done
|
done
|
||||||
```
|
```
|
||||||
|
You may also decode using shallow fusion with external RNNLM. To do so you need to
|
||||||
|
download a well-trained RNNLM from this link <https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main>
|
||||||
|
|
||||||
|
```bash
|
||||||
|
rnn_lm_scale=0.3
|
||||||
|
|
||||||
|
for iter in 1224000; do
|
||||||
|
for avg in 14; do
|
||||||
|
for method in modified_beam_search_rnnlm_shallow_fusion ; do
|
||||||
|
./pruned_transducer_stateless3/decode.py \
|
||||||
|
--iter $iter \
|
||||||
|
--avg $avg \
|
||||||
|
--exp-dir ./pruned_transducer_stateless3/exp-0.9/ \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method $method \
|
||||||
|
--max-sym-per-frame 1 \
|
||||||
|
--beam 4 \
|
||||||
|
--max-contexts 32 \
|
||||||
|
--rnn-lm-scale $rnn_lm_scale \
|
||||||
|
--rnn-lm-exp-dir /path/to/RNNLM \
|
||||||
|
--rnn-lm-epoch 99 \
|
||||||
|
--rnn-lm-avg 1 \
|
||||||
|
--rnn-lm-num-layers 3 \
|
||||||
|
--rnn-lm-tie-weights 1
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
If you want to try out with LODR decoding, use the following command. This assums you have a bi-gram LM trained on LibriSpeech text. You can also download the bi-gram LM from here <https://huggingface.co/marcoyang/librispeech_bigram/tree/main> and put it under the directory `data/lang_bpe_500`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
rnn_lm_scale=0.4
|
||||||
|
|
||||||
|
for iter in 1224000; do
|
||||||
|
for avg in 14; do
|
||||||
|
for method in modified_beam_search_rnnlm_LODR ; do
|
||||||
|
./pruned_transducer_stateless3/decode.py \
|
||||||
|
--iter $iter \
|
||||||
|
--avg $avg \
|
||||||
|
--exp-dir ./pruned_transducer_stateless3/exp-0.9/ \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method $method \
|
||||||
|
--max-sym-per-frame 1 \
|
||||||
|
--beam 4 \
|
||||||
|
--max-contexts 32 \
|
||||||
|
--rnn-lm-scale $rnn_lm_scale \
|
||||||
|
--rnn-lm-exp-dir /path/to/RNNLM \
|
||||||
|
--rnn-lm-epoch 99 \
|
||||||
|
--rnn-lm-avg 1 \
|
||||||
|
--rnn-lm-num-layers 3 \
|
||||||
|
--rnn-lm-tie-weights 1 \
|
||||||
|
--tokens-ngram 2 \
|
||||||
|
--ngram-lm-scale -0.14
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
The pretrained models, training logs, decoding logs, and decoding results
|
The pretrained models, training logs, decoding logs, and decoding results
|
||||||
can be found at
|
can be found at
|
||||||
|
@ -107,8 +107,25 @@ Usage:
|
|||||||
--rnn-lm-avg 1 \
|
--rnn-lm-avg 1 \
|
||||||
--rnn-lm-num-layers 3 \
|
--rnn-lm-num-layers 3 \
|
||||||
--rnn-lm-tie-weights 1
|
--rnn-lm-tie-weights 1
|
||||||
"""
|
|
||||||
|
|
||||||
|
(9) modified beam search with RNNLM shallow fusion + LODR
|
||||||
|
./lstm_transducer_stateless2/decode.py \
|
||||||
|
--epoch 35 \
|
||||||
|
--avg 15 \
|
||||||
|
--max-duration 600 \
|
||||||
|
--exp-dir ./lstm_transducer_stateless2/exp \
|
||||||
|
--decoding-method modified_beam_search_rnnlm_LODR \
|
||||||
|
--beam 4 \
|
||||||
|
--max-contexts 4 \
|
||||||
|
--rnn-lm-scale 0.4 \
|
||||||
|
--rnn-lm-exp-dir /path/to/RNNLM/exp \
|
||||||
|
--rnn-lm-epoch 99 \
|
||||||
|
--rnn-lm-avg 1 \
|
||||||
|
--rnn-lm-num-layers 3 \
|
||||||
|
--rnn-lm-tie-weights 1 \
|
||||||
|
--tokens-ngram 2 \
|
||||||
|
--ngram-lm-scale -0.16 \
|
||||||
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
@ -132,6 +149,7 @@ from beam_search import (
|
|||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
modified_beam_search_ngram_rescoring,
|
modified_beam_search_ngram_rescoring,
|
||||||
|
modified_beam_search_rnnlm_LODR,
|
||||||
modified_beam_search_rnnlm_shallow_fusion,
|
modified_beam_search_rnnlm_shallow_fusion,
|
||||||
)
|
)
|
||||||
from librispeech import LibriSpeech
|
from librispeech import LibriSpeech
|
||||||
@ -235,7 +253,8 @@ def get_parser():
|
|||||||
- fast_beam_search_nbest_oracle
|
- fast_beam_search_nbest_oracle
|
||||||
- fast_beam_search_nbest_LG
|
- fast_beam_search_nbest_LG
|
||||||
- modified_beam_search_ngram_rescoring
|
- modified_beam_search_ngram_rescoring
|
||||||
- modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion
|
- modified_beam_search_rnnlm_shallow_fusion
|
||||||
|
- modified_beam_search_rnnlm_LODR
|
||||||
If you use fast_beam_search_nbest_LG, you have to specify
|
If you use fast_beam_search_nbest_LG, you have to specify
|
||||||
`--lang-dir`, which should contain `LG.pt`.
|
`--lang-dir`, which should contain `LG.pt`.
|
||||||
""",
|
""",
|
||||||
@ -394,7 +413,8 @@ def get_parser():
|
|||||||
type=int,
|
type=int,
|
||||||
default=3,
|
default=3,
|
||||||
help="""Token Ngram used for rescoring.
|
help="""Token Ngram used for rescoring.
|
||||||
Used only when the decoding method is modified_beam_search_ngram_rescoring""",
|
Used only when the decoding method is
|
||||||
|
modified_beam_search_ngram_rescoring""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -402,7 +422,8 @@ def get_parser():
|
|||||||
type=int,
|
type=int,
|
||||||
default=500,
|
default=500,
|
||||||
help="""ID of the backoff symbol.
|
help="""ID of the backoff symbol.
|
||||||
Used only when the decoding method is modified_beam_search_ngram_rescoring""",
|
Used only when the decoding method is
|
||||||
|
modified_beam_search_ngram_rescoring""",
|
||||||
)
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
@ -572,6 +593,20 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
|
elif params.decoding_method == "modified_beam_search_rnnlm_LODR":
|
||||||
|
hyp_tokens = modified_beam_search_rnnlm_LODR(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam_size,
|
||||||
|
sp=sp,
|
||||||
|
LODR_lm=ngram_lm,
|
||||||
|
LODR_lm_scale=ngram_lm_scale,
|
||||||
|
rnnlm=rnnlm,
|
||||||
|
rnnlm_scale=rnnlm_scale,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
else:
|
else:
|
||||||
batch_size = encoder_out.size(0)
|
batch_size = encoder_out.size(0)
|
||||||
|
|
||||||
@ -760,6 +795,7 @@ def main():
|
|||||||
"fast_beam_search_nbest_LG",
|
"fast_beam_search_nbest_LG",
|
||||||
"fast_beam_search_nbest_oracle",
|
"fast_beam_search_nbest_oracle",
|
||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
|
"modified_beam_search_rnnlm_LODR",
|
||||||
"modified_beam_search_ngram_rescoring",
|
"modified_beam_search_ngram_rescoring",
|
||||||
"modified_beam_search_rnnlm_shallow_fusion",
|
"modified_beam_search_rnnlm_shallow_fusion",
|
||||||
)
|
)
|
||||||
@ -788,6 +824,9 @@ def main():
|
|||||||
if "rnnlm" in params.decoding_method:
|
if "rnnlm" in params.decoding_method:
|
||||||
params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}"
|
params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}"
|
||||||
|
|
||||||
|
if "LODR" in params.decoding_method:
|
||||||
|
params.suffix += "-LODR"
|
||||||
|
|
||||||
if params.use_averaged_model:
|
if params.use_averaged_model:
|
||||||
params.suffix += "-use-averaged-model"
|
params.suffix += "-use-averaged-model"
|
||||||
|
|
||||||
@ -901,7 +940,7 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
# only load N-gram LM when needed
|
# only load N-gram LM when needed
|
||||||
if "ngram" in params.decoding_method:
|
if "ngram" in params.decoding_method or "LODR" in params.decoding_method:
|
||||||
lm_filename = f"{params.tokens_ngram}gram.fst.txt"
|
lm_filename = f"{params.tokens_ngram}gram.fst.txt"
|
||||||
logging.info(f"lm filename: {lm_filename}")
|
logging.info(f"lm filename: {lm_filename}")
|
||||||
ngram_lm = NgramLm(
|
ngram_lm = NgramLm(
|
||||||
@ -910,6 +949,7 @@ def main():
|
|||||||
is_binary=False,
|
is_binary=False,
|
||||||
)
|
)
|
||||||
logging.info(f"num states: {ngram_lm.lm.num_states}")
|
logging.info(f"num states: {ngram_lm.lm.num_states}")
|
||||||
|
ngram_lm_scale = params.ngram_lm_scale
|
||||||
else:
|
else:
|
||||||
ngram_lm = None
|
ngram_lm = None
|
||||||
ngram_lm_scale = None
|
ngram_lm_scale = None
|
||||||
@ -933,7 +973,6 @@ def main():
|
|||||||
)
|
)
|
||||||
rnn_lm_model.to(device)
|
rnn_lm_model.to(device)
|
||||||
rnn_lm_model.eval()
|
rnn_lm_model.eval()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
rnn_lm_model = None
|
rnn_lm_model = None
|
||||||
rnn_lm_scale = 0.0
|
rnn_lm_scale = 0.0
|
||||||
|
@ -2083,3 +2083,267 @@ def modified_beam_search_rnnlm_shallow_fusion(
|
|||||||
tokens=ans,
|
tokens=ans,
|
||||||
timestamps=ans_timestamps,
|
timestamps=ans_timestamps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def modified_beam_search_rnnlm_LODR(
|
||||||
|
model: Transducer,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
sp: spm.SentencePieceProcessor,
|
||||||
|
LODR_lm: NgramLm,
|
||||||
|
LODR_lm_scale: float,
|
||||||
|
rnnlm: RnnLmModel,
|
||||||
|
rnnlm_scale: float,
|
||||||
|
beam: int = 4,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""This function implements LODR (https://arxiv.org/abs/2203.16776) with
|
||||||
|
`modified_beam_search`. It uses a bi-gram language model as the estimate
|
||||||
|
of the internal language model and subtracts its score during shallow fusion
|
||||||
|
with an external language model. This implementation uses a RNNLM as the
|
||||||
|
external language model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (Transducer):
|
||||||
|
The transducer model
|
||||||
|
encoder_out (torch.Tensor):
|
||||||
|
Encoder output in (N,T,C)
|
||||||
|
encoder_out_lens (torch.Tensor):
|
||||||
|
A 1-D tensor of shape (N,), containing the number of
|
||||||
|
valid frames in encoder_out before padding.
|
||||||
|
sp:
|
||||||
|
Sentence piece generator.
|
||||||
|
LODR_lm:
|
||||||
|
A low order n-gram LM
|
||||||
|
LODR_lm_scale:
|
||||||
|
The scale of the LODR_lm
|
||||||
|
rnnlm (RnnLmModel):
|
||||||
|
RNNLM, the external language model
|
||||||
|
rnnlm_scale (float):
|
||||||
|
scale of RNNLM in shallow fusion
|
||||||
|
beam (int, optional):
|
||||||
|
Beam size. Defaults to 4.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a list-of-list of token IDs. ans[i] is the decoding results
|
||||||
|
for the i-th utterance.
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert encoder_out.ndim == 3, encoder_out.shape
|
||||||
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
|
assert rnnlm is not None
|
||||||
|
lm_scale = rnnlm_scale
|
||||||
|
vocab_size = rnnlm.vocab_size
|
||||||
|
|
||||||
|
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||||
|
input=encoder_out,
|
||||||
|
lengths=encoder_out_lens.cpu(),
|
||||||
|
batch_first=True,
|
||||||
|
enforce_sorted=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
blank_id = model.decoder.blank_id
|
||||||
|
sos_id = sp.piece_to_id("<sos/eos>")
|
||||||
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
|
context_size = model.decoder.context_size
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
|
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||||
|
N = encoder_out.size(0)
|
||||||
|
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||||
|
assert N == batch_size_list[0], (N, batch_size_list)
|
||||||
|
|
||||||
|
# get initial lm score and lm state by scoring the "sos" token
|
||||||
|
sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
|
||||||
|
init_score, init_states = rnnlm.score_token(sos_token)
|
||||||
|
|
||||||
|
B = [HypothesisList() for _ in range(N)]
|
||||||
|
for i in range(N):
|
||||||
|
B[i].add(
|
||||||
|
Hypothesis(
|
||||||
|
ys=[blank_id] * context_size,
|
||||||
|
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||||
|
state=init_states, # state of the RNNLM
|
||||||
|
lm_score=init_score.reshape(-1),
|
||||||
|
state_cost=NgramLmStateCost(
|
||||||
|
LODR_lm
|
||||||
|
), # state of the source domain ngram
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
rnnlm.clean_cache()
|
||||||
|
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
finalized_B = []
|
||||||
|
for batch_size in batch_size_list:
|
||||||
|
start = offset
|
||||||
|
end = offset + batch_size
|
||||||
|
current_encoder_out = encoder_out.data[start:end] # get batch
|
||||||
|
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
|
||||||
|
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
|
||||||
|
offset = end
|
||||||
|
|
||||||
|
finalized_B = B[batch_size:] + finalized_B
|
||||||
|
B = B[:batch_size]
|
||||||
|
|
||||||
|
hyps_shape = get_hyps_shape(B).to(device)
|
||||||
|
|
||||||
|
A = [list(b) for b in B]
|
||||||
|
B = [HypothesisList() for _ in range(batch_size)]
|
||||||
|
|
||||||
|
ys_log_probs = torch.cat(
|
||||||
|
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
) # (num_hyps, context_size)
|
||||||
|
|
||||||
|
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
|
||||||
|
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||||
|
|
||||||
|
current_encoder_out = torch.index_select(
|
||||||
|
current_encoder_out,
|
||||||
|
dim=0,
|
||||||
|
index=hyps_shape.row_ids(1).to(torch.int64),
|
||||||
|
) # (num_hyps, 1, 1, encoder_out_dim)
|
||||||
|
|
||||||
|
logits = model.joiner(
|
||||||
|
current_encoder_out,
|
||||||
|
decoder_out,
|
||||||
|
project_input=False,
|
||||||
|
) # (num_hyps, 1, 1, vocab_size)
|
||||||
|
|
||||||
|
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
|
||||||
|
|
||||||
|
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
|
||||||
|
|
||||||
|
log_probs.add_(ys_log_probs)
|
||||||
|
|
||||||
|
vocab_size = log_probs.size(-1)
|
||||||
|
|
||||||
|
log_probs = log_probs.reshape(-1)
|
||||||
|
|
||||||
|
row_splits = hyps_shape.row_splits(1) * vocab_size
|
||||||
|
log_probs_shape = k2.ragged.create_ragged_shape2(
|
||||||
|
row_splits=row_splits, cached_tot_size=log_probs.numel()
|
||||||
|
)
|
||||||
|
ragged_log_probs = k2.RaggedTensor(
|
||||||
|
shape=log_probs_shape, value=log_probs
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
for all hyps with a non-blank new token, score this token.
|
||||||
|
It is a little confusing here because this for-loop
|
||||||
|
looks very similar to the one below. Here, we go through all
|
||||||
|
top-k tokens and only add the non-blanks ones to the token_list.
|
||||||
|
The RNNLM will score those tokens given the LM states. Note that
|
||||||
|
the variable `scores` is the LM score after seeing the new
|
||||||
|
non-blank token.
|
||||||
|
"""
|
||||||
|
token_list = []
|
||||||
|
hs = []
|
||||||
|
cs = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||||
|
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||||
|
for k in range(len(topk_hyp_indexes)):
|
||||||
|
hyp_idx = topk_hyp_indexes[k]
|
||||||
|
hyp = A[i][hyp_idx]
|
||||||
|
|
||||||
|
new_token = topk_token_indexes[k]
|
||||||
|
if new_token not in (blank_id, unk_id):
|
||||||
|
assert new_token != 0, new_token
|
||||||
|
token_list.append([new_token])
|
||||||
|
# store the LSTM states
|
||||||
|
hs.append(hyp.state[0])
|
||||||
|
cs.append(hyp.state[1])
|
||||||
|
|
||||||
|
# forward RNNLM to get new states and scores
|
||||||
|
if len(token_list) != 0:
|
||||||
|
tokens_to_score = (
|
||||||
|
torch.tensor(token_list)
|
||||||
|
.to(torch.int64)
|
||||||
|
.to(device)
|
||||||
|
.reshape(-1, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
hs = torch.cat(hs, dim=1).to(device)
|
||||||
|
cs = torch.cat(cs, dim=1).to(device)
|
||||||
|
scores, lm_states = rnnlm.score_token(tokens_to_score, (hs, cs))
|
||||||
|
|
||||||
|
count = 0 # index, used to locate score and lm states
|
||||||
|
for i in range(batch_size):
|
||||||
|
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||||
|
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||||
|
|
||||||
|
for k in range(len(topk_hyp_indexes)):
|
||||||
|
hyp_idx = topk_hyp_indexes[k]
|
||||||
|
hyp = A[i][hyp_idx]
|
||||||
|
|
||||||
|
ys = hyp.ys[:]
|
||||||
|
|
||||||
|
# current score of hyp
|
||||||
|
lm_score = hyp.lm_score
|
||||||
|
state = hyp.state
|
||||||
|
|
||||||
|
hyp_log_prob = topk_log_probs[k] # get score of current hyp
|
||||||
|
new_token = topk_token_indexes[k]
|
||||||
|
if new_token not in (blank_id, unk_id):
|
||||||
|
|
||||||
|
ys.append(new_token)
|
||||||
|
state_cost = hyp.state_cost.forward_one_step(new_token)
|
||||||
|
|
||||||
|
# calculate the score of the latest token
|
||||||
|
current_ngram_score = (
|
||||||
|
state_cost.lm_score - hyp.state_cost.lm_score
|
||||||
|
)
|
||||||
|
|
||||||
|
assert current_ngram_score <= 0.0, (
|
||||||
|
state_cost.lm_score,
|
||||||
|
hyp.state_cost.lm_score,
|
||||||
|
)
|
||||||
|
# score = score + RNNLM_score - LODR_score
|
||||||
|
# LODR_LM_scale is a negative number here
|
||||||
|
hyp_log_prob += (
|
||||||
|
lm_score[new_token] * lm_scale
|
||||||
|
+ LODR_lm_scale * current_ngram_score
|
||||||
|
) # add the lm score
|
||||||
|
|
||||||
|
lm_score = scores[count]
|
||||||
|
state = (
|
||||||
|
lm_states[0][:, count, :].unsqueeze(1),
|
||||||
|
lm_states[1][:, count, :].unsqueeze(1),
|
||||||
|
)
|
||||||
|
count += 1
|
||||||
|
else:
|
||||||
|
state_cost = hyp.state_cost
|
||||||
|
|
||||||
|
new_hyp = Hypothesis(
|
||||||
|
ys=ys,
|
||||||
|
log_prob=hyp_log_prob,
|
||||||
|
state=state,
|
||||||
|
lm_score=lm_score,
|
||||||
|
state_cost=state_cost,
|
||||||
|
)
|
||||||
|
B[i].add(new_hyp)
|
||||||
|
|
||||||
|
B = B + finalized_B
|
||||||
|
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||||
|
|
||||||
|
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
||||||
|
ans = []
|
||||||
|
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||||
|
for i in range(N):
|
||||||
|
ans.append(sorted_ans[unsorted_indices[i]])
|
||||||
|
|
||||||
|
return ans
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
#
|
#
|
||||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang
|
||||||
|
# Xiaoyu Yang)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -90,8 +91,40 @@ Usage:
|
|||||||
--beam 20.0 \
|
--beam 20.0 \
|
||||||
--max-contexts 8 \
|
--max-contexts 8 \
|
||||||
--max-states 64
|
--max-states 64
|
||||||
"""
|
|
||||||
|
|
||||||
|
(8) modified beam search (with RNNLM shallow fusion)
|
||||||
|
./pruned_transducer_stateless3/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method modified_beam_search_rnnlm_shallow_fusion \
|
||||||
|
--beam 4 \
|
||||||
|
--rnn-lm-scale 0.3 \
|
||||||
|
--rnn-lm-exp-dir /path/to/RNNLM \
|
||||||
|
--rnn-lm-epoch 99 \
|
||||||
|
--rnn-lm-avg 1 \
|
||||||
|
--rnn-lm-num-layers 3 \
|
||||||
|
--rnn-lm-tie-weights 1
|
||||||
|
|
||||||
|
(9) modified beam search with RNNLM shallow fusion + LODR
|
||||||
|
./pruned_transducer_stateless3/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--max-duration 600 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
|
--decoding-method modified_beam_search_rnnlm_LODR \
|
||||||
|
--beam 4 \
|
||||||
|
--max-contexts 4 \
|
||||||
|
--rnn-lm-scale 0.4 \
|
||||||
|
--rnn-lm-exp-dir /path/to/RNNLM/exp \
|
||||||
|
--rnn-lm-epoch 99 \
|
||||||
|
--rnn-lm-avg 1 \
|
||||||
|
--rnn-lm-num-layers 3 \
|
||||||
|
--rnn-lm-tie-weights 1 \
|
||||||
|
--tokens-ngram 2 \
|
||||||
|
--ngram-lm-scale -0.16 \
|
||||||
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
@ -116,10 +149,14 @@ from beam_search import (
|
|||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
|
modified_beam_search_ngram_rescoring,
|
||||||
|
modified_beam_search_rnnlm_LODR,
|
||||||
|
modified_beam_search_rnnlm_shallow_fusion,
|
||||||
)
|
)
|
||||||
from librispeech import LibriSpeech
|
from librispeech import LibriSpeech
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall import NgramLm
|
||||||
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.rnn_lm.model import RnnLmModel
|
from icefall.rnn_lm.model import RnnLmModel
|
||||||
@ -202,6 +239,9 @@ def get_parser():
|
|||||||
- fast_beam_search_nbest
|
- fast_beam_search_nbest
|
||||||
- fast_beam_search_nbest_oracle
|
- fast_beam_search_nbest_oracle
|
||||||
- fast_beam_search_nbest_LG
|
- fast_beam_search_nbest_LG
|
||||||
|
- modified_beam_search_ngram_rescoring
|
||||||
|
- modified_beam_search_rnnlm_shallow_fusion
|
||||||
|
- modified_beam_search_rnnlm_LODR
|
||||||
If you use fast_beam_search_nbest_LG, you have to specify
|
If you use fast_beam_search_nbest_LG, you have to specify
|
||||||
`--lang-dir`, which should contain `LG.pt`.
|
`--lang-dir`, which should contain `LG.pt`.
|
||||||
""",
|
""",
|
||||||
@ -263,6 +303,7 @@ def get_parser():
|
|||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sym-per-frame",
|
"--max-sym-per-frame",
|
||||||
type=int,
|
type=int,
|
||||||
@ -341,6 +382,15 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--rnn-lm-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="""Used only when --method is modified-beam-search_rnnlm_shallow_fusion.
|
||||||
|
It specifies the path to RNN LM exp dir.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--rnn-lm-exp-dir",
|
"--rnn-lm-exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
@ -397,6 +447,24 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens-ngram",
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help="""Token Ngram used for rescoring.
|
||||||
|
Used only when the decoding method is
|
||||||
|
modified_beam_search_ngram_rescoring""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--backoff-id",
|
||||||
|
type=int,
|
||||||
|
default=500,
|
||||||
|
help="""ID of the backoff symbol.
|
||||||
|
Used only when the decoding method is
|
||||||
|
modified_beam_search_ngram_rescoring""",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -410,7 +478,10 @@ def decode_one_batch(
|
|||||||
word_table: Optional[k2.SymbolTable] = None,
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
G: Optional[k2.Fsa] = None,
|
G: Optional[k2.Fsa] = None,
|
||||||
rnn_lm_model: torch.nn.Module = None,
|
ngram_lm: Optional[NgramLm] = None,
|
||||||
|
ngram_lm_scale: float = 1.0,
|
||||||
|
rnn_lm_model: Optional[RnnLmModel] = None,
|
||||||
|
rnnlm_scale: float = 1.0,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, List[List[str]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
@ -444,6 +515,14 @@ def decode_one_batch(
|
|||||||
fast_beam_search_nbest, fast_beam_search_nbest_oracle,
|
fast_beam_search_nbest, fast_beam_search_nbest_oracle,
|
||||||
or fast_beam_search_with_nbest_rescoring.
|
or fast_beam_search_with_nbest_rescoring.
|
||||||
It an FsaVec containing an acceptor.
|
It an FsaVec containing an acceptor.
|
||||||
|
rnn_lm_model:
|
||||||
|
A rnnlm which can be used for rescoring or shallow fusion
|
||||||
|
rnnlm_scale:
|
||||||
|
The scale of the rnnlm.
|
||||||
|
ngram_lm:
|
||||||
|
A ngram lm. Used in LODR decoding.
|
||||||
|
ngram_lm_scale:
|
||||||
|
The scale of the ngram language model.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoding result. See above description for the format of
|
Return the decoding result. See above description for the format of
|
||||||
the returned dict.
|
the returned dict.
|
||||||
@ -607,6 +686,43 @@ def decode_one_batch(
|
|||||||
nbest_scale=params.nbest_scale,
|
nbest_scale=params.nbest_scale,
|
||||||
temperature=params.temperature,
|
temperature=params.temperature,
|
||||||
)
|
)
|
||||||
|
elif params.decoding_method == "modified_beam_search_ngram_rescoring":
|
||||||
|
hyp_tokens = modified_beam_search_ngram_rescoring(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
ngram_lm=ngram_lm,
|
||||||
|
ngram_lm_scale=ngram_lm_scale,
|
||||||
|
beam=params.beam_size,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
|
elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
|
||||||
|
hyp_tokens = modified_beam_search_rnnlm_shallow_fusion(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam_size,
|
||||||
|
sp=sp,
|
||||||
|
rnnlm=rnn_lm_model,
|
||||||
|
rnnlm_scale=rnnlm_scale,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
|
elif params.decoding_method == "modified_beam_search_rnnlm_LODR":
|
||||||
|
hyp_tokens = modified_beam_search_rnnlm_LODR(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam_size,
|
||||||
|
sp=sp,
|
||||||
|
LODR_lm=ngram_lm,
|
||||||
|
LODR_lm_scale=ngram_lm_scale,
|
||||||
|
rnnlm=rnn_lm_model,
|
||||||
|
rnnlm_scale=rnnlm_scale,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
else:
|
else:
|
||||||
batch_size = encoder_out.size(0)
|
batch_size = encoder_out.size(0)
|
||||||
|
|
||||||
@ -693,7 +809,10 @@ def decode_dataset(
|
|||||||
word_table: Optional[k2.SymbolTable] = None,
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
G: Optional[k2.Fsa] = None,
|
G: Optional[k2.Fsa] = None,
|
||||||
rnn_lm_model: torch.nn.Module = None,
|
ngram_lm: Optional[NgramLm] = None,
|
||||||
|
ngram_lm_scale: float = 1.0,
|
||||||
|
rnn_lm_model: Optional[RnnLmModel] = None,
|
||||||
|
rnnlm_scale: float = 1.0,
|
||||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
@ -749,7 +868,10 @@ def decode_dataset(
|
|||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
G=G,
|
G=G,
|
||||||
|
ngram_lm=ngram_lm,
|
||||||
|
ngram_lm_scale=ngram_lm_scale,
|
||||||
rnn_lm_model=rnn_lm_model,
|
rnn_lm_model=rnn_lm_model,
|
||||||
|
rnnlm_scale=rnnlm_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
@ -900,6 +1022,9 @@ def main():
|
|||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
"fast_beam_search_with_nbest_rescoring",
|
"fast_beam_search_with_nbest_rescoring",
|
||||||
"fast_beam_search_with_nbest_rnn_rescoring",
|
"fast_beam_search_with_nbest_rnn_rescoring",
|
||||||
|
"modified_beam_search_rnnlm_LODR",
|
||||||
|
"modified_beam_search_ngram_rescoring",
|
||||||
|
"modified_beam_search_rnnlm_shallow_fusion",
|
||||||
)
|
)
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
@ -930,6 +1055,13 @@ def main():
|
|||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
params.suffix += f"-temperature-{params.temperature}"
|
params.suffix += f"-temperature-{params.temperature}"
|
||||||
|
|
||||||
|
if "rnnlm" in params.decoding_method:
|
||||||
|
params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}"
|
||||||
|
if "LODR" in params.decoding_method:
|
||||||
|
params.suffix += "-LODR"
|
||||||
|
if "ngram" in params.decoding_method:
|
||||||
|
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||||
|
|
||||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
|
|
||||||
@ -1048,6 +1180,44 @@ def main():
|
|||||||
word_table = None
|
word_table = None
|
||||||
rnn_lm_model = None
|
rnn_lm_model = None
|
||||||
|
|
||||||
|
# only load N-gram LM when needed
|
||||||
|
if "ngram" in params.decoding_method or "LODR" in params.decoding_method:
|
||||||
|
lm_filename = f"{params.tokens_ngram}gram.fst.txt"
|
||||||
|
logging.info(f"lm filename: {lm_filename}")
|
||||||
|
ngram_lm = NgramLm(
|
||||||
|
str(params.lang_dir / lm_filename),
|
||||||
|
backoff_id=params.backoff_id,
|
||||||
|
is_binary=False,
|
||||||
|
)
|
||||||
|
logging.info(f"num states: {ngram_lm.lm.num_states}")
|
||||||
|
ngram_lm_scale = params.ngram_lm_scale
|
||||||
|
else:
|
||||||
|
ngram_lm = None
|
||||||
|
ngram_lm_scale = None
|
||||||
|
|
||||||
|
# only load rnnlm if used
|
||||||
|
if "rnnlm" in params.decoding_method:
|
||||||
|
rnn_lm_scale = params.rnn_lm_scale
|
||||||
|
|
||||||
|
rnn_lm_model = RnnLmModel(
|
||||||
|
vocab_size=params.vocab_size,
|
||||||
|
embedding_dim=params.rnn_lm_embedding_dim,
|
||||||
|
hidden_dim=params.rnn_lm_hidden_dim,
|
||||||
|
num_layers=params.rnn_lm_num_layers,
|
||||||
|
tie_weights=params.rnn_lm_tie_weights,
|
||||||
|
)
|
||||||
|
assert params.rnn_lm_avg == 1
|
||||||
|
|
||||||
|
load_checkpoint(
|
||||||
|
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
|
||||||
|
rnn_lm_model,
|
||||||
|
)
|
||||||
|
rnn_lm_model.to(device)
|
||||||
|
rnn_lm_model.eval()
|
||||||
|
else:
|
||||||
|
rnn_lm_model = None
|
||||||
|
rnn_lm_scale = 0.0
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
@ -1074,7 +1244,10 @@ def main():
|
|||||||
word_table=word_table,
|
word_table=word_table,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
G=G,
|
G=G,
|
||||||
|
ngram_lm=ngram_lm,
|
||||||
|
ngram_lm_scale=ngram_lm_scale,
|
||||||
rnn_lm_model=rnn_lm_model,
|
rnn_lm_model=rnn_lm_model,
|
||||||
|
rnnlm_scale=rnn_lm_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_results(
|
save_results(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user