Merge branch 'master' of github.com:k2-fsa/icefall into lstm_aishell

This commit is contained in:
yaozengwei 2022-10-16 22:21:08 +08:00
commit 3d42e70029
115 changed files with 10327 additions and 543 deletions

View File

@ -9,7 +9,7 @@ per-file-ignores =
egs/*/ASR/pruned_transducer_stateless*/*.py: E501,
egs/*/ASR/*/optim.py: E501,
egs/*/ASR/*/scaling.py: E501,
egs/librispeech/ASR/lstm_transducer_stateless/*.py: E501, E203
egs/librispeech/ASR/lstm_transducer_stateless*/*.py: E501, E203
egs/librispeech/ASR/conv_emformer_transducer_stateless*/*.py: E501, E203
egs/librispeech/ASR/conformer_ctc2/*py: E501,
egs/librispeech/ASR/RESULTS.md: E999,

View File

@ -4,6 +4,8 @@
# The computed features are saved to ~/tmp/fbank-libri and are
# cached for later runs
set -e
export PYTHONPATH=$PWD:$PYTHONPATH
echo $PYTHONPATH

View File

@ -6,6 +6,8 @@
# You will find directories `~/tmp/giga-dev-dataset-fbank` after running
# this script.
set -e
mkdir -p ~/tmp
cd ~/tmp

View File

@ -7,6 +7,8 @@
# You will find directories ~/tmp/download/LibriSpeech after running
# this script.
set -e
mkdir ~/tmp/download
cd egs/librispeech/ASR
ln -s ~/tmp/download .

View File

@ -3,6 +3,8 @@
# This script installs kaldifeat into the directory ~/tmp/kaldifeat
# which is cached by GitHub actions for later runs.
set -e
mkdir -p ~/tmp
cd ~/tmp
git clone https://github.com/csukuangfj/kaldifeat

View File

@ -4,6 +4,8 @@
# to egs/librispeech/ASR/download/LibriSpeech and generates manifest
# files in egs/librispeech/ASR/data/manifests
set -e
cd egs/librispeech/ASR
[ ! -e download ] && ln -s ~/tmp/download .
mkdir -p data/manifests

View File

@ -1,5 +1,7 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}

View File

@ -1,5 +1,7 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}

View File

@ -0,0 +1,162 @@
#!/usr/bin/env bash
#
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/librispeech/ASR
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
log "Downloading pre-trained model from $repo_url"
git lfs install
git clone $repo_url
repo=$(basename $repo_url)
log "Display test files"
tree $repo/
soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
pushd $repo/exp
ln -s pretrained-iter-468000-avg-16.pt pretrained.pt
ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt
popd
log "Install ncnn and pnnx"
# We are using a modified ncnn here. Will try to merge it to the official repo
# of ncnn
git clone https://github.com/csukuangfj/ncnn
pushd ncnn
git submodule init
git submodule update python/pybind11
python3 setup.py bdist_wheel
ls -lh dist/
pip install dist/*.whl
cd tools/pnnx
mkdir build
cd build
cmake ..
make -j4 pnnx
./src/pnnx || echo "pass"
popd
log "Test exporting to pnnx format"
./lstm_transducer_stateless2/export.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--use-averaged-model 0 \
--pnnx 1
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt
./lstm_transducer_stateless2/ncnn-decode.py \
--bpe-model-filename $repo/data/lang_bpe_500/bpe.model \
--encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
--encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
--decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
--decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
--joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
--joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
$repo/test_wavs/1089-134686-0001.wav
./lstm_transducer_stateless2/streaming-ncnn-decode.py \
--bpe-model-filename $repo/data/lang_bpe_500/bpe.model \
--encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
--encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
--decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
--decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
--joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
--joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
$repo/test_wavs/1089-134686-0001.wav
log "Test exporting with torch.jit.trace()"
./lstm_transducer_stateless2/export.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--use-averaged-model 0 \
--jit-trace 1
log "Decode with models exported by torch.jit.trace()"
./lstm_transducer_stateless2/jit_pretrained.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--encoder-model-filename $repo/exp/encoder_jit_trace.pt \
--decoder-model-filename $repo/exp/decoder_jit_trace.pt \
--joiner-model-filename $repo/exp/joiner_jit_trace.pt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
for sym in 1 2 3; do
log "Greedy search with --max-sym-per-frame $sym"
./lstm_transducer_stateless2/pretrained.py \
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done
for method in modified_beam_search beam_search fast_beam_search; do
log "$method"
./lstm_transducer_stateless2/pretrained.py \
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"ncnn" ]]; then
mkdir -p lstm_transducer_stateless2/exp
ln -s $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt
ln -s $PWD/$repo/data/lang_bpe_500 data/
ls -lh data
ls -lh lstm_transducer_stateless2/exp
log "Decoding test-clean and test-other"
# use a small value for decoding with CPU
max_duration=100
for method in greedy_search fast_beam_search modified_beam_search; do
log "Decoding with $method"
./lstm_transducer_stateless2/decode.py \
--decoding-method $method \
--epoch 999 \
--avg 1 \
--use-averaged-model 0 \
--max-duration $max_duration \
--exp-dir lstm_transducer_stateless2/exp
done
rm lstm_transducer_stateless2/exp/*.pt
fi

View File

@ -1,5 +1,7 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}

View File

@ -1,5 +1,7 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}

View File

@ -1,5 +1,7 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}

View File

@ -1,5 +1,7 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
@ -58,17 +60,17 @@ log "Decode with ONNX models"
--jit-filename $repo/exp/cpu_jit.pt \
--onnx-encoder-filename $repo/exp/encoder.onnx \
--onnx-decoder-filename $repo/exp/decoder.onnx \
--onnx-joiner-filename $repo/exp/joiner.onnx
./pruned_transducer_stateless3/onnx_check_all_in_one.py \
--jit-filename $repo/exp/cpu_jit.pt \
--onnx-all-in-one-filename $repo/exp/all_in_one.onnx
--onnx-joiner-filename $repo/exp/joiner.onnx \
--onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj.onnx \
--onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj.onnx
./pruned_transducer_stateless3/onnx_pretrained.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--encoder-model-filename $repo/exp/encoder.onnx \
--decoder-model-filename $repo/exp/decoder.onnx \
--joiner-model-filename $repo/exp/joiner.onnx \
--joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \
--joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

View File

@ -1,5 +1,7 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}

View File

@ -1,5 +1,7 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}

View File

@ -1,5 +1,7 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}

View File

@ -1,5 +1,7 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
@ -10,7 +12,6 @@ cd egs/librispeech/ASR
repo_url=https://github.com/csukuangfj/icefall-asr-conformer-ctc-bpe-500
git lfs install
git clone $repo
log "Downloading pre-trained model from $repo_url"
git clone $repo_url

View File

@ -1,5 +1,7 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}

View File

@ -1,5 +1,7 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}

View File

@ -1,5 +1,7 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}

View File

@ -1,5 +1,7 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}

View File

@ -1,5 +1,7 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}

View File

@ -1,5 +1,7 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}

View File

@ -0,0 +1,124 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/wenetspeech/ASR
repo_url=https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2
log "Downloading pre-trained model from $repo_url"
git lfs install
git clone $repo_url
repo=$(basename $repo_url)
log "Display test files"
tree $repo/
soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
pushd $repo/exp
ln -s pretrained_epoch_10_avg_2.pt pretrained.pt
ln -s pretrained_epoch_10_avg_2.pt epoch-99.pt
popd
log "Test exporting to ONNX format"
./pruned_transducer_stateless2/export.py \
--exp-dir $repo/exp \
--lang-dir $repo/data/lang_char \
--epoch 99 \
--avg 1 \
--onnx 1
log "Export to torchscript model"
./pruned_transducer_stateless2/export.py \
--exp-dir $repo/exp \
--lang-dir $repo/data/lang_char \
--epoch 99 \
--avg 1 \
--jit 1
./pruned_transducer_stateless2/export.py \
--exp-dir $repo/exp \
--lang-dir $repo/data/lang_char \
--epoch 99 \
--avg 1 \
--jit-trace 1
ls -lh $repo/exp/*.onnx
ls -lh $repo/exp/*.pt
log "Decode with ONNX models"
./pruned_transducer_stateless2/onnx_check.py \
--jit-filename $repo/exp/cpu_jit.pt \
--onnx-encoder-filename $repo/exp/encoder.onnx \
--onnx-decoder-filename $repo/exp/decoder.onnx \
--onnx-joiner-filename $repo/exp/joiner.onnx \
--onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj.onnx \
--onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj.onnx
./pruned_transducer_stateless2/onnx_pretrained.py \
--tokens $repo/data/lang_char/tokens.txt \
--encoder-model-filename $repo/exp/encoder.onnx \
--decoder-model-filename $repo/exp/decoder.onnx \
--joiner-model-filename $repo/exp/joiner.onnx \
--joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \
--joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \
$repo/test_wavs/DEV_T0000000000.wav \
$repo/test_wavs/DEV_T0000000001.wav \
$repo/test_wavs/DEV_T0000000002.wav
log "Decode with models exported by torch.jit.trace()"
./pruned_transducer_stateless2/jit_pretrained.py \
--tokens $repo/data/lang_char/tokens.txt \
--encoder-model-filename $repo/exp/encoder_jit_trace.pt \
--decoder-model-filename $repo/exp/decoder_jit_trace.pt \
--joiner-model-filename $repo/exp/joiner_jit_trace.pt \
$repo/test_wavs/DEV_T0000000000.wav \
$repo/test_wavs/DEV_T0000000001.wav \
$repo/test_wavs/DEV_T0000000002.wav
./pruned_transducer_stateless2/jit_pretrained.py \
--tokens $repo/data/lang_char/tokens.txt \
--encoder-model-filename $repo/exp/encoder_jit_script.pt \
--decoder-model-filename $repo/exp/decoder_jit_script.pt \
--joiner-model-filename $repo/exp/joiner_jit_script.pt \
$repo/test_wavs/DEV_T0000000000.wav \
$repo/test_wavs/DEV_T0000000001.wav \
$repo/test_wavs/DEV_T0000000002.wav
for sym in 1 2 3; do
log "Greedy search with --max-sym-per-frame $sym"
./pruned_transducer_stateless2/pretrained.py \
--checkpoint $repo/exp/epoch-99.pt \
--lang-dir $repo/data/lang_char \
--decoding-method greedy_search \
--max-sym-per-frame $sym \
$repo/test_wavs/DEV_T0000000000.wav \
$repo/test_wavs/DEV_T0000000001.wav \
$repo/test_wavs/DEV_T0000000002.wav
done
for method in modified_beam_search beam_search fast_beam_search; do
log "$method"
./pruned_transducer_stateless2/pretrained.py \
--decoding-method $method \
--beam-size 4 \
--checkpoint $repo/exp/epoch-99.pt \
--lang-dir $repo/data/lang_char \
$repo/test_wavs/DEV_T0000000000.wav \
$repo/test_wavs/DEV_T0000000001.wav \
$repo/test_wavs/DEV_T0000000002.wav
done

View File

@ -69,7 +69,7 @@ jobs:
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'

View File

@ -68,7 +68,7 @@ jobs:
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'

View File

@ -68,7 +68,7 @@ jobs:
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'

View File

@ -68,7 +68,7 @@ jobs:
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'

View File

@ -68,7 +68,7 @@ jobs:
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'

View File

@ -0,0 +1,136 @@
name: run-librispeech-lstm-transducer-2022-09-03
on:
push:
branches:
- master
pull_request:
types: [labeled]
schedule:
# minute (0-59)
# hour (0-23)
# day of the month (1-31)
# month (1-12)
# day of the week (0-6)
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
jobs:
run_librispeech_pruned_transducer_stateless3_2022_05_13:
if: github.event.label.name == 'ncnn' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
python-version: [3.8]
fail-fast: false
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
cache-dependency-path: '**/requirements-ci.txt'
- name: Install Python dependencies
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
pip uninstall -y protobuf
pip install --no-binary protobuf protobuf
- name: Cache kaldifeat
id: my-cache
uses: actions/cache@v2
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/install-kaldifeat.sh
- name: Cache LibriSpeech test-clean and test-other datasets
id: libri-test-clean-and-test-other-data
uses: actions/cache@v2
with:
path: |
~/tmp/download
key: cache-libri-test-clean-and-test-other
- name: Download LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
- name: Prepare manifests for LibriSpeech test-clean and test-other
shell: bash
run: |
.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
- name: Cache LibriSpeech test-clean and test-other fbank features
id: libri-test-clean-and-test-other-fbank
uses: actions/cache@v2
with:
path: |
~/tmp/fbank-libri
key: cache-libri-fbank-test-clean-and-test-other-v2
- name: Compute fbank for LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
- name: Inference with pre-trained model
shell: bash
env:
GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
run: |
mkdir -p egs/librispeech/ASR/data
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
ls -lh egs/librispeech/ASR/data/*
sudo apt-get -qq install git-lfs tree sox
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
- name: Display decoding results for lstm_transducer_stateless2
if: github.event_name == 'schedule' || github.event.label.name == 'ncnn'
shell: bash
run: |
cd egs/librispeech/ASR
tree lstm_transducer_stateless2/exp
cd lstm_transducer_stateless2/exp
echo "===greedy search==="
find greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===fast_beam_search==="
find fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===modified beam search==="
find modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Upload decoding results for lstm_transducer_stateless2
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'ncnn'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-lstm_transducer_stateless2-2022-09-03
path: egs/librispeech/ASR/lstm_transducer_stateless2/exp/

View File

@ -68,7 +68,7 @@ jobs:
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'

View File

@ -68,7 +68,7 @@ jobs:
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'

View File

@ -68,7 +68,7 @@ jobs:
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'

View File

@ -58,7 +58,7 @@ jobs:
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'

View File

@ -67,7 +67,7 @@ jobs:
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'

View File

@ -67,7 +67,7 @@ jobs:
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'

View File

@ -58,7 +58,7 @@ jobs:
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'

View File

@ -58,7 +58,7 @@ jobs:
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'

View File

@ -67,7 +67,7 @@ jobs:
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'

View File

@ -58,7 +58,7 @@ jobs:
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'

View File

@ -0,0 +1,80 @@
# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: run-wenetspeech-pruned-transducer-stateless2
on:
push:
branches:
- master
pull_request:
types: [labeled]
jobs:
run_librispeech_pruned_transducer_stateless3_2022_05_13:
if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'wenetspeech'
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
python-version: [3.8]
fail-fast: false
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
cache-dependency-path: '**/requirements-ci.txt'
- name: Install Python dependencies
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
pip uninstall -y protobuf
pip install --no-binary protobuf protobuf
- name: Cache kaldifeat
id: my-cache
uses: actions/cache@v2
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/install-kaldifeat.sh
- name: Inference with pre-trained model
shell: bash
env:
GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
run: |
sudo apt-get -qq install git-lfs tree sox
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh

View File

@ -29,8 +29,8 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04, macos-latest]
python-version: [3.7, 3.9]
os: [ubuntu-latest]
python-version: [3.8]
fail-fast: false
steps:

2
.gitignore vendored
View File

@ -11,3 +11,5 @@ log
*.bak
*-bak
*bak.py
*.param
*.bin

View File

@ -1,24 +1,114 @@
# icefall dockerfile
We provide a dockerfile for some users, the configuration of dockerfile is : Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8-python3.8. You can use the dockerfile by following the steps:
2 sets of configuration are provided - (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8, and (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8.
## Building images locally
If your NVIDIA driver supports CUDA Version: 11.3, please go for case (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8.
Otherwise, since the older PyTorch images are not updated with the [apt-key rotation by NVIDIA](https://developer.nvidia.com/blog/updating-the-cuda-linux-gpg-repository-key), you have to go for case (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8. Ensure that your NVDIA driver supports at least CUDA 11.0.
You can check the highest CUDA version within your NVIDIA driver's support with the `nvidia-smi` command below. In this example, the highest CUDA version is 11.0, i.e. case (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8.
```bash
cd docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8
docker build -t icefall/pytorch1.7.1:latest -f ./Dockerfile ./
$ nvidia-smi
Tue Sep 20 00:26:13 2022
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.119.03 Driver Version: 450.119.03 CUDA Version: 11.0 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 TITAN RTX On | 00000000:03:00.0 Off | N/A |
| 41% 31C P8 4W / 280W | 16MiB / 24219MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 1 TITAN RTX On | 00000000:04:00.0 Off | N/A |
| 41% 30C P8 11W / 280W | 6MiB / 24220MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 2085 G /usr/lib/xorg/Xorg 9MiB |
| 0 N/A N/A 2240 G /usr/bin/gnome-shell 4MiB |
| 1 N/A N/A 2085 G /usr/lib/xorg/Xorg 4MiB |
+-----------------------------------------------------------------------------+
```
## Using built images
Sample usage of the GPU based images:
## Building images locally
If your environment requires a proxy to access the Internet, remember to add those information into the Dockerfile directly.
For most cases, you can uncomment these lines in the Dockerfile and add in your proxy details.
```dockerfile
ENV http_proxy=http://aaa.bb.cc.net:8080 \
https_proxy=http://aaa.bb.cc.net:8080
```
Then, proceed with these commands.
### If you are case (a), i.e. your NVIDIA driver supports CUDA version >= 11.3:
```bash
cd docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8
docker build -t icefall/pytorch1.12.1 .
```
### If you are case (b), i.e. your NVIDIA driver can only support CUDA versions 11.0 <= x < 11.3:
```bash
cd docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8
docker build -t icefall/pytorch1.7.1 .
```
## Running your built local image
Sample usage of the GPU based images. These commands are written with case (a) in mind, so please make the necessary changes to your image name if you are case (b).
Note: use [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) to run the GPU images.
```bash
docker run -it --runtime=nvidia --name=icefall_username --gpus all icefall/pytorch1.7.1:latest
docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all icefall/pytorch1.12.1
```
Sample usage of the CPU based images:
### Tips:
1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/docker}:{/path/in/host/machine}`.
2. Also, if your environment requires a proxy, this would be a good time to add it in too: `-e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080`.
Overall, your docker run command should look like this.
```bash
docker run -it icefall/pytorch1.7.1:latest /bin/bash
```
docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all -v {/path/in/docker}:{/path/in/host/machine} -e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080 icefall/pytorch1.12.1
```
You can explore more docker run options [here](https://docs.docker.com/engine/reference/commandline/run/) to suit your environment.
### Linking to icefall in your host machine
If you already have icefall downloaded onto your host machine, you can use that repository instead so that changes in your code are visible inside and outside of the container.
Note: Remember to set the -v flag above during the first run of the container, as that is the only way for your container to access your host machine.
Warning: Check that the icefall in your host machine is visible from within your container before proceeding to the commands below.
Use these commands once you are inside the container.
```bash
rm -r /workspace/icefall
ln -s {/path/in/docker/to/icefall} /workspace/icefall
```
## Starting another session in the same running container.
```bash
docker exec -it icefall /bin/bash
```
## Restarting a killed container that has been run before.
```bash
docker start -ai icefall
```
## Sample usage of the CPU based images:
```bash
docker run -it icefall /bin/bash
```

View File

@ -0,0 +1,72 @@
FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-devel
# ENV http_proxy=http://aaa.bbb.cc.net:8080 \
# https_proxy=http://aaa.bbb.cc.net:8080
# install normal source
RUN apt-get update && \
apt-get install -y --no-install-recommends \
g++ \
make \
automake \
autoconf \
bzip2 \
unzip \
wget \
sox \
libtool \
git \
subversion \
zlib1g-dev \
gfortran \
ca-certificates \
patch \
ffmpeg \
valgrind \
libssl-dev \
vim \
curl
# cmake
RUN wget -P /opt https://cmake.org/files/v3.18/cmake-3.18.0.tar.gz && \
cd /opt && \
tar -zxvf cmake-3.18.0.tar.gz && \
cd cmake-3.18.0 && \
./bootstrap && \
make && \
make install && \
rm -rf cmake-3.18.0.tar.gz && \
find /opt/cmake-3.18.0 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
cd -
# flac
RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && \
cd /opt && \
xz -d flac-1.3.2.tar.xz && \
tar -xvf flac-1.3.2.tar && \
cd flac-1.3.2 && \
./configure && \
make && make install && \
rm -rf flac-1.3.2.tar && \
find /opt/flac-1.3.2 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
cd -
RUN pip install kaldiio graphviz && \
conda install -y -c pytorch torchaudio
#install k2 from source
RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \
cd /opt/k2 && \
python3 setup.py install && \
cd -
# install lhotse
RUN pip install git+https://github.com/lhotse-speech/lhotse
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
cd /workspace/icefall && \
pip install -r requirements.txt
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
WORKDIR /workspace/icefall

View File

@ -1,7 +1,13 @@
FROM pytorch/pytorch:1.7.1-cuda11.0-cudnn8-devel
# install normal source
# ENV http_proxy=http://aaa.bbb.cc.net:8080 \
# https_proxy=http://aaa.bbb.cc.net:8080
RUN rm /etc/apt/sources.list.d/cuda.list && \
rm /etc/apt/sources.list.d/nvidia-ml.list && \
apt-key del 7fa2af80
# install normal source
RUN apt-get update && \
apt-get install -y --no-install-recommends \
g++ \
@ -21,20 +27,25 @@ RUN apt-get update && \
patch \
ffmpeg \
valgrind \
libssl-dev \
vim && \
rm -rf /var/lib/apt/lists/*
libssl-dev \
vim \
curl
RUN mv /opt/conda/lib/libcufft.so.10 /opt/libcufft.so.10.bak && \
# Add new keys and reupdate
RUN curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub | apt-key add - && \
curl -fsSL https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub | apt-key add - && \
echo "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/cuda.list && \
echo "deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/nvidia-ml.list && \
rm -rf /var/lib/apt/lists/* && \
mv /opt/conda/lib/libcufft.so.10 /opt/libcufft.so.10.bak && \
mv /opt/conda/lib/libcurand.so.10 /opt/libcurand.so.10.bak && \
mv /opt/conda/lib/libcublas.so.11 /opt/libcublas.so.11.bak && \
mv /opt/conda/lib/libnvrtc.so.11.0 /opt/libnvrtc.so.11.1.bak && \
mv /opt/conda/lib/libnvToolsExt.so.1 /opt/libnvToolsExt.so.1.bak && \
mv /opt/conda/lib/libcudart.so.11.0 /opt/libcudart.so.11.0.bak
# mv /opt/conda/lib/libnvToolsExt.so.1 /opt/libnvToolsExt.so.1.bak && \
mv /opt/conda/lib/libcudart.so.11.0 /opt/libcudart.so.11.0.bak && \
apt-get update && apt-get -y upgrade
# cmake
RUN wget -P /opt https://cmake.org/files/v3.18/cmake-3.18.0.tar.gz && \
cd /opt && \
tar -zxvf cmake-3.18.0.tar.gz && \
@ -45,11 +56,7 @@ RUN wget -P /opt https://cmake.org/files/v3.18/cmake-3.18.0.tar.gz && \
rm -rf cmake-3.18.0.tar.gz && \
find /opt/cmake-3.18.0 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
cd -
#kaldiio
RUN pip install kaldiio
# flac
RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && \
cd /opt && \
@ -62,15 +69,8 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz &&
find /opt/flac-1.3.2 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
cd -
# graphviz
RUN pip install graphviz
# kaldifeat
RUN git clone https://github.com/csukuangfj/kaldifeat.git /opt/kaldifeat && \
cd /opt/kaldifeat && \
python setup.py install && \
cd -
RUN pip install kaldiio graphviz && \
conda install -y -c pytorch torchaudio=0.7.1
#install k2 from source
RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \
@ -79,14 +79,13 @@ RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \
cd -
# install lhotse
RUN pip install torchaudio==0.7.2
RUN pip install git+https://github.com/lhotse-speech/lhotse
#RUN pip install lhotse
RUN pip install git+https://github.com/lhotse-speech/lhotse
# install icefall
RUN git clone https://github.com/k2-fsa/icefall && \
cd icefall && \
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
cd /workspace/icefall && \
pip install -r requirements.txt
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
WORKDIR /workspace/icefall

View File

@ -74,7 +74,7 @@ html_context = {
"github_user": "k2-fsa",
"github_repo": "icefall",
"github_version": "master",
"conf_py_path": "/icefall/docs/source/",
"conf_py_path": "/docs/source/",
}
todo_include_todos = True

View File

@ -21,6 +21,7 @@ speech recognition recipes using `k2 <https://github.com/k2-fsa/k2>`_.
:caption: Contents:
installation/index
model-export/index
recipes/index
contributing/index
huggingface/index

View File

@ -0,0 +1,21 @@
2022-10-13 19:09:02,233 INFO [pretrained.py:265] {'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 50, 'reset_interval': 200, 'valid_interval': 3000, 'feature_dim': 80, 'subsampling_factor': 4, 'encoder_dim': 512, 'nhead': 8, 'dim_feedforward': 2048, 'num_encoder_layers': 12, 'decoder_dim': 512, 'joiner_dim': 512, 'model_warm_step': 3000, 'env_info': {'k2-version': '1.21', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '4810e00d8738f1a21278b0156a42ff396a2d40ac', 'k2-git-date': 'Fri Oct 7 19:35:03 2022', 'lhotse-version': '1.3.0.dev+missing.version.file', 'torch-version': '1.10.0+cu102', 'torch-cuda-available': False, 'torch-cuda-version': '10.2', 'python-version': '3.8', 'icefall-git-branch': 'onnx-doc-1013', 'icefall-git-sha1': 'c39cba5-dirty', 'icefall-git-date': 'Thu Oct 13 15:17:20 2022', 'icefall-path': '/k2-dev/fangjun/open-source/icefall-master', 'k2-path': '/k2-dev/fangjun/open-source/k2-master/k2/python/k2/__init__.py', 'lhotse-path': '/ceph-fj/fangjun/open-source-2/lhotse-jsonl/lhotse/__init__.py', 'hostname': 'de-74279-k2-test-4-0324160024-65bfd8b584-jjlbn', 'IP address': '10.177.74.203'}, 'checkpoint': './icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/pretrained-iter-1224000-avg-14.pt', 'bpe_model': './icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/bpe.model', 'method': 'greedy_search', 'sound_files': ['./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav', './icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav', './icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav'], 'sample_rate': 16000, 'beam_size': 4, 'beam': 4, 'max_contexts': 4, 'max_states': 8, 'context_size': 2, 'max_sym_per_frame': 1, 'simulate_streaming': False, 'decode_chunk_size': 16, 'left_context': 64, 'dynamic_chunk_training': False, 'causal_convolution': False, 'short_chunk_size': 25, 'num_left_chunks': 4, 'blank_id': 0, 'unk_id': 2, 'vocab_size': 500}
2022-10-13 19:09:02,233 INFO [pretrained.py:271] device: cpu
2022-10-13 19:09:02,233 INFO [pretrained.py:273] Creating model
2022-10-13 19:09:02,612 INFO [train.py:458] Disable giga
2022-10-13 19:09:02,623 INFO [pretrained.py:277] Number of model parameters: 78648040
2022-10-13 19:09:02,951 INFO [pretrained.py:285] Constructing Fbank computer
2022-10-13 19:09:02,952 INFO [pretrained.py:295] Reading sound files: ['./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav', './icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav', './icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav']
2022-10-13 19:09:02,957 INFO [pretrained.py:301] Decoding started
2022-10-13 19:09:06,700 INFO [pretrained.py:329] Using greedy_search
2022-10-13 19:09:06,912 INFO [pretrained.py:388]
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav:
GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONORED BOSOM TO CONNECT HER PARENT FOREVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
2022-10-13 19:09:06,912 INFO [pretrained.py:390] Decoding Done

View File

@ -0,0 +1,135 @@
Export model.state_dict()
=========================
When to use it
--------------
During model training, we save checkpoints periodically to disk.
A checkpoint contains the following information:
- ``model.state_dict()``
- ``optimizer.state_dict()``
- and some other information related to training
When we need to resume the training process from some point, we need a checkpoint.
However, if we want to publish the model for inference, then only
``model.state_dict()`` is needed. In this case, we need to strip all other information
except ``model.state_dict()`` to reduce the file size of the published model.
How to export
-------------
Every recipe contains a file ``export.py`` that you can use to
export ``model.state_dict()`` by taking some checkpoints as inputs.
.. hint::
Each ``export.py`` contains well-documented usage information.
In the following, we use
`<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless3/export.py>`_
as an example.
.. note::
The steps for other recipes are almost the same.
.. code-block:: bash
cd egs/librispeech/ASR
./pruned_transducer_stateless3/export.py \
--exp-dir ./pruned_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10
will generate a file ``pruned_transducer_stateless3/exp/pretrained.pt``, which
is a dict containing ``{"model": model.state_dict()}`` saved by ``torch.save()``.
How to use the exported model
-----------------------------
For each recipe, we provide pretrained models hosted on huggingface.
You can find links to pretrained models in ``RESULTS.md`` of each dataset.
In the following, we demonstrate how to use the pretrained model from
`<https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13>`_.
.. code-block:: bash
cd egs/librispeech/ASR
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
After cloning the repo with ``git lfs``, you will find several files in the folder
``icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp``
that have a prefix ``pretrained-``. Those files contain ``model.state_dict()``
exported by the above ``export.py``.
In each recipe, there is also a file ``pretrained.py``, which can use
``pretrained-xxx.pt`` to decode waves. The following is an example:
.. code-block:: bash
cd egs/librispeech/ASR
./pruned_transducer_stateless3/pretrained.py \
--checkpoint ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/pretrained-iter-1224000-avg-14.pt \
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/bpe.model \
--method greedy_search \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav \
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav
The above commands show how to use the exported model with ``pretrained.py`` to
decode multiple sound files. Its output is given as follows for reference:
.. literalinclude:: ./code/export-model-state-dict-pretrained-out.txt
Use the exported model to run decode.py
---------------------------------------
When we publish the model, we always note down its WERs on some test
dataset in ``RESULTS.md``. This section describes how to use the
pretrained model to reproduce the WER.
.. code-block:: bash
cd egs/librispeech/ASR
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
cd icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp
ln -s pretrained-iter-1224000-avg-14.pt epoch-9999.pt
cd ../..
We create a symlink with name ``epoch-9999.pt`` to ``pretrained-iter-1224000-avg-14.pt``,
so that we can pass ``--epoch 9999 --avg 1`` to ``decode.py`` in the following
command:
.. code-block:: bash
./pruned_transducer_stateless3/decode.py \
--epoch 9999 \
--avg 1 \
--exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp \
--lang-dir ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500 \
--max-duration 600 \
--decoding-method greedy_search
You will find the decoding results in
``./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/greedy_search``.
.. caution::
For some recipes, you also need to pass ``--use-averaged-model False``
to ``decode.py``. The reason is that the exported pretrained model is already
the averaged one.
.. hint::
Before running ``decode.py``, we assume that you have already run
``prepare.sh`` to prepare the test dataset.

View File

@ -0,0 +1,12 @@
Export to ncnn
==============
We support exporting LSTM transducer models to `ncnn <https://github.com/tencent/ncnn>`_.
Please refer to :ref:`export-model-for-ncnn` for details.
We also provide `<https://github.com/k2-fsa/sherpa-ncnn>`_
performing speech recognition using ``ncnn`` with exported models.
It has been tested on Linux, macOS, Windows, and Raspberry Pi. The project is
self-contained and can be statically linked to produce a binary containing
everything needed.

View File

@ -0,0 +1,69 @@
Export to ONNX
==============
In this section, we describe how to export models to ONNX.
.. hint::
Only non-streaming conformer transducer models are tested.
When to use it
--------------
It you want to use an inference framework that supports ONNX
to run the pretrained model.
How to export
-------------
We use
`<https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless3>`_
as an example in the following.
.. code-block:: bash
cd egs/librispeech/ASR
epoch=14
avg=2
./pruned_transducer_stateless3/export.py \
--exp-dir ./pruned_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch $epoch \
--avg $avg \
--onnx 1
It will generate the following files inside ``pruned_transducer_stateless3/exp``:
- ``encoder.onnx``
- ``decoder.onnx``
- ``joiner.onnx``
- ``joiner_encoder_proj.onnx``
- ``joiner_decoder_proj.onnx``
You can use ``./pruned_transducer_stateless3/exp/onnx_pretrained.py`` to decode
waves with the generated files:
.. code-block:: bash
./pruned_transducer_stateless3/onnx_pretrained.py \
--bpe-model ./data/lang_bpe_500/bpe.model \
--encoder-model-filename ./pruned_transducer_stateless3/exp/encoder.onnx \
--decoder-model-filename ./pruned_transducer_stateless3/exp/decoder.onnx \
--joiner-model-filename ./pruned_transducer_stateless3/exp/joiner.onnx \
--joiner-encoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_encoder_proj.onnx \
--joiner-decoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_decoder_proj.onnx \
/path/to/foo.wav \
/path/to/bar.wav \
/path/to/baz.wav
How to use the exported model
-----------------------------
We also provide `<https://github.com/k2-fsa/sherpa-onnx>`_
performing speech recognition using `onnxruntime <https://github.com/microsoft/onnxruntime>`_
with exported models.
It has been tested on Linux, macOS, and Windows.

View File

@ -0,0 +1,58 @@
.. _export-model-with-torch-jit-script:
Export model with torch.jit.script()
===================================
In this section, we describe how to export a model via
``torch.jit.script()``.
When to use it
--------------
If we want to use our trained model with torchscript,
we can use ``torch.jit.script()``.
.. hint::
See :ref:`export-model-with-torch-jit-trace`
if you want to use ``torch.jit.trace()``.
How to export
-------------
We use
`<https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless3>`_
as an example in the following.
.. code-block:: bash
cd egs/librispeech/ASR
epoch=14
avg=1
./pruned_transducer_stateless3/export.py \
--exp-dir ./pruned_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch $epoch \
--avg $avg \
--jit 1
It will generate a file ``cpu_jit.pt`` in ``pruned_transducer_stateless3/exp``.
.. caution::
Don't be confused by ``cpu`` in ``cpu_jit.pt``. We move all parameters
to CPU before saving it into a ``pt`` file; that's why we use ``cpu``
in the filename.
How to use the exported model
-----------------------------
Please refer to the following pages for usage:
- `<https://k2-fsa.github.io/sherpa/python/streaming_asr/emformer/index.html>`_
- `<https://k2-fsa.github.io/sherpa/python/streaming_asr/conv_emformer/index.html>`_
- `<https://k2-fsa.github.io/sherpa/python/streaming_asr/conformer/index.html>`_
- `<https://k2-fsa.github.io/sherpa/python/offline_asr/conformer/index.html>`_
- `<https://k2-fsa.github.io/sherpa/cpp/offline_asr/gigaspeech.html>`_
- `<https://k2-fsa.github.io/sherpa/cpp/offline_asr/wenetspeech.html>`_

View File

@ -0,0 +1,69 @@
.. _export-model-with-torch-jit-trace:
Export model with torch.jit.trace()
===================================
In this section, we describe how to export a model via
``torch.jit.trace()``.
When to use it
--------------
If we want to use our trained model with torchscript,
we can use ``torch.jit.trace()``.
.. hint::
See :ref:`export-model-with-torch-jit-script`
if you want to use ``torch.jit.script()``.
How to export
-------------
We use
`<https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/lstm_transducer_stateless2>`_
as an example in the following.
.. code-block:: bash
iter=468000
avg=16
cd egs/librispeech/ASR
./lstm_transducer_stateless2/export.py \
--exp-dir ./lstm_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--iter $iter \
--avg $avg \
--jit-trace 1
It will generate three files inside ``lstm_transducer_stateless2/exp``:
- ``encoder_jit_trace.pt``
- ``decoder_jit_trace.pt``
- ``joiner_jit_trace.pt``
You can use
`<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py>`_
to decode sound files with the following commands:
.. code-block:: bash
cd egs/librispeech/ASR
./lstm_transducer_stateless2/jit_pretrained.py \
--bpe-model ./data/lang_bpe_500/bpe.model \
--encoder-model-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace.pt \
--decoder-model-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace.pt \
--joiner-model-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace.pt \
/path/to/foo.wav \
/path/to/bar.wav \
/path/to/baz.wav
How to use the exported models
------------------------------
Please refer to
`<https://k2-fsa.github.io/sherpa/python/streaming_asr/lstm/index.html>`_
for its usage in `sherpa <https://k2-fsa.github.io/sherpa/python/streaming_asr/lstm/index.html>`_.
You can also find pretrained models there.

View File

@ -0,0 +1,14 @@
Model export
============
In this section, we describe various ways to export models.
.. toctree::
export-model-state-dict
export-with-torch-jit-trace
export-with-torch-jit-script
export-onnx
export-ncnn

View File

@ -422,7 +422,7 @@ The information of the test sound files is listed below:
.. code-block:: bash
$ soxi tmp/icefall_asr_aishell_conformer_ctc/test_wavs/*.wav
$ soxi tmp/icefall_asr_aishell_conformer_ctc/test_waves/*.wav
Input File : 'tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0121.wav'
Channels : 1
@ -485,9 +485,9 @@ The command to run CTC decoding is:
--checkpoint ./tmp/icefall_asr_aishell_conformer_ctc/exp/pretrained.pt \
--tokens-file ./tmp/icefall_asr_aishell_conformer_ctc/data/lang_char/tokens.txt \
--method ctc-decoding \
./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0121.wav \
./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0122.wav \
./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0123.wav
./tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0121.wav \
./tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0122.wav \
./tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0123.wav
The output is given below:
@ -529,9 +529,9 @@ The command to run HLG decoding is:
--words-file ./tmp/icefall_asr_aishell_conformer_ctc/data/lang_char/words.txt \
--HLG ./tmp/icefall_asr_aishell_conformer_ctc/data/lang_char/HLG.pt \
--method 1best \
./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0121.wav \
./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0122.wav \
./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0123.wav
./tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0121.wav \
./tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0122.wav \
./tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0123.wav
The output is given below:
@ -575,9 +575,9 @@ The command to run HLG decoding + attention decoder rescoring is:
--words-file ./tmp/icefall_asr_aishell_conformer_ctc/data/lang_char/words.txt \
--HLG ./tmp/icefall_asr_aishell_conformer_ctc/data/lang_char/HLG.pt \
--method attention-decoder \
./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0121.wav \
./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0122.wav \
./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0123.wav
./tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0121.wav \
./tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0122.wav \
./tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0123.wav
The output is below:

View File

@ -402,7 +402,7 @@ The information of the test sound files is listed below:
.. code-block:: bash
$ soxi tmp/icefall_asr_aishell_tdnn_lstm_ctc/test_wavs/*.wav
$ soxi tmp/icefall_asr_aishell_tdnn_lstm_ctc/test_waves/*.wav
Input File : 'tmp/icefall_asr_aishell_tdnn_lstm_ctc/test_waves/BAC009S0764W0121.wav'
Channels : 1
@ -461,9 +461,9 @@ The command to run HLG decoding is:
--words-file ./tmp/icefall_asr_aishell_tdnn_lstm_ctc/data/lang_phone/words.txt \
--HLG ./tmp/icefall_asr_aishell_tdnn_lstm_ctc/data/lang_phone/HLG.pt \
--method 1best \
./tmp/icefall_asr_aishell_tdnn_lstm_ctc/test_wavs/BAC009S0764W0121.wav \
./tmp/icefall_asr_aishell_tdnn_lstm_ctc/test_wavs/BAC009S0764W0122.wav \
./tmp/icefall_asr_aishell_tdnn_lstm_ctc/test_wavs/BAC009S0764W0123.wav
./tmp/icefall_asr_aishell_tdnn_lstm_ctc/test_waves/BAC009S0764W0121.wav \
./tmp/icefall_asr_aishell_tdnn_lstm_ctc/test_waves/BAC009S0764W0122.wav \
./tmp/icefall_asr_aishell_tdnn_lstm_ctc/test_waves/BAC009S0764W0123.wav
The output is given below:

Binary file not shown.

After

Width:  |  Height:  |  Size: 413 KiB

View File

@ -6,3 +6,4 @@ LibriSpeech
tdnn_lstm_ctc
conformer_ctc
lstm_pruned_stateless_transducer

View File

@ -0,0 +1,636 @@
LSTM Transducer
===============
.. hint::
Please scroll down to the bottom of this page to find download links
for pretrained models if you don't want to train a model from scratch.
This tutorial shows you how to train an LSTM transducer model
with the `LibriSpeech <https://www.openslr.org/12>`_ dataset.
We use pruned RNN-T to compute the loss.
.. note::
You can find the paper about pruned RNN-T at the following address:
`<https://arxiv.org/abs/2206.13236>`_
The transducer model consists of 3 parts:
- Encoder, a.k.a, the transcription network. We use an LSTM model
- Decoder, a.k.a, the prediction network. We use a stateless model consisting of
``nn.Embedding`` and ``nn.Conv1d``
- Joiner, a.k.a, the joint network.
.. caution::
Contrary to the conventional RNN-T models, we use a stateless decoder.
That is, it has no recurrent connections.
.. hint::
Since the encoder model is an LSTM, not Transformer/Conformer, the
resulting model is suitable for streaming/online ASR.
Which model to use
------------------
Currently, there are two folders about LSTM stateless transducer training:
- ``(1)`` `<https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/lstm_transducer_stateless>`_
This recipe uses only LibriSpeech during training.
- ``(2)`` `<https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/lstm_transducer_stateless2>`_
This recipe uses GigaSpeech + LibriSpeech during training.
``(1)`` and ``(2)`` use the same model architecture. The only difference is that ``(2)`` supports
multi-dataset. Since ``(2)`` uses more data, it has a lower WER than ``(1)`` but it needs
more training time.
We use ``lstm_transducer_stateless2`` as an example below.
.. note::
You need to download the `GigaSpeech <https://github.com/SpeechColab/GigaSpeech>`_ dataset
to run ``(2)``. If you have only ``LibriSpeech`` dataset available, feel free to use ``(1)``.
Data preparation
----------------
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./prepare.sh
# If you use (1), you can **skip** the following command
$ ./prepare_giga_speech.sh
The script ``./prepare.sh`` handles the data preparation for you, **automagically**.
All you need to do is to run it.
.. note::
We encourage you to read ``./prepare.sh``.
The data preparation contains several stages. You can use the following two
options:
- ``--stage``
- ``--stop-stage``
to control which stage(s) should be run. By default, all stages are executed.
For example,
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./prepare.sh --stage 0 --stop-stage 0
means to run only stage 0.
To run stage 2 to stage 5, use:
.. code-block:: bash
$ ./prepare.sh --stage 2 --stop-stage 5
.. hint::
If you have pre-downloaded the `LibriSpeech <https://www.openslr.org/12>`_
dataset and the `musan <http://www.openslr.org/17/>`_ dataset, say,
they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify
the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that
``./prepare.sh`` won't re-download them.
.. note::
All generated files by ``./prepare.sh``, e.g., features, lexicon, etc,
are saved in ``./data`` directory.
We provide the following YouTube video showing how to run ``./prepare.sh``.
.. note::
To get the latest news of `next-gen Kaldi <https://github.com/k2-fsa>`_, please subscribe
the following YouTube channel by `Nadira Povey <https://www.youtube.com/channel/UC_VaumpkmINz1pNkFXAN9mw>`_:
`<https://www.youtube.com/channel/UC_VaumpkmINz1pNkFXAN9mw>`_
.. youtube:: ofEIoJL-mGM
Training
--------
Configurable options
~~~~~~~~~~~~~~~~~~~~
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./lstm_transducer_stateless2/train.py --help
shows you the training options that can be passed from the commandline.
The following options are used quite often:
- ``--full-libri``
If it's True, the training part uses all the training data, i.e.,
960 hours. Otherwise, the training part uses only the subset
``train-clean-100``, which has 100 hours of training data.
.. CAUTION::
The training set is perturbed by speed with two factors: 0.9 and 1.1.
If ``--full-libri`` is True, each epoch actually processes
``3x960 == 2880`` hours of data.
- ``--num-epochs``
It is the number of epochs to train. For instance,
``./lstm_transducer_stateless2/train.py --num-epochs 30`` trains for 30 epochs
and generates ``epoch-1.pt``, ``epoch-2.pt``, ..., ``epoch-30.pt``
in the folder ``./lstm_transducer_stateless2/exp``.
- ``--start-epoch``
It's used to resume training.
``./lstm_transducer_stateless2/train.py --start-epoch 10`` loads the
checkpoint ``./lstm_transducer_stateless2/exp/epoch-9.pt`` and starts
training from epoch 10, based on the state from epoch 9.
- ``--world-size``
It is used for multi-GPU single-machine DDP training.
- (a) If it is 1, then no DDP training is used.
- (b) If it is 2, then GPU 0 and GPU 1 are used for DDP training.
The following shows some use cases with it.
**Use case 1**: You have 4 GPUs, but you only want to use GPU 0 and
GPU 2 for training. You can do the following:
.. code-block:: bash
$ cd egs/librispeech/ASR
$ export CUDA_VISIBLE_DEVICES="0,2"
$ ./lstm_transducer_stateless2/train.py --world-size 2
**Use case 2**: You have 4 GPUs and you want to use all of them
for training. You can do the following:
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./lstm_transducer_stateless2/train.py --world-size 4
**Use case 3**: You have 4 GPUs but you only want to use GPU 3
for training. You can do the following:
.. code-block:: bash
$ cd egs/librispeech/ASR
$ export CUDA_VISIBLE_DEVICES="3"
$ ./lstm_transducer_stateless2/train.py --world-size 1
.. caution::
Only multi-GPU single-machine DDP training is implemented at present.
Multi-GPU multi-machine DDP training will be added later.
- ``--max-duration``
It specifies the number of seconds over all utterances in a
batch, before **padding**.
If you encounter CUDA OOM, please reduce it.
.. HINT::
Due to padding, the number of seconds of all utterances in a
batch will usually be larger than ``--max-duration``.
A larger value for ``--max-duration`` may cause OOM during training,
while a smaller value may increase the training time. You have to
tune it.
- ``--giga-prob``
The probability to select a batch from the ``GigaSpeech`` dataset.
Note: It is available only for ``(2)``.
Pre-configured options
~~~~~~~~~~~~~~~~~~~~~~
There are some training options, e.g., weight decay,
number of warmup steps, results dir, etc,
that are not passed from the commandline.
They are pre-configured by the function ``get_params()`` in
`lstm_transducer_stateless2/train.py <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/lstm_transducer_stateless2/train.py>`_
You don't need to change these pre-configured parameters. If you really need to change
them, please modify ``./lstm_transducer_stateless2/train.py`` directly.
Training logs
~~~~~~~~~~~~~
Training logs and checkpoints are saved in ``lstm_transducer_stateless2/exp``.
You will find the following files in that directory:
- ``epoch-1.pt``, ``epoch-2.pt``, ...
These are checkpoint files saved at the end of each epoch, containing model
``state_dict`` and optimizer ``state_dict``.
To resume training from some checkpoint, say ``epoch-10.pt``, you can use:
.. code-block:: bash
$ ./lstm_transducer_stateless2/train.py --start-epoch 11
- ``checkpoint-436000.pt``, ``checkpoint-438000.pt``, ...
These are checkpoint files saved every ``--save-every-n`` batches,
containing model ``state_dict`` and optimizer ``state_dict``.
To resume training from some checkpoint, say ``checkpoint-436000``, you can use:
.. code-block:: bash
$ ./lstm_transducer_stateless2/train.py --start-batch 436000
- ``tensorboard/``
This folder contains tensorBoard logs. Training loss, validation loss, learning
rate, etc, are recorded in these logs. You can visualize them by:
.. code-block:: bash
$ cd lstm_transducer_stateless2/exp/tensorboard
$ tensorboard dev upload --logdir . --description "LSTM transducer training for LibriSpeech with icefall"
It will print something like below:
.. code-block::
TensorFlow installation not found - running with reduced feature set.
Upload started and will continue reading any new data as it's added to the logdir.
To stop uploading, press Ctrl-C.
New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/cj2vtPiwQHKN9Q1tx6PTpg/
[2022-09-20T15:50:50] Started scanning logdir.
Uploading 4468 scalars...
[2022-09-20T15:53:02] Total uploaded: 210171 scalars, 0 tensors, 0 binary objects
Listening for new data in logdir...
Note there is a URL in the above output. Click it and you will see
the following screenshot:
.. figure:: images/librispeech-lstm-transducer-tensorboard-log.png
:width: 600
:alt: TensorBoard screenshot
:align: center
:target: https://tensorboard.dev/experiment/lzGnETjwRxC3yghNMd4kPw/
TensorBoard screenshot.
.. hint::
If you don't have access to google, you can use the following command
to view the tensorboard log locally:
.. code-block:: bash
cd lstm_transducer_stateless2/exp/tensorboard
tensorboard --logdir . --port 6008
It will print the following message:
.. code-block::
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.8.0 at http://localhost:6008/ (Press CTRL+C to quit)
Now start your browser and go to `<http://localhost:6008>`_ to view the tensorboard
logs.
- ``log/log-train-xxxx``
It is the detailed training log in text format, same as the one
you saw printed to the console during training.
Usage example
~~~~~~~~~~~~~
You can use the following command to start the training using 8 GPUs:
.. code-block:: bash
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
./lstm_transducer_stateless2/train.py \
--world-size 8 \
--num-epochs 35 \
--start-epoch 1 \
--full-libri 1 \
--exp-dir lstm_transducer_stateless2/exp \
--max-duration 500 \
--use-fp16 0 \
--lr-epochs 10 \
--num-workers 2 \
--giga-prob 0.9
Decoding
--------
The decoding part uses checkpoints saved by the training part, so you have
to run the training part first.
.. hint::
There are two kinds of checkpoints:
- (1) ``epoch-1.pt``, ``epoch-2.pt``, ..., which are saved at the end
of each epoch. You can pass ``--epoch`` to
``lstm_transducer_stateless2/decode.py`` to use them.
- (2) ``checkpoints-436000.pt``, ``epoch-438000.pt``, ..., which are saved
every ``--save-every-n`` batches. You can pass ``--iter`` to
``lstm_transducer_stateless2/decode.py`` to use them.
We suggest that you try both types of checkpoints and choose the one
that produces the lowest WERs.
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./lstm_transducer_stateless2/decode.py --help
shows the options for decoding.
The following shows two examples:
.. code-block:: bash
for m in greedy_search fast_beam_search modified_beam_search; do
for epoch in 17; do
for avg in 1 2; do
./lstm_transducer_stateless2/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir lstm_transducer_stateless2/exp \
--max-duration 600 \
--num-encoder-layers 12 \
--rnn-hidden-size 1024 \
--decoding-method $m \
--use-averaged-model True \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
--beam-size 4
done
done
done
.. code-block:: bash
for m in greedy_search fast_beam_search modified_beam_search; do
for iter in 474000; do
for avg in 8 10 12 14 16 18; do
./lstm_transducer_stateless2/decode.py \
--iter $iter \
--avg $avg \
--exp-dir lstm_transducer_stateless2/exp \
--max-duration 600 \
--num-encoder-layers 12 \
--rnn-hidden-size 1024 \
--decoding-method $m \
--use-averaged-model True \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
--beam-size 4
done
done
done
Export models
-------------
`lstm_transducer_stateless2/export.py <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/lstm_transducer_stateless2/export.py>`_ supports exporting checkpoints from ``lstm_transducer_stateless2/exp`` in the following ways.
Export ``model.state_dict()``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Checkpoints saved by ``lstm_transducer_stateless2/train.py`` also include
``optimizer.state_dict()``. It is useful for resuming training. But after training,
we are interested only in ``model.state_dict()``. You can use the following
command to extract ``model.state_dict()``.
.. code-block:: bash
# Assume that --iter 468000 --avg 16 produces the smallest WER
# (You can get such information after running ./lstm_transducer_stateless2/decode.py)
iter=468000
avg=16
./lstm_transducer_stateless2/export.py \
--exp-dir ./lstm_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--iter $iter \
--avg $avg
It will generate a file ``./lstm_transducer_stateless2/exp/pretrained.pt``.
.. hint::
To use the generated ``pretrained.pt`` for ``lstm_transducer_stateless2/decode.py``,
you can run:
.. code-block:: bash
cd lstm_transducer_stateless2/exp
ln -s pretrained epoch-9999.pt
And then pass ``--epoch 9999 --avg 1 --use-averaged-model 0`` to
``./lstm_transducer_stateless2/decode.py``.
To use the exported model with ``./lstm_transducer_stateless2/pretrained.py``, you
can run:
.. code-block:: bash
./lstm_transducer_stateless2/pretrained.py \
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
Export model using ``torch.jit.trace()``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: bash
iter=468000
avg=16
./lstm_transducer_stateless2/export.py \
--exp-dir ./lstm_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--iter $iter \
--avg $avg \
--jit-trace 1
It will generate 3 files:
- ``./lstm_transducer_stateless2/exp/encoder_jit_trace.pt``
- ``./lstm_transducer_stateless2/exp/decoder_jit_trace.pt``
- ``./lstm_transducer_stateless2/exp/joiner_jit_trace.pt``
To use the generated files with ``./lstm_transducer_stateless2/jit_pretrained``:
.. code-block:: bash
./lstm_transducer_stateless2/jit_pretrained.py \
--bpe-model ./data/lang_bpe_500/bpe.model \
--encoder-model-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace.pt \
--decoder-model-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace.pt \
--joiner-model-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace.pt \
/path/to/foo.wav \
/path/to/bar.wav
.. hint::
Please see `<https://k2-fsa.github.io/sherpa/python/streaming_asr/lstm/english/server.html>`_
for how to use the exported models in ``sherpa``.
.. _export-model-for-ncnn:
Export model for ncnn
~~~~~~~~~~~~~~~~~~~~~
We support exporting pretrained LSTM transducer models to
`ncnn <https://github.com/tencent/ncnn>`_ using
`pnnx <https://github.com/Tencent/ncnn/tree/master/tools/pnnx>`_.
First, let us install a modified version of ``ncnn``:
.. code-block:: bash
git clone https://github.com/csukuangfj/ncnn
cd ncnn
git submodule update --recursive --init
python3 setup.py bdist_wheel
ls -lh dist/
pip install ./dist/*.whl
# now build pnnx
cd tools/pnnx
mkdir build
cd build
make -j4
export PATH=$PWD/src:$PATH
./src/pnnx
.. note::
We assume that you have added the path to the binary ``pnnx`` to the
environment variable ``PATH``.
Second, let us export the model using ``torch.jit.trace()`` that is suitable
for ``pnnx``:
.. code-block:: bash
iter=468000
avg=16
./lstm_transducer_stateless2/export.py \
--exp-dir ./lstm_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--iter $iter \
--avg $avg \
--pnnx 1
It will generate 3 files:
- ``./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.pt``
- ``./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.pt``
- ``./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.pt``
Third, convert torchscript model to ``ncnn`` format:
.. code-block::
pnnx ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.pt
pnnx ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.pt
pnnx ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.pt
It will generate the following files:
- ``./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param``
- ``./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin``
- ``./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param``
- ``./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.bin``
- ``./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.param``
- ``./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.bin``
To use the above generated files, run:
.. code-block:: bash
./lstm_transducer_stateless2/ncnn-decode.py \
--bpe-model-filename ./data/lang_bpe_500/bpe.model \
--encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param \
--encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin \
--decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param \
--decoder-bin-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.bin \
--joiner-param-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.param \
--joiner-bin-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.bin \
/path/to/foo.wav
.. code-block:: bash
./lstm_transducer_stateless2/streaming-ncnn-decode.py \
--bpe-model-filename ./data/lang_bpe_500/bpe.model \
--encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param \
--encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin \
--decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param \
--decoder-bin-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.bin \
--joiner-param-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.param \
--joiner-bin-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.bin \
/path/to/foo.wav
To use the above generated files in C++, please see
`<https://github.com/k2-fsa/sherpa-ncnn>`_
It is able to generate a static linked executable that can be run on Linux, Windows,
macOS, Raspberry Pi, etc, without external dependencies.
Download pretrained models
--------------------------
If you don't want to train from scratch, you can download the pretrained models
by visiting the following links:
- `<https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03>`_
- `<https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18>`_
See `<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md>`_
for the details of the above pretrained models
You can find more usages of the pretrained models in
`<https://k2-fsa.github.io/sherpa/python/streaming_asr/lstm/index.html>`_

View File

@ -248,7 +248,9 @@ class ConformerEncoderLayer(nn.Module):
residual = src
if self.normalize_before:
src = self.norm_conv(src)
src = residual + self.dropout(self.conv_module(src))
src = residual + self.dropout(
self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
)
if not self.normalize_before:
src = self.norm_conv(src)
@ -879,11 +881,16 @@ class ConvolutionModule(nn.Module):
)
self.activation = Swish()
def forward(self, x: Tensor) -> Tensor:
def forward(
self,
x: Tensor,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
src_key_padding_mask: the mask for the src keys per batch (optional).
Returns:
Tensor: Output tensor (#time, batch, channels).
@ -897,6 +904,8 @@ class ConvolutionModule(nn.Module):
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
if src_key_padding_mask is not None:
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
x = self.depthwise_conv(x)
x = self.activation(self.norm(x))

View File

@ -248,7 +248,9 @@ class ConformerEncoderLayer(nn.Module):
residual = src
if self.normalize_before:
src = self.norm_conv(src)
src = residual + self.dropout(self.conv_module(src))
src = residual + self.dropout(
self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
)
if not self.normalize_before:
src = self.norm_conv(src)
@ -879,11 +881,16 @@ class ConvolutionModule(nn.Module):
)
self.activation = Swish()
def forward(self, x: Tensor) -> Tensor:
def forward(
self,
x: Tensor,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
src_key_padding_mask: the mask for the src keys per batch (optional).
Returns:
Tensor: Output tensor (#time, batch, channels).
@ -897,6 +904,8 @@ class ConvolutionModule(nn.Module):
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
if src_key_padding_mask is not None:
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
x = self.depthwise_conv(x)
x = self.activation(self.norm(x))

View File

@ -246,7 +246,9 @@ class ConformerEncoderLayer(nn.Module):
residual = src
if self.normalize_before:
src = self.norm_conv(src)
src = residual + self.dropout(self.conv_module(src))
src = residual + self.dropout(
self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
)
if not self.normalize_before:
src = self.norm_conv(src)
@ -877,11 +879,16 @@ class ConvolutionModule(nn.Module):
)
self.activation = Swish()
def forward(self, x: Tensor) -> Tensor:
def forward(
self,
x: Tensor,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
src_key_padding_mask: the mask for the src keys per batch (optional).
Returns:
Tensor: Output tensor (#time, batch, channels).
@ -895,6 +902,8 @@ class ConvolutionModule(nn.Module):
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
if src_key_padding_mask is not None:
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
x = self.depthwise_conv(x)
# x is (batch, channels, time)
x = x.permute(0, 2, 1)

View File

@ -253,7 +253,9 @@ class ConformerEncoderLayer(nn.Module):
residual = src
if self.normalize_before:
src = self.norm_conv(src)
src = residual + self.dropout(self.conv_module(src))
src = residual + self.dropout(
self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
)
if not self.normalize_before:
src = self.norm_conv(src)
@ -890,11 +892,16 @@ class ConvolutionModule(nn.Module):
)
self.activation = Swish()
def forward(self, x: Tensor) -> Tensor:
def forward(
self,
x: Tensor,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
src_key_padding_mask: the mask for the src keys per batch (optional).
Returns:
Tensor: Output tensor (#time, batch, channels).
@ -908,6 +915,8 @@ class ConvolutionModule(nn.Module):
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
if src_key_padding_mask is not None:
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
x = self.depthwise_conv(x)
if self.use_batchnorm:
x = self.norm(x)

View File

@ -1,12 +1,100 @@
## Results
#### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T + multi-dataset)
### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T + gradient filter)
[lstm_transducer_stateless2](./lstm_transducer_stateless2)
#### [lstm_transducer_stateless3](./lstm_transducer_stateless3)
It implements LSTM model with mechanisms in reworked model for streaming ASR.
Gradient filter is applied inside each lstm module to stabilize the training.
See <https://github.com/k2-fsa/icefall/pull/564> for more details.
##### training on full librispeech
This model contains 12 encoder layers (LSTM module + Feedforward module). The number of model parameters is 84689496.
The WERs are:
| | test-clean | test-other | comment | decoding mode |
|-------------------------------------|------------|------------|----------------------|----------------------|
| greedy search (max sym per frame 1) | 3.66 | 9.51 | --epoch 40 --avg 15 | simulated streaming |
| greedy search (max sym per frame 1) | 3.66 | 9.48 | --epoch 40 --avg 15 | streaming |
| fast beam search | 3.55 | 9.33 | --epoch 40 --avg 15 | simulated streaming |
| fast beam search | 3.57 | 9.25 | --epoch 40 --avg 15 | streaming |
| modified beam search | 3.55 | 9.28 | --epoch 40 --avg 15 | simulated streaming |
| modified beam search | 3.54 | 9.25 | --epoch 40 --avg 15 | streaming |
Note: `simulated streaming` indicates feeding full utterance during decoding, while `streaming` indicates feeding certain number of frames at each time.
The training command is:
```bash
./lstm_transducer_stateless3/train.py \
--world-size 4 \
--num-epochs 40 \
--start-epoch 1 \
--exp-dir lstm_transducer_stateless3/exp \
--full-libri 1 \
--max-duration 500 \
--master-port 12325 \
--num-encoder-layers 12 \
--grad-norm-threshold 25.0 \
--rnn-hidden-size 1024
```
The tensorboard log can be found at
<https://tensorboard.dev/experiment/caNPyr5lT8qAl9qKsXEeEQ/>
The simulated streaming decoding command using greedy search, fast beam search, and modified beam search is:
```bash
for decoding_method in greedy_search fast_beam_search modified_beam_search; do
./lstm_transducer_stateless3/decode.py \
--epoch 40 \
--avg 15 \
--exp-dir lstm_transducer_stateless3/exp \
--max-duration 600 \
--num-encoder-layers 12 \
--rnn-hidden-size 1024 \
--decoding-method $decoding_method \
--use-averaged-model True \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
--beam-size 4
done
```
The streaming decoding command using greedy search, fast beam search, and modified beam search is:
```bash
for decoding_method in greedy_search fast_beam_search modified_beam_search; do
./lstm_transducer_stateless3/streaming_decode.py \
--epoch 40 \
--avg 15 \
--exp-dir lstm_transducer_stateless3/exp \
--max-duration 600 \
--num-encoder-layers 12 \
--rnn-hidden-size 1024 \
--decoding-method $decoding_method \
--use-averaged-model True \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
--beam-size 4
done
```
Pretrained models, training logs, decoding logs, and decoding results
are available at
<https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless3-2022-09-28>
### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T + multi-dataset)
#### [lstm_transducer_stateless2](./lstm_transducer_stateless2)
See <https://github.com/k2-fsa/icefall/pull/558> for more details.
The WERs are:
| | test-clean | test-other | comment |
@ -18,6 +106,7 @@ The WERs are:
| modified_beam_search | 2.75 | 7.08 | --iter 472000 --avg 18 |
| fast_beam_search | 2.77 | 7.29 | --iter 472000 --avg 18 |
The training command is:
```bash
@ -70,15 +159,16 @@ Pretrained models, training logs, decoding logs, and decoding results
are available at
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03>
#### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T)
[lstm_transducer_stateless](./lstm_transducer_stateless)
### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T)
#### [lstm_transducer_stateless](./lstm_transducer_stateless)
It implements LSTM model with mechanisms in reworked model for streaming ASR.
See <https://github.com/k2-fsa/icefall/pull/479> for more details.
#### training on full librispeech
##### training on full librispeech
This model contains 12 encoder layers (LSTM module + Feedforward module). The number of model parameters is 84689496.
@ -165,7 +255,7 @@ It is modified from [torchaudio](https://github.com/pytorch/audio).
See <https://github.com/k2-fsa/icefall/pull/440> for more details.
#### With lower latency setup, training on full librispeech
##### With lower latency setup, training on full librispeech
In this model, the lengths of chunk and right context are 32 frames (i.e., 0.32s) and 8 frames (i.e., 0.08s), respectively.
@ -316,7 +406,7 @@ Pretrained models, training logs, decoding logs, and decoding results
are available at
<https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05>
#### With higher latency setup, training on full librispeech
##### With higher latency setup, training on full librispeech
In this model, the lengths of chunk and right context are 64 frames (i.e., 0.64s) and 16 frames (i.e., 0.16s), respectively.
@ -851,14 +941,14 @@ Pre-trained models, training and decoding logs, and decoding results are availab
### LibriSpeech BPE training results (Pruned Stateless Conv-Emformer RNN-T)
[conv_emformer_transducer_stateless](./conv_emformer_transducer_stateless)
#### [conv_emformer_transducer_stateless](./conv_emformer_transducer_stateless)
It implements [Emformer](https://arxiv.org/abs/2010.10759) augmented with convolution module for streaming ASR.
It is modified from [torchaudio](https://github.com/pytorch/audio).
See <https://github.com/k2-fsa/icefall/pull/389> for more details.
#### Training on full librispeech
##### Training on full librispeech
In this model, the lengths of chunk and right context are 32 frames (i.e., 0.32s) and 8 frames (i.e., 0.08s), respectively.
@ -1011,7 +1101,7 @@ are available at
### LibriSpeech BPE training results (Pruned Stateless Emformer RNN-T)
[pruned_stateless_emformer_rnnt2](./pruned_stateless_emformer_rnnt2)
#### [pruned_stateless_emformer_rnnt2](./pruned_stateless_emformer_rnnt2)
Use <https://github.com/k2-fsa/icefall/pull/390>.
@ -1079,7 +1169,7 @@ results at:
### LibriSpeech BPE training results (Pruned Stateless Transducer 5)
[pruned_transducer_stateless5](./pruned_transducer_stateless5)
#### [pruned_transducer_stateless5](./pruned_transducer_stateless5)
Same as `Pruned Stateless Transducer 2` but with more layers.
@ -1092,7 +1182,7 @@ The notations `large` and `medium` below are from the [Conformer](https://arxiv.
paper, where the large model has about 118 M parameters and the medium model
has 30.8 M parameters.
#### Large
##### Large
Number of model parameters 118129516 (i.e, 118.13 M).
@ -1152,7 +1242,7 @@ results at:
<https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless5-2022-07-07>
#### Medium
##### Medium
Number of model parameters 30896748 (i.e, 30.9 M).
@ -1212,7 +1302,7 @@ results at:
<https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless5-M-2022-07-07>
#### Baseline-2
##### Baseline-2
It has 88.98 M parameters. Compared to the model in pruned_transducer_stateless2, its has more
layers (24 v.s 12) but a narrower model (1536 feedforward dim and 384 encoder dim vs 2048 feed forward dim and 512 encoder dim).
@ -1273,13 +1363,13 @@ results at:
### LibriSpeech BPE training results (Pruned Stateless Transducer 4)
[pruned_transducer_stateless4](./pruned_transducer_stateless4)
#### [pruned_transducer_stateless4](./pruned_transducer_stateless4)
This version saves averaged model during training, and decodes with averaged model.
See <https://github.com/k2-fsa/icefall/issues/337> for details about the idea of model averaging.
#### Training on full librispeech
##### Training on full librispeech
See <https://github.com/k2-fsa/icefall/pull/344>
@ -1355,7 +1445,7 @@ Pretrained models, training logs, decoding logs, and decoding results
are available at
<https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless4-2022-06-03>
#### Training on train-clean-100
##### Training on train-clean-100
See <https://github.com/k2-fsa/icefall/pull/344>
@ -1392,7 +1482,7 @@ The tensorboard log can be found at
### LibriSpeech BPE training results (Pruned Stateless Transducer 3, 2022-04-29)
[pruned_transducer_stateless3](./pruned_transducer_stateless3)
#### [pruned_transducer_stateless3](./pruned_transducer_stateless3)
Same as `Pruned Stateless Transducer 2` but using the XL subset from
[GigaSpeech](https://github.com/SpeechColab/GigaSpeech) as extra training data.
@ -1606,10 +1696,10 @@ can be found at
### LibriSpeech BPE training results (Pruned Transducer 2)
[pruned_transducer_stateless2](./pruned_transducer_stateless2)
#### [pruned_transducer_stateless2](./pruned_transducer_stateless2)
This is with a reworked version of the conformer encoder, with many changes.
#### Training on fulll librispeech
##### Training on full librispeech
Using commit `34aad74a2c849542dd5f6359c9e6b527e8782fd6`.
See <https://github.com/k2-fsa/icefall/pull/288>
@ -1658,7 +1748,7 @@ can be found at
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless2-2022-04-29>
#### Training on train-clean-100:
##### Training on train-clean-100:
Trained with 1 job:
```

View File

@ -253,7 +253,9 @@ class ConformerEncoderLayer(nn.Module):
residual = src
if self.normalize_before:
src = self.norm_conv(src)
src = residual + self.dropout(self.conv_module(src))
src = residual + self.dropout(
self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
)
if not self.normalize_before:
src = self.norm_conv(src)
@ -890,11 +892,16 @@ class ConvolutionModule(nn.Module):
)
self.activation = Swish()
def forward(self, x: Tensor) -> Tensor:
def forward(
self,
x: Tensor,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
src_key_padding_mask: the mask for the src keys per batch (optional).
Returns:
Tensor: Output tensor (#time, batch, channels).
@ -908,6 +915,8 @@ class ConvolutionModule(nn.Module):
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
if src_key_padding_mask is not None:
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
x = self.depthwise_conv(x)
if self.use_batchnorm:
x = self.norm(x)

View File

@ -268,7 +268,9 @@ class ConformerEncoderLayer(nn.Module):
src = src + self.dropout(src_att)
# convolution module
src = src + self.dropout(self.conv_module(src))
src = src + self.dropout(
self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
)
# feed forward module
src = src + self.dropout(self.feed_forward(src))
@ -921,11 +923,16 @@ class ConvolutionModule(nn.Module):
initial_scale=0.25,
)
def forward(self, x: Tensor) -> Tensor:
def forward(
self,
x: Tensor,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
src_key_padding_mask: the mask for the src keys per batch (optional).
Returns:
Tensor: Output tensor (#time, batch, channels).
@ -941,6 +948,8 @@ class ConvolutionModule(nn.Module):
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
if src_key_padding_mask is not None:
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
x = self.depthwise_conv(x)
x = self.deriv_balancer2(x)

View File

@ -247,7 +247,9 @@ class ConformerEncoderLayer(nn.Module):
residual = src
if self.normalize_before:
src = self.norm_conv(src)
src = residual + self.dropout(self.conv_module(src))
src = residual + self.dropout(
self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
)
if not self.normalize_before:
src = self.norm_conv(src)
@ -878,11 +880,16 @@ class ConvolutionModule(nn.Module):
)
self.activation = Swish()
def forward(self, x: Tensor) -> Tensor:
def forward(
self,
x: Tensor,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
src_key_padding_mask: the mask for the src keys per batch (optional).
Returns:
Tensor: Output tensor (#time, batch, channels).
@ -896,6 +903,8 @@ class ConvolutionModule(nn.Module):
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
if src_key_padding_mask is not None:
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
x = self.depthwise_conv(x)
x = self.activation(self.norm(x))

View File

@ -116,6 +116,8 @@ class RNN(EncoderInterface):
Period of auxiliary layers used for random combiner during training.
If set to 0, will not use the random combiner (Default).
You can set a positive integer to use the random combiner, e.g., 3.
is_pnnx:
True to make this class exportable via PNNX.
"""
def __init__(
@ -129,6 +131,7 @@ class RNN(EncoderInterface):
dropout: float = 0.1,
layer_dropout: float = 0.075,
aux_layer_period: int = 0,
is_pnnx: bool = False,
) -> None:
super(RNN, self).__init__()
@ -142,7 +145,13 @@ class RNN(EncoderInterface):
# That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_features -> d_model
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.encoder_embed = Conv2dSubsampling(
num_features,
d_model,
is_pnnx=is_pnnx,
)
self.is_pnnx = is_pnnx
self.num_encoder_layers = num_encoder_layers
self.d_model = d_model
@ -209,7 +218,13 @@ class RNN(EncoderInterface):
# lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning
#
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
lengths = (((x_lens - 3) >> 1) - 1) >> 1
if not self.is_pnnx:
lengths = (((x_lens - 3) >> 1) - 1) >> 1
else:
lengths1 = torch.floor((x_lens - 3) / 2)
lengths = torch.floor((lengths1 - 1) / 2)
lengths = lengths.to(x_lens)
if not torch.jit.is_tracing():
assert x.size(0) == lengths.max().item()
@ -359,7 +374,7 @@ class RNNEncoderLayer(nn.Module):
# for cell state
assert states[1].shape == (1, src.size(1), self.rnn_hidden_size)
src_lstm, new_states = self.lstm(src, states)
src = src + self.dropout(src_lstm)
src = self.dropout(src_lstm) + src
# feed forward module
src = src + self.dropout(self.feed_forward(src))
@ -505,6 +520,7 @@ class Conv2dSubsampling(nn.Module):
layer1_channels: int = 8,
layer2_channels: int = 32,
layer3_channels: int = 128,
is_pnnx: bool = False,
) -> None:
"""
Args:
@ -517,6 +533,9 @@ class Conv2dSubsampling(nn.Module):
Number of channels in layer1
layer1_channels:
Number of channels in layer2
is_pnnx:
True if we are converting the model to PNNX format.
False otherwise.
"""
assert in_channels >= 9
super().__init__()
@ -559,6 +578,10 @@ class Conv2dSubsampling(nn.Module):
channel_dim=-1, min_positive=0.45, max_positive=0.55
)
# ncnn supports only batch size == 1
self.is_pnnx = is_pnnx
self.conv_out_dim = self.out.weight.shape[1]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
@ -572,9 +595,15 @@ class Conv2dSubsampling(nn.Module):
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
x = self.conv(x)
# Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if torch.jit.is_tracing() and self.is_pnnx:
x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim)
x = self.out(x)
else:
# Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape (N, ((T-3)//2-1))//2, odim)
x = self.out_norm(x)
x = self.out_balancer(x)

View File

@ -169,6 +169,18 @@ def get_parser():
""",
)
parser.add_argument(
"--pnnx",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.trace for later
converting to PNNX. It will generate 3 files:
- encoder_jit_trace-pnnx.pt
- decoder_jit_trace-pnnx.pt
- joiner_jit_trace-pnnx.pt
""",
)
parser.add_argument(
"--context-size",
type=int,
@ -277,6 +289,10 @@ def main():
logging.info(params)
if params.pnnx:
params.is_pnnx = params.pnnx
logging.info("For PNNX")
logging.info("About to create model")
model = get_transducer_model(params, enable_giga=False)
@ -371,7 +387,18 @@ def main():
model.to("cpu")
model.eval()
if params.jit_trace is True:
if params.pnnx:
convert_scaled_to_non_scaled(model, inplace=True)
logging.info("Using torch.jit.trace()")
encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt"
export_encoder_model_jit_trace(model.encoder, encoder_filename)
decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt"
export_decoder_model_jit_trace(model.decoder, decoder_filename)
joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt"
export_joiner_model_jit_trace(model.joiner, joiner_filename)
elif params.jit_trace is True:
convert_scaled_to_non_scaled(model, inplace=True)
logging.info("Using torch.jit.trace()")
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"

View File

@ -0,0 +1,295 @@
#!/usr/bin/env python3
# flake8: noqa
#
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./lstm_transducer_stateless2/ncnn-decode.py \
--bpe-model-filename ./data/lang_bpe_500/bpe.model \
--encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \
--encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \
--decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \
--decoder-bin-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \
--joiner-param-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \
--joiner-bin-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \
./test_wavs/1089-134686-0001.wav
"""
import argparse
import logging
from typing import List
import kaldifeat
import ncnn
import sentencepiece as spm
import torch
import torchaudio
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--bpe-model-filename",
type=str,
help="Path to bpe.model",
)
parser.add_argument(
"--encoder-param-filename",
type=str,
help="Path to encoder.ncnn.param",
)
parser.add_argument(
"--encoder-bin-filename",
type=str,
help="Path to encoder.ncnn.bin",
)
parser.add_argument(
"--decoder-param-filename",
type=str,
help="Path to decoder.ncnn.param",
)
parser.add_argument(
"--decoder-bin-filename",
type=str,
help="Path to decoder.ncnn.bin",
)
parser.add_argument(
"--joiner-param-filename",
type=str,
help="Path to joiner.ncnn.param",
)
parser.add_argument(
"--joiner-bin-filename",
type=str,
help="Path to joiner.ncnn.bin",
)
parser.add_argument(
"sound_filename",
type=str,
help="Path to foo.wav",
)
return parser.parse_args()
class Model:
def __init__(self, args):
self.init_encoder(args)
self.init_decoder(args)
self.init_joiner(args)
def init_encoder(self, args):
encoder_net = ncnn.Net()
encoder_net.opt.use_packing_layout = False
encoder_net.opt.use_fp16_storage = False
encoder_param = args.encoder_param_filename
encoder_model = args.encoder_bin_filename
encoder_net.load_param(encoder_param)
encoder_net.load_model(encoder_model)
self.encoder_net = encoder_net
def init_decoder(self, args):
decoder_param = args.decoder_param_filename
decoder_model = args.decoder_bin_filename
decoder_net = ncnn.Net()
decoder_net.opt.use_packing_layout = False
decoder_net.load_param(decoder_param)
decoder_net.load_model(decoder_model)
self.decoder_net = decoder_net
def init_joiner(self, args):
joiner_param = args.joiner_param_filename
joiner_model = args.joiner_bin_filename
joiner_net = ncnn.Net()
joiner_net.opt.use_packing_layout = False
joiner_net.load_param(joiner_param)
joiner_net.load_model(joiner_model)
self.joiner_net = joiner_net
def run_encoder(self, x, states):
with self.encoder_net.create_extractor() as ex:
ex.set_num_threads(10)
ex.input("in0", ncnn.Mat(x.numpy()).clone())
x_lens = torch.tensor([x.size(0)], dtype=torch.float32)
ex.input("in1", ncnn.Mat(x_lens.numpy()).clone())
ex.input("in2", ncnn.Mat(states[0].numpy()).clone())
ex.input("in3", ncnn.Mat(states[1].numpy()).clone())
ret, ncnn_out0 = ex.extract("out0")
assert ret == 0, ret
ret, ncnn_out1 = ex.extract("out1")
assert ret == 0, ret
ret, ncnn_out2 = ex.extract("out2")
assert ret == 0, ret
ret, ncnn_out3 = ex.extract("out3")
assert ret == 0, ret
encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to(
torch.int32
)
hx = torch.from_numpy(ncnn_out2.numpy()).clone()
cx = torch.from_numpy(ncnn_out3.numpy()).clone()
return encoder_out, encoder_out_lens, hx, cx
def run_decoder(self, decoder_input):
assert decoder_input.dtype == torch.int32
with self.decoder_net.create_extractor() as ex:
ex.set_num_threads(10)
ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone())
ret, ncnn_out0 = ex.extract("out0")
assert ret == 0, ret
decoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
return decoder_out
def run_joiner(self, encoder_out, decoder_out):
with self.joiner_net.create_extractor() as ex:
ex.set_num_threads(10)
ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone())
ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone())
ret, ncnn_out0 = ex.extract("out0")
assert ret == 0, ret
joiner_out = torch.from_numpy(ncnn_out0.numpy()).clone()
return joiner_out
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
def greedy_search(model: Model, encoder_out: torch.Tensor):
assert encoder_out.ndim == 2
T = encoder_out.size(0)
context_size = 2
blank_id = 0 # hard-code to 0
hyp = [blank_id] * context_size
decoder_input = torch.tensor(hyp, dtype=torch.int32) # (1, context_size)
decoder_out = model.run_decoder(decoder_input).squeeze(0)
# print(decoder_out.shape) # (512,)
for t in range(T):
encoder_out_t = encoder_out[t]
joiner_out = model.run_joiner(encoder_out_t, decoder_out)
# print(joiner_out.shape) # [500]
y = joiner_out.argmax(dim=0).tolist()
if y != blank_id:
hyp.append(y)
decoder_input = hyp[-context_size:]
decoder_input = torch.tensor(decoder_input, dtype=torch.int32)
decoder_out = model.run_decoder(decoder_input).squeeze(0)
return hyp[context_size:]
def main():
args = get_args()
logging.info(vars(args))
model = Model(args)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model_filename)
sound_file = args.sound_filename
sample_rate = 16000
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = sample_rate
opts.mel_opts.num_bins = 80
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {sound_file}")
wave_samples = read_sound_files(
filenames=[sound_file],
expected_sample_rate=sample_rate,
)[0]
logging.info("Decoding started")
features = fbank(wave_samples)
num_encoder_layers = 12
d_model = 512
rnn_hidden_size = 1024
states = (
torch.zeros(num_encoder_layers, d_model),
torch.zeros(
num_encoder_layers,
rnn_hidden_size,
),
)
encoder_out, encoder_out_lens, hx, cx = model.run_encoder(features, states)
hyp = greedy_search(model, encoder_out)
logging.info(sound_file)
logging.info(sp.decode(hyp))
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,353 @@
#!/usr/bin/env python3
# flake8: noqa
#
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from typing import List, Optional
import ncnn
import sentencepiece as spm
import torch
import torchaudio
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--bpe-model-filename",
type=str,
help="Path to bpe.model",
)
parser.add_argument(
"--encoder-param-filename",
type=str,
help="Path to encoder.ncnn.param",
)
parser.add_argument(
"--encoder-bin-filename",
type=str,
help="Path to encoder.ncnn.bin",
)
parser.add_argument(
"--decoder-param-filename",
type=str,
help="Path to decoder.ncnn.param",
)
parser.add_argument(
"--decoder-bin-filename",
type=str,
help="Path to decoder.ncnn.bin",
)
parser.add_argument(
"--joiner-param-filename",
type=str,
help="Path to joiner.ncnn.param",
)
parser.add_argument(
"--joiner-bin-filename",
type=str,
help="Path to joiner.ncnn.bin",
)
parser.add_argument(
"sound_filename",
type=str,
help="Path to foo.wav",
)
return parser.parse_args()
class Model:
def __init__(self, args):
self.init_encoder(args)
self.init_decoder(args)
self.init_joiner(args)
def init_encoder(self, args):
encoder_net = ncnn.Net()
encoder_net.opt.use_packing_layout = False
encoder_net.opt.use_fp16_storage = False
encoder_param = args.encoder_param_filename
encoder_model = args.encoder_bin_filename
encoder_net.load_param(encoder_param)
encoder_net.load_model(encoder_model)
self.encoder_net = encoder_net
def init_decoder(self, args):
decoder_param = args.decoder_param_filename
decoder_model = args.decoder_bin_filename
decoder_net = ncnn.Net()
decoder_net.opt.use_packing_layout = False
decoder_net.load_param(decoder_param)
decoder_net.load_model(decoder_model)
self.decoder_net = decoder_net
def init_joiner(self, args):
joiner_param = args.joiner_param_filename
joiner_model = args.joiner_bin_filename
joiner_net = ncnn.Net()
joiner_net.opt.use_packing_layout = False
joiner_net.load_param(joiner_param)
joiner_net.load_model(joiner_model)
self.joiner_net = joiner_net
def run_encoder(self, x, states):
with self.encoder_net.create_extractor() as ex:
# ex.set_num_threads(10)
ex.input("in0", ncnn.Mat(x.numpy()).clone())
x_lens = torch.tensor([x.size(0)], dtype=torch.float32)
ex.input("in1", ncnn.Mat(x_lens.numpy()).clone())
ex.input("in2", ncnn.Mat(states[0].numpy()).clone())
ex.input("in3", ncnn.Mat(states[1].numpy()).clone())
ret, ncnn_out0 = ex.extract("out0")
assert ret == 0, ret
ret, ncnn_out1 = ex.extract("out1")
assert ret == 0, ret
ret, ncnn_out2 = ex.extract("out2")
assert ret == 0, ret
ret, ncnn_out3 = ex.extract("out3")
assert ret == 0, ret
encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to(
torch.int32
)
hx = torch.from_numpy(ncnn_out2.numpy()).clone()
cx = torch.from_numpy(ncnn_out3.numpy()).clone()
return encoder_out, encoder_out_lens, hx, cx
def run_decoder(self, decoder_input):
assert decoder_input.dtype == torch.int32
with self.decoder_net.create_extractor() as ex:
# ex.set_num_threads(10)
ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone())
ret, ncnn_out0 = ex.extract("out0")
assert ret == 0, ret
decoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
return decoder_out
def run_joiner(self, encoder_out, decoder_out):
with self.joiner_net.create_extractor() as ex:
# ex.set_num_threads(10)
ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone())
ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone())
ret, ncnn_out0 = ex.extract("out0")
assert ret == 0, ret
joiner_out = torch.from_numpy(ncnn_out0.numpy()).clone()
return joiner_out
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
def create_streaming_feature_extractor() -> OnlineFeature:
"""Create a CPU streaming feature extractor.
At present, we assume it returns a fbank feature extractor with
fixed options. In the future, we will support passing in the options
from outside.
Returns:
Return a CPU streaming feature extractor.
"""
opts = FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
return OnlineFbank(opts)
def greedy_search(
model: Model,
encoder_out: torch.Tensor,
decoder_out: Optional[torch.Tensor] = None,
hyp: Optional[List[int]] = None,
):
assert encoder_out.ndim == 1
context_size = 2
blank_id = 0
if decoder_out is None:
assert hyp is None, hyp
hyp = [blank_id] * context_size
decoder_input = torch.tensor(
hyp, dtype=torch.int32
) # (1, context_size)
decoder_out = model.run_decoder(decoder_input).squeeze(0)
else:
assert decoder_out.ndim == 1
assert hyp is not None, hyp
joiner_out = model.run_joiner(encoder_out, decoder_out)
y = joiner_out.argmax(dim=0).tolist()
if y != blank_id:
hyp.append(y)
decoder_input = hyp[-context_size:]
decoder_input = torch.tensor(decoder_input, dtype=torch.int32)
decoder_out = model.run_decoder(decoder_input).squeeze(0)
return hyp, decoder_out
def main():
args = get_args()
logging.info(vars(args))
model = Model(args)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model_filename)
sound_file = args.sound_filename
sample_rate = 16000
logging.info("Constructing Fbank computer")
online_fbank = create_streaming_feature_extractor()
logging.info(f"Reading sound files: {sound_file}")
wave_samples = read_sound_files(
filenames=[sound_file],
expected_sample_rate=sample_rate,
)[0]
logging.info(wave_samples.shape)
num_encoder_layers = 12
batch_size = 1
d_model = 512
rnn_hidden_size = 1024
states = (
torch.zeros(num_encoder_layers, batch_size, d_model),
torch.zeros(
num_encoder_layers,
batch_size,
rnn_hidden_size,
),
)
hyp = None
decoder_out = None
num_processed_frames = 0
segment = 9
offset = 4
chunk = 3200 # 0.2 second
start = 0
while start < wave_samples.numel():
end = min(start + chunk, wave_samples.numel())
samples = wave_samples[start:end]
start += chunk
online_fbank.accept_waveform(
sampling_rate=sample_rate,
waveform=samples,
)
while online_fbank.num_frames_ready - num_processed_frames >= segment:
frames = []
for i in range(segment):
frames.append(online_fbank.get_frame(num_processed_frames + i))
num_processed_frames += offset
frames = torch.cat(frames, dim=0)
encoder_out, encoder_out_lens, hx, cx = model.run_encoder(
frames, states
)
states = (hx, cx)
hyp, decoder_out = greedy_search(
model, encoder_out.squeeze(0), decoder_out, hyp
)
online_fbank.accept_waveform(
sampling_rate=sample_rate, waveform=torch.zeros(8000, dtype=torch.int32)
)
online_fbank.input_finished()
while online_fbank.num_frames_ready - num_processed_frames >= segment:
frames = []
for i in range(segment):
frames.append(online_fbank.get_frame(num_processed_frames + i))
num_processed_frames += offset
frames = torch.cat(frames, dim=0)
encoder_out, encoder_out_lens, hx, cx = model.run_encoder(
frames, states
)
states = (hx, cx)
hyp, decoder_out = greedy_search(
model, encoder_out.squeeze(0), decoder_out, hyp
)
context_size = 2
logging.info(sound_file)
logging.info(sp.decode(hyp[context_size:]))
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -406,6 +406,8 @@ def get_params() -> AttributeDict:
"decoder_dim": 512,
# parameters for joiner
"joiner_dim": 512,
# True to generate a model that can be exported via PNNX
"is_pnnx": False,
# parameters for Noam
"model_warm_step": 3000, # arg given to model, not for lrate
"env_info": get_env_info(),
@ -424,6 +426,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
aux_layer_period=params.aux_layer_period,
is_pnnx=params.is_pnnx,
)
return encoder

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/__init__.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/asr_datamodule.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/beam_search.py

View File

@ -0,0 +1,818 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./lstm_transducer_stateless3/decode.py \
--epoch 40 \
--avg 20 \
--exp-dir ./lstm_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search (not recommended)
./lstm_transducer_stateless2/decode.py \
--epoch 40 \
--avg 20 \
--exp-dir ./lstm_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./lstm_transducer_stateless3/decode.py \
--epoch 40 \
--avg 20 \
--exp-dir ./lstm_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search (one best)
./lstm_transducer_stateless3/decode.py \
--epoch 40 \
--avg 20 \
--exp-dir ./lstm_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(5) fast beam search (nbest)
./lstm_transducer_stateless3/decode.py \
--epoch 40 \
--avg 20 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./lstm_transducer_stateless3/decode.py \
--epoch 40 \
--avg 20 \
--exp-dir ./lstm_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(7) fast beam search (with LG)
./lstm_transducer_stateless3/decode.py \
--epoch 40 \
--avg 20 \
--exp-dir ./lstm_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
"""
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="lstm_transducer_stateless/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
# tail padding here to alleviate the tail deletion problem
num_tail_padded_frames = 35
feature = torch.nn.functional.pad(
feature,
(0, 0, 0, num_tail_padded_frames),
mode="constant",
value=LOG_EPS,
)
feature_lens += num_tail_padded_frames
encoder_out, encoder_out_lens, _ = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
else:
decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/decoder.py

View File

@ -0,0 +1 @@
../transducer_stateless/encoder_interface.py

View File

@ -0,0 +1,388 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
(1) Export to torchscript model using torch.jit.trace()
./lstm_transducer_stateless3/export.py \
--exp-dir ./lstm_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 40 \
--avg 20 \
--jit-trace 1
It will generate 3 files: `encoder_jit_trace.pt`,
`decoder_jit_trace.pt`, and `joiner_jit_trace.pt`.
(2) Export `model.state_dict()`
./lstm_transducer_stateless3/export.py \
--exp-dir ./lstm_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 40 \
--avg 20
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
load it by `icefall.checkpoint.load_checkpoint()`.
To use the generated file with `lstm_transducer_stateless3/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./lstm_transducer_stateless3/decode.py \
--exp-dir ./lstm_transducer_stateless3/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
Check ./pretrained.py for its usage.
Note: If you don't want to train a model from scratch, we have
provided one for you. You can get it at
https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18
with the following commands:
sudo apt-get install git-lfs
git lfs install
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18
# You will find the pre-trained model in icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18/exp
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless3/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--jit-trace",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.trace.
It will generate 3 files:
- encoder_jit_trace.pt
- decoder_jit_trace.pt
- joiner_jit_trace.pt
Check ./jit_pretrained.py for how to use them.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
add_model_arguments(parser)
return parser
def export_encoder_model_jit_trace(
encoder_model: nn.Module,
encoder_filename: str,
) -> None:
"""Export the given encoder model with torch.jit.trace()
Note: The warmup argument is fixed to 1.
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported model.
"""
x = torch.zeros(1, 100, 80, dtype=torch.float32)
x_lens = torch.tensor([100], dtype=torch.int64)
states = encoder_model.get_init_states()
traced_model = torch.jit.trace(encoder_model, (x, x_lens, states))
traced_model.save(encoder_filename)
logging.info(f"Saved to {encoder_filename}")
def export_decoder_model_jit_trace(
decoder_model: nn.Module,
decoder_filename: str,
) -> None:
"""Export the given decoder model with torch.jit.trace()
Note: The argument need_pad is fixed to False.
Args:
decoder_model:
The input decoder model
decoder_filename:
The filename to save the exported model.
"""
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
need_pad = torch.tensor([False])
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
traced_model.save(decoder_filename)
logging.info(f"Saved to {decoder_filename}")
def export_joiner_model_jit_trace(
joiner_model: nn.Module,
joiner_filename: str,
) -> None:
"""Export the given joiner model with torch.jit.trace()
Note: The argument project_input is fixed to True. A user should not
project the encoder_out/decoder_out by himself/herself. The exported joiner
will do that for the user.
Args:
joiner_model:
The input joiner model
joiner_filename:
The filename to save the exported model.
"""
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
traced_model.save(joiner_filename)
logging.info(f"Saved to {joiner_filename}")
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
if params.jit_trace is True:
convert_scaled_to_non_scaled(model, inplace=True)
logging.info("Using torch.jit.trace()")
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
export_encoder_model_jit_trace(model.encoder, encoder_filename)
decoder_filename = params.exp_dir / "decoder_jit_trace.pt"
export_decoder_model_jit_trace(model.decoder, decoder_filename)
joiner_filename = params.exp_dir / "joiner_jit_trace.pt"
export_joiner_model_jit_trace(model.joiner, joiner_filename)
else:
logging.info("Not using torchscript")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,322 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads torchscript models, either exported by `torch.jit.trace()`
or by `torch.jit.script()`, and uses them to decode waves.
You can use the following command to get the exported models:
./lstm_transducer_stateless3/export.py \
--exp-dir ./lstm_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 40 \
--avg 15 \
--jit-trace 1
Usage of this script:
./lstm_transducer_stateless3/jit_pretrained.py \
--encoder-model-filename ./lstm_transducer_stateless3/exp/encoder_jit_trace.pt \
--decoder-model-filename ./lstm_transducer_stateless3/exp/decoder_jit_trace.pt \
--joiner-model-filename ./lstm_transducer_stateless3/exp/joiner_jit_trace.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
/path/to/foo.wav \
/path/to/bar.wav
"""
import argparse
import logging
import math
from typing import List
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--encoder-model-filename",
type=str,
required=True,
help="Path to the encoder torchscript model. ",
)
parser.add_argument(
"--decoder-model-filename",
type=str,
required=True,
help="Path to the decoder torchscript model. ",
)
parser.add_argument(
"--joiner-model-filename",
type=str,
required=True,
help="Path to the joiner torchscript model. ",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="Context size of the decoder model",
)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
def greedy_search(
decoder: torch.jit.ScriptModule,
joiner: torch.jit.ScriptModule,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
context_size: int,
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
decoder:
The decoder model.
joiner:
The joiner model.
encoder_out:
A 3-D tensor of shape (N, T, C)
encoder_out_lens:
A 1-D tensor of shape (N,).
context_size:
The context size of the decoder model.
Returns:
Return the decoded results for each utterance.
"""
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
device = encoder_out.device
blank_id = 0 # hard-code to 0
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
hyps = [[blank_id] * context_size for _ in range(N)]
decoder_input = torch.tensor(
hyps,
device=device,
dtype=torch.int64,
) # (N, context_size)
decoder_out = decoder(
decoder_input,
need_pad=torch.tensor([False]),
).squeeze(1)
offset = 0
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = packed_encoder_out.data[start:end]
current_encoder_out = current_encoder_out
# current_encoder_out's shape: (batch_size, encoder_out_dim)
offset = end
decoder_out = decoder_out[:batch_size]
logits = joiner(
current_encoder_out,
decoder_out,
)
# logits'shape (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
hyps[i].append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
decoder_input = torch.tensor(
decoder_input,
device=device,
dtype=torch.int64,
)
decoder_out = decoder(
decoder_input,
need_pad=torch.tensor([False]),
)
decoder_out = decoder_out.squeeze(1)
sorted_ans = [h[context_size:] for h in hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
encoder = torch.jit.load(args.encoder_model_filename)
decoder = torch.jit.load(args.decoder_model_filename)
joiner = torch.jit.load(args.joiner_model_filename)
encoder.eval()
decoder.eval()
joiner.eval()
encoder.to(device)
decoder.to(device)
joiner.to(device)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files,
expected_sample_rate=args.sample_rate,
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features,
batch_first=True,
padding_value=math.log(1e-10),
)
feature_lengths = torch.tensor(feature_lengths, device=device)
states = encoder.get_init_states(batch_size=features.size(0), device=device)
encoder_out, encoder_out_lens, _ = encoder(
x=features,
x_lens=feature_lengths,
states=states,
)
hyps = greedy_search(
decoder=decoder,
joiner=joiner,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
context_size=args.context_size,
)
s = "\n"
for filename, hyp in zip(args.sound_files, hyps):
words = sp.decode(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/joiner.py

View File

@ -0,0 +1,860 @@
# Copyright 2022 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import math
from typing import List, Optional, Tuple
import torch
from encoder_interface import EncoderInterface
from scaling import (
ActivationBalancer,
BasicNorm,
DoubleSwish,
ScaledConv2d,
ScaledLinear,
ScaledLSTM,
)
from torch import nn
LOG_EPSILON = math.log(1e-10)
def unstack_states(
states: Tuple[torch.Tensor, torch.Tensor]
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
"""
Unstack the lstm states corresponding to a batch of utterances into a list
of states, where the i-th entry is the state from the i-th utterance.
Args:
states:
A tuple of 2 elements.
``states[0]`` is the lstm hidden states, of a batch of utterance.
``states[1]`` is the lstm cell states, of a batch of utterances.
Returns:
A list of states.
``states[i]`` is a tuple of 2 elememts of i-th utterance.
``states[i][0]`` is the lstm hidden states of i-th utterance.
``states[i][1]`` is the lstm cell states of i-th utterance.
"""
hidden_states, cell_states = states
list_hidden_states = hidden_states.unbind(dim=1)
list_cell_states = cell_states.unbind(dim=1)
ans = [
(h.unsqueeze(1), c.unsqueeze(1))
for (h, c) in zip(list_hidden_states, list_cell_states)
]
return ans
def stack_states(
states_list: List[Tuple[torch.Tensor, torch.Tensor]]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Stack list of lstm states corresponding to separate utterances into a single
lstm state so that it can be used as an input for lstm when those utterances
are formed into a batch.
Args:
state_list:
Each element in state_list corresponds to the lstm state for a single
utterance.
``states[i]`` is a tuple of 2 elememts of i-th utterance.
``states[i][0]`` is the lstm hidden states of i-th utterance.
``states[i][1]`` is the lstm cell states of i-th utterance.
Returns:
A new state corresponding to a batch of utterances.
It is a tuple of 2 elements.
``states[0]`` is the lstm hidden states, of a batch of utterance.
``states[1]`` is the lstm cell states, of a batch of utterances.
"""
hidden_states = torch.cat([s[0] for s in states_list], dim=1)
cell_states = torch.cat([s[1] for s in states_list], dim=1)
ans = (hidden_states, cell_states)
return ans
class RNN(EncoderInterface):
"""
Args:
num_features (int):
Number of input features.
subsampling_factor (int):
Subsampling factor of encoder (convolution layers before lstm layers) (default=4). # noqa
d_model (int):
Output dimension (default=512).
dim_feedforward (int):
Feedforward dimension (default=2048).
rnn_hidden_size (int):
Hidden dimension for lstm layers (default=1024).
grad_norm_threshold:
For each sequence element in batch, its gradient will be
filtered out if the gradient norm is larger than
`grad_norm_threshold * median`, where `median` is the median
value of gradient norms of all elememts in batch.
num_encoder_layers (int):
Number of encoder layers (default=12).
dropout (float):
Dropout rate (default=0.1).
layer_dropout (float):
Dropout value for model-level warmup (default=0.075).
aux_layer_period (int):
Period of auxiliary layers used for random combiner during training.
If set to 0, will not use the random combiner (Default).
You can set a positive integer to use the random combiner, e.g., 3.
"""
def __init__(
self,
num_features: int,
subsampling_factor: int = 4,
d_model: int = 512,
dim_feedforward: int = 2048,
rnn_hidden_size: int = 1024,
grad_norm_threshold: float = 10.0,
num_encoder_layers: int = 12,
dropout: float = 0.1,
layer_dropout: float = 0.075,
aux_layer_period: int = 0,
) -> None:
super(RNN, self).__init__()
self.num_features = num_features
self.subsampling_factor = subsampling_factor
if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.")
# self.encoder_embed converts the input of shape (N, T, num_features)
# to the shape (N, T//subsampling_factor, d_model).
# That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_features -> d_model
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.num_encoder_layers = num_encoder_layers
self.d_model = d_model
self.rnn_hidden_size = rnn_hidden_size
encoder_layer = RNNEncoderLayer(
d_model=d_model,
dim_feedforward=dim_feedforward,
rnn_hidden_size=rnn_hidden_size,
grad_norm_threshold=grad_norm_threshold,
dropout=dropout,
layer_dropout=layer_dropout,
)
self.encoder = RNNEncoder(
encoder_layer,
num_encoder_layers,
aux_layers=list(
range(
num_encoder_layers // 3,
num_encoder_layers - 1,
aux_layer_period,
)
)
if aux_layer_period > 0
else None,
)
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
warmup: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Args:
x:
The input tensor. Its shape is (N, T, C), where N is the batch size,
T is the sequence length, C is the feature dimension.
x_lens:
A tensor of shape (N,), containing the number of frames in `x`
before padding.
states:
A tuple of 2 tensors (optional). It is for streaming inference.
states[0] is the hidden states of all layers,
with shape of (num_layers, N, d_model);
states[1] is the cell states of all layers,
with shape of (num_layers, N, rnn_hidden_size).
warmup:
A floating point value that gradually increases from 0 throughout
training; when it is >= 1.0 we are "fully warmed up". It is used
to turn modules on sequentially.
Returns:
A tuple of 3 tensors:
- embeddings: its shape is (N, T', d_model), where T' is the output
sequence lengths.
- lengths: a tensor of shape (batch_size,) containing the number of
frames in `embeddings` before padding.
- updated states, whose shape is the same as the input states.
"""
x = self.encoder_embed(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
# lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning
#
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
lengths = (((x_lens - 3) >> 1) - 1) >> 1
if not torch.jit.is_tracing():
assert x.size(0) == lengths.max().item()
if states is None:
x = self.encoder(x, warmup=warmup)[0]
# torch.jit.trace requires returned types to be the same as annotated # noqa
new_states = (torch.empty(0), torch.empty(0))
else:
assert not self.training
assert len(states) == 2
if not torch.jit.is_tracing():
# for hidden state
assert states[0].shape == (
self.num_encoder_layers,
x.size(1),
self.d_model,
)
# for cell state
assert states[1].shape == (
self.num_encoder_layers,
x.size(1),
self.rnn_hidden_size,
)
x, new_states = self.encoder(x, states)
x = x.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
return x, lengths, new_states
@torch.jit.export
def get_init_states(
self, batch_size: int = 1, device: torch.device = torch.device("cpu")
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get model initial states."""
# for rnn hidden states
hidden_states = torch.zeros(
(self.num_encoder_layers, batch_size, self.d_model), device=device
)
cell_states = torch.zeros(
(self.num_encoder_layers, batch_size, self.rnn_hidden_size),
device=device,
)
return (hidden_states, cell_states)
class RNNEncoderLayer(nn.Module):
"""
RNNEncoderLayer is made up of lstm and feedforward networks.
For stable training, in each lstm module, gradient filter
is applied to filter out extremely large elements in batch gradients
and also the module parameters with soft masks.
Args:
d_model:
The number of expected features in the input (required).
dim_feedforward:
The dimension of feedforward network model (default=2048).
rnn_hidden_size:
The hidden dimension of rnn layer.
grad_norm_threshold:
For each sequence element in batch, its gradient will be
filtered out if the gradient norm is larger than
`grad_norm_threshold * median`, where `median` is the median
value of gradient norms of all elememts in batch.
dropout:
The dropout value (default=0.1).
layer_dropout:
The dropout value for model-level warmup (default=0.075).
"""
def __init__(
self,
d_model: int,
dim_feedforward: int,
rnn_hidden_size: int,
grad_norm_threshold: float = 10.0,
dropout: float = 0.1,
layer_dropout: float = 0.075,
) -> None:
super(RNNEncoderLayer, self).__init__()
self.layer_dropout = layer_dropout
self.d_model = d_model
self.rnn_hidden_size = rnn_hidden_size
assert rnn_hidden_size >= d_model, (rnn_hidden_size, d_model)
self.lstm = ScaledLSTM(
input_size=d_model,
hidden_size=rnn_hidden_size,
proj_size=d_model if rnn_hidden_size > d_model else 0,
num_layers=1,
dropout=0.0,
grad_norm_threshold=grad_norm_threshold,
)
self.feed_forward = nn.Sequential(
ScaledLinear(d_model, dim_feedforward),
ActivationBalancer(channel_dim=-1),
DoubleSwish(),
nn.Dropout(dropout),
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
)
self.norm_final = BasicNorm(d_model)
# try to ensure the output is close to zero-mean (or at least, zero-median). # noqa
self.balancer = ActivationBalancer(
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
)
self.dropout = nn.Dropout(dropout)
def forward(
self,
src: torch.Tensor,
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
warmup: float = 1.0,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Pass the input through the encoder layer.
Args:
src:
The sequence to the encoder layer (required).
Its shape is (S, N, E), where S is the sequence length,
N is the batch size, and E is the feature number.
states:
A tuple of 2 tensors (optional). It is for streaming inference.
states[0] is the hidden states of all layers,
with shape of (1, N, d_model);
states[1] is the cell states of all layers,
with shape of (1, N, rnn_hidden_size).
warmup:
It controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
"""
src_orig = src
warmup_scale = min(0.1 + warmup, 1.0)
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean
# completely bypass it.
if self.training:
alpha = (
warmup_scale
if torch.rand(()).item() <= (1.0 - self.layer_dropout)
else 0.1
)
else:
alpha = 1.0
# lstm module
if states is None:
src_lstm = self.lstm(src)[0]
# torch.jit.trace requires returned types be the same as annotated
new_states = (torch.empty(0), torch.empty(0))
else:
assert not self.training
assert len(states) == 2
if not torch.jit.is_tracing():
# for hidden state
assert states[0].shape == (1, src.size(1), self.d_model)
# for cell state
assert states[1].shape == (1, src.size(1), self.rnn_hidden_size)
src_lstm, new_states = self.lstm(src, states)
src = src + self.dropout(src_lstm)
# feed forward module
src = src + self.dropout(self.feed_forward(src))
src = self.norm_final(self.balancer(src))
if alpha != 1.0:
src = alpha * src + (1 - alpha) * src_orig
return src, new_states
class RNNEncoder(nn.Module):
"""
RNNEncoder is a stack of N encoder layers.
Args:
encoder_layer:
An instance of the RNNEncoderLayer() class (required).
num_layers:
The number of sub-encoder-layers in the encoder (required).
"""
def __init__(
self,
encoder_layer: nn.Module,
num_layers: int,
aux_layers: Optional[List[int]] = None,
) -> None:
super(RNNEncoder, self).__init__()
self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
)
self.num_layers = num_layers
self.d_model = encoder_layer.d_model
self.rnn_hidden_size = encoder_layer.rnn_hidden_size
self.aux_layers: List[int] = []
self.combiner: Optional[nn.Module] = None
if aux_layers is not None:
assert len(set(aux_layers)) == len(aux_layers)
assert num_layers - 1 not in aux_layers
self.aux_layers = aux_layers + [num_layers - 1]
self.combiner = RandomCombine(
num_inputs=len(self.aux_layers),
final_weight=0.5,
pure_prob=0.333,
stddev=2.0,
)
def forward(
self,
src: torch.Tensor,
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
warmup: float = 1.0,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Pass the input through the encoder layer in turn.
Args:
src:
The sequence to the encoder layer (required).
Its shape is (S, N, E), where S is the sequence length,
N is the batch size, and E is the feature number.
states:
A tuple of 2 tensors (optional). It is for streaming inference.
states[0] is the hidden states of all layers,
with shape of (num_layers, N, d_model);
states[1] is the cell states of all layers,
with shape of (num_layers, N, rnn_hidden_size).
warmup:
It controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
"""
if states is not None:
assert not self.training
assert len(states) == 2
if not torch.jit.is_tracing():
# for hidden state
assert states[0].shape == (
self.num_layers,
src.size(1),
self.d_model,
)
# for cell state
assert states[1].shape == (
self.num_layers,
src.size(1),
self.rnn_hidden_size,
)
output = src
outputs = []
new_hidden_states = []
new_cell_states = []
for i, mod in enumerate(self.layers):
if states is None:
output = mod(output, warmup=warmup)[0]
else:
layer_state = (
states[0][i : i + 1, :, :], # h: (1, N, d_model)
states[1][i : i + 1, :, :], # c: (1, N, rnn_hidden_size)
)
output, (h, c) = mod(output, layer_state)
new_hidden_states.append(h)
new_cell_states.append(c)
if self.combiner is not None and i in self.aux_layers:
outputs.append(output)
if self.combiner is not None:
output = self.combiner(outputs)
if states is None:
new_states = (torch.empty(0), torch.empty(0))
else:
new_states = (
torch.cat(new_hidden_states, dim=0),
torch.cat(new_cell_states, dim=0),
)
return output, new_states
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
Convert an input of shape (N, T, idim) to an output
with shape (N, T', odim), where
T' = ((T-3)//2-1)//2, which approximates T' == T//4
It is based on
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
"""
def __init__(
self,
in_channels: int,
out_channels: int,
layer1_channels: int = 8,
layer2_channels: int = 32,
layer3_channels: int = 128,
) -> None:
"""
Args:
in_channels:
Number of channels in. The input shape is (N, T, in_channels).
Caution: It requires: T >= 9, in_channels >= 9.
out_channels
Output dim. The output shape is (N, ((T-3)//2-1)//2, out_channels)
layer1_channels:
Number of channels in layer1
layer1_channels:
Number of channels in layer2
"""
assert in_channels >= 9
super().__init__()
self.conv = nn.Sequential(
ScaledConv2d(
in_channels=1,
out_channels=layer1_channels,
kernel_size=3,
padding=0,
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
ScaledConv2d(
in_channels=layer1_channels,
out_channels=layer2_channels,
kernel_size=3,
stride=2,
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
ScaledConv2d(
in_channels=layer2_channels,
out_channels=layer3_channels,
kernel_size=3,
stride=2,
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
)
self.out = ScaledLinear(
layer3_channels * (((in_channels - 3) // 2 - 1) // 2), out_channels
)
# set learn_eps=False because out_norm is preceded by `out`, and `out`
# itself has learned scale, so the extra degree of freedom is not
# needed.
self.out_norm = BasicNorm(out_channels, learn_eps=False)
# constrain median of output to be close to zero.
self.out_balancer = ActivationBalancer(
channel_dim=-1, min_positive=0.45, max_positive=0.55
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
Args:
x:
Its shape is (N, T, idim).
Returns:
Return a tensor of shape (N, ((T-3)//2-1)//2, odim)
"""
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
x = self.conv(x)
# Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape (N, ((T-3)//2-1))//2, odim)
x = self.out_norm(x)
x = self.out_balancer(x)
return x
class RandomCombine(nn.Module):
"""
This module combines a list of Tensors, all with the same shape, to
produce a single output of that same shape which, in training time,
is a random combination of all the inputs; but which in test time
will be just the last input.
The idea is that the list of Tensors will be a list of outputs of multiple
conformer layers. This has a similar effect as iterated loss. (See:
DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER
NETWORKS).
"""
def __init__(
self,
num_inputs: int,
final_weight: float = 0.5,
pure_prob: float = 0.5,
stddev: float = 2.0,
) -> None:
"""
Args:
num_inputs:
The number of tensor inputs, which equals the number of layers'
outputs that are fed into this module. E.g. in an 18-layer neural
net if we output layers 16, 12, 18, num_inputs would be 3.
final_weight:
The amount of weight or probability we assign to the
final layer when randomly choosing layers or when choosing
continuous layer weights.
pure_prob:
The probability, on each frame, with which we choose
only a single layer to output (rather than an interpolation)
stddev:
A standard deviation that we add to log-probs for computing
randomized weights.
The method of choosing which layers, or combinations of layers, to use,
is conceptually as follows::
With probability `pure_prob`::
With probability `final_weight`: choose final layer,
Else: choose random non-final layer.
Else::
Choose initial log-weights that correspond to assigning
weight `final_weight` to the final layer and equal
weights to other layers; then add Gaussian noise
with variance `stddev` to these log-weights, and normalize
to weights (note: the average weight assigned to the
final layer here will not be `final_weight` if stddev>0).
"""
super().__init__()
assert 0 <= pure_prob <= 1, pure_prob
assert 0 < final_weight < 1, final_weight
assert num_inputs >= 1
self.num_inputs = num_inputs
self.final_weight = final_weight
self.pure_prob = pure_prob
self.stddev = stddev
self.final_log_weight = (
torch.tensor(
(final_weight / (1 - final_weight)) * (self.num_inputs - 1)
)
.log()
.item()
)
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
"""Forward function.
Args:
inputs:
A list of Tensor, e.g. from various layers of a transformer.
All must be the same shape, of (*, num_channels)
Returns:
A Tensor of shape (*, num_channels). In test mode
this is just the final input.
"""
num_inputs = self.num_inputs
assert len(inputs) == num_inputs
if not self.training or torch.jit.is_scripting():
return inputs[-1]
# Shape of weights: (*, num_inputs)
num_channels = inputs[0].shape[-1]
num_frames = inputs[0].numel() // num_channels
ndim = inputs[0].ndim
# stacked_inputs: (num_frames, num_channels, num_inputs)
stacked_inputs = torch.stack(inputs, dim=ndim).reshape(
(num_frames, num_channels, num_inputs)
)
# weights: (num_frames, num_inputs)
weights = self._get_random_weights(
inputs[0].dtype, inputs[0].device, num_frames
)
weights = weights.reshape(num_frames, num_inputs, 1)
# ans: (num_frames, num_channels, 1)
ans = torch.matmul(stacked_inputs, weights)
# ans: (*, num_channels)
ans = ans.reshape(inputs[0].shape[:-1] + (num_channels,))
# The following if causes errors for torch script in torch 1.6.0
# if __name__ == "__main__":
# # for testing only...
# print("Weights = ", weights.reshape(num_frames, num_inputs))
return ans
def _get_random_weights(
self, dtype: torch.dtype, device: torch.device, num_frames: int
) -> torch.Tensor:
"""Return a tensor of random weights, of shape
`(num_frames, self.num_inputs)`,
Args:
dtype:
The data-type desired for the answer, e.g. float, double.
device:
The device needed for the answer.
num_frames:
The number of sets of weights desired
Returns:
A tensor of shape (num_frames, self.num_inputs), such that
`ans.sum(dim=1)` is all ones.
"""
pure_prob = self.pure_prob
if pure_prob == 0.0:
return self._get_random_mixed_weights(dtype, device, num_frames)
elif pure_prob == 1.0:
return self._get_random_pure_weights(dtype, device, num_frames)
else:
p = self._get_random_pure_weights(dtype, device, num_frames)
m = self._get_random_mixed_weights(dtype, device, num_frames)
return torch.where(
torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m
)
def _get_random_pure_weights(
self, dtype: torch.dtype, device: torch.device, num_frames: int
):
"""Return a tensor of random one-hot weights, of shape
`(num_frames, self.num_inputs)`,
Args:
dtype:
The data-type desired for the answer, e.g. float, double.
device:
The device needed for the answer.
num_frames:
The number of sets of weights desired.
Returns:
A one-hot tensor of shape `(num_frames, self.num_inputs)`, with
exactly one weight equal to 1.0 on each frame.
"""
final_prob = self.final_weight
# final contains self.num_inputs - 1 in all elements
final = torch.full((num_frames,), self.num_inputs - 1, device=device)
# nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. # noqa
nonfinal = torch.randint(
self.num_inputs - 1, (num_frames,), device=device
)
indexes = torch.where(
torch.rand(num_frames, device=device) < final_prob, final, nonfinal
)
ans = torch.nn.functional.one_hot(
indexes, num_classes=self.num_inputs
).to(dtype=dtype)
return ans
def _get_random_mixed_weights(
self, dtype: torch.dtype, device: torch.device, num_frames: int
):
"""Return a tensor of random one-hot weights, of shape
`(num_frames, self.num_inputs)`,
Args:
dtype:
The data-type desired for the answer, e.g. float, double.
device:
The device needed for the answer.
num_frames:
The number of sets of weights desired.
Returns:
A tensor of shape (num_frames, self.num_inputs), which elements
in [0..1] that sum to one over the second axis, i.e.
`ans.sum(dim=1)` is all ones.
"""
logprobs = (
torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device)
* self.stddev
)
logprobs[:, -1] += self.final_log_weight
return logprobs.softmax(dim=1)
def _test_random_combine(final_weight: float, pure_prob: float, stddev: float):
print(
f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}" # noqa
)
num_inputs = 3
num_channels = 50
m = RandomCombine(
num_inputs=num_inputs,
final_weight=final_weight,
pure_prob=pure_prob,
stddev=stddev,
)
x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)]
y = m(x)
assert y.shape == x[0].shape
assert torch.allclose(y, x[0]) # .. since actually all ones.
def _test_random_combine_main():
_test_random_combine(0.999, 0, 0.0)
_test_random_combine(0.5, 0, 0.0)
_test_random_combine(0.999, 0, 0.0)
_test_random_combine(0.5, 0, 0.3)
_test_random_combine(0.5, 1, 0.3)
_test_random_combine(0.5, 0.5, 0.3)
feature_dim = 50
c = RNN(num_features=feature_dim, d_model=128)
batch_size = 5
seq_len = 20
# Just make sure the forward pass runs.
f = c(
torch.randn(batch_size, seq_len, feature_dim),
torch.full((batch_size,), seq_len, dtype=torch.int64),
)
f # to remove flake8 warnings
if __name__ == "__main__":
feature_dim = 80
m = RNN(
num_features=feature_dim,
d_model=512,
rnn_hidden_size=1024,
dim_feedforward=2048,
num_encoder_layers=12,
)
batch_size = 5
seq_len = 20
# Just make sure the forward pass runs.
f = m(
torch.randn(batch_size, seq_len, feature_dim),
torch.full((batch_size,), seq_len, dtype=torch.int64),
warmup=0.5,
)
num_param = sum([p.numel() for p in m.parameters()])
print(f"Number of model parameters: {num_param}")
_test_random_combine_main()

View File

@ -0,0 +1 @@
../lstm_transducer_stateless/model.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/optim.py

View File

@ -0,0 +1,352 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./lstm_transducer_stateless3/pretrained.py \
--checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
(2) beam search
./lstm_transducer_stateless3/pretrained.py \
--checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(3) modified beam search
./lstm_transducer_stateless3/pretrained.py \
--checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(4) fast beam search
./lstm_transducer_stateless3/pretrained.py \
--checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./lstm_transducer_stateless3/exp/epoch-xx.pt`.
Note: ./lstm_transducer_stateless3/exp/pretrained.pt is generated by
./lstm_transducer_stateless3/export.py
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.""",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
add_model_arguments(parser)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens, _ = model.encoder(
x=features, x_lens=feature_lengths
)
num_waves = encoder_out.size(0)
hyps = []
msg = f"Using {params.method}"
if params.method == "beam_search":
msg += f" with beam size {params.beam_size}"
logging.info(msg)
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/scaling.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless3/scaling_converter.py

View File

@ -0,0 +1 @@
../lstm_transducer_stateless/stream.py

View File

@ -0,0 +1,968 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./lstm_transducer_stateless3/streaming_decode.py \
--epoch 40 \
--avg 20 \
--exp-dir lstm_transducer_stateless3/exp \
--num-decode-streams 2000 \
--num-encoder-layers 12 \
--rnn-hidden-size 1024 \
--decoding-method greedy_search \
--use-averaged-model True
(2) modified beam search
./lstm_transducer_stateless3/streaming_decode.py \
--epoch 40 \
--avg 20 \
--exp-dir lstm_transducer_stateless3/exp \
--num-decode-streams 2000 \
--num-encoder-layers 12 \
--rnn-hidden-size 1024 \
--decoding-method modified_beam_search \
--use-averaged-model True \
--beam-size 4
(3) fast beam search
./lstm_transducer_stateless3/streaming_decode.py \
--epoch 40 \
--avg 20 \
--exp-dir lstm_transducer_stateless3/exp \
--num-decode-streams 2000 \
--num-encoder-layers 12 \
--rnn-hidden-size 1024 \
--decoding-method fast_beam_search \
--use-averaged-model True \
--beam 4 \
--max-contexts 4 \
--max-states 8
"""
import argparse
import logging
import warnings
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import numpy as np
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
from lstm import LOG_EPSILON, stack_states, unstack_states
from stream import Stream
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.decode import one_best_decoding
from icefall.utils import (
AttributeDict,
get_texts,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=40,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=20,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=False,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="lstm_transducer_stateless3/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--sampling-rate",
type=float,
default=16000,
help="Sample rate of the audio",
)
parser.add_argument(
"--num-decode-streams",
type=int,
default=2000,
help="The number of streams that can be decoded in parallel",
)
add_model_arguments(parser)
return parser
def greedy_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[Stream],
) -> None:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C), where N >= 1.
streams:
A list of Stream objects.
"""
assert len(streams) == encoder_out.size(0)
assert encoder_out.ndim == 3
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = next(model.parameters()).device
T = encoder_out.size(1)
encoder_out = model.joiner.encoder_proj(encoder_out)
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
# decoder_out is of shape (batch_size, 1, decoder_out_dim)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
for t in range(T):
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
# logits'shape (batch_size, vocab_size)
logits = logits.squeeze(1).squeeze(1)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
streams[i].hyp.append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=False,
)
decoder_out = model.joiner.decoder_proj(decoder_out)
def modified_beam_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[Stream],
beam: int = 4,
):
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Args:
model:
The RNN-T model.
encoder_out:
A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of
the encoder model.
streams:
A list of stream objects.
beam:
Number of active paths during the beam search.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert len(streams) == encoder_out.size(0)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = next(model.parameters()).device
batch_size = len(streams)
T = encoder_out.size(1)
B = [stream.hyps for stream in streams]
encoder_out = model.joiner.encoder_proj(encoder_out)
for t in range(T):
current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.stack(
[hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out)
# decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, encoder_out_dim)
logits = model.joiner(
current_encoder_out, decoder_out, project_input=False
)
# logits is of shape (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(
shape=log_probs_shape, value=log_probs
)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
if new_token != blank_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp)
for i in range(batch_size):
streams[i].hyps = B[i]
def fast_beam_search_one_best(
model: nn.Module,
streams: List[Stream],
encoder_out: torch.Tensor,
processed_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
) -> None:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then
the shortest path within the lattice is used as the final output.
Args:
model:
An instance of `Transducer`.
streams:
A list of stream objects.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
processed_lens:
A tensor of shape (N,) containing the number of processed frames
in `encoder_out` before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
"""
assert encoder_out.ndim == 3
context_size = model.decoder.context_size
vocab_size = model.decoder.vocab_size
B, T, C = encoder_out.shape
assert B == len(streams)
config = k2.RnntDecodingConfig(
vocab_size=vocab_size,
decoder_history_len=context_size,
beam=beam,
max_contexts=max_contexts,
max_states=max_states,
)
individual_streams = []
for i in range(B):
individual_streams.append(streams[i].rnnt_decoding_stream)
decoding_streams = k2.RnntDecodingStreams(individual_streams, config)
encoder_out = model.joiner.encoder_proj(encoder_out)
for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# current_encoder_out is of shape
# (shape.NumElements(), 1, joiner_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
)
# fmt: on
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(processed_lens.tolist())
best_path = one_best_decoding(lattice)
hyps = get_texts(best_path)
for i in range(B):
streams[i].hyp = hyps[i]
def decode_one_chunk(
model: nn.Module,
streams: List[Stream],
params: AttributeDict,
decoding_graph: Optional[k2.Fsa] = None,
) -> List[int]:
"""
Args:
model:
The Transducer model.
streams:
A list of Stream objects.
params:
It is returned by :func:`get_params`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search.
Returns:
A list of indexes indicating the finished streams.
"""
device = next(model.parameters()).device
feature_list = []
feature_len_list = []
state_list = []
num_processed_frames_list = []
for stream in streams:
# We should first get `stream.num_processed_frames`
# before calling `stream.get_feature_chunk()`
# since `stream.num_processed_frames` would be updated
num_processed_frames_list.append(stream.num_processed_frames)
feature = stream.get_feature_chunk()
feature_len = feature.size(0)
feature_list.append(feature)
feature_len_list.append(feature_len)
state_list.append(stream.states)
features = pad_sequence(
feature_list, batch_first=True, padding_value=LOG_EPSILON
).to(device)
feature_lens = torch.tensor(feature_len_list, device=device)
num_processed_frames = torch.tensor(
num_processed_frames_list, device=device
)
# Make sure it has at least 1 frame after subsampling
tail_length = params.subsampling_factor + 5
if features.size(1) < tail_length:
pad_length = tail_length - features.size(1)
feature_lens += pad_length
features = torch.nn.functional.pad(
features,
(0, 0, 0, pad_length),
mode="constant",
value=LOG_EPSILON,
)
# Stack states of all streams
states = stack_states(state_list)
encoder_out, encoder_out_lens, states = model.encoder(
x=features,
x_lens=feature_lens,
states=states,
)
if params.decoding_method == "greedy_search":
greedy_search(
model=model,
streams=streams,
encoder_out=encoder_out,
)
elif params.decoding_method == "modified_beam_search":
modified_beam_search(
model=model,
streams=streams,
encoder_out=encoder_out,
beam=params.beam_size,
)
elif params.decoding_method == "fast_beam_search":
# feature_len is needed to get partial results.
# The rnnt_decoding_stream for fast_beam_search.
with warnings.catch_warnings():
warnings.simplefilter("ignore")
processed_lens = (
num_processed_frames // params.subsampling_factor
+ encoder_out_lens
)
fast_beam_search_one_best(
model=model,
streams=streams,
encoder_out=encoder_out,
processed_lens=processed_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
# Update cached states of each stream
state_list = unstack_states(states)
for i, s in enumerate(state_list):
streams[i].states = s
finished_streams = [i for i, stream in enumerate(streams) if stream.done]
return finished_streams
def create_streaming_feature_extractor() -> Fbank:
"""Create a CPU streaming feature extractor.
At present, we assume it returns a fbank feature extractor with
fixed options. In the future, we will support passing in the options
from outside.
Returns:
Return a CPU streaming feature extractor.
"""
opts = FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
return Fbank(opts)
def decode_dataset(
cuts: CutSet,
model: nn.Module,
params: AttributeDict,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
):
"""Decode dataset.
Args:
cuts:
Lhotse Cutset containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The Transducer model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
device = next(model.parameters()).device
log_interval = 300
fbank = create_streaming_feature_extractor()
decode_results = []
streams = []
for num, cut in enumerate(cuts):
# Each utterance has a Stream.
stream = Stream(
params=params,
cut_id=cut.id,
decoding_graph=decoding_graph,
device=device,
LOG_EPS=LOG_EPSILON,
)
stream.states = model.encoder.get_init_states(device=device)
audio: np.ndarray = cut.load_audio()
# audio.shape: (1, num_samples)
assert len(audio.shape) == 2
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
samples = torch.from_numpy(audio).squeeze(0)
feature = fbank(samples)
stream.set_feature(feature)
stream.ground_truth = cut.supervisions[0].text
streams.append(stream)
while len(streams) >= params.num_decode_streams:
finished_streams = decode_one_chunk(
model=model,
streams=streams,
params=params,
decoding_graph=decoding_graph,
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
streams[i].id,
streams[i].ground_truth.split(),
sp.decode(streams[i].decoding_result()).split(),
)
)
del streams[i]
if num % log_interval == 0:
logging.info(f"Cuts processed until now is {num}.")
while len(streams) > 0:
finished_streams = decode_one_chunk(
model=model,
streams=streams,
params=params,
decoding_graph=decoding_graph,
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
streams[i].id,
streams[i].ground_truth.split(),
sp.decode(streams[i].decoding_result()).split(),
)
)
del streams[i]
if params.decoding_method == "greedy_search":
key = "greedy_search"
elif params.decoding_method == "fast_beam_search":
key = (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
)
else:
key = f"beam_size_{params.beam_size}"
return {key: decode_results}
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=sorted(results))
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"fast_beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-streaming-decode")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.device = device
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.eval()
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_sets = ["test-clean", "test-other"]
test_cuts = [test_clean_cuts, test_other_cuts]
for test_set, test_cut in zip(test_sets, test_cuts):
results_dict = decode_dataset(
cuts=test_cut,
model=model,
params=params,
sp=sp,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
torch.manual_seed(20220810)
main()

View File

@ -0,0 +1,92 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./lstm_transducer_stateless/test_model.py
"""
import os
from pathlib import Path
import torch
from export import (
export_decoder_model_jit_trace,
export_encoder_model_jit_trace,
export_joiner_model_jit_trace,
)
from lstm import stack_states, unstack_states
from scaling_converter import convert_scaled_to_non_scaled
from train import get_params, get_transducer_model
def test_model():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.unk_id = 2
params.encoder_dim = 512
params.rnn_hidden_size = 1024
params.num_encoder_layers = 12
params.aux_layer_period = 0
params.exp_dir = Path("exp_test_model")
model = get_transducer_model(params)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
convert_scaled_to_non_scaled(model, inplace=True)
if not os.path.exists(params.exp_dir):
os.path.mkdir(params.exp_dir)
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
export_encoder_model_jit_trace(model.encoder, encoder_filename)
decoder_filename = params.exp_dir / "decoder_jit_trace.pt"
export_decoder_model_jit_trace(model.decoder, decoder_filename)
joiner_filename = params.exp_dir / "joiner_jit_trace.pt"
export_joiner_model_jit_trace(model.joiner, joiner_filename)
print("The model has been successfully exported using jit.trace.")
def test_states_stack_and_unstack():
layer, batch, hidden, cell = 12, 100, 512, 1024
states = (
torch.randn(layer, batch, hidden),
torch.randn(layer, batch, cell),
)
states2 = stack_states(unstack_states(states))
assert torch.allclose(states[0], states2[0])
assert torch.allclose(states[1], states2[1])
def main():
test_model()
test_states_stack_and_unstack()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,257 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./lstm_transducer_stateless/test_scaling_converter.py
"""
import copy
import torch
from scaling import (
ScaledConv1d,
ScaledConv2d,
ScaledEmbedding,
ScaledLinear,
ScaledLSTM,
)
from scaling_converter import (
convert_scaled_to_non_scaled,
scaled_conv1d_to_conv1d,
scaled_conv2d_to_conv2d,
scaled_embedding_to_embedding,
scaled_linear_to_linear,
scaled_lstm_to_lstm,
)
from train import get_params, get_transducer_model
def get_model():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.unk_id = 2
params.encoder_dim = 512
params.rnn_hidden_size = 1024
params.num_encoder_layers = 12
params.aux_layer_period = -1
model = get_transducer_model(params)
return model
def test_scaled_linear_to_linear():
N = 5
in_features = 10
out_features = 20
for bias in [True, False]:
scaled_linear = ScaledLinear(
in_features=in_features,
out_features=out_features,
bias=bias,
)
linear = scaled_linear_to_linear(scaled_linear)
x = torch.rand(N, in_features)
y1 = scaled_linear(x)
y2 = linear(x)
assert torch.allclose(y1, y2)
jit_scaled_linear = torch.jit.script(scaled_linear)
jit_linear = torch.jit.script(linear)
y3 = jit_scaled_linear(x)
y4 = jit_linear(x)
assert torch.allclose(y3, y4)
assert torch.allclose(y1, y4)
def test_scaled_conv1d_to_conv1d():
in_channels = 3
for bias in [True, False]:
scaled_conv1d = ScaledConv1d(
in_channels,
6,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
conv1d = scaled_conv1d_to_conv1d(scaled_conv1d)
x = torch.rand(20, in_channels, 10)
y1 = scaled_conv1d(x)
y2 = conv1d(x)
assert torch.allclose(y1, y2)
jit_scaled_conv1d = torch.jit.script(scaled_conv1d)
jit_conv1d = torch.jit.script(conv1d)
y3 = jit_scaled_conv1d(x)
y4 = jit_conv1d(x)
assert torch.allclose(y3, y4)
assert torch.allclose(y1, y4)
def test_scaled_conv2d_to_conv2d():
in_channels = 1
for bias in [True, False]:
scaled_conv2d = ScaledConv2d(
in_channels=in_channels,
out_channels=3,
kernel_size=3,
padding=1,
bias=bias,
)
conv2d = scaled_conv2d_to_conv2d(scaled_conv2d)
x = torch.rand(20, in_channels, 10, 20)
y1 = scaled_conv2d(x)
y2 = conv2d(x)
assert torch.allclose(y1, y2)
jit_scaled_conv2d = torch.jit.script(scaled_conv2d)
jit_conv2d = torch.jit.script(conv2d)
y3 = jit_scaled_conv2d(x)
y4 = jit_conv2d(x)
assert torch.allclose(y3, y4)
assert torch.allclose(y1, y4)
def test_scaled_embedding_to_embedding():
scaled_embedding = ScaledEmbedding(
num_embeddings=500,
embedding_dim=10,
padding_idx=0,
)
embedding = scaled_embedding_to_embedding(scaled_embedding)
for s in [10, 100, 300, 500, 800, 1000]:
x = torch.randint(low=0, high=500, size=(s,))
scaled_y = scaled_embedding(x)
y = embedding(x)
assert torch.equal(scaled_y, y)
def test_scaled_lstm_to_lstm():
input_size = 512
batch_size = 20
for bias in [True, False]:
for hidden_size in [512, 1024]:
scaled_lstm = ScaledLSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=1,
bias=bias,
proj_size=0 if hidden_size == input_size else input_size,
)
lstm = scaled_lstm_to_lstm(scaled_lstm)
x = torch.rand(200, batch_size, input_size)
h0 = torch.randn(1, batch_size, input_size)
c0 = torch.randn(1, batch_size, hidden_size)
y1, (h1, c1) = scaled_lstm(x, (h0, c0))
y2, (h2, c2) = lstm(x, (h0, c0))
assert torch.allclose(y1, y2)
assert torch.allclose(h1, h2)
assert torch.allclose(c1, c2)
jit_scaled_lstm = torch.jit.trace(lstm, (x, (h0, c0)))
y3, (h3, c3) = jit_scaled_lstm(x, (h0, c0))
assert torch.allclose(y1, y3)
assert torch.allclose(h1, h3)
assert torch.allclose(c1, c3)
def test_convert_scaled_to_non_scaled():
for inplace in [False, True]:
model = get_model()
model.eval()
orig_model = copy.deepcopy(model)
converted_model = convert_scaled_to_non_scaled(model, inplace=inplace)
model = orig_model
# test encoder
N = 2
T = 100
vocab_size = model.decoder.vocab_size
x = torch.randn(N, T, 80, dtype=torch.float32)
x_lens = torch.full((N,), x.size(1))
e1, e1_lens, _ = model.encoder(x, x_lens)
e2, e2_lens, _ = converted_model.encoder(x, x_lens)
assert torch.all(torch.eq(e1_lens, e2_lens))
assert torch.allclose(e1, e2), (e1 - e2).abs().max()
# test decoder
U = 50
y = torch.randint(low=1, high=vocab_size - 1, size=(N, U))
d1 = model.decoder(y)
d2 = model.decoder(y)
assert torch.allclose(d1, d2)
# test simple projection
lm1 = model.simple_lm_proj(d1)
am1 = model.simple_am_proj(e1)
lm2 = converted_model.simple_lm_proj(d2)
am2 = converted_model.simple_am_proj(e2)
assert torch.allclose(lm1, lm2)
assert torch.allclose(am1, am2)
# test joiner
e = torch.rand(2, 3, 4, 512)
d = torch.rand(2, 3, 4, 512)
j1 = model.joiner(e, d)
j2 = converted_model.joiner(e, d)
assert torch.allclose(j1, j2)
@torch.no_grad()
def main():
test_scaled_linear_to_linear()
test_scaled_conv1d_to_conv1d()
test_scaled_conv2d_to_conv2d()
test_scaled_embedding_to_embedding()
test_scaled_lstm_to_lstm()
test_convert_scaled_to_non_scaled()
if __name__ == "__main__":
torch.manual_seed(20220730)
main()

File diff suppressed because it is too large Load Diff

View File

@ -476,8 +476,8 @@ class ConformerEncoderLayer(nn.Module):
self,
src: Tensor,
pos_emb: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
src_mask: Optional[Tensor] = None,
warmup: float = 1.0,
) -> Tensor:
"""
@ -486,8 +486,8 @@ class ConformerEncoderLayer(nn.Module):
Args:
src: the sequence to the encoder layer (required).
pos_emb: Positional embedding tensor (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
src_mask: the mask for the src sequence (optional).
warmup: controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
Shape:
@ -527,7 +527,9 @@ class ConformerEncoderLayer(nn.Module):
src = src + self.dropout(src_att)
# convolution module
conv, _ = self.conv_module(src)
conv, _ = self.conv_module(
src, src_key_padding_mask=src_key_padding_mask
)
src = src + self.dropout(conv)
# feed forward module
@ -661,8 +663,8 @@ class ConformerEncoder(nn.Module):
self,
src: Tensor,
pos_emb: Tensor,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
warmup: float = 1.0,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
@ -670,8 +672,8 @@ class ConformerEncoder(nn.Module):
Args:
src: the sequence to the encoder (required).
pos_emb: Positional embedding tensor (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
mask: the mask for the src sequence (optional).
warmup: controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
@ -930,7 +932,7 @@ class RelPositionMultiheadAttention(nn.Module):
value: Tensor,
pos_emb: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
need_weights: bool = False,
attn_mask: Optional[Tensor] = None,
left_context: int = 0,
) -> Tuple[Tensor, Optional[Tensor]]:
@ -1057,7 +1059,7 @@ class RelPositionMultiheadAttention(nn.Module):
out_proj_bias: Tensor,
training: bool = True,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
need_weights: bool = False,
attn_mask: Optional[Tensor] = None,
left_context: int = 0,
) -> Tuple[Tensor, Optional[Tensor]]:
@ -1457,6 +1459,7 @@ class ConvolutionModule(nn.Module):
x: Tensor,
cache: Optional[Tensor] = None,
right_context: int = 0,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
"""Compute convolution module.
@ -1467,6 +1470,7 @@ class ConvolutionModule(nn.Module):
right_context:
How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames
src_key_padding_mask: the mask for the src keys per batch (optional).
of right context, some have more.
Returns:
@ -1486,6 +1490,8 @@ class ConvolutionModule(nn.Module):
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
if src_key_padding_mask is not None:
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
if self.causal and self.lorder > 0:
if cache is None:
# Make depthwise_conv causal by

View File

@ -16,6 +16,7 @@
import collections
import random
from itertools import repeat
from typing import Optional, Tuple
@ -111,6 +112,76 @@ class ActivationBalancerFunction(torch.autograd.Function):
return x_grad - neg_delta_grad, None, None, None, None, None, None
class GradientFilterFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x: Tensor,
batch_dim: int, # e.g., 1
threshold: float, # e.g., 10.0
*params: Tensor, # module parameters
) -> Tuple[Tensor, ...]:
if x.requires_grad:
if batch_dim < 0:
batch_dim += x.ndim
ctx.batch_dim = batch_dim
ctx.threshold = threshold
return (x,) + params
@staticmethod
def backward(
ctx,
x_grad: Tensor,
*param_grads: Tensor,
) -> Tuple[Tensor, ...]:
eps = 1.0e-20
dim = ctx.batch_dim
norm_dims = [d for d in range(x_grad.ndim) if d != dim]
norm_of_batch = (x_grad ** 2).mean(dim=norm_dims, keepdim=True).sqrt()
median_norm = norm_of_batch.median()
cutoff = median_norm * ctx.threshold
inv_mask = (cutoff + norm_of_batch) / (cutoff + eps)
mask = 1.0 / (inv_mask + eps)
x_grad = x_grad * mask
avg_mask = 1.0 / (inv_mask.mean() + eps)
param_grads = [avg_mask * g for g in param_grads]
return (x_grad, None, None) + tuple(param_grads)
class GradientFilter(torch.nn.Module):
"""This is used to filter out elements that have extremely large gradients
in batch and the module parameters with soft masks.
Args:
batch_dim (int):
The batch dimension.
threshold (float):
For each element in batch, its gradient will be
filtered out if the gradient norm is larger than
`grad_norm_threshold * median`, where `median` is the median
value of gradient norms of all elememts in batch.
"""
def __init__(self, batch_dim: int = 1, threshold: float = 10.0):
super(GradientFilter, self).__init__()
self.batch_dim = batch_dim
self.threshold = threshold
def forward(self, x: Tensor, *params: Tensor) -> Tuple[Tensor, ...]:
if torch.jit.is_scripting() or is_jit_tracing():
return (x,) + params
else:
return GradientFilterFunction.apply(
x,
self.batch_dim,
self.threshold,
*params,
)
class BasicNorm(torch.nn.Module):
"""
This is intended to be a simpler, and hopefully cheaper, replacement for
@ -195,7 +266,7 @@ class ScaledLinear(nn.Linear):
*args,
initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs
**kwargs,
):
super(ScaledLinear, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log()
@ -242,7 +313,7 @@ class ScaledConv1d(nn.Conv1d):
*args,
initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs
**kwargs,
):
super(ScaledConv1d, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log()
@ -314,7 +385,7 @@ class ScaledConv2d(nn.Conv2d):
*args,
initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs
**kwargs,
):
super(ScaledConv2d, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log()
@ -389,7 +460,8 @@ class ScaledLSTM(nn.LSTM):
*args,
initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs
grad_norm_threshold: float = 10.0,
**kwargs,
):
if "bidirectional" in kwargs:
assert kwargs["bidirectional"] is False
@ -404,6 +476,10 @@ class ScaledLSTM(nn.LSTM):
setattr(self, scale_name, param)
self._scales.append(param)
self.grad_filter = GradientFilter(
batch_dim=1, threshold=grad_norm_threshold
)
self._reset_parameters(
initial_speed
) # Overrides the reset_parameters in base class
@ -513,10 +589,14 @@ class ScaledLSTM(nn.LSTM):
hx = (h_zeros, c_zeros)
self.check_forward_args(input, hx, None)
flat_weights = self._get_flat_weights()
input, *flat_weights = self.grad_filter(input, *flat_weights)
result = _VF.lstm(
input,
hx,
self._get_flat_weights(),
flat_weights,
self.bias,
self.num_layers,
self.dropout,
@ -557,6 +637,7 @@ class ActivationBalancer(torch.nn.Module):
max_abs: the maximum average-absolute-value per channel, which
we allow, before we start to modify the derivatives to prevent
this.
balance_prob: the probability to apply the ActivationBalancer.
"""
def __init__(
@ -567,6 +648,7 @@ class ActivationBalancer(torch.nn.Module):
max_factor: float = 0.01,
min_abs: float = 0.2,
max_abs: float = 100.0,
balance_prob: float = 0.25,
):
super(ActivationBalancer, self).__init__()
self.channel_dim = channel_dim
@ -575,9 +657,11 @@ class ActivationBalancer(torch.nn.Module):
self.max_factor = max_factor
self.min_abs = min_abs
self.max_abs = max_abs
assert 0 < balance_prob <= 1, balance_prob
self.balance_prob = balance_prob
def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting() or is_jit_tracing():
if random.random() >= self.balance_prob:
return x
else:
return ActivationBalancerFunction.apply(
@ -585,7 +669,7 @@ class ActivationBalancer(torch.nn.Module):
self.channel_dim,
self.min_positive,
self.max_positive,
self.max_factor,
self.max_factor / self.balance_prob,
self.min_abs,
self.max_abs,
)
@ -891,9 +975,54 @@ def _test_scaled_lstm():
assert c.shape == (1, N, dim_hidden)
def _test_grad_filter():
threshold = 50.0
time, batch, channel = 200, 5, 128
grad_filter = GradientFilter(batch_dim=1, threshold=threshold)
for i in range(2):
x = torch.randn(time, batch, channel, requires_grad=True)
w = nn.Parameter(torch.ones(5))
b = nn.Parameter(torch.zeros(5))
x_out, w_out, b_out = grad_filter(x, w, b)
w_out_grad = torch.randn_like(w)
b_out_grad = torch.randn_like(b)
x_out_grad = torch.rand_like(x)
if i % 2 == 1:
# The gradient norm of the first element must be larger than
# `threshold * median`, where `median` is the median value
# of gradient norms of all elements in batch.
x_out_grad[:, 0, :] = torch.full((time, channel), threshold)
torch.autograd.backward(
[x_out, w_out, b_out], [x_out_grad, w_out_grad, b_out_grad]
)
print(
"_test_grad_filter: for gradient norms, the first element > median * threshold ", # noqa
i % 2 == 1,
)
print(
"_test_grad_filter: x_out_grad norm = ",
(x_out_grad ** 2).mean(dim=(0, 2)).sqrt(),
)
print(
"_test_grad_filter: x.grad norm = ",
(x.grad ** 2).mean(dim=(0, 2)).sqrt(),
)
print("_test_grad_filter: w_out_grad = ", w_out_grad)
print("_test_grad_filter: w.grad = ", w.grad)
print("_test_grad_filter: b_out_grad = ", b_out_grad)
print("_test_grad_filter: b.grad = ", b.grad)
if __name__ == "__main__":
_test_activation_balancer_sign()
_test_activation_balancer_magnitude()
_test_basic_norm()
_test_double_swish_deriv()
_test_scaled_lstm()
_test_grad_filter()

View File

@ -62,13 +62,20 @@ It will generates 3 files: `encoder_jit_trace.pt`,
--avg 10 \
--onnx 1
It will generate the following three files in the given `exp_dir`.
It will generate the following files in the given `exp_dir`.
Check `onnx_check.py` for how to use them.
- encoder.onnx
- decoder.onnx
- joiner.onnx
- joiner_encoder_proj.onnx
- joiner_decoder_proj.onnx
Please see ./onnx_pretrained.py for usage of the generated files
Check
https://github.com/k2-fsa/sherpa-onnx
for how to use the exported models outside of icefall.
(4) Export `model.state_dict()`
@ -115,7 +122,6 @@ import argparse
import logging
from pathlib import Path
import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
@ -213,13 +219,15 @@ def get_parser():
type=str2bool,
default=False,
help="""If True, --jit is ignored and it exports the model
to onnx format. Three files will be generated:
to onnx format. It will generate the following files:
- encoder.onnx
- decoder.onnx
- joiner.onnx
- joiner_encoder_proj.onnx
- joiner_decoder_proj.onnx
Check ./onnx_check.py and ./onnx_pretrained.py for how to use them.
Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them.
""",
)
@ -476,65 +484,99 @@ def export_joiner_model_onnx(
opset_version: int = 11,
) -> None:
"""Export the joiner model to ONNX format.
The exported model has two inputs:
The exported joiner model has two inputs:
- projected_encoder_out: a tensor of shape (N, joiner_dim)
- projected_decoder_out: a tensor of shape (N, joiner_dim)
and produces one output:
- logit: a tensor of shape (N, vocab_size)
The exported encoder_proj model has one input:
- encoder_out: a tensor of shape (N, encoder_out_dim)
and produces one output:
- projected_encoder_out: a tensor of shape (N, joiner_dim)
The exported decoder_proj model has one input:
- decoder_out: a tensor of shape (N, decoder_out_dim)
and has one output:
and produces one output:
- joiner_out: a tensor of shape (N, vocab_size)
Note: The argument project_input is fixed to True. A user should not
project the encoder_out/decoder_out by himself/herself. The exported joiner
will do that for the user.
- projected_decoder_out: a tensor of shape (N, joiner_dim)
"""
encoder_proj_filename = str(joiner_filename).replace(
".onnx", "_encoder_proj.onnx"
)
decoder_proj_filename = str(joiner_filename).replace(
".onnx", "_decoder_proj.onnx"
)
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
joiner_dim = joiner_model.decoder_proj.weight.shape[0]
project_input = True
projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32)
projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32)
project_input = False
# Note: It uses torch.jit.trace() internally
torch.onnx.export(
joiner_model,
(encoder_out, decoder_out, project_input),
(projected_encoder_out, projected_decoder_out, project_input),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=["encoder_out", "decoder_out", "project_input"],
input_names=[
"projected_encoder_out",
"projected_decoder_out",
"project_input",
],
output_names=["logit"],
dynamic_axes={
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
"projected_encoder_out": {0: "N"},
"projected_decoder_out": {0: "N"},
"logit": {0: "N"},
},
)
logging.info(f"Saved to {joiner_filename}")
def export_all_in_one_onnx(
encoder_filename: str,
decoder_filename: str,
joiner_filename: str,
all_in_one_filename: str,
):
encoder_onnx = onnx.load(encoder_filename)
decoder_onnx = onnx.load(decoder_filename)
joiner_onnx = onnx.load(joiner_filename)
encoder_onnx = onnx.compose.add_prefix(encoder_onnx, prefix="encoder/")
decoder_onnx = onnx.compose.add_prefix(decoder_onnx, prefix="decoder/")
joiner_onnx = onnx.compose.add_prefix(joiner_onnx, prefix="joiner/")
combined_model = onnx.compose.merge_models(
encoder_onnx, decoder_onnx, io_map={}
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model.encoder_proj,
encoder_out,
encoder_proj_filename,
verbose=False,
opset_version=opset_version,
input_names=["encoder_out"],
output_names=["projected_encoder_out"],
dynamic_axes={
"encoder_out": {0: "N"},
"projected_encoder_out": {0: "N"},
},
)
combined_model = onnx.compose.merge_models(
combined_model, joiner_onnx, io_map={}
logging.info(f"Saved to {encoder_proj_filename}")
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model.decoder_proj,
decoder_out,
decoder_proj_filename,
verbose=False,
opset_version=opset_version,
input_names=["decoder_out"],
output_names=["projected_decoder_out"],
dynamic_axes={
"decoder_out": {0: "N"},
"projected_decoder_out": {0: "N"},
},
)
onnx.save(combined_model, all_in_one_filename)
logging.info(f"Saved to {all_in_one_filename}")
logging.info(f"Saved to {decoder_proj_filename}")
@torch.no_grad()
@ -628,14 +670,6 @@ def main():
joiner_filename,
opset_version=opset_version,
)
all_in_one_filename = params.exp_dir / "all_in_one.onnx"
export_all_in_one_onnx(
encoder_filename,
decoder_filename,
joiner_filename,
all_in_one_filename,
)
elif params.jit is True:
convert_scaled_to_non_scaled(model, inplace=True)
logging.info("Using torch.jit.script()")

View File

@ -63,6 +63,20 @@ def get_parser():
help="Path to the onnx joiner model",
)
parser.add_argument(
"--onnx-joiner-encoder-proj-filename",
required=True,
type=str,
help="Path to the onnx joiner encoder projection model",
)
parser.add_argument(
"--onnx-joiner-decoder-proj-filename",
required=True,
type=str,
help="Path to the onnx joiner decoder projection model",
)
return parser
@ -70,11 +84,13 @@ def test_encoder(
model: torch.jit.ScriptModule,
encoder_session: ort.InferenceSession,
):
encoder_inputs = encoder_session.get_inputs()
assert encoder_inputs[0].name == "x"
assert encoder_inputs[1].name == "x_lens"
assert encoder_inputs[0].shape == ["N", "T", 80]
assert encoder_inputs[1].shape == ["N"]
inputs = encoder_session.get_inputs()
outputs = encoder_session.get_outputs()
input_names = [n.name for n in inputs]
output_names = [n.name for n in outputs]
assert inputs[0].shape == ["N", "T", 80]
assert inputs[1].shape == ["N"]
for N in [1, 5]:
for T in [12, 25]:
@ -84,11 +100,11 @@ def test_encoder(
x_lens[0] = T
encoder_inputs = {
"x": x.numpy(),
"x_lens": x_lens.numpy(),
input_names[0]: x.numpy(),
input_names[1]: x_lens.numpy(),
}
encoder_out, encoder_out_lens = encoder_session.run(
["encoder_out", "encoder_out_lens"],
output_names,
encoder_inputs,
)
@ -96,7 +112,9 @@ def test_encoder(
encoder_out = torch.from_numpy(encoder_out)
assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), (
(encoder_out - torch_encoder_out).abs().max()
(encoder_out - torch_encoder_out).abs().max(),
encoder_out.shape,
torch_encoder_out.shape,
)
@ -104,15 +122,18 @@ def test_decoder(
model: torch.jit.ScriptModule,
decoder_session: ort.InferenceSession,
):
decoder_inputs = decoder_session.get_inputs()
assert decoder_inputs[0].name == "y"
assert decoder_inputs[0].shape == ["N", 2]
inputs = decoder_session.get_inputs()
outputs = decoder_session.get_outputs()
input_names = [n.name for n in inputs]
output_names = [n.name for n in outputs]
assert inputs[0].shape == ["N", 2]
for N in [1, 5, 10]:
y = torch.randint(low=1, high=500, size=(10, 2))
decoder_inputs = {"y": y.numpy()}
decoder_inputs = {input_names[0]: y.numpy()}
decoder_out = decoder_session.run(
["decoder_out"],
output_names,
decoder_inputs,
)[0]
decoder_out = torch.from_numpy(decoder_out)
@ -126,34 +147,92 @@ def test_decoder(
def test_joiner(
model: torch.jit.ScriptModule,
joiner_session: ort.InferenceSession,
joiner_encoder_proj_session: ort.InferenceSession,
joiner_decoder_proj_session: ort.InferenceSession,
):
joiner_inputs = joiner_session.get_inputs()
assert joiner_inputs[0].name == "encoder_out"
assert joiner_inputs[0].shape == ["N", 512]
joiner_outputs = joiner_session.get_outputs()
joiner_input_names = [n.name for n in joiner_inputs]
joiner_output_names = [n.name for n in joiner_outputs]
assert joiner_inputs[1].name == "decoder_out"
assert joiner_inputs[0].shape == ["N", 512]
assert joiner_inputs[1].shape == ["N", 512]
joiner_encoder_proj_inputs = joiner_encoder_proj_session.get_inputs()
encoder_proj_input_name = joiner_encoder_proj_inputs[0].name
assert joiner_encoder_proj_inputs[0].shape == ["N", 512]
joiner_encoder_proj_outputs = joiner_encoder_proj_session.get_outputs()
encoder_proj_output_name = joiner_encoder_proj_outputs[0].name
joiner_decoder_proj_inputs = joiner_decoder_proj_session.get_inputs()
decoder_proj_input_name = joiner_decoder_proj_inputs[0].name
assert joiner_decoder_proj_inputs[0].shape == ["N", 512]
joiner_decoder_proj_outputs = joiner_decoder_proj_session.get_outputs()
decoder_proj_output_name = joiner_decoder_proj_outputs[0].name
for N in [1, 5, 10]:
encoder_out = torch.rand(N, 512)
decoder_out = torch.rand(N, 512)
projected_encoder_out = torch.rand(N, 512)
projected_decoder_out = torch.rand(N, 512)
joiner_inputs = {
"encoder_out": encoder_out.numpy(),
"decoder_out": decoder_out.numpy(),
joiner_input_names[0]: projected_encoder_out.numpy(),
joiner_input_names[1]: projected_decoder_out.numpy(),
}
joiner_out = joiner_session.run(["logit"], joiner_inputs)[0]
joiner_out = joiner_session.run(joiner_output_names, joiner_inputs)[0]
joiner_out = torch.from_numpy(joiner_out)
torch_joiner_out = model.joiner(
encoder_out,
decoder_out,
project_input=True,
projected_encoder_out,
projected_decoder_out,
project_input=False,
)
assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), (
(joiner_out - torch_joiner_out).abs().max()
)
# Now test encoder_proj
joiner_encoder_proj_inputs = {
encoder_proj_input_name: encoder_out.numpy()
}
joiner_encoder_proj_out = joiner_encoder_proj_session.run(
[encoder_proj_output_name], joiner_encoder_proj_inputs
)[0]
joiner_encoder_proj_out = torch.from_numpy(joiner_encoder_proj_out)
torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out)
assert torch.allclose(
joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5
), (
(joiner_encoder_proj_out - torch_joiner_encoder_proj_out)
.abs()
.max()
)
# Now test decoder_proj
joiner_decoder_proj_inputs = {
decoder_proj_input_name: decoder_out.numpy()
}
joiner_decoder_proj_out = joiner_decoder_proj_session.run(
[decoder_proj_output_name], joiner_decoder_proj_inputs
)[0]
joiner_decoder_proj_out = torch.from_numpy(joiner_decoder_proj_out)
torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out)
assert torch.allclose(
joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5
), (
(joiner_decoder_proj_out - torch_joiner_decoder_proj_out)
.abs()
.max()
)
@torch.no_grad()
def main():
@ -185,7 +264,20 @@ def main():
args.onnx_joiner_filename,
sess_options=options,
)
test_joiner(model, joiner_session)
joiner_encoder_proj_session = ort.InferenceSession(
args.onnx_joiner_encoder_proj_filename,
sess_options=options,
)
joiner_decoder_proj_session = ort.InferenceSession(
args.onnx_joiner_decoder_proj_filename,
sess_options=options,
)
test_joiner(
model,
joiner_session,
joiner_encoder_proj_session,
joiner_decoder_proj_session,
)
logging.info("Finished checking ONNX models")

View File

@ -1,284 +0,0 @@
#!/usr/bin/env python3
#
# Copyright 2022 Xiaomi Corporation (Author: Yunus Emre Ozkose)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script checks that exported onnx models produce the same output
with the given torchscript model for the same input.
"""
import argparse
import logging
import os
import onnx
import onnx_graphsurgeon as gs
import onnxruntime
import onnxruntime as ort
import torch
ort.set_default_logger_severity(3)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--jit-filename",
required=True,
type=str,
help="Path to the torchscript model",
)
parser.add_argument(
"--onnx-all-in-one-filename",
required=True,
type=str,
help="Path to the onnx all in one model",
)
return parser
def test_encoder(
model: torch.jit.ScriptModule,
encoder_session: ort.InferenceSession,
):
encoder_inputs = encoder_session.get_inputs()
assert encoder_inputs[0].shape == ["N", "T", 80]
assert encoder_inputs[1].shape == ["N"]
encoder_input_names = [i.name for i in encoder_inputs]
encoder_output_names = [i.name for i in encoder_session.get_outputs()]
for N in [1, 5]:
for T in [12, 25]:
print("N, T", N, T)
x = torch.rand(N, T, 80, dtype=torch.float32)
x_lens = torch.randint(low=10, high=T + 1, size=(N,))
x_lens[0] = T
encoder_inputs = {
encoder_input_names[0]: x.numpy(),
encoder_input_names[1]: x_lens.numpy(),
}
encoder_out, encoder_out_lens = encoder_session.run(
[encoder_output_names[1], encoder_output_names[0]],
encoder_inputs,
)
torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens)
encoder_out = torch.from_numpy(encoder_out)
assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), (
(encoder_out - torch_encoder_out).abs().max()
)
def test_decoder(
model: torch.jit.ScriptModule,
decoder_session: ort.InferenceSession,
):
decoder_inputs = decoder_session.get_inputs()
assert decoder_inputs[0].shape == ["N", 2]
decoder_input_names = [i.name for i in decoder_inputs]
decoder_output_names = [i.name for i in decoder_session.get_outputs()]
for N in [1, 5, 10]:
y = torch.randint(low=1, high=500, size=(10, 2))
decoder_inputs = {decoder_input_names[0]: y.numpy()}
decoder_out = decoder_session.run(
[decoder_output_names[0]],
decoder_inputs,
)[0]
decoder_out = torch.from_numpy(decoder_out)
torch_decoder_out = model.decoder(y, need_pad=False)
assert torch.allclose(decoder_out, torch_decoder_out, atol=1e-5), (
(decoder_out - torch_decoder_out).abs().max()
)
def test_joiner(
model: torch.jit.ScriptModule,
joiner_session: ort.InferenceSession,
):
joiner_inputs = joiner_session.get_inputs()
assert joiner_inputs[0].shape == ["N", 512]
assert joiner_inputs[1].shape == ["N", 512]
joiner_input_names = [i.name for i in joiner_inputs]
joiner_output_names = [i.name for i in joiner_session.get_outputs()]
for N in [1, 5, 10]:
encoder_out = torch.rand(N, 512)
decoder_out = torch.rand(N, 512)
joiner_inputs = {
joiner_input_names[0]: encoder_out.numpy(),
joiner_input_names[1]: decoder_out.numpy(),
}
joiner_out = joiner_session.run(
[joiner_output_names[0]], joiner_inputs
)[0]
joiner_out = torch.from_numpy(joiner_out)
torch_joiner_out = model.joiner(
encoder_out,
decoder_out,
project_input=True,
)
assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), (
(joiner_out - torch_joiner_out).abs().max()
)
def extract_sub_model(
onnx_graph: onnx.ModelProto,
input_op_names: list,
output_op_names: list,
non_verbose=False,
):
onnx_graph = onnx.shape_inference.infer_shapes(onnx_graph)
graph = gs.import_onnx(onnx_graph)
graph.cleanup().toposort()
# Extraction of input OP and output OP
graph_node_inputs = [
graph_nodes
for graph_nodes in graph.nodes
for graph_nodes_input in graph_nodes.inputs
if graph_nodes_input.name in input_op_names
]
graph_node_outputs = [
graph_nodes
for graph_nodes in graph.nodes
for graph_nodes_output in graph_nodes.outputs
if graph_nodes_output.name in output_op_names
]
# Init graph INPUT/OUTPUT
graph.inputs.clear()
graph.outputs.clear()
# Update graph INPUT/OUTPUT
graph.inputs = [
graph_node_input
for graph_node in graph_node_inputs
for graph_node_input in graph_node.inputs
if graph_node_input.shape
]
graph.outputs = [
graph_node_output
for graph_node in graph_node_outputs
for graph_node_output in graph_node.outputs
]
# Cleanup
graph.cleanup().toposort()
# Shape Estimation
extracted_graph = None
try:
extracted_graph = onnx.shape_inference.infer_shapes(
gs.export_onnx(graph)
)
except Exception:
extracted_graph = gs.export_onnx(graph)
if not non_verbose:
print(
"WARNING: "
+ "The input shape of the next OP does not match the output shape. "
+ "Be sure to open the .onnx file to verify the certainty of the geometry."
)
return extracted_graph
def extract_encoder(onnx_model: onnx.ModelProto):
encoder_ = extract_sub_model(
onnx_model,
["encoder/x", "encoder/x_lens"],
["encoder/encoder_out", "encoder/encoder_out_lens"],
False,
)
onnx.save(encoder_, "tmp_encoder.onnx")
onnx.checker.check_model(encoder_)
sess = onnxruntime.InferenceSession("tmp_encoder.onnx")
os.remove("tmp_encoder.onnx")
return sess
def extract_decoder(onnx_model: onnx.ModelProto):
decoder_ = extract_sub_model(
onnx_model, ["decoder/y"], ["decoder/decoder_out"], False
)
onnx.save(decoder_, "tmp_decoder.onnx")
onnx.checker.check_model(decoder_)
sess = onnxruntime.InferenceSession("tmp_decoder.onnx")
os.remove("tmp_decoder.onnx")
return sess
def extract_joiner(onnx_model: onnx.ModelProto):
joiner_ = extract_sub_model(
onnx_model,
["joiner/encoder_out", "joiner/decoder_out"],
["joiner/logit"],
False,
)
onnx.save(joiner_, "tmp_joiner.onnx")
onnx.checker.check_model(joiner_)
sess = onnxruntime.InferenceSession("tmp_joiner.onnx")
os.remove("tmp_joiner.onnx")
return sess
@torch.no_grad()
def main():
args = get_parser().parse_args()
logging.info(vars(args))
model = torch.jit.load(args.jit_filename)
onnx_model = onnx.load(args.onnx_all_in_one_filename)
options = ort.SessionOptions()
options.inter_op_num_threads = 1
options.intra_op_num_threads = 1
logging.info("Test encoder")
encoder_session = extract_encoder(onnx_model)
test_encoder(model, encoder_session)
logging.info("Test decoder")
decoder_session = extract_decoder(onnx_model)
test_decoder(model, decoder_session)
logging.info("Test joiner")
joiner_session = extract_joiner(onnx_model)
test_joiner(model, joiner_session)
logging.info("Finished checking ONNX models")
if __name__ == "__main__":
torch.manual_seed(20220727)
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -27,10 +27,12 @@ You can use the following command to get the exported models:
Usage of this script:
./pruned_transducer_stateless3/jit_trace_pretrained.py \
./pruned_transducer_stateless3/onnx_pretrained.py \
--encoder-model-filename ./pruned_transducer_stateless3/exp/encoder.onnx \
--decoder-model-filename ./pruned_transducer_stateless3/exp/decoder.onnx \
--joiner-model-filename ./pruned_transducer_stateless3/exp/joiner.onnx \
--joiner-encoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_encoder_proj.onnx \
--joiner-decoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_decoder_proj.onnx \
--bpe-model ./data/lang_bpe_500/bpe.model \
/path/to/foo.wav \
/path/to/bar.wav
@ -59,21 +61,35 @@ def get_parser():
"--encoder-model-filename",
type=str,
required=True,
help="Path to the encoder torchscript model. ",
help="Path to the encoder onnx model. ",
)
parser.add_argument(
"--decoder-model-filename",
type=str,
required=True,
help="Path to the decoder torchscript model. ",
help="Path to the decoder onnx model. ",
)
parser.add_argument(
"--joiner-model-filename",
type=str,
required=True,
help="Path to the joiner torchscript model. ",
help="Path to the joiner onnx model. ",
)
parser.add_argument(
"--joiner-encoder-proj-model-filename",
type=str,
required=True,
help="Path to the joiner encoder_proj onnx model. ",
)
parser.add_argument(
"--joiner-decoder-proj-model-filename",
type=str,
required=True,
help="Path to the joiner decoder_proj onnx model. ",
)
parser.add_argument(
@ -136,6 +152,8 @@ def read_sound_files(
def greedy_search(
decoder: ort.InferenceSession,
joiner: ort.InferenceSession,
joiner_encoder_proj: ort.InferenceSession,
joiner_decoder_proj: ort.InferenceSession,
encoder_out: np.ndarray,
encoder_out_lens: np.ndarray,
context_size: int,
@ -146,6 +164,10 @@ def greedy_search(
The decoder model.
joiner:
The joiner model.
joiner_encoder_proj:
The joiner encoder projection model.
joiner_decoder_proj:
The joiner decoder projection model.
encoder_out:
A 3-D tensor of shape (N, T, C)
encoder_out_lens:
@ -167,6 +189,15 @@ def greedy_search(
enforce_sorted=False,
)
projected_encoder_out = joiner_encoder_proj.run(
[joiner_encoder_proj.get_outputs()[0].name],
{
joiner_encoder_proj.get_inputs()[
0
].name: packed_encoder_out.data.numpy()
},
)[0]
blank_id = 0 # hard-code to 0
batch_size_list = packed_encoder_out.batch_sizes.tolist()
@ -194,26 +225,31 @@ def greedy_search(
decoder_input_nodes[0].name: decoder_input.numpy(),
},
)[0].squeeze(1)
projected_decoder_out = joiner_decoder_proj.run(
[joiner_decoder_proj.get_outputs()[0].name],
{joiner_decoder_proj.get_inputs()[0].name: decoder_out},
)[0]
projected_decoder_out = torch.from_numpy(projected_decoder_out)
offset = 0
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = packed_encoder_out.data[start:end]
current_encoder_out = current_encoder_out
current_encoder_out = projected_encoder_out[start:end]
# current_encoder_out's shape: (batch_size, encoder_out_dim)
offset = end
decoder_out = decoder_out[:batch_size]
projected_decoder_out = projected_decoder_out[:batch_size]
logits = joiner.run(
[joiner_output_nodes[0].name],
{
joiner_input_nodes[0].name: current_encoder_out.numpy(),
joiner_input_nodes[1].name: decoder_out,
joiner_input_nodes[0].name: current_encoder_out,
joiner_input_nodes[1].name: projected_decoder_out.numpy(),
},
)[0]
logits = torch.from_numpy(logits)
logits = torch.from_numpy(logits).squeeze(1).squeeze(1)
# logits'shape (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
@ -236,6 +272,11 @@ def greedy_search(
decoder_input_nodes[0].name: decoder_input.numpy(),
},
)[0].squeeze(1)
projected_decoder_out = joiner_decoder_proj.run(
[joiner_decoder_proj.get_outputs()[0].name],
{joiner_decoder_proj.get_inputs()[0].name: decoder_out},
)[0]
projected_decoder_out = torch.from_numpy(projected_decoder_out)
sorted_ans = [h[context_size:] for h in hyps]
ans = []
@ -271,6 +312,16 @@ def main():
sess_options=session_opts,
)
joiner_encoder_proj = ort.InferenceSession(
args.joiner_encoder_proj_model_filename,
sess_options=session_opts,
)
joiner_decoder_proj = ort.InferenceSession(
args.joiner_decoder_proj_model_filename,
sess_options=session_opts,
)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
@ -315,6 +366,8 @@ def main():
hyps = greedy_search(
decoder=decoder,
joiner=joiner,
joiner_encoder_proj=joiner_encoder_proj,
joiner_decoder_proj=joiner_decoder_proj,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
context_size=args.context_size,

View File

@ -271,7 +271,7 @@ def main():
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params)
model = get_transducer_model(params, enable_giga=False)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")

Some files were not shown because too many files have changed in this diff Show More