Merge branch 'master' of https://github.com/k2-fsa/icefall into surt

This commit is contained in:
Desh Raj 2022-12-12 09:57:24 -05:00
commit 6892ac85fb
67 changed files with 7625 additions and 52 deletions

View File

@ -1,7 +1,7 @@
[flake8]
show-source=true
statistics=true
max-line-length = 80
max-line-length = 88
per-file-ignores =
# line too long
icefall/diagnostics.py: E501,
@ -12,6 +12,7 @@ per-file-ignores =
egs/librispeech/ASR/lstm_transducer_stateless*/*.py: E501, E203
egs/librispeech/ASR/conv_emformer_transducer_stateless*/*.py: E501, E203
egs/librispeech/ASR/conformer_ctc*/*py: E501,
egs/librispeech/ASR/zipformer_mmi/*.py: E501, E203
egs/librispeech/ASR/RESULTS.md: E999,
# invalid escape sequence (cause by tex formular), W605

View File

@ -13,7 +13,6 @@ cd egs/librispeech/ASR
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conformer-ctc3-2022-11-27
log "Downloading pre-trained model from $repo_url"
git lfs install
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
@ -23,7 +22,12 @@ soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
pushd $repo/exp
git lfs pull --include "data/*"
git lfs pull --include "data/lang_bpe_500/HLG.pt"
git lfs pull --include "data/lang_bpe_500/L.pt"
git lfs pull --include "data/lang_bpe_500/LG.pt"
git lfs pull --include "data/lang_bpe_500/Linv.pt"
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "data/lm/G_4_gram.pt"
git lfs pull --include "exp/jit_trace.pt"
git lfs pull --include "exp/pretrained.pt"
ln -s pretrained.pt epoch-99.pt

View File

@ -13,7 +13,6 @@ cd egs/librispeech/ASR
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-ctc-2022-12-01
log "Downloading pre-trained model from $repo_url"
git lfs install
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
@ -23,7 +22,12 @@ soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
pushd $repo/exp
git lfs pull --include "data/*"
git lfs pull --include "data/lang_bpe_500/HLG.pt"
git lfs pull --include "data/lang_bpe_500/L.pt"
git lfs pull --include "data/lang_bpe_500/LG.pt"
git lfs pull --include "data/lang_bpe_500/Linv.pt"
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "data/lm/G_4_gram.pt"
git lfs pull --include "exp/cpu_jit.pt"
git lfs pull --include "exp/pretrained.pt"
ln -s pretrained.pt epoch-99.pt

View File

@ -0,0 +1,103 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/librispeech/ASR
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-mmi-2022-12-08
log "Downloading pre-trained model from $repo_url"
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
log "Display test files"
tree $repo/
soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
pushd $repo/exp
git lfs pull --include "data/lang_bpe_500/3gram.pt"
git lfs pull --include "data/lang_bpe_500/4gram.pt"
git lfs pull --include "data/lang_bpe_500/L.pt"
git lfs pull --include "data/lang_bpe_500/LG.pt"
git lfs pull --include "data/lang_bpe_500/Linv.pt"
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/cpu_jit.pt"
git lfs pull --include "exp/pretrained.pt"
ln -s pretrained.pt epoch-99.pt
ls -lh *.pt
popd
log "Export to torchscript model"
./zipformer_mmi/export.py \
--exp-dir $repo/exp \
--use-averaged-model false \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--jit 1
ls -lh $repo/exp/*.pt
log "Decode with models exported by torch.jit.script()"
./zipformer_mmi/jit_pretrained.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--nn-model-filename $repo/exp/cpu_jit.pt \
--lang-dir $repo/data/lang_bpe_500 \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
for method in 1best nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescoring-4-gram; do
log "$method"
./zipformer_mmi/pretrained.py \
--method $method \
--checkpoint $repo/exp/pretrained.pt \
--lang-dir $repo/data/lang_bpe_500 \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
mkdir -p zipformer_mmi/exp
ln -s $PWD/$repo/exp/pretrained.pt zipformer_mmi/exp/epoch-999.pt
ln -s $PWD/$repo/data/lang_bpe_500 data/
ls -lh data
ls -lh zipformer_mmi/exp
log "Decoding test-clean and test-other"
# use a small value for decoding with CPU
max_duration=100
for method in 1best nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescoring-4-gram; do
log "Decoding with $method"
./zipformer_mmi/decode.py \
--decoding-method $method \
--epoch 999 \
--avg 1 \
--use-averaged-model 0 \
--nbest-scale 1.2 \
--hp-scale 1.0 \
--max-duration $max_duration \
--lang-dir $repo/data/lang_bpe_500 \
--exp-dir zipformer_mmi/exp
done
rm zipformer_mmi/exp/*.pt
fi

View File

@ -0,0 +1,167 @@
# Copyright 2022 Zengwei Yao
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: run-librispeech-2022-12-08-zipformer-mmi
# zipformer
on:
push:
branches:
- master
pull_request:
types: [labeled]
schedule:
# minute (0-59)
# hour (0-23)
# day of the month (1-31)
# month (1-12)
# day of the week (0-6)
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
concurrency:
group: run_librispeech_2022_12_08_zipformer-${{ github.ref }}
cancel-in-progress: true
jobs:
run_librispeech_2022_12_08_zipformer:
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest]
python-version: [3.8]
fail-fast: false
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
cache-dependency-path: '**/requirements-ci.txt'
- name: Install Python dependencies
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
pip uninstall -y protobuf
pip install --no-binary protobuf protobuf
- name: Cache kaldifeat
id: my-cache
uses: actions/cache@v2
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/install-kaldifeat.sh
- name: Cache LibriSpeech test-clean and test-other datasets
id: libri-test-clean-and-test-other-data
uses: actions/cache@v2
with:
path: |
~/tmp/download
key: cache-libri-test-clean-and-test-other
- name: Download LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
- name: Prepare manifests for LibriSpeech test-clean and test-other
shell: bash
run: |
.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
- name: Cache LibriSpeech test-clean and test-other fbank features
id: libri-test-clean-and-test-other-fbank
uses: actions/cache@v2
with:
path: |
~/tmp/fbank-libri
key: cache-libri-fbank-test-clean-and-test-other-v2
- name: Compute fbank for LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
- name: Inference with pre-trained model
shell: bash
env:
GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
run: |
mkdir -p egs/librispeech/ASR/data
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
ls -lh egs/librispeech/ASR/data/*
sudo apt-get -qq install git-lfs tree sox
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh
- name: Display decoding results for librispeech zipformer-mmi
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
shell: bash
run: |
cd egs/librispeech/ASR/
tree ./zipformer-mmi/exp
cd zipformer-mmi
echo "results for zipformer-mmi"
echo "===1best==="
find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===nbest==="
find exp/nbest -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/nbest -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===nbest-rescoring-LG==="
find exp/nbest-rescoring-LG -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/nbest-rescoring-LG -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===nbest-rescoring-3-gram==="
find exp/nbest-rescoring-3-gram -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/nbest-rescoring-3-gram -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===nbest-rescoring-4-gram==="
find exp/nbest-rescoring-4-gram -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/nbest-rescoring-4-gram -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Upload decoding results for librispeech zipformer-mmi
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-zipformer_mmi-2022-12-08
path: egs/librispeech/ASR/zipformer_mmi/exp/

View File

@ -7,3 +7,4 @@ LibriSpeech
tdnn_lstm_ctc
conformer_ctc
lstm_pruned_stateless_transducer
zipformer_mmi

View File

@ -0,0 +1,422 @@
Zipformer MMI
===============
.. hint::
Please scroll down to the bottom of this page to find download links
for pretrained models if you don't want to train a model from scratch.
This tutorial shows you how to train an Zipformer MMI model
with the `LibriSpeech <https://www.openslr.org/12>`_ dataset.
We use LF-MMI to compute the loss.
.. note::
You can find the document about LF-MMI training at the following address:
`<https://github.com/k2-fsa/next-gen-kaldi-wechat/blob/master/pdf/LF-MMI-training-and-decoding-in-k2-Part-I.pdf>`_
Data preparation
----------------
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./prepare.sh
The script ``./prepare.sh`` handles the data preparation for you, **automagically**.
All you need to do is to run it.
.. note::
We encourage you to read ``./prepare.sh``.
The data preparation contains several stages. You can use the following two
options:
- ``--stage``
- ``--stop-stage``
to control which stage(s) should be run. By default, all stages are executed.
For example,
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./prepare.sh --stage 0 --stop-stage 0
means to run only stage 0.
To run stage 2 to stage 5, use:
.. code-block:: bash
$ ./prepare.sh --stage 2 --stop-stage 5
.. hint::
If you have pre-downloaded the `LibriSpeech <https://www.openslr.org/12>`_
dataset and the `musan <http://www.openslr.org/17/>`_ dataset, say,
they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify
the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that
``./prepare.sh`` won't re-download them.
.. note::
All generated files by ``./prepare.sh``, e.g., features, lexicon, etc,
are saved in ``./data`` directory.
We provide the following YouTube video showing how to run ``./prepare.sh``.
.. note::
To get the latest news of `next-gen Kaldi <https://github.com/k2-fsa>`_, please subscribe
the following YouTube channel by `Nadira Povey <https://www.youtube.com/channel/UC_VaumpkmINz1pNkFXAN9mw>`_:
`<https://www.youtube.com/channel/UC_VaumpkmINz1pNkFXAN9mw>`_
.. youtube:: ofEIoJL-mGM
Training
--------
For stability, it uses CTC loss for model warm-up and then switches to MMI loss.
Configurable options
~~~~~~~~~~~~~~~~~~~~
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./zipformer_mmi/train.py --help
shows you the training options that can be passed from the commandline.
The following options are used quite often:
- ``--full-libri``
If it's True, the training part uses all the training data, i.e.,
960 hours. Otherwise, the training part uses only the subset
``train-clean-100``, which has 100 hours of training data.
.. CAUTION::
The training set is perturbed by speed with two factors: 0.9 and 1.1.
If ``--full-libri`` is True, each epoch actually processes
``3x960 == 2880`` hours of data.
- ``--num-epochs``
It is the number of epochs to train. For instance,
``./zipformer_mmi/train.py --num-epochs 30`` trains for 30 epochs
and generates ``epoch-1.pt``, ``epoch-2.pt``, ..., ``epoch-30.pt``
in the folder ``./zipformer_mmi/exp``.
- ``--start-epoch``
It's used to resume training.
``./zipformer_mmi/train.py --start-epoch 10`` loads the
checkpoint ``./zipformer_mmi/exp/epoch-9.pt`` and starts
training from epoch 10, based on the state from epoch 9.
- ``--world-size``
It is used for multi-GPU single-machine DDP training.
- (a) If it is 1, then no DDP training is used.
- (b) If it is 2, then GPU 0 and GPU 1 are used for DDP training.
The following shows some use cases with it.
**Use case 1**: You have 4 GPUs, but you only want to use GPU 0 and
GPU 2 for training. You can do the following:
.. code-block:: bash
$ cd egs/librispeech/ASR
$ export CUDA_VISIBLE_DEVICES="0,2"
$ ./zipformer_mmi/train.py --world-size 2
**Use case 2**: You have 4 GPUs and you want to use all of them
for training. You can do the following:
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./zipformer_mmi/train.py --world-size 4
**Use case 3**: You have 4 GPUs but you only want to use GPU 3
for training. You can do the following:
.. code-block:: bash
$ cd egs/librispeech/ASR
$ export CUDA_VISIBLE_DEVICES="3"
$ ./zipformer_mmi/train.py --world-size 1
.. caution::
Only multi-GPU single-machine DDP training is implemented at present.
Multi-GPU multi-machine DDP training will be added later.
- ``--max-duration``
It specifies the number of seconds over all utterances in a
batch, before **padding**.
If you encounter CUDA OOM, please reduce it.
.. HINT::
Due to padding, the number of seconds of all utterances in a
batch will usually be larger than ``--max-duration``.
A larger value for ``--max-duration`` may cause OOM during training,
while a smaller value may increase the training time. You have to
tune it.
Pre-configured options
~~~~~~~~~~~~~~~~~~~~~~
There are some training options, e.g., weight decay,
number of warmup steps, results dir, etc,
that are not passed from the commandline.
They are pre-configured by the function ``get_params()`` in
`zipformer_mmi/train.py <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/zipformer_mmi/train.py>`_
You don't need to change these pre-configured parameters. If you really need to change
them, please modify ``./zipformer_mmi/train.py`` directly.
Training logs
~~~~~~~~~~~~~
Training logs and checkpoints are saved in ``zipformer_mmi/exp``.
You will find the following files in that directory:
- ``epoch-1.pt``, ``epoch-2.pt``, ...
These are checkpoint files saved at the end of each epoch, containing model
``state_dict`` and optimizer ``state_dict``.
To resume training from some checkpoint, say ``epoch-10.pt``, you can use:
.. code-block:: bash
$ ./zipformer_mmi/train.py --start-epoch 11
- ``checkpoint-436000.pt``, ``checkpoint-438000.pt``, ...
These are checkpoint files saved every ``--save-every-n`` batches,
containing model ``state_dict`` and optimizer ``state_dict``.
To resume training from some checkpoint, say ``checkpoint-436000``, you can use:
.. code-block:: bash
$ ./zipformer_mmi/train.py --start-batch 436000
- ``tensorboard/``
This folder contains tensorBoard logs. Training loss, validation loss, learning
rate, etc, are recorded in these logs. You can visualize them by:
.. code-block:: bash
$ cd zipformer_mmi/exp/tensorboard
$ tensorboard dev upload --logdir . --description "Zipformer MMI training for LibriSpeech with icefall"
It will print something like below:
.. code-block::
TensorFlow installation not found - running with reduced feature set.
Upload started and will continue reading any new data as it's added to the logdir.
To stop uploading, press Ctrl-C.
New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/xyOZUKpEQm62HBIlUD4uPA/
Note there is a URL in the above output. Click it and you will see
tensorboard.
.. hint::
If you don't have access to google, you can use the following command
to view the tensorboard log locally:
.. code-block:: bash
cd zipformer_mmi/exp/tensorboard
tensorboard --logdir . --port 6008
It will print the following message:
.. code-block::
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.8.0 at http://localhost:6008/ (Press CTRL+C to quit)
Now start your browser and go to `<http://localhost:6008>`_ to view the tensorboard
logs.
- ``log/log-train-xxxx``
It is the detailed training log in text format, same as the one
you saw printed to the console during training.
Usage example
~~~~~~~~~~~~~
You can use the following command to start the training using 8 GPUs:
.. code-block:: bash
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./zipformer_mmi/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--full-libri 1 \
--exp-dir zipformer_mmi/exp \
--max-duration 500 \
--use-fp16 1 \
--num-workers 2
Decoding
--------
The decoding part uses checkpoints saved by the training part, so you have
to run the training part first.
.. hint::
There are two kinds of checkpoints:
- (1) ``epoch-1.pt``, ``epoch-2.pt``, ..., which are saved at the end
of each epoch. You can pass ``--epoch`` to
``zipformer_mmi/decode.py`` to use them.
- (2) ``checkpoints-436000.pt``, ``epoch-438000.pt``, ..., which are saved
every ``--save-every-n`` batches. You can pass ``--iter`` to
``zipformer_mmi/decode.py`` to use them.
We suggest that you try both types of checkpoints and choose the one
that produces the lowest WERs.
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./zipformer_mmi/decode.py --help
shows the options for decoding.
The following shows the example using ``epoch-*.pt``:
.. code-block:: bash
for m in nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescoring-4-gram; do
./zipformer_mmi/decode.py \
--epoch 30 \
--avg 10 \
--exp-dir ./zipformer_mmi/exp/ \
--max-duration 100 \
--lang-dir data/lang_bpe_500 \
--nbest-scale 1.2 \
--hp-scale 1.0 \
--decoding-method $m
done
Export models
-------------
`zipformer_mmi/export.py <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/zipformer_mmi/export.py>`_ supports exporting checkpoints from ``zipformer_mmi/exp`` in the following ways.
Export ``model.state_dict()``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Checkpoints saved by ``zipformer_mmi/train.py`` also include
``optimizer.state_dict()``. It is useful for resuming training. But after training,
we are interested only in ``model.state_dict()``. You can use the following
command to extract ``model.state_dict()``.
.. code-block:: bash
./zipformer_mmi/export.py \
--exp-dir ./zipformer_mmi/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 30 \
--avg 9 \
--jit 0
It will generate a file ``./zipformer_mmi/exp/pretrained.pt``.
.. hint::
To use the generated ``pretrained.pt`` for ``zipformer_mmi/decode.py``,
you can run:
.. code-block:: bash
cd zipformer_mmi/exp
ln -s pretrained epoch-9999.pt
And then pass ``--epoch 9999 --avg 1 --use-averaged-model 0`` to
``./zipformer_mmi/decode.py``.
To use the exported model with ``./zipformer_mmi/pretrained.py``, you
can run:
.. code-block:: bash
./zipformer_mmi/pretrained.py \
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method 1best \
/path/to/foo.wav \
/path/to/bar.wav
Export model using ``torch.jit.script()``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: bash
./zipformer_mmi/export.py \
--exp-dir ./zipformer_mmi/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 30 \
--avg 9 \
--jit 1
It will generate a file ``cpu_jit.pt`` in the given ``exp_dir``. You can later
load it by ``torch.jit.load("cpu_jit.pt")``.
Note ``cpu`` in the name ``cpu_jit.pt`` means the parameters when loaded into Python
are on CPU. You can use ``to("cuda")`` to move them to a CUDA device.
To use the generated files with ``./zipformer_mmi/jit_pretrained.py``:
.. code-block:: bash
./zipformer_mmi/jit_pretrained.py \
--nn-model-filename ./zipformer_mmi/exp/cpu_jit.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method 1best \
/path/to/foo.wav \
/path/to/bar.wav
Download pretrained models
--------------------------
If you don't want to train from scratch, you can download the pretrained models
by visiting the following links:
- `<https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-mmi-2022-12-08>`_
See `<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md>`_
for the details of the above pretrained models

View File

@ -0,0 +1,38 @@
# Introduction
This recipe trains multi-domain ASR models for AliMeeting. By multi-domain, we mean that
we train a single model on close-talk and far-field conditions. This recipe optionally
uses [GSS]-based enhancement for far-field array microphone.
We pool data in the following 4 ways and train a single model on the pooled data:
(i) individual headset microphone (IHM)
(ii) IHM with simulated reverb
(iii) Single distant microphone (SDM)
(iv) GSS-enhanced array microphones
This is different from `alimeeting/ASR` since that recipe trains a model only on the
far-field audio. Additionally, we use text normalization here similar to the original
M2MeT challenge, so the results should be more comparable to those from Table 4 of
the [paper](https://arxiv.org/abs/2110.07393).
The following additional packages need to be installed to run this recipe:
* `pip install jieba`
* `pip install paddlepaddle`
* `pip install git+https://github.com/desh2608/gss.git`
[./RESULTS.md](./RESULTS.md) contains the latest results.
## Performance Record
### pruned_transducer_stateless7
The following are decoded using `modified_beam_search`:
| Evaluation set | eval WER | test WER |
|--------------------------|------------|---------|
| IHM | 9.58 | 11.53 |
| SDM | 23.37 | 25.85 |
| MDM (GSS-enhanced) | 11.82 | 14.22 |
See [RESULTS](/egs/alimeeting/ASR_v2/RESULTS.md) for details.

View File

@ -0,0 +1,90 @@
## Results (CER)
#### 2022-12-09
#### Zipformer (pruned_transducer_stateless7)
Zipformer encoder + non-current decoder. The decoder
contains only an embedding layer, a Conv1d (with kernel size 2) and a linear
layer (to transform tensor dim).
All the results below are using a single model that is trained by combining the following
data: IHM, IHM+reverb, SDM, and GSS-enhanced MDM. Speed perturbation and MUSAN noise
augmentation are applied on top of the pooled data.
**WERs for IHM:**
| | eval | test | comment |
|---------------------------|------------|------------|------------------------------------------|
| greedy search | 10.13 | 12.21 | --epoch 15 --avg 8 --max-duration 500 |
| modified beam search | 9.58 | 11.53 | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 |
| fast beam search | 9.92 | 12.07 | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |
**WERs for SDM:**
| | eval | test | comment |
|---------------------------|------------|------------|------------------------------------------|
| greedy search | 23.70 | 26.41 | --epoch 15 --avg 8 --max-duration 500 |
| modified beam search | 23.37 | 25.85 | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 |
| fast beam search | 23.60 | 26.38 | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |
**WERs for GSS-enhanced MDM:**
| | eval | test | comment |
|---------------------------|------------|------------|------------------------------------------|
| greedy search | 12.24 | 14.99 | --epoch 15 --avg 8 --max-duration 500 |
| modified beam search | 11.82 | 14.22 | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 |
| fast beam search | 12.30 | 14.98 | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |
The training command for reproducing is given below:
```
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless7/train.py \
--world-size 4 \
--num-epochs 15 \
--exp-dir pruned_transducer_stateless7/exp \
--max-duration 300 \
--max-cuts 100 \
--prune-range 5 \
--lr-factor 5 \
--lm-scale 0.25 \
--use-fp16 True
```
The decoding command is:
```
# greedy search
./pruned_transducer_stateless7/decode.py \
--epoch 15 \
--avg 8 \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 500 \
--decoding-method greedy_search
# modified beam search
./pruned_transducer_stateless7/decode.py \
--epoch 15 \
--avg 8 \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 500 \
--decoding-method modified_beam_search \
--beam-size 4
# fast beam search
./pruned_transducer_stateless7/decode.py \
--epoch 15 \
--avg 8 \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
```
Pretrained model is available at <https://huggingface.co/desh2608/icefall-asr-alimeeting-pruned-transducer-stateless7>
The tensorboard training log can be found at
<https://tensorboard.dev/experiment/EzmVahMMTb2YfKWXwQ2dyQ/#scalars>

View File

View File

@ -0,0 +1,193 @@
#!/usr/bin/env python3
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the AliMeeting dataset.
For the training data, we prepare IHM, reverberated IHM, SDM, and GSS-enhanced
audios. For the test data, we separately prepare IHM, SDM, and GSS-enhanced
parts (which are the 3 evaluation settings).
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import logging
from pathlib import Path
import torch
import torch.multiprocessing
from lhotse import CutSet, LilcomChunkyWriter
from lhotse.features.kaldifeat import (
KaldifeatFbank,
KaldifeatFbankConfig,
KaldifeatFrameOptions,
KaldifeatMelOptions,
)
from lhotse.recipes.utils import read_manifests_if_cached
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
torch.multiprocessing.set_sharing_strategy("file_system")
def compute_fbank_ami():
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
sampling_rate = 16000
num_mel_bins = 80
extractor = KaldifeatFbank(
KaldifeatFbankConfig(
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
device="cuda",
)
)
logging.info("Reading manifests")
manifests_ihm = read_manifests_if_cached(
dataset_parts=["train", "eval", "test"],
output_dir=src_dir,
prefix="alimeeting-ihm",
suffix="jsonl.gz",
)
manifests_sdm = read_manifests_if_cached(
dataset_parts=["train", "eval", "test"],
output_dir=src_dir,
prefix="alimeeting-sdm",
suffix="jsonl.gz",
)
# For GSS we already have cuts so we read them directly.
manifests_gss = read_manifests_if_cached(
dataset_parts=["train", "eval", "test"],
output_dir=src_dir,
prefix="alimeeting-gss",
suffix="jsonl.gz",
)
def _extract_feats(cuts: CutSet, storage_path: Path, manifest_path: Path) -> None:
cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1)
_ = cuts.compute_and_store_features_batch(
extractor=extractor,
storage_path=storage_path,
manifest_path=manifest_path,
batch_duration=5000,
num_workers=8,
storage_type=LilcomChunkyWriter,
)
logging.info(
"Preparing training cuts: IHM + reverberated IHM + SDM + GSS (optional)"
)
logging.info("Processing train split IHM")
cuts_ihm = (
CutSet.from_manifests(**manifests_ihm["train"])
.trim_to_supervisions(keep_overlapping=False, keep_all_channels=False)
.modify_ids(lambda x: x + "-ihm")
)
_extract_feats(
cuts_ihm,
output_dir / "feats_train_ihm",
src_dir / "cuts_train_ihm.jsonl.gz",
)
logging.info("Processing train split IHM + reverberated IHM")
cuts_ihm_rvb = cuts_ihm.reverb_rir()
_extract_feats(
cuts_ihm_rvb,
output_dir / "feats_train_ihm_rvb",
src_dir / "cuts_train_ihm_rvb.jsonl.gz",
)
logging.info("Processing train split SDM")
cuts_sdm = (
CutSet.from_manifests(**manifests_sdm["train"])
.trim_to_supervisions(keep_overlapping=False)
.modify_ids(lambda x: x + "-sdm")
)
_extract_feats(
cuts_sdm,
output_dir / "feats_train_sdm",
src_dir / "cuts_train_sdm.jsonl.gz",
)
logging.info("Processing train split GSS")
cuts_gss = (
CutSet.from_manifests(**manifests_gss["train"])
.trim_to_supervisions(keep_overlapping=False)
.modify_ids(lambda x: x + "-gss")
)
_extract_feats(
cuts_gss,
output_dir / "feats_train_gss",
src_dir / "cuts_train_gss.jsonl.gz",
)
logging.info("Preparing test cuts: IHM, SDM, GSS (optional)")
for split in ["eval", "test"]:
logging.info(f"Processing {split} IHM")
cuts_ihm = (
CutSet.from_manifests(**manifests_ihm[split])
.trim_to_supervisions(keep_overlapping=False, keep_all_channels=False)
.compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir / f"feats_{split}_ihm",
manifest_path=src_dir / f"cuts_{split}_ihm.jsonl.gz",
batch_duration=500,
num_workers=4,
storage_type=LilcomChunkyWriter,
)
)
logging.info(f"Processing {split} SDM")
cuts_sdm = (
CutSet.from_manifests(**manifests_sdm[split])
.trim_to_supervisions(keep_overlapping=False)
.compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir / f"feats_{split}_sdm",
manifest_path=src_dir / f"cuts_{split}_sdm.jsonl.gz",
batch_duration=500,
num_workers=4,
storage_type=LilcomChunkyWriter,
)
)
logging.info(f"Processing {split} GSS")
cuts_gss = (
CutSet.from_manifests(**manifests_gss[split])
.trim_to_supervisions(keep_overlapping=False)
.compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir / f"feats_{split}_gss",
manifest_path=src_dir / f"cuts_{split}_gss.jsonl.gz",
batch_duration=500,
num_workers=4,
storage_type=LilcomChunkyWriter,
)
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_ami()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/compute_fbank_musan.py

View File

@ -0,0 +1,158 @@
#!/usr/local/bin/python
# -*- coding: utf-8 -*-
# Data preparation for AliMeeting GSS-enhanced dataset.
import logging
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from lhotse import Recording, RecordingSet, SupervisionSet
from lhotse.qa import fix_manifests
from lhotse.recipes.utils import read_manifests_if_cached
from lhotse.utils import fastcopy
from tqdm import tqdm
logging.basicConfig(
format="%(asctime)s %(levelname)-8s %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)
def get_args():
import argparse
parser = argparse.ArgumentParser(description="AMI enhanced dataset preparation.")
parser.add_argument(
"manifests_dir",
type=Path,
help="Path to directory containing AliMeeting manifests.",
)
parser.add_argument(
"enhanced_dir",
type=Path,
help="Path to enhanced data directory.",
)
parser.add_argument(
"--num-jobs",
"-j",
type=int,
default=1,
help="Number of parallel jobs to run.",
)
parser.add_argument(
"--min-segment-duration",
"-d",
type=float,
default=0.0,
help="Minimum duration of a segment in seconds.",
)
return parser.parse_args()
def find_recording_and_create_new_supervision(enhanced_dir, supervision):
"""
Given a supervision (corresponding to original AMI recording), this function finds the
enhanced recording correspoding to the supervision, and returns this recording and
a new supervision whose start and end times are adjusted to match the enhanced recording.
"""
file_name = Path(
f"{supervision.recording_id}-{supervision.speaker}-{int(100*supervision.start):06d}_{int(100*supervision.end):06d}.flac"
)
save_path = enhanced_dir / f"{supervision.recording_id}" / file_name
if save_path.exists():
recording = Recording.from_file(save_path)
if recording.duration == 0:
logging.warning(f"Skipping {save_path} which has duration 0 seconds.")
return None
# Old supervision is wrt to the original recording, we create new supervision
# wrt to the enhanced segment
new_supervision = fastcopy(
supervision,
recording_id=recording.id,
start=0,
duration=recording.duration,
)
return recording, new_supervision
else:
logging.warning(f"{save_path} does not exist.")
return None
def main(args):
# Get arguments
manifests_dir = args.manifests_dir
enhanced_dir = args.enhanced_dir
# Load manifests from cache if they exist (saves time)
manifests = read_manifests_if_cached(
dataset_parts=["train", "eval", "test"],
output_dir=manifests_dir,
prefix="alimeeting-sdm",
suffix="jsonl.gz",
)
if not manifests:
raise ValueError(
"AliMeeting SDM manifests not found in {}".format(manifests_dir)
)
with ThreadPoolExecutor(args.num_jobs) as ex:
for part in ["train", "eval", "test"]:
logging.info(f"Processing {part}...")
supervisions_orig = manifests[part]["supervisions"].filter(
lambda s: s.duration >= args.min_segment_duration
)
futures = []
for supervision in tqdm(
supervisions_orig,
desc="Distributing tasks",
):
futures.append(
ex.submit(
find_recording_and_create_new_supervision,
enhanced_dir,
supervision,
)
)
recordings = []
supervisions = []
for future in tqdm(
futures,
total=len(futures),
desc="Processing tasks",
):
result = future.result()
if result is not None:
recording, new_supervision = result
recordings.append(recording)
supervisions.append(new_supervision)
# Remove duplicates from the recordings
recordings_nodup = {}
for recording in recordings:
if recording.id not in recordings_nodup:
recordings_nodup[recording.id] = recording
else:
logging.warning("Recording {} is duplicated.".format(recording.id))
recordings = RecordingSet.from_recordings(recordings_nodup.values())
supervisions = SupervisionSet.from_segments(supervisions)
recordings, supervisions = fix_manifests(
recordings=recordings, supervisions=supervisions
)
logging.info(f"Writing {part} enhanced manifests")
recordings.to_file(
manifests_dir / f"alimeeting-gss_recordings_{part}.jsonl.gz"
)
supervisions.to_file(
manifests_dir / f"alimeeting-gss_supervisions_{part}.jsonl.gz"
)
if __name__ == "__main__":
args = get_args()
main(args)

View File

@ -0,0 +1,98 @@
#!/bin/bash
# This script is used to run GSS-based enhancement on AMI data.
set -euo pipefail
nj=4
stage=0
. shared/parse_options.sh || exit 1
if [ $# != 2 ]; then
echo "Wrong #arguments ($#, expected 2)"
echo "Usage: local/prepare_alimeeting_gss.sh [options] <data-dir> <exp-dir>"
echo "e.g. local/prepare_alimeeting_gss.sh data/manifests exp/ami_gss"
echo "main options (for others, see top of script file)"
echo " --nj <nj> # number of parallel jobs"
echo " --stage <stage> # stage to start running from"
exit 1;
fi
DATA_DIR=$1
EXP_DIR=$2
mkdir -p $EXP_DIR
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
if [ $stage -le 1 ]; then
log "Stage 1: Prepare cut sets"
for part in train eval test; do
lhotse cut simple \
-r $DATA_DIR/alimeeting-mdm_recordings_${part}.jsonl.gz \
-s $DATA_DIR/alimeeting-mdm_supervisions_${part}.jsonl.gz \
$EXP_DIR/cuts_${part}.jsonl.gz
done
fi
if [ $stage -le 2 ]; then
log "Stage 2: Trim cuts to supervisions (1 cut per supervision segment)"
for part in train eval test; do
lhotse cut trim-to-supervisions --discard-overlapping \
$EXP_DIR/cuts_${part}.jsonl.gz $EXP_DIR/cuts_per_segment_${part}.jsonl.gz
done
fi
if [ $stage -le 3 ]; then
log "Stage 3: Split manifests for multi-GPU processing (optional)"
for part in train eval test; do
gss utils split $nj $EXP_DIR/cuts_per_segment_${part}.jsonl.gz \
$EXP_DIR/cuts_per_segment_${part}_split$nj
done
fi
if [ $stage -le 4 ]; then
log "Stage 4: Enhance train segments using GSS (requires GPU)"
# for train, we use smaller context and larger batches to speed-up processing
for JOB in $(seq $nj); do
gss enhance cuts $EXP_DIR/cuts_train.jsonl.gz \
$EXP_DIR/cuts_per_segment_train_split$nj/cuts_per_segment_train.JOB.jsonl.gz $EXP_DIR/enhanced \
--bss-iterations 10 \
--context-duration 5.0 \
--use-garbage-class \
--channels 0,1,2,3,4,5,6,7 \
--min-segment-length 0.05 \
--max-segment-length 25.0 \
--max-batch-duration 60.0 \
--num-buckets 4 \
--num-workers 4
done
fi
if [ $stage -le 5 ]; then
log "Stage 5: Enhance eval/test segments using GSS (using GPU)"
# for eval/test, we use larger context and smaller batches to get better quality
for part in eval test; do
for JOB in $(seq $nj); do
gss enhance cuts $EXP_DIR/cuts_${part}.jsonl.gz \
$EXP_DIR/cuts_per_segment_${part}_split$nj/cuts_per_segment_${part}.JOB.jsonl.gz \
$EXP_DIR/enhanced \
--bss-iterations 10 \
--context-duration 15.0 \
--use-garbage-class \
--channels 0,1,2,3,4,5,6,7 \
--min-segment-length 0.05 \
--max-segment-length 16.0 \
--max-batch-duration 45.0 \
--num-buckets 4 \
--num-workers 4
done
done
fi
if [ $stage -le 6 ]; then
log "Stage 6: Prepare manifests for GSS-enhanced data"
python local/prepare_alimeeting_enhanced.py $DATA_DIR $EXP_DIR/enhanced -j $nj --min-segment-duration 0.05
fi

View File

@ -0,0 +1 @@
../../ASR/local/prepare_char.py

View File

@ -0,0 +1 @@
../../ASR/local/prepare_words.py

View File

@ -0,0 +1 @@
../../ASR/local/text2segments.py

View File

@ -0,0 +1 @@
../../ASR/local/text2token.py

125
egs/alimeeting/ASR_v2/prepare.sh Executable file
View File

@ -0,0 +1,125 @@
#!/usr/bin/env bash
set -eou pipefail
stage=-1
stop_stage=100
use_gss=true # Use GSS-based enhancement with MDM setting
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
# by this script automatically.
#
# - $dl_dir/alimeeting
# This directory contains the following files downloaded from
# https://openslr.org/62/
#
# - Train_Ali_far.tar.gz
# - Train_Ali_near.tar.gz
# - Test_Ali.tar.gz
# - Eval_Ali.tar.gz
#
# - $dl_dir/musan
# This directory contains the following directories downloaded from
# http://www.openslr.org/17/
#
# - music
# - noise
# - speech
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
if [ ! -f $dl_dir/alimeeting/Train_Ali_far.tar.gz ]; then
lhotse download ali-meeting $dl_dir/alimeeting
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare alimeeting manifest"
# We assume that you have downloaded the alimeeting corpus
# to $dl_dir/alimeeting
for part in ihm sdm mdm; do
mkdir -p data/manifests/alimeeting
lhotse prepare ali-meeting --mic $part --save-mono --normalize-text m2met \
$dl_dir/alimeeting data/manifests
done
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Prepare musan manifest"
# We assume that you have downloaded the musan corpus
# to data/musan
mkdir -p data/manifests
lhotse prepare musan $dl_dir/musan data/manifests
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ] && [ $use_gss = true ]; then
log "Stage 3: Apply GSS enhancement on MDM data (this stage requires a GPU)"
# We assume that you have installed the GSS package: https://github.com/desh2608/gss
local/prepare_alimeeting_gss.sh data/manifests exp/alimeeting_gss
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute fbank for musan"
mkdir -p data/fbank
python local/compute_fbank_musan.py
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Compute fbank for alimeeting"
mkdir -p data/fbank
python local/compute_fbank_alimeeting.py
log "Combine features from train splits"
lhotse combine data/manifests/cuts_train_{ihm,ihm_rvb,sdm,gss}.jsonl.gz - | shuf |\
gzip -c > data/manifests/cuts_train_all.jsonl.gz
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Prepare char based lang"
lang_char_dir=data/lang_char
mkdir -p $lang_char_dir
# Prepare text.
# Note: in Linux, you can install jq with the following command:
# wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
gunzip -c data/manifests/alimeeting-sdm_supervisions_train.jsonl.gz \
| jq ".text" | sed 's/"//g' \
| ./local/text2token.py -t "char" > $lang_char_dir/text
# Prepare words segments
python ./local/text2segments.py \
--input $lang_char_dir/text \
--output $lang_char_dir/text_words_segmentation
cat $lang_char_dir/text_words_segmentation | sed "s/ /\n/g" \
| sort -u | sed "/^$/d" \
| uniq > $lang_char_dir/words_no_ids.txt
# Prepare words.txt
if [ ! -f $lang_char_dir/words.txt ]; then
./local/prepare_words.py \
--input-file $lang_char_dir/words_no_ids.txt \
--output-file $lang_char_dir/words.txt
fi
if [ ! -f $lang_char_dir/L_disambig.pt ]; then
./local/prepare_char.py
fi
fi

View File

@ -0,0 +1,419 @@
# Copyright 2021 Piotr Żelasko
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import re
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.cut import Cut
from lhotse.dataset import (
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from tqdm import tqdm
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class AlimeetingAsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description=(
"These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc."
),
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/manifests"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help=(
"When enabled, select noise from MUSAN and mix it "
"with training dataset. "
),
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help=(
"When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding."
),
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help=(
"Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch."
),
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help=(
"The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used."
),
)
group.add_argument(
"--max-duration",
type=int,
default=100.0,
help=(
"Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM."
),
)
group.add_argument(
"--max-cuts", type=int, default=None, help="Maximum cuts in a single batch."
)
group.add_argument(
"--num-buckets",
type=int,
default=50,
help=(
"The number of buckets for the BucketingSampler"
"(you might want to increase it for larger datasets)."
),
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help=(
"When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available."
),
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help=(
"When enabled (=default), the examples will be "
"shuffled for each epoch."
),
)
group.add_argument(
"--num-workers",
type=int,
default=8,
help=(
"The number of training dataloader workers that " "collect the batches."
),
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help=(
"Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp."
),
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
logging.info("About to get Musan cuts")
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
if self.args.concatenate_cuts:
logging.info(
"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=2,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
if self.args.on_the_fly_feats:
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
)
else:
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_transforms=input_transforms,
)
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
max_cuts=self.args.max_cuts,
shuffle=False,
num_buckets=self.args.num_buckets,
drop_last=True,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=True,
)
sampler = DynamicBucketingSampler(
cuts, max_duration=self.args.max_duration, shuffle=False
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
def remove_short_cuts(self, cut: Cut) -> bool:
"""
See: https://github.com/k2-fsa/icefall/issues/500
Basically, the zipformer model subsamples the input using the following formula:
num_out_frames = ((num_in_frames - 7)//2 + 1)//2
For num_out_frames to be at least 1, num_in_frames must be at least 9.
"""
return cut.duration >= 0.09
@lru_cache()
def train_cuts(self, sp: Optional[Any] = None) -> CutSet:
logging.info("About to get AMI train cuts")
def _remove_short_and_long_utt(c: Cut):
if c.duration < 0.1 or c.duration > 25.0:
return False
# In pruned RNN-T, we require that T >= S
# where T is the number of feature frames after subsampling
# and S is the number of tokens in the utterance
# In ./zipformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 7) // 2 + 1) // 2
tokens = c.supervisions[0].text
return T >= len(tokens)
cuts_train = load_manifest_lazy(
self.args.manifest_dir / "cuts_train_all.jsonl.gz"
)
return cuts_train.filter(_remove_short_and_long_utt)
@lru_cache()
def eval_ihm_cuts(self) -> CutSet:
logging.info("About to get AliMeeting IHM eval cuts")
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_eval_ihm.jsonl.gz")
return cs.filter(self.remove_short_cuts)
@lru_cache()
def eval_sdm_cuts(self) -> CutSet:
logging.info("About to get AliMeeting SDM eval cuts")
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_eval_sdm.jsonl.gz")
return cs.filter(self.remove_short_cuts)
@lru_cache()
def eval_gss_cuts(self) -> CutSet:
if not (self.args.manifest_dir / "cuts_eval_gss.jsonl.gz").exists():
logging.info("No GSS dev cuts found")
return None
logging.info("About to get AliMeeting GSS-enhanced eval cuts")
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_eval_gss.jsonl.gz")
return cs.filter(self.remove_short_cuts)
@lru_cache()
def test_ihm_cuts(self) -> CutSet:
logging.info("About to get AliMeeting IHM test cuts")
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_ihm.jsonl.gz")
return cs.filter(self.remove_short_cuts)
@lru_cache()
def test_sdm_cuts(self) -> CutSet:
logging.info("About to get AliMeeting SDM test cuts")
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_sdm.jsonl.gz")
return cs.filter(self.remove_short_cuts)
@lru_cache()
def test_gss_cuts(self) -> CutSet:
if not (self.args.manifest_dir / "cuts_test_gss.jsonl.gz").exists():
logging.info("No GSS test cuts found")
return None
logging.info("About to get AliMeeting GSS-enhanced test cuts")
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_gss.jsonl.gz")
return cs.filter(self.remove_short_cuts)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/beam_search.py

View File

@ -0,0 +1,698 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./pruned_transducer_stateless7/decode.py \
--epoch 15 \
--avg 8 \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 500 \
--decoding-method greedy_search
(2) modified beam search
./pruned_transducer_stateless7/decode.py \
--epoch 15 \
--avg 8 \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 500 \
--decoding-method modified_beam_search \
--beam-size 4
(3) fast beam search
./pruned_transducer_stateless7/decode.py \
--epoch 15 \
--avg 8 \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import AlimeetingAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest_LG,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall import NgramLm
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=10,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = model.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append([lexicon.token_table[idx] for idx in hyp])
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 100
else:
log_interval = 2
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
texts = [list(str(text).replace(" ", "")) for text in texts]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
lexicon=lexicon,
decoding_graph=decoding_graph,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
this_batch.append((cut_id, ref_text, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
AlimeetingAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_nbest_LG",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
model.device = device
if "fast_beam_search" in params.decoding_method:
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
alimeeting = AlimeetingAsrDataModule(args)
eval_ihm_cuts = alimeeting.eval_ihm_cuts()
test_ihm_cuts = alimeeting.test_ihm_cuts()
eval_sdm_cuts = alimeeting.eval_sdm_cuts()
test_sdm_cuts = alimeeting.test_sdm_cuts()
eval_gss_cuts = alimeeting.eval_gss_cuts()
test_gss_cuts = alimeeting.test_gss_cuts()
eval_ihm_dl = alimeeting.test_dataloaders(eval_ihm_cuts)
test_ihm_dl = alimeeting.test_dataloaders(test_ihm_cuts)
eval_sdm_dl = alimeeting.test_dataloaders(eval_sdm_cuts)
test_sdm_dl = alimeeting.test_dataloaders(test_sdm_cuts)
if eval_gss_cuts is not None:
eval_gss_dl = alimeeting.test_dataloaders(eval_gss_cuts)
if test_gss_cuts is not None:
test_gss_dl = alimeeting.test_dataloaders(test_gss_cuts)
test_sets = {
"eval_ihm": (eval_ihm_dl, eval_ihm_cuts),
"test_ihm": (test_ihm_dl, test_ihm_cuts),
"eval_sdm": (eval_sdm_dl, eval_sdm_cuts),
"test_sdm": (test_sdm_dl, test_sdm_cuts),
}
if eval_gss_cuts is not None:
test_sets["eval_gss"] = (eval_gss_dl, eval_gss_cuts)
if test_gss_cuts is not None:
test_sets["test_gss"] = (test_gss_dl, test_gss_cuts)
for test_set in test_sets:
logging.info(f"Decoding {test_set}")
dl, cuts = test_sets[test_set]
results_dict = decode_dataset(
dl=dl,
params=params,
model=model,
lexicon=lexicon,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/encoder_interface.py

View File

@ -0,0 +1,320 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
(1) Export to torchscript model using torch.jit.script()
./pruned_transducer_stateless7/export.py \
--exp-dir ./pruned_transducer_stateless7/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 30 \
--avg 9 \
--jit 1
It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
load it by `torch.jit.load("cpu_jit.pt")`.
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
Check
https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall.
(2) Export `model.state_dict()`
./pruned_transducer_stateless7/export.py \
--exp-dir ./pruned_transducer_stateless7/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
load it by `icefall.checkpoint.load_checkpoint()`.
To use the generated file with `pruned_transducer_stateless7/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./pruned_transducer_stateless7/decode.py \
--exp-dir ./pruned_transducer_stateless7/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
Check ./pretrained.py for its usage.
Note: If you don't want to train a model from scratch, we have
provided one for you. You can get it at
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
with the following commands:
sudo apt-get install git-lfs
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
# You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=15,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=8,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="The lang dir",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
It will generate a file named cpu_jit.pt
Check ./jit_pretrained.py for how to use it.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
if params.jit is True:
convert_scaled_to_non_scaled(model, inplace=True)
logging.info("Using torch.jit.script()")
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torchscript. Export model.state_dict()")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/model.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/optim.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/pretrained.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/test_model.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py

View File

@ -0,0 +1 @@
../../../egs/aishell/ASR/shared

View File

@ -1,5 +1,62 @@
## Results
### zipformer_mmi (zipformer with mmi loss)
See <https://github.com/k2-fsa/icefall/pull/746> for more details.
[zipformer_mmi](./zipformer_mmi)
The tensorboard log can be found at
<https://tensorboard.dev/experiment/xyOZUKpEQm62HBIlUD4uPA/>
You can find a pretrained model, training logs, decoding logs, and decoding
results at:
<https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-mmi-2022-12-08>
Number of model parameters: 69136519, i.e., 69.14 M
| | test-clean | test-other | comment |
|--------------------------|------------|-------------|---------------------|
| 1best | 2.54 | 5.65 | --epoch 30 --avg 10 |
| nbest | 2.54 | 5.66 | --epoch 30 --avg 10 |
| nbest-rescoring-LG | 2.49 | 5.42 | --epoch 30 --avg 10 |
| nbest-rescoring-3-gram | 2.52 | 5.62 | --epoch 30 --avg 10 |
| nbest-rescoring-4-gram | 2.5 | 5.51 | --epoch 30 --avg 10 |
The training commands are:
```bash
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./zipformer_mmi/train.py \
--world-size 4 \
--master-port 12345 \
--num-epochs 30 \
--start-epoch 1 \
--lang-dir data/lang_bpe_500 \
--max-duration 500 \
--full-libri 1 \
--use-fp16 1 \
--exp-dir zipformer_mmi/exp
```
The decoding commands for the transducer branch are:
```bash
export CUDA_VISIBLE_DEVICES="5"
for m in nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescoring-4-gram; do
./zipformer_mmi/decode.py \
--epoch 30 \
--avg 10 \
--exp-dir ./zipformer_mmi/exp/ \
--max-duration 100 \
--lang-dir data/lang_bpe_500 \
--nbest-scale 1.2 \
--hp-scale 1.0 \
--decoding-method $m
done
```
### pruned_transducer_stateless7_ctc (zipformer with transducer loss and ctc loss)
See <https://github.com/k2-fsa/icefall/pull/683> for more details.

View File

@ -291,7 +291,10 @@ def main():
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
[
[i, 0, feature_lengths[i] // params.subsampling_factor]
for i in range(batch_size)
],
dtype=torch.int32,
)

View File

@ -339,7 +339,10 @@ def main():
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
[
[i, 0, feature_lengths[i] // params.subsampling_factor]
for i in range(batch_size)
],
dtype=torch.int32,
)

View File

@ -660,14 +660,22 @@ def main():
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
# CAUTION: `test_sets` is for displaying only.
# If you want to skip test-clean, you have to skip
# it inside the for loop. That is, use
#
# if test_set == 'test-clean': continue
#
test_sets = ["test-clean", "test-other"]
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
test_dls = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,

View File

@ -30,6 +30,8 @@ import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
@ -100,6 +102,41 @@ def get_parser():
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_mmi/exp-attn",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe_500",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
parser.add_argument(
"--use-pruned-intersect",
type=str2bool,
default=False,
help="""Whether to use `intersect_dense_pruned` to get denominator
lattice.""",
)
return parser
@ -114,12 +151,6 @@ def get_params() -> AttributeDict:
Explanation of options saved in `params`:
- exp_dir: It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
- lang_dir: It contains language related input files such as
"lexicon.txt"
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
@ -164,8 +195,6 @@ def get_params() -> AttributeDict:
"""
params = AttributeDict(
{
"exp_dir": Path("conformer_mmi/exp_500_with_attention"),
"lang_dir": Path("data/lang_bpe_500"),
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
@ -184,15 +213,12 @@ def get_params() -> AttributeDict:
"beam_size": 6, # will change it to 8 after some batches (see code)
"reduction": "sum",
"use_double_scores": True,
# "att_rate": 0.0,
# "num_decoder_layers": 0,
"att_rate": 0.7,
"num_decoder_layers": 6,
# parameters for Noam
"weight_decay": 1e-6,
"lr_factor": 5.0,
"warm_step": 80000,
"use_pruned_intersect": False,
"den_scale": 1.0,
# use alignments before this number of batches
"use_ali_until": 13000,
@ -661,7 +687,7 @@ def run(rank, world_size, args):
params = get_params()
params.update(vars(args))
fix_random_seed(42)
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
@ -745,8 +771,29 @@ def run(rank, world_size, args):
valid_ali = None
librispeech = LibriSpeechAsrDataModule(args)
train_dl = librispeech.train_dataloaders()
valid_dl = librispeech.valid_dataloaders()
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 20.0
train_cuts = train_cuts.filter(remove_short_and_long_utt)
train_dl = librispeech.train_dataloaders(train_cuts)
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch)
@ -796,6 +843,7 @@ def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = args.world_size
assert world_size >= 1

View File

@ -30,6 +30,8 @@ import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
@ -100,6 +102,26 @@ def get_parser():
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_mmi/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe_500",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
parser.add_argument(
"--seed",
type=int,
@ -107,6 +129,14 @@ def get_parser():
help="The seed for random generators intended for reproducibility",
)
parser.add_argument(
"--use-pruned-intersect",
type=str2bool,
default=False,
help="""Whether to use `intersect_dense_pruned` to get denominator
lattice.""",
)
return parser
@ -121,12 +151,6 @@ def get_params() -> AttributeDict:
Explanation of options saved in `params`:
- exp_dir: It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
- lang_dir: It contains language related input files such as
"lexicon.txt"
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
@ -171,8 +195,6 @@ def get_params() -> AttributeDict:
"""
params = AttributeDict(
{
"exp_dir": Path("conformer_mmi/exp_500"),
"lang_dir": Path("data/lang_bpe_500"),
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
@ -193,13 +215,10 @@ def get_params() -> AttributeDict:
"use_double_scores": True,
"att_rate": 0.0,
"num_decoder_layers": 0,
# "att_rate": 0.7,
# "num_decoder_layers": 6,
# parameters for Noam
"weight_decay": 1e-6,
"lr_factor": 5.0,
"warm_step": 80000,
"use_pruned_intersect": False,
"den_scale": 1.0,
# use alignments before this number of batches
"use_ali_until": 13000,
@ -752,8 +771,29 @@ def run(rank, world_size, args):
valid_ali = None
librispeech = LibriSpeechAsrDataModule(args)
train_dl = librispeech.train_dataloaders()
valid_dl = librispeech.valid_dataloaders()
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 20.0
train_cuts = train_cuts.filter(remove_short_and_long_utt)
train_dl = librispeech.train_dataloaders(train_cuts)
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch)
@ -804,6 +844,7 @@ def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = args.world_size
assert world_size >= 1

View File

@ -2,7 +2,7 @@
lang_dir=data/lang_bpe_500
for ngram in 2 3 5; do
for ngram in 2 3 4 5; do
if [ ! -f $lang_dir/${ngram}gram.arpa ]; then
./shared/make_kn_lm.py \
-ngram-order ${ngram} \

View File

@ -81,7 +81,6 @@ class Zipformer(EncoderInterface):
super(Zipformer, self).__init__()
self.num_features = num_features
self.encoder_unmasked_dims = encoder_unmasked_dims
assert 0 < encoder_dims[0] <= encoder_dims[1]
self.encoder_dims = encoder_dims
self.encoder_unmasked_dims = encoder_unmasked_dims

View File

@ -72,14 +72,14 @@ Check ./pretrained.py for its usage.
Note: If you don't want to train a model from scratch, we have
provided one for you. You can get it at
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-ctc-2022-12-01
with the following commands:
sudo apt-get install git-lfs
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
# You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-ctc-2022-12-01
# You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-ctc-2022-12-01/exp
"""
import argparse

View File

@ -304,7 +304,10 @@ def main():
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
[
[i, 0, feature_lengths[i] // params.subsampling_factor]
for i in range(batch_size)
],
dtype=torch.int32,
)

View File

@ -322,7 +322,10 @@ def main():
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
[
[i, 0, feature_lengths[i] // params.subsampling_factor]
for i in range(batch_size)
],
dtype=torch.int32,
)

View File

@ -0,0 +1,26 @@
This recipe implements Zipformer-MMI model.
See https://k2-fsa.github.io/icefall/recipes/librispeech/zipformer_mmi.html for detailed tutorials.
It uses **CTC loss for warm-up** and then switches to MMI loss during training.
For decoding, it uses HP (H is ctc_topo, P is token-level bi-gram) as decoding graph. Supported decoding methods are:
- **1best**. Extract the best path from the decoding lattice as the decoding result.
- **nbest**. Extract n paths from the decoding lattice; the path with the highest score is the decoding result.
- **nbest-rescoring-LG**. Extract n paths from the decoding lattice, rescore them with an word-level 3-gram LM, the path with the highest score is the decoding result.
- **nbest-rescoring-3-gram**. Extract n paths from the decoding lattice, rescore them with an token-level 3-gram LM, the path with the highest score is the decoding result.
- **nbest-rescoring-4-gram**. Extract n paths from the decoding lattice, rescore them with an token-level 4-gram LM, the path with the highest score is the decoding result.
Experimental results training on train-clean-100 (epoch-30-avg-10):
- 1best. 6.43 & 17.44
- nbest, nbest-scale=1.2, 6.43 & 17.45
- nbest-rescoring-LG, nbest-scale=1.2, 5.87 & 16.35
- nbest-rescoring-3-gram, nbest-scale=1.2, 6.19 & 16.57
- nbest-rescoring-4-gram, nbest-scale=1.2, 5.87 & 16.07
Experimental results training on full librispeech (epoch-30-avg-10):
- 1best. 2.54 & 5.65
- nbest, nbest-scale=1.2, 2.54 & 5.66
- nbest-rescoring-LG, nbest-scale=1.2, 2.49 & 5.42
- nbest-rescoring-3-gram, nbest-scale=1.2, 2.52 & 5.62
- nbest-rescoring-4-gram, nbest-scale=1.2, 2.5 & 5.51

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/asr_datamodule.py

View File

@ -0,0 +1,736 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Liyong Guo,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) 1best
./zipformer_mmi/mmi_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer_mmi/exp \
--max-duration 100 \
--decoding-method 1best
(2) nbest
./zipformer_mmi/mmi_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer_mmi/exp \
--max-duration 100 \
--nbest-scale 1.0 \
--decoding-method nbest
(3) nbest-rescoring-LG
./zipformer_mmi/mmi_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer_mmi/exp \
--max-duration 100 \
--nbest-scale 1.0 \
--decoding-method nbest-rescoring-LG
(4) nbest-rescoring-3-gram
./zipformer_mmi/mmi_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer_mmi/exp \
--max-duration 100 \
--nbest-scale 1.0 \
--decoding-method nbest-rescoring-3-gram
(5) nbest-rescoring-4-gram
./zipformer_mmi/mmi_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer_mmi/exp \
--max-duration 100 \
--nbest-scale 1.0 \
--decoding-method nbest-rescoring-4-gram
"""
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from train import add_model_arguments, get_ctc_model, get_params
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.decode import (
get_lattice,
nbest_decoding,
nbest_rescore_with_LM,
one_best_decoding,
)
from icefall.lexicon import Lexicon
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
from icefall.utils import (
AttributeDict,
get_texts,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer_mmi/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
default="1best",
help="""Decoding method. Use HP as decoding graph, where H is
ctc_topo and P is token-level bi-gram lm.
Supported values are:
- (1) 1best. Extract the best path from the decoding lattice as the
decoding result.
- (2) nbest. Extract n paths from the decoding lattice; the path
with the highest score is the decoding result.
- (4) nbest-rescoring-LG. Extract n paths from the decoding lattice,
rescore them with an word-level 3-gram LM, the path with the
highest score is the decoding result.
- (5) nbest-rescoring-3-gram. Extract n paths from the decoding
lattice, rescore them with an token-level 3-gram LM, the path with
the highest score is the decoding result.
- (6) nbest-rescoring-4-gram. Extract n paths from the decoding
lattice, rescore them with an token-level 4-gram LM, the path with
the highest score is the decoding result.
""",
)
parser.add_argument(
"--num-paths",
type=int,
default=100,
help="""Number of paths for n-best based decoding method.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, and nbest-oracle
""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=1.0,
help="""The scale to be applied to `lattice.scores`.
It's needed if you use any kinds of n-best based rescoring.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, and nbest-oracle
A smaller value results in more unique paths.
""",
)
parser.add_argument(
"--hp-scale",
type=float,
default=1.0,
help="""The scale to be applied to `ctc_topo_P.scores`.
""",
)
add_model_arguments(parser)
return parser
def get_decoding_params() -> AttributeDict:
"""Parameters for decoding."""
params = AttributeDict(
{
"frame_shift_ms": 10,
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
}
)
return params
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
HP: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
batch: dict,
G: Optional[k2.Fsa] = None,
LG: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if no rescoring is used, the key is the string `no_rescore`.
If LM rescoring is used, the key is the string `lm_scale_xxx`,
where `xxx` is the value of `lm_scale`. An example key is
`lm_scale_0.7`
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
- params.decoding_method is "1best", it uses 1best decoding without LM rescoring.
- params.decoding_method is "nbest", it uses nbest decoding without LM rescoring.
- params.decoding_method is "nbest-rescoring-LG", it uses nbest rescoring with word-level 3-gram LM.
- params.decoding_method is "nbest-rescoring-3-gram", it uses nbest rescoring with token-level 3-gram LM.
- params.decoding_method is "nbest-rescoring-4-gram", it uses nbest rescoring with token-level 4-gram LM.
model:
The neural model.
HP:
The decoding graph. H is ctc_topo, P is token-level bi-gram LM.
bpe_model:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
LG:
An LM. L is the lexicon, G is a word-level 3-gram LM.
It is used when params.decoding_method is "nbest-rescoring-LG".
G:
An LM. L is the lexicon, G is a token-level 3-gram or 4-gram LM.
It is used when params.decoding_method is "nbest-rescoring-3-gram"
or "nbest-rescoring-4-gram".
Returns:
Return the decoding result. See above description for the format of
the returned dict. Note: If it decodes to nothing, then return None.
"""
device = HP.device
feature = batch["inputs"]
assert feature.ndim == 3, feature.shape
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
nnet_output, encoder_out_lens = model(x=feature, x_lens=feature_lens)
# nnet_output is (N, T, C)
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
supervisions["start_frame"] // params.subsampling_factor,
supervisions["num_frames"] // params.subsampling_factor,
),
1,
).to(torch.int32)
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=HP,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
method = params.decoding_method
if method in ["1best", "nbest"]:
if method == "1best":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
key = "no_rescore"
else:
best_path = nbest_decoding(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
nbest_scale=params.nbest_scale,
)
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
# Note: `best_path.aux_labels` contains token IDs, not word IDs
# since we are using HP, not HLG here.
#
# token_ids is a lit-of-list of IDs
token_ids = get_texts(best_path)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
return {key: hyps}
assert method in [
"nbest-rescoring-LG", # word-level 3-gram lm
"nbest-rescoring-3-gram", # token-level 3-gram lm
"nbest-rescoring-4-gram", # token-level 4-gram lm
]
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
if method == "nbest-rescoring-LG":
assert LG is not None
LM = LG
else:
assert G is not None
LM = G
best_path_dict = nbest_rescore_with_LM(
lattice=lattice,
LM=LM,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
nbest_scale=params.nbest_scale,
)
ans = dict()
suffix = f"-nbest-scale-{params.nbest_scale}-{params.num_paths}"
for lm_scale_str, best_path in best_path_dict.items():
token_ids = get_texts(best_path)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
ans[lm_scale_str + suffix] = hyps
return ans
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
HP: k2.Fsa,
bpe_model: spm.SentencePieceProcessor,
G: Optional[k2.Fsa] = None,
LG: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
HP:
The decoding graph. H is ctc_topo, P is token-level bi-gram LM.
bpe_model:
The BPE model.
LG:
An LM. L is the lexicon, G is a word-level 3-gram LM.
It is used when params.decoding_method is "nbest-rescoring-LG".
G:
An LM. L is the lexicon, G is a token-level 3-gram or 4-gram LM.
It is used when params.decoding_method is "nbest-rescoring-3-gram"
or "nbest-rescoring-4-gram".
Returns:
Return a dict, whose key may be "no-rescore" if no LM rescoring
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
HP=HP,
bpe_model=bpe_model,
batch=batch,
G=G,
LG=LG,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}-{key}", results)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
params = get_params()
# add decoding params
params.update(get_decoding_params())
params.update(vars(args))
assert params.decoding_method in (
"1best",
"nbest",
"nbest-rescoring-LG", # word-level 3-gram lm
"nbest-rescoring-3-gram", # token-level 3-gram lm
"nbest-rescoring-4-gram", # token-level 4-gram lm
), params.decoding_method
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
params.vocab_size = num_classes
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = 0
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(str(params.lang_dir / "bpe.model"))
mmi_graph_compiler = MmiTrainingGraphCompiler(
params.lang_dir,
uniq_filename="lexicon.txt",
device=device,
oov="<UNK>",
sos_id=1,
eos_id=1,
)
HP = mmi_graph_compiler.ctc_topo_P
HP.scores *= params.hp_scale
if not hasattr(HP, "lm_scores"):
HP.lm_scores = HP.scores.clone()
LG = None
G = None
if params.decoding_method == "nbest-rescoring-LG":
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
LG = k2.Fsa.from_dict(torch.load(lg_filename, map_location=device))
LG = k2.Fsa.from_fsas([LG]).to(device)
LG.lm_scores = LG.scores.clone()
elif params.decoding_method in ["nbest-rescoring-3-gram", "nbest-rescoring-4-gram"]:
order = params.decoding_method[-6]
assert order in ("3", "4"), (params.decoding_method, order)
order = int(order)
if not (params.lang_dir / f"{order}gram.pt").is_file():
logging.info(f"Loading {order}gram.fst.txt")
logging.warning("It may take a few minutes.")
with open(params.lang_dir / f"{order}gram.fst.txt") as f:
first_token_disambig_id = lexicon.token_table["#0"]
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
# G.aux_labels is not needed in later computations, so
# remove it here.
del G.aux_labels
# CAUTION: The following line is crucial.
# Arcs entering the back-off state have label equal to #0.
# We have to change it to 0 here.
G.labels[G.labels >= first_token_disambig_id] = 0
G = k2.Fsa.from_fsas([G]).to(device)
# G = k2.remove_epsilon(G)
G = k2.arc_sort(G)
# Save a dummy value so that it can be loaded in C++.
# See https://github.com/pytorch/pytorch/issues/67902
# for why we need to do this.
G.dummy = 1
torch.save(G.as_dict(), params.lang_dir / f"{order}gram.pt")
else:
logging.info(f"Loading pre-compiled {order}gram.pt")
d = torch.load(params.lang_dir / f"{order}gram.pt", map_location=device)
G = k2.Fsa.from_dict(d)
G.lm_scores = G.scores.clone()
logging.info("About to create model")
model = get_ctc_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
HP=HP,
bpe_model=bpe_model,
G=G,
LG=LG,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/encoder_interface.py

View File

@ -0,0 +1,307 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
(1) Export to torchscript model using torch.jit.script()
./zipformer_mmi/export.py \
--exp-dir ./zipformer_mmi/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 30 \
--avg 9 \
--jit 1
It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
load it by `torch.jit.load("cpu_jit.pt")`.
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
Check
https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall.
(2) Export `model.state_dict()`
./zipformer_mmi/export.py \
--exp-dir ./zipformer_mmi/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
load it by `icefall.checkpoint.load_checkpoint()`.
To use the generated file with `zipformer_mmi/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./zipformer_mmi/decode.py \
--exp-dir ./zipformer_mmi/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
Check ./pretrained.py for its usage.
Note: If you don't want to train a model from scratch, we have
provided one for you. You can get it at
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-mmi-2022-12-08
with the following commands:
sudo apt-get install git-lfs
git lfs install
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-mmi-2022-12-08
# You will find the pre-trained model in icefall-asr-librispeech-zipformer-mmi-2022-12-08/exp
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_ctc_model, get_params
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer_mmi/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
It will generate a file named cpu_jit.pt
Check ./jit_pretrained.py for how to use it.
""",
)
add_model_arguments(parser)
return parser
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_ctc_model(params)
model.to(device)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
if params.jit is True:
convert_scaled_to_non_scaled(model, inplace=True)
logging.info("Using torch.jit.script()")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torchscript. Export model.state_dict()")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,391 @@
#!/usr/bin/env python3
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengwei)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads torchscript models, exported by `torch.jit.script()`
and uses them to decode waves.
You can use the following command to get the exported models:
./zipformer_mmi/export.py \
--exp-dir ./zipformer_mmi/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10 \
--jit 1
Usage of this script:
(1) 1best
./zipformer_mmi/jit_pretrained.py \
--nn-model-filename ./zipformer_mmi/exp/cpu_jit.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method 1best \
/path/to/foo.wav \
/path/to/bar.wav
(2) nbest
./zipformer_mmi/jit_pretrained.py \
--nn-model-filename ./zipformer_mmi/exp/cpu_jit.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--nbest-scale 1.2 \
--method nbest \
/path/to/foo.wav \
/path/to/bar.wav
(3) nbest-rescoring-LG
./zipformer_mmi/jit_pretrained.py \
--nn-model-filename ./zipformer_mmi/exp/cpu_jit.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--nbest-scale 1.2 \
--method nbest-rescoring-LG \
/path/to/foo.wav \
/path/to/bar.wav
(4) nbest-rescoring-3-gram
./zipformer_mmi/jit_pretrained.py \
--nn-model-filename ./zipformer_mmi/exp/cpu_jit.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--nbest-scale 1.2 \
--method nbest-rescoring-3-gram \
/path/to/foo.wav \
/path/to/bar.wav
(5) nbest-rescoring-4-gram
./zipformer_mmi/jit_pretrained.py \
--nn-model-filename ./zipformer_mmi/exp/cpu_jit.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--nbest-scale 1.2 \
--method nbest-rescoring-4-gram \
/path/to/foo.wav \
/path/to/bar.wav
"""
import argparse
import logging
import math
from pathlib import Path
from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from decode import get_decoding_params
from torch.nn.utils.rnn import pad_sequence
from train import get_params
from icefall.decode import (
get_lattice,
nbest_decoding,
nbest_rescore_with_LM,
one_best_decoding,
)
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
from icefall.utils import get_texts
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--nn-model-filename",
type=str,
required=True,
help="Path to the torchscript model cpu_jit.pt",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.""",
)
parser.add_argument(
"--method",
type=str,
default="1best",
help="""Decoding method. Use HP as decoding graph, where H is
ctc_topo and P is token-level bi-gram lm.
Supported values are:
- (1) 1best. Extract the best path from the decoding lattice as the
decoding result.
- (2) nbest. Extract n paths from the decoding lattice; the path
with the highest score is the decoding result.
- (4) nbest-rescoring-LG. Extract n paths from the decoding lattice,
rescore them with an word-level 3-gram LM, the path with the
highest score is the decoding result.
- (5) nbest-rescoring-3-gram. Extract n paths from the decoding
lattice, rescore them with an token-level 3-gram LM, the path with
the highest score is the decoding result.
- (6) nbest-rescoring-4-gram. Extract n paths from the decoding
lattice, rescore them with an token-level 4-gram LM, the path with
the highest score is the decoding result.
""",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--num-paths",
type=int,
default=100,
help="""Number of paths for n-best based decoding method.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, and nbest-oracle
""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=1.2,
help="""The scale to be applied to `lattice.scores`.
It's needed if you use any kinds of n-best based rescoring.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, and nbest-oracle
A smaller value results in more unique paths.
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.1,
help="""
Used when method is nbest-rescoring-LG, nbest-rescoring-3-gram,
and nbest-rescoring-4-gram.
It specifies the scale for n-gram LM scores.
(Note: You need to tune it on a dataset.)
""",
)
parser.add_argument(
"--hp-scale",
type=float,
default=1.0,
help="""The scale to be applied to `ctc_topo_P.scores`.
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float = 16000
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
params = get_params()
# add decoding params
params.update(get_decoding_params())
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
model = torch.jit.load(params.nn_model_filename)
model.eval()
model.to(device)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features,
batch_first=True,
padding_value=math.log(1e-10),
)
feature_lengths = torch.tensor(feature_lengths, device=device)
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(str(params.lang_dir / "bpe.model"))
mmi_graph_compiler = MmiTrainingGraphCompiler(
params.lang_dir,
uniq_filename="lexicon.txt",
device=device,
oov="<UNK>",
sos_id=1,
eos_id=1,
)
HP = mmi_graph_compiler.ctc_topo_P
HP.scores *= params.hp_scale
if not hasattr(HP, "lm_scores"):
HP.lm_scores = HP.scores.clone()
method = params.method
assert method in (
"1best",
"nbest",
"nbest-rescoring-LG", # word-level 3-gram lm
"nbest-rescoring-3-gram", # token-level 3-gram lm
"nbest-rescoring-4-gram", # token-level 4-gram lm
)
# loading language model for rescoring
LM = None
if method == "nbest-rescoring-LG":
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
LG = k2.Fsa.from_dict(torch.load(lg_filename, map_location=device))
LG = k2.Fsa.from_fsas([LG]).to(device)
LG.lm_scores = LG.scores.clone()
LM = LG
elif method in ["nbest-rescoring-3-gram", "nbest-rescoring-4-gram"]:
order = method[-6]
assert order in ("3", "4")
order = int(order)
logging.info(f"Loading pre-compiled {order}gram.pt")
d = torch.load(params.lang_dir / f"{order}gram.pt", map_location=device)
G = k2.Fsa.from_dict(d)
G.lm_scores = G.scores.clone()
LM = G
# Encoder forward
nnet_output, encoder_out_lens = model(x=features, x_lens=feature_lengths)
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
[
[i, 0, feature_lengths[i] // params.subsampling_factor]
for i in range(batch_size)
],
dtype=torch.int32,
)
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=HP,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
if method in ["1best", "nbest"]:
if method == "1best":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
else:
best_path = nbest_decoding(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
nbest_scale=params.nbest_scale,
)
else:
best_path_dict = nbest_rescore_with_LM(
lattice=lattice,
LM=LM,
num_paths=params.num_paths,
lm_scale_list=[params.ngram_lm_scale],
nbest_scale=params.nbest_scale,
)
best_path = next(iter(best_path_dict.values()))
# Note: `best_path.aux_labels` contains token IDs, not word IDs
# since we are using HP, not HLG here.
#
# token_ids is a lit-of-list of IDs
token_ids = get_texts(best_path)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,75 @@
# Copyright 2022 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
class CTCModel(nn.Module):
def __init__(
self,
encoder: EncoderInterface,
encoder_dim: int,
vocab_size: int,
):
"""
Args:
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
`logit_lens` of shape (N,).
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
self.encoder = encoder
self.ctc_output = nn.Sequential(
nn.Dropout(p=0.1),
nn.Linear(encoder_dim, vocab_size),
nn.LogSoftmax(dim=-1),
)
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
Returns:
Return the ctc outputs and encoder output lengths.
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
encoder_out, x_lens = self.encoder(x, x_lens)
assert torch.all(x_lens > 0)
# compute ctc log-probs
ctc_output = self.ctc_output(encoder_out)
return ctc_output, x_lens

View File

@ -0,0 +1 @@
../pruned_transducer_stateless7/optim.py

View File

@ -0,0 +1,410 @@
#!/usr/bin/env python3
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengwei)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads a checkpoint and uses it to decode waves.
You can generate the checkpoint with the following command:
./zipformer_mmi/export.py \
--exp-dir ./zipformer_mmi/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10
Usage of this script:
(1) 1best
./zipformer_mmi/pretrained.py \
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method 1best \
/path/to/foo.wav \
/path/to/bar.wav
(2) nbest
./zipformer_mmi/pretrained.py \
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--nbest-scale 1.2 \
--method nbest \
/path/to/foo.wav \
/path/to/bar.wav
(3) nbest-rescoring-LG
./zipformer_mmi/pretrained.py \
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--nbest-scale 1.2 \
--method nbest-rescoring-LG \
/path/to/foo.wav \
/path/to/bar.wav
(4) nbest-rescoring-3-gram
./zipformer_mmi/pretrained.py \
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--nbest-scale 1.2 \
--method nbest-rescoring-3-gram \
/path/to/foo.wav \
/path/to/bar.wav
(5) nbest-rescoring-4-gram
./zipformer_mmi/pretrained.py \
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--nbest-scale 1.2 \
--method nbest-rescoring-4-gram \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./zipformer_mmi/exp/epoch-xx.pt`.
Note: ./zipformer_mmi/exp/pretrained.pt is generated by
./zipformer_mmi/export.py
"""
import argparse
import logging
import math
from pathlib import Path
from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from decode import get_decoding_params
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_ctc_model, get_params
from icefall.decode import (
get_lattice,
nbest_decoding,
nbest_rescore_with_LM,
one_best_decoding,
)
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
from icefall.utils import get_texts
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.""",
)
parser.add_argument(
"--method",
type=str,
default="1best",
help="""Decoding method. Use HP as decoding graph, where H is
ctc_topo and P is token-level bi-gram lm.
Supported values are:
- (1) 1best. Extract the best path from the decoding lattice as the
decoding result.
- (2) nbest. Extract n paths from the decoding lattice; the path
with the highest score is the decoding result.
- (4) nbest-rescoring-LG. Extract n paths from the decoding lattice,
rescore them with an word-level 3-gram LM, the path with the
highest score is the decoding result.
- (5) nbest-rescoring-3-gram. Extract n paths from the decoding
lattice, rescore them with an token-level 3-gram LM, the path with
the highest score is the decoding result.
- (6) nbest-rescoring-4-gram. Extract n paths from the decoding
lattice, rescore them with an token-level 4-gram LM, the path with
the highest score is the decoding result.
""",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--num-paths",
type=int,
default=100,
help="""Number of paths for n-best based decoding method.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, and nbest-oracle
""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=1.2,
help="""The scale to be applied to `lattice.scores`.
It's needed if you use any kinds of n-best based rescoring.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, and nbest-oracle
A smaller value results in more unique paths.
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.1,
help="""
Used when method is nbest-rescoring-LG, nbest-rescoring-3-gram,
and nbest-rescoring-4-gram.
It specifies the scale for n-gram LM scores.
(Note: You need to tune it on a dataset.)
""",
)
parser.add_argument(
"--hp-scale",
type=float,
default=1.0,
help="""The scale to be applied to `ctc_topo_P.scores`.
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
add_model_arguments(parser)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
# add decoding params
params.update(get_decoding_params())
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_ctc_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
feature_lengths = torch.tensor(feature_lengths, device=device)
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(str(params.lang_dir / "bpe.model"))
mmi_graph_compiler = MmiTrainingGraphCompiler(
params.lang_dir,
uniq_filename="lexicon.txt",
device=device,
oov="<UNK>",
sos_id=1,
eos_id=1,
)
HP = mmi_graph_compiler.ctc_topo_P
HP.scores *= params.hp_scale
if not hasattr(HP, "lm_scores"):
HP.lm_scores = HP.scores.clone()
method = params.method
assert method in (
"1best",
"nbest",
"nbest-rescoring-LG", # word-level 3-gram lm
"nbest-rescoring-3-gram", # token-level 3-gram lm
"nbest-rescoring-4-gram", # token-level 4-gram lm
)
# loading language model for rescoring
LM = None
if method == "nbest-rescoring-LG":
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
LG = k2.Fsa.from_dict(torch.load(lg_filename, map_location=device))
LG = k2.Fsa.from_fsas([LG]).to(device)
LG.lm_scores = LG.scores.clone()
LM = LG
elif method in ["nbest-rescoring-3-gram", "nbest-rescoring-4-gram"]:
order = method[-6]
assert order in ("3", "4")
order = int(order)
logging.info(f"Loading pre-compiled {order}gram.pt")
d = torch.load(params.lang_dir / f"{order}gram.pt", map_location=device)
G = k2.Fsa.from_dict(d)
G.lm_scores = G.scores.clone()
LM = G
# Encoder forward
nnet_output, encoder_out_lens = model(x=features, x_lens=feature_lengths)
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
[
[i, 0, feature_lengths[i] // params.subsampling_factor]
for i in range(batch_size)
],
dtype=torch.int32,
)
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=HP,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
if method in ["1best", "nbest"]:
if method == "1best":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
else:
best_path = nbest_decoding(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
nbest_scale=params.nbest_scale,
)
else:
best_path_dict = nbest_rescore_with_LM(
lattice=lattice,
LM=LM,
num_paths=params.num_paths,
lm_scale_list=[params.ngram_lm_scale],
nbest_scale=params.nbest_scale,
)
best_path = next(iter(best_path_dict.values()))
# Note: `best_path.aux_labels` contains token IDs, not word IDs
# since we are using HP, not HLG here.
#
# token_ids is a lit-of-list of IDs
token_ids = get_texts(best_path)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless7/scaling.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless7/scaling_converter.py

View File

@ -0,0 +1,57 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./zipformer_mmi/test_model.py
"""
import torch
from train import get_ctc_model, get_params
def test_model():
params = get_params()
params.vocab_size = 500
params.num_encoder_layers = "2,4,3,2,4"
# params.feedforward_dims = "1024,1024,1536,1536,1024"
params.feedforward_dims = "1024,1024,2048,2048,1024"
params.nhead = "8,8,8,8,8"
params.encoder_dims = "384,384,384,384,384"
params.attention_dims = "192,192,192,192,192"
params.encoder_unmasked_dims = "256,256,256,256,256"
params.zipformer_downsampling_factors = "1,2,4,8,2"
params.cnn_module_kernels = "31,31,31,31,31"
model = get_ctc_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
features = torch.randn(2, 100, 80)
feature_lengths = torch.full((2,), 100)
model(x=features, x_lens=feature_lengths)
def main():
test_model()
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../pruned_transducer_stateless7/zipformer.py

View File

@ -861,15 +861,41 @@ def run(rank, world_size, args):
valid_cuts = wenetspeech.valid_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 15.0 seconds
# Keep only utterances with duration between 1 second and 10 seconds
#
# Caution: There is a reason to select 15.0 here. Please see
# Caution: There is a reason to select 10.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 15.0
if c.duration < 1.0 or c.duration > 10.0:
logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
)
return False
# In pruned RNN-T, we require that T >= S
# where T is the number of feature frames after subsampling
# and S is the number of tokens in the utterance
# In ./conformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 1) // 2 - 1) // 2
tokens = c.supervisions[0].text.replace(" ", "")
if T < len(tokens):
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before subsampling): {c.num_frames}. "
f"Number of frames (after subsampling): {T}. "
f"Text: {c.supervisions[0].text}. "
f"Tokens: {tokens}. "
f"Number of tokens: {len(tokens)}"
)
return False
return True
train_cuts = train_cuts.filter(remove_short_and_long_utt)

View File

@ -1006,15 +1006,41 @@ def run(rank, world_size, args):
valid_cuts = wenetspeech.valid_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 15.0 seconds
# Keep only utterances with duration between 1 second and 10 seconds
#
# Caution: There is a reason to select 15.0 here. Please see
# Caution: There is a reason to select 10.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 15.0
if c.duration < 1.0 or c.duration > 10.0:
logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
)
return False
# In pruned RNN-T, we require that T >= S
# where T is the number of feature frames after subsampling
# and S is the number of tokens in the utterance
# In ./conformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 1) // 2 - 1) // 2
tokens = c.supervisions[0].text.replace(" ", "")
if T < len(tokens):
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before subsampling): {c.num_frames}. "
f"Number of frames (after subsampling): {T}. "
f"Text: {c.supervisions[0].text}. "
f"Tokens: {tokens}. "
f"Number of tokens: {len(tokens)}"
)
return False
return True
train_cuts = train_cuts.filter(remove_short_and_long_utt)

View File

@ -717,6 +717,107 @@ def rescore_with_n_best_list(
return ans
def nbest_rescore_with_LM(
lattice: k2.Fsa,
LM: k2.Fsa,
num_paths: int,
lm_scale_list: List[float],
nbest_scale: float = 1.0,
use_double_scores: bool = True,
) -> Dict[str, k2.Fsa]:
"""Rescore an n-best list with an n-gram LM.
The path with the maximum score is used as the decoding output.
Args:
lattice:
An FsaVec with axes [utt][state][arc]. It must have the following
attributes: ``aux_labels`` and ``lm_scores``. They are both token
IDs.
LM:
An FsaVec containing only a single FSA. It is one of follows:
- LG, L is lexicon and G is word-level n-gram LM.
- G, token-level n-gram LM.
num_paths:
Size of nbest list.
lm_scale_list:
A list of floats representing LM score scales.
nbest_scale:
Scale to be applied to ``lattice.score`` when sampling paths
using ``k2.random_paths``.
use_double_scores:
True to use double precision during computation. False to use
single precision.
Returns:
A dict of FsaVec, whose key is an lm_scale and the value is the
best decoding path for each utterance in the lattice.
"""
device = lattice.device
assert len(lattice.shape) == 3
assert hasattr(lattice, "aux_labels")
assert hasattr(lattice, "lm_scores")
assert LM.shape == (1, None, None)
assert LM.device == device
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# nbest.fsa.scores contains 0s
nbest = nbest.intersect(lattice)
# Now nbest.fsa has its scores set
assert hasattr(nbest.fsa, "lm_scores")
# am scores + bi-gram scores
hp_scores = nbest.tot_scores()
# Now start to intersect nbest with LG or G
inv_fsa = k2.invert(nbest.fsa)
if hasattr(LM, "aux_labels"):
# LM is LG here
# delete token IDs as it is not needed
del inv_fsa.aux_labels
inv_fsa.scores.zero_()
inv_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(inv_fsa)
path_to_utt_map = nbest.shape.row_ids(1)
LM = k2.arc_sort(LM)
path_lattice = k2.intersect_device(
LM,
inv_fsa_with_epsilon_loops,
b_to_a_map=torch.zeros_like(path_to_utt_map),
sorted_match_a=True,
)
# Its labels are token IDs.
# If LM is G, its aux_labels are tokens IDs;
# If LM is LG, its aux_labels are words IDs.
path_lattice = k2.top_sort(k2.connect(path_lattice))
one_best = k2.shortest_path(path_lattice, use_double_scores=use_double_scores)
lm_scores = one_best.get_tot_scores(
use_double_scores=use_double_scores,
log_semiring=True, # Note: we always use True
)
# If LM is LG, we might get empty paths
lm_scores[lm_scores == float("-inf")] = -1e9
ans = dict()
for lm_scale in lm_scale_list:
tot_scores = hp_scores.values / lm_scale + lm_scores
tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
key = f"lm_scale_{lm_scale}"
ans[key] = best_path
return ans
def rescore_with_whole_lattice(
lattice: k2.Fsa,
G_with_epsilon_loops: k2.Fsa,

View File

@ -112,8 +112,12 @@ def _compute_mmi_loss_exact_non_optimized(
num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True)
# TODO: pass output_beam as function argument
num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=beam_size)
den_lats = k2.intersect_dense(den_graphs, dense_fsa_vec, output_beam=beam_size)
num_lats = k2.intersect_dense(
num_graphs, dense_fsa_vec, output_beam=beam_size, max_arcs=2147483600
)
den_lats = k2.intersect_dense(
den_graphs, dense_fsa_vec, output_beam=beam_size, max_arcs=2147483600
)
num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
@ -144,7 +148,7 @@ def _compute_mmi_loss_pruned(
"""
num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=False)
num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=10.0)
num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=8.0)
# the values for search_beam/output_beam/min_active_states/max_active_states
# are not tuned. You may want to tune them.