Add the upgraded Zipformer model (#1058)

* add the zipformer codes, copied from branch from_dan_scaled_adam_exp1119

* support model export with torch.jit.script

* update RESULTS.md

* support exporting streaming model with torch.jit.script

* add results of streaming models, with some minor changes

* update README.md

* add CI test

* update k2 version in requirements-ci.txt

* update pyproject.toml
This commit is contained in:
Zengwei Yao 2023-05-19 16:47:59 +08:00 committed by GitHub
parent a5bbfc6f7e
commit f18b539fbc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 12381 additions and 52 deletions

View File

@ -13,6 +13,7 @@ per-file-ignores =
egs/librispeech/ASR/conv_emformer_transducer_stateless*/*.py: E501, E203
egs/librispeech/ASR/conformer_ctc*/*py: E501,
egs/librispeech/ASR/zipformer_mmi/*.py: E501, E203
egs/librispeech/ASR/zipformer/*.py: E501, E203
egs/librispeech/ASR/RESULTS.md: E999,
# invalid escape sequence (cause by tex formular), W605

View File

@ -0,0 +1,115 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/librispeech/ASR
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
log "Downloading pre-trained model from $repo_url"
git lfs install
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
log "Display test files"
tree $repo/
ls -lh $repo/test_wavs/*.wav
pushd $repo/exp
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/jit_script_chunk_16_left_128.pt"
git lfs pull --include "exp/pretrained.pt"
ln -s pretrained.pt epoch-99.pt
ls -lh *.pt
popd
log "Export to torchscript model"
./zipformer/export.py \
--exp-dir $repo/exp \
--use-averaged-model false \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--epoch 99 \
--avg 1 \
--jit 1
ls -lh $repo/exp/*.pt
log "Decode with models exported by torch.jit.script()"
./zipformer/jit_pretrained_streaming.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--nn-model-filename $repo/exp/jit_script_chunk_16_left_128.pt \
$repo/test_wavs/1089-134686-0001.wav
for method in greedy_search modified_beam_search fast_beam_search; do
log "$method"
./zipformer/pretrained.py \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
mkdir -p zipformer/exp
ln -s $PWD/$repo/exp/pretrained.pt zipformer/exp/epoch-999.pt
ln -s $PWD/$repo/data/lang_bpe_500 data/
ls -lh data
ls -lh zipformer/exp
log "Decoding test-clean and test-other"
# use a small value for decoding with CPU
max_duration=100
for method in greedy_search fast_beam_search modified_beam_search; do
log "Simulated streaming decoding with $method"
./zipformer/decode.py \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--decoding-method $method \
--epoch 999 \
--avg 1 \
--use-averaged-model 0 \
--max-duration $max_duration \
--exp-dir zipformer/exp
done
for method in greedy_search fast_beam_search modified_beam_search; do
log "Chunk-wise streaming decoding with $method"
./zipformer/streaming_decode.py \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--decoding-method $method \
--epoch 999 \
--avg 1 \
--use-averaged-model 0 \
--max-duration $max_duration \
--exp-dir zipformer/exp
done
rm zipformer/exp/*.pt
fi

View File

@ -0,0 +1,93 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/librispeech/ASR
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
log "Downloading pre-trained model from $repo_url"
git lfs install
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
log "Display test files"
tree $repo/
ls -lh $repo/test_wavs/*.wav
pushd $repo/exp
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/jit_script.pt"
git lfs pull --include "exp/pretrained.pt"
ln -s pretrained.pt epoch-99.pt
ls -lh *.pt
popd
log "Export to torchscript model"
./zipformer/export.py \
--exp-dir $repo/exp \
--use-averaged-model false \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--jit 1
ls -lh $repo/exp/*.pt
log "Decode with models exported by torch.jit.script()"
./zipformer/jit_pretrained.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--nn-model-filename $repo/exp/jit_script.pt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
for method in greedy_search modified_beam_search fast_beam_search; do
log "$method"
./zipformer/pretrained.py \
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
mkdir -p zipformer/exp
ln -s $PWD/$repo/exp/pretrained.pt zipformer/exp/epoch-999.pt
ln -s $PWD/$repo/data/lang_bpe_500 data/
ls -lh data
ls -lh zipformer/exp
log "Decoding test-clean and test-other"
# use a small value for decoding with CPU
max_duration=100
for method in greedy_search fast_beam_search modified_beam_search; do
log "Decoding with $method"
./zipformer/decode.py \
--decoding-method $method \
--epoch 999 \
--avg 1 \
--use-averaged-model 0 \
--max-duration $max_duration \
--exp-dir zipformer/exp
done
rm zipformer/exp/*.pt
fi

View File

@ -0,0 +1,174 @@
# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com)
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: run-librispeech-streaming-zipformer-2023-05-18
# zipformer
on:
push:
branches:
- master
pull_request:
types: [labeled]
schedule:
# minute (0-59)
# hour (0-23)
# day of the month (1-31)
# month (1-12)
# day of the week (0-6)
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
concurrency:
group: run_librispeech_2023_05_18_streaming_zipformer-${{ github.ref }}
cancel-in-progress: true
jobs:
run_librispeech_2023_05_18_streaming_zipformer:
if: github.event.label.name == 'zipformer' ||github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest]
python-version: [3.8]
fail-fast: false
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
cache-dependency-path: '**/requirements-ci.txt'
- name: Install Python dependencies
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
pip uninstall -y protobuf
pip install --no-binary protobuf protobuf==3.20.*
- name: Cache kaldifeat
id: my-cache
uses: actions/cache@v2
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/install-kaldifeat.sh
- name: Cache LibriSpeech test-clean and test-other datasets
id: libri-test-clean-and-test-other-data
uses: actions/cache@v2
with:
path: |
~/tmp/download
key: cache-libri-test-clean-and-test-other
- name: Download LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
- name: Prepare manifests for LibriSpeech test-clean and test-other
shell: bash
run: |
.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
- name: Cache LibriSpeech test-clean and test-other fbank features
id: libri-test-clean-and-test-other-fbank
uses: actions/cache@v2
with:
path: |
~/tmp/fbank-libri
key: cache-libri-fbank-test-clean-and-test-other-v2
- name: Compute fbank for LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
- name: Inference with pre-trained model
shell: bash
env:
GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
run: |
mkdir -p egs/librispeech/ASR/data
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
ls -lh egs/librispeech/ASR/data/*
sudo apt-get -qq install git-lfs tree
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
.github/scripts/run-librispeech-streaming-zipformer-2023-05-18.sh
- name: Display decoding results for librispeech zipformer
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
shell: bash
run: |
cd egs/librispeech/ASR/
tree ./zipformer/exp
cd zipformer
echo "results for zipformer, simulated streaming decoding"
echo "===greedy search==="
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===fast_beam_search==="
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===modified beam search==="
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "results for zipformer, chunk-wise streaming decoding"
echo "===greedy search==="
find exp/streaming/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/streaming/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===fast_beam_search==="
find exp/streaming/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/streaming/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===modified beam search==="
find exp/streaming/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/streaming/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Upload decoding results for librispeech zipformer
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-zipformer-2022-11-11
path: egs/librispeech/ASR/zipformer/exp/

View File

@ -0,0 +1,159 @@
# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com)
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: run-librispeech-zipformer-2023-05-18
# zipformer
on:
push:
branches:
- master
pull_request:
types: [labeled]
schedule:
# minute (0-59)
# hour (0-23)
# day of the month (1-31)
# month (1-12)
# day of the week (0-6)
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
concurrency:
group: run_librispeech_2023_05_18_zipformer-${{ github.ref }}
cancel-in-progress: true
jobs:
run_librispeech_2023_05_18_zipformer:
if: github.event.label.name == 'zipformer' ||github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest]
python-version: [3.8]
fail-fast: false
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
cache-dependency-path: '**/requirements-ci.txt'
- name: Install Python dependencies
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
pip uninstall -y protobuf
pip install --no-binary protobuf protobuf==3.20.*
- name: Cache kaldifeat
id: my-cache
uses: actions/cache@v2
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/install-kaldifeat.sh
- name: Cache LibriSpeech test-clean and test-other datasets
id: libri-test-clean-and-test-other-data
uses: actions/cache@v2
with:
path: |
~/tmp/download
key: cache-libri-test-clean-and-test-other
- name: Download LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
- name: Prepare manifests for LibriSpeech test-clean and test-other
shell: bash
run: |
.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
- name: Cache LibriSpeech test-clean and test-other fbank features
id: libri-test-clean-and-test-other-fbank
uses: actions/cache@v2
with:
path: |
~/tmp/fbank-libri
key: cache-libri-fbank-test-clean-and-test-other-v2
- name: Compute fbank for LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
- name: Inference with pre-trained model
shell: bash
env:
GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
run: |
mkdir -p egs/librispeech/ASR/data
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
ls -lh egs/librispeech/ASR/data/*
sudo apt-get -qq install git-lfs tree
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
.github/scripts/run-librispeech-zipformer-2023-05-18.sh
- name: Display decoding results for librispeech zipformer
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
shell: bash
run: |
cd egs/librispeech/ASR/
tree ./zipformer/exp
cd zipformer
echo "results for zipformer"
echo "===greedy search==="
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===fast_beam_search==="
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===modified beam search==="
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Upload decoding results for librispeech zipformer
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-zipformer-2022-11-11
path: egs/librispeech/ASR/zipformer/exp/

View File

@ -113,6 +113,7 @@ jobs:
cd ../pruned_transducer_stateless4
pytest -v -s
echo $PYTHONPATH
cd ../pruned_transducer_stateless7
pytest -v -s

View File

@ -55,12 +55,13 @@ We provide a Colab notebook for this recipe: [![Open In Colab](https://colab.res
Please see <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md>
for the **latest** results.
We provide 4 models for this recipe:
We provide 5 models for this recipe:
- [conformer CTC model][LibriSpeech_conformer_ctc]
- [TDNN LSTM CTC model][LibriSpeech_tdnn_lstm_ctc]
- [Transducer: Conformer encoder + LSTM decoder][LibriSpeech_transducer]
- [Transducer: Conformer encoder + Embedding decoder][LibriSpeech_transducer_stateless]
- [Transducer: Zipformer encoder + Embedding decoder][LibriSpeech_zipformer]
#### Conformer CTC Model
@ -115,9 +116,11 @@ We provide a Colab notebook to run a pre-trained transducer conformer + stateles
#### k2 pruned RNN-T
| | test-clean | test-other |
|-----|------------|------------|
| WER | 2.15 | 5.20 |
| Encoder | Params | test-clean | test-other |
|-----------------|--------|------------|------------|
| zipformer | 65.5M | 2.21 | 4.91 |
| zipformer-small | 23.2M | 2.46 | 5.83 |
| zipformer-large | 148.4M | 2.11 | 4.77 |
Note: No auxiliary losses are used in the training and no LMs are used
in the decoding.
@ -361,6 +364,7 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad
[LibriSpeech_conformer_ctc]: egs/librispeech/ASR/conformer_ctc
[LibriSpeech_transducer]: egs/librispeech/ASR/transducer
[LibriSpeech_transducer_stateless]: egs/librispeech/ASR/transducer_stateless
[LibriSpeech_zipformer]: egs/librispeech/ASR/zipformer
[Aishell_tdnn_lstm_ctc]: egs/aishell/ASR/tdnn_lstm_ctc
[Aishell_conformer_ctc]: egs/aishell/ASR/conformer_ctc
[Aishell_pruned_transducer_stateless7]: egs/aishell/ASR/pruned_transducer_stateless7_bbpe

View File

@ -34,6 +34,7 @@ The following table lists the differences among them.
| `lstm_transducer_stateless` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model |
| `lstm_transducer_stateless2` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gigaspeech (multi-dataset setup) |
| `lstm_transducer_stateless3` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gradient filter + delay penalty |
| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe |
The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).

View File

@ -1,5 +1,249 @@
## Results
### zipformer (zipformer + pruned stateless transducer)
See <https://github.com/k2-fsa/icefall/pull/1058> for more details.
[zipformer](./zipformer)
#### Non-streaming
##### normal-scaled model, number of model parameters: 65549011, i.e., 65.55 M
The tensorboard log can be found at
<https://tensorboard.dev/experiment/cBaoIabCQxSDsyZM7FzqZA/>
You can find a pretrained model, training logs, decoding logs, and decoding results at:
<https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15>
You can use <https://github.com/k2-fsa/sherpa> to deploy it.
| decoding method | test-clean | test-other | comment |
|----------------------|------------|------------|--------------------|
| greedy_search | 2.27 | 5.1 | --epoch 30 --avg 9 |
| modified_beam_search | 2.25 | 5.06 | --epoch 30 --avg 9 |
| fast_beam_search | 2.25 | 5.04 | --epoch 30 --avg 9 |
| greedy_search | 2.23 | 4.96 | --epoch 40 --avg 16 |
| modified_beam_search | 2.21 | 4.91 | --epoch 40 --avg 16 |
| fast_beam_search | 2.24 | 4.93 | --epoch 40 --avg 16 |
The training command is:
```bash
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./zipformer/train.py \
--world-size 4 \
--num-epochs 40 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--causal 0 \
--full-libri 1 \
--max-duration 1000
```
The decoding command is:
```bash
export CUDA_VISIBLE_DEVICES="0"
for m in greedy_search modified_beam_search fast_beam_search; do
./zipformer/decode.py \
--epoch 30 \
--avg 9 \
--use-averaged-model 1 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method $m
done
```
##### small-scaled model, number of model parameters: 23285615, i.e., 23.3 M
The tensorboard log can be found at
<https://tensorboard.dev/experiment/53P4tL22TpO0UdiL0kPaLg/>
You can find a pretrained model, training logs, decoding logs, and decoding results at:
<https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-small-2023-05-16>
You can use <https://github.com/k2-fsa/sherpa> to deploy it.
| decoding method | test-clean | test-other | comment |
|----------------------|------------|------------|--------------------|
| greedy_search | 2.64 | 6.14 | --epoch 30 --avg 8 |
| modified_beam_search | 2.6 | 6.01 | --epoch 30 --avg 8 |
| fast_beam_search | 2.62 | 6.06 | --epoch 30 --avg 8 |
| greedy_search | 2.49 | 5.91 | --epoch 40 --avg 13 |
| modified_beam_search | 2.46 | 5.83 | --epoch 40 --avg 13 |
| fast_beam_search | 2.46 | 5.87 | --epoch 40 --avg 13 |
The training command is:
```bash
export CUDA_VISIBLE_DEVICES="0,1"
./zipformer/train.py \
--world-size 2 \
--num-epochs 40 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp-small \
--causal 0 \
--num-encoder-layers 2,2,2,2,2,2 \
--feedforward-dim 512,768,768,768,768,768 \
--encoder-dim 192,256,256,256,256,256 \
--encoder-unmasked-dim 192,192,192,192,192,192 \
--base-lr 0.04 \
--full-libri 1 \
--max-duration 1500
```
The decoding command is:
```bash
export CUDA_VISIBLE_DEVICES="0"
for m in greedy_search modified_beam_search fast_beam_search; do
./zipformer/decode.py \
--epoch 40 \
--avg 13 \
--exp-dir zipformer/exp-small \
--max-duration 600 \
--causal 0 \
--decoding-method $m \
--num-encoder-layers 2,2,2,2,2,2 \
--feedforward-dim 512,768,768,768,768,768 \
--encoder-dim 192,256,256,256,256,256 \
--encoder-unmasked-dim 192,192,192,192,192,192
done
```
##### large-scaled model, number of model parameters: 148439574, i.e., 148.4 M
The tensorboard log can be found at
<https://tensorboard.dev/experiment/HJ74wWYpQAGSzETkmQnrmQ/>
You can find a pretrained model, training logs, decoding logs, and decoding results at:
<https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-large-2023-05-16>
You can use <https://github.com/k2-fsa/sherpa> to deploy it.
| decoding method | test-clean | test-other | comment |
|----------------------|------------|------------|--------------------|
| greedy_search | 2.12 | 4.91 | --epoch 30 --avg 9 |
| modified_beam_search | 2.11 | 4.9 | --epoch 30 --avg 9 |
| fast_beam_search | 2.13 | 4.93 | --epoch 30 --avg 9 |
| greedy_search | 2.12 | 4.8 | --epoch 40 --avg 13 |
| modified_beam_search | 2.11 | 4.7 | --epoch 40 --avg 13 |
| fast_beam_search | 2.13 | 4.78 | --epoch 40 --avg 13 |
The training command is:
```bash
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./zipformer/train.py \
--world-size 4 \
--num-epochs 40 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp-large \
--causal 0 \
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192 \
--full-libri 1 \
--max-duration 1000
```
The decoding command is:
```bash
export CUDA_VISIBLE_DEVICES="0"
for m in greedy_search modified_beam_search fast_beam_search; do
./zipformer/decode.py \
--epoch 40 \
--avg 16 \
--exp-dir zipformer/exp-large \
--max-duration 600 \
--causal 0 \
--decoding-method $m \
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192
done
```
#### streaming
##### normal-scaled model, number of model parameters: 66110931, i.e., 66.11 M
The tensorboard log can be found at
<https://tensorboard.dev/experiment/9rD0i6rMSWq1O61poWi71A>
You can find a pretrained model, training logs, decoding logs, and decoding results at:
<https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17>
You can use <https://github.com/k2-fsa/sherpa> to deploy it.
| decoding method | chunk size | test-clean | test-other | decoding mode | comment |
|----------------------|------------|------------|------------|---------------------|--------------------|
| greedy_search | 320ms | 3.06 | 7.81 | simulated streaming | --epoch 30 --avg 8 --chunk-size 16 --left-context-frames 128 |
| greedy_search | 320ms | 3.06 | 7.79 | chunk-wise | --epoch 30 --avg 8 --chunk-size 16 --left-context-frames 128 |
| modified_beam_search | 320ms | 3.01 | 7.69 | simulated streaming | --epoch 30 --avg 8 --chunk-size 16 --left-context-frames 128 |
| modified_beam_search | 320ms | 3.05 | 7.69 | chunk-wise | --epoch 30 --avg 8 --chunk-size 16 --left-context-frames 128 |
| fast_beam_search | 320ms | 3.04 | 7.68 | simulated streaming | --epoch 30 --avg 8 --chunk-size 16 --left-context-frames 128 |
| fast_beam_search | 320ms | 3.07 | 7.69 | chunk-wise | --epoch 30 --avg 8 --chunk-size 16 --left-context-frames 128 |
| greedy_search | 640ms | 2.81 | 7.15 | simulated streaming | --epoch 30 --avg 8 --chunk-size 32 --left-context-frames 256 |
| greedy_search | 640ms | 2.84 | 7.16 | chunk-wise | --epoch 30 --avg 8 --chunk-size 32 --left-context-frames 256 |
| modified_beam_search | 640ms | 2.79 | 7.05 | simulated streaming | --epoch 30 --avg 8 --chunk-size 32 --left-context-frames 256 |
| modified_beam_search | 640ms | 2.81 | 7.11 | chunk-wise | --epoch 30 --avg 8 --chunk-size 32 --left-context-frames 256 |
| fast_beam_search | 640ms | 2.84 | 7.04 | simulated streaming | --epoch 30 --avg 8 --chunk-size 32 --left-context-frames 256 |
| fast_beam_search | 640ms | 2.83 | 7.1 | chunk-wise | --epoch 30 --avg 8 --chunk-size 32 --left-context-frames 256 |
Note: For decoding mode, `simulated streaming` indicates feeding full utterance during decoding using `decode.py`,
while `chunk-size` indicates feeding certain number of frames at each time using `streaming_decode.py`.
The training command is:
```bash
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./zipformer/train.py \
--world-size 4 \
--num-epochs 40 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp-causal \
--causal 1 \
--full-libri 1 \
--max-duration 1000
```
The simulated streaming decoding command is:
```bash
export CUDA_VISIBLE_DEVICES="0"
for m in greedy_search modified_beam_search fast_beam_search; do
./zipformer/decode.py \
--epoch 30 \
--avg 8 \
--use-averaged-model 1 \
--exp-dir ./zipformer/exp-causal \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--max-duration 600 \
--decoding-method $m
done
```
The chunk-wise streaming decoding command is:
```bash
export CUDA_VISIBLE_DEVICES="0"
for m in greedy_search modified_beam_search fast_beam_search; do
./zipformer/streaming_decode.py \
--epoch 30 \
--avg 8 \
--use-averaged-model 1 \
--exp-dir ./zipformer/exp-causal \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--num-decode-streams 2000 \
--decoding-method $m
done
```
### pruned_transducer_stateless7 (zipformer + multidataset(LibriSpeech + GigaSpeech + CommonVoice 13.0))
See <https://github.com/k2-fsa/icefall/pull/1010> for more details.

View File

@ -0,0 +1 @@
../transducer/asr_datamodule.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/beam_search.py

View File

@ -0,0 +1,834 @@
#!/usr/bin/env python3
#
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search (not recommended)
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search (one best)
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(5) fast beam search (nbest)
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(7) fast beam search (with LG)
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
"""
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
make_pad_mask,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
if params.causal:
# this seems to cause insertions at the end of the utterance if used with zipformer.
pad_len = 30
feature_lens += pad_len
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, pad_len),
value=LOG_EPS,
)
x, x_lens = model.encoder_embed(feature, feature_lens)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = model.encoder(
x, x_lens, src_key_padding_mask
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.causal:
assert (
"," not in params.chunk_size
), "chunk_size should be one value in decoding."
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
params.suffix += f"-chunk-{params.chunk_size}"
params.suffix += f"-left-context-{params.left_context_frames}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
else:
decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,148 @@
# Copyright 2022 Xiaomi Corp. (authors: Wei Kang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple
import k2
import torch
from beam_search import Hypothesis, HypothesisList
from icefall.utils import AttributeDict
class DecodeStream(object):
def __init__(
self,
params: AttributeDict,
cut_id: str,
initial_states: List[torch.Tensor],
decoding_graph: Optional[k2.Fsa] = None,
device: torch.device = torch.device("cpu"),
) -> None:
"""
Args:
initial_states:
Initial decode states of the model, e.g. the return value of
`get_init_state` in conformer.py
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
Used only when decoding_method is fast_beam_search.
device:
The device to run this stream.
"""
if params.decoding_method == "fast_beam_search":
assert decoding_graph is not None
assert device == decoding_graph.device
self.params = params
self.cut_id = cut_id
self.LOG_EPS = math.log(1e-10)
self.states = initial_states
# It contains a 2-D tensors representing the feature frames.
self.features: torch.Tensor = None
self.num_frames: int = 0
# how many frames have been processed. (before subsampling).
# we only modify this value in `func:get_feature_frames`.
self.num_processed_frames: int = 0
self._done: bool = False
# The transcript of current utterance.
self.ground_truth: str = ""
# The decoding result (partial or final) of current utterance.
self.hyp: List = []
# how many frames have been processed, at encoder output
self.done_frames: int = 0
# The encoder_embed subsample features (T - 7) // 2
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
self.pad_length = 7 + 2 * 3
if params.decoding_method == "greedy_search":
self.hyp = [params.blank_id] * params.context_size
elif params.decoding_method == "modified_beam_search":
self.hyps = HypothesisList()
self.hyps.add(
Hypothesis(
ys=[params.blank_id] * params.context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
elif params.decoding_method == "fast_beam_search":
# The rnnt_decoding_stream for fast_beam_search.
self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream(
decoding_graph
)
else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
@property
def done(self) -> bool:
"""Return True if all the features are processed."""
return self._done
@property
def id(self) -> str:
return self.cut_id
def set_features(
self,
features: torch.Tensor,
tail_pad_len: int = 0,
) -> None:
"""Set features tensor of current utterance."""
assert features.dim() == 2, features.dim()
self.features = torch.nn.functional.pad(
features,
(0, 0, 0, self.pad_length + tail_pad_len),
mode="constant",
value=self.LOG_EPS,
)
self.num_frames = self.features.size(0)
def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]:
"""Consume chunk_size frames of features"""
chunk_length = chunk_size + self.pad_length
ret_length = min(self.num_frames - self.num_processed_frames, chunk_length)
ret_features = self.features[
self.num_processed_frames : self.num_processed_frames + ret_length # noqa
]
self.num_processed_frames += chunk_size
if self.num_processed_frames >= self.num_frames:
self._done = True
return ret_features, ret_length
def decoding_result(self) -> List[int]:
"""Obtain current decoding result."""
if self.params.decoding_method == "greedy_search":
return self.hyp[self.params.context_size :] # noqa
elif self.params.decoding_method == "modified_beam_search":
best_hyp = self.hyps.get_most_probable(length_norm=True)
return best_hyp.ys[self.params.context_size :] # noqa
else:
assert self.params.decoding_method == "fast_beam_search"
return self.hyp

View File

@ -0,0 +1,123 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
from scaling import Balancer
class Decoder(nn.Module):
"""This class modifies the stateless decoder from the following paper:
RNN-transducer with stateless prediction network
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
It removes the recurrent connection from the decoder, i.e., the prediction
network. Different from the above paper, it adds an extra Conv1d
right after the embedding layer.
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
"""
def __init__(
self,
vocab_size: int,
decoder_dim: int,
blank_id: int,
context_size: int,
):
"""
Args:
vocab_size:
Number of tokens of the modeling unit including blank.
decoder_dim:
Dimension of the input embedding, and of the decoder output.
blank_id:
The ID of the blank symbol.
context_size:
Number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
"""
super().__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=decoder_dim,
padding_idx=blank_id,
)
# the balancers are to avoid any drift in the magnitude of the
# embeddings, which would interact badly with parameter averaging.
self.balancer = Balancer(decoder_dim, channel_dim=-1,
min_positive=0.0, max_positive=1.0,
min_abs=0.5, max_abs=1.0,
prob=0.05)
self.blank_id = blank_id
assert context_size >= 1, context_size
self.context_size = context_size
self.vocab_size = vocab_size
if context_size > 1:
self.conv = nn.Conv1d(
in_channels=decoder_dim,
out_channels=decoder_dim,
kernel_size=context_size,
padding=0,
groups=decoder_dim // 4, # group size == 4
bias=False,
)
self.balancer2 = Balancer(decoder_dim, channel_dim=-1,
min_positive=0.0, max_positive=1.0,
min_abs=0.5, max_abs=1.0,
prob=0.05)
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, U).
need_pad:
True to left pad the input. Should be True during training.
False to not pad the input. Should be False during inference.
Returns:
Return a tensor of shape (N, U, decoder_dim).
"""
y = y.to(torch.int64)
# this stuff about clamp() is a temporary fix for a mismatch
# at utterance start, we use negative ids in beam_search.py
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
embedding_out = self.balancer(embedding_out)
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
embedding_out = F.pad(
embedding_out, pad=(self.context_size - 1, 0)
)
else:
# During inference time, there is no need to do extra padding
# as we only need one output
assert embedding_out.size(-1) == self.context_size
embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1)
embedding_out = F.relu(embedding_out)
embedding_out = self.balancer2(embedding_out)
return embedding_out

View File

@ -0,0 +1 @@
../transducer_stateless/encoder_interface.py

View File

@ -0,0 +1,523 @@
#!/usr/bin/env python3
#
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
(1) Export to torchscript model using torch.jit.script()
- For non-streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 30 \
--avg 9 \
--jit 1
It will generate a file `jit_script.pt` in the given `exp_dir`. You can later
load it by `torch.jit.load("jit_script.pt")`.
Check ./jit_pretrained.py for its usage.
Check https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall.
- For streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 30 \
--avg 9 \
--jit 1
It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`.
You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`.
Check ./jit_pretrained_streaming.py for its usage.
Check https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall.
(2) Export `model.state_dict()`
- For non-streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 30 \
--avg 9
- For streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--causal 1 \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 30 \
--avg 9
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
load it by `icefall.checkpoint.load_checkpoint()`.
- For non-streaming model:
To use the generated file with `zipformer/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./zipformer/decode.py \
--exp-dir ./zipformer/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
- For streaming model:
To use the generated file with `zipformer/decode.py` and `zipformer/streaming_decode.py`, you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
# simulated streaming decoding
./zipformer/decode.py \
--exp-dir ./zipformer/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
# chunk-wise streaming decoding
./zipformer/streaming_decode.py \
--exp-dir ./zipformer/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
Check ./pretrained.py for its usage.
Note: If you don't want to train a model from scratch, we have
provided one for you. You can get it at
- non-streaming model:
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
- streaming model:
https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
with the following commands:
sudo apt-get install git-lfs
git lfs install
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
# You will find the pre-trained models in exp dir
"""
import argparse
import logging
from pathlib import Path
from typing import List, Tuple
import sentencepiece as spm
import torch
from torch import Tensor, nn
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import make_pad_mask, str2bool
from scaling_converter import convert_scaled_to_non_scaled
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
It will generate a file named cpu_jit.pt.
Check ./jit_pretrained.py for how to use it.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
class EncoderModel(nn.Module):
"""A wrapper for encoder and encoder_embed"""
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
super().__init__()
self.encoder = encoder
self.encoder_embed = encoder_embed
def forward(
self, features: Tensor, feature_lengths: Tensor
) -> Tuple[Tensor, Tensor]:
"""
Args:
features: (N, T, C)
feature_lengths: (N,)
"""
x, x_lens = self.encoder_embed(features, feature_lengths)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder(
x, x_lens, src_key_padding_mask
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return encoder_out, encoder_out_lens
class StreamingEncoderModel(nn.Module):
"""A wrapper for encoder and encoder_embed"""
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
super().__init__()
assert len(encoder.chunk_size) == 1, encoder.chunk_size
assert len(encoder.left_context_frames) == 1, encoder.left_context_frames
self.chunk_size = encoder.chunk_size[0]
self.left_context_len = encoder.left_context_frames[0]
# The encoder_embed subsample features (T - 7) // 2
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
self.pad_length = 7 + 2 * 3
self.encoder = encoder
self.encoder_embed = encoder_embed
def forward(
self, features: Tensor, feature_lengths: Tensor, states: List[Tensor]
) -> Tuple[Tensor, Tensor, List[Tensor]]:
"""Streaming forward for encoder_embed and encoder.
Args:
features: (N, T, C)
feature_lengths: (N,)
states: a list of Tensors
Returns encoder outputs, output lengths, and updated states.
"""
chunk_size = self.chunk_size
left_context_len = self.left_context_len
cached_embed_left_pad = states[-2]
x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward(
x=features,
x_lens=feature_lengths,
cached_left_pad=cached_embed_left_pad,
)
assert x.size(1) == chunk_size, (x.size(1), chunk_size)
src_key_padding_mask = make_pad_mask(x_lens)
# processed_mask is used to mask out initial states
processed_mask = torch.arange(left_context_len, device=x.device).expand(
x.size(0), left_context_len
)
processed_lens = states[-1] # (batch,)
# (batch, left_context_size)
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
# Update processed lengths
new_processed_lens = processed_lens + x_lens
# (batch, left_context_size + chunk_size)
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_states = states[:-2]
(
encoder_out,
encoder_out_lens,
new_encoder_states,
) = self.encoder.streaming_forward(
x=x,
x_lens=x_lens,
states=encoder_states,
src_key_padding_mask=src_key_padding_mask,
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
new_states = new_encoder_states + [
new_cached_embed_left_pad,
new_processed_lens,
]
return encoder_out, encoder_out_lens, new_states
@torch.jit.export
def get_init_states(
self,
batch_size: int = 1,
device: torch.device = torch.device("cpu"),
) -> List[torch.Tensor]:
"""
Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
states[-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
states[-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
"""
states = self.encoder.get_init_states(batch_size, device)
embed_states = self.encoder_embed.get_init_states(batch_size, device)
states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
states.append(processed_lens)
return states
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
# if torch.cuda.is_available():
# device = torch.device("cuda", 0)
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.eval()
if params.jit is True:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
# Wrap encoder and encoder_embed as a module
if params.causal:
model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed)
chunk_size = model.encoder.chunk_size
left_context_len = model.encoder.left_context_len
filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt"
else:
model.encoder = EncoderModel(model.encoder, model.encoder_embed)
filename = "jit_script.pt"
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
model.save(str(params.exp_dir / filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torchscript. Export model.state_dict()")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,202 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) use the checkpoint exp_dir/epoch-xxx.pt
./zipformer/generate_averaged_model.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp
It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`.
You can later load it by `torch.load("epoch-28-avg-15.pt")`.
(2) use the checkpoint exp_dir/checkpoint-iter.pt
./zipformer/generate_averaged_model.py \
--iter 22000 \
--avg 5 \
--exp-dir ./zipformer/exp
It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`.
You can later load it by `torch.load("iter-22000-avg-5.pt")`.
"""
import argparse
from pathlib import Path
import sentencepiece as spm
import torch
from asr_datamodule import LibriSpeechAsrDataModule
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints_with_averaged_model,
find_checkpoints,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
print("Script started")
device = torch.device("cpu")
print(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
print("About to create model")
model = get_transducer_model(params)
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
print(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt"
torch.save({"model": model.state_dict()}, filename)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
print(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt"
torch.save({"model": model.state_dict()}, filename)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
print("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,272 @@
#!/usr/bin/env python3
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads torchscript models, exported by `torch.jit.script()`
and uses them to decode waves.
You can use the following command to get the exported models:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 30 \
--avg 9 \
--jit 1
Usage of this script:
./zipformer/jit_pretrained.py \
--nn-model-filename ./zipformer/exp/cpu_jit.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
/path/to/foo.wav \
/path/to/bar.wav
"""
import argparse
import logging
import math
from typing import List
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--nn-model-filename",
type=str,
required=True,
help="Path to the torchscript model cpu_jit.pt",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float = 16000
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
def greedy_search(
model: torch.jit.ScriptModule,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
A 3-D tensor of shape (N, T, C)
encoder_out_lens:
A 1-D tensor of shape (N,).
Returns:
Return the decoded results for each utterance.
"""
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
device = encoder_out.device
blank_id = 0 # hard-code to 0
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
context_size = model.decoder.context_size
hyps = [[blank_id] * context_size for _ in range(N)]
decoder_input = torch.tensor(
hyps,
device=device,
dtype=torch.int64,
) # (N, context_size)
decoder_out = model.decoder(
decoder_input,
need_pad=torch.tensor([False]),
).squeeze(1)
offset = 0
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = packed_encoder_out.data[start:end]
current_encoder_out = current_encoder_out
# current_encoder_out's shape: (batch_size, encoder_out_dim)
offset = end
decoder_out = decoder_out[:batch_size]
logits = model.joiner(
current_encoder_out,
decoder_out,
)
# logits'shape (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
hyps[i].append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
decoder_input = torch.tensor(
decoder_input,
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=torch.tensor([False]),
)
decoder_out = decoder_out.squeeze(1)
sorted_ans = [h[context_size:] for h in hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
model = torch.jit.load(args.nn_model_filename)
model.eval()
model.to(device)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files,
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features,
batch_first=True,
padding_value=math.log(1e-10),
)
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(
features=features,
feature_lengths=feature_lengths,
)
hyps = greedy_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
s = "\n"
for filename, hyp in zip(args.sound_files, hyps):
words = sp.decode(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,269 @@
#!/usr/bin/env python3
# flake8: noqa
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads torchscript models exported by `torch.jit.script()`
and uses them to decode waves.
You can use the following command to get the exported models:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 30 \
--avg 9 \
--jit 1
Usage of this script:
./zipformer/jit_pretrained_streaming.py \
--nn-model-filename ./zipformer/exp-causal/jit_script_chunk_16_left_128.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
/path/to/foo.wav \
"""
import argparse
import logging
import math
from typing import List, Optional
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
from torch.nn.utils.rnn import pad_sequence
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--nn-model-filename",
type=str,
required=True,
help="Path to the torchscript model cpu_jit.pt",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.""",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"sound_file",
type=str,
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
def greedy_search(
decoder: torch.jit.ScriptModule,
joiner: torch.jit.ScriptModule,
encoder_out: torch.Tensor,
decoder_out: Optional[torch.Tensor] = None,
hyp: Optional[List[int]] = None,
device: torch.device = torch.device("cpu"),
):
assert encoder_out.ndim == 2
context_size = 2
blank_id = 0
if decoder_out is None:
assert hyp is None, hyp
hyp = [blank_id] * context_size
decoder_input = torch.tensor(hyp, dtype=torch.int32, device=device).unsqueeze(0)
# decoder_input.shape (1,, 1 context_size)
decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1)
else:
assert decoder_out.ndim == 2
assert hyp is not None, hyp
T = encoder_out.size(0)
for i in range(T):
cur_encoder_out = encoder_out[i : i + 1]
joiner_out = joiner(cur_encoder_out, decoder_out).squeeze(0)
y = joiner_out.argmax(dim=0).item()
if y != blank_id:
hyp.append(y)
decoder_input = hyp[-context_size:]
decoder_input = torch.tensor(
decoder_input, dtype=torch.int32, device=device
).unsqueeze(0)
decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1)
return hyp, decoder_out
def create_streaming_feature_extractor(sample_rate) -> OnlineFeature:
"""Create a CPU streaming feature extractor.
At present, we assume it returns a fbank feature extractor with
fixed options. In the future, we will support passing in the options
from outside.
Returns:
Return a CPU streaming feature extractor.
"""
opts = FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = sample_rate
opts.mel_opts.num_bins = 80
return OnlineFbank(opts)
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
model = torch.jit.load(args.nn_model_filename)
model.eval()
model.to(device)
encoder = model.encoder
decoder = model.decoder
joiner = model.joiner
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
logging.info("Constructing Fbank computer")
online_fbank = create_streaming_feature_extractor(args.sample_rate)
logging.info(f"Reading sound files: {args.sound_file}")
wave_samples = read_sound_files(
filenames=[args.sound_file],
expected_sample_rate=args.sample_rate,
)[0]
logging.info(wave_samples.shape)
logging.info("Decoding started")
chunk_length = encoder.chunk_size * 2
T = chunk_length + encoder.pad_length
logging.info(f"chunk_length: {chunk_length}")
logging.info(f"T: {T}")
states = encoder.get_init_states(device=device)
tail_padding = torch.zeros(int(0.3 * args.sample_rate), dtype=torch.float32)
wave_samples = torch.cat([wave_samples, tail_padding])
chunk = int(0.25 * args.sample_rate) # 0.2 second
num_processed_frames = 0
hyp = None
decoder_out = None
start = 0
while start < wave_samples.numel():
logging.info(f"{start}/{wave_samples.numel()}")
end = min(start + chunk, wave_samples.numel())
samples = wave_samples[start:end]
start += chunk
online_fbank.accept_waveform(
sampling_rate=args.sample_rate,
waveform=samples,
)
while online_fbank.num_frames_ready - num_processed_frames >= T:
frames = []
for i in range(T):
frames.append(online_fbank.get_frame(num_processed_frames + i))
frames = torch.cat(frames, dim=0).to(device).unsqueeze(0)
x_lens = torch.tensor([T], dtype=torch.int32, device=device)
encoder_out, out_lens, states = encoder(
features=frames,
feature_lengths=x_lens,
states=states,
)
num_processed_frames += chunk_length
hyp, decoder_out = greedy_search(
decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp, device=device
)
context_size = 2
logging.info(args.sound_file)
logging.info(sp.decode(hyp[context_size:]))
logging.info("Decoding Done")
torch.set_num_threads(4)
torch.set_num_interop_threads(1)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
torch._C._set_graph_executor_optimize(False)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,66 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from scaling import ScaledLinear
class Joiner(nn.Module):
def __init__(
self,
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int,
):
super().__init__()
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25)
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25)
self.output_linear = nn.Linear(joiner_dim, vocab_size)
def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
project_input: bool = True,
) -> torch.Tensor:
"""
Args:
encoder_out:
Output from the encoder. Its shape is (N, T, s_range, C).
decoder_out:
Output from the decoder. Its shape is (N, T, s_range, C).
project_input:
If true, apply input projections encoder_proj and decoder_proj.
If this is false, it is the user's responsibility to do this
manually.
Returns:
Return a tensor of shape (N, T, s_range, C).
"""
assert encoder_out.ndim == decoder_out.ndim, (encoder_out.shape, decoder_out.shape)
if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj(
decoder_out
)
else:
logit = encoder_out + decoder_out
logit = self.output_linear(torch.tanh(logit))
return logit

View File

@ -0,0 +1,217 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from icefall.utils import add_sos, make_pad_mask
from scaling import ScaledLinear
class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf
"Sequence Transduction with Recurrent Neural Networks"
"""
def __init__(
self,
encoder_embed: nn.Module,
encoder: EncoderInterface,
decoder: nn.Module,
joiner: nn.Module,
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int,
):
"""
Args:
encoder_embed:
It is a Convolutional 2D subsampling module. It converts
an input of shape (N, T, idim) to an output of of shape
(N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, decoder_dim).
It should contain one attribute: `blank_id`.
joiner:
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
assert hasattr(decoder, "blank_id")
self.encoder_embed = encoder_embed
self.encoder = encoder
self.decoder = decoder
self.joiner = joiner
self.simple_am_proj = ScaledLinear(
encoder_dim,
vocab_size,
initial_scale=0.25,
)
self.simple_lm_proj = ScaledLinear(
decoder_dim,
vocab_size,
initial_scale=0.25,
)
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
Returns:
Return the transducer loss.
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
x, x_lens = self.encoder_embed(x, x_lens)
# logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, x_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
# decoder_out: [B, S + 1, decoder_dim]
decoder_out = self.decoder(sos_y_padded)
# Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(encoder_out.size(0), 4),
dtype=torch.int64,
device=encoder_out.device,
)
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
# if self.training and random.random() < 0.25:
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
)
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=prune_range,
)
# am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim]
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=self.joiner.encoder_proj(encoder_out),
lm=self.joiner.decoder_proj(decoder_out),
ranges=ranges,
)
# logits : [B, T, prune_range, vocab_size]
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
)
return (simple_loss, pruned_loss)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,382 @@
#!/usr/bin/env python3
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads a checkpoint and uses it to decode waves.
You can generate the checkpoint with the following command:
- For non-streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 30 \
--avg 9
- For streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--causal 1 \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 30 \
--avg 9
Usage of this script:
- For non-streaming model:
(1) greedy search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
(2) modified beam search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method modified_beam_search \
/path/to/foo.wav \
/path/to/bar.wav
(3) fast beam search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method fast_beam_search \
/path/to/foo.wav \
/path/to/bar.wav
- For streaming model:
(1) greedy search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
(2) modified beam search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method modified_beam_search \
/path/to/foo.wav \
/path/to/bar.wav
(3) fast beam search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method fast_beam_search \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./zipformer/exp/epoch-xx.pt`.
Note: ./zipformer/exp/pretrained.pt is generated by ./zipformer/export.py
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
fast_beam_search_one_best,
greedy_search_batch,
modified_beam_search,
)
from icefall.utils import make_pad_mask
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.""",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
add_model_arguments(parser)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
if params.causal:
assert (
"," not in params.chunk_size
), "chunk_size should be one value in decoding."
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
logging.info("Creating model")
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
feature_lengths = torch.tensor(feature_lengths, device=device)
# model forward
x, x_lens = model.encoder_embed(features, feature_lengths)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = model.encoder(
x, x_lens, src_key_padding_mask
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
hyps = []
msg = f"Using {params.method}"
logging.info(msg)
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
raise ValueError(f"Unsupported method: {params.method}")
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,82 @@
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file replaces various modules in a model.
Specifically, ActivationBalancer is replaced with an identity operator;
Whiten is also replaced with an identity operator;
BasicNorm is replaced by a module with `exp` removed.
"""
import copy
from typing import List, Tuple
import torch
import torch.nn as nn
from scaling import Balancer, Dropout3, ScaleGrad, Whiten
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
# get_submodule was added to nn.Module at v1.9.0
def get_submodule(model, target):
if target == "":
return model
atoms: List[str] = target.split(".")
mod: torch.nn.Module = model
for item in atoms:
if not hasattr(mod, item):
raise AttributeError(
mod._get_name() + " has no " "attribute `" + item + "`"
)
mod = getattr(mod, item)
if not isinstance(mod, torch.nn.Module):
raise AttributeError("`" + item + "` is not " "an nn.Module")
return mod
def convert_scaled_to_non_scaled(
model: nn.Module,
inplace: bool = False,
is_pnnx: bool = False,
):
"""
Args:
model:
The model to be converted.
inplace:
If True, the input model is modified inplace.
If False, the input model is copied and we modify the copied version.
is_pnnx:
True if we are going to export the model for PNNX.
Return:
Return a model without scaled layers.
"""
if not inplace:
model = copy.deepcopy(model)
d = {}
for name, m in model.named_modules():
if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)):
d[name] = nn.Identity()
for k, v in d.items():
if "." in k:
parent, child = k.rsplit(".", maxsplit=1)
setattr(get_submodule(model, parent), child, v)
else:
setattr(model, k, v)
return model

View File

@ -0,0 +1,282 @@
# Copyright 2022 Xiaomi Corp. (authors: Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import List
import k2
import torch
import torch.nn as nn
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
from decode_stream import DecodeStream
from icefall.decode import one_best_decoding
from icefall.utils import get_texts
def greedy_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[DecodeStream],
) -> None:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C), where N >= 1.
streams:
A list of Stream objects.
"""
assert len(streams) == encoder_out.size(0)
assert encoder_out.ndim == 3
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
T = encoder_out.size(1)
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
# decoder_out is of shape (N, 1, decoder_out_dim)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
for t in range(T):
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
# logits'shape (batch_size, vocab_size)
logits = logits.squeeze(1).squeeze(1)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
streams[i].hyp.append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=False,
)
decoder_out = model.joiner.decoder_proj(decoder_out)
def modified_beam_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[DecodeStream],
num_active_paths: int = 4,
) -> None:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Args:
model:
The RNN-T model.
encoder_out:
A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of
the encoder model.
streams:
A list of stream objects.
num_active_paths:
Number of active paths during the beam search.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert len(streams) == encoder_out.size(0)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = next(model.parameters()).device
batch_size = len(streams)
T = encoder_out.size(1)
B = [stream.hyps for stream in streams]
for t in range(T):
current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.stack(
[hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out)
# decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, encoder_out_dim)
logits = model.joiner(current_encoder_out, decoder_out, project_input=False)
# logits is of shape (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
if new_token != blank_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp)
for i in range(batch_size):
streams[i].hyps = B[i]
def fast_beam_search_one_best(
model: nn.Module,
encoder_out: torch.Tensor,
processed_lens: torch.Tensor,
streams: List[DecodeStream],
beam: float,
max_states: int,
max_contexts: int,
) -> None:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first generated by Fsa-based beam search, then we get the
recognition by applying shortest path on the lattice.
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
processed_lens:
A tensor of shape (N,) containing the number of processed frames
in `encoder_out` before padding.
streams:
A list of stream objects.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
"""
assert encoder_out.ndim == 3
B, T, C = encoder_out.shape
assert B == len(streams)
context_size = model.decoder.context_size
vocab_size = model.decoder.vocab_size
config = k2.RnntDecodingConfig(
vocab_size=vocab_size,
decoder_history_len=context_size,
beam=beam,
max_contexts=max_contexts,
max_states=max_states,
)
individual_streams = []
for i in range(B):
individual_streams.append(streams[i].rnnt_decoding_stream)
decoding_streams = k2.RnntDecodingStreams(individual_streams, config)
for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# current_encoder_out is of shape
# (shape.NumElements(), 1, joiner_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
)
# fmt: on
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(processed_lens.tolist())
best_path = one_best_decoding(lattice)
hyp_tokens = get_texts(best_path)
for i in range(B):
streams[i].hyp = hyp_tokens[i]

View File

@ -0,0 +1,876 @@
#!/usr/bin/env python3
# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang,
# Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./zipformer/streaming_decode.py \
--epoch 28 \
--avg 15 \
--causal 1 \
--chunk-size 32 \
--left-context-frames 256 \
--exp-dir ./zipformer/exp \
--decoding-method greedy_search \
--num-decode-streams 2000
"""
import argparse
import logging
import math
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import numpy as np
import sentencepiece as spm
import torch
from asr_datamodule import LibriSpeechAsrDataModule
from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
from streaming_beam_search import (
fast_beam_search_one_best,
greedy_search,
modified_beam_search,
)
from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import (
AttributeDict,
make_pad_mask,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Supported decoding methods are:
greedy_search
modified_beam_search
fast_beam_search
""",
)
parser.add_argument(
"--num_active_paths",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=32,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--num-decode-streams",
type=int,
default=2000,
help="The number of streams that can be decoded parallel.",
)
add_model_arguments(parser)
return parser
def get_init_states(
model: nn.Module,
batch_size: int = 1,
device: torch.device = torch.device("cpu"),
) -> List[torch.Tensor]:
"""
Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
states[-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
states[-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
"""
states = model.encoder.get_init_states(batch_size, device)
embed_states = model.encoder_embed.get_init_states(batch_size, device)
states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
states.append(processed_lens)
return states
def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
"""Stack list of zipformer states that correspond to separate utterances
into a single emformer state, so that it can be used as an input for
zipformer when those utterances are formed into a batch.
Args:
state_list:
Each element in state_list corresponding to the internal state
of the zipformer model for a single utterance. For element-n,
state_list[n] is a list of cached tensors of all encoder layers. For layer-i,
state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1,
cached_val2, cached_conv1, cached_conv2).
state_list[n][-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
state_list[n][-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
Note:
It is the inverse of :func:`unstack_states`.
"""
batch_size = len(state_list)
assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0])
tot_num_layers = (len(state_list[0]) - 2) // 6
batch_states = []
for layer in range(tot_num_layers):
layer_offset = layer * 6
# cached_key: (left_context_len, batch_size, key_dim)
cached_key = torch.cat(
[state_list[i][layer_offset] for i in range(batch_size)], dim=1
)
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
cached_nonlin_attn = torch.cat(
[state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1
)
# cached_val1: (left_context_len, batch_size, value_dim)
cached_val1 = torch.cat(
[state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1
)
# cached_val2: (left_context_len, batch_size, value_dim)
cached_val2 = torch.cat(
[state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1
)
# cached_conv1: (#batch, channels, left_pad)
cached_conv1 = torch.cat(
[state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0
)
# cached_conv2: (#batch, channels, left_pad)
cached_conv2 = torch.cat(
[state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0
)
batch_states += [
cached_key,
cached_nonlin_attn,
cached_val1,
cached_val2,
cached_conv1,
cached_conv2,
]
cached_embed_left_pad = torch.cat(
[state_list[i][-2] for i in range(batch_size)], dim=0
)
batch_states.append(cached_embed_left_pad)
processed_lens = torch.cat(
[state_list[i][-1] for i in range(batch_size)], dim=0
)
batch_states.append(processed_lens)
return batch_states
def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
"""Unstack the zipformer state corresponding to a batch of utterances
into a list of states, where the i-th entry is the state from the i-th
utterance in the batch.
Note:
It is the inverse of :func:`stack_states`.
Args:
batch_states: A list of cached tensors of all encoder layers. For layer-i,
states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2,
cached_conv1, cached_conv2).
state_list[-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
states[-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
Returns:
state_list: A list of list. Each element in state_list corresponding to the internal state
of the zipformer model for a single utterance.
"""
assert (len(batch_states) - 2) % 6 == 0, len(batch_states)
tot_num_layers = (len(batch_states) - 2) // 6
processed_lens = batch_states[-1]
batch_size = processed_lens.shape[0]
state_list = [[] for _ in range(batch_size)]
for layer in range(tot_num_layers):
layer_offset = layer * 6
# cached_key: (left_context_len, batch_size, key_dim)
cached_key_list = batch_states[layer_offset].chunk(
chunks=batch_size, dim=1
)
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
chunks=batch_size, dim=1
)
# cached_val1: (left_context_len, batch_size, value_dim)
cached_val1_list = batch_states[layer_offset + 2].chunk(
chunks=batch_size, dim=1
)
# cached_val2: (left_context_len, batch_size, value_dim)
cached_val2_list = batch_states[layer_offset + 3].chunk(
chunks=batch_size, dim=1
)
# cached_conv1: (#batch, channels, left_pad)
cached_conv1_list = batch_states[layer_offset + 4].chunk(
chunks=batch_size, dim=0
)
# cached_conv2: (#batch, channels, left_pad)
cached_conv2_list = batch_states[layer_offset + 5].chunk(
chunks=batch_size, dim=0
)
for i in range(batch_size):
state_list[i] += [
cached_key_list[i],
cached_nonlin_attn_list[i],
cached_val1_list[i],
cached_val2_list[i],
cached_conv1_list[i],
cached_conv2_list[i],
]
cached_embed_left_pad_list = batch_states[-2].chunk(
chunks=batch_size, dim=0
)
for i in range(batch_size):
state_list[i].append(cached_embed_left_pad_list[i])
processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0)
for i in range(batch_size):
state_list[i].append(processed_lens_list[i])
return state_list
def streaming_forward(
features: Tensor,
feature_lens: Tensor,
model: nn.Module,
states: List[Tensor],
chunk_size: int,
left_context_len: int,
) -> Tuple[Tensor, Tensor, List[Tensor]]:
"""
Returns encoder outputs, output lengths, and updated states.
"""
cached_embed_left_pad = states[-2]
(
x,
x_lens,
new_cached_embed_left_pad,
) = model.encoder_embed.streaming_forward(
x=features,
x_lens=feature_lens,
cached_left_pad=cached_embed_left_pad,
)
assert x.size(1) == chunk_size, (x.size(1), chunk_size)
src_key_padding_mask = make_pad_mask(x_lens)
# processed_mask is used to mask out initial states
processed_mask = torch.arange(left_context_len, device=x.device).expand(
x.size(0), left_context_len
)
processed_lens = states[-1] # (batch,)
# (batch, left_context_size)
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
# Update processed lengths
new_processed_lens = processed_lens + x_lens
# (batch, left_context_size + chunk_size)
src_key_padding_mask = torch.cat(
[processed_mask, src_key_padding_mask], dim=1
)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_states = states[:-2]
(
encoder_out,
encoder_out_lens,
new_encoder_states,
) = model.encoder.streaming_forward(
x=x,
x_lens=x_lens,
states=encoder_states,
src_key_padding_mask=src_key_padding_mask,
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
new_states = new_encoder_states + [
new_cached_embed_left_pad,
new_processed_lens,
]
return encoder_out, encoder_out_lens, new_states
def decode_one_chunk(
params: AttributeDict,
model: nn.Module,
decode_streams: List[DecodeStream],
) -> List[int]:
"""Decode one chunk frames of features for each decode_streams and
return the indexes of finished streams in a List.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
decode_streams:
A List of DecodeStream, each belonging to a utterance.
Returns:
Return a List containing which DecodeStreams are finished.
"""
device = model.device
chunk_size = int(params.chunk_size)
left_context_len = int(params.left_context_frames)
features = []
feature_lens = []
states = []
processed_lens = [] # Used in fast-beam-search
for stream in decode_streams:
feat, feat_len = stream.get_feature_frames(chunk_size * 2)
features.append(feat)
feature_lens.append(feat_len)
states.append(stream.states)
processed_lens.append(stream.done_frames)
feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
# Make sure the length after encoder_embed is at least 1.
# The encoder_embed subsample features (T - 7) // 2
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
tail_length = chunk_size * 2 + 7 + 2 * 3
if features.size(1) < tail_length:
pad_length = tail_length - features.size(1)
feature_lens += pad_length
features = torch.nn.functional.pad(
features,
(0, 0, 0, pad_length),
mode="constant",
value=LOG_EPS,
)
states = stack_states(states)
encoder_out, encoder_out_lens, new_states = streaming_forward(
features=features,
feature_lens=feature_lens,
model=model,
states=states,
chunk_size=chunk_size,
left_context_len=left_context_len,
)
encoder_out = model.joiner.encoder_proj(encoder_out)
if params.decoding_method == "greedy_search":
greedy_search(
model=model, encoder_out=encoder_out, streams=decode_streams
)
elif params.decoding_method == "fast_beam_search":
processed_lens = torch.tensor(processed_lens, device=device)
processed_lens = processed_lens + encoder_out_lens
fast_beam_search_one_best(
model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
streams=decode_streams,
beam=params.beam,
max_states=params.max_states,
max_contexts=params.max_contexts,
)
elif params.decoding_method == "modified_beam_search":
modified_beam_search(
model=model,
streams=decode_streams,
encoder_out=encoder_out,
num_active_paths=params.num_active_paths,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
states = unstack_states(new_states)
finished_streams = []
for i in range(len(decode_streams)):
decode_streams[i].states = states[i]
decode_streams[i].done_frames += encoder_out_lens[i]
if decode_streams[i].done:
finished_streams.append(i)
return finished_streams
def decode_dataset(
cuts: CutSet,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
cuts:
Lhotse Cutset containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
device = model.device
opts = FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
log_interval = 100
decode_results = []
# Contain decode streams currently running.
decode_streams = []
for num, cut in enumerate(cuts):
# each utterance has a DecodeStream.
initial_states = get_init_states(
model=model, batch_size=1, device=device
)
decode_stream = DecodeStream(
params=params,
cut_id=cut.id,
initial_states=initial_states,
decoding_graph=decoding_graph,
device=device,
)
audio: np.ndarray = cut.load_audio()
# audio.shape: (1, num_samples)
assert len(audio.shape) == 2
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
samples = torch.from_numpy(audio).squeeze(0)
fbank = Fbank(opts)
feature = fbank(samples.to(device))
decode_stream.set_features(feature, tail_pad_len=30)
decode_stream.ground_truth = cut.supervisions[0].text
decode_streams.append(decode_stream)
while len(decode_streams) >= params.num_decode_streams:
finished_streams = decode_one_chunk(
params=params, model=model, decode_streams=decode_streams
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
decode_streams[i].id,
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
)
)
del decode_streams[i]
if num % log_interval == 0:
logging.info(f"Cuts processed until now is {num}.")
# decode final chunks of last sequences
while len(decode_streams):
finished_streams = decode_one_chunk(
params=params, model=model, decode_streams=decode_streams
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
decode_streams[i].id,
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
)
)
del decode_streams[i]
if params.decoding_method == "greedy_search":
key = "greedy_search"
elif params.decoding_method == "fast_beam_search":
key = (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
)
elif params.decoding_method == "modified_beam_search":
key = f"num_active_paths_{params.num_active_paths}"
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
return {key: decode_results}
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
assert params.causal, params.causal
assert (
"," not in params.chunk_size
), "chunk_size should be one value in decoding."
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
params.suffix += f"-chunk-{params.chunk_size}"
params.suffix += f"-left-context-{params.left_context_frames}"
# for fast_beam_search
if params.decoding_method == "fast_beam_search":
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
model.device = device
decoding_graph = None
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_sets = ["test-clean", "test-other"]
test_cuts = [test_clean_cuts, test_other_cuts]
for test_set, test_cut in zip(test_sets, test_cuts):
results_dict = decode_dataset(
cuts=test_cut,
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,407 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import warnings
import torch
from torch import Tensor, nn
from scaling import (
Balancer,
BiasNorm,
Dropout3,
FloatLike,
Optional,
ScaledConv2d,
ScaleGrad,
ScheduledFloat,
SwooshL,
SwooshR,
Whiten,
)
class ConvNeXt(nn.Module):
"""
Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf
"""
def __init__(
self,
channels: int,
hidden_ratio: int = 3,
kernel_size: Tuple[int, int] = (7, 7),
layerdrop_rate: FloatLike = None,
):
super().__init__()
self.padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)
hidden_channels = channels * hidden_ratio
if layerdrop_rate is None:
layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015))
self.layerdrop_rate = layerdrop_rate
self.depthwise_conv = nn.Conv2d(
in_channels=channels,
out_channels=channels,
groups=channels,
kernel_size=kernel_size,
padding=self.padding,
)
self.pointwise_conv1 = nn.Conv2d(
in_channels=channels, out_channels=hidden_channels, kernel_size=1
)
self.hidden_balancer = Balancer(
hidden_channels,
channel_dim=1,
min_positive=0.3,
max_positive=1.0,
min_abs=0.75,
max_abs=5.0,
)
self.activation = SwooshL()
self.pointwise_conv2 = ScaledConv2d(
in_channels=hidden_channels,
out_channels=channels,
kernel_size=1,
initial_scale=0.01,
)
self.out_balancer = Balancer(
channels,
channel_dim=1,
min_positive=0.4,
max_positive=0.6,
min_abs=1.0,
max_abs=6.0,
)
self.out_whiten = Whiten(
num_groups=1,
whitening_limit=5.0,
prob=(0.025, 0.25),
grad_scale=0.01,
)
def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting() or not self.training:
return self.forward_internal(x)
layerdrop_rate = float(self.layerdrop_rate)
if layerdrop_rate != 0.0:
batch_size = x.shape[0]
mask = (
torch.rand(
(batch_size, 1, 1, 1), dtype=x.dtype, device=x.device
)
> layerdrop_rate
)
else:
mask = None
# turns out this caching idea does not work with --world-size > 1
# return caching_eval(self.forward_internal, x, mask)
return self.forward_internal(x, mask)
def forward_internal(
self, x: Tensor, layer_skip_mask: Optional[Tensor] = None
) -> Tensor:
"""
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
The returned value has the same shape as x.
"""
bypass = x
x = self.depthwise_conv(x)
x = self.pointwise_conv1(x)
x = self.hidden_balancer(x)
x = self.activation(x)
x = self.pointwise_conv2(x)
if layer_skip_mask is not None:
x = x * layer_skip_mask
x = bypass + x
x = self.out_balancer(x)
x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last
x = self.out_whiten(x)
x = x.transpose(1, 3) # (N, C, H, W)
return x
def streaming_forward(
self,
x: Tensor,
cached_left_pad: Tensor,
) -> Tuple[Tensor, Tensor]:
"""
Args:
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
cached_left_pad: (batch_size, num_channels, left_pad, num_freqs)
Returns:
- The returned value has the same shape as x.
- Updated cached_left_pad.
"""
padding = self.padding
# The length without right padding for depth-wise conv
T = x.size(2) - padding[0]
bypass = x[:, :, :T, :]
# Pad left side
assert cached_left_pad.size(2) == padding[0], (
cached_left_pad.size(2),
padding[0],
)
x = torch.cat([cached_left_pad, x], dim=2)
# Update cached left padding
cached_left_pad = x[:, :, T : padding[0] + T, :]
# depthwise_conv
x = torch.nn.functional.conv2d(
x,
weight=self.depthwise_conv.weight,
bias=self.depthwise_conv.bias,
padding=(0, padding[1]),
groups=self.depthwise_conv.groups,
)
x = self.pointwise_conv1(x)
x = self.hidden_balancer(x)
x = self.activation(x)
x = self.pointwise_conv2(x)
x = bypass + x
return x, cached_left_pad
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/2 length).
Convert an input of shape (N, T, idim) to an output
with shape (N, T', odim), where
T' = (T-3)//2 - 2 == (T-7)//2
It is based on
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
"""
def __init__(
self,
in_channels: int,
out_channels: int,
layer1_channels: int = 8,
layer2_channels: int = 32,
layer3_channels: int = 128,
dropout: FloatLike = 0.1,
) -> None:
"""
Args:
in_channels:
Number of channels in. The input shape is (N, T, in_channels).
Caution: It requires: T >=7, in_channels >=7
out_channels
Output dim. The output shape is (N, (T-3)//2, out_channels)
layer1_channels:
Number of channels in layer1
layer1_channels:
Number of channels in layer2
bottleneck:
bottleneck dimension for 1d squeeze-excite
"""
assert in_channels >= 7
super().__init__()
# The ScaleGrad module is there to prevent the gradients
# w.r.t. the weight or bias of the first Conv2d module in self.conv from
# exceeding the range of fp16 when using automatic mixed precision (amp)
# training. (The second one is necessary to stop its bias from getting
# a too-large gradient).
self.conv = nn.Sequential(
nn.Conv2d(
in_channels=1,
out_channels=layer1_channels,
kernel_size=3,
padding=(0, 1), # (time, freq)
),
ScaleGrad(0.2),
Balancer(layer1_channels, channel_dim=1, max_abs=1.0),
SwooshR(),
nn.Conv2d(
in_channels=layer1_channels,
out_channels=layer2_channels,
kernel_size=3,
stride=2,
padding=0,
),
Balancer(layer2_channels, channel_dim=1, max_abs=4.0),
SwooshR(),
nn.Conv2d(
in_channels=layer2_channels,
out_channels=layer3_channels,
kernel_size=3,
stride=(1, 2), # (time, freq)
),
Balancer(layer3_channels, channel_dim=1, max_abs=4.0),
SwooshR(),
)
# just one convnext layer
self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7))
self.out_width = (((in_channels - 1) // 2) - 1) // 2
self.layer3_channels = layer3_channels
self.out = nn.Linear(self.out_width * layer3_channels, out_channels)
# use a larger than normal grad_scale on this whitening module; there is
# only one such module, so there is not a concern about adding together
# many copies of this extra gradient term.
self.out_whiten = Whiten(
num_groups=1,
whitening_limit=ScheduledFloat(
(0.0, 4.0), (20000.0, 8.0), default=4.0
),
prob=(0.025, 0.25),
grad_scale=0.02,
)
# max_log_eps=0.0 is to prevent both eps and the output of self.out from
# getting large, there is an unnecessary degree of freedom.
self.out_norm = BiasNorm(out_channels)
self.dropout = Dropout3(dropout, shared_dim=1)
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x:
Its shape is (N, T, idim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
Returns:
- a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
- output lengths, of shape (batch_size,)
"""
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
# scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision)
# training, since the weights in the first convolution are otherwise the limiting factor for getting infinite
# gradients.
x = self.conv(x)
x = self.convnext(x)
# Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2)
b, c, t, f = x.size()
x = x.transpose(1, 2).reshape(b, t, c * f)
# now x: (N, ((T-1)//2 - 1))//2, out_width * layer3_channels))
x = self.out(x)
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
x = self.out_whiten(x)
x = self.out_norm(x)
x = self.dropout(x)
if torch.jit.is_scripting():
x_lens = (x_lens - 7) // 2
else:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
x_lens = (x_lens - 7) // 2
assert x.size(1) == x_lens.max().item()
return x, x_lens
def streaming_forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
cached_left_pad: Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x:
Its shape is (N, T, idim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
Returns:
- a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
- output lengths, of shape (batch_size,)
- updated cache
"""
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
# T' = (T-7)//2
x = self.conv(x)
# T' = (T-7)//2-3
x, cached_left_pad = self.convnext.streaming_forward(
x, cached_left_pad=cached_left_pad
)
# Now x is of shape (N, odim, T', ((idim-1)//2 - 1)//2)
b, c, t, f = x.size()
x = x.transpose(1, 2).reshape(b, t, c * f)
# now x: (N, T', out_width * layer3_channels))
x = self.out(x)
# Now x is of shape (N, T', odim)
x = self.out_norm(x)
if torch.jit.is_scripting() or torch.jit.is_tracing():
assert self.convnext.padding[0] == 3
# The ConvNeXt module needs 3 frames of right padding after subsampling
x_lens = (x_lens - 7) // 2 - 3
else:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# The ConvNeXt module needs 3 frames of right padding after subsampling
assert self.convnext.padding[0] == 3
x_lens = (x_lens - 7) // 2 - 3
assert x.size(1) == x_lens.max().item()
return x, x_lens, cached_left_pad
@torch.jit.export
def get_init_states(
self,
batch_size: int = 1,
device: torch.device = torch.device("cpu"),
) -> Tensor:
"""Get initial states for Conv2dSubsampling module.
It is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
"""
left_pad = self.convnext.padding[0]
freq = self.out_width
channels = self.layer3_channels
cached_embed_left_pad = torch.zeros(
batch_size, channels, left_pad, freq
).to(device)
return cached_embed_left_pad

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -16,15 +16,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from dataclasses import dataclass
from typing import List, Optional, Tuple
from typing import Optional, Tuple, List
import torch
from torch import Tensor, nn
class TensorDiagnosticOptions(object):
"""Options object for tensor diagnostics:
@ -60,7 +58,8 @@ def get_tensor_stats(
"abs" -> take abs() before summing
"positive" -> take (x > 0) before summing
"rms" -> square before summing, we'll take sqrt later
"value -> just sum x itself
"value" -> just sum x itself
"max", "min" -> take the maximum or minimum [over all other dims but dim] instead of summing
Returns:
stats: a Tensor of shape (x.shape[dim],).
count: an integer saying how many items were counted in each element
@ -78,11 +77,11 @@ def get_tensor_stats(
elif stats_type == "abs":
x = x.abs()
elif stats_type == "rms":
x = x**2
x = x ** 2
elif stats_type == "positive":
x = (x > 0).to(dtype=torch.float)
else:
assert stats_type in ["value", "max", "min"]
assert stats_type in [ "value", "max", "min" ]
sum_dims = [d for d in range(x.ndim) if d != dim]
if len(sum_dims) > 0:
@ -94,7 +93,7 @@ def get_tensor_stats(
x = torch.min(x, dim=dim)[0]
else:
x = torch.sum(x, dim=sum_dims)
x = x.flatten()
x = x.flatten().clone()
return x, count
@ -106,7 +105,7 @@ class TensorAndCount:
class TensorDiagnostic(object):
"""This class is not directly used by the user, it is responsible for
collecting diagnostics for a single parameter tensor of a torch.nn.Module.
collecting diagnostics for a module or parameter tensor of a torch.nn.Module.
Args:
opts:
@ -121,9 +120,14 @@ class TensorDiagnostic(object):
self.name = name
self.class_name = None # will assign in accumulate()
self.stats = (
None # we'll later assign a list to this data member. It's a list of dict.
)
self.stats = None # we'll later assign a list to self.stats.
# It's a list of dicts, indexed by dim (i.e. by the
# axis of the tensor). The dicts, in turn, are
# indexed by `stats-type` which are strings in
# ["abs", "max", "min", "positive", "value", "rms"].
# scalar_stats contains some analysis of the activations and gradients,
self.scalar_stats = None
# the keys into self.stats[dim] are strings, whose values can be
# "abs", "max", "min" ,"value", "positive", "rms", "value".
@ -135,6 +139,7 @@ class TensorDiagnostic(object):
# only adding a new element to the list if there was a different dim.
# if the string in the key is "eigs", if we detect a length mismatch we put None as the value.
def accumulate(self, x, class_name: Optional[str] = None):
"""
Accumulate tensors.
@ -178,20 +183,27 @@ class TensorDiagnostic(object):
if s.tensor.shape == stats.shape:
if stats_type == "max":
s.tensor = torch.maximum(s.tensor, stats)
elif stats_type == "min":
s.tensor = torch.minimum(s.tensor, stats)
else:
assert stats_type != "max"
s.tensor += stats
s.count += count
done = True
break
if not done:
if this_dim_stats[stats_type] != [] and stats_type == "eigs":
if (
this_dim_stats[stats_type] != []
and stats_type == "eigs"
):
# >1 size encountered on this dim, e.g. it's a batch or time dimension,
# don't accumulat "eigs" stats type, it uses too much memory
this_dim_stats[stats_type] = None
else:
this_dim_stats[stats_type].append(TensorAndCount(stats, count))
this_dim_stats[stats_type].append(
TensorAndCount(stats, count)
)
def print_diagnostics(self):
"""Print diagnostics for each dimension of the tensor."""
@ -199,14 +211,28 @@ class TensorDiagnostic(object):
print(f"Warning: the stats of {self.name} is None.")
return
for dim, this_dim_stats in enumerate(self.stats):
if "rms" in this_dim_stats and "value" in this_dim_stats:
# produce "stddev" stats, which is centered RMS.
rms_stats_list = this_dim_stats["rms"]
value_stats_list = this_dim_stats["value"]
if len(rms_stats_list) == len(value_stats_list):
stddev_stats_list = []
for r, v in zip(rms_stats_list, value_stats_list):
stddev_stats_list.append(
# r.count and v.count should be the same, but we don't check this.
TensorAndCount(r.tensor - v.tensor * v.tensor / (v.count + 1.0e-20),
r.count))
this_dim_stats["stddev"] = stddev_stats_list
for stats_type, stats_list in this_dim_stats.items():
# stats_type could be "rms", "value", "abs", "eigs", "positive".
# stats_type could be "rms", "value", "abs", "eigs", "positive", "min" or "max".
# "stats_list" could be a list of TensorAndCount (one list per distinct tensor
# shape of the stats), or None
if stats_list is None:
assert stats_type == "eigs"
continue
def get_count(count):
return 1 if stats_type in ["max", "min"] else count
@ -224,20 +250,22 @@ class TensorDiagnostic(object):
eigs, _ = torch.symeig(stats)
stats = eigs.abs().sqrt()
except: # noqa
print("Error getting eigenvalues, trying another method.")
print(
"Error getting eigenvalues, trying another method."
)
eigs, _ = torch.eig(stats)
stats = eigs.abs().sqrt()
stats = eigs.norm(dim=1).sqrt()
# sqrt so it reflects data magnitude, like stddev- not variance
if stats_type == "rms":
if stats_type in [ "rms", "stddev" ]:
# we stored the square; after aggregation we need to take sqrt.
stats = stats.sqrt()
# if `summarize` we print percentiles of the stats; else,
# we print out individual elements.
summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized(
stats.numel()
)
summarize = (
len(stats_list) > 1
) or self.opts.dim_is_summarized(stats.numel())
if summarize: # usually `summarize` will be true
# print out percentiles.
stats = stats.sort()[0]
@ -254,32 +282,192 @@ class TensorDiagnostic(object):
ans = stats.tolist()
ans = ["%.2g" % x for x in ans]
ans = "[" + " ".join(ans) + "]"
if stats_type in ["value", "rms", "eigs"]:
if stats_type in [ "value", "rms", "stddev", "eigs" ]:
# This norm is useful because it is strictly less than the largest
# sqrt(eigenvalue) of the variance, which we print out, and shows,
# speaking in an approximate way, how much of that largest eigenvalue
# can be attributed to the mean of the distribution.
norm = (stats**2).sum().sqrt().item()
norm = (stats ** 2).sum().sqrt().item()
ans += f", norm={norm:.2g}"
mean = stats.mean().item()
rms = (stats**2).mean().sqrt().item()
ans += f", mean={mean:.2g}, rms={rms:.2g}"
rms = (stats ** 2).mean().sqrt().item()
ans += f", mean={mean:.3g}, rms={rms:.3g}"
# OK, "ans" contains the actual stats, e.g.
# ans = "percentiles: [0.43 0.46 0.48 0.49 0.49 0.5 0.51 0.52 0.53 0.54 0.59], mean=0.5, rms=0.5"
sizes = [x.tensor.shape[0] for x in stats_list]
size_str = (
f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}"
)
maybe_class_name = (
f" type={self.class_name}," if self.class_name is not None else ""
f"{sizes[0]}"
if len(sizes) == 1
else f"{min(sizes)}..{max(sizes)}"
)
maybe_class_name = f" type={self.class_name}," if self.class_name is not None else ""
print(
f"module={self.name},{maybe_class_name} dim={dim}, size={size_str}, {stats_type} {ans}"
)
class ScalarDiagnostic(object):
"""This class is not directly used by the user, it is responsible for
collecting diagnostics for a single module (subclass of torch.nn.Module) that
represents some kind of nonlinearity, e.g. ReLU, sigmoid, etc.
"""
def __init__(self, opts: TensorDiagnosticOptions, name: str):
self.opts = opts
self.name = name
self.class_name = None # will assign in accumulate()
self.is_forward_pass = True
self.tick_scale = None
self.saved_inputs = []
self.is_ok = True
self.counts = None
self.sum_grad = None
self.sum_gradsq = None
self.sum_abs_grad = None
def accumulate_input(self, x: Tensor, class_name: Optional[str] = None):
"""
Called in forward pass.
"""
if not self.is_forward_pass:
# in case we did a forward pass without a backward pass, for some reason.
self.saved_inputs = []
self.is_forward_pass = True
if class_name is not None:
self.class_name = class_name
if not self.is_ok:
return
limit = 10
if len(self.saved_inputs) > limit:
print(f"ERROR: forward pass called for this module over {limit} times with no backward pass. "
f" Will not accumulate scalar stats.")
self.is_ok = False
return
self.saved_inputs.append(x)
def accumulate_output_grad(self, grad: Tensor):
if not self.is_ok:
return
if self.is_forward_pass:
self.is_forward_pass = False
last_shape = 'n/a' if len(self.saved_inputs) == 0 else self.saved_inputs[-1].shape
if len(self.saved_inputs) == 0 or grad.shape != last_shape:
print(f"ERROR: shape mismatch or no forward activation present when backward "
f"pass called: grad shape ={tuple(grad.shape)}, num-saved-inputs={len(self.saved_inputs)}"
f", shape-of-last-saved-input={last_shape}")
self.is_ok = False
return
x = self.saved_inputs.pop()
self.process_input_and_grad(x, grad)
def process_input_and_grad(self, x: Tensor, grad: Tensor):
assert x.shape == grad.shape
x = x.flatten()
grad = grad.flatten()
num_ticks_per_side = 256
if self.tick_scale is None:
x_abs_sorted = x.abs().sort()[0]
# take the 98th percentile as the largest value we count separately.
index = int(x.numel() * 0.98)
self.tick_scale = float(x_abs_sorted[index] / num_ticks_per_side)
# integerize from tick * (-num ticks_per_side .. num_ticks_per_side - 1]
self.counts = torch.zeros(2 * num_ticks_per_side, dtype=torch.long, device=x.device)
self.sum_grad = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device)
# sum_gradsq is for getting error bars.
self.sum_gradsq = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device)
self.sum_abs_grad = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device)
# this will round down.
x = (x / self.tick_scale).to(torch.long)
x = x.clamp_(min=-num_ticks_per_side, max=num_ticks_per_side - 1)
x = x + num_ticks_per_side
self.counts.index_add_(dim=0, index=x, source=torch.ones_like(x))
self.sum_grad.index_add_(dim=0, index=x, source=grad.to(torch.double))
self.sum_gradsq.index_add_(dim=0, index=x, source=(grad*grad).to(torch.double))
self.sum_abs_grad.index_add_(dim=0, index=x, source=grad.abs().to(torch.double))
def print_diagnostics(self):
"""Print diagnostics."""
if self.is_ok is False or self.counts is None:
print(f"Warning: no stats accumulated for {self.name}, is_ok={self.is_ok}")
return
counts = self.counts.to('cpu')
sum_grad = self.sum_grad.to(device='cpu', dtype=torch.float32)
sum_gradsq = self.sum_gradsq.to(device='cpu', dtype=torch.float32)
sum_abs_grad = self.sum_abs_grad.to(device='cpu', dtype=torch.float32)
counts_cumsum = counts.cumsum(dim=0)
counts_tot = counts_cumsum[-1]
# subdivide the distribution up into `num_bins` intervals for analysis, for greater
# statistical significance. each bin corresponds to multiple of the original 'tick' intervals.
num_bins = 20
# integer division
counts_per_bin = (counts_tot // num_bins) + 1
bin_indexes = counts_cumsum // counts_per_bin
bin_indexes = bin_indexes.clamp(min=0, max=num_bins).to(torch.long)
bin_counts = torch.zeros(num_bins, dtype=torch.long)
bin_counts.index_add_(dim=0, index=bin_indexes, source=counts)
bin_grad = torch.zeros(num_bins)
bin_grad.index_add_(dim=0, index=bin_indexes, source=sum_grad)
bin_gradsq = torch.zeros(num_bins)
bin_gradsq.index_add_(dim=0, index=bin_indexes, source=sum_gradsq)
bin_abs_grad = torch.zeros(num_bins)
bin_abs_grad.index_add_(dim=0, index=bin_indexes, source=sum_abs_grad)
avg_grad = (bin_grad / bin_counts)
avg_grad_stddev = (bin_gradsq / bin_counts).sqrt()
bin_boundary_counts = torch.arange(num_bins + 1, dtype=torch.long) * counts_per_bin
bin_tick_indexes = torch.searchsorted(counts_cumsum, bin_boundary_counts)
# boundaries are the "x" values between the bins, e.g. corresponding to the
# locations of percentiles of the distribution.
num_ticks_per_side = counts.numel() // 2
bin_boundaries = (bin_tick_indexes - num_ticks_per_side) * self.tick_scale
bin_grad = bin_grad / (bin_counts + 1)
bin_conf_interval = bin_gradsq.sqrt() / (bin_counts + 1) # consider this a standard deviation.
# bin_grad / bin_abs_grad will give us a sense for how important in a practical sense,
# the gradients are.
bin_abs_grad = bin_abs_grad / (bin_counts + 1)
bin_rel_grad = bin_grad / (bin_abs_grad + 1.0e-20)
bin_conf = bin_grad / (bin_conf_interval + 1.0e-20)
def tensor_to_str(x: Tensor):
x = ["%.2g" % f for f in x]
x = "[" + " ".join(x) + "]"
return x
maybe_class_name = f" type={self.class_name}," if self.class_name is not None else ""
print(
f"module={self.name},{maybe_class_name} bin-boundaries={tensor_to_str(bin_boundaries)}, "
f"rel_grad={tensor_to_str(bin_rel_grad)}, grad_conf={tensor_to_str(bin_conf)}"
)
class ModelDiagnostic(object):
"""This class stores diagnostics for all tensors in the torch.nn.Module.
@ -297,9 +485,11 @@ class ModelDiagnostic(object):
self.opts = opts
self.diagnostics = dict()
def __getitem__(self, name: str):
T = ScalarDiagnostic if name[-7:] == '.scalar' else TensorDiagnostic
if name not in self.diagnostics:
self.diagnostics[name] = TensorDiagnostic(self.opts, name)
self.diagnostics[name] = T(self.opts, name)
return self.diagnostics[name]
def print_diagnostics(self):
@ -332,41 +522,73 @@ def attach_diagnostics(
if name == "":
name = "<top-level>"
# Setting model_diagnostic=ans and n=name below, instead of trying to
# capture the variables, ensures that we use the current values.
# (matters for name, since the variable gets overwritten).
# (this matters for `name`, since the variable gets overwritten).
# These closures don't really capture by value, only by
# "the final value the variable got in the function" :-(
def forward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
def forward_hook(
_module, _input, _output, _model_diagnostic=ans, _name=name
):
if isinstance(_output, tuple) and len(_output) == 1:
_output = _output[0]
if isinstance(_output, Tensor):
_model_diagnostic[f"{_name}.output"].accumulate(
_output, class_name=type(_module).__name__
)
if isinstance(_output, Tensor) and _output.dtype in ( torch.float32, torch.float16, torch.float64 ):
_model_diagnostic[f"{_name}.output"].accumulate(_output,
class_name=type(_module).__name__)
elif isinstance(_output, tuple):
for i, o in enumerate(_output):
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(
o, class_name=type(_module).__name__
)
if o.dtype in ( torch.float32, torch.float16, torch.float64 ):
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o,
class_name=type(_module).__name__)
def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
def backward_hook(
_module, _input, _output, _model_diagnostic=ans, _name=name
):
if isinstance(_output, tuple) and len(_output) == 1:
_output = _output[0]
if isinstance(_output, Tensor):
_model_diagnostic[f"{_name}.grad"].accumulate(
_output, class_name=type(_module).__name__
)
if isinstance(_output, Tensor) and _output.dtype in ( torch.float32, torch.float16, torch.float64 ):
_model_diagnostic[f"{_name}.grad"].accumulate(_output,
class_name=type(_module).__name__)
elif isinstance(_output, tuple):
for i, o in enumerate(_output):
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(
o, class_name=type(_module).__name__
)
if o.dtype in ( torch.float32, torch.float16, torch.float64 ):
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
class_name=type(_module).__name__)
module.register_forward_hook(forward_hook)
module.register_backward_hook(backward_hook)
if type(module).__name__ in ["Sigmoid", "Tanh", "ReLU", "TanSwish", "Swish", "DoubleSwish", "Swoosh"]:
# For these specific module types, accumulate some additional diagnostics
# that can help us improve the activation function. These require a lot of memory,
# to save the forward activations, so limit this to some select classes.
# Note: this will not work correctly for all model types.
def scalar_forward_hook(
_module, _input, _output, _model_diagnostic=ans, _name=name
):
if isinstance(_input, tuple):
_input, = _input
assert isinstance(_input, Tensor)
_model_diagnostic[f"{_name}.scalar"].accumulate_input(_input,
class_name=type(_module).__name__)
def scalar_backward_hook(
_module, _input, _output, _model_diagnostic=ans, _name=name
):
if isinstance(_output, tuple):
_output, = _output
assert isinstance(_output, Tensor)
_model_diagnostic[f"{_name}.scalar"].accumulate_output_grad(_output)
module.register_forward_hook(scalar_forward_hook)
module.register_backward_hook(scalar_backward_hook)
for name, parameter in model.named_parameters():
def param_backward_hook(
@ -390,7 +612,7 @@ def _test_tensor_diagnostic():
diagnostic.print_diagnostics()
model = nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 80))
model = nn.Sequential(nn.Linear(100, 50), nn.ReLU(), nn.Linear(50, 80))
diagnostic = attach_diagnostics(model, opts)
for _ in range(10):

View File

@ -1195,6 +1195,64 @@ def measure_gradient_norms(model: nn.Module, norm: str = "l1") -> Dict[str, floa
return norms
def get_parameter_groups_with_lrs(
model: nn.Module, lr: float, include_names: bool = False
) -> List[dict]:
"""
This is for use with the ScaledAdam optimizers (more recent versions that accept lists of
named-parameters; we can, if needed, create a version without the names).
It provides a way to specifiy learning-rate scales inside the module, so that if
any nn.Module in the hierarchy has a floating-point parameter 'lr_scale', it will
scale the LR of any parameters inside that module or its submodules. Note: you
can set module parameters outside the __init__ function, e.g.:
>>> a = nn.Linear(10, 10)
>>> a.lr_scale = 0.5
Returns: a list of dicts, of the following form:
if include_names == False:
[ { 'params': [ tensor1, tensor2, ... ], 'lr': 0.01 },
{ 'params': [ tensor3, tensor4, ... ], 'lr': 0.005 },
... ]
if include_names == true:
[ { 'named_params': [ (name1, tensor1, (name2, tensor2), ... ], 'lr': 0.01 },
{ 'named_params': [ (name3, tensor3), (name4, tensor4), ... ], 'lr': 0.005 },
... ]
"""
# flat_lr_scale just contains the lr_scale explicitly specified
# for each prefix of the name, e.g. 'encoder.layers.3', these need
# to be multiplied for all prefix of the name of any given parameter.
flat_lr_scale = defaultdict(lambda: 1.0)
names = []
for name, m in model.named_modules():
names.append(name)
if hasattr(m, "lr_scale"):
flat_lr_scale[name] = m.lr_scale
# lr_to_parames is a dict from learning rate (floating point) to: if
# include_names == true, a list of (name, parameter) for that learning rate;
# otherwise a list of parameters for that learning rate.
lr_to_params = defaultdict(list)
for name, parameter in model.named_parameters():
split_name = name.split(".")
# caution: as a special case, if the name is '', split_name will be [ '' ].
prefix = split_name[0]
cur_lr = lr * flat_lr_scale[prefix]
if prefix != "":
cur_lr *= flat_lr_scale[""]
for part in split_name[1:]:
prefix = ".".join([prefix, part])
cur_lr *= flat_lr_scale[prefix]
lr_to_params[cur_lr].append((name, parameter) if include_names else parameter)
if include_names:
return [{"named_params": pairs, "lr": lr} for lr, pairs in lr_to_params.items()]
else:
return [{"params": params, "lr": lr} for lr, params in lr_to_params.items()]
def optim_step_and_measure_param_change(
model: nn.Module,
old_parameters: Dict[str, nn.parameter.Parameter],

View File

@ -11,4 +11,6 @@ exclude = '''
)/
| make_kn_lm.py
| icefall\/__init__\.py
| icefall\/diagnostics\.py
| egs\/librispeech\/ASR\/zipformer
'''

View File

@ -11,7 +11,7 @@ graphviz==0.19.1
-f https://download.pytorch.org/whl/cpu/torch_stable.html torch==1.10.0+cpu
-f https://download.pytorch.org/whl/cpu/torch_stable.html torchaudio==0.10.0+cpu
-f https://k2-fsa.org/nightly/ k2==1.15.1.dev20220426+cpu.torch1.10.0
-f https://k2-fsa.org/nightly/ k2==1.23.4.dev20230316+cpu.torch1.10.0
git+https://github.com/lhotse-speech/lhotse
kaldilm==1.11