Merge branch 'k2-fsa:master' into dev/lm_multi_zh-hans

This commit is contained in:
zr_jin 2023-11-09 11:08:28 +08:00 committed by GitHub
commit fb541ec60c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
114 changed files with 13074 additions and 165 deletions

View File

@ -18,8 +18,8 @@ log "Downloading pre-commputed fbank from $fbank_url"
git clone https://huggingface.co/csukuangfj/aishell-test-dev-manifests
ln -s $PWD/aishell-test-dev-manifests/data .
log "Downloading pre-trained model from $repo_url"
repo_url=https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20
log "Downloading pre-trained model from $repo_url"
git clone $repo_url
repo=$(basename $repo_url)

View File

@ -0,0 +1,103 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/aishell/ASR
git lfs install
fbank_url=https://huggingface.co/csukuangfj/aishell-test-dev-manifests
log "Downloading pre-commputed fbank from $fbank_url"
git clone https://huggingface.co/csukuangfj/aishell-test-dev-manifests
ln -s $PWD/aishell-test-dev-manifests/data .
log "======================="
log "CI testing large model"
repo_url=https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-large-2023-10-24/
log "Downloading pre-trained model from $repo_url"
git clone $repo_url
repo=$(basename $repo_url)
log "Display test files"
tree $repo/
ls -lh $repo/test_wavs/*.wav
for method in modified_beam_search greedy_search fast_beam_search; do
log "$method"
./zipformer/pretrained.py \
--method $method \
--context-size 1 \
--checkpoint $repo/exp/pretrained.pt \
--tokens $repo/data/lang_char/tokens.txt \
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192 \
$repo/test_wavs/BAC009S0764W0121.wav \
$repo/test_wavs/BAC009S0764W0122.wav \
$repo/test_wavs/BAC009S0764W0123.wav
done
log "======================="
log "CI testing medium model"
repo_url=https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-2023-10-24/
log "Downloading pre-trained model from $repo_url"
git clone $repo_url
repo=$(basename $repo_url)
log "Display test files"
tree $repo/
ls -lh $repo/test_wavs/*.wav
for method in modified_beam_search greedy_search fast_beam_search; do
log "$method"
./zipformer/pretrained.py \
--method $method \
--context-size 1 \
--checkpoint $repo/exp/pretrained.pt \
--tokens $repo/data/lang_char/tokens.txt \
$repo/test_wavs/BAC009S0764W0121.wav \
$repo/test_wavs/BAC009S0764W0122.wav \
$repo/test_wavs/BAC009S0764W0123.wav
done
log "======================="
log "CI testing small model"
repo_url=https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-small-2023-10-24/
log "Downloading pre-trained model from $repo_url"
git clone $repo_url
repo=$(basename $repo_url)
log "Display test files"
tree $repo/
ls -lh $repo/test_wavs/*.wav
for method in modified_beam_search greedy_search fast_beam_search; do
log "$method"
./zipformer/pretrained.py \
--method $method \
--context-size 1 \
--checkpoint $repo/exp/pretrained.pt \
--tokens $repo/data/lang_char/tokens.txt \
--num-encoder-layers 2,2,2,2,2,2 \
--feedforward-dim 512,768,768,768,768,768 \
--encoder-dim 192,256,256,256,256,256 \
--encoder-unmasked-dim 192,192,192,192,192,192 \
$repo/test_wavs/BAC009S0764W0121.wav \
$repo/test_wavs/BAC009S0764W0122.wav \
$repo/test_wavs/BAC009S0764W0123.wav
done

View File

@ -10,6 +10,7 @@ log() {
cd egs/multi_zh-hans/ASR
log "==== Test icefall-asr-multi-zh-hans-zipformer-2023-9-2 ===="
repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/
log "Downloading pre-trained model from $repo_url"
@ -49,3 +50,46 @@ for method in modified_beam_search fast_beam_search; do
$repo/test_wavs/DEV_T0000000001.wav \
$repo/test_wavs/DEV_T0000000002.wav
done
log "==== Test icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24 ===="
repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/
log "Downloading pre-trained model from $repo_url"
git lfs install
git clone $repo_url
repo=$(basename $repo_url)
log "Display test files"
tree $repo/
ls -lh $repo/test_wavs/*.wav
pushd $repo/exp
ln -s epoch-20.pt epoch-99.pt
popd
ls -lh $repo/exp/*.pt
./zipformer/pretrained.py \
--checkpoint $repo/exp/epoch-99.pt \
--tokens $repo/data/lang_bpe_2000/tokens.txt \
--use-ctc 1 \
--method greedy_search \
$repo/test_wavs/DEV_T0000000000.wav \
$repo/test_wavs/DEV_T0000000001.wav \
$repo/test_wavs/DEV_T0000000002.wav
for method in modified_beam_search fast_beam_search; do
log "$method"
./zipformer/pretrained.py \
--method $method \
--beam-size 4 \
--use-ctc 1 \
--checkpoint $repo/exp/epoch-99.pt \
--tokens $repo/data/lang_bpe_2000/tokens.txt \
$repo/test_wavs/DEV_T0000000000.wav \
$repo/test_wavs/DEV_T0000000001.wav \
$repo/test_wavs/DEV_T0000000002.wav
done

View File

@ -0,0 +1,95 @@
# Copyright 2023 Zengrui Jin (Xiaomi Corp.)
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: run-aishell-zipformer-2023-10-24
on:
push:
branches:
- master
pull_request:
types: [labeled]
schedule:
# minute (0-59)
# hour (0-23)
# day of the month (1-31)
# month (1-12)
# day of the week (0-6)
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
concurrency:
group: run_aishell_zipformer_2023_10_24-${{ github.ref }}
cancel-in-progress: true
jobs:
run_aishell_zipformer_2023_10_24:
if: github.event.label.name == 'ready' || github.event.label.name == 'zipformer' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest]
python-version: [3.8]
fail-fast: false
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
cache-dependency-path: '**/requirements-ci.txt'
- name: Install Python dependencies
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
pip uninstall -y protobuf
pip install --no-binary protobuf protobuf==3.20.*
- name: Cache kaldifeat
id: my-cache
uses: actions/cache@v2
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/install-kaldifeat.sh
- name: Inference with pre-trained model
shell: bash
env:
GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
run: |
sudo apt-get -qq install git-lfs tree
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
.github/scripts/run-aishell-zipformer-2023-10-24.sh

View File

@ -29,7 +29,7 @@ concurrency:
jobs:
run_multi-zh_hans_zipformer:
if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'multi-zh_hans'
if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'multi-zh_hans' || github.event.label.name == 'zipformer'
runs-on: ${{ matrix.os }}
strategy:
matrix:

View File

@ -118,11 +118,12 @@ We provide a Colab notebook to run a pre-trained transducer conformer + stateles
#### k2 pruned RNN-T
| Encoder | Params | test-clean | test-other |
|-----------------|--------|------------|------------|
| zipformer | 65.5M | 2.21 | 4.79 |
| zipformer-small | 23.2M | 2.42 | 5.73 |
| zipformer-large | 148.4M | 2.06 | 4.63 |
| Encoder | Params | test-clean | test-other | epochs | devices |
|-----------------|--------|------------|------------|---------|------------|
| zipformer | 65.5M | 2.21 | 4.79 | 50 | 4 32G-V100 |
| zipformer-small | 23.2M | 2.42 | 5.73 | 50 | 2 32G-V100 |
| zipformer-large | 148.4M | 2.06 | 4.63 | 50 | 4 32G-V100 |
| zipformer-large | 148.4M | 2.00 | 4.38 | 174 | 8 80G-A100 |
Note: No auxiliary losses are used in the training and no LMs are used
in the decoding.
@ -366,7 +367,7 @@ Once you have trained a model in icefall, you may want to deploy it with C++,
without Python dependencies.
Please refer to the documentation
<https://icefall.readthedocs.io/en/latest/recipes/librispeech/conformer_ctc.html#deployment-with-c>
<https://icefall.readthedocs.io/en/latest/recipes/Non-streaming-ASR/librispeech/conformer_ctc.html#deployment-with-c>
for how to do this.
We also provide a Colab notebook, showing you how to run a torch scripted model in [k2][k2] with C++.

View File

@ -1,39 +1,37 @@
# Contributing to Our Project
## Pre-commit hooks
Thank you for your interest in contributing to our project! We use Git pre-commit hooks to ensure code quality and consistency. Before contributing, please follow these guidelines to enable and use the pre-commit hooks.
We use [git][git] [pre-commit][pre-commit] [hooks][hooks] to check that files
going to be committed:
## Pre-Commit Hooks
- contain no trailing spaces
- are formatted with [black][black]
- are compatible to [PEP8][PEP8] (checked by [flake8][flake8])
- end in a newline and only a newline
- contain sorted `imports` (checked by [isort][isort])
We have set up pre-commit hooks to check that the files you're committing meet our coding and formatting standards. These checks include:
These hooks are disabled by default. Please use the following commands to enable them:
- Ensuring there are no trailing spaces.
- Formatting code with [black](https://github.com/psf/black).
- Checking compliance with PEP8 using [flake8](https://flake8.pycqa.org/).
- Verifying that files end with a newline character (and only a newline).
- Sorting imports using [isort](https://pycqa.github.io/isort/).
```bash
pip install pre-commit # run it only once
pre-commit install # run it only once, it will install all hooks
Please note that these hooks are disabled by default. To enable them, follow these steps:
# modify some files
git add <some files>
git commit # It runs all hooks automatically.
### Installation (Run only once)
# If all hooks run successfully, you can write the commit message now. Done!
#
# If any hook failed, your commit was not successful.
# Please read the error messages and make changes accordingly.
# And rerun
1. Install the `pre-commit` package using pip:
```bash
pip install pre-commit
```
1. Install the Git hooks using:
```bash
pre-commit install
```
### Making a Commit
Once you have enabled the pre-commit hooks, follow these steps when making a commit:
1. Make your changes to the codebase.
2. Stage your changes by using git add for the files you modified.
3. Commit your changes using git commit. The pre-commit hooks will run automatically at this point.
4. If all hooks run successfully, you can write your commit message, and your changes will be successfully committed.
5. If any hook fails, your commit will not be successful. Please read and follow the error messages provided, make the necessary changes, and then re-run git add and git commit.
git add <some files>
git commit
```
### Your Contribution
Your contributions are valuable to us, and by following these guidelines, you help maintain code consistency and quality in our project. We appreciate your dedication to ensuring high-quality code. If you have questions or need assistance, feel free to reach out to us. Thank you for being part of our open-source community!
[git]: https://git-scm.com/book/en/v2/Customizing-Git-Git-Hooks
[flake8]: https://github.com/PyCQA/flake8
[PEP8]: https://www.python.org/dev/peps/pep-0008/
[black]: https://github.com/psf/black
[hooks]: https://github.com/pre-commit/pre-commit-hooks
[pre-commit]: https://github.com/pre-commit/pre-commit
[isort]: https://github.com/PyCQA/isort

View File

@ -3,7 +3,7 @@ How to create a recipe
.. HINT::
Please read :ref:`follow the code style` to adjust your code sytle.
Please read :ref:`follow the code style` to adjust your code style.
.. CAUTION::

View File

@ -67,7 +67,7 @@ To run stage 2 to stage 5, use:
.. HINT::
A 3-gram language model will be downloaded from huggingface, we assume you have
intalled and initialized ``git-lfs``. If not, you could install ``git-lfs`` by
installed and initialized ``git-lfs``. If not, you could install ``git-lfs`` by
.. code-block:: bash

View File

@ -67,7 +67,7 @@ To run stage 2 to stage 5, use:
.. HINT::
A 3-gram language model will be downloaded from huggingface, we assume you have
intalled and initialized ``git-lfs``. If not, you could install ``git-lfs`` by
installed and initialized ``git-lfs``. If not, you could install ``git-lfs`` by
.. code-block:: bash

View File

@ -418,7 +418,7 @@ The following shows two examples (for two types of checkpoints):
- ``beam_search`` : It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf and
`espnet/nets/beam_search_transducer.py <https://github.com/espnet/espnet/blob/master/espnet/nets/beam_search_transducer.py#L247>`_
is used as a reference. Basicly, it keeps topk states for each frame, and expands the kept states with their own contexts to
is used as a reference. Basically, it keeps topk states for each frame, and expands the kept states with their own contexts to
next frame.
- ``modified_beam_search`` : It implements the same algorithm as ``beam_search`` above, but it

View File

@ -1,6 +1,6 @@
.. _train_nnlm:
Train an RNN langugage model
Train an RNN language model
======================================
If you have enough text data, you can train a neural network language model (NNLM) to improve

View File

@ -32,7 +32,7 @@ In icefall, we implement the streaming conformer the way just like what `WeNet <
.. HINT::
If you want to modify a non-streaming conformer recipe to support both streaming and non-streaming, please refer
to `this pull request <https://github.com/k2-fsa/icefall/pull/454>`_. After adding the code needed by streaming training,
you have to re-train it with the extra arguments metioned in the docs above to get a streaming model.
you have to re-train it with the extra arguments mentioned in the docs above to get a streaming model.
Streaming Emformer

View File

@ -584,7 +584,7 @@ The following shows two examples (for the two types of checkpoints):
- ``beam_search`` : It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf and
`espnet/nets/beam_search_transducer.py <https://github.com/espnet/espnet/blob/master/espnet/nets/beam_search_transducer.py#L247>`_
is used as a reference. Basicly, it keeps topk states for each frame, and expands the kept states with their own contexts to
is used as a reference. Basically, it keeps topk states for each frame, and expands the kept states with their own contexts to
next frame.
- ``modified_beam_search`` : It implements the same algorithm as ``beam_search`` above, but it
@ -648,7 +648,7 @@ command to extract ``model.state_dict()``.
.. caution::
``--streaming-model`` and ``--causal-convolution`` require to be True to export
a streaming mdoel.
a streaming model.
It will generate a file ``./pruned_transducer_stateless4/exp/pretrained.pt``.
@ -697,7 +697,7 @@ Export model using ``torch.jit.script()``
.. caution::
``--streaming-model`` and ``--causal-convolution`` require to be True to export
a streaming mdoel.
a streaming model.
It will generate a file ``cpu_jit.pt`` in the given ``exp_dir``. You can later
load it by ``torch.jit.load("cpu_jit.pt")``.

View File

@ -7,6 +7,8 @@ set -eou pipefail
stage=-1
stop_stage=100
perturb_speed=true
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
@ -77,7 +79,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute fbank for aidatatang_200zh"
if [ ! -f data/fbank/.aidatatang_200zh.done ]; then
mkdir -p data/fbank
./local/compute_fbank_aidatatang_200zh.py --perturb-speed True
./local/compute_fbank_aidatatang_200zh.py --perturb-speed ${perturb_speed}
touch data/fbank/.aidatatang_200zh.done
fi
fi

View File

@ -102,7 +102,7 @@ class Aidatatang_200zhAsrDataModule:
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
default=False,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
@ -289,6 +289,7 @@ class Aidatatang_200zhAsrDataModule:
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=True,
buffer_size=50000,
)
else:
logging.info("Using SimpleCutSampler.")

View File

@ -1,10 +1,12 @@
# Introduction
Please refer to <https://icefall.readthedocs.io/en/latest/recipes/Non-streaming-ASR/aishell/index.html>
for how to run models in this recipe.
Please refer to <https://k2-fsa.github.io/icefall/recipes/Non-streaming-ASR/aishell/index.html> for how to run models in this recipe.
Aishell is an open-source Chinese Mandarin speech corpus published by Beijing Shell Shell Technology Co., Ltd.
400 people from different accent areas in China are invited to participate in the recording, which is conducted in a quiet indoor environment using high fidelity microphone and downsampled to 16kHz. The manual transcription accuracy is above 95%, through professional speech annotation and strict quality inspection. The data is free for academic use. We hope to provide moderate amount of data for new researchers in the field of speech recognition.
(From [Open Speech and Language Resources](https://www.openslr.org/33/))
# Transducers

View File

@ -1,6 +1,162 @@
## Results
### Aishell training result(Stateless Transducer)
### Aishell training result (Stateless Transducer)
#### Zipformer (Non-streaming)
[./zipformer](./zipformer)
It's reworked Zipformer with Pruned RNNT loss.
**Caution**: It uses `--context-size=1`.
##### normal-scaled model, number of model parameters: 73412551, i.e., 73.41 M
| | test | dev | comment |
|------------------------|------|------|-----------------------------------------|
| greedy search | 4.67 | 4.37 | --epoch 55 --avg 17 |
| modified beam search | 4.40 | 4.13 | --epoch 55 --avg 17 |
| fast beam search | 4.60 | 4.31 | --epoch 55 --avg 17 |
Command for training is:
```bash
./prepare.sh
export CUDA_VISIBLE_DEVICES="0,1"
./zipformer/train.py \
--world-size 2 \
--num-epochs 60 \
--start-epoch 1 \
--use-fp16 1 \
--context-size 1 \
--enable-musan 0 \
--exp-dir zipformer/exp \
--max-duration 1000 \
--enable-musan 0 \
--base-lr 0.045 \
--lr-batches 7500 \
--lr-epochs 18 \
--spec-aug-time-warp-factor 20
```
Command for decoding is:
```bash
for m in greedy_search modified_beam_search fast_beam_search ; do
./zipformer/decode.py \
--epoch 55 \
--avg 17 \
--exp-dir ./zipformer/exp \
--lang-dir data/lang_char \
--context-size 1 \
--decoding-method $m
done
```
Pretrained models, training logs, decoding logs, tensorboard and decoding results
are available at
<https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-2023-10-24>
##### small-scaled model, number of model parameters: 30167139, i.e., 30.17 M
| | test | dev | comment |
|------------------------|------|------|-----------------------------------------|
| greedy search | 4.97 | 4.67 | --epoch 55 --avg 21 |
| modified beam search | 4.67 | 4.40 | --epoch 55 --avg 21 |
| fast beam search | 4.85 | 4.61 | --epoch 55 --avg 21 |
Command for training is:
```bash
export CUDA_VISIBLE_DEVICES="0,1"
./zipformer/train.py \
--world-size 2 \
--num-epochs 60 \
--start-epoch 1 \
--use-fp16 1 \
--context-size 1 \
--exp-dir zipformer/exp-small \
--enable-musan 0 \
--base-lr 0.045 \
--lr-batches 7500 \
--lr-epochs 18 \
--spec-aug-time-warp-factor 20 \
--num-encoder-layers 2,2,2,2,2,2 \
--feedforward-dim 512,768,768,768,768,768 \
--encoder-dim 192,256,256,256,256,256 \
--encoder-unmasked-dim 192,192,192,192,192,192 \
--max-duration 1200
```
Command for decoding is:
```bash
for m in greedy_search modified_beam_search fast_beam_search ; do
./zipformer/decode.py \
--epoch 55 \
--avg 21 \
--exp-dir ./zipformer/exp-small \
--lang-dir data/lang_char \
--context-size 1 \
--decoding-method $m \
--num-encoder-layers 2,2,2,2,2,2 \
--feedforward-dim 512,768,768,768,768,768 \
--encoder-dim 192,256,256,256,256,256 \
--encoder-unmasked-dim 192,192,192,192,192,192
done
```
Pretrained models, training logs, decoding logs, tensorboard and decoding results
are available at
<https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-small-2023-10-24/>
##### large-scaled model, number of model parameters: 157285130, i.e., 157.29 M
| | test | dev | comment |
|------------------------|------|------|-----------------------------------------|
| greedy search | 4.49 | 4.22 | --epoch 56 --avg 23 |
| modified beam search | 4.28 | 4.03 | --epoch 56 --avg 23 |
| fast beam search | 4.44 | 4.18 | --epoch 56 --avg 23 |
Command for training is:
```bash
export CUDA_VISIBLE_DEVICES="0,1"
./zipformer/train.py \
--world-size 2 \
--num-epochs 60 \
--use-fp16 1 \
--context-size 1 \
--exp-dir ./zipformer/exp-large \
--enable-musan 0 \
--lr-batches 7500 \
--lr-epochs 18 \
--spec-aug-time-warp-factor 20 \
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192 \
--max-duration 800
```
Command for decoding is:
```bash
for m in greedy_search modified_beam_search fast_beam_search ; do
./zipformer/decode.py \
--epoch 56 \
--avg 23 \
--exp-dir ./zipformer/exp-large \
--lang-dir data/lang_char \
--context-size 1 \
--decoding-method $m \
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192
done
```
Pretrained models, training logs, decoding logs, tensorboard and decoding results
are available at
<https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-large-2023-10-24/>
#### Pruned transducer stateless 7 streaming
[./pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming)

View File

@ -8,6 +8,7 @@ set -eou pipefail
nj=15
stage=-1
stop_stage=11
perturb_speed=true
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
@ -114,7 +115,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Compute fbank for aishell"
if [ ! -f data/fbank/.aishell.done ]; then
mkdir -p data/fbank
./local/compute_fbank_aishell.py --perturb-speed True
./local/compute_fbank_aishell.py --perturb-speed ${perturb_speed}
touch data/fbank/.aishell.done
fi
fi
@ -242,7 +243,7 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
-lm data/lm/3-gram.unpruned.arpa
fi
# We assume you have install kaldilm, if not, please install
# We assume you have installed kaldilm, if not, please install
# it using: pip install kaldilm
if [ ! -f data/lm/G_3_gram_char.fst.txt ]; then
# It is used in building HLG

View File

@ -70,6 +70,10 @@ class Decoder(nn.Module):
groups=embedding_dim,
bias=False,
)
else:
# To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'`
# when inference with torch.jit.script and context_size == 1
self.conv = nn.Identity()
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
"""

View File

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/asr_datamodule.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/beam_search.py

View File

@ -0,0 +1,814 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./zipformer/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./zipformer/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method greedy_search
(2) modified beam search
./zipformer/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./zipformer/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(3) fast beam search (trivial_graph)
./zipformer/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./zipformer/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(4) fast beam search (LG)
./zipformer/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method fast_beam_search_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(5) fast beam search (nbest oracle WER)
./zipformer/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./zipformer/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
"""
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import AishellAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from lhotse.cut import Cut
from train import add_model_arguments, get_model, get_params
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
make_pad_mask,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_char",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_LG
- fast_beam_search_nbest_oracle
If you use fast_beam_search_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search,
fast_beam_search, fast_beam_search_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--ilme-scale",
type=float,
default=0.2,
help="""
Used only when --decoding_method is fast_beam_search_LG.
It specifies the scale for the internal language model estimation.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search, fast_beam_search_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search, fast_beam_search_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
graph_compiler: CharCtcTrainingGraphCompiler,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
if params.causal:
# this seems to cause insertions at the end of the utterance if used with zipformer.
pad_len = 30
feature_lens += pad_len
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, pad_len),
value=LOG_EPS,
)
x, x_lens = model.encoder_embed(feature, feature_lens)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
blank_penalty=params.blank_penalty,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "fast_beam_search_LG":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
blank_penalty=params.blank_penalty,
ilme_scale=params.ilme_scale,
)
for hyp in hyp_tokens:
sentence = "".join([lexicon.word_table[i] for i in hyp])
hyps.append(list(sentence))
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=graph_compiler.texts_to_ids(supervisions["text"]),
nbest_scale=params.nbest_scale,
blank_penalty=params.blank_penalty,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
blank_penalty=params.blank_penalty,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
blank_penalty=params.blank_penalty,
beam=params.beam_size,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
blank_penalty=params.blank_penalty,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
blank_penalty=params.blank_penalty,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append([lexicon.token_table[idx] for idx in hyp])
key = f"blank_penalty_{params.blank_penalty}"
if params.decoding_method == "greedy_search":
return {"greedy_search_" + key: hyps}
elif "fast_beam_search" in params.decoding_method:
key += f"_beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ilme_scale_{params.ilme_scale}"
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else:
return {f"beam_size_{params.beam_size}_" + key: hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
graph_compiler: CharCtcTrainingGraphCompiler,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
texts = [list("".join(text.split())) for text in texts]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
lexicon=lexicon,
graph_compiler=graph_compiler,
decoding_graph=decoding_graph,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
this_batch.append((cut_id, ref_text, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
AishellAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"modified_beam_search",
"fast_beam_search",
"fast_beam_search_LG",
"fast_beam_search_nbest_oracle",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.causal:
assert (
"," not in params.chunk_size
), "chunk_size should be one value in decoding."
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
params.suffix += f"-chunk-{params.chunk_size}"
params.suffix += f"-left-context-{params.left_context_frames}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"_ilme_scale_{params.ilme_scale}"
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
params.suffix += f"-blank-penalty-{params.blank_penalty}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
logging.info(params)
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
if "fast_beam_search" in params.decoding_method:
if "LG" in params.decoding_method:
lexicon = Lexicon(params.lang_dir)
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
aishell = AishellAsrDataModule(args)
def remove_short_utt(c: Cut):
T = ((c.num_frames - 7) // 2 + 1) // 2
if T <= 0:
logging.warning(
f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}."
)
return T > 0
dev_cuts = aishell.valid_cuts()
dev_cuts = dev_cuts.filter(remove_short_utt)
dev_dl = aishell.valid_dataloaders(dev_cuts)
test_cuts = aishell.test_cuts()
test_cuts = test_cuts.filter(remove_short_utt)
test_dl = aishell.test_dataloaders(test_cuts)
test_sets = ["dev", "test"]
test_dls = [dev_dl, test_dl]
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
lexicon=lexicon,
graph_compiler=graph_compiler,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/decode_stream.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/decoder.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/encoder_interface.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/export-onnx-streaming.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/export-onnx.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/export.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/jit_pretrained.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/jit_pretrained_streaming.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/joiner.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/model.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/onnx_check.py

View File

@ -0,0 +1,286 @@
#!/usr/bin/env python3
#
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao,
# Xiaoyu Yang,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads ONNX exported models and uses them to decode the test sets.
"""
import argparse
import logging
import time
from pathlib import Path
from typing import List, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import AishellAsrDataModule
from lhotse.cut import Cut
from onnx_pretrained import OnnxModel, greedy_search
from icefall.utils import setup_logger, store_transcripts, write_error_stats
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--encoder-model-filename",
type=str,
required=True,
help="Path to the encoder onnx model. ",
)
parser.add_argument(
"--decoder-model-filename",
type=str,
required=True,
help="Path to the decoder onnx model. ",
)
parser.add_argument(
"--joiner-model-filename",
type=str,
required=True,
help="Path to the joiner onnx model. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7/exp",
help="The experiment dir",
)
parser.add_argument(
"--tokens",
type=str,
default="data/lang_char/tokens.txt",
help="Path to the tokens.txt",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="Valid values are greedy_search and modified_beam_search",
)
return parser
def decode_one_batch(
model: OnnxModel, token_table: k2.SymbolTable, batch: dict
) -> List[List[str]]:
"""Decode one batch and return the result.
Currently it only greedy_search is supported.
Args:
model:
The neural model.
token_table:
Mapping ids to tokens.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
Returns:
Return the decoded results for each utterance.
"""
feature = batch["inputs"]
assert feature.ndim == 3
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(dtype=torch.int64)
encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens)
hyps = greedy_search(
model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens
)
hyps = [[token_table[h] for h in hyp] for hyp in hyps]
return hyps
def decode_dataset(
dl: torch.utils.data.DataLoader,
model: nn.Module,
token_table: k2.SymbolTable,
) -> Tuple[List[Tuple[str, List[str], List[str]]], float]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
model:
The neural model.
token_table:
Mapping ids to tokens.
Returns:
- A list of tuples. Each tuple contains three elements:
- cut_id,
- reference transcript,
- predicted result.
- The total duration (in seconds) of the dataset.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
log_interval = 10
total_duration = 0
results = []
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]])
hyps = decode_one_batch(model=model, token_table=token_table, batch=batch)
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = list(ref_text)
this_batch.append((cut_id, ref_words, hyp_words))
results.extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results, total_duration
def save_results(
res_dir: Path,
test_set_name: str,
results: List[Tuple[str, List[str], List[str]]],
):
recog_path = res_dir / f"recogs-{test_set_name}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = res_dir / f"errs-{test_set_name}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True)
logging.info("Wrote detailed error stats to {}".format(errs_filename))
errs_info = res_dir / f"wer-summary-{test_set_name}.txt"
with open(errs_info, "w") as f:
print("WER", file=f)
print(wer, file=f)
s = "\nFor {}, WER is {}:\n".format(test_set_name, wer)
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
AishellAsrDataModule.add_arguments(parser)
args = parser.parse_args()
assert (
args.decoding_method == "greedy_search"
), "Only supports greedy_search currently."
res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}"
setup_logger(f"{res_dir}/log-decode")
logging.info("Decoding started")
device = torch.device("cpu")
logging.info(f"Device: {device}")
token_table = k2.SymbolTable.from_file(args.tokens)
assert token_table[0] == "<blk>"
logging.info(vars(args))
logging.info("About to create model")
model = OnnxModel(
encoder_model_filename=args.encoder_model_filename,
decoder_model_filename=args.decoder_model_filename,
joiner_model_filename=args.joiner_model_filename,
)
# we need cut ids to display recognition results.
args.return_cuts = True
aishell = AishellAsrDataModule(args)
def remove_short_utt(c: Cut):
T = ((c.num_frames - 7) // 2 + 1) // 2
if T <= 0:
logging.warning(
f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}."
)
return T > 0
dev_cuts = aishell.valid_cuts()
dev_cuts = dev_cuts.filter(remove_short_utt)
dev_dl = aishell.valid_dataloaders(dev_cuts)
test_cuts = aishell.test_net_cuts()
test_cuts = test_cuts.filter(remove_short_utt)
test_dl = aishell.test_dataloaders(test_cuts)
test_sets = ["dev", "test"]
test_dl = [dev_dl, test_dl]
for test_set, test_dl in zip(test_sets, test_dl):
start_time = time.time()
results, total_duration = decode_dataset(
dl=test_dl, model=model, token_table=token_table
)
end_time = time.time()
elapsed_seconds = end_time - start_time
rtf = elapsed_seconds / total_duration
logging.info(f"Elapsed time: {elapsed_seconds:.3f} s")
logging.info(f"Wave duration: {total_duration:.3f} s")
logging.info(
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
)
save_results(res_dir=res_dir, test_set_name=test_set, results=results)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/onnx_pretrained-streaming.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/onnx_pretrained.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/optim.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/pretrained.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/scaling.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/scaling_converter.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/streaming_beam_search.py

View File

@ -0,0 +1,880 @@
#!/usr/bin/env python3
# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang,
# Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./zipformer/streaming_decode.py \
--epoch 28 \
--avg 15 \
--causal 1 \
--chunk-size 16 \
--left-context-frames 256 \
--exp-dir ./zipformer/exp \
--decoding-method greedy_search \
--num-decode-streams 2000
"""
import argparse
import logging
import math
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import numpy as np
import torch
from asr_datamodule import AishellAsrDataModule
from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
from streaming_beam_search import (
fast_beam_search_one_best,
greedy_search,
modified_beam_search,
)
from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
make_pad_mask,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="Path to the lang dir(containing lexicon, tokens, etc.)",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Supported decoding methods are:
greedy_search
modified_beam_search
fast_beam_search
""",
)
parser.add_argument(
"--num_active_paths",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=32,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
parser.add_argument(
"--num-decode-streams",
type=int,
default=2000,
help="The number of streams that can be decoded parallel.",
)
add_model_arguments(parser)
return parser
def get_init_states(
model: nn.Module,
batch_size: int = 1,
device: torch.device = torch.device("cpu"),
) -> List[torch.Tensor]:
"""
Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
states[-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
states[-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
"""
states = model.encoder.get_init_states(batch_size, device)
embed_states = model.encoder_embed.get_init_states(batch_size, device)
states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
states.append(processed_lens)
return states
def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
"""Stack list of zipformer states that correspond to separate utterances
into a single emformer state, so that it can be used as an input for
zipformer when those utterances are formed into a batch.
Args:
state_list:
Each element in state_list corresponding to the internal state
of the zipformer model for a single utterance. For element-n,
state_list[n] is a list of cached tensors of all encoder layers. For layer-i,
state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1,
cached_val2, cached_conv1, cached_conv2).
state_list[n][-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
state_list[n][-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
Note:
It is the inverse of :func:`unstack_states`.
"""
batch_size = len(state_list)
assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0])
tot_num_layers = (len(state_list[0]) - 2) // 6
batch_states = []
for layer in range(tot_num_layers):
layer_offset = layer * 6
# cached_key: (left_context_len, batch_size, key_dim)
cached_key = torch.cat(
[state_list[i][layer_offset] for i in range(batch_size)], dim=1
)
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
cached_nonlin_attn = torch.cat(
[state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1
)
# cached_val1: (left_context_len, batch_size, value_dim)
cached_val1 = torch.cat(
[state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1
)
# cached_val2: (left_context_len, batch_size, value_dim)
cached_val2 = torch.cat(
[state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1
)
# cached_conv1: (#batch, channels, left_pad)
cached_conv1 = torch.cat(
[state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0
)
# cached_conv2: (#batch, channels, left_pad)
cached_conv2 = torch.cat(
[state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0
)
batch_states += [
cached_key,
cached_nonlin_attn,
cached_val1,
cached_val2,
cached_conv1,
cached_conv2,
]
cached_embed_left_pad = torch.cat(
[state_list[i][-2] for i in range(batch_size)], dim=0
)
batch_states.append(cached_embed_left_pad)
processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0)
batch_states.append(processed_lens)
return batch_states
def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
"""Unstack the zipformer state corresponding to a batch of utterances
into a list of states, where the i-th entry is the state from the i-th
utterance in the batch.
Note:
It is the inverse of :func:`stack_states`.
Args:
batch_states: A list of cached tensors of all encoder layers. For layer-i,
states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2,
cached_conv1, cached_conv2).
state_list[-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
states[-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
Returns:
state_list: A list of list. Each element in state_list corresponding to the internal state
of the zipformer model for a single utterance.
"""
assert (len(batch_states) - 2) % 6 == 0, len(batch_states)
tot_num_layers = (len(batch_states) - 2) // 6
processed_lens = batch_states[-1]
batch_size = processed_lens.shape[0]
state_list = [[] for _ in range(batch_size)]
for layer in range(tot_num_layers):
layer_offset = layer * 6
# cached_key: (left_context_len, batch_size, key_dim)
cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1)
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
chunks=batch_size, dim=1
)
# cached_val1: (left_context_len, batch_size, value_dim)
cached_val1_list = batch_states[layer_offset + 2].chunk(
chunks=batch_size, dim=1
)
# cached_val2: (left_context_len, batch_size, value_dim)
cached_val2_list = batch_states[layer_offset + 3].chunk(
chunks=batch_size, dim=1
)
# cached_conv1: (#batch, channels, left_pad)
cached_conv1_list = batch_states[layer_offset + 4].chunk(
chunks=batch_size, dim=0
)
# cached_conv2: (#batch, channels, left_pad)
cached_conv2_list = batch_states[layer_offset + 5].chunk(
chunks=batch_size, dim=0
)
for i in range(batch_size):
state_list[i] += [
cached_key_list[i],
cached_nonlin_attn_list[i],
cached_val1_list[i],
cached_val2_list[i],
cached_conv1_list[i],
cached_conv2_list[i],
]
cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0)
for i in range(batch_size):
state_list[i].append(cached_embed_left_pad_list[i])
processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0)
for i in range(batch_size):
state_list[i].append(processed_lens_list[i])
return state_list
def streaming_forward(
features: Tensor,
feature_lens: Tensor,
model: nn.Module,
states: List[Tensor],
chunk_size: int,
left_context_len: int,
) -> Tuple[Tensor, Tensor, List[Tensor]]:
"""
Returns encoder outputs, output lengths, and updated states.
"""
cached_embed_left_pad = states[-2]
(x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward(
x=features,
x_lens=feature_lens,
cached_left_pad=cached_embed_left_pad,
)
assert x.size(1) == chunk_size, (x.size(1), chunk_size)
src_key_padding_mask = make_pad_mask(x_lens)
# processed_mask is used to mask out initial states
processed_mask = torch.arange(left_context_len, device=x.device).expand(
x.size(0), left_context_len
)
processed_lens = states[-1] # (batch,)
# (batch, left_context_size)
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
# Update processed lengths
new_processed_lens = processed_lens + x_lens
# (batch, left_context_size + chunk_size)
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_states = states[:-2]
(
encoder_out,
encoder_out_lens,
new_encoder_states,
) = model.encoder.streaming_forward(
x=x,
x_lens=x_lens,
states=encoder_states,
src_key_padding_mask=src_key_padding_mask,
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
new_states = new_encoder_states + [
new_cached_embed_left_pad,
new_processed_lens,
]
return encoder_out, encoder_out_lens, new_states
def decode_one_chunk(
params: AttributeDict,
model: nn.Module,
decode_streams: List[DecodeStream],
) -> List[int]:
"""Decode one chunk frames of features for each decode_streams and
return the indexes of finished streams in a List.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
decode_streams:
A List of DecodeStream, each belonging to a utterance.
Returns:
Return a List containing which DecodeStreams are finished.
"""
device = model.device
chunk_size = int(params.chunk_size)
left_context_len = int(params.left_context_frames)
features = []
feature_lens = []
states = []
processed_lens = [] # Used in fast-beam-search
for stream in decode_streams:
feat, feat_len = stream.get_feature_frames(chunk_size * 2)
features.append(feat)
feature_lens.append(feat_len)
states.append(stream.states)
processed_lens.append(stream.done_frames)
feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
# Make sure the length after encoder_embed is at least 1.
# The encoder_embed subsample features (T - 7) // 2
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
tail_length = chunk_size * 2 + 7 + 2 * 3
if features.size(1) < tail_length:
pad_length = tail_length - features.size(1)
feature_lens += pad_length
features = torch.nn.functional.pad(
features,
(0, 0, 0, pad_length),
mode="constant",
value=LOG_EPS,
)
states = stack_states(states)
encoder_out, encoder_out_lens, new_states = streaming_forward(
features=features,
feature_lens=feature_lens,
model=model,
states=states,
chunk_size=chunk_size,
left_context_len=left_context_len,
)
encoder_out = model.joiner.encoder_proj(encoder_out)
if params.decoding_method == "greedy_search":
greedy_search(
model=model,
encoder_out=encoder_out,
streams=decode_streams,
blank_penalty=params.blank_penalty,
)
elif params.decoding_method == "fast_beam_search":
processed_lens = torch.tensor(processed_lens, device=device)
processed_lens = processed_lens + encoder_out_lens
fast_beam_search_one_best(
model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
streams=decode_streams,
beam=params.beam,
max_states=params.max_states,
max_contexts=params.max_contexts,
blank_penalty=params.blank_penalty,
)
elif params.decoding_method == "modified_beam_search":
modified_beam_search(
model=model,
streams=decode_streams,
encoder_out=encoder_out,
num_active_paths=params.num_active_paths,
blank_penalty=params.blank_penalty,
)
else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
states = unstack_states(new_states)
finished_streams = []
for i in range(len(decode_streams)):
decode_streams[i].states = states[i]
decode_streams[i].done_frames += encoder_out_lens[i]
if decode_streams[i].done:
finished_streams.append(i)
return finished_streams
def decode_dataset(
cuts: CutSet,
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
cuts:
Lhotse Cutset containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
lexicon:
The Lexicon.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
device = model.device
opts = FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
log_interval = 100
decode_results = []
# Contain decode streams currently running.
decode_streams = []
for num, cut in enumerate(cuts):
# each utterance has a DecodeStream.
initial_states = get_init_states(model=model, batch_size=1, device=device)
decode_stream = DecodeStream(
params=params,
cut_id=cut.id,
initial_states=initial_states,
decoding_graph=decoding_graph,
device=device,
)
audio: np.ndarray = cut.load_audio()
# audio.shape: (1, num_samples)
assert len(audio.shape) == 2
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
if audio.max() > 1:
logging.warning(
f"The audio should be normalized to [-1, 1], audio.max : {audio.max()}."
f"Clipping to [-1, 1]."
)
audio = np.clip(audio, -1, 1)
samples = torch.from_numpy(audio).squeeze(0)
fbank = Fbank(opts)
feature = fbank(samples.to(device))
decode_stream.set_features(feature, tail_pad_len=30)
decode_stream.ground_truth = cut.supervisions[0].text
decode_streams.append(decode_stream)
while len(decode_streams) >= params.num_decode_streams:
finished_streams = decode_one_chunk(
params=params, model=model, decode_streams=decode_streams
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
decode_streams[i].id,
list(decode_streams[i].ground_truth.strip()),
[
lexicon.token_table[idx]
for idx in decode_streams[i].decoding_result()
],
)
)
del decode_streams[i]
if num % log_interval == 0:
logging.info(f"Cuts processed until now is {num}.")
# decode final chunks of last sequences
while len(decode_streams):
finished_streams = decode_one_chunk(
params=params, model=model, decode_streams=decode_streams
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
decode_streams[i].id,
decode_streams[i].ground_truth.split(),
[
lexicon.token_table[idx]
for idx in decode_streams[i].decoding_result()
],
)
)
del decode_streams[i]
key = f"blank_penalty_{params.blank_penalty}"
if params.decoding_method == "greedy_search":
key = f"greedy_search_{key}"
elif params.decoding_method == "fast_beam_search":
key = (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}_{key}"
)
elif params.decoding_method == "modified_beam_search":
key = f"num_active_paths_{params.num_active_paths}_{key}"
else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
return {key: decode_results}
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
AishellAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
assert params.causal, params.causal
assert "," not in params.chunk_size, "chunk_size should be one value in decoding."
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
params.suffix += f"-chunk-{params.chunk_size}"
params.suffix += f"-left-context-{params.left_context_frames}"
params.suffix += f"-blank-penalty-{params.blank_penalty}"
# for fast_beam_search
if params.decoding_method == "fast_beam_search":
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
model.device = device
decoding_graph = None
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
aishell = AishellAsrDataModule(args)
dev_cuts = aishell.valid_cuts()
test_cuts = aishell.test_cuts()
test_sets = ["dev", "test"]
test_cuts = [dev_cuts, test_cuts]
for test_set, test_cut in zip(test_sets, test_cuts):
results_dict = decode_dataset(
cuts=test_cut,
params=params,
model=model,
lexicon=lexicon,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/subsampling.py

1350
egs/aishell/ASR/zipformer/train.py Executable file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/zipformer.py

View File

@ -1,7 +1,11 @@
# Introduction
This recipe includes some different ASR models trained with Aishell2.
This recipe contains various different ASR models trained with Aishell2.
In AISHELL-2, 1000 hours of clean read-speech data from iOS is published, which is free for academic usage. On top of AISHELL-2 corpus, an improved recipe is developed and released, containing key components for industrial applications, such as Chinese word segmentation, flexible vocabulary expension and phone set transformation etc. Pipelines support various state-of-the-art techniques, such as time-delayed neural networks and Lattic-Free MMI objective funciton. In addition, we also release dev and test data from other channels (Android and Mic).
(From [AISHELL-2: Transforming Mandarin ASR Research Into Industrial Scale](https://arxiv.org/abs/1808.10583))
[./RESULTS.md](./RESULTS.md) contains the latest results.

View File

@ -1,8 +1,8 @@
## Results
### Aishell2 char-based training results (Pruned Transducer 5)
### Aishell2 char-based training results
#### 2022-07-11
#### Pruned transducer stateless 5
Using the codes from this commit https://github.com/k2-fsa/icefall/pull/465.
@ -41,9 +41,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
The decoding command is:
```bash
for method in greedy_search modified_beam_search \
fast_beam_search fast_beam_search_nbest \
fast_beam_search_nbest_oracle fast_beam_search_nbest_LG; do
for method in greedy_search modified_beam_search fast_beam_search fast_beam_search_nbest fast_beam_search_nbest_oracle fast_beam_search_nbest_LG; do
./pruned_transducer_stateless5/decode.py \
--epoch 25 \
--avg 5 \

View File

@ -7,7 +7,9 @@ set -eou pipefail
nj=30
stage=0
stop_stage=5
stop_stage=7
perturb_speed=true
# We assume dl_dir (download dir) contains the following
# directories and files. If not, you need to apply aishell2 through
@ -101,7 +103,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Compute fbank for aishell2"
if [ ! -f data/fbank/.aishell2.done ]; then
mkdir -p data/fbank
./local/compute_fbank_aishell2.py --perturb-speed True
./local/compute_fbank_aishell2.py --perturb-speed ${perturb_speed}
touch data/fbank/.aishell2.done
fi
fi
@ -157,7 +159,7 @@ fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Prepare G"
# We assume you have install kaldilm, if not, please install
# We assume you have installed kaldilm, if not, please install
# it using: pip install kaldilm
if [ ! -f ${lang_char_dir}/3-gram.unpruned.arpa ]; then

View File

@ -1,7 +1,11 @@
# Introduction
This recipe includes some different ASR models trained with Aishell4 (including S, M and L three subsets).
This recipe contains some various ASR models trained with Aishell4 (including S, M and L three subsets).
The AISHELL-4 is a sizable real-recorded Mandarin speech dataset collected by 8-channel circular microphone array for speech processing in conference scenarios. The dataset consists of 211 recorded meeting sessions, each containing 4 to 8 speakers, with a total length of 120 hours. This dataset aims to bridge the advanced research on multi-speaker processing and the practical application scenario in three aspects. With real recorded meetings, AISHELL-4 provides realistic acoustics and rich natural speech characteristics in conversation such as short pause, speech overlap, quick speaker turn, noise, etc. Meanwhile, the accurate transcription and speaker voice activity are provided for each meeting in AISHELL-4. This allows the researchers to explore different aspects in meeting processing, ranging from individual tasks such as speech front-end processing, speech recognition and speaker diarization, to multi-modality modeling and joint optimization of relevant tasks.
(From [Open Speech and Language Resources](https://www.openslr.org/111/))
[./RESULTS.md](./RESULTS.md) contains the latest results.

View File

@ -7,6 +7,8 @@ set -eou pipefail
stage=-1
stop_stage=100
perturb_speed=true
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
@ -107,7 +109,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Compute fbank for aishell4"
if [ ! -f data/fbank/.aishell4.done ]; then
mkdir -p data/fbank
./local/compute_fbank_aishell4.py --perturb-speed True
./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed}
touch data/fbank/.aishell4.done
fi
fi

View File

@ -306,7 +306,7 @@ class Aishell4AsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=30000,
buffer_size=100000,
drop_last=self.args.drop_last,
)
else:

View File

@ -0,0 +1 @@
../local

View File

@ -293,7 +293,7 @@ fi
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
log "Stage 12: Prepare G"
# We assume you have install kaldilm, if not, please install
# We assume you have installed kaldilm, if not, please install
# it using: pip install kaldilm
mkdir -p data/lm

View File

@ -233,7 +233,7 @@ class GigaSpeechAsrDataModule:
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")

View File

@ -1164,7 +1164,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -95,6 +95,11 @@ class Decoder(nn.Module):
max_abs=1.0,
prob=0.05,
)
else:
# To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'`
# when inference with torch.jit.script and context_size == 1
self.conv = nn.Identity()
self.balancer2 = nn.Identity()
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
"""

View File

@ -116,7 +116,7 @@ class BatchedOptimizer(Optimizer):
yield tuples # <-- calling code will do the actual optimization here!
for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
for (stacked_params, _state, _names), batch in zip(tuples, batches):
for i, p in enumerate(batch): # batch is list of Parameter
p.copy_(stacked_params[i])
@ -181,7 +181,6 @@ class ScaledAdam(BatchedOptimizer):
size_update_period=4,
clipping_update_period=100,
):
defaults = dict(
lr=lr,
clipping_scale=clipping_scale,
@ -327,9 +326,7 @@ class ScaledAdam(BatchedOptimizer):
batch = True
for group, group_params_names in zip(self.param_groups, self.parameters_names):
with self.batched_params(group["params"], group_params_names) as batches:
# batches is list of pairs (stacked_param, state). stacked_param is like
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
# a stacking dim, it is not a real dim.
@ -429,7 +426,7 @@ class ScaledAdam(BatchedOptimizer):
clipping_update_period = group["clipping_update_period"]
tot_sumsq = torch.tensor(0.0, device=first_p.device)
for (p, state, param_names) in tuples:
for p, state, param_names in tuples:
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
@ -514,7 +511,7 @@ class ScaledAdam(BatchedOptimizer):
from tuples, we still pass it to save some time.
"""
all_sumsq_orig = {}
for (p, state, batch_param_names) in tuples:
for p, state, batch_param_names in tuples:
# p is a stacked batch parameters.
batch_grad = p.grad
if p.numel() == p.shape[0]: # a batch of scalars
@ -530,7 +527,6 @@ class ScaledAdam(BatchedOptimizer):
for name, sumsq_orig, rms, grad in zip(
batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
):
proportion_orig = sumsq_orig / tot_sumsq
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
@ -1106,7 +1102,7 @@ def _test_scaled_adam(hidden_dim: int):
# if epoch == 130:
# opts = diagnostics.TensorDiagnosticOptions(
# 2 ** 22
# 512
# ) # allow 4 megabytes per sub-module
# diagnostic = diagnostics.attach_diagnostics(m, opts)

View File

@ -1194,7 +1194,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -1565,7 +1565,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
args.max_duration = 100
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -47,6 +47,7 @@ We place an additional Conv1d layer right after the input embedding layer.
| `conformer-ctc` | Conformer | Use auxiliary attention head |
| `conformer-ctc2` | Reworked Conformer | Use auxiliary attention head |
| `conformer-ctc3` | Reworked Conformer | Streaming version + delay penalty |
| `zipformer-ctc` | Zipformer | Use auxiliary attention head |
| `zipformer` | Upgraded Zipformer | Use auxiliary transducer head | The latest recipe |
# MMI

View File

@ -245,6 +245,58 @@ for m in greedy_search modified_beam_search fast_beam_search; do
done
```
##### large-scaled model, number of model parameters: 148439574, i.e., 148.4 M, trained on 8 80G-A100 GPUs
The tensorboard log can be found at
<https://tensorboard.dev/experiment/95TdNyEuQXaWK2PzFpD9yg/>
You can find a pretrained model, training logs, decoding logs, and decoding results at:
<https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-large-2023-10-26-8-a100>
You can use <https://github.com/k2-fsa/sherpa> to deploy it.
| decoding method | test-clean | test-other | comment |
|----------------------|------------|------------|-----------------------|
| greedy_search | 2.00 | 4.47 | --epoch 174 --avg 172 |
| modified_beam_search | 2.00 | 4.38 | --epoch 174 --avg 172 |
| fast_beam_search | 2.00 | 4.42 | --epoch 174 --avg 172 |
The training command is:
```bash
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./zipformer/train.py \
--world-size 8 \
--num-epochs 174 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp-large \
--causal 0 \
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192 \
--full-libri 1 \
--max-duration 2200
```
The decoding command is:
```bash
export CUDA_VISIBLE_DEVICES="0"
for m in greedy_search modified_beam_search fast_beam_search; do
./zipformer/decode.py \
--epoch 174 \
--avg 172 \
--exp-dir zipformer/exp-large \
--max-duration 600 \
--causal 0 \
--decoding-method $m \
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192
done
```
#### streaming
##### normal-scaled model, number of model parameters: 66110931, i.e., 66.11 M
@ -323,6 +375,55 @@ for m in greedy_search modified_beam_search fast_beam_search; do
done
```
### Zipformer CTC
#### [zipformer_ctc](./zipformer_ctc)
See <https://github.com/k2-fsa/icefall/pull/941> for more details.
You can find a pretrained model, training logs, decoding logs, and decoding
results at:
<https://huggingface.co/desh2608/icefall-asr-librispeech-zipformer-ctc>
Number of model parameters: 86083707, i.e., 86.08 M
| decoding method | test-clean | test-other | comment |
|-------------------------|------------|------------|---------------------|
| ctc-decoding | 2.50 | 5.86 | --epoch 30 --avg 9 |
| whole-lattice-rescoring | 2.44 | 5.38 | --epoch 30 --avg 9 |
| attention-rescoring | 2.35 | 5.16 | --epoch 30 --avg 9 |
| 1best | 2.01 | 4.61 | --epoch 30 --avg 9 |
The training commands are:
```bash
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./zipformer_ctc/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer_ctc/exp \
--full-libri 1 \
--max-duration 1000 \
--master-port 12345
```
The tensorboard log can be found at:
<https://tensorboard.dev/experiment/IjPSJjHOQFKPYA5Z0Vf8wg>
The decoding command is:
```bash
./zipformer_ctc/decode.py \
--epoch 30 --avg 9 --use-averaged-model True \
--exp-dir zipformer_ctc/exp \
--lang-dir data/lang_bpe_500 \
--lm-dir data/lm \
--method ctc-decoding
```
### pruned_transducer_stateless7 (Fine-tune with mux)
See <https://github.com/k2-fsa/icefall/pull/1059> for more details.
@ -564,7 +665,6 @@ for m in greedy_search modified_beam_search fast_beam_search; do
done
```
#### Smaller model
We also provide a very small version (only 6.1M parameters) of this setup. The training command for the small model is:
@ -611,6 +711,7 @@ This small model achieves the following WERs on GigaSpeech test and dev sets:
You can find the tensorboard logs at <https://tensorboard.dev/experiment/tAc5iXxTQrCQxky5O5OLyw/#scalars>.
### Streaming Zipformer-Transducer (Pruned Stateless Transducer + Streaming Zipformer)
#### [pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming)

View File

@ -278,7 +278,7 @@ fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
log "Stage 8: Prepare G"
# We assume you have install kaldilm, if not, please install
# We assume you have installed kaldilm, if not, please install
# it using: pip install kaldilm
mkdir -p data/lm

View File

@ -74,6 +74,10 @@ class Decoder(nn.Module):
groups=embedding_dim,
bias=False,
)
else:
# To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'`
# when inference with torch.jit.script and context_size == 1
self.conv = nn.Identity()
self.output_linear = nn.Linear(embedding_dim, vocab_size)
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:

View File

@ -0,0 +1,184 @@
## Introduction
This recipe is intended for streaming ASR on very low cost devices, with model parameters in the range of 1-2M. It uses a small convolutional net as the encoder. It is trained with combined transducer and CTC losses, and supports both phone and BPE lexicons. For phone lexicon, you can do transducer decoding using a method with LG, but the results were bad.
The encoder consists of 2 subsampling layers followed by a stack of Conv1d-batchnorm-activation-causal_squeeze_excite blocks, with optional skip connections. To reduce latency (at the cost of slightly higher WER), half of the blocks use causal convolution.
A few remarks & observations:
1. Phone lexicon works better than BPE for CTC decoding (with HLG) but worse for transducer decoding.
2. SpecAugment is not helpful for very small models as they tend to underfit rather than overfit. For the large model, a less aggressive SpecAugment (see asr_datamodule.py) improved the result a little.
3. Squeeze-and-excitation worked like a charm! It reduces WER quite a bit with marginal increase of parameters and MAC ops. To make it causal I changed the global average pooling layer to a moving average filter, so only historical context is used.
## Pretrained models
You can find pretrained models, training logs, decoding logs, and decoding results at:
<https://huggingface.co/wangtiance/tiny_transducer_ctc/tree/main>
## Results on full libri
I tried 3 different sizes of the encoder. The parameters are around 1M, 2M and 4M, respectively. For CTC decoding, whole-lattice-rescoring frequently causes OOM error so the result is not shown.
### Small encoder
The small encoder uses 10 layers of 1D convolution block with 256 channels, without skip connections. The encoder, decoder and joiner dim is 256. Algorithmic latency is 280ms. Multiply-add ops for the encoder is 22.0Mops. It is more applicable for ASR products with limited vocabulary (like a fixed set of phrases or short sentences).
#### CTC decoding with phone lexicon
Total parameters: 1073392
Parameters for CTC decoding: 865816
| | test-clean | test-other | comment |
|-----------------|------------|------------|----------------------|
| 1best | 9.68 | 24.9 | --epoch 30 --avg 2 |
| nbest-rescoring | 8.2 | 22.7 | --epoch 30 --avg 2 |
The training commands are:
```bash
./tiny_transducer_ctc/train.py \
--num-epochs 30 \
--full-libri 1 \
--max-duration 600 \
--exp-dir tiny_transducer_ctc/exp_small_phone \
--ctc-loss-scale 0.7 \
--enable-spec-aug 0 \
--lang-dir lang_phone \
--encoder-dim 256 \
--decoder-dim 256 \
--joiner-dim 256 \
--conv-layers 10 \
--channels 256 \
--skip-add 0 \
```
#### Transducer decoding with BPE 500 lexicon
Total parameters: 1623264
Parameters for transducer decoding: 1237764
| | test-clean | test-other | comment |
|--------------------|------------|------------|----------------------|
| greedy_search | 14.47 | 32.03 | --epoch 30 --avg 1 |
| fast_beam_search | 13.38 | 29.61 | --epoch 30 --avg 1 |
|modified_beam_search| 13.02 | 29.32 | --epoch 30 --avg 1 |
The training commands are:
```bash
./tiny_transducer_ctc/train.py \
--num-epochs 30 \
--full-libri 1 \
--max-duration 600 \
--exp-dir tiny_transducer_ctc/exp_small_bpe \
--ctc-loss-scale 0.2 \
--enable-spec-aug 0 \
--lang-dir lang_bpe_500 \
--encoder-dim 256 \
--decoder-dim 256 \
--joiner-dim 256 \
--conv-layers 10 \
--channels 256 \
--skip-add 0 \
```
### Middle encoder
The middle encoder uses 18 layers of 1D convolution block with 300 channels, with skip connections. The encoder, decoder and joiner dim is 256. Algorithmic latency is 440ms. Multiply-add ops for the encoder is 50.1Mops. Note that the nbest-rescoring result is better than the tdnn_lstm_ctc recipe with whole-lattice-rescoring.
#### CTC decoding with phone lexicon
Total parameters: 2186242
Parameters for CTC decoding: 1978666
| | test-clean | test-other | comment |
|-----------------|------------|------------|----------------------|
| 1best | 7.48 | 18.94 | --epoch 30 --avg 1 |
| nbest-rescoring | 6.31 | 16.89 | --epoch 30 --avg 1 |
The training commands are:
```bash
./tiny_transducer_ctc/train.py \
--num-epochs 30 \
--full-libri 1 \
--max-duration 600 \
--exp-dir tiny_transducer_ctc/exp_middle_phone \
--ctc-loss-scale 0.7 \
--enable-spec-aug 0 \
--lang-dir lang_phone \
--encoder-dim 256 \
--decoder-dim 256 \
--joiner-dim 256 \
--conv-layers 18 \
--channels 300 \
--skip-add 1 \
```
#### Transducer decoding with BPE 500 lexicon
Total parameters: 2735794
Parameters for transducer decoding: 2350294
| | test-clean | test-other | comment |
|--------------------|------------|------------|----------------------|
| greedy_search | 10.26 | 25.13 | --epoch 30 --avg 2 |
| fast_beam_search | 9.69 | 23.58 | --epoch 30 --avg 2 |
|modified_beam_search| 9.43 | 23.53 | --epoch 30 --avg 2 |
The training commands are:
```bash
./tiny_transducer_ctc/train.py \
--num-epochs 30 \
--full-libri 1 \
--max-duration 600 \
--exp-dir tiny_transducer_ctc/exp_middle_bpe \
--ctc-loss-scale 0.2 \
--enable-spec-aug 0 \
--lang-dir lang_bpe_500 \
--encoder-dim 256 \
--decoder-dim 256 \
--joiner-dim 256 \
--conv-layers 18 \
--channels 300 \
--skip-add 1 \
```
### Large encoder
The large encoder uses 18 layers of 1D convolution block with 400 channels, with skip connections. The encoder, decoder and joiner dim is 400. Algorithmic latency is 440ms. Multiply-add ops for the encoder is 88.8Mops. It is interesting to see how much the gap is if we simply scale down more complicated models like Zipformer or emformer.
#### Transducer decoding with BPE 500 lexicon
Total parameters: 4821330
Parameters for transducer decoding: 4219830
| | test-clean | test-other | comment |
|--------------------|------------|------------|----------------------|
| greedy_search | 8.29 | 21.11 | --epoch 30 --avg 1 |
| fast_beam_search | 7.91 | 20.1 | --epoch 30 --avg 1 |
|modified_beam_search| 7.74 | 19.89 | --epoch 30 --avg 1 |
The training commands are:
```bash
./tiny_transducer_ctc/train.py \
--num-epochs 30 \
--full-libri 1 \
--max-duration 600 \
--exp-dir tiny_transducer_ctc/exp_large_bpe \
--ctc-loss-scale 0.2 \
--enable-spec-aug 1 \
--lang-dir lang_bpe_500 \
--encoder-dim 400 \
--decoder-dim 400 \
--joiner-dim 400 \
--conv-layers 18 \
--channels 400 \
--skip-add 1 \
```

View File

@ -0,0 +1,454 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class LibriSpeechAsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--full-libri",
type=str2bool,
default=True,
help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=False,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=0,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_feature_masks=2,
features_mask_size=5,
num_frame_masks=10,
frames_mask_size=5,
p=0.5,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
input_strategy=eval(self.args.input_strategy)(),
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_clean_100_cuts(self) -> CutSet:
logging.info("About to get train-clean-100 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
)
@lru_cache()
def train_clean_360_cuts(self) -> CutSet:
logging.info("About to get train-clean-360 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
)
@lru_cache()
def train_other_500_cuts(self) -> CutSet:
logging.info("About to get train-other-500 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
)
@lru_cache()
def train_all_shuf_cuts(self) -> CutSet:
logging.info(
"About to get the shuffled train-clean-100, \
train-clean-360 and train-other-500 cuts"
)
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
)
@lru_cache()
def dev_clean_cuts(self) -> CutSet:
logging.info("About to get dev-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
)
@lru_cache()
def dev_other_cuts(self) -> CutSet:
logging.info("About to get dev-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
)
@lru_cache()
def test_clean_cuts(self) -> CutSet:
logging.info("About to get test-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
)
@lru_cache()
def test_other_cuts(self) -> CutSet:
logging.info("About to get test-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
)

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/beam_search.py

View File

@ -0,0 +1,771 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Liyong Guo,
# Quandong Wang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import math
import pprint
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.decode import (
get_lattice,
nbest_decoding,
nbest_oracle,
one_best_decoding,
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
get_texts,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--exp-dir",
type=str,
default="tiny_transducer_ctc/exp",
help="The experiment dir",
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=1,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_phone",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
default="1best",
help="""Decoding method.
Supported values are:
- (1) ctc-decoding. Use CTC decoding. It uses a sentence piece
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
It needs neither a lexicon nor an n-gram LM.
- (2) 1best. Extract the best path from the decoding lattice as the
decoding result.
- (3) nbest. Extract n paths from the decoding lattice; the path
with the highest score is the decoding result.
- (4) nbest-rescoring. Extract n paths from the decoding lattice,
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
the highest score is the decoding result.
- (5) whole-lattice-rescoring. Rescore the decoding lattice with an
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
is the decoding result.
you have trained an RNN LM using ./rnn_lm/train.py
- (6) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best
rescoring method.
""",
)
parser.add_argument(
"--num-paths",
type=int,
default=100,
help="""Number of paths for n-best based decoding method.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, and nbest-oracle
""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=1.0,
help="""The scale to be applied to `lattice.scores`.
It's needed if you use any kinds of n-best based rescoring.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, and nbest-oracle
A smaller value results in more unique paths.
""",
)
parser.add_argument(
"--hlg-scale",
type=float,
default=0.7,
help="""The scale to be applied to `hlg.scores`.
""",
)
parser.add_argument(
"--lm-dir",
type=str,
default="data/lm",
help="""The n-gram LM dir.
It should contain either G_4_gram.pt or G_4_gram.fst.txt
""",
)
add_model_arguments(parser)
return parser
def get_decoding_params() -> AttributeDict:
"""Parameters for decoding."""
params = AttributeDict(
{
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
"context_size": 2,
}
)
return params
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
batch: dict,
word_table: k2.SymbolTable,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if no rescoring is used, the key is the string `no_rescore`.
If LM rescoring is used, the key is the string `lm_scale_xxx`,
where `xxx` is the value of `lm_scale`. An example key is
`lm_scale_0.7`
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
- params.decoding_method is "1best", it uses 1best decoding without LM rescoring.
- params.decoding_method is "nbest", it uses nbest decoding without LM rescoring.
- params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring.
- params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM
rescoring.
model:
The neural model.
HLG:
The decoding graph. Used only when params.decoding_method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.decoding_method is ctc-decoding.
bpe_model:
The BPE model. Used only when params.decoding_method is ctc-decoding.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
G:
An LM. It is not None when params.decoding_method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
Returns:
Return the decoding result. See above description for the format of
the returned dict. Note: If it decodes to nothing, then return None.
"""
if HLG is not None:
device = HLG.device
else:
device = H.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, _ = model.encoder(feature, feature_lens)
nnet_output = model.ctc_output(encoder_out)
# nnet_output is (N, T, C)
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
torch.div(
supervisions["start_frame"],
params.subsampling_factor,
rounding_mode="trunc",
),
torch.div(
supervisions["num_frames"],
params.subsampling_factor,
rounding_mode="trunc",
),
),
1,
).to(torch.int32)
if H is None:
assert HLG is not None
decoding_graph = HLG
else:
assert HLG is None
assert bpe_model is not None
decoding_graph = H
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=decoding_graph,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
if params.decoding_method == "ctc-decoding":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
# Note: `best_path.aux_labels` contains token IDs, not word IDs
# since we are using H, not HLG here.
#
# token_ids is a lit-of-list of IDs
token_ids = get_texts(best_path)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
key = "ctc-decoding"
return {key: hyps}
if params.decoding_method == "nbest-oracle":
# Note: You can also pass rescored lattices to it.
# We choose the HLG decoded lattice for speed reasons
# as HLG decoding is faster and the oracle WER
# is only slightly worse than that of rescored lattices.
best_path = nbest_oracle(
lattice=lattice,
num_paths=params.num_paths,
ref_texts=supervisions["text"],
word_table=word_table,
nbest_scale=params.nbest_scale,
oov="<UNK>",
)
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
return {key: hyps}
if params.decoding_method in ["1best", "nbest"]:
if params.decoding_method == "1best":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
key = "no_rescore"
else:
best_path = nbest_decoding(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
nbest_scale=params.nbest_scale,
)
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
return {key: hyps}
assert params.decoding_method in [
"nbest-rescoring",
"whole-lattice-rescoring",
]
# lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
# lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
# lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
lm_scale_list = [0.6, 0.7, 0.8, 0.9]
if params.decoding_method == "nbest-rescoring":
best_path_dict = rescore_with_n_best_list(
lattice=lattice,
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
nbest_scale=params.nbest_scale,
)
elif params.decoding_method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=lm_scale_list,
)
else:
assert False, f"Unsupported decoding method: {params.decoding_method}"
ans = dict()
if best_path_dict is not None:
for lm_scale_str, best_path in best_path_dict.items():
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
ans[lm_scale_str] = hyps
else:
ans = None
return ans
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
word_table: k2.SymbolTable,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
HLG:
The decoding graph. Used only when params.decoding_method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.decoding_method is ctc-decoding.
bpe_model:
The BPE model. Used only when params.decoding_method is ctc-decoding.
word_table:
It is the word symbol table.
G:
An LM. It is not None when params.decoding_method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
Returns:
Return a dict, whose key may be "no-rescore" if no LM rescoring
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
HLG=HLG,
H=H,
bpe_model=bpe_model,
batch=batch,
word_table=word_table,
G=G,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}-{key}", results)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.res_dir / f"{wer}-{test_set_name}-{key}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
args.lm_dir = Path(args.lm_dir)
params = get_params()
# add decoding params
params.update(get_decoding_params())
params.update(vars(args))
assert params.decoding_method in (
"ctc-decoding",
"1best",
"nbest",
"nbest-rescoring",
"whole-lattice-rescoring",
"nbest-oracle",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
params.suffix += f"-hlg-scale-{params.hlg_scale}"
if params.use_averaged_model:
params.suffix += "-uam"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
logging.info(pprint.pformat(params, indent=2))
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
params.vocab_size = num_classes
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = 0
if params.decoding_method == "ctc-decoding":
assert "lang_bpe" in str(
params.lang_dir
), "ctc-decoding only supports BPE lexicons."
HLG = None
H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
device=device,
)
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(str(params.lang_dir / "bpe.model"))
else:
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
)
assert HLG.requires_grad is False
HLG.scores *= params.hlg_scale
if not hasattr(HLG, "lm_scores"):
HLG.lm_scores = HLG.scores.clone()
if params.decoding_method in (
"nbest-rescoring",
"whole-lattice-rescoring",
):
if not (params.lm_dir / "G_4_gram.pt").is_file():
logging.info("Loading G_4_gram.fst.txt")
logging.warning("It may take 8 minutes.")
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
first_word_disambig_id = lexicon.word_table["#0"]
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
# G.aux_labels is not needed in later computations, so
# remove it here.
del G.aux_labels
# CAUTION: The following line is crucial.
# Arcs entering the back-off state have label equal to #0.
# We have to change it to 0 here.
G.labels[G.labels >= first_word_disambig_id] = 0
# See https://github.com/k2-fsa/k2/issues/874
# for why we need to set G.properties to None
G.__dict__["_properties"] = None
G = k2.Fsa.from_fsas([G]).to(device)
G = k2.arc_sort(G)
# Save a dummy value so that it can be loaded in C++.
# See https://github.com/pytorch/pytorch/issues/67902
# for why we need to do this.
G.dummy = 1
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
G = k2.Fsa.from_dict(d)
if params.decoding_method == "whole-lattice-rescoring":
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G = G.to(device)
# G.lm_scores is used to replace HLG.lm_scores during
# LM rescoring.
G.lm_scores = G.scores.clone()
else:
G = None
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
enc_param = sum([p.numel() for p in model.encoder.parameters()])
ctc_param = sum([p.numel() for p in model.ctc_output.parameters()])
logging.info(f"Number of model parameters: {num_param}")
logging.info(f"Parameters for CTC decoding: {enc_param + ctc_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
HLG=HLG,
H=H,
bpe_model=bpe_model,
word_table=lexicon.word_table,
G=G,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,717 @@
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import pprint
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--exp-dir",
type=str,
default="tiny_transducer_ctc/exp",
help="The experiment dir",
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=1,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
default="fast_beam_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_LG
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search, fast_beam_search_LG,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.1,
help="""
Used only when --decoding_method is fast_beam_search_LG or
fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=100,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_LG,
fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
hyps = []
if (
params.decoding_method == "fast_beam_search"
or params.decoding_method == "fast_beam_search_LG"
):
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
if params.decoding_method == "fast_beam_search":
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.res_dir / f"{wer}-{test_set_name}-{key}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_LG",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
)
if "lang_phone" in str(params.lang_dir):
assert params.decoding_method in (
"fast_beam_search_LG",
"fast_beam_search_nbest_LG",
), "For phone lexicon, use a decoding method with LG."
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-uam"
setup_logger(f"{params.res_dir}/log-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
lexicon = Lexicon(params.lang_dir)
if "lang_bpe" in str(params.lang_dir):
sp = spm.SentencePieceProcessor()
sp.load(str(params.lang_dir / "bpe.model"))
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
else:
params.blank_id = lexicon.token_table.get("<eps>")
params.unk_id = lexicon.token_table.get("SPN")
params.vocab_size = max(lexicon.tokens) + 1
sp = None
logging.info(pprint.pformat(params, indent=2))
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
if "fast_beam_search" in params.decoding_method:
if "LG" in params.decoding_method:
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()])
enc_param = sum([p.numel() for p in model.encoder.parameters()])
dec_param = sum([p.numel() for p in model.decoder.parameters()])
join_param = sum([p.numel() for p in model.joiner.parameters()])
logging.info(f"Number of model parameters: {num_param}")
logging.info(
f"Parameters for transducer decoding: {enc_param + dec_param + join_param}"
)
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless7/decoder.py

View File

@ -0,0 +1,378 @@
#!/usr/bin/env python3
# Copyright (c) 2022 Spacetouch Inc. (author: Tiance Wang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch
import torch.nn.functional as F
from encoder_interface import EncoderInterface
from scaling import ActivationBalancer, DoubleSwish
from torch import Tensor, nn
class Conv1dNet(EncoderInterface):
"""
1D Convolution network with causal squeeze and excitation
module and optional skip connections.
Latency: 80ms + (conv_layers+1) // 2 * 40ms, assuming 10ms stride.
Args:
output_dim (int): Number of output channels of the last layer.
input_dim (int): Number of input features
conv_layers (int): Number of convolution layers,
excluding the subsampling layers.
channels (int): Number of output channels for each layer,
except the last layer.
subsampling_factor (int): The subsampling factor for the model.
skip_add (bool): Whether to use skip connection for each convolution layer.
dscnn (bool): Whether to use depthwise-separated convolution.
activation (str): Activation function type.
"""
def __init__(
self,
output_dim: int,
input_dim: int = 80,
conv_layers: int = 10,
channels: int = 256,
subsampling_factor: int = 4,
skip_add: bool = False,
dscnn: bool = True,
activation: str = "relu",
) -> None:
super().__init__()
assert subsampling_factor == 4, "Only support subsampling = 4"
self.conv_layers = conv_layers
self.skip_add = skip_add
# 80ms latency for subsample_layer
self.subsample_layer = nn.Sequential(
conv1d_bn_block(
input_dim, channels, 9, stride=2, activation=activation, dscnn=dscnn
),
conv1d_bn_block(
channels, channels, 5, stride=2, activation=activation, dscnn=dscnn
),
)
self.conv_blocks = nn.ModuleList()
cin = [channels] * conv_layers
cout = [channels] * (conv_layers - 1) + [output_dim]
# Use causal and standard convolution alternatively
for ly in range(conv_layers):
self.conv_blocks.append(
nn.Sequential(
conv1d_bn_block(
cin[ly],
cout[ly],
3,
activation=activation,
dscnn=dscnn,
causal=ly % 2,
),
CausalSqueezeExcite1d(cout[ly], 16, 30),
)
)
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
`x` before padding.
Returns:
Return a tuple containing 2 tensors:
- embeddings: its shape is (batch_size, output_seq_len, encoder_dims)
- lengths, a tensor of shape (batch_size,) containing the number
of frames in `embeddings` before padding.
"""
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.subsample_layer(x)
for idx, layer in enumerate(self.conv_blocks):
if self.skip_add and 0 < idx < self.conv_layers - 1:
x = layer(x) + x
else:
x = layer(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
lengths = x_lens >> 2
return x, lengths
def get_activation(
name: str,
channels: int,
channel_dim: int = -1,
min_val: int = 0,
max_val: int = 1,
) -> nn.Module:
"""
Get activation function from name in string.
Args:
name: activation function name
channels: only used for PReLU, should be equal to x.shape[1].
channel_dim: the axis/dimension corresponding to the channel,
interprted as an offset from the input's ndim if negative.
e.g. for NCHW tensor, channel_dim = 1
min_val: minimum value of hardtanh
max_val: maximum value of hardtanh
Returns:
The activation function module
"""
act_layer = nn.Identity()
name = name.lower()
if name == "prelu":
act_layer = nn.PReLU(channels)
elif name == "relu":
act_layer = nn.ReLU()
elif name == "relu6":
act_layer = nn.ReLU6()
elif name == "hardtanh":
act_layer = nn.Hardtanh(min_val, max_val)
elif name in ["swish", "silu"]:
act_layer = nn.SiLU()
elif name == "elu":
act_layer = nn.ELU()
elif name == "doubleswish":
act_layer = nn.Sequential(
ActivationBalancer(num_channels=channels, channel_dim=channel_dim),
DoubleSwish(),
)
elif name == "":
act_layer = nn.Identity()
else:
raise Exception(f"Unknown activation function: {name}")
return act_layer
class CausalSqueezeExcite1d(nn.Module):
"""
Causal squeeze and excitation module with input and output shape
(batch, channels, time). The global average pooling in the original
SE module is replaced by a causal filter, so
the layer does not introduce any algorithmic latency.
Args:
channels (int): Number of channels
reduction (int): channel reduction rate
context_window (int): Context window size for the moving average operation.
For EMA, the smoothing factor is 1 / context_window.
"""
def __init__(
self,
channels: int,
reduction: int = 16,
context_window: int = 10,
) -> None:
super(CausalSqueezeExcite1d, self).__init__()
assert channels >= reduction
self.context_window = context_window
c_squeeze = channels // reduction
self.linear1 = nn.Linear(channels, c_squeeze, bias=True)
self.act1 = nn.ReLU(inplace=True)
self.linear2 = nn.Linear(c_squeeze, channels, bias=True)
self.act2 = nn.Sigmoid()
# EMA worked better than MA empirically
# self.avg_filter = self.moving_avg
self.avg_filter = self.exponential_moving_avg
self.ema_matrix = torch.tensor([0])
self.ema_matrix_size = 0
def _precompute_ema_matrix(self, N: int, device: torch.device):
a = 1.0 / self.context_window # smoothing factor
w = [[(1 - a) ** k * a for k in range(n, n - N, -1)] for n in range(N)]
w = torch.tensor(w).to(device).tril()
w[:, 0] *= self.context_window
self.ema_matrix = w.T
self.ema_matrix_size = N
def exponential_moving_avg(self, x: Tensor) -> Tensor:
"""
Exponential moving average filter, which is calculated as:
y[t] = (1-a) * y[t-1] + a * x[t]
where a = 1 / self.context_window is the smoothing factor.
For training, the iterative version is too slow. A better way is
to expand y[t] as a function of x[0..t] only and use matrix-vector multiplication.
The weight matrix can be precomputed if the smoothing factor is fixed.
"""
if self.training:
# use matrix version to speed up training
N = x.shape[-1]
if N > self.ema_matrix_size:
self._precompute_ema_matrix(N, x.device)
y = torch.matmul(x, self.ema_matrix[:N, :N])
else:
# use iterative version to save memory
a = 1.0 / self.context_window
y = torch.empty_like(x)
y[:, :, 0] = x[:, :, 0]
for t in range(1, y.shape[-1]):
y[:, :, t] = (1 - a) * y[:, :, t - 1] + a * x[:, :, t]
return y
def moving_avg(self, x: Tensor) -> Tensor:
"""
Simple moving average with context_window as window size.
"""
y = torch.empty_like(x)
k = min(x.shape[2], self.context_window)
w = [[1 / n] * n + [0] * (k - n - 1) for n in range(1, k)]
w = torch.tensor(w, device=x.device)
y[:, :, : k - 1] = torch.matmul(x[:, :, : k - 1], w.T)
y[:, :, k - 1 :] = F.avg_pool1d(x, k, 1)
return y
def forward(self, x: Tensor) -> Tensor:
assert len(x.shape) == 3, "Input is not a 3D tensor!"
y = self.exponential_moving_avg(x)
y = y.permute(0, 2, 1) # make channel last for squeeze op
y = self.act1(self.linear1(y))
y = self.act2(self.linear2(y))
y = y.permute(0, 2, 1) # back to original shape
y = x * y
return y
def conv1d_bn_block(
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
dilation: int = 1,
activation: str = "relu",
dscnn: bool = False,
causal: bool = False,
) -> nn.Sequential:
"""
Conv1d - batchnorm - activation block.
If kernel size is even, output length = input length + 1.
Otherwise, output and input lengths are equal.
Args:
in_channels (int): Number of input channels
out_channels (int): Number of output channels
kernel_size (int): kernel size
stride (int): convolution stride
dilation (int): convolution dilation rate
dscnn (bool): Use depthwise separated convolution.
causal (bool): Use causal convolution
activation (str): Activation function type.
"""
if dscnn:
return nn.Sequential(
CausalConv1d(
in_channels,
in_channels,
kernel_size,
stride=stride,
dilation=dilation,
groups=in_channels,
bias=False,
)
if causal
else nn.Conv1d(
in_channels,
in_channels,
kernel_size,
stride=stride,
padding=(kernel_size // 2) * dilation,
dilation=dilation,
groups=in_channels,
bias=False,
),
nn.BatchNorm1d(in_channels),
get_activation(activation, in_channels),
nn.Conv1d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm1d(out_channels),
get_activation(activation, out_channels),
)
else:
return nn.Sequential(
CausalConv1d(
in_channels,
out_channels,
kernel_size,
stride=stride,
dilation=dilation,
bias=False,
)
if causal
else nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=(kernel_size // 2) * dilation,
dilation=dilation,
bias=False,
),
nn.BatchNorm1d(out_channels),
get_activation(activation, out_channels),
)
class CausalConv1d(nn.Module):
"""
Causal convolution with padding automatically chosen to match input/output length.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
) -> None:
super(CausalConv1d, self).__init__()
assert kernel_size > 2
self.padding = dilation * (kernel_size - 1)
self.stride = stride
self.conv = nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride,
self.padding,
dilation,
groups,
bias=bias,
)
def forward(self, x: Tensor) -> Tensor:
return self.conv(x)[:, :, : -self.padding // self.stride]

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/encoder_interface.py

View File

@ -0,0 +1,305 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
(1) Export to torchscript model using torch.jit.script()
./tiny_transducer_ctc/export.py \
--exp-dir ./tiny_transducer_ctc/exp \
--lang-dir data/lang_bpe_500 \
--epoch 30 \
--avg 2 \
--jit 1
It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
load it by `torch.jit.load("cpu_jit.pt")`.
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
Check
https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall.
(2) Export `model.state_dict()`
./tiny_transducer_ctc/export.py \
--exp-dir ./tiny_transducer_ctc/exp \
--lang-dir data/lang_bpe_500 \
--epoch 30 \
--avg 2
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
load it by `icefall.checkpoint.load_checkpoint()`.
To use the generated file with `tiny_transducer_ctc/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./tiny_transducer_ctc/decode.py \
--exp-dir ./tiny_transducer_ctc/exp \
--epoch 9999 \
--use-averaged-model 0
--avg 1 \
--max-duration 600 \
--decoding-method greedy_search \
--lang-dir data/lang_bpe_500 \
Check ./pretrained.py for its usage.
"""
import argparse
import logging
from pathlib import Path
import k2
import torch
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import num_tokens, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=1,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="tiny_transducer_ctc/exp_4m_bpe500_halfdelay_specaug",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--tokens",
type=str,
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
It will generate a file named cpu_jit.pt
Check ./jit_pretrained.py for how to use it.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
if params.jit is True:
logging.info("Using torch.jit.script()")
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torchscript. Export model.state_dict()")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,271 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads torchscript models, exported by `torch.jit.script()`
and uses them to decode waves.
You can use the following command to get the exported models:
./tiny_transducer_ctc/export.py \
--exp-dir ./tiny_transducer_ctc/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10 \
--jit 1
Usage of this script:
./tiny_transducer_ctc/jit_pretrained.py \
--nn-model-filename ./tiny_transducer_ctc/exp/cpu_jit.pt \
/path/to/foo.wav \
/path/to/bar.wav
"""
import argparse
import logging
import math
from typing import List
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--nn-model-filename",
type=str,
required=True,
help="Path to the torchscript model cpu_jit.pt",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float = 16000
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
def greedy_search(
model: torch.jit.ScriptModule,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
A 3-D tensor of shape (N, T, C)
encoder_out_lens:
A 1-D tensor of shape (N,).
Returns:
Return the decoded results for each utterance.
"""
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
device = encoder_out.device
blank_id = 0 # hard-code to 0
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
context_size = model.decoder.context_size
hyps = [[blank_id] * context_size for _ in range(N)]
decoder_input = torch.tensor(
hyps,
device=device,
dtype=torch.int64,
) # (N, context_size)
decoder_out = model.decoder(
decoder_input,
need_pad=torch.tensor([False]),
).squeeze(1)
offset = 0
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = packed_encoder_out.data[start:end]
current_encoder_out = current_encoder_out
# current_encoder_out's shape: (batch_size, encoder_out_dim)
offset = end
decoder_out = decoder_out[:batch_size]
logits = model.joiner(
current_encoder_out,
decoder_out,
)
# logits'shape (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
hyps[i].append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
decoder_input = torch.tensor(
decoder_input,
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=torch.tensor([False]),
)
decoder_out = decoder_out.squeeze(1)
sorted_ans = [h[context_size:] for h in hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
model = torch.jit.load(args.nn_model_filename)
model.eval()
model.to(device)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files,
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features,
batch_first=True,
padding_value=math.log(1e-10),
)
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(
x=features,
x_lens=feature_lengths,
)
hyps = greedy_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
s = "\n"
for filename, hyp in zip(args.sound_files, hyps):
words = sp.decode(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,426 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads torchscript models, exported by `torch.jit.script()`
and uses them to decode waves.
You can use the following command to get the exported models:
./tiny_transducer_ctc/export.py \
--exp-dir ./tiny_transducer_ctc/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10 \
--jit 1
Usage of this script:
(1) ctc-decoding
./tiny_transducer_ctc/jit_pretrained_ctc.py \
--model-filename ./tiny_transducer_ctc/exp/cpu_jit.pt \
--bpe-model data/lang_bpe_500/bpe.model \
--method ctc-decoding \
--sample-rate 16000 \
/path/to/foo.wav \
/path/to/bar.wav
(2) 1best
./tiny_transducer_ctc/jit_pretrained_ctc.py \
--model-filename ./tiny_transducer_ctc/exp/cpu_jit.pt \
--HLG data/lang_bpe_500/HLG.pt \
--words-file data/lang_bpe_500/words.txt \
--method 1best \
--sample-rate 16000 \
/path/to/foo.wav \
/path/to/bar.wav
(3) nbest-rescoring
./tiny_transducer_ctc/jit_pretrained_ctc.py \
--model-filename ./tiny_transducer_ctc/exp/cpu_jit.pt \
--HLG data/lang_bpe_500/HLG.pt \
--words-file data/lang_bpe_500/words.txt \
--G data/lm/G_4_gram.pt \
--method nbest-rescoring \
--sample-rate 16000 \
/path/to/foo.wav \
/path/to/bar.wav
(4) whole-lattice-rescoring
./tiny_transducer_ctc/jit_pretrained_ctc.py \
--model-filename ./tiny_transducer_ctc/exp/cpu_jit.pt \
--HLG data/lang_bpe_500/HLG.pt \
--words-file data/lang_bpe_500/words.txt \
--G data/lm/G_4_gram.pt \
--method whole-lattice-rescoring \
--sample-rate 16000 \
/path/to/foo.wav \
/path/to/bar.wav
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from ctc_decode import get_decoding_params
from torch.nn.utils.rnn import pad_sequence
from train import get_params
from icefall.decode import (
get_lattice,
one_best_decoding,
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
from icefall.utils import get_texts
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--model-filename",
type=str,
required=True,
help="Path to the torchscript model.",
)
parser.add_argument(
"--words-file",
type=str,
help="""Path to words.txt.
Used only when method is not ctc-decoding.
""",
)
parser.add_argument(
"--HLG",
type=str,
help="""Path to HLG.pt.
Used only when method is not ctc-decoding.
""",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.
Used only when method is ctc-decoding.
""",
)
parser.add_argument(
"--method",
type=str,
default="1best",
help="""Decoding method.
Possible values are:
(0) ctc-decoding - Use CTC decoding. It uses a sentence
piece model, i.e., lang_dir/bpe.model, to convert
word pieces to words. It needs neither a lexicon
nor an n-gram LM.
(1) 1best - Use the best path as decoding output. Only
the transformer encoder output is used for decoding.
We call it HLG decoding.
(2) nbest-rescoring. Extract n paths from the decoding lattice,
rescore them with an LM, the path with
the highest score is the decoding result.
We call it HLG decoding + n-gram LM rescoring.
(3) whole-lattice-rescoring - Use an LM to rescore the
decoding lattice and then use 1best to decode the
rescored lattice.
We call it HLG decoding + n-gram LM rescoring.
""",
)
parser.add_argument(
"--G",
type=str,
help="""An LM for rescoring.
Used only when method is
whole-lattice-rescoring or nbest-rescoring.
It's usually a 4-gram LM.
""",
)
parser.add_argument(
"--num-paths",
type=int,
default=100,
help="""
Used only when method is attention-decoder.
It specifies the size of n-best list.""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=1.3,
help="""
Used only when method is whole-lattice-rescoring and nbest-rescoring.
It specifies the scale for n-gram LM scores.
(Note: You need to tune it on a dataset.)
""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""
Used only when method is nbest-rescoring.
It specifies the scale for lattice.scores when
extracting n-best lists. A smaller value results in
more unique number of paths with the risk of missing
the best path.
""",
)
parser.add_argument(
"--num-classes",
type=int,
default=500,
help="""
Vocab size in the BPE model.
""",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float = 16000
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
# add decoding params
params.update(get_decoding_params())
params.update(vars(args))
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
model = torch.jit.load(args.model_filename)
model.to(device)
model.eval()
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(
x=features,
x_lens=feature_lengths,
)
nnet_output = model.ctc_output(encoder_out)
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
[
[i, 0, feature_lengths[i] // params.subsampling_factor]
for i in range(batch_size)
],
dtype=torch.int32,
)
if params.method == "ctc-decoding":
logging.info("Use CTC decoding")
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(params.bpe_model)
max_token_id = params.num_classes - 1
H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
device=device,
)
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=H,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
token_ids = get_texts(best_path)
hyps = bpe_model.decode(token_ids)
hyps = [s.split() for s in hyps]
elif params.method in [
"1best",
"nbest-rescoring",
"whole-lattice-rescoring",
]:
logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
HLG.lm_scores = HLG.scores.clone()
if params.method in [
"nbest-rescoring",
"whole-lattice-rescoring",
]:
logging.info(f"Loading G from {params.G}")
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
G = G.to(device)
if params.method == "whole-lattice-rescoring":
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
# G.lm_scores is used to replace HLG.lm_scores during
# LM rescoring.
G.lm_scores = G.scores.clone()
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
if params.method == "1best":
logging.info("Use HLG decoding")
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
if params.method == "nbest-rescoring":
logging.info("Use HLG decoding + LM rescoring")
best_path_dict = rescore_with_n_best_list(
lattice=lattice,
G=G,
num_paths=params.num_paths,
lm_scale_list=[params.ngram_lm_scale],
nbest_scale=params.nbest_scale,
)
best_path = next(iter(best_path_dict.values()))
elif params.method == "whole-lattice-rescoring":
logging.info("Use HLG decoding + LM rescoring")
best_path_dict = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=[params.ngram_lm_scale],
)
best_path = next(iter(best_path_dict.values()))
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file)
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
else:
raise ValueError(f"Unsupported decoding method: {params.method}")
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless7/joiner.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless7_ctc/model.py

View File

@ -0,0 +1,357 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads a checkpoint and uses it to decode waves.
You can generate the checkpoint with the following command:
./tiny_transducer_ctc/export.py \
--exp-dir ./tiny_transducer_ctc/exp \
--lang-dir data/lang_bpe_500 \
--epoch 20 \
--avg 10
Usage of this script:
(1) greedy search
./tiny_transducer_ctc/pretrained.py \
--checkpoint ./tiny_transducer_ctc/exp/pretrained.pt \
--lang-dir data/lang_bpe_500 \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
(2) beam search
./tiny_transducer_ctc/pretrained.py \
--checkpoint ./tiny_transducer_ctc/exp/pretrained.pt \
--lang-dir data/lang_bpe_500 \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(3) modified beam search
./tiny_transducer_ctc/pretrained.py \
--checkpoint ./tiny_transducer_ctc/exp/pretrained.pt \
--lang-dir data/lang_bpe_500 \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(4) fast beam search
./tiny_transducer_ctc/pretrained.py \
--checkpoint ./tiny_transducer_ctc/exp/pretrained.pt \
--lang-dir data/lang_bpe_500 \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./tiny_transducer_ctc/exp/epoch-xx.pt`.
Note: ./tiny_transducer_ctc/exp/pretrained.pt is generated by
./tiny_transducer_ctc/export.py
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe_500",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
add_model_arguments(parser)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.lang_dir + "/bpe.model")
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
num_waves = encoder_out.size(0)
hyps = []
msg = f"Using {params.method}"
if params.method == "beam_search":
msg += f" with beam size {params.beam_size}"
logging.info(msg)
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,444 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads torchscript models, exported by `torch.jit.script()`
and uses them to decode waves.
You can use the following command to get the exported models:
./tiny_transducer_ctc/export.py \
--exp-dir ./tiny_transducer_ctc/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10
Usage of this script:
(1) ctc-decoding
./tiny_transducer_ctc/jit_pretrained_ctc.py \
--checkpoint ./tiny_transducer_ctc/exp/pretrained.pt \
--lang-dir data/lang_bpe_500 \
--method ctc-decoding \
--sample-rate 16000 \
/path/to/foo.wav \
/path/to/bar.wav
(2) 1best
./tiny_transducer_ctc/jit_pretrained_ctc.py \
--checkpoint ./tiny_transducer_ctc/exp/pretrained.pt \
--HLG data/lang_bpe_500/HLG.pt \
--words-file data/lang_bpe_500/words.txt \
--method 1best \
--sample-rate 16000 \
/path/to/foo.wav \
/path/to/bar.wav
(3) nbest-rescoring
./tiny_transducer_ctc/jit_pretrained_ctc.py \
--checkpoint ./tiny_transducer_ctc/exp/pretrained.pt \
--HLG data/lang_bpe_500/HLG.pt \
--words-file data/lang_bpe_500/words.txt \
--G data/lm/G_4_gram.pt \
--method nbest-rescoring \
--sample-rate 16000 \
/path/to/foo.wav \
/path/to/bar.wav
(4) whole-lattice-rescoring
./tiny_transducer_ctc/jit_pretrained_ctc.py \
--checkpoint ./tiny_transducer_ctc/exp/pretrained.pt \
--HLG data/lang_bpe_500/HLG.pt \
--words-file data/lang_bpe_500/words.txt \
--G data/lm/G_4_gram.pt \
--method whole-lattice-rescoring \
--sample-rate 16000 \
/path/to/foo.wav \
/path/to/bar.wav
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from ctc_decode import get_decoding_params
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.decode import (
get_lattice,
one_best_decoding,
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
from icefall.utils import get_texts
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
"--words-file",
type=str,
help="""Path to words.txt.
Used only when method is not ctc-decoding.
""",
)
parser.add_argument(
"--HLG",
type=str,
help="""Path to HLG.pt.
Used only when method is not ctc-decoding.
""",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.
Used only when method is ctc-decoding.
""",
)
parser.add_argument(
"--method",
type=str,
default="1best",
help="""Decoding method.
Possible values are:
(0) ctc-decoding - Use CTC decoding. It uses a sentence
piece model, i.e., lang_dir/bpe.model, to convert
word pieces to words. It needs neither a lexicon
nor an n-gram LM.
(1) 1best - Use the best path as decoding output. Only
the transformer encoder output is used for decoding.
We call it HLG decoding.
(2) nbest-rescoring. Extract n paths from the decoding lattice,
rescore them with an LM, the path with
the highest score is the decoding result.
We call it HLG decoding + n-gram LM rescoring.
(3) whole-lattice-rescoring - Use an LM to rescore the
decoding lattice and then use 1best to decode the
rescored lattice.
We call it HLG decoding + n-gram LM rescoring.
""",
)
parser.add_argument(
"--G",
type=str,
help="""An LM for rescoring.
Used only when method is
whole-lattice-rescoring or nbest-rescoring.
It's usually a 4-gram LM.
""",
)
parser.add_argument(
"--num-paths",
type=int,
default=100,
help="""
Used only when method is attention-decoder.
It specifies the size of n-best list.""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=1.3,
help="""
Used only when method is whole-lattice-rescoring and nbest-rescoring.
It specifies the scale for n-gram LM scores.
(Note: You need to tune it on a dataset.)
""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""
Used only when method is nbest-rescoring.
It specifies the scale for lattice.scores when
extracting n-best lists. A smaller value results in
more unique number of paths with the risk of missing
the best path.
""",
)
parser.add_argument(
"--num-classes",
type=int,
default=500,
help="""
Vocab size in the BPE model.
""",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
add_model_arguments(parser)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float = 16000
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
# add decoding params
params.update(get_decoding_params())
params.update(vars(args))
params.vocab_size = params.num_classes
params.blank_id = 0
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(
x=features,
x_lens=feature_lengths,
)
nnet_output = model.ctc_output(encoder_out)
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
[
[i, 0, feature_lengths[i] // params.subsampling_factor]
for i in range(batch_size)
],
dtype=torch.int32,
)
if params.method == "ctc-decoding":
logging.info("Use CTC decoding")
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(params.bpe_model)
max_token_id = params.num_classes - 1
H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
device=device,
)
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=H,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
token_ids = get_texts(best_path)
hyps = bpe_model.decode(token_ids)
hyps = [s.split() for s in hyps]
elif params.method in [
"1best",
"nbest-rescoring",
"whole-lattice-rescoring",
]:
logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
HLG.lm_scores = HLG.scores.clone()
if params.method in [
"nbest-rescoring",
"whole-lattice-rescoring",
]:
logging.info(f"Loading G from {params.G}")
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
G = G.to(device)
if params.method == "whole-lattice-rescoring":
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
# G.lm_scores is used to replace HLG.lm_scores during
# LM rescoring.
G.lm_scores = G.scores.clone()
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
if params.method == "1best":
logging.info("Use HLG decoding")
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
if params.method == "nbest-rescoring":
logging.info("Use HLG decoding + LM rescoring")
best_path_dict = rescore_with_n_best_list(
lattice=lattice,
G=G,
num_paths=params.num_paths,
lm_scale_list=[params.ngram_lm_scale],
nbest_scale=params.nbest_scale,
)
best_path = next(iter(best_path_dict.values()))
elif params.method == "whole-lattice-rescoring":
logging.info("Use HLG decoding + LM rescoring")
best_path_dict = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=[params.ngram_lm_scale],
)
best_path = next(iter(best_path_dict.values()))
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file)
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
else:
raise ValueError(f"Unsupported decoding method: {params.method}")
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless7/scaling.py

File diff suppressed because it is too large Load Diff

View File

@ -71,6 +71,10 @@ class Decoder(nn.Module):
groups=embedding_dim,
bias=False,
)
else:
# To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'`
# when inference with torch.jit.script and context_size == 1
self.conv = nn.Identity()
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
"""

View File

@ -17,7 +17,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from scaling import Balancer
@ -95,6 +94,11 @@ class Decoder(nn.Module):
max_abs=1.0,
prob=0.05,
)
else:
# To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'`
# when inference with torch.jit.script and context_size == 1
self.conv = nn.Identity()
self.balancer2 = nn.Identity()
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
"""

View File

@ -22,7 +22,7 @@ from typing import Dict, List, Optional, Tuple, Union
import torch
from lhotse.utils import fix_random_seed
from torch import Tensor
from torch import Tensor, nn
from torch.optim import Optimizer
@ -116,7 +116,7 @@ class BatchedOptimizer(Optimizer):
yield tuples # <-- calling code will do the actual optimization here!
for (stacked_params, _state, _names), batch in zip(tuples, batches):
for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
for i, p in enumerate(batch): # batch is list of Parameter
p.copy_(stacked_params[i])
@ -181,6 +181,7 @@ class ScaledAdam(BatchedOptimizer):
size_update_period=4,
clipping_update_period=100,
):
defaults = dict(
lr=lr,
clipping_scale=clipping_scale,
@ -326,7 +327,9 @@ class ScaledAdam(BatchedOptimizer):
batch = True
for group, group_params_names in zip(self.param_groups, self.parameters_names):
with self.batched_params(group["params"], group_params_names) as batches:
# batches is list of pairs (stacked_param, state). stacked_param is like
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
# a stacking dim, it is not a real dim.
@ -423,16 +426,19 @@ class ScaledAdam(BatchedOptimizer):
# parameters' state won't have been initialized yet.
return 1.0
clipping_update_period = group["clipping_update_period"]
scalar_lr_scale = group["scalar_lr_scale"]
tot_sumsq = torch.tensor(0.0, device=first_p.device)
for p, state, param_names in tuples:
for (p, state, param_names) in tuples:
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
"ScaledAdam optimizer does not support sparse gradients"
)
if p.numel() == p.shape[0]: # a batch of scalars
tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
tot_sumsq += (grad**2).sum() * (
scalar_lr_scale**2
) # sum() to change shape [1] to []
else:
tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
@ -443,64 +449,74 @@ class ScaledAdam(BatchedOptimizer):
)
first_state["model_norms"][step % clipping_update_period] = tot_norm
if step % clipping_update_period == 0:
irregular_estimate_steps = [
i for i in [10, 20, 40] if i < clipping_update_period
]
if step % clipping_update_period == 0 or step in irregular_estimate_steps:
# Print some stats.
# We don't reach here if step == 0 because we would have returned
# above.
sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
if step in irregular_estimate_steps:
sorted_norms = sorted_norms[-step:]
num_norms = sorted_norms.numel()
quartiles = []
for n in range(0, 5):
index = min(
clipping_update_period - 1, (clipping_update_period // 4) * n
)
index = min(num_norms - 1, (num_norms // 4) * n)
quartiles.append(sorted_norms[index].item())
median = quartiles[2]
if median - median != 0:
raise RuntimeError("Too many grads were not finite")
threshold = clipping_scale * median
if step in irregular_estimate_steps:
# use larger thresholds on first few steps of estimating threshold,
# as norm may be changing rapidly.
threshold = threshold * 2.0
first_state["model_norm_threshold"] = threshold
percent_clipped = (
first_state["num_clipped"] * 100.0 / clipping_update_period
first_state["num_clipped"] * 100.0 / num_norms
if "num_clipped" in first_state
else 0.0
)
first_state["num_clipped"] = 0
quartiles = " ".join(["%.3e" % x for x in quartiles])
logging.info(
logging.warn(
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
)
if step < clipping_update_period:
return 1.0 # We have not yet estimated a norm to clip to.
else:
try:
model_norm_threshold = first_state["model_norm_threshold"]
except KeyError:
logging.info(
"Warning: model_norm_threshold not in state: possibly "
"you changed config when restarting, adding clipping_scale option?"
)
return 1.0
ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
if ans < 1.0:
first_state["num_clipped"] += 1
if ans < 0.1:
logging.warn(
f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
)
if self.show_dominant_parameters:
assert p.shape[0] == len(param_names)
self._show_gradient_dominating_parameter(tuples, tot_sumsq)
if ans != ans: # e.g. ans is nan
ans = 0.0
if ans == 0.0:
for p, state, param_names in tuples:
p.grad.zero_() # get rid of infinity()
try:
model_norm_threshold = first_state["model_norm_threshold"]
except KeyError:
return 1.0 # threshold has not yet been set.
return ans
ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
if ans != ans: # e.g. ans is nan
ans = 0.0
if ans < 1.0:
first_state["num_clipped"] += 1
if ans < 0.1:
logging.warn(
f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
)
if self.show_dominant_parameters:
assert p.shape[0] == len(param_names)
self._show_gradient_dominating_parameter(
tuples, tot_sumsq, group["scalar_lr_scale"]
)
if ans == 0.0:
for (p, state, param_names) in tuples:
p.grad.zero_() # get rid of infinity()
return ans
def _show_gradient_dominating_parameter(
self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor
self,
tuples: List[Tuple[Tensor, dict, List[str]]],
tot_sumsq: Tensor,
scalar_lr_scale: float,
):
"""
Show information of parameter which dominates tot_sumsq.
@ -516,29 +532,30 @@ class ScaledAdam(BatchedOptimizer):
from tuples, we still pass it to save some time.
"""
all_sumsq_orig = {}
for p, state, batch_param_names in tuples:
for (p, state, batch_param_names) in tuples:
# p is a stacked batch parameters.
batch_grad = p.grad
if p.numel() == p.shape[0]: # a batch of scalars
batch_sumsq_orig = batch_grad**2
# Dummy values used by following `zip` statement.
batch_rms_orig = torch.ones(p.shape[0])
batch_rms_orig = torch.full(
p.shape, scalar_lr_scale, device=batch_grad.device
)
else:
batch_rms_orig = state["param_rms"]
batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(
batch_sumsq_orig = (batch_grad * batch_rms_orig) ** 2
if batch_grad.ndim > 1:
# need to guard it with if-statement because sum() sums over
# all dims if dim == ().
batch_sumsq_orig = batch_sumsq_orig.sum(
dim=list(range(1, batch_grad.ndim))
)
for name, sumsq_orig, rms, grad in zip(
batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
):
proportion_orig = sumsq_orig / tot_sumsq
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
assert torch.isclose(
sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
torch.tensor(1.0),
)
sorted_by_proportion = {
k: v
for k, v in sorted(
@ -552,7 +569,7 @@ class ScaledAdam(BatchedOptimizer):
dominant_rms,
dominant_grad,
) = sorted_by_proportion[dominant_param_name]
logging.info(
logging.warn(
f"Parameter dominating tot_sumsq {dominant_param_name}"
f" with proportion {dominant_proportion:.2f},"
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
@ -826,7 +843,7 @@ class LRScheduler(object):
def print_lr(self, is_verbose, group, lr):
"""Display the current learning rate."""
if is_verbose:
logging.info(
logging.warn(
f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
f" of group {group} to {lr:.4e}."
)
@ -841,8 +858,14 @@ class Eden(LRScheduler):
where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches
and then stays constant at 1.
If you don't have the concept of epochs, or one epoch takes a very long time,
you can replace the notion of 'epoch' with some measure of the amount of data
processed, e.g. hours of data or frames of data, with 'lr_epochs' being set to
some measure representing "quite a lot of data": say, one fifth or one third
of an entire training run, but it doesn't matter much. You could also use
Eden2 which has only the notion of batches.
E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
We suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
Args:
optimizer: the optimizer to change the learning rates on
@ -888,6 +911,56 @@ class Eden(LRScheduler):
return [x * factor * warmup_factor for x in self.base_lrs]
class Eden2(LRScheduler):
"""
Eden2 scheduler, simpler than Eden because it does not use the notion of epoch,
only batches.
The basic formula (before warmup) is:
lr = base_lr * ((batch**2 + lr_batches**2) / lr_batches**2) ** -0.5) * warmup
where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches
and then stays constant at 1.
E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
Args:
optimizer: the optimizer to change the learning rates on
lr_batches: the number of batches after which we start significantly
decreasing the learning rate, suggest 5000.
"""
def __init__(
self,
optimizer: Optimizer,
lr_batches: Union[int, float],
warmup_batches: Union[int, float] = 500.0,
warmup_start: float = 0.5,
verbose: bool = False,
):
super().__init__(optimizer, verbose)
self.lr_batches = lr_batches
self.warmup_batches = warmup_batches
assert 0.0 <= warmup_start <= 1.0, warmup_start
self.warmup_start = warmup_start
def get_lr(self):
factor = (
(self.batch**2 + self.lr_batches**2) / self.lr_batches**2
) ** -0.5
warmup_factor = (
1.0
if self.batch >= self.warmup_batches
else self.warmup_start
+ (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
# else 0.5 + 0.5 * (self.batch / self.warmup_batches)
)
return [x * factor * warmup_factor for x in self.base_lrs]
def _test_eden():
m = torch.nn.Linear(100, 100)
optim = ScaledAdam(m.parameters(), lr=0.03)

View File

@ -0,0 +1 @@
../tdnn_lstm_ctc/asr_datamodule.py

View File

@ -0,0 +1,886 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from train import add_model_arguments, get_ctc_model, get_params
from transformer import encoder_padding_mask
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.decode import (
get_lattice,
nbest_decoding,
nbest_oracle,
one_best_decoding,
rescore_with_attention_decoder,
rescore_with_n_best_list,
rescore_with_rnn_lm,
rescore_with_whole_lattice,
)
from icefall.lexicon import Lexicon
from icefall.rnn_lm.model import RnnLmModel
from icefall.utils import (
AttributeDict,
get_texts,
load_averaged_model,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=77,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=55,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--method",
type=str,
default="attention-decoder",
help="""Decoding method.
Supported values are:
- (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
It needs neither a lexicon nor an n-gram LM.
- (1) 1best. Extract the best path from the decoding lattice as the
decoding result.
- (2) nbest. Extract n paths from the decoding lattice; the path
with the highest score is the decoding result.
- (3) nbest-rescoring. Extract n paths from the decoding lattice,
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
the highest score is the decoding result.
- (4) whole-lattice-rescoring. Rescore the decoding lattice with an
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
is the decoding result.
- (5) attention-decoder. Extract n paths from the LM rescored
lattice, the path with the highest score is the decoding result.
- (6) rnn-lm. Rescoring with attention-decoder and RNN LM. We assume
you have trained an RNN LM using ./rnn_lm/train.py
- (7) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best
rescoring method.
""",
)
parser.add_argument(
"--num-paths",
type=int,
default=100,
help="""Number of paths for n-best based decoding method.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle
""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""The scale to be applied to `lattice.scores`.
It's needed if you use any kinds of n-best based rescoring.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle
A smaller value results in more unique paths.
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer_ctc/exp",
help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe_500",
help="The lang dir",
)
parser.add_argument(
"--lm-dir",
type=str,
default="data/lm",
help="""The n-gram LM dir.
It should contain either G_4_gram.pt or G_4_gram.fst.txt
""",
)
parser.add_argument(
"--rnn-lm-exp-dir",
type=str,
default="rnn_lm/exp",
help="""Used only when --method is rnn-lm.
It specifies the path to RNN LM exp dir.
""",
)
parser.add_argument(
"--rnn-lm-epoch",
type=int,
default=7,
help="""Used only when --method is rnn-lm.
It specifies the checkpoint to use.
""",
)
parser.add_argument(
"--rnn-lm-avg",
type=int,
default=2,
help="""Used only when --method is rnn-lm.
It specifies the number of checkpoints to average.
""",
)
parser.add_argument(
"--rnn-lm-embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--rnn-lm-hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--rnn-lm-num-layers",
type=int,
default=4,
help="Number of RNN layers the model",
)
parser.add_argument(
"--rnn-lm-tie-weights",
type=str2bool,
default=False,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
rnn_lm_model: Optional[nn.Module],
HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
batch: dict,
word_table: k2.SymbolTable,
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if no rescoring is used, the key is the string `no_rescore`.
If LM rescoring is used, the key is the string `lm_scale_xxx`,
where `xxx` is the value of `lm_scale`. An example key is
`lm_scale_0.7`
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
- params.method is "1best", it uses 1best decoding without LM rescoring.
- params.method is "nbest", it uses nbest decoding without LM rescoring.
- params.method is "nbest-rescoring", it uses nbest LM rescoring.
- params.method is "whole-lattice-rescoring", it uses whole lattice LM
rescoring.
model:
The neural model.
rnn_lm_model:
The neural model for RNN LM.
HLG:
The decoding graph. Used only when params.method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.method is ctc-decoding.
bpe_model:
The BPE model. Used only when params.method is ctc-decoding.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
sos_id:
The token ID of the SOS.
eos_id:
The token ID of the EOS.
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
Returns:
Return the decoding result. See above description for the format of
the returned dict. Note: If it decodes to nothing, then return None.
"""
if HLG is not None:
device = HLG.device
else:
device = H.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
nnet_output, _ = model.encoder(feature, feature_lens)
ctc_output = model.ctc_output(nnet_output)
# nnet_output is (N, T, C)
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
supervisions["start_frame"] // params.subsampling_factor,
supervisions["num_frames"] // params.subsampling_factor,
),
1,
).to(torch.int32)
if H is None:
assert HLG is not None
decoding_graph = HLG
else:
assert HLG is None
assert bpe_model is not None
decoding_graph = H
lattice = get_lattice(
nnet_output=ctc_output,
decoding_graph=decoding_graph,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
if params.method == "ctc-decoding":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
# Note: `best_path.aux_labels` contains token IDs, not word IDs
# since we are using H, not HLG here.
#
# token_ids is a lit-of-list of IDs
token_ids = get_texts(best_path)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
key = "ctc-decoding"
return {key: hyps}
if params.method == "nbest-oracle":
# Note: You can also pass rescored lattices to it.
# We choose the HLG decoded lattice for speed reasons
# as HLG decoding is faster and the oracle WER
# is only slightly worse than that of rescored lattices.
best_path = nbest_oracle(
lattice=lattice,
num_paths=params.num_paths,
ref_texts=supervisions["text"],
word_table=word_table,
nbest_scale=params.nbest_scale,
oov="<UNK>",
)
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
return {key: hyps}
if params.method in ["1best", "nbest"]:
if params.method == "1best":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
key = "no_rescore"
else:
best_path = nbest_decoding(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
nbest_scale=params.nbest_scale,
)
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
return {key: hyps}
assert params.method in [
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder",
"rnn-lm",
]
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
nnet_output = nnet_output.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
mask = encoder_padding_mask(nnet_output.size(0), supervisions)
mask = mask.to(nnet_output.device) if mask is not None else None
mmodel = model.decoder.module if hasattr(model.decoder, "module") else model.decoder
if params.method == "nbest-rescoring":
best_path_dict = rescore_with_n_best_list(
lattice=lattice,
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
nbest_scale=params.nbest_scale,
)
elif params.method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=lm_scale_list,
)
elif params.method == "attention-decoder":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=None,
)
best_path_dict = rescore_with_attention_decoder(
lattice=rescored_lattice,
num_paths=params.num_paths,
model=mmodel,
memory=nnet_output,
memory_key_padding_mask=mask,
sos_id=sos_id,
eos_id=eos_id,
nbest_scale=params.nbest_scale,
)
elif params.method == "rnn-lm":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=None,
)
best_path_dict = rescore_with_rnn_lm(
lattice=rescored_lattice,
num_paths=params.num_paths,
rnn_lm_model=rnn_lm_model,
model=mmodel,
memory=nnet_output,
memory_key_padding_mask=mask,
sos_id=sos_id,
eos_id=eos_id,
blank_id=0,
nbest_scale=params.nbest_scale,
)
else:
assert False, f"Unsupported decoding method: {params.method}"
ans = dict()
if best_path_dict is not None:
for lm_scale_str, best_path in best_path_dict.items():
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
ans[lm_scale_str] = hyps
else:
ans = None
return ans
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
rnn_lm_model: Optional[nn.Module],
HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
word_table: k2.SymbolTable,
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
rnn_lm_model:
The neural model for RNN LM.
HLG:
The decoding graph. Used only when params.method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.method is ctc-decoding.
bpe_model:
The BPE model. Used only when params.method is ctc-decoding.
word_table:
It is the word symbol table.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
Returns:
Return a dict, whose key may be "no-rescore" if no LM rescoring
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
rnn_lm_model=rnn_lm_model,
HLG=HLG,
H=H,
bpe_model=bpe_model,
batch=batch,
word_table=word_table,
G=G,
sos_id=sos_id,
eos_id=eos_id,
)
if hyps_dict is not None:
for lm_scale, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[lm_scale].extend(this_batch)
else:
assert len(results) > 0, "It should not decode to empty in the first batch!"
this_batch = []
hyp_words = []
for ref_text in texts:
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
for lm_scale in results.keys():
results[lm_scale].extend(this_batch)
num_cuts += len(texts)
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[int], List[int]]]],
):
if params.method in ("attention-decoder", "rnn-lm"):
# Set it to False since there are too many logs.
enable_log = False
else:
enable_log = True
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=enable_log
)
test_set_wers[key] = wer
if enable_log:
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
args.lm_dir = Path(args.lm_dir)
params = get_params()
params.update(vars(args))
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
logging.info("Decoding started")
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
params.vocab_size = num_classes
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir,
device=device,
sos_token="<sos/eos>",
eos_token="<sos/eos>",
)
sos_id = graph_compiler.sos_id
eos_id = graph_compiler.eos_id
params.num_classes = num_classes
params.sos_id = sos_id
params.eos_id = eos_id
if params.method == "ctc-decoding":
HLG = None
H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
device=device,
)
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(str(params.lang_dir / "bpe.model"))
else:
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
)
assert HLG.requires_grad is False
if not hasattr(HLG, "lm_scores"):
HLG.lm_scores = HLG.scores.clone()
if params.method in (
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder",
"rnn-lm",
):
if not (params.lm_dir / "G_4_gram.pt").is_file():
logging.info("Loading G_4_gram.fst.txt")
logging.warning("It may take 8 minutes.")
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
first_word_disambig_id = lexicon.word_table["#0"]
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
# G.aux_labels is not needed in later computations, so
# remove it here.
del G.aux_labels
# CAUTION: The following line is crucial.
# Arcs entering the back-off state have label equal to #0.
# We have to change it to 0 here.
G.labels[G.labels >= first_word_disambig_id] = 0
# See https://github.com/k2-fsa/k2/issues/874
# for why we need to set G.properties to None
G.__dict__["_properties"] = None
G = k2.Fsa.from_fsas([G]).to(device)
G = k2.arc_sort(G)
# Save a dummy value so that it can be loaded in C++.
# See https://github.com/pytorch/pytorch/issues/67902
# for why we need to do this.
G.dummy = 1
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
G = k2.Fsa.from_dict(d)
if params.method in [
"whole-lattice-rescoring",
"attention-decoder",
"rnn-lm",
]:
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G = G.to(device)
# G.lm_scores is used to replace HLG.lm_scores during
# LM rescoring.
G.lm_scores = G.scores.clone()
else:
G = None
logging.info("About to create model")
model = get_ctc_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
rnn_lm_model = None
if params.method == "rnn-lm":
rnn_lm_model = RnnLmModel(
vocab_size=params.num_classes,
embedding_dim=params.rnn_lm_embedding_dim,
hidden_dim=params.rnn_lm_hidden_dim,
num_layers=params.rnn_lm_num_layers,
tie_weights=params.rnn_lm_tie_weights,
)
if params.rnn_lm_avg == 1:
load_checkpoint(
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
rnn_lm_model,
)
rnn_lm_model.to(device)
else:
rnn_lm_model = load_averaged_model(
params.rnn_lm_exp_dir,
rnn_lm_model,
params.rnn_lm_epoch,
params.rnn_lm_avg,
device,
)
rnn_lm_model.eval()
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
rnn_lm_model=rnn_lm_model,
HLG=HLG,
H=H,
bpe_model=bpe_model,
word_table=lexicon.word_table,
G=G,
sos_id=sos_id,
eos_id=eos_id,
)
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,298 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from label_smoothing import LabelSmoothingLoss
from torch.nn.utils.rnn import pad_sequence
from transformer import PositionalEncoding, TransformerDecoderLayer
class Decoder(nn.Module):
"""This class implements Transformer based decoder for an attention-based encoder-decoder
model.
"""
def __init__(
self,
num_layers: int,
num_classes: int,
d_model: int = 256,
nhead: int = 4,
dim_feedforward: int = 2048,
dropout: float = 0.1,
normalize_before: bool = True,
):
"""
Args:
num_layers:
Number of layers.
num_classes:
Number of tokens of the modeling unit including blank.
d_model:
Dimension of the input embedding, and of the decoder output.
"""
super().__init__()
if num_layers > 0:
self.decoder_num_class = num_classes # bpe model already has sos/eos symbol
self.decoder_embed = nn.Embedding(
num_embeddings=self.decoder_num_class, embedding_dim=d_model
)
self.decoder_pos = PositionalEncoding(d_model, dropout)
decoder_layer = TransformerDecoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
normalize_before=normalize_before,
)
if normalize_before:
decoder_norm = nn.LayerNorm(d_model)
else:
decoder_norm = None
self.decoder = nn.TransformerDecoder(
decoder_layer=decoder_layer,
num_layers=num_layers,
norm=decoder_norm,
)
self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
self.decoder_criterion = LabelSmoothingLoss()
else:
self.decoder_criterion = None
@torch.jit.export
def forward(
self,
memory: torch.Tensor,
memory_key_padding_mask: torch.Tensor,
token_ids: List[List[int]],
sos_id: int,
eos_id: int,
) -> torch.Tensor:
"""
Args:
memory:
It's the output of the encoder with shape (T, N, C)
memory_key_padding_mask:
The padding mask from the encoder.
token_ids:
A list-of-list IDs. Each sublist contains IDs for an utterance.
The IDs can be either phone IDs or word piece IDs.
sos_id:
sos token id
eos_id:
eos token id
Returns:
A scalar, the **sum** of label smoothing loss over utterances
in the batch without any normalization.
"""
ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
device = memory.device
ys_in_pad = ys_in_pad.to(device)
ys_out_pad = ys_out_pad.to(device)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# TODO: Use length information to create the decoder padding mask
# We set the first column to False since the first column in ys_in_pad
# contains sos_id, which is the same as eos_id in our current setting.
tgt_key_padding_mask[:, 0] = False
tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C)
tgt = self.decoder_pos(tgt)
tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
pred_pad = self.decoder(
tgt=tgt,
memory=memory,
tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
) # (T, N, C)
pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C)
decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad)
return decoder_loss
@torch.jit.export
def decoder_nll(
self,
memory: torch.Tensor,
memory_key_padding_mask: torch.Tensor,
token_ids: List[torch.Tensor],
sos_id: int,
eos_id: int,
) -> torch.Tensor:
"""
Args:
memory:
It's the output of the encoder with shape (T, N, C)
memory_key_padding_mask:
The padding mask from the encoder.
token_ids:
A list-of-list IDs (e.g., word piece IDs).
Each sublist represents an utterance.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
Returns:
A 2-D tensor of shape (len(token_ids), max_token_length)
representing the cross entropy loss (i.e., negative log-likelihood).
"""
# The common part between this function and decoder_forward could be
# extracted as a separate function.
if isinstance(token_ids[0], torch.Tensor):
# This branch is executed by torchscript in C++.
# See https://github.com/k2-fsa/k2/pull/870
# https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286
token_ids = [tolist(t) for t in token_ids]
ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
device = memory.device
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# TODO: Use length information to create the decoder padding mask
# We set the first column to False since the first column in ys_in_pad
# contains sos_id, which is the same as eos_id in our current setting.
tgt_key_padding_mask[:, 0] = False
tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F)
tgt = self.decoder_pos(tgt)
tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
pred_pad = self.decoder(
tgt=tgt,
memory=memory,
tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
) # (T, B, F)
pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F)
pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F)
# nll: negative log-likelihood
nll = torch.nn.functional.cross_entropy(
pred_pad.view(-1, self.decoder_num_class),
ys_out_pad.view(-1),
ignore_index=-1,
reduction="none",
)
nll = nll.view(pred_pad.shape[0], -1)
return nll
def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]:
"""Prepend sos_id to each utterance.
Args:
token_ids:
A list-of-list of token IDs. Each sublist contains
token IDs (e.g., word piece IDs) of an utterance.
sos_id:
The ID of the SOS token.
Return:
Return a new list-of-list, where each sublist starts
with SOS ID.
"""
return [[sos_id] + utt for utt in token_ids]
def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
"""Append eos_id to each utterance.
Args:
token_ids:
A list-of-list of token IDs. Each sublist contains
token IDs (e.g., word piece IDs) of an utterance.
eos_id:
The ID of the EOS token.
Return:
Return a new list-of-list, where each sublist ends
with EOS ID.
"""
return [utt + [eos_id] for utt in token_ids]
def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
"""Generate a length mask for input.
The masked position are filled with True,
Unmasked positions are filled with False.
Args:
ys_pad:
padded tensor of dimension (batch_size, input_length).
ignore_id:
the ignored number (the padding number) in ys_pad
Returns:
Tensor:
a bool tensor of the same shape as the input tensor.
"""
ys_mask = ys_pad == ignore_id
return ys_mask
def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
"""Generate a square mask for the sequence. The masked positions are
filled with float('-inf'). Unmasked positions are filled with float(0.0).
The mask can be used for masked self-attention.
For instance, if sz is 3, it returns::
tensor([[0., -inf, -inf],
[0., 0., -inf],
[0., 0., 0]])
Args:
sz: mask size
Returns:
A square mask of dimension (sz, sz)
"""
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = (
mask.float()
.masked_fill(mask == 0, float("-inf"))
.masked_fill(mask == 1, float(0.0))
)
return mask
def tolist(t: torch.Tensor) -> List[int]:
"""Used by jit"""
return torch.jit.annotate(List[int], t.tolist())

View File

@ -0,0 +1 @@
../pruned_transducer_stateless7/encoder_interface.py

View File

@ -0,0 +1,237 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
import argparse
import logging
from pathlib import Path
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_ctc_model, get_params
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import num_tokens, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer_ctc/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--tokens",
type=str,
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt",
)
parser.add_argument(
"--jit",
type=str2bool,
default=True,
help="""True to save a model after applying torch.jit.script.
""",
)
add_model_arguments(parser)
return parser
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
model = get_ctc_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
if params.jit:
logging.info("Using torch.jit.script")
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
convert_scaled_to_non_scaled(model, inplace=True)
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../conformer_ctc/label_smoothing.py

View File

@ -0,0 +1,158 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from transformer import encoder_padding_mask
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.utils import encode_supervisions
class CTCModel(nn.Module):
"""It implements a CTC model with an auxiliary attention head."""
def __init__(
self,
encoder: EncoderInterface,
decoder: nn.Module,
encoder_dim: int,
vocab_size: int,
):
"""
Args:
encoder:
An instance of `EncoderInterface`. The shared encoder for the CTC and attention
branches
decoder:
An instance of `nn.Module`. This is the decoder for the attention branch.
encoder_dim:
Dimension of the encoder output.
decoder_dim:
Dimension of the decoder output.
vocab_size:
Number of tokens of the modeling unit including blank.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
self.encoder = encoder
self.ctc_output = nn.Sequential(
nn.Dropout(p=0.1),
nn.Linear(encoder_dim, vocab_size),
nn.LogSoftmax(dim=-1),
)
self.decoder = decoder
@torch.jit.ignore
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
supervisions: torch.Tensor,
graph_compiler: BpeCtcTrainingGraphCompiler,
subsampling_factor: int = 1,
beam_size: int = 10,
reduction: str = "sum",
use_double_scores: bool = False,
) -> torch.Tensor:
"""
Args:
x:
Tensor of dimension (N, T, C) where N is the batch size,
T is the number of frames, and C is the feature dimension.
x_lens:
Tensor of dimension (N,) where N is the batch size.
supervisions:
Supervisions are used in training.
graph_compiler:
It is used to compile a decoding graph from texts.
subsampling_factor:
It is used to compute the `supervisions` for the encoder.
beam_size:
Beam size used in `k2.ctc_loss`.
reduction:
Reduction method used in `k2.ctc_loss`.
use_double_scores:
If True, use double precision in `k2.ctc_loss`.
Returns:
Return the CTC loss, attention loss, and the total number of frames.
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
nnet_output, x_lens = self.encoder(x, x_lens)
assert torch.all(x_lens > 0)
# compute ctc log-probs
ctc_output = self.ctc_output(nnet_output)
# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
# `k2.intersect_dense` called in `k2.ctc_loss`
supervision_segments, texts = encode_supervisions(
supervisions, subsampling_factor=subsampling_factor
)
num_frames = supervision_segments[:, 2].sum().item()
# Works with a BPE model
token_ids = graph_compiler.texts_to_ids(texts)
decoding_graph = graph_compiler.compile(token_ids)
dense_fsa_vec = k2.DenseFsaVec(
ctc_output,
supervision_segments.cpu(),
allow_truncate=subsampling_factor - 1,
)
ctc_loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=beam_size,
reduction=reduction,
use_double_scores=use_double_scores,
)
if self.decoder is not None:
nnet_output = nnet_output.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
mmodel = (
self.decoder.module if hasattr(self.decoder, "module") else self.decoder
)
# Note: We need to generate an unsorted version of token_ids
# `encode_supervisions()` called above sorts text, but
# encoder_memory and memory_mask are not sorted, so we
# use an unsorted version `supervisions["text"]` to regenerate
# the token_ids
#
# See https://github.com/k2-fsa/icefall/issues/97
# for more details
unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"])
mask = encoder_padding_mask(nnet_output.size(0), supervisions)
mask = mask.to(nnet_output.device) if mask is not None else None
att_loss = mmodel.forward(
nnet_output,
mask,
token_ids=unsorted_token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
else:
att_loss = torch.tensor([0])
return ctc_loss, att_loss, num_frames

View File

@ -0,0 +1 @@
../pruned_transducer_stateless7/optim.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless7/scaling.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless7/scaling_converter.py

View File

@ -0,0 +1 @@
../conformer_ctc/subsampling.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../conformer_ctc/transformer.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless7/zipformer.py

Some files were not shown because too many files have changed in this diff Show More